From bea72f566d8a0a432df994c542017c5c6f37187a Mon Sep 17 00:00:00 2001 From: piotrmp <piotr.m.przybyla@gmail.com> Date: Thu, 24 Nov 2022 21:38:08 +0100 Subject: [PATCH] Reduced window size. --- src/lambo/evaluation/evaluate.py | 15 ++++++++++----- src/lambo/learning/train.py | 4 ++-- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/lambo/evaluation/evaluate.py b/src/lambo/evaluation/evaluate.py index f4bb70c..b86a62f 100644 --- a/src/lambo/evaluation/evaluate.py +++ b/src/lambo/evaluation/evaluate.py @@ -1,4 +1,4 @@ -from lambo.evaluation.conll18_ud_eval import load_conllu, evaluate +from lambo.evaluation.conll18_ud_eval import load_conllu, evaluate, UDError from lambo.utils.printer import print_document_to_conll @@ -19,8 +19,13 @@ def evaluate_segmenter(segmenter, test_text, gold_path, tmp_path): with open(gold_path) as fGold: pred = load_conllu(fPred) gold = load_conllu(fGold) - conll_result = evaluate(gold, pred) - for category in ['Tokens', 'Words', 'Sentences']: - result[category] = {'F1': conll_result[category].f1, 'precision': conll_result[category].precision, - 'recall': conll_result[category].recall} + try: + conll_result = evaluate(gold, pred) + for category in ['Tokens', 'Words', 'Sentences']: + result[category] = {'F1': conll_result[category].f1, 'precision': conll_result[category].precision, + 'recall': conll_result[category].recall} + except UDError as e: + for category in ['Tokens', 'Words', 'Sentences']: + result[category] = {'F1': 0.0, 'precision': 0.0, + 'recall': 0.0} return result diff --git a/src/lambo/learning/train.py b/src/lambo/learning/train.py index d3f3a06..a0efdfc 100644 --- a/src/lambo/learning/train.py +++ b/src/lambo/learning/train.py @@ -121,7 +121,7 @@ def train_new_and_save(model_name, treebank_path, save_path, epochs=10, device=' BATCH_SIZE = 32 print("Initiating the model.") - MAX_LEN = 1024 + MAX_LEN = 256 dict, train_dataloader, test_dataloader = prepare_dataloaders_withdict([train_doc, dev_doc], [test_doc], MAX_LEN, BATCH_SIZE) @@ -168,7 +168,7 @@ def train_pretrained_and_save(language, treebank_path, save_path, pretrained_pat train_doc, dev_doc, test_doc = read_treebank(treebank_path, True) print("Initiating the model.") - MAX_LEN = 1024 + MAX_LEN = 256 model = LamboNetwork(MAX_LEN, dict, len(utf_category_dictionary), pretrained=pretrained_model) print("Preparing data") -- GitLab