~elementary-os/elementaryos/os-patch-onboard-trusty

« back to all changes in this revision

Viewing changes to Onboard/pypredict/lm/lm_dynamic_kn.h

  • Committer: RabbitBot
  • Date: 2014-08-31 20:00:45 UTC
  • Revision ID: rabbitbot@elementaryos.org-20140831200045-guqqu1s80isrm103
Initial import, version 1.0.0-0ubuntu4

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*
 
2
This program is free software: you can redistribute it and/or modify
 
3
it under the terms of the GNU General Public License as published by
 
4
the Free Software Foundation, either version 3 of the License, or
 
5
(at your option) any later version.
 
6
 
 
7
This program is distributed in the hope that it will be useful,
 
8
but WITHOUT ANY WARRANTY; without even the implied warranty of
 
9
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
10
GNU General Public License for more details.
 
11
 
 
12
You should have received a copy of the GNU General Public License
 
13
along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
14
 
 
15
Author: marmuta <marmvta@gmail.com>
 
16
*/
 
17
 
 
18
#ifndef LM_DYNAMIC_KN_H
 
19
#define LM_DYNAMIC_KN_H
 
20
 
 
21
#include <assert.h>
 
22
#include "lm_dynamic.h"
 
23
 
 
24
#pragma pack(2)
 
25
 
 
26
 
 
27
//------------------------------------------------------------------------
 
28
// BeforeLastNodeKN - second to last node of the ngram trie, bigram for order 3
 
29
//------------------------------------------------------------------------
 
30
template <class TBASE>
 
31
class BeforeLastNodeKNBase : public TBASE
 
32
{
 
33
    public:
 
34
        BeforeLastNodeKNBase(WordId wid = (WordId)-1)
 
35
        : TBASE(wid)
 
36
        {
 
37
            N1pxr = 0;
 
38
        }
 
39
        int get_N1pxr() {return N1pxr;}
 
40
 
 
41
    public:
 
42
        uint32_t N1pxr;    // number of word types wid-n+1 that precede wid-n+2..wid in the training data
 
43
};
 
44
 
 
45
//------------------------------------------------------------------------
 
46
// TrieNodeKN - node for all lower levels of the ngram trie, unigrams for order 3
 
47
//------------------------------------------------------------------------
 
48
template <class TBASE>
 
49
class TrieNodeKNBase : public TBASE
 
50
{
 
51
    public:
 
52
        TrieNodeKNBase(WordId wid = (WordId)-1)
 
53
        : TBASE(wid)
 
54
        {
 
55
            clear();
 
56
        }
 
57
 
 
58
        void clear()
 
59
        {
 
60
            N1pxr = 0;
 
61
            N1pxrx = 0;
 
62
            TBASE::clear();
 
63
        }
 
64
 
 
65
        int get_N1pxrx()
 
66
        {
 
67
            return N1pxrx;
 
68
        }
 
69
 
 
70
    public:
 
71
        // Nomenclature:
 
72
        // N1p: number of word types with count>=1 (1p=one plus)
 
73
        // x: word, free running variable over all word types wi
 
74
        // r: remainder, remaining part of the full ngram
 
75
        uint32_t N1pxr;    // number of word types wi-n+1 that precede
 
76
                           // wi-n+2..wi in the training data
 
77
        uint32_t N1pxrx;   // number of permutations around center part
 
78
};
 
79
 
 
80
//------------------------------------------------------------------------
 
81
// NGramTrieKN - root node of the ngram trie
 
82
//------------------------------------------------------------------------
 
83
template <class TNODE, class TBEFORELASTNODE, class TLASTNODE>
 
84
class NGramTrieKN : public NGramTrie<TNODE, TBEFORELASTNODE, TLASTNODE>
 
85
 
 
86
{
 
87
    private:
 
88
        typedef NGramTrie<TNODE, TBEFORELASTNODE, TLASTNODE> Base;
 
89
 
 
90
    public:
 
91
        NGramTrieKN(WordId wid = (WordId)-1)
 
92
        : Base(wid)
 
93
        {
 
94
        }
 
95
 
 
96
        int increment_node_count(BaseNode* node, const WordId* wids, int n,
 
97
                                  int increment);
 
98
 
 
99
        int get_N1pxr(BaseNode* node, int level);
 
100
        int get_N1pxrx(BaseNode* node, int level);
 
101
 
 
102
        void get_probs_kneser_ney_i(const std::vector<WordId>& history,
 
103
                                    const std::vector<WordId>& words,
 
104
                                    std::vector<double>& vp,
 
105
                                    int num_word_types,
 
106
                                    const std::vector<double>& Ds);
 
107
};
 
108
 
 
109
// Add increment to node->count and incrementally update kneser-ney counts
 
110
template <class TNODE, class TBEFORELASTNODE, class TLASTNODE>
 
111
int NGramTrieKN<TNODE, TBEFORELASTNODE, TLASTNODE>::
 
112
    increment_node_count(BaseNode* node, const WordId* wids, int n,
 
113
                         int increment)
 
114
{
 
115
    // only the first time for each ngram
 
116
    if (increment && node->count == 0)
 
117
    {
 
118
        // get/add node for ngram (wids) excluding predecessor
 
119
        // ex: ngram = ["We", "saw"] -> wxr = ["saw"] with predecessor "We"
 
120
        // Predecessors exist for unigrams or greater, predecessor of unigrams
 
121
        // are all unigrams. In that case use the root to store N1pxr.
 
122
        std::vector<WordId> wxr(wids+1, wids+n);
 
123
        BaseNode *nd = this->add_node(wxr);
 
124
        if (!nd)
 
125
            return -1;
 
126
        ((TBEFORELASTNODE*)nd)->N1pxr++; // count number of word types wid-n+1
 
127
                                         // that precede wid-n+2..wid in the
 
128
                                         // training data
 
129
 
 
130
        // get/add node for ngram (wids) excluding predecessor and successor
 
131
        // ex: ngram = ["We", "saw", "whales"] -> wxrx = ["saw"]
 
132
        //     with predecessor "We" and successor "whales"
 
133
        // Predecessors and successors exist for bigrams or greater. wxrx is
 
134
        // an empty vector for bigrams. In that case use the root to store N1pxrx.
 
135
        if (n >= 2)
 
136
        {
 
137
            std::vector<WordId> wxrx(wids+1, wids+n-1);
 
138
            BaseNode* nd = this->add_node(wxrx);
 
139
            if (!nd)
 
140
                return -1;
 
141
            ((TNODE*)nd)->N1pxrx++;  // count number of word types wid-n+1 that precede wid-n+2..wid in the training data
 
142
        }
 
143
    }
 
144
 
 
145
    return Base::increment_node_count(node, wids, n, increment);
 
146
}
 
147
 
 
148
template <class TNODE, class TBEFORELASTNODE, class TLASTNODE>
 
149
int NGramTrieKN<TNODE, TBEFORELASTNODE, TLASTNODE>::
 
150
    get_N1pxr(BaseNode* node, int level)
 
151
{
 
152
    if (level == this->order)
 
153
        return 0;
 
154
    if (level == this->order - 1)
 
155
        return static_cast<TBEFORELASTNODE*>(node)->N1pxr;
 
156
    return static_cast<TNODE*>(node)->N1pxr;
 
157
}
 
158
 
 
159
template <class TNODE, class TBEFORELASTNODE, class TLASTNODE>
 
160
int NGramTrieKN<TNODE, TBEFORELASTNODE, TLASTNODE>::
 
161
    get_N1pxrx(BaseNode* node, int level)
 
162
{
 
163
    if (level == this->order)
 
164
        return 0;
 
165
    if (level == this->order - 1)
 
166
        return 0;
 
167
    return static_cast<TNODE*>(node)->get_N1pxrx();
 
168
}
 
169
 
 
170
// kneser-ney smoothed probabilities
 
171
template <class TNODE, class TBEFORELASTNODE, class TLASTNODE>
 
172
void NGramTrieKN<TNODE, TBEFORELASTNODE, TLASTNODE>::
 
173
     get_probs_kneser_ney_i(const std::vector<WordId>& history,
 
174
                            const std::vector<WordId>& words,
 
175
                            std::vector<double>& vp,
 
176
                            int num_word_types,
 
177
                            const std::vector<double>& Ds)
 
178
{
 
179
    // only fixed history size allowed; don't remove unknown words
 
180
    // from the history, mark them with UNKNOWN_WORD_ID instead.
 
181
    ASSERT((int)history.size() == order-1);
 
182
 
 
183
    int i,j;
 
184
    int n = history.size() + 1;
 
185
    int size = words.size();   // number of candidate words
 
186
    std::vector<int32_t> vc(size);  // vector of counts, reused for order 1..n
 
187
 
 
188
    // order 0
 
189
    vp.resize(size);
 
190
    fill(vp.begin(), vp.end(), 1.0/num_word_types); // uniform distribution
 
191
 
 
192
    // order 1..n
 
193
    for(j=0; j<n; j++)
 
194
    {
 
195
        std::vector<WordId> h(history.begin()+(n-j-1), history.end()); // tmp history
 
196
        BaseNode* hnode = this->get_node(h);
 
197
        if (hnode)
 
198
        {
 
199
            int N1prx = this->get_N1prx(hnode, j);   // number of word types following the history
 
200
            if (!N1prx)  // break early, don't reset probabilities to 0
 
201
                break;   // for unknown histories
 
202
 
 
203
            // orders 1..n-1
 
204
            if (j < n-1)
 
205
            {
 
206
                // Exclude children without predecessor from the count of
 
207
                // successors. This corrects normalization errors for the case
 
208
                // that the language model wasn't trained from a single
 
209
                // continous stream of tokens, i.e. some tokens don't have
 
210
                // successors. This happenes by default with the predefined
 
211
                // control words <unk>, <s>, ..., but can also happen when
 
212
                // incrementally adding text fragments to a language model.
 
213
                int num_children = this->get_num_children(hnode, j);
 
214
                for(i=0; i<num_children; i++)
 
215
                {
 
216
                    // children here may be of type TrieNode or BeforeLastNode,
 
217
                    // play safe and cast to the latter.
 
218
                    TBEFORELASTNODE* child = static_cast<TBEFORELASTNODE*>
 
219
                                    (this->get_child_at(hnode, j, i));
 
220
 
 
221
                    if (child->get_N1pxr() == 0)  // no predecessors?
 
222
                        N1prx--;  // exclude it from the count of successors
 
223
                }
 
224
 
 
225
                // number of permutations around history h
 
226
                int N1pxrx = get_N1pxrx(hnode, j);
 
227
                if (N1pxrx)
 
228
                {
 
229
                    // get number of word types seen to precede history h
 
230
                    if (h.size() == 0) // empty history?
 
231
                    {
 
232
                        // We're at the root and there are many children, all
 
233
                        // unigrams to be accurate. So the number of child nodes
 
234
                        // is >= the number of candidate words.
 
235
                        // Luckily a childs word_id can be directly looked up
 
236
                        // in the unigrams because they are always sorted by word_id
 
237
                        // as well. -> take that shortcut for root.
 
238
                        for(i=0; i<size; i++)
 
239
                        {
 
240
                            //printf("%d %d %d %d %d\n", size, j, i, words[i], (int)ngrams.children.size());
 
241
                            TNODE* node = static_cast<TNODE*>(this->children[words[i]]);
 
242
                            vc[i] = node->N1pxr;
 
243
                        }
 
244
                    }
 
245
                    else
 
246
                    {
 
247
                        // We're at some level > 0 and very likely there are much
 
248
                        // less child nodes than candidate words. E.g. everything
 
249
                        // from bigrams up has in all likelihood only few children.
 
250
                        // -> Turn the algorithm around and search the child nodes
 
251
                        // in the candidate words.
 
252
                        fill(vc.begin(), vc.end(), 0);
 
253
                        int num_children = this->get_num_children(hnode, j);
 
254
                        for(i=0; i<num_children; i++)
 
255
                        {
 
256
                            // children here may be of type TrieNode or BeforeLastNode,
 
257
                            // play safe and cast to the latter.
 
258
                            TBEFORELASTNODE* child = static_cast<TBEFORELASTNODE*>
 
259
                                            (this->get_child_at(hnode, j, i));
 
260
 
 
261
                            // word_indices have to be sorted by index
 
262
                            int index = binsearch(words, child->word_id);
 
263
                            if (index != -1)
 
264
                                vc[index] = child->N1pxr;
 
265
                        }
 
266
                    }
 
267
 
 
268
                    double D = Ds[j];
 
269
                    double l1 = D / float(N1pxrx) * N1prx; // normalization factor
 
270
                                                           // 1 - lambda
 
271
                    for(i=0; i<size; i++)
 
272
                    {
 
273
                        double a = vc[i] - D;
 
274
                        if (a < 0)
 
275
                            a = 0;
 
276
                        vp[i] = a / N1pxrx + l1 * vp[i];
 
277
                    }
 
278
                }
 
279
 
 
280
            }
 
281
            // order n
 
282
            else
 
283
            {
 
284
                // total number of occurences of the history
 
285
                int cs = this->sum_child_counts(hnode, j);
 
286
                if (cs)
 
287
                {
 
288
                    // get ngram counts
 
289
                    fill(vc.begin(), vc.end(), 0);
 
290
                    int num_children = this->get_num_children(hnode, j);
 
291
                    for(i=0; i<num_children; i++)
 
292
                    {
 
293
                        BaseNode* child = this->get_child_at(hnode, j, i);
 
294
                        int index = binsearch(words, child->word_id); // word_indices have to be sorted by index
 
295
                        if (index >= 0)
 
296
                            vc[index] = child->get_count();
 
297
                    }
 
298
 
 
299
                    double D = Ds[j];
 
300
                    double l1 = D / float(cs) * N1prx; // normalization factor
 
301
                                                           // 1 - lambda
 
302
                    for(i=0; i<size; i++)
 
303
                    {
 
304
                        double a = vc[i] - D;
 
305
                        if (a < 0)
 
306
                            a = 0;
 
307
                        vp[i] = a / float(cs) + l1 * vp[i];
 
308
                    }
 
309
                }
 
310
            }
 
311
        }
 
312
    }
 
313
}
 
314
#pragma pack()
 
315
 
 
316
 
 
317
//------------------------------------------------------------------------
 
318
// DynamicModelKN - dynamically updatable language model with kneser-ney support
 
319
//------------------------------------------------------------------------
 
320
template <class TNGRAMS>
 
321
class _DynamicModelKN : public _DynamicModel<TNGRAMS>
 
322
{
 
323
    public:
 
324
        typedef _DynamicModel<TNGRAMS> Base;
 
325
 
 
326
        static const Smoothing DEFAULT_SMOOTHING = KNESER_NEY_I;
 
327
 
 
328
    public:
 
329
        _DynamicModelKN()
 
330
        {
 
331
            this->smoothing = DEFAULT_SMOOTHING;
 
332
        }
 
333
 
 
334
        virtual std::vector<Smoothing> get_smoothings()
 
335
        {
 
336
            std::vector<Smoothing> smoothings = Base::get_smoothings();
 
337
            smoothings.push_back(KNESER_NEY_I);
 
338
            return smoothings;
 
339
        }
 
340
 
 
341
        virtual void get_node_values(BaseNode* node, int level,
 
342
                                    std::vector<int>& values)
 
343
        {
 
344
            Base::get_node_values(node, level, values);
 
345
            values.push_back(this->ngrams.get_N1pxrx(node, level));
 
346
            values.push_back(this->ngrams.get_N1pxr(node, level));
 
347
        }
 
348
 
 
349
    protected:
 
350
        virtual void get_probs(const std::vector<WordId>& history,
 
351
                                    const std::vector<WordId>& words,
 
352
                                    std::vector<double>& probabilities);
 
353
 
 
354
    private:
 
355
        virtual int increment_node_count(BaseNode* node, const WordId* wids,
 
356
                                         int n, int increment)
 
357
        {return this->ngrams.increment_node_count(node, wids, n, increment);}
 
358
};
 
359
 
 
360
typedef _DynamicModelKN<NGramTrieKN<TrieNode<TrieNodeKNBase<BaseNode> >,
 
361
                                  BeforeLastNode<BeforeLastNodeKNBase<BaseNode>,
 
362
                                                 LastNode<BaseNode> >,
 
363
                                  LastNode<BaseNode> > > DynamicModelKN;
 
364
 
 
365
// Calculate a vector of probabilities for the ngrams formed
 
366
// by history + word[i], for all i.
 
367
// input:  constant history and a vector of candidate words
 
368
// output: vector of probabilities, one value per candidate word
 
369
template <class TNGRAMS>
 
370
void _DynamicModelKN<TNGRAMS>::get_probs(const std::vector<WordId>& history,
 
371
                                         const std::vector<WordId>& words,
 
372
                                         std::vector<double>& probabilities)
 
373
{
 
374
    // pad/cut history so it's always of length order-1
 
375
    int n = std::min((int)history.size(), this->order-1);
 
376
    std::vector<WordId> h(this->order-1, UNKNOWN_WORD_ID);
 
377
    copy_backward(history.end()-n, history.end(), h.end());
 
378
 
 
379
    switch(this->smoothing)
 
380
    {
 
381
        case KNESER_NEY_I:
 
382
            this->ngrams.get_probs_kneser_ney_i(h, words, probabilities,
 
383
                                          this->get_num_word_types(), this->Ds);
 
384
            break;
 
385
 
 
386
        default:
 
387
            Base::get_probs(history, words, probabilities);
 
388
            break;
 
389
    }
 
390
}
 
391
 
 
392
#endif