~vamsi-krishnak/+junk/libBipartiteMatch

« back to all changes in this revision

Viewing changes to Bipartite_bcup.c

  • Committer: Vamsi Kundeti
  • Date: 2008-12-18 04:00:39 UTC
  • Revision ID: vamsi.krishnak@gmail.com-20081218040039-h1w0l0dln1v9x3cj
libBipartiteMatch inital import

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*
 
2
 * May 20,2008. Vamsi Kundeti.
 
3
 * Implementing unweighted and weighted bipartite matching algorithms.
 
4
 */
 
5
#include<stdlib.h>
 
6
#include<math.h>
 
7
#include<assert.h>
 
8
#include "Bipartite.h"
 
9
/*This implements the MC21 perfect matching*/
 
10
template <class RowStarts, class ColIndices, class Indices>
 
11
match_size_t FindPerfectMatch(Indices match, RowStarts row_starts, 
 
12
                                                          ColIndices col_indices, std::size_t n, 
 
13
                                                          std::size_t nnz, Indices col_marker,
 
14
                                                          Indices P, Indices Minv){
 
15
        unsigned int match_size=0;
 
16
        assert(G && M);
 
17
        //std::size_t n = G->order;
 
18
        //double *nnz=G->nnz;
 
19
        //std::size_t *colind=G->colind;colind--;
 
20
        //std::size_t *rowptr=G->rowptr;rowptr--;
 
21
        //std::size_t *m = M->m;m--;
 
22
        //BitArray_t *col_marker = CreateBitArray(n);
 
23
        //BitArray_t *matched_row = CreateBitArray(n);
 
24
        //std::size_t *p = (std::size_t *) malloc(sizeof(std::size_t)*n);p--;
 
25
        //std::size_t *m_inv = (std::size_t *) malloc(sizeof(std::size_t)*n);m_inv--;
 
26
        std::size_t i,j,k,i1,jend;std::size_t m_inv_prev;char aug_extend=0;
 
27
        match_size=0;
 
28
        for(i=0;i<n;i++){
 
29
                Minv[i] = n+1;
 
30
        }
 
31
        /*for(j=1;j<=n;j++){
 
32
                if(m[j]){
 
33
                        m_inv[m[j]]=j;
 
34
                        match_size++;
 
35
                }
 
36
        }*/
 
37
        for(i=1;i<=n;i++){
 
38
                /*Augmenting path at any unmatched node*/
 
39
                if(m_inv[i]){
 
40
                        continue;
 
41
                }
 
42
                i1 = i; p[i1]=0; jend=0;
 
43
                //ResetAllBits(col_marker);
 
44
                while(i1 && !jend){
 
45
                        std::size_t end = row_ptr [i1+1] ;
 
46
                        for(std::size_t k = row_ptr [i1] ; k < end ; k++){
 
47
                                j = colind[k];
 
48
                                if(!m[j]){
 
49
                                        jend=j;
 
50
                                        break;
 
51
                                }
 
52
                        }
 
53
                        if(jend){
 
54
                                /*Found an unmatched node*/
 
55
                                break;
 
56
                        }
 
57
                        bool aug_extend=false;
 
58
                        for(k=0;k<(rowptr[i1+1]-rowptr[i1]) && !jend;k++){
 
59
                                j = colind[rowptr[i1]+k];
 
60
                                if(!CheckBit(col_marker,j)){
 
61
                                        p[m[j]]=i1;
 
62
                                        //m_inv[m[j]]=j;
 
63
                                        SetBit(col_marker,j);
 
64
                                        i1 = m[j];
 
65
                                        aug_extend = true ;
 
66
                                        break;
 
67
                                }
 
68
                        }
 
69
                        if(!aug_extend){
 
70
                                /*Unable to find a unmarked matched node*/
 
71
                                i1 = p[i1];
 
72
                        }
 
73
                }
 
74
                if(i1){
 
75
                        /*Augmenting path is found so augment*/
 
76
                        j = jend; 
 
77
                        printf("Augmenting the path {");
 
78
                        while(i1){
 
79
                                m_inv_prev = m_inv[i1];
 
80
                                m[j] = i1; m_inv[i1] = j;
 
81
                                printf("(%u,%u)",i1,j);
 
82
                                j = m_inv_prev;
 
83
                                i1 = p[i1];
 
84
                        }
 
85
                        printf("}\n");
 
86
                        match_size++;
 
87
                        printf("Match Size: %u\n",match_size);
 
88
                }else{
 
89
                        /*the matching is maximum*/
 
90
                        break;
 
91
                }
 
92
        }
 
93
        //FreeBitArray(col_marker); FreeBitArray(matched_row);
 
94
        //free(++m_inv); free(++p);
 
95
        return match_size;
 
96
}
 
