/toolkits/graphical_models/deprecated/factors/discrete_domain.hpp
C++ Header | 313 lines | 206 code | 54 blank | 53 comment | 73 complexity | 9dff8d8dd2e13b2d4e463939e20e657a MD5 | raw file
Possible License(s): ISC, Apache-2.0
- /**
- * Copyright (c) 2009 Carnegie Mellon University.
- * All rights reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an "AS
- * IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
- * express or implied. See the License for the specific language
- * governing permissions and limitations under the License.
- *
- * For more about this software visit:
- *
- * http://www.graphlab.ml.cmu.edu
- *
- */
- #ifndef DISCRETE_DOMAIN_HPP
- #define DISCRETE_DOMAIN_HPP
- #include <graphlab/logger/assertions.hpp>
- #include "discrete_variable.hpp"
- #include <graphlab/macros_def.hpp>
- //! Predeclearation of assignment used for discrete domain
- template<size_t MAX_DIM> class discrete_assignment;
- /**
- * This class respresents a discrete discrete_domain over a set of variables.
- */
- template<size_t MAX_DIM>
- class discrete_domain {
- public:
- typedef discrete_assignment<MAX_DIM> assignment_type;
- //! Make an empy domain
- discrete_domain() : _num_vars(0) { }
- //! Make a single variable discrete_domain
- discrete_domain(const discrete_variable& v1) :
- _num_vars(1) {
- ASSERT_LE(_num_vars, MAX_DIM);
- _vars[0] = v1;
- }
- //! Make a two variable discrete_domain
- discrete_domain(const discrete_variable& v1, const discrete_variable& v2) :
- _num_vars(2) {
- ASSERT_LE(_num_vars, MAX_DIM);
- assert(v1 != v2);
- if(v1 < v2) {
- _vars[0] = v1;
- _vars[1] = v2;
- } else {
- _vars[0] = v2;
- _vars[1] = v1;
- }
- }
- //! Make a three variable discrete_domain
- discrete_domain(const discrete_variable& v1,
- const discrete_variable& v2,
- const discrete_variable& v3) :
- _num_vars(3) {
- ASSERT_LE(_num_vars, MAX_DIM);
- ASSERT_NE(v1, v2);
- ASSERT_NE(v2, v3);
- ASSERT_NE(v1, v3);
-
- if(v1 < v2 && v2 < v3) {
- _vars[0] = v1;
- _vars[1] = v2;
- _vars[2] = v3;
- } else if( v1 < v3 && v3 < v2) {
- _vars[0] = v1;
- _vars[1] = v3;
- _vars[2] = v2;
- } else if( v2 < v1 && v1 < v3) {
- _vars[0] = v2;
- _vars[1] = v1;
- _vars[2] = v3;
- } else if( v2 < v3 && v3 < v1) {
- _vars[0] = v2;
- _vars[1] = v3;
- _vars[2] = v1;
- } else if( v3 < v1 && v1 < v2) {
- _vars[0] = v3;
- _vars[1] = v1;
- _vars[2] = v2;
- } else if( v3 < v1 && v1 < v2) {
- _vars[0] = v3;
- _vars[1] = v1;
- _vars[2] = v2;
- } else { throw("Invalid Case!"); }
- }
- //! Make a discrete_domain from a vector of variables
- discrete_domain(const std::vector<discrete_variable>& variables) :
- _num_vars(variables.size()) {
- ASSERT_LE(_num_vars, MAX_DIM);
- for(size_t i = 0; i < _num_vars; ++i)
- _vars[i] = variables[i];
- std::sort(_vars, _vars + std::min(MAX_DIM, _num_vars) );
- }
- //! Make a discrete_domain from a set of variables
- discrete_domain(const std::set<discrete_variable>& variables) :
- _num_vars(variables.size()) {
- ASSERT_LE(_num_vars, MAX_DIM);
- size_t i = 0;
- foreach(const discrete_variable& var, variables) _vars[i++] = var;
- }
-
- discrete_domain& operator+=(const discrete_variable& var) {
- if(_vars[_num_vars - 1] < var) {
- _vars[_num_vars] = var;
- _num_vars++;
- return *this;
- }
- return operator+=(discrete_domain(var));
- }
- //! add the other discrete_domain to this discrete_domain
- discrete_domain operator+(const discrete_variable& var) const {
- discrete_domain dom = *this;
- return dom += var;
- }
- //! add the discrete_domain to this discrete_domain
- discrete_domain& operator+=(const discrete_domain& other) {
- if(other.num_vars() == 0) return *this;
- discrete_domain backup = *this;
- _num_vars = 0;
- for(size_t i = 0, j = 0;
- i < backup.num_vars() || j < other.num_vars(); ) {
- ASSERT_LE(_num_vars, MAX_DIM);
- // Both
- if(i < backup.num_vars() && j < other.num_vars()
- && _num_vars < MAX_DIM) {
- if(backup.var(i) < other.var(j))
- _vars[_num_vars++] = backup.var(i++);
- else if(other.var(j) < backup.var(i))
- _vars[_num_vars++] = other.var(j++);
- else { _vars[_num_vars++] = backup.var(i++); j++; }
- } else if(i < backup.num_vars() && _num_vars < MAX_DIM) {
- _vars[_num_vars++] = backup.var(i++);
- } else if(j < other.num_vars() && _num_vars < MAX_DIM) {
- _vars[_num_vars++] = other.var(j++);
- } else {
- // Unreachable
- throw("Unreachable case in domain operator+=");
- }
- }
- return *this;
- }
-
- //! add the other discrete_domain to this discrete_domain
- discrete_domain operator+(const discrete_domain& other) const {
- discrete_domain dom = *this;
- return dom += other;
- }
-
- //! subtract the other discrete_domain from this discrete_domain
- discrete_domain& operator-=(const discrete_domain& other) {
- if(other.num_vars() == 0) return *this;
-
- size_t tmp_num_vars = 0;
- for(size_t i = 0, j = 0; i < _num_vars; ++i ) {
- // advance the other index
- for( ; j < other._num_vars && _vars[i].id() > other._vars[j].id(); ++j);
- if(!(j < other._num_vars && _vars[i].id() == other._vars[j].id())) {
- _vars[tmp_num_vars++] = _vars[i];
- }
- }
- _num_vars = tmp_num_vars;
- return *this;
- }
- //! subtract the other discrete_domain from this discrete_domain
- discrete_domain operator-(const discrete_domain& other) const {
- discrete_domain dom = *this;
- return dom -= other;
- }
- discrete_domain intersect(const discrete_domain& other) const {
- discrete_domain new_dom;
- new_dom._num_vars = 0;
- for(size_t i = 0, j = 0;
- i < num_vars() && j < other.num_vars(); ) {
- if(_vars[i] == other._vars[j]) {
- // new discrete_domain gets the variable
- new_dom._vars[new_dom._num_vars] = _vars[i];
- // Everyone advances
- new_dom._num_vars++; i++; j++;
- } else {
- // otherwise increment one of the variables
- if(_vars[i] < other._vars[j]) i++; else j++;
- }
- }
- return new_dom;
- }
-
- //! Get the number of variables
- size_t num_vars() const { return _num_vars; }
- //! Get the ith variable
- const discrete_variable& var(size_t index) const {
- ASSERT_LT(index, _num_vars);
- return _vars[index];
- }
- /** get the index of the variable or returns number of variables
- if the index is not found */
- size_t var_location(size_t var_id) const {
- size_t location = _num_vars;
- for(size_t i = 0; i < _num_vars && !(location < _num_vars); ++i) {
- if(_vars[i].id() == var_id) location = i;
- }
- return location;
- }
-
- //! determine the number of assignments
- size_t size() const {
- size_t sum = 0;
- if(_num_vars > 0) {
- sum = 1;
- for(size_t i = 0; i < _num_vars; ++i) {
- // Require variables to be sorted order
- if(i > 0) ASSERT_LT( _vars[ i-1], _vars[i] );
- // and have positive arity
- ASSERT_GT(_vars[i].size(), 0);
- sum *= _vars[i].size();
- }
- }
- return sum;
- }
- //! test whether two discrete_domains are equal
- bool operator==(const discrete_domain& other) const {
- if( num_vars() != other.num_vars() ) return false;
- for(size_t i = 0; i < num_vars(); ++i) {
- if(var(i) != other.var(i)) return false;
- }
- return true;
- }
-
- //! test whether two discrete_domains are not equal
- bool operator!=(const discrete_domain& other) const {
- return !(*this == other);
- }
- //! Get the first assignment in the discrete_domain
- assignment_type begin() const;
- //! Get the second assignment in the discrete_domain
- assignment_type end() const;
- void load(graphlab::iarchive& arc) {
- arc >> _num_vars;
- ASSERT_LE(_num_vars, MAX_DIM);
- for(size_t i = 0; i < _num_vars; ++i) arc >> _vars[i];
- }
-
- void save(graphlab::oarchive& arc) const {
- arc << _num_vars;
- for(size_t i = 0; i < _num_vars; ++i) arc << _vars[i];
- }
- private:
- size_t _num_vars;
- discrete_variable _vars[MAX_DIM];
- };
- template<size_t MAX_DIM>
- std::ostream& operator<<(std::ostream& out,
- const discrete_domain<MAX_DIM>& dom) {
- out << "{";
- for(size_t i = 0; i < dom.num_vars(); ++i) {
- out << dom.var(i);
- if( i < dom.num_vars()-1 ) out << ", ";
- }
- return out << "}";
- }
- #include <graphlab/macros_undef.hpp>
- #endif