PageRenderTime 53ms CodeModel.GetById 9ms app.highlight 39ms RepoModel.GetById 1ms app.codeStats 0ms

/mgo/socket.go

https://bitbucket.org/zuko_uno/unisearch
Go | 529 lines | 411 code | 63 blank | 55 comment | 86 complexity | 26e0bf4daff057c68ff22b6eb126c019 MD5 | raw file
  1// mgo - MongoDB driver for Go
  2// 
  3// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net>
  4// 
  5// All rights reserved.
  6//
  7// Redistribution and use in source and binary forms, with or without
  8// modification, are permitted provided that the following conditions are met: 
  9// 
 10// 1. Redistributions of source code must retain the above copyright notice, this
 11//    list of conditions and the following disclaimer. 
 12// 2. Redistributions in binary form must reproduce the above copyright notice,
 13//    this list of conditions and the following disclaimer in the documentation
 14//    and/or other materials provided with the distribution. 
 15// 
 16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 17// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 18// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 19// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
 20// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 21// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 22// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 23// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 24// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 25// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 26
 27package mgo
 28
 29import (
 30	"errors"
 31	"unisearch/mgo/bson"
 32	"net"
 33	"sync"
 34)
 35
 36type replyFunc func(err error, reply *replyOp, docNum int, docData []byte)
 37
 38type mongoSocket struct {
 39	sync.Mutex
 40	server        *mongoServer // nil when cached
 41	conn          *net.TCPConn
 42	addr          string // For debugging only.
 43	nextRequestId uint32
 44	replyFuncs    map[uint32]replyFunc
 45	references    int
 46	auth          []authInfo
 47	logout        []authInfo
 48	cachedNonce   string
 49	gotNonce      sync.Cond
 50	dead          error
 51}
 52
 53type queryOp struct {
 54	collection string
 55	query      interface{}
 56	skip       int32
 57	limit      int32
 58	selector   interface{}
 59	flags      uint32
 60	replyFunc  replyFunc
 61}
 62
 63type getMoreOp struct {
 64	collection string
 65	limit      int32
 66	cursorId   int64
 67	replyFunc  replyFunc
 68}
 69
 70type replyOp struct {
 71	flags     uint32
 72	cursorId  int64
 73	firstDoc  int32
 74	replyDocs int32
 75}
 76
 77type insertOp struct {
 78	collection string        // "database.collection"
 79	documents  []interface{} // One or more documents to insert
 80}
 81
 82type updateOp struct {
 83	collection string // "database.collection"
 84	selector   interface{}
 85	update     interface{}
 86	flags      uint32
 87}
 88
 89type deleteOp struct {
 90	collection string // "database.collection"
 91	selector   interface{}
 92	flags      uint32
 93}
 94
 95type requestInfo struct {
 96	bufferPos int
 97	replyFunc replyFunc
 98}
 99
100func newSocket(server *mongoServer, conn *net.TCPConn) *mongoSocket {
101	socket := &mongoSocket{conn: conn, addr: server.Addr}
102	socket.gotNonce.L = &socket.Mutex
103	socket.replyFuncs = make(map[uint32]replyFunc)
104	socket.server = server
105	if err := socket.InitialAcquire(); err != nil {
106		panic("newSocket: InitialAcquire returned error: " + err.Error())
107	}
108	stats.socketsAlive(+1)
109	debugf("Socket %p to %s: initialized", socket, socket.addr)
110	socket.resetNonce()
111	go socket.readLoop()
112	return socket
113}
114
115// InitialAcquire obtains the first reference to the socket, either
116// right after the connection is made or once a recycled socket is
117// being put back in use.
118func (socket *mongoSocket) InitialAcquire() error {
119	socket.Lock()
120	if socket.references > 0 {
121		panic("Socket acquired out of cache with references")
122	}
123	if socket.dead != nil {
124		socket.Unlock()
125		return socket.dead
126	}
127	socket.references++
128	stats.socketsInUse(+1)
129	stats.socketRefs(+1)
130	socket.Unlock()
131	return nil
132}
133
134// Acquire obtains an additional reference to the socket.
135// The socket will only be recycled when it's released as many
136// times as it's been acquired.
137func (socket *mongoSocket) Acquire() (isMaster bool) {
138	socket.Lock()
139	if socket.references == 0 {
140		panic("Socket got non-initial acquire with references == 0")
141	}
142	socket.references++
143	stats.socketRefs(+1)
144	// We'll track references to dead sockets as well.
145	// Caller is still supposed to release the socket.
146	if socket.dead == nil {
147		isMaster = socket.server.IsMaster()
148	}
149	socket.Unlock()
150	return isMaster
151}
152
153// Release decrements a socket reference. The socket will be
154// recycled once its released as many times as it's been acquired.
155func (socket *mongoSocket) Release() {
156	socket.Lock()
157	if socket.references == 0 {
158		panic("socket.Release() with references == 0")
159	}
160	socket.references--
161	stats.socketRefs(-1)
162	if socket.references == 0 {
163		stats.socketsInUse(-1)
164		server := socket.server
165		socket.Unlock()
166		socket.LogoutAll()
167		// If the socket is dead server is nil.
168		if server != nil {
169			server.RecycleSocket(socket)
170		}
171	} else {
172		socket.Unlock()
173	}
174}
175
176// Close terminates the socket use.
177func (socket *mongoSocket) Close() {
178	socket.kill(errors.New("Closed explicitly"), false)
179}
180
181func (socket *mongoSocket) kill(err error, abend bool) {
182	socket.Lock()
183	if socket.dead != nil {
184		debugf("Socket %p to %s: killed again: %s (previously: %s)", socket, socket.addr, err.Error(), socket.dead.Error())
185		socket.Unlock()
186		return
187	}
188	logf("Socket %p to %s: closing: %s (abend=%v)", socket, socket.addr, err.Error(), abend)
189	socket.dead = err
190	socket.conn.Close()
191	stats.socketsAlive(-1)
192	replyFuncs := socket.replyFuncs
193	socket.replyFuncs = make(map[uint32]replyFunc)
194	server := socket.server
195	socket.server = nil
196	socket.Unlock()
197	for _, f := range replyFuncs {
198		logf("Socket %p to %s: notifying replyFunc of closed socket: %s", socket, socket.addr, err.Error())
199		f(err, nil, -1, nil)
200	}
201	if abend {
202		server.AbendSocket(socket)
203	}
204}
205
206func (socket *mongoSocket) SimpleQuery(op *queryOp) (data []byte, err error) {
207	var mutex sync.Mutex
208	var replyData []byte
209	var replyErr error
210	mutex.Lock()
211	op.replyFunc = func(err error, reply *replyOp, docNum int, docData []byte) {
212		replyData = docData
213		replyErr = err
214		mutex.Unlock()
215	}
216	err = socket.Query(op)
217	if err != nil {
218		return nil, err
219	}
220	mutex.Lock() // Wait.
221	if replyErr != nil {
222		return nil, replyErr
223	}
224	return replyData, nil
225}
226
227func (socket *mongoSocket) Query(ops ...interface{}) (err error) {
228
229	if lops := socket.flushLogout(); len(lops) > 0 {
230		ops = append(lops, ops...)
231	}
232
233	buf := make([]byte, 0, 256)
234
235	// Serialize operations synchronously to avoid interrupting
236	// other goroutines while we can't really be sending data.
237	// Also, record id positions so that we can compute request
238	// ids at once later with the lock already held.
239	requests := make([]requestInfo, len(ops))
240	requestCount := 0
241
242	for _, op := range ops {
243		debugf("Socket %p to %s: serializing op: %#v", socket, socket.addr, op)
244		start := len(buf)
245		var replyFunc replyFunc
246		switch op := op.(type) {
247
248		case *updateOp:
249			buf = addHeader(buf, 2001)
250			buf = addInt32(buf, 0) // Reserved
251			buf = addCString(buf, op.collection)
252			buf = addInt32(buf, int32(op.flags))
253			debugf("Socket %p to %s: serializing selector document: %#v", socket, socket.addr, op.selector)
254			buf, err = addBSON(buf, op.selector)
255			if err != nil {
256				return err
257			}
258			debugf("Socket %p to %s: serializing update document: %#v", socket, socket.addr, op.update)
259			buf, err = addBSON(buf, op.update)
260			if err != nil {
261				return err
262			}
263
264		case *insertOp:
265			buf = addHeader(buf, 2002)
266			buf = addInt32(buf, 0) // Reserved
267			buf = addCString(buf, op.collection)
268			for _, doc := range op.documents {
269				debugf("Socket %p to %s: serializing document for insertion: %#v", socket, socket.addr, doc)
270				buf, err = addBSON(buf, doc)
271				if err != nil {
272					return err
273				}
274			}
275
276		case *queryOp:
277			buf = addHeader(buf, 2004)
278			buf = addInt32(buf, int32(op.flags))
279			buf = addCString(buf, op.collection)
280			buf = addInt32(buf, op.skip)
281			buf = addInt32(buf, op.limit)
282			buf, err = addBSON(buf, op.query)
283			if err != nil {
284				return err
285			}
286			if op.selector != nil {
287				buf, err = addBSON(buf, op.selector)
288				if err != nil {
289					return err
290				}
291			}
292			replyFunc = op.replyFunc
293
294		case *getMoreOp:
295			buf = addHeader(buf, 2005)
296			buf = addInt32(buf, 0) // Reserved
297			buf = addCString(buf, op.collection)
298			buf = addInt32(buf, op.limit)
299			buf = addInt64(buf, op.cursorId)
300			replyFunc = op.replyFunc
301
302		case *deleteOp:
303			buf = addHeader(buf, 2006)
304			buf = addInt32(buf, 0) // Reserved
305			buf = addCString(buf, op.collection)
306			buf = addInt32(buf, int32(op.flags))
307			debugf("Socket %p to %s: serializing selector document: %#v", socket, socket.addr, op.selector)
308			buf, err = addBSON(buf, op.selector)
309			if err != nil {
310				return err
311			}
312
313		default:
314			panic("Internal error: unknown operation type")
315		}
316
317		setInt32(buf, start, int32(len(buf)-start))
318
319		if replyFunc != nil {
320			request := &requests[requestCount]
321			request.replyFunc = replyFunc
322			request.bufferPos = start
323			requestCount++
324		}
325	}
326
327	// Buffer is ready for the pipe.  Lock, allocate ids, and enqueue.
328
329	socket.Lock()
330	if socket.dead != nil {
331		socket.Unlock()
332		debug("Socket %p to %s: failing query, already closed: %s", socket, socket.addr, socket.dead.Error())
333		// XXX This seems necessary in case the session is closed concurrently
334		// with a query being performed, but it's not yet tested:
335		for i := 0; i != requestCount; i++ {
336			request := &requests[i]
337			if request.replyFunc != nil {
338				request.replyFunc(socket.dead, nil, -1, nil)
339			}
340		}
341		return socket.dead
342	}
343
344	// Reserve id 0 for requests which should have no responses.
345	requestId := socket.nextRequestId + 1
346	if requestId == 0 {
347		requestId++
348	}
349	socket.nextRequestId = requestId + uint32(requestCount)
350	for i := 0; i != requestCount; i++ {
351		request := &requests[i]
352		setInt32(buf, request.bufferPos+4, int32(requestId))
353		socket.replyFuncs[requestId] = request.replyFunc
354		requestId++
355	}
356
357	debugf("Socket %p to %s: sending %d op(s) (%d bytes)", socket, socket.addr, len(ops), len(buf))
358	stats.sentOps(len(ops))
359
360	_, err = socket.conn.Write(buf)
361	socket.Unlock()
362	return err
363}
364
365func fill(r *net.TCPConn, b []byte) error {
366	l := len(b)
367	n, err := r.Read(b)
368	for n != l && err == nil {
369		var ni int
370		ni, err = r.Read(b[n:])
371		n += ni
372	}
373	return err
374}
375
376// Estimated minimum cost per socket: 1 goroutine + memory for the largest
377// document ever seen.
378func (socket *mongoSocket) readLoop() {
379	p := make([]byte, 36) // 16 from header + 20 from OP_REPLY fixed fields
380	s := make([]byte, 4)
381	conn := socket.conn // No locking, conn never changes.
382	for {
383		// XXX Handle timeouts, , etc
384		err := fill(conn, p)
385		if err != nil {
386			socket.kill(err, true)
387			return
388		}
389
390		totalLen := getInt32(p, 0)
391		responseTo := getInt32(p, 8)
392		opCode := getInt32(p, 12)
393
394		// Don't use socket.server.Addr here.  socket is not
395		// locked and socket.server may go away.
396		debugf("Socket %p to %s: got reply (%d bytes)", socket, socket.addr, totalLen)
397
398		_ = totalLen
399
400		if opCode != 1 {
401			socket.kill(errors.New("opcode != 1, corrupted data?"), true)
402			return
403		}
404
405		reply := replyOp{
406			flags:     uint32(getInt32(p, 16)),
407			cursorId:  getInt64(p, 20),
408			firstDoc:  getInt32(p, 28),
409			replyDocs: getInt32(p, 32),
410		}
411
412		stats.receivedOps(+1)
413		stats.receivedDocs(int(reply.replyDocs))
414
415		socket.Lock()
416		replyFunc, replyFuncFound := socket.replyFuncs[uint32(responseTo)]
417		socket.Unlock()
418
419		if replyFunc != nil && reply.replyDocs == 0 {
420			replyFunc(nil, &reply, -1, nil)
421		} else {
422			for i := 0; i != int(reply.replyDocs); i++ {
423				err := fill(conn, s)
424				if err != nil {
425					socket.kill(err, true)
426					return
427				}
428
429				b := make([]byte, int(getInt32(s, 0)))
430
431				// copy(b, s) in an efficient way.
432				b[0] = s[0]
433				b[1] = s[1]
434				b[2] = s[2]
435				b[3] = s[3]
436
437				err = fill(conn, b[4:])
438				if err != nil {
439					socket.kill(err, true)
440					return
441				}
442
443				if globalDebug && globalLogger != nil {
444					m := bson.M{}
445					if err := bson.Unmarshal(b, m); err == nil {
446						debugf("Socket %p to %s: received document: %#v", socket, socket.addr, m)
447					}
448				}
449
450				if replyFunc != nil {
451					replyFunc(nil, &reply, i, b)
452				}
453
454				// XXX Do bound checking against totalLen.
455			}
456		}
457
458		// Only remove replyFunc after iteration, so that kill() will see it.
459		socket.Lock()
460		if replyFuncFound {
461			delete(socket.replyFuncs, uint32(responseTo))
462		}
463		socket.Unlock()
464
465		// XXX Do bound checking against totalLen.
466	}
467}
468
469var emptyHeader = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
470
471func addHeader(b []byte, opcode int) []byte {
472	i := len(b)
473	b = append(b, emptyHeader...)
474	// Enough for current opcodes.
475	b[i+12] = byte(opcode)
476	b[i+13] = byte(opcode >> 8)
477	return b
478}
479
480func addInt32(b []byte, i int32) []byte {
481	return append(b, byte(i), byte(i>>8), byte(i>>16), byte(i>>24))
482}
483
484func addInt64(b []byte, i int64) []byte {
485	return append(b, byte(i), byte(i>>8), byte(i>>16), byte(i>>24),
486		byte(i>>32), byte(i>>40), byte(i>>48), byte(i>>56))
487}
488
489func addCString(b []byte, s string) []byte {
490	b = append(b, []byte(s)...)
491	b = append(b, 0)
492	return b
493}
494
495func addBSON(b []byte, doc interface{}) ([]byte, error) {
496	if doc == nil {
497		return append(b, 5, 0, 0, 0, 0), nil
498	}
499	data, err := bson.Marshal(doc)
500	if err != nil {
501		return b, err
502	}
503	return append(b, data...), nil
504}
505
506func setInt32(b []byte, pos int, i int32) {
507	b[pos] = byte(i)
508	b[pos+1] = byte(i >> 8)
509	b[pos+2] = byte(i >> 16)
510	b[pos+3] = byte(i >> 24)
511}
512
513func getInt32(b []byte, pos int) int32 {
514	return (int32(b[pos+0])) |
515		(int32(b[pos+1]) << 8) |
516		(int32(b[pos+2]) << 16) |
517		(int32(b[pos+3]) << 24)
518}
519
520func getInt64(b []byte, pos int) int64 {
521	return (int64(b[pos+0])) |
522		(int64(b[pos+1]) << 8) |
523		(int64(b[pos+2]) << 16) |
524		(int64(b[pos+3]) << 24) |
525		(int64(b[pos+4]) << 32) |
526		(int64(b[pos+5]) << 40) |
527		(int64(b[pos+6]) << 48) |
528		(int64(b[pos+7]) << 56)
529}