211
214
* PermutohedralLattice::filter(...) does all the work. *
213
216
******************************************************************/
217
template <int D, int VD>
214
218
class PermutohedralLattice
219
/* Filters given image against a reference image.
220
* im : image to be filtered.
221
* ref : reference image whose edges are to be respected.
223
static Image filter(Image im, Image ref)
228
PermutohedralLattice lattice(ref.channels, im.channels+1, im.width*im.height*im.frames);
230
// Splat into the lattice
231
//printf("Splatting...\n");
233
float *imPtr = im(0, 0, 0);
234
float *refPtr = ref(0, 0, 0);
235
for (int t = 0; t < im.frames; t++)
237
for (int y = 0; y < im.height; y++)
239
for (int x = 0; x < im.width; x++)
241
lattice.splat(refPtr, imPtr);
242
refPtr += ref.channels;
243
imPtr += im.channels;
249
//printf("Blurring...\n");
252
// Slice from the lattice
253
//printf("Slicing...\n");
255
Image out(im.width, im.height, im.frames, im.channels);
257
lattice.beginSlice();
258
float *outPtr = out(0, 0, 0);
259
for (int t = 0; t < im.frames; t++)
261
for (int y = 0; y < im.height; y++)
263
for (int x = 0; x < im.width; x++)
265
lattice.slice(outPtr);
266
outPtr += out.channels;
276
222
* d_ : dimensionality of key vectors
277
223
* vd_ : dimensionality of value vectors
278
224
* nData_ : number of points in the input
280
PermutohedralLattice(int d_, int vd_, int nData_) :
281
d(d_), vd(vd_), nData(nData_), hashTable(d_, vd_)
226
PermutohedralLattice(int nData_, int nThreads_=1) :
227
nData(nData_), nThreads(nThreads_)
284
230
// Allocate storage for various arrays
285
elevated = new float[d+1];
286
scaleFactor = new float[d];
231
float *scaleFactorTmp = new float[D];
232
int *canonicalTmp = new int[(D+1)*(D+1)];
288
greedy = new short[d+1];
289
rank = new char[d+1];
290
barycentric = new float[d+2];
291
replay = new ReplayEntry[nData*(d+1)];
293
canonical = new short[(d+1)*(d+1)];
294
key = new short[d+1];
234
replay = new ReplayEntry[nData*(D+1)];
296
236
// compute the coordinates of the canonical simplex, in which
297
237
// the difference between a contained point and the zero
298
238
// remainder vertex is always in ascending order. (See pg.4 of paper.)
299
for (int i = 0; i <= d; i++)
239
for (int i = 0; i <= D; i++)
301
for (int j = 0; j <= d-i; j++)
302
canonical[i*(d+1)+j] = i;
303
for (int j = d-i+1; j <= d; j++)
304
canonical[i*(d+1)+j] = i - (d+1);
241
for (int j = 0; j <= D-i; j++)
242
canonicalTmp[i*(D+1)+j] = i;
243
for (int j = D-i+1; j <= D; j++)
244
canonicalTmp[i*(D+1)+j] = i - (D+1);
246
canonical = canonicalTmp;
307
248
// Compute parts of the rotation matrix E. (See pg.4-5 of paper.)
308
for (int i = 0; i < d; i++)
249
for (int i = 0; i < D; i++)
310
251
// the diagonal entries for normalization
311
scaleFactor[i] = 1.0f/(sqrtf((float)(i+1)*(i+2)));
252
scaleFactorTmp[i] = 1.0f/(sqrtf((float)(i+1)*(i+2)));
313
254
/* We presume that the user would like to do a Gaussian blur of standard deviation
314
255
* 1 in each dimension (or a total variance of d, summed over dimensions.)
323
264
* So we need to scale the space by (d+1)sqrt(2/3).
325
scaleFactor[i] *= (d+1)*sqrtf(2.0/3);
266
scaleFactorTmp[i] *= (D+1)*sqrtf(2.0/3);
268
scaleFactor = scaleFactorTmp;
270
hashTables = new HashTablePermutohedral<D,VD>[nThreads];
330
274
~PermutohedralLattice()
332
276
delete[] scaleFactor;
336
delete[] barycentric;
338
278
delete[] canonical;
343
283
/* Performs splatting with given position and value vectors */
344
void splat(float *position, float *value)
284
void splat(float *position, float *value, int replay_index, int thread_index=0)
289
float barycentric[D+2];
347
292
// first rotate position into the (d+1)-dimensional hyperplane
348
elevated[d] = -d*position[d-1]*scaleFactor[d-1];
349
for (int i = d-1; i > 0; i--)
293
elevated[D] = -D*position[D-1]*scaleFactor[D-1];
294
for (int i = D-1; i > 0; i--)
350
295
elevated[i] = (elevated[i+1] -
351
296
i*position[i-1]*scaleFactor[i-1] +
352
297
(i+2)*position[i]*scaleFactor[i]);
353
298
elevated[0] = elevated[1] + 2*position[0]*scaleFactor[0];
355
300
// prepare to find the closest lattice points
356
float scale = 1.0f/(d+1);
357
char * myrank = rank;
358
short * mygreedy = greedy;
301
float scale = 1.0f/(D+1);
360
303
// greedily search for the closest zero-colored lattice point
362
for (int i = 0; i <= d; i++)
305
for (int i = 0; i <= D; i++)
364
307
float v = elevated[i]*scale;
365
float up = ceilf(v)*(d+1);
366
float down = floorf(v)*(d+1);
368
if (up - elevated[i] < elevated[i] - down) mygreedy[i] = (short)up;
369
else mygreedy[i] = (short)down;
308
float up = ceilf(v)*(D+1);
309
float down = floorf(v)*(D+1);
311
if (up - elevated[i] < elevated[i] - down) greedy[i] = up;
312
else greedy[i] = down;
375
318
// rank differential to find the permutation between this simplex and the canonical one.
376
319
// (See pg. 3-4 in paper.)
377
memset(myrank, 0, sizeof(char)*(d+1));
378
for (int i = 0; i < d; i++)
379
for (int j = i+1; j <= d; j++)
380
if (elevated[i] - mygreedy[i] < elevated[j] - mygreedy[j]) myrank[i]++;
320
memset(rank, 0, sizeof rank);
321
for (int i = 0; i < D; i++)
322
for (int j = i+1; j <= D; j++)
323
if (elevated[i] - greedy[i] < elevated[j] - greedy[j]) rank[i]++;
385
328
// sum too large - the point is off the hyperplane.
386
329
// need to bring down the ones with the smallest differential
387
for (int i = 0; i <= d; i++)
330
for (int i = 0; i <= D; i++)
389
if (myrank[i] >= d + 1 - sum)
332
if (rank[i] >= D + 1 - sum)
392
myrank[i] += sum - (d+1);
335
rank[i] += sum - (D+1);
398
341
else if (sum < 0)
400
343
// sum too small - the point is off the hyperplane
401
344
// need to bring up the ones with largest differential
402
for (int i = 0; i <= d; i++)
345
for (int i = 0; i <= D; i++)
404
if (myrank[i] < -sum)
407
myrank[i] += (d+1) + sum;
350
rank[i] += (D+1) + sum;
414
357
// Compute barycentric coordinates (See pg.10 of paper.)
415
memset(barycentric, 0, sizeof(float)*(d+2));
416
for (int i = 0; i <= d; i++)
358
memset(barycentric, 0, sizeof barycentric);
359
for (int i = 0; i <= D; i++)
418
barycentric[d-myrank[i]] += (elevated[i] - mygreedy[i]) * scale;
419
barycentric[d+1-myrank[i]] -= (elevated[i] - mygreedy[i]) * scale;
361
barycentric[D-rank[i]] += (elevated[i] - greedy[i]) * scale;
362
barycentric[D+1-rank[i]] -= (elevated[i] - greedy[i]) * scale;
421
barycentric[0] += 1.0f + barycentric[d+1];
364
barycentric[0] += 1.0f + barycentric[D+1];
423
366
// Splat the value into each vertex of the simplex, with barycentric weights.
424
for (int remainder = 0; remainder <= d; remainder++)
367
for (int remainder = 0; remainder <= D; remainder++)
426
369
// Compute the location of the lattice point explicitly (all but the last coordinate - it's redundant because they sum to zero)
427
for (int i = 0; i < d; i++)
428
key[i] = mygreedy[i] + canonical[remainder*(d+1) + myrank[i]];
370
for (int i = 0; i < D; i++)
371
key[i] = greedy[i] + canonical[remainder*(D+1) + rank[i]];
430
373
// Retrieve pointer to the value at this vertex.
431
float * val = hashTable.lookup(key, true);
374
float * val = hashTables[thread_index].lookup(key, true);
433
376
// Accumulate values with barycentric weight.
434
for (int i = 0; i < vd; i++)
377
for (int i = 0; i < VD; i++)
435
378
val[i] += barycentric[remainder]*value[i];
437
380
// Record this interaction to use later when slicing
438
replay[nReplay].offset = val - hashTable.getValues();
439
replay[nReplay].weight = barycentric[remainder];
381
replay[replay_index*(D+1)+remainder].table = thread_index;
382
replay[replay_index*(D+1)+remainder].offset = val - hashTables[thread_index].getValues();
383
replay[replay_index*(D+1)+remainder].weight = barycentric[remainder];
445
// Prepare for slicing
387
/* Merge the multiple threads' hash tables into the totals. */
388
void merge_splat_threads(void)
393
/* Merge the multiple hash tables into one, creating an offset remap table. */
394
int *offset_remap[nThreads];
395
for (int i = 1; i < nThreads; i++)
397
const short *oldKeys = hashTables[i].getKeys();
398
const float *oldVals = hashTables[i].getValues();
399
const int filled = hashTables[i].size();
400
offset_remap[i] = new int[filled];
401
for (int j = 0; j < filled; j++)
403
float *val = hashTables[0].lookup(oldKeys+j*D, true);
404
const float *oldVal = oldVals + j*VD;
405
for (int k = 0; k < VD; k++)
407
offset_remap[i][j] = val - hashTables[0].getValues();
411
/* Rewrite the offsets in the replay structure from the above generated table. */
412
for (int i = 0; i < nData*(D+1); i++)
413
if (replay[i].table > 0)
414
replay[i].offset = offset_remap[replay[i].table][replay[i].offset/VD];
416
for (int i = 1; i < nThreads; i++)
417
delete[] offset_remap[i];
451
420
/* Performs slicing out of position vectors. Note that the barycentric weights and the simplex
452
421
* containing each position vector were calculated and stored in the splatting step.
453
422
* We may reuse this to accelerate the algorithm. (See pg. 6 in paper.)
455
void slice(float *col)
424
void slice(float *col, int replay_index)
457
float *base = hashTable.getValues();
458
for (int j = 0; j < vd; j++) col[j] = 0;
459
for (int i = 0; i <= d; i++)
426
float *base = hashTables[0].getValues();
427
for (int j = 0; j < VD; j++) col[j] = 0;
428
for (int i = 0; i <= D; i++)
461
ReplayEntry r = replay[nReplay++];
462
for (int j = 0; j < vd; j++)
430
ReplayEntry r = replay[replay_index*(D+1)+i];
431
for (int j = 0; j < VD; j++)
464
433
col[j] += r.weight*base[r.offset + j];
472
441
// Prepare arrays
473
short *neighbor1 = new short[d+1];
474
short *neighbor2 = new short[d+1];
475
float *newValue = new float[vd*hashTable.size()];
476
float *oldValue = hashTable.getValues();
442
float *newValue = new float[VD*hashTables[0].size()];
443
float *oldValue = hashTables[0].getValues();
477
444
float *hashTableBase = oldValue;
479
float *zero = new float[vd];
480
for (int k = 0; k < vd; k++) zero[k] = 0;
447
for (int k = 0; k < VD; k++) zero[k] = 0;
482
449
// For each of d+1 axes,
483
for (int j = 0; j <= d; j++)
450
for (int j = 0; j <= D; j++)
453
#pragma omp parallel for shared(j, oldValue, newValue, hashTableBase, zero)
485
455
// For each vertex in the lattice,
486
for (int i = 0; i < hashTable.size(); i++) // blur point i in dimension j
456
for (int i = 0; i < hashTables[0].size(); i++) // blur point i in dimension j
488
short *key = hashTable.getKeys() + i*(d); // keys to current vertex
489
for (int k = 0; k < d; k++)
458
const short *key = hashTables[0].getKeys() + i*(D); // keys to current vertex
459
short neighbor1[D+1];
460
short neighbor2[D+1];
461
for (int k = 0; k < D; k++)
491
463
neighbor1[k] = key[k] + 1;
492
464
neighbor2[k] = key[k] - 1;
494
neighbor1[j] = key[j] - d;
495
neighbor2[j] = key[j] + d; // keys to the neighbors along the given axis.
466
neighbor1[j] = key[j] - D;
467
neighbor2[j] = key[j] + D; // keys to the neighbors along the given axis.
497
float *oldVal = oldValue + i*vd;
498
float *newVal = newValue + i*vd;
469
float *oldVal = oldValue + i*VD;
470
float *newVal = newValue + i*VD;
500
472
float *vm1, *vp1;
502
vm1 = hashTable.lookup(neighbor1, false); // look up first neighbor
474
vm1 = hashTables[0].lookup(neighbor1, false); // look up first neighbor
503
475
if (vm1) vm1 = vm1 - hashTableBase + oldValue;
506
vp1 = hashTable.lookup(neighbor2, false); // look up second neighbor
478
vp1 = hashTables[0].lookup(neighbor2, false); // look up second neighbor
507
479
if (vp1) vp1 = vp1 - hashTableBase + oldValue;
510
482
// Mix values of the three vertices
511
for (int k = 0; k < vd; k++)
483
for (int k = 0; k < VD; k++)
512
484
newVal[k] = (0.25f*vm1[k] + 0.5f*oldVal[k] + 0.25f*vp1[k]);
514
486
float *tmp = newValue;