diff --git a/combo/data/api.py b/combo/data/api.py index 6233cc000a94342aad0fe97e38e1c6ec703d2a9d..82cca44f19fe8fb57a5cf7b16f336036fbf6a135 100644 --- a/combo/data/api.py +++ b/combo/data/api.py @@ -54,21 +54,21 @@ class Sentence: return len(self.tokens) -class _TokenList(conllu.models.TokenList): +class _TokenList(conllu.TokenList): @overrides def __repr__(self): - return 'TokenList<' + ', '.join(token['token'] for token in self) + '>' + return 'TokenList<' + ', '.join(token['text'] for token in self) + '>' -def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.models.TokenList: +def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.TokenList: tokens = [] for token in sentence.tokens: - token_dict = collections.OrderedDict(dataclasses.asdict(token)) + token_dict = collections.OrderedDict(token.as_dict(keep_semrel)) # Remove semrel to have default conllu format. - if not keep_semrel: - del token_dict["semrel"] - del token_dict["embeddings"] + # if not keep_semrel: + # del token_dict["semrel"] + # del token_dict["embeddings"] tokens.append(token_dict) # Range tokens must be tuple not list, this is conllu library requirement for t in tokens: diff --git a/combo/data/tokenizers/token.py b/combo/data/tokenizers/token.py index 536e87de869a961ff3a0f3c4806c9791dbc5f038..33182ddd3a30dafa7886fed5c337af43b1b116d3 100644 --- a/combo/data/tokenizers/token.py +++ b/combo/data/tokenizers/token.py @@ -2,7 +2,7 @@ Adapted from AllenNLP https://github.com/allenai/allennlp/blob/main/allennlp/data/tokenizers/token_class.py """ - +from collections import defaultdict from typing import Any, Dict, List, Optional, Tuple, Union import logging from dataclasses import dataclass, field @@ -95,6 +95,20 @@ class Token: self.text_id = text_id self.type_id = type_id + def as_dict(self, semrel: bool = True) -> Dict[str, Any]: + repr = {} + repr_keys = [ + 'text', 'idx', 'lemma', 'upostag', 'xpostag', 'entity_type', 'feats', + 'head', 'deprel', 'deps', 'misc' + ] + for rk in repr_keys: + repr[rk] = self.__getattribute__(rk) + + if semrel: + repr['semrel'] = self.semrel + + return repr + def __str__(self): return self.text diff --git a/combo/main.py b/combo/main.py index 5dee611b547d1a6e562d273921d904bf417a5652..d5fb9dc317c1931e3a3f6f4ebeb6934949b3500f 100755 --- a/combo/main.py +++ b/combo/main.py @@ -1,5 +1,6 @@ import json import logging +import os import pathlib import tempfile from typing import Dict @@ -7,6 +8,7 @@ from typing import Dict import torch from absl import app, flags import pytorch_lightning as pl +from tqdm import tqdm from combo.training.trainable_combo import TrainableCombo from combo.utils import checks, ComboLogger @@ -15,6 +17,7 @@ from combo.config import resolve from combo.default_model import default_ud_dataset_reader, default_data_loader from combo.modules.archival import load_archive, archive from combo.predict import COMBO +from combo.data import api logging.setLoggerClass(ComboLogger) logger = logging.getLogger(__name__) @@ -65,8 +68,9 @@ flags.DEFINE_string(name="finetuning_validation_data_path", default="", help="Validation data path(s)") # Test after training flags -flags.DEFINE_string(name="test_path", default=None, +flags.DEFINE_string(name="test_data_path", default=None, help="Test path file.") +flags.DEFINE_alias(name="test_data", original_name="test_data_path") # Experimental flags.DEFINE_boolean(name="use_pure_config", default=False, @@ -111,12 +115,16 @@ def run(_): serialization_dir = tempfile.mkdtemp(prefix='combo', dir=FLAGS.serialization_dir) + params['vocabulary']['parameters']['directory'] = os.path.join('/'.join(FLAGS.config_path.split('/')[:-1]), + params['vocabulary']['parameters']['directory']) + try: vocabulary = resolve(params['vocabulary']) except KeyError: logger.error('No vocabulary in config.json!') return + model = resolve(params['model'], pass_down_parameters={'vocabulary': vocabulary}) dataset_reader = None @@ -184,9 +192,23 @@ def run(_): gradient_clip_val=5) trainer.fit(model=nlp, train_dataloaders=train_data_loader, val_dataloaders=validation_data_loader) - logger.info(f'Archiving the fine-tuned model in {serialization_dir}', prefix=prefix) + logger.info(f'Archiving the model in {serialization_dir}', prefix=prefix) archive(model, serialization_dir, train_data_loader, validation_data_loader, dataset_reader) - logger.info(f"Training model stored in: {serialization_dir}", prefix=prefix) + logger.info(f"Model stored in: {serialization_dir}", prefix=prefix) + + if FLAGS.test_data_path and FLAGS.output_file: + checks.file_exists(FLAGS.test_data_path) + if not dataset_reader: + logger.info("No dataset reader in the configuration or archive file - using a default UD dataset reader", + prefix=prefix) + dataset_reader = default_ud_dataset_reader() + logger.info("Predicting test examples", prefix=prefix) + test_trees = dataset_reader.read(FLAGS.test_data_path) + predictor = COMBO(model, dataset_reader) + with open(FLAGS.output_file, "w") as file: + for tree in tqdm(test_trees): + file.writelines(api.sentence2conllu(predictor.predict_instance(tree), + keep_semrel=dataset_reader.use_sem).serialize()) elif FLAGS.mode == 'predict': predictor = get_predictor()