97
/*MC64: Weighted Bipartite matching, MAKE SURE that M is *a 'extreme' 
 
98
 *matching w.r.t nnz values in G->nnz, if you are not passing an empty
 
99
 *match.
 
100
 */
 
101
match_size_t WeightedMatching(Match_t *M,SparseGraph *G){
 
102
        double *C = G->nnz;C--;std::size_t *m = M->m;m--;
 
103
        std::size_t n = G->order;std::size_t i,j,i1,jend,k,m_inv_prev;
 
104
        double curr_shortest_path=0;double curr_aug_path=HUGE_VAL;
 
105
        double *C1 = (double *)malloc(sizeof(double)*(G->nnz_size));C1--;
 
106
        double *dist = (double *)malloc(sizeof(double)*n);dist--;
 
107
        double *u = (double *)malloc(sizeof(double)*n);u--;
 
108
        double *v = (double *)malloc(sizeof(double)*n);v--;
 
109
        std::size_t *m_inv = (std::size_t *)malloc(sizeof(std::size_t)*n);m_inv--;
 
110
        std::size_t *rowptr = G->rowptr; rowptr--;
 
111
        std::size_t *colind = G->colind; colind--;
 
112
        std::size_t *p = (std::size_t *) malloc(sizeof(std::size_t)*n);p--;
 
113
        unsigned int match_size=0;double dist1; std::size_t itrace;
 
114
        BitArray_t *col_marker = CreateBitArray(n);
 
115
        BitArray_t *heap_marker = CreateBitArray(n);
 
116
 
 
117
        assert(C1 && dist && u && v && p);
 
118
 
 
119
        /*If the matching is not M has to be an extreme matching*/
 
120
        for(i=1;i<=G->nnz_size;i++){
 
121
                if(i<=n){
 
122
                        m_inv[i] = 0;
 
123
                        u[i]=0; v[i]=0;
 
124
                }
 
125
                C1[i] = C[i];
 
126
        }
 
127
        for(j=1;j<=n;j++){
 
128
                if(m[j]){
 
129
                        m_inv[m[j]]=j;
 
130
                }
 
131
                dist[j] = HUGE_VAL; 
 
132
        }
 
133
 
 
134
        for(i=1;i<=n;i++){
 
135
                if(m_inv[i]){
 
136
                        continue;
 
137
                }
 
138
                /*
 
139
                 *Aim is to find a value for jend such that the path
 
140
                 *from i-->jend is the shortest
 
141
                 */
 
142
                i1 = i; p[i1] = 0; jend=0; 
 
143
                ResetAllBits(col_marker);
 
144
                ResetAllBits(heap_marker);
 
145
                curr_shortest_path=0;curr_aug_path=HUGE_VAL;
 
146
                while(curr_aug_path > curr_shortest_path){
 
147
                        for(k=0;k<(rowptr[i1+1]-rowptr[i1]);k++){
 
148
                                j = colind[rowptr[i1]+k];
 
149
                                if(CheckBit(col_marker,j)){
 
150
                                        continue;
 
151
                                }
 
152
                                dist1 = curr_shortest_path + C1[rowptr[i1]+k];
 
153
                                /*Prune any dist1's > curr_aug_path, since
 
154
                                 *all the costs>0 
 
155
                                 */
 
156
                                if(dist1 < curr_aug_path){
 
157
                                        if(!m[j]){
 
158
                                                /*we need itrace because, the last i1 which
 
159
                                                 *we explore may not actually give the shortest 
 
160
                                                 *augmenting path.*/
 
161
                                                jend = j; itrace = i1;
 
162
                                                curr_aug_path = dist1;
 
163
                                        }else if(dist1 < dist[j]){ /*Update the dist*/
 
164
                                                dist[j] = dist1; p[m[j]] = i1;
 
165
                                                SetBit(heap_marker,j);
 
166
                                        }
 
167
                                }
 
168
                        }
 
169
                        /*We now have a heap of matched cols, so pick the min*/
 
170
                        j = SimplePickMin(heap_marker,dist,n);
 
171
                        if(j){
 
172
                                curr_shortest_path = dist[j]; 
 
173
                                UnsetBit(heap_marker,j); SetBit(col_marker,j);
 
174
                                i1 = m[j];
 
175
                        }else{
 
176
                                break;
 
177
                        }
 
178
                }
 
179
                if(jend){ /*We found a shortest augmenting path*/
 
180
                        j=jend;
 
181
                        std::size_t itrace_prev;
 
182
                        //printf("Shortest augmenting Path {");
 
183
                        match_size = 0;
 
184
                        while(itrace){
 
185
                                m_inv_prev = m_inv[itrace];
 
186
                                m[j] = itrace; m_inv[itrace]=j;
 
187
                                //printf("(%u,%u)",itrace,j);
 
188
                                j=m_inv_prev;
 
189
                                itrace_prev = itrace;
 
190
                                itrace = p[itrace];
 
191
                                if(itrace){
 
192
                                //      printf("(%u,%u)",itrace_prev,m_inv_prev);
 
193
                                }
 
194
                                match_size++;
 
195
                        }
 
196
                        //printf("}\n");
 
197
                        /*Update the cost with new match m*/
 
198
                        for(j=1;j<=n;j++){
 
199
                                if(CheckBit(col_marker,j)){
 
200
                                        u[j] = u[j]+dist[j]-curr_aug_path;
 
201
                                        /*Reset the dist values*/
 
202
                                }
 
203
                        }
 
204
                        for(i1=1;i1<=n;i1++){
 
205
                                if(!m_inv[i1]) continue;
 
206
                                j = m_inv[i1];
 
207
                                std::size_t end = rowptr [i1 + 1] ;
 
208
                                for(std::size_t k = rowptr[i1];k < end ; k++){
 
209
                                        if(colind[k] == j){
 
210
                                                v[i1] = C[k] - u[j];
 
211
                                                assert (C[k] - u[j] - v[i1] == 0.0) ;
 
212
                                                break ;
 
213
                                        }
 
214
                                }
 
215
                        }
 
216
                        /*Update the cost*/
 
217
                        for(i1=1;i1<=n;i1++){
 
218
                                for(k=0;k<(rowptr[i1+1]-rowptr[i1]);k++){
 
219
                                        j = colind[rowptr[i1]+k];
 
220
                                        C1[rowptr[i1]+k] = C[rowptr[i1]+k]-u[j]-v[i1];
 
221
                                }
 
222
                                /*The index should be j rather than i1 but just 
 
223
                                 *avoiding another loop*/
 
224
                                dist[i1] = HUGE_VAL;
 
225
                        }
 
226
                }
 
227
        }
 
228
        /*TODO: Freeup the stuff*/
 
229
}
 
