From 846f2c3cce8b937637e0b46a7f62be068b835ade Mon Sep 17 00:00:00 2001 From: Jonathan L Long Date: Fri, 28 Aug 2015 21:27:11 -0700 Subject: [PATCH] fix GPU data race Previously, the prefetch GPU -> top GPU and prefetch CPU -> prefetch GPU copies were launched concurrently in separate streams, allowing the next batch to be copied in before the current one is read. This patch explicitly synchronizes the prefetch -> top copy wrt the host, preventing the CPU -> GPU from being launched until its completion. --- src/caffe/layers/base_data_layer.cu | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/caffe/layers/base_data_layer.cu b/src/caffe/layers/base_data_layer.cu index 56439bc506a..ff6e412aba6 100644 --- a/src/caffe/layers/base_data_layer.cu +++ b/src/caffe/layers/base_data_layer.cu @@ -20,7 +20,9 @@ void BasePrefetchingDataLayer::Forward_gpu( caffe_copy(batch->label_.count(), batch->label_.gpu_data(), top[1]->mutable_gpu_data()); } - + // Ensure the copy is synchronous wrt the host, so that the next batch isn't + // copied in meanwhile. + CUDA_CHECK(cudaStreamSynchronize(cudaStreamDefault)); prefetch_free_.push(batch); }