From f0811ea238aa6afbcc2eb6c910b906ead37c2501 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martyna=20Wi=C4=85cek?= <martyna.wiacek@ipipan.waw.pl> Date: Wed, 14 Feb 2024 11:33:45 +0100 Subject: [PATCH 1/3] fixed imports to be absolute, not relative --- combo/combo_model.py | 2 +- combo/modules/seq2seq_encoders/transformer_encoder.py | 9 +-------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/combo/combo_model.py b/combo/combo_model.py index d87d514..162a3b1 100644 --- a/combo/combo_model.py +++ b/combo/combo_model.py @@ -24,7 +24,7 @@ from combo.nn.utils import get_text_field_mask from combo.predictors import Predictor from combo.utils import metrics from combo.utils import ConfigurationError -from modules.seq2seq_encoders.transformer_encoder import TransformerEncoder +from combo.modules.seq2seq_encoders.transformer_encoder import TransformerEncoder @Registry.register("semantic_multitask") diff --git a/combo/modules/seq2seq_encoders/transformer_encoder.py b/combo/modules/seq2seq_encoders/transformer_encoder.py index a49100d..3b126fa 100644 --- a/combo/modules/seq2seq_encoders/transformer_encoder.py +++ b/combo/modules/seq2seq_encoders/transformer_encoder.py @@ -1,18 +1,11 @@ from typing import Optional -from overrides import overrides import torch from torch import nn from combo.modules.encoder import _EncoderBase from combo.config.from_parameters import FromParameters, register_arguments - -# from modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder -from nn.utils import add_positional_features - - -# from allennlp.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder -# from allennlp.nn.util import add_positional_features +from combo.nn.utils import add_positional_features class TransformerEncoder(_EncoderBase, FromParameters): -- GitLab From ed29c7a0ee9e289560424dec3f2ee280376197cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martyna=20Wi=C4=85cek?= <martyna.wiacek@ipipan.waw.pl> Date: Wed, 14 Feb 2024 11:43:54 +0100 Subject: [PATCH 2/3] add setting random seed --- combo/main.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/combo/main.py b/combo/main.py index 79e9820..db75cb7 100755 --- a/combo/main.py +++ b/combo/main.py @@ -4,8 +4,10 @@ import os import pathlib import tempfile from itertools import chain +import random from typing import Dict, Optional, Any, Tuple +import numpy.random.bit_generator import torch from absl import app, flags import pytorch_lightning as pl @@ -47,6 +49,7 @@ flags.DEFINE_integer(name="n_cuda_devices", default=-1, help="Number of devices to train on (default -1 auto mode - train on as many as possible)") flags.DEFINE_string(name="output_file", default="output.log", help="Predictions result file.") +flags.DEFINE_integer(name="seed", default=None, help="Random seed.") # Training flags flags.DEFINE_string(name="training_data_path", default="", help="Training data path(s)") @@ -293,7 +296,17 @@ def read_model_from_config(logging_prefix: str) -> Optional[ return model, dataset_reader, training_data_loader, validation_data_loader, vocabulary +def set_seed(seed: int) -> None: + random.seed(seed) + numpy.random.seed(seed) + torch.manual_seed(seed) + + def run(_): + + if FLAGS.seed: + set_seed(FLAGS.seed) + if FLAGS.mode == 'train': model, dataset_reader, training_data_loader, validation_data_loader, vocabulary = None, None, None, None, None -- GitLab From 99fe2a13cb544bae2342a3155feda6a4e929c62f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martyna=20Wi=C4=85cek?= <martyna.wiacek@ipipan.waw.pl> Date: Wed, 14 Feb 2024 11:46:55 +0100 Subject: [PATCH 3/3] fix numpy import --- combo/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/combo/main.py b/combo/main.py index db75cb7..73b3a26 100755 --- a/combo/main.py +++ b/combo/main.py @@ -7,7 +7,7 @@ from itertools import chain import random from typing import Dict, Optional, Any, Tuple -import numpy.random.bit_generator +import numpy import torch from absl import app, flags import pytorch_lightning as pl -- GitLab