/db/handler/db.go

https://github.com/micro/services · Go · 345 lines · 297 code · 42 blank · 6 comment · 82 complexity · f8dee4934f7ab658083bfa758b3f362a MD5 · raw file

  1. package handler
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "regexp"
  7. "strings"
  8. "time"
  9. "github.com/google/uuid"
  10. "github.com/micro/micro/v3/service/errors"
  11. "github.com/micro/micro/v3/service/logger"
  12. db "github.com/micro/services/db/proto"
  13. gorm2 "github.com/micro/services/pkg/gorm"
  14. "github.com/micro/services/pkg/tenant"
  15. "github.com/patrickmn/go-cache"
  16. "google.golang.org/protobuf/types/known/structpb"
  17. "gorm.io/datatypes"
  18. "gorm.io/gorm"
  19. )
  20. const idKey = "id"
  21. const stmt = "create table if not exists %v(id text not null, data jsonb, primary key(id)); alter table %v add created_at timestamptz; alter table %v add updated_at timestamptz"
  22. const truncateStmt = `truncate table "%v"`
  23. var re = regexp.MustCompile("^[a-zA-Z0-9_]*$")
  24. var c = cache.New(5*time.Minute, 10*time.Minute)
  25. type Record struct {
  26. ID string
  27. Data datatypes.JSON `json:"data"`
  28. // private field, ignored from gorm
  29. table string `gorm:"-"`
  30. CreatedAt time.Time
  31. UpdatedAt time.Time
  32. }
  33. type Db struct {
  34. gorm2.Helper
  35. }
  36. func correctFieldName(s string) string {
  37. switch s {
  38. // top level fields can stay top level
  39. case "id": // "created_at", "updated_at", <-- these are not special fields for now
  40. return s
  41. }
  42. if !strings.Contains(s, ".") {
  43. return fmt.Sprintf("data ->> '%v'", s)
  44. }
  45. paths := strings.Split(s, ".")
  46. ret := "data"
  47. for _, path := range paths {
  48. ret += fmt.Sprintf(" ->> '%v'", path)
  49. }
  50. return ret
  51. }
  52. func (e *Db) tableName(ctx context.Context, t string) (string, error) {
  53. tenantId, ok := tenant.FromContext(ctx)
  54. if !ok {
  55. tenantId = "micro"
  56. }
  57. if t == "" {
  58. t = "default"
  59. }
  60. t = strings.ToLower(t)
  61. t = strings.Replace(t, "-", "_", -1)
  62. tenantId = strings.Replace(strings.Replace(tenantId, "/", "_", -1), "-", "_", -1)
  63. tableName := tenantId + "_" + t
  64. if !re.Match([]byte(tableName)) {
  65. return "", fmt.Errorf("table name %v is invalid", t)
  66. }
  67. return tableName, nil
  68. }
  69. // Call is a single request handler called via client.Call or the generated client code
  70. func (e *Db) Create(ctx context.Context, req *db.CreateRequest, rsp *db.CreateResponse) error {
  71. if len(req.Record.AsMap()) == 0 {
  72. return errors.BadRequest("db.create", "missing record")
  73. }
  74. tableName, err := e.tableName(ctx, req.Table)
  75. if err != nil {
  76. return err
  77. }
  78. logger.Infof("Inserting into table '%v'", tableName)
  79. db, err := e.GetDBConn(ctx)
  80. if err != nil {
  81. return err
  82. }
  83. _, ok := c.Get(tableName)
  84. if !ok {
  85. logger.Infof("Creating table '%v'", tableName)
  86. db.Exec(fmt.Sprintf(stmt, tableName, tableName, tableName))
  87. c.Set(tableName, true, 0)
  88. }
  89. m := req.Record.AsMap()
  90. if _, ok := m[idKey].(string); !ok {
  91. m[idKey] = uuid.New().String()
  92. }
  93. bs, _ := json.Marshal(m)
  94. err = db.Table(tableName).Create(&Record{
  95. ID: m[idKey].(string),
  96. Data: bs,
  97. }).Error
  98. if err != nil {
  99. return err
  100. }
  101. // set the response id
  102. rsp.Id = m[idKey].(string)
  103. return nil
  104. }
  105. func (e *Db) Update(ctx context.Context, req *db.UpdateRequest, rsp *db.UpdateResponse) error {
  106. if len(req.Record.AsMap()) == 0 {
  107. return errors.BadRequest("db.update", "missing record")
  108. }
  109. tableName, err := e.tableName(ctx, req.Table)
  110. if err != nil {
  111. return err
  112. }
  113. logger.Infof("Updating table '%v'", tableName)
  114. db, err := e.GetDBConn(ctx)
  115. if err != nil {
  116. return err
  117. }
  118. m := req.Record.AsMap()
  119. // where ID is specified do a single update record update
  120. id := req.Id
  121. if v, ok := m[idKey].(string); ok && id == "" {
  122. id = v
  123. }
  124. // if the id is blank then check the data
  125. if len(req.Id) == 0 {
  126. var ok bool
  127. id, ok = m[idKey].(string)
  128. if !ok {
  129. return fmt.Errorf("update failed: missing id")
  130. }
  131. }
  132. return db.Transaction(func(tx *gorm.DB) error {
  133. rec := []Record{}
  134. err = tx.Table(tableName).Where("id = ?", id).Find(&rec).Error
  135. if err != nil {
  136. return err
  137. }
  138. if len(rec) == 0 {
  139. return fmt.Errorf("update failed: not found")
  140. }
  141. old := map[string]interface{}{}
  142. err = json.Unmarshal(rec[0].Data, &old)
  143. if err != nil {
  144. return err
  145. }
  146. for k, v := range m {
  147. old[k] = v
  148. }
  149. bs, _ := json.Marshal(old)
  150. return tx.Table(tableName).Save(&Record{
  151. ID: id,
  152. Data: bs,
  153. }).Error
  154. })
  155. }
  156. func (e *Db) Read(ctx context.Context, req *db.ReadRequest, rsp *db.ReadResponse) error {
  157. recs := []Record{}
  158. queries, err := Parse(req.Query)
  159. if err != nil {
  160. return err
  161. }
  162. tableName, err := e.tableName(ctx, req.Table)
  163. if err != nil {
  164. return err
  165. }
  166. db, err := e.GetDBConn(ctx)
  167. if err != nil {
  168. return err
  169. }
  170. _, ok := c.Get(tableName)
  171. if !ok {
  172. logger.Infof("Creating table '%v'", tableName)
  173. db.Exec(fmt.Sprintf(stmt, tableName, tableName, tableName))
  174. c.Set(tableName, true, 0)
  175. }
  176. if req.Limit > 1000 {
  177. return errors.BadRequest("db.read", fmt.Sprintf("limit over 1000 is invalid, you specified %v", req.Limit))
  178. }
  179. if req.Limit == 0 {
  180. req.Limit = 25
  181. }
  182. db = db.Table(tableName)
  183. if req.Id != "" {
  184. logger.Infof("Query by id: %v", req.Id)
  185. db = db.Where("id = ?", req.Id)
  186. } else {
  187. for _, query := range queries {
  188. logger.Infof("Query field: %v, op: %v, type: %v", query.Field, query.Op, query.Value)
  189. typ := "text"
  190. switch query.Value.(type) {
  191. case int64:
  192. typ = "int"
  193. case bool:
  194. typ = "boolean"
  195. }
  196. op := ""
  197. switch query.Op {
  198. case itemEquals:
  199. op = "="
  200. case itemGreaterThan:
  201. op = ">"
  202. case itemGreaterThanEquals:
  203. op = ">="
  204. case itemLessThan:
  205. op = "<"
  206. case itemLessThanEquals:
  207. op = "<="
  208. case itemNotEquals:
  209. op = "!="
  210. }
  211. queryField := correctFieldName(query.Field)
  212. db = db.Where(fmt.Sprintf("(%v)::%v %v ?", queryField, typ, op), query.Value)
  213. }
  214. }
  215. orderField := "created_at"
  216. if req.OrderBy != "" {
  217. orderField = req.OrderBy
  218. }
  219. orderField = correctFieldName(orderField)
  220. ordering := "asc"
  221. if req.Order != "" {
  222. switch strings.ToLower(req.Order) {
  223. case "asc":
  224. ordering = "asc"
  225. case "", "desc":
  226. ordering = "desc"
  227. default:
  228. return errors.BadRequest("db.read", "invalid ordering: "+req.Order)
  229. }
  230. }
  231. db = db.Order(orderField + " " + ordering).Offset(int(req.Offset)).Limit(int(req.Limit))
  232. err = db.Find(&recs).Error
  233. if err != nil {
  234. return err
  235. }
  236. rsp.Records = []*structpb.Struct{}
  237. for _, rec := range recs {
  238. m, err := rec.Data.MarshalJSON()
  239. if err != nil {
  240. return err
  241. }
  242. ma := map[string]interface{}{}
  243. json.Unmarshal(m, &ma)
  244. ma[idKey] = rec.ID
  245. m, _ = json.Marshal(ma)
  246. s := &structpb.Struct{}
  247. err = s.UnmarshalJSON(m)
  248. if err != nil {
  249. return err
  250. }
  251. rsp.Records = append(rsp.Records, s)
  252. }
  253. return nil
  254. }
  255. func (e *Db) Delete(ctx context.Context, req *db.DeleteRequest, rsp *db.DeleteResponse) error {
  256. if len(req.Id) == 0 {
  257. return errors.BadRequest("db.delete", "missing id")
  258. }
  259. tableName, err := e.tableName(ctx, req.Table)
  260. if err != nil {
  261. return err
  262. }
  263. logger.Infof("Deleting from table '%v'", tableName)
  264. db, err := e.GetDBConn(ctx)
  265. if err != nil {
  266. return err
  267. }
  268. return db.Table(tableName).Delete(Record{
  269. ID: req.Id,
  270. }).Error
  271. }
  272. func (e *Db) Truncate(ctx context.Context, req *db.TruncateRequest, rsp *db.TruncateResponse) error {
  273. tableName, err := e.tableName(ctx, req.Table)
  274. if err != nil {
  275. return err
  276. }
  277. logger.Infof("Truncating table '%v'", tableName)
  278. db, err := e.GetDBConn(ctx)
  279. if err != nil {
  280. return err
  281. }
  282. return db.Exec(fmt.Sprintf(truncateStmt, tableName)).Error
  283. }
  284. func (e *Db) Count(ctx context.Context, req *db.CountRequest, rsp *db.CountResponse) error {
  285. if req.Table == "" {
  286. req.Table = "default"
  287. }
  288. tableName, err := e.tableName(ctx, req.Table)
  289. if err != nil {
  290. return err
  291. }
  292. db, err := e.GetDBConn(ctx)
  293. if err != nil {
  294. return err
  295. }
  296. var a int64
  297. err = db.Table(tableName).Model(Record{}).Count(&a).Error
  298. if err != nil {
  299. return err
  300. }
  301. rsp.Count = int32(a)
  302. return nil
  303. }