/sqlautocode/declarative.py
Python | 467 lines | 398 code | 44 blank | 25 comment | 87 complexity | f870f66308470b71140c69c8fcd18faa MD5 | raw file
1import sys, re, inspect, operator 2import logging 3from util import emit, name2label, plural, singular 4try: 5 from cStringIO import StringIO 6except ImportError: 7 from StringIO import StringIO 8 9import sqlalchemy 10from sqlalchemy import exc, and_ 11from sqlalchemy import MetaData, ForeignKeyConstraint 12from sqlalchemy.ext.declarative import declarative_base 13try: 14 from sqlalchemy.ext.declarative import _deferred_relationship 15except ImportError: 16 #SA 0.5 support 17 from sqlalchemy.ext.declarative import _deferred_relation as _deferred_relationship 18 19from sqlalchemy.orm import relation, backref, class_mapper, Mapper 20 21try: 22 #SA 0.5 support 23 from sqlalchemy.orm import RelationProperty 24except ImportError: 25 #SA 0.7 support 26 try: 27 from sqlalchemy.orm.properties import RelationshipProperty, RelationProperty 28 except ImportError: 29 RelationProperty = None 30 31 32import config 33import constants 34from formatter import _repr_coltype_as, foreignkeyconstraint_repr 35 36log = logging.getLogger('saac.decl') 37log.setLevel(logging.DEBUG) 38handler = logging.StreamHandler() 39formatter = logging.Formatter("%(name)s - %(levelname)s - %(message)s") 40handler.setFormatter(formatter) 41log.addHandler(handler) 42 43def by_name(a, b): 44 if a.name>b.name: 45 return 1 46 return -1 47def by__name__(a, b): 48 if a.__name__ > b.__name__: 49 return 1 50 return -1 51 52def column_repr(self): 53 54 kwarg = [] 55 if self.key != self.name: 56 kwarg.append( 'key') 57 58 if hasattr(self, 'primary_key') and self.primary_key: 59 self.primary_key = True 60 kwarg.append( 'primary_key') 61 62 if not self.nullable: 63 kwarg.append( 'nullable') 64 if self.onupdate: 65 kwarg.append( 'onupdate') 66 if self.default: 67 kwarg.append( 'default') 68 ks = ', '.join('%s=%r' % (k, getattr(self, k)) for k in kwarg) 69 70 name = self.name 71 72 if not hasattr(config, 'options') and self.config.options.generictypes: 73 coltype = repr(self.type) 74 elif type(self.type).__module__ == 'sqlalchemy.types': 75 coltype = repr(self.type) 76 else: 77 # Try to 'cast' this column type to a cross-platform type 78 # from sqlalchemy.types, dropping any database-specific type 79 # arguments. 80 for base in type(self.type).__mro__: 81 if (base.__module__ == 'sqlalchemy.types' and 82 base.__name__ in sqlalchemy.__all__): 83 coltype = _repr_coltype_as(self.type, base) 84 break 85 # FIXME: if a dialect has a non-standard type that does not 86 # derive from an ANSI type, there's no choice but to ignore 87 # generic-types and output the exact type. However, import 88 # headers have already been output and lack the required 89 # dialect import. 90 else: 91 coltype = repr(self.type) 92 93 data = {'name': self.name, 94 'type': coltype, 95 'constraints': ', '.join(["ForeignKey('%s')"%cn.target_fullname for cn in self.foreign_keys]), 96 'args': ks and ks or '', 97 } 98 99 if data['constraints']: 100 if data['constraints']: data['constraints'] = ', ' + data['constraints'] 101 if data['args']: 102 if data['args']: data['args'] = ', ' + data['args'] 103 104 return constants.COLUMN % data 105 106class ModelFactory(object): 107 108 def __init__(self, config): 109 self.config = config 110 self.used_model_names = [] 111 self.used_table_names = [] 112 schema = getattr(self.config, 'schema', None) 113 self._metadata = MetaData(bind=config.engine) 114 self._foreign_keys = {} 115 kw = {} 116 self.schemas = None 117 if schema: 118 if isinstance(schema, (list, tuple)): 119 self.schemas = schema 120 else: 121 self.schemas = (schema, ) 122 for schema in self.schemas: 123 log.info('Reflecting database... schema:%s'%schema) 124 self._metadata.reflect(schema=schema) 125 else: 126 log.info('Reflecting database...') 127 self._metadata.reflect() 128 129 self.DeclarativeBase = declarative_base(metadata=self._metadata) 130 131 def _table_repr(self, table): 132 s = "Table(u'%s', metadata,\n"%(table.name) 133 for column in table.c: 134 s += " %s,\n"%column_repr(column) 135 if table.schema: 136 s +=" schema='%s'\n"%table.schema 137 s+=")" 138 return s 139 140 def __repr__(self): 141 tables = self.get_many_to_many_tables() 142 tables.extend(self.get_tables_with_no_pks()) 143 models = self.models 144 145 s = StringIO() 146 engine = self.config.engine 147 if not isinstance(engine, basestring): 148 engine = str(engine.url) 149 s.write(constants.HEADER_DECL%engine) 150 if 'postgres' in engine: 151 s.write(constants.PG_IMPORT) 152 153 self.used_table_names = [] 154 self.used_model_names = [] 155 for table in tables: 156 if table not in self.tables: 157 continue 158 table_name = self.find_new_name(table.name, self.used_table_names) 159 self.used_table_names.append(table_name) 160 s.write('%s = %s\n\n'%(table_name, self._table_repr(table))) 161 162 for model in models: 163 s.write(model.__repr__()) 164 s.write("\n\n") 165 166 if self.config.example or self.config.interactive: 167 s.write(constants.EXAMPLE_DECL%(models[0].__name__,models[0].__name__)) 168 if self.config.interactive: 169 s.write(constants.INTERACTIVE%([model.__name__ for model in models], models[0].__name__)) 170 return s.getvalue() 171 172 @property 173 def tables(self): 174 if self.config.options.tables: 175 tables = set(self.config.options.tables) 176 return [self._metadata.tables[t] for t in set(self._metadata.tables.keys()).intersection(tables)] 177 return self._metadata.tables.values() 178 179 @property 180 def table_names(self): 181 return [t.name for t in self.tables] 182 183 @property 184 def models(self): 185 if hasattr(self, '_models'): 186 return self._models 187 self.used_model_names = [] 188 self.used_table_names = [] 189 self._models = [] 190 for table in self.get_non_many_to_many_tables(): 191 try: 192 self._models.append(self.create_model(table)) 193 except exc.ArgumentError: 194 log.warning("Table with name %s ha no primary key. No ORM class created"%table.name) 195 self._models.sort(by__name__) 196 return self._models 197 198 def get_tables_with_no_pks(self): 199 r = [] 200 for table in self.get_non_many_to_many_tables(): 201 if not [c for c in table.columns if c.primary_key]: 202 r.append(table) 203 return r 204 205 def model_table_lookup(self): 206 if hasattr(self, '_model_table_lookup'): 207 return self._model_table_lookup 208 self._model_table_lookup = dict(((m.__table__.name, m.__name__) for m in self.models)) 209 return self._model_table_lookup 210 211 def find_new_name(self, prefix, used, i=0): 212 if i!=0: 213 prefix = "%s%d"%(prefix, i) 214 if prefix in used: 215 prefix = prefix 216 return self.find_new_name(prefix, used, i+1) 217 return prefix 218 219 220 def create_model(self, table): 221 #partially borrowed from Jorge Vargas' code 222 #http://dpaste.org/V6YS/ 223 log.debug('Creating Model from table: %s'%table.name) 224 225 model_name = self.find_new_name(singular(name2label(table.name)), self.used_model_names) 226 self.used_model_names.append(model_name) 227 is_many_to_many_table = self.is_many_to_many_table(table) 228 table_name = self.find_new_name(table.name, self.used_table_names) 229 self.used_table_names.append(table_name) 230 231 mtl = self.model_table_lookup 232 233 234 class Temporal(self.DeclarativeBase): 235 __table__ = table 236 237 @classmethod 238 def _relation_repr(cls, rel): 239 target = rel.argument 240 if target and inspect.isfunction(target): 241 target = target() 242 if isinstance(target, Mapper): 243 target = target.class_ 244 target = target.__name__ 245 primaryjoin='' 246 lookup = mtl() 247 if rel.primaryjoin is not None and hasattr(rel.primaryjoin, 'right'): 248 right_lookup = lookup.get(rel.primaryjoin.right.table.name, '%s.c'%rel.primaryjoin.right.table.name) 249 left_lookup = lookup.get(rel.primaryjoin.left.table.name, '%s.c'%rel.primaryjoin.left.table.name) 250 251 primaryjoin = ", primaryjoin='%s.%s==%s.%s'"%(left_lookup, 252 rel.primaryjoin.left.name, 253 right_lookup, 254 rel.primaryjoin.right.name) 255 elif hasattr(rel, '_as_string'): 256 primaryjoin=', primaryjoin="%s"'%rel._as_string 257 258 secondary = '' 259 secondaryjoin = '' 260 if rel.secondary is not None: 261 secondary = ", secondary=%s"%rel.secondary.name 262 right_lookup = lookup.get(rel.secondaryjoin.right.table.name, '%s.c'%rel.secondaryjoin.right.table.name) 263 left_lookup = lookup.get(rel.secondaryjoin.left.table.name, '%s.c'%rel.secondaryjoin.left.table.name) 264 secondaryjoin = ", secondaryjoin='%s.%s==%s.%s'"%(left_lookup, 265 rel.secondaryjoin.left.name, 266 right_lookup, 267 rel.secondaryjoin.right.name) 268 backref='' 269# if rel.backref: 270# backref=", backref='%s'"%rel.backref.key 271 return "%s = relation('%s'%s%s%s%s)"%(rel.key, target, primaryjoin, secondary, secondaryjoin, backref) 272 273 @classmethod 274 def __repr__(cls): 275 log.debug('repring class with name %s'%cls.__name__) 276 try: 277 mapper = None 278 try: 279 mapper = class_mapper(cls) 280 except exc.InvalidRequestError: 281 log.warn("A proper mapper could not be generated for the class %s, no relations will be created"%model_name) 282 s = "" 283 s += "class "+model_name+'(DeclarativeBase):\n' 284 if is_many_to_many_table: 285 s += " __table__ = %s\n\n"%table_name 286 else: 287 s += " __tablename__ = '%s'\n\n"%table_name 288 if hasattr(cls, '__table_args__'): 289 #if cls.__table_args__[0]: 290 #for fkc in cls.__table_args__[0]: 291 # fkc.__class__.__repr__ = foreignkeyconstraint_repr 292 # break 293 s+=" __table_args__ = %s\n\n"%cls.__table_args__ 294 s += " #column definitions\n" 295 for column in sorted(cls.__table__.c, by_name): 296 s += " %s = %s\n"%(column.name, column_repr(column)) 297 s += "\n #relation definitions\n" 298 ess = s 299 # this is only required in SA 0.5 300 if mapper and RelationProperty: 301 for prop in mapper.iterate_properties: 302 if isinstance(prop, RelationshipProperty): 303 s+=' %s\n'%cls._relation_repr(prop) 304 return s 305 306 except Exception, e: 307 log.error("Could not generate class for: %s"%cls.__name__) 308 from traceback import format_exc 309 log.error(format_exc()) 310 return '' 311 312 313 #hack the class to have the right classname 314 Temporal.__name__ = model_name 315 316 #set up some blank table args 317 Temporal.__table_args__ = {} 318 319 #add in the schema 320 #if self.config.schema: 321 #Temporal.__table_args__[1]['schema'] = table.schema 322 323 #trick sa's model registry to think the model is the correct name 324 if model_name != 'Temporal': 325 Temporal._decl_class_registry[model_name] = Temporal._decl_class_registry['Temporal'] 326 del Temporal._decl_class_registry['Temporal'] 327 328 #add in single relations 329 fks = self.get_single_foreign_keys_by_column(table) 330 for column, fk in fks.iteritems(): 331 related_table = fk.column.table 332 if related_table not in self.tables: 333 continue 334 335 log.info(' Adding <primary> foreign key for:%s'%related_table.name) 336 backref_name = plural(table_name) 337 rel = relation(singular(name2label(related_table.name, related_table.schema)), 338 primaryjoin=column==fk.column)#, backref=backref_name) 339 340 setattr(Temporal, related_table.name, _deferred_relationship(Temporal, rel)) 341 342 #add in the relations for the composites 343 for constraint in table.constraints: 344 if isinstance(constraint, ForeignKeyConstraint): 345 if len(constraint.elements) >1: 346 related_table = constraint.elements[0].column.table 347 related_classname = singular(name2label(related_table.name, related_table.schema)) 348 349 primary_join = "and_(%s)"%', '.join(["%s.%s==%s.%s"%(model_name, 350 k.parent.name, 351 related_classname, 352 k.column.name) 353 for k in constraint.elements]) 354 rel = relation(related_classname, 355 primaryjoin=primary_join 356# foreign_keys=[k.parent for k in constraint.elements] 357 ) 358 359 rel._as_string = primary_join 360 setattr(Temporal, related_table.name, rel) # _deferred_relationship(Temporal, rel)) 361 362 363 #add in many-to-many relations 364 for join_table in self.get_related_many_to_many_tables(table.name): 365 366 if join_table not in self.tables: 367 continue 368 primary_column = [c for c in join_table.columns if c.foreign_keys and list(c.foreign_keys)[0].column.table==table][0] 369 370 for column in join_table.columns: 371 if column.foreign_keys: 372 key = list(column.foreign_keys)[0] 373 if key.column.table is not table: 374 related_column = related_table = list(column.foreign_keys)[0].column 375 related_table = related_column.table 376 if related_table not in self.tables: 377 continue 378 log.info(' Adding <secondary> foreign key(%s) for:%s'%(key, related_table.name)) 379 setattr(Temporal, plural(related_table.name), _deferred_relationship(Temporal, 380 relation(singular(name2label(related_table.name, 381 related_table.schema)), 382 secondary=join_table, 383 primaryjoin=list(primary_column.foreign_keys)[0].column==primary_column, 384 secondaryjoin=column==related_column 385 ))) 386 break; 387 388 return Temporal 389 390 def get_table(self, name): 391 """(name) -> sqlalchemy.schema.Table 392 get the table definition with the given table name 393 """ 394 if self.schemas: 395 for schema in self.schemas: 396 if schema and not name.startswith(schema): 397 new_name = '.'.join((schema, name)) 398 table = self._metadata.tables.get(new_name, None) 399 if table is not None: 400 return table 401 return self._metadata.tables[name] 402 403 def get_single_foreign_keys_by_column(self, table): 404 keys_by_column = {} 405 fks = self.get_foreign_keys(table) 406 for table, keys in fks.iteritems(): 407 if len(keys) == 1: 408 fk = keys[0] 409 keys_by_column[fk.parent] = fk 410 return keys_by_column 411 412 def get_composite_foreign_keys(self, table): 413 l = [] 414 fks = self.get_foreign_keys(table) 415 for table, keys in fks.iteritems(): 416 if len(keys)>1: 417 l.append(keys) 418 return l 419 420 421 def get_foreign_keys(self, table): 422 if table in self._foreign_keys: 423 return self._foreign_keys[table] 424 425 fks = table.foreign_keys 426 427 #group fks by table. I think this is needed because of a problem in the sa reflection alg. 428 grouped_fks = {} 429 for key in fks: 430 grouped_fks.setdefault(key.column.table, []).append(key) 431 432 self._foreign_keys[table] = grouped_fks 433 return grouped_fks 434 435# fks = {} 436# for column in table.columns: 437# if len(column.foreign_keys)>0: 438# fks.setdefault(column.name, []).extend(column.foreign_keys) 439# return fks 440 441 def is_many_to_many_table(self, table): 442 fks = self.get_single_foreign_keys_by_column(table).values() 443 return len(fks) >= 2 444 445 def is_only_many_to_many_table(self, table): 446 return len(self.get_single_foreign_keys_by_column(table)) == 2 and len(table.c) == 2 447 448 def get_many_to_many_tables(self): 449 if not hasattr(self, '_many_to_many_tables'): 450 self._many_to_many_tables = [table for table in self._metadata.tables.values() if self.is_many_to_many_table(table)] 451 return sorted(self._many_to_many_tables, by_name) 452 453 def get_non_many_to_many_tables(self): 454 tables = [table for table in self.tables if not(self.is_only_many_to_many_table(table))] 455 return sorted(tables, by_name) 456 457 def get_related_many_to_many_tables(self, table_name): 458 tables = [] 459 src_table = self.get_table(table_name) 460 for table in self.get_many_to_many_tables(): 461 for column in table.columns: 462 if column.foreign_keys: 463 key = list(column.foreign_keys)[0] 464 if key.column.table is src_table: 465 tables.append(table) 466 break 467 return sorted(tables, by_name)