/sqlautocode/declarative.py

https://code.google.com/p/sqlautocode/ · Python · 467 lines · 356 code · 72 blank · 39 comment · 118 complexity · f870f66308470b71140c69c8fcd18faa MD5 · raw file

  1. import sys, re, inspect, operator
  2. import logging
  3. from util import emit, name2label, plural, singular
  4. try:
  5. from cStringIO import StringIO
  6. except ImportError:
  7. from StringIO import StringIO
  8. import sqlalchemy
  9. from sqlalchemy import exc, and_
  10. from sqlalchemy import MetaData, ForeignKeyConstraint
  11. from sqlalchemy.ext.declarative import declarative_base
  12. try:
  13. from sqlalchemy.ext.declarative import _deferred_relationship
  14. except ImportError:
  15. #SA 0.5 support
  16. from sqlalchemy.ext.declarative import _deferred_relation as _deferred_relationship
  17. from sqlalchemy.orm import relation, backref, class_mapper, Mapper
  18. try:
  19. #SA 0.5 support
  20. from sqlalchemy.orm import RelationProperty
  21. except ImportError:
  22. #SA 0.7 support
  23. try:
  24. from sqlalchemy.orm.properties import RelationshipProperty, RelationProperty
  25. except ImportError:
  26. RelationProperty = None
  27. import config
  28. import constants
  29. from formatter import _repr_coltype_as, foreignkeyconstraint_repr
  30. log = logging.getLogger('saac.decl')
  31. log.setLevel(logging.DEBUG)
  32. handler = logging.StreamHandler()
  33. formatter = logging.Formatter("%(name)s - %(levelname)s - %(message)s")
  34. handler.setFormatter(formatter)
  35. log.addHandler(handler)
  36. def by_name(a, b):
  37. if a.name>b.name:
  38. return 1
  39. return -1
  40. def by__name__(a, b):
  41. if a.__name__ > b.__name__:
  42. return 1
  43. return -1
  44. def column_repr(self):
  45. kwarg = []
  46. if self.key != self.name:
  47. kwarg.append( 'key')
  48. if hasattr(self, 'primary_key') and self.primary_key:
  49. self.primary_key = True
  50. kwarg.append( 'primary_key')
  51. if not self.nullable:
  52. kwarg.append( 'nullable')
  53. if self.onupdate:
  54. kwarg.append( 'onupdate')
  55. if self.default:
  56. kwarg.append( 'default')
  57. ks = ', '.join('%s=%r' % (k, getattr(self, k)) for k in kwarg)
  58. name = self.name
  59. if not hasattr(config, 'options') and self.config.options.generictypes:
  60. coltype = repr(self.type)
  61. elif type(self.type).__module__ == 'sqlalchemy.types':
  62. coltype = repr(self.type)
  63. else:
  64. # Try to 'cast' this column type to a cross-platform type
  65. # from sqlalchemy.types, dropping any database-specific type
  66. # arguments.
  67. for base in type(self.type).__mro__:
  68. if (base.__module__ == 'sqlalchemy.types' and
  69. base.__name__ in sqlalchemy.__all__):
  70. coltype = _repr_coltype_as(self.type, base)
  71. break
  72. # FIXME: if a dialect has a non-standard type that does not
  73. # derive from an ANSI type, there's no choice but to ignore
  74. # generic-types and output the exact type. However, import
  75. # headers have already been output and lack the required
  76. # dialect import.
  77. else:
  78. coltype = repr(self.type)
  79. data = {'name': self.name,
  80. 'type': coltype,
  81. 'constraints': ', '.join(["ForeignKey('%s')"%cn.target_fullname for cn in self.foreign_keys]),
  82. 'args': ks and ks or '',
  83. }
  84. if data['constraints']:
  85. if data['constraints']: data['constraints'] = ', ' + data['constraints']
  86. if data['args']:
  87. if data['args']: data['args'] = ', ' + data['args']
  88. return constants.COLUMN % data
  89. class ModelFactory(object):
  90. def __init__(self, config):
  91. self.config = config
  92. self.used_model_names = []
  93. self.used_table_names = []
  94. schema = getattr(self.config, 'schema', None)
  95. self._metadata = MetaData(bind=config.engine)
  96. self._foreign_keys = {}
  97. kw = {}
  98. self.schemas = None
  99. if schema:
  100. if isinstance(schema, (list, tuple)):
  101. self.schemas = schema
  102. else:
  103. self.schemas = (schema, )
  104. for schema in self.schemas:
  105. log.info('Reflecting database... schema:%s'%schema)
  106. self._metadata.reflect(schema=schema)
  107. else:
  108. log.info('Reflecting database...')
  109. self._metadata.reflect()
  110. self.DeclarativeBase = declarative_base(metadata=self._metadata)
  111. def _table_repr(self, table):
  112. s = "Table(u'%s', metadata,\n"%(table.name)
  113. for column in table.c:
  114. s += " %s,\n"%column_repr(column)
  115. if table.schema:
  116. s +=" schema='%s'\n"%table.schema
  117. s+=")"
  118. return s
  119. def __repr__(self):
  120. tables = self.get_many_to_many_tables()
  121. tables.extend(self.get_tables_with_no_pks())
  122. models = self.models
  123. s = StringIO()
  124. engine = self.config.engine
  125. if not isinstance(engine, basestring):
  126. engine = str(engine.url)
  127. s.write(constants.HEADER_DECL%engine)
  128. if 'postgres' in engine:
  129. s.write(constants.PG_IMPORT)
  130. self.used_table_names = []
  131. self.used_model_names = []
  132. for table in tables:
  133. if table not in self.tables:
  134. continue
  135. table_name = self.find_new_name(table.name, self.used_table_names)
  136. self.used_table_names.append(table_name)
  137. s.write('%s = %s\n\n'%(table_name, self._table_repr(table)))
  138. for model in models:
  139. s.write(model.__repr__())
  140. s.write("\n\n")
  141. if self.config.example or self.config.interactive:
  142. s.write(constants.EXAMPLE_DECL%(models[0].__name__,models[0].__name__))
  143. if self.config.interactive:
  144. s.write(constants.INTERACTIVE%([model.__name__ for model in models], models[0].__name__))
  145. return s.getvalue()
  146. @property
  147. def tables(self):
  148. if self.config.options.tables:
  149. tables = set(self.config.options.tables)
  150. return [self._metadata.tables[t] for t in set(self._metadata.tables.keys()).intersection(tables)]
  151. return self._metadata.tables.values()
  152. @property
  153. def table_names(self):
  154. return [t.name for t in self.tables]
  155. @property
  156. def models(self):
  157. if hasattr(self, '_models'):
  158. return self._models
  159. self.used_model_names = []
  160. self.used_table_names = []
  161. self._models = []
  162. for table in self.get_non_many_to_many_tables():
  163. try:
  164. self._models.append(self.create_model(table))
  165. except exc.ArgumentError:
  166. log.warning("Table with name %s ha no primary key. No ORM class created"%table.name)
  167. self._models.sort(by__name__)
  168. return self._models
  169. def get_tables_with_no_pks(self):
  170. r = []
  171. for table in self.get_non_many_to_many_tables():
  172. if not [c for c in table.columns if c.primary_key]:
  173. r.append(table)
  174. return r
  175. def model_table_lookup(self):
  176. if hasattr(self, '_model_table_lookup'):
  177. return self._model_table_lookup
  178. self._model_table_lookup = dict(((m.__table__.name, m.__name__) for m in self.models))
  179. return self._model_table_lookup
  180. def find_new_name(self, prefix, used, i=0):
  181. if i!=0:
  182. prefix = "%s%d"%(prefix, i)
  183. if prefix in used:
  184. prefix = prefix
  185. return self.find_new_name(prefix, used, i+1)
  186. return prefix
  187. def create_model(self, table):
  188. #partially borrowed from Jorge Vargas' code
  189. #http://dpaste.org/V6YS/
  190. log.debug('Creating Model from table: %s'%table.name)
  191. model_name = self.find_new_name(singular(name2label(table.name)), self.used_model_names)
  192. self.used_model_names.append(model_name)
  193. is_many_to_many_table = self.is_many_to_many_table(table)
  194. table_name = self.find_new_name(table.name, self.used_table_names)
  195. self.used_table_names.append(table_name)
  196. mtl = self.model_table_lookup
  197. class Temporal(self.DeclarativeBase):
  198. __table__ = table
  199. @classmethod
  200. def _relation_repr(cls, rel):
  201. target = rel.argument
  202. if target and inspect.isfunction(target):
  203. target = target()
  204. if isinstance(target, Mapper):
  205. target = target.class_
  206. target = target.__name__
  207. primaryjoin=''
  208. lookup = mtl()
  209. if rel.primaryjoin is not None and hasattr(rel.primaryjoin, 'right'):
  210. right_lookup = lookup.get(rel.primaryjoin.right.table.name, '%s.c'%rel.primaryjoin.right.table.name)
  211. left_lookup = lookup.get(rel.primaryjoin.left.table.name, '%s.c'%rel.primaryjoin.left.table.name)
  212. primaryjoin = ", primaryjoin='%s.%s==%s.%s'"%(left_lookup,
  213. rel.primaryjoin.left.name,
  214. right_lookup,
  215. rel.primaryjoin.right.name)
  216. elif hasattr(rel, '_as_string'):
  217. primaryjoin=', primaryjoin="%s"'%rel._as_string
  218. secondary = ''
  219. secondaryjoin = ''
  220. if rel.secondary is not None:
  221. secondary = ", secondary=%s"%rel.secondary.name
  222. right_lookup = lookup.get(rel.secondaryjoin.right.table.name, '%s.c'%rel.secondaryjoin.right.table.name)
  223. left_lookup = lookup.get(rel.secondaryjoin.left.table.name, '%s.c'%rel.secondaryjoin.left.table.name)
  224. secondaryjoin = ", secondaryjoin='%s.%s==%s.%s'"%(left_lookup,
  225. rel.secondaryjoin.left.name,
  226. right_lookup,
  227. rel.secondaryjoin.right.name)
  228. backref=''
  229. # if rel.backref:
  230. # backref=", backref='%s'"%rel.backref.key
  231. return "%s = relation('%s'%s%s%s%s)"%(rel.key, target, primaryjoin, secondary, secondaryjoin, backref)
  232. @classmethod
  233. def __repr__(cls):
  234. log.debug('repring class with name %s'%cls.__name__)
  235. try:
  236. mapper = None
  237. try:
  238. mapper = class_mapper(cls)
  239. except exc.InvalidRequestError:
  240. log.warn("A proper mapper could not be generated for the class %s, no relations will be created"%model_name)
  241. s = ""
  242. s += "class "+model_name+'(DeclarativeBase):\n'
  243. if is_many_to_many_table:
  244. s += " __table__ = %s\n\n"%table_name
  245. else:
  246. s += " __tablename__ = '%s'\n\n"%table_name
  247. if hasattr(cls, '__table_args__'):
  248. #if cls.__table_args__[0]:
  249. #for fkc in cls.__table_args__[0]:
  250. # fkc.__class__.__repr__ = foreignkeyconstraint_repr
  251. # break
  252. s+=" __table_args__ = %s\n\n"%cls.__table_args__
  253. s += " #column definitions\n"
  254. for column in sorted(cls.__table__.c, by_name):
  255. s += " %s = %s\n"%(column.name, column_repr(column))
  256. s += "\n #relation definitions\n"
  257. ess = s
  258. # this is only required in SA 0.5
  259. if mapper and RelationProperty:
  260. for prop in mapper.iterate_properties:
  261. if isinstance(prop, RelationshipProperty):
  262. s+=' %s\n'%cls._relation_repr(prop)
  263. return s
  264. except Exception, e:
  265. log.error("Could not generate class for: %s"%cls.__name__)
  266. from traceback import format_exc
  267. log.error(format_exc())
  268. return ''
  269. #hack the class to have the right classname
  270. Temporal.__name__ = model_name
  271. #set up some blank table args
  272. Temporal.__table_args__ = {}
  273. #add in the schema
  274. #if self.config.schema:
  275. #Temporal.__table_args__[1]['schema'] = table.schema
  276. #trick sa's model registry to think the model is the correct name
  277. if model_name != 'Temporal':
  278. Temporal._decl_class_registry[model_name] = Temporal._decl_class_registry['Temporal']
  279. del Temporal._decl_class_registry['Temporal']
  280. #add in single relations
  281. fks = self.get_single_foreign_keys_by_column(table)
  282. for column, fk in fks.iteritems():
  283. related_table = fk.column.table
  284. if related_table not in self.tables:
  285. continue
  286. log.info(' Adding <primary> foreign key for:%s'%related_table.name)
  287. backref_name = plural(table_name)
  288. rel = relation(singular(name2label(related_table.name, related_table.schema)),
  289. primaryjoin=column==fk.column)#, backref=backref_name)
  290. setattr(Temporal, related_table.name, _deferred_relationship(Temporal, rel))
  291. #add in the relations for the composites
  292. for constraint in table.constraints:
  293. if isinstance(constraint, ForeignKeyConstraint):
  294. if len(constraint.elements) >1:
  295. related_table = constraint.elements[0].column.table
  296. related_classname = singular(name2label(related_table.name, related_table.schema))
  297. primary_join = "and_(%s)"%', '.join(["%s.%s==%s.%s"%(model_name,
  298. k.parent.name,
  299. related_classname,
  300. k.column.name)
  301. for k in constraint.elements])
  302. rel = relation(related_classname,
  303. primaryjoin=primary_join
  304. # foreign_keys=[k.parent for k in constraint.elements]
  305. )
  306. rel._as_string = primary_join
  307. setattr(Temporal, related_table.name, rel) # _deferred_relationship(Temporal, rel))
  308. #add in many-to-many relations
  309. for join_table in self.get_related_many_to_many_tables(table.name):
  310. if join_table not in self.tables:
  311. continue
  312. primary_column = [c for c in join_table.columns if c.foreign_keys and list(c.foreign_keys)[0].column.table==table][0]
  313. for column in join_table.columns:
  314. if column.foreign_keys:
  315. key = list(column.foreign_keys)[0]
  316. if key.column.table is not table:
  317. related_column = related_table = list(column.foreign_keys)[0].column
  318. related_table = related_column.table
  319. if related_table not in self.tables:
  320. continue
  321. log.info(' Adding <secondary> foreign key(%s) for:%s'%(key, related_table.name))
  322. setattr(Temporal, plural(related_table.name), _deferred_relationship(Temporal,
  323. relation(singular(name2label(related_table.name,
  324. related_table.schema)),
  325. secondary=join_table,
  326. primaryjoin=list(primary_column.foreign_keys)[0].column==primary_column,
  327. secondaryjoin=column==related_column
  328. )))
  329. break;
  330. return Temporal
  331. def get_table(self, name):
  332. """(name) -> sqlalchemy.schema.Table
  333. get the table definition with the given table name
  334. """
  335. if self.schemas:
  336. for schema in self.schemas:
  337. if schema and not name.startswith(schema):
  338. new_name = '.'.join((schema, name))
  339. table = self._metadata.tables.get(new_name, None)
  340. if table is not None:
  341. return table
  342. return self._metadata.tables[name]
  343. def get_single_foreign_keys_by_column(self, table):
  344. keys_by_column = {}
  345. fks = self.get_foreign_keys(table)
  346. for table, keys in fks.iteritems():
  347. if len(keys) == 1:
  348. fk = keys[0]
  349. keys_by_column[fk.parent] = fk
  350. return keys_by_column
  351. def get_composite_foreign_keys(self, table):
  352. l = []
  353. fks = self.get_foreign_keys(table)
  354. for table, keys in fks.iteritems():
  355. if len(keys)>1:
  356. l.append(keys)
  357. return l
  358. def get_foreign_keys(self, table):
  359. if table in self._foreign_keys:
  360. return self._foreign_keys[table]
  361. fks = table.foreign_keys
  362. #group fks by table. I think this is needed because of a problem in the sa reflection alg.
  363. grouped_fks = {}
  364. for key in fks:
  365. grouped_fks.setdefault(key.column.table, []).append(key)
  366. self._foreign_keys[table] = grouped_fks
  367. return grouped_fks
  368. # fks = {}
  369. # for column in table.columns:
  370. # if len(column.foreign_keys)>0:
  371. # fks.setdefault(column.name, []).extend(column.foreign_keys)
  372. # return fks
  373. def is_many_to_many_table(self, table):
  374. fks = self.get_single_foreign_keys_by_column(table).values()
  375. return len(fks) >= 2
  376. def is_only_many_to_many_table(self, table):
  377. return len(self.get_single_foreign_keys_by_column(table)) == 2 and len(table.c) == 2
  378. def get_many_to_many_tables(self):
  379. if not hasattr(self, '_many_to_many_tables'):
  380. self._many_to_many_tables = [table for table in self._metadata.tables.values() if self.is_many_to_many_table(table)]
  381. return sorted(self._many_to_many_tables, by_name)
  382. def get_non_many_to_many_tables(self):
  383. tables = [table for table in self.tables if not(self.is_only_many_to_many_table(table))]
  384. return sorted(tables, by_name)
  385. def get_related_many_to_many_tables(self, table_name):
  386. tables = []
  387. src_table = self.get_table(table_name)
  388. for table in self.get_many_to_many_tables():
  389. for column in table.columns:
  390. if column.foreign_keys:
  391. key = list(column.foreign_keys)[0]
  392. if key.column.table is src_table:
  393. tables.append(table)
  394. break
  395. return sorted(tables, by_name)