Skip to content
Snippets Groups Projects
Select Git revision
  • 5fd4d9598fe7a36a37e3464414b81a49963065f2
  • master default protected
  • develop protected
  • develop-0.7.x
  • develop-0.8.0
  • dev_czuk
  • loader
  • kgr10_roberta
  • 14-BiLSTM-CRF-RoBERTa
  • 12-handle-long-sequences
  • 13-flair-embeddings
  • BiLSTM
  • v0.7.0
  • v0.6.1
  • v0.5
  • v0.4.1
  • v0.3
17 results

setup.py

Blame
  • test_main.py 2.00 KiB
    """Testing training invocation with synthetic data."""
    import logging
    import os
    import pathlib
    import shutil
    import tempfile
    import unittest
    from allennlp.commands import train
    from allennlp.common import Params, util
    
    
    class TrainingEndToEndTest(unittest.TestCase):
        PROJECT_ROOT = (pathlib.Path(__file__).parent / "..").resolve()
        MODULE_ROOT = PROJECT_ROOT / "combo"
        TESTS_ROOT = PROJECT_ROOT / "tests"
        FIXTURES_ROOT = TESTS_ROOT / "fixtures"
        TEST_DIR = pathlib.Path("/tmp/test")
    
        def setUp(self) -> None:
            logging.getLogger("allennlp.common.util").disabled = True
            logging.getLogger("allennlp.training.tensorboard_writer").disabled = True
            logging.getLogger("allennlp.common.params").disabled = True
            logging.getLogger("allennlp.nn.initializers").disabled = True
    
        def test_training_produces_model(self):
            # given
            self.TEST_DIR.mkdir(exist_ok=True)
            util.import_module_and_submodules("combo.models")
            util.import_module_and_submodules("combo.training")
            ext_vars = {
                "training_data_path": os.path.join(self.FIXTURES_ROOT, "example.conllu"),
                "validation_data_path": os.path.join(self.FIXTURES_ROOT, "example.conllu"),
                "features": "token char",
                "targets": "deprel head lemma feats upostag xpostag",
                "type": "default",
                "pretrained_tokens": os.path.join(self.FIXTURES_ROOT, "example.vec"),
                "pretrained_transformer_name": "",
                "embedding_dim": "300",
                "cuda_device": "-1",
                "num_epochs": "1",
                "word_batch_size": "1",
                "use_tensorboard": "False"
            }
            params = Params.from_file(os.path.join(self.PROJECT_ROOT, "config.template.jsonnet"),
                                      ext_vars=ext_vars)
    
            # when
            model = train.train_model(params, serialization_dir=self.TEST_DIR)
    
            # then
            self.assertIsNotNone(model)
    
        def tearDown(self) -> None:
            shutil.rmtree(self.TEST_DIR)