3
/////////////////////////////////////////////////////////////////////////////
5
// Copyright (c) 1999-2002 David Ward
7
/////////////////////////////////////////////////////////////////////////////
11
#include "PPMLanguageModel.h"
13
using namespace Dasher;
16
// static TCHAR debug[256];
17
typedef unsigned long ulong;
19
////////////////////////////////////////////////////////////////////////
20
/// PPMnode definitions
21
////////////////////////////////////////////////////////////////////////
23
CPPMLanguageModel::CPPMnode *CPPMLanguageModel::CPPMnode::find_symbol(int sym)
24
// see if symbol is a child of node
26
// printf("finding symbol %d at node %d\n",sym,node->id);
27
CPPMnode *found=child;
29
if (found->symbol==sym)
37
CPPMLanguageModel::CPPMnode * CPPMLanguageModel::CPPMnode::add_symbol_to_node(int sym,int *update)
39
CPPMnode *born,*search;
40
search=find_symbol(sym);
42
born = new CPPMnode(sym);
48
if (*update) { // perform update exclusions
58
/////////////////////////////////////////////////////////////////////
59
// CPPMLanguageModel defs
60
/////////////////////////////////////////////////////////////////////
62
CPPMLanguageModel::CPPMLanguageModel(CAlphabet *_alphabet,int _normalization)
63
: CLanguageModel(_alphabet,_normalization)
65
root=new CPPMnode(-1);
66
m_rootcontext=new CPPMContext(root,0);
70
CPPMLanguageModel::~CPPMLanguageModel()
75
// A non-recursive node deletion algorithm using a stack
76
std::stack<CPPMnode*> deletenodes;
77
deletenodes.push(root);
78
while (!deletenodes.empty())
80
CPPMnode* temp = deletenodes.top();
89
deletenodes.push(temp->child);
102
bool CPPMLanguageModel::GetProbs(CContext *context,vector<unsigned int> &probs, int norm)
103
// get the probability distribution at the context
105
// seems like we have to have this hack for VC++
106
CPPMContext *ppmcontext=static_cast<CPPMContext *> (context);
109
int modelchars=GetNumberModelChars();
110
// int norm=CLanguageModel::normalization();
111
probs.resize( GetNumberModelChars() );
112
for( vector<unsigned int>::iterator it( probs.begin() ); it != probs.end(); ++it )
115
vector<bool> exclusions( probs.size() );
116
for( vector<bool>::iterator it( exclusions.begin() ); it != exclusions.end(); ++it )
119
vector<bool> valid( probs.size() );
120
for( int i(0); i < valid.size(); ++i )
121
valid[i] = isRealSymbol( i );
129
temp=ppmcontext->head;
141
if (!exclusions[sym] && valid[sym]) {
142
if( sym == GetControlSymbol() ) {
145
else if( sym == GetSpaceSymbol() ) {
146
total=total+s->count;
148
controlcount = 0.4 * s->count; // FIXME - and here
150
if( controlcount < 1 )
153
if( GetControlSymbol() != -1 )
154
total = total + controlcount;
158
total=total+s->count;
164
size_of_slice=tospend;
167
if (!exclusions[s->symbol] && valid[s->symbol]) {
168
// exclusions[s->symbol]=1;
169
if( s->symbol == GetControlSymbol() ) {
172
else if( s->symbol == GetSpaceSymbol() ) {
173
ulong p=((size_of_slice/2)/ulong(total))*(2*s->count-1);
176
exclusions[s->symbol]=1;
177
if( GetControlSymbol() != -1 )
178
if( !exclusions[GetControlSymbol()] ) {
179
ulong p=((size_of_slice/2)/ulong(total))*(2*controlcount-1);
180
probs[GetControlSymbol()]+=p;
182
exclusions[GetControlSymbol()]=1;
186
ulong p=((size_of_slice/2)/ulong(total))*(2*s->count-1);
189
exclusions[s->symbol]=1;
197
// Usprintf(debug,TEXT("Norm %u tospend %u\n"),Norm,tospend);
198
// DebugOutput(debug);
200
size_of_slice=tospend;
202
for (sym=0;sym<modelchars;sym++)
203
if (!probs[sym] && valid[sym])
205
for (sym=0;sym<modelchars;sym++)
206
if (!probs[sym] && valid[sym]) {
207
ulong p=size_of_slice/symbolsleft;
212
// distribute what's left evenly
215
// int current_symbol(0);
216
// while( tospend > 0 )
218
// if( valid[current_symbol] ) {
219
// probs[current_symbol] += 1;
224
// if( current_symbol == modelchars )
225
// current_symbol = 0;
228
int valid_char_count(0);
230
for( int i(0); i < modelchars; ++i )
235
for (int i(0);i<modelchars;++i)
237
ulong p=tospend/(valid_char_count);
243
// ulong p=tospend/(modelchars-sym);
248
// Usprintf(debug,TEXT("finaltospend %u\n"),tospend);
249
// DebugOutput(debug);
251
// free(exclusions); // !!!
252
// !!! NB by IAM: p577 Stroustrup 3rd Edition: "Allocating an object using new and deleting it using free() is asking for trouble"
253
// delete[] exclusions;
258
void CPPMLanguageModel::AddSymbol(CPPMLanguageModel::CPPMContext &context,int symbol)
259
// add symbol to the context
260
// creates new nodes, updates counts
261
// and leaves 'context' at the new context
264
if (symbol==0 || symbol>=GetNumberModelChars())
267
CPPMnode *vineptr,*temp;
270
temp=context.head->vine;
271
context.head=context.head->add_symbol_to_node(symbol,&updatecnt);
272
vineptr=context.head;
276
vineptr->vine=temp->add_symbol_to_node(symbol,&updatecnt);
277
vineptr=vineptr->vine;
281
if (context.order>MAX_ORDER){
282
context.head=context.head->vine;
288
// update context with symbol 'Symbol'
289
void CPPMLanguageModel::EnterSymbol(CContext* Context, modelchar Symbol)
291
CPPMLanguageModel::CPPMContext& context = * static_cast<CPPMContext *> (Context);
294
CPPMnode *temp=context.head;
296
while (context.head) {
297
find =context.head->find_symbol(Symbol);
301
// Usprintf(debug,TEXT("found context %x order %d\n"),head,order);
302
// DebugOutput(debug);
306
context.head=context.head->vine;
309
if (context.head==0) {
317
void CPPMLanguageModel::LearnSymbol(CContext* Context, modelchar Symbol)
319
CPPMLanguageModel::CPPMContext& context = * static_cast<CPPMContext *> (Context);
320
AddSymbol(context, Symbol);
324
void CPPMLanguageModel::dumpSymbol(int symbol)
326
if ((symbol <= 32) || (symbol >= 127))
327
printf( "<%d>", symbol );
329
printf( "%c", symbol );
333
void CPPMLanguageModel::dumpString( char *str, int pos, int len )
334
// Dump the string STR starting at position POS
338
for (p = pos; p<pos+len; p++) {
340
if ((cc <= 31) || (cc >= 127))
341
printf( "<%d>", cc );
348
void CPPMLanguageModel::dumpTrie( CPPMLanguageModel::CPPMnode *t, int d )
349
// diagnostic display of the PPM trie from node t and deeper
356
Usprintf( debug,TEXT("%5d %7x "), d, t );
357
//TODO: Uncomment this when headers sort out
358
//DebugOutput(debug);
359
if (t < 0) // pointer to input
362
Usprintf(debug,TEXT( " %3d %5d %7x %7x %7x <"), t->symbol,t->count, t->vine, t->child, t->next );
363
//TODO: Uncomment this when headers sort out
364
//DebugOutput(debug);
367
dumpString( dumpTrieStr, 0, d );
368
Usprintf( debug,TEXT(">\n") );
369
//TODO: Uncomment this when headers sort out
370
//DebugOutput(debug);
376
dumpTrieStr [d] = sym;
385
void CPPMLanguageModel::dump()
386
// diagnostic display of the whole PPM trie
391
Usprintf(debug,TEXT( "Dump of Trie : \n" ));
392
//TODO: Uncomment this when headers sort out
393
//DebugOutput(debug);
394
Usprintf(debug,TEXT( "---------------\n" ));
395
//TODO: Uncomment this when headers sort out
396
//DebugOutput(debug);
397
Usprintf( debug,TEXT( "depth node symbol count vine child next context\n") );
398
//TODO: Uncomment this when headers sort out
399
//DebugOutput(debug);
401
Usprintf( debug,TEXT( "---------------\n" ));
402
//TODO: Uncomment this when headers sort out
403
//DebugOutput(debug);
404
Usprintf(debug,TEXT( "\n" ));
405
//TODO: Uncomment this when headers sort out
406
//DebugOutput(debug);