/orm/db_alias.go

https://gitlab.com/chenggangschool/beego · Go · 278 lines · 219 code · 39 blank · 20 comment · 47 complexity · 2299df4c0ca13efbd5108f7aa46b92d6 MD5 · raw file

  1. package orm
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "reflect"
  6. "sync"
  7. "time"
  8. )
  9. // database driver constant int.
  10. type DriverType int
  11. const (
  12. _ DriverType = iota // int enum type
  13. DR_MySQL // mysql
  14. DR_Sqlite // sqlite
  15. DR_Oracle // oracle
  16. DR_Postgres // pgsql
  17. )
  18. // database driver string.
  19. type driver string
  20. // get type constant int of current driver..
  21. func (d driver) Type() DriverType {
  22. a, _ := dataBaseCache.get(string(d))
  23. return a.Driver
  24. }
  25. // get name of current driver
  26. func (d driver) Name() string {
  27. return string(d)
  28. }
  29. // check driver iis implemented Driver interface or not.
  30. var _ Driver = new(driver)
  31. var (
  32. dataBaseCache = &_dbCache{cache: make(map[string]*alias)}
  33. drivers = map[string]DriverType{
  34. "mysql": DR_MySQL,
  35. "postgres": DR_Postgres,
  36. "sqlite3": DR_Sqlite,
  37. }
  38. dbBasers = map[DriverType]dbBaser{
  39. DR_MySQL: newdbBaseMysql(),
  40. DR_Sqlite: newdbBaseSqlite(),
  41. DR_Oracle: newdbBaseMysql(),
  42. DR_Postgres: newdbBasePostgres(),
  43. }
  44. )
  45. // database alias cacher.
  46. type _dbCache struct {
  47. mux sync.RWMutex
  48. cache map[string]*alias
  49. }
  50. // add database alias with original name.
  51. func (ac *_dbCache) add(name string, al *alias) (added bool) {
  52. ac.mux.Lock()
  53. defer ac.mux.Unlock()
  54. if _, ok := ac.cache[name]; ok == false {
  55. ac.cache[name] = al
  56. added = true
  57. }
  58. return
  59. }
  60. // get database alias if cached.
  61. func (ac *_dbCache) get(name string) (al *alias, ok bool) {
  62. ac.mux.RLock()
  63. defer ac.mux.RUnlock()
  64. al, ok = ac.cache[name]
  65. return
  66. }
  67. // get default alias.
  68. func (ac *_dbCache) getDefault() (al *alias) {
  69. al, _ = ac.get("default")
  70. return
  71. }
  72. type alias struct {
  73. Name string
  74. Driver DriverType
  75. DriverName string
  76. DataSource string
  77. MaxIdleConns int
  78. MaxOpenConns int
  79. DB *sql.DB
  80. DbBaser dbBaser
  81. TZ *time.Location
  82. Engine string
  83. }
  84. func detectTZ(al *alias) {
  85. // orm timezone system match database
  86. // default use Local
  87. al.TZ = time.Local
  88. if al.DriverName == "sphinx" {
  89. return
  90. }
  91. switch al.Driver {
  92. case DR_MySQL:
  93. row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)")
  94. var tz string
  95. row.Scan(&tz)
  96. if len(tz) >= 8 {
  97. if tz[0] != '-' {
  98. tz = "+" + tz
  99. }
  100. t, err := time.Parse("-07:00:00", tz)
  101. if err == nil {
  102. al.TZ = t.Location()
  103. } else {
  104. DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
  105. }
  106. }
  107. // get default engine from current database
  108. row = al.DB.QueryRow("SELECT ENGINE, TRANSACTIONS FROM information_schema.engines WHERE SUPPORT = 'DEFAULT'")
  109. var engine string
  110. var tx bool
  111. row.Scan(&engine, &tx)
  112. if engine != "" {
  113. al.Engine = engine
  114. } else {
  115. engine = "INNODB"
  116. }
  117. case DR_Sqlite:
  118. al.TZ = time.UTC
  119. case DR_Postgres:
  120. row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')")
  121. var tz string
  122. row.Scan(&tz)
  123. loc, err := time.LoadLocation(tz)
  124. if err == nil {
  125. al.TZ = loc
  126. } else {
  127. DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
  128. }
  129. }
  130. }
  131. func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) {
  132. al := new(alias)
  133. al.Name = aliasName
  134. al.DriverName = driverName
  135. al.DB = db
  136. if dr, ok := drivers[driverName]; ok {
  137. al.DbBaser = dbBasers[dr]
  138. al.Driver = dr
  139. } else {
  140. return nil, fmt.Errorf("driver name `%s` have not registered", driverName)
  141. }
  142. err := db.Ping()
  143. if err != nil {
  144. return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error())
  145. }
  146. if dataBaseCache.add(aliasName, al) == false {
  147. return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName)
  148. }
  149. return al, nil
  150. }
  151. func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error {
  152. _, err := addAliasWthDB(aliasName, driverName, db)
  153. return err
  154. }
  155. // Setting the database connect params. Use the database driver self dataSource args.
  156. func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error {
  157. var (
  158. err error
  159. db *sql.DB
  160. al *alias
  161. )
  162. db, err = sql.Open(driverName, dataSource)
  163. if err != nil {
  164. err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
  165. goto end
  166. }
  167. al, err = addAliasWthDB(aliasName, driverName, db)
  168. if err != nil {
  169. goto end
  170. }
  171. al.DataSource = dataSource
  172. detectTZ(al)
  173. for i, v := range params {
  174. switch i {
  175. case 0:
  176. SetMaxIdleConns(al.Name, v)
  177. case 1:
  178. SetMaxOpenConns(al.Name, v)
  179. }
  180. }
  181. end:
  182. if err != nil {
  183. if db != nil {
  184. db.Close()
  185. }
  186. DebugLog.Println(err.Error())
  187. }
  188. return err
  189. }
  190. // Register a database driver use specify driver name, this can be definition the driver is which database type.
  191. func RegisterDriver(driverName string, typ DriverType) error {
  192. if t, ok := drivers[driverName]; ok == false {
  193. drivers[driverName] = typ
  194. } else {
  195. if t != typ {
  196. return fmt.Errorf("driverName `%s` db driver already registered and is other type\n", driverName)
  197. }
  198. }
  199. return nil
  200. }
  201. // Change the database default used timezone
  202. func SetDataBaseTZ(aliasName string, tz *time.Location) error {
  203. if al, ok := dataBaseCache.get(aliasName); ok {
  204. al.TZ = tz
  205. } else {
  206. return fmt.Errorf("DataBase alias name `%s` not registered\n", aliasName)
  207. }
  208. return nil
  209. }
  210. // Change the max idle conns for *sql.DB, use specify database alias name
  211. func SetMaxIdleConns(aliasName string, maxIdleConns int) {
  212. al := getDbAlias(aliasName)
  213. al.MaxIdleConns = maxIdleConns
  214. al.DB.SetMaxIdleConns(maxIdleConns)
  215. }
  216. // Change the max open conns for *sql.DB, use specify database alias name
  217. func SetMaxOpenConns(aliasName string, maxOpenConns int) {
  218. al := getDbAlias(aliasName)
  219. al.MaxOpenConns = maxOpenConns
  220. // for tip go 1.2
  221. if fun := reflect.ValueOf(al.DB).MethodByName("SetMaxOpenConns"); fun.IsValid() {
  222. fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)})
  223. }
  224. }
  225. // Get *sql.DB from registered database by db alias name.
  226. // Use "default" as alias name if you not set.
  227. func GetDB(aliasNames ...string) (*sql.DB, error) {
  228. var name string
  229. if len(aliasNames) > 0 {
  230. name = aliasNames[0]
  231. } else {
  232. name = "default"
  233. }
  234. if al, ok := dataBaseCache.get(name); ok {
  235. return al.DB, nil
  236. } else {
  237. return nil, fmt.Errorf("DataBase of alias name `%s` not found\n", name)
  238. }
  239. }