230
std::size_t SimplePickMin(BitArray_t *bit_heap,double *dist,std::size_t n){
 
231
        std::size_t min_j=0;std::size_t j; 
 
232
        double curr_min = HUGE_VAL; 
 
233
        for(j=1;j<=n;j++){
 
234
                if(CheckBit(bit_heap,j) && dist[j]<curr_min){
 
235
                        min_j = j;
 
236
                        curr_min = dist[j];
 
237
                }
 
238
        }
 
239
        return min_j;
 
240
}
 
241
 
 
242
BitArray_t* CreateBitArray(unsigned int size){
 
243
        /* change it to O(n) representation */
 
244
        div_t d = div(size,SIZE_OF_BYTE_IN_BITS);
 
245
        unsigned int len;
 
246
        BitArray_t *bits = (BitArray_t *)malloc(sizeof(BitArray_t)*1);
 
247
        assert(bits);
 
248
        bits->size = (d.rem > 0)?(d.quot+1):d.quot;
 
249
        bits->ba = (char *)malloc(sizeof(char)*(bits->size));
 
250
        assert(bits->ba);
 
251
        memset(bits->ba,'\0',bits->size);
 
252
        return bits;
 
253
}
 
254
inline void SetBit(BitArray_t *b,unsigned int bitno){
 
255
        div_t d;
 
256
        assert(bitno>=1 && bitno<=(b->size*SIZE_OF_BYTE_IN_BITS));
 
257
        d = div((bitno-1),SIZE_OF_BYTE_IN_BITS);
 
258
        (b->ba)[d.quot] = (b->ba)[d.quot]|(1<<(d.rem));
 
259
}
 
260
inline void UnsetBit(BitArray_t *b,unsigned int bitno){
 
261
        div_t d;
 
262
        assert(bitno>=1 && bitno<=(b->size*SIZE_OF_BYTE_IN_BITS));
 
263
        d = div((bitno-1),SIZE_OF_BYTE_IN_BITS);
 
264
        (b->ba)[d.quot] = (b->ba)[d.quot]^(1<<(d.rem));
 
265
}
 
