From 1ce3380f172336cadaa649a6e077a42a246a534d Mon Sep 17 00:00:00 2001 From: Mohamed Omran Date: Sat, 20 Sep 2014 19:01:28 +0200 Subject: [PATCH 1/3] Implement AdaDelta; add test cases; add mnist examples --- examples/mnist/lenet_adadelta_solver.prototxt | 22 +++ .../mnist_autoencoder_solver_adadelta.prototxt | 17 ++ examples/mnist/train_mnist_autoencoder_adadelta.sh | 4 + include/caffe/solver.hpp | 23 +++ src/caffe/proto/caffe.proto | 1 + src/caffe/solver.cpp | 199 +++++++++++++++++++++ src/caffe/test/test_gradient_based_solver.cpp | 100 ++++++++++- 7 files changed, 364 insertions(+), 2 deletions(-) create mode 100644 examples/mnist/lenet_adadelta_solver.prototxt create mode 100644 examples/mnist/mnist_autoencoder_solver_adadelta.prototxt create mode 100755 examples/mnist/train_mnist_autoencoder_adadelta.sh diff --git a/examples/mnist/lenet_adadelta_solver.prototxt b/examples/mnist/lenet_adadelta_solver.prototxt new file mode 100644 index 00000000000..b77b451d56a --- /dev/null +++ b/examples/mnist/lenet_adadelta_solver.prototxt @@ -0,0 +1,22 @@ +# The train/test net protocol buffer definition +net: "examples/mnist/lenet_train_test.prototxt" +# test_iter specifies how many forward passes the test should carry out. +# In the case of MNIST, we have test batch size 100 and 100 test iterations, +# covering the full 10,000 testing images. +test_iter: 100 +# Carry out testing every 500 training iterations. +test_interval: 500 +# The base learning rate, momentum and the weight decay of the network. +momentum: 0.95 +weight_decay: 0.0005 +# Display every 100 iterations +display: 100 +# The maximum number of iterations +max_iter: 10000 +# snapshot intermediate results +snapshot: 5000 +snapshot_prefix: "examples/mnist/lenet_adadelta" +# solver mode: CPU or GPU +solver_mode: GPU +solver_type: ADADELTA +delta: 1e-6 diff --git a/examples/mnist/mnist_autoencoder_solver_adadelta.prototxt b/examples/mnist/mnist_autoencoder_solver_adadelta.prototxt new file mode 100644 index 00000000000..cc4f0bbb4a7 --- /dev/null +++ b/examples/mnist/mnist_autoencoder_solver_adadelta.prototxt @@ -0,0 +1,17 @@ +net: "examples/mnist/mnist_autoencoder.prototxt" +test_state: { stage: 'test-on-train' } +test_iter: 500 +test_state: { stage: 'test-on-test' } +test_iter: 100 +test_interval: 500 +test_compute_loss: true +momentum: 0.95 +display: 100 +max_iter: 65000 +weight_decay: 0.0005 +snapshot: 10000 +snapshot_prefix: "examples/mnist/mnist_autoencoder_adadelta_train" +# solver mode: CPU or GPU +solver_mode: GPU +solver_type: ADADELTA +delta: 1e-8 diff --git a/examples/mnist/train_mnist_autoencoder_adadelta.sh b/examples/mnist/train_mnist_autoencoder_adadelta.sh new file mode 100755 index 00000000000..4be0ebddedc --- /dev/null +++ b/examples/mnist/train_mnist_autoencoder_adadelta.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +./build/tools/caffe train \ + --solver=examples/mnist/mnist_autoencoder_solver_adadelta.prototxt diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index fbade9389ff..4b408380119 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -159,6 +159,27 @@ class RMSPropSolver : public SGDSolver { }; template +class AdaDeltaSolver : public SGDSolver { + public: + explicit AdaDeltaSolver(const SolverParameter& param) + : SGDSolver(param) { constructor_sanity_check(); } + explicit AdaDeltaSolver(const string& param_file) + : SGDSolver(param_file) { constructor_sanity_check(); } + + protected: + virtual void PreSolve(); + virtual void ComputeUpdateValue(); + void constructor_sanity_check() { + CHECK_EQ(0, this->param_.base_lr()) + << "Learning rate cannot be used with AdaDelta."; + CHECK_EQ("", this->param_.lr_policy()) + << "Learning rate policy cannot be applied to AdaDelta."; + } + + DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver); +}; + +template Solver* GetSolver(const SolverParameter& param) { SolverParameter_SolverType type = param.solver_type(); @@ -171,6 +192,8 @@ Solver* GetSolver(const SolverParameter& param) { return new AdaGradSolver(param); case SolverParameter_SolverType_RMSPROP: return new RMSPropSolver(param); + case SolverParameter_SolverType_ADADELTA: + return new AdaDeltaSolver(param); default: LOG(FATAL) << "Unknown SolverType: " << type; } diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 89f14595ba6..7cfcaa8bac7 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -215,6 +215,7 @@ message SolverParameter { NESTEROV = 1; ADAGRAD = 2; RMSPROP = 3; + ADADELTA = 4; } optional SolverType solver_type = 30 [default = SGD]; // numerical stability for AdaGrad diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 54e085a63e5..d8749a1b939 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -934,10 +934,209 @@ void RMSPropSolver::ComputeUpdateValue(int param_id, Dtype rate) { } } +template +void AdaDeltaSolver::PreSolve() { + // Initialize the history + vector > >& net_params = this->net_->params(); + this->history_.clear(); + this->update_.clear(); + this->temp_.clear(); + for (int i = 0; i < net_params.size(); ++i) { + const Blob* net_param = net_params[i].get(); + this->history_.push_back(shared_ptr >(new Blob( + net_param->num(), net_param->channels(), net_param->height(), + net_param->width()))); + this->update_.push_back(shared_ptr >(new Blob( + net_param->num(), net_param->channels(), net_param->height(), + net_param->width()))); + this->temp_.push_back(shared_ptr >(new Blob( + net_param->num(), net_param->channels(), net_param->height(), + net_param->width()))); + } + for (int i = 0; i < net_params.size(); ++i) { + const Blob* net_param = net_params[i].get(); + this->history_.push_back(shared_ptr >(new Blob( + net_param->num(), net_param->channels(), net_param->height(), + net_param->width()))); + } +} + +template +void AdaDeltaSolver::ComputeUpdateValue() { + vector > >& net_params = this->net_->params(); + vector& net_params_weight_decay = this->net_->params_weight_decay(); + Dtype delta = this->param_.delta(); + Dtype momentum = this->param_.momentum(); + Dtype weight_decay = this->param_.weight_decay(); + string regularization_type = this->param_.regularization_type(); + size_t update_history_offset = net_params.size(); + switch (Caffe::mode()) { + case Caffe::CPU: + for (int param_id = 0; param_id < net_params.size(); ++param_id) { + Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; + + if (local_decay) { + if (regularization_type == "L2") { + // add weight decay + caffe_axpy(net_params[param_id]->count(), + local_decay, + net_params[param_id]->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + } else if (regularization_type == "L1") { + caffe_cpu_sign(net_params[param_id]->count(), + net_params[param_id]->cpu_data(), + this->temp_[param_id]->mutable_cpu_data()); + caffe_axpy(net_params[param_id]->count(), + local_decay, + this->temp_[param_id]->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + } else { + LOG(FATAL) << "Unknown regularization type: " << regularization_type; + } + } + + // compute square of gradient in update + caffe_powx(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), Dtype(2), + this->update_[param_id]->mutable_cpu_data()); + + // update history of gradients + caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, + this->update_[param_id]->cpu_data(), momentum, + this->history_[param_id]->mutable_cpu_data()); + + // add delta to history to guard against dividing by zero later + caffe_set(net_params[param_id]->count(), delta, + this->temp_[param_id]->mutable_cpu_data()); + + caffe_add(net_params[param_id]->count(), + this->temp_[param_id]->cpu_data(), + this->history_[update_history_offset + param_id]->cpu_data(), + this->update_[param_id]->mutable_cpu_data()); + + caffe_add(net_params[param_id]->count(), + this->temp_[param_id]->cpu_data(), + this->history_[param_id]->cpu_data(), + this->temp_[param_id]->mutable_cpu_data()); + + // divide history of updates by history of gradients + caffe_div(net_params[param_id]->count(), + this->update_[param_id]->cpu_data(), + this->temp_[param_id]->cpu_data(), + this->update_[param_id]->mutable_cpu_data()); + + // jointly compute the RMS of both for update and gradient history + caffe_powx(net_params[param_id]->count(), + this->update_[param_id]->cpu_data(), Dtype(0.5), + this->update_[param_id]->mutable_cpu_data()); + + // compute the update + caffe_mul(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), + this->update_[param_id]->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + + // compute square of update + caffe_powx(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), Dtype(2), + this->update_[param_id]->mutable_cpu_data()); + + // update history of updates + caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, + this->update_[param_id]->cpu_data(), momentum, + this->history_[update_history_offset + param_id]->mutable_cpu_data()); + } + break; + case Caffe::GPU: +#ifndef CPU_ONLY + for (int param_id = 0; param_id < net_params.size(); ++param_id) { + Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; + + if (local_decay) { + if (regularization_type == "L2") { + // add weight decay + caffe_gpu_axpy(net_params[param_id]->count(), + local_decay, + net_params[param_id]->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); + } else if (regularization_type == "L1") { + caffe_gpu_sign(net_params[param_id]->count(), + net_params[param_id]->gpu_data(), + this->temp_[param_id]->mutable_gpu_data()); + caffe_gpu_axpy(net_params[param_id]->count(), + local_decay, + this->temp_[param_id]->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); + } else { + LOG(FATAL) << "Unknown regularization type: " << regularization_type; + } + } + + // compute square of gradient in update + caffe_gpu_powx(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), Dtype(2), + this->update_[param_id]->mutable_gpu_data()); + + // update history of gradients + caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, + this->update_[param_id]->gpu_data(), momentum, + this->history_[param_id]->mutable_gpu_data()); + + // add delta to history to guard against dividing by zero later + caffe_gpu_set(net_params[param_id]->count(), delta, + this->temp_[param_id]->mutable_gpu_data()); + + caffe_gpu_add(net_params[param_id]->count(), + this->temp_[param_id]->gpu_data(), + this->history_[update_history_offset + param_id]->gpu_data(), + this->update_[param_id]->mutable_gpu_data()); + + caffe_gpu_add(net_params[param_id]->count(), + this->temp_[param_id]->gpu_data(), + this->history_[param_id]->gpu_data(), + this->temp_[param_id]->mutable_gpu_data()); + + // divide history of updates by history of gradients + caffe_gpu_div(net_params[param_id]->count(), + this->update_[param_id]->gpu_data(), + this->temp_[param_id]->gpu_data(), + this->update_[param_id]->mutable_gpu_data()); + + // jointly compute the RMS of both for update and gradient history + caffe_gpu_powx(net_params[param_id]->count(), + this->update_[param_id]->gpu_data(), Dtype(0.5), + this->update_[param_id]->mutable_gpu_data()); + + // compute the update and copy to net_diff + caffe_gpu_mul(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), + this->update_[param_id]->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); + + // compute square of update + caffe_gpu_powx(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), Dtype(2), + this->update_[param_id]->mutable_gpu_data()); + + // update history of updates + caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, + this->update_[param_id]->gpu_data(), momentum, + this->history_[update_history_offset + param_id]->mutable_gpu_data()); + } +#else + NO_GPU; +#endif + break; + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + INSTANTIATE_CLASS(Solver); INSTANTIATE_CLASS(SGDSolver); INSTANTIATE_CLASS(NesterovSolver); INSTANTIATE_CLASS(AdaGradSolver); INSTANTIATE_CLASS(RMSPropSolver); +INSTANTIATE_CLASS(AdaDeltaSolver); } // namespace caffe diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp index eaa7a759b9b..db89e285a9f 100644 --- a/src/caffe/test/test_gradient_based_solver.cpp +++ b/src/caffe/test/test_gradient_based_solver.cpp @@ -64,7 +64,8 @@ class GradientBasedSolverTest : public MultiDeviceTest { } InitSolver(param); delta_ = (solver_type() == SolverParameter_SolverType_ADAGRAD || - solver_type() == SolverParameter_SolverType_RMSPROP) ? + solver_type() == SolverParameter_SolverType_RMSPROP || + solver_type() == SolverParameter_SolverType_ADADELTA) ? param.delta() : 0; } @@ -164,6 +165,10 @@ class GradientBasedSolverTest : public MultiDeviceTest { " bottom: 'targets' " " } " "} "; + if (learning_rate != 0) { + proto << "base_lr: " << learning_rate << " "; + proto << "lr_policy: 'fixed' "; + } if (weight_decay != 0) { proto << "weight_decay: " << weight_decay << " "; } @@ -266,7 +271,11 @@ class GradientBasedSolverTest : public MultiDeviceTest { ((i == D) ? bias.cpu_data()[0] : weights.cpu_data()[i]); // Finally, compute update. const vector > >& history = solver_->history(); - ASSERT_EQ(2, history.size()); // 1 blob for weights, 1 for bias + if (solver_type() != SolverParameter_SolverType_ADADELTA) { + ASSERT_EQ(2, history.size()); // 1 blob for weights, 1 for bias + } else { + ASSERT_EQ(4, history.size()); // additional blobs for update history + } Dtype update_value = learning_rate * grad; const Dtype history_value = (i == D) ? history[1]->cpu_data()[0] : history[0]->cpu_data()[i]; @@ -289,6 +298,19 @@ class GradientBasedSolverTest : public MultiDeviceTest { + grad * grad * (1 - rms_decay)) + delta_; } break; + case SolverParameter_SolverType_ADADELTA: + { + const Dtype update_history_value = (i == D) ? + history[3]->cpu_data()[0] : history[2]->cpu_data()[i]; + const Dtype weighted_gradient_average = + momentum * history_value + (1 - momentum) * (grad * grad); + update_value = grad * std::sqrt((update_history_value + delta_) / + (weighted_gradient_average + delta_)); + // not actually needed, just here for illustrative purposes + // const Dtype weighted_update_average = + // momentum * update_history_value + (1 - momentum) * (update_value); + break; + } default: LOG(FATAL) << "Unknown solver type: " << solver_type(); } @@ -981,4 +1003,78 @@ TYPED_TEST(RMSPropSolverTest, TestSnapshotShare) { } } +template +class AdaDeltaSolverTest : public GradientBasedSolverTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + virtual void InitSolver(const SolverParameter& param) { + this->solver_.reset(new AdaDeltaSolver(param)); + } + + virtual SolverParameter_SolverType solver_type() { + return SolverParameter_SolverType_ADADELTA; + } +}; + +TYPED_TEST_CASE(AdaDeltaSolverTest, TestDtypesAndDevices); + +TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdate) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.0; + this->TestLeastSquaresUpdate(kLearningRate); +} + +TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithWeightDecay) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.0; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.95; + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum); +} + +TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithHalfMomentum) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.0; + const Dtype kWeightDecay = 0.0; + const Dtype kMomentum = 0.5; + const int kNumIters = 1; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum); + } +} + +TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithMomentum) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.0; + const Dtype kWeightDecay = 0.0; + const Dtype kMomentum = 0.95; + const int kNumIters = 1; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum); + } +} + +TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.0; + const Dtype kWeightDecay = 0.0; + const Dtype kMomentum = 0.95; + const int kNumIters = 500; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithEverything) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.0; + const Dtype kWeightDecay = 0.1; + const Dtype kMomentum = 0.95; + const int kNumIters = 500; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + } // namespace caffe From 4c58741ce2e031b61aef53914128801e6edd673d Mon Sep 17 00:00:00 2001 From: Kevin Bache Date: Thu, 19 Mar 2015 15:56:51 -0700 Subject: [PATCH 2/3] Updated AdaDelta for modern Caffe; reduced iterations on multi-iter tests --- .../mnist_autoencoder_solver_adadelta.prototxt | 2 +- include/caffe/solver.hpp | 6 ++-- src/caffe/solver.cpp | 32 ++++++---------------- src/caffe/test/test_gradient_based_solver.cpp | 4 +-- 4 files changed, 15 insertions(+), 29 deletions(-) diff --git a/examples/mnist/mnist_autoencoder_solver_adadelta.prototxt b/examples/mnist/mnist_autoencoder_solver_adadelta.prototxt index cc4f0bbb4a7..4e43468a71f 100644 --- a/examples/mnist/mnist_autoencoder_solver_adadelta.prototxt +++ b/examples/mnist/mnist_autoencoder_solver_adadelta.prototxt @@ -6,6 +6,7 @@ test_iter: 100 test_interval: 500 test_compute_loss: true momentum: 0.95 +delta: 1e-8 display: 100 max_iter: 65000 weight_decay: 0.0005 @@ -14,4 +15,3 @@ snapshot_prefix: "examples/mnist/mnist_autoencoder_adadelta_train" # solver mode: CPU or GPU solver_mode: GPU solver_type: ADADELTA -delta: 1e-8 diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index 4b408380119..495cd4f159e 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -82,12 +82,12 @@ class SGDSolver : public Solver { const vector > >& history() { return history_; } protected: - void PreSolve(); Dtype GetLearningRate(); virtual void ApplyUpdate(); virtual void Normalize(int param_id); virtual void Regularize(int param_id); virtual void ComputeUpdateValue(int param_id, Dtype rate); + virtual void PreSolve(); virtual void ClipGradients(); virtual void SnapshotSolverState(const string& model_filename); virtual void SnapshotSolverStateToBinaryProto(const string& model_filename); @@ -162,9 +162,9 @@ template class AdaDeltaSolver : public SGDSolver { public: explicit AdaDeltaSolver(const SolverParameter& param) - : SGDSolver(param) { constructor_sanity_check(); } + : SGDSolver(param) { PreSolve(); constructor_sanity_check(); } explicit AdaDeltaSolver(const string& param_file) - : SGDSolver(param_file) { constructor_sanity_check(); } + : SGDSolver(param_file) { PreSolve(); constructor_sanity_check(); } protected: virtual void PreSolve(); diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index d8749a1b939..34a290ffe3d 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -936,35 +936,21 @@ void RMSPropSolver::ComputeUpdateValue(int param_id, Dtype rate) { template void AdaDeltaSolver::PreSolve() { - // Initialize the history - vector > >& net_params = this->net_->params(); - this->history_.clear(); - this->update_.clear(); - this->temp_.clear(); - for (int i = 0; i < net_params.size(); ++i) { - const Blob* net_param = net_params[i].get(); - this->history_.push_back(shared_ptr >(new Blob( - net_param->num(), net_param->channels(), net_param->height(), - net_param->width()))); - this->update_.push_back(shared_ptr >(new Blob( - net_param->num(), net_param->channels(), net_param->height(), - net_param->width()))); - this->temp_.push_back(shared_ptr >(new Blob( - net_param->num(), net_param->channels(), net_param->height(), - net_param->width()))); - } + // Add the extra history entries for AdaDelta after those from + // SGDSolver::PreSolve + const vector > >& net_params = this->net_->params(); for (int i = 0; i < net_params.size(); ++i) { - const Blob* net_param = net_params[i].get(); - this->history_.push_back(shared_ptr >(new Blob( - net_param->num(), net_param->channels(), net_param->height(), - net_param->width()))); + const vector& shape = net_params[i]->shape(); + this->history_.push_back( + shared_ptr >(new Blob(shape))); } } template void AdaDeltaSolver::ComputeUpdateValue() { - vector > >& net_params = this->net_->params(); - vector& net_params_weight_decay = this->net_->params_weight_decay(); + const vector > >& net_params = this->net_->params(); + const vector& net_params_weight_decay = + this->net_->params_weight_decay(); Dtype delta = this->param_.delta(); Dtype momentum = this->param_.momentum(); Dtype weight_decay = this->param_.weight_decay(); diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp index db89e285a9f..277aa3a5c8e 100644 --- a/src/caffe/test/test_gradient_based_solver.cpp +++ b/src/caffe/test/test_gradient_based_solver.cpp @@ -1060,7 +1060,7 @@ TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) { const Dtype kLearningRate = 0.0; const Dtype kWeightDecay = 0.0; const Dtype kMomentum = 0.95; - const int kNumIters = 500; + const int kNumIters = 4; for (int i = 0; i <= kNumIters; ++i) { this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); } @@ -1071,7 +1071,7 @@ TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithEverything) { const Dtype kLearningRate = 0.0; const Dtype kWeightDecay = 0.1; const Dtype kMomentum = 0.95; - const int kNumIters = 500; + const int kNumIters = 4; for (int i = 0; i <= kNumIters; ++i) { this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); } From f2e523e479b89902b644f3a8bb2ac51a6dc28eee Mon Sep 17 00:00:00 2001 From: Matthias Plappert Date: Sat, 18 Jul 2015 18:46:51 +0200 Subject: [PATCH 3/3] Clean up and modernize AdaDelta code; add learning rate support; add additional test cases --- examples/mnist/lenet_adadelta_solver.prototxt | 2 + .../mnist_autoencoder_solver_adadelta.prototxt | 2 + include/caffe/solver.hpp | 16 +- src/caffe/solver.cpp | 274 +++++++++------------ src/caffe/test/test_gradient_based_solver.cpp | 211 ++++++++++------ 5 files changed, 260 insertions(+), 245 deletions(-) diff --git a/examples/mnist/lenet_adadelta_solver.prototxt b/examples/mnist/lenet_adadelta_solver.prototxt index b77b451d56a..776d1e06139 100644 --- a/examples/mnist/lenet_adadelta_solver.prototxt +++ b/examples/mnist/lenet_adadelta_solver.prototxt @@ -7,6 +7,8 @@ test_iter: 100 # Carry out testing every 500 training iterations. test_interval: 500 # The base learning rate, momentum and the weight decay of the network. +base_lr: 1.0 +lr_policy: "fixed" momentum: 0.95 weight_decay: 0.0005 # Display every 100 iterations diff --git a/examples/mnist/mnist_autoencoder_solver_adadelta.prototxt b/examples/mnist/mnist_autoencoder_solver_adadelta.prototxt index 4e43468a71f..065647df31b 100644 --- a/examples/mnist/mnist_autoencoder_solver_adadelta.prototxt +++ b/examples/mnist/mnist_autoencoder_solver_adadelta.prototxt @@ -5,6 +5,8 @@ test_state: { stage: 'test-on-test' } test_iter: 100 test_interval: 500 test_compute_loss: true +base_lr: 1.0 +lr_policy: "fixed" momentum: 0.95 delta: 1e-8 display: 100 diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index 495cd4f159e..5fefd01e549 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -82,12 +82,12 @@ class SGDSolver : public Solver { const vector > >& history() { return history_; } protected: + void PreSolve(); Dtype GetLearningRate(); virtual void ApplyUpdate(); virtual void Normalize(int param_id); virtual void Regularize(int param_id); virtual void ComputeUpdateValue(int param_id, Dtype rate); - virtual void PreSolve(); virtual void ClipGradients(); virtual void SnapshotSolverState(const string& model_filename); virtual void SnapshotSolverStateToBinaryProto(const string& model_filename); @@ -162,19 +162,13 @@ template class AdaDeltaSolver : public SGDSolver { public: explicit AdaDeltaSolver(const SolverParameter& param) - : SGDSolver(param) { PreSolve(); constructor_sanity_check(); } + : SGDSolver(param) { AdaDeltaPreSolve(); } explicit AdaDeltaSolver(const string& param_file) - : SGDSolver(param_file) { PreSolve(); constructor_sanity_check(); } + : SGDSolver(param_file) { AdaDeltaPreSolve(); } protected: - virtual void PreSolve(); - virtual void ComputeUpdateValue(); - void constructor_sanity_check() { - CHECK_EQ(0, this->param_.base_lr()) - << "Learning rate cannot be used with AdaDelta."; - CHECK_EQ("", this->param_.lr_policy()) - << "Learning rate policy cannot be applied to AdaDelta."; - } + void AdaDeltaPreSolve(); + virtual void ComputeUpdateValue(int param_id, Dtype rate); DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver); }; diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 34a290ffe3d..78902ca0ebc 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -935,10 +935,10 @@ void RMSPropSolver::ComputeUpdateValue(int param_id, Dtype rate) { } template -void AdaDeltaSolver::PreSolve() { +void AdaDeltaSolver::AdaDeltaPreSolve() { // Add the extra history entries for AdaDelta after those from // SGDSolver::PreSolve - const vector > >& net_params = this->net_->params(); + const vector*>& net_params = this->net_->learnable_params(); for (int i = 0; i < net_params.size(); ++i) { const vector& shape = net_params[i]->shape(); this->history_.push_back( @@ -947,172 +947,134 @@ void AdaDeltaSolver::PreSolve() { } template -void AdaDeltaSolver::ComputeUpdateValue() { - const vector > >& net_params = this->net_->params(); - const vector& net_params_weight_decay = - this->net_->params_weight_decay(); +void AdaDeltaSolver::ComputeUpdateValue(int param_id, Dtype rate) { + const vector*>& net_params = this->net_->learnable_params(); + const vector& net_params_lr = this->net_->params_lr(); Dtype delta = this->param_.delta(); Dtype momentum = this->param_.momentum(); - Dtype weight_decay = this->param_.weight_decay(); - string regularization_type = this->param_.regularization_type(); + Dtype local_rate = rate * net_params_lr[param_id]; size_t update_history_offset = net_params.size(); switch (Caffe::mode()) { - case Caffe::CPU: - for (int param_id = 0; param_id < net_params.size(); ++param_id) { - Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; - - if (local_decay) { - if (regularization_type == "L2") { - // add weight decay - caffe_axpy(net_params[param_id]->count(), - local_decay, - net_params[param_id]->cpu_data(), - net_params[param_id]->mutable_cpu_diff()); - } else if (regularization_type == "L1") { - caffe_cpu_sign(net_params[param_id]->count(), - net_params[param_id]->cpu_data(), - this->temp_[param_id]->mutable_cpu_data()); - caffe_axpy(net_params[param_id]->count(), - local_decay, - this->temp_[param_id]->cpu_data(), - net_params[param_id]->mutable_cpu_diff()); - } else { - LOG(FATAL) << "Unknown regularization type: " << regularization_type; - } - } + case Caffe::CPU: { + // compute square of gradient in update + caffe_powx(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), Dtype(2), + this->update_[param_id]->mutable_cpu_data()); - // compute square of gradient in update - caffe_powx(net_params[param_id]->count(), - net_params[param_id]->cpu_diff(), Dtype(2), - this->update_[param_id]->mutable_cpu_data()); - - // update history of gradients - caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, - this->update_[param_id]->cpu_data(), momentum, - this->history_[param_id]->mutable_cpu_data()); - - // add delta to history to guard against dividing by zero later - caffe_set(net_params[param_id]->count(), delta, - this->temp_[param_id]->mutable_cpu_data()); - - caffe_add(net_params[param_id]->count(), - this->temp_[param_id]->cpu_data(), - this->history_[update_history_offset + param_id]->cpu_data(), - this->update_[param_id]->mutable_cpu_data()); - - caffe_add(net_params[param_id]->count(), - this->temp_[param_id]->cpu_data(), - this->history_[param_id]->cpu_data(), - this->temp_[param_id]->mutable_cpu_data()); - - // divide history of updates by history of gradients - caffe_div(net_params[param_id]->count(), - this->update_[param_id]->cpu_data(), - this->temp_[param_id]->cpu_data(), - this->update_[param_id]->mutable_cpu_data()); - - // jointly compute the RMS of both for update and gradient history - caffe_powx(net_params[param_id]->count(), - this->update_[param_id]->cpu_data(), Dtype(0.5), - this->update_[param_id]->mutable_cpu_data()); - - // compute the update - caffe_mul(net_params[param_id]->count(), - net_params[param_id]->cpu_diff(), - this->update_[param_id]->cpu_data(), - net_params[param_id]->mutable_cpu_diff()); - - // compute square of update - caffe_powx(net_params[param_id]->count(), - net_params[param_id]->cpu_diff(), Dtype(2), - this->update_[param_id]->mutable_cpu_data()); - - // update history of updates - caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, - this->update_[param_id]->cpu_data(), momentum, - this->history_[update_history_offset + param_id]->mutable_cpu_data()); - } + // update history of gradients + caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, + this->update_[param_id]->cpu_data(), momentum, + this->history_[param_id]->mutable_cpu_data()); + + // add delta to history to guard against dividing by zero later + caffe_set(net_params[param_id]->count(), delta, + this->temp_[param_id]->mutable_cpu_data()); + + caffe_add(net_params[param_id]->count(), + this->temp_[param_id]->cpu_data(), + this->history_[update_history_offset + param_id]->cpu_data(), + this->update_[param_id]->mutable_cpu_data()); + + caffe_add(net_params[param_id]->count(), + this->temp_[param_id]->cpu_data(), + this->history_[param_id]->cpu_data(), + this->temp_[param_id]->mutable_cpu_data()); + + // divide history of updates by history of gradients + caffe_div(net_params[param_id]->count(), + this->update_[param_id]->cpu_data(), + this->temp_[param_id]->cpu_data(), + this->update_[param_id]->mutable_cpu_data()); + + // jointly compute the RMS of both for update and gradient history + caffe_powx(net_params[param_id]->count(), + this->update_[param_id]->cpu_data(), Dtype(0.5), + this->update_[param_id]->mutable_cpu_data()); + + // compute the update + caffe_mul(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), + this->update_[param_id]->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + + // compute square of update + caffe_powx(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), Dtype(2), + this->update_[param_id]->mutable_cpu_data()); + + // update history of updates + caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, + this->update_[param_id]->cpu_data(), momentum, + this->history_[update_history_offset + param_id]->mutable_cpu_data()); + + // apply learning rate + caffe_cpu_scale(net_params[param_id]->count(), local_rate, + net_params[param_id]->cpu_diff(), + net_params[param_id]->mutable_cpu_diff()); break; - case Caffe::GPU: + } + case Caffe::GPU: { #ifndef CPU_ONLY - for (int param_id = 0; param_id < net_params.size(); ++param_id) { - Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; - - if (local_decay) { - if (regularization_type == "L2") { - // add weight decay - caffe_gpu_axpy(net_params[param_id]->count(), - local_decay, - net_params[param_id]->gpu_data(), - net_params[param_id]->mutable_gpu_diff()); - } else if (regularization_type == "L1") { - caffe_gpu_sign(net_params[param_id]->count(), - net_params[param_id]->gpu_data(), - this->temp_[param_id]->mutable_gpu_data()); - caffe_gpu_axpy(net_params[param_id]->count(), - local_decay, - this->temp_[param_id]->gpu_data(), - net_params[param_id]->mutable_gpu_diff()); - } else { - LOG(FATAL) << "Unknown regularization type: " << regularization_type; - } - } + // compute square of gradient in update + caffe_gpu_powx(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), Dtype(2), + this->update_[param_id]->mutable_gpu_data()); - // compute square of gradient in update - caffe_gpu_powx(net_params[param_id]->count(), - net_params[param_id]->gpu_diff(), Dtype(2), - this->update_[param_id]->mutable_gpu_data()); - - // update history of gradients - caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, - this->update_[param_id]->gpu_data(), momentum, - this->history_[param_id]->mutable_gpu_data()); - - // add delta to history to guard against dividing by zero later - caffe_gpu_set(net_params[param_id]->count(), delta, - this->temp_[param_id]->mutable_gpu_data()); - - caffe_gpu_add(net_params[param_id]->count(), - this->temp_[param_id]->gpu_data(), - this->history_[update_history_offset + param_id]->gpu_data(), - this->update_[param_id]->mutable_gpu_data()); - - caffe_gpu_add(net_params[param_id]->count(), - this->temp_[param_id]->gpu_data(), - this->history_[param_id]->gpu_data(), - this->temp_[param_id]->mutable_gpu_data()); - - // divide history of updates by history of gradients - caffe_gpu_div(net_params[param_id]->count(), - this->update_[param_id]->gpu_data(), - this->temp_[param_id]->gpu_data(), - this->update_[param_id]->mutable_gpu_data()); - - // jointly compute the RMS of both for update and gradient history - caffe_gpu_powx(net_params[param_id]->count(), - this->update_[param_id]->gpu_data(), Dtype(0.5), - this->update_[param_id]->mutable_gpu_data()); - - // compute the update and copy to net_diff - caffe_gpu_mul(net_params[param_id]->count(), - net_params[param_id]->gpu_diff(), - this->update_[param_id]->gpu_data(), - net_params[param_id]->mutable_gpu_diff()); - - // compute square of update - caffe_gpu_powx(net_params[param_id]->count(), - net_params[param_id]->gpu_diff(), Dtype(2), - this->update_[param_id]->mutable_gpu_data()); - - // update history of updates - caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, - this->update_[param_id]->gpu_data(), momentum, - this->history_[update_history_offset + param_id]->mutable_gpu_data()); - } + // update history of gradients + caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, + this->update_[param_id]->gpu_data(), momentum, + this->history_[param_id]->mutable_gpu_data()); + + // add delta to history to guard against dividing by zero later + caffe_gpu_set(net_params[param_id]->count(), delta, + this->temp_[param_id]->mutable_gpu_data()); + + caffe_gpu_add(net_params[param_id]->count(), + this->temp_[param_id]->gpu_data(), + this->history_[update_history_offset + param_id]->gpu_data(), + this->update_[param_id]->mutable_gpu_data()); + + caffe_gpu_add(net_params[param_id]->count(), + this->temp_[param_id]->gpu_data(), + this->history_[param_id]->gpu_data(), + this->temp_[param_id]->mutable_gpu_data()); + + // divide history of updates by history of gradients + caffe_gpu_div(net_params[param_id]->count(), + this->update_[param_id]->gpu_data(), + this->temp_[param_id]->gpu_data(), + this->update_[param_id]->mutable_gpu_data()); + + // jointly compute the RMS of both for update and gradient history + caffe_gpu_powx(net_params[param_id]->count(), + this->update_[param_id]->gpu_data(), Dtype(0.5), + this->update_[param_id]->mutable_gpu_data()); + + // compute the update and copy to net_diff + caffe_gpu_mul(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), + this->update_[param_id]->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); + + // compute square of update + caffe_gpu_powx(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), Dtype(2), + this->update_[param_id]->mutable_gpu_data()); + + // update history of updates + caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, + this->update_[param_id]->gpu_data(), momentum, + this->history_[update_history_offset + param_id]->mutable_gpu_data()); + + // apply learning rate + caffe_gpu_scale(net_params[param_id]->count(), local_rate, + net_params[param_id]->gpu_diff(), + net_params[param_id]->mutable_gpu_diff()); #else NO_GPU; #endif break; + } default: LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); } diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp index 277aa3a5c8e..c97d4ede3b3 100644 --- a/src/caffe/test/test_gradient_based_solver.cpp +++ b/src/caffe/test/test_gradient_based_solver.cpp @@ -165,10 +165,6 @@ class GradientBasedSolverTest : public MultiDeviceTest { " bottom: 'targets' " " } " "} "; - if (learning_rate != 0) { - proto << "base_lr: " << learning_rate << " "; - proto << "lr_policy: 'fixed' "; - } if (weight_decay != 0) { proto << "weight_decay: " << weight_decay << " "; } @@ -898,6 +894,139 @@ TYPED_TEST(NesterovSolverTest, TestSnapshotShare) { } template +class AdaDeltaSolverTest : public GradientBasedSolverTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + virtual void InitSolver(const SolverParameter& param) { + this->solver_.reset(new AdaDeltaSolver(param)); + } + + virtual SolverParameter_SolverType solver_type() { + return SolverParameter_SolverType_ADADELTA; + } +}; + +TYPED_TEST_CASE(AdaDeltaSolverTest, TestDtypesAndDevices); + +TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdate) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + this->TestLeastSquaresUpdate(kLearningRate); +} + +TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithWeightDecay) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.95; + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum); +} + +TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithHalfMomentum) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.0; + const Dtype kMomentum = 0.5; + const int kNumIters = 1; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum); + } +} + +TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithMomentum) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.0; + const Dtype kMomentum = 0.95; + const int kNumIters = 1; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum); + } +} + +TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.0; + const Dtype kMomentum = 0.95; + const int kNumIters = 4; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithEverything) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.1; + const Dtype kMomentum = 0.95; + const int kNumIters = 4; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(AdaDeltaSolverTest, + TestAdaDeltaLeastSquaresUpdateWithEverythingShare) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.1; + const Dtype kMomentum = 0.95; + const int kNumIters = 4; + this->share_ = true; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithEverythingAccum) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.1; + const Dtype kMomentum = 0.95; + const int kNumIters = 4; + const int kIterSize = 2; + this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters, + kIterSize); +} + +TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.1; + const Dtype kMomentum = 0.95; + const int kNumIters = 4; + const int kIterSize = 2; + this->share_ = true; + this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters, + kIterSize); +} + +TYPED_TEST(AdaDeltaSolverTest, TestSnapshot) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.1; + const Dtype kMomentum = 0.95; + const int kNumIters = 4; + for (int i = 1; i <= kNumIters; ++i) { + this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(AdaDeltaSolverTest, TestSnapshotShare) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.1; + const Dtype kMomentum = 0.95; + const int kNumIters = 4; + this->share_ = true; + for (int i = 1; i <= kNumIters; ++i) { + this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +template class RMSPropSolverTest : public GradientBasedSolverTest { typedef typename TypeParam::Dtype Dtype; @@ -1003,78 +1132,4 @@ TYPED_TEST(RMSPropSolverTest, TestSnapshotShare) { } } -template -class AdaDeltaSolverTest : public GradientBasedSolverTest { - typedef typename TypeParam::Dtype Dtype; - - protected: - virtual void InitSolver(const SolverParameter& param) { - this->solver_.reset(new AdaDeltaSolver(param)); - } - - virtual SolverParameter_SolverType solver_type() { - return SolverParameter_SolverType_ADADELTA; - } -}; - -TYPED_TEST_CASE(AdaDeltaSolverTest, TestDtypesAndDevices); - -TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdate) { - typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 0.0; - this->TestLeastSquaresUpdate(kLearningRate); -} - -TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithWeightDecay) { - typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 0.0; - const Dtype kWeightDecay = 0.5; - const Dtype kMomentum = 0.95; - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum); -} - -TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithHalfMomentum) { - typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 0.0; - const Dtype kWeightDecay = 0.0; - const Dtype kMomentum = 0.5; - const int kNumIters = 1; - for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum); - } -} - -TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithMomentum) { - typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 0.0; - const Dtype kWeightDecay = 0.0; - const Dtype kMomentum = 0.95; - const int kNumIters = 1; - for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum); - } -} - -TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) { - typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 0.0; - const Dtype kWeightDecay = 0.0; - const Dtype kMomentum = 0.95; - const int kNumIters = 4; - for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); - } -} - -TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithEverything) { - typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 0.0; - const Dtype kWeightDecay = 0.1; - const Dtype kMomentum = 0.95; - const int kNumIters = 4; - for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); - } -} - } // namespace caffe