Skip to content
Snippets Groups Projects
Select Git revision
  • 57349740f9e4009f457eabdf2dd6f9f9592f43dd
  • main default protected
  • ud_training_script
  • fix_seed
  • merged-with-ner
  • multiword_fix_transformer
  • transformer_encoder
  • combo3
  • save_deprel_matrix_to_npz
  • master protected
  • combo-lambo
  • lambo-sent-attributes
  • adding_lambo
  • develop
  • update_allenlp2
  • develop_tmp
  • tokens_truncation
  • LR_test
  • eud_iwpt
  • iob
  • eud_iwpt_shared_task_bert_finetuning
  • 3.3.1
  • list
  • 3.2.1
  • 3.0.3
  • 3.0.1
  • 3.0.0
  • v1.0.6
  • v1.0.5
  • v1.0.4
  • v1.0.3
  • v1.0.2
  • v1.0.1
  • v1.0.0
34 results

test_main.py

Blame
  • test_main.py 2.00 KiB
    """Testing training invocation with synthetic data."""
    import logging
    import os
    import pathlib
    import shutil
    import tempfile
    import unittest
    from allennlp.commands import train
    from allennlp.common import Params, util
    
    
    class TrainingEndToEndTest(unittest.TestCase):
        PROJECT_ROOT = (pathlib.Path(__file__).parent / "..").resolve()
        MODULE_ROOT = PROJECT_ROOT / "combo"
        TESTS_ROOT = PROJECT_ROOT / "tests"
        FIXTURES_ROOT = TESTS_ROOT / "fixtures"
        TEST_DIR = pathlib.Path("/tmp/test")
    
        def setUp(self) -> None:
            logging.getLogger("allennlp.common.util").disabled = True
            logging.getLogger("allennlp.training.tensorboard_writer").disabled = True
            logging.getLogger("allennlp.common.params").disabled = True
            logging.getLogger("allennlp.nn.initializers").disabled = True
    
        def test_training_produces_model(self):
            # given
            self.TEST_DIR.mkdir(exist_ok=True)
            util.import_module_and_submodules("combo.models")
            util.import_module_and_submodules("combo.training")
            ext_vars = {
                "training_data_path": os.path.join(self.FIXTURES_ROOT, "example.conllu"),
                "validation_data_path": os.path.join(self.FIXTURES_ROOT, "example.conllu"),
                "features": "token char",
                "targets": "deprel head lemma feats upostag xpostag",
                "type": "default",
                "pretrained_tokens": os.path.join(self.FIXTURES_ROOT, "example.vec"),
                "pretrained_transformer_name": "",
                "embedding_dim": "300",
                "cuda_device": "-1",
                "num_epochs": "1",
                "word_batch_size": "1",
                "use_tensorboard": "False"
            }
            params = Params.from_file(os.path.join(self.PROJECT_ROOT, "config.template.jsonnet"),
                                      ext_vars=ext_vars)
    
            # when
            model = train.train_model(params, serialization_dir=self.TEST_DIR)
    
            # then
            self.assertIsNotNone(model)
    
        def tearDown(self) -> None:
            shutil.rmtree(self.TEST_DIR)