PageRenderTime 43ms CodeModel.GetById 11ms RepoModel.GetById 0ms app.codeStats 0ms

/desktop/core/ext-py/Twisted/twisted/conch/client/unix.py

https://github.com/jcrobak/hue
Python | 396 lines | 319 code | 62 blank | 15 comment | 28 complexity | 651567edd40eabde07bb96ece54444dc MD5 | raw file
  1. # Copyright (c) 2001-2004 Twisted Matrix Laboratories.
  2. # See LICENSE for details.
  3. #
  4. from twisted.conch.error import ConchError
  5. from twisted.conch.ssh import channel, connection
  6. from twisted.internet import defer, protocol, reactor
  7. from twisted.python import log
  8. from twisted.spread import banana
  9. import os, stat, pickle
  10. import types # this is for evil
  11. class SSHUnixClientFactory(protocol.ClientFactory):
  12. # noisy = 1
  13. def __init__(self, d, options, userAuthObject):
  14. self.d = d
  15. self.options = options
  16. self.userAuthObject = userAuthObject
  17. def clientConnectionLost(self, connector, reason):
  18. if self.options['reconnect']:
  19. connector.connect()
  20. #log.err(reason)
  21. if not self.d: return
  22. d = self.d
  23. self.d = None
  24. d.errback(reason)
  25. def clientConnectionFailed(self, connector, reason):
  26. #try:
  27. # os.unlink(connector.transport.addr)
  28. #except:
  29. # pass
  30. #log.err(reason)
  31. if not self.d: return
  32. d = self.d
  33. self.d = None
  34. d.errback(reason)
  35. #reactor.connectTCP(options['host'], options['port'], SSHClientFactory())
  36. def startedConnecting(self, connector):
  37. fd = connector.transport.fileno()
  38. stats = os.fstat(fd)
  39. try:
  40. filestats = os.stat(connector.transport.addr)
  41. except:
  42. connector.stopConnecting()
  43. return
  44. if stat.S_IMODE(filestats[0]) != 0600:
  45. log.msg("socket mode is not 0600: %s" % oct(stat.S_IMODE(stats[0])))
  46. elif filestats[4] != os.getuid():
  47. log.msg("socket not owned by us: %s" % stats[4])
  48. elif filestats[5] != os.getgid():
  49. log.msg("socket not owned by our group: %s" % stats[5])
  50. # XXX reenable this when i can fix it for cygwin
  51. #elif filestats[-3:] != stats[-3:]:
  52. # log.msg("socket doesn't have same create times")
  53. else:
  54. log.msg('conecting OK')
  55. return
  56. connector.stopConnecting()
  57. def buildProtocol(self, addr):
  58. # here comes the EVIL
  59. obj = self.userAuthObject.instance
  60. bases = []
  61. for base in obj.__class__.__bases__:
  62. if base == connection.SSHConnection:
  63. bases.append(SSHUnixClientProtocol)
  64. else:
  65. bases.append(base)
  66. newClass = types.ClassType(obj.__class__.__name__, tuple(bases), obj.__class__.__dict__)
  67. obj.__class__ = newClass
  68. SSHUnixClientProtocol.__init__(obj)
  69. log.msg('returning %s' % obj)
  70. if self.d:
  71. d = self.d
  72. self.d = None
  73. d.callback(None)
  74. return obj
  75. class SSHUnixServerFactory(protocol.Factory):
  76. def __init__(self, conn):
  77. self.conn = conn
  78. def buildProtocol(self, addr):
  79. return SSHUnixServerProtocol(self.conn)
  80. class SSHUnixProtocol(banana.Banana):
  81. knownDialects = ['none']
  82. def __init__(self):
  83. banana.Banana.__init__(self)
  84. self.deferredQueue = []
  85. self.deferreds = {}
  86. self.deferredID = 0
  87. def connectionMade(self):
  88. log.msg('connection made %s' % self)
  89. banana.Banana.connectionMade(self)
  90. def expressionReceived(self, lst):
  91. vocabName = lst[0]
  92. fn = "msg_%s" % vocabName
  93. func = getattr(self, fn)
  94. func(lst[1:])
  95. def sendMessage(self, vocabName, *tup):
  96. self.sendEncoded([vocabName] + list(tup))
  97. def returnDeferredLocal(self):
  98. d = defer.Deferred()
  99. self.deferredQueue.append(d)
  100. return d
  101. def returnDeferredWire(self, d):
  102. di = self.deferredID
  103. self.deferredID += 1
  104. self.sendMessage('returnDeferred', di)
  105. d.addCallback(self._cbDeferred, di)
  106. d.addErrback(self._ebDeferred, di)
  107. def _cbDeferred(self, result, di):
  108. self.sendMessage('callbackDeferred', di, pickle.dumps(result))
  109. def _ebDeferred(self, reason, di):
  110. self.sendMessage('errbackDeferred', di, pickle.dumps(reason))
  111. def msg_returnDeferred(self, lst):
  112. deferredID = lst[0]
  113. self.deferreds[deferredID] = self.deferredQueue.pop(0)
  114. def msg_callbackDeferred(self, lst):
  115. deferredID, result = lst
  116. d = self.deferreds[deferredID]
  117. del self.deferreds[deferredID]
  118. d.callback(pickle.loads(result))
  119. def msg_errbackDeferred(self, lst):
  120. deferredID, result = lst
  121. d = self.deferreds[deferredID]
  122. del self.deferreds[deferredID]
  123. d.errback(pickle.loads(result))
  124. class SSHUnixClientProtocol(SSHUnixProtocol):
  125. def __init__(self):
  126. SSHUnixProtocol.__init__(self)
  127. self.isClient = 1
  128. self.channelQueue = []
  129. self.channels = {}
  130. def logPrefix(self):
  131. return "SSHUnixClientProtocol (%i) on %s" % (id(self), self.transport.logPrefix())
  132. def connectionReady(self):
  133. log.msg('connection ready')
  134. self.serviceStarted()
  135. def connectionLost(self, reason):
  136. self.serviceStopped()
  137. def requestRemoteForwarding(self, remotePort, hostport):
  138. self.sendMessage('requestRemoteForwarding', remotePort, hostport)
  139. def cancelRemoteForwarding(self, remotePort):
  140. self.sendMessage('cancelRemoteForwarding', remotePort)
  141. def sendGlobalRequest(self, request, data, wantReply = 0):
  142. self.sendMessage('sendGlobalRequest', request, data, wantReply)
  143. if wantReply:
  144. return self.returnDeferredLocal()
  145. def openChannel(self, channel, extra = ''):
  146. self.channelQueue.append(channel)
  147. channel.conn = self
  148. self.sendMessage('openChannel', channel.name,
  149. channel.localWindowSize,
  150. channel.localMaxPacket, extra)
  151. def sendRequest(self, channel, requestType, data, wantReply = 0):
  152. self.sendMessage('sendRequest', channel.id, requestType, data, wantReply)
  153. if wantReply:
  154. return self.returnDeferredLocal()
  155. def adjustWindow(self, channel, bytesToAdd):
  156. self.sendMessage('adjustWindow', channel.id, bytesToAdd)
  157. def sendData(self, channel, data):
  158. self.sendMessage('sendData', channel.id, data)
  159. def sendExtendedData(self, channel, dataType, data):
  160. self.sendMessage('sendExtendedData', channel.id, data)
  161. def sendEOF(self, channel):
  162. self.sendMessage('sendEOF', channel.id)
  163. def sendClose(self, channel):
  164. self.sendMessage('sendClose', channel.id)
  165. def msg_channelID(self, lst):
  166. channelID = lst[0]
  167. self.channels[channelID] = self.channelQueue.pop(0)
  168. self.channels[channelID].id = channelID
  169. def msg_channelOpen(self, lst):
  170. channelID, remoteWindow, remoteMax, specificData = lst
  171. channel = self.channels[channelID]
  172. channel.remoteWindowLeft = remoteWindow
  173. channel.remoteMaxPacket = remoteMax
  174. channel.channelOpen(specificData)
  175. def msg_openFailed(self, lst):
  176. channelID, reason = lst
  177. self.channels[channelID].openFailed(pickle.loads(reason))
  178. del self.channels[channelID]
  179. def msg_addWindowBytes(self, lst):
  180. channelID, bytes = lst
  181. self.channels[channelID].addWindowBytes(bytes)
  182. def msg_requestReceived(self, lst):
  183. channelID, requestType, data = lst
  184. d = defer.maybeDeferred(self.channels[channelID].requestReceived, requestType, data)
  185. self.returnDeferredWire(d)
  186. def msg_dataReceived(self, lst):
  187. channelID, data = lst
  188. self.channels[channelID].dataReceived(data)
  189. def msg_extReceived(self, lst):
  190. channelID, dataType, data = lst
  191. self.channels[channelID].extReceived(dataType, data)
  192. def msg_eofReceived(self, lst):
  193. channelID = lst[0]
  194. self.channels[channelID].eofReceived()
  195. def msg_closeReceived(self, lst):
  196. channelID = lst[0]
  197. channel = self.channels[channelID]
  198. channel.remoteClosed = 1
  199. channel.closeReceived()
  200. def msg_closed(self, lst):
  201. channelID = lst[0]
  202. channel = self.channels[channelID]
  203. self.channelClosed(channel)
  204. def channelClosed(self, channel):
  205. channel.localClosed = channel.remoteClosed = 1
  206. del self.channels[channel.id]
  207. log.callWithLogger(channel, channel.closed)
  208. # just in case the user doesn't override
  209. def serviceStarted(self):
  210. pass
  211. def serviceStopped(self):
  212. pass
  213. class SSHUnixServerProtocol(SSHUnixProtocol):
  214. def __init__(self, conn):
  215. SSHUnixProtocol.__init__(self)
  216. self.isClient = 0
  217. self.conn = conn
  218. def connectionLost(self, reason):
  219. for channel in self.conn.channels.values():
  220. if isinstance(channel, SSHUnixChannel) and channel.unix == self:
  221. log.msg('forcibly closing %s' % channel)
  222. try:
  223. self.conn.sendClose(channel)
  224. except:
  225. pass
  226. def haveChannel(self, channelID):
  227. return self.conn.channels.has_key(channelID)
  228. def getChannel(self, channelID):
  229. channel = self.conn.channels[channelID]
  230. if not isinstance(channel, SSHUnixChannel):
  231. raise ConchError('nice try bub')
  232. return channel
  233. def msg_requestRemoteForwarding(self, lst):
  234. remotePort, hostport = lst
  235. hostport = tuple(hostport)
  236. self.conn.requestRemoteForwarding(remotePort, hostport)
  237. def msg_cancelRemoteForwarding(self, lst):
  238. [remotePort] = lst
  239. self.conn.cancelRemoteForwarding(remotePort)
  240. def msg_sendGlobalRequest(self, lst):
  241. requestName, data, wantReply = lst
  242. d = self.conn.sendGlobalRequest(requestName, data, wantReply)
  243. if wantReply:
  244. self.returnDeferredWire(d)
  245. def msg_openChannel(self, lst):
  246. name, windowSize, maxPacket, extra = lst
  247. channel = SSHUnixChannel(self, name, windowSize, maxPacket)
  248. self.conn.openChannel(channel, extra)
  249. self.sendMessage('channelID', channel.id)
  250. def msg_sendRequest(self, lst):
  251. cn, requestType, data, wantReply = lst
  252. if not self.haveChannel(cn):
  253. if wantReply:
  254. self.returnDeferredWire(defer.fail(ConchError("no channel")))
  255. channel = self.getChannel(cn)
  256. d = self.conn.sendRequest(channel, requestType, data, wantReply)
  257. if wantReply:
  258. self.returnDeferredWire(d)
  259. def msg_adjustWindow(self, lst):
  260. cn, bytesToAdd = lst
  261. if not self.haveChannel(cn): return
  262. channel = self.getChannel(cn)
  263. self.conn.adjustWindow(channel, bytesToAdd)
  264. def msg_sendData(self, lst):
  265. cn, data = lst
  266. if not self.haveChannel(cn): return
  267. channel = self.getChannel(cn)
  268. self.conn.sendData(channel, data)
  269. def msg_sendExtended(self, lst):
  270. cn, dataType, data = lst
  271. if not self.haveChannel(cn): return
  272. channel = self.getChannel(cn)
  273. self.conn.sendExtendedData(channel, dataType, data)
  274. def msg_sendEOF(self, lst):
  275. (cn, ) = lst
  276. if not self.haveChannel(cn): return
  277. channel = self.getChannel(cn)
  278. self.conn.sendEOF(channel)
  279. def msg_sendClose(self, lst):
  280. (cn, ) = lst
  281. if not self.haveChannel(cn): return
  282. channel = self.getChannel(cn)
  283. self.conn.sendClose(channel)
  284. class SSHUnixChannel(channel.SSHChannel):
  285. def __init__(self, unix, name, windowSize, maxPacket):
  286. channel.SSHChannel.__init__(self, windowSize, maxPacket, conn = unix.conn)
  287. self.unix = unix
  288. self.name = name
  289. def channelOpen(self, specificData):
  290. self.unix.sendMessage('channelOpen', self.id, self.remoteWindowLeft,
  291. self.remoteMaxPacket, specificData)
  292. def openFailed(self, reason):
  293. self.unix.sendMessage('openFailed', self.id, pickle.dumps(reason))
  294. def addWindowBytes(self, bytes):
  295. self.unix.sendMessage('addWindowBytes', self.id, bytes)
  296. def dataReceived(self, data):
  297. self.unix.sendMessage('dataReceived', self.id, data)
  298. def requestReceived(self, reqType, data):
  299. self.unix.sendMessage('requestReceived', self.id, reqType, data)
  300. return self.unix.returnDeferredLocal()
  301. def extReceived(self, dataType, data):
  302. self.unix.sendMessage('extReceived', self.id, dataType, data)
  303. def eofReceived(self):
  304. self.unix.sendMessage('eofReceived', self.id)
  305. def closeReceived(self):
  306. self.unix.sendMessage('closeReceived', self.id)
  307. def closed(self):
  308. self.unix.sendMessage('closed', self.id)
  309. def connect(host, port, options, verifyHostKey, userAuthObject):
  310. if options['nocache']:
  311. return defer.fail(ConchError('not using connection caching'))
  312. d = defer.Deferred()
  313. filename = os.path.expanduser("~/.conch-%s-%s-%i" % (userAuthObject.user, host, port))
  314. factory = SSHUnixClientFactory(d, options, userAuthObject)
  315. reactor.connectUNIX(filename, factory, timeout=2, checkPID=1)
  316. return d