PageRenderTime 80ms CodeModel.GetById 18ms app.highlight 56ms RepoModel.GetById 0ms app.codeStats 1ms

/bangkokhotel/lib/python2.5/site-packages/whoosh/codec/whoosh2.py

https://bitbucket.org/luisrodriguez/bangkokhotel
Python | 1048 lines | 748 code | 209 blank | 91 comment | 111 complexity | 293c167d6802093bed75d095d2e4e599 MD5 | raw file
   1# Copyright 2011 Matt Chaput. All rights reserved.
   2#
   3# Redistribution and use in source and binary forms, with or without
   4# modification, are permitted provided that the following conditions are met:
   5#
   6#    1. Redistributions of source code must retain the above copyright notice,
   7#       this list of conditions and the following disclaimer.
   8#
   9#    2. Redistributions in binary form must reproduce the above copyright
  10#       notice, this list of conditions and the following disclaimer in the
  11#       documentation and/or other materials provided with the distribution.
  12#
  13# THIS SOFTWARE IS PROVIDED BY MATT CHAPUT ``AS IS'' AND ANY EXPRESS OR
  14# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
  15# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
  16# EVENT SHALL MATT CHAPUT OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
  17# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
  18# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
  19# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
  20# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
  21# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
  22# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  23#
  24# The views and conclusions contained in the software and documentation are
  25# those of the authors and should not be interpreted as representing official
  26# policies, either expressed or implied, of Matt Chaput.
  27
  28
  29from array import array
  30from collections import defaultdict
  31from struct import Struct
  32
  33from whoosh.compat import (loads, dumps, xrange, iteritems, itervalues, b,
  34                           bytes_type, string_type, integer_types)
  35from whoosh.codec import base
  36from whoosh.codec.base import (minimize_ids, deminimize_ids, minimize_weights,
  37                               deminimize_weights, minimize_values,
  38                               deminimize_values)
  39from whoosh.filedb.fileindex import TOC, clean_files
  40from whoosh.filedb.filetables import CodedOrderedWriter, CodedOrderedReader
  41from whoosh.matching import ListMatcher
  42from whoosh.reading import TermNotFound
  43from whoosh.store import Storage
  44from whoosh.support.dawg import GraphWriter, GraphReader
  45from whoosh.system import (pack_ushort, pack_long, unpack_ushort, unpack_long,
  46                           _INT_SIZE, _LONG_SIZE)
  47from whoosh.util import byte_to_length, length_to_byte, utf8encode, utf8decode
  48
  49
  50# Standard codec top-level object
  51
  52class W2Codec(base.Codec):
  53    TERMS_EXT = ".trm"  # Term index
  54    POSTS_EXT = ".pst"  # Term postings
  55    DAWG_EXT = ".dag"  # Spelling graph file
  56    LENGTHS_EXT = ".fln"  # Field lengths file
  57    VECTOR_EXT = ".vec"  # Vector index
  58    VPOSTS_EXT = ".vps"  # Vector postings
  59    STORED_EXT = ".sto"  # Stored fields file
  60
  61    def __init__(self, blocklimit=128, compression=3, loadlengths=False,
  62                 inlinelimit=1):
  63        self.blocklimit = blocklimit
  64        self.compression = compression
  65        self.loadlengths = loadlengths
  66        self.inlinelimit = inlinelimit
  67
  68    # Per-document value writer
  69    def per_document_writer(self, storage, segment):
  70        return W2PerDocWriter(storage, segment, blocklimit=self.blocklimit,
  71                              compression=self.compression)
  72
  73    # Inverted index writer
  74    def field_writer(self, storage, segment):
  75        return W2FieldWriter(storage, segment, blocklimit=self.blocklimit,
  76                             compression=self.compression,
  77                             inlinelimit=self.inlinelimit)
  78
  79    # Readers
  80
  81    def terms_reader(self, storage, segment):
  82        tifile = segment.open_file(storage, self.TERMS_EXT)
  83        postfile = segment.open_file(storage, self.POSTS_EXT)
  84        return W2TermsReader(tifile, postfile)
  85
  86    def lengths_reader(self, storage, segment):
  87        flfile = segment.open_file(storage, self.LENGTHS_EXT)
  88        doccount = segment.doc_count_all()
  89
  90        # Check the first byte of the file to see if it's an old format
  91        firstbyte = flfile.read(1)
  92        flfile.seek(0)
  93        if firstbyte != b("~"):
  94            from whoosh.codec.legacy import load_old_lengths
  95            lengths = load_old_lengths(InMemoryLengths(), flfile, doccount)
  96        elif self.loadlengths:
  97            lengths = InMemoryLengths.from_file(flfile, doccount)
  98        else:
  99            lengths = OnDiskLengths(flfile, doccount)
 100        return lengths
 101
 102    def vector_reader(self, storage, segment):
 103        vifile = segment.open_file(storage, self.VECTOR_EXT)
 104        postfile = segment.open_file(storage, self.VPOSTS_EXT)
 105        return W2VectorReader(vifile, postfile)
 106
 107    def stored_fields_reader(self, storage, segment):
 108        sffile = segment.open_file(storage, self.STORED_EXT)
 109        return StoredFieldReader(sffile)
 110
 111    def graph_reader(self, storage, segment):
 112        dawgfile = segment.open_file(storage, self.DAWG_EXT)
 113        return GraphReader(dawgfile)
 114
 115    # Segments and generations
 116
 117    def new_segment(self, storage, indexname):
 118        return W2Segment(indexname)
 119
 120    def commit_toc(self, storage, indexname, schema, segments, generation,
 121                   clean=True):
 122        toc = TOC(schema, segments, generation)
 123        toc.write(storage, indexname)
 124        # Delete leftover files
 125        if clean:
 126            clean_files(storage, indexname, generation, segments)
 127
 128
 129# Per-document value writer
 130
 131class W2PerDocWriter(base.PerDocumentWriter):
 132    def __init__(self, storage, segment, blocklimit=128, compression=3):
 133        if not isinstance(blocklimit, int):
 134            raise ValueError
 135        self.storage = storage
 136        self.segment = segment
 137        self.blocklimit = blocklimit
 138        self.compression = compression
 139        self.doccount = 0
 140
 141        sffile = segment.create_file(storage, W2Codec.STORED_EXT)
 142        self.stored = StoredFieldWriter(sffile)
 143        self.storedfields = None
 144
 145        self.lengths = InMemoryLengths()
 146
 147        # We'll wait to create the vector files until someone actually tries
 148        # to add a vector
 149        self.vindex = self.vpostfile = None
 150
 151    def _make_vector_files(self):
 152        vifile = self.segment.create_file(self.storage, W2Codec.VECTOR_EXT)
 153        self.vindex = VectorWriter(vifile)
 154        self.vpostfile = self.segment.create_file(self.storage,
 155                                                  W2Codec.VPOSTS_EXT)
 156
 157    def start_doc(self, docnum):
 158        self.docnum = docnum
 159        self.storedfields = {}
 160        self.doccount = max(self.doccount, docnum + 1)
 161
 162    def add_field(self, fieldname, fieldobj, value, length):
 163        if length:
 164            self.lengths.add(self.docnum, fieldname, length)
 165        if value is not None:
 166            self.storedfields[fieldname] = value
 167
 168    def _new_block(self, vformat):
 169        postingsize = vformat.posting_size
 170        return W2Block(postingsize, stringids=True)
 171
 172    def add_vector_items(self, fieldname, fieldobj, items):
 173        if self.vindex is None:
 174            self._make_vector_files()
 175
 176        # items = (text, freq, weight, valuestring) ...
 177        postfile = self.vpostfile
 178        blocklimit = self.blocklimit
 179        block = self._new_block(fieldobj.vector)
 180
 181        startoffset = postfile.tell()
 182        postfile.write(block.magic)  # Magic number
 183        blockcount = 0
 184        postfile.write_uint(0)  # Placeholder for block count
 185
 186        countdown = blocklimit
 187        for text, _, weight, valuestring in items:
 188            block.add(text, weight, valuestring)
 189            countdown -= 1
 190            if countdown == 0:
 191                block.to_file(postfile, compression=self.compression)
 192                block = self._new_block(fieldobj.vector)
 193                blockcount += 1
 194                countdown = blocklimit
 195        # If there are leftover items in the current block, write them out
 196        if block:
 197            block.to_file(postfile, compression=self.compression)
 198            blockcount += 1
 199
 200        # Seek back to the start of this list of posting blocks and write the
 201        # number of blocks
 202        postfile.flush()
 203        here = postfile.tell()
 204        postfile.seek(startoffset + 4)
 205        postfile.write_uint(blockcount)
 206        postfile.seek(here)
 207
 208        # Add to the index
 209        self.vindex.add((self.docnum, fieldname), startoffset)
 210
 211    def add_vector_matcher(self, fieldname, fieldobj, vmatcher):
 212        def readitems():
 213            while vmatcher.is_active():
 214                text = vmatcher.id()
 215                weight = vmatcher.weight()
 216                valuestring = vmatcher.value()
 217                yield (text, None, weight, valuestring)
 218                vmatcher.next()
 219        self.add_vector_items(fieldname, fieldobj, readitems())
 220
 221    def finish_doc(self):
 222        self.stored.add(self.storedfields)
 223        self.storedfields = None
 224
 225    def lengths_reader(self):
 226        return self.lengths
 227
 228    def close(self):
 229        if self.storedfields is not None:
 230            self.stored.add(self.storedfields)
 231        self.stored.close()
 232        flfile = self.segment.create_file(self.storage, W2Codec.LENGTHS_EXT)
 233        self.lengths.to_file(flfile, self.doccount)
 234        if self.vindex:
 235            self.vindex.close()
 236            self.vpostfile.close()
 237
 238
 239# Inverted index writer
 240
 241class W2FieldWriter(base.FieldWriter):
 242    def __init__(self, storage, segment, blocklimit=128, compression=3,
 243                 inlinelimit=1):
 244        assert isinstance(storage, Storage)
 245        assert isinstance(segment, base.Segment)
 246        assert isinstance(blocklimit, int)
 247        assert isinstance(compression, int)
 248        assert isinstance(inlinelimit, int)
 249
 250        self.storage = storage
 251        self.segment = segment
 252        self.fieldname = None
 253        self.text = None
 254        self.field = None
 255        self.format = None
 256        self.spelling = False
 257
 258        tifile = segment.create_file(storage, W2Codec.TERMS_EXT)
 259        self.termsindex = TermIndexWriter(tifile)
 260        self.postfile = segment.create_file(storage, W2Codec.POSTS_EXT)
 261
 262        # We'll wait to create the DAWG builder until someone actually adds
 263        # a spelled field
 264        self.dawg = None
 265
 266        self.blocklimit = blocklimit
 267        self.compression = compression
 268        self.inlinelimit = inlinelimit
 269        self.block = None
 270        self.terminfo = None
 271        self._infield = False
 272
 273    def _make_dawg_files(self):
 274        dawgfile = self.segment.create_file(self.storage, W2Codec.DAWG_EXT)
 275        self.dawg = GraphWriter(dawgfile)
 276
 277    def _new_block(self):
 278        return W2Block(self.format.posting_size)
 279
 280    def _reset_block(self):
 281        self.block = self._new_block()
 282
 283    def _write_block(self):
 284        self.terminfo.add_block(self.block)
 285        self.block.to_file(self.postfile, compression=self.compression)
 286        self._reset_block()
 287        self.blockcount += 1
 288
 289    def _start_blocklist(self):
 290        postfile = self.postfile
 291        self._reset_block()
 292
 293        # Magic number
 294        self.startoffset = postfile.tell()
 295        postfile.write(W2Block.magic)
 296        # Placeholder for block count
 297        self.blockcount = 0
 298        postfile.write_uint(0)
 299
 300    def start_field(self, fieldname, fieldobj):
 301        self.fieldname = fieldname
 302        self.field = fieldobj
 303        self.format = fieldobj.format
 304        self.spelling = fieldobj.spelling and not fieldobj.separate_spelling()
 305        self._dawgfield = False
 306        if self.spelling or fieldobj.separate_spelling():
 307            if self.dawg is None:
 308                self._make_dawg_files()
 309            self.dawg.start_field(fieldname)
 310            self._dawgfield = True
 311        self._infield = True
 312
 313    def start_term(self, text):
 314        if self.block is not None:
 315            raise Exception("Called start_term in a block")
 316        self.text = text
 317        self.terminfo = base.FileTermInfo()
 318        if self.spelling:
 319            self.dawg.insert(text)
 320        self._start_blocklist()
 321
 322    def add(self, docnum, weight, valuestring, length):
 323        self.block.add(docnum, weight, valuestring, length)
 324        if len(self.block) > self.blocklimit:
 325            self._write_block()
 326
 327    def add_spell_word(self, fieldname, text):
 328        if self.dawg is None:
 329            self._make_dawg_files()
 330        self.dawg.insert(text)
 331
 332    def finish_term(self):
 333        block = self.block
 334        if block is None:
 335            raise Exception("Called finish_term when not in a block")
 336
 337        terminfo = self.terminfo
 338        if self.blockcount < 1 and block and len(block) < self.inlinelimit:
 339            # Inline the single block
 340            terminfo.add_block(block)
 341            vals = None if not block.values else tuple(block.values)
 342            postings = (tuple(block.ids), tuple(block.weights), vals)
 343        else:
 344            if block:
 345                # Write the current unfinished block to disk
 346                self._write_block()
 347
 348            # Seek back to the start of this list of posting blocks and write
 349            # the number of blocks
 350            postfile = self.postfile
 351            postfile.flush()
 352            here = postfile.tell()
 353            postfile.seek(self.startoffset + 4)
 354            postfile.write_uint(self.blockcount)
 355            postfile.seek(here)
 356
 357            self.block = None
 358            postings = self.startoffset
 359
 360        self.block = None
 361        terminfo.postings = postings
 362        self.termsindex.add((self.fieldname, self.text), terminfo)
 363
 364    def finish_field(self):
 365        if not self._infield:
 366            raise Exception("Called finish_field before start_field")
 367        self._infield = False
 368
 369        if self._dawgfield:
 370            self.dawg.finish_field()
 371            self._dawgfield = False
 372
 373    def close(self):
 374        self.termsindex.close()
 375        self.postfile.close()
 376        if self.dawg is not None:
 377            self.dawg.close()
 378
 379
 380# Matcher
 381
 382class PostingMatcher(base.BlockPostingMatcher):
 383    def __init__(self, postfile, startoffset, fmt, scorer=None, term=None,
 384                 stringids=False):
 385        self.postfile = postfile
 386        self.startoffset = startoffset
 387        self.format = fmt
 388        self.scorer = scorer
 389        self._term = term
 390        self.stringids = stringids
 391
 392        postfile.seek(startoffset)
 393        magic = postfile.read(4)
 394        if magic != W2Block.magic:
 395            from whoosh.codec.legacy import old_block_type
 396            self.blockclass = old_block_type(magic)
 397        else:
 398            self.blockclass = W2Block
 399
 400        self.blockcount = postfile.read_uint()
 401        self.baseoffset = postfile.tell()
 402
 403        self._active = True
 404        self.currentblock = -1
 405        self._next_block()
 406
 407    def is_active(self):
 408        return self._active
 409
 410    def _read_block(self, offset):
 411        pf = self.postfile
 412        pf.seek(offset)
 413        return self.blockclass.from_file(pf, self.format.posting_size,
 414                                         stringids=self.stringids)
 415
 416    def _consume_block(self):
 417        self.block.read_ids()
 418        self.block.read_weights()
 419        self.i = 0
 420
 421    def _next_block(self, consume=True):
 422        if not (self.currentblock < self.blockcount):
 423            raise Exception("No next block")
 424
 425        self.currentblock += 1
 426        if self.currentblock == self.blockcount:
 427            self._active = False
 428            return
 429
 430        if self.currentblock == 0:
 431            pos = self.baseoffset
 432        else:
 433            pos = self.block.nextoffset
 434
 435        self.block = self._read_block(pos)
 436        if consume:
 437            self._consume_block()
 438
 439    def _skip_to_block(self, targetfn):
 440        skipped = 0
 441        while self._active and targetfn():
 442            self._next_block(consume=False)
 443            skipped += 1
 444
 445        if self._active:
 446            self._consume_block()
 447
 448        return skipped
 449
 450    def score(self):
 451        return self.scorer.score(self)
 452
 453
 454# Tables
 455
 456# Term index
 457
 458class TermIndexWriter(CodedOrderedWriter):
 459    def __init__(self, dbfile):
 460        super(TermIndexWriter, self).__init__(dbfile)
 461        self.fieldcounter = 0
 462        self.fieldmap = {}
 463
 464    def keycoder(self, key):
 465        # Encode term
 466        fieldmap = self.fieldmap
 467        fieldname, text = key
 468
 469        if fieldname in fieldmap:
 470            fieldnum = fieldmap[fieldname]
 471        else:
 472            fieldnum = self.fieldcounter
 473            fieldmap[fieldname] = fieldnum
 474            self.fieldcounter += 1
 475
 476        key = pack_ushort(fieldnum) + utf8encode(text)[0]
 477        return key
 478
 479    def valuecoder(self, terminfo):
 480        return terminfo.to_string()
 481
 482    def close(self):
 483        self._write_hashes()
 484        dbfile = self.dbfile
 485
 486        dbfile.write_uint(len(self.index))
 487        for n in self.index:
 488            dbfile.write_long(n)
 489        dbfile.write_pickle(self.fieldmap)
 490
 491        self._write_directory()
 492        self.dbfile.close()
 493
 494
 495class PostingIndexBase(CodedOrderedReader):
 496    # Shared base class for terms index and vector index readers
 497    def __init__(self, dbfile, postfile):
 498        CodedOrderedReader.__init__(self, dbfile)
 499        self.postfile = postfile
 500
 501        dbfile.seek(self.indexbase + self.length * _LONG_SIZE)
 502        self.fieldmap = dbfile.read_pickle()
 503        self.names = [None] * len(self.fieldmap)
 504        for name, num in iteritems(self.fieldmap):
 505            self.names[num] = name
 506
 507    def close(self):
 508        CodedOrderedReader.close(self)
 509        self.postfile.close()
 510
 511
 512class W2TermsReader(PostingIndexBase):
 513    # Implements whoosh.codec.base.TermsReader
 514
 515    def terminfo(self, fieldname, text):
 516        return self[fieldname, text]
 517
 518    def matcher(self, fieldname, text, format_, scorer=None):
 519        # Note this does not filter out deleted documents; a higher level is
 520        # expected to wrap this matcher to eliminate deleted docs
 521        pf = self.postfile
 522        term = (fieldname, text)
 523        try:
 524            terminfo = self[term]
 525        except KeyError:
 526            raise TermNotFound("No term %s:%r" % (fieldname, text))
 527
 528        p = terminfo.postings
 529        if isinstance(p, integer_types):
 530            # terminfo.postings is an offset into the posting file
 531            pr = PostingMatcher(pf, p, format_, scorer=scorer, term=term)
 532        else:
 533            # terminfo.postings is an inlined tuple of (ids, weights, values)
 534            docids, weights, values = p
 535            pr = ListMatcher(docids, weights, values, format_, scorer=scorer,
 536                             term=term, terminfo=terminfo)
 537        return pr
 538
 539    def keycoder(self, key):
 540        fieldname, text = key
 541        fnum = self.fieldmap.get(fieldname, 65535)
 542        return pack_ushort(fnum) + utf8encode(text)[0]
 543
 544    def keydecoder(self, v):
 545        assert isinstance(v, bytes_type)
 546        return (self.names[unpack_ushort(v[:2])[0]], utf8decode(v[2:])[0])
 547
 548    def valuedecoder(self, v):
 549        assert isinstance(v, bytes_type)
 550        return base.FileTermInfo.from_string(v)
 551
 552    def frequency(self, key):
 553        datapos = self.range_for_key(key)[0]
 554        return base.FileTermInfo.read_weight(self.dbfile, datapos)
 555
 556    def doc_frequency(self, key):
 557        datapos = self.range_for_key(key)[0]
 558        return base.FileTermInfo.read_doc_freq(self.dbfile, datapos)
 559
 560
 561# Vectors
 562
 563# docnum, fieldnum
 564_vectorkey_struct = Struct("!IH")
 565
 566
 567class VectorWriter(TermIndexWriter):
 568    def keycoder(self, key):
 569        fieldmap = self.fieldmap
 570        docnum, fieldname = key
 571
 572        if fieldname in fieldmap:
 573            fieldnum = fieldmap[fieldname]
 574        else:
 575            fieldnum = self.fieldcounter
 576            fieldmap[fieldname] = fieldnum
 577            self.fieldcounter += 1
 578
 579        return _vectorkey_struct.pack(docnum, fieldnum)
 580
 581    def valuecoder(self, offset):
 582        return pack_long(offset)
 583
 584
 585class W2VectorReader(PostingIndexBase):
 586    # Implements whoosh.codec.base.VectorReader
 587
 588    def matcher(self, docnum, fieldname, format_):
 589        pf = self.postfile
 590        offset = self[(docnum, fieldname)]
 591        pr = PostingMatcher(pf, offset, format_, stringids=True)
 592        return pr
 593
 594    def keycoder(self, key):
 595        return _vectorkey_struct.pack(key[0], self.fieldmap[key[1]])
 596
 597    def keydecoder(self, v):
 598        docnum, fieldnum = _vectorkey_struct.unpack(v)
 599        return (docnum, self.names[fieldnum])
 600
 601    def valuedecoder(self, v):
 602        return unpack_long(v)[0]
 603
 604
 605# Field lengths
 606
 607class LengthsBase(base.LengthsReader):
 608    magic = b("~LN1")
 609
 610    def __init__(self):
 611        self.starts = {}
 612        self.totals = {}
 613        self.minlens = {}
 614        self.maxlens = {}
 615
 616    def _read_header(self, dbfile, doccount):
 617        first = dbfile.read(4)  # Magic
 618        assert first == self.magic
 619        version = dbfile.read_int()  # Version number
 620        assert version == 1
 621
 622        dc = dbfile.read_uint()  # Number of documents saved
 623        if doccount is None:
 624            doccount = dc
 625        assert dc == doccount, "read=%s argument=%s" % (dc, doccount)
 626        self._count = doccount
 627
 628        fieldcount = dbfile.read_ushort()  # Number of fields
 629        # Read per-field info
 630        for i in xrange(fieldcount):
 631            fieldname = dbfile.read_string().decode('utf-8')
 632            self.totals[fieldname] = dbfile.read_long()
 633            self.minlens[fieldname] = byte_to_length(dbfile.read_byte())
 634            self.maxlens[fieldname] = byte_to_length(dbfile.read_byte())
 635            self.starts[fieldname] = i * doccount
 636
 637        # Add header length to per-field offsets
 638        eoh = dbfile.tell()  # End of header
 639        for fieldname in self.starts:
 640            self.starts[fieldname] += eoh
 641
 642    def doc_count_all(self):
 643        return self._count
 644
 645    def field_length(self, fieldname):
 646        return self.totals.get(fieldname, 0)
 647
 648    def min_field_length(self, fieldname):
 649        return self.minlens.get(fieldname, 0)
 650
 651    def max_field_length(self, fieldname):
 652        return self.maxlens.get(fieldname, 0)
 653
 654
 655class InMemoryLengths(LengthsBase):
 656    def __init__(self):
 657        LengthsBase.__init__(self)
 658        self.totals = defaultdict(int)
 659        self.lengths = {}
 660        self._count = 0
 661
 662    # IO
 663
 664    def to_file(self, dbfile, doccount):
 665        self._pad_arrays(doccount)
 666        fieldnames = list(self.lengths.keys())
 667
 668        dbfile.write(self.magic)
 669        dbfile.write_int(1)  # Format version number
 670        dbfile.write_uint(doccount)  # Number of documents
 671        dbfile.write_ushort(len(self.lengths))  # Number of fields
 672
 673        # Write per-field info
 674        for fieldname in fieldnames:
 675            dbfile.write_string(fieldname.encode('utf-8'))  # Fieldname
 676            dbfile.write_long(self.field_length(fieldname))
 677            dbfile.write_byte(length_to_byte(self.min_field_length(fieldname)))
 678            dbfile.write_byte(length_to_byte(self.max_field_length(fieldname)))
 679
 680        # Write byte arrays
 681        for fieldname in fieldnames:
 682            dbfile.write_array(self.lengths[fieldname])
 683        dbfile.close()
 684
 685    @classmethod
 686    def from_file(cls, dbfile, doccount=None):
 687        obj = cls()
 688        obj._read_header(dbfile, doccount)
 689        for fieldname, start in iteritems(obj.starts):
 690            obj.lengths[fieldname] = dbfile.get_array(start, "B", obj._count)
 691        dbfile.close()
 692        return obj
 693
 694    # Get
 695
 696    def doc_field_length(self, docnum, fieldname, default=0):
 697        try:
 698            arry = self.lengths[fieldname]
 699        except KeyError:
 700            return default
 701        if docnum >= len(arry):
 702            return default
 703        return byte_to_length(arry[docnum])
 704
 705    # Min/max cache setup -- not meant to be called while adding
 706
 707    def _minmax(self, fieldname, op, cache):
 708        if fieldname in cache:
 709            return cache[fieldname]
 710        else:
 711            ls = self.lengths[fieldname]
 712            if ls:
 713                result = byte_to_length(op(ls))
 714            else:
 715                result = 0
 716            cache[fieldname] = result
 717            return result
 718
 719    def min_field_length(self, fieldname):
 720        return self._minmax(fieldname, min, self.minlens)
 721
 722    def max_field_length(self, fieldname):
 723        return self._minmax(fieldname, max, self.maxlens)
 724
 725    # Add
 726
 727    def _create_field(self, fieldname, docnum):
 728        dc = max(self._count, docnum + 1)
 729        self.lengths[fieldname] = array("B", (0 for _ in xrange(dc)))
 730        self._count = dc
 731
 732    def _pad_arrays(self, doccount):
 733        # Pad out arrays to full length
 734        for fieldname in self.lengths.keys():
 735            arry = self.lengths[fieldname]
 736            if len(arry) < doccount:
 737                for _ in xrange(doccount - len(arry)):
 738                    arry.append(0)
 739        self._count = doccount
 740
 741    def add(self, docnum, fieldname, length):
 742        lengths = self.lengths
 743        if length:
 744            if fieldname not in lengths:
 745                self._create_field(fieldname, docnum)
 746
 747            arry = self.lengths[fieldname]
 748            count = docnum + 1
 749            if len(arry) < count:
 750                for _ in xrange(count - len(arry)):
 751                    arry.append(0)
 752            if count > self._count:
 753                self._count = count
 754            byte = length_to_byte(length)
 755            arry[docnum] = byte
 756            self.totals[fieldname] += length
 757
 758    def add_other(self, other):
 759        lengths = self.lengths
 760        totals = self.totals
 761        doccount = self._count
 762        for fname in other.lengths:
 763            if fname not in lengths:
 764                lengths[fname] = array("B")
 765        self._pad_arrays(doccount)
 766
 767        for fname in other.lengths:
 768            lengths[fname].extend(other.lengths[fname])
 769        self._count = doccount + other._count
 770        self._pad_arrays(self._count)
 771
 772        for fname in other.totals:
 773            totals[fname] += other.totals[fname]
 774
 775
 776class OnDiskLengths(LengthsBase):
 777    def __init__(self, dbfile, doccount=None):
 778        LengthsBase.__init__(self)
 779        self.dbfile = dbfile
 780        self._read_header(dbfile, doccount)
 781
 782    def doc_field_length(self, docnum, fieldname, default=0):
 783        try:
 784            start = self.starts[fieldname]
 785        except KeyError:
 786            return default
 787        return byte_to_length(self.dbfile.get_byte(start + docnum))
 788
 789    def close(self):
 790        self.dbfile.close()
 791
 792
 793# Stored fields
 794
 795_stored_pointer_struct = Struct("!qI")  # offset, length
 796stored_pointer_size = _stored_pointer_struct.size
 797pack_stored_pointer = _stored_pointer_struct.pack
 798unpack_stored_pointer = _stored_pointer_struct.unpack
 799
 800
 801class StoredFieldWriter(object):
 802    def __init__(self, dbfile):
 803        self.dbfile = dbfile
 804        self.length = 0
 805        self.directory = []
 806
 807        self.dbfile.write_long(0)
 808        self.dbfile.write_uint(0)
 809
 810        self.names = []
 811        self.name_map = {}
 812
 813    def add(self, vdict):
 814        f = self.dbfile
 815        names = self.names
 816        name_map = self.name_map
 817
 818        vlist = [None] * len(names)
 819        for k, v in iteritems(vdict):
 820            if k in name_map:
 821                vlist[name_map[k]] = v
 822            else:
 823                name_map[k] = len(names)
 824                names.append(k)
 825                vlist.append(v)
 826
 827        vstring = dumps(tuple(vlist), -1)[2:-1]
 828        self.length += 1
 829        self.directory.append(pack_stored_pointer(f.tell(), len(vstring)))
 830        f.write(vstring)
 831
 832    def add_reader(self, sfreader):
 833        add = self.add
 834        for vdict in sfreader:
 835            add(vdict)
 836
 837    def close(self):
 838        f = self.dbfile
 839        dirpos = f.tell()
 840        f.write_pickle(self.names)
 841        for pair in self.directory:
 842            f.write(pair)
 843        f.flush()
 844        f.seek(0)
 845        f.write_long(dirpos)
 846        f.write_uint(self.length)
 847        f.close()
 848
 849
 850class StoredFieldReader(object):
 851    def __init__(self, dbfile):
 852        self.dbfile = dbfile
 853
 854        dbfile.seek(0)
 855        dirpos = dbfile.read_long()
 856        self.length = dbfile.read_uint()
 857        self.basepos = dbfile.tell()
 858
 859        dbfile.seek(dirpos)
 860
 861        nameobj = dbfile.read_pickle()
 862        if isinstance(nameobj, dict):
 863            # Previous versions stored the list of names as a map of names to
 864            # positions... it seemed to make sense at the time...
 865            self.names = [None] * len(nameobj)
 866            for name, pos in iteritems(nameobj):
 867                self.names[pos] = name
 868        else:
 869            self.names = nameobj
 870        self.directory_offset = dbfile.tell()
 871
 872    def close(self):
 873        self.dbfile.close()
 874
 875    def __iter__(self):
 876        dbfile = self.dbfile
 877        names = self.names
 878        lengths = array("I")
 879
 880        dbfile.seek(self.directory_offset)
 881        for i in xrange(self.length):
 882            dbfile.seek(_LONG_SIZE, 1)
 883            lengths.append(dbfile.read_uint())
 884
 885        dbfile.seek(self.basepos)
 886        for length in lengths:
 887            vlist = loads(dbfile.read(length) + b("."))
 888            vdict = dict((names[i], vlist[i]) for i in xrange(len(vlist))
 889                     if vlist[i] is not None)
 890            yield vdict
 891
 892    def __getitem__(self, num):
 893        if num > self.length - 1:
 894            raise IndexError("Tried to get document %s, file has %s"
 895                             % (num, self.length))
 896
 897        dbfile = self.dbfile
 898        start = self.directory_offset + num * stored_pointer_size
 899        dbfile.seek(start)
 900        ptr = dbfile.read(stored_pointer_size)
 901        if len(ptr) != stored_pointer_size:
 902            raise Exception("Error reading %r @%s %s < %s"
 903                            % (dbfile, start, len(ptr), stored_pointer_size))
 904        position, length = unpack_stored_pointer(ptr)
 905        dbfile.seek(position)
 906        vlist = loads(dbfile.read(length) + b("."))
 907
 908        names = self.names
 909        # Recreate a dictionary by putting the field names and values back
 910        # together by position. We can't just use dict(zip(...)) because we
 911        # want to filter out the None values.
 912        vdict = dict((names[i], vlist[i]) for i in xrange(len(vlist))
 913                     if vlist[i] is not None)
 914        return vdict
 915
 916
 917# Segment object
 918
 919class W2Segment(base.Segment):
 920    def __init__(self, indexname, doccount=0, segid=None, deleted=None):
 921        """
 922        :param name: The name of the segment (the Index object computes this
 923            from its name and the generation).
 924        :param doccount: The maximum document number in the segment.
 925        :param term_count: Total count of all terms in all documents.
 926        :param deleted: A set of deleted document numbers, or None if no
 927            deleted documents exist in this segment.
 928        """
 929
 930        assert isinstance(indexname, string_type)
 931        self.indexname = indexname
 932        assert isinstance(doccount, integer_types)
 933        self.doccount = doccount
 934        self.segid = self._random_id() if segid is None else segid
 935        self.deleted = deleted
 936        self.compound = False
 937
 938    def codec(self, **kwargs):
 939        return W2Codec(**kwargs)
 940
 941    def doc_count_all(self):
 942        return self.doccount
 943
 944    def doc_count(self):
 945        return self.doccount - self.deleted_count()
 946
 947    def has_deletions(self):
 948        return self.deleted_count() > 0
 949
 950    def deleted_count(self):
 951        if self.deleted is None:
 952            return 0
 953        return len(self.deleted)
 954
 955    def delete_document(self, docnum, delete=True):
 956        if delete:
 957            if self.deleted is None:
 958                self.deleted = set()
 959            self.deleted.add(docnum)
 960        elif self.deleted is not None and docnum in self.deleted:
 961            self.deleted.clear(docnum)
 962
 963    def is_deleted(self, docnum):
 964        if self.deleted is None:
 965            return False
 966        return docnum in self.deleted
 967
 968
 969# Posting blocks
 970
 971class W2Block(base.BlockBase):
 972    magic = b("Blk3")
 973
 974    infokeys = ("count", "maxid", "maxweight", "minlength", "maxlength",
 975                "idcode", "compression", "idslen", "weightslen")
 976
 977    def to_file(self, postfile, compression=3):
 978        ids = self.ids
 979        idcode, idstring = minimize_ids(ids, self.stringids, compression)
 980        wtstring = minimize_weights(self.weights, compression)
 981        vstring = minimize_values(self.postingsize, self.values, compression)
 982
 983        info = (len(ids), ids[-1], self.maxweight,
 984                length_to_byte(self.minlength), length_to_byte(self.maxlength),
 985                idcode, compression, len(idstring), len(wtstring))
 986        infostring = dumps(info, -1)
 987
 988        # Offset to next block
 989        postfile.write_uint(len(infostring) + len(idstring) + len(wtstring)
 990                            + len(vstring))
 991        # Block contents
 992        postfile.write(infostring)
 993        postfile.write(idstring)
 994        postfile.write(wtstring)
 995        postfile.write(vstring)
 996
 997    @classmethod
 998    def from_file(cls, postfile, postingsize, stringids=False):
 999        block = cls(postingsize, stringids=stringids)
1000        block.postfile = postfile
1001
1002        delta = postfile.read_uint()
1003        block.nextoffset = postfile.tell() + delta
1004        info = postfile.read_pickle()
1005        block.dataoffset = postfile.tell()
1006
1007        for key, value in zip(cls.infokeys, info):
1008            if key in ("minlength", "maxlength"):
1009                value = byte_to_length(value)
1010            setattr(block, key, value)
1011
1012        return block
1013
1014    def read_ids(self):
1015        offset = self.dataoffset
1016        self.postfile.seek(offset)
1017        idstring = self.postfile.read(self.idslen)
1018        ids = deminimize_ids(self.idcode, self.count, idstring,
1019                             self.compression)
1020        self.ids = ids
1021        return ids
1022
1023    def read_weights(self):
1024        if self.weightslen == 0:
1025            weights = [1.0] * self.count
1026        else:
1027            offset = self.dataoffset + self.idslen
1028            self.postfile.seek(offset)
1029            wtstring = self.postfile.read(self.weightslen)
1030            weights = deminimize_weights(self.count, wtstring,
1031                                         self.compression)
1032        self.weights = weights
1033        return weights
1034
1035    def read_values(self):
1036        postingsize = self.postingsize
1037        if postingsize == 0:
1038            values = [None] * self.count
1039        else:
1040            offset = self.dataoffset + self.idslen + self.weightslen
1041            self.postfile.seek(offset)
1042            vstring = self.postfile.read(self.nextoffset - offset)
1043            values = deminimize_values(postingsize, self.count, vstring,
1044                                       self.compression)
1045        self.values = values
1046        return values
1047
1048