Skip to content
Snippets Groups Projects
Commit 7210dc10 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Make tensorboard metrics lighter, group output file config flags.

parent d9dd119c
Branches
Tags
No related merge requests found
...@@ -46,6 +46,8 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): ...@@ -46,6 +46,8 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
fields = list(parser.DEFAULT_FIELDS) fields = list(parser.DEFAULT_FIELDS)
fields[1] = 'token' # use 'token' instead of 'form' fields[1] = 'token' # use 'token' instead of 'form'
field_parsers = parser.DEFAULT_FIELD_PARSERS field_parsers = parser.DEFAULT_FIELD_PARSERS
# Do not make it nullable
field_parsers.pop('xpostag', None)
if self._use_sem: if self._use_sem:
fields = list(fields) fields = list(fields)
fields.append('semrel') fields.append('semrel')
......
...@@ -25,6 +25,8 @@ flags.DEFINE_enum(name='mode', default=None, enum_values=['train', 'predict'], ...@@ -25,6 +25,8 @@ flags.DEFINE_enum(name='mode', default=None, enum_values=['train', 'predict'],
# Common flags # Common flags
flags.DEFINE_integer(name='cuda_device', default=-1, flags.DEFINE_integer(name='cuda_device', default=-1,
help="Cuda device id (default -1 cpu)") help="Cuda device id (default -1 cpu)")
flags.DEFINE_string(name='output_file', default='output.log',
help='Predictions result file.')
# Training flags # Training flags
flags.DEFINE_list(name='training_data_path', default="./tests/fixtures/example.conllu", flags.DEFINE_list(name='training_data_path', default="./tests/fixtures/example.conllu",
...@@ -60,8 +62,6 @@ flags.DEFINE_string(name='config_path', default='config.template.jsonnet', ...@@ -60,8 +62,6 @@ flags.DEFINE_string(name='config_path', default='config.template.jsonnet',
help='Config file path.') help='Config file path.')
# Test after training flags # Test after training flags
flags.DEFINE_string(name='result', default='result.conllu',
help='Test result path file')
flags.DEFINE_string(name='test_path', default=None, flags.DEFINE_string(name='test_path', default=None,
help='Test path file.') help='Test path file.')
...@@ -74,8 +74,6 @@ flags.DEFINE_string(name='model_path', default=None, ...@@ -74,8 +74,6 @@ flags.DEFINE_string(name='model_path', default=None,
help='Pretrained model path.') help='Pretrained model path.')
flags.DEFINE_string(name='input_file', default=None, flags.DEFINE_string(name='input_file', default=None,
help='File to predict path') help='File to predict path')
flags.DEFINE_string(name='output_file', default='output.log',
help='Predictions result file.')
flags.DEFINE_integer(name='batch_size', default=1, flags.DEFINE_integer(name='batch_size', default=1,
help='Prediction batch size.') help='Prediction batch size.')
flags.DEFINE_boolean(name='silent', default=True, flags.DEFINE_boolean(name='silent', default=True,
...@@ -126,7 +124,7 @@ def run(_): ...@@ -126,7 +124,7 @@ def run(_):
logger.info(f'Finetuned model stored in: {serialization_dir}') logger.info(f'Finetuned model stored in: {serialization_dir}')
if FLAGS.test_path and FLAGS.result: if FLAGS.test_path and FLAGS.output_file:
checks.file_exists(FLAGS.test_path) checks.file_exists(FLAGS.test_path)
params = common.Params.from_file(FLAGS.config_path, ext_vars=_get_ext_vars())['dataset_reader'] params = common.Params.from_file(FLAGS.config_path, ext_vars=_get_ext_vars())['dataset_reader']
params.pop('type') params.pop('type')
...@@ -137,7 +135,7 @@ def run(_): ...@@ -137,7 +135,7 @@ def run(_):
) )
test_path = FLAGS.test_path test_path = FLAGS.test_path
test_trees = dataset_reader.read(test_path) test_trees = dataset_reader.read(test_path)
with open(FLAGS.result, 'w') as f: with open(FLAGS.output_file, 'w') as f:
for tree in test_trees: for tree in test_trees:
f.writelines(predictor.predict_instance_as_tree(tree).serialize()) f.writelines(predictor.predict_instance_as_tree(tree).serialize())
else: else:
......
...@@ -102,7 +102,8 @@ class FeedForwardPredictor(Predictor): ...@@ -102,7 +102,8 @@ class FeedForwardPredictor(Predictor):
"len(hidden_dims) (%d) + 1 != num_layers (%d)" % (len(hidden_dims), num_layers) "len(hidden_dims) (%d) + 1 != num_layers (%d)" % (len(hidden_dims), num_layers)
) )
assert vocab_namespace in vocab.get_namespaces() assert vocab_namespace in vocab.get_namespaces(),\
f"There is not {vocab_namespace} in created vocabs, check if this field has any values to predict!"
hidden_dims = hidden_dims + [vocab.get_vocab_size(vocab_namespace)] hidden_dims = hidden_dims + [vocab.get_vocab_size(vocab_namespace)]
return cls(feedforward.FeedForward( return cls(feedforward.FeedForward(
......
...@@ -364,8 +364,9 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't ...@@ -364,8 +364,9 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
}, },
tensorboard_writer: { tensorboard_writer: {
serialization_dir: metrics_dir, serialization_dir: metrics_dir,
should_log_learning_rate: true, should_log_learning_rate: false,
summary_interval: 2, should_log_parameter_statistics: false,
summary_interval: 100,
}, },
validation_metric: "+EM", validation_metric: "+EM",
}, },
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment