From c6ab96596d9eae01c2c403487dc8be8e3edc8fbb Mon Sep 17 00:00:00 2001 From: Evan Shelhamer Date: Tue, 15 Nov 2016 11:19:37 -0800 Subject: [PATCH 1/2] sigmoid cross-entropy loss: ignore selected targets by `ignore_label` sig-ce learns to ignore by zeroing out the loss/diff at targets equal to the configured `ignore_label`. n.b. as of now the loss/diff are not properly normalized when there are ignored targets. sig-ce loss should adopt the same normalization options as softmax loss. --- .../layers/sigmoid_cross_entropy_loss_layer.hpp | 5 ++++ .../layers/sigmoid_cross_entropy_loss_layer.cpp | 19 +++++++++++++++ .../layers/sigmoid_cross_entropy_loss_layer.cu | 23 ++++++++++++++++++ .../test/test_sigmoid_cross_entropy_loss_layer.cpp | 28 ++++++++++++++++++++++ 4 files changed, 75 insertions(+) diff --git a/include/caffe/layers/sigmoid_cross_entropy_loss_layer.hpp b/include/caffe/layers/sigmoid_cross_entropy_loss_layer.hpp index 6452ea5106a..a9fe33c8e08 100644 --- a/include/caffe/layers/sigmoid_cross_entropy_loss_layer.hpp +++ b/include/caffe/layers/sigmoid_cross_entropy_loss_layer.hpp @@ -105,6 +105,11 @@ class SigmoidCrossEntropyLossLayer : public LossLayer { vector*> sigmoid_bottom_vec_; /// top vector holder to call the underlying SigmoidLayer::Forward vector*> sigmoid_top_vec_; + + /// Whether to ignore instances with a certain label. + bool has_ignore_label_; + /// The label indicating that an instance should be ignored. + int ignore_label_; }; } // namespace caffe diff --git a/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp b/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp index eb77a9c2cb8..21b64c28002 100644 --- a/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp +++ b/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp @@ -14,6 +14,12 @@ void SigmoidCrossEntropyLossLayer::LayerSetUp( sigmoid_top_vec_.clear(); sigmoid_top_vec_.push_back(sigmoid_output_.get()); sigmoid_layer_->SetUp(sigmoid_bottom_vec_, sigmoid_top_vec_); + + has_ignore_label_ = + this->layer_param_.loss_param().has_ignore_label(); + if (has_ignore_label_) { + ignore_label_ = this->layer_param_.loss_param().ignore_label(); + } } template @@ -39,6 +45,10 @@ void SigmoidCrossEntropyLossLayer::Forward_cpu( const Dtype* target = bottom[1]->cpu_data(); Dtype loss = 0; for (int i = 0; i < count; ++i) { + const int target_value = static_cast(target[i]); + if (has_ignore_label_ && target_value == ignore_label_) { + continue; + } loss -= input_data[i] * (target[i] - (input_data[i] >= 0)) - log(1 + exp(input_data[i] - 2 * input_data[i] * (input_data[i] >= 0))); } @@ -64,6 +74,15 @@ void SigmoidCrossEntropyLossLayer::Backward_cpu( // Scale down gradient const Dtype loss_weight = top[0]->cpu_diff()[0]; caffe_scal(count, loss_weight / num, bottom_diff); + // Zero out gradient of ignored targets. + if (has_ignore_label_) { + for (int i = 0; i < count; ++i) { + const int target_value = static_cast(target[i]); + if (target_value == ignore_label_) { + bottom_diff[i] = 0; + } + } + } } } diff --git a/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu b/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu index 7cb982d2d70..39eb050664b 100644 --- a/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu +++ b/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu @@ -15,6 +15,17 @@ __global__ void SigmoidCrossEntropyLossForwardGPU(const int nthreads, } template +__global__ void SigmoidCrossEntropyLossIgnoreGPU(const int count, + const int ignore_label, const Dtype* target, Dtype* reference) { + CUDA_KERNEL_LOOP(index, count) { + const int target_value = static_cast(target[index]); + if (target_value == ignore_label) { + reference[index] = 0; + } + } +} + +template void SigmoidCrossEntropyLossLayer::Forward_gpu( const vector*>& bottom, const vector*>& top) { // The forward pass computes the sigmoid outputs. @@ -33,6 +44,12 @@ void SigmoidCrossEntropyLossLayer::Forward_gpu( // NOLINT_NEXT_LINE(whitespace/operators) SigmoidCrossEntropyLossForwardGPU<<>>(count, input_data, target, loss_data); + // Zero out loss of ignored targets. + if (has_ignore_label_) { + // NOLINT_NEXT_LINE(whitespace/operators) + SigmoidCrossEntropyLossIgnoreGPU<<>>(count, ignore_label_, target, loss_data); + } Dtype loss; caffe_gpu_asum(count, loss_data, &loss); top[0]->mutable_cpu_data()[0] = loss / num; @@ -58,6 +75,12 @@ void SigmoidCrossEntropyLossLayer::Backward_gpu( // Scale down gradient const Dtype loss_weight = top[0]->cpu_diff()[0]; caffe_gpu_scal(count, loss_weight / num, bottom_diff); + // Zero out gradient of ignored targets. + if (has_ignore_label_) { + // NOLINT_NEXT_LINE(whitespace/operators) + SigmoidCrossEntropyLossIgnoreGPU<<>>(count, ignore_label_, target, bottom_diff); + } } } diff --git a/src/caffe/test/test_sigmoid_cross_entropy_loss_layer.cpp b/src/caffe/test/test_sigmoid_cross_entropy_loss_layer.cpp index 5dfd7656db2..1bd5f93796f 100644 --- a/src/caffe/test/test_sigmoid_cross_entropy_loss_layer.cpp +++ b/src/caffe/test/test_sigmoid_cross_entropy_loss_layer.cpp @@ -116,5 +116,33 @@ TYPED_TEST(SigmoidCrossEntropyLossLayerTest, TestGradient) { this->blob_top_vec_, 0); } +TYPED_TEST(SigmoidCrossEntropyLossLayerTest, TestIgnoreGradient) { + typedef typename TypeParam::Dtype Dtype; + FillerParameter data_filler_param; + data_filler_param.set_std(1); + GaussianFiller data_filler(data_filler_param); + data_filler.Fill(this->blob_bottom_data_); + LayerParameter layer_param; + LossParameter* loss_param = layer_param.mutable_loss_param(); + loss_param->set_ignore_label(-1); + Dtype* target = this->blob_bottom_targets_->mutable_cpu_data(); + const int count = this->blob_bottom_targets_->count(); + // Ignore half of targets, then check that diff of this half is zero, + // while the other half is nonzero. + caffe_set(count / 2, Dtype(-1), target); + SigmoidCrossEntropyLossLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + vector propagate_down(2); + propagate_down[0] = true; + propagate_down[1] = false; + layer.Backward(this->blob_top_vec_, propagate_down, this->blob_bottom_vec_); + const Dtype* diff = this->blob_bottom_data_->cpu_diff(); + for (int i = 0; i < count / 2; ++i) { + EXPECT_FLOAT_EQ(diff[i], 0.); + EXPECT_NE(diff[i + count / 2], 0.); + } +} + } // namespace caffe From 3d62e3cc9da66dbf3328567d0f30d5183b318d81 Mon Sep 17 00:00:00 2001 From: Evan Shelhamer Date: Wed, 16 Nov 2016 20:39:42 -0800 Subject: [PATCH 2/2] sigmoid cross-entropy loss: normalize loss by different schemes sig-ce loss handles all the same normalizations as the softmax loss; refer to #3296 for more detail. this preserves the default normalization for sig-ce loss: batch size. --- .../layers/sigmoid_cross_entropy_loss_layer.hpp | 11 ++++ .../layers/sigmoid_cross_entropy_loss_layer.cpp | 60 +++++++++++++++++++--- .../layers/sigmoid_cross_entropy_loss_layer.cu | 57 ++++++++++++-------- src/caffe/proto/caffe.proto | 4 +- 4 files changed, 102 insertions(+), 30 deletions(-) diff --git a/include/caffe/layers/sigmoid_cross_entropy_loss_layer.hpp b/include/caffe/layers/sigmoid_cross_entropy_loss_layer.hpp index a9fe33c8e08..3d92524421c 100644 --- a/include/caffe/layers/sigmoid_cross_entropy_loss_layer.hpp +++ b/include/caffe/layers/sigmoid_cross_entropy_loss_layer.hpp @@ -97,6 +97,13 @@ class SigmoidCrossEntropyLossLayer : public LossLayer { virtual void Backward_gpu(const vector*>& top, const vector& propagate_down, const vector*>& bottom); + /// Read the normalization mode parameter and compute the normalizer based + /// on the blob size. If normalization_mode is VALID, the count of valid + /// outputs will be read from valid_count, unless it is -1 in which case + /// all outputs are assumed to be valid. + virtual Dtype get_normalizer( + LossParameter_NormalizationMode normalization_mode, int valid_count); + /// The internal SigmoidLayer used to map predictions to probabilities. shared_ptr > sigmoid_layer_; /// sigmoid_output stores the output of the SigmoidLayer. @@ -110,6 +117,10 @@ class SigmoidCrossEntropyLossLayer : public LossLayer { bool has_ignore_label_; /// The label indicating that an instance should be ignored. int ignore_label_; + /// How to normalize the loss. + LossParameter_NormalizationMode normalization_; + Dtype normalizer_; + int outer_num_, inner_num_; }; } // namespace caffe diff --git a/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp b/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp index 21b64c28002..99fa3eb645a 100644 --- a/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp +++ b/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp @@ -1,3 +1,4 @@ +#include #include #include "caffe/layers/sigmoid_cross_entropy_loss_layer.hpp" @@ -20,17 +21,60 @@ void SigmoidCrossEntropyLossLayer::LayerSetUp( if (has_ignore_label_) { ignore_label_ = this->layer_param_.loss_param().ignore_label(); } + if (this->layer_param_.loss_param().has_normalization()) { + normalization_ = this->layer_param_.loss_param().normalization(); + } else if (this->layer_param_.loss_param().has_normalize()) { + normalization_ = this->layer_param_.loss_param().normalize() ? + LossParameter_NormalizationMode_VALID : + LossParameter_NormalizationMode_BATCH_SIZE; + } else { + normalization_ = LossParameter_NormalizationMode_BATCH_SIZE; + } } template void SigmoidCrossEntropyLossLayer::Reshape( const vector*>& bottom, const vector*>& top) { LossLayer::Reshape(bottom, top); + outer_num_ = bottom[0]->shape(0); // batch size + inner_num_ = bottom[0]->count(1); // instance size: |output| == |target| CHECK_EQ(bottom[0]->count(), bottom[1]->count()) << "SIGMOID_CROSS_ENTROPY_LOSS layer inputs must have the same count."; sigmoid_layer_->Reshape(sigmoid_bottom_vec_, sigmoid_top_vec_); } +// TODO(shelhamer) loss normalization should be pulled up into LossLayer, +// instead of duplicated here and in SoftMaxWithLossLayer +template +Dtype SigmoidCrossEntropyLossLayer::get_normalizer( + LossParameter_NormalizationMode normalization_mode, int valid_count) { + Dtype normalizer; + switch (normalization_mode) { + case LossParameter_NormalizationMode_FULL: + normalizer = Dtype(outer_num_ * inner_num_); + break; + case LossParameter_NormalizationMode_VALID: + if (valid_count == -1) { + normalizer = Dtype(outer_num_ * inner_num_); + } else { + normalizer = Dtype(valid_count); + } + break; + case LossParameter_NormalizationMode_BATCH_SIZE: + normalizer = Dtype(outer_num_); + break; + case LossParameter_NormalizationMode_NONE: + normalizer = Dtype(1); + break; + default: + LOG(FATAL) << "Unknown normalization mode: " + << LossParameter_NormalizationMode_Name(normalization_mode); + } + // Some users will have no labels for some examples in order to 'turn off' a + // particular loss in a multi-task setup. The max prevents NaNs in that case. + return std::max(Dtype(1.0), normalizer); +} + template void SigmoidCrossEntropyLossLayer::Forward_cpu( const vector*>& bottom, const vector*>& top) { @@ -38,21 +82,22 @@ void SigmoidCrossEntropyLossLayer::Forward_cpu( sigmoid_bottom_vec_[0] = bottom[0]; sigmoid_layer_->Forward(sigmoid_bottom_vec_, sigmoid_top_vec_); // Compute the loss (negative log likelihood) - const int count = bottom[0]->count(); - const int num = bottom[0]->num(); // Stable version of loss computation from input data const Dtype* input_data = bottom[0]->cpu_data(); const Dtype* target = bottom[1]->cpu_data(); + int valid_count = 0; Dtype loss = 0; - for (int i = 0; i < count; ++i) { + for (int i = 0; i < bottom[0]->count(); ++i) { const int target_value = static_cast(target[i]); if (has_ignore_label_ && target_value == ignore_label_) { continue; } loss -= input_data[i] * (target[i] - (input_data[i] >= 0)) - log(1 + exp(input_data[i] - 2 * input_data[i] * (input_data[i] >= 0))); + ++valid_count; } - top[0]->mutable_cpu_data()[0] = loss / num; + normalizer_ = get_normalizer(normalization_, valid_count); + top[0]->mutable_cpu_data()[0] = loss / normalizer_; } template @@ -66,14 +111,10 @@ void SigmoidCrossEntropyLossLayer::Backward_cpu( if (propagate_down[0]) { // First, compute the diff const int count = bottom[0]->count(); - const int num = bottom[0]->num(); const Dtype* sigmoid_output_data = sigmoid_output_->cpu_data(); const Dtype* target = bottom[1]->cpu_data(); Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); caffe_sub(count, sigmoid_output_data, target, bottom_diff); - // Scale down gradient - const Dtype loss_weight = top[0]->cpu_diff()[0]; - caffe_scal(count, loss_weight / num, bottom_diff); // Zero out gradient of ignored targets. if (has_ignore_label_) { for (int i = 0; i < count; ++i) { @@ -83,6 +124,9 @@ void SigmoidCrossEntropyLossLayer::Backward_cpu( } } } + // Scale down gradient + Dtype loss_weight = top[0]->cpu_diff()[0] / normalizer_; + caffe_scal(count, loss_weight, bottom_diff); } } diff --git a/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu b/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu index 39eb050664b..b9877e6a3f6 100644 --- a/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu +++ b/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu @@ -5,26 +5,38 @@ namespace caffe { + template __global__ void SigmoidCrossEntropyLossForwardGPU(const int nthreads, - const Dtype* input_data, const Dtype* target, Dtype* loss) { + const Dtype* input_data, const Dtype* target, Dtype* loss, + const bool has_ignore_label_, const int ignore_label_, + Dtype* counts) { CUDA_KERNEL_LOOP(i, nthreads) { - loss[i] = input_data[i] * (target[i] - (input_data[i] >= 0)) - - log(1 + exp(input_data[i] - 2 * input_data[i] * (input_data[i] >= 0))); + const int target_value = static_cast(target[i]); + if (has_ignore_label_ && target_value == ignore_label_) { + loss[i] = 0; + counts[i] = 0; + } else { + loss[i] = input_data[i] * (target[i] - (input_data[i] >= 0)) - + log(1 + exp(input_data[i] - 2 * input_data[i] * + (input_data[i] >= 0))); + counts[i] = 1; + } } } template -__global__ void SigmoidCrossEntropyLossIgnoreGPU(const int count, - const int ignore_label, const Dtype* target, Dtype* reference) { - CUDA_KERNEL_LOOP(index, count) { - const int target_value = static_cast(target[index]); +__global__ void SigmoidCrossEntropyLossIgnoreDiffGPU(const int count, + const int ignore_label, const Dtype* target, Dtype* diff) { + CUDA_KERNEL_LOOP(i, count) { + const int target_value = static_cast(target[i]); if (target_value == ignore_label) { - reference[index] = 0; + diff[i] = 0; } } } + template void SigmoidCrossEntropyLossLayer::Forward_gpu( const vector*>& bottom, const vector*>& top) { @@ -33,7 +45,6 @@ void SigmoidCrossEntropyLossLayer::Forward_gpu( sigmoid_layer_->Forward(sigmoid_bottom_vec_, sigmoid_top_vec_); // Compute the loss (negative log likelihood) const int count = bottom[0]->count(); - const int num = bottom[0]->num(); // Stable version of loss computation from input data const Dtype* input_data = bottom[0]->gpu_data(); const Dtype* target = bottom[1]->gpu_data(); @@ -41,18 +52,23 @@ void SigmoidCrossEntropyLossLayer::Forward_gpu( // on the backward pass, we use it here to avoid having to allocate new GPU // memory to accumulate intermediate results in the kernel. Dtype* loss_data = bottom[0]->mutable_gpu_diff(); + Dtype* count_data = bottom[1]->mutable_gpu_diff(); + Dtype valid_count; // NOLINT_NEXT_LINE(whitespace/operators) SigmoidCrossEntropyLossForwardGPU<<>>(count, input_data, target, loss_data); - // Zero out loss of ignored targets. - if (has_ignore_label_) { - // NOLINT_NEXT_LINE(whitespace/operators) - SigmoidCrossEntropyLossIgnoreGPU<<>>(count, ignore_label_, target, loss_data); + CAFFE_CUDA_NUM_THREADS>>>(count, input_data, target, loss_data, + has_ignore_label_, ignore_label_, count_data); + // Only launch another CUDA kernel if we actually need the valid count. + if (normalization_ == LossParameter_NormalizationMode_VALID && + has_ignore_label_) { + caffe_gpu_asum(count, count_data, &valid_count); + } else { + valid_count = count; } Dtype loss; caffe_gpu_asum(count, loss_data, &loss); - top[0]->mutable_cpu_data()[0] = loss / num; + normalizer_ = get_normalizer(normalization_, valid_count); + top[0]->mutable_cpu_data()[0] = loss / normalizer_; } template @@ -66,21 +82,20 @@ void SigmoidCrossEntropyLossLayer::Backward_gpu( if (propagate_down[0]) { // First, compute the diff const int count = bottom[0]->count(); - const int num = bottom[0]->num(); const Dtype* sigmoid_output_data = sigmoid_output_->gpu_data(); const Dtype* target = bottom[1]->gpu_data(); Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); caffe_copy(count, sigmoid_output_data, bottom_diff); caffe_gpu_axpy(count, Dtype(-1), target, bottom_diff); - // Scale down gradient - const Dtype loss_weight = top[0]->cpu_diff()[0]; - caffe_gpu_scal(count, loss_weight / num, bottom_diff); // Zero out gradient of ignored targets. if (has_ignore_label_) { // NOLINT_NEXT_LINE(whitespace/operators) - SigmoidCrossEntropyLossIgnoreGPU<<<<>>(count, ignore_label_, target, bottom_diff); } + // Scale down gradient + Dtype loss_weight = top[0]->cpu_diff()[0] / normalizer_; + caffe_gpu_scal(count, loss_weight, bottom_diff); } } diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 6940a705eb6..0b2768b7708 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -434,7 +434,7 @@ message LossParameter { optional int32 ignore_label = 1; // How to normalize the loss for loss layers that aggregate across batches, // spatial dimensions, or other dimensions. Currently only implemented in - // SoftmaxWithLoss layer. + // SoftmaxWithLoss and SigmoidCrossEntropyLoss layers. enum NormalizationMode { // Divide by the number of examples in the batch times spatial dimensions. // Outputs that receive the ignore label will NOT be ignored in computing @@ -448,6 +448,8 @@ message LossParameter { // Do not normalize the loss. NONE = 3; } + // For historical reasons, the default normalization for + // SigmoidCrossEntropyLoss is BATCH_SIZE and *not* VALID. optional NormalizationMode normalization = 3 [default = VALID]; // Deprecated. Ignored if normalization is specified. If normalization // is not specified, then setting this to false will be equivalent to