From 57f498ccbfe79481daabc41677e7d4680d51311e Mon Sep 17 00:00:00 2001 From: Noiredd Date: Mon, 3 Jul 2017 12:43:47 +0200 Subject: [PATCH] Optimized accuracy calculation --- src/caffe/layers/accuracy_layer.cpp | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/src/caffe/layers/accuracy_layer.cpp b/src/caffe/layers/accuracy_layer.cpp index 4eddbb5c850..e64b85612a5 100644 --- a/src/caffe/layers/accuracy_layer.cpp +++ b/src/caffe/layers/accuracy_layer.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -69,22 +70,33 @@ void AccuracyLayer::Forward_cpu(const vector*>& bottom, if (top.size() > 1) ++nums_buffer_.mutable_cpu_data()[label_value]; DCHECK_GE(label_value, 0); DCHECK_LT(label_value, num_labels); - // Top-k accuracy - std::vector > bottom_data_vector; - for (int k = 0; k < num_labels; ++k) { - bottom_data_vector.push_back(std::make_pair( - bottom_data[i * dim + k * inner_num_ + j], k)); + // Top-k accuracy using priority queue + typedef std::pair Dpair; + std::priority_queue, + std::greater > top_scores; + std::greater greater_; + // fill the first k elements + for (int k = 0; k < top_k_; ++k) { + const Dtype score = bottom_data[i * dim + k * inner_num_ + j]; + top_scores.push(std::make_pair(score, k)); + } + // only push new element if it's greater than the current smallest + for (int k = top_k_; k < num_labels; ++k) { + const Dtype score = bottom_data[i * dim + k * inner_num_ + j]; + if (greater_(std::make_pair(score, k), top_scores.top())) { + top_scores.pop(); + top_scores.push(std::make_pair(score, k)); + } } - std::partial_sort( - bottom_data_vector.begin(), bottom_data_vector.begin() + top_k_, - bottom_data_vector.end(), std::greater >()); // check if true label is in top k predictions - for (int k = 0; k < top_k_; k++) { - if (bottom_data_vector[k].second == label_value) { + while (!top_scores.empty()) { + if (top_scores.top().second == label_value) { ++accuracy; if (top.size() > 1) ++top[1]->mutable_cpu_data()[label_value]; break; } + top_scores.pop(); } ++count; }