PageRenderTime 58ms CodeModel.GetById 20ms app.highlight 32ms RepoModel.GetById 1ms app.codeStats 0ms

/vendor/labix.org/v2/mgo/socket.go

http://github.com/bradfitz/camlistore
Go | 655 lines | 517 code | 77 blank | 61 comment | 100 complexity | 1b311b1161293507ddca9e87e915e904 MD5 | raw file
  1// mgo - MongoDB driver for Go
  2//
  3// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net>
  4//
  5// All rights reserved.
  6//
  7// Redistribution and use in source and binary forms, with or without
  8// modification, are permitted provided that the following conditions are met:
  9//
 10// 1. Redistributions of source code must retain the above copyright notice, this
 11//    list of conditions and the following disclaimer.
 12// 2. Redistributions in binary form must reproduce the above copyright notice,
 13//    this list of conditions and the following disclaimer in the documentation
 14//    and/or other materials provided with the distribution.
 15//
 16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 17// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 18// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 19// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
 20// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 21// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 22// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 23// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 24// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 25// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 26
 27package mgo
 28
 29import (
 30	"errors"
 31	"labix.org/v2/mgo/bson"
 32	"net"
 33	"sync"
 34	"time"
 35)
 36
 37type replyFunc func(err error, reply *replyOp, docNum int, docData []byte)
 38
 39type mongoSocket struct {
 40	sync.Mutex
 41	server        *mongoServer // nil when cached
 42	conn          net.Conn
 43	timeout       time.Duration
 44	addr          string // For debugging only.
 45	nextRequestId uint32
 46	replyFuncs    map[uint32]replyFunc
 47	references    int
 48	auth          []authInfo
 49	logout        []authInfo
 50	cachedNonce   string
 51	gotNonce      sync.Cond
 52	dead          error
 53	serverInfo    *mongoServerInfo
 54}
 55
 56type queryOpFlags uint32
 57
 58const (
 59	_ queryOpFlags = 1 << iota
 60	flagTailable
 61	flagSlaveOk
 62	flagLogReplay
 63	flagNoCursorTimeout
 64	flagAwaitData
 65)
 66
 67type queryOp struct {
 68	collection string
 69	query      interface{}
 70	skip       int32
 71	limit      int32
 72	selector   interface{}
 73	flags      queryOpFlags
 74	replyFunc  replyFunc
 75
 76	options    queryWrapper
 77	hasOptions bool
 78	serverTags []bson.D
 79}
 80
 81type queryWrapper struct {
 82	Query          interface{} "$query"
 83	OrderBy        interface{} "$orderby,omitempty"
 84	Hint           interface{} "$hint,omitempty"
 85	Explain        bool        "$explain,omitempty"
 86	Snapshot       bool        "$snapshot,omitempty"
 87	ReadPreference bson.D      "$readPreference,omitempty"
 88}
 89
 90func (op *queryOp) finalQuery(socket *mongoSocket) interface{} {
 91	if op.flags&flagSlaveOk != 0 && len(op.serverTags) > 0 && socket.ServerInfo().Mongos {
 92		op.hasOptions = true
 93		op.options.ReadPreference = bson.D{{"mode", "secondaryPreferred"}, {"tags", op.serverTags}}
 94	}
 95	if op.hasOptions {
 96		if op.query == nil {
 97			var empty bson.D
 98			op.options.Query = empty
 99		} else {
100			op.options.Query = op.query
101		}
102		debugf("final query is %#v\n", &op.options)
103		return &op.options
104	}
105	return op.query
106}
107
108type getMoreOp struct {
109	collection string
110	limit      int32
111	cursorId   int64
112	replyFunc  replyFunc
113}
114
115type replyOp struct {
116	flags     uint32
117	cursorId  int64
118	firstDoc  int32
119	replyDocs int32
120}
121
122type insertOp struct {
123	collection string        // "database.collection"
124	documents  []interface{} // One or more documents to insert
125}
126
127type updateOp struct {
128	collection string // "database.collection"
129	selector   interface{}
130	update     interface{}
131	flags      uint32
132}
133
134type deleteOp struct {
135	collection string // "database.collection"
136	selector   interface{}
137	flags      uint32
138}
139
140type killCursorsOp struct {
141	cursorIds []int64
142}
143
144type requestInfo struct {
145	bufferPos int
146	replyFunc replyFunc
147}
148
149func newSocket(server *mongoServer, conn net.Conn, timeout time.Duration) *mongoSocket {
150	socket := &mongoSocket{
151		conn:       conn,
152		addr:       server.Addr,
153		server:     server,
154		replyFuncs: make(map[uint32]replyFunc),
155	}
156	socket.gotNonce.L = &socket.Mutex
157	if err := socket.InitialAcquire(server.Info(), timeout); err != nil {
158		panic("newSocket: InitialAcquire returned error: " + err.Error())
159	}
160	stats.socketsAlive(+1)
161	debugf("Socket %p to %s: initialized", socket, socket.addr)
162	socket.resetNonce()
163	go socket.readLoop()
164	return socket
165}
166
167// Server returns the server that the socket is associated with.
168// It returns nil while the socket is cached in its respective server.
169func (socket *mongoSocket) Server() *mongoServer {
170	socket.Lock()
171	server := socket.server
172	socket.Unlock()
173	return server
174}
175
176// ServerInfo returns details for the server at the time the socket
177// was initially acquired.
178func (socket *mongoSocket) ServerInfo() *mongoServerInfo {
179	socket.Lock()
180	serverInfo := socket.serverInfo
181	socket.Unlock()
182	return serverInfo
183}
184
185// InitialAcquire obtains the first reference to the socket, either
186// right after the connection is made or once a recycled socket is
187// being put back in use.
188func (socket *mongoSocket) InitialAcquire(serverInfo *mongoServerInfo, timeout time.Duration) error {
189	socket.Lock()
190	if socket.references > 0 {
191		panic("Socket acquired out of cache with references")
192	}
193	if socket.dead != nil {
194		socket.Unlock()
195		return socket.dead
196	}
197	socket.references++
198	socket.serverInfo = serverInfo
199	socket.timeout = timeout
200	stats.socketsInUse(+1)
201	stats.socketRefs(+1)
202	socket.Unlock()
203	return nil
204}
205
206// Acquire obtains an additional reference to the socket.
207// The socket will only be recycled when it's released as many
208// times as it's been acquired.
209func (socket *mongoSocket) Acquire() (info *mongoServerInfo) {
210	socket.Lock()
211	if socket.references == 0 {
212		panic("Socket got non-initial acquire with references == 0")
213	}
214	// We'll track references to dead sockets as well.
215	// Caller is still supposed to release the socket.
216	socket.references++
217	stats.socketRefs(+1)
218	serverInfo := socket.serverInfo
219	socket.Unlock()
220	return serverInfo
221}
222
223// Release decrements a socket reference. The socket will be
224// recycled once its released as many times as it's been acquired.
225func (socket *mongoSocket) Release() {
226	socket.Lock()
227	if socket.references == 0 {
228		panic("socket.Release() with references == 0")
229	}
230	socket.references--
231	stats.socketRefs(-1)
232	if socket.references == 0 {
233		stats.socketsInUse(-1)
234		server := socket.server
235		socket.Unlock()
236		socket.LogoutAll()
237		// If the socket is dead server is nil.
238		if server != nil {
239			server.RecycleSocket(socket)
240		}
241	} else {
242		socket.Unlock()
243	}
244}
245
246// SetTimeout changes the timeout used on socket operations.
247func (socket *mongoSocket) SetTimeout(d time.Duration) {
248	socket.Lock()
249	socket.timeout = d
250	socket.Unlock()
251}
252
253type deadlineType int
254
255const (
256	readDeadline  deadlineType = 1
257	writeDeadline deadlineType = 2
258)
259
260func (socket *mongoSocket) updateDeadline(which deadlineType) {
261	var when time.Time
262	if socket.timeout > 0 {
263		when = time.Now().Add(socket.timeout)
264	}
265	whichstr := ""
266	switch which {
267	case readDeadline | writeDeadline:
268		whichstr = "read/write"
269		socket.conn.SetDeadline(when)
270	case readDeadline:
271		whichstr = "read"
272		socket.conn.SetReadDeadline(when)
273	case writeDeadline:
274		whichstr = "write"
275		socket.conn.SetWriteDeadline(when)
276	default:
277		panic("invalid parameter to updateDeadline")
278	}
279	debugf("Socket %p to %s: updated %s deadline to %s ahead (%s)", socket, socket.addr, whichstr, socket.timeout, when)
280}
281
282// Close terminates the socket use.
283func (socket *mongoSocket) Close() {
284	socket.kill(errors.New("Closed explicitly"), false)
285}
286
287func (socket *mongoSocket) kill(err error, abend bool) {
288	socket.Lock()
289	if socket.dead != nil {
290		debugf("Socket %p to %s: killed again: %s (previously: %s)", socket, socket.addr, err.Error(), socket.dead.Error())
291		socket.Unlock()
292		return
293	}
294	logf("Socket %p to %s: closing: %s (abend=%v)", socket, socket.addr, err.Error(), abend)
295	socket.dead = err
296	socket.conn.Close()
297	stats.socketsAlive(-1)
298	replyFuncs := socket.replyFuncs
299	socket.replyFuncs = make(map[uint32]replyFunc)
300	server := socket.server
301	socket.server = nil
302	socket.Unlock()
303	for _, f := range replyFuncs {
304		logf("Socket %p to %s: notifying replyFunc of closed socket: %s", socket, socket.addr, err.Error())
305		f(err, nil, -1, nil)
306	}
307	if abend {
308		server.AbendSocket(socket)
309	}
310}
311
312func (socket *mongoSocket) SimpleQuery(op *queryOp) (data []byte, err error) {
313	var mutex sync.Mutex
314	var replyData []byte
315	var replyErr error
316	mutex.Lock()
317	op.replyFunc = func(err error, reply *replyOp, docNum int, docData []byte) {
318		replyData = docData
319		replyErr = err
320		mutex.Unlock()
321	}
322	err = socket.Query(op)
323	if err != nil {
324		return nil, err
325	}
326	mutex.Lock() // Wait.
327	if replyErr != nil {
328		return nil, replyErr
329	}
330	return replyData, nil
331}
332
333func (socket *mongoSocket) Query(ops ...interface{}) (err error) {
334
335	if lops := socket.flushLogout(); len(lops) > 0 {
336		ops = append(lops, ops...)
337	}
338
339	buf := make([]byte, 0, 256)
340
341	// Serialize operations synchronously to avoid interrupting
342	// other goroutines while we can't really be sending data.
343	// Also, record id positions so that we can compute request
344	// ids at once later with the lock already held.
345	requests := make([]requestInfo, len(ops))
346	requestCount := 0
347
348	for _, op := range ops {
349		debugf("Socket %p to %s: serializing op: %#v", socket, socket.addr, op)
350		start := len(buf)
351		var replyFunc replyFunc
352		switch op := op.(type) {
353
354		case *updateOp:
355			buf = addHeader(buf, 2001)
356			buf = addInt32(buf, 0) // Reserved
357			buf = addCString(buf, op.collection)
358			buf = addInt32(buf, int32(op.flags))
359			debugf("Socket %p to %s: serializing selector document: %#v", socket, socket.addr, op.selector)
360			buf, err = addBSON(buf, op.selector)
361			if err != nil {
362				return err
363			}
364			debugf("Socket %p to %s: serializing update document: %#v", socket, socket.addr, op.update)
365			buf, err = addBSON(buf, op.update)
366			if err != nil {
367				return err
368			}
369
370		case *insertOp:
371			buf = addHeader(buf, 2002)
372			buf = addInt32(buf, 0) // Reserved
373			buf = addCString(buf, op.collection)
374			for _, doc := range op.documents {
375				debugf("Socket %p to %s: serializing document for insertion: %#v", socket, socket.addr, doc)
376				buf, err = addBSON(buf, doc)
377				if err != nil {
378					return err
379				}
380			}
381
382		case *queryOp:
383			buf = addHeader(buf, 2004)
384			buf = addInt32(buf, int32(op.flags))
385			buf = addCString(buf, op.collection)
386			buf = addInt32(buf, op.skip)
387			buf = addInt32(buf, op.limit)
388			buf, err = addBSON(buf, op.finalQuery(socket))
389			if err != nil {
390				return err
391			}
392			if op.selector != nil {
393				buf, err = addBSON(buf, op.selector)
394				if err != nil {
395					return err
396				}
397			}
398			replyFunc = op.replyFunc
399
400		case *getMoreOp:
401			buf = addHeader(buf, 2005)
402			buf = addInt32(buf, 0) // Reserved
403			buf = addCString(buf, op.collection)
404			buf = addInt32(buf, op.limit)
405			buf = addInt64(buf, op.cursorId)
406			replyFunc = op.replyFunc
407
408		case *deleteOp:
409			buf = addHeader(buf, 2006)
410			buf = addInt32(buf, 0) // Reserved
411			buf = addCString(buf, op.collection)
412			buf = addInt32(buf, int32(op.flags))
413			debugf("Socket %p to %s: serializing selector document: %#v", socket, socket.addr, op.selector)
414			buf, err = addBSON(buf, op.selector)
415			if err != nil {
416				return err
417			}
418
419		case *killCursorsOp:
420			buf = addHeader(buf, 2007)
421			buf = addInt32(buf, 0) // Reserved
422			buf = addInt32(buf, int32(len(op.cursorIds)))
423			for _, cursorId := range op.cursorIds {
424				buf = addInt64(buf, cursorId)
425			}
426
427		default:
428			panic("Internal error: unknown operation type")
429		}
430
431		setInt32(buf, start, int32(len(buf)-start))
432
433		if replyFunc != nil {
434			request := &requests[requestCount]
435			request.replyFunc = replyFunc
436			request.bufferPos = start
437			requestCount++
438		}
439	}
440
441	// Buffer is ready for the pipe.  Lock, allocate ids, and enqueue.
442
443	socket.Lock()
444	if socket.dead != nil {
445		socket.Unlock()
446		debugf("Socket %p to %s: failing query, already closed: %s", socket, socket.addr, socket.dead.Error())
447		// XXX This seems necessary in case the session is closed concurrently
448		// with a query being performed, but it's not yet tested:
449		for i := 0; i != requestCount; i++ {
450			request := &requests[i]
451			if request.replyFunc != nil {
452				request.replyFunc(socket.dead, nil, -1, nil)
453			}
454		}
455		return socket.dead
456	}
457
458	wasWaiting := len(socket.replyFuncs) > 0
459
460	// Reserve id 0 for requests which should have no responses.
461	requestId := socket.nextRequestId + 1
462	if requestId == 0 {
463		requestId++
464	}
465	socket.nextRequestId = requestId + uint32(requestCount)
466	for i := 0; i != requestCount; i++ {
467		request := &requests[i]
468		setInt32(buf, request.bufferPos+4, int32(requestId))
469		socket.replyFuncs[requestId] = request.replyFunc
470		requestId++
471	}
472
473	debugf("Socket %p to %s: sending %d op(s) (%d bytes)", socket, socket.addr, len(ops), len(buf))
474	stats.sentOps(len(ops))
475
476	socket.updateDeadline(writeDeadline)
477	_, err = socket.conn.Write(buf)
478	if !wasWaiting && requestCount > 0 {
479		socket.updateDeadline(readDeadline)
480	}
481	socket.Unlock()
482	return err
483}
484
485func fill(r net.Conn, b []byte) error {
486	l := len(b)
487	n, err := r.Read(b)
488	for n != l && err == nil {
489		var ni int
490		ni, err = r.Read(b[n:])
491		n += ni
492	}
493	return err
494}
495
496// Estimated minimum cost per socket: 1 goroutine + memory for the largest
497// document ever seen.
498func (socket *mongoSocket) readLoop() {
499	p := make([]byte, 36) // 16 from header + 20 from OP_REPLY fixed fields
500	s := make([]byte, 4)
501	conn := socket.conn // No locking, conn never changes.
502	for {
503		// XXX Handle timeouts, , etc
504		err := fill(conn, p)
505		if err != nil {
506			socket.kill(err, true)
507			return
508		}
509
510		totalLen := getInt32(p, 0)
511		responseTo := getInt32(p, 8)
512		opCode := getInt32(p, 12)
513
514		// Don't use socket.server.Addr here.  socket is not
515		// locked and socket.server may go away.
516		debugf("Socket %p to %s: got reply (%d bytes)", socket, socket.addr, totalLen)
517
518		_ = totalLen
519
520		if opCode != 1 {
521			socket.kill(errors.New("opcode != 1, corrupted data?"), true)
522			return
523		}
524
525		reply := replyOp{
526			flags:     uint32(getInt32(p, 16)),
527			cursorId:  getInt64(p, 20),
528			firstDoc:  getInt32(p, 28),
529			replyDocs: getInt32(p, 32),
530		}
531
532		stats.receivedOps(+1)
533		stats.receivedDocs(int(reply.replyDocs))
534
535		socket.Lock()
536		replyFunc, replyFuncFound := socket.replyFuncs[uint32(responseTo)]
537		socket.Unlock()
538
539		if replyFunc != nil && reply.replyDocs == 0 {
540			replyFunc(nil, &reply, -1, nil)
541		} else {
542			for i := 0; i != int(reply.replyDocs); i++ {
543				err := fill(conn, s)
544				if err != nil {
545					socket.kill(err, true)
546					return
547				}
548
549				b := make([]byte, int(getInt32(s, 0)))
550
551				// copy(b, s) in an efficient way.
552				b[0] = s[0]
553				b[1] = s[1]
554				b[2] = s[2]
555				b[3] = s[3]
556
557				err = fill(conn, b[4:])
558				if err != nil {
559					socket.kill(err, true)
560					return
561				}
562
563				if globalDebug && globalLogger != nil {
564					m := bson.M{}
565					if err := bson.Unmarshal(b, m); err == nil {
566						debugf("Socket %p to %s: received document: %#v", socket, socket.addr, m)
567					}
568				}
569
570				if replyFunc != nil {
571					replyFunc(nil, &reply, i, b)
572				}
573
574				// XXX Do bound checking against totalLen.
575			}
576		}
577
578		// Only remove replyFunc after iteration, so that kill() will see it.
579		socket.Lock()
580		if replyFuncFound {
581			delete(socket.replyFuncs, uint32(responseTo))
582		}
583		if len(socket.replyFuncs) == 0 {
584			// Nothing else to read for now. Disable deadline.
585			socket.conn.SetReadDeadline(time.Time{})
586		} else {
587			socket.updateDeadline(readDeadline)
588		}
589		socket.Unlock()
590
591		// XXX Do bound checking against totalLen.
592	}
593}
594
595var emptyHeader = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
596
597func addHeader(b []byte, opcode int) []byte {
598	i := len(b)
599	b = append(b, emptyHeader...)
600	// Enough for current opcodes.
601	b[i+12] = byte(opcode)
602	b[i+13] = byte(opcode >> 8)
603	return b
604}
605
606func addInt32(b []byte, i int32) []byte {
607	return append(b, byte(i), byte(i>>8), byte(i>>16), byte(i>>24))
608}
609
610func addInt64(b []byte, i int64) []byte {
611	return append(b, byte(i), byte(i>>8), byte(i>>16), byte(i>>24),
612		byte(i>>32), byte(i>>40), byte(i>>48), byte(i>>56))
613}
614
615func addCString(b []byte, s string) []byte {
616	b = append(b, []byte(s)...)
617	b = append(b, 0)
618	return b
619}
620
621func addBSON(b []byte, doc interface{}) ([]byte, error) {
622	if doc == nil {
623		return append(b, 5, 0, 0, 0, 0), nil
624	}
625	data, err := bson.Marshal(doc)
626	if err != nil {
627		return b, err
628	}
629	return append(b, data...), nil
630}
631
632func setInt32(b []byte, pos int, i int32) {
633	b[pos] = byte(i)
634	b[pos+1] = byte(i >> 8)
635	b[pos+2] = byte(i >> 16)
636	b[pos+3] = byte(i >> 24)
637}
638
639func getInt32(b []byte, pos int) int32 {
640	return (int32(b[pos+0])) |
641		(int32(b[pos+1]) << 8) |
642		(int32(b[pos+2]) << 16) |
643		(int32(b[pos+3]) << 24)
644}
645
646func getInt64(b []byte, pos int) int64 {
647	return (int64(b[pos+0])) |
648		(int64(b[pos+1]) << 8) |
649		(int64(b[pos+2]) << 16) |
650		(int64(b[pos+3]) << 24) |
651		(int64(b[pos+4]) << 32) |
652		(int64(b[pos+5]) << 40) |
653		(int64(b[pos+6]) << 48) |
654		(int64(b[pos+7]) << 56)
655}