PageRenderTime 43ms CodeModel.GetById 19ms RepoModel.GetById 0ms app.codeStats 1ms

/toolkits/graphical_models/deprecated/factors/discrete_domain.hpp

https://github.com/michaelkook/GraphLab-2
C++ Header | 313 lines | 206 code | 54 blank | 53 comment | 73 complexity | 9dff8d8dd2e13b2d4e463939e20e657a MD5 | raw file
Possible License(s): ISC, Apache-2.0
  1. /**
  2. * Copyright (c) 2009 Carnegie Mellon University.
  3. * All rights reserved.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing,
  12. * software distributed under the License is distributed on an "AS
  13. * IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
  14. * express or implied. See the License for the specific language
  15. * governing permissions and limitations under the License.
  16. *
  17. * For more about this software visit:
  18. *
  19. * http://www.graphlab.ml.cmu.edu
  20. *
  21. */
  22. #ifndef DISCRETE_DOMAIN_HPP
  23. #define DISCRETE_DOMAIN_HPP
  24. #include <graphlab/logger/assertions.hpp>
  25. #include "discrete_variable.hpp"
  26. #include <graphlab/macros_def.hpp>
  27. //! Predeclearation of assignment used for discrete domain
  28. template<size_t MAX_DIM> class discrete_assignment;
  29. /**
  30. * This class respresents a discrete discrete_domain over a set of variables.
  31. */
  32. template<size_t MAX_DIM>
  33. class discrete_domain {
  34. public:
  35. typedef discrete_assignment<MAX_DIM> assignment_type;
  36. //! Make an empy domain
  37. discrete_domain() : _num_vars(0) { }
  38. //! Make a single variable discrete_domain
  39. discrete_domain(const discrete_variable& v1) :
  40. _num_vars(1) {
  41. ASSERT_LE(_num_vars, MAX_DIM);
  42. _vars[0] = v1;
  43. }
  44. //! Make a two variable discrete_domain
  45. discrete_domain(const discrete_variable& v1, const discrete_variable& v2) :
  46. _num_vars(2) {
  47. ASSERT_LE(_num_vars, MAX_DIM);
  48. assert(v1 != v2);
  49. if(v1 < v2) {
  50. _vars[0] = v1;
  51. _vars[1] = v2;
  52. } else {
  53. _vars[0] = v2;
  54. _vars[1] = v1;
  55. }
  56. }
  57. //! Make a three variable discrete_domain
  58. discrete_domain(const discrete_variable& v1,
  59. const discrete_variable& v2,
  60. const discrete_variable& v3) :
  61. _num_vars(3) {
  62. ASSERT_LE(_num_vars, MAX_DIM);
  63. ASSERT_NE(v1, v2);
  64. ASSERT_NE(v2, v3);
  65. ASSERT_NE(v1, v3);
  66. if(v1 < v2 && v2 < v3) {
  67. _vars[0] = v1;
  68. _vars[1] = v2;
  69. _vars[2] = v3;
  70. } else if( v1 < v3 && v3 < v2) {
  71. _vars[0] = v1;
  72. _vars[1] = v3;
  73. _vars[2] = v2;
  74. } else if( v2 < v1 && v1 < v3) {
  75. _vars[0] = v2;
  76. _vars[1] = v1;
  77. _vars[2] = v3;
  78. } else if( v2 < v3 && v3 < v1) {
  79. _vars[0] = v2;
  80. _vars[1] = v3;
  81. _vars[2] = v1;
  82. } else if( v3 < v1 && v1 < v2) {
  83. _vars[0] = v3;
  84. _vars[1] = v1;
  85. _vars[2] = v2;
  86. } else if( v3 < v1 && v1 < v2) {
  87. _vars[0] = v3;
  88. _vars[1] = v1;
  89. _vars[2] = v2;
  90. } else { throw("Invalid Case!"); }
  91. }
  92. //! Make a discrete_domain from a vector of variables
  93. discrete_domain(const std::vector<discrete_variable>& variables) :
  94. _num_vars(variables.size()) {
  95. ASSERT_LE(_num_vars, MAX_DIM);
  96. for(size_t i = 0; i < _num_vars; ++i)
  97. _vars[i] = variables[i];
  98. std::sort(_vars, _vars + std::min(MAX_DIM, _num_vars) );
  99. }
  100. //! Make a discrete_domain from a set of variables
  101. discrete_domain(const std::set<discrete_variable>& variables) :
  102. _num_vars(variables.size()) {
  103. ASSERT_LE(_num_vars, MAX_DIM);
  104. size_t i = 0;
  105. foreach(const discrete_variable& var, variables) _vars[i++] = var;
  106. }
  107. discrete_domain& operator+=(const discrete_variable& var) {
  108. if(_vars[_num_vars - 1] < var) {
  109. _vars[_num_vars] = var;
  110. _num_vars++;
  111. return *this;
  112. }
  113. return operator+=(discrete_domain(var));
  114. }
  115. //! add the other discrete_domain to this discrete_domain
  116. discrete_domain operator+(const discrete_variable& var) const {
  117. discrete_domain dom = *this;
  118. return dom += var;
  119. }
  120. //! add the discrete_domain to this discrete_domain
  121. discrete_domain& operator+=(const discrete_domain& other) {
  122. if(other.num_vars() == 0) return *this;
  123. discrete_domain backup = *this;
  124. _num_vars = 0;
  125. for(size_t i = 0, j = 0;
  126. i < backup.num_vars() || j < other.num_vars(); ) {
  127. ASSERT_LE(_num_vars, MAX_DIM);
  128. // Both
  129. if(i < backup.num_vars() && j < other.num_vars()
  130. && _num_vars < MAX_DIM) {
  131. if(backup.var(i) < other.var(j))
  132. _vars[_num_vars++] = backup.var(i++);
  133. else if(other.var(j) < backup.var(i))
  134. _vars[_num_vars++] = other.var(j++);
  135. else { _vars[_num_vars++] = backup.var(i++); j++; }
  136. } else if(i < backup.num_vars() && _num_vars < MAX_DIM) {
  137. _vars[_num_vars++] = backup.var(i++);
  138. } else if(j < other.num_vars() && _num_vars < MAX_DIM) {
  139. _vars[_num_vars++] = other.var(j++);
  140. } else {
  141. // Unreachable
  142. throw("Unreachable case in domain operator+=");
  143. }
  144. }
  145. return *this;
  146. }
  147. //! add the other discrete_domain to this discrete_domain
  148. discrete_domain operator+(const discrete_domain& other) const {
  149. discrete_domain dom = *this;
  150. return dom += other;
  151. }
  152. //! subtract the other discrete_domain from this discrete_domain
  153. discrete_domain& operator-=(const discrete_domain& other) {
  154. if(other.num_vars() == 0) return *this;
  155. size_t tmp_num_vars = 0;
  156. for(size_t i = 0, j = 0; i < _num_vars; ++i ) {
  157. // advance the other index
  158. for( ; j < other._num_vars && _vars[i].id() > other._vars[j].id(); ++j);
  159. if(!(j < other._num_vars && _vars[i].id() == other._vars[j].id())) {
  160. _vars[tmp_num_vars++] = _vars[i];
  161. }
  162. }
  163. _num_vars = tmp_num_vars;
  164. return *this;
  165. }
  166. //! subtract the other discrete_domain from this discrete_domain
  167. discrete_domain operator-(const discrete_domain& other) const {
  168. discrete_domain dom = *this;
  169. return dom -= other;
  170. }
  171. discrete_domain intersect(const discrete_domain& other) const {
  172. discrete_domain new_dom;
  173. new_dom._num_vars = 0;
  174. for(size_t i = 0, j = 0;
  175. i < num_vars() && j < other.num_vars(); ) {
  176. if(_vars[i] == other._vars[j]) {
  177. // new discrete_domain gets the variable
  178. new_dom._vars[new_dom._num_vars] = _vars[i];
  179. // Everyone advances
  180. new_dom._num_vars++; i++; j++;
  181. } else {
  182. // otherwise increment one of the variables
  183. if(_vars[i] < other._vars[j]) i++; else j++;
  184. }
  185. }
  186. return new_dom;
  187. }
  188. //! Get the number of variables
  189. size_t num_vars() const { return _num_vars; }
  190. //! Get the ith variable
  191. const discrete_variable& var(size_t index) const {
  192. ASSERT_LT(index, _num_vars);
  193. return _vars[index];
  194. }
  195. /** get the index of the variable or returns number of variables
  196. if the index is not found */
  197. size_t var_location(size_t var_id) const {
  198. size_t location = _num_vars;
  199. for(size_t i = 0; i < _num_vars && !(location < _num_vars); ++i) {
  200. if(_vars[i].id() == var_id) location = i;
  201. }
  202. return location;
  203. }
  204. //! determine the number of assignments
  205. size_t size() const {
  206. size_t sum = 0;
  207. if(_num_vars > 0) {
  208. sum = 1;
  209. for(size_t i = 0; i < _num_vars; ++i) {
  210. // Require variables to be sorted order
  211. if(i > 0) ASSERT_LT( _vars[ i-1], _vars[i] );
  212. // and have positive arity
  213. ASSERT_GT(_vars[i].size(), 0);
  214. sum *= _vars[i].size();
  215. }
  216. }
  217. return sum;
  218. }
  219. //! test whether two discrete_domains are equal
  220. bool operator==(const discrete_domain& other) const {
  221. if( num_vars() != other.num_vars() ) return false;
  222. for(size_t i = 0; i < num_vars(); ++i) {
  223. if(var(i) != other.var(i)) return false;
  224. }
  225. return true;
  226. }
  227. //! test whether two discrete_domains are not equal
  228. bool operator!=(const discrete_domain& other) const {
  229. return !(*this == other);
  230. }
  231. //! Get the first assignment in the discrete_domain
  232. assignment_type begin() const;
  233. //! Get the second assignment in the discrete_domain
  234. assignment_type end() const;
  235. void load(graphlab::iarchive& arc) {
  236. arc >> _num_vars;
  237. ASSERT_LE(_num_vars, MAX_DIM);
  238. for(size_t i = 0; i < _num_vars; ++i) arc >> _vars[i];
  239. }
  240. void save(graphlab::oarchive& arc) const {
  241. arc << _num_vars;
  242. for(size_t i = 0; i < _num_vars; ++i) arc << _vars[i];
  243. }
  244. private:
  245. size_t _num_vars;
  246. discrete_variable _vars[MAX_DIM];
  247. };
  248. template<size_t MAX_DIM>
  249. std::ostream& operator<<(std::ostream& out,
  250. const discrete_domain<MAX_DIM>& dom) {
  251. out << "{";
  252. for(size_t i = 0; i < dom.num_vars(); ++i) {
  253. out << dom.var(i);
  254. if( i < dom.num_vars()-1 ) out << ", ";
  255. }
  256. return out << "}";
  257. }
  258. #include <graphlab/macros_undef.hpp>
  259. #endif