diff --git a/combo/data/dataset_readers/dataset_reader.py b/combo/data/dataset_readers/dataset_reader.py index fbd11595ac9b2f5f27b8da913686d90ec321475d..873009849c141e156ab4cddc91c8dcc90c148989 100644 --- a/combo/data/dataset_readers/dataset_reader.py +++ b/combo/data/dataset_readers/dataset_reader.py @@ -42,6 +42,10 @@ class DatasetReader(IterableDataset, FromParameters): def tokenizer(self) -> Optional[Tokenizer]: return self.__tokenizer + @tokenizer.setter + def tokenizer(self, value): + self.__tokenizer = value + @property def token_indexers(self) -> Optional[Dict[str, TokenIndexer]]: return self.__token_indexers diff --git a/combo/data/dataset_readers/universal_dependencies_dataset_reader.py b/combo/data/dataset_readers/universal_dependencies_dataset_reader.py index 5b726f3641f14cbcd19153b9dbd049415f3f3754..a3b040889671acbec754ed6ef2631181d6710156 100644 --- a/combo/data/dataset_readers/universal_dependencies_dataset_reader.py +++ b/combo/data/dataset_readers/universal_dependencies_dataset_reader.py @@ -269,4 +269,4 @@ class UniversalDependenciesDatasetReader(DatasetReader, ABC): one_hot_encoding[index] = 1 return one_hot_encoding - return _m_from_n_ones_encoding \ No newline at end of file + return _m_from_n_ones_encoding diff --git a/combo/main.py b/combo/main.py index 9e40120761bc44c9575bf339cf8611e9bbd0cce6..dd4bdd292626789424fb7216d5b3a22f5dd1b4ba 100755 --- a/combo/main.py +++ b/combo/main.py @@ -365,14 +365,20 @@ def run(_): elif FLAGS.mode == 'predict': prefix = 'Predicting' logger.info('Loading the model', prefix=prefix) - model, _, _, _, dataset_reader = load_archive(FLAGS.model_path) + model, config, _, _, dataset_reader = load_archive(FLAGS.model_path) + + if config.get("tokenizer_language") is None: + logger.warning("Tokenizer language was not found in archive's configuration file - " + + "using the --tokenizer_language parameter (default: English)") + tokenizer_language = config.get("tokenizer_language", FLAGS.tokenizer_language) if not dataset_reader: logger.info("No dataset reader in the configuration or archive file - using a default UD dataset reader", prefix=prefix) dataset_reader = default_ud_dataset_reader(FLAGS.pretrained_transformer_name) + dataset_reader.tokenizer = LamboTokenizer(tokenizer_language) - predictor = COMBO(model, dataset_reader, LamboTokenizer(language=FLAGS.tokenizer_language)) + predictor = COMBO(model, dataset_reader) if FLAGS.input_file == '-': print("Interactive mode.") @@ -388,10 +394,7 @@ def run(_): except ConfigurationError as e: handle_error(e, prefix) - try: - pathlib.Path(FLAGS.output_file).mkdir(parents=True, exist_ok=True) - except FileExistsError: - pass + pathlib.Path(FLAGS.output_file).touch(exist_ok=True) logger.info("Predicting examples from file", prefix=prefix) @@ -406,7 +409,7 @@ def run(_): predictions.append(prediction) else: - tokenizer = LamboTokenizer(FLAGS.tokenizer_language) + tokenizer = LamboTokenizer(tokenizer_language) with open(FLAGS.input_file, "r") as file: input_sentences = tokenizer.segment(file.read()) with open(FLAGS.output_file, "w") as file: @@ -418,6 +421,9 @@ def run(_): if FLAGS.save_matrices: logger.info("Saving matrices", prefix=prefix) + if FLAGS.serialization_dir is None or pathlib.Path(FLAGS.serialization_dir).exists(): + logger.warning('Serialization path was not passed as an argument - skipping matrix extraction.') + return extract_combo_matrices(predictions, pathlib.Path(FLAGS.serialization_dir), pathlib.Path(FLAGS.input_file),