Commit 54f1183a authored by Łukasz Kopociński's avatar Łukasz Kopociński

Refactor spert json generator

parent e8f8754d
......@@ -17,7 +17,7 @@ def get_indices(
lexical_split: bool = False,
in_domain: str = None,
random_seed: int = 42
) -> Tuple[Tuple[List, List, List], Dict[int, str]]:
) -> Tuple[List, List, List, Dict[int, str]]:
keys = BrandProductDataset._load_keys(keys_file)
ds_generator = DatasetGenerator(keys, random_seed)
train, valid, test = ds_generator.generate_datasets(
......
#!/usr/bin/env python3.6
from collections import defaultdict
from pathlib import Path
from typing import Dict, List
......@@ -9,52 +10,22 @@ from semrel.data.scripts import constant
from semrel.data.scripts.entities import Relation
from semrel.data.scripts.relations import RelationsLoader
from semrel.data.scripts.utils.io import save_json, load_json
from spert.scripts.entities import SPERTDocument, SPERTDocRelation
from spert.scripts.mapper import InSentenceSPERTMapper, \
from spert.scripts.entities import \
SPERTDocument, \
SPERTDocRelation, \
SPERTRelation
from spert.scripts.mapper import \
InSentenceSPERTMapper, \
BetweenSentencesSPERTMapper
def split_relations(indices: Dict, relations_dict: Dict) -> Dict:
train_relations = [
relation
for index, relation in relations_dict.items()
if index in indices['train']
]
valid_relations = [
relation
for index, relation in relations_dict.items()
if index in indices['valid']
]
test_relations = [
relation
for index, relation in relations_dict.items()
if index in indices['test']
]
return {
'train': train_relations,
'valid': valid_relations,
'test': test_relations
}
def in_same_context(relation: Relation) -> bool:
_, member_from, member_to = relation
return relation.member_from.id_sentence == relation.member_to.id_sentence
def make_key(relation: Relation):
id_document = relation.id_document
id_from = relation.member_from.id_sentence
id_to = relation.member_to.id_sentence
return f'{id_document}-{id_from}-{id_to}'
from spert.scripts.utils import in_same_context, make_relation_key, \
split_relations
def map_relations(
relations: List[Relation],
in_sentence_mapper: InSentenceSPERTMapper,
between_sentence_mapper: BetweenSentencesSPERTMapper
) -> Dict:
) -> List[dict]:
same_context_relations = [
relation
for relation in relations
......@@ -78,18 +49,25 @@ def map_relations(
]
spert_same_context_relations_keys = [
make_key(relation)
make_relation_key(relation)
for relation in same_context_relations
]
spert_diff_context_relations_keys = [
make_key(relation)
make_relation_key(relation)
for relation in diff_context_relations
]
spert_relations = spert_same_context_relations + spert_diff_context_relations
spert_keys = spert_same_context_relations_keys + spert_diff_context_relations_keys
return generate_spert_jsons(spert_relations, spert_keys)
def generate_spert_jsons(
spert_relations: List[SPERTRelation],
spert_keys: List[str]
) -> List[Dict]:
documents = defaultdict(SPERTDocument)
for relation, key in zip(spert_relations, spert_keys):
document = documents[key]
......@@ -108,7 +86,7 @@ def map_relations(
SPERTDocRelation(index_from, index_to, relation.relation_type)
)
return documents
return [document.to_dict() for document in documents.values()]
@click.command()
......@@ -129,19 +107,15 @@ def main(input_path, indices_file, output_dir):
between_sentence_mapper = BetweenSentencesSPERTMapper()
for run_id, run_indices in indices.items():
print(f"\n\nRUN_ID: {run_id}")
run_relations = split_relations(run_indices, relations)
for set_name, set_relations in run_relations.items():
print(f"SET_NAME: {set_name}", end=" ")
documents = map_relations(
relations=set_relations,
in_sentence_mapper=in_sentence_mapper,
between_sentence_mapper=between_sentence_mapper
)
documents = [document.to_dict() for document in documents.values()]
save_path = Path(f'{output_dir}/{run_id}/{set_name}.json')
save_json(documents, save_path)
......
from typing import Dict, List
from entities import Relation
def in_same_context(relation: Relation) -> bool:
id_from = relation.member_from.id_sentence
id_to = relation.member_to.id_sentence
return id_from == id_to
def make_relation_key(relation: Relation) -> str:
id_document = relation.id_document
id_from = relation.member_from.id_sentence
id_to = relation.member_to.id_sentence
return f'{id_document}-{id_from}-{id_to}'
def split_relations(
indices: Dict, relations_dict: Dict
) -> Dict[str, List[Relation]]:
train_relations = _split_relations(indices['train'], relations_dict)
valid_relations = _split_relations(indices['valid'], relations_dict)
test_relations = _split_relations(indices['test'], relations_dict)
return {
'train': train_relations,
'valid': valid_relations,
'test': test_relations
}
def _split_relations(indices: List, relations_dict: Dict) -> List[Relation]:
return [
relation
for index, relation in relations_dict.items()
if index in indices
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment