diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index a28d8cb897e..4bcf0fcd0fa 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -80,6 +80,10 @@ class Solver { virtual void on_start() = 0; virtual void on_gradients_ready() = 0; + virtual void on_test_start(int test_net_id) {} + virtual void on_test_end(int test_net_id) {} + virtual void on_test_iter_start(int test_net_id, int iter) {} + template friend class Solver; }; diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 044269371ad..c3a10dedc9e 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -331,11 +331,20 @@ void Solver::Test(const int test_net_id) { << ", Testing net (#" << test_net_id << ")"; CHECK_NOTNULL(test_nets_[test_net_id].get())-> ShareTrainedLayersWith(net_.get()); + + for (int i = 0; i < callbacks_.size(); ++i) { + callbacks_[i]->on_test_start(test_net_id); + } + vector test_score; vector test_score_output_id; const shared_ptr >& test_net = test_nets_[test_net_id]; Dtype loss = 0; for (int i = 0; i < param_.test_iter(test_net_id); ++i) { + for (int i = 0; i < callbacks_.size(); ++i) { + callbacks_[i]->on_test_iter_start(test_net_id, i); + } + SolverAction::Enum request = GetRequestedAction(); // Check to see if stoppage of testing/training has been requested. while (request != SolverAction::NONE) { @@ -375,6 +384,11 @@ void Solver::Test(const int test_net_id) { } } } + + for (int i = 0; i < callbacks_.size(); ++i) { + callbacks_[i]->on_test_end(test_net_id); + } + if (requested_early_exit_) { LOG(INFO) << "Test interrupted."; return;