/_examples/signalhandling/main.go

https://github.com/c-bata/goptuna · Go · 104 lines · 86 code · 13 blank · 5 comment · 23 complexity · f38c98885c37caea54b27c0fd675076d MD5 · raw file

  1. package main
  2. import (
  3. "context"
  4. "log"
  5. "math"
  6. "os"
  7. "os/exec"
  8. "os/signal"
  9. "runtime"
  10. "sync"
  11. "syscall"
  12. "gorm.io/driver/sqlite"
  13. "gorm.io/gorm"
  14. "github.com/c-bata/goptuna"
  15. "github.com/c-bata/goptuna/rdb.v2"
  16. )
  17. func objective(trial goptuna.Trial) (float64, error) {
  18. ctx := trial.GetContext()
  19. x1, _ := trial.SuggestFloat("x1", -10, 10)
  20. x2, _ := trial.SuggestFloat("x2", -10, 10)
  21. cmd := exec.CommandContext(ctx, "sleep", "1")
  22. err := cmd.Run()
  23. if err != nil {
  24. return -1, err
  25. }
  26. return math.Pow(x1-2, 2) + math.Pow(x2+5, 2), nil
  27. }
  28. func main() {
  29. db, err := gorm.Open(sqlite.Open("db.sqlite3"), &gorm.Config{})
  30. if err != nil {
  31. log.Fatal("failed to open database:", err)
  32. }
  33. if sqlDB, err := db.DB(); err != nil {
  34. log.Fatal("failed to get sql.DB:", err)
  35. } else {
  36. sqlDB.SetMaxOpenConns(1)
  37. }
  38. err = rdb.RunAutoMigrate(db)
  39. if err != nil {
  40. log.Fatal("failed to run auto migrate:", err)
  41. }
  42. // create a study
  43. study, err := goptuna.CreateStudy(
  44. "goptuna-example",
  45. goptuna.StudyOptionStorage(rdb.NewStorage(db)),
  46. goptuna.StudyOptionDirection(goptuna.StudyDirectionMinimize),
  47. )
  48. if err != nil {
  49. log.Fatal("failed to create a study:", err)
  50. }
  51. // create a context with cancel function
  52. ctx, cancel := context.WithCancel(context.Background())
  53. defer cancel()
  54. study.WithContext(ctx)
  55. // set signal handler
  56. sigch := make(chan os.Signal, 1)
  57. defer close(sigch)
  58. signal.Notify(sigch, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
  59. var wg sync.WaitGroup
  60. wg.Add(1)
  61. go func() {
  62. defer wg.Done()
  63. sig, ok := <-sigch
  64. if !ok {
  65. return
  66. }
  67. cancel()
  68. log.Println("Catch a kill signal:", sig.String())
  69. }()
  70. // run optimize with multiple goroutine workers
  71. concurrency := runtime.NumCPU() - 1
  72. if concurrency == 0 {
  73. concurrency = 1
  74. }
  75. for i := 0; i < concurrency; i++ {
  76. wg.Add(1)
  77. go func() {
  78. defer wg.Done()
  79. err = study.Optimize(objective, 100/concurrency)
  80. if err != nil {
  81. log.Fatal("Optimize error:", err)
  82. }
  83. }()
  84. }
  85. wg.Wait()
  86. // print best hyper-parameters and the result
  87. v, _ := study.GetBestValue()
  88. params, _ := study.GetBestParams()
  89. log.Printf("Best evaluation=%f (x1=%f, x2=%f)",
  90. v, params["x1"].(float64), params["x2"].(float64))
  91. }