diff --git a/combo/data/api.py b/combo/data/api.py index 10a3a727c9220601ebf243752a7e605e127a1774..ca8f75a01c4725d039729efc2997348d74d56b71 100644 --- a/combo/data/api.py +++ b/combo/data/api.py @@ -50,6 +50,10 @@ def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.Toke for t in tokens: if type(t["id"]) == list: t["id"] = tuple(t["id"]) + if t["deps"]: + for dep in t["deps"]: + if type(dep[1]) == list: + dep[1] = tuple(dep[1]) return _TokenList(tokens=tokens, metadata=sentence.metadata) diff --git a/scripts/train.py b/scripts/train.py index 7ca0fce656fde639c27d3087f6011ef8fb9ab142..939088800f772c113693eb6d0858304ed82f766d 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -135,7 +135,7 @@ def run(_): embeddings_dir = FLAGS.embeddings_dir embeddings_file = None if embeddings_dir: - embeddings_dir = embeddings_dir / language + embeddings_dir = pathlib.Path(embeddings_dir) / language embeddings_file = [f for f in embeddings_dir.iterdir() if "vectors" in f.name and ".vec.gz" in f.name] assert len(embeddings_file) == 1, f"Couldn't find embeddings file." embeddings_file = embeddings_file[0]