~paparazzi-uav/paparazzi/v5.0-manual

« back to all changes in this revision

Viewing changes to sw/ext/opencv_bebop/opencv/samples/python/digits.py

  • Committer: Paparazzi buildbot
  • Date: 2016-05-18 15:00:29 UTC
  • Revision ID: felix.ruess+docbot@gmail.com-20160518150029-e8lgzi5kvb4p7un9
Manual import commit 4b8bbb730080dac23cf816b98908dacfabe2a8ec from v5.0 branch.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
#!/usr/bin/env python
 
2
 
 
3
'''
 
4
SVM and KNearest digit recognition.
 
5
 
 
6
Sample loads a dataset of handwritten digits from '../data/digits.png'.
 
7
Then it trains a SVM and KNearest classifiers on it and evaluates
 
8
their accuracy.
 
9
 
 
10
Following preprocessing is applied to the dataset:
 
11
 - Moment-based image deskew (see deskew())
 
12
 - Digit images are split into 4 10x10 cells and 16-bin
 
13
   histogram of oriented gradients is computed for each
 
14
   cell
 
15
 - Transform histograms to space with Hellinger metric (see [1] (RootSIFT))
 
16
 
 
17
 
 
18
[1] R. Arandjelovic, A. Zisserman
 
19
    "Three things everyone should know to improve object retrieval"
 
20
    http://www.robots.ox.ac.uk/~vgg/publications/2012/Arandjelovic12/arandjelovic12.pdf
 
21
 
 
22
Usage:
 
23
   digits.py
 
24
'''
 
25
 
 
26
 
 
27
# Python 2/3 compatibility
 
28
from __future__ import print_function
 
29
 
 
30
# built-in modules
 
31
from multiprocessing.pool import ThreadPool
 
32
 
 
33
import cv2
 
34
 
 
35
import numpy as np
 
36
from numpy.linalg import norm
 
37
 
 
38
# local modules
 
39
from common import clock, mosaic
 
40
 
 
41
 
 
42
 
 
43
SZ = 20 # size of each digit is SZ x SZ
 
44
CLASS_N = 10
 
45
DIGITS_FN = '../data/digits.png'
 
46
 
 
47
def split2d(img, cell_size, flatten=True):
 
48
    h, w = img.shape[:2]
 
49
    sx, sy = cell_size
 
