PageRenderTime 44ms CodeModel.GetById 16ms RepoModel.GetById 0ms app.codeStats 0ms

/src/main/clojure/core/matrix.clj

https://github.com/bmabey/core.matrix
Clojure | 1045 lines | 844 code | 133 blank | 68 comment | 35 complexity | 3e5fb6ce82e18d2559129fe04691351d MD5 | raw file
  1. (ns core.matrix
  2. (:use core.matrix.utils)
  3. (:require [core.matrix.impl double-array ndarray persistent-vector wrappers])
  4. (:require [core.matrix.impl sequence]) ;; TODO: figure out if we want this?
  5. (:require [core.matrix.multimethods :as mm])
  6. (:require [core.matrix.protocols :as mp])
  7. (:require [core.matrix.implementations :as imp])
  8. (:require [core.matrix.impl.mathsops :as mops]))
  9. ;; ==================================================================================
  10. ;; core.matrix API namespace
  11. ;;
  12. ;; This is the public API for core.matrix
  13. ;;
  14. ;; General handling of operations is as follows:
  15. ;;
  16. ;; 1. user calls public AI function defined in core.matrix
  17. ;; 2. core.matrix function delegates to a protocol for the appropriate function
  18. ;; with protocols as defined in the core.matrix.protocols namespace. In most cases
  19. ;; core.matrix will try to delagate as quickly as possible to the implementation.
  20. ;; 3. The underlying matrix implementation implements the protocol to handle the API
  21. ;; function call
  22. ;; 4. It's up to the implementation to decide what to do then
  23. ;; 5. If the implementation does not understand one or more parameters, then it is
  24. ;; expected to call the multimethod version in core.matrix.multimethods as this
  25. ;; will allow an alternative implementation to be found via multiple dispatch
  26. ;;
  27. ;; ==================================================================================
  28. (set! *warn-on-reflection* true)
  29. (set! *unchecked-math* true)
  30. ;; =============================================================
  31. ;; matrix construction functions
  32. (declare current-implementation)
  33. (declare current-implementation-object)
  34. (def ^:dynamic *matrix-implementation* :persistent-vector)
  35. (defn matrix
  36. "Constructs a matrix from the given data.
  37. The data may be in one of the following forms:
  38. - Nested sequences, e.g. Clojure vectors
  39. - A valid existing matrix
  40. If implementation is not specified, uses the current matrix library as specified
  41. in *matrix-implementation*"
  42. ([data]
  43. (if-let [m (current-implementation-object)]
  44. (mp/construct-matrix m data)
  45. (error "No core.matrix implementation available")))
  46. ([implementation data]
  47. (mp/construct-matrix (imp/get-canonical-object implementation) data)))
  48. (defn array
  49. "Constructs a new n-dimensional array from the given data.
  50. The data may be in one of the following forms:
  51. - Nested sequences, e.g. Clojure vectors
  52. - A valid existing array
  53. If implementation is not specified, uses the current matrix library as specified
  54. in *matrix-implementation*"
  55. ([data]
  56. (if-let [m (current-implementation-object)]
  57. (mp/construct-matrix m data)
  58. (error "No core.matrix implementation available")))
  59. ([implementation data]
  60. (mp/construct-matrix (imp/get-canonical-object implementation) data)))
  61. (defn new-vector
  62. "Constructs a new zero-filled vector with the given length"
  63. ([length]
  64. (if-let [m (current-implementation-object)]
  65. (mp/new-vector m length)
  66. (error "No core.matrix implementation available")))
  67. ([length implementation]
  68. (mp/new-vector (imp/get-canonical-object implementation) length)))
  69. (defn new-matrix
  70. "Constructs a new zero-filled matrix with the given dimensions"
  71. ([rows columns]
  72. (if-let [ik (current-implementation)]
  73. (mp/new-matrix (imp/get-canonical-object ik) rows columns)
  74. (error "No core.matrix implementation available"))))
  75. (defn new-array
  76. "Creates a new array with the given dimensions. "
  77. ([length] (new-vector length))
  78. ([rows columns] (new-matrix rows columns))
  79. ([dim-1 dim-2 & more-dim]
  80. (if-let [ik (current-implementation)]
  81. (mp/new-matrix-nd (imp/get-canonical-object ik) (cons dim-1 (cons dim-2 more-dim)))
  82. (error "No core.matrix implementation available"))))
  83. (defn row-matrix
  84. "Constucts a row matrix with the given values. The returned matrix is a 2D 1xN row matrix."
  85. ([data]
  86. (if-let [ik (current-implementation)]
  87. (mp/construct-matrix (imp/get-canonical-object ik) (vector data))
  88. (error "No core.matrix implementation available")))
  89. ([implementation data]
  90. (mp/construct-matrix (imp/get-canonical-object implementation) (vector data))))
  91. (defn column-matrix
  92. "Constucts a column matrix with the given values. The returned matrix is a 2D Nx1 column matrix."
  93. ([data]
  94. (if-let [ik (current-implementation)]
  95. (mp/construct-matrix (imp/get-canonical-object ik) (map vector data))
  96. (error "No core.matrix implementation available")))
  97. ([implementation data]
  98. (mp/construct-matrix (imp/get-canonical-object implementation) (map vector data))))
  99. (defn identity-matrix
  100. "Constructs a 2D identity matrix with the given number or rows"
  101. ([dims]
  102. (mp/identity-matrix (current-implementation-object) dims))
  103. ([implementation dims]
  104. (mp/identity-matrix (imp/get-canonical-object implementation) dims)))
  105. (defn diagonal-matrix
  106. "Constructs a 2D diagonal matrix with the given values on the main diagonal.
  107. diagonal-values may be a vector or any Clojure sequence of values."
  108. ([diagonal-values]
  109. (mp/diagonal-matrix (current-implementation-object) diagonal-values))
  110. ([implementation diagonal-values]
  111. (mp/diagonal-matrix (imp/get-canonical-object implementation) diagonal-values)))
  112. (defn compute-matrix
  113. "Creates a matrix with the specified shape, and each element specified by (f i j k...)
  114. Where i, j, k... are the index positions of each element in the matrix"
  115. ([shape f]
  116. (compute-matrix (current-implementation-object) shape f))
  117. ([implementation shape f]
  118. (let [m (imp/get-canonical-object implementation)]
  119. (TODO))))
  120. ;; ======================================
  121. ;; matrix assignment and copying
  122. (defn assign!
  123. "Assigns a value to a matrix.
  124. Returns the mutated matrix"
  125. ([m a]
  126. (mp/assign! m a)
  127. m))
  128. (defn clone
  129. "Constructs a clone of the matrix, using the same implementation. This function is intended to
  130. allow safe defensive copying of matrices / vectors.
  131. Guarantees that:
  132. 1. Mutating the returned matrix will not modify any other matrix (defensive copy)
  133. 2. The return matrix will be mutable, if the implementation supports mutable matrices.
  134. A matrix implementation which only provides immutable matrices may safely return the same matrix."
  135. ([m]
  136. (mp/clone m)))
  137. (defn to-nested-vectors
  138. "Converts an array to nested vectors.
  139. The depth of nesting is equal to the dimensionality of the array."
  140. ([m]
  141. (mp/convert-to-nested-vectors m)))
  142. ;; ==============================
  143. ;; Matrix predicates and querying
  144. (defn array?
  145. "Returns true if the parameter is an N-dimensional array, for any N>=1"
  146. ([m]
  147. (> (mp/dimensionality m) 0)))
  148. (defn matrix?
  149. "Returns true if parameter is a valid matrix (dimensionality == 2)"
  150. ([m]
  151. (== (mp/dimensionality m) 2)))
  152. (defn vec?
  153. "Returns true if the parameter is a vector"
  154. ([m]
  155. (mp/is-vector? m)))
  156. (defn scalar?
  157. "Returns true if the parameter is a scalar (zero dimensionality, acceptable as matrix value)."
  158. ([m]
  159. (mp/is-scalar? m)))
  160. (defn element-type
  161. "Returns the class of elements in the array."
  162. ([m]
  163. (mp/element-type m)))
  164. (defn dimensionality
  165. ;; TODO: alternative names to consider: order, tensor-rank?
  166. "Returns the dimensionality (number of array dimensions) of a matrix / array"
  167. ([m]
  168. (mp/dimensionality m)))
  169. (defn row-count
  170. "Returns the number of rows in a matrix (must be 1D or more)"
  171. ([m]
  172. (mp/dimension-count m 0)))
  173. (defn column-count
  174. "Returns the number of columns in a matrix (must be 2D or more)"
  175. ([m]
  176. (mp/dimension-count m 1)))
  177. (defn dimension-count
  178. "Returns the size of the specified dimension in a matrix."
  179. ([m dim]
  180. (mp/dimension-count m dim)))
  181. (defn square?
  182. "Returns true if matrix is square (2D with same number of rows and columns)"
  183. ([m]
  184. (and
  185. (== 2 (mp/dimensionality m))
  186. (== (mp/dimension-count m 0) (mp/dimension-count m 1)))))
  187. (defn row-matrix?
  188. "Returns true if a matrix is a row-matrix (i.e. has exactly one row)"
  189. ([m]
  190. (and (== (mp/dimensionality m) 2)
  191. (== 1 (mp/dimension-count m 0)))))
  192. (defn column-matrix?
  193. "Returns true if a matrix is a column-matrix (i.e. has exactly one column)"
  194. ([m]
  195. (and (== (mp/dimensionality m) 2)
  196. (== 1 (mp/dimension-count m 1)))))
  197. (defn shape
  198. "Returns the shape of a matrix, i.e. the dimension sizes for all dimensions.
  199. Result may be a sequence or Java array, to allow implemenations flexibility to return
  200. their own internal representation of matrix shape.
  201. You are guaranteed however that you can call `seq` on this to get a sequence of dimension sizes."
  202. ([m]
  203. (mp/get-shape m)))
  204. (defn mutable?
  205. "Returns true if the matrix is mutable, i.e. supports setting of values"
  206. ([m]
  207. (and (satisfies? mp/PIndexedSetting m) (mp/is-mutable? m))))
  208. (defn supports-dimensionality?
  209. "Returns true if the implementation for a given matrix supports a specific dimensionality, i.e.
  210. can create and manipulate matrices with the given number of dimensions"
  211. ([m dimension-count]
  212. (mp/supports-dimensionality? m dimension-count)))
  213. (defn- broadcast-shape*
  214. ([a b]
  215. (cond
  216. (nil? a) (or b '())
  217. (nil? b) a
  218. (== 1 (first a)) (broadcast-shape* (first b) (next a) (next b))
  219. (== 1 (first b)) (broadcast-shape* (first a) (next a) (next b))
  220. (== (first a) (first b)) (broadcast-shape* (first a) (next a) (next b))
  221. :else nil))
  222. ([prefix a b]
  223. (if (or a b)
  224. (let [r (broadcast-shape* a b)]
  225. (if r (cons prefix r) nil))
  226. (cons prefix nil))))
  227. (defn broadcast-shape
  228. "Returns the smallest compatible shape that shapes a and b can both broadcast to.
  229. Returns nil if this is not possible (i.e. the shapes are incompatible).
  230. Returns an empty list if both shape sequences are empty (i.e. represent scalars)"
  231. ([a b]
  232. (let [a (seq (reverse a))
  233. b (seq (reverse b))
  234. r (broadcast-shape* a b)]
  235. (if r (reverse r) nil))))
  236. ;; =======================================
  237. ;; Conversions
  238. (defn to-double-array
  239. "Returns a double array containing the values of m in row-major order.
  240. If want-copy is true, will guarantee a new double array (defensive copy).
  241. If want-copy is false, will return the internal array used by m, or nil if not supported
  242. by the implementation.
  243. If want copy is not sepcified, will return either a copy or the internal array"
  244. ([m]
  245. (mp/to-double-array m))
  246. ([m want-copy?]
  247. (let [arr (mp/as-double-array m)]
  248. (if want-copy?
  249. (if arr (copy-double-array arr) (mp/to-double-array m))
  250. arr))))
  251. ;; =======================================
  252. ;; matrix access
  253. (defn mget
  254. "Gets a scalar value from a matrix at a specified position. Supports any number of matrix dimensions."
  255. ([m]
  256. (if (mp/is-scalar? m)
  257. m
  258. (error "Can't mget from a non-scalar value without indexes")))
  259. ([m x]
  260. (mp/get-1d m x))
  261. ([m x y]
  262. (mp/get-2d m x y))
  263. ([m x y & more]
  264. (mp/get-nd m (cons x (cons y more)))))
  265. (defn mset!
  266. "Sets a scalar value in a matrix at a specified position. Supports any number of matrix dimensions.
  267. Will throw an error if the matrix is not mutable."
  268. ([m v]
  269. (if (mp/is-scalar? m)
  270. (error "Can't set a scalar value!")
  271. (error "Can't mset! without indexes on array of dimensionality: " (dimensionality m))))
  272. ([m x v]
  273. (mp/set-1d m x v))
  274. ([m x y v]
  275. (mp/set-2d m x y v))
  276. ([m x y z & more]
  277. (mp/set-nd m (cons x (cons y (cons z (butlast more)))) (last more))))
  278. (defn get-row
  279. "Gets a row of a 2D matrix.
  280. May return a mutable view if supported by the implementation."
  281. ([m x]
  282. (mp/get-row m x)))
  283. (defn get-column
  284. "Gets a column of a 2D matrix.
  285. May return a mutable view if supported by the implementation."
  286. ([m y]
  287. (mp/get-column m y)))
  288. (defn coerce
  289. "Coerces param to a format usable by a specific matrix implementation.
  290. If param is already in a format deemed usable by the implementation, returns it unchanged."
  291. ([m param]
  292. (or
  293. (mp/coerce-param m param)
  294. (mp/coerce-param m (mp/convert-to-nested-vectors param)))))
  295. ;; =====================================
  296. ;; matrix slicing and views
  297. (defn sub-matrix
  298. "Gets a view of a submatrix, for a set of index-ranges.
  299. Index ranges should be a sequence of [start, length] pairs.
  300. Index ranges can be nil (gets the whole range) "
  301. ([m index-ranges]
  302. (TODO)))
  303. (defn sub-vector
  304. "Gets a view of part of a vector. The view maintains a reference to the original,
  305. so can be used to modify the original vector if it is mutable."
  306. ([m start length]
  307. (TODO)))
  308. (defn slice
  309. "Gets a view of a slice of a matrix along a specific dimension.
  310. The returned matrix will have one less dimension.
  311. Slicing a 1D vector will return a scalar.
  312. Slicing on the first dimension (dimension 0) is likely to perform better
  313. for many matrix implementations."
  314. ([m index]
  315. (mp/get-slice m 0 index))
  316. ([m dimension index]
  317. (mp/get-slice m dimension index)))
  318. (defn slices
  319. "Gets a lazy sequence of slices of a matrix. If dimension is supplied, slices along a given dimension,
  320. otherwise slices along the first dimension."
  321. ([m]
  322. (mp/get-major-slice-seq m))
  323. ([m dimension]
  324. (map #(mp/get-slice m dimension %) (range (mp/dimension-count m dimension)))))
  325. (defn main-diagonal
  326. "Returns the main diagonal of a matrix or general array, as a vector"
  327. ([m]
  328. (mp/main-diagonal m)))
  329. (defn rotate
  330. "Rotates an array along specified dimensions"
  331. ([m dimension shift-amount]
  332. (TODO))
  333. ([m [shifts]]
  334. (TODO)))
  335. ;; ====================================
  336. ;; structural change operations
  337. (defn broadcast
  338. "Broadcasts a matrix to a specified shape"
  339. ([m shape]
  340. (mp/broadcast m shape)))
  341. (defn transpose
  342. "Transposes a 2D matrix"
  343. ([m]
  344. (mp/transpose m)))
  345. (defn transpose!
  346. "Transposes a square 2D matrix in-place"
  347. ([m]
  348. ;; TODO: implement with a proper protocol
  349. (assign! m (transpose m))))
  350. (defn reshape
  351. "Changes the shape of a matrix to the specified new shape. shape can be any sequence of dimension sizes.
  352. Preserves the row-major order of matrix elements."
  353. ([m shape]
  354. (mp/reshape m shape)))
  355. ;; ======================================
  356. ;; matrix comparisons
  357. (defn equals
  358. "Returns true if two matrices are numerically equal."
  359. ([a b]
  360. (mp/matrix-equals a b))
  361. ([a b epsilon]
  362. (TODO)))
  363. ;; ======================================
  364. ;; matrix maths / operations
  365. (defn mul
  366. "Performs matrix multiplication with matrices, vectors or scalars"
  367. ([a] a)
  368. ([a b]
  369. (cond
  370. (scalar? b) (if (scalar? a) (* a b) (mp/scale a b))
  371. (scalar? a) (mp/pre-scale b a)
  372. :else (mp/matrix-multiply a b)))
  373. ([a b & more]
  374. (reduce mul (mul a b) more)))
  375. (defn emul
  376. "Performs element-wise matrix multiplication. Matrices must be same size."
  377. ([a] a)
  378. ([a b]
  379. (mp/element-multiply a b))
  380. ([a b & more]
  381. (reduce mp/element-multiply (mp/element-multiply a b) more)))
  382. (defn emul!
  383. "Performs in-place element-wise matrix multiplication."
  384. ([a] a)
  385. ([a b]
  386. (TODO))
  387. ([a b & more]
  388. (TODO)))
  389. (defn transform
  390. "Transforms a given vector, returning a new vector"
  391. ([m v] (mp/vector-transform m v)))
  392. (defn transform!
  393. "Transforms a given vector in place"
  394. ([m v] (mp/vector-transform! m v)))
  395. (defn add
  396. "Performs element-wise matrix addition on one or more matrices."
  397. ([a] a)
  398. ([a b]
  399. (mp/matrix-add a b))
  400. ([a b & more]
  401. (reduce mp/matrix-add (mp/matrix-add a b) more)))
  402. (defn sub
  403. "Performs element-wise matrix subtraction on one or more matrices."
  404. ([a] a)
  405. ([a b]
  406. (mp/matrix-sub a b))
  407. ([a b & more]
  408. (reduce mp/matrix-sub (mp/matrix-sub a b) more)))
  409. (defn scale
  410. "Scales a matrix by a scalar factor"
  411. ([m factor]
  412. (mp/scale m factor)))
  413. (defn scale!
  414. "Scales a matrix by a scalar factor (in place)"
  415. ([m factor]
  416. (mp/scale! m factor)))
  417. (defn normalise
  418. "Normalises a matrix (scales to unit length)"
  419. ([m]
  420. (mp/normalise m)))
  421. (defn normalise!
  422. "Normalises a matrix in-place (scales to unit length).
  423. Returns the modified vector."
  424. ([m]
  425. (mp/normalise! m)))
  426. (defn dot
  427. "Computes the dot product (inner product) of two vectors"
  428. ([a b]
  429. (mp/vector-dot a b)))
  430. (defn det
  431. "Calculates the determinant of a matrix"
  432. ([a]
  433. (mp/determinant a)))
  434. (defn trace
  435. "Calculates the trace of a matrix (sum of elements on main diagonal)"
  436. ([a]
  437. (mp/trace a)))
  438. (defn length
  439. "Calculates the length (magnitude) of a vector"
  440. ([m]
  441. (mp/length m)))
  442. (defn length-squared
  443. "Calculates the squared length (squared magnitude) of a vector"
  444. ([m]
  445. (mp/length-squared m)))
  446. (defn sum
  447. "Calculates the sum of all the elements"
  448. [m]
  449. (mp/sum m))
  450. ;; create all unary maths operators
  451. (eval
  452. `(do ~@(map (fn [[name func]]
  453. `(defn ~name
  454. ([~'m]
  455. (~(symbol "core.matrix.protocols" (str name)) ~'m)))) mops/maths-ops)
  456. ~@(map (fn [[name func]]
  457. `(defn ~(symbol (str name "!"))
  458. ([~'m]
  459. (~(symbol "core.matrix.protocols" (str name "!")) ~'m)))) mops/maths-ops))
  460. )
  461. ;; ====================================
  462. ;; functional operations
  463. (defn ecount
  464. "Returns the total count of elements in an array"
  465. ([m]
  466. (cond
  467. (array? m) (reduce * 1 (shape m))
  468. :else (count m))))
  469. (defn eseq
  470. "Returns all elements of an array as a sequence in row-major order"
  471. ([m]
  472. (mp/element-seq m)))
  473. (defn ereduce
  474. "Element-wise reduce on all elements of an array."
  475. ([f m]
  476. (mp/element-reduce m f))
  477. ([f init m]
  478. (mp/element-reduce m f init)))
  479. (defn emap
  480. "Element-wise map over all elements of one or more arrays.
  481. Returns a new array of the same type and shape."
  482. ([f m]
  483. (mp/element-map m f))
  484. ([f m a]
  485. (mp/element-map m f a))
  486. ([f m a & more]
  487. (mp/element-map m f a more)))
  488. (defn emap!
  489. "Element-wise map over all elements of one or more arrays.
  490. Performs in-place modification of the first array argument."
  491. ([f m]
  492. (mp/element-map! m f))
  493. ([f m a]
  494. (mp/element-map! m f a))
  495. ([f m a & more]
  496. (mp/element-map! m f a more)))
  497. (defn index-seq-for-shape [sh]
  498. "Returns a sequence of all possible index vectors for a given shape, in row-major order"
  499. (let [gen (fn gen [prefix rem]
  500. (if rem
  501. (let [nrem (next rem)]
  502. (mapcat #(gen (conj prefix %) nrem) (range (first rem))))
  503. (list prefix)))]
  504. (gen [] (seq sh))))
  505. (defn index-seq [m]
  506. "Returns a sequence of all possible index vectors in a matrix, in row-major order"
  507. (index-seq-for-shape (shape m)))
  508. ;; ============================================================
  509. ;; Default implementations
  510. ;; - default behaviour for java.lang.Number scalars
  511. ;; - for stuff we don't recognise (java.lang.Object) we should try to
  512. ;; implement in terms of simpler operations, on assumption that
  513. ;; we have fallen through to the default implementation
  514. ;; default implementation for matrix ops
  515. (extend-protocol mp/PIndexedAccess
  516. java.util.List
  517. (get-1d [m x]
  518. (.get m (int x)))
  519. (get-2d [m x y]
  520. (mp/get-1d (.get m (int x)) y))
  521. (get-nd [m indexes]
  522. (if-let [s (seq indexes)]
  523. (mp/get-nd (.get m (int (first s))) (next s))
  524. m))
  525. java.lang.Object
  526. (get-1d [m x] (mp/get-nd m [x]))
  527. (get-2d [m x y] (mp/get-nd m [x y]))
  528. (get-nd [m indexes]
  529. (if (seq indexes)
  530. (error "Indexed get failed, not defined for:" (class m))
  531. (if (scalar? m) m
  532. (error "Not a scalar, cannot do zero dimensional get")))))
  533. (extend-protocol mp/PVectorOps
  534. java.lang.Number
  535. (vector-dot [a b] (* a b))
  536. (length [a] (double a))
  537. (length-squared [a] (Math/sqrt (double a)))
  538. (normalise [a]
  539. (let [a (double a)]
  540. (cond
  541. (> a 0.0) 1.0
  542. (< a 0.0) -1.0
  543. :else 0.0)))
  544. java.lang.Object
  545. (vector-dot [a b])
  546. (length [a]
  547. (Math/sqrt (double (mp/length-squared a))))
  548. (length-squared [a]
  549. (ereduce (fn [r x] (+ r (* x x))) 0 a))
  550. (normalise [a]
  551. (scale a (/ 1.0 (Math/sqrt (double (mp/length-squared a)))))))
  552. (extend-protocol mp/PMutableVectorOps
  553. java.lang.Object
  554. (normalise! [a]
  555. (scale! a (/ 1.0 (Math/sqrt (double (mp/length-squared a)))))))
  556. (extend-protocol mp/PAssignment
  557. java.lang.Object
  558. (assign! [m x]
  559. (cond
  560. (mp/is-vector? x)
  561. (dotimes [i (row-count m)]
  562. (mset! m i (mget x i)))
  563. (array? x)
  564. (doall (map (fn [a b] (mp/assign! a b))
  565. (slices m)
  566. (slices x)))
  567. (.isArray (class x))
  568. (mp/assign-array! m x)
  569. :else
  570. (error "Can't assign to a non-matrix object: " (class m))))
  571. (assign-array!
  572. ([m arr]
  573. (let [alen (long (count arr))]
  574. (if (mp/is-vector? m)
  575. (dotimes [i alen]
  576. (mp/set-1d m i (nth arr i)))
  577. (mp/assign-array! m arr 0 alen))))
  578. ([m arr start length]
  579. (let [length (long length)
  580. start (long start)]
  581. (if (mp/is-vector? m)
  582. (dotimes [i length]
  583. (mp/set-1d m i (nth arr (+ start i))))
  584. (let [ss (seq (slices m))
  585. skip (long (if ss (ecount (first (slices m))) 0))]
  586. (doseq-indexed [s ss i]
  587. (mp/assign-array! s arr (* skip i) skip))))))))
  588. (extend-protocol mp/PMatrixCloning
  589. java.lang.Cloneable
  590. (clone [m]
  591. (.invoke ^java.lang.reflect.Method (.getDeclaredMethod (class m) "clone" nil) m nil))
  592. java.lang.Object
  593. (clone [m]
  594. (coerce m (coerce [] m))))
  595. (extend-protocol mp/PDimensionInfo
  596. nil
  597. (dimensionality [m] 0)
  598. (is-scalar? [m] true)
  599. (is-vector? [m] false)
  600. (get-shape [m] [])
  601. (dimension-count [m i] (error "cannot get dimension count from nil"))
  602. java.lang.Number
  603. (dimensionality [m] 0)
  604. (is-scalar? [m] true)
  605. (is-vector? [m] false)
  606. (get-shape [m] [])
  607. (dimension-count [m i] (error "java.lang.Number has zero dimensionality, cannot get dimension count"))
  608. java.lang.Object
  609. (dimensionality [m] 0)
  610. (is-vector? [m] (== 1 (mp/dimensionality m)))
  611. (is-scalar? [m] false)
  612. (get-shape [m] (for [i (range (mp/dimensionality m))] (mp/dimension-count m i)))
  613. (dimension-count [m i] (error "Can't determine count of dimension " i " on Object: " (class m))))
  614. ;; generic versions of matrix ops
  615. (extend-protocol mp/PMatrixOps
  616. java.lang.Object
  617. (trace [m]
  618. (when-not (square? m) (error "Can't compute trace of non-square matrix"))
  619. (let [dims (long (row-count m))]
  620. (loop [i 0 res 0.0]
  621. (if (>= i dims)
  622. res
  623. (recur (inc i) (+ res (double (mp/get-2d m i i))))))))
  624. (negate [m]
  625. (mp/scale m -1.0))
  626. (length-squared [m]
  627. (ereduce #(+ %1 (* %2 *2)) 0.0 m))
  628. (length [m]
  629. (Math/sqrt (mp/length-squared m)))
  630. (transpose [m]
  631. (case (long (dimensionality m))
  632. 0 m
  633. 1 m
  634. 2 (coerce m (vec (apply map vector (map #(coerce [] %) (slices m)))))
  635. (error "Don't know how to transpose matrix of dimensionality: " m))))
  636. ;; matrix multiply
  637. (extend-protocol mp/PMatrixMultiply
  638. java.lang.Number
  639. (element-multiply [m a]
  640. (clojure.core/* m a))
  641. (matrix-multiply [m a]
  642. (cond
  643. (number? a) (* m a)
  644. (matrix? a) (mp/pre-scale a m)
  645. :else (error "Don't know how to multiply number with: " (class a))))
  646. java.lang.Object
  647. (matrix-multiply [m a]
  648. (coerce m (mp/matrix-multiply (coerce [] m) (coerce [] a))))
  649. (element-multiply [m a]
  650. (emap clojure.core/* m a)))
  651. ;; matrix element summation
  652. (extend-protocol mp/PSummable
  653. java.lang.Number
  654. (sum [a] a)
  655. java.lang.Object
  656. (sum [a]
  657. (mp/element-reduce a +)))
  658. ;; matrix element summation
  659. (extend-protocol mp/PTypeInfo
  660. java.lang.Number
  661. (element-type [a] (class a))
  662. java.lang.Object
  663. (element-type [a]
  664. (if (mp/is-scalar? a)
  665. (class a)
  666. (class (first (eseq a))))))
  667. ;; general transformation of a vector
  668. (extend-protocol mp/PVectorTransform
  669. clojure.lang.IFn
  670. (vector-transform [m a]
  671. (m a))
  672. (vector-transform! [m a]
  673. (assign! a (m a)))
  674. java.lang.Object
  675. (vector-transform [m a]
  676. (cond
  677. (matrix? m) (mul m a)
  678. :else (error "Don't know how to transform using: " (class m))))
  679. (vector-transform! [m a]
  680. (assign! a (mp/vector-transform m a))))
  681. ;; matrix scaling
  682. (extend-protocol mp/PMatrixScaling
  683. java.lang.Number
  684. (scale [m a]
  685. (if (number? a)
  686. (* m a)
  687. (mp/pre-scale a m)))
  688. (pre-scale [m a]
  689. (if (number? a)
  690. (* a m)
  691. (mp/scale a m)))
  692. java.lang.Object
  693. (scale [m a]
  694. (emap #(* % a) m))
  695. (pre-scale [m a]
  696. (emap (partial * a) m)))
  697. (extend-protocol mp/PMatrixMutableScaling
  698. java.lang.Number
  699. (scale! [m a]
  700. (error "Can't scale! a numeric value: " m))
  701. (pre-scale! [m a]
  702. (error "Can't pre-scale! a numeric value: " m))
  703. java.lang.Object
  704. (scale! [m a]
  705. (emap! #(* % a) m))
  706. (pre-scale! [m a]
  707. (emap! (partial * a) m)))
  708. (extend-protocol mp/PMatrixAdd
  709. ;; matrix add for scalars
  710. java.lang.Number
  711. (matrix-add [m a]
  712. (if (number? a) (+ m a) (error "Can't add scalar number to a matrix")))
  713. (matrix-sub [m a]
  714. (if (number? a) (- m a) (error "Can't a matrix from a scalar number")))
  715. ;; default impelementation - assume we can use emap?
  716. java.lang.Object
  717. (matrix-add [m a]
  718. (emap + m a))
  719. (matrix-sub [m a]
  720. (emap - m a)))
  721. ;; equality checking
  722. (extend-protocol mp/PMatrixEquality
  723. java.lang.Number
  724. (matrix-equals [a b]
  725. (== a b))
  726. java.lang.Object
  727. (matrix-equals [a b]
  728. (not (some false? (map == (mp/element-seq a) (mp/element-seq b))))))
  729. (extend-protocol mp/PDoubleArrayOutput
  730. java.lang.Number
  731. (to-double-array [m] (aset (double-array 1) 0 (double m)))
  732. (as-double-array [m] nil)
  733. java.lang.Object
  734. (to-double-array [m]
  735. (double-array (eseq m)))
  736. (as-double-array [m] nil))
  737. ;; functional operations
  738. (extend-protocol mp/PFunctionalOperations
  739. java.lang.Number
  740. (element-seq [m]
  741. (list m))
  742. (element-map
  743. ([m f]
  744. (f m))
  745. ([m f a]
  746. (f m a))
  747. ([m f a more]
  748. (apply f m a more)))
  749. (element-map!
  750. ([m f]
  751. (error "java.lang.Number instance is not mutable!"))
  752. ([m f a]
  753. (error "java.lang.Number instance is not mutable!"))
  754. ([m f a more]
  755. (error "java.lang.Number instance is not mutable!")))
  756. (element-reduce
  757. ([m f]
  758. m)
  759. ([m f init]
  760. (f init m)))
  761. java.lang.Object
  762. (element-seq [m]
  763. (cond
  764. (array? m) (mapcat mp/element-seq (slices m))
  765. :else (seq m)))
  766. (element-map
  767. ([m f]
  768. (coerce m (mp/element-map (mp/convert-to-nested-vectors m) f)))
  769. ([m f a]
  770. (coerce m (mp/element-map (mp/convert-to-nested-vectors m) f a)))
  771. ([m f a more]
  772. (coerce m (mp/element-map (mp/convert-to-nested-vectors m) f a more))))
  773. (element-map!
  774. ([m f]
  775. (assign! m (mp/element-map m f)))
  776. ([m f a]
  777. (assign! m (mp/element-map m f a)))
  778. ([m f a more]
  779. (assign! m (mp/element-map m f a more))))
  780. (element-reduce
  781. ([m f]
  782. (coerce m (mp/element-reduce (mp/convert-to-nested-vectors m) f)))
  783. ([m f init]
  784. (coerce m (mp/element-reduce (mp/convert-to-nested-vectors m) f init))))
  785. nil
  786. (element-seq [m] nil)
  787. (element-map
  788. ([m f] nil)
  789. ([m f a] nil)
  790. ([m f a more] nil))
  791. (element-map!
  792. ([m f] nil)
  793. ([m f a] nil)
  794. ([m f a more] nil))
  795. (element-reduce
  796. ([m f] (f))
  797. ([m f init] init)))
  798. ;; TODO: return a view object by default for matrix slices
  799. (extend-protocol mp/PMatrixSlices
  800. java.lang.Object
  801. (get-row [m i]
  802. (mp/get-major-slice m i))
  803. (get-column [m i]
  804. (mp/get-slice m 1 i))
  805. (get-major-slice [m i]
  806. (coerce m ((coerce [] m) i)))
  807. (get-slice [m dimension i]
  808. (coerce m (mp/get-slice (coerce [] m) dimension i))))
  809. (extend-protocol mp/PSliceView
  810. java.lang.Object
  811. ;; default implementation uses a lightweight wrapper object
  812. (get-major-slice-view [m i] (core.matrix.impl.wrappers/wrap-slice m i)))
  813. (extend-protocol mp/PSliceSeq
  814. java.lang.Object
  815. (get-major-slice-seq [m]
  816. (let [sc (try (mp/dimension-count m 0) (catch Throwable t (error "No dimensionality for getting slices: " (class m))))]
  817. (if (== 1 (dimensionality m))
  818. (for [i (range sc)] (mp/get-1d m i))
  819. (map #(mp/get-major-slice m %) (range sc))))))
  820. ;; attempt conversion to nested vectors
  821. (extend-protocol mp/PConversion
  822. java.lang.Number
  823. (convert-to-nested-vectors [m]
  824. ;; we accept a scalar as a "nested vector" for these purposes?
  825. m)
  826. java.lang.Object
  827. (convert-to-nested-vectors [m]
  828. (cond
  829. (scalar? m) m
  830. (mp/is-vector? m)
  831. (mapv #(mget m %) (range (row-count m)))
  832. (array? m)
  833. (mapv mp/convert-to-nested-vectors (slices m))
  834. (sequential? m)
  835. (mapv mp/convert-to-nested-vectors m)
  836. (seq? m)
  837. (mapv mp/convert-to-nested-vectors m)
  838. :default
  839. (error "Can't work out how to convert to nested vectors: " (class m) " = " m))))
  840. (extend-protocol mp/PReshaping
  841. java.lang.Number
  842. (reshape [m shape]
  843. (compute-matrix shape (constantly m)))
  844. java.lang.Object
  845. (reshape [m shape]
  846. (let [partition-shape (fn partition-shape [es shape]
  847. (if-let [s (seq shape)]
  848. (let [ns (next s)
  849. plen (reduce * 1 ns)]
  850. (map #(partition-shape % ns) (partition plen es)))
  851. (first es)))]
  852. (if-let [shape (seq shape)]
  853. (let [fs (long (first shape))
  854. parts (partition-shape (mp/element-seq m) shape)]
  855. (when-not (<= fs (count parts))
  856. (error "Reshape not possible: insufficient elements for shape: " shape " have: " (seq parts)))
  857. (array m (take fs parts)))
  858. (first (mp/element-seq m))))))
  859. (extend-protocol mp/PCoercion
  860. java.lang.Object
  861. (coerce-param [m param]
  862. (mp/construct-matrix m (mp/convert-to-nested-vectors param))))
  863. ;; define standard Java maths functions for numbers
  864. (eval
  865. `(extend-protocol mp/PMathsFunctions
  866. java.lang.Number
  867. ~@(map (fn [[name func]]
  868. `(~name [~'m] (double (~func (double ~'m)))))
  869. mops/maths-ops)
  870. java.lang.Object
  871. ~@(map (fn [[name func]]
  872. `(~name [~'m] (emap #(double (~func (double %))) ~'m)))
  873. mops/maths-ops)
  874. ~@(map (fn [[name func]]
  875. `(~(symbol (str name "!")) [~'m] (emap! #(double (~func (double %))) ~'m)))
  876. mops/maths-ops)))
  877. (extend-protocol mp/PMatrixSubComponents
  878. java.lang.Object
  879. (main-diagonal [m]
  880. (let [sh (shape m)
  881. rank (count sh)
  882. dims (first sh)]
  883. (if-not (reduce = sh) (error "Not a square array!"))
  884. (matrix m (for [i (range dims)] (apply mget m (repeat rank i)))))))
  885. (extend-protocol mp/PSpecialisedConstructors
  886. java.lang.Object
  887. (identity-matrix [m dims]
  888. (diagonal-matrix (repeat dims 1.0)))
  889. (diagonal-matrix [m diagonal-values]
  890. (let [dims (count diagonal-values)
  891. diagonal-values (coerce [] diagonal-values)
  892. zs (vec (repeat dims 0.0))
  893. dm (vec (for [i (range dims)]
  894. (assoc zs i (nth diagonal-values i))))]
  895. (coerce m dm))))
  896. ;; =======================================================
  897. ;; default multimethod implementations
  898. (defmethod mm/mul :default [x y]
  899. (error "Don't know how to multiply " (class x) " with " (class y)))
  900. ;; =========================================================
  901. ;; Final implementation setup
  902. (defn current-implementation
  903. "Gets the currently active matrix implementation"
  904. ([] core.matrix/*matrix-implementation*))
  905. (defn current-implementation-object
  906. "Gets the currently active matrix implementation"
  907. ([] (imp/get-canonical-object (current-implementation))))
  908. (defn set-current-implementation
  909. "Sets the currently active matrix implementation"
  910. ([m]
  911. (alter-var-root (var core.matrix/*matrix-implementation*)
  912. (fn [_] (imp/get-implementation-key m)))))