PageRenderTime 60ms CodeModel.GetById 22ms RepoModel.GetById 0ms app.codeStats 0ms

/mercurial/wireproto.py

https://bitbucket.org/mirror/mercurial/
Python | 863 lines | 739 code | 28 blank | 96 comment | 11 complexity | 3901550b13bd8b030ff58ca2e6d631cd MD5 | raw file
Possible License(s): GPL-2.0
  1. # wireproto.py - generic wire protocol support functions
  2. #
  3. # Copyright 2005-2010 Matt Mackall <mpm@selenic.com>
  4. #
  5. # This software may be used and distributed according to the terms of the
  6. # GNU General Public License version 2 or any later version.
  7. import urllib, tempfile, os, sys
  8. from i18n import _
  9. from node import bin, hex
  10. import changegroup as changegroupmod, bundle2, pushkey as pushkeymod
  11. import peer, error, encoding, util, store, exchange
  12. class abstractserverproto(object):
  13. """abstract class that summarizes the protocol API
  14. Used as reference and documentation.
  15. """
  16. def getargs(self, args):
  17. """return the value for arguments in <args>
  18. returns a list of values (same order as <args>)"""
  19. raise NotImplementedError()
  20. def getfile(self, fp):
  21. """write the whole content of a file into a file like object
  22. The file is in the form::
  23. (<chunk-size>\n<chunk>)+0\n
  24. chunk size is the ascii version of the int.
  25. """
  26. raise NotImplementedError()
  27. def redirect(self):
  28. """may setup interception for stdout and stderr
  29. See also the `restore` method."""
  30. raise NotImplementedError()
  31. # If the `redirect` function does install interception, the `restore`
  32. # function MUST be defined. If interception is not used, this function
  33. # MUST NOT be defined.
  34. #
  35. # left commented here on purpose
  36. #
  37. #def restore(self):
  38. # """reinstall previous stdout and stderr and return intercepted stdout
  39. # """
  40. # raise NotImplementedError()
  41. def groupchunks(self, cg):
  42. """return 4096 chunks from a changegroup object
  43. Some protocols may have compressed the contents."""
  44. raise NotImplementedError()
  45. # abstract batching support
  46. class future(object):
  47. '''placeholder for a value to be set later'''
  48. def set(self, value):
  49. if util.safehasattr(self, 'value'):
  50. raise error.RepoError("future is already set")
  51. self.value = value
  52. class batcher(object):
  53. '''base class for batches of commands submittable in a single request
  54. All methods invoked on instances of this class are simply queued and
  55. return a a future for the result. Once you call submit(), all the queued
  56. calls are performed and the results set in their respective futures.
  57. '''
  58. def __init__(self):
  59. self.calls = []
  60. def __getattr__(self, name):
  61. def call(*args, **opts):
  62. resref = future()
  63. self.calls.append((name, args, opts, resref,))
  64. return resref
  65. return call
  66. def submit(self):
  67. pass
  68. class localbatch(batcher):
  69. '''performs the queued calls directly'''
  70. def __init__(self, local):
  71. batcher.__init__(self)
  72. self.local = local
  73. def submit(self):
  74. for name, args, opts, resref in self.calls:
  75. resref.set(getattr(self.local, name)(*args, **opts))
  76. class remotebatch(batcher):
  77. '''batches the queued calls; uses as few roundtrips as possible'''
  78. def __init__(self, remote):
  79. '''remote must support _submitbatch(encbatch) and
  80. _submitone(op, encargs)'''
  81. batcher.__init__(self)
  82. self.remote = remote
  83. def submit(self):
  84. req, rsp = [], []
  85. for name, args, opts, resref in self.calls:
  86. mtd = getattr(self.remote, name)
  87. batchablefn = getattr(mtd, 'batchable', None)
  88. if batchablefn is not None:
  89. batchable = batchablefn(mtd.im_self, *args, **opts)
  90. encargsorres, encresref = batchable.next()
  91. if encresref:
  92. req.append((name, encargsorres,))
  93. rsp.append((batchable, encresref, resref,))
  94. else:
  95. resref.set(encargsorres)
  96. else:
  97. if req:
  98. self._submitreq(req, rsp)
  99. req, rsp = [], []
  100. resref.set(mtd(*args, **opts))
  101. if req:
  102. self._submitreq(req, rsp)
  103. def _submitreq(self, req, rsp):
  104. encresults = self.remote._submitbatch(req)
  105. for encres, r in zip(encresults, rsp):
  106. batchable, encresref, resref = r
  107. encresref.set(encres)
  108. resref.set(batchable.next())
  109. def batchable(f):
  110. '''annotation for batchable methods
  111. Such methods must implement a coroutine as follows:
  112. @batchable
  113. def sample(self, one, two=None):
  114. # Handle locally computable results first:
  115. if not one:
  116. yield "a local result", None
  117. # Build list of encoded arguments suitable for your wire protocol:
  118. encargs = [('one', encode(one),), ('two', encode(two),)]
  119. # Create future for injection of encoded result:
  120. encresref = future()
  121. # Return encoded arguments and future:
  122. yield encargs, encresref
  123. # Assuming the future to be filled with the result from the batched
  124. # request now. Decode it:
  125. yield decode(encresref.value)
  126. The decorator returns a function which wraps this coroutine as a plain
  127. method, but adds the original method as an attribute called "batchable",
  128. which is used by remotebatch to split the call into separate encoding and
  129. decoding phases.
  130. '''
  131. def plain(*args, **opts):
  132. batchable = f(*args, **opts)
  133. encargsorres, encresref = batchable.next()
  134. if not encresref:
  135. return encargsorres # a local result in this case
  136. self = args[0]
  137. encresref.set(self._submitone(f.func_name, encargsorres))
  138. return batchable.next()
  139. setattr(plain, 'batchable', f)
  140. return plain
  141. # list of nodes encoding / decoding
  142. def decodelist(l, sep=' '):
  143. if l:
  144. return map(bin, l.split(sep))
  145. return []
  146. def encodelist(l, sep=' '):
  147. return sep.join(map(hex, l))
  148. # batched call argument encoding
  149. def escapearg(plain):
  150. return (plain
  151. .replace(':', '::')
  152. .replace(',', ':,')
  153. .replace(';', ':;')
  154. .replace('=', ':='))
  155. def unescapearg(escaped):
  156. return (escaped
  157. .replace(':=', '=')
  158. .replace(':;', ';')
  159. .replace(':,', ',')
  160. .replace('::', ':'))
  161. # mapping of options accepted by getbundle and their types
  162. #
  163. # Meant to be extended by extensions. It is extensions responsibility to ensure
  164. # such options are properly processed in exchange.getbundle.
  165. #
  166. # supported types are:
  167. #
  168. # :nodes: list of binary nodes
  169. # :csv: list of comma-separated values
  170. # :plain: string with no transformation needed.
  171. gboptsmap = {'heads': 'nodes',
  172. 'common': 'nodes',
  173. 'bundlecaps': 'csv',
  174. 'listkeys': 'csv'}
  175. # client side
  176. class wirepeer(peer.peerrepository):
  177. def batch(self):
  178. return remotebatch(self)
  179. def _submitbatch(self, req):
  180. cmds = []
  181. for op, argsdict in req:
  182. args = ','.join('%s=%s' % p for p in argsdict.iteritems())
  183. cmds.append('%s %s' % (op, args))
  184. rsp = self._call("batch", cmds=';'.join(cmds))
  185. return rsp.split(';')
  186. def _submitone(self, op, args):
  187. return self._call(op, **args)
  188. @batchable
  189. def lookup(self, key):
  190. self.requirecap('lookup', _('look up remote revision'))
  191. f = future()
  192. yield {'key': encoding.fromlocal(key)}, f
  193. d = f.value
  194. success, data = d[:-1].split(" ", 1)
  195. if int(success):
  196. yield bin(data)
  197. self._abort(error.RepoError(data))
  198. @batchable
  199. def heads(self):
  200. f = future()
  201. yield {}, f
  202. d = f.value
  203. try:
  204. yield decodelist(d[:-1])
  205. except ValueError:
  206. self._abort(error.ResponseError(_("unexpected response:"), d))
  207. @batchable
  208. def known(self, nodes):
  209. f = future()
  210. yield {'nodes': encodelist(nodes)}, f
  211. d = f.value
  212. try:
  213. yield [bool(int(f)) for f in d]
  214. except ValueError:
  215. self._abort(error.ResponseError(_("unexpected response:"), d))
  216. @batchable
  217. def branchmap(self):
  218. f = future()
  219. yield {}, f
  220. d = f.value
  221. try:
  222. branchmap = {}
  223. for branchpart in d.splitlines():
  224. branchname, branchheads = branchpart.split(' ', 1)
  225. branchname = encoding.tolocal(urllib.unquote(branchname))
  226. branchheads = decodelist(branchheads)
  227. branchmap[branchname] = branchheads
  228. yield branchmap
  229. except TypeError:
  230. self._abort(error.ResponseError(_("unexpected response:"), d))
  231. def branches(self, nodes):
  232. n = encodelist(nodes)
  233. d = self._call("branches", nodes=n)
  234. try:
  235. br = [tuple(decodelist(b)) for b in d.splitlines()]
  236. return br
  237. except ValueError:
  238. self._abort(error.ResponseError(_("unexpected response:"), d))
  239. def between(self, pairs):
  240. batch = 8 # avoid giant requests
  241. r = []
  242. for i in xrange(0, len(pairs), batch):
  243. n = " ".join([encodelist(p, '-') for p in pairs[i:i + batch]])
  244. d = self._call("between", pairs=n)
  245. try:
  246. r.extend(l and decodelist(l) or [] for l in d.splitlines())
  247. except ValueError:
  248. self._abort(error.ResponseError(_("unexpected response:"), d))
  249. return r
  250. @batchable
  251. def pushkey(self, namespace, key, old, new):
  252. if not self.capable('pushkey'):
  253. yield False, None
  254. f = future()
  255. self.ui.debug('preparing pushkey for "%s:%s"\n' % (namespace, key))
  256. yield {'namespace': encoding.fromlocal(namespace),
  257. 'key': encoding.fromlocal(key),
  258. 'old': encoding.fromlocal(old),
  259. 'new': encoding.fromlocal(new)}, f
  260. d = f.value
  261. d, output = d.split('\n', 1)
  262. try:
  263. d = bool(int(d))
  264. except ValueError:
  265. raise error.ResponseError(
  266. _('push failed (unexpected response):'), d)
  267. for l in output.splitlines(True):
  268. self.ui.status(_('remote: '), l)
  269. yield d
  270. @batchable
  271. def listkeys(self, namespace):
  272. if not self.capable('pushkey'):
  273. yield {}, None
  274. f = future()
  275. self.ui.debug('preparing listkeys for "%s"\n' % namespace)
  276. yield {'namespace': encoding.fromlocal(namespace)}, f
  277. d = f.value
  278. yield pushkeymod.decodekeys(d)
  279. def stream_out(self):
  280. return self._callstream('stream_out')
  281. def changegroup(self, nodes, kind):
  282. n = encodelist(nodes)
  283. f = self._callcompressable("changegroup", roots=n)
  284. return changegroupmod.unbundle10(f, 'UN')
  285. def changegroupsubset(self, bases, heads, kind):
  286. self.requirecap('changegroupsubset', _('look up remote changes'))
  287. bases = encodelist(bases)
  288. heads = encodelist(heads)
  289. f = self._callcompressable("changegroupsubset",
  290. bases=bases, heads=heads)
  291. return changegroupmod.unbundle10(f, 'UN')
  292. def getbundle(self, source, **kwargs):
  293. self.requirecap('getbundle', _('look up remote changes'))
  294. opts = {}
  295. for key, value in kwargs.iteritems():
  296. if value is None:
  297. continue
  298. keytype = gboptsmap.get(key)
  299. if keytype is None:
  300. assert False, 'unexpected'
  301. elif keytype == 'nodes':
  302. value = encodelist(value)
  303. elif keytype == 'csv':
  304. value = ','.join(value)
  305. elif keytype != 'plain':
  306. raise KeyError('unknown getbundle option type %s'
  307. % keytype)
  308. opts[key] = value
  309. f = self._callcompressable("getbundle", **opts)
  310. bundlecaps = kwargs.get('bundlecaps')
  311. if bundlecaps is not None and 'HG2X' in bundlecaps:
  312. return bundle2.unbundle20(self.ui, f)
  313. else:
  314. return changegroupmod.unbundle10(f, 'UN')
  315. def unbundle(self, cg, heads, source):
  316. '''Send cg (a readable file-like object representing the
  317. changegroup to push, typically a chunkbuffer object) to the
  318. remote server as a bundle.
  319. When pushing a bundle10 stream, return an integer indicating the
  320. result of the push (see localrepository.addchangegroup()).
  321. When pushing a bundle20 stream, return a bundle20 stream.'''
  322. if heads != ['force'] and self.capable('unbundlehash'):
  323. heads = encodelist(['hashed',
  324. util.sha1(''.join(sorted(heads))).digest()])
  325. else:
  326. heads = encodelist(heads)
  327. if util.safehasattr(cg, 'deltaheader'):
  328. # this a bundle10, do the old style call sequence
  329. ret, output = self._callpush("unbundle", cg, heads=heads)
  330. if ret == "":
  331. raise error.ResponseError(
  332. _('push failed:'), output)
  333. try:
  334. ret = int(ret)
  335. except ValueError:
  336. raise error.ResponseError(
  337. _('push failed (unexpected response):'), ret)
  338. for l in output.splitlines(True):
  339. self.ui.status(_('remote: '), l)
  340. else:
  341. # bundle2 push. Send a stream, fetch a stream.
  342. stream = self._calltwowaystream('unbundle', cg, heads=heads)
  343. ret = bundle2.unbundle20(self.ui, stream)
  344. return ret
  345. def debugwireargs(self, one, two, three=None, four=None, five=None):
  346. # don't pass optional arguments left at their default value
  347. opts = {}
  348. if three is not None:
  349. opts['three'] = three
  350. if four is not None:
  351. opts['four'] = four
  352. return self._call('debugwireargs', one=one, two=two, **opts)
  353. def _call(self, cmd, **args):
  354. """execute <cmd> on the server
  355. The command is expected to return a simple string.
  356. returns the server reply as a string."""
  357. raise NotImplementedError()
  358. def _callstream(self, cmd, **args):
  359. """execute <cmd> on the server
  360. The command is expected to return a stream.
  361. returns the server reply as a file like object."""
  362. raise NotImplementedError()
  363. def _callcompressable(self, cmd, **args):
  364. """execute <cmd> on the server
  365. The command is expected to return a stream.
  366. The stream may have been compressed in some implementations. This
  367. function takes care of the decompression. This is the only difference
  368. with _callstream.
  369. returns the server reply as a file like object.
  370. """
  371. raise NotImplementedError()
  372. def _callpush(self, cmd, fp, **args):
  373. """execute a <cmd> on server
  374. The command is expected to be related to a push. Push has a special
  375. return method.
  376. returns the server reply as a (ret, output) tuple. ret is either
  377. empty (error) or a stringified int.
  378. """
  379. raise NotImplementedError()
  380. def _calltwowaystream(self, cmd, fp, **args):
  381. """execute <cmd> on server
  382. The command will send a stream to the server and get a stream in reply.
  383. """
  384. raise NotImplementedError()
  385. def _abort(self, exception):
  386. """clearly abort the wire protocol connection and raise the exception
  387. """
  388. raise NotImplementedError()
  389. # server side
  390. # wire protocol command can either return a string or one of these classes.
  391. class streamres(object):
  392. """wireproto reply: binary stream
  393. The call was successful and the result is a stream.
  394. Iterate on the `self.gen` attribute to retrieve chunks.
  395. """
  396. def __init__(self, gen):
  397. self.gen = gen
  398. class pushres(object):
  399. """wireproto reply: success with simple integer return
  400. The call was successful and returned an integer contained in `self.res`.
  401. """
  402. def __init__(self, res):
  403. self.res = res
  404. class pusherr(object):
  405. """wireproto reply: failure
  406. The call failed. The `self.res` attribute contains the error message.
  407. """
  408. def __init__(self, res):
  409. self.res = res
  410. class ooberror(object):
  411. """wireproto reply: failure of a batch of operation
  412. Something failed during a batch call. The error message is stored in
  413. `self.message`.
  414. """
  415. def __init__(self, message):
  416. self.message = message
  417. def dispatch(repo, proto, command):
  418. repo = repo.filtered("served")
  419. func, spec = commands[command]
  420. args = proto.getargs(spec)
  421. return func(repo, proto, *args)
  422. def options(cmd, keys, others):
  423. opts = {}
  424. for k in keys:
  425. if k in others:
  426. opts[k] = others[k]
  427. del others[k]
  428. if others:
  429. sys.stderr.write("warning: %s ignored unexpected arguments %s\n"
  430. % (cmd, ",".join(others)))
  431. return opts
  432. # list of commands
  433. commands = {}
  434. def wireprotocommand(name, args=''):
  435. """decorator for wire protocol command"""
  436. def register(func):
  437. commands[name] = (func, args)
  438. return func
  439. return register
  440. @wireprotocommand('batch', 'cmds *')
  441. def batch(repo, proto, cmds, others):
  442. repo = repo.filtered("served")
  443. res = []
  444. for pair in cmds.split(';'):
  445. op, args = pair.split(' ', 1)
  446. vals = {}
  447. for a in args.split(','):
  448. if a:
  449. n, v = a.split('=')
  450. vals[n] = unescapearg(v)
  451. func, spec = commands[op]
  452. if spec:
  453. keys = spec.split()
  454. data = {}
  455. for k in keys:
  456. if k == '*':
  457. star = {}
  458. for key in vals.keys():
  459. if key not in keys:
  460. star[key] = vals[key]
  461. data['*'] = star
  462. else:
  463. data[k] = vals[k]
  464. result = func(repo, proto, *[data[k] for k in keys])
  465. else:
  466. result = func(repo, proto)
  467. if isinstance(result, ooberror):
  468. return result
  469. res.append(escapearg(result))
  470. return ';'.join(res)
  471. @wireprotocommand('between', 'pairs')
  472. def between(repo, proto, pairs):
  473. pairs = [decodelist(p, '-') for p in pairs.split(" ")]
  474. r = []
  475. for b in repo.between(pairs):
  476. r.append(encodelist(b) + "\n")
  477. return "".join(r)
  478. @wireprotocommand('branchmap')
  479. def branchmap(repo, proto):
  480. branchmap = repo.branchmap()
  481. heads = []
  482. for branch, nodes in branchmap.iteritems():
  483. branchname = urllib.quote(encoding.fromlocal(branch))
  484. branchnodes = encodelist(nodes)
  485. heads.append('%s %s' % (branchname, branchnodes))
  486. return '\n'.join(heads)
  487. @wireprotocommand('branches', 'nodes')
  488. def branches(repo, proto, nodes):
  489. nodes = decodelist(nodes)
  490. r = []
  491. for b in repo.branches(nodes):
  492. r.append(encodelist(b) + "\n")
  493. return "".join(r)
  494. wireprotocaps = ['lookup', 'changegroupsubset', 'branchmap', 'pushkey',
  495. 'known', 'getbundle', 'unbundlehash', 'batch']
  496. def _capabilities(repo, proto):
  497. """return a list of capabilities for a repo
  498. This function exists to allow extensions to easily wrap capabilities
  499. computation
  500. - returns a lists: easy to alter
  501. - change done here will be propagated to both `capabilities` and `hello`
  502. command without any other action needed.
  503. """
  504. # copy to prevent modification of the global list
  505. caps = list(wireprotocaps)
  506. if _allowstream(repo.ui):
  507. if repo.ui.configbool('server', 'preferuncompressed', False):
  508. caps.append('stream-preferred')
  509. requiredformats = repo.requirements & repo.supportedformats
  510. # if our local revlogs are just revlogv1, add 'stream' cap
  511. if not requiredformats - set(('revlogv1',)):
  512. caps.append('stream')
  513. # otherwise, add 'streamreqs' detailing our local revlog format
  514. else:
  515. caps.append('streamreqs=%s' % ','.join(requiredformats))
  516. if repo.ui.configbool('experimental', 'bundle2-exp', False):
  517. capsblob = bundle2.encodecaps(repo.bundle2caps)
  518. caps.append('bundle2-exp=' + urllib.quote(capsblob))
  519. caps.append('unbundle=%s' % ','.join(changegroupmod.bundlepriority))
  520. caps.append('httpheader=1024')
  521. return caps
  522. # If you are writing an extension and consider wrapping this function. Wrap
  523. # `_capabilities` instead.
  524. @wireprotocommand('capabilities')
  525. def capabilities(repo, proto):
  526. return ' '.join(_capabilities(repo, proto))
  527. @wireprotocommand('changegroup', 'roots')
  528. def changegroup(repo, proto, roots):
  529. nodes = decodelist(roots)
  530. cg = changegroupmod.changegroup(repo, nodes, 'serve')
  531. return streamres(proto.groupchunks(cg))
  532. @wireprotocommand('changegroupsubset', 'bases heads')
  533. def changegroupsubset(repo, proto, bases, heads):
  534. bases = decodelist(bases)
  535. heads = decodelist(heads)
  536. cg = changegroupmod.changegroupsubset(repo, bases, heads, 'serve')
  537. return streamres(proto.groupchunks(cg))
  538. @wireprotocommand('debugwireargs', 'one two *')
  539. def debugwireargs(repo, proto, one, two, others):
  540. # only accept optional args from the known set
  541. opts = options('debugwireargs', ['three', 'four'], others)
  542. return repo.debugwireargs(one, two, **opts)
  543. # List of options accepted by getbundle.
  544. #
  545. # Meant to be extended by extensions. It is the extension's responsibility to
  546. # ensure such options are properly processed in exchange.getbundle.
  547. gboptslist = ['heads', 'common', 'bundlecaps']
  548. @wireprotocommand('getbundle', '*')
  549. def getbundle(repo, proto, others):
  550. opts = options('getbundle', gboptsmap.keys(), others)
  551. for k, v in opts.iteritems():
  552. keytype = gboptsmap[k]
  553. if keytype == 'nodes':
  554. opts[k] = decodelist(v)
  555. elif keytype == 'csv':
  556. opts[k] = set(v.split(','))
  557. elif keytype != 'plain':
  558. raise KeyError('unknown getbundle option type %s'
  559. % keytype)
  560. cg = exchange.getbundle(repo, 'serve', **opts)
  561. return streamres(proto.groupchunks(cg))
  562. @wireprotocommand('heads')
  563. def heads(repo, proto):
  564. h = repo.heads()
  565. return encodelist(h) + "\n"
  566. @wireprotocommand('hello')
  567. def hello(repo, proto):
  568. '''the hello command returns a set of lines describing various
  569. interesting things about the server, in an RFC822-like format.
  570. Currently the only one defined is "capabilities", which
  571. consists of a line in the form:
  572. capabilities: space separated list of tokens
  573. '''
  574. return "capabilities: %s\n" % (capabilities(repo, proto))
  575. @wireprotocommand('listkeys', 'namespace')
  576. def listkeys(repo, proto, namespace):
  577. d = repo.listkeys(encoding.tolocal(namespace)).items()
  578. return pushkeymod.encodekeys(d)
  579. @wireprotocommand('lookup', 'key')
  580. def lookup(repo, proto, key):
  581. try:
  582. k = encoding.tolocal(key)
  583. c = repo[k]
  584. r = c.hex()
  585. success = 1
  586. except Exception, inst:
  587. r = str(inst)
  588. success = 0
  589. return "%s %s\n" % (success, r)
  590. @wireprotocommand('known', 'nodes *')
  591. def known(repo, proto, nodes, others):
  592. return ''.join(b and "1" or "0" for b in repo.known(decodelist(nodes)))
  593. @wireprotocommand('pushkey', 'namespace key old new')
  594. def pushkey(repo, proto, namespace, key, old, new):
  595. # compatibility with pre-1.8 clients which were accidentally
  596. # sending raw binary nodes rather than utf-8-encoded hex
  597. if len(new) == 20 and new.encode('string-escape') != new:
  598. # looks like it could be a binary node
  599. try:
  600. new.decode('utf-8')
  601. new = encoding.tolocal(new) # but cleanly decodes as UTF-8
  602. except UnicodeDecodeError:
  603. pass # binary, leave unmodified
  604. else:
  605. new = encoding.tolocal(new) # normal path
  606. if util.safehasattr(proto, 'restore'):
  607. proto.redirect()
  608. try:
  609. r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key),
  610. encoding.tolocal(old), new) or False
  611. except util.Abort:
  612. r = False
  613. output = proto.restore()
  614. return '%s\n%s' % (int(r), output)
  615. r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key),
  616. encoding.tolocal(old), new)
  617. return '%s\n' % int(r)
  618. def _allowstream(ui):
  619. return ui.configbool('server', 'uncompressed', True, untrusted=True)
  620. def _walkstreamfiles(repo):
  621. # this is it's own function so extensions can override it
  622. return repo.store.walk()
  623. @wireprotocommand('stream_out')
  624. def stream(repo, proto):
  625. '''If the server supports streaming clone, it advertises the "stream"
  626. capability with a value representing the version and flags of the repo
  627. it is serving. Client checks to see if it understands the format.
  628. The format is simple: the server writes out a line with the amount
  629. of files, then the total amount of bytes to be transferred (separated
  630. by a space). Then, for each file, the server first writes the filename
  631. and file size (separated by the null character), then the file contents.
  632. '''
  633. if not _allowstream(repo.ui):
  634. return '1\n'
  635. entries = []
  636. total_bytes = 0
  637. try:
  638. # get consistent snapshot of repo, lock during scan
  639. lock = repo.lock()
  640. try:
  641. repo.ui.debug('scanning\n')
  642. for name, ename, size in _walkstreamfiles(repo):
  643. if size:
  644. entries.append((name, size))
  645. total_bytes += size
  646. finally:
  647. lock.release()
  648. except error.LockError:
  649. return '2\n' # error: 2
  650. def streamer(repo, entries, total):
  651. '''stream out all metadata files in repository.'''
  652. yield '0\n' # success
  653. repo.ui.debug('%d files, %d bytes to transfer\n' %
  654. (len(entries), total_bytes))
  655. yield '%d %d\n' % (len(entries), total_bytes)
  656. sopener = repo.sopener
  657. oldaudit = sopener.mustaudit
  658. debugflag = repo.ui.debugflag
  659. sopener.mustaudit = False
  660. try:
  661. for name, size in entries:
  662. if debugflag:
  663. repo.ui.debug('sending %s (%d bytes)\n' % (name, size))
  664. # partially encode name over the wire for backwards compat
  665. yield '%s\0%d\n' % (store.encodedir(name), size)
  666. if size <= 65536:
  667. fp = sopener(name)
  668. try:
  669. data = fp.read(size)
  670. finally:
  671. fp.close()
  672. yield data
  673. else:
  674. for chunk in util.filechunkiter(sopener(name), limit=size):
  675. yield chunk
  676. # replace with "finally:" when support for python 2.4 has been dropped
  677. except Exception:
  678. sopener.mustaudit = oldaudit
  679. raise
  680. sopener.mustaudit = oldaudit
  681. return streamres(streamer(repo, entries, total_bytes))
  682. @wireprotocommand('unbundle', 'heads')
  683. def unbundle(repo, proto, heads):
  684. their_heads = decodelist(heads)
  685. try:
  686. proto.redirect()
  687. exchange.check_heads(repo, their_heads, 'preparing changes')
  688. # write bundle data to temporary file because it can be big
  689. fd, tempname = tempfile.mkstemp(prefix='hg-unbundle-')
  690. fp = os.fdopen(fd, 'wb+')
  691. r = 0
  692. try:
  693. proto.getfile(fp)
  694. fp.seek(0)
  695. gen = exchange.readbundle(repo.ui, fp, None)
  696. r = exchange.unbundle(repo, gen, their_heads, 'serve',
  697. proto._client())
  698. if util.safehasattr(r, 'addpart'):
  699. # The return looks streameable, we are in the bundle2 case and
  700. # should return a stream.
  701. return streamres(r.getchunks())
  702. return pushres(r)
  703. finally:
  704. fp.close()
  705. os.unlink(tempname)
  706. except error.BundleValueError, exc:
  707. bundler = bundle2.bundle20(repo.ui)
  708. errpart = bundler.newpart('B2X:ERROR:UNSUPPORTEDCONTENT')
  709. if exc.parttype is not None:
  710. errpart.addparam('parttype', exc.parttype)
  711. if exc.params:
  712. errpart.addparam('params', '\0'.join(exc.params))
  713. return streamres(bundler.getchunks())
  714. except util.Abort, inst:
  715. # The old code we moved used sys.stderr directly.
  716. # We did not change it to minimise code change.
  717. # This need to be moved to something proper.
  718. # Feel free to do it.
  719. if getattr(inst, 'duringunbundle2', False):
  720. bundler = bundle2.bundle20(repo.ui)
  721. manargs = [('message', str(inst))]
  722. advargs = []
  723. if inst.hint is not None:
  724. advargs.append(('hint', inst.hint))
  725. bundler.addpart(bundle2.bundlepart('B2X:ERROR:ABORT',
  726. manargs, advargs))
  727. return streamres(bundler.getchunks())
  728. else:
  729. sys.stderr.write("abort: %s\n" % inst)
  730. return pushres(0)
  731. except error.PushRaced, exc:
  732. if getattr(exc, 'duringunbundle2', False):
  733. bundler = bundle2.bundle20(repo.ui)
  734. bundler.newpart('B2X:ERROR:PUSHRACED', [('message', str(exc))])
  735. return streamres(bundler.getchunks())
  736. else:
  737. return pusherr(str(exc))