PageRenderTime 26ms CodeModel.GetById 1ms app.highlight 20ms RepoModel.GetById 1ms app.codeStats 0ms

/tools/genome_diversity/cdblib.py

https://bitbucket.org/cistrome/cistrome-harvard/
Python | 230 lines | 218 code | 3 blank | 9 comment | 3 complexity | 755bdbfdc604256b5329cc01d487b875 MD5 | raw file
  1#!/usr/bin/env python2.5
  2
  3'''
  4Manipulate DJB's Constant Databases. These are 2 level disk-based hash tables
  5that efficiently handle many keys, while remaining space-efficient.
  6
  7    http://cr.yp.to/cdb.html
  8
  9When generated databases are only used with Python code, consider using hash()
 10rather than djb_hash() for a tidy speedup.
 11'''
 12
 13from _struct import Struct
 14from itertools import chain
 15
 16
 17def py_djb_hash(s):
 18    '''Return the value of DJB's hash function for the given 8-bit string.'''
 19    h = 5381
 20    for c in s:
 21        h = (((h << 5) + h) ^ ord(c)) & 0xffffffff
 22    return h
 23
 24try:
 25    from _cdblib import djb_hash
 26except ImportError:
 27    djb_hash = py_djb_hash
 28
 29read_2_le4 = Struct('<LL').unpack
 30write_2_le4 = Struct('<LL').pack
 31
 32
 33class Reader(object):
 34    '''A dictionary-like object for reading a Constant Database accessed
 35    through a string or string-like sequence, such as mmap.mmap().'''
 36
 37    def __init__(self, data, hashfn=djb_hash):
 38        '''Create an instance reading from a sequence and using hashfn to hash
 39        keys.'''
 40        if len(data) < 2048:
 41            raise IOError('CDB too small')
 42
 43        self.data = data
 44        self.hashfn = hashfn
 45
 46        self.index = [read_2_le4(data[i:i+8]) for i in xrange(0, 2048, 8)]
 47        self.table_start = min(p[0] for p in self.index)
 48        # Assume load load factor is 0.5 like official CDB.
 49        self.length = sum(p[1] >> 1 for p in self.index)
 50
 51    def iteritems(self):
 52        '''Like dict.iteritems(). Items are returned in insertion order.'''
 53        pos = 2048
 54        while pos < self.table_start:
 55            klen, dlen = read_2_le4(self.data[pos:pos+8])
 56            pos += 8
 57
 58            key = self.data[pos:pos+klen]
 59            pos += klen
 60
 61            data = self.data[pos:pos+dlen]
 62            pos += dlen
 63
 64            yield key, data
 65
 66    def items(self):
 67        '''Like dict.items().'''
 68        return list(self.iteritems())
 69
 70    def iterkeys(self):
 71        '''Like dict.iterkeys().'''
 72        return (p[0] for p in self.iteritems())
 73    __iter__ = iterkeys
 74
 75    def itervalues(self):
 76        '''Like dict.itervalues().'''
 77        return (p[1] for p in self.iteritems())
 78
 79    def keys(self):
 80        '''Like dict.keys().'''
 81        return [p[0] for p in self.iteritems()]
 82
 83    def values(self):
 84        '''Like dict.values().'''
 85        return [p[1] for p in self.iteritems()]
 86
 87    def __getitem__(self, key):
 88        '''Like dict.__getitem__().'''
 89        value = self.get(key)
 90        if value is None:
 91            raise KeyError(key)
 92        return value
 93
 94    def has_key(self, key):
 95        '''Return True if key exists in the database.'''
 96        return self.get(key) is not None
 97    __contains__ = has_key
 98
 99    def __len__(self):
100        '''Return the number of records in the database.'''
101        return self.length
102
103    def gets(self, key):
104        '''Yield values for key in insertion order.'''
105        # Truncate to 32 bits and remove sign.
106        h = self.hashfn(key) & 0xffffffff
107        start, nslots = self.index[h & 0xff]
108
109        if nslots:
110            end = start + (nslots << 3)
111            slot_off = start + (((h >> 8) % nslots) << 3)
112
113            for pos in chain(xrange(slot_off, end, 8),
114                             xrange(start, slot_off, 8)):
115                rec_h, rec_pos = read_2_le4(self.data[pos:pos+8])
116
117                if not rec_h:
118                    break
119                elif rec_h == h:
120                    klen, dlen = read_2_le4(self.data[rec_pos:rec_pos+8])
121                    rec_pos += 8
122
123                    if self.data[rec_pos:rec_pos+klen] == key:
124                        rec_pos += klen
125                        yield self.data[rec_pos:rec_pos+dlen]
126
127    def get(self, key, default=None):
128        '''Get the first value for key, returning default if missing.'''
129        # Avoid exception catch when handling default case; much faster.
130        return chain(self.gets(key), (default,)).next()
131
132    def getint(self, key, default=None, base=0):
133        '''Get the first value for key converted it to an int, returning
134        default if missing.'''
135        value = self.get(key, default)
136        if value is not default:
137            return int(value, base)
138        return value
139
140    def getints(self, key, base=0):
141        '''Yield values for key in insertion order after converting to int.'''
142        return (int(v, base) for v in self.gets(key))
143
144    def getstring(self, key, default=None, encoding='utf-8'):
145        '''Get the first value for key decoded as unicode, returning default if
146        not found.'''
147        value = self.get(key, default)
148        if value is not default:
149            return value.decode(encoding)
150        return value
151
152    def getstrings(self, key, encoding='utf-8'):
153        '''Yield values for key in insertion order after decoding as
154        unicode.'''
155        return (v.decode(encoding) for v in self.gets(key))
156
157
158class Writer(object):
159    '''Object for building new Constant Databases, and writing them to a
160    seekable file-like object.'''
161
162    def __init__(self, fp, hashfn=djb_hash):
163        '''Create an instance writing to a file-like object, using hashfn to
164        hash keys.'''
165        self.fp = fp
166        self.hashfn = hashfn
167
168        fp.write('\x00' * 2048)
169        self._unordered = [[] for i in xrange(256)]
170
171    def put(self, key, value=''):
172        '''Write a string key/value pair to the output file.'''
173        assert type(key) is str and type(value) is str
174
175        pos = self.fp.tell()
176        self.fp.write(write_2_le4(len(key), len(value)))
177        self.fp.write(key)
178        self.fp.write(value)
179
180        h = self.hashfn(key) & 0xffffffff
181        self._unordered[h & 0xff].append((h, pos))
182
183    def puts(self, key, values):
184        '''Write more than one value for the same key to the output file.
185        Equivalent to calling put() in a loop.'''
186        for value in values:
187            self.put(key, value)
188
189    def putint(self, key, value):
190        '''Write an integer as a base-10 string associated with the given key
191        to the output file.'''
192        self.put(key, str(value))
193
194    def putints(self, key, values):
195        '''Write zero or more integers for the same key to the output file.
196        Equivalent to calling putint() in a loop.'''
197        self.puts(key, (str(value) for value in values))
198
199    def putstring(self, key, value, encoding='utf-8'):
200        '''Write a unicode string associated with the given key to the output
201        file after encoding it as UTF-8 or the given encoding.'''
202        self.put(key, unicode.encode(value, encoding))
203
204    def putstrings(self, key, values, encoding='utf-8'):
205        '''Write zero or more unicode strings to the output file. Equivalent to
206        calling putstring() in a loop.'''
207        self.puts(key, (unicode.encode(value, encoding) for value in values))
208
209    def finalize(self):
210        '''Write the final hash tables to the output file, and write out its
211        index. The output file remains open upon return.'''
212        index = []
213        for tbl in self._unordered:
214            length = len(tbl) << 1
215            ordered = [(0, 0)] * length
216            for pair in tbl:
217                where = (pair[0] >> 8) % length
218                for i in chain(xrange(where, length), xrange(0, where)):
219                    if not ordered[i][0]:
220                        ordered[i] = pair
221                        break
222
223            index.append((self.fp.tell(), length))
224            for pair in ordered:
225                self.fp.write(write_2_le4(*pair))
226
227        self.fp.seek(0)
228        for pair in index:
229            self.fp.write(write_2_le4(*pair))
230        self.fp = None # prevent double finalize()