Skip to content
Snippets Groups Projects
Commit 7210dc10 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Make tensorboard metrics lighter, group output file config flags.

parent d9dd119c
No related branches found
No related tags found
No related merge requests found
......@@ -46,6 +46,8 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
fields = list(parser.DEFAULT_FIELDS)
fields[1] = 'token' # use 'token' instead of 'form'
field_parsers = parser.DEFAULT_FIELD_PARSERS
# Do not make it nullable
field_parsers.pop('xpostag', None)
if self._use_sem:
fields = list(fields)
fields.append('semrel')
......
......@@ -25,6 +25,8 @@ flags.DEFINE_enum(name='mode', default=None, enum_values=['train', 'predict'],
# Common flags
flags.DEFINE_integer(name='cuda_device', default=-1,
help="Cuda device id (default -1 cpu)")
flags.DEFINE_string(name='output_file', default='output.log',
help='Predictions result file.')
# Training flags
flags.DEFINE_list(name='training_data_path', default="./tests/fixtures/example.conllu",
......@@ -60,8 +62,6 @@ flags.DEFINE_string(name='config_path', default='config.template.jsonnet',
help='Config file path.')
# Test after training flags
flags.DEFINE_string(name='result', default='result.conllu',
help='Test result path file')
flags.DEFINE_string(name='test_path', default=None,
help='Test path file.')
......@@ -74,8 +74,6 @@ flags.DEFINE_string(name='model_path', default=None,
help='Pretrained model path.')
flags.DEFINE_string(name='input_file', default=None,
help='File to predict path')
flags.DEFINE_string(name='output_file', default='output.log',
help='Predictions result file.')
flags.DEFINE_integer(name='batch_size', default=1,
help='Prediction batch size.')
flags.DEFINE_boolean(name='silent', default=True,
......@@ -126,7 +124,7 @@ def run(_):
logger.info(f'Finetuned model stored in: {serialization_dir}')
if FLAGS.test_path and FLAGS.result:
if FLAGS.test_path and FLAGS.output_file:
checks.file_exists(FLAGS.test_path)
params = common.Params.from_file(FLAGS.config_path, ext_vars=_get_ext_vars())['dataset_reader']
params.pop('type')
......@@ -137,7 +135,7 @@ def run(_):
)
test_path = FLAGS.test_path
test_trees = dataset_reader.read(test_path)
with open(FLAGS.result, 'w') as f:
with open(FLAGS.output_file, 'w') as f:
for tree in test_trees:
f.writelines(predictor.predict_instance_as_tree(tree).serialize())
else:
......
......@@ -102,7 +102,8 @@ class FeedForwardPredictor(Predictor):
"len(hidden_dims) (%d) + 1 != num_layers (%d)" % (len(hidden_dims), num_layers)
)
assert vocab_namespace in vocab.get_namespaces()
assert vocab_namespace in vocab.get_namespaces(),\
f"There is not {vocab_namespace} in created vocabs, check if this field has any values to predict!"
hidden_dims = hidden_dims + [vocab.get_vocab_size(vocab_namespace)]
return cls(feedforward.FeedForward(
......
......@@ -364,8 +364,9 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
},
tensorboard_writer: {
serialization_dir: metrics_dir,
should_log_learning_rate: true,
summary_interval: 2,
should_log_learning_rate: false,
should_log_parameter_statistics: false,
summary_interval: 100,
},
validation_metric: "+EM",
},
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment