/src/dl4clj/datasets/new_datasets.clj

https://github.com/engagor/dl4clj · Clojure · 48 lines · 37 code · 7 blank · 4 comment · 1 complexity · e51fdc9863f00639d14ad4ef2ca4ddaa MD5 · raw file

  1. (ns dl4clj.datasets.new-datasets
  2. (:import [org.nd4j.linalg.dataset DataSet]
  3. [org.nd4j.linalg.dataset MultiDataSet])
  4. (:require [dl4clj.utils :refer [contains-many? obj-or-code?]]
  5. [clojure.core.match :refer [match]]
  6. [nd4clj.linalg.factory.nd4j :refer [vec-or-matrix->indarray]]))
  7. (defn new-ds
  8. "Creates a DataSet object with the specified input and output.
  9. if they are not supplied, creates a new empty DataSet object
  10. :input (vec, matrix or INDArray), the input to a model
  11. :output (vec, matrix or INDArray), the targets/labels for the supplied input
  12. :as-code? (boolean), return the dl4j obj or the code for creating it
  13. see: http://nd4j.org/doc/org/nd4j/linalg/dataset/DataSet.html"
  14. [& {:keys [input output as-code?]
  15. :or {as-code? true}
  16. :as opts}]
  17. (let [code (if (contains-many? opts :input :output)
  18. `(DataSet. (vec-or-matrix->indarray ~input) (vec-or-matrix->indarray ~output))
  19. `(DataSet.))]
  20. (obj-or-code? as-code? code)))
  21. (defn new-multi-ds
  22. "a dataset that contains multiple datasets
  23. see: http://nd4j.org/doc/org/nd4j/linalg/dataset/MultiDataSet.html"
  24. ;; come back and beef up this doc string
  25. ;; also ensure the arrays are INDArrays of INDarrays
  26. ;; bc the constructor can accept these as just a single INDArray
  27. ;; make sure this is documented
  28. [& {:keys [features labels features-mask labels-mask as-code?]
  29. :or {as-code? true}
  30. :as opts}]
  31. (let [f `(vec-or-matrix->indarray ~features)
  32. l `(vec-or-matrix->indarray ~labels)
  33. code (match [opts]
  34. [{:features _ :labels _ :features-mask _ :labels-mask _}]
  35. `(MultiDataSet. ~f ~l (vec-or-matrix->indarray ~features-mask)
  36. (vec-or-matrix->indarray ~labels-mask))
  37. [{:features _ :labels _}]
  38. `(MultiDataSet. ~f ~l)
  39. :else
  40. `(MultiDataSet.))]
  41. (obj-or-code? as-code? code)))