/core.py

https://github.com/GunioRobot/pyresto · Python · 202 lines · 154 code · 44 blank · 4 comment · 35 complexity · f836b0850cffa8ac7ce4d57481700f46 MD5 · raw file

  1. # coding: utf-8
  2. import httplib
  3. import json
  4. import logging
  5. from urllib import quote
  6. __all__ = ('Model', 'Many', 'Foreign')
  7. logging.getLogger().setLevel(logging.DEBUG)
  8. class ModelBase(type):
  9. def __new__(cls, name, bases, attrs):
  10. if name == 'Model':
  11. return super(ModelBase, cls).__new__(cls, name, bases, attrs)
  12. new_class = type.__new__(cls, name, bases, attrs)
  13. if not hasattr(new_class, '_path'):
  14. new_class._path = '/%s/%%(id)s' % quote(name.lower())
  15. conn_class = httplib.HTTPSConnection if new_class._secure else httplib.HTTPConnection
  16. new_class._connection = conn_class(new_class._host)
  17. return new_class
  18. class WrappedList(list):
  19. def __init__(self, iterable, wrapper):
  20. super(self.__class__, self).__init__(iterable)
  21. self.__wrapper = wrapper
  22. @staticmethod
  23. def is_dict(obj):
  24. return isinstance(obj, dict)
  25. def __getitem__(self, key):
  26. item = super(self.__class__, self).__getitem__(key)
  27. should_wrap = self.is_dict(item) or isinstance(key, slice) and any(map(self.is_dict, item))
  28. if should_wrap:
  29. item = map(self.__wrapper, item) if isinstance(key, slice) \
  30. else self.__wrapper(item)
  31. self[key] = item
  32. return item
  33. def __getslice__(self, i, j):
  34. items = super(self.__class__, self).__getslice__(i, j)
  35. if any(map(self.is_dict, items)):
  36. items = map(self.__wrapper, items)
  37. self[i:j] = items
  38. return items
  39. def __iter__(self):
  40. iterator = super(self.__class__, self).__iter__()
  41. return (self.__wrapper(item) for item in iterator)
  42. class Relation(object):
  43. pass
  44. class Many(Relation):
  45. def __init__(self, model, path=None):
  46. self.__model = model
  47. self.__path = path or model._path
  48. self.__cache = {}
  49. def _with_owner(self, owner):
  50. def mapper(data):
  51. if isinstance(data, dict):
  52. instance = self.__model(**data)
  53. #set auto fetching true for man fields which usually contain a summary
  54. instance._auto_fetch = True
  55. instance._owner = owner
  56. return instance
  57. elif isinstance(data, self.__model):
  58. return data
  59. return mapper
  60. def __get__(self, instance, owner):
  61. if not instance:
  62. return self.__model
  63. if instance not in self.__cache:
  64. model = self.__model
  65. if not instance:
  66. return model
  67. path_params = instance._get_id_dict()
  68. if hasattr(instance, '_get_params'):
  69. path_params.update(instance._get_params)
  70. path = self.__path % path_params
  71. logging.debug('Call many path: %s' % path)
  72. data = model._rest_call(method='GET', url=path) or []
  73. self.__cache[instance] = WrappedList(data, self._with_owner(instance))
  74. return self.__cache[instance]
  75. class Foreign(Relation):
  76. def __init__(self, model, key_extractor=None):
  77. self.__model = model
  78. model_name = model.__name__.lower()
  79. model_pk = model._pk
  80. self.__key_extractor = key_extractor if key_extractor else \
  81. lambda x:{model_pk: getattr(x, '__' + model_name)[model_pk]}
  82. self.__cache = {}
  83. def __get__(self, instance, owner):
  84. if not instance:
  85. return self.__model
  86. if instance not in self.__cache:
  87. keys = instance._get_id_dict()
  88. keys.update(self.__key_extractor(instance))
  89. logging.debug('Keys dict for foreign acccess: %s', str(keys))
  90. pk = keys.pop(self.__model._pk)
  91. self.__cache[instance] = self.__model.get(pk, **keys)
  92. return self.__cache[instance]
  93. class Model(object):
  94. __metaclass__ = ModelBase
  95. _secure = True
  96. _continuator = lambda x, y:None
  97. _parser = staticmethod(json.loads)
  98. _fetched = False
  99. _get_params = dict()
  100. def __init__(self, **kwargs):
  101. self.__dict__.update(kwargs)
  102. cls = self.__class__
  103. overlaps = set(cls.__dict__) & set(kwargs)
  104. #logging.debug('Found overlaps: %s', str(overlaps))
  105. for item in overlaps:
  106. if issubclass(getattr(cls, item), Model):
  107. self.__dict__['__' + item] = self.__dict__.pop(item)
  108. def _get_id_dict(self):
  109. ids = {}
  110. owner = self
  111. while owner:
  112. ids[owner.__class__.__name__.lower()] = getattr(owner, owner._pk)
  113. owner = getattr(owner, '_owner', None)
  114. return ids
  115. @classmethod
  116. def _rest_call(cls, **kwargs):
  117. conn = cls._connection
  118. try:
  119. conn.request(**kwargs)
  120. response = conn.getresponse()
  121. except Exception as e:
  122. #should call conn.close() on any error to allow further calls to be made
  123. logging.debug('httplib error: %s', e.__class__.__name__)
  124. conn.close()
  125. return None
  126. logging.debug('Response code: %s', response.status)
  127. if response.status == 200:
  128. continuation_url = cls._continuator(response)
  129. encoding = response.getheader('content-type', '').split('charset=')
  130. encoding = encoding[1] if len(encoding) > 1 else 'utf-8'
  131. data = cls._parser(unicode(response.read(), encoding, 'replace'))
  132. if continuation_url:
  133. logging.debug('Found more at: %s', continuation_url)
  134. kwargs['url'] = continuation_url
  135. data += cls._rest_call(**kwargs)
  136. return data
  137. def __fetch(self):
  138. path = self._path % self.__dict__
  139. data = self._rest_call(method='GET', url=path)
  140. if data:
  141. self.__dict__.update(data)
  142. self._fetched = True
  143. def __getattr__(self, name):
  144. if self._fetched:
  145. raise AttributeError
  146. self.__fetch()
  147. return getattr(self, name)
  148. @classmethod
  149. def get(cls, id, **kwargs):
  150. kwargs[cls._pk] = id
  151. path = cls._path % kwargs
  152. data = cls._rest_call(method='GET', url=path)
  153. if not data:
  154. return
  155. instance = cls(**data)
  156. instance._get_params = kwargs
  157. instance._fetched = True
  158. return instance