diff --git a/python/caffe/net_spec.py b/python/caffe/net_spec.py index 1b4814a45c6..31cde7ad946 100644 --- a/python/caffe/net_spec.py +++ b/python/caffe/net_spec.py @@ -18,7 +18,7 @@ class -- assign to its attributes directly to name layers, and call are not guaranteed to be forward-compatible. """ -from collections import OrderedDict +from collections import OrderedDict, Counter from .proto import caffe_pb2 from google import protobuf @@ -44,10 +44,8 @@ def to_proto(*tops): """Generate a NetParameter that contains all layers needed to compute all arguments.""" - if not isinstance(tops, tuple): - tops = (tops,) layers = OrderedDict() - autonames = {} + autonames = Counter() for top in tops: top.fn._to_proto(layers, {}, autonames) net = caffe_pb2.NetParameter() @@ -89,6 +87,9 @@ def to_proto(self): return to_proto(self) + def _to_proto(self, layers, names, autonames): + return self.fn._to_proto(layers, names, autonames) + class Function(object): """A Function specifies a layer, its parameters, and its inputs (which @@ -107,11 +108,18 @@ def __init__(self, type_name, inputs, params): del self.params['in_place'] self.tops = tuple(Top(self, n) for n in range(self.ntop)) - def _get_name(self, top, names, autonames): + def _get_name(self, names, autonames): + if self not in names and self.ntop > 0: + names[self] = self._get_top_name(self.tops[0], names, autonames) + elif self not in names: + autonames[self.type_name] += 1 + names[self] = self.type_name + str(autonames[self.type_name]) + return names[self] + + def _get_top_name(self, top, names, autonames): if top not in names: - n = autonames.setdefault(top.fn.type_name, 1) autonames[top.fn.type_name] += 1 - names[top] = top.fn.type_name + str(n) + names[top] = top.fn.type_name + str(autonames[top.fn.type_name]) return names[top] def _to_proto(self, layers, names, autonames): @@ -119,7 +127,7 @@ def _to_proto(self, layers, names, autonames): return bottom_names = [] for inp in self.inputs: - inp.fn._to_proto(layers, names, autonames) + inp._to_proto(layers, names, autonames) bottom_names.append(layers[inp.fn].top[inp.n]) layer = caffe_pb2.LayerParameter() layer.type = self.type_name @@ -129,8 +137,8 @@ def _to_proto(self, layers, names, autonames): layer.top.extend(layer.bottom) else: for top in self.tops: - layer.top.append(self._get_name(top, names, autonames)) - layer.name = self._get_name(self.tops[0], names, autonames) + layer.top.append(self._get_top_name(top, names, autonames)) + layer.name = self._get_name(names, autonames) for k, v in six.iteritems(self.params): # special case to handle generic *params @@ -163,10 +171,10 @@ def __getattr__(self, name): def to_proto(self): names = {v: k for k, v in six.iteritems(self.tops)} - autonames = {} + autonames = Counter() layers = OrderedDict() for name, top in six.iteritems(self.tops): - top.fn._to_proto(layers, names, autonames) + top._to_proto(layers, names, autonames) net = caffe_pb2.NetParameter() net.layer.extend(layers.values()) return net @@ -180,7 +188,9 @@ class Layers(object): def __getattr__(self, name): def layer_fn(*args, **kwargs): fn = Function(name, args, kwargs) - if fn.ntop == 1: + if fn.ntop == 0: + return fn + elif fn.ntop == 1: return fn.tops[0] else: return fn.tops diff --git a/python/caffe/test/test_net_spec.py b/python/caffe/test/test_net_spec.py index 65b73b96f73..b344946932a 100644 --- a/python/caffe/test/test_net_spec.py +++ b/python/caffe/test/test_net_spec.py @@ -41,6 +41,14 @@ def anon_lenet(batch_size): loss = L.SoftmaxWithLoss(ip2, label) return loss.to_proto() +def silent_net(): + n = caffe.NetSpec() + n.data, n.data2 = L.DummyData(shape=[dict(dim=[3]), dict(dim=[4, 2])], + ntop=2) + n.silence_data = L.Silence(n.data, ntop=0) + n.silence_data2 = L.Silence(n.data2, ntop=0) + return n.to_proto() + class TestNetSpec(unittest.TestCase): def load_net(self, net_proto): f = tempfile.NamedTemporaryFile(delete=False) @@ -65,3 +73,10 @@ def test_lenet(self): net_proto.layer[6].top) net = self.load_net(net_proto) self.assertEqual(len(net.layers), 9) + + def test_zero_tops(self): + """Test net construction for top-less layers.""" + + net_proto = silent_net() + net = self.load_net(net_proto) + self.assertEqual(len(net.forward()), 0)