~dustin-spy/twisted/dustin

« back to all changes in this revision

Viewing changes to twisted/secsh/keys.py

  • Committer: z3p
  • Date: 2002-07-17 14:44:36 UTC
  • Revision ID: vcs-imports@canonical.com-20020717144436-6dce525e73d836e2
moving secsh to conch.
Conch: The Twisted Shell

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# base library imports
2
 
import base64
3
 
import os.path
4
 
import string
5
 
import sha
6
 
 
7
 
# external library imports
8
 
from Crypto.PublicKey import RSA, DSA
9
 
from Crypto import Util
10
 
 
11
 
# sibling imports
12
 
import asn1, common
13
 
 
14
 
class BadKeyError(Exception):
15
 
    """
16
 
    raised when a key isn't what we expected from it.
17
 
 
18
 
    XXX: we really need to check for bad keys    
19
 
    """
20
 
 
21
 
def getPublicKeyString(filename, line=0):
22
 
    lines = open(filename).readlines()
23
 
    data = lines[line]
24
 
    fileKind, fileData, desc = data.split()
25
 
#    if fileKind != kind:
26
 
#        raise BadKeyError, 'key should be %s but instead is %s' % (kind, fileKind)
27
 
    return base64.decodestring(fileData)
28
 
 
29
 
def getPublicKeyObject(filename=None, line = 0, data = ''):
30
 
    if not filename:
31
 
        publicKey = data
32
 
    else:
33
 
        publicKey = getPublicKeyString(filename, line)
34
 
    keyKind, rest = common.getNS(publicKey)
35
 
    if keyKind == 'ssh-rsa':
36
 
        e, rest = common.getMP(rest)
37
 
        n, rest = common.getMP(rest)
38
 
        return RSA.construct((n,e))
39
 
    elif keyKind == 'ssh-dss':
40
 
        p, rest = common.getMP(rest)
41
 
        q, rest = common.getMP(rest)
42
 
        g, rest = common.getMP(rest)
43
 
        y, rest = common.getMP(rest)
44
 
        return DSA.construct((y,g,p,q))
45
 
 
46
 
def getPrivateKeyObject(filename):
47
 
    data = open(filename).readlines()
48
 
    kind = data[0][11:14]
49
 
    keyData = base64.decodestring(''.join(data[1:-1]))
50
 
    decodedKey = asn1.parse(keyData)
51
 
    if kind == 'RSA':
52
 
        return RSA.construct(decodedKey[1:6])
53
 
    elif kind == 'DSA':
54
 
        p,q,g,y,x = decodedKey[1:6]
55
 
        return DSA.construct((y,g,p,q,x))
56
 
 
57
 
def objectType(obj):
58
 
    keyDataMapping = {
59
 
        ('n', 'e', 'd', 'p','q'):'ssh-rsa',
60
 
        ('y', 'g', 'p', 'q', 'x'):'ssh-dss'
61
 
    }
62
 
    return keyDataMapping[tuple(obj.keydata)]
63
 
 
64
 
def pkcs1Pad(data, lMod):
65
 
    lenPad = lMod - 2 - len(data)
66
 
    return '\x01' + ('\xff'*lenPad) + '\x00' + data
67
 
 
68
 
def pkcs1Digest(data, lMod):
69
 
    digest = sha.new(data).digest()
70
 
    return pkcs1Pad(ID_SHA1 + digest, lMod)
71
 
 
72
 
def lenSig(obj):
73
 
    return obj.size()/8
74
 
 
75
 
def signData(obj, data):
76
 
    mapping = {
77
 
        'ssh-rsa':signData_rsa,
78
 
        'ssh-dss':signData_dsa
79
 
    }
80
 
    objType = objectType(obj)
81
 
    return common.NS(objType) + mapping[objType](obj,data)
82
 
 
83
 
def signData_rsa(obj, data):
84
 
    sigData = pkcs1Digest(data, lenSig(obj))
85
 
    sig = obj.sign(sigData, '')[0] 
86
 
    return common.NS(Util.number.long_to_bytes(sig)) # get around adding the \x00 byte
87
 
 
88
 
def signData_dsa(obj, data):
89
 
    sigData = sha.new(data).digest()
90
 
    randData = open('/dev/random').read(19)
91
 
    sig = obj.sign(sigData, randData)
92
 
    return common.NS(''.join(map(Util.number.long_to_bytes, sig)))
93
 
 
94
 
def verifySignature(obj, sig, data):
95
 
    mapping = {
96
 
        'ssh-rsa':verifySignature_rsa,
97
 
        'ssh-dss':verifySignature_dsa,
98
 
    }
99
 
    objType = objectType(obj)
100
 
    sigType, sigData = common.getNS(sig)
101
 
    assert objType == sigType, 'object and signature are not of same type'
102
 
    return mapping[objType](obj, sigData, data)
103
 
 
104
 
def verifySignature_rsa(obj, sig, data):
105
 
    sigTuple = [common.getMP(sig)[0]]
106
 
    return obj.verify(pkcs1Digest(data, lenSig(obj)), sigTuple)
107
 
 
108
 
def verifySignature_dsa(obj, sig, data):
109
 
    sig = common.getNS(sig)[0]
110
 
    l = len(sig)/2
111
 
    sigTuple = map(Util.number.bytes_to_long, [sig[:l],sig[l:]])
112
 
    return obj.verify(sha.new(data).digest(), sigTuple)
113
 
 
114
 
def printKey(obj):
115
 
    print '%s %s (%s bits)' % (objectType(obj),
116
 
                               obj.hasprivate() and 'Private Key' or 'Public Key',
117
 
                               obj.size())
118
 
    for k in obj.keydata:
119
 
        if hasattr(obj, k):
120
 
            print 'attr', k
121
 
            by = common.MP(getattr(obj,k))[4:]
122
 
            while by:
123
 
                m = by[:15]
124
 
                by = by[15:]
125
 
                o = ''
126
 
                for c in m:
127
 
                    o=o+'%02x:'%ord(c)
128
 
                if len(m)<15:
129
 
                    o=o[:-1]
130
 
                print '\t'+o
131
 
 
132
 
ID_SHA1 = '\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14'
133
 
 
134
 
def test():
135
 
    testData = 'this is the testData'
136
 
    try:
137
 
        rsaKey = getPrivateKeyObject(os.path.expanduser('~/.ssh/id_rsa'))
138
 
        rsaPub = getPublicKeyObject(os.path.expanduser('~/.ssh/id_rsa.pub'))
139
 
    except IOError:
140
 
        print 'passing on rsa test, no rsa key'
141
 
    else:
142
 
        signature = signData(rsaKey, testData)
143
 
        assert verifySignature(rsaKey, signature, testData)
144
 
        assert verifySignature(rsaPub, signature, testData)
145
 
        print 'rsa is ok'
146
 
    try:
147
 
        dsaKey = getPrivateKeyObject(os.path.expanduser('~/.ssh/id_dsa'))
148
 
        dsaPub = getPublicKeyObject(os.path.expanduser('~/.ssh/id_dsa.pub'))
149
 
    except IOError:
150
 
        print 'passing on dsa test, no dsa key'
151
 
    else:
152
 
        signature = signData(dsaKey, testData)
153
 
        assert verifySignature(dsaKey, signature, testData), 'dsa is not ok'
154
 
        assert verifySignature(dsaPub, signature, testData)
155
 
        print 'dsa is ok'
156
 
if __name__=='__main__': test()
157
 
 
158
 
 
159
 
 
160