Skip to content
Snippets Groups Projects
Commit fb9eb280 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Working scratch of joint CoNNL-U and IOB based training.

Script for preprocessing .iob NKJP file.
Multitask dataset readers/data loaders/model using AllenNLP toolkit.
parent 63d3cf6f
Branches
No related tags found
No related merge requests found
Pipeline #2904 passed
# Configuration file for jointly training a model using CoNNL-U and IOB.
local shared_config = import "config.shared.libsonnet";
########################################################################################
# BASIC configuration #
########################################################################################
# Training data path, str
# Must be in CONNLU format (or it's extended version with semantic relation field).
# Can accepted multiple paths when concatenated with ',', "path1,path2"
local training_data_path = std.extVar("training_data_path");
# Validation data path, str
# Can accepted multiple paths when concatenated with ',', "path1,path2"
local validation_data_path = if std.length(std.extVar("validation_data_path")) > 0 then std.extVar("validation_data_path");
# Path to pretrained tokens, str or null
local pretrained_tokens = if std.length(std.extVar("pretrained_tokens")) > 0 then std.extVar("pretrained_tokens");
# Name of pretrained transformer model, str or null
local pretrained_transformer_name = if std.length(std.extVar("pretrained_transformer_name")) > 0 then std.extVar("pretrained_transformer_name");
# Learning rate value, float
local learning_rate = 0.002;
# Number of epochs, int
local num_epochs = std.parseInt(std.extVar("num_epochs"));
# Cuda device id, -1 for cpu, int
local cuda_device = std.parseInt(std.extVar("cuda_device"));
# Minimum number of words in batch, int
local word_batch_size = std.parseInt(std.extVar("word_batch_size"));
# Features used as input, list of str
# Choice "upostag", "xpostag", "lemma"
# Required "token", "char"
local features = std.split(std.extVar("features"), " ");
# Targets of the model, list of str
# Choice "feats", "lemma", "upostag", "xpostag", "semrel". "sent"
# Required "deprel", "head"
local targets = std.split(std.extVar("targets"), " ");
# Word embedding dimension, int
# If pretrained_tokens is not null must much provided dimensionality
local embedding_dim = std.parseInt(std.extVar("embedding_dim"));
# Dropout rate on predictors, float
# All of the models on top of the encoder uses this dropout
local predictors_dropout = 0.25;
# Xpostag embedding dimension, int
# (discarded if xpostag not in features)
local xpostag_dim = 32;
# Upostag embedding dimension, int
# (discarded if upostag not in features)
local upostag_dim = 32;
# Feats embedding dimension, int
# (discarded if feats not in featres)
local feats_dim = 32;
# Lemma embedding dimension, int
# (discarded if lemma not in features)
local lemma_char_dim = 64;
# Character embedding dim, int
local char_dim = 64;
# Word embedding projection dim, int
local projected_embedding_dim = 100;
# Loss weights, dict[str, int]
local loss_weights = {
xpostag: 0.05,
upostag: 0.05,
lemma: 0.05,
feats: 0.2,
deprel: 0.8,
head: 0.2,
semrel: 0.05,
};
# Encoder hidden size, int
local hidden_size = 512;
# Number of layers in the encoder, int
local num_layers = 2;
# Cycle loss iterations, int
local cycle_loss_n = 0;
# Maximum length of the word, int
# Shorter words are padded, longer - truncated
local word_length = 30;
# Whether to use tensorboard, bool
local use_tensorboard = if std.extVar("use_tensorboard") == "True" then true else false;
# Helper functions
local in_features(name) = !(std.length(std.find(name, features)) == 0);
local in_targets(name) = !(std.length(std.find(name, targets)) == 0);
local use_transformer = pretrained_transformer_name != null;
# Verify some configuration requirements
assert in_features("token"): "Key 'token' must be in features!";
assert in_features("char"): "Key 'char' must be in features!";
assert in_targets("deprel"): "Key 'deprel' must be in targets!";
assert in_targets("head"): "Key 'head' must be in targets!";
assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't use pretrained tokens and pretrained transformer at the same time!";
########################################################################################
# ADVANCED configuration #
########################################################################################
# Detailed dataset, training, vocabulary and model configuration.
{
# Configuration type (default or finetuning), str
type: std.extVar('type'),
# Datasets used for vocab creation, list of str
# Choice "train", "valid"
datasets_for_vocab_creation: ['train'],
# Path to training data, str
train_data_path: {
conllu: training_data_path,
# TODO Add configuration
iob: "./data/nkjp-nested-simplified-v2.fixed.iob",
},
# Path to validation data, str
validation_data_path: validation_data_path,
# Dataset reader configuration (conllu format)
dataset_reader: {
type: "multitask",
readers: {
"conllu": {
type: "conllu",
features: features,
targets: targets,
# Whether data contains semantic relation field, bool
use_sem: if in_targets("semrel") then true else false,
token_indexers: {
token: if use_transformer then {
type: "pretrained_transformer_mismatched",
model_name: pretrained_transformer_name,
} else {
# SingleIdTokenIndexer, token as single int
type: "single_id",
},
upostag: {
type: "single_id",
namespace: "upostag",
feature_name: "pos_",
},
xpostag: {
type: "single_id",
namespace: "xpostag",
feature_name: "tag_",
},
lemma: {
type: "characters_const_padding",
character_tokenizer: {
start_tokens: ["__START__"],
end_tokens: ["__END__"],
},
# +2 for start and end token
min_padding_length: word_length + 2,
},
char: {
type: "characters_const_padding",
character_tokenizer: {
start_tokens: ["__START__"],
end_tokens: ["__END__"],
},
# +2 for start and end token
min_padding_length: word_length + 2,
},
feats: {
type: "feats_indexer",
},
},
lemma_indexers: {
char: {
type: "characters_const_padding",
namespace: "lemma_characters",
character_tokenizer: {
start_tokens: ["__START__"],
end_tokens: ["__END__"],
},
# +2 for start and end token
min_padding_length: word_length + 2,
},
},
},
iob: {
type: "iob",
tag_label: "ner",
label_namespace: "ner_labels",
token_indexers: {
token: if use_transformer then {
type: "pretrained_transformer_mismatched",
model_name: pretrained_transformer_name,
} else {
# SingleIdTokenIndexer, token as single int
type: "single_id",
},
char: {
type: "characters_const_padding",
character_tokenizer: {
start_tokens: ["__START__"],
end_tokens: ["__END__"],
},
# +2 for start and end token
min_padding_length: word_length + 2,
},
},
},
},
},
# Data loader configuration
data_loader: {
type: "multitask",
scheduler: {
batch_size: 10
},
shuffle: true,
// batch_sampler: {
// type: "token_count",
// word_batch_size: word_batch_size,
// },
},
# Vocabulary configuration
vocabulary: std.prune({
type: "from_instances_extended",
only_include_pretrained_words: true,
pretrained_files: {
tokens: pretrained_tokens,
},
oov_token: "_",
padding_token: "__PAD__",
non_padded_namespaces: ["head_labels"],
}),
model: std.prune({
type: "multitask_extended",
backbone: {
type: "combo_backbone",
text_field_embedder: {
type: "basic",
token_embedders: {
xpostag: if in_features("xpostag") then {
type: "embedding",
padding_index: 0,
embedding_dim: xpostag_dim,
vocab_namespace: "xpostag",
},
upostag: if in_features("upostag") then {
type: "embedding",
padding_index: 0,
embedding_dim: upostag_dim,
vocab_namespace: "upostag",
},
token: if use_transformer then {
type: "transformers_word_embeddings",
model_name: pretrained_transformer_name,
projection_dim: projected_embedding_dim,
} else {
type: "embeddings_projected",
embedding_dim: embedding_dim,
projection_layer: {
in_features: embedding_dim,
out_features: projected_embedding_dim,
dropout_rate: 0.25,
activation: "tanh"
},
vocab_namespace: "tokens",
pretrained_file: pretrained_tokens,
trainable: if pretrained_tokens == null then true else false,
},
char: {
type: "char_embeddings_from_config",
embedding_dim: char_dim,
dilated_cnn_encoder: {
input_dim: char_dim,
filters: [512, 256, char_dim],
kernel_size: [3, 3, 3],
stride: [1, 1, 1],
padding: [1, 2, 4],
dilation: [1, 2, 4],
activations: ["relu", "relu", "linear"],
},
},
lemma: if in_features("lemma") then {
type: "char_embeddings_from_config",
embedding_dim: lemma_char_dim,
dilated_cnn_encoder: {
input_dim: lemma_char_dim,
filters: [512, 256, lemma_char_dim],
kernel_size: [3, 3, 3],
stride: [1, 1, 1],
padding: [1, 2, 4],
dilation: [1, 2, 4],
activations: ["relu", "relu", "linear"],
},
},
feats: if in_features("feats") then {
type: "feats_embedding",
padding_index: 0,
embedding_dim: feats_dim,
vocab_namespace: "feats",
},
},
},
seq_encoder: {
type: "combo_encoder",
layer_dropout_probability: 0.33,
stacked_bilstm: {
input_size:
(char_dim + projected_embedding_dim +
(if in_features('xpostag') then xpostag_dim else 0) +
(if in_features('lemma') then lemma_char_dim else 0) +
(if in_features('upostag') then upostag_dim else 0) +
(if in_features('feats') then feats_dim else 0)),
hidden_size: hidden_size,
num_layers: num_layers,
recurrent_dropout_probability: 0.33,
layer_dropout_probability: 0.33
},
}
},
heads: {
conllu: {
type: "semantic_multitask_head",
loss_weights: loss_weights,
dependency_relation: {
type: "combo_dependency_parsing_from_vocab",
vocab_namespace: 'deprel_labels',
head_predictor: {
local projection_dim = 512,
cycle_loss_n: cycle_loss_n,
head_projection_layer: {
in_features: hidden_size * 2,
out_features: projection_dim,
activation: "tanh",
},
dependency_projection_layer: {
in_features: hidden_size * 2,
out_features: projection_dim,
activation: "tanh",
},
},
local projection_dim = 128,
head_projection_layer: {
in_features: hidden_size * 2,
out_features: projection_dim,
dropout_rate: predictors_dropout,
activation: "tanh"
},
dependency_projection_layer: {
in_features: hidden_size * 2,
out_features: projection_dim,
dropout_rate: predictors_dropout,
activation: "tanh"
},
},
morphological_feat: if in_targets("feats") then {
type: "combo_morpho_from_vocab",
vocab_namespace: "feats_labels",
input_dim: hidden_size * 2,
hidden_dims: [128],
activations: ["tanh", "linear"],
dropout: [predictors_dropout, 0.0],
num_layers: 2,
},
lemmatizer: if in_targets("lemma") then shared_config.Model.lemma(hidden_size, predictors_dropout),
upos_tagger: if in_targets("upostag") then {
input_dim: hidden_size * 2,
hidden_dims: [64],
activations: ["tanh", "linear"],
dropout: [predictors_dropout, 0.0],
num_layers: 2,
vocab_namespace: "upostag_labels"
},
xpos_tagger: if in_targets("xpostag") then {
input_dim: hidden_size * 2,
hidden_dims: [128],
activations: ["tanh", "linear"],
dropout: [predictors_dropout, 0.0],
num_layers: 2,
vocab_namespace: "xpostag_labels"
},
semantic_relation: if in_targets("semrel") then {
input_dim: hidden_size * 2,
hidden_dims: [64],
activations: ["tanh", "linear"],
dropout: [predictors_dropout, 0.0],
num_layers: 2,
vocab_namespace: "semrel_labels"
},
regularizer: {
regexes: [
[".*conv1d.*", {type: "l2", alpha: 1e-6}],
[".*forward.*", {type: "l2", alpha: 1e-6}],
[".*backward.*", {type: "l2", alpha: 1e-6}],
[".*char_embed.*", {type: "l2", alpha: 1e-5}],
],
},
},
iob: {
type: "ner_head",
feedforward_predictor: {
type: "feedforward_predictor_from_vocab",
input_dim: hidden_size * 2,
hidden_dims: [128],
activations: ["tanh", "linear"],
dropout: [predictors_dropout, 0.0],
num_layers: 2,
vocab_namespace: "ner_labels"
}
},
},
}),
trainer: shared_config.Trainer(cuda_device, num_epochs, learning_rate, use_tensorboard),
random_seed: 8787,
pytorch_seed: 8787,
numpy_seed: 8787,
}
......@@ -25,5 +25,27 @@
},
validation_metric: "+EM",
}),
Trainer: trainer
local lemma(hidden_size, dropout) = {
type: "combo_lemma_predictor_from_vocab",
char_vocab_namespace: "token_characters",
lemma_vocab_namespace: "lemma_characters",
embedding_dim: 256,
input_projection_layer: {
in_features: hidden_size * 2,
out_features: 32,
dropout_rate: dropout,
activation: "tanh"
},
filters: [256, 256, 256],
kernel_size: [3, 3, 3, 1],
stride: [1, 1, 1, 1],
padding: [1, 2, 4, 0],
dilation: [1, 2, 4, 1],
activations: ["relu", "relu", "relu", "linear"],
},
Trainer: trainer,
Model: {
lemma: lemma,
}
}
\ No newline at end of file
import copy
import itertools
import logging
import pathlib
from typing import Union, List, Dict, Iterable, Optional, Any, Tuple
......@@ -6,8 +7,9 @@ from typing import Union, List, Dict, Iterable, Optional, Any, Tuple
import conllu
import torch
from allennlp import data as allen_data
from allennlp.common import checks, util
from allennlp.common import checks, util, file_utils
from allennlp.data import fields as allen_fields, vocabulary
from allennlp.data.dataset_readers import dataset_reader, conll2003
from conllu import parser
from dataclasses import dataclass
from overrides import overrides
......@@ -17,6 +19,36 @@ from combo.data import fields
logger = logging.getLogger(__name__)
@allen_data.DatasetReader.register("iob")
class IOBDatasetReader(conll2003.Conll2003DatasetReader):
"""Extension of the AllenNLP Conll2003DatasetReader with tab as a separator."""
def _read(self, file_path: dataset_reader.PathOrStr) -> Iterable[allen_data.Instance]:
# if `file_path` is a URL, redirect to the cache
file_path = file_utils.cached_path(file_path)
with open(file_path, "r") as data_file:
logger.info("Reading instances from lines in file at: %s", file_path)
# Group lines into sentence chunks based on the divider.
line_chunks = (
lines
for is_divider, lines in itertools.groupby(data_file, conll2003._is_divider)
# Ignore the divider chunks, so that `lines` corresponds to the words
# of a single sentence.
if not is_divider
)
for lines in self.shard_iterable(line_chunks):
fields = [line.strip().split("\t") for line in lines]
# unzipping trick returns tuples, but our Fields need lists
fields = [list(field) for field in zip(*fields)]
tokens_, pos_tags, chunk_tags, ner_tags = fields
# TextField requires `Token` objects
tokens = [allen_data.Token(token) for token in tokens_]
yield self.text_to_instance(tokens, pos_tags, chunk_tags, ner_tags)
@allen_data.DatasetReader.register("conllu")
class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
......@@ -99,7 +131,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
# features
text_field = allen_fields.TextField(tokens, self._token_indexers)
fields_["sentence"] = text_field
fields_["tokens"] = text_field
# targets
if self.generate_labels:
......
......@@ -18,7 +18,7 @@ class TokenCountBatchSampler(allen_data.BatchSampler):
batches = []
batch = []
words_count = 0
lengths = [len(instance.fields["sentence"].tokens) for instance in dataset]
lengths = [len(instance.fields["tokens"].tokens) for instance in dataset]
argsorted_lengths = np.argsort(lengths)
for idx in argsorted_lengths:
words_count += lengths[idx]
......
......@@ -7,3 +7,4 @@ from .encoder import ComboEncoder
from .lemma import LemmatizerModel
from .model import ComboModel
from .morpho import MorphologicalFeatures
from .multitask import MultiTaskModel
......@@ -106,7 +106,8 @@ class FeedForwardPredictor(Predictor):
)
assert vocab_namespace in vocab.get_namespaces(), \
f"There is not {vocab_namespace} in created vocabs, check if this field has any values to predict!"
f"There is not {vocab_namespace} in created vocabs ({','.join(vocab.get_namespaces())}), " \
f"check if this field has any values to predict!"
hidden_dims = hidden_dims + [vocab.get_vocab_size(vocab_namespace)]
return cls(feedforward.FeedForward(
......
"""Main COMBO model."""
from typing import Optional, Dict, Any, List
from typing import Optional, Dict, Any, List, Union
import torch
from allennlp import data, modules, models as allen_models, nn as allen_nn
from allennlp import data, modules, nn as allen_nn
from allennlp.models import heads
from allennlp.modules import text_field_embedders
from allennlp.nn import util
from allennlp.training import metrics as allen_metrics
from overrides import overrides
from combo.models import base
from combo.utils import metrics
@allen_models.Model.register("semantic_multitask")
class ComboModel(allen_models.Model):
@modules.Backbone.register("combo_backbone")
class ComboBackbone(modules.Backbone):
def __init__(self, text_field_embedder: text_field_embedders.TextFieldEmbedder,
seq_encoder: modules.Seq2SeqEncoder):
super().__init__()
self.text_field_embedder = text_field_embedder
self.seq_encoder = seq_encoder
def forward(self, tokens: Dict[str, Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
char_mask = tokens["char"]["token_characters"].gt(0)
word_mask = util.get_text_field_mask(tokens)
encoder_input = self.text_field_embedder(tokens, char_mask=char_mask)
return dict(encoder_emb=self.seq_encoder(encoder_input, word_mask),
word_mask=word_mask,
char_mask=char_mask)
@heads.Head.register("ner_head")
class NERModel(heads.Head):
def __init__(self, feedforward_predictor: base.Predictor, vocab: data.Vocabulary):
super().__init__(vocab)
self.feedforward_predictor = feedforward_predictor
self._accuracy_metric = allen_metrics.CategoricalAccuracy()
# self._f1_metric = allen_metrics.SpanBasedF1Measure(vocab, tag_namespace="ner_labels", label_encoding="IOB1",
# ignore_classes=["_"])
self._loss = 0.0
def forward(self,
encoder_emb: Union[torch.Tensor, List[torch.Tensor]],
word_mask: Optional[torch.BoolTensor] = None,
tags: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
output = self.feedforward_predictor(
x=encoder_emb,
mask=word_mask,
labels=tags,
)
if tags is not None:
self._loss = output["loss"]
self._accuracy_metric(output["probability"], tags, word_mask)
# self._f1_metric(output["probability"], tags, word_mask)
return output
@overrides
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
return {
**{"accuracy": self._accuracy_metric.get_metric(reset), "loss": self._loss},
# **self._f1_metric.get_metric(reset)
}
@heads.Head.register("semantic_multitask_head")
class ComboModel(heads.Head):
"""Main COMBO model."""
def __init__(self,
vocab: data.Vocabulary,
loss_weights: Dict[str, float],
text_field_embedder: text_field_embedders.TextFieldEmbedder,
seq_encoder: modules.Seq2SeqEncoder,
use_sample_weight: bool = True,
lemmatizer: Optional[base.Predictor] = None,
upos_tagger: Optional[base.Predictor] = None,
......@@ -30,10 +84,8 @@ class ComboModel(allen_models.Model):
enhanced_dependency_relation: Optional[base.Predictor] = None,
regularizer: allen_nn.RegularizerApplicator = None) -> None:
super().__init__(vocab, regularizer)
self.text_field_embedder = text_field_embedder
self.loss_weights = loss_weights
self.use_sample_weight = use_sample_weight
self.seq_encoder = seq_encoder
self.lemmatizer = lemmatizer
self.upos_tagger = upos_tagger
self.xpos_tagger = xpos_tagger
......@@ -41,13 +93,16 @@ class ComboModel(allen_models.Model):
self.morphological_feat = morphological_feat
self.dependency_relation = dependency_relation
self.enhanced_dependency_relation = enhanced_dependency_relation
self._head_sentinel = torch.nn.Parameter(torch.randn([1, 1, self.seq_encoder.get_output_dim()]))
self._head_sentinel = torch.nn.Parameter(torch.randn([1, 1, 1024]))
self.scores = metrics.SemanticMetrics()
self._partial_losses = None
@overrides
def forward(self,
sentence: Dict[str, Dict[str, torch.Tensor]],
tokens: Dict[str, Dict[str, torch.Tensor]],
encoder_emb: torch.Tensor,
word_mask: torch.Tensor,
char_mask: torch.Tensor,
metadata: List[Dict[str, Any]],
upostag: torch.Tensor = None,
xpostag: torch.Tensor = None,
......@@ -58,19 +113,11 @@ class ComboModel(allen_models.Model):
semrel: torch.Tensor = None,
enhanced_heads: torch.Tensor = None,
enhanced_deprels: torch.Tensor = None) -> Dict[str, torch.Tensor]:
# Prepare masks
char_mask = sentence["char"]["token_characters"].gt(0)
word_mask = util.get_text_field_mask(sentence)
device = word_mask.device
device = encoder_emb.device
# If enabled weight samples loss by log(sentence_length)
sample_weights = word_mask.sum(-1).float().log() if self.use_sample_weight else None
encoder_input = self.text_field_embedder(sentence, char_mask=char_mask)
encoder_emb = self.seq_encoder(encoder_input, word_mask)
batch_size, _, encoding_dim = encoder_emb.size()
# Concatenate the head sentinel (ROOT) onto the sentence representation.
......@@ -99,8 +146,8 @@ class ComboModel(allen_models.Model):
labels=feats,
sample_weights=sample_weights)
lemma_output = self._optional(self.lemmatizer,
(encoder_emb, sentence.get("char").get("token_characters")
if sentence.get("char") else None),
(encoder_emb, tokens.get("char").get("token_characters")
if tokens.get("char") else None),
mask=word_mask,
labels=lemma.get("char").get("token_characters") if lemma else None,
sample_weights=sample_weights)
......
import collections
from typing import Mapping, List, Dict, Union
import torch
from allennlp import models
@models.Model.register("multitask_extended")
class MultiTaskModel(models.MultiTaskModel):
"""Extension of the AllenNLP MultiTaskModel to handle dictionary inputs."""
def forward(self, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore
if "task" not in kwargs:
raise ValueError(
"Instances for multitask training need to contain a MetadataField with "
"the name 'task' to indicate which task they belong to. Usually the "
"MultitaskDataLoader provides this field and you don't have to do anything."
)
task_indices_just_for_mypy: Mapping[str, List[int]] = collections.defaultdict(lambda: [])
for i, task in enumerate(kwargs["task"]):
task_indices_just_for_mypy[task].append(i)
task_indices: Dict[str, torch.LongTensor] = {
task: torch.LongTensor(indices) for task, indices in task_indices_just_for_mypy.items()
}
def make_inputs_for_task(task: str, whole_batch_input: Union[torch.Tensor, List, Dict]):
if isinstance(whole_batch_input, torch.Tensor):
task_indices[task] = task_indices[task].to(whole_batch_input.device)
return torch.index_select(whole_batch_input, 0, task_indices[task])
if isinstance(whole_batch_input, dict):
return {k: make_inputs_for_task(task, v) for k, v in whole_batch_input.items()}
else:
return [whole_batch_input[i] for i in task_indices[task]]
backbone_arguments = self._get_arguments(kwargs, "backbone")
backbone_outputs = self._backbone(**backbone_arguments)
combined_arguments = {**backbone_outputs, **kwargs}
outputs = {**backbone_outputs}
loss = None
for head_name in self._heads:
if head_name not in task_indices:
continue
head_arguments = self._get_arguments(combined_arguments, head_name)
head_arguments = {
key: make_inputs_for_task(head_name, value) for key, value in head_arguments.items()
}
head_outputs = self._heads[head_name](**head_arguments)
for key in head_outputs:
outputs[f"{head_name}_{key}"] = head_outputs[key]
if "loss" in head_outputs:
self._heads_called.add(head_name)
head_loss = self._loss_weights[head_name] * head_outputs["loss"]
if loss is None:
loss = head_loss
else:
loss += head_loss
if loss is not None:
outputs["loss"] = loss
return outputs
"""Script which fixes -DOCSTART- misspellings."""
import pathlib
from absl import app
from absl import flags
FLAGS = flags.FLAGS
flags.DEFINE_string(name="input_path", default=None,
help="Path to IOB file.")
flags.DEFINE_string(name="output_path", default=None,
help="Path to store fixed IOB file.")
flags.mark_flag_as_required("input_path")
flags.mark_flag_as_required("output_path")
def run(_):
input_path = pathlib.Path(FLAGS.input_path)
assert input_path.exists() and input_path.is_file(), "Input doesn't exists or is not a file."
with input_path.open("r") as input_fh:
with pathlib.Path(FLAGS.output_path).open("w") as output_fh:
for line in input_fh:
# Replace -DOCSTART with -DOCSTART-
if "-DOCSTART" in line and "-DOCSTART-" not in line:
line = line.replace("-DOCSTART", "-DOCSTART-")
output_fh.write(line)
def main():
app.run(run)
if __name__ == "__main__":
main()
......@@ -16,7 +16,7 @@ class TokenCountBatchSamplerTest(unittest.TestCase):
tokens = [data.Token(t)
for t in sentence.split()]
text_field = fields.TextField(tokens, {})
self.dataset.append(data.Instance({"sentence": text_field}))
self.dataset.append(data.Instance({"tokens": text_field}))
def test_batches(self):
# given
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment