PageRenderTime 135ms CodeModel.GetById 112ms app.highlight 19ms RepoModel.GetById 1ms app.codeStats 0ms

/go/mysql/mysql.go

https://code.google.com/p/vitess/
Go | 273 lines | 222 code | 35 blank | 16 comment | 45 complexity | bcc399cb2490412f5bbc920f8692e9e1 MD5 | raw file
  1// Copyright 2012, Google Inc. All rights reserved.
  2// Use of this source code is governed by a BSD-style
  3// license that can be found in the LICENSE file.
  4
  5package mysql
  6
  7/*
  8#cgo pkg-config: gomysql
  9#include <stdlib.h>
 10#include "vtmysql.h"
 11*/
 12import "C"
 13
 14import (
 15	"fmt"
 16	"unsafe"
 17
 18	"code.google.com/p/vitess/go/hack"
 19	"code.google.com/p/vitess/go/mysql/proto"
 20	"code.google.com/p/vitess/go/relog"
 21	"code.google.com/p/vitess/go/sqltypes"
 22)
 23
 24const (
 25	// NOTE(szopa): maxSize used to be 1 << 30, but that causes
 26	// compiler errors in some situations.
 27	maxSize = 1 << 20
 28)
 29
 30func init() {
 31	// This needs to be called before threads begin to spawn.
 32	C.vt_library_init()
 33}
 34
 35type SqlError struct {
 36	Num     int
 37	Message string
 38	Query   string
 39}
 40
 41func NewSqlError(number int, format string, args ...interface{}) *SqlError {
 42	return &SqlError{Num: number, Message: fmt.Sprintf(format, args...)}
 43}
 44
 45func (se *SqlError) Error() string {
 46	if se.Query == "" {
 47		return fmt.Sprintf("%v (errno %v)", se.Message, se.Num)
 48	}
 49	return fmt.Sprintf("%v (errno %v) during query: %s", se.Message, se.Num, se.Query)
 50}
 51
 52func (se *SqlError) Number() int {
 53	return se.Num
 54}
 55
 56func handleError(err *error) {
 57	if x := recover(); x != nil {
 58		terr := x.(*SqlError)
 59		*err = terr
 60	}
 61}
 62
 63type ConnectionParams struct {
 64	Host       string `json:"host"`
 65	Port       int    `json:"port"`
 66	Uname      string `json:"uname"`
 67	Pass       string `json:"pass"`
 68	Dbname     string `json:"dbname"`
 69	UnixSocket string `json:"unix_socket"`
 70	Charset    string `json:"charset"`
 71	Flags      uint64 `json:"flags"`
 72
 73	// the following flags are only used for 'Change Master' command
 74	// for now (along with flags |= 2048 for CLIENT_SSL)
 75	SslCa     string `json:"ssl_ca"`
 76	SslCaPath string `json:"ssl_ca_path"`
 77	SslCert   string `json:"ssl_cert"`
 78	SslKey    string `json:"ssl_key"`
 79}
 80
 81func (c *ConnectionParams) EnableMultiStatements() {
 82	c.Flags |= C.CLIENT_MULTI_STATEMENTS
 83}
 84
 85func (c *ConnectionParams) SslEnabled() bool {
 86	return (c.Flags & C.CLIENT_SSL) != 0
 87}
 88
 89func (c ConnectionParams) Redacted() interface{} {
 90	c.Pass = relog.Redact(c.Pass)
 91	return c
 92}
 93
 94type Connection struct {
 95	c C.VT_CONN
 96}
 97
 98func Connect(params ConnectionParams) (conn *Connection, err error) {
 99	defer handleError(&err)
100
101	host := C.CString(params.Host)
102	defer cfree(host)
103	port := C.uint(params.Port)
104	uname := C.CString(params.Uname)
105	defer cfree(uname)
106	pass := C.CString(params.Pass)
107	defer cfree(pass)
108	dbname := C.CString(params.Dbname)
109	defer cfree(dbname)
110	unix_socket := C.CString(params.UnixSocket)
111	defer cfree(unix_socket)
112	charset := C.CString(params.Charset)
113	defer cfree(charset)
114	flags := C.ulong(params.Flags)
115
116	conn = &Connection{}
117	if C.vt_connect(&conn.c, host, uname, pass, dbname, port, unix_socket, charset, flags) != 0 {
118		defer conn.Close()
119		return nil, conn.lastError("")
120	}
121	return conn, nil
122}
123
124func (conn *Connection) Close() {
125	C.vt_close(&conn.c)
126}
127
128func (conn *Connection) IsClosed() bool {
129	return conn.c.mysql == nil
130}
131
132func (conn *Connection) ExecuteFetch(query string, maxrows int, wantfields bool) (qr *proto.QueryResult, err error) {
133	if conn.IsClosed() {
134		return nil, NewSqlError(2006, "Connection is closed")
135	}
136
137	if C.vt_execute(&conn.c, (*C.char)(hack.StringPointer(query)), C.ulong(len(query)), 0) != 0 {
138		return nil, conn.lastError(query)
139	}
140	defer conn.CloseResult()
141
142	qr = &proto.QueryResult{}
143	qr.RowsAffected = uint64(conn.c.affected_rows)
144	qr.InsertId = uint64(conn.c.insert_id)
145	if conn.c.num_fields == 0 {
146		return qr, nil
147	}
148
149	if qr.RowsAffected > uint64(maxrows) {
150		return nil, &SqlError{0, fmt.Sprintf("Row count exceeded %d", maxrows), string(query)}
151	}
152	if wantfields {
153		qr.Fields = conn.Fields()
154	}
155	qr.Rows, err = conn.fetchAll()
156	return qr, err
157}
158
159// when using ExecuteStreamFetch, use FetchNext on the Connection until it returns nil or error
160func (conn *Connection) ExecuteStreamFetch(query string) (err error) {
161	if conn.IsClosed() {
162		return NewSqlError(2006, "Connection is closed")
163	}
164	if C.vt_execute(&conn.c, (*C.char)(hack.StringPointer(query)), C.ulong(len(query)), 1) != 0 {
165		return conn.lastError(query)
166	}
167	return nil
168}
169
170func (conn *Connection) Fields() (fields []proto.Field) {
171	nfields := int(conn.c.num_fields)
172	if nfields == 0 {
173		return nil
174	}
175	cfields := (*[maxSize]C.MYSQL_FIELD)(unsafe.Pointer(conn.c.fields))
176	totalLength := uint64(0)
177	for i := 0; i < nfields; i++ {
178		totalLength += uint64(cfields[i].name_length)
179	}
180	fields = make([]proto.Field, nfields)
181	for i := 0; i < nfields; i++ {
182		length := cfields[i].name_length
183		fname := (*[maxSize]byte)(unsafe.Pointer(cfields[i].name))[:length]
184		fields[i].Name = string(fname)
185		fields[i].Type = int64(cfields[i]._type)
186	}
187	return fields
188}
189
190func (conn *Connection) fetchAll() (rows [][]sqltypes.Value, err error) {
191	rowCount := int(conn.c.affected_rows)
192	if rowCount == 0 {
193		return nil, nil
194	}
195	rows = make([][]sqltypes.Value, rowCount)
196	for i := 0; i < rowCount; i++ {
197		rows[i], err = conn.FetchNext()
198		if err != nil {
199			return nil, err
200		}
201	}
202	return rows, nil
203}
204
205func (conn *Connection) FetchNext() (row []sqltypes.Value, err error) {
206	vtrow := C.vt_fetch_next(&conn.c)
207	if vtrow.has_error != 0 {
208		return nil, conn.lastError("")
209	}
210	rowPtr := (*[maxSize]*[maxSize]byte)(unsafe.Pointer(vtrow.mysql_row))
211	if rowPtr == nil {
212		return nil, nil
213	}
214	colCount := int(conn.c.num_fields)
215	cfields := (*[maxSize]C.MYSQL_FIELD)(unsafe.Pointer(conn.c.fields))
216	row = make([]sqltypes.Value, colCount)
217	lengths := (*[maxSize]uint64)(unsafe.Pointer(vtrow.lengths))
218	totalLength := uint64(0)
219	for i := 0; i < colCount; i++ {
220		totalLength += lengths[i]
221	}
222	arena := make([]byte, 0, int(totalLength))
223	for i := 0; i < colCount; i++ {
224		colLength := lengths[i]
225		colPtr := rowPtr[i]
226		if colPtr == nil {
227			continue
228		}
229		start := len(arena)
230		arena = append(arena, colPtr[:colLength]...)
231		row[i] = BuildValue(arena[start:start+int(colLength)], cfields[i]._type)
232	}
233	return row, nil
234}
235
236func (conn *Connection) CloseResult() {
237	C.vt_close_result(&conn.c)
238}
239
240func (conn *Connection) Id() int64 {
241	if conn.c.mysql == nil {
242		return 0
243	}
244	return int64(C.vt_thread_id(&conn.c))
245}
246
247func (conn *Connection) lastError(query string) error {
248	if err := C.vt_error(&conn.c); *err != 0 {
249		return &SqlError{Num: int(C.vt_errno(&conn.c)), Message: C.GoString(err), Query: query}
250	}
251	return &SqlError{0, "Dummy", string(query)}
252}
253
254func BuildValue(bytes []byte, fieldType uint32) sqltypes.Value {
255	switch fieldType {
256	case C.MYSQL_TYPE_DECIMAL, C.MYSQL_TYPE_FLOAT, C.MYSQL_TYPE_DOUBLE, C.MYSQL_TYPE_NEWDECIMAL:
257		return sqltypes.MakeFractional(bytes)
258	case C.MYSQL_TYPE_TIMESTAMP:
259		return sqltypes.MakeString(bytes)
260	}
261	// The below condition represents the following list of values:
262	// C.MYSQL_TYPE_TINY, C.MYSQL_TYPE_SHORT, C.MYSQL_TYPE_LONG, C.MYSQL_TYPE_LONGLONG, C.MYSQL_TYPE_INT24, C.MYSQL_TYPE_YEAR:
263	if fieldType <= C.MYSQL_TYPE_INT24 || fieldType == C.MYSQL_TYPE_YEAR {
264		return sqltypes.MakeNumeric(bytes)
265	}
266	return sqltypes.MakeString(bytes)
267}
268
269func cfree(str *C.char) {
270	if str != nil {
271		C.free(unsafe.Pointer(str))
272	}
273}