From 69afc593a8c1d9fff59790a9f7aab48b2b5fdb34 Mon Sep 17 00:00:00 2001 From: Maja Jablonska <majajjablonska@gmail.com> Date: Tue, 21 Nov 2023 23:36:16 +1100 Subject: [PATCH] Add default serialization path and tokenizer language settings in predict.py --- combo/data/dataset_readers/dataset_reader.py | 4 ++++ .../universal_dependencies_dataset_reader.py | 2 +- combo/main.py | 20 ++++++++++++------- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/combo/data/dataset_readers/dataset_reader.py b/combo/data/dataset_readers/dataset_reader.py index fbd1159..8730098 100644 --- a/combo/data/dataset_readers/dataset_reader.py +++ b/combo/data/dataset_readers/dataset_reader.py @@ -42,6 +42,10 @@ class DatasetReader(IterableDataset, FromParameters): def tokenizer(self) -> Optional[Tokenizer]: return self.__tokenizer + @tokenizer.setter + def tokenizer(self, value): + self.__tokenizer = value + @property def token_indexers(self) -> Optional[Dict[str, TokenIndexer]]: return self.__token_indexers diff --git a/combo/data/dataset_readers/universal_dependencies_dataset_reader.py b/combo/data/dataset_readers/universal_dependencies_dataset_reader.py index 5b726f3..a3b0408 100644 --- a/combo/data/dataset_readers/universal_dependencies_dataset_reader.py +++ b/combo/data/dataset_readers/universal_dependencies_dataset_reader.py @@ -269,4 +269,4 @@ class UniversalDependenciesDatasetReader(DatasetReader, ABC): one_hot_encoding[index] = 1 return one_hot_encoding - return _m_from_n_ones_encoding \ No newline at end of file + return _m_from_n_ones_encoding diff --git a/combo/main.py b/combo/main.py index 9e40120..dd4bdd2 100755 --- a/combo/main.py +++ b/combo/main.py @@ -365,14 +365,20 @@ def run(_): elif FLAGS.mode == 'predict': prefix = 'Predicting' logger.info('Loading the model', prefix=prefix) - model, _, _, _, dataset_reader = load_archive(FLAGS.model_path) + model, config, _, _, dataset_reader = load_archive(FLAGS.model_path) + + if config.get("tokenizer_language") is None: + logger.warning("Tokenizer language was not found in archive's configuration file - " + + "using the --tokenizer_language parameter (default: English)") + tokenizer_language = config.get("tokenizer_language", FLAGS.tokenizer_language) if not dataset_reader: logger.info("No dataset reader in the configuration or archive file - using a default UD dataset reader", prefix=prefix) dataset_reader = default_ud_dataset_reader(FLAGS.pretrained_transformer_name) + dataset_reader.tokenizer = LamboTokenizer(tokenizer_language) - predictor = COMBO(model, dataset_reader, LamboTokenizer(language=FLAGS.tokenizer_language)) + predictor = COMBO(model, dataset_reader) if FLAGS.input_file == '-': print("Interactive mode.") @@ -388,10 +394,7 @@ def run(_): except ConfigurationError as e: handle_error(e, prefix) - try: - pathlib.Path(FLAGS.output_file).mkdir(parents=True, exist_ok=True) - except FileExistsError: - pass + pathlib.Path(FLAGS.output_file).touch(exist_ok=True) logger.info("Predicting examples from file", prefix=prefix) @@ -406,7 +409,7 @@ def run(_): predictions.append(prediction) else: - tokenizer = LamboTokenizer(FLAGS.tokenizer_language) + tokenizer = LamboTokenizer(tokenizer_language) with open(FLAGS.input_file, "r") as file: input_sentences = tokenizer.segment(file.read()) with open(FLAGS.output_file, "w") as file: @@ -418,6 +421,9 @@ def run(_): if FLAGS.save_matrices: logger.info("Saving matrices", prefix=prefix) + if FLAGS.serialization_dir is None or pathlib.Path(FLAGS.serialization_dir).exists(): + logger.warning('Serialization path was not passed as an argument - skipping matrix extraction.') + return extract_combo_matrices(predictions, pathlib.Path(FLAGS.serialization_dir), pathlib.Path(FLAGS.input_file), -- GitLab