From d9abf6eacdfefe3843191f203eae516a2222073f Mon Sep 17 00:00:00 2001
From: Maja Jablonska <majajjablonska@gmail.com>
Date: Tue, 16 Jan 2024 20:37:21 +0100
Subject: [PATCH] Correct config.json naming bug

---
 combo/modules/archival.py | 6 +++---
 combo/modules/model.py    | 2 +-
 setup.py                  | 4 ++--
 3 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/combo/modules/archival.py b/combo/modules/archival.py
index 63a4cd8..496753f 100644
--- a/combo/modules/archival.py
+++ b/combo/modules/archival.py
@@ -81,7 +81,7 @@ def archive(model: Model,
     with (TemporaryDirectory(os.path.join('tmp')) as t,
           BytesIO() as out_stream,
           tarfile.open(os.path.join(serialization_dir, 'model.tar.gz'), 'w|gz') as tar_file):
-        add_to_tar(tar_file, out_stream, json.dumps(parameters).encode(), 'config.template.json')
+        add_to_tar(tar_file, out_stream, json.dumps(parameters).encode(), 'config.json')
         weights_path = os.path.join(t, 'weights.th')
         torch.save(model.state_dict(), weights_path)
         tar_file.add(weights_path, 'weights.th')
@@ -100,7 +100,7 @@ def extracted_archive(resolved_archive_file, cleanup=True):
         with tarfile.open(resolved_archive_file) as archive:
             subdir_and_files = [
                 tarinfo for tarinfo in archive.getmembers()
-                if (any([tarinfo.name.endswith(f) for f in ['config.template.json', 'weights.th']])
+                if (any([tarinfo.name.endswith(f) for f in ['config.json', 'weights.th']])
                     or 'vocabulary' in tarinfo.name)
             ]
             for f in subdir_and_files:
@@ -127,7 +127,7 @@ def load_archive(url_or_filename: Union[PathLike, str],
     with extracted_archive(rarchive_file) as archive_file:
         model = Model.load(archive_file, cuda_device=cuda_device)
 
-        config_path = os.path.join(archive_file, 'config.template.json')
+        config_path = os.path.join(archive_file, 'config.json')
         with open(config_path, 'r') as f:
             config = json.load(f)
 
diff --git a/combo/modules/model.py b/combo/modules/model.py
index 93b9740..76e83be 100644
--- a/combo/modules/model.py
+++ b/combo/modules/model.py
@@ -447,7 +447,7 @@ class Model(Module, pl.LightningModule, FromParameters):
             vocabulary and the trained weights.
         """
         if config is None:
-            with open(os.path.join(serialization_dir, 'config.template.json'), 'r') as f:
+            with open(os.path.join(serialization_dir, 'config.json'), 'r') as f:
                 config = json.load(f)
         elif isinstance(config, str) or isinstance(config, PathLike):
             with open(config, 'r') as f:
diff --git a/setup.py b/setup.py
index 05c1991..6943e29 100644
--- a/setup.py
+++ b/setup.py
@@ -27,8 +27,8 @@ REQUIREMENTS = [
 ]
 
 setup(
-    name="combo",
-    version="3.0.0",
+    name="combo-nlp",
+    version="3.0.1",
     author="Maja Jablonska",
     author_email="maja.jablonska@ipipan.waw.pl",
     install_requires=REQUIREMENTS,
-- 
GitLab