/go/lib/infra/modules/db/sqlite.go

https://github.com/netsec-ethz/scion · Go · 122 lines · 90 code · 7 blank · 25 comment · 34 complexity · 2b7f371caf20a20fbc2544855a70ebc9 MD5 · raw file

  1. // Copyright 2018 ETH Zurich, Anapaya Systems
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package db
  15. import (
  16. "database/sql"
  17. "fmt"
  18. "net/url"
  19. "github.com/scionproto/scion/go/lib/serrors"
  20. )
  21. // NewSqlite returns a new SQLite backend opening a database at the given path. If
  22. // no database exists a new database is be created. If the schema version of the
  23. // stored database is different from schemaVersion, an error is returned.
  24. func NewSqlite(path string, schema string, schemaVersion int) (*sql.DB, error) {
  25. var err error
  26. if path == "" {
  27. return nil, serrors.New("Empty path not allowed for sqlite")
  28. }
  29. db, err := open(path)
  30. if err != nil {
  31. return nil, err
  32. }
  33. // On future errors, close the sql database before exiting
  34. defer func() {
  35. if err != nil {
  36. db.Close()
  37. }
  38. }()
  39. // prevent weird errors. (see https://stackoverflow.com/a/35805826)
  40. db.SetMaxOpenConns(1)
  41. // Check the schema version and set up new DB if necessary.
  42. var existingVersion int
  43. err = db.QueryRow("PRAGMA user_version;").Scan(&existingVersion)
  44. if err != nil {
  45. return nil, serrors.WrapStr("Failed to check schema version", err,
  46. "path", path)
  47. }
  48. if existingVersion == 0 {
  49. if err = setup(db, schema, schemaVersion, path); err != nil {
  50. return nil, err
  51. }
  52. } else if existingVersion != schemaVersion {
  53. return nil, serrors.New("Database schema version mismatch",
  54. "expected", schemaVersion, "have", existingVersion, "path", path)
  55. }
  56. return db, nil
  57. }
  58. func open(path string) (*sql.DB, error) {
  59. var err error
  60. u, err := url.Parse(path)
  61. if err != nil {
  62. return nil, serrors.WrapStr("invalid connection path", err, "path", path)
  63. }
  64. q := u.Query()
  65. // Add foreign_key parameter to path to enable foreign key support.
  66. q.Set("_foreign_keys", "1")
  67. // prevent weird errors. (see https://stackoverflow.com/a/35805826)
  68. q.Set("_journal_mode", "WAL")
  69. u.RawQuery = q.Encode()
  70. path = u.String()
  71. db, err := sql.Open("sqlite3", path)
  72. if err != nil {
  73. return nil, serrors.WrapStr("Couldn't open SQLite database", err, "path", path)
  74. }
  75. // On future errors, close the sql database before exiting
  76. defer func() {
  77. if err != nil {
  78. db.Close()
  79. }
  80. }()
  81. // Make sure DB is reachable
  82. if err = db.Ping(); err != nil {
  83. return nil, serrors.WrapStr("Initial DB ping failed, connection broken?", err,
  84. "path", path)
  85. }
  86. // Ensure foreign keys are supported and enabled.
  87. var enabled bool
  88. err = db.QueryRow("PRAGMA foreign_keys;").Scan(&enabled)
  89. if err == sql.ErrNoRows {
  90. return nil, serrors.WrapStr("Foreign keys not supported", err,
  91. "path", path)
  92. }
  93. if err != nil {
  94. return nil, serrors.WrapStr("Failed to check for foreign key support", err,
  95. "path", path)
  96. }
  97. if !enabled {
  98. db.Close()
  99. return nil, serrors.New("Failed to enable foreign key support",
  100. "path", path)
  101. }
  102. return db, nil
  103. }
  104. func setup(db *sql.DB, schema string, schemaVersion int, path string) error {
  105. _, err := db.Exec(schema)
  106. if err != nil {
  107. return serrors.WrapStr("Failed to set up SQLite database", err, "path", path)
  108. }
  109. // Write schema version to database.
  110. _, err = db.Exec(fmt.Sprintf("PRAGMA user_version = %d", schemaVersion))
  111. if err != nil {
  112. return serrors.WrapStr("Failed to write schema version", err, "path", path)
  113. }
  114. return nil
  115. }