2
2
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS HEADER.
4
4
* Copyright (c) 2007 Sun Microsystems, Inc. All Rights Reserved.
6
6
* The contents of this file are subject to the terms of either the GNU Lesser
7
7
* General Public License Version 2.1 only ("LGPL") or the Common Development and
8
8
* Distribution License ("CDDL")(collectively, the "License"). You may not use this
9
9
* file except in compliance with the License. You can obtain a copy of the CDDL at
10
10
* http://www.opensource.org/licenses/cddl1.php and a copy of the LGPLv2.1 at
11
* http://www.opensource.org/licenses/lgpl-license.php. See the License for the
11
* http://www.opensource.org/licenses/lgpl-license.php. See the License for the
12
12
* specific language governing permissions and limitations under the License. When
13
13
* distributing the software, include this License Header Notice in each file and
14
14
* include the full text of the License in the License file as well as the
15
15
* following notice:
17
17
* NOTICE PURSUANT TO SECTION 9 OF THE COMMON DEVELOPMENT AND DISTRIBUTION LICENSE
19
19
* For Covered Software in this distribution, this License shall be governed by the
64
TNodeInfo(double distance=0.0, int pos=0, bool children=0) : d(distance)
65
{ idx = pos; child = (children==0)?0:1; }
67
bool operator< (const TNodeInfo& r) const
68
{ return ((child ^ r.child) == 0)?(d < r.d):(child == 0); }
70
bool operator==(const TNodeInfo& r) const
71
{ return (child == r.child && d == r.d); }
64
TNodeInfo(double distance = 0.0, int pos = 0, bool children =
67
idx = pos; child = (children == 0) ? 0 : 1;
71
operator<(const TNodeInfo& r) const
73
return ((child ^ r.child) == 0) ? (d < r.d) : (child == 0);
77
operator==(const TNodeInfo& r) const
79
return(child == r.child && d == r.d);
74
83
class CSlmPruner : public CSIMSlm {
76
85
CSlmPruner() : CSIMSlm(), cut(NULL)
80
{ if (cut) delete [] cut; }
91
if (cut) delete [] cut;
82
94
void SetCut(int* nCut);
83
95
void SetReserve(int* nReserve);
95
107
double cache_PA, cache_PB;
98
void CSlmPruner::Prune()
100
113
printf("Erasing items using Entropy distance"); fflush(stdout);
101
for (int lvl=N; lvl>0; --lvl)
114
for (int lvl = N; lvl > 0; --lvl)
103
116
printf("\n"); fflush(stdout);
106
void CSlmPruner::Write(const char* filename)
120
CSlmPruner::Write(const char* filename)
108
122
FILE* out = fopen(filename, "wb");
109
123
fwrite(&N, sizeof(N), 1, out);
110
124
fwrite(&bUseLogPr, sizeof(bUseLogPr), 1, out);
111
fwrite(sz, sizeof(int), N+1, out);
112
for (int i=0; i<N; ++i) {
125
fwrite(sz, sizeof(int), N + 1, out);
126
for (int i = 0; i < N; ++i) {
113
127
fwrite(level[i], sizeof(TNode), sz[i], out);
115
129
fwrite(level[N], sizeof(TLeaf), sz[N], out);
119
void CSlmPruner::SetReserve(int* nReserve)
134
CSlmPruner::SetReserve(int* nReserve)
136
cut = new int [N + 1];
123
for (int lvl=1; lvl<=N; ++lvl) {
138
for (int lvl = 1; lvl <= N; ++lvl) {
124
139
cut[lvl] = sz[lvl] - 1 - nReserve[lvl];
125
140
if (cut[lvl] < 0) cut[lvl] = 0;
129
void CSlmPruner::SetCut(int* nCut)
145
CSlmPruner::SetCut(int* nCut)
147
cut = new int [N + 1];
133
for (int lvl=1; lvl<=N; ++lvl)
149
for (int lvl = 1; lvl <= N; ++lvl)
134
150
cut[lvl] = nCut[lvl];
137
153
template <class chIterator>
138
int CutLevel(CSIMSlm::TNode* pfirst, CSIMSlm::TNode* plast, chIterator chfirst, chIterator chlast, bool bUseLogPr)
155
CutLevel(CSIMSlm::TNode* pfirst,
156
CSIMSlm::TNode* plast,
140
int idxfirst, idxchk;
141
chIterator chchk = chfirst;
142
for (idxfirst=idxchk=0; chchk != chlast; ++chchk, ++idxchk) {
161
int idxfirst, idxchk;
162
chIterator chchk = chfirst;
163
for (idxfirst = idxchk = 0; chchk != chlast; ++chchk, ++idxchk) {
143
164
//cut item whoese pr == 1.0; and not psuedo tail
144
if (chchk->pr != ((bUseLogPr)?0.0:1.0) || (chchk+1) == chlast) {
165
if (chchk->pr != ((bUseLogPr) ? 0.0 : 1.0) || (chchk + 1) == chlast) {
145
166
if (idxfirst < idxchk) *chfirst = *chchk;
146
167
while (pfirst != plast && pfirst->child <= idxchk)
147
168
pfirst++->child = idxfirst;
155
void CSlmPruner::PruneLevel(int lvl)
177
CSlmPruner::PruneLevel(int lvl)
157
179
cache_level = cache_idx = -1;
159
181
if (cut[lvl] <= 0) {
160
printf("\n Level %d (%d items), no need to cut as your command!", lvl, sz[lvl]-1); fflush(stdout);
182
printf("\n Level %d (%d items), no need to cut as your command!",
184
sz[lvl] - 1); fflush(stdout);
164
printf("\n Level %d (%d items), allocating...", lvl, sz[lvl]-1); fflush(stdout);
188
printf("\n Level %d (%d items), allocating...", lvl, sz[lvl] - 1); fflush(
166
191
int n = sz[lvl] - 1; //do not count last psuedo tail
167
if (cut[lvl] >= n) cut[lvl] = n-1;
192
if (cut[lvl] >= n) cut[lvl] = n - 1;
168
193
TNodeInfo* pbuf = new TNodeInfo[n];
169
194
TSIMWordId hw[16]; // it should be lvl+1, yet some compiler do not support it
170
195
int idx[16]; // it should be lvl+1, yet some compiler do not support it
172
197
printf(", Calculating..."); fflush(stdout);
173
for (int i=0; i <=lvl; ++i)
198
for (int i = 0; i <= lvl; ++i)
175
200
while (idx[lvl] < n) {
177
hw[lvl] = (((TLeaf*)level[lvl])+idx[lvl])->id;
202
hw[lvl] = (((TLeaf*)level[lvl]) + idx[lvl])->id;
179
hw[lvl] = (((TNode*)level[lvl])+idx[lvl])->id;
204
hw[lvl] = (((TNode*)level[lvl]) + idx[lvl])->id;
181
for (int j=lvl-1; j >= 0; --j) {
182
TNode* pnode = ((TNode*)level[j])+idx[j];
183
for (; (pnode+1)->child <= idx[j+1]; ++pnode, ++idx[j])
206
for (int j = lvl - 1; j >= 0; --j) {
207
TNode* pnode = ((TNode*)level[j]) + idx[j];
208
for (; (pnode + 1)->child <= idx[j + 1]; ++pnode, ++idx[j])
185
210
hw[j] = pnode->id;
187
212
bool has_child = false;
189
214
TNode* pn = ((TNode*)level[lvl]) + idx[lvl];
190
if ((pn+1)->child > pn->child)
215
if ((pn + 1)->child > pn->child)
191
216
has_child = true;
193
pbuf[idx[lvl]].child = (has_child)?1:0;
218
pbuf[idx[lvl]].child = (has_child) ? 1 : 0;
194
219
pbuf[idx[lvl]].idx = idx[lvl];
196
221
pbuf[idx[lvl]].d = CalcDistance(lvl, idx, hw);
199
224
printf(", sorting...");
200
std::make_heap(pbuf, pbuf+n);
201
std::sort_heap(pbuf, pbuf+n);
225
std::make_heap(pbuf, pbuf + n);
226
std::sort_heap(pbuf, pbuf + n);
204
229
// because pr in model can not be 1.0, so we use this to mark a item to be prune
205
for (TNodeInfo* pinfo = pbuf; k < cut[lvl] && pinfo->child == 0; ++k, ++pinfo) {
230
for (TNodeInfo* pinfo = pbuf;
231
k < cut[lvl] && pinfo->child == 0;
208
(((TLeaf*)level[lvl]) + pinfo->idx)->pr = 0.0; // -log(1.0)
235
(((TLeaf*)level[lvl]) + pinfo->idx)->pr = 0.0; // -log(1.0)
210
237
(((TLeaf*)level[lvl]) + pinfo->idx)->pr = 1.0;
213
(((TNode*)level[lvl]) + pinfo->idx)->pr = 0.0; // -log(1.0)
240
(((TNode*)level[lvl]) + pinfo->idx)->pr = 0.0; // -log(1.0)
215
(((TNode*)level[lvl]) + pinfo->idx)->pr = 1.0; // -log(1.0)
242
(((TNode*)level[lvl]) + pinfo->idx)->pr = 1.0; // -log(1.0)
218
245
printf("(cut %d items), build parent ptr...", k); fflush(stdout);
220
k = CutLevel((TNode*)level[lvl-1], ((TNode*)level[lvl-1])+sz[lvl-1], (TLeaf*)level[lvl], ((TLeaf*)level[lvl])+sz[lvl], bUseLogPr);
248
CutLevel((TNode*)level[lvl - 1],
249
((TNode*)level[lvl - 1]) + sz[lvl - 1],
251
((TLeaf*)level[lvl]) + sz[lvl],
222
k = CutLevel((TNode*)level[lvl-1], ((TNode*)level[lvl-1])+sz[lvl-1], (TNode*)level[lvl], ((TNode*)level[lvl])+sz[lvl], bUseLogPr);
255
CutLevel((TNode*)level[lvl - 1],
256
((TNode*)level[lvl - 1]) + sz[lvl - 1],
258
((TNode*)level[lvl]) + sz[lvl],
224
261
sz[lvl] = k; //k is new size
238
281
sumnext += exp(-double(chh->pr));
240
283
sumnext += double(chh->pr);
241
words[lvl+1] = chh->id;
242
sum += pruner->getPr(lvl, words+2);
284
words[lvl + 1] = chh->id;
285
sum += pruner->getPr(lvl, words + 2);
244
287
assert(sumnext >= 0.0 && sumnext < 1.0);
245
288
assert(sum >= 0.0 && sum < 1.0);
246
return (1.0-sumnext)/(1.0-sum);
289
return (1.0 - sumnext) / (1.0 - sum);
249
void CSlmPruner::CalcBOW()
293
CSlmPruner::CalcBOW()
251
295
printf("\nUpdating back-off weight"); fflush(stdout);
252
for (int lvl=0; lvl < N; ++lvl) {
296
for (int lvl = 0; lvl < N; ++lvl) {
253
297
printf("\n Level %d...", lvl); fflush(stdout);
254
298
TNode* base[16]; //it should be lvl+1, yet some compiler do not support it
255
299
int idx[16]; //it should be lvl+1, yet some compiler do not support it
256
for (int i=0; i <= lvl; ++i) {
300
for (int i = 0; i <= lvl; ++i) {
257
301
base[i] = (TNode*)level[i];
260
304
TSIMWordId words[17]; //it should be lvl+2, yet some compiler do not support it
261
for (int lsz = sz[lvl]-1; idx[lvl] < lsz; ++idx[lvl]) {
305
for (int lsz = sz[lvl] - 1; idx[lvl] < lsz; ++idx[lvl]) {
262
306
words[lvl] = base[lvl][idx[lvl]].id;
263
for (int k=lvl-1; k >= 0; --k) {
264
while (base[k][idx[k]+1].child <= idx[k+1])
307
for (int k = lvl - 1; k >= 0; --k) {
308
while (base[k][idx[k] + 1].child <= idx[k + 1])
266
310
words[k] = base[k][idx[k]].id;
268
312
TNode & node = base[lvl][idx[lvl]];
269
TNode & nodenext = *((&node)+1);
313
TNode & nodenext = *((&node) + 1);
271
315
double bow = 1.0;
273
TLeaf* ch = (TLeaf*)level[lvl+1];
274
bow = CalcNodeBow(this, lvl, words, &(ch[node.child]), &(ch[nodenext.child]), bUseLogPr);
317
TLeaf* ch = (TLeaf*)level[lvl + 1];
319
CalcNodeBow(this, lvl, words, &(ch[node.child]),
320
&(ch[nodenext.child]), bUseLogPr);
276
TNode* ch = (TNode*)level[lvl+1];
277
bow = CalcNodeBow(this, lvl, words, &(ch[node.child]), &(ch[nodenext.child]), bUseLogPr);
322
TNode* ch = (TNode*)level[lvl + 1];
324
CalcNodeBow(this, lvl, words, &(ch[node.child]),
325
&(ch[nodenext.child]), bUseLogPr);
280
328
node.bow = PR_TYPE(-log(bow));
285
333
printf("\n"); fflush(stdout);
288
double CSlmPruner::CalcDistance(int lvl, int* idx, TSIMWordId* hw)
337
CSlmPruner::CalcDistance(int lvl, int* idx, TSIMWordId* hw)
290
339
double PA, PB, PHW, PH_W, PH, BOW, _BOW, pr, p_r;
291
340
TSIMWordId w = hw[lvl];
294
TNode* parent = ((TNode*)level[lvl-1])+idx[lvl-1];
343
TNode* parent = ((TNode*)level[lvl - 1]) + idx[lvl - 1];
296
345
BOW = exp(-double(parent->bow)); //Fix original bug to use the BOW directly
298
347
BOW = double(parent->bow);
300
for (int i=1; i < lvl; ++i)
301
PH *= getPr(i, hw+1+(lvl-i));
302
assert(PH <= 1.0 && PH >0.0);
349
for (int i = 1; i < lvl; ++i)
350
PH *= getPr(i, hw + 1 + (lvl - i));
351
assert(PH <= 1.0 && PH > 0.0);
306
PHW = exp(-((((TLeaf*)level[lvl])+idx[lvl])->pr));
355
PHW = exp(-((((TLeaf*)level[lvl]) + idx[lvl])->pr));
308
PHW = ((((TLeaf*)level[lvl])+idx[lvl])->pr);
309
assert(w == (((TLeaf*)level[lvl])+idx[lvl])->id);
357
PHW = ((((TLeaf*)level[lvl]) + idx[lvl])->pr);
358
assert(w == (((TLeaf*)level[lvl]) + idx[lvl])->id);
312
PHW = exp(-((((TNode*)level[lvl])+idx[lvl])->pr));
361
PHW = exp(-((((TNode*)level[lvl]) + idx[lvl])->pr));
314
PHW = ((((TNode*)level[lvl])+idx[lvl])->pr);
315
assert(w == (((TNode*)level[lvl])+idx[lvl])->id);
363
PHW = ((((TNode*)level[lvl]) + idx[lvl])->pr);
364
assert(w == (((TNode*)level[lvl]) + idx[lvl])->id);
318
PH_W = getPr(lvl-1, hw+2);
366
PH_W = getPr(lvl - 1, hw + 2);
319
367
assert(PHW > 0.0 && PHW < 1.0);
320
368
assert(PH_W > 0.0 && PH_W < 1.0);
322
if (cache_level != lvl-1 || cache_idx != idx[lvl-1]) {
324
cache_idx = idx[lvl-1];
370
if (cache_level != lvl - 1 || cache_idx != idx[lvl - 1]) {
371
cache_level = lvl - 1;
372
cache_idx = idx[lvl - 1];
325
373
cache_PA = cache_PB = 1.0;
326
for (int h=parent->child, t = (parent+1)->child; h<t; ++h) {
374
for (int h = parent->child, t = (parent + 1)->child; h < t; ++h) {
330
pr = exp(-((((TLeaf*)level[lvl])+h)->pr));
378
pr = exp(-((((TLeaf*)level[lvl]) + h)->pr));
332
pr = ((((TLeaf*)level[lvl])+h)->pr);
333
id = (((TLeaf*)level[lvl])+h)->id;
380
pr = ((((TLeaf*)level[lvl]) + h)->pr);
381
id = (((TLeaf*)level[lvl]) + h)->id;
337
pr = exp(-((((TNode*)level[lvl])+h)->pr));
384
pr = exp(-((((TNode*)level[lvl]) + h)->pr));
339
pr = ((((TNode*)level[lvl])+h)->pr);
340
id = (((TNode*)level[lvl])+h)->id;
386
pr = ((((TNode*)level[lvl]) + h)->pr);
387
id = (((TNode*)level[lvl]) + h)->id;
343
389
assert(pr > 0.0 && pr < 1.0);
347
p_r = getPr(lvl-1, hw+2); // Fix bug from pr = getPr(lvl-1, hw+1)
393
p_r = getPr(lvl - 1, hw + 2); // Fix bug from pr = getPr(lvl-1, hw+1)
348
394
assert(p_r > 0.0 && p_r < 1.0);
351
397
assert(cache_PA > -0.01 && cache_PB > -0.01);
352
398
if (cache_PA < 0.00001 || cache_PB < 0.00001) {
353
printf("\n precision problem on %d gram:", lvl-1);
354
for (int i=1; i < lvl; ++i) printf("%d ", idx[i]);
399
printf("\n precision problem on %d gram:", lvl - 1);
400
for (int i = 1; i < lvl; ++i) printf("%d ", idx[i]);
356
402
if (cache_PA < 0.00001) {
357
403
printf("{1.0 - sigma p(w|h)} ==> 0.00001");
369
_BOW = (PA+PHW) / (PB+PH_W); // Fix bug from "(1.0-PA+PHW)/(1.0-PB+PH_W);"
415
_BOW = (PA + PHW) / (PB + PH_W); // Fix bug from "(1.0-PA+PHW)/(1.0-PB+PH_W);"
371
417
assert(BOW > 0.0);
372
418
assert(_BOW > 0.0);
373
assert(PA+PHW < 1.01); // %1 error rate
374
assert(PB+PH_W < 1.01); // %1 error rate
419
assert(PA + PHW < 1.01); // %1 error rate
420
assert(PB + PH_W < 1.01); // %1 error rate
377
423
* PH = P(h), PHW = P(w|h), PH_W = P(w|h'), _BOW = bow'(h) (the new bow)
378
424
* BOW = bow(h) (the original bow), PA = sum_{w_i:C(w_i,h)=0} P(w_i|h),
379
425
* PB = sum_{w_i:C(w_i,h)=0} P(w_i|h')
381
return -(PH * (PHW * (log(PH_W)+log(_BOW)-log(PHW)) + PA * (log(_BOW)-log(BOW)) ));
429
(log(PH_W) + log(_BOW) - log(PHW)) + PA * (log(_BOW) - log(BOW))));
386
435
printf("Usage:\n");
387
436
printf(" slmprune input_slm result_slm [R|C] num1 num2...\n");
388
437
printf("\nDescription:\n");
390
440
This program uses entropy-based method to prune the size of back-off \n\
391
441
language model 'input_slm' to a specific size and write to 'result_slm'. \n\
392
442
the third parameter [R|C] means the following numbers is the number for\n\