2
from twisted.internet import defer
3
from common import NS, getNS, MP
4
import keys, transport, service
6
class SSHUserAuthServer(service.SSHService):
8
protocolMessages = None # set later
9
supportedAuthentications = ('publickey','password')
10
authenticatedWith = []
12
def tryAuth(self, kind, user, data):
13
print 'trying auth %s for %s' % (kind, user)
14
#print 'with data: %s' % repr(data)
15
f= getattr(self,'auth_%s'%kind, None)
20
def ssh_USERAUTH_REQUEST(self, packet):
21
user, nextService, method, rest = getNS(packet, 3)
22
self.nextService = nextService
23
r = self.tryAuth(method, user, rest)
24
if r<0: # sent a different packet type back
26
if type(r) != type(defer.Deferred()):
28
r = defer.succeed(None)
31
r.addCallbacks(self._cbGoodAuth, self._cbBadAuth, callbackArgs = (method,))
33
def _cbGoodAuth(self, foo, method):
34
self.authenticatedWith.append(method)
36
self.transport.sendPacket(MSG_USERAUTH_SUCCESS, '')
37
self.transport.setService(self.transport.factory.services[self.nextService]())
39
self.transport.sendPacket(MSG_USERAUTH_FAILURE, NS(','.join(self.supportedAuthentications))+'\xff')
41
def _cbBadAuth(self, foo):
42
self.transport.sendPacket(MSG_USERAUTH_FAILURE, NS(','.join(self.supportedAuthentications))+'\x00')
44
def auth_publickey(self, user, packet):
45
hasSig = ord(packet[0])
46
self.hasSigType = hasSig # protocol impl.s differ in this
47
algName, blob, rest = getNS(packet[1:], 2)
50
if self.isValidKeyFor(user, blob) and self.verifySignatureFor(user, blob, getNS(rest)[0]):
54
if self.isValidKeyFor(user, blob):
55
self.transport.sendPacket(MSG_USERAUTH_PK_OK, packet[1:])
59
def verifySignatureFor(self, user, blob, signature):
60
pubKey = keys.getPublicKeyObject(data = blob)
61
b = NS(self.transport.sessionID) + chr(MSG_USERAUTH_REQUEST) + \
62
NS(user) + NS(self.nextService) + NS('publickey') + chr(self.hasSigType) + \
63
NS(keys.objectType(pubKey)) + NS(blob)
64
return keys.verifySignature(pubKey, signature, b)
68
# overwrite on the client side
70
return len(self.authenticatedWith)>0
72
def isValidKeyFor(self, user, pubKey):
73
home = os.path.expanduser('~%s/.ssh/' % user)
74
for file in ['authorized_keys', 'authorized_keys2']:
75
if os.path.exists(home+file):
76
lines = open(home+file).readlines()
78
if base64.decodestring(l.split()[1])==pubKey:
83
class SSHUserAuthClient(service.SSHService):
85
protocolMessages = None # set later
86
def __init__(self, user, instance):
88
self.instance = instance
89
self.authenticatedWith = []
90
self.triedPublicKeys = []
92
def serviceStarted(self):
93
self.askForAuth('none', '')
95
def askForAuth(self, kind, extraData):
97
self.transport.sendPacket(MSG_USERAUTH_REQUEST, NS(self.user) + \
98
NS(self.instance.name) + NS(kind) + extraData)
99
def tryAuth(self, kind):
100
f= getattr(self,'auth_%s'%kind, None)
104
def ssh_USERAUTH_SUCCESS(self, packet):
105
self.transport.setService(self.instance)
107
def ssh_USERAUTH_FAILURE(self, packet):
108
canContinue, partial = getNS(packet)
109
canContinue = canContinue.split(',')
111
partial = ord(partial)
113
self.authenticatedWith.append(self.lastAuth)
114
for method in canContinue:
115
if method not in self.authenticatedWith and self.tryAuth(method):
118
def ssh_USERAUTH_PK_OK(self, packet):
119
if self.lastAuth == 'publickey':
121
privateKey = keys.getPrivateKeyObject(os.path.expanduser('~/.ssh/id_rsa'))
122
publicKey = keys.getPublicKeyString(os.path.expanduser('~/.ssh/id_rsa.pub'))
123
b = NS(self.transport.sessionID) + chr(MSG_USERAUTH_REQUEST) + \
124
NS(self.user) + NS(self.instance.name) + NS('publickey') + '\xff' + \
125
NS('ssh-rsa') + NS(publicKey)
126
self.askForAuth('publickey', '\xff' + NS('ssh-rsa') + NS(publicKey) + \
127
NS(keys.signData(privateKey, b)))
128
elif self.lastAuth == 'password':
129
prompt, language, rest = getNS(packet, 2)
130
op = getpass('Old Password: ')
132
self.askForAuth('password', '\xff'+NS(op)+NS(np))
134
def auth_publickey(self):
135
if os.path.exists(os.path.expanduser('~/.ssh/id_rsa')) and not 'file' in self.triedPublicKeys:
136
self.triedPublicKeys.append('file')
137
self.askForAuth('publickey', '\x00' + NS('ssh-rsa') + \
138
NS(keys.getPublicKeyString(os.path.expanduser('~/.ssh/id_rsa.pub'))))
141
def auth_password(self):
142
d = self.getPassword()
143
d.addCallback(self._cbPassword)
146
def _cbPassword(self, password):
147
self.askForAuth('password', '\x00'+NS(password))
149
def getPassword(self, prompt = None):
151
prompt = 'Password for %s: ' % self.user
152
return defer.succeed(getpass(prompt))
154
def getpass(prompt = "Password: "):
156
fd = sys.stdin.fileno()
157
old = termios.tcgetattr(fd)
158
new = termios.tcgetattr(fd)
159
new[3] = new[3] & ~termios.ECHO # lflags
161
termios.tcsetattr(fd, termios.TCSADRAIN, new)
162
passwd = raw_input(prompt)
164
termios.tcsetattr(fd, termios.TCSADRAIN, old)
167
MSG_USERAUTH_REQUEST = 50
168
MSG_USERAUTH_FAILURE = 51
169
MSG_USERAUTH_SUCCESS = 52
170
MSG_USERAUTH_BANNER = 53
171
MSG_USERAUTH_PASSWD_CHANGEREQ = 60
172
MSG_USERAUTH_PK_OK = 60
176
for v in dir(userauth):
178
messages[getattr(userauth,v)] = v # doesn't handle doubles
180
SSHUserAuthServer.protocolMessages = messages
181
SSHUserAuthClient.protocolMessages = messages