diff --git a/tools/extra/parse_log.py b/tools/extra/parse_log.py index 4248e2b87a3..81479b3bf34 100755 --- a/tools/extra/parse_log.py +++ b/tools/extra/parse_log.py @@ -23,12 +23,14 @@ def parse_log(path_to_log): """ regex_iteration = re.compile('Iteration (\d+)') + regex_netnum = re.compile('net \(#(\d+)\)') regex_train_output = re.compile('Train net output #(\d+): (\S+) = ([\.\deE+-]+)') regex_test_output = re.compile('Test net output #(\d+): (\S+) = ([\.\deE+-]+)') regex_learning_rate = re.compile('lr = ([-+]?[0-9]*\.?[0-9]+([eE]?[-+]?[0-9]+)?)') # Pick out lines of interest iteration = -1 + netnum = 0 learning_rate = float('NaN') train_dict_list = [] test_dict_list = [] @@ -44,6 +46,9 @@ def parse_log(path_to_log): iteration_match = regex_iteration.search(line) if iteration_match: iteration = float(iteration_match.group(1)) + netnum_match = regex_netnum.search(line) + if netnum_match: + netnum = int(netnum_match.group(1)) if iteration == -1: # Only start parsing for other stuff if we've found the first # iteration @@ -70,11 +75,11 @@ def parse_log(path_to_log): train_dict_list, train_row = parse_line_for_net_output( regex_train_output, train_row, train_dict_list, - line, iteration, seconds, learning_rate + line, iteration, netnum, seconds, learning_rate ) test_dict_list, test_row = parse_line_for_net_output( regex_test_output, test_row, test_dict_list, - line, iteration, seconds, learning_rate + line, iteration, netnum, seconds, learning_rate ) fix_initial_nan_learning_rate(train_dict_list) @@ -84,7 +89,7 @@ def parse_log(path_to_log): def parse_line_for_net_output(regex_obj, row, row_dict_list, - line, iteration, seconds, learning_rate): + line, iteration, netnum, seconds, learning_rate): """Parse a single line for training or test output Returns a a tuple with (row_dict_list, row) @@ -95,7 +100,7 @@ def parse_line_for_net_output(regex_obj, row, row_dict_list, output_match = regex_obj.search(line) if output_match: - if not row or row['NumIters'] != iteration: + if not row or row['NumIters'] != iteration or row['NetNum'] != netnum : # Push the last row and start a new one if row: # If we're on a new iteration, push the last row @@ -106,6 +111,7 @@ def parse_line_for_net_output(regex_obj, row, row_dict_list, row = OrderedDict([ ('NumIters', iteration), + ('NetNum', netnum), ('Seconds', seconds), ('LearningRate', learning_rate) ])