/internal/statements/query.go

https://github.com/xormplus/xorm · Go · 534 lines · 462 code · 61 blank · 11 comment · 241 complexity · a22ef6ea5a25c118b3cbe1c4d579a4e6 MD5 · raw file

  1. // Copyright 2019 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package statements
  5. import (
  6. "errors"
  7. "fmt"
  8. "reflect"
  9. "regexp"
  10. "strings"
  11. "github.com/xormplus/builder"
  12. "github.com/xormplus/xorm/core"
  13. "github.com/xormplus/xorm/dialects"
  14. "github.com/xormplus/xorm/internal/utils"
  15. "github.com/xormplus/xorm/schemas"
  16. )
  17. func (statement *Statement) genSelectSql(dialect dialects.Dialect, rownumber string) string {
  18. var sql = statement.RawSQL
  19. var orderBys = statement.OrderStr
  20. pLimitN := statement.LimitN
  21. if dialect.URI().DBType != schemas.MSSQL && dialect.URI().DBType != schemas.ORACLE {
  22. if statement.Start > 0 {
  23. sql = fmt.Sprintf("%v LIMIT %v OFFSET %v", sql, statement.LimitN, statement.Start)
  24. if pLimitN != nil {
  25. sql = fmt.Sprintf("%v LIMIT %v OFFSET %v", sql, *pLimitN, statement.Start)
  26. } else {
  27. sql = fmt.Sprintf("%v LIMIT 0 OFFSET %v", sql, *pLimitN)
  28. }
  29. } else if pLimitN != nil {
  30. sql = fmt.Sprintf("%v LIMIT %v", sql, statement.LimitN)
  31. }
  32. } else if dialect.URI().DBType == schemas.ORACLE {
  33. if statement.Start != 0 || pLimitN != nil {
  34. sql = fmt.Sprintf("SELECT aat.* FROM (SELECT at.*,ROWNUM %v FROM (%v) at WHERE ROWNUM <= %d) aat WHERE %v > %d",
  35. rownumber, sql, statement.Start+*pLimitN, rownumber, statement.Start)
  36. }
  37. } else {
  38. keepSelect := false
  39. var fullQuery string
  40. if statement.Start > 0 {
  41. fullQuery = fmt.Sprintf("SELECT sq.* FROM (SELECT ROW_NUMBER() OVER (ORDER BY %v) AS %v,", orderBys, rownumber)
  42. } else if pLimitN != nil {
  43. fullQuery = fmt.Sprintf("SELECT TOP %d", *pLimitN)
  44. } else {
  45. keepSelect = true
  46. }
  47. if !keepSelect {
  48. expr := `^\s*SELECT\s*`
  49. reg, err := regexp.Compile(expr)
  50. if err != nil {
  51. fmt.Println(err)
  52. }
  53. sql = strings.ToUpper(sql)
  54. if reg.MatchString(sql) {
  55. str := reg.FindAllString(sql, -1)
  56. fullQuery = fmt.Sprintf("%v %v", fullQuery, sql[len(str[0]):])
  57. }
  58. }
  59. if statement.Start > 0 {
  60. // T-SQL offset starts with 1, not like MySQL with 0;
  61. if pLimitN != nil {
  62. fullQuery = fmt.Sprintf("%v) AS sq WHERE %v BETWEEN %d AND %d", fullQuery, rownumber, statement.Start+1, statement.Start+*pLimitN)
  63. } else {
  64. fullQuery = fmt.Sprintf("%v) AS sq WHERE %v >= %d", fullQuery, rownumber, statement.Start+1)
  65. }
  66. } else {
  67. fullQuery = fmt.Sprintf("%v ORDER BY %v", fullQuery, orderBys)
  68. }
  69. if keepSelect {
  70. if len(orderBys) > 0 {
  71. sql = fmt.Sprintf("%v ORDER BY %v", sql, orderBys)
  72. }
  73. } else {
  74. sql = fullQuery
  75. }
  76. }
  77. return sql
  78. }
  79. func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []interface{}, error) {
  80. if len(sqlOrArgs) > 0 {
  81. return statement.ConvertSQLOrArgs(sqlOrArgs...)
  82. }
  83. if statement.RawSQL != "" {
  84. var dialect = statement.dialect
  85. rownumber := "xorm" + utils.NewShortUUID().String()
  86. sql := statement.genSelectSql(dialect, rownumber)
  87. params := statement.RawParams
  88. i := len(params)
  89. // var result []map[string]interface{}
  90. // var err error
  91. if i == 1 {
  92. vv := reflect.ValueOf(params[0])
  93. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  94. return sql, params, nil
  95. } else {
  96. sqlStr1, param, _ := core.MapToSlice(sql, params[0])
  97. return sqlStr1, param, nil
  98. }
  99. } else {
  100. return sql, params, nil
  101. }
  102. // return session.statement.RawSQL, session.statement.RawParams, nil
  103. }
  104. if len(statement.TableName()) <= 0 {
  105. return "", nil, ErrTableNotFound
  106. }
  107. var columnStr = statement.ColumnStr()
  108. if len(statement.SelectStr) > 0 {
  109. columnStr = statement.SelectStr
  110. } else {
  111. if statement.JoinStr == "" {
  112. if columnStr == "" {
  113. if statement.GroupByStr != "" {
  114. columnStr = statement.quoteColumnStr(statement.GroupByStr)
  115. } else {
  116. columnStr = statement.genColumnStr()
  117. }
  118. }
  119. } else {
  120. if columnStr == "" {
  121. if statement.GroupByStr != "" {
  122. columnStr = statement.quoteColumnStr(statement.GroupByStr)
  123. } else {
  124. columnStr = "*"
  125. }
  126. }
  127. }
  128. if columnStr == "" {
  129. columnStr = "*"
  130. }
  131. }
  132. if err := statement.ProcessIDParam(); err != nil {
  133. return "", nil, err
  134. }
  135. sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true)
  136. if err != nil {
  137. return "", nil, err
  138. }
  139. args := append(statement.joinArgs, condArgs...)
  140. // for mssql and use limit
  141. qs := strings.Count(sqlStr, "?")
  142. if len(args)*2 == qs {
  143. args = append(args, args...)
  144. }
  145. return sqlStr, args, nil
  146. }
  147. func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
  148. if statement.RawSQL != "" {
  149. return statement.GenRawSQL(), statement.RawParams, nil
  150. }
  151. statement.SetRefBean(bean)
  152. var sumStrs = make([]string, 0, len(columns))
  153. for _, colName := range columns {
  154. if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") {
  155. colName = statement.quote(colName)
  156. } else {
  157. colName = statement.ReplaceQuote(colName)
  158. }
  159. sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName))
  160. }
  161. sumSelect := strings.Join(sumStrs, ", ")
  162. if err := statement.mergeConds(bean); err != nil {
  163. return "", nil, err
  164. }
  165. sqlStr, condArgs, err := statement.genSelectSQL(sumSelect, true, true)
  166. if err != nil {
  167. return "", nil, err
  168. }
  169. return sqlStr, append(statement.joinArgs, condArgs...), nil
  170. }
  171. func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, error) {
  172. v := rValue(bean)
  173. isStruct := v.Kind() == reflect.Struct
  174. if isStruct {
  175. statement.SetRefBean(bean)
  176. }
  177. var columnStr = statement.ColumnStr()
  178. if len(statement.SelectStr) > 0 {
  179. columnStr = statement.SelectStr
  180. } else {
  181. // TODO: always generate column names, not use * even if join
  182. if len(statement.JoinStr) == 0 {
  183. if len(columnStr) == 0 {
  184. if len(statement.GroupByStr) > 0 {
  185. columnStr = statement.quoteColumnStr(statement.GroupByStr)
  186. } else {
  187. columnStr = statement.genColumnStr()
  188. }
  189. }
  190. } else {
  191. if len(columnStr) == 0 {
  192. if len(statement.GroupByStr) > 0 {
  193. columnStr = statement.quoteColumnStr(statement.GroupByStr)
  194. }
  195. }
  196. }
  197. }
  198. if len(columnStr) == 0 {
  199. columnStr = "*"
  200. }
  201. if isStruct {
  202. if err := statement.mergeConds(bean); err != nil {
  203. return "", nil, err
  204. }
  205. } else {
  206. if err := statement.ProcessIDParam(); err != nil {
  207. return "", nil, err
  208. }
  209. }
  210. sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true)
  211. if err != nil {
  212. return "", nil, err
  213. }
  214. return sqlStr, append(statement.joinArgs, condArgs...), nil
  215. }
  216. // GenCountSQL generates the SQL for counting
  217. func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interface{}, error) {
  218. if statement.RawSQL != "" {
  219. return statement.GenRawSQL(), statement.RawParams, nil
  220. }
  221. var condArgs []interface{}
  222. var err error
  223. if len(beans) > 0 {
  224. statement.SetRefBean(beans[0])
  225. if err := statement.mergeConds(beans[0]); err != nil {
  226. return "", nil, err
  227. }
  228. }
  229. var selectSQL = statement.SelectStr
  230. if len(selectSQL) <= 0 {
  231. if statement.IsDistinct {
  232. selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr())
  233. } else if statement.ColumnStr() != "" {
  234. selectSQL = fmt.Sprintf("count(%s)", statement.ColumnStr())
  235. } else {
  236. selectSQL = "count(*)"
  237. }
  238. }
  239. sqlStr, condArgs, err := statement.genSelectSQL(selectSQL, false, false)
  240. if err != nil {
  241. return "", nil, err
  242. }
  243. return sqlStr, append(statement.joinArgs, condArgs...), nil
  244. }
  245. func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderBy bool) (string, []interface{}, error) {
  246. var (
  247. distinct string
  248. dialect = statement.dialect
  249. quote = statement.quote
  250. fromStr = " FROM "
  251. top, mssqlCondi, whereStr string
  252. )
  253. if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
  254. distinct = "DISTINCT "
  255. }
  256. condSQL, condArgs, err := statement.GenCondSQL(statement.cond)
  257. if err != nil {
  258. return "", nil, err
  259. }
  260. if len(condSQL) > 0 {
  261. whereStr = " WHERE " + condSQL
  262. }
  263. if dialect.URI().DBType == schemas.MSSQL && strings.Contains(statement.TableName(), "..") {
  264. fromStr += statement.TableName()
  265. } else {
  266. fromStr += quote(statement.TableName())
  267. }
  268. if statement.TableAlias != "" {
  269. if dialect.URI().DBType == schemas.ORACLE {
  270. fromStr += " " + quote(statement.TableAlias)
  271. } else {
  272. fromStr += " AS " + quote(statement.TableAlias)
  273. }
  274. }
  275. if statement.JoinStr != "" {
  276. fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr)
  277. }
  278. pLimitN := statement.LimitN
  279. if dialect.URI().DBType == schemas.MSSQL {
  280. if pLimitN != nil {
  281. LimitNValue := *pLimitN
  282. top = fmt.Sprintf("TOP %d ", LimitNValue)
  283. }
  284. if statement.Start > 0 {
  285. var column string
  286. if len(statement.RefTable.PKColumns()) == 0 {
  287. for _, index := range statement.RefTable.Indexes {
  288. if len(index.Cols) == 1 {
  289. column = index.Cols[0]
  290. break
  291. }
  292. }
  293. if len(column) == 0 {
  294. column = statement.RefTable.ColumnsSeq()[0]
  295. }
  296. } else {
  297. column = statement.RefTable.PKColumns()[0].Name
  298. }
  299. if statement.needTableName() {
  300. if len(statement.TableAlias) > 0 {
  301. column = statement.TableAlias + "." + column
  302. } else {
  303. column = statement.TableName() + "." + column
  304. }
  305. }
  306. var orderStr string
  307. if needOrderBy && len(statement.OrderStr) > 0 {
  308. orderStr = " ORDER BY " + statement.OrderStr
  309. }
  310. var groupStr string
  311. if len(statement.GroupByStr) > 0 {
  312. groupStr = " GROUP BY " + statement.GroupByStr
  313. }
  314. mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s%s))",
  315. column, statement.Start, column, fromStr, whereStr, orderStr, groupStr)
  316. }
  317. }
  318. var buf strings.Builder
  319. fmt.Fprintf(&buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr)
  320. if len(mssqlCondi) > 0 {
  321. if len(whereStr) > 0 {
  322. fmt.Fprint(&buf, " AND ", mssqlCondi)
  323. } else {
  324. fmt.Fprint(&buf, " WHERE ", mssqlCondi)
  325. }
  326. }
  327. if statement.GroupByStr != "" {
  328. fmt.Fprint(&buf, " GROUP BY ", statement.GroupByStr)
  329. }
  330. if statement.HavingStr != "" {
  331. fmt.Fprint(&buf, " ", statement.HavingStr)
  332. }
  333. if needOrderBy && statement.OrderStr != "" {
  334. fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr)
  335. }
  336. if needLimit {
  337. if dialect.URI().DBType != schemas.MSSQL && dialect.URI().DBType != schemas.ORACLE {
  338. if statement.Start > 0 {
  339. if pLimitN != nil {
  340. fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", *pLimitN, statement.Start)
  341. } else {
  342. fmt.Fprintf(&buf, "LIMIT 0 OFFSET %v", statement.Start)
  343. }
  344. } else if pLimitN != nil {
  345. fmt.Fprint(&buf, " LIMIT ", *pLimitN)
  346. }
  347. } else if dialect.URI().DBType == schemas.ORACLE {
  348. if statement.Start != 0 || pLimitN != nil {
  349. oldString := buf.String()
  350. buf.Reset()
  351. rawColStr := columnStr
  352. if rawColStr == "*" {
  353. rawColStr = "at.*"
  354. }
  355. fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d",
  356. columnStr, rawColStr, oldString, statement.Start+*pLimitN, statement.Start)
  357. }
  358. }
  359. }
  360. if statement.IsForUpdate {
  361. return dialect.ForUpdateSQL(buf.String()), condArgs, nil
  362. }
  363. return buf.String(), condArgs, nil
  364. }
  365. func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interface{}, error) {
  366. if statement.RawSQL != "" {
  367. return statement.GenRawSQL(), statement.RawParams, nil
  368. }
  369. var sqlStr string
  370. var args []interface{}
  371. var joinStr string
  372. var err error
  373. if len(bean) == 0 {
  374. tableName := statement.TableName()
  375. if len(tableName) <= 0 {
  376. return "", nil, ErrTableNotFound
  377. }
  378. tableName = statement.quote(tableName)
  379. if len(statement.JoinStr) > 0 {
  380. joinStr = statement.JoinStr
  381. }
  382. if statement.Conds().IsValid() {
  383. condSQL, condArgs, err := statement.GenCondSQL(statement.Conds())
  384. if err != nil {
  385. return "", nil, err
  386. }
  387. if statement.dialect.URI().DBType == schemas.MSSQL {
  388. sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s WHERE %s", tableName, joinStr, condSQL)
  389. } else if statement.dialect.URI().DBType == schemas.ORACLE {
  390. sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE (%s) %s AND ROWNUM=1", tableName, joinStr, condSQL)
  391. } else {
  392. sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE %s LIMIT 1", tableName, joinStr, condSQL)
  393. }
  394. args = condArgs
  395. } else {
  396. if statement.dialect.URI().DBType == schemas.MSSQL {
  397. sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s", tableName, joinStr)
  398. } else if statement.dialect.URI().DBType == schemas.ORACLE {
  399. sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE ROWNUM=1", tableName, joinStr)
  400. } else {
  401. sqlStr = fmt.Sprintf("SELECT * FROM %s %s LIMIT 1", tableName, joinStr)
  402. }
  403. args = []interface{}{}
  404. }
  405. } else {
  406. beanValue := reflect.ValueOf(bean[0])
  407. if beanValue.Kind() != reflect.Ptr {
  408. return "", nil, errors.New("needs a pointer")
  409. }
  410. if beanValue.Elem().Kind() == reflect.Struct {
  411. if err := statement.SetRefBean(bean[0]); err != nil {
  412. return "", nil, err
  413. }
  414. }
  415. if len(statement.TableName()) <= 0 {
  416. return "", nil, ErrTableNotFound
  417. }
  418. statement.Limit(1)
  419. sqlStr, args, err = statement.GenGetSQL(bean[0])
  420. if err != nil {
  421. return "", nil, err
  422. }
  423. }
  424. return sqlStr, args, nil
  425. }
  426. func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interface{}, error) {
  427. if statement.RawSQL != "" {
  428. return statement.GenRawSQL(), statement.RawParams, nil
  429. }
  430. var sqlStr string
  431. var args []interface{}
  432. var err error
  433. if len(statement.TableName()) <= 0 {
  434. return "", nil, ErrTableNotFound
  435. }
  436. var columnStr = statement.ColumnStr()
  437. if len(statement.SelectStr) > 0 {
  438. columnStr = statement.SelectStr
  439. } else {
  440. if statement.JoinStr == "" {
  441. if columnStr == "" {
  442. if statement.GroupByStr != "" {
  443. columnStr = statement.quoteColumnStr(statement.GroupByStr)
  444. } else {
  445. columnStr = statement.genColumnStr()
  446. }
  447. }
  448. } else {
  449. if columnStr == "" {
  450. if statement.GroupByStr != "" {
  451. columnStr = statement.quoteColumnStr(statement.GroupByStr)
  452. } else {
  453. columnStr = "*"
  454. }
  455. }
  456. }
  457. if columnStr == "" {
  458. columnStr = "*"
  459. }
  460. }
  461. statement.cond = statement.cond.And(autoCond)
  462. sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true)
  463. if err != nil {
  464. return "", nil, err
  465. }
  466. args = append(statement.joinArgs, condArgs...)
  467. // for mssql and use limit
  468. qs := strings.Count(sqlStr, "?")
  469. if len(args)*2 == qs {
  470. args = append(args, args...)
  471. }
  472. return sqlStr, args, nil
  473. }