/go/mysql/mysql.go

https://code.google.com/p/vitess/ · Go · 273 lines · 222 code · 35 blank · 16 comment · 54 complexity · bcc399cb2490412f5bbc920f8692e9e1 MD5 · raw file

  1. // Copyright 2012, Google Inc. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package mysql
  5. /*
  6. #cgo pkg-config: gomysql
  7. #include <stdlib.h>
  8. #include "vtmysql.h"
  9. */
  10. import "C"
  11. import (
  12. "fmt"
  13. "unsafe"
  14. "code.google.com/p/vitess/go/hack"
  15. "code.google.com/p/vitess/go/mysql/proto"
  16. "code.google.com/p/vitess/go/relog"
  17. "code.google.com/p/vitess/go/sqltypes"
  18. )
  19. const (
  20. // NOTE(szopa): maxSize used to be 1 << 30, but that causes
  21. // compiler errors in some situations.
  22. maxSize = 1 << 20
  23. )
  24. func init() {
  25. // This needs to be called before threads begin to spawn.
  26. C.vt_library_init()
  27. }
  28. type SqlError struct {
  29. Num int
  30. Message string
  31. Query string
  32. }
  33. func NewSqlError(number int, format string, args ...interface{}) *SqlError {
  34. return &SqlError{Num: number, Message: fmt.Sprintf(format, args...)}
  35. }
  36. func (se *SqlError) Error() string {
  37. if se.Query == "" {
  38. return fmt.Sprintf("%v (errno %v)", se.Message, se.Num)
  39. }
  40. return fmt.Sprintf("%v (errno %v) during query: %s", se.Message, se.Num, se.Query)
  41. }
  42. func (se *SqlError) Number() int {
  43. return se.Num
  44. }
  45. func handleError(err *error) {
  46. if x := recover(); x != nil {
  47. terr := x.(*SqlError)
  48. *err = terr
  49. }
  50. }
  51. type ConnectionParams struct {
  52. Host string `json:"host"`
  53. Port int `json:"port"`
  54. Uname string `json:"uname"`
  55. Pass string `json:"pass"`
  56. Dbname string `json:"dbname"`
  57. UnixSocket string `json:"unix_socket"`
  58. Charset string `json:"charset"`
  59. Flags uint64 `json:"flags"`
  60. // the following flags are only used for 'Change Master' command
  61. // for now (along with flags |= 2048 for CLIENT_SSL)
  62. SslCa string `json:"ssl_ca"`
  63. SslCaPath string `json:"ssl_ca_path"`
  64. SslCert string `json:"ssl_cert"`
  65. SslKey string `json:"ssl_key"`
  66. }
  67. func (c *ConnectionParams) EnableMultiStatements() {
  68. c.Flags |= C.CLIENT_MULTI_STATEMENTS
  69. }
  70. func (c *ConnectionParams) SslEnabled() bool {
  71. return (c.Flags & C.CLIENT_SSL) != 0
  72. }
  73. func (c ConnectionParams) Redacted() interface{} {
  74. c.Pass = relog.Redact(c.Pass)
  75. return c
  76. }
  77. type Connection struct {
  78. c C.VT_CONN
  79. }
  80. func Connect(params ConnectionParams) (conn *Connection, err error) {
  81. defer handleError(&err)
  82. host := C.CString(params.Host)
  83. defer cfree(host)
  84. port := C.uint(params.Port)
  85. uname := C.CString(params.Uname)
  86. defer cfree(uname)
  87. pass := C.CString(params.Pass)
  88. defer cfree(pass)
  89. dbname := C.CString(params.Dbname)
  90. defer cfree(dbname)
  91. unix_socket := C.CString(params.UnixSocket)
  92. defer cfree(unix_socket)
  93. charset := C.CString(params.Charset)
  94. defer cfree(charset)
  95. flags := C.ulong(params.Flags)
  96. conn = &Connection{}
  97. if C.vt_connect(&conn.c, host, uname, pass, dbname, port, unix_socket, charset, flags) != 0 {
  98. defer conn.Close()
  99. return nil, conn.lastError("")
  100. }
  101. return conn, nil
  102. }
  103. func (conn *Connection) Close() {
  104. C.vt_close(&conn.c)
  105. }
  106. func (conn *Connection) IsClosed() bool {
  107. return conn.c.mysql == nil
  108. }
  109. func (conn *Connection) ExecuteFetch(query string, maxrows int, wantfields bool) (qr *proto.QueryResult, err error) {
  110. if conn.IsClosed() {
  111. return nil, NewSqlError(2006, "Connection is closed")
  112. }
  113. if C.vt_execute(&conn.c, (*C.char)(hack.StringPointer(query)), C.ulong(len(query)), 0) != 0 {
  114. return nil, conn.lastError(query)
  115. }
  116. defer conn.CloseResult()
  117. qr = &proto.QueryResult{}
  118. qr.RowsAffected = uint64(conn.c.affected_rows)
  119. qr.InsertId = uint64(conn.c.insert_id)
  120. if conn.c.num_fields == 0 {
  121. return qr, nil
  122. }
  123. if qr.RowsAffected > uint64(maxrows) {
  124. return nil, &SqlError{0, fmt.Sprintf("Row count exceeded %d", maxrows), string(query)}
  125. }
  126. if wantfields {
  127. qr.Fields = conn.Fields()
  128. }
  129. qr.Rows, err = conn.fetchAll()
  130. return qr, err
  131. }
  132. // when using ExecuteStreamFetch, use FetchNext on the Connection until it returns nil or error
  133. func (conn *Connection) ExecuteStreamFetch(query string) (err error) {
  134. if conn.IsClosed() {
  135. return NewSqlError(2006, "Connection is closed")
  136. }
  137. if C.vt_execute(&conn.c, (*C.char)(hack.StringPointer(query)), C.ulong(len(query)), 1) != 0 {
  138. return conn.lastError(query)
  139. }
  140. return nil
  141. }
  142. func (conn *Connection) Fields() (fields []proto.Field) {
  143. nfields := int(conn.c.num_fields)
  144. if nfields == 0 {
  145. return nil
  146. }
  147. cfields := (*[maxSize]C.MYSQL_FIELD)(unsafe.Pointer(conn.c.fields))
  148. totalLength := uint64(0)
  149. for i := 0; i < nfields; i++ {
  150. totalLength += uint64(cfields[i].name_length)
  151. }
  152. fields = make([]proto.Field, nfields)
  153. for i := 0; i < nfields; i++ {
  154. length := cfields[i].name_length
  155. fname := (*[maxSize]byte)(unsafe.Pointer(cfields[i].name))[:length]
  156. fields[i].Name = string(fname)
  157. fields[i].Type = int64(cfields[i]._type)
  158. }
  159. return fields
  160. }
  161. func (conn *Connection) fetchAll() (rows [][]sqltypes.Value, err error) {
  162. rowCount := int(conn.c.affected_rows)
  163. if rowCount == 0 {
  164. return nil, nil
  165. }
  166. rows = make([][]sqltypes.Value, rowCount)
  167. for i := 0; i < rowCount; i++ {
  168. rows[i], err = conn.FetchNext()
  169. if err != nil {
  170. return nil, err
  171. }
  172. }
  173. return rows, nil
  174. }
  175. func (conn *Connection) FetchNext() (row []sqltypes.Value, err error) {
  176. vtrow := C.vt_fetch_next(&conn.c)
  177. if vtrow.has_error != 0 {
  178. return nil, conn.lastError("")
  179. }
  180. rowPtr := (*[maxSize]*[maxSize]byte)(unsafe.Pointer(vtrow.mysql_row))
  181. if rowPtr == nil {
  182. return nil, nil
  183. }
  184. colCount := int(conn.c.num_fields)
  185. cfields := (*[maxSize]C.MYSQL_FIELD)(unsafe.Pointer(conn.c.fields))
  186. row = make([]sqltypes.Value, colCount)
  187. lengths := (*[maxSize]uint64)(unsafe.Pointer(vtrow.lengths))
  188. totalLength := uint64(0)
  189. for i := 0; i < colCount; i++ {
  190. totalLength += lengths[i]
  191. }
  192. arena := make([]byte, 0, int(totalLength))
  193. for i := 0; i < colCount; i++ {
  194. colLength := lengths[i]
  195. colPtr := rowPtr[i]
  196. if colPtr == nil {
  197. continue
  198. }
  199. start := len(arena)
  200. arena = append(arena, colPtr[:colLength]...)
  201. row[i] = BuildValue(arena[start:start+int(colLength)], cfields[i]._type)
  202. }
  203. return row, nil
  204. }
  205. func (conn *Connection) CloseResult() {
  206. C.vt_close_result(&conn.c)
  207. }
  208. func (conn *Connection) Id() int64 {
  209. if conn.c.mysql == nil {
  210. return 0
  211. }
  212. return int64(C.vt_thread_id(&conn.c))
  213. }
  214. func (conn *Connection) lastError(query string) error {
  215. if err := C.vt_error(&conn.c); *err != 0 {
  216. return &SqlError{Num: int(C.vt_errno(&conn.c)), Message: C.GoString(err), Query: query}
  217. }
  218. return &SqlError{0, "Dummy", string(query)}
  219. }
  220. func BuildValue(bytes []byte, fieldType uint32) sqltypes.Value {
  221. switch fieldType {
  222. case C.MYSQL_TYPE_DECIMAL, C.MYSQL_TYPE_FLOAT, C.MYSQL_TYPE_DOUBLE, C.MYSQL_TYPE_NEWDECIMAL:
  223. return sqltypes.MakeFractional(bytes)
  224. case C.MYSQL_TYPE_TIMESTAMP:
  225. return sqltypes.MakeString(bytes)
  226. }
  227. // The below condition represents the following list of values:
  228. // C.MYSQL_TYPE_TINY, C.MYSQL_TYPE_SHORT, C.MYSQL_TYPE_LONG, C.MYSQL_TYPE_LONGLONG, C.MYSQL_TYPE_INT24, C.MYSQL_TYPE_YEAR:
  229. if fieldType <= C.MYSQL_TYPE_INT24 || fieldType == C.MYSQL_TYPE_YEAR {
  230. return sqltypes.MakeNumeric(bytes)
  231. }
  232. return sqltypes.MakeString(bytes)
  233. }
  234. func cfree(str *C.char) {
  235. if str != nil {
  236. C.free(unsafe.Pointer(str))
  237. }
  238. }