~dustin-spy/twisted/dustin

« back to all changes in this revision

Viewing changes to twisted/conch/keys.py

  • Committer: z3p
  • Date: 2002-07-17 14:44:36 UTC
  • Revision ID: vcs-imports@canonical.com-20020717144436-6dce525e73d836e2
moving secsh to conch.
Conch: The Twisted Shell

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# base library imports
 
2
import base64
 
3
import os.path
 
4
import string
 
5
import sha
 
6
 
 
7
# external library imports
 
8
from Crypto.PublicKey import RSA, DSA
 
9
from Crypto import Util
 
10
 
 
11
# sibling imports
 
12
import asn1, common
 
13
 
 
14
class BadKeyError(Exception):
 
15
    """
 
16
    raised when a key isn't what we expected from it.
 
17
 
 
18
    XXX: we really need to check for bad keys    
 
19
    """
 
20
 
 
21
def getPublicKeyString(filename, line=0):
 
22
    lines = open(filename).readlines()
 
23
    data = lines[line]
 
24
    fileKind, fileData, desc = data.split()
 
25
#    if fileKind != kind:
 
26
#        raise BadKeyError, 'key should be %s but instead is %s' % (kind, fileKind)
 
27
    return base64.decodestring(fileData)
 
28
 
 
29
def getPublicKeyObject(filename=None, line = 0, data = ''):
 
30
    if not filename:
 
31
        publicKey = data
 
32
    else:
 
33
        publicKey = getPublicKeyString(filename, line)
 
34
    keyKind, rest = common.getNS(publicKey)
 
35
    if keyKind == 'ssh-rsa':
 
36
        e, rest = common.getMP(rest)
 
37
        n, rest = common.getMP(rest)
 
38
        return RSA.construct((n,e))
 
39
    elif keyKind == 'ssh-dss':
 
40
        p, rest = common.getMP(rest)
 
41
        q, rest = common.getMP(rest)
 
42
        g, rest = common.getMP(rest)
 
43
        y, rest = common.getMP(rest)
 
44
        return DSA.construct((y,g,p,q))
 
45
 
 
46
def getPrivateKeyObject(filename):
 
47
    data = open(filename).readlines()
 
48
    kind = data[0][11:14]
 
49
    keyData = base64.decodestring(''.join(data[1:-1]))
 
50
    decodedKey = asn1.parse(keyData)
 
51
    if kind == 'RSA':
 
52
        return RSA.construct(decodedKey[1:6])
 
53
    elif kind == 'DSA':
 
54
        p,q,g,y,x = decodedKey[1:6]
 
55
        return DSA.construct((y,g,p,q,x))
 
56
 
 
57
def objectType(obj):
 
58
    keyDataMapping = {
 
59
        ('n', 'e', 'd', 'p','q'):'ssh-rsa',
 
60
        ('y', 'g', 'p', 'q', 'x'):'ssh-dss'
 
61
    }
 
62
    return keyDataMapping[tuple(obj.keydata)]
 
63
 
 
64
def pkcs1Pad(data, lMod):
 
65
    lenPad = lMod - 2 - len(data)
 
66
    return '\x01' + ('\xff'*lenPad) + '\x00' + data
 
67
 
 
68
def pkcs1Digest(data, lMod):
 
69
    digest = sha.new(data).digest()
 
70
    return pkcs1Pad(ID_SHA1 + digest, lMod)
 
71
 
 
72
def lenSig(obj):
 
73
    return obj.size()/8
 
74
 
 
75
def signData(obj, data):
 
76
    mapping = {
 
77
        'ssh-rsa':signData_rsa,
 
78
        'ssh-dss':signData_dsa
 
79
    }
 
80
    objType = objectType(obj)
 
81
    return common.NS(objType) + mapping[objType](obj,data)
 
82
 
 
83
def signData_rsa(obj, data):
 
84
    sigData = pkcs1Digest(data, lenSig(obj))
 
85
    sig = obj.sign(sigData, '')[0] 
 
86
    return common.NS(Util.number.long_to_bytes(sig)) # get around adding the \x00 byte
 
87
 
 
88
def signData_dsa(obj, data):
 
89
    sigData = sha.new(data).digest()
 
90
    randData = open('/dev/random').read(19)
 
91
    sig = obj.sign(sigData, randData)
 
92
    return common.NS(''.join(map(Util.number.long_to_bytes, sig)))
 
93
 
 
94
def verifySignature(obj, sig, data):
 
95
    mapping = {
 
96
        'ssh-rsa':verifySignature_rsa,
 
97
        'ssh-dss':verifySignature_dsa,
 
98
    }
 
99
    objType = objectType(obj)
 
100
    sigType, sigData = common.getNS(sig)
 
101
    assert objType == sigType, 'object and signature are not of same type'
 
102
    return mapping[objType](obj, sigData, data)
 
103
 
 
104
def verifySignature_rsa(obj, sig, data):
 
105
    sigTuple = [common.getMP(sig)[0]]
 
106
    return obj.verify(pkcs1Digest(data, lenSig(obj)), sigTuple)
 
107
 
 
108
def verifySignature_dsa(obj, sig, data):
 
109
    sig = common.getNS(sig)[0]
 
110
    l = len(sig)/2
 
111
    sigTuple = map(Util.number.bytes_to_long, [sig[:l],sig[l:]])
 
112
    return obj.verify(sha.new(data).digest(), sigTuple)
 
113
 
 
114
def printKey(obj):
 
115
    print '%s %s (%s bits)' % (objectType(obj),
 
116
                               obj.hasprivate() and 'Private Key' or 'Public Key',
 
117
                               obj.size())
 
118
    for k in obj.keydata:
 
119
        if hasattr(obj, k):
 
120
            print 'attr', k
 
121
            by = common.MP(getattr(obj,k))[4:]
 
122
            while by:
 
123
                m = by[:15]
 
124
                by = by[15:]
 
125
                o = ''
 
126
                for c in m:
 
127
                    o=o+'%02x:'%ord(c)
 
128
                if len(m)<15:
 
129
                    o=o[:-1]
 
130
                print '\t'+o
 
131
 
 
132
ID_SHA1 = '\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14'
 
133
 
 
134
def test():
 
135
    testData = 'this is the testData'
 
136
    try:
 
137
        rsaKey = getPrivateKeyObject(os.path.expanduser('~/.ssh/id_rsa'))
 
138
        rsaPub = getPublicKeyObject(os.path.expanduser('~/.ssh/id_rsa.pub'))
 
139
    except IOError:
 
140
        print 'passing on rsa test, no rsa key'
 
141
    else:
 
142
        signature = signData(rsaKey, testData)
 
143
        assert verifySignature(rsaKey, signature, testData)
 
144
        assert verifySignature(rsaPub, signature, testData)
 
145
        print 'rsa is ok'
 
146
    try:
 
147
        dsaKey = getPrivateKeyObject(os.path.expanduser('~/.ssh/id_dsa'))
 
148
        dsaPub = getPublicKeyObject(os.path.expanduser('~/.ssh/id_dsa.pub'))
 
149
    except IOError:
 
150
        print 'passing on dsa test, no dsa key'
 
151
    else:
 
152
        signature = signData(dsaKey, testData)
 
153
        assert verifySignature(dsaKey, signature, testData), 'dsa is not ok'
 
154
        assert verifySignature(dsaPub, signature, testData)
 
155
        print 'dsa is ok'
 
156
if __name__=='__main__': test()
 
157
 
 
158
 
 
159
 
 
160