From 7cc3f2c2943d51ffe464fc63608a226ed7831c30 Mon Sep 17 00:00:00 2001 From: Maja Jablonska <majajjablonska@gmail.com> Date: Mon, 16 Oct 2023 15:51:00 +1100 Subject: [PATCH] Add prediction to main.py --- combo/data/vocabulary.py | 2 +- combo/main.py | 20 +++----------------- 2 files changed, 4 insertions(+), 18 deletions(-) diff --git a/combo/data/vocabulary.py b/combo/data/vocabulary.py index 3cf5fed..1c0a384 100644 --- a/combo/data/vocabulary.py +++ b/combo/data/vocabulary.py @@ -300,7 +300,6 @@ class Vocabulary(FromParameters): filename = os.path.join(directory, namespace_filename) vocab.set_from_file(filename, is_padded, namespace=namespace, oov_token=oov_token) - get_slices_if_not_provided(vocab) vocab.constructed_from = 'from_files' return vocab @@ -758,6 +757,7 @@ class Vocabulary(FromParameters): def get_slices_if_not_provided(vocab: Vocabulary): if hasattr(vocab, "slices"): return vocab.slices + print("Getting slices...") if "feats_labels" in vocab.get_namespaces(): idx2token = vocab.get_index_to_token_vocabulary("feats_labels") diff --git a/combo/main.py b/combo/main.py index a426b10..aeb1dc8 100755 --- a/combo/main.py +++ b/combo/main.py @@ -92,25 +92,11 @@ flags.DEFINE_enum(name="predictor_name", default="combo-lambo", help="Use predictor with whitespace, spacy or lambo (recommended) tokenizer.") -def get_saved_model(parameters) -> ComboModel: - return ComboModel.load(os.path.join(FLAGS.model_path), - config=parameters, - weights_file=os.path.join(FLAGS.model_path, 'best.th'), - cuda_device=FLAGS.cuda_device) - - def get_predictor() -> COMBO: - # Check for GPU - # allen_checks.check_for_gpu(FLAGS.cuda_device) checks.file_exists(FLAGS.model_path) - with open(os.path.join(FLAGS.model_path, 'params.json'), 'r') as f: - serialized = json.load(f) - model = get_saved_model(serialized) - if 'dataset_reader' in serialized: - dataset_reader = resolve(serialized['dataset_reader']) - else: - dataset_reader = default_ud_dataset_reader() - return COMBO(model, dataset_reader) + arch = load_archive(FLAGS.model_path) + dataset_reader = default_ud_dataset_reader() + return COMBO(arch.model, dataset_reader) def run(_): -- GitLab