diff --git a/include/caffe/common_layers.hpp b/include/caffe/common_layers.hpp index b1ac3a93eff..cae1c3e4ee6 100644 --- a/include/caffe/common_layers.hpp +++ b/include/caffe/common_layers.hpp @@ -386,8 +386,8 @@ class CuDNNSoftmaxLayer : public SoftmaxLayer { bool handles_setup_; cudnnHandle_t handle_; - cudnnTensor4dDescriptor_t bottom_desc_; - cudnnTensor4dDescriptor_t top_desc_; + cudnnTensorDescriptor_t bottom_desc_; + cudnnTensorDescriptor_t top_desc_; }; #endif diff --git a/include/caffe/neuron_layers.hpp b/include/caffe/neuron_layers.hpp index 8669923d1fa..323215134c7 100644 --- a/include/caffe/neuron_layers.hpp +++ b/include/caffe/neuron_layers.hpp @@ -433,8 +433,8 @@ class CuDNNReLULayer : public ReLULayer { bool handles_setup_; cudnnHandle_t handle_; - cudnnTensor4dDescriptor_t bottom_desc_; - cudnnTensor4dDescriptor_t top_desc_; + cudnnTensorDescriptor_t bottom_desc_; + cudnnTensorDescriptor_t top_desc_; }; #endif @@ -516,8 +516,8 @@ class CuDNNSigmoidLayer : public SigmoidLayer { bool handles_setup_; cudnnHandle_t handle_; - cudnnTensor4dDescriptor_t bottom_desc_; - cudnnTensor4dDescriptor_t top_desc_; + cudnnTensorDescriptor_t bottom_desc_; + cudnnTensorDescriptor_t top_desc_; }; #endif @@ -601,8 +601,8 @@ class CuDNNTanHLayer : public TanHLayer { bool handles_setup_; cudnnHandle_t handle_; - cudnnTensor4dDescriptor_t bottom_desc_; - cudnnTensor4dDescriptor_t top_desc_; + cudnnTensorDescriptor_t bottom_desc_; + cudnnTensorDescriptor_t top_desc_; }; #endif diff --git a/include/caffe/util/cudnn.hpp b/include/caffe/util/cudnn.hpp index eaed7333df8..b531dd5fa7a 100644 --- a/include/caffe/util/cudnn.hpp +++ b/include/caffe/util/cudnn.hpp @@ -50,41 +50,45 @@ template class dataType; template<> class dataType { public: static const cudnnDataType_t type = CUDNN_DATA_FLOAT; + static float oneval, zeroval; + static const void *one, *zero; }; template<> class dataType { public: static const cudnnDataType_t type = CUDNN_DATA_DOUBLE; + static double oneval, zeroval; + static const void *one, *zero; }; template -inline void createTensor4dDesc(cudnnTensor4dDescriptor_t* desc) { - CUDNN_CHECK(cudnnCreateTensor4dDescriptor(desc)); +inline void createTensor4dDesc(cudnnTensorDescriptor_t* desc) { + CUDNN_CHECK(cudnnCreateTensorDescriptor(desc)); } template -inline void setTensor4dDesc(cudnnTensor4dDescriptor_t* desc, +inline void setTensor4dDesc(cudnnTensorDescriptor_t* desc, int n, int c, int h, int w, int stride_n, int stride_c, int stride_h, int stride_w) { CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(*desc, dataType::type, - n, c, h, w, stride_n, stride_c, stride_h, stride_w)); + n, c, h, w, stride_n, stride_c, stride_h, stride_w)); } template -inline void setTensor4dDesc(cudnnTensor4dDescriptor_t* desc, +inline void setTensor4dDesc(cudnnTensorDescriptor_t* desc, int n, int c, int h, int w) { const int stride_w = 1; const int stride_h = w * stride_w; const int stride_c = h * stride_h; const int stride_n = c * stride_c; setTensor4dDesc(desc, n, c, h, w, - stride_n, stride_c, stride_h, stride_w); + stride_n, stride_c, stride_h, stride_w); } template inline void createFilterDesc(cudnnFilterDescriptor_t* desc, int n, int c, int h, int w) { CUDNN_CHECK(cudnnCreateFilterDescriptor(desc)); - CUDNN_CHECK(cudnnSetFilterDescriptor(*desc, dataType::type, + CUDNN_CHECK(cudnnSetFilter4dDescriptor(*desc, dataType::type, n, c, h, w)); } @@ -95,29 +99,29 @@ inline void createConvolutionDesc(cudnnConvolutionDescriptor_t* conv) { template inline void setConvolutionDesc(cudnnConvolutionDescriptor_t* conv, - cudnnTensor4dDescriptor_t bottom, cudnnFilterDescriptor_t filter, + cudnnTensorDescriptor_t bottom, cudnnFilterDescriptor_t filter, int pad_h, int pad_w, int stride_h, int stride_w) { - CUDNN_CHECK(cudnnSetConvolutionDescriptor(*conv, bottom, filter, + CUDNN_CHECK(cudnnSetConvolution2dDescriptor(*conv, pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION)); } template -inline void createPoolingDesc(cudnnPoolingDescriptor_t* conv, +inline void createPoolingDesc(cudnnPoolingDescriptor_t* pool_desc, PoolingParameter_PoolMethod poolmethod, cudnnPoolingMode_t* mode, - int h, int w, int stride_h, int stride_w) { + int h, int w, int pad_h, int pad_w, int stride_h, int stride_w) { switch (poolmethod) { case PoolingParameter_PoolMethod_MAX: *mode = CUDNN_POOLING_MAX; break; case PoolingParameter_PoolMethod_AVE: - *mode = CUDNN_POOLING_AVERAGE; + *mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; break; default: LOG(FATAL) << "Unknown pooling method."; } - CUDNN_CHECK(cudnnCreatePoolingDescriptor(conv)); - CUDNN_CHECK(cudnnSetPoolingDescriptor(*conv, *mode, h, w, - stride_h, stride_w)); + CUDNN_CHECK(cudnnCreatePoolingDescriptor(pool_desc)); + CUDNN_CHECK(cudnnSetPooling2dDescriptor(*pool_desc, *mode, h, w, + pad_h, pad_w, stride_h, stride_w)); } } // namespace cudnn diff --git a/include/caffe/vision_layers.hpp b/include/caffe/vision_layers.hpp index 6cb507a5780..cd0ab8babb0 100644 --- a/include/caffe/vision_layers.hpp +++ b/include/caffe/vision_layers.hpp @@ -246,11 +246,13 @@ class CuDNNConvolutionLayer : public ConvolutionLayer { bool handles_setup_; cudnnHandle_t* handle_; cudaStream_t* stream_; - vector bottom_descs_, top_descs_; - cudnnTensor4dDescriptor_t bias_desc_; + vector bottom_descs_, top_descs_; + cudnnTensorDescriptor_t bias_desc_; cudnnFilterDescriptor_t filter_desc_; vector conv_descs_; int bottom_offset_, top_offset_, weight_offset_, bias_offset_; + size_t workspaceSizeInBytes; + void *workspace; }; #endif @@ -445,7 +447,7 @@ class CuDNNPoolingLayer : public PoolingLayer { bool handles_setup_; cudnnHandle_t handle_; - cudnnTensor4dDescriptor_t bottom_desc_, top_desc_; + cudnnTensorDescriptor_t bottom_desc_, top_desc_; cudnnPoolingDescriptor_t pooling_desc_; cudnnPoolingMode_t mode_; }; diff --git a/src/caffe/layers/cudnn_conv_layer.cpp b/src/caffe/layers/cudnn_conv_layer.cpp index 4a69ca20d0a..524caf1320f 100644 --- a/src/caffe/layers/cudnn_conv_layer.cpp +++ b/src/caffe/layers/cudnn_conv_layer.cpp @@ -43,10 +43,10 @@ void CuDNNConvolutionLayer::LayerSetUp( // Create tensor descriptor(s) for data and corresponding convolution(s). for (int i = 0; i < bottom.size(); i++) { - cudnnTensor4dDescriptor_t bottom_desc; + cudnnTensorDescriptor_t bottom_desc; cudnn::createTensor4dDesc(&bottom_desc); bottom_descs_.push_back(bottom_desc); - cudnnTensor4dDescriptor_t top_desc; + cudnnTensorDescriptor_t top_desc; cudnn::createTensor4dDesc(&top_desc); top_descs_.push_back(top_desc); cudnnConvolutionDescriptor_t conv_desc; @@ -104,12 +104,12 @@ CuDNNConvolutionLayer::~CuDNNConvolutionLayer() { if (!handles_setup_) { return; } for (int i = 0; i < bottom_descs_.size(); i++) { - cudnnDestroyTensor4dDescriptor(bottom_descs_[i]); - cudnnDestroyTensor4dDescriptor(top_descs_[i]); + cudnnDestroyTensorDescriptor(bottom_descs_[i]); + cudnnDestroyTensorDescriptor(top_descs_[i]); cudnnDestroyConvolutionDescriptor(conv_descs_[i]); } if (this->bias_term_) { - cudnnDestroyTensor4dDescriptor(bias_desc_); + cudnnDestroyTensorDescriptor(bias_desc_); } cudnnDestroyFilterDescriptor(filter_desc_); diff --git a/src/caffe/layers/cudnn_conv_layer.cu b/src/caffe/layers/cudnn_conv_layer.cu index 071014e1b48..08f5201bc22 100644 --- a/src/caffe/layers/cudnn_conv_layer.cu +++ b/src/caffe/layers/cudnn_conv_layer.cu @@ -21,21 +21,57 @@ void CuDNNConvolutionLayer::Forward_gpu( // Forward through cuDNN in parallel over groups. for (int g = 0; g < this->group_; g++) { + cudnnConvolutionFwdAlgo_t algo; + + // pick the convolution algorithm + // TODO(shelhamer) this should be done during reshape + // TODO(shelhamer) the choice of automatic or manual algorithm picking + // should be exposed in proto + CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(handle_[g], + bottom_descs_[i], + filter_desc_, + conv_descs_[i], + top_descs_[i], + CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, + 0, // memoryLimitInBytes, + &algo)); + + // get minimum size of the workspace needed for the desired algorithm + size_t workspaceSizeInBytes_temp = 0; + + CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(handle_[g], + bottom_descs_[i], + filter_desc_, + conv_descs_[i], + top_descs_[i], + algo, + &workspaceSizeInBytes)); + + if (workspaceSizeInBytes_temp > workspaceSizeInBytes) { + workspaceSizeInBytes = workspaceSizeInBytes_temp; + // free the existing workspace and allocate a new (larger) one + cudaFree(this->workspace); + cudaMalloc(&(this->workspace), workspaceSizeInBytes); + } + // Filters. CUDNN_CHECK(cudnnConvolutionForward(handle_[g], - bottom_descs_[i], bottom_data + bottom_offset_ * g, - filter_desc_, weight + weight_offset_ * g, - conv_descs_[i], - top_descs_[i], top_data + top_offset_ * g, - CUDNN_RESULT_NO_ACCUMULATE)); + cudnn::dataType::one, + bottom_descs_[i], bottom_data + bottom_offset_ * g, + filter_desc_, weight + weight_offset_ * g, + conv_descs_[i], + algo, workspace, workspaceSizeInBytes, + cudnn::dataType::zero, + top_descs_[i], top_data + top_offset_ * g)); // Bias. if (this->bias_term_) { const Dtype* bias_data = this->blobs_[1]->gpu_data(); - Dtype alpha = 1.; - CUDNN_CHECK(cudnnAddTensor4d(handle_[g], CUDNN_ADD_SAME_C, &alpha, - bias_desc_, bias_data + bias_offset_ * g, - top_descs_[i], top_data + top_offset_ * g)); + CUDNN_CHECK(cudnnAddTensor(handle_[g], CUDNN_ADD_SAME_C, + cudnn::dataType::one, + bias_desc_, bias_data + bias_offset_ * g, + cudnn::dataType::one, + top_descs_[i], top_data + top_offset_ * g)); } } @@ -68,20 +104,22 @@ void CuDNNConvolutionLayer::Backward_gpu(const vector*>& top, // Gradient w.r.t. bias. if (this->bias_term_ && this->param_propagate_down_[1]) { CUDNN_CHECK(cudnnConvolutionBackwardBias(handle_[0*this->group_ + g], - top_descs_[i], top_diff + top_offset_ * g, - bias_desc_, bias_diff + bias_offset_ * g, - CUDNN_RESULT_ACCUMULATE)); + cudnn::dataType::one, + top_descs_[i], top_diff + top_offset_ * g, + cudnn::dataType::one, + bias_desc_, bias_diff + bias_offset_ * g)); } // Gradient w.r.t. weights. if (this->param_propagate_down_[0]) { const Dtype* bottom_data = bottom[i]->gpu_data(); CUDNN_CHECK(cudnnConvolutionBackwardFilter(handle_[1*this->group_ + g], - bottom_descs_[i], bottom_data + bottom_offset_ * g, - top_descs_[i], top_diff + top_offset_ * g, - conv_descs_[i], - filter_desc_, weight_diff + weight_offset_ * g, - CUDNN_RESULT_ACCUMULATE)); + cudnn::dataType::one, + bottom_descs_[i], bottom_data + bottom_offset_ * g, + top_descs_[i], top_diff + top_offset_ * g, + conv_descs_[i], + cudnn::dataType::one, + filter_desc_, weight_diff + weight_offset_ * g)); } // Gradient w.r.t. bottom data. @@ -91,11 +129,12 @@ void CuDNNConvolutionLayer::Backward_gpu(const vector*>& top, } Dtype* bottom_diff = bottom[i]->mutable_gpu_diff(); CUDNN_CHECK(cudnnConvolutionBackwardData(handle_[2*this->group_ + g], - filter_desc_, weight + weight_offset_ * g, - top_descs_[i], top_diff + top_offset_ * g, - conv_descs_[i], - bottom_descs_[i], bottom_diff + bottom_offset_ * g, - CUDNN_RESULT_NO_ACCUMULATE)); + cudnn::dataType::one, + filter_desc_, weight + weight_offset_ * g, + top_descs_[i], top_diff + top_offset_ * g, + conv_descs_[i], + cudnn::dataType::zero, + bottom_descs_[i], bottom_diff + bottom_offset_ * g)); } } diff --git a/src/caffe/layers/cudnn_pooling_layer.cpp b/src/caffe/layers/cudnn_pooling_layer.cpp index dd90195637b..c92c4e477b5 100644 --- a/src/caffe/layers/cudnn_pooling_layer.cpp +++ b/src/caffe/layers/cudnn_pooling_layer.cpp @@ -13,15 +13,13 @@ template void CuDNNPoolingLayer::LayerSetUp(const vector*>& bottom, const vector*>& top) { PoolingLayer::LayerSetUp(bottom, top); - // Sanity check: CUDNN currently only supports pad == 0. - CHECK_EQ(this->pad_h_, 0); - CHECK_EQ(this->pad_w_, 0); CUDNN_CHECK(cudnnCreate(&handle_)); cudnn::createTensor4dDesc(&bottom_desc_); cudnn::createTensor4dDesc(&top_desc_); cudnn::createPoolingDesc(&pooling_desc_, this->layer_param_.pooling_param().pool(), &mode_, - this->kernel_h_, this->kernel_w_, this->stride_h_, this->stride_w_); + this->kernel_h_, this->kernel_w_, this->pad_h_, this->pad_w_, + this->stride_h_, this->stride_w_); handles_setup_ = true; } @@ -40,8 +38,8 @@ CuDNNPoolingLayer::~CuDNNPoolingLayer() { // Check that handles have been setup before destroying. if (!handles_setup_) { return; } - cudnnDestroyTensor4dDescriptor(bottom_desc_); - cudnnDestroyTensor4dDescriptor(top_desc_); + cudnnDestroyTensorDescriptor(bottom_desc_); + cudnnDestroyTensorDescriptor(top_desc_); cudnnDestroyPoolingDescriptor(pooling_desc_); cudnnDestroy(handle_); } diff --git a/src/caffe/layers/cudnn_pooling_layer.cu b/src/caffe/layers/cudnn_pooling_layer.cu index 1c113aad75f..a952b855a48 100644 --- a/src/caffe/layers/cudnn_pooling_layer.cu +++ b/src/caffe/layers/cudnn_pooling_layer.cu @@ -15,7 +15,10 @@ void CuDNNPoolingLayer::Forward_gpu(const vector*>& bottom, const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* top_data = top[0]->mutable_gpu_data(); CUDNN_CHECK(cudnnPoolingForward(handle_, pooling_desc_, - bottom_desc_, bottom_data, top_desc_, top_data)); + cudnn::dataType::one, + bottom_desc_, bottom_data, + cudnn::dataType::zero, + top_desc_, top_data)); } template @@ -29,8 +32,11 @@ void CuDNNPoolingLayer::Backward_gpu(const vector*>& top, const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); CUDNN_CHECK(cudnnPoolingBackward(handle_, pooling_desc_, - top_desc_, top_data, top_desc_, top_diff, - bottom_desc_, bottom_data, bottom_desc_, bottom_diff)); + cudnn::dataType::one, + top_desc_, top_data, top_desc_, top_diff, + bottom_desc_, bottom_data, + cudnn::dataType::zero, + bottom_desc_, bottom_diff)); } INSTANTIATE_LAYER_GPU_FUNCS(CuDNNPoolingLayer); diff --git a/src/caffe/layers/cudnn_relu_layer.cpp b/src/caffe/layers/cudnn_relu_layer.cpp index 0b8a6bc3248..759d83984ef 100644 --- a/src/caffe/layers/cudnn_relu_layer.cpp +++ b/src/caffe/layers/cudnn_relu_layer.cpp @@ -35,8 +35,8 @@ CuDNNReLULayer::~CuDNNReLULayer() { // Check that handles have been setup before destroying. if (!handles_setup_) { return; } - cudnnDestroyTensor4dDescriptor(this->bottom_desc_); - cudnnDestroyTensor4dDescriptor(this->top_desc_); + cudnnDestroyTensorDescriptor(this->bottom_desc_); + cudnnDestroyTensorDescriptor(this->top_desc_); cudnnDestroy(this->handle_); } diff --git a/src/caffe/layers/cudnn_relu_layer.cu b/src/caffe/layers/cudnn_relu_layer.cu index 862508707a0..21d14857dd2 100644 --- a/src/caffe/layers/cudnn_relu_layer.cu +++ b/src/caffe/layers/cudnn_relu_layer.cu @@ -18,8 +18,11 @@ void CuDNNReLULayer::Forward_gpu(const vector*>& bottom, const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* top_data = top[0]->mutable_gpu_data(); CUDNN_CHECK(cudnnActivationForward(this->handle_, - CUDNN_ACTIVATION_RELU, - this->bottom_desc_, bottom_data, this->top_desc_, top_data)); + CUDNN_ACTIVATION_RELU, + cudnn::dataType::one, + this->bottom_desc_, bottom_data, + cudnn::dataType::zero, + this->top_desc_, top_data)); } template @@ -40,9 +43,12 @@ void CuDNNReLULayer::Backward_gpu(const vector*>& top, const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); CUDNN_CHECK(cudnnActivationBackward(this->handle_, - CUDNN_ACTIVATION_RELU, - this->top_desc_, top_data, this->top_desc_, top_diff, - this->bottom_desc_, bottom_data, this->bottom_desc_, bottom_diff)); + CUDNN_ACTIVATION_RELU, + cudnn::dataType::one, + this->top_desc_, top_data, this->top_desc_, top_diff, + this->bottom_desc_, bottom_data, + cudnn::dataType::zero, + this->bottom_desc_, bottom_diff)); } INSTANTIATE_LAYER_GPU_FUNCS(CuDNNReLULayer); diff --git a/src/caffe/layers/cudnn_sigmoid_layer.cpp b/src/caffe/layers/cudnn_sigmoid_layer.cpp index 67bd9c373b0..32637873d46 100644 --- a/src/caffe/layers/cudnn_sigmoid_layer.cpp +++ b/src/caffe/layers/cudnn_sigmoid_layer.cpp @@ -35,8 +35,8 @@ CuDNNSigmoidLayer::~CuDNNSigmoidLayer() { // Check that handles have been setup before destroying. if (!handles_setup_) { return; } - cudnnDestroyTensor4dDescriptor(this->bottom_desc_); - cudnnDestroyTensor4dDescriptor(this->top_desc_); + cudnnDestroyTensorDescriptor(this->bottom_desc_); + cudnnDestroyTensorDescriptor(this->top_desc_); cudnnDestroy(this->handle_); } diff --git a/src/caffe/layers/cudnn_sigmoid_layer.cu b/src/caffe/layers/cudnn_sigmoid_layer.cu index 31b094e25d4..7a06cf721da 100644 --- a/src/caffe/layers/cudnn_sigmoid_layer.cu +++ b/src/caffe/layers/cudnn_sigmoid_layer.cu @@ -13,8 +13,11 @@ void CuDNNSigmoidLayer::Forward_gpu(const vector*>& bottom, const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* top_data = top[0]->mutable_gpu_data(); CUDNN_CHECK(cudnnActivationForward(this->handle_, - CUDNN_ACTIVATION_SIGMOID, - this->bottom_desc_, bottom_data, this->top_desc_, top_data)); + CUDNN_ACTIVATION_SIGMOID, + cudnn::dataType::one, + this->bottom_desc_, bottom_data, + cudnn::dataType::zero, + this->top_desc_, top_data)); } template @@ -30,9 +33,12 @@ void CuDNNSigmoidLayer::Backward_gpu(const vector*>& top, const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); CUDNN_CHECK(cudnnActivationBackward(this->handle_, - CUDNN_ACTIVATION_SIGMOID, - this->top_desc_, top_data, this->top_desc_, top_diff, - this->bottom_desc_, bottom_data, this->bottom_desc_, bottom_diff)); + CUDNN_ACTIVATION_SIGMOID, + cudnn::dataType::one, + this->top_desc_, top_data, this->top_desc_, top_diff, + this->bottom_desc_, bottom_data, + cudnn::dataType::zero, + this->bottom_desc_, bottom_diff)); } INSTANTIATE_LAYER_GPU_FUNCS(CuDNNSigmoidLayer); diff --git a/src/caffe/layers/cudnn_softmax_layer.cpp b/src/caffe/layers/cudnn_softmax_layer.cpp index 211701cad49..77a3225adcd 100644 --- a/src/caffe/layers/cudnn_softmax_layer.cpp +++ b/src/caffe/layers/cudnn_softmax_layer.cpp @@ -39,8 +39,8 @@ CuDNNSoftmaxLayer::~CuDNNSoftmaxLayer() { // Check that handles have been setup before destroying. if (!handles_setup_) { return; } - cudnnDestroyTensor4dDescriptor(bottom_desc_); - cudnnDestroyTensor4dDescriptor(top_desc_); + cudnnDestroyTensorDescriptor(bottom_desc_); + cudnnDestroyTensorDescriptor(top_desc_); cudnnDestroy(handle_); } diff --git a/src/caffe/layers/cudnn_softmax_layer.cu b/src/caffe/layers/cudnn_softmax_layer.cu index f328afdd831..a9e2fcefaf7 100644 --- a/src/caffe/layers/cudnn_softmax_layer.cu +++ b/src/caffe/layers/cudnn_softmax_layer.cu @@ -17,8 +17,11 @@ void CuDNNSoftmaxLayer::Forward_gpu(const vector*>& bottom, const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* top_data = top[0]->mutable_gpu_data(); CUDNN_CHECK(cudnnSoftmaxForward(handle_, CUDNN_SOFTMAX_ACCURATE, - CUDNN_SOFTMAX_MODE_CHANNEL, - bottom_desc_, bottom_data, top_desc_, top_data)); + CUDNN_SOFTMAX_MODE_CHANNEL, + cudnn::dataType::one, + bottom_desc_, bottom_data, + cudnn::dataType::zero, + top_desc_, top_data)); } template @@ -29,9 +32,13 @@ void CuDNNSoftmaxLayer::Backward_gpu(const vector*>& top, const Dtype* top_diff = top[0]->gpu_diff(); const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + CUDNN_CHECK(cudnnSoftmaxBackward(handle_, CUDNN_SOFTMAX_ACCURATE, - CUDNN_SOFTMAX_MODE_CHANNEL, - top_desc_, top_data, top_desc_, top_diff, bottom_desc_, bottom_diff)); + CUDNN_SOFTMAX_MODE_CHANNEL, + cudnn::dataType::one, + top_desc_, top_data, top_desc_, top_diff, + cudnn::dataType::zero, + bottom_desc_, bottom_diff)); } } diff --git a/src/caffe/layers/cudnn_tanh_layer.cpp b/src/caffe/layers/cudnn_tanh_layer.cpp index b1d2b86384e..376faad324d 100644 --- a/src/caffe/layers/cudnn_tanh_layer.cpp +++ b/src/caffe/layers/cudnn_tanh_layer.cpp @@ -35,8 +35,8 @@ CuDNNTanHLayer::~CuDNNTanHLayer() { // Check that handles have been setup before destroying. if (!handles_setup_) { return; } - cudnnDestroyTensor4dDescriptor(this->bottom_desc_); - cudnnDestroyTensor4dDescriptor(this->top_desc_); + cudnnDestroyTensorDescriptor(this->bottom_desc_); + cudnnDestroyTensorDescriptor(this->top_desc_); cudnnDestroy(this->handle_); } diff --git a/src/caffe/layers/cudnn_tanh_layer.cu b/src/caffe/layers/cudnn_tanh_layer.cu index bf9ec7cfac4..d287f6fee85 100644 --- a/src/caffe/layers/cudnn_tanh_layer.cu +++ b/src/caffe/layers/cudnn_tanh_layer.cu @@ -13,8 +13,11 @@ void CuDNNTanHLayer::Forward_gpu(const vector*>& bottom, const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* top_data = top[0]->mutable_gpu_data(); CUDNN_CHECK(cudnnActivationForward(this->handle_, - CUDNN_ACTIVATION_TANH, - this->bottom_desc_, bottom_data, this->top_desc_, top_data)); + CUDNN_ACTIVATION_TANH, + cudnn::dataType::one, + this->bottom_desc_, bottom_data, + cudnn::dataType::zero, + this->top_desc_, top_data)); } template @@ -29,10 +32,14 @@ void CuDNNTanHLayer::Backward_gpu(const vector*>& top, const Dtype* top_diff = top[0]->gpu_diff(); const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + CUDNN_CHECK(cudnnActivationBackward(this->handle_, - CUDNN_ACTIVATION_TANH, - this->top_desc_, top_data, this->top_desc_, top_diff, - this->bottom_desc_, bottom_data, this->bottom_desc_, bottom_diff)); + CUDNN_ACTIVATION_TANH, + cudnn::dataType::one, + this->top_desc_, top_data, this->top_desc_, top_diff, + this->bottom_desc_, bottom_data, + cudnn::dataType::zero, + this->bottom_desc_, bottom_diff)); } INSTANTIATE_LAYER_GPU_FUNCS(CuDNNTanHLayer); diff --git a/src/caffe/test/test_pooling_layer.cpp b/src/caffe/test/test_pooling_layer.cpp index 435caa8381e..e9964e7f0b7 100644 --- a/src/caffe/test/test_pooling_layer.cpp +++ b/src/caffe/test/test_pooling_layer.cpp @@ -976,9 +976,6 @@ TYPED_TEST(CuDNNPoolingLayerTest, TestSetupCuDNN) { EXPECT_EQ(this->blob_top_->width(), 2); } -// This test and all following cuDNN pooling tests with padding are commented -// for now, since cuDNN pooling does not currently support padding. -/* TYPED_TEST(CuDNNPoolingLayerTest, TestSetupPaddedCuDNN) { Caffe::set_mode(Caffe::GPU); LayerParameter layer_param; @@ -994,7 +991,6 @@ TYPED_TEST(CuDNNPoolingLayerTest, TestSetupPaddedCuDNN) { EXPECT_EQ(this->blob_top_->height(), 4); EXPECT_EQ(this->blob_top_->width(), 3); } -*/ /* TYPED_TEST(CuDNNPoolingLayerTest, PrintBackwardCuDNN) { @@ -1062,7 +1058,6 @@ TYPED_TEST(CuDNNPoolingLayerTest, TestGradientMaxCuDNN) { } } -/* TYPED_TEST(CuDNNPoolingLayerTest, TestForwardMaxPaddedCuDNN) { Caffe::set_mode(Caffe::GPU); LayerParameter layer_param; @@ -1107,7 +1102,6 @@ TYPED_TEST(CuDNNPoolingLayerTest, TestForwardMaxPaddedCuDNN) { EXPECT_NEAR(this->blob_top_->cpu_data()[7], 4, epsilon); EXPECT_NEAR(this->blob_top_->cpu_data()[8], 1, epsilon); } -*/ /* TYPED_TEST(CuDNNPoolingLayerTest, TestGradientMaxTopMaskCuDNN) { @@ -1175,7 +1169,6 @@ TYPED_TEST(CuDNNPoolingLayerTest, TestGradientAveCuDNN) { } } -/* TYPED_TEST(CuDNNPoolingLayerTest, TestGradientAvePaddedCuDNN) { Caffe::set_mode(Caffe::GPU); for (int kernel_h = 3; kernel_h <= 4; kernel_h++) { @@ -1194,7 +1187,6 @@ TYPED_TEST(CuDNNPoolingLayerTest, TestGradientAvePaddedCuDNN) { } } } -*/ #endif diff --git a/src/caffe/util/cudnn.cpp b/src/caffe/util/cudnn.cpp new file mode 100644 index 00000000000..1772f0099ce --- /dev/null +++ b/src/caffe/util/cudnn.cpp @@ -0,0 +1,23 @@ +#ifdef USE_CUDNN +#include "caffe/util/cudnn.hpp" + +namespace caffe { +namespace cudnn { + +float dataType::oneval = 1.0; +float dataType::zeroval = 0.0; +const void* dataType::one = + static_cast(&dataType::oneval); +const void* dataType::zero = + static_cast(&dataType::zeroval); + +double dataType::oneval = 1.0; +double dataType::zeroval = 0.0; +const void* dataType::one = + static_cast(&dataType::oneval); +const void* dataType::zero = + static_cast(&dataType::zeroval); + +} // namespace cudnn +} // namespace caffe +#endif