/tags/R2001-10-10/octave-forge/main/image/conv2.cc

# · C++ · 301 lines · 205 code · 39 blank · 57 comment · 57 complexity · 6b4eb40fb1d684608ea16f208b504e59 MD5 · raw file

  1. /*
  2. * conv2: 2D convolution for octave
  3. *
  4. * Copyright (C) 1999 Andy Adler
  5. * This code has no warrany whatsoever.
  6. * Do what you like with this code as long as you
  7. * leave this copyright in place.
  8. *
  9. * $Id: conv2.cc 2 2001-10-10 19:54:49Z pkienzle $
  10. ## 2000-05-17: Paul Kienzle
  11. ## * change argument to vector conversion to work for 2.1 series octave
  12. ## as well as 2.0 series
  13. ## 2001-02-05: Paul Kienzle
  14. ## * accept complex arguments
  15. */
  16. #include <octave/oct.h>
  17. #define MAX(a,b) ((a) > (b) ? (a) : (b))
  18. #define SHAPE_FULL 1
  19. #define SHAPE_SAME 2
  20. #define SHAPE_VALID 3
  21. #if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL)
  22. extern MArray2<double>
  23. conv2 (MArray<double>&, MArray<double>&, MArray2<double>&, int);
  24. extern MArray2<Complex>
  25. conv2 (MArray<Complex>&, MArray<Complex>&, MArray2<Complex>&, int);
  26. #endif
  27. template <class T>
  28. MArray2<T>
  29. conv2 (MArray<T>& R, MArray<T>& C, MArray2<T>& A, int ishape)
  30. {
  31. int Rn= R.length();
  32. int Cm= C.length();
  33. int Am = A.rows();
  34. int An = A.columns();
  35. /*
  36. * Here we calculate the size of the output matrix,
  37. * in order to stay Matlab compatible, it is based
  38. * on the third parameter if its separable, and the
  39. * first if it's not
  40. */
  41. int outM, outN, edgM, edgN;
  42. if ( ishape == SHAPE_FULL ) {
  43. outM= Am + Cm - 1;
  44. outN= An + Rn - 1;
  45. edgM= Cm - 1;
  46. edgN= Rn - 1;
  47. } else if ( ishape == SHAPE_SAME ) {
  48. outM= Am;
  49. outN= An;
  50. // Matlab seems to arbitrarily choose this convention for
  51. // 'same' with even length R, C
  52. edgM= ( Cm - 1) /2;
  53. edgN= ( Rn - 1) /2;
  54. } else if ( ishape == SHAPE_VALID ) {
  55. outM= Am - Cm + 1;
  56. outN= An - Rn + 1;
  57. edgM= edgN= 0;
  58. }
  59. // printf("A(%d,%d) C(%d) R(%d) O(%d,%d) E(%d,%d)\n",
  60. // Am,An, Cm,Rn, outM, outN, edgM, edgN);
  61. MArray2<T> O(outM,outN);
  62. /*
  63. * T accumulated the 1-D conv for each row, before calculating
  64. * the convolution in the other direction
  65. * There is no efficiency advantage to doing it in either direction
  66. * first
  67. */
  68. MArray<T> X( An );
  69. for( int oi=0; oi < outM; oi++ ) {
  70. for( int oj=0; oj < An; oj++ ) {
  71. T sum=0;
  72. int ci= Cm - 1 - MAX(0, edgM-oi);
  73. int ai= MAX(0, oi-edgM) ;
  74. const T* Ad= A.data() + ai + Am*oj;
  75. const T* Cd= C.data() + ci;
  76. for( ; ci >= 0 && ai < Am;
  77. ci--, Cd--, ai++, Ad++) {
  78. sum+= (*Ad) * (*Cd);
  79. } // for( int ci=
  80. X(oj)= sum;
  81. } // for( int oj=0
  82. for( int oj=0; oj < outN; oj++ ) {
  83. T sum=0;
  84. int rj= Rn - 1 - MAX(0, edgN-oj);
  85. int aj= MAX(0, oj-edgN) ;
  86. const T* Xd= X.data() + aj;
  87. const T* Rd= R.data() + rj;
  88. for( ; rj >= 0 && aj < An;
  89. rj--, Rd--, aj++, Xd++) {
  90. sum+= (*Xd) * (*Rd);
  91. } //for( int rj=
  92. O(oi,oj)= sum;
  93. } // for( int oj=0
  94. } // for( int oi=0
  95. return O;
  96. }
  97. #if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL)
  98. extern MArray2<double>
  99. conv2 (MArray2<double>&, MArray2<double>&, int);
  100. extern MArray2<Complex>
  101. conv2 (MArray2<Complex>&, MArray2<Complex>&, int);
  102. #endif
  103. template <class T>
  104. MArray2<T>
  105. conv2 (MArray2<T>&A, MArray2<T>&B, int ishape)
  106. {
  107. /* Convolution works fastest if we choose the A matrix to be
  108. * the largest.
  109. *
  110. * Here we calculate the size of the output matrix,
  111. * in order to stay Matlab compatible, it is based
  112. * on the third parameter if its separable, and the
  113. * first if it's not
  114. *
  115. * NOTE in order to be Matlab compatible, we give
  116. * wrong sizes for 'valid' if the smallest matrix is first
  117. */
  118. int Am = A.rows();
  119. int An = A.columns();
  120. int Bm = B.rows();
  121. int Bn = B.columns();
  122. int outM, outN, edgM, edgN;
  123. if ( ishape == SHAPE_FULL ) {
  124. outM= Am + Bm - 1;
  125. outN= An + Bn - 1;
  126. edgM= Bm - 1;
  127. edgN= Bn - 1;
  128. } else if ( ishape == SHAPE_SAME ) {
  129. outM= Am;
  130. outN= An;
  131. // Matlab seems to arbitrarily choose this convention for
  132. // 'same' with even length R, C
  133. edgM= ( Bm - 1) /2;
  134. edgN= ( Bn - 1) /2;
  135. } else if ( ishape == SHAPE_VALID ) {
  136. outM= Am - Bm + 1;
  137. outN= An - Bn + 1;
  138. edgM= edgN= 0;
  139. }
  140. // printf("A(%d,%d) B(%d,%d) O(%d,%d) E(%d,%d)\n",
  141. // Am,An, Bm,Bn, outM, outN, edgM, edgN);
  142. MArray2<T> O(outM,outN);
  143. for( int oi=0; oi < outM; oi++ ) {
  144. for( int oj=0; oj < outN; oj++ ) {
  145. T sum=0;
  146. for( int bj= Bn - 1 - MAX(0, edgN-oj),
  147. aj= MAX(0, oj-edgN);
  148. bj >= 0 && aj < An;
  149. bj--, aj++) {
  150. int bi= Bm - 1 - MAX(0, edgM-oi);
  151. int ai= MAX(0, oi-edgM);
  152. const T* Ad= A.data() + ai + Am*aj;
  153. const T* Bd= B.data() + bi + Bm*bj;
  154. for( ; bi >= 0 && ai < Am;
  155. bi--, Bd--, ai++, Ad++) {
  156. sum+= (*Ad) * (*Bd);
  157. /*
  158. * It seems to be about 2.5 times faster to use pointers than
  159. * to do this
  160. * sum+= A(ai,aj) * B(bi,bj);
  161. */
  162. } // for( int bi=
  163. } //for( int bj=
  164. O(oi,oj)= sum;
  165. } // for( int oj=
  166. } // for( int oi=
  167. return O;
  168. }
  169. DEFUN_DLD (conv2, args, ,
  170. "[...] = conv2 (...)
  171. CONV2: do 2 dimensional convolution
  172. c= conv2(a,b) -> same as c= conv2(a,b,'full')
  173. c= conv2(a,b,shape) returns 2-D convolution of a and b
  174. where the size of c is given by
  175. shape= 'full' -> returns full 2-D convolution
  176. shape= 'same' -> same size as a. 'central' part of convolution
  177. shape= 'valid' -> only parts which do not include zero-padded edges
  178. c= conv2(a,b,shape) returns 2-D convolution of a and b
  179. c= conv2(v1,v2,a) -> same as c= conv2(v1,v2,a,'full')
  180. c= conv2(v1,v2,a,shape) returns convolution of a by vector v1
  181. in the column direction and vector v2 in the row direction ")
  182. {
  183. octave_value_list retval;
  184. octave_value tmp;
  185. int nargin = args.length ();
  186. string shape= "full";
  187. bool separable= false;
  188. int ishape;
  189. if (nargin < 2 ) {
  190. print_usage ("conv2");
  191. return retval;
  192. } else if (nargin == 3) {
  193. if ( args(2).is_string() )
  194. shape= args(2).string_value();
  195. else
  196. separable= true;
  197. } else if (nargin >= 4) {
  198. separable= true;
  199. shape= args(3).string_value();
  200. }
  201. if ( shape == "full" ) ishape = SHAPE_FULL;
  202. else if ( shape == "same" ) ishape = SHAPE_SAME;
  203. else if ( shape == "valid" ) ishape = SHAPE_VALID;
  204. else { // if ( shape
  205. error("Shape type not valid");
  206. print_usage ("conv2");
  207. return retval;
  208. }
  209. if (separable) {
  210. /*
  211. * Check that the first two parameters are vectors
  212. * if we're doing separable
  213. */
  214. if ( !( 1== args(0).rows() || 1== args(0).columns() ) ||
  215. !( 1== args(1).rows() || 1== args(1).columns() ) ) {
  216. print_usage ("conv2");
  217. return retval;
  218. }
  219. if (args(0).is_complex_type() || args(1).is_complex_type()
  220. || args(2).is_complex_type()) {
  221. ComplexColumnVector v1 (args(0).complex_vector_value());
  222. ComplexColumnVector v2 (args(1).complex_vector_value());
  223. ComplexMatrix a (args(2).complex_matrix_value());
  224. ComplexMatrix c(conv2(v1, v2, a, ishape));
  225. retval(0) = c;
  226. } else {
  227. ColumnVector v1 (args(0).vector_value());
  228. ColumnVector v2 (args(1).vector_value());
  229. Matrix a (args(2).matrix_value());
  230. Matrix c(conv2(v1, v2, a, ishape));
  231. retval(0) = c;
  232. }
  233. } else { // if (separable)
  234. if (args(0).is_complex_type() || args(1).is_complex_type()) {
  235. ComplexMatrix a (args(0).complex_matrix_value());
  236. ComplexMatrix b (args(1).complex_matrix_value());
  237. ComplexMatrix c(conv2(a, b, ishape));
  238. retval(0) = c;
  239. } else {
  240. Matrix a (args(0).matrix_value());
  241. Matrix b (args(1).matrix_value());
  242. Matrix c(conv2(a, b, ishape));
  243. retval(0) = c;
  244. }
  245. } // if (separable)
  246. return retval;
  247. }
  248. template MArray2<double>
  249. conv2 (MArray<double>&, MArray<double>&, MArray2<double>&, int);
  250. template MArray2<double>
  251. conv2 (MArray2<double>&, MArray2<double>&, int);
  252. template MArray2<Complex>
  253. conv2 (MArray<Complex>&, MArray<Complex>&, MArray2<Complex>&, int);
  254. template MArray2<Complex>
  255. conv2 (MArray2<Complex>&, MArray2<Complex>&, int);