266
inline unsigned char CheckBit(BitArray_t* b,unsigned int bitno){
 
267
        div_t d;
 
268
        assert(bitno>=1 && bitno<=(b->size*SIZE_OF_BYTE_IN_BITS));
 
269
        d = div(bitno-1,SIZE_OF_BYTE_IN_BITS);
 
270
        return ((b->ba)[d.quot] & (1<<(d.rem)));
 
271
}
 
272
inline void ResetAllBits(BitArray_t *b){
 
273
        memset(b->ba,'\0',b->size);
 
274
}
 
275
void FreeBitArray(BitArray_t *b){
 
276
        free(b->ba);
 
277
        free(b);
 
278
}
 
279
/*std::size_t with value 0 is considered an empty node*/
 
280
Match_t * CreateEmptyMatch(unsigned int n){
 
281
        unsigned int i;
 
282
        Match_t *ma = (Match_t *)malloc(sizeof(Match_t)*1);
 
283
        ma->m = (std::size_t *) malloc(sizeof(std::size_t)*n);
 
284
        for(i=0;i<n;i++){
 
285
                ma->m[i] = 0;
 
286
        }
 
287
        return ma;
 
288
}
 
289
 
 
290
void LogTransform(SparseGraph *G){
 
291
        std::size_t *rowptr = G->rowptr; rowptr--;
 
292
        std::size_t *colind = G->colind; colind--;
 
293
        double *val = G->nnz;val--;
 
294
        std::size_t n = G->order; std::size_t k;
 
295
        std::size_t nnz_size = G->nnz_size;
 
296
        double *max_c = (double *)malloc(sizeof(double)*n);max_c--;
 
297
        for(k=1;k<=n;k++){
 
298
                /*Find the maximum in the column*/
 
299
                max_c[k] = 0;
 
300
        }
 
301
        //use absolute value of the elements
 
302
        for(k=1;k<=nnz_size;k++){
 
303
                if(std::abs(val[k]) > max_c[colind[k]]){
 
304
                        max_c[colind[k]] = std::abs(val[k]);
 
305
                }
 
306
        }
 
307
#if 0
 
308
        for(k=1;k<=n;k++){
 
309
                printf("max in col %u is %lf \n",k,max_c[k]);
 
310
        }
 
311
#endif
 
312
 
 
313
        /*Update*/
 
314
        for(k=1;k<=nnz_size;k++){
 
315
                //if val[k] = 0, handle the case
 
316
                val[k] = log(max_c[colind[k]]) - log(fabs(val[k]));
 
317
        }
 
318
        free(++max_c);
 
319
}
 
320
 
 
321
extern int *perm;
 
322
 
 
323
 
 
324
#ifndef NO_UNIT_TEST 
 
325
int main(int argc,char** argv){
 
326
        FILE *ptr = NULL; unsigned int i;
 
327
        SparseGraph *G = NULL;
 
328
        Match_t *m = NULL;
 
329
        if(argc<2){
 
330
                fprintf(stderr,"./a.out <sparse_file>");
 
331
                exit(1);
 
332
        }
 
333
        ptr = fopen(argv[1],"r");
 
334
        assert(ptr);
 
335
        G = ReadSparseGraph(ptr);
 
336
        m = CreateEmptyMatch(G->order);
 
337
        if(argc > 2 && argv[2]){
 
338
                FILE *mfile = fopen(argv[2],"r");
 
339
                std::size_t count=0;
 
340
                std::size_t mid;
 
341
                std::size_t *m1 = m->m; m1--;
 
342
                while(fscanf(mfile,"%u",&mid)==1){
 
343
                        count++;
 
344
                        m1[count] = mid;
 
345
                }
 
346
        }
 
347
        LogTransform(G);
 
348
        WeightedMatching(m,G);
 
349
        /*Now call the MC64*/
 
350
        MC64Driver(argc,argv);
 
351
 
 
352
        printf("Verifying the permutation\n");
 
353
 
 
354
        for(i=0;i<G->order;i++){
 
355
        //      printf("%d %d\n",(i+1),(m->m)[i]);
 
356
                printf("%u(%d) ",(m->m)[i],perm[i]);
 
357
                if(perm[i]==(int)(m->m)[i]){
 
358
                        continue;
 
359
                }else{
 
360
                        printf("FAILED\n");
 
361
                        exit(1);
 
362
                }
 
363
        }
 
364
        printf("\nSUCCESS\n");
 
365
        printf("\n");
 
366
}
 
367
#endif