#### /matrix/arithmetic.go

Go | 205 lines | 152 code | 24 blank | 29 comment | 42 complexity | 3c60dab8ef2063c122dad789d006c667 MD5 | raw file
```  1// Copyright 2009 The GoMatrix Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
4
5package matrix
6
7import "math"
8
9/*
10Finds the sum of two matrices.
11*/
12func Sum(A MatrixRO, Bs ...MatrixRO) (C *DenseMatrix) {
13	C = MakeDenseCopy(A)
14	var err error
15	for _, B := range Bs {
17		if err != nil {
18			break
19		}
20	}
21	if err != nil {
22		C = nil
23	}
24	return
25}
26
27/*
28Finds the difference between two matrices.
29*/
30func Difference(A, B MatrixRO) (C *DenseMatrix) {
31	C = MakeDenseCopy(A)
32	err := C.Subtract(MakeDenseCopy(B))
33	if err != nil {
34		C = nil
35	}
36	return
37}
38
39/*
40Finds the Product of two matrices.
41*/
42func Product(A MatrixRO, Bs ...MatrixRO) (C *DenseMatrix) {
43	C = MakeDenseCopy(A)
44
45	for _, B := range Bs {
46		Cm, err := C.Times(B)
47		if err != nil {
48			return
49		}
50		C = Cm.(*DenseMatrix)
51	}
52
53	return
54}
55
56func Transpose(A MatrixRO) (B Matrix) {
57	switch Am := A.(type) {
58	case *DenseMatrix:
59		B = Am.Transpose()
60		return
61	case *SparseMatrix:
62		B = Am.Transpose()
63		return
64	}
65	B = A.DenseMatrix().Transpose()
66	return
67}
68
69func Inverse(A MatrixRO) (B Matrix) {
70	var err error
71	switch Am := A.(type) {
72	case *DenseMatrix:
73		B, err = Am.Inverse()
74		if err != nil {
75			panic(err)
76		}
77		return
78	}
79	B, err = A.DenseMatrix().Inverse()
80	if err != nil {
81		panic(err)
82	}
83	return
84}
85
86/*
87The Kronecker product. (http://en.wikipedia.org/wiki/Kronecker_product)
88*/
89func Kronecker(A, B MatrixRO) (C *DenseMatrix) {
90	ars, acs := A.Rows(), A.Cols()
91	brs, bcs := B.Rows(), B.Cols()
92	C = Zeros(ars*brs, acs*bcs)
93	for i := 0; i < ars; i++ {
94		for j := 0; j < acs; j++ {
95			Cij := C.GetMatrix(i*brs, j*bcs, brs, bcs)
96			Cij.SetMatrix(0, 0, Scaled(B, A.Get(i, j)))
97		}
98	}
99	return
100}
101
102func Vectorize(Am MatrixRO) (V *DenseMatrix) {
103	elems := Am.DenseMatrix().Transpose().Array()
104	V = MakeDenseMatrix(elems, Am.Rows()*Am.Cols(), 1)
105	return
106}
107
108func Unvectorize(V MatrixRO, rows, cols int) (A *DenseMatrix) {
109	A = MakeDenseMatrix(V.DenseMatrix().Array(), cols, rows).Transpose()
110	return
111}
112
113/*
114Uses a number of goroutines to do the dot products necessary
115for the matrix multiplication in parallel.
116*/
117func ParallelProduct(A, B MatrixRO) (C *DenseMatrix) {
118	if A.Cols() != B.Rows() {
119		return nil
120	}
121
122	C = Zeros(A.Rows(), B.Cols())
123
124	in := make(chan int)
125	quit := make(chan bool)
126
127	dotRowCol := func() {
128		for {
129			select {
130			case i := <-in:
131				sums := make([]float64, B.Cols())
132				for k := 0; k < A.Cols(); k++ {
133					for j := 0; j < B.Cols(); j++ {
134						sums[j] += A.Get(i, k) * B.Get(k, j)
135					}
136				}
137				for j := 0; j < B.Cols(); j++ {
138					C.Set(i, j, sums[j])
139				}
140			case <-quit:
141				return
142			}
143		}
144	}
145
147
148	for i := 0; i < threads; i++ {
149		go dotRowCol()
150	}
151
152	for i := 0; i < A.Rows(); i++ {
153		in <- i
154	}
155
156	for i := 0; i < threads; i++ {
157		quit <- true
158	}
159
160	return
161}
162
163/*
164Scales a matrix by a scalar.
165*/
166func Scaled(A MatrixRO, f float64) (B *DenseMatrix) {
167	B = MakeDenseCopy(A)
168	B.Scale(f)
169	return
170}
171
172/*
173Tests the element-wise equality of the two matrices.
174*/
175func Equals(A, B MatrixRO) bool {
176	if A.Rows() != B.Rows() || A.Cols() != B.Cols() {
177		return false
178	}
179	for i := 0; i < A.Rows(); i++ {
180		for j := 0; j < A.Cols(); j++ {
181			if A.Get(i, j) != B.Get(i, j) {
182				return false
183			}
184		}
185	}
186	return true
187}
188
189/*
190Tests to see if the difference between two matrices,
191element-wise, exceeds ε.
192*/
193func ApproxEquals(A, B MatrixRO, ε float64) bool {
194	if A.Rows() != B.Rows() || A.Cols() != B.Cols() {
195		return false
196	}
197	for i := 0; i < A.Rows(); i++ {
198		for j := 0; j < A.Cols(); j++ {
199			if math.Abs(A.Get(i, j)-B.Get(i, j)) > ε {
200				return false
201			}
202		}
203	}
204	return true
205}
```