From 7cc3f2c2943d51ffe464fc63608a226ed7831c30 Mon Sep 17 00:00:00 2001
From: Maja Jablonska <majajjablonska@gmail.com>
Date: Mon, 16 Oct 2023 15:51:00 +1100
Subject: [PATCH] Add prediction to main.py

---
 combo/data/vocabulary.py |  2 +-
 combo/main.py            | 20 +++-----------------
 2 files changed, 4 insertions(+), 18 deletions(-)

diff --git a/combo/data/vocabulary.py b/combo/data/vocabulary.py
index 3cf5fed..1c0a384 100644
--- a/combo/data/vocabulary.py
+++ b/combo/data/vocabulary.py
@@ -300,7 +300,6 @@ class Vocabulary(FromParameters):
                 filename = os.path.join(directory, namespace_filename)
                 vocab.set_from_file(filename, is_padded, namespace=namespace, oov_token=oov_token)
 
-        get_slices_if_not_provided(vocab)
         vocab.constructed_from = 'from_files'
         return vocab
 
@@ -758,6 +757,7 @@ class Vocabulary(FromParameters):
 def get_slices_if_not_provided(vocab: Vocabulary):
     if hasattr(vocab, "slices"):
         return vocab.slices
+    print("Getting slices...")
 
     if "feats_labels" in vocab.get_namespaces():
         idx2token = vocab.get_index_to_token_vocabulary("feats_labels")
diff --git a/combo/main.py b/combo/main.py
index a426b10..aeb1dc8 100755
--- a/combo/main.py
+++ b/combo/main.py
@@ -92,25 +92,11 @@ flags.DEFINE_enum(name="predictor_name", default="combo-lambo",
                   help="Use predictor with whitespace, spacy or lambo (recommended) tokenizer.")
 
 
-def get_saved_model(parameters) -> ComboModel:
-    return ComboModel.load(os.path.join(FLAGS.model_path),
-                           config=parameters,
-                           weights_file=os.path.join(FLAGS.model_path, 'best.th'),
-                           cuda_device=FLAGS.cuda_device)
-
-
 def get_predictor() -> COMBO:
-    # Check for GPU
-    # allen_checks.check_for_gpu(FLAGS.cuda_device)
     checks.file_exists(FLAGS.model_path)
-    with open(os.path.join(FLAGS.model_path, 'params.json'), 'r') as f:
-        serialized = json.load(f)
-    model = get_saved_model(serialized)
-    if 'dataset_reader' in serialized:
-        dataset_reader = resolve(serialized['dataset_reader'])
-    else:
-        dataset_reader = default_ud_dataset_reader()
-    return COMBO(model, dataset_reader)
+    arch = load_archive(FLAGS.model_path)
+    dataset_reader = default_ud_dataset_reader()
+    return COMBO(arch.model, dataset_reader)
 
 
 def run(_):
-- 
GitLab