Skip to content
Snippets Groups Projects
Commit e1bfa0c7 authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Add testing

parent d80a60c1
1 merge request!46Merge COMBO 3.0 into master
...@@ -54,21 +54,21 @@ class Sentence: ...@@ -54,21 +54,21 @@ class Sentence:
return len(self.tokens) return len(self.tokens)
class _TokenList(conllu.models.TokenList): class _TokenList(conllu.TokenList):
@overrides @overrides
def __repr__(self): def __repr__(self):
return 'TokenList<' + ', '.join(token['token'] for token in self) + '>' return 'TokenList<' + ', '.join(token['text'] for token in self) + '>'
def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.models.TokenList: def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.TokenList:
tokens = [] tokens = []
for token in sentence.tokens: for token in sentence.tokens:
token_dict = collections.OrderedDict(dataclasses.asdict(token)) token_dict = collections.OrderedDict(token.as_dict(keep_semrel))
# Remove semrel to have default conllu format. # Remove semrel to have default conllu format.
if not keep_semrel: # if not keep_semrel:
del token_dict["semrel"] # del token_dict["semrel"]
del token_dict["embeddings"] # del token_dict["embeddings"]
tokens.append(token_dict) tokens.append(token_dict)
# Range tokens must be tuple not list, this is conllu library requirement # Range tokens must be tuple not list, this is conllu library requirement
for t in tokens: for t in tokens:
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
Adapted from AllenNLP Adapted from AllenNLP
https://github.com/allenai/allennlp/blob/main/allennlp/data/tokenizers/token_class.py https://github.com/allenai/allennlp/blob/main/allennlp/data/tokenizers/token_class.py
""" """
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
...@@ -95,6 +95,20 @@ class Token: ...@@ -95,6 +95,20 @@ class Token:
self.text_id = text_id self.text_id = text_id
self.type_id = type_id self.type_id = type_id
def as_dict(self, semrel: bool = True) -> Dict[str, Any]:
repr = {}
repr_keys = [
'text', 'idx', 'lemma', 'upostag', 'xpostag', 'entity_type', 'feats',
'head', 'deprel', 'deps', 'misc'
]
for rk in repr_keys:
repr[rk] = self.__getattribute__(rk)
if semrel:
repr['semrel'] = self.semrel
return repr
def __str__(self): def __str__(self):
return self.text return self.text
......
import json import json
import logging import logging
import os
import pathlib import pathlib
import tempfile import tempfile
from typing import Dict from typing import Dict
...@@ -7,6 +8,7 @@ from typing import Dict ...@@ -7,6 +8,7 @@ from typing import Dict
import torch import torch
from absl import app, flags from absl import app, flags
import pytorch_lightning as pl import pytorch_lightning as pl
from tqdm import tqdm
from combo.training.trainable_combo import TrainableCombo from combo.training.trainable_combo import TrainableCombo
from combo.utils import checks, ComboLogger from combo.utils import checks, ComboLogger
...@@ -15,6 +17,7 @@ from combo.config import resolve ...@@ -15,6 +17,7 @@ from combo.config import resolve
from combo.default_model import default_ud_dataset_reader, default_data_loader from combo.default_model import default_ud_dataset_reader, default_data_loader
from combo.modules.archival import load_archive, archive from combo.modules.archival import load_archive, archive
from combo.predict import COMBO from combo.predict import COMBO
from combo.data import api
logging.setLoggerClass(ComboLogger) logging.setLoggerClass(ComboLogger)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -65,8 +68,9 @@ flags.DEFINE_string(name="finetuning_validation_data_path", default="", ...@@ -65,8 +68,9 @@ flags.DEFINE_string(name="finetuning_validation_data_path", default="",
help="Validation data path(s)") help="Validation data path(s)")
# Test after training flags # Test after training flags
flags.DEFINE_string(name="test_path", default=None, flags.DEFINE_string(name="test_data_path", default=None,
help="Test path file.") help="Test path file.")
flags.DEFINE_alias(name="test_data", original_name="test_data_path")
# Experimental # Experimental
flags.DEFINE_boolean(name="use_pure_config", default=False, flags.DEFINE_boolean(name="use_pure_config", default=False,
...@@ -111,12 +115,16 @@ def run(_): ...@@ -111,12 +115,16 @@ def run(_):
serialization_dir = tempfile.mkdtemp(prefix='combo', dir=FLAGS.serialization_dir) serialization_dir = tempfile.mkdtemp(prefix='combo', dir=FLAGS.serialization_dir)
params['vocabulary']['parameters']['directory'] = os.path.join('/'.join(FLAGS.config_path.split('/')[:-1]),
params['vocabulary']['parameters']['directory'])
try: try:
vocabulary = resolve(params['vocabulary']) vocabulary = resolve(params['vocabulary'])
except KeyError: except KeyError:
logger.error('No vocabulary in config.json!') logger.error('No vocabulary in config.json!')
return return
model = resolve(params['model'], pass_down_parameters={'vocabulary': vocabulary}) model = resolve(params['model'], pass_down_parameters={'vocabulary': vocabulary})
dataset_reader = None dataset_reader = None
...@@ -184,9 +192,23 @@ def run(_): ...@@ -184,9 +192,23 @@ def run(_):
gradient_clip_val=5) gradient_clip_val=5)
trainer.fit(model=nlp, train_dataloaders=train_data_loader, val_dataloaders=validation_data_loader) trainer.fit(model=nlp, train_dataloaders=train_data_loader, val_dataloaders=validation_data_loader)
logger.info(f'Archiving the fine-tuned model in {serialization_dir}', prefix=prefix) logger.info(f'Archiving the model in {serialization_dir}', prefix=prefix)
archive(model, serialization_dir, train_data_loader, validation_data_loader, dataset_reader) archive(model, serialization_dir, train_data_loader, validation_data_loader, dataset_reader)
logger.info(f"Training model stored in: {serialization_dir}", prefix=prefix) logger.info(f"Model stored in: {serialization_dir}", prefix=prefix)
if FLAGS.test_data_path and FLAGS.output_file:
checks.file_exists(FLAGS.test_data_path)
if not dataset_reader:
logger.info("No dataset reader in the configuration or archive file - using a default UD dataset reader",
prefix=prefix)
dataset_reader = default_ud_dataset_reader()
logger.info("Predicting test examples", prefix=prefix)
test_trees = dataset_reader.read(FLAGS.test_data_path)
predictor = COMBO(model, dataset_reader)
with open(FLAGS.output_file, "w") as file:
for tree in tqdm(test_trees):
file.writelines(api.sentence2conllu(predictor.predict_instance(tree),
keep_semrel=dataset_reader.use_sem).serialize())
elif FLAGS.mode == 'predict': elif FLAGS.mode == 'predict':
predictor = get_predictor() predictor = get_predictor()
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment