Commit 102ba87b authored by Łukasz Kopociński's avatar Łukasz Kopociński

Clean up spert package

parent 49536aef
......@@ -18,5 +18,5 @@ credentials
.tox
data
semrel/data/data
spert/data
/vectors
/sent2vec
/corpora
/elmo
/fasttext
/maps
/relations
/relations_files.list
IN_RELATION_LABEL = 'in_relation'
NO_RELATION_LABEL = 'no_relation'
CHANNELS = (('BRAND_NAME', 'PRODUCT_NAME'),
('PRODUCT_NAME', 'BRAND_NAME'))
BRAND_NAME_KEY = 'BRAND_NAME'
PRODUCT_NAME_KEY = 'BRAND_NAME'
CHANNELS = ((BRAND_NAME_KEY, PRODUCT_NAME_KEY),
(PRODUCT_NAME_KEY, BRAND_NAME_KEY))
LABEL = 'label'
DOMAIN = 'id_domain'
......
This diff is collapsed.
......@@ -149,8 +149,13 @@ class DatasetGenerator:
test.extend(test_indices)
return train, valid, test
def generate_datasets(self, balanced: bool, lexical_split: bool,
in_domain: str, out_domain: str = None):
def generate_datasets(
self,
balanced: bool,
lexical_split: bool,
in_domain: str,
out_domain: str = None
) -> Tuple[List, List, List]:
if in_domain:
indices = [index for index, descriptor in self.dataset_keys.items()
if descriptor[1] == in_domain]
......
/spert.json
/dataset
......@@ -2,16 +2,22 @@
pushd "$(git rev-parse --show-toplevel)"
DATA_DIR="./data/vectors"
SCRIPTS_DIR="./data/spert"
OUTPUT_PATH="./data/spert/indices.json"
DATA_DIR="./semrel/data/data"
VECTORS_DIR="${DATA_DIR}/vectors"
SCRIPTS_DIR="./spert/scripts"
OUTPUT_DIR="./spert/data"
OUTPUT_PATH="${OUTPUT_DIR}/indices.json"
mkdir -p ${OUTPUT_DIR}
dvc run \
-d ${DATA_DIR} \
-d ${VECTORS_DIR} \
-d ${SCRIPTS_DIR}/generate_indices.py \
-O ${OUTPUT_PATH} \
-f spert.indices.dvc \
${SCRIPTS_DIR}/generate_indices.py --dataset-keys ${DATA_DIR}/elmo.rel.keys \
-o ${OUTPUT_PATH} \
-f _spert.indices.dvc \
${SCRIPTS_DIR}/generate_indices.py --dataset-keys ${VECTORS_DIR}/elmo.rel.keys \
--output-path ${OUTPUT_PATH}
popd
......@@ -2,11 +2,13 @@
pushd "$(git rev-parse --show-toplevel)"
SCRIPTS_DIR="./data/spert"
INDICES_FILE="./data/spert/indices.json"
DATA_DIR="./semrel/data/data"
INPUT_PATH="./data/relations/relations.tsv"
OUTPUT_DIR="./data/spert/dataset"
SCRIPTS_DIR="./spert/scripts"
INDICES_FILE="./spert/data/indices.json"
INPUT_PATH="${DATA_DIR}/relations/relations.tsv"
OUTPUT_DIR="./spert/data/dataset"
mkdir -p ${OUTPUT_DIR}
......@@ -15,7 +17,7 @@ dvc run \
-d ${INDICES_FILE} \
-d ${SCRIPTS_DIR}/generate_spert_json.py \
-o ${OUTPUT_DIR} \
-f spert.jsons.dvc \
-f _spert.jsons.dvc \
${SCRIPTS_DIR}/generate_spert_json.py --input-path ${INPUT_PATH} \
--indices-file ${INDICES_FILE} \
--output-dir ${OUTPUT_DIR}
......
This diff is collapsed.
from typing import NamedTuple, List, Set
from typing import NamedTuple, List, Set, Dict
class Indices(NamedTuple):
......@@ -48,7 +48,7 @@ class SPERTDocument:
self.entities = entities or []
self.relations = relations or set()
def to_dict(self):
def to_dict(self) -> Dict:
return {
'tokens': self.tokens,
'entities': [entity.to_dict() for entity in self.entities],
......
#!/usr/bin/env python3.6
from pathlib import Path
from typing import List, Tuple, Dict
import click
from scripts import save_json
from semrel.model import RUNS
from semrel.model.scripts.utils import BrandProductDataset, \
from semrel.data.scripts.utils.io import save_json
from semrel.model.runs import RUNS, IN_DOMAIN_KEY, LEXICAL_SPLIT_KEY
from semrel.model.scripts.utils.data_loader import BrandProductDataset, \
DatasetGenerator
......@@ -16,12 +17,13 @@ def get_indices(
lexical_split: bool = False,
in_domain: str = None,
random_seed: int = 42
):
) -> Tuple[Tuple[List, List, List], Dict[int, str]]:
keys = BrandProductDataset._load_keys(keys_file)
ds_generator = DatasetGenerator(keys, random_seed)
return ds_generator.generate_datasets(
train, valid, test = ds_generator.generate_datasets(
balanced, lexical_split, in_domain
), keys
)
return train, valid, test, keys
@click.command()
......@@ -32,8 +34,8 @@ def main(dataset_keys, output_path):
spert_runs = RUNS['spert']
for index, params in spert_runs.items():
in_domain = params.get('in_domain')
lexical_split = params.get('lexical_split', False)
in_domain = params.get(IN_DOMAIN_KEY)
lexical_split = params.get(LEXICAL_SPLIT_KEY, False)
(train, valid, test), keys = get_indices(
keys_file=Path(dataset_keys),
......
......@@ -5,11 +5,12 @@ from typing import Iterator, Dict
import click
from entities import Relation
from scripts import RelationsLoader
from scripts import save_json, load_json
from spert import SPERTDocument, SPERTDocRelation
from spert.mapper import InSentenceSPERTMapper, BetweenSentencesSPERTMapper
from semrel.data.scripts.entities import Relation
from semrel.data.scripts.relations import RelationsLoader
from semrel.data.scripts.utils.io import save_json, load_json
from spert.scripts.entities import SPERTDocument, SPERTDocRelation
from spert.scripts.mapper import InSentenceSPERTMapper, \
BetweenSentencesSPERTMapper
def split_relations(indices: Dict, relations_dict: Dict) -> Dict:
......@@ -25,24 +26,26 @@ def split_relations(indices: Dict, relations_dict: Dict) -> Dict:
'test': test_relations}
def map_relations(relations: Iterator[Relation],
in_sentence_mapper: InSentenceSPERTMapper,
between_sentence_mapper: BetweenSentencesSPERTMapper):
def map_relations(
relations: Iterator[Relation],
in_sentence_mapper: InSentenceSPERTMapper,
between_sentence_mapper: BetweenSentencesSPERTMapper
) -> Dict:
documents = defaultdict(SPERTDocument)
for relation in relations:
id_document, member_from, member_to = relation
in_same_context = member_from.id_sentence == member_to.id_sentence
id_from = relation.member_from.id_sentence
id_to = relation.member_to.id_sentence
key = f'{id_document}-{id_from}-{id_to}'
if in_same_context:
spert_relation = in_sentence_mapper.map(relation)
else:
spert_relation = between_sentence_mapper.map(relation)
id_from = relation.member_from.id_sentence
id_to = relation.member_to.id_sentence
key = f'{id_document}-{id_from}-{id_to}'
document = documents[key]
document.tokens = spert_relation.tokens
......@@ -70,7 +73,6 @@ def map_relations(relations: Iterator[Relation],
@click.option('--output-dir', required=True, type=str,
help='Paths for saving SPERT json file.')
def main(input_path, indices_file, output_dir):
indices = load_json(Path(indices_file))
relations_loader = RelationsLoader(Path(input_path))
relations = relations_loader._filter_relations(filter_label='in_relation')
......
from abc import ABC, abstractmethod
from typing import Tuple, List
from entities import Relation, Member
from spert import SPERTEntity, SPERTRelation
from semrel.data.scripts.constant import BRAND_NAME_KEY, PRODUCT_NAME_KEY
from semrel.data.scripts.entities import Relation, Member
from spert.scripts.entities import SPERTEntity, SPERTRelation
class BrandProductSPERTMapper(ABC):
ENTITY_TYPE_MAP = {
'BRAND_NAME': 'Brand',
'PRODUCT_NAME': 'Product',
BRAND_NAME_KEY: 'Brand',
PRODUCT_NAME_KEY: 'Product',
}
RELATION_TYPE = 'Brand-Product'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment