/orm/db_alias.go

https://github.com/assad2008/beego · Go · 220 lines · 180 code · 31 blank · 9 comment · 34 complexity · e21edaa4dc361fdb53bb17260079c6ff MD5 · raw file

  1. package orm
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "os"
  6. "reflect"
  7. "sync"
  8. "time"
  9. )
  10. type DriverType int
  11. const (
  12. _ DriverType = iota
  13. DR_MySQL
  14. DR_Sqlite
  15. DR_Oracle
  16. DR_Postgres
  17. )
  18. type driver string
  19. func (d driver) Type() DriverType {
  20. a, _ := dataBaseCache.get(string(d))
  21. return a.Driver
  22. }
  23. func (d driver) Name() string {
  24. return string(d)
  25. }
  26. var _ Driver = new(driver)
  27. var (
  28. dataBaseCache = &_dbCache{cache: make(map[string]*alias)}
  29. drivers = map[string]DriverType{
  30. "mysql": DR_MySQL,
  31. "postgres": DR_Postgres,
  32. "sqlite3": DR_Sqlite,
  33. }
  34. dbBasers = map[DriverType]dbBaser{
  35. DR_MySQL: newdbBaseMysql(),
  36. DR_Sqlite: newdbBaseSqlite(),
  37. DR_Oracle: newdbBaseMysql(),
  38. DR_Postgres: newdbBasePostgres(),
  39. }
  40. )
  41. type _dbCache struct {
  42. mux sync.RWMutex
  43. cache map[string]*alias
  44. }
  45. func (ac *_dbCache) add(name string, al *alias) (added bool) {
  46. ac.mux.Lock()
  47. defer ac.mux.Unlock()
  48. if _, ok := ac.cache[name]; ok == false {
  49. ac.cache[name] = al
  50. added = true
  51. }
  52. return
  53. }
  54. func (ac *_dbCache) get(name string) (al *alias, ok bool) {
  55. ac.mux.RLock()
  56. defer ac.mux.RUnlock()
  57. al, ok = ac.cache[name]
  58. return
  59. }
  60. func (ac *_dbCache) getDefault() (al *alias) {
  61. al, _ = ac.get("default")
  62. return
  63. }
  64. type alias struct {
  65. Name string
  66. Driver DriverType
  67. DriverName string
  68. DataSource string
  69. MaxIdleConns int
  70. MaxOpenConns int
  71. DB *sql.DB
  72. DbBaser dbBaser
  73. TZ *time.Location
  74. Engine string
  75. }
  76. // Setting the database connect params. Use the database driver self dataSource args.
  77. func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) {
  78. al := new(alias)
  79. al.Name = aliasName
  80. al.DriverName = driverName
  81. al.DataSource = dataSource
  82. var (
  83. err error
  84. )
  85. if dr, ok := drivers[driverName]; ok {
  86. al.DbBaser = dbBasers[dr]
  87. al.Driver = dr
  88. } else {
  89. err = fmt.Errorf("driver name `%s` have not registered", driverName)
  90. goto end
  91. }
  92. if dataBaseCache.add(aliasName, al) == false {
  93. err = fmt.Errorf("db name `%s` already registered, cannot reuse", aliasName)
  94. goto end
  95. }
  96. al.DB, err = sql.Open(driverName, dataSource)
  97. if err != nil {
  98. err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
  99. goto end
  100. }
  101. // orm timezone system match database
  102. // default use Local
  103. al.TZ = time.Local
  104. switch al.Driver {
  105. case DR_MySQL:
  106. row := al.DB.QueryRow("SELECT @@session.time_zone")
  107. var tz string
  108. row.Scan(&tz)
  109. if tz != "SYSTEM" {
  110. t, err := time.Parse("-07:00", tz)
  111. if err == nil {
  112. al.TZ = t.Location()
  113. }
  114. }
  115. // get default engine from current database
  116. row = al.DB.QueryRow("SELECT ENGINE, TRANSACTIONS FROM information_schema.engines WHERE SUPPORT = 'DEFAULT'")
  117. var engine string
  118. var tx bool
  119. row.Scan(&engine, &tx)
  120. if engine != "" {
  121. al.Engine = engine
  122. } else {
  123. engine = "INNODB"
  124. }
  125. case DR_Sqlite:
  126. al.TZ = time.UTC
  127. case DR_Postgres:
  128. row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')")
  129. var tz string
  130. row.Scan(&tz)
  131. loc, err := time.LoadLocation(tz)
  132. if err == nil {
  133. al.TZ = loc
  134. }
  135. }
  136. for i, v := range params {
  137. switch i {
  138. case 0:
  139. SetMaxIdleConns(al.Name, v)
  140. case 1:
  141. SetMaxOpenConns(al.Name, v)
  142. }
  143. }
  144. err = al.DB.Ping()
  145. if err != nil {
  146. err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
  147. goto end
  148. }
  149. end:
  150. if err != nil {
  151. fmt.Println(err.Error())
  152. os.Exit(2)
  153. }
  154. }
  155. // Register a database driver use specify driver name, this can be definition the driver is which database type.
  156. func RegisterDriver(driverName string, typ DriverType) {
  157. if t, ok := drivers[driverName]; ok == false {
  158. drivers[driverName] = typ
  159. } else {
  160. if t != typ {
  161. fmt.Sprintf("driverName `%s` db driver already registered and is other type\n", driverName)
  162. os.Exit(2)
  163. }
  164. }
  165. }
  166. // Change the database default used timezone
  167. func SetDataBaseTZ(aliasName string, tz *time.Location) {
  168. if al, ok := dataBaseCache.get(aliasName); ok {
  169. al.TZ = tz
  170. } else {
  171. fmt.Sprintf("DataBase name `%s` not registered\n", aliasName)
  172. os.Exit(2)
  173. }
  174. }
  175. // Change the max idle conns for *sql.DB, use specify database alias name
  176. func SetMaxIdleConns(aliasName string, maxIdleConns int) {
  177. al := getDbAlias(aliasName)
  178. al.MaxIdleConns = maxIdleConns
  179. al.DB.SetMaxIdleConns(maxIdleConns)
  180. }
  181. // Change the max open conns for *sql.DB, use specify database alias name
  182. func SetMaxOpenConns(aliasName string, maxOpenConns int) {
  183. al := getDbAlias(aliasName)
  184. al.MaxOpenConns = maxOpenConns
  185. // for tip go 1.2
  186. if fun := reflect.ValueOf(al.DB).MethodByName("SetMaxOpenConns"); fun.IsValid() {
  187. fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)})
  188. }
  189. }