Commit e8f8754d authored by Łukasz Kopociński's avatar Łukasz Kopociński

Refactor spert json generator

parent 8868581a
......@@ -152,7 +152,7 @@ class RelationsLoader:
yield label, id_domain, relation
def _filter_relations(self, filter_label: str) -> Dict[int, Relation]:
def filter_relations(self, filter_label: str) -> Dict[int, Relation]:
"""Experimental feature"""
relations_dict = {}
index = 0
......
#!/usr/bin/env python3.6
from collections import defaultdict
from pathlib import Path
from typing import Iterator, Dict
from typing import Dict, List
import click
from semrel.data.scripts import constant
from semrel.data.scripts.entities import Relation
from semrel.data.scripts.relations import RelationsLoader
from semrel.data.scripts.utils.io import save_json, load_json
......@@ -14,55 +15,100 @@ from spert.scripts.mapper import InSentenceSPERTMapper, \
def split_relations(indices: Dict, relations_dict: Dict) -> Dict:
train_relations = [relation for index, relation in relations_dict.items()
if index in indices['train']]
valid_relations = [relation for index, relation in relations_dict.items()
if index in indices['valid']]
test_relations = [relation for index, relation in relations_dict.items()
if index in indices['test']]
return {'train': train_relations,
'valid': valid_relations,
'test': test_relations}
train_relations = [
relation
for index, relation in relations_dict.items()
if index in indices['train']
]
valid_relations = [
relation
for index, relation in relations_dict.items()
if index in indices['valid']
]
test_relations = [
relation
for index, relation in relations_dict.items()
if index in indices['test']
]
return {
'train': train_relations,
'valid': valid_relations,
'test': test_relations
}
def in_same_context(relation: Relation) -> bool:
_, member_from, member_to = relation
return relation.member_from.id_sentence == relation.member_to.id_sentence
def make_key(relation: Relation):
id_document = relation.id_document
id_from = relation.member_from.id_sentence
id_to = relation.member_to.id_sentence
return f'{id_document}-{id_from}-{id_to}'
def map_relations(
relations: Iterator[Relation],
relations: List[Relation],
in_sentence_mapper: InSentenceSPERTMapper,
between_sentence_mapper: BetweenSentencesSPERTMapper
) -> Dict:
documents = defaultdict(SPERTDocument)
for relation in relations:
id_document, member_from, member_to = relation
in_same_context = member_from.id_sentence == member_to.id_sentence
if in_same_context:
spert_relation = in_sentence_mapper.map(relation)
else:
spert_relation = between_sentence_mapper.map(relation)
id_from = relation.member_from.id_sentence
id_to = relation.member_to.id_sentence
key = f'{id_document}-{id_from}-{id_to}'
same_context_relations = [
relation
for relation in relations
if in_same_context(relation)
]
diff_context_relations = [
relation
for relation in relations
if not in_same_context(relation)
]
spert_same_context_relations = [
in_sentence_mapper.map(relation)
for relation in same_context_relations
]
spert_diff_context_relations = [
between_sentence_mapper.map(relation)
for relation in diff_context_relations
]
spert_same_context_relations_keys = [
make_key(relation)
for relation in same_context_relations
]
spert_diff_context_relations_keys = [
make_key(relation)
for relation in diff_context_relations
]
spert_relations = spert_same_context_relations + spert_diff_context_relations
spert_keys = spert_same_context_relations_keys + spert_diff_context_relations_keys
documents = defaultdict(SPERTDocument)
for relation, key in zip(spert_relations, spert_keys):
document = documents[key]
document.tokens = spert_relation.tokens
document.tokens = relation.tokens
if spert_relation.head not in document.entities:
document.entities.append(spert_relation.head)
if relation.head not in document.entities:
document.entities.append(relation.head)
if spert_relation.tail not in document.entities:
document.entities.append(spert_relation.tail)
if relation.tail not in document.entities:
document.entities.append(relation.tail)
index_from = document.entities.index(spert_relation.head)
index_to = document.entities.index(spert_relation.tail)
index_from = document.entities.index(relation.head)
index_to = document.entities.index(relation.tail)
document.relations.add(
SPERTDocRelation(index_from, index_to, spert_relation.relation_type)
SPERTDocRelation(index_from, index_to, relation.relation_type)
)
return documents
return documents
@click.command()
......@@ -75,7 +121,9 @@ def map_relations(
def main(input_path, indices_file, output_dir):
indices = load_json(Path(indices_file))
relations_loader = RelationsLoader(Path(input_path))
relations = relations_loader._filter_relations(filter_label='in_relation')
relations = relations_loader.filter_relations(
filter_label=constant.IN_RELATION_LABEL
)
in_sentence_mapper = InSentenceSPERTMapper()
between_sentence_mapper = BetweenSentencesSPERTMapper()
......@@ -94,7 +142,8 @@ def main(input_path, indices_file, output_dir):
documents = [document.to_dict() for document in documents.values()]
save_json(documents, Path(f'{output_dir}/{run_id}/{set_name}.json'))
save_path = Path(f'{output_dir}/{run_id}/{set_name}.json')
save_json(documents, save_path)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment