/python/caffe/_caffe.cpp

https://github.com/songlu/caffe · C++ · 349 lines · 256 code · 57 blank · 36 comment · 26 complexity · af08d9faefb91cc75db27f5c855aed1b MD5 · raw file

  1. // Copyright 2014 BVLC and contributors.
  2. // pycaffe provides a wrapper of the caffe::Net class as well as some
  3. // caffe::Caffe functions so that one could easily call it from Python.
  4. // Note that for Python, we will simply use float as the data type.
  5. #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
  6. #include "boost/python.hpp"
  7. #include "boost/python/suite/indexing/vector_indexing_suite.hpp"
  8. #include "numpy/arrayobject.h"
  9. // these need to be included after boost on OS X
  10. #include <string> // NOLINT(build/include_order)
  11. #include <vector> // NOLINT(build/include_order)
  12. #include <fstream> // NOLINT
  13. #include "caffe/caffe.hpp"
  14. // Temporary solution for numpy < 1.7 versions: old macro, no promises.
  15. // You're strongly advised to upgrade to >= 1.7.
  16. #ifndef NPY_ARRAY_C_CONTIGUOUS
  17. #define NPY_ARRAY_C_CONTIGUOUS NPY_C_CONTIGUOUS
  18. #define PyArray_SetBaseObject(arr, x) (PyArray_BASE(arr) = (x))
  19. #endif
  20. using namespace caffe; // NOLINT(build/namespaces)
  21. using boost::python::extract;
  22. using boost::python::len;
  23. using boost::python::list;
  24. using boost::python::object;
  25. using boost::python::handle;
  26. using boost::python::vector_indexing_suite;
  27. // for convenience, check that input files can be opened, and raise an
  28. // exception that boost will send to Python if not (caffe could still crash
  29. // later if the input files are disturbed before they are actually used, but
  30. // this saves frustration in most cases)
  31. static void CheckFile(const string& filename) {
  32. std::ifstream f(filename.c_str());
  33. if (!f.good()) {
  34. f.close();
  35. throw std::runtime_error("Could not open file " + filename);
  36. }
  37. f.close();
  38. }
  39. // wrap shared_ptr<Blob<float> > in a class that we construct in C++ and pass
  40. // to Python
  41. class CaffeBlob {
  42. public:
  43. CaffeBlob(const shared_ptr<Blob<float> > &blob, const string& name)
  44. : blob_(blob), name_(name) {}
  45. string name() const { return name_; }
  46. int num() const { return blob_->num(); }
  47. int channels() const { return blob_->channels(); }
  48. int height() const { return blob_->height(); }
  49. int width() const { return blob_->width(); }
  50. int count() const { return blob_->count(); }
  51. // this is here only to satisfy boost's vector_indexing_suite
  52. bool operator == (const CaffeBlob &other) {
  53. return this->blob_ == other.blob_;
  54. }
  55. protected:
  56. shared_ptr<Blob<float> > blob_;
  57. string name_;
  58. };
  59. // We need another wrapper (used as boost::python's HeldType) that receives a
  60. // self PyObject * which we can use as ndarray.base, so that data/diff memory
  61. // is not freed while still being used in Python.
  62. class CaffeBlobWrap : public CaffeBlob {
  63. public:
  64. CaffeBlobWrap(PyObject *p, const CaffeBlob &blob)
  65. : CaffeBlob(blob), self_(p) {}
  66. object get_data() {
  67. npy_intp dims[] = {num(), channels(), height(), width()};
  68. PyObject *obj = PyArray_SimpleNewFromData(4, dims, NPY_FLOAT32,
  69. blob_->mutable_cpu_data());
  70. PyArray_SetBaseObject(reinterpret_cast<PyArrayObject *>(obj), self_);
  71. Py_INCREF(self_);
  72. handle<> h(obj);
  73. return object(h);
  74. }
  75. object get_diff() {
  76. npy_intp dims[] = {num(), channels(), height(), width()};
  77. PyObject *obj = PyArray_SimpleNewFromData(4, dims, NPY_FLOAT32,
  78. blob_->mutable_cpu_diff());
  79. PyArray_SetBaseObject(reinterpret_cast<PyArrayObject *>(obj), self_);
  80. Py_INCREF(self_);
  81. handle<> h(obj);
  82. return object(h);
  83. }
  84. private:
  85. PyObject *self_;
  86. };
  87. class CaffeLayer {
  88. public:
  89. CaffeLayer(const shared_ptr<Layer<float> > &layer, const string &name)
  90. : layer_(layer), name_(name) {}
  91. string name() const { return name_; }
  92. vector<CaffeBlob> blobs() {
  93. vector<CaffeBlob> result;
  94. for (int i = 0; i < layer_->blobs().size(); ++i) {
  95. result.push_back(CaffeBlob(layer_->blobs()[i], name_));
  96. }
  97. return result;
  98. }
  99. // this is here only to satisfy boost's vector_indexing_suite
  100. bool operator == (const CaffeLayer &other) {
  101. return this->layer_ == other.layer_;
  102. }
  103. protected:
  104. shared_ptr<Layer<float> > layer_;
  105. string name_;
  106. };
  107. // A simple wrapper over CaffeNet that runs the forward process.
  108. struct CaffeNet {
  109. // For cases where parameters will be determined later by the Python user,
  110. // create a Net with unallocated parameters (which will not be zero-filled
  111. // when accessed).
  112. explicit CaffeNet(string param_file) {
  113. Init(param_file);
  114. }
  115. CaffeNet(string param_file, string pretrained_param_file) {
  116. Init(param_file);
  117. CheckFile(pretrained_param_file);
  118. net_->CopyTrainedLayersFrom(pretrained_param_file);
  119. }
  120. explicit CaffeNet(shared_ptr<Net<float> > net)
  121. : net_(net) {}
  122. void Init(string param_file) {
  123. CheckFile(param_file);
  124. net_.reset(new Net<float>(param_file));
  125. }
  126. virtual ~CaffeNet() {}
  127. // Generate Python exceptions for badly shaped or discontiguous arrays.
  128. inline void check_contiguous_array(PyArrayObject* arr, string name,
  129. int channels, int height, int width) {
  130. if (!(PyArray_FLAGS(arr) & NPY_ARRAY_C_CONTIGUOUS)) {
  131. throw std::runtime_error(name + " must be C contiguous");
  132. }
  133. if (PyArray_NDIM(arr) != 4) {
  134. throw std::runtime_error(name + " must be 4-d");
  135. }
  136. if (PyArray_TYPE(arr) != NPY_FLOAT32) {
  137. throw std::runtime_error(name + " must be float32");
  138. }
  139. if (PyArray_DIMS(arr)[1] != channels) {
  140. throw std::runtime_error(name + " has wrong number of channels");
  141. }
  142. if (PyArray_DIMS(arr)[2] != height) {
  143. throw std::runtime_error(name + " has wrong height");
  144. }
  145. if (PyArray_DIMS(arr)[3] != width) {
  146. throw std::runtime_error(name + " has wrong width");
  147. }
  148. }
  149. void Forward() {
  150. net_->ForwardPrefilled();
  151. }
  152. void Backward() {
  153. net_->Backward();
  154. }
  155. void set_input_arrays(object data_obj, object labels_obj) {
  156. // check that this network has an input MemoryDataLayer
  157. shared_ptr<MemoryDataLayer<float> > md_layer =
  158. boost::dynamic_pointer_cast<MemoryDataLayer<float> >(net_->layers()[0]);
  159. if (!md_layer) {
  160. throw std::runtime_error("set_input_arrays may only be called if the"
  161. " first layer is a MemoryDataLayer");
  162. }
  163. // check that we were passed appropriately-sized contiguous memory
  164. PyArrayObject* data_arr =
  165. reinterpret_cast<PyArrayObject*>(data_obj.ptr());
  166. PyArrayObject* labels_arr =
  167. reinterpret_cast<PyArrayObject*>(labels_obj.ptr());
  168. check_contiguous_array(data_arr, "data array", md_layer->datum_channels(),
  169. md_layer->datum_height(), md_layer->datum_width());
  170. check_contiguous_array(labels_arr, "labels array", 1, 1, 1);
  171. if (PyArray_DIMS(data_arr)[0] != PyArray_DIMS(labels_arr)[0]) {
  172. throw std::runtime_error("data and labels must have the same first"
  173. " dimension");
  174. }
  175. if (PyArray_DIMS(data_arr)[0] % md_layer->batch_size() != 0) {
  176. throw std::runtime_error("first dimensions of input arrays must be a"
  177. " multiple of batch size");
  178. }
  179. // hold references
  180. input_data_ = data_obj;
  181. input_labels_ = labels_obj;
  182. md_layer->Reset(static_cast<float*>(PyArray_DATA(data_arr)),
  183. static_cast<float*>(PyArray_DATA(labels_arr)),
  184. PyArray_DIMS(data_arr)[0]);
  185. }
  186. // The caffe::Caffe utility functions.
  187. void set_mode_cpu() { Caffe::set_mode(Caffe::CPU); }
  188. void set_mode_gpu() { Caffe::set_mode(Caffe::GPU); }
  189. void set_phase_train() { Caffe::set_phase(Caffe::TRAIN); }
  190. void set_phase_test() { Caffe::set_phase(Caffe::TEST); }
  191. void set_device(int device_id) { Caffe::SetDevice(device_id); }
  192. vector<CaffeBlob> blobs() {
  193. vector<CaffeBlob> result;
  194. for (int i = 0; i < net_->blobs().size(); ++i) {
  195. result.push_back(CaffeBlob(net_->blobs()[i], net_->blob_names()[i]));
  196. }
  197. return result;
  198. }
  199. vector<CaffeLayer> layers() {
  200. vector<CaffeLayer> result;
  201. for (int i = 0; i < net_->layers().size(); ++i) {
  202. result.push_back(CaffeLayer(net_->layers()[i], net_->layer_names()[i]));
  203. }
  204. return result;
  205. }
  206. list inputs() {
  207. list input_blob_names;
  208. for (int i = 0; i < net_->input_blob_indices().size(); ++i) {
  209. input_blob_names.append(
  210. net_->blob_names()[net_->input_blob_indices()[i]]);
  211. }
  212. return input_blob_names;
  213. }
  214. list outputs() {
  215. list output_blob_names;
  216. for (int i = 0; i < net_->output_blob_indices().size(); ++i) {
  217. output_blob_names.append(
  218. net_->blob_names()[net_->output_blob_indices()[i]]);
  219. }
  220. return output_blob_names;
  221. }
  222. // The pointer to the internal caffe::Net instant.
  223. shared_ptr<Net<float> > net_;
  224. // if taking input from an ndarray, we need to hold references
  225. object input_data_;
  226. object input_labels_;
  227. };
  228. class CaffeSGDSolver {
  229. public:
  230. explicit CaffeSGDSolver(const string& param_file) {
  231. // as in CaffeNet, (as a convenience, not a guarantee), create a Python
  232. // exception if param_file can't be opened
  233. CheckFile(param_file);
  234. solver_.reset(new SGDSolver<float>(param_file));
  235. // we need to explicitly store the net wrapper, rather than constructing
  236. // it on the fly, so that it can hold references to Python objects
  237. net_.reset(new CaffeNet(solver_->net()));
  238. }
  239. shared_ptr<CaffeNet> net() { return net_; }
  240. void Solve() { return solver_->Solve(); }
  241. void SolveResume(const string& resume_file) {
  242. CheckFile(resume_file);
  243. return solver_->Solve(resume_file);
  244. }
  245. protected:
  246. shared_ptr<CaffeNet> net_;
  247. shared_ptr<SGDSolver<float> > solver_;
  248. };
  249. // The boost_python module definition.
  250. BOOST_PYTHON_MODULE(_caffe) {
  251. // below, we prepend an underscore to methods that will be replaced
  252. // in Python
  253. boost::python::class_<CaffeNet, shared_ptr<CaffeNet> >(
  254. "Net", boost::python::init<string, string>())
  255. .def(boost::python::init<string>())
  256. .def("_forward", &CaffeNet::Forward)
  257. .def("_backward", &CaffeNet::Backward)
  258. .def("set_mode_cpu", &CaffeNet::set_mode_cpu)
  259. .def("set_mode_gpu", &CaffeNet::set_mode_gpu)
  260. .def("set_phase_train", &CaffeNet::set_phase_train)
  261. .def("set_phase_test", &CaffeNet::set_phase_test)
  262. .def("set_device", &CaffeNet::set_device)
  263. .add_property("_blobs", &CaffeNet::blobs)
  264. .add_property("layers", &CaffeNet::layers)
  265. .add_property("inputs", &CaffeNet::inputs)
  266. .add_property("outputs", &CaffeNet::outputs)
  267. .def("_set_input_arrays", &CaffeNet::set_input_arrays);
  268. boost::python::class_<CaffeBlob, CaffeBlobWrap>(
  269. "Blob", boost::python::no_init)
  270. .add_property("name", &CaffeBlob::name)
  271. .add_property("num", &CaffeBlob::num)
  272. .add_property("channels", &CaffeBlob::channels)
  273. .add_property("height", &CaffeBlob::height)
  274. .add_property("width", &CaffeBlob::width)
  275. .add_property("count", &CaffeBlob::count)
  276. .add_property("data", &CaffeBlobWrap::get_data)
  277. .add_property("diff", &CaffeBlobWrap::get_diff);
  278. boost::python::class_<CaffeLayer>(
  279. "Layer", boost::python::no_init)
  280. .add_property("name", &CaffeLayer::name)
  281. .add_property("blobs", &CaffeLayer::blobs);
  282. boost::python::class_<CaffeSGDSolver, boost::noncopyable>(
  283. "SGDSolver", boost::python::init<string>())
  284. .add_property("net", &CaffeSGDSolver::net)
  285. .def("solve", &CaffeSGDSolver::Solve)
  286. .def("solve", &CaffeSGDSolver::SolveResume);
  287. boost::python::class_<vector<CaffeBlob> >("BlobVec")
  288. .def(vector_indexing_suite<vector<CaffeBlob>, true>());
  289. boost::python::class_<vector<CaffeLayer> >("LayerVec")
  290. .def(vector_indexing_suite<vector<CaffeLayer>, true>());
  291. import_array();
  292. }