PageRenderTime 47ms CodeModel.GetById 17ms RepoModel.GetById 0ms app.codeStats 0ms

/uppsrc/MySql/MySql.cpp

http://upp-mirror.googlecode.com/
C++ | 494 lines | 443 code | 40 blank | 11 comment | 100 complexity | 8d59d60c5b00bf70bc3798aa18936761 MD5 | raw file
Possible License(s): LGPL-2.1, GPL-2.0, BSD-2-Clause, BSD-3-Clause, LGPL-3.0, GPL-3.0
  1. #include "MySql.h"
  2. #ifndef flagNOMYSQL
  3. NAMESPACE_UPP
  4. class MySqlConnection : public SqlConnection {
  5. protected:
  6. virtual void SetParam(int i, const Value& r);
  7. virtual bool Execute();
  8. virtual int GetRowsProcessed() const;
  9. virtual bool Fetch();
  10. virtual void GetColumn(int i, Ref f) const;
  11. virtual void Cancel();
  12. virtual Value GetInsertedId() const;
  13. virtual SqlSession& GetSession() const;
  14. virtual String GetUser() const;
  15. virtual String ToString() const;
  16. private:
  17. MySqlSession& session;
  18. MYSQL *mysql;
  19. Vector<String> param;
  20. MYSQL_RES *result;
  21. MYSQL_ROW row;
  22. unsigned long *len;
  23. int rows;
  24. int lastid;
  25. Buffer<bool> convert;
  26. String MakeQuery() const;
  27. void FreeResult();
  28. String EscapeString(const String& v);
  29. public:
  30. MySqlConnection(MySqlSession& session, MYSQL *mysql);
  31. virtual ~MySqlConnection() { Cancel(); }
  32. };
  33. bool MySqlSession::IsOpen() const { return mysql; }
  34. static const char *sEmpNull(const char *s) {
  35. return s && *s == '\0' ? NULL : s;
  36. }
  37. bool MySqlSession::Connect(const char *user, const char *password, const char *database,
  38. const char *host, int port, const char *socket) {
  39. mysql = mysql_init((MYSQL*) 0);
  40. if(mysql && mysql_real_connect(mysql, sEmpNull(host), sEmpNull(user),
  41. sEmpNull(password), sEmpNull(database), port,
  42. sEmpNull(socket), 0)) {
  43. Sql sql(*this);
  44. username = sql.Select("substring_index(USER(),'@',1)");
  45. mysql_set_character_set(mysql, "utf8");
  46. sql.Execute("SET NAMES 'utf8'");
  47. sql.Execute("SET CHARACTER SET utf8");
  48. return true;
  49. }
  50. Close();
  51. return false;
  52. }
  53. inline static const char *EmpNull(const String& s)
  54. {
  55. return *s ? (const char *)s : 0;
  56. }
  57. bool MySqlSession::Open(const char *connect) {
  58. String user, pwd, socket;
  59. String database = Null;
  60. String host = Null;
  61. int port = MYSQL_PORT;
  62. level = 0;
  63. const char *p = connect, *b;
  64. for(b = p; *p && *p != '/' && *p != '@'; p++)
  65. ;
  66. user = String(b, p);
  67. if(*p == '/')
  68. {
  69. for(b = ++p; *p && *p != '@'; p++)
  70. ;
  71. pwd = String(b, p);
  72. }
  73. if(*p == '@')
  74. {
  75. for(b = ++p; *p && *p != '/' && *p != ':' && *p != ','; p++)
  76. ;
  77. if(*p == '/' || *p == 0)
  78. {
  79. database = String(b, p);
  80. if(*p)
  81. p++;
  82. b = p;
  83. }
  84. while(*p && *p != ':' && *p != ',')
  85. p++;
  86. host = String(b, p);
  87. if(*p == ':')
  88. { // port
  89. if(!IsDigit(*++p))
  90. throw Exc("Port number expected.");
  91. port = stou(p, &p);
  92. }
  93. if(*p == ',') // socket
  94. socket = p + 1;
  95. }
  96. return Connect(EmpNull(user), EmpNull(pwd),
  97. EmpNull(database), EmpNull(host), port, EmpNull(socket));
  98. }
  99. void MySqlSession::Close() {
  100. SessionClose();
  101. if(mysql) {
  102. mysql_close(mysql);
  103. mysql = NULL;
  104. level = 0;
  105. }
  106. }
  107. void MySqlSession::Begin()
  108. {
  109. static const char btrans[] = "start transaction";
  110. if(trace)
  111. *trace << btrans << ";\n";
  112. if(mysql_query(mysql, btrans))
  113. SetError(mysql_error(mysql), btrans);
  114. level++;
  115. }
  116. void MySqlSession::Commit()
  117. {
  118. static const char ctrans[] = "commit";
  119. if(trace)
  120. *trace << ctrans << ";\n";
  121. if(mysql_query(mysql, ctrans))
  122. SetError(mysql_error(mysql), ctrans);
  123. level--;
  124. }
  125. void MySqlSession::Rollback()
  126. {
  127. static const char rtrans[] = "rollback";
  128. if(trace)
  129. *trace << rtrans << ";\n";
  130. if(mysql_query(mysql, rtrans))
  131. SetError(mysql_error(mysql), rtrans);
  132. if(level > 0) level--;
  133. }
  134. int MySqlSession::GetTransactionLevel() const
  135. {
  136. return level;
  137. }
  138. static Vector<String> FetchList(Sql& cursor, bool upper = false)
  139. {
  140. Vector<String> out;
  141. String s;
  142. while(cursor.Fetch(s))
  143. out.Add(upper ? ToUpper(s) : s);
  144. return out;
  145. }
  146. Vector<String> MySqlSession::EnumUsers()
  147. {
  148. Vector<String> out;
  149. Sql cursor(*this);
  150. if(Select(SqlId("USER")).From(SqlId("MYSQL.USER")).Execute(cursor))
  151. out = FetchList(cursor);
  152. return out;
  153. }
  154. Vector<String> MySqlSession::EnumDatabases()
  155. {
  156. Vector<String> out;
  157. Sql cursor(*this);
  158. if(cursor.Execute("show databases"))
  159. out = FetchList(cursor); // 06-09-12 cxl: was false; In Linux, names are case sensitive
  160. return out;
  161. }
  162. Vector<String> MySqlSession::EnumTables(String database)
  163. {
  164. Vector<String> out;
  165. Sql cursor(*this);
  166. if(cursor.Execute("show tables from " + database))
  167. out = FetchList(cursor); // 06-09-12 cxl: was false; In Linux, names are case sensitive
  168. return out;
  169. }
  170. SqlConnection *MySqlSession::CreateConnection() {
  171. return new MySqlConnection(*this, mysql);
  172. }
  173. MySqlConnection::MySqlConnection(MySqlSession& session, MYSQL *mysql)
  174. : session(session), mysql(mysql) {
  175. result = NULL;
  176. lastid = 0;
  177. }
  178. String MySqlConnection::EscapeString(const String& v)
  179. {
  180. StringBuffer b(v.GetLength() * 2 + 3);
  181. char *q = b;
  182. *q = '\"';
  183. int n = mysql_real_escape_string(mysql, q + 1, v, v.GetLength());
  184. q[1 + n] = '\"';
  185. b.SetCount(2 + n); //TODO - check this fix
  186. return b;
  187. }
  188. void MySqlConnection::SetParam(int i, const Value& r) {
  189. String p;
  190. if(IsNull(r))
  191. p = "NULL";
  192. else
  193. switch(r.GetType()) {
  194. case 34:
  195. p = EscapeString(SqlRaw(r));
  196. break;
  197. case WSTRING_V:
  198. case STRING_V:
  199. p = EscapeString(ToCharset(CHARSET_UTF8, r));
  200. break;
  201. case BOOL_V:
  202. case INT_V:
  203. p = Format("%d", int(r));
  204. break;
  205. case DOUBLE_V:
  206. p = FormatDouble(double(r), 20);
  207. break;
  208. case DATE_V: {
  209. Date d = r;
  210. p = Format("\'%04d-%02d-%02d\'", d.year, d.month, d.day);
  211. }
  212. break;
  213. case TIME_V: {
  214. Time t = r;
  215. p = Format("\'%04d-%02d-%02d %02d:%02d:%02d\'",
  216. t.year, t.month, t.day, t.hour, t.minute, t.second);
  217. }
  218. break;
  219. default:
  220. NEVER();
  221. }
  222. param.At(i, p);
  223. }
  224. bool MySqlConnection::Execute() {
  225. String query;
  226. int pi = 0;
  227. const char *s = statement;
  228. while(s < statement.End())
  229. if(*s == '\'' || *s == '\"')
  230. s = MySqlReadString(s, query);
  231. else {
  232. if(*s == '?')
  233. query.Cat(param[pi++]);
  234. else
  235. query.Cat(*s);
  236. s++;
  237. }
  238. Cancel();
  239. /* Stream *trace = session.GetTrace();
  240. dword time;
  241. if(session.IsTraceTime())
  242. time = GetTickCount();*/
  243. if(mysql_query(mysql, query)) {
  244. session.SetError(mysql_error(mysql), query);
  245. return false;
  246. }
  247. result = mysql_store_result(mysql);
  248. rows = (int)mysql_affected_rows(mysql);
  249. if(result) {
  250. DDUMP(rows);
  251. int fields = mysql_num_fields(result);
  252. info.SetCount(fields);
  253. convert.Alloc(fields, false);
  254. for(int i = 0; i < fields; i++) {
  255. MYSQL_FIELD *field = mysql_fetch_field_direct(result, i);
  256. SqlColumnInfo& f = info[i];
  257. f.name = field->name;
  258. switch(field->type) {
  259. case FIELD_TYPE_TINY:
  260. case FIELD_TYPE_SHORT:
  261. case FIELD_TYPE_LONG:
  262. case FIELD_TYPE_INT24:
  263. f.type = INT_V;
  264. break;
  265. case FIELD_TYPE_LONGLONG:
  266. case FIELD_TYPE_DECIMAL:
  267. case FIELD_TYPE_FLOAT:
  268. case FIELD_TYPE_DOUBLE:
  269. f.type = DOUBLE_V;
  270. break;
  271. case FIELD_TYPE_DATE:
  272. f.type = DATE_V;
  273. break;
  274. case FIELD_TYPE_DATETIME:
  275. case FIELD_TYPE_TIMESTAMP:
  276. f.type = TIME_V;
  277. break;
  278. case FIELD_TYPE_VAR_STRING:
  279. case FIELD_TYPE_STRING:
  280. convert[i] = true;
  281. default:
  282. f.type = STRING_V;
  283. break;
  284. }
  285. f.width = field->length;
  286. f.scale = f.precision = 0;
  287. }
  288. }
  289. else {
  290. lastid = (int)mysql_insert_id(mysql);
  291. if(lastid) {
  292. SqlColumnInfo& f = info.Add();
  293. f.width = f.scale = f.precision = 0;
  294. f.binary = false;
  295. f.type = DOUBLE_V;
  296. f.name = "LAST_INSERT_ID";
  297. rows = 1;
  298. }
  299. }
  300. return true;
  301. }
  302. int MySqlConnection::GetRowsProcessed() const {
  303. return rows;
  304. }
  305. Value MySqlConnection::GetInsertedId() const
  306. {
  307. return lastid;
  308. }
  309. bool MySqlConnection::Fetch() {
  310. if(result) {
  311. row = mysql_fetch_row(result);
  312. if(row) {
  313. len = mysql_fetch_lengths(result);
  314. return true;
  315. }
  316. }
  317. else
  318. if(lastid && rows > 0) {
  319. rows--;
  320. return true;
  321. }
  322. FreeResult();
  323. return false;
  324. }
  325. // 0123456789012345678
  326. // YYYY-MM-DD HH-MM-SS
  327. static Date sDate(const char *s) {
  328. return Date(atoi(s), atoi(s + 5), atoi(s + 8));
  329. }
  330. void MySqlConnection::GetColumn(int i, Ref f) const {
  331. if(lastid) {
  332. f = lastid;
  333. return;
  334. }
  335. const char *s = row[i];
  336. if(s == NULL)
  337. f = Null;
  338. else {
  339. switch(info[i].type) {
  340. case INT_V:
  341. f = atoi(s);
  342. break;
  343. case DOUBLE_V:
  344. f = ScanDouble(s, NULL, true);
  345. break;
  346. case DATE_V:
  347. f = Value(sDate(s));
  348. break;
  349. case TIME_V: {
  350. Time t = ToTime(sDate(s));
  351. t.hour = atoi(s + 11);
  352. t.minute = atoi(s + 14);
  353. t.second = atoi(s + 17);
  354. f = Value(t);
  355. }
  356. break;
  357. default:
  358. if(convert[i])
  359. f = Value(ToCharset(CHARSET_DEFAULT, String(s, len[i]), CHARSET_UTF8));
  360. else
  361. f = Value(String(s, len[i]));
  362. break;
  363. }
  364. }
  365. }
  366. void MySqlConnection::FreeResult() {
  367. lastid = 0;
  368. if(result) {
  369. mysql_free_result(result);
  370. result = NULL;
  371. }
  372. }
  373. void MySqlConnection::Cancel() {
  374. param.Clear();
  375. info.Clear();
  376. rows = 0;
  377. FreeResult();
  378. }
  379. SqlSession& MySqlConnection::GetSession() const { return session; }
  380. String MySqlConnection::GetUser() const { return session.GetUser(); }
  381. String MySqlConnection::ToString() const { return statement; }
  382. String MySqlTextType(int n) {
  383. return n < 256 ? Format("varchar(%d)", n) : String("text");
  384. }
  385. const char *MySqlReadString(const char *s, String& stmt) {
  386. stmt.Cat(*s);
  387. int c = *s++;
  388. for(;;) {
  389. if(*s == '\0') break;
  390. else
  391. if(*s == '\'' && s[1] == '\'') {
  392. stmt.Cat('\'');
  393. s += 2;
  394. }
  395. // else
  396. // if(*s == '\"' && s[1] == '\"') {
  397. // stmt.Cat('\"');
  398. // s += 2;
  399. // }
  400. else
  401. if(*s == c) {
  402. stmt.Cat(c);
  403. s++;
  404. break;
  405. }
  406. else
  407. if(*s == '\\') {
  408. stmt.Cat('\\');
  409. if(*++s)
  410. stmt.Cat(*s++);
  411. }
  412. else
  413. stmt.Cat(*s++);
  414. }
  415. return s;
  416. }
  417. bool MySqlPerformScript(const String& txt, StatementExecutor& se, Gate2<int, int> progress_canceled) {
  418. const char *text = txt;
  419. for(;;) {
  420. String stmt;
  421. while(*text <= 32 && *text > 0) text++;
  422. if(*text == '\0') break;
  423. for(;;) {
  424. if(*text == '\0')
  425. break;
  426. if(*text == ';')
  427. break;
  428. else
  429. if(*text == '\'')
  430. text = MySqlReadString(text, stmt);
  431. else
  432. if(*text == '\"')
  433. text = MySqlReadString(text, stmt);
  434. else
  435. stmt.Cat(*text++);
  436. }
  437. if(progress_canceled(text - txt.Begin(), txt.GetLength()))
  438. return false;
  439. if(!se.Execute(stmt))
  440. return false;
  441. if(*text) text++;
  442. }
  443. return true;
  444. }
  445. bool MySqlUpdateSchema(const SqlSchema& sch, int i, StatementExecutor& se) {
  446. if(sch.UpdateNormalFile(i)) {
  447. MySqlPerformScript(LoadFile(sch.NormalFileName(i + 1)), se);
  448. sch.UpdateNormalFile(i + 1);
  449. return MySqlPerformScript(sch.Script(i), se);
  450. }
  451. return true;
  452. }
  453. END_UPP_NAMESPACE
  454. #endif