Skip to content
Snippets Groups Projects
Commit c5031d73 authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Switched id and text in token list

parent fed57c54
Branches
Tags
1 merge request!46Merge COMBO 3.0 into master
......@@ -55,6 +55,54 @@ def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.mode
return _TokenList(tokens=tokens,
metadata=sentence.metadata if sentence.metadata is None else Metadata())
def serialize_field(field: Any) -> str:
if field is None:
return '_'
if isinstance(field, dict):
if field == {}:
return '_'
fields = []
for key, value in field.items():
if value is None:
value = "_"
if value == "":
fields.append(key)
continue
fields.append('='.join((key, value)))
return '|'.join(fields)
if isinstance(field, tuple):
return "".join([serialize_field(item) for item in field])
if isinstance(field, list):
if len(field[0]) != 2:
raise ValueError("Can't serialize '{}', invalid format".format(field))
return "|".join([serialize_field(value) + ":" + str(key) for key, value in field])
return "{}".format(field)
def serialize_token_list(tokenlist: conllu.models.TokenList) -> str:
KEYS_ORDER = ['idx', 'text', 'lemma', 'upostag', 'xpostag',
'entity_type', 'feats', 'head', 'deprel', 'deps', 'misc']
lines = []
if tokenlist.metadata:
for key, value in tokenlist.metadata.items():
if value:
line = f"# {key} = {value}"
else:
line = f"# {key}"
lines.append(line)
for token_data in tokenlist:
line = '\t'.join(serialize_field(token_data[k]) for k in KEYS_ORDER)
lines.append(line)
return '\n'.join(lines) + "\n\n"
def tokens2conllu(tokens: List[str]) -> conllu.models.TokenList:
return _TokenList(
......
......@@ -27,7 +27,6 @@ from combo.modules.model import Model
from combo.utils import ConfigurationError
from combo.utils.matrices import extract_combo_matrices
import codecs
logging.setLoggerClass(ComboLogger)
logger = logging.getLogger(__name__)
......@@ -103,8 +102,6 @@ flags.DEFINE_string(name="input_file", default=None,
help="File to predict path")
flags.DEFINE_boolean(name="conllu_format", default=True,
help="Prediction based on conllu format (instead of raw text).")
flags.DEFINE_boolean(name="silent", default=True,
help="Silent prediction to file (without printing to console).")
flags.DEFINE_boolean(name="finetuning", default=False,
help="Finetuning mode for training.")
flags.DEFINE_string(name="tokenizer_language", default="English",
......@@ -410,8 +407,8 @@ def run(_):
with open(FLAGS.output_file, "w") as file:
for tree in tqdm(test_trees):
prediction = predictor.predict_instance(tree)
file.writelines(api.sentence2conllu(prediction,
keep_semrel=dataset_reader.use_sem).serialize())
file.writelines(api.serialize_token_list(api.sentence2conllu(prediction,
keep_semrel=dataset_reader.use_sem)))
predictions.append(prediction)
else:
......@@ -421,8 +418,8 @@ def run(_):
predictions = predictor.predict(input_sentences)
with open(FLAGS.output_file, "w") as file:
for prediction in tqdm(predictions):
file.writelines(api.sentence2conllu(prediction,
keep_semrel=dataset_reader.use_sem).serialize())
file.writelines(api.serialize_token_list(api.sentence2conllu(prediction,
keep_semrel=dataset_reader.use_sem)))
if FLAGS.save_matrices:
logger.info("Saving matrices", prefix=prefix)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment