Skip to content
Snippets Groups Projects
Commit 3e11893a authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Make targets and features flags as list.

parent 90c6fbee
No related merge requests found
...@@ -24,7 +24,7 @@ Examples (for clarity without training/validation data paths): ...@@ -24,7 +24,7 @@ Examples (for clarity without training/validation data paths):
* train on gpu 0 * train on gpu 0
```bash ```bash
combo --mode train --cuda_davice 0 combo --mode train --cuda_device 0
``` ```
* use pretrained embeddings: * use pretrained embeddings:
...@@ -42,13 +42,13 @@ Examples (for clarity without training/validation data paths): ...@@ -42,13 +42,13 @@ Examples (for clarity without training/validation data paths):
* predict only dependency tree: * predict only dependency tree:
```bash ```bash
combo --mode train --targets head --targets deprel combo --mode train --targets head,deprel
``` ```
* use part-of-speech tags for predicting only dependency tree * use part-of-speech tags for predicting only dependency tree
```bash ```bash
combo --mode train --targets head --targets deprel --features token --features char --features upostag combo --mode train --targets head,deprel --features token,char,upostag
``` ```
Advanced configuration: [Configuration](#configuration) Advanced configuration: [Configuration](#configuration)
......
...@@ -69,7 +69,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): ...@@ -69,7 +69,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
@overrides @overrides
def _read(self, file_path: str) -> Iterable[allen_data.Instance]: def _read(self, file_path: str) -> Iterable[allen_data.Instance]:
file_path = [file_path] if len(file_path.split(":")) == 0 else file_path.split(":") file_path = [file_path] if len(file_path.split(",")) == 0 else file_path.split(",")
for conllu_file in file_path: for conllu_file in file_path:
with open(conllu_file, "r") as file: with open(conllu_file, "r") as file:
......
...@@ -17,6 +17,8 @@ from combo.data import dataset ...@@ -17,6 +17,8 @@ from combo.data import dataset
from combo.utils import checks from combo.utils import checks
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_FEATURES = ["token", "char", "upostag", "xpostag", "lemma", "feats"]
_TARGETS = ["deprel", "feats", "head", "lemma", "upostag", "xpostag", "semrel", "sent"]
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_enum(name="mode", default=None, enum_values=["train", "predict"], flags.DEFINE_enum(name="mode", default=None, enum_values=["train", "predict"],
...@@ -30,9 +32,9 @@ flags.DEFINE_string(name="output_file", default="output.log", ...@@ -30,9 +32,9 @@ flags.DEFINE_string(name="output_file", default="output.log",
# Training flags # Training flags
flags.DEFINE_list(name="training_data_path", default="./tests/fixtures/example.conllu", flags.DEFINE_list(name="training_data_path", default="./tests/fixtures/example.conllu",
help="Training data path") help="Training data path(s)")
flags.DEFINE_list(name="validation_data_path", default="", flags.DEFINE_list(name="validation_data_path", default="",
help="Validation data path") help="Validation data path(s)")
flags.DEFINE_string(name="pretrained_tokens", default="", flags.DEFINE_string(name="pretrained_tokens", default="",
help="Pretrained tokens embeddings path") help="Pretrained tokens embeddings path")
flags.DEFINE_integer(name="embedding_dim", default=300, flags.DEFINE_integer(name="embedding_dim", default=300,
...@@ -42,14 +44,12 @@ flags.DEFINE_integer(name="num_epochs", default=400, ...@@ -42,14 +44,12 @@ flags.DEFINE_integer(name="num_epochs", default=400,
flags.DEFINE_integer(name="word_batch_size", default=2500, flags.DEFINE_integer(name="word_batch_size", default=2500,
help="Minimum words in batch") help="Minimum words in batch")
flags.DEFINE_string(name="pretrained_transformer_name", default="", flags.DEFINE_string(name="pretrained_transformer_name", default="",
help="Pretrained transformer model name (see transformers from HuggingFace library for list of" help="Pretrained transformer model name (see transformers from HuggingFace library for list of "
"available models) for transformers based embeddings.") "available models) for transformers based embeddings.")
flags.DEFINE_multi_enum(name="features", default=["token", "char"], flags.DEFINE_list(name="features", default=["token", "char"],
enum_values=["token", "char", "upostag", "xpostag", "lemma", "feats"], help=f"Features used to train model (required 'token' and 'char'). Possible values: {_FEATURES}.")
help="Features used to train model (required 'token' and 'char')") flags.DEFINE_list(name="targets", default=["deprel", "feats", "head", "lemma", "upostag", "xpostag"],
flags.DEFINE_multi_enum(name="targets", default=["deprel", "feats", "head", "lemma", "upostag", "xpostag"], help=f"Targets of the model (required `deprel` and `head`). Possible values: {_TARGETS}.")
enum_values=["deprel", "feats", "head", "lemma", "upostag", "xpostag", "semrel", "sent"],
help="Targets of the model (required `deprel` and `head`)")
flags.DEFINE_string(name="serialization_dir", default=None, flags.DEFINE_string(name="serialization_dir", default=None,
help="Model serialization directory (default - system temp dir).") help="Model serialization directory (default - system temp dir).")
flags.DEFINE_boolean(name="tensorboard", default=False, flags.DEFINE_boolean(name="tensorboard", default=False,
...@@ -189,10 +189,22 @@ def _get_ext_vars(finetuning: bool = False) -> Dict: ...@@ -189,10 +189,22 @@ def _get_ext_vars(finetuning: bool = False) -> Dict:
def main(): def main():
"""Parse flags.""" """Parse flags."""
flags.register_validator(
"features",
lambda values: all(
value in _FEATURES for value in values),
message="Flags --features contains unknown value(s)."
)
flags.register_validator( flags.register_validator(
"mode", "mode",
lambda value: value is not None, lambda value: value is not None,
message="Flag --mode must be set with either `predict` or `train` value") message="Flag --mode must be set with either `predict` or `train` value")
flags.register_validator(
"targets",
lambda values: all(
value in _TARGETS for value in values),
message="Flag --targets contains unknown value(s)."
)
app.run(run) app.run(run)
......
...@@ -3,10 +3,10 @@ ...@@ -3,10 +3,10 @@
######################################################################################## ########################################################################################
# Training data path, str # Training data path, str
# Must be in CONNLU format (or it's extended version with semantic relation field). # Must be in CONNLU format (or it's extended version with semantic relation field).
# Can accepted multiple paths when concatenated with ':', "path1:path2" # Can accepted multiple paths when concatenated with ',', "path1,path2"
local training_data_path = std.extVar("training_data_path"); local training_data_path = std.extVar("training_data_path");
# Validation data path, str # Validation data path, str
# Can accepted multiple paths when concatenated with ':', "path1:path2" # Can accepted multiple paths when concatenated with ',', "path1,path2"
local validation_data_path = if std.length(std.extVar("validation_data_path")) > 0 then std.extVar("validation_data_path"); local validation_data_path = if std.length(std.extVar("validation_data_path")) > 0 then std.extVar("validation_data_path");
# Path to pretrained tokens, str or null # Path to pretrained tokens, str or null
local pretrained_tokens = if std.length(std.extVar("pretrained_tokens")) > 0 then std.extVar("pretrained_tokens"); local pretrained_tokens = if std.length(std.extVar("pretrained_tokens")) > 0 then std.extVar("pretrained_tokens");
...@@ -36,13 +36,13 @@ local embedding_dim = std.parseInt(std.extVar("embedding_dim")); ...@@ -36,13 +36,13 @@ local embedding_dim = std.parseInt(std.extVar("embedding_dim"));
local predictors_dropout = 0.25; local predictors_dropout = 0.25;
# Xpostag embedding dimension, int # Xpostag embedding dimension, int
# (discarded if xpostag not in features) # (discarded if xpostag not in features)
local xpostag_dim = 100; local xpostag_dim = 32;
# Upostag embedding dimension, int # Upostag embedding dimension, int
# (discarded if upostag not in features) # (discarded if upostag not in features)
local upostag_dim = 100; local upostag_dim = 32;
# Feats embedding dimension, int # Feats embedding dimension, int
# (discarded if feats not in featres) # (discarded if feats not in featres)
local feats_dim = 100; local feats_dim = 32;
# Lemma embedding dimension, int # Lemma embedding dimension, int
# (discarded if lemma not in features) # (discarded if lemma not in features)
local lemma_char_dim = 64; local lemma_char_dim = 64;
......
...@@ -3,11 +3,11 @@ from setuptools import find_packages, setup ...@@ -3,11 +3,11 @@ from setuptools import find_packages, setup
REQUIREMENTS = [ REQUIREMENTS = [
'absl-py==0.9.0', 'absl-py==0.9.0',
'allennlp==1.0.0rc4', 'allennlp==1.0.0rc5',
'conllu==2.3.2', 'conllu==2.3.2',
'joblib==0.14.1', 'joblib==0.14.1',
'jsonnet==0.15.0', 'jsonnet==0.15.0',
'overrides==2.8.0', 'overrides==3.0.0',
'tensorboard==2.1.0', 'tensorboard==2.1.0',
'torch==1.5.0', 'torch==1.5.0',
'torchvision==0.6.0', 'torchvision==0.6.0',
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment