2
This program is free software: you can redistribute it and/or modify
3
it under the terms of the GNU General Public License as published by
4
the Free Software Foundation, either version 3 of the License, or
5
(at your option) any later version.
7
This program is distributed in the hope that it will be useful,
8
but WITHOUT ANY WARRANTY; without even the implied warranty of
9
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10
GNU General Public License for more details.
12
You should have received a copy of the GNU General Public License
13
along with this program. If not, see <http://www.gnu.org/licenses/>.
15
Author: marmuta <marmvta@gmail.com>
18
#ifndef LM_DYNAMIC_KN_H
19
#define LM_DYNAMIC_KN_H
22
#include "lm_dynamic.h"
27
//------------------------------------------------------------------------
28
// BeforeLastNodeKN - second to last node of the ngram trie, bigram for order 3
29
//------------------------------------------------------------------------
30
template <class TBASE>
31
class BeforeLastNodeKNBase : public TBASE
34
BeforeLastNodeKNBase(WordId wid = (WordId)-1)
39
int get_N1pxr() {return N1pxr;}
42
uint32_t N1pxr; // number of word types wid-n+1 that precede wid-n+2..wid in the training data
45
//------------------------------------------------------------------------
46
// TrieNodeKN - node for all lower levels of the ngram trie, unigrams for order 3
47
//------------------------------------------------------------------------
48
template <class TBASE>
49
class TrieNodeKNBase : public TBASE
52
TrieNodeKNBase(WordId wid = (WordId)-1)
72
// N1p: number of word types with count>=1 (1p=one plus)
73
// x: word, free running variable over all word types wi
74
// r: remainder, remaining part of the full ngram
75
uint32_t N1pxr; // number of word types wi-n+1 that precede
76
// wi-n+2..wi in the training data
77
uint32_t N1pxrx; // number of permutations around center part
80
//------------------------------------------------------------------------
81
// NGramTrieKN - root node of the ngram trie
82
//------------------------------------------------------------------------
83
template <class TNODE, class TBEFORELASTNODE, class TLASTNODE>
84
class NGramTrieKN : public NGramTrie<TNODE, TBEFORELASTNODE, TLASTNODE>
88
typedef NGramTrie<TNODE, TBEFORELASTNODE, TLASTNODE> Base;
91
NGramTrieKN(WordId wid = (WordId)-1)
96
int increment_node_count(BaseNode* node, const WordId* wids, int n,
99
int get_N1pxr(BaseNode* node, int level);
100
int get_N1pxrx(BaseNode* node, int level);
102
void get_probs_kneser_ney_i(const std::vector<WordId>& history,
103
const std::vector<WordId>& words,
104
std::vector<double>& vp,
106
const std::vector<double>& Ds);
109
// Add increment to node->count and incrementally update kneser-ney counts
110
template <class TNODE, class TBEFORELASTNODE, class TLASTNODE>
111
int NGramTrieKN<TNODE, TBEFORELASTNODE, TLASTNODE>::
112
increment_node_count(BaseNode* node, const WordId* wids, int n,
115
// only the first time for each ngram
116
if (increment && node->count == 0)
118
// get/add node for ngram (wids) excluding predecessor
119
// ex: ngram = ["We", "saw"] -> wxr = ["saw"] with predecessor "We"
120
// Predecessors exist for unigrams or greater, predecessor of unigrams
121
// are all unigrams. In that case use the root to store N1pxr.
122
std::vector<WordId> wxr(wids+1, wids+n);
123
BaseNode *nd = this->add_node(wxr);
126
((TBEFORELASTNODE*)nd)->N1pxr++; // count number of word types wid-n+1
127
// that precede wid-n+2..wid in the
130
// get/add node for ngram (wids) excluding predecessor and successor
131
// ex: ngram = ["We", "saw", "whales"] -> wxrx = ["saw"]
132
// with predecessor "We" and successor "whales"
133
// Predecessors and successors exist for bigrams or greater. wxrx is
134
// an empty vector for bigrams. In that case use the root to store N1pxrx.
137
std::vector<WordId> wxrx(wids+1, wids+n-1);
138
BaseNode* nd = this->add_node(wxrx);
141
((TNODE*)nd)->N1pxrx++; // count number of word types wid-n+1 that precede wid-n+2..wid in the training data
145
return Base::increment_node_count(node, wids, n, increment);
148
template <class TNODE, class TBEFORELASTNODE, class TLASTNODE>
149
int NGramTrieKN<TNODE, TBEFORELASTNODE, TLASTNODE>::
150
get_N1pxr(BaseNode* node, int level)
152
if (level == this->order)
154
if (level == this->order - 1)
155
return static_cast<TBEFORELASTNODE*>(node)->N1pxr;
156
return static_cast<TNODE*>(node)->N1pxr;
159
template <class TNODE, class TBEFORELASTNODE, class TLASTNODE>
160
int NGramTrieKN<TNODE, TBEFORELASTNODE, TLASTNODE>::
161
get_N1pxrx(BaseNode* node, int level)
163
if (level == this->order)
165
if (level == this->order - 1)
167
return static_cast<TNODE*>(node)->get_N1pxrx();
170
// kneser-ney smoothed probabilities
171
template <class TNODE, class TBEFORELASTNODE, class TLASTNODE>
172
void NGramTrieKN<TNODE, TBEFORELASTNODE, TLASTNODE>::
173
get_probs_kneser_ney_i(const std::vector<WordId>& history,
174
const std::vector<WordId>& words,
175
std::vector<double>& vp,
177
const std::vector<double>& Ds)
179
// only fixed history size allowed; don't remove unknown words
180
// from the history, mark them with UNKNOWN_WORD_ID instead.
181
ASSERT((int)history.size() == order-1);
184
int n = history.size() + 1;
185
int size = words.size(); // number of candidate words
186
std::vector<int32_t> vc(size); // vector of counts, reused for order 1..n
190
fill(vp.begin(), vp.end(), 1.0/num_word_types); // uniform distribution
195
std::vector<WordId> h(history.begin()+(n-j-1), history.end()); // tmp history
196
BaseNode* hnode = this->get_node(h);
199
int N1prx = this->get_N1prx(hnode, j); // number of word types following the history
200
if (!N1prx) // break early, don't reset probabilities to 0
201
break; // for unknown histories
206
// Exclude children without predecessor from the count of
207
// successors. This corrects normalization errors for the case
208
// that the language model wasn't trained from a single
209
// continous stream of tokens, i.e. some tokens don't have
210
// successors. This happenes by default with the predefined
211
// control words <unk>, <s>, ..., but can also happen when
212
// incrementally adding text fragments to a language model.
213
int num_children = this->get_num_children(hnode, j);
214
for(i=0; i<num_children; i++)
216
// children here may be of type TrieNode or BeforeLastNode,
217
// play safe and cast to the latter.
218
TBEFORELASTNODE* child = static_cast<TBEFORELASTNODE*>
219
(this->get_child_at(hnode, j, i));
221
if (child->get_N1pxr() == 0) // no predecessors?
222
N1prx--; // exclude it from the count of successors
225
// number of permutations around history h
226
int N1pxrx = get_N1pxrx(hnode, j);
229
// get number of word types seen to precede history h
230
if (h.size() == 0) // empty history?
232
// We're at the root and there are many children, all
233
// unigrams to be accurate. So the number of child nodes
234
// is >= the number of candidate words.
235
// Luckily a childs word_id can be directly looked up
236
// in the unigrams because they are always sorted by word_id
237
// as well. -> take that shortcut for root.
238
for(i=0; i<size; i++)
240
//printf("%d %d %d %d %d\n", size, j, i, words[i], (int)ngrams.children.size());
241
TNODE* node = static_cast<TNODE*>(this->children[words[i]]);
247
// We're at some level > 0 and very likely there are much
248
// less child nodes than candidate words. E.g. everything
249
// from bigrams up has in all likelihood only few children.
250
// -> Turn the algorithm around and search the child nodes
251
// in the candidate words.
252
fill(vc.begin(), vc.end(), 0);
253
int num_children = this->get_num_children(hnode, j);
254
for(i=0; i<num_children; i++)
256
// children here may be of type TrieNode or BeforeLastNode,
257
// play safe and cast to the latter.
258
TBEFORELASTNODE* child = static_cast<TBEFORELASTNODE*>
259
(this->get_child_at(hnode, j, i));
261
// word_indices have to be sorted by index
262
int index = binsearch(words, child->word_id);
264
vc[index] = child->N1pxr;
269
double l1 = D / float(N1pxrx) * N1prx; // normalization factor
271
for(i=0; i<size; i++)
273
double a = vc[i] - D;
276
vp[i] = a / N1pxrx + l1 * vp[i];
284
// total number of occurences of the history
285
int cs = this->sum_child_counts(hnode, j);
289
fill(vc.begin(), vc.end(), 0);
290
int num_children = this->get_num_children(hnode, j);
291
for(i=0; i<num_children; i++)
293
BaseNode* child = this->get_child_at(hnode, j, i);
294
int index = binsearch(words, child->word_id); // word_indices have to be sorted by index
296
vc[index] = child->get_count();
300
double l1 = D / float(cs) * N1prx; // normalization factor
302
for(i=0; i<size; i++)
304
double a = vc[i] - D;
307
vp[i] = a / float(cs) + l1 * vp[i];
317
//------------------------------------------------------------------------
318
// DynamicModelKN - dynamically updatable language model with kneser-ney support
319
//------------------------------------------------------------------------
320
template <class TNGRAMS>
321
class _DynamicModelKN : public _DynamicModel<TNGRAMS>
324
typedef _DynamicModel<TNGRAMS> Base;
326
static const Smoothing DEFAULT_SMOOTHING = KNESER_NEY_I;
331
this->smoothing = DEFAULT_SMOOTHING;
334
virtual std::vector<Smoothing> get_smoothings()
336
std::vector<Smoothing> smoothings = Base::get_smoothings();
337
smoothings.push_back(KNESER_NEY_I);
341
virtual void get_node_values(BaseNode* node, int level,
342
std::vector<int>& values)
344
Base::get_node_values(node, level, values);
345
values.push_back(this->ngrams.get_N1pxrx(node, level));
346
values.push_back(this->ngrams.get_N1pxr(node, level));
350
virtual void get_probs(const std::vector<WordId>& history,
351
const std::vector<WordId>& words,
352
std::vector<double>& probabilities);
355
virtual int increment_node_count(BaseNode* node, const WordId* wids,
356
int n, int increment)
357
{return this->ngrams.increment_node_count(node, wids, n, increment);}
360
typedef _DynamicModelKN<NGramTrieKN<TrieNode<TrieNodeKNBase<BaseNode> >,
361
BeforeLastNode<BeforeLastNodeKNBase<BaseNode>,
362
LastNode<BaseNode> >,
363
LastNode<BaseNode> > > DynamicModelKN;
365
// Calculate a vector of probabilities for the ngrams formed
366
// by history + word[i], for all i.
367
// input: constant history and a vector of candidate words
368
// output: vector of probabilities, one value per candidate word
369
template <class TNGRAMS>
370
void _DynamicModelKN<TNGRAMS>::get_probs(const std::vector<WordId>& history,
371
const std::vector<WordId>& words,
372
std::vector<double>& probabilities)
374
// pad/cut history so it's always of length order-1
375
int n = std::min((int)history.size(), this->order-1);
376
std::vector<WordId> h(this->order-1, UNKNOWN_WORD_ID);
377
copy_backward(history.end()-n, history.end(), h.end());
379
switch(this->smoothing)
382
this->ngrams.get_probs_kneser_ney_i(h, words, probabilities,
383
this->get_num_word_types(), this->Ds);
387
Base::get_probs(history, words, probabilities);