1
// Copyright 2012 The Go Authors. All rights reserved.
2
// Use of this source code is governed by a BSD-style
3
// license that can be found in the LICENSE file.
5
// This file implements the Socialist Millionaires Protocol as described in
6
// http://www.cypherpunks.ca/otr/Protocol-v2-3.1.0.html. The protocol
7
// specification is required in order to understand this code and, where
8
// possible, the variable names in the code match up with the spec.
20
type smpFailure string
22
func (s smpFailure) Error() string {
26
var smpFailureError = smpFailure("otr: SMP protocol failed")
27
var smpSecretMissingError = smpFailure("otr: mutual secret needed")
38
type smpState struct {
40
a2, a3, b2, b3, pb, qb *big.Int
43
g3b, papb, qaqb, ra *big.Int
49
func (c *Conversation) startSMP(question string) (tlvs []tlv) {
50
if c.smp.state != smpState1 {
51
tlvs = append(tlvs, c.generateSMPAbort())
53
tlvs = append(tlvs, c.generateSMP1(question))
55
c.smp.state = smpState2
59
func (c *Conversation) resetSMP() {
60
c.smp.state = smpState1
65
func (c *Conversation) processSMP(in tlv) (out tlv, complete bool, err error) {
70
if c.smp.state != smpState1 {
75
case tlvTypeSMP1WithQuestion:
76
// We preprocess this into a SMP1 message.
77
nulPos := bytes.IndexByte(data, 0)
79
err = errors.New("otr: SMP message with question didn't contain a NUL byte")
82
c.smp.question = string(data[:nulPos])
83
data = data[nulPos+1:]
86
numMPIs, data, ok := getU32(data)
87
if !ok || numMPIs > 20 {
88
err = errors.New("otr: corrupt SMP message")
92
mpis := make([]*big.Int, numMPIs)
95
mpis[i], data, ok = getMPI(data)
97
err = errors.New("otr: corrupt SMP message")
103
case tlvTypeSMP1, tlvTypeSMP1WithQuestion:
104
if c.smp.state != smpState1 {
106
out = c.generateSMPAbort()
109
if c.smp.secret == nil {
110
err = smpSecretMissingError
113
if err = c.processSMP1(mpis); err != nil {
116
c.smp.state = smpState3
117
out = c.generateSMP2()
119
if c.smp.state != smpState2 {
121
out = c.generateSMPAbort()
124
if out, err = c.processSMP2(mpis); err != nil {
125
out = c.generateSMPAbort()
128
c.smp.state = smpState4
130
if c.smp.state != smpState3 {
132
out = c.generateSMPAbort()
135
if out, err = c.processSMP3(mpis); err != nil {
138
c.smp.state = smpState1
142
if c.smp.state != smpState4 {
144
out = c.generateSMPAbort()
147
if err = c.processSMP4(mpis); err != nil {
148
out = c.generateSMPAbort()
151
c.smp.state = smpState1
155
panic("unknown SMP message")
161
func (c *Conversation) calcSMPSecret(mutualSecret []byte, weStarted bool) {
163
h.Write([]byte{smpVersion})
165
h.Write(c.PrivateKey.PublicKey.Fingerprint())
166
h.Write(c.TheirPublicKey.Fingerprint())
168
h.Write(c.TheirPublicKey.Fingerprint())
169
h.Write(c.PrivateKey.PublicKey.Fingerprint())
172
h.Write(mutualSecret)
173
c.smp.secret = new(big.Int).SetBytes(h.Sum(nil))
176
func (c *Conversation) generateSMP1(question string) tlv {
178
c.smp.a2 = c.randMPI(randBuf[:])
179
c.smp.a3 = c.randMPI(randBuf[:])
180
g2a := new(big.Int).Exp(g, c.smp.a2, p)
181
g3a := new(big.Int).Exp(g, c.smp.a3, p)
184
r2 := c.randMPI(randBuf[:])
185
r := new(big.Int).Exp(g, r2, p)
186
c2 := new(big.Int).SetBytes(hashMPIs(h, 1, r))
187
d2 := new(big.Int).Mul(c.smp.a2, c2)
194
r3 := c.randMPI(randBuf[:])
196
c3 := new(big.Int).SetBytes(hashMPIs(h, 2, r))
197
d3 := new(big.Int).Mul(c.smp.a3, c3)
205
if len(question) > 0 {
206
ret.typ = tlvTypeSMP1WithQuestion
207
ret.data = append(ret.data, question...)
208
ret.data = append(ret.data, 0)
210
ret.typ = tlvTypeSMP1
212
ret.data = appendU32(ret.data, 6)
213
ret.data = appendMPIs(ret.data, g2a, c2, d2, g3a, c3, d3)
217
func (c *Conversation) processSMP1(mpis []*big.Int) error {
219
return errors.New("otr: incorrect number of arguments in SMP1 message")
229
r := new(big.Int).Exp(g, d2, p)
230
s := new(big.Int).Exp(g2a, c2, p)
233
t := new(big.Int).SetBytes(hashMPIs(h, 1, r))
235
return errors.New("otr: ZKP c2 incorrect in SMP1 message")
241
t.SetBytes(hashMPIs(h, 2, r))
243
return errors.New("otr: ZKP c3 incorrect in SMP1 message")
251
func (c *Conversation) generateSMP2() tlv {
253
b2 := c.randMPI(randBuf[:])
254
c.smp.b3 = c.randMPI(randBuf[:])
255
r2 := c.randMPI(randBuf[:])
256
r3 := c.randMPI(randBuf[:])
257
r4 := c.randMPI(randBuf[:])
258
r5 := c.randMPI(randBuf[:])
259
r6 := c.randMPI(randBuf[:])
261
g2b := new(big.Int).Exp(g, b2, p)
262
g3b := new(big.Int).Exp(g, c.smp.b3, p)
264
r := new(big.Int).Exp(g, r2, p)
266
c2 := new(big.Int).SetBytes(hashMPIs(h, 3, r))
267
d2 := new(big.Int).Mul(b2, c2)
275
c3 := new(big.Int).SetBytes(hashMPIs(h, 4, r))
276
d3 := new(big.Int).Mul(c.smp.b3, c3)
283
c.smp.g2 = new(big.Int).Exp(c.smp.g2a, b2, p)
284
c.smp.g3 = new(big.Int).Exp(c.smp.g3a, c.smp.b3, p)
285
c.smp.pb = new(big.Int).Exp(c.smp.g3, r4, p)
286
c.smp.qb = new(big.Int).Exp(g, r4, p)
287
r.Exp(c.smp.g2, c.smp.secret, p)
288
c.smp.qb.Mul(c.smp.qb, r)
289
c.smp.qb.Mod(c.smp.qb, p)
292
s.Exp(c.smp.g2, r6, p)
296
r.Exp(c.smp.g3, r5, p)
297
cp := new(big.Int).SetBytes(hashMPIs(h, 5, r, s))
299
// D5 = r5 - r4 cP mod q and D6 = r6 - y cP mod q
303
d5 := new(big.Int).Mod(r, q)
308
s.Mul(c.smp.secret, cp)
310
d6 := new(big.Int).Mod(r, q)
316
ret.typ = tlvTypeSMP2
317
ret.data = appendU32(ret.data, 11)
318
ret.data = appendMPIs(ret.data, g2b, c2, d2, g3b, c3, d3, c.smp.pb, c.smp.qb, cp, d5, d6)
322
func (c *Conversation) processSMP2(mpis []*big.Int) (out tlv, err error) {
324
err = errors.New("otr: incorrect number of arguments in SMP2 message")
340
r := new(big.Int).Exp(g, d2, p)
341
s := new(big.Int).Exp(g2b, c2, p)
344
s.SetBytes(hashMPIs(h, 3, r))
346
err = errors.New("otr: ZKP c2 failed in SMP2 message")
354
s.SetBytes(hashMPIs(h, 4, r))
356
err = errors.New("otr: ZKP c3 failed in SMP2 message")
360
c.smp.g2 = new(big.Int).Exp(g2b, c.smp.a2, p)
361
c.smp.g3 = new(big.Int).Exp(g3b, c.smp.a3, p)
364
s.Exp(c.smp.g2, d6, p)
370
s.Exp(c.smp.g3, d5, p)
371
t := new(big.Int).Exp(pb, cp, p)
374
t.SetBytes(hashMPIs(h, 5, s, r))
376
err = errors.New("otr: ZKP cP failed in SMP2 message")
381
r4 := c.randMPI(randBuf[:])
382
r5 := c.randMPI(randBuf[:])
383
r6 := c.randMPI(randBuf[:])
384
r7 := c.randMPI(randBuf[:])
386
pa := new(big.Int).Exp(c.smp.g3, r4, p)
387
r.Exp(c.smp.g2, c.smp.secret, p)
388
qa := new(big.Int).Exp(g, r4, p)
393
s.Exp(c.smp.g2, r6, p)
397
s.Exp(c.smp.g3, r5, p)
398
cp.SetBytes(hashMPIs(h, 6, s, r))
401
d5 = new(big.Int).Sub(r5, r)
407
r.Mul(c.smp.secret, cp)
408
d6 = new(big.Int).Sub(r6, r)
415
qaqb := new(big.Int).Mul(qa, r)
418
ra := new(big.Int).Exp(qaqb, c.smp.a3, p)
421
cr := new(big.Int).SetBytes(hashMPIs(h, 7, s, r))
424
d7 := new(big.Int).Sub(r7, r)
434
c.smp.papb = new(big.Int).Mul(pa, r)
435
c.smp.papb.Mod(c.smp.papb, p)
438
out.typ = tlvTypeSMP3
439
out.data = appendU32(out.data, 8)
440
out.data = appendMPIs(out.data, pa, qa, cp, d5, d6, ra, cr, d7)
444
func (c *Conversation) processSMP3(mpis []*big.Int) (out tlv, err error) {
446
err = errors.New("otr: incorrect number of arguments in SMP3 message")
459
r := new(big.Int).Exp(g, d5, p)
460
s := new(big.Int).Exp(c.smp.g2, d6, p)
466
s.Exp(c.smp.g3, d5, p)
467
t := new(big.Int).Exp(pa, cp, p)
470
t.SetBytes(hashMPIs(h, 6, s, r))
472
err = errors.New("otr: ZKP cP failed in SMP3 message")
476
r.ModInverse(c.smp.qb, p)
477
qaqb := new(big.Int).Mul(qa, r)
486
t.Exp(c.smp.g3a, cr, p)
489
t.SetBytes(hashMPIs(h, 7, s, r))
491
err = errors.New("otr: ZKP cR failed in SMP3 message")
496
r7 := c.randMPI(randBuf[:])
497
rb := new(big.Int).Exp(qaqb, c.smp.b3, p)
501
cr = new(big.Int).SetBytes(hashMPIs(h, 8, s, r))
504
d7 = new(big.Int).Sub(r7, r)
510
out.typ = tlvTypeSMP4
511
out.data = appendU32(out.data, 3)
512
out.data = appendMPIs(out.data, rb, cr, d7)
514
r.ModInverse(c.smp.pb, p)
517
s.Exp(ra, c.smp.b3, p)
519
err = smpFailureError
525
func (c *Conversation) processSMP4(mpis []*big.Int) error {
527
return errors.New("otr: incorrect number of arguments in SMP4 message")
534
r := new(big.Int).Exp(c.smp.qaqb, d7, p)
535
s := new(big.Int).Exp(rb, cr, p)
540
t := new(big.Int).Exp(c.smp.g3b, cr, p)
543
t.SetBytes(hashMPIs(h, 8, s, r))
545
return errors.New("otr: ZKP cR failed in SMP4 message")
548
r.Exp(rb, c.smp.a3, p)
549
if r.Cmp(c.smp.papb) != 0 {
550
return smpFailureError
556
func (c *Conversation) generateSMPAbort() tlv {
557
return tlv{typ: tlvTypeSMPAbort}
560
func hashMPIs(h hash.Hash, magic byte, mpis ...*big.Int) []byte {
567
h.Write([]byte{magic})
568
for _, mpi := range mpis {
569
h.Write(appendMPI(nil, mpi))