50
    cells = [np.hsplit(row, w//sx) for row in np.vsplit(img, h//sy)]
 
51
    cells = np.array(cells)
 
52
    if flatten:
 
53
        cells = cells.reshape(-1, sy, sx)
 
54
    return cells
 
55
 
 
56
def load_digits(fn):
 
57
    print('loading "%s" ...' % fn)
 
58
    digits_img = cv2.imread(fn, 0)
 
59
    digits = split2d(digits_img, (SZ, SZ))
 
60
    labels = np.repeat(np.arange(CLASS_N), len(digits)/CLASS_N)
 
61
    return digits, labels
 
62
 
 
63
def deskew(img):
 
64
    m = cv2.moments(img)
 
65
    if abs(m['mu02']) < 1e-2:
 
66
        return img.copy()
 
67
    skew = m['mu11']/m['mu02']
 
68
    M = np.float32([[1, skew, -0.5*SZ*skew], [0, 1, 0]])
 
69
    img = cv2.warpAffine(img, M, (SZ, SZ), flags=cv2.WARP_INVERSE_MAP | cv2.INTER_LINEAR)
 
70
    return img
 
71
 
 
72
class StatModel(object):
 
73
    def load(self, fn):
 
74
        self.model.load(fn)  # Known bug: https://github.com/Itseez/opencv/issues/4969
 
75
    def save(self, fn):
 
76
        self.model.save(fn)
 
77
 
 
78
class KNearest(StatModel):
 
79
    def __init__(self, k = 3):
 
80
        self.k = k
 
81
        self.model = cv2.ml.KNearest_create()
 
82
 
 
83
    def train(self, samples, responses):
 
84
        self.model.train(samples, cv2.ml.ROW_SAMPLE, responses)
 
85
 
 
86
    def predict(self, samples):
 
87
        retval, results, neigh_resp, dists = self.model.findNearest(samples, self.k)
 
88
        return results.ravel()
 
89
 
 
90
class SVM(StatModel):
 
91
    def __init__(self, C = 1, gamma = 0.5):
 
92
        self.model = cv2.ml.SVM_create()
 
93
        self.model.setGamma(gamma)
 
94
        self.model.setC(C)
 
95
        self.model.setKernel(cv2.ml.SVM_RBF)
 
96
        self.model.setType(cv2.ml.SVM_C_SVC)
 
97
 
 
98
    def train(self, samples, responses):
 
99
        self.model.train(samples, cv2.ml.ROW_SAMPLE, responses)
 
100
 
 
101
    def predict(self, samples):
 
102
        return self.model.predict(samples)[1].ravel()
 
103
 
 
104
 
 
105
def evaluate_model(model, digits, samples, labels):
 
106
    resp = model.predict(samples)
 
107
    err = (labels != resp).mean()
 
108
    print('error: %.2f %%' % (err*100))
 
109
 
 
110
    confusion = np.zeros((10, 10), np.int32)
 
111
    for i, j in zip(labels, resp):
 
112
        confusion[i, j] += 1
 
113
    print('confusion matrix:')
 
114
    print(confusion)
 
115
    print()
 
116
 
 
117
    vis = []
 
118
    for img, flag in zip(digits, resp == labels):
 
119
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
 
120
        if not flag:
 
121
            img[...,:2] = 0
 
122
        vis.append(img)
 
123
    return mosaic(25, vis)
 
124
 
 
125
def preprocess_simple(digits):
 
126
    return np.float32(digits).reshape(-1, SZ*SZ) / 255.0
 
127
 
 
128
def preprocess_hog(digits):
 
129
    samples = []
 
130
    for img in digits:
 
131
        gx = cv2.Sobel(img, cv2.CV_32F, 1, 0)
 
132
        gy = cv2.Sobel(img, cv2.CV_32F, 0, 1)
 
133
        mag, ang = cv2.cartToPolar(gx, gy)
 
134
        bin_n = 16
 
135
        bin = np.int32(bin_n*ang/(2*np.pi))
 
136
        bin_cells = bin[:10,:10], bin[10:,:10], bin[:10,10:], bin[10:,10:]
 
137
        mag_cells = mag[:10,:10], mag[10:,:10], mag[:10,10:], mag[10:,10:]
 
138
        hists = [np.bincount(b.ravel(), m.ravel(), bin_n) for b, m in zip(bin_cells, mag_cells)]
 
139
        hist = np.hstack(hists)
 
140
 
 
141
        # transform to Hellinger kernel
 
142
        eps = 1e-7
 
143
        hist /= hist.sum() + eps
 
144
        hist = np.sqrt(hist)
 
145
        hist /= norm(hist) + eps
 
146
 
 
147
        samples.append(hist)
 
148
    return np.float32(samples)
 
149
 
 
150
 
 
151
if __name__ == '__main__':
 
152
    print(__doc__)
 
153
 
 
154
    digits, labels = load_digits(DIGITS_FN)
 
155
 
 
156
    print('preprocessing...')
 
157
    # shuffle digits
 
158
    rand = np.random.RandomState(321)
 
159
    shuffle = rand.permutation(len(digits))
 
160
    digits, labels = digits[shuffle], labels[shuffle]
 
161
 
 
162
    digits2 = list(map(deskew, digits))
 
163
    samples = preprocess_hog(digits2)
 
164
 
 
165
    train_n = int(0.9*len(samples))
 
166
    cv2.imshow('test set', mosaic(25, digits[train_n:]))
 
167
    digits_train, digits_test = np.split(digits2, [train_n])
 
168
    samples_train, samples_test = np.split(samples, [train_n])
 
169
    labels_train, labels_test = np.split(labels, [train_n])
 
170
 
 
171
 
 
172
    print('training KNearest...')
 
173
    model = KNearest(k=4)
 
174
    model.train(samples_train, labels_train)
 
175
    vis = evaluate_model(model, digits_test, samples_test, labels_test)
 
176
    cv2.imshow('KNearest test', vis)
 
177
 
 
178
    print('training SVM...')
 
179
    model = SVM(C=2.67, gamma=5.383)
 
180
    model.train(samples_train, labels_train)
 
181
    vis = evaluate_model(model, digits_test, samples_test, labels_test)
 
182
    cv2.imshow('SVM test', vis)
 
183
    print('saving SVM as "digits_svm.dat"...')
 
184
    model.save('digits_svm.dat')
 
185
 
 
186
    cv2.waitKey(0)