diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index 8d52785ac6e..51f8d495c37 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -60,6 +60,11 @@ class Solver { // RestoreSolverStateFrom___ protected methods. You should implement these // methods to restore the state from the appropriate snapshot type. void Restore(const char* resume_file); + // The Solver::Snapshot function implements the basic snapshotting utility + // that stores the learned net. You should implement the SnapshotSolverState() + // function that produces a SolverState protocol buffer that needs to be + // written to disk together with the learned net. + void Snapshot(); virtual ~Solver() {} inline const SolverParameter& param() const { return param_; } inline shared_ptr > net() { return net_; } @@ -87,11 +92,6 @@ class Solver { protected: // Make and apply the update value for the current iteration. virtual void ApplyUpdate() = 0; - // The Solver::Snapshot function implements the basic snapshotting utility - // that stores the learned net. You should implement the SnapshotSolverState() - // function that produces a SolverState protocol buffer that needs to be - // written to disk together with the learned net. - void Snapshot(); string SnapshotFilename(const string extension); string SnapshotToBinaryProto(); string SnapshotToHDF5(); diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp index ccd5776ac40..6c2ccaa5794 100644 --- a/python/caffe/_caffe.cpp +++ b/python/caffe/_caffe.cpp @@ -286,7 +286,8 @@ BOOST_PYTHON_MODULE(_caffe) { .def("solve", static_cast::*)(const char*)>( &Solver::Solve), SolveOverloads()) .def("step", &Solver::Step) - .def("restore", &Solver::Restore); + .def("restore", &Solver::Restore) + .def("snapshot", &Solver::Snapshot); bp::class_, bp::bases >, shared_ptr >, boost::noncopyable>( diff --git a/python/caffe/test/test_solver.py b/python/caffe/test/test_solver.py index 9cfc10d29a9..f618fded8cd 100644 --- a/python/caffe/test/test_solver.py +++ b/python/caffe/test/test_solver.py @@ -16,7 +16,8 @@ def setUp(self): f.write("""net: '""" + net_f + """' test_iter: 10 test_interval: 10 base_lr: 0.01 momentum: 0.9 weight_decay: 0.0005 lr_policy: 'inv' gamma: 0.0001 power: 0.75 - display: 100 max_iter: 100 snapshot_after_train: false""") + display: 100 max_iter: 100 snapshot_after_train: false + snapshot_prefix: "model" """) f.close() self.solver = caffe.SGDSolver(f.name) # also make sure get_solver runs @@ -51,3 +52,11 @@ def test_net_memory(self): total += p.data.sum() + p.diff.sum() for bl in six.itervalues(net.blobs): total += bl.data.sum() + bl.diff.sum() + + def test_snapshot(self): + self.solver.snapshot() + # Check that these files exist and then remove them + files = ['model_iter_0.caffemodel', 'model_iter_0.solverstate'] + for fn in files: + assert os.path.isfile(fn) + os.remove(fn)