PageRenderTime 1550ms CodeModel.GetById 7ms RepoModel.GetById 1ms app.codeStats 0ms

/vendor/github.com/astaxie/beego/orm/db_alias.go

https://gitlab.com/e0/harbor
Go | 297 lines | 222 code | 40 blank | 35 comment | 46 complexity | c3faef5e45588741ae1fcd3c94bd6b82 MD5 | raw file
  1. // Copyright 2014 beego Author. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package orm
  15. import (
  16. "database/sql"
  17. "fmt"
  18. "reflect"
  19. "sync"
  20. "time"
  21. )
  22. // DriverType database driver constant int.
  23. type DriverType int
  24. // Enum the Database driver
  25. const (
  26. _ DriverType = iota // int enum type
  27. DRMySQL // mysql
  28. DRSqlite // sqlite
  29. DROracle // oracle
  30. DRPostgres // pgsql
  31. DRTiDB // TiDB
  32. )
  33. // database driver string.
  34. type driver string
  35. // get type constant int of current driver..
  36. func (d driver) Type() DriverType {
  37. a, _ := dataBaseCache.get(string(d))
  38. return a.Driver
  39. }
  40. // get name of current driver
  41. func (d driver) Name() string {
  42. return string(d)
  43. }
  44. // check driver iis implemented Driver interface or not.
  45. var _ Driver = new(driver)
  46. var (
  47. dataBaseCache = &_dbCache{cache: make(map[string]*alias)}
  48. drivers = map[string]DriverType{
  49. "mysql": DRMySQL,
  50. "postgres": DRPostgres,
  51. "sqlite3": DRSqlite,
  52. "tidb": DRTiDB,
  53. }
  54. dbBasers = map[DriverType]dbBaser{
  55. DRMySQL: newdbBaseMysql(),
  56. DRSqlite: newdbBaseSqlite(),
  57. DROracle: newdbBaseOracle(),
  58. DRPostgres: newdbBasePostgres(),
  59. DRTiDB: newdbBaseTidb(),
  60. }
  61. )
  62. // database alias cacher.
  63. type _dbCache struct {
  64. mux sync.RWMutex
  65. cache map[string]*alias
  66. }
  67. // add database alias with original name.
  68. func (ac *_dbCache) add(name string, al *alias) (added bool) {
  69. ac.mux.Lock()
  70. defer ac.mux.Unlock()
  71. if _, ok := ac.cache[name]; ok == false {
  72. ac.cache[name] = al
  73. added = true
  74. }
  75. return
  76. }
  77. // get database alias if cached.
  78. func (ac *_dbCache) get(name string) (al *alias, ok bool) {
  79. ac.mux.RLock()
  80. defer ac.mux.RUnlock()
  81. al, ok = ac.cache[name]
  82. return
  83. }
  84. // get default alias.
  85. func (ac *_dbCache) getDefault() (al *alias) {
  86. al, _ = ac.get("default")
  87. return
  88. }
  89. type alias struct {
  90. Name string
  91. Driver DriverType
  92. DriverName string
  93. DataSource string
  94. MaxIdleConns int
  95. MaxOpenConns int
  96. DB *sql.DB
  97. DbBaser dbBaser
  98. TZ *time.Location
  99. Engine string
  100. }
  101. func detectTZ(al *alias) {
  102. // orm timezone system match database
  103. // default use Local
  104. al.TZ = time.Local
  105. if al.DriverName == "sphinx" {
  106. return
  107. }
  108. switch al.Driver {
  109. case DRMySQL:
  110. row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)")
  111. var tz string
  112. row.Scan(&tz)
  113. if len(tz) >= 8 {
  114. if tz[0] != '-' {
  115. tz = "+" + tz
  116. }
  117. t, err := time.Parse("-07:00:00", tz)
  118. if err == nil {
  119. al.TZ = t.Location()
  120. } else {
  121. DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
  122. }
  123. }
  124. // get default engine from current database
  125. row = al.DB.QueryRow("SELECT ENGINE, TRANSACTIONS FROM information_schema.engines WHERE SUPPORT = 'DEFAULT'")
  126. var engine string
  127. var tx bool
  128. row.Scan(&engine, &tx)
  129. if engine != "" {
  130. al.Engine = engine
  131. } else {
  132. al.Engine = "INNODB"
  133. }
  134. case DRSqlite:
  135. al.TZ = time.UTC
  136. case DRPostgres:
  137. row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')")
  138. var tz string
  139. row.Scan(&tz)
  140. loc, err := time.LoadLocation(tz)
  141. if err == nil {
  142. al.TZ = loc
  143. } else {
  144. DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
  145. }
  146. }
  147. }
  148. func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) {
  149. al := new(alias)
  150. al.Name = aliasName
  151. al.DriverName = driverName
  152. al.DB = db
  153. if dr, ok := drivers[driverName]; ok {
  154. al.DbBaser = dbBasers[dr]
  155. al.Driver = dr
  156. } else {
  157. return nil, fmt.Errorf("driver name `%s` have not registered", driverName)
  158. }
  159. err := db.Ping()
  160. if err != nil {
  161. return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error())
  162. }
  163. if dataBaseCache.add(aliasName, al) == false {
  164. return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName)
  165. }
  166. return al, nil
  167. }
  168. // AddAliasWthDB add a aliasName for the drivename
  169. func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error {
  170. _, err := addAliasWthDB(aliasName, driverName, db)
  171. return err
  172. }
  173. // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args.
  174. func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error {
  175. var (
  176. err error
  177. db *sql.DB
  178. al *alias
  179. )
  180. db, err = sql.Open(driverName, dataSource)
  181. if err != nil {
  182. err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
  183. goto end
  184. }
  185. al, err = addAliasWthDB(aliasName, driverName, db)
  186. if err != nil {
  187. goto end
  188. }
  189. al.DataSource = dataSource
  190. detectTZ(al)
  191. for i, v := range params {
  192. switch i {
  193. case 0:
  194. SetMaxIdleConns(al.Name, v)
  195. case 1:
  196. SetMaxOpenConns(al.Name, v)
  197. }
  198. }
  199. end:
  200. if err != nil {
  201. if db != nil {
  202. db.Close()
  203. }
  204. DebugLog.Println(err.Error())
  205. }
  206. return err
  207. }
  208. // RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type.
  209. func RegisterDriver(driverName string, typ DriverType) error {
  210. if t, ok := drivers[driverName]; ok == false {
  211. drivers[driverName] = typ
  212. } else {
  213. if t != typ {
  214. return fmt.Errorf("driverName `%s` db driver already registered and is other type\n", driverName)
  215. }
  216. }
  217. return nil
  218. }
  219. // SetDataBaseTZ Change the database default used timezone
  220. func SetDataBaseTZ(aliasName string, tz *time.Location) error {
  221. if al, ok := dataBaseCache.get(aliasName); ok {
  222. al.TZ = tz
  223. } else {
  224. return fmt.Errorf("DataBase alias name `%s` not registered\n", aliasName)
  225. }
  226. return nil
  227. }
  228. // SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name
  229. func SetMaxIdleConns(aliasName string, maxIdleConns int) {
  230. al := getDbAlias(aliasName)
  231. al.MaxIdleConns = maxIdleConns
  232. al.DB.SetMaxIdleConns(maxIdleConns)
  233. }
  234. // SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name
  235. func SetMaxOpenConns(aliasName string, maxOpenConns int) {
  236. al := getDbAlias(aliasName)
  237. al.MaxOpenConns = maxOpenConns
  238. // for tip go 1.2
  239. if fun := reflect.ValueOf(al.DB).MethodByName("SetMaxOpenConns"); fun.IsValid() {
  240. fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)})
  241. }
  242. }
  243. // GetDB Get *sql.DB from registered database by db alias name.
  244. // Use "default" as alias name if you not set.
  245. func GetDB(aliasNames ...string) (*sql.DB, error) {
  246. var name string
  247. if len(aliasNames) > 0 {
  248. name = aliasNames[0]
  249. } else {
  250. name = "default"
  251. }
  252. al, ok := dataBaseCache.get(name)
  253. if ok {
  254. return al.DB, nil
  255. }
  256. return nil, fmt.Errorf("DataBase of alias name `%s` not found\n", name)
  257. }