From 1c49130c33ebdec042ff6da18d03b7c5f6ad8c93 Mon Sep 17 00:00:00 2001 From: ZhouYzzz Date: Fri, 15 Apr 2016 22:51:49 +0800 Subject: [PATCH 1/2] Allow the python layer have attribute "phase" --- include/caffe/layers/python_layer.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/include/caffe/layers/python_layer.hpp b/include/caffe/layers/python_layer.hpp index b839d52684e..66dbbdf13b8 100644 --- a/include/caffe/layers/python_layer.hpp +++ b/include/caffe/layers/python_layer.hpp @@ -26,6 +26,7 @@ class PythonLayer : public Layer { } self_.attr("param_str") = bp::str( this->layer_param_.python_param().param_str()); + self_.attr("phase") = static_cast(this->phase_); self_.attr("setup")(bottom, top); } virtual void Reshape(const vector*>& bottom, From c2dba923b82c669f2998a3174310fbbb5c64c39f Mon Sep 17 00:00:00 2001 From: ZhouYzzz Date: Wed, 4 May 2016 18:00:12 +0800 Subject: [PATCH 2/2] Add test for attribute "phase" in python layer --- python/caffe/test/test_python_layer.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/python/caffe/test/test_python_layer.py b/python/caffe/test/test_python_layer.py index e46b7118014..899514e90f1 100644 --- a/python/caffe/test/test_python_layer.py +++ b/python/caffe/test/test_python_layer.py @@ -44,6 +44,18 @@ def forward(self, bottom, top): def backward(self, top, propagate_down, bottom): self.blobs[0].diff[0] = 1 +class PhaseLayer(caffe.Layer): + """A layer for checking attribute `phase`""" + + def setup(self, bottom, top): + pass + + def reshape(self, bootom, top): + top[0].reshape() + + def forward(self, bottom, top): + top[0].data[()] = self.phase + def python_net_file(): with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f: f.write("""name: 'pythonnet' force_backward: true @@ -76,6 +88,14 @@ def parameter_net_file(): """) return f.name +def phase_net_file(): + with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f: + f.write("""name: 'pythonnet' force_backward: true + layer { type: 'Python' name: 'layer' top: 'phase' + python_param { module: 'test_python_layer' layer: 'PhaseLayer' } } + """) + return f.name + @unittest.skipIf('Python' not in caffe.layer_type_list(), 'Caffe built without Python layer support') @@ -140,3 +160,9 @@ def test_parameter(self): self.assertEqual(layer.blobs[0].data[0], 1) os.remove(net_file) + + def test_phase(self): + net_file = phase_net_file() + for phase in caffe.TRAIN, caffe.TEST: + net = caffe.Net(net_file, phase) + self.assertEqual(net.forward()['phase'], phase)