Skip to content
Snippets Groups Projects
Commit dfec6d56 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Fix herberta training.

parent 9d339859
Branches
Tags
2 merge requests!4Documentation,!3Herbert configuration and AllenNLP 1.2.0 update.
This commit is part of merge request !4. Comments created here will be created in the context of that merge request.
## Installation
### HERBERTA notes:
Install herberta transformers package **before** running command below
Clone this repository and run:
```bash
python setup.py develop
......@@ -86,7 +82,7 @@ Input: one sentence per line.
Output: List of token jsons.
```bash
combo --mode predict --model_path your_model_tar_gz --input_file your_text_file --output_file your_output_file --silent
combo --mode predict --model_path your_model_tar_gz --input_file your_text_file --output_file your_output_file --silent --noconllu_format
```
#### Advanced
......
......@@ -20,6 +20,7 @@ class Token:
deprel: Optional[str] = None
deps: Optional[str] = None
misc: Optional[str] = None
semrel: Optional[str] = None
@dataclass_json
......@@ -37,8 +38,14 @@ class _TokenList(conllu.TokenList):
return 'TokenList<' + ', '.join(token['token'] for token in self) + '>'
def sentence2conllu(sentence: Sentence) -> conllu.TokenList:
tokens = [collections.OrderedDict(t.to_dict()) for t in sentence.tokens]
def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.TokenList:
tokens = []
for token in sentence.tokens:
token_dict = collections.OrderedDict(token.to_dict())
# Remove semrel to have default conllu format.
if not keep_semrel:
del token_dict["semrel"]
tokens.append(token_dict)
# Range tokens must be tuple not list, this is conllu library requirement
for t in tokens:
if type(t["id"]) == list:
......
......@@ -41,7 +41,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
"Features and targets cannot share elements! "
"Remove {} from either features or targets.".format(intersection)
)
self._use_sem = use_sem
self.use_sem = use_sem
# *.conllu readers configuration
fields = list(parser.DEFAULT_FIELDS)
......@@ -49,7 +49,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
field_parsers = parser.DEFAULT_FIELD_PARSERS
# Do not make it nullable
field_parsers.pop("xpostag", None)
if self._use_sem:
if self.use_sem:
fields = list(fields)
fields.append("semrel")
field_parsers["semrel"] = lambda line, i: line[i]
......@@ -113,8 +113,23 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
fields_[target_name] = allen_fields.SequenceLabelField(target_values, text_field,
label_namespace=target_name + "_labels")
# Restore feats fields to string representation
# parser.serialize_field doesn't handle key without value
for token in tree.tokens:
if "feats" in token:
feats = token["feats"]
if feats:
feats_values = []
for k, v in feats.items():
feats_values.append('='.join((k, v)) if v else k)
field = "|".join(feats_values)
else:
field = "_"
token["feats"] = field
# metadata
fields_["metadata"] = allen_fields.MetadataField({"input": tree, "field_names": self.fields})
return allen_data.Instance(fields_)
@staticmethod
......
......@@ -13,7 +13,7 @@ from allennlp.common import checks as allen_checks, util
from allennlp.models import archival
from combo import predict
from combo.data import dataset
from combo.data import api, dataset
from combo.utils import checks
logger = logging.getLogger(__name__)
......@@ -76,6 +76,8 @@ flags.DEFINE_string(name="model_path", default=None,
help="Pretrained model path.")
flags.DEFINE_string(name="input_file", default=None,
help="File to predict path")
flags.DEFINE_boolean(name="conllu_format", default=True,
help="Prediction based on conllu format (instead of raw text).")
flags.DEFINE_integer(name="batch_size", default=1,
help="Prediction batch size.")
flags.DEFINE_boolean(name="silent", default=True,
......@@ -136,13 +138,13 @@ def run(_):
model=model,
dataset_reader=dataset_reader
)
test_path = FLAGS.test_path
test_trees = dataset_reader.read(test_path)
test_trees = dataset_reader.read(FLAGS.test_path)
with open(FLAGS.output_file, "w") as file:
for tree in test_trees:
file.writelines(predictor.predict_instance_as_tree(tree).serialize())
file.writelines(api.sentence2conllu(predictor.predict_instance(tree),
keep_semrel=dataset_reader.use_sem).serialize())
else:
use_dataset_reader = ".conllu" in FLAGS.input_file.lower()
use_dataset_reader = FLAGS.conllu_format
predictor = _get_predictor()
if use_dataset_reader:
predictor.line_to_conllu = True
......
......@@ -128,7 +128,7 @@ class SemanticMultitaskPredictor(predictor.Predictor):
# Check whether serialized (str) tree or token's list
# Serialized tree has already separators between lines
if self.line_to_conllu:
return sentence2conllu(outputs).serialize()
return sentence2conllu(outputs, keep_semrel=self._dataset_reader.use_sem).serialize()
else:
return outputs.to_json()
......
......@@ -127,6 +127,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
val_reg_loss,
num_batches=num_batches,
batch_loss=None,
batch_reg_loss=None,
reset=True,
world_size=self._world_size,
cuda_device=self.cuda_device,
......
......@@ -14,7 +14,7 @@ REQUIREMENTS = [
'torch==1.6.0',
'tqdm==4.43.0',
'transformers>=3.0.0,<3.1.0',
'urllib3==1.24.2',
'urllib3>=1.25.11',
]
setup(
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment