~ubuntu-branches/ubuntu/warty/dasher/warty

« back to all changes in this revision

Viewing changes to Src/DasherCore/PPMLanguageModel.cpp

  • Committer: Bazaar Package Importer
  • Author(s): Matthew Garrett
  • Date: 2003-06-05 11:10:04 UTC
  • Revision ID: james.westby@ubuntu.com-20030605111004-kqiutbrlvs7td9ic
Tags: upstream-3.2.10
ImportĀ upstreamĀ versionĀ 3.2.10

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
// PPMLanguageModel.h
 
2
//
 
3
/////////////////////////////////////////////////////////////////////////////
 
4
//
 
5
// Copyright (c) 1999-2002 David Ward
 
6
//
 
7
/////////////////////////////////////////////////////////////////////////////
 
8
 
 
9
#include <math.h>
 
10
#include <stack>
 
11
#include "PPMLanguageModel.h"
 
12
 
 
13
using namespace Dasher;
 
14
using namespace std;
 
15
 
 
16
// static TCHAR debug[256];
 
17
typedef unsigned long ulong;
 
18
 
 
19
////////////////////////////////////////////////////////////////////////
 
20
/// PPMnode definitions 
 
21
////////////////////////////////////////////////////////////////////////
 
22
 
 
23
CPPMLanguageModel::CPPMnode *CPPMLanguageModel::CPPMnode::find_symbol(int sym)
 
24
// see if symbol is a child of node
 
25
{
 
26
        //  printf("finding symbol %d at node %d\n",sym,node->id);
 
27
        CPPMnode *found=child;
 
28
        while (found) {
 
29
                if (found->symbol==sym)
 
30
                        return found;
 
31
                found=found->next;
 
32
        }
 
33
        return 0;
 
34
}
 
35
 
 
36
 
 
37
CPPMLanguageModel::CPPMnode * CPPMLanguageModel::CPPMnode::add_symbol_to_node(int sym,int *update)
 
38
{
 
39
        CPPMnode *born,*search;
 
40
        search=find_symbol(sym);
 
41
        if (!search) {
 
42
                born = new CPPMnode(sym);
 
43
                born->next=child;
 
44
                child=born;
 
45
                //   node->count=1;
 
46
                return born;            
 
47
        } else {
 
48
                if (*update) {   // perform update exclusions
 
49
                        search->count++;
 
50
                        *update=0;
 
51
                }
 
52
                return search;
 
53
        }
 
54
        
 
55
}
 
56
 
 
57
 
 
58
/////////////////////////////////////////////////////////////////////
 
59
// CPPMLanguageModel defs
 
60
/////////////////////////////////////////////////////////////////////
 
61
 
 
62
CPPMLanguageModel::CPPMLanguageModel(CAlphabet *_alphabet,int _normalization)
 
63
        : CLanguageModel(_alphabet,_normalization)
 
64
{
 
65
        root=new CPPMnode(-1);
 
66
        m_rootcontext=new CPPMContext(root,0);
 
67
}
 
68
 
 
69
 
 
70
CPPMLanguageModel::~CPPMLanguageModel()
 
71
{
 
72
 
 
73
        delete m_rootcontext;
 
74
 
 
75
        // A non-recursive node deletion algorithm using a stack
 
76
        std::stack<CPPMnode*> deletenodes;
 
77
        deletenodes.push(root);
 
78
        while (!deletenodes.empty())
 
79
        {
 
80
                CPPMnode* temp = deletenodes.top();
 
81
                deletenodes.pop();
 
82
                CPPMnode* next;
 
83
                do      
 
84
                {
 
85
                        next = temp->next;
 
86
 
 
87
                        // push the child
 
88
                        if (temp->child)
 
89
                                deletenodes.push(temp->child);
 
90
                        
 
91
                        delete temp;
 
92
 
 
93
                        temp=next;
 
94
                        
 
95
                } while (temp !=0); 
 
96
 
 
97
        }
 
98
 
 
99
}
 
100
 
 
101
 
 
102
bool CPPMLanguageModel::GetProbs(CContext *context,vector<unsigned int> &probs, int norm)
 
103
        // get the probability distribution at the context
 
104
{
 
105
        // seems like we have to have this hack for VC++
 
106
        CPPMContext *ppmcontext=static_cast<CPPMContext *> (context);
 
107
        
 
108
        
 
109
        int modelchars=GetNumberModelChars();
 
110
        //      int norm=CLanguageModel::normalization();
 
111
        probs.resize( GetNumberModelChars() );
 
112
        for( vector<unsigned int>::iterator it( probs.begin() ); it != probs.end(); ++it )
 
113
          *it = 0;
 
114
 
 
115
        vector<bool> exclusions( probs.size() );
 
116
        for( vector<bool>::iterator it( exclusions.begin() ); it != exclusions.end(); ++it )
 
117
          *it = false;
 
118
 
 
119
        vector<bool> valid( probs.size() );
 
120
        for( int i(0); i < valid.size(); ++i )
 
121
          valid[i] = isRealSymbol( i );
 
122
        
 
123
        CPPMnode *temp,*s; 
 
124
        //      int loop,total;
 
125
        int sym; 
 
126
        ulong spent=0; 
 
127
        ulong size_of_slice;
 
128
        ulong tospend=norm;
 
129
        temp=ppmcontext->head;
 
130
 
 
131
        int total;
 
132
 
 
133
        while (temp!=0) {
 
134
 
 
135
          int controlcount;
 
136
 
 
137
                total=0;
 
138
                s=temp->child;
 
139
                while (s) {
 
140
                  sym=s->symbol; 
 
141
                  if (!exclusions[sym] && valid[sym]) {
 
142
                    if( sym == GetControlSymbol() ) {
 
143
                      // Do nothing
 
144
                    } 
 
145
                    else if( sym == GetSpaceSymbol() ) {
 
146
                      total=total+s->count;
 
147
                      
 
148
                      controlcount = 0.4 * s->count; // FIXME - and here
 
149
                      
 
150
                      if( controlcount < 1 )
 
151
                        controlcount = 1;
 
152
 
 
153
                      if( GetControlSymbol() != -1 )
 
154
                        total = total + controlcount;
 
155
 
 
156
                    }
 
157
                    else {
 
158
                      total=total+s->count;
 
159
                    }
 
160
                  }
 
161
                  s=s->next;
 
162
                }
 
163
                if (total) {
 
164
                  size_of_slice=tospend;
 
165
                  s=temp->child;
 
166
                  while (s) {
 
167
                    if (!exclusions[s->symbol] && valid[s->symbol]) {
 
168
                      //                      exclusions[s->symbol]=1; 
 
169
                      if( s->symbol == GetControlSymbol() ) {
 
170
                        // Do nothing
 
171
                      } 
 
172
                      else if( s->symbol == GetSpaceSymbol() ) {
 
173
                        ulong p=((size_of_slice/2)/ulong(total))*(2*s->count-1);
 
174
                        probs[s->symbol]+=p;
 
175
                        tospend-=p;
 
176
                        exclusions[s->symbol]=1;
 
177
                        if( GetControlSymbol() != -1 )
 
178
                          if( !exclusions[GetControlSymbol()] ) {
 
179
                            ulong p=((size_of_slice/2)/ulong(total))*(2*controlcount-1);
 
180
                            probs[GetControlSymbol()]+=p;
 
181
                            tospend-=p;
 
182
                            exclusions[GetControlSymbol()]=1;
 
183
                          }
 
184
                      }
 
185
                      else {
 
186
                        ulong p=((size_of_slice/2)/ulong(total))*(2*s->count-1);
 
187
                        probs[s->symbol]+=p;
 
188
                        tospend-=p;     
 
189
                        exclusions[s->symbol]=1;
 
190
                      }
 
191
                    }
 
192
                    s=s->next;
 
193
                  }
 
194
                }
 
195
                temp = temp->vine;
 
196
        }
 
197
        //      Usprintf(debug,TEXT("Norm %u tospend %u\n"),Norm,tospend);
 
198
        //      DebugOutput(debug);
 
199
        
 
200
        size_of_slice=tospend;
 
201
        int symbolsleft=0;
 
202
        for (sym=0;sym<modelchars;sym++)
 
203
          if (!probs[sym] && valid[sym])
 
204
            symbolsleft++;
 
205
        for (sym=0;sym<modelchars;sym++) 
 
206
          if (!probs[sym] && valid[sym]) {
 
207
            ulong p=size_of_slice/symbolsleft;
 
208
            probs[sym]+=p;
 
209
            tospend-=p;
 
210
          }
 
211
        
 
212
                        // distribute what's left evenly        
 
213
                //tospend+=uniform;
 
214
 
 
215
//      int current_symbol(0);
 
216
//      while( tospend > 0 )
 
217
//        {
 
218
//          if( valid[current_symbol] ) {
 
219
//              probs[current_symbol] += 1;
 
220
//              tospend -= 1;
 
221
//          }
 
222
 
 
223
//          ++current_symbol;
 
224
//          if( current_symbol == modelchars )
 
225
//            current_symbol = 0;
 
226
//        }
 
227
 
 
228
        int valid_char_count(0);
 
229
 
 
230
        for( int i(0); i < modelchars; ++i )
 
231
          if( valid[i] ) 
 
232
            ++valid_char_count;
 
233
          
 
234
        
 
235
        for (int i(0);i<modelchars;++i) 
 
236
          if( valid[i] ) {
 
237
            ulong p=tospend/(valid_char_count);
 
238
            probs[i] +=p;
 
239
            --valid_char_count;
 
240
            tospend -=p;
 
241
          }
 
242
//                        {
 
243
//                              ulong p=tospend/(modelchars-sym);
 
244
//                              probs[sym]+=p;
 
245
//                              tospend-=p;
 
246
//                        }
 
247
//                      }
 
248
                        //      Usprintf(debug,TEXT("finaltospend %u\n"),tospend);
 
249
                        //      DebugOutput(debug);
 
250
                        
 
251
                        // free(exclusions); // !!!
 
252
                        // !!! NB by IAM: p577 Stroustrup 3rd Edition: "Allocating an object using new and deleting it using free() is asking for trouble"
 
253
        //              delete[] exclusions;
 
254
                        return true;
 
255
}
 
256
 
 
257
 
 
258
void CPPMLanguageModel::AddSymbol(CPPMLanguageModel::CPPMContext &context,int symbol)
 
259
        // add symbol to the context
 
260
        // creates new nodes, updates counts
 
261
        // and leaves 'context' at the new context
 
262
{
 
263
        // sanity check
 
264
        if (symbol==0 || symbol>=GetNumberModelChars())
 
265
                return;
 
266
        
 
267
        CPPMnode *vineptr,*temp;
 
268
        int updatecnt=1;
 
269
        
 
270
        temp=context.head->vine;
 
271
        context.head=context.head->add_symbol_to_node(symbol,&updatecnt);
 
272
        vineptr=context.head;
 
273
        context.order++;
 
274
        
 
275
        while (temp!=0) {
 
276
                vineptr->vine=temp->add_symbol_to_node(symbol,&updatecnt);    
 
277
                vineptr=vineptr->vine;
 
278
                temp=temp->vine;
 
279
        }
 
280
        vineptr->vine=root;
 
281
        if (context.order>MAX_ORDER){
 
282
                context.head=context.head->vine;
 
283
                context.order--;
 
284
        }
 
285
}
 
286
 
 
287
 
 
288
// update context with symbol 'Symbol'
 
289
void CPPMLanguageModel::EnterSymbol(CContext* Context, modelchar Symbol)
 
290
{
 
291
        CPPMLanguageModel::CPPMContext& context = * static_cast<CPPMContext *> (Context);
 
292
        
 
293
        CPPMnode *find;
 
294
        CPPMnode *temp=context.head;
 
295
        
 
296
        while (context.head) {
 
297
                find =context.head->find_symbol(Symbol);
 
298
                if (find) {
 
299
                        context.order++;
 
300
                        context.head=find;
 
301
                        //      Usprintf(debug,TEXT("found context %x order %d\n"),head,order);
 
302
                        //      DebugOutput(debug);
 
303
                        return;
 
304
                }
 
305
                context.order--;
 
306
                context.head=context.head->vine;
 
307
        }
 
308
        
 
309
        if (context.head==0) {
 
310
                context.head=root;
 
311
                context.order=0;
 
312
        }
 
313
        
 
314
}
 
315
 
 
316
 
 
317
void CPPMLanguageModel::LearnSymbol(CContext* Context, modelchar Symbol)
 
318
{
 
319
        CPPMLanguageModel::CPPMContext& context = * static_cast<CPPMContext *> (Context);
 
320
        AddSymbol(context, Symbol);
 
321
}
 
322
 
 
323
 
 
324
void CPPMLanguageModel::dumpSymbol(int symbol)
 
325
{
 
326
        if ((symbol <= 32) || (symbol >= 127))
 
327
                printf( "<%d>", symbol );
 
328
        else
 
329
                printf( "%c", symbol );
 
330
}
 
331
 
 
332
 
 
333
void CPPMLanguageModel::dumpString( char *str, int pos, int len )
 
334
        // Dump the string STR starting at position POS
 
335
{
 
336
        char cc;
 
337
        int p;
 
338
        for (p = pos; p<pos+len; p++) {
 
339
                cc = str [p];
 
340
                if ((cc <= 31) || (cc >= 127))
 
341
                        printf( "<%d>", cc );
 
342
                else
 
343
                        printf( "%c", cc );
 
344
        }
 
345
}
 
346
 
 
347
 
 
348
void CPPMLanguageModel::dumpTrie( CPPMLanguageModel::CPPMnode *t, int d )
 
349
        // diagnostic display of the PPM trie from node t and deeper
 
350
{
 
351
//TODO
 
352
/*
 
353
        dchar debug[256];
 
354
        int sym;
 
355
        CPPMnode *s;
 
356
        Usprintf( debug,TEXT("%5d %7x "), d, t );
 
357
        //TODO: Uncomment this when headers sort out
 
358
        //DebugOutput(debug);
 
359
        if (t < 0) // pointer to input
 
360
                printf( "                     <" );
 
361
        else {
 
362
                Usprintf(debug,TEXT( " %3d %5d %7x  %7x  %7x    <"), t->symbol,t->count, t->vine, t->child, t->next );
 
363
                //TODO: Uncomment this when headers sort out
 
364
                //DebugOutput(debug);
 
365
        }
 
366
        
 
367
        dumpString( dumpTrieStr, 0, d );
 
368
        Usprintf( debug,TEXT(">\n") );
 
369
        //TODO: Uncomment this when headers sort out
 
370
        //DebugOutput(debug);
 
371
        if (t != 0) {
 
372
                s = t->child;
 
373
                while (s != 0) {
 
374
                        sym =s->symbol;
 
375
                        
 
376
                        dumpTrieStr [d] = sym;
 
377
                        dumpTrie( s, d+1 );
 
378
                        s = s->next;
 
379
                }
 
380
        }
 
381
*/
 
382
}
 
383
 
 
384
 
 
385
void CPPMLanguageModel::dump()
 
386
        // diagnostic display of the whole PPM trie
 
387
{
 
388
// TODO:
 
389
/*
 
390
        dchar debug[256];
 
391
        Usprintf(debug,TEXT(  "Dump of Trie : \n" ));
 
392
        //TODO: Uncomment this when headers sort out
 
393
        //DebugOutput(debug);
 
394
        Usprintf(debug,TEXT(   "---------------\n" ));
 
395
        //TODO: Uncomment this when headers sort out
 
396
        //DebugOutput(debug);
 
397
        Usprintf( debug,TEXT(  "depth node     symbol count  vine   child      next   context\n") );
 
398
        //TODO: Uncomment this when headers sort out
 
399
        //DebugOutput(debug);
 
400
        dumpTrie( root, 0 );
 
401
        Usprintf( debug,TEXT(  "---------------\n" ));
 
402
        //TODO: Uncomment this when headers sort out
 
403
        //DebugOutput(debug);
 
404
        Usprintf(debug,TEXT( "\n" ));
 
405
        //TODO: Uncomment this when headers sort out
 
406
        //DebugOutput(debug);
 
407
*/
 
408
}