PageRenderTime 1307ms CodeModel.GetById 22ms RepoModel.GetById 0ms app.codeStats 0ms

/orm/db_alias.go

https://gitlab.com/kumarsiva07/beego
Go | 298 lines | 223 code | 40 blank | 35 comment | 46 complexity | 384a52d7ae4d0abf94404597c2a4da23 MD5 | raw file
  1. // Copyright 2014 beego Author. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package orm
  15. import (
  16. "database/sql"
  17. "fmt"
  18. "reflect"
  19. "sync"
  20. "time"
  21. )
  22. // DriverType database driver constant int.
  23. type DriverType int
  24. // Enum the Database driver
  25. const (
  26. _ DriverType = iota // int enum type
  27. DRMySQL // mysql
  28. DRSqlite // sqlite
  29. DROracle // oracle
  30. DRPostgres // pgsql
  31. DRTiDB // TiDB
  32. )
  33. // database driver string.
  34. type driver string
  35. // get type constant int of current driver..
  36. func (d driver) Type() DriverType {
  37. a, _ := dataBaseCache.get(string(d))
  38. return a.Driver
  39. }
  40. // get name of current driver
  41. func (d driver) Name() string {
  42. return string(d)
  43. }
  44. // check driver iis implemented Driver interface or not.
  45. var _ Driver = new(driver)
  46. var (
  47. dataBaseCache = &_dbCache{cache: make(map[string]*alias)}
  48. drivers = map[string]DriverType{
  49. "mysql": DRMySQL,
  50. "postgres": DRPostgres,
  51. "sqlite3": DRSqlite,
  52. "tidb": DRTiDB,
  53. "oracle": DROracle,
  54. }
  55. dbBasers = map[DriverType]dbBaser{
  56. DRMySQL: newdbBaseMysql(),
  57. DRSqlite: newdbBaseSqlite(),
  58. DROracle: newdbBaseOracle(),
  59. DRPostgres: newdbBasePostgres(),
  60. DRTiDB: newdbBaseTidb(),
  61. }
  62. )
  63. // database alias cacher.
  64. type _dbCache struct {
  65. mux sync.RWMutex
  66. cache map[string]*alias
  67. }
  68. // add database alias with original name.
  69. func (ac *_dbCache) add(name string, al *alias) (added bool) {
  70. ac.mux.Lock()
  71. defer ac.mux.Unlock()
  72. if _, ok := ac.cache[name]; ok == false {
  73. ac.cache[name] = al
  74. added = true
  75. }
  76. return
  77. }
  78. // get database alias if cached.
  79. func (ac *_dbCache) get(name string) (al *alias, ok bool) {
  80. ac.mux.RLock()
  81. defer ac.mux.RUnlock()
  82. al, ok = ac.cache[name]
  83. return
  84. }
  85. // get default alias.
  86. func (ac *_dbCache) getDefault() (al *alias) {
  87. al, _ = ac.get("default")
  88. return
  89. }
  90. type alias struct {
  91. Name string
  92. Driver DriverType
  93. DriverName string
  94. DataSource string
  95. MaxIdleConns int
  96. MaxOpenConns int
  97. DB *sql.DB
  98. DbBaser dbBaser
  99. TZ *time.Location
  100. Engine string
  101. }
  102. func detectTZ(al *alias) {
  103. // orm timezone system match database
  104. // default use Local
  105. al.TZ = time.Local
  106. if al.DriverName == "sphinx" {
  107. return
  108. }
  109. switch al.Driver {
  110. case DRMySQL:
  111. row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)")
  112. var tz string
  113. row.Scan(&tz)
  114. if len(tz) >= 8 {
  115. if tz[0] != '-' {
  116. tz = "+" + tz
  117. }
  118. t, err := time.Parse("-07:00:00", tz)
  119. if err == nil {
  120. al.TZ = t.Location()
  121. } else {
  122. DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
  123. }
  124. }
  125. // get default engine from current database
  126. row = al.DB.QueryRow("SELECT ENGINE, TRANSACTIONS FROM information_schema.engines WHERE SUPPORT = 'DEFAULT'")
  127. var engine string
  128. var tx bool
  129. row.Scan(&engine, &tx)
  130. if engine != "" {
  131. al.Engine = engine
  132. } else {
  133. al.Engine = "INNODB"
  134. }
  135. case DRSqlite, DROracle:
  136. al.TZ = time.UTC
  137. case DRPostgres:
  138. row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')")
  139. var tz string
  140. row.Scan(&tz)
  141. loc, err := time.LoadLocation(tz)
  142. if err == nil {
  143. al.TZ = loc
  144. } else {
  145. DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
  146. }
  147. }
  148. }
  149. func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) {
  150. al := new(alias)
  151. al.Name = aliasName
  152. al.DriverName = driverName
  153. al.DB = db
  154. if dr, ok := drivers[driverName]; ok {
  155. al.DbBaser = dbBasers[dr]
  156. al.Driver = dr
  157. } else {
  158. return nil, fmt.Errorf("driver name `%s` have not registered", driverName)
  159. }
  160. err := db.Ping()
  161. if err != nil {
  162. return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error())
  163. }
  164. if dataBaseCache.add(aliasName, al) == false {
  165. return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName)
  166. }
  167. return al, nil
  168. }
  169. // AddAliasWthDB add a aliasName for the drivename
  170. func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error {
  171. _, err := addAliasWthDB(aliasName, driverName, db)
  172. return err
  173. }
  174. // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args.
  175. func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error {
  176. var (
  177. err error
  178. db *sql.DB
  179. al *alias
  180. )
  181. db, err = sql.Open(driverName, dataSource)
  182. if err != nil {
  183. err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
  184. goto end
  185. }
  186. al, err = addAliasWthDB(aliasName, driverName, db)
  187. if err != nil {
  188. goto end
  189. }
  190. al.DataSource = dataSource
  191. detectTZ(al)
  192. for i, v := range params {
  193. switch i {
  194. case 0:
  195. SetMaxIdleConns(al.Name, v)
  196. case 1:
  197. SetMaxOpenConns(al.Name, v)
  198. }
  199. }
  200. end:
  201. if err != nil {
  202. if db != nil {
  203. db.Close()
  204. }
  205. DebugLog.Println(err.Error())
  206. }
  207. return err
  208. }
  209. // RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type.
  210. func RegisterDriver(driverName string, typ DriverType) error {
  211. if t, ok := drivers[driverName]; ok == false {
  212. drivers[driverName] = typ
  213. } else {
  214. if t != typ {
  215. return fmt.Errorf("driverName `%s` db driver already registered and is other type\n", driverName)
  216. }
  217. }
  218. return nil
  219. }
  220. // SetDataBaseTZ Change the database default used timezone
  221. func SetDataBaseTZ(aliasName string, tz *time.Location) error {
  222. if al, ok := dataBaseCache.get(aliasName); ok {
  223. al.TZ = tz
  224. } else {
  225. return fmt.Errorf("DataBase alias name `%s` not registered\n", aliasName)
  226. }
  227. return nil
  228. }
  229. // SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name
  230. func SetMaxIdleConns(aliasName string, maxIdleConns int) {
  231. al := getDbAlias(aliasName)
  232. al.MaxIdleConns = maxIdleConns
  233. al.DB.SetMaxIdleConns(maxIdleConns)
  234. }
  235. // SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name
  236. func SetMaxOpenConns(aliasName string, maxOpenConns int) {
  237. al := getDbAlias(aliasName)
  238. al.MaxOpenConns = maxOpenConns
  239. // for tip go 1.2
  240. if fun := reflect.ValueOf(al.DB).MethodByName("SetMaxOpenConns"); fun.IsValid() {
  241. fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)})
  242. }
  243. }
  244. // GetDB Get *sql.DB from registered database by db alias name.
  245. // Use "default" as alias name if you not set.
  246. func GetDB(aliasNames ...string) (*sql.DB, error) {
  247. var name string
  248. if len(aliasNames) > 0 {
  249. name = aliasNames[0]
  250. } else {
  251. name = "default"
  252. }
  253. al, ok := dataBaseCache.get(name)
  254. if ok {
  255. return al.DB, nil
  256. }
  257. return nil, fmt.Errorf("DataBase of alias name `%s` not found\n", name)
  258. }