diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index 2510de748de..4dcdc3dc20b 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -27,6 +27,10 @@ class Solver { virtual void Solve(const char* resume_file = NULL); inline void Solve(const string resume_file) { Solve(resume_file.c_str()); } void Step(int iters); + // The Restore function implements how one should restore the solver to a + // previously snapshotted state. You should implement the RestoreSolverState() + // function that restores the state from a SolverState protocol buffer. + void Restore(const char* resume_file); virtual ~Solver() {} inline shared_ptr > net() { return net_; } inline const vector > >& test_nets() { @@ -46,10 +50,6 @@ class Solver { void TestAll(); void Test(const int test_net_id = 0); virtual void SnapshotSolverState(SolverState* state) = 0; - // The Restore function implements how one should restore the solver to a - // previously snapshotted state. You should implement the RestoreSolverState() - // function that restores the state from a SolverState protocol buffer. - void Restore(const char* resume_file); virtual void RestoreSolverState(const SolverState& state) = 0; void DisplayOutputBlobs(const int net_id); diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp index bfea0de661b..dff7f627016 100644 --- a/python/caffe/_caffe.cpp +++ b/python/caffe/_caffe.cpp @@ -261,7 +261,8 @@ BOOST_PYTHON_MODULE(_caffe) { .add_property("iter", &Solver::iter) .def("solve", static_cast::*)(const char*)>( &Solver::Solve), SolveOverloads()) - .def("step", &Solver::Step); + .def("step", &Solver::Step) + .def("restore", &Solver::Restore); bp::class_, bp::bases >, shared_ptr >, boost::noncopyable>(