/neatx/lib/auth.py

http://neatx.googlecode.com/ · Python · 249 lines · 138 code · 55 blank · 56 comment · 17 complexity · 58d67d11372d1c956c5df7331b5fff51 MD5 · raw file

  1. #
  2. #
  3. # Copyright (C) 2009 Google Inc.
  4. #
  5. # This program is free software; you can redistribute it and/or modify
  6. # it under the terms of the GNU General Public License as published by
  7. # the Free Software Foundation; either version 2 of the License, or
  8. # (at your option) any later version.
  9. #
  10. # This program is distributed in the hope that it will be useful, but
  11. # WITHOUT ANY WARRANTY; without even the implied warranty of
  12. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  13. # General Public License for more details.
  14. #
  15. # You should have received a copy of the GNU General Public License
  16. # along with this program; if not, write to the Free Software
  17. # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
  18. # 02110-1301, USA.
  19. """Module for authentication"""
  20. import logging
  21. import os
  22. import pexpect
  23. import re
  24. from cStringIO import StringIO
  25. from neatx import constants
  26. from neatx import errors
  27. from neatx import utils
  28. class _AuthBase(object):
  29. def __init__(self, cfg,
  30. stdout_fileno=constants.STDOUT_FILENO,
  31. stdin_fileno=constants.STDIN_FILENO):
  32. self._cfg = cfg
  33. self._stdout_fileno = stdout_fileno
  34. self._stdin_fileno = stdin_fileno
  35. def AuthenticateAndRun(self, username, password, args):
  36. raise NotImplementedError()
  37. class _ExpectAuthBase(_AuthBase):
  38. def AuthenticateAndRun(self, username, password, args):
  39. logging.debug("Authenticating as '%s', running %r", username, args)
  40. all_args = [self._GetTtySetupPath()] + self.GetCommand(username, args)
  41. logging.debug("Auth command %r", all_args)
  42. # Avoid NLS issues by unsetting LC_*, and setting LANG=C
  43. env = os.environ.copy()
  44. env["LANG"] = "C"
  45. for key in env.keys():
  46. if key.startswith('LC_'):
  47. del env[key]
  48. # Using variables instead of hardcoded indexes
  49. patterns = []
  50. password_prompt_idx = self._AddPattern(patterns,
  51. self.GetPasswordPrompt())
  52. nx_idx = self._AddPattern(patterns, re.compile("^NX> ", re.M))
  53. # Start child process
  54. # TODO: Timeout in configuration and/or per auth method
  55. child = pexpect.spawn(all_args[0], args=all_args[1:], env=env,
  56. timeout=30)
  57. buf = StringIO()
  58. nxbuf = StringIO()
  59. auth_successful = False
  60. try:
  61. while True:
  62. idx = child.expect(patterns)
  63. # Store all output seen before the match
  64. buf.write(child.before)
  65. # Store the matched output
  66. buf.write(child.after)
  67. if idx == password_prompt_idx:
  68. self._Send(child, password + os.linesep)
  69. # Wait for end of password prompt
  70. child.expect(os.linesep)
  71. # TODO: Timeout for programs not printing NX prompt within X seconds
  72. elif idx == nx_idx:
  73. # Program was started
  74. auth_successful = True
  75. nxbuf.write(child.after)
  76. nxbuf.write(child.buffer)
  77. break
  78. else:
  79. raise AssertionError("Invalid index")
  80. except pexpect.EOF:
  81. buf.write(child.before)
  82. except pexpect.TIMEOUT:
  83. buf.write(child.before)
  84. logging.debug("Authentication timed out (output=%r)", buf.getvalue())
  85. raise errors.AuthTimeoutError()
  86. if not auth_successful:
  87. raise errors.AuthFailedError(("Authentication failed (output=%r, "
  88. "exitstatus=%s, signum=%s)") %
  89. (utils.NormalizeSpace(buf.getvalue()),
  90. child.exitstatus, child.signalstatus))
  91. # Write protocol buffer contents to stdout
  92. os.write(self._stdout_fileno, nxbuf.getvalue())
  93. utils.SetCloseOnExecFlag(child.fileno(), False)
  94. utils.SetCloseOnExecFlag(self._stdin_fileno, False)
  95. utils.SetCloseOnExecFlag(self._stdout_fileno, False)
  96. cpargs = [self._GetFdCopyPath(),
  97. "%s:%s" % (child.fileno(), self._stdout_fileno),
  98. "%s:%s" % (self._stdin_fileno, child.fileno())]
  99. # Run fdcopy to copy data between file descriptors
  100. ret = os.spawnve(os.P_WAIT, cpargs[0], cpargs, env)
  101. (exitcode, signum) = utils.GetExitcodeSignal(ret)
  102. logging.debug("fdcopy exited (exitstatus=%s, signum=%s)",
  103. exitcode, signum)
  104. # Discard anything left in buffer
  105. child.read()
  106. def _CheckChild():
  107. if child.isalive():
  108. raise utils.RetryAgain()
  109. logging.info("Waiting for authenticated program to finish")
  110. try:
  111. utils.Retry(_CheckChild, 0.5, 1.1, 5.0, 30)
  112. except utils.RetryTimeout:
  113. logging.error("Timeout while waiting for authenticated program "
  114. "to finish")
  115. child.close()
  116. logging.debug(("Authenticated program finished (exitstatus=%s, "
  117. "signalstatus=%s)"), child.exitstatus, child.signalstatus)
  118. def _GetFdCopyPath(self):
  119. return constants.FDCOPY
  120. def _GetTtySetupPath(self):
  121. return constants.TTYSETUP
  122. @staticmethod
  123. def _Send(child, text):
  124. """Write password to child program.
  125. """
  126. # child.send may not write everything in one go
  127. pos = 0
  128. while True:
  129. pos += child.send(text[pos:])
  130. if pos >= len(text):
  131. break
  132. @staticmethod
  133. def _AddPattern(patterns, pattern):
  134. """Adds pattern to list and returns new index.
  135. """
  136. patterns.append(pattern)
  137. return len(patterns) - 1
  138. class SuAuth(_ExpectAuthBase):
  139. def GetCommand(self, username, args):
  140. cmd = " && ".join([
  141. # Change to home directory
  142. "cd",
  143. # Run command
  144. utils.ShellQuoteArgs(args)
  145. ])
  146. return [self._cfg.su, username, "-c", cmd]
  147. def GetPasswordPrompt(self):
  148. return re.compile(r"^(\S+\s)?Password:\s*", re.I | re.M)
  149. class SshAuth(_ExpectAuthBase):
  150. def GetCommand(self, username, args):
  151. # TODO: Allow for per-user hostname. A very flexible way would be to run an
  152. # external script (e.g. "/.../userhost $username"), and let it print the
  153. # target hostname on stdout. If the hostname is an absolute path it could
  154. # be used as the script.
  155. host = self._cfg.auth_ssh_host
  156. port = self._cfg.auth_ssh_port
  157. options = [
  158. "-oNumberOfPasswordPrompts=1",
  159. "-oPreferredAuthentications=password",
  160. "-oEscapeChar=none",
  161. "-oCompression=no",
  162. # Always trust host keys
  163. "-oStrictHostKeyChecking=no",
  164. # Don't try to write a known_hosts file
  165. "-oUserKnownHostsFile=/dev/null",
  166. ]
  167. cmd = utils.ShellQuoteArgs(args)
  168. return ([self._cfg.ssh, "-2", "-x", "-l", username, "-p", str(port)] +
  169. options + [host, "--", cmd])
  170. def GetPasswordPrompt(self):
  171. return re.compile(r"^.*@.*\s+password:\s*", re.I | re.M)
  172. _AUTH_METHOD_MAP = {
  173. constants.AUTH_METHOD_SU: SuAuth,
  174. constants.AUTH_METHOD_SSH: SshAuth,
  175. }
  176. def GetAuthenticator(cfg, _method_map=_AUTH_METHOD_MAP):
  177. """Returns the authenticator for an authentication method.
  178. @type cfg: L{config.Config}
  179. @param cfg: Configuration object
  180. @rtype: class
  181. @return: Authentication class
  182. @raise errors.UnknownAuthMethod: Raised when an unknown authentication method
  183. is requested
  184. """
  185. method = cfg.auth_method
  186. try:
  187. cls = _method_map[method]
  188. except KeyError:
  189. raise errors.UnknownAuthMethod("Unknown authentication method %r" % method)
  190. return cls(cfg)