PageRenderTime 431ms CodeModel.GetById 101ms app.highlight 214ms RepoModel.GetById 108ms app.codeStats 0ms

/sqlautocode/declarative.py

https://code.google.com/p/sqlautocode/
Python | 467 lines | 398 code | 44 blank | 25 comment | 87 complexity | f870f66308470b71140c69c8fcd18faa MD5 | raw file
  1import sys, re, inspect, operator
  2import logging
  3from util import emit, name2label, plural, singular
  4try:
  5    from cStringIO import StringIO
  6except ImportError:
  7    from StringIO import StringIO
  8
  9import sqlalchemy
 10from sqlalchemy import exc, and_
 11from sqlalchemy import MetaData, ForeignKeyConstraint
 12from sqlalchemy.ext.declarative import declarative_base
 13try:
 14    from sqlalchemy.ext.declarative import _deferred_relationship
 15except ImportError:
 16    #SA 0.5 support
 17    from sqlalchemy.ext.declarative import _deferred_relation as _deferred_relationship
 18    
 19from sqlalchemy.orm import relation, backref, class_mapper, Mapper
 20
 21try:
 22    #SA 0.5 support
 23    from sqlalchemy.orm import RelationProperty
 24except ImportError:
 25    #SA 0.7 support
 26    try:
 27        from sqlalchemy.orm.properties import RelationshipProperty, RelationProperty
 28    except ImportError:
 29        RelationProperty = None
 30
 31
 32import config
 33import constants
 34from formatter import _repr_coltype_as, foreignkeyconstraint_repr
 35
 36log = logging.getLogger('saac.decl')
 37log.setLevel(logging.DEBUG)
 38handler = logging.StreamHandler()
 39formatter = logging.Formatter("%(name)s - %(levelname)s - %(message)s")
 40handler.setFormatter(formatter)
 41log.addHandler(handler)
 42
 43def by_name(a, b):
 44    if a.name>b.name:
 45        return 1
 46    return -1
 47def by__name__(a, b):
 48    if a.__name__ > b.__name__:
 49        return 1
 50    return -1
 51
 52def column_repr(self):
 53
 54    kwarg = []
 55    if self.key != self.name:
 56        kwarg.append( 'key')
 57
 58    if hasattr(self, 'primary_key') and self.primary_key:
 59        self.primary_key = True
 60        kwarg.append( 'primary_key')
 61
 62    if not self.nullable:
 63        kwarg.append( 'nullable')
 64    if self.onupdate:
 65        kwarg.append( 'onupdate')
 66    if self.default:
 67        kwarg.append( 'default')
 68    ks = ', '.join('%s=%r' % (k, getattr(self, k)) for k in kwarg)
 69
 70    name = self.name
 71
 72    if not hasattr(config, 'options') and self.config.options.generictypes:
 73        coltype = repr(self.type)
 74    elif type(self.type).__module__ == 'sqlalchemy.types':
 75        coltype = repr(self.type)
 76    else:
 77        # Try to 'cast' this column type to a cross-platform type
 78        # from sqlalchemy.types, dropping any database-specific type
 79        # arguments.
 80        for base in type(self.type).__mro__:
 81            if (base.__module__ == 'sqlalchemy.types' and
 82                base.__name__ in sqlalchemy.__all__):
 83                coltype = _repr_coltype_as(self.type, base)
 84                break
 85        # FIXME: if a dialect has a non-standard type that does not
 86        # derive from an ANSI type, there's no choice but to ignore
 87        # generic-types and output the exact type. However, import
 88        # headers have already been output and lack the required
 89        # dialect import.
 90        else:
 91            coltype = repr(self.type)
 92
 93    data = {'name': self.name,
 94            'type': coltype,
 95            'constraints': ', '.join(["ForeignKey('%s')"%cn.target_fullname for cn in self.foreign_keys]),
 96            'args': ks and ks or '',
 97            }
 98
 99    if data['constraints']:
100        if data['constraints']: data['constraints'] = ', ' + data['constraints']
101    if data['args']:
102        if data['args']: data['args'] = ', ' + data['args']
103
104    return constants.COLUMN % data
105
106class ModelFactory(object):
107
108    def __init__(self, config):
109        self.config = config
110        self.used_model_names = []
111        self.used_table_names = []
112        schema = getattr(self.config, 'schema', None)
113        self._metadata = MetaData(bind=config.engine)
114        self._foreign_keys = {}
115        kw = {}
116        self.schemas = None
117        if schema:
118            if isinstance(schema, (list, tuple)):
119                self.schemas = schema
120            else:
121                self.schemas = (schema, )
122            for schema in self.schemas:
123                log.info('Reflecting database... schema:%s'%schema)
124                self._metadata.reflect(schema=schema)
125        else:
126            log.info('Reflecting database...')
127            self._metadata.reflect()
128
129        self.DeclarativeBase = declarative_base(metadata=self._metadata)
130
131    def _table_repr(self, table):
132        s = "Table(u'%s', metadata,\n"%(table.name)
133        for column in table.c:
134            s += "    %s,\n"%column_repr(column)
135        if table.schema:
136            s +="    schema='%s'\n"%table.schema
137        s+=")"
138        return s
139
140    def __repr__(self):
141        tables = self.get_many_to_many_tables()
142        tables.extend(self.get_tables_with_no_pks())
143        models = self.models
144
145        s = StringIO()
146        engine = self.config.engine
147        if not isinstance(engine, basestring):
148            engine = str(engine.url)
149        s.write(constants.HEADER_DECL%engine)
150        if 'postgres' in engine:
151            s.write(constants.PG_IMPORT)
152
153        self.used_table_names = []
154        self.used_model_names = []
155        for table in tables:
156            if table not in self.tables:
157                continue
158            table_name = self.find_new_name(table.name, self.used_table_names)
159            self.used_table_names.append(table_name)
160            s.write('%s = %s\n\n'%(table_name, self._table_repr(table)))
161
162        for model in models:
163            s.write(model.__repr__())
164            s.write("\n\n")
165
166        if self.config.example or self.config.interactive:
167            s.write(constants.EXAMPLE_DECL%(models[0].__name__,models[0].__name__))
168        if self.config.interactive:
169            s.write(constants.INTERACTIVE%([model.__name__ for model in models], models[0].__name__))
170        return s.getvalue()
171
172    @property
173    def tables(self):
174        if self.config.options.tables:
175            tables = set(self.config.options.tables)
176            return [self._metadata.tables[t] for t in set(self._metadata.tables.keys()).intersection(tables)]
177        return self._metadata.tables.values()
178
179    @property
180    def table_names(self):
181        return [t.name for t in self.tables]
182    
183    @property
184    def models(self):
185        if hasattr(self, '_models'):
186            return self._models
187        self.used_model_names = []
188        self.used_table_names = []
189        self._models = []
190        for table in self.get_non_many_to_many_tables():
191            try:
192                self._models.append(self.create_model(table))
193            except exc.ArgumentError:
194                log.warning("Table with name %s ha no primary key. No ORM class created"%table.name)
195        self._models.sort(by__name__)
196        return self._models
197    
198    def get_tables_with_no_pks(self):
199        r = []
200        for table in self.get_non_many_to_many_tables():
201            if not [c for c in table.columns if c.primary_key]:
202                r.append(table)
203        return r
204    
205    def model_table_lookup(self):
206        if hasattr(self, '_model_table_lookup'):
207            return self._model_table_lookup
208        self._model_table_lookup = dict(((m.__table__.name, m.__name__) for m in self.models))
209        return self._model_table_lookup
210
211    def find_new_name(self, prefix, used, i=0):
212        if i!=0:
213            prefix = "%s%d"%(prefix, i)
214        if prefix in used:
215            prefix = prefix
216            return self.find_new_name(prefix, used, i+1)
217        return prefix
218
219
220    def create_model(self, table):
221        #partially borrowed from Jorge Vargas' code
222        #http://dpaste.org/V6YS/
223        log.debug('Creating Model from table: %s'%table.name)
224
225        model_name = self.find_new_name(singular(name2label(table.name)), self.used_model_names)
226        self.used_model_names.append(model_name)
227        is_many_to_many_table = self.is_many_to_many_table(table)
228        table_name = self.find_new_name(table.name, self.used_table_names)
229        self.used_table_names.append(table_name)
230
231        mtl = self.model_table_lookup
232
233            
234        class Temporal(self.DeclarativeBase):
235            __table__ = table
236            
237            @classmethod
238            def _relation_repr(cls, rel):
239                target = rel.argument
240                if target and inspect.isfunction(target):
241                    target = target()
242                if isinstance(target, Mapper):
243                    target = target.class_
244                target = target.__name__
245                primaryjoin=''
246                lookup = mtl()
247                if rel.primaryjoin is not None and hasattr(rel.primaryjoin, 'right'):
248                    right_lookup = lookup.get(rel.primaryjoin.right.table.name, '%s.c'%rel.primaryjoin.right.table.name)
249                    left_lookup = lookup.get(rel.primaryjoin.left.table.name, '%s.c'%rel.primaryjoin.left.table.name)
250                    
251                    primaryjoin = ", primaryjoin='%s.%s==%s.%s'"%(left_lookup,
252                                                                  rel.primaryjoin.left.name,
253                                                                  right_lookup,
254                                                                  rel.primaryjoin.right.name)
255                elif hasattr(rel, '_as_string'):
256                    primaryjoin=', primaryjoin="%s"'%rel._as_string
257                    
258                secondary = ''
259                secondaryjoin = ''
260                if rel.secondary is not None:
261                    secondary = ", secondary=%s"%rel.secondary.name
262                    right_lookup = lookup.get(rel.secondaryjoin.right.table.name, '%s.c'%rel.secondaryjoin.right.table.name)
263                    left_lookup = lookup.get(rel.secondaryjoin.left.table.name, '%s.c'%rel.secondaryjoin.left.table.name)
264                    secondaryjoin = ", secondaryjoin='%s.%s==%s.%s'"%(left_lookup,
265                                                                  rel.secondaryjoin.left.name,
266                                                                  right_lookup,
267                                                                  rel.secondaryjoin.right.name)
268                backref=''
269#                if rel.backref:
270#                    backref=", backref='%s'"%rel.backref.key
271                return "%s = relation('%s'%s%s%s%s)"%(rel.key, target, primaryjoin, secondary, secondaryjoin, backref)
272                
273            @classmethod
274            def __repr__(cls):
275                log.debug('repring class with name %s'%cls.__name__)
276                try:
277                    mapper = None
278                    try:
279                        mapper = class_mapper(cls)
280                    except exc.InvalidRequestError:
281                        log.warn("A proper mapper could not be generated for the class %s, no relations will be created"%model_name)
282                    s = ""
283                    s += "class "+model_name+'(DeclarativeBase):\n'
284                    if is_many_to_many_table:
285                        s += "    __table__ = %s\n\n"%table_name
286                    else:
287                        s += "    __tablename__ = '%s'\n\n"%table_name
288                        if hasattr(cls, '__table_args__'):
289                            #if cls.__table_args__[0]:
290                                #for fkc in cls.__table_args__[0]:
291                                #    fkc.__class__.__repr__ = foreignkeyconstraint_repr
292                                #    break
293                            s+="    __table_args__ = %s\n\n"%cls.__table_args__
294                        s += "    #column definitions\n"
295                        for column in sorted(cls.__table__.c, by_name):
296                            s += "    %s = %s\n"%(column.name, column_repr(column))
297                    s += "\n    #relation definitions\n"
298                    ess = s
299                    # this is only required in SA 0.5
300                    if mapper and RelationProperty: 
301                        for prop in mapper.iterate_properties:
302                            if isinstance(prop, RelationshipProperty):
303                                s+='    %s\n'%cls._relation_repr(prop)
304                    return s
305                    
306                except Exception, e:
307                    log.error("Could not generate class for: %s"%cls.__name__)
308                    from traceback import format_exc
309                    log.error(format_exc())
310                    return ''
311                    
312
313        #hack the class to have the right classname
314        Temporal.__name__ = model_name
315        
316        #set up some blank table args
317        Temporal.__table_args__ = {} 
318        
319        #add in the schema
320        #if self.config.schema:
321            #Temporal.__table_args__[1]['schema'] = table.schema
322
323        #trick sa's model registry to think the model is the correct name
324        if model_name != 'Temporal':
325            Temporal._decl_class_registry[model_name] = Temporal._decl_class_registry['Temporal']
326            del Temporal._decl_class_registry['Temporal']
327
328        #add in single relations
329        fks = self.get_single_foreign_keys_by_column(table)
330        for column, fk in fks.iteritems():
331            related_table = fk.column.table
332            if related_table not in self.tables:
333                continue
334
335            log.info('    Adding <primary> foreign key for:%s'%related_table.name)
336            backref_name = plural(table_name)
337            rel = relation(singular(name2label(related_table.name, related_table.schema)), 
338                           primaryjoin=column==fk.column)#, backref=backref_name)
339        
340            setattr(Temporal, related_table.name, _deferred_relationship(Temporal, rel))
341        
342        #add in the relations for the composites
343        for constraint in table.constraints:
344            if isinstance(constraint, ForeignKeyConstraint):
345                if len(constraint.elements) >1:
346                    related_table = constraint.elements[0].column.table
347                    related_classname = singular(name2label(related_table.name, related_table.schema))
348                                    
349                    primary_join = "and_(%s)"%', '.join(["%s.%s==%s.%s"%(model_name,
350                                                                        k.parent.name,
351                                                                        related_classname,
352                                                                        k.column.name)
353                                                      for k in constraint.elements])
354                    rel = relation(related_classname,
355                                   primaryjoin=primary_join
356#                                   foreign_keys=[k.parent for k in constraint.elements]
357                               )
358                    
359                    rel._as_string = primary_join
360                    setattr(Temporal, related_table.name, rel) # _deferred_relationship(Temporal, rel))
361                
362        
363        #add in many-to-many relations
364        for join_table in self.get_related_many_to_many_tables(table.name):
365
366            if join_table not in self.tables:
367                continue
368            primary_column = [c for c in join_table.columns if c.foreign_keys and list(c.foreign_keys)[0].column.table==table][0]
369            
370            for column in join_table.columns:
371                if column.foreign_keys:
372                    key = list(column.foreign_keys)[0]
373                    if key.column.table is not table:
374                        related_column = related_table = list(column.foreign_keys)[0].column
375                        related_table = related_column.table
376                        if related_table not in self.tables:
377                            continue
378                        log.info('    Adding <secondary> foreign key(%s) for:%s'%(key, related_table.name))
379                        setattr(Temporal, plural(related_table.name), _deferred_relationship(Temporal,
380                                                                                         relation(singular(name2label(related_table.name,
381                                                                                                             related_table.schema)),
382                                                                                                  secondary=join_table,
383                                                                                                  primaryjoin=list(primary_column.foreign_keys)[0].column==primary_column,
384                                                                                                  secondaryjoin=column==related_column
385                                                                                                  )))
386                        break;
387        
388        return Temporal
389
390    def get_table(self, name):
391        """(name) -> sqlalchemy.schema.Table
392        get the table definition with the given table name
393        """
394        if self.schemas:
395            for schema in self.schemas:
396                if schema and not name.startswith(schema):
397                    new_name = '.'.join((schema, name))
398                table = self._metadata.tables.get(new_name, None)
399                if table is not None:
400                    return table
401        return self._metadata.tables[name]
402
403    def get_single_foreign_keys_by_column(self, table):
404        keys_by_column = {}
405        fks = self.get_foreign_keys(table)
406        for table, keys in fks.iteritems():
407            if len(keys) == 1:
408                fk = keys[0]
409                keys_by_column[fk.parent] = fk
410        return keys_by_column
411
412    def get_composite_foreign_keys(self, table):
413        l = []
414        fks = self.get_foreign_keys(table)
415        for table, keys in fks.iteritems():
416            if len(keys)>1:
417                l.append(keys)
418        return l
419        
420        
421    def get_foreign_keys(self, table):
422        if table in self._foreign_keys:
423            return self._foreign_keys[table]
424        
425        fks = table.foreign_keys
426
427        #group fks by table.  I think this is needed because of a problem in the sa reflection alg.
428        grouped_fks = {}
429        for key in fks:
430            grouped_fks.setdefault(key.column.table, []).append(key)
431        
432        self._foreign_keys[table] = grouped_fks
433        return grouped_fks
434    
435#        fks = {}
436#        for column in table.columns:
437#            if len(column.foreign_keys)>0:
438#                fks.setdefault(column.name, []).extend(column.foreign_keys)
439#        return fks
440
441    def is_many_to_many_table(self, table):
442        fks = self.get_single_foreign_keys_by_column(table).values()
443        return len(fks) >= 2
444
445    def is_only_many_to_many_table(self, table):
446        return len(self.get_single_foreign_keys_by_column(table)) == 2 and len(table.c) == 2
447
448    def get_many_to_many_tables(self):
449        if not hasattr(self, '_many_to_many_tables'):
450            self._many_to_many_tables = [table for table in self._metadata.tables.values() if self.is_many_to_many_table(table)]
451        return sorted(self._many_to_many_tables, by_name)
452
453    def get_non_many_to_many_tables(self):
454        tables = [table for table in self.tables if not(self.is_only_many_to_many_table(table))]
455        return sorted(tables, by_name)
456
457    def get_related_many_to_many_tables(self, table_name):
458        tables = []
459        src_table = self.get_table(table_name)
460        for table in self.get_many_to_many_tables():
461            for column in table.columns:
462                if column.foreign_keys:
463                    key = list(column.foreign_keys)[0]
464                    if key.column.table is src_table:
465                        tables.append(table)
466                        break
467        return sorted(tables, by_name)