Skip to content
Snippets Groups Projects
Unverified Commit e800454d authored by Marcin Wątroba's avatar Marcin Wątroba
Browse files

Change model and update pipeline

parent e7a1f7ac
Branches
No related tags found
1 merge request!13Change data model
Showing
with 253 additions and 109 deletions
No preview for this file type
No preview for this file type
No preview for this file type
......@@ -3,6 +3,7 @@ from typing import List
from sziszapangma.integration.repository.experiment_repository import ExperimentRepository
from .record_id_iterator import RecordIdIterator
from .relation_manager_provider import RelationManagerProvider
from .task.processing_task import ProcessingTask
......@@ -10,6 +11,7 @@ class ExperimentManager:
_experiment_repository: ExperimentRepository
_record_id_iterator: RecordIdIterator
_processing_tasks: List[ProcessingTask]
_relation_manager_provider: RelationManagerProvider
def __init__(
self,
......@@ -24,4 +26,5 @@ class ExperimentManager:
def process(self):
self._experiment_repository.initialise()
for processing_task in self._processing_tasks:
processing_task.process(self._record_id_iterator, self._experiment_repository)
processing_task.process(self._record_id_iterator, self._experiment_repository,
self._relation_manager_provider)
from typing import Any, Dict
from sziszapangma.core.alignment.step_words import StepWords
from sziszapangma.integration.mapper.word_mapper import WordMapper
class StepWordsMapper:
......@@ -9,18 +8,15 @@ class StepWordsMapper:
def to_json_dict(step_words: StepWords) -> Dict[str, Any]:
to_return = dict()
if step_words.hypothesis_word is not None:
to_return["hypothesis_word"] = WordMapper.to_json_dict(step_words.hypothesis_word)
to_return["hypothesis_word"] = step_words.hypothesis_word
if step_words.reference_word is not None:
to_return["reference_word"] = WordMapper.to_json_dict(step_words.reference_word)
to_return["reference_word"] = step_words.reference_word
return to_return
@staticmethod
def from_json_dict(input_json_dict: Dict[str, Any]) -> StepWords:
return StepWords(
None
if "reference_word" not in input_json_dict
else WordMapper.from_json_dict(input_json_dict["reference_word"]),
None
if "hypothesis_word" not in input_json_dict
else WordMapper.from_json_dict(input_json_dict["hypothesis_word"]),
)
reference_word = None if "reference_word" not in input_json_dict \
else input_json_dict["reference_word"]
hypothesis_word = None if "hypothesis_word" not in input_json_dict \
else input_json_dict["hypothesis_word"]
return StepWords(reference_word, hypothesis_word)
from typing import Dict
from sziszapangma.core.alignment.word import Word
_ID = "id"
_VALUE = "value"
class WordMapper:
@staticmethod
def to_json_dict(word: Word) -> Dict[str, str]:
return {_ID: word.id, _VALUE: word.value}
@staticmethod
def from_json_dict(input_json_dict: Dict[str, str]) -> Word:
return Word(input_json_dict[_ID], input_json_dict[_VALUE])
from abc import ABC, abstractmethod
from sziszapangma.model.relation_manager import RelationManager
class RelationManagerProvider(ABC):
@abstractmethod
def get_relation_manager(self, record_id: str) -> RelationManager:
pass
No preview for this file type
File deleted
from sziszapangma.core.alignment.word import Word
from sziszapangma.integration.asr_processor import AsrProcessor
from sziszapangma.integration.mapper.word_mapper import WordMapper
from sziszapangma.integration.record_path_provider import RecordPathProvider
from sziszapangma.integration.repository.experiment_repository import ExperimentRepository
from sziszapangma.integration.task.processing_task import ProcessingTask
from sziszapangma.model.model_creators import create_new_word
from sziszapangma.model.relation_manager import RelationManager
class AsrTask(ProcessingTask):
......@@ -30,12 +30,13 @@ class AsrTask(ProcessingTask):
return asr_value is not None and "transcription" in asr_value
def run_single_process(
self, record_id: str, experiment_repository: ExperimentRepository
self, record_id: str, experiment_repository: ExperimentRepository,
relation_manager: RelationManager
) -> None:
file_record_path = self._record_path_provider.get_path(record_id)
asr_result = self._asr_processor.call_recognise(file_record_path)
asr_result["transcription"] = [
WordMapper.to_json_dict(Word.from_string(it)) for it in asr_result["transcription"]
create_new_word(it) for it in asr_result["transcription"]
]
experiment_repository.update_property_for_key(
record_id, self._asr_property_name, asr_result
......
......@@ -2,12 +2,13 @@ from typing import Any, Dict, List
from sziszapangma.core.alignment.alignment_classic_calculator import AlignmentClassicCalculator
from sziszapangma.core.alignment.alignment_step import AlignmentStep
from sziszapangma.core.alignment.word import Word
from sziszapangma.core.wer.wer_calculator import WerCalculator
from sziszapangma.integration.mapper.alignment_step_mapper import AlignmentStepMapper
from sziszapangma.integration.mapper.word_mapper import WordMapper
from sziszapangma.integration.repository.experiment_repository import ExperimentRepository
from sziszapangma.integration.task.processing_task import ProcessingTask
from sziszapangma.integration.task.task_util import TaskUtil
from sziszapangma.model.model import Word
from sziszapangma.model.relation_manager import RelationManager
_CLASSIC_WER = "classic_wer"
......@@ -42,10 +43,9 @@ class ClassicWerMetricTask(ProcessingTask):
is not None
)
def run_single_process(self, record_id: str, experiment_repository: ExperimentRepository):
gold_transcript = experiment_repository.get_property_for_key(
record_id, self._gold_transcript_property_name
)
def run_single_process(self, record_id: str, experiment_repository: ExperimentRepository,
relation_manager: RelationManager):
gold_transcript = TaskUtil.get_words_from_record(relation_manager)
asr_result = experiment_repository.get_property_for_key(record_id, self._asr_property_name)
if gold_transcript is not None and asr_result is not None and "transcription" in asr_result:
alignment_steps = self._get_alignment(gold_transcript, asr_result["transcription"])
......@@ -59,16 +59,12 @@ class ClassicWerMetricTask(ProcessingTask):
)
def _get_alignment(
self, gold_transcript: List[Dict[str, Any]], asr_result: List[Dict[str, Any]]
self,
gold_transcript: List[Word],
asr_transcript: List[Word]
) -> List[AlignmentStep]:
gold_transcript_words = [
# WordMapper.from_json_dict(word_dict)
Word(word_dict["id"], word_dict["word"])
for word_dict in gold_transcript
]
asr_words = [WordMapper.from_json_dict(word_dict).to_lower() for word_dict in asr_result]
return self._alignment_classic_calculator.calculate_alignment(
reference=gold_transcript_words, hypothesis=asr_words
reference=gold_transcript, hypothesis=asr_transcript
)
def calculate_metrics(self, alignment_steps: List[AlignmentStep]) -> Dict[str, Any]:
......
from typing import Dict, List
from sziszapangma.core.alignment.alignment_embedding_calculator import AlignmentEmbeddingCalculator
from sziszapangma.core.alignment.alignment_soft_calculator import AlignmentSoftCalculator
from sziszapangma.core.alignment.word import Word
from sziszapangma.core.transformer.cached_embedding_transformer import CachedEmbeddingTransformer
from sziszapangma.core.transformer.embedding_transformer import EmbeddingTransformer
from sziszapangma.core.wer.wer_calculator import WerCalculator
from sziszapangma.integration.mapper.alignment_step_mapper import AlignmentStepMapper
from sziszapangma.integration.mapper.word_mapper import WordMapper
from sziszapangma.integration.repository.experiment_repository import ExperimentRepository
from sziszapangma.integration.task.processing_task import ProcessingTask
from sziszapangma.integration.task.task_util import TaskUtil
from sziszapangma.model.relation_manager import RelationManager
_SOFT_WER = "soft_wer"
_EMBEDDING_WER = "embedding_wer"
......@@ -53,24 +51,19 @@ class EmbeddingWerMetricsTask(ProcessingTask):
is not None
)
def run_single_process(self, record_id: str, experiment_repository: ExperimentRepository):
gold_transcript = experiment_repository.get_property_for_key(
record_id, self._gold_transcript_property_name
)
def run_single_process(self, record_id: str, experiment_repository: ExperimentRepository,
relation_manager: RelationManager):
gold_transcript = TaskUtil.get_words_from_record(relation_manager)
asr_result = experiment_repository.get_property_for_key(record_id, self._asr_property_name)
if gold_transcript is not None and asr_result is not None and "transcription" in asr_result:
gold_transcript_words = self._map_words_to_domain_gold_transcript(gold_transcript)
asr_words = [
it
for it in self._map_words_to_domain(asr_result["transcription"])
if len(it.value) > 0
]
gold_transcript_lower = TaskUtil.words_to_lower(gold_transcript)
asr_transcript_lower = TaskUtil.words_to_lower(asr_result["transcription"])
soft_alignment = self._alignment_soft_calculator.calculate_alignment(
gold_transcript_words, asr_words
gold_transcript_lower, asr_transcript_lower
)
embedding_alignment = self._alignment_embedding_calculator.calculate_alignment(
gold_transcript_words, asr_words
gold_transcript_lower, asr_transcript_lower
)
soft_wer = self._wer_calculator.calculate_wer(soft_alignment)
......@@ -92,13 +85,3 @@ class EmbeddingWerMetricsTask(ProcessingTask):
)
self._embedding_transformer.clear()
@staticmethod
def _map_words_to_domain(input_json_dicts: List[Dict[str, str]]) -> List[Word]:
return [WordMapper.from_json_dict(word_dict).to_lower() for word_dict in input_json_dicts]
@staticmethod
def _map_words_to_domain_gold_transcript(input_json_dicts: List[Dict[str, str]]) -> List[Word]:
return [
Word(word_dict["id"], word_dict["word"]).to_lower() for word_dict in input_json_dicts
]
from sziszapangma.integration.gold_transcript_processor import GoldTranscriptProcessor
from sziszapangma.integration.repository.experiment_repository import ExperimentRepository
from sziszapangma.integration.task.processing_task import ProcessingTask
class GoldTranscriptTask(ProcessingTask):
_gold_transcript_processor: GoldTranscriptProcessor
_gold_transcript_property_name: str
def __init__(
self,
task_name: str,
gold_transcript_processor: GoldTranscriptProcessor,
gold_transcript_property_name: str,
require_update: bool,
):
super().__init__(task_name, require_update)
self._gold_transcript_processor = gold_transcript_processor
self._gold_transcript_property_name = gold_transcript_property_name
def skip_for_record(self, record_id: str, experiment_repository: ExperimentRepository) -> bool:
return (
experiment_repository.get_property_for_key(
record_id, self._gold_transcript_property_name
)
is not None
)
def run_single_process(self, record_id: str, experiment_repository: ExperimentRepository):
experiment_repository.update_property_for_key(
record_id,
self._gold_transcript_property_name,
self._gold_transcript_processor.get_gold_transcript(record_id),
)
......@@ -2,7 +2,9 @@ import traceback
from abc import ABC, abstractmethod
from sziszapangma.integration.record_id_iterator import RecordIdIterator
from sziszapangma.integration.relation_manager_provider import RelationManagerProvider
from sziszapangma.integration.repository.experiment_repository import ExperimentRepository
from sziszapangma.model.relation_manager import RelationManager
class ProcessingTask(ABC):
......@@ -14,7 +16,8 @@ class ProcessingTask(ABC):
self._task_name = task_name
@abstractmethod
def run_single_process(self, record_id: str, experiment_repository: ExperimentRepository):
def run_single_process(self, record_id: str, experiment_repository: ExperimentRepository,
relation_manager: RelationManager):
pass
@abstractmethod
......@@ -22,7 +25,8 @@ class ProcessingTask(ABC):
pass
def process(
self, record_id_iterator: RecordIdIterator, experiment_repository: ExperimentRepository
self, record_id_iterator: RecordIdIterator, experiment_repository: ExperimentRepository,
relation_manager_provider: RelationManagerProvider
):
records_ids = list(record_id_iterator.get_all_records())
for record_index in range(len(records_ids)):
......@@ -35,7 +39,8 @@ class ProcessingTask(ABC):
if not skip or self._require_update:
print(base_log)
try:
self.run_single_process(record_id, experiment_repository)
relation_manager = relation_manager_provider.get_relation_manager(record_id)
self.run_single_process(record_id, experiment_repository, relation_manager)
except Exception as err:
print("Handling run-time error:", err)
traceback.print_exc()
......
from typing import List
from sziszapangma.model.model import Word
from sziszapangma.model.relation_manager import RelationManager
class TaskUtil:
@staticmethod
def get_words_from_record(relation_manager: RelationManager) -> List[Word]:
document = [itt for itt in relation_manager.get_all_items() if itt['type'] == 'Document'][0]
return [relation_manager.get_item_by_id(item_id) for item_id in document['word_ids']]
@staticmethod
def _word_to_lower(word: Word) -> Word:
return Word(id=word['id'], type='Word', text=word['text'].lower())
@staticmethod
def words_to_lower(words: List[Word]) -> List[Word]:
return [TaskUtil._word_to_lower(it) for it in words]
from typing import List, TypedDict, Literal, Any, Union
AnnotationType = Literal["lemma", "pos", "morph", "concept", "chunk", "turn"]
ReferenceType = Literal["Token", "Word", "Document", "SpanAnnotation"]
UUIDableType = Union[AnnotationType, ReferenceType]
class UUIDable(TypedDict):
id: str
type: UUIDableType
class Token(UUIDable):
text: str
class Word(UUIDable):
text: str
class Document(UUIDable):
word_ids: List[str]
class Annotation(UUIDable):
value: Any
class SingleAnnotation(Annotation):
reference_id: str
class SpanAnnotation(UUIDable):
name: str
elements: List[str]
class RelationAnnotation(UUIDable):
parent: UUIDable
child: UUIDable
import uuid
from typing import List
from sziszapangma.model.model import Word, SingleAnnotation, AnnotationType, SpanAnnotation, \
Document
def _get_uuid() -> str:
return str(uuid.uuid4())
def create_new_word(text: str) -> Word:
return Word(id=_get_uuid(), type='Word', text=text)
def create_new_single_annotation(
annotation_type: AnnotationType,
value: any,
reference_id: str
) -> SingleAnnotation:
return SingleAnnotation(id=_get_uuid(), type=annotation_type, value=value,
reference_id=reference_id)
def create_new_span_annotation(name: str, elements: List[str]) -> SpanAnnotation:
return SpanAnnotation(id=_get_uuid(), type='SpanAnnotation', name=name, elements=elements)
def create_new_document(word_ids: List[str]) -> Document:
return Document(id=_get_uuid(), type='Document', word_ids=word_ids)
import json
import os
from abc import abstractmethod, ABC
from typing import List, TypedDict, Dict
import pandas as pd
from sziszapangma.model.model import UUIDable
class RelationItem(TypedDict):
first_id: str
first_type: str
second_id: str
second_type: str
class RelationManager(ABC):
@abstractmethod
def get_all_relations_for_item(self, item_id: str) -> List[RelationItem]:
pass
@abstractmethod
def get_item_by_id(self, item_id: str) -> UUIDable:
pass
@abstractmethod
def get_items_by_query(self, query: str) -> List[UUIDable]:
pass
@abstractmethod
def save_item(self, item: UUIDable):
pass
@abstractmethod
def save_relation(self, item1: UUIDable, item2: UUIDable):
pass
@abstractmethod
def commit(self):
pass
@abstractmethod
def get_all_items(self) -> List[UUIDable]:
pass
class FileRelationManager(RelationManager):
relations_csv_path: str
items_json_path: str
relations_dataframe: pd.DataFrame
items_dict: Dict[str, UUIDable]
def __init__(self, relations_csv_path: str, items_json_path: str):
self.relations_csv_path = relations_csv_path
self.items_json_path = items_json_path
if os.path.isfile(relations_csv_path):
self.relations_dataframe = pd.read_csv(relations_csv_path)
else:
self.relations_dataframe = pd.DataFrame(
[], columns=["first_id", "first_type", "second_id", "second_type"]
)
if os.path.isfile(items_json_path):
with open(items_json_path, "r") as f:
self.items_dict = json.loads(f.read())
else:
self.items_dict = dict()
def get_all_items(self) -> List[UUIDable]:
return list(self.items_dict.values())
def get_items_by_query(self, query: str) -> List[UUIDable]:
return self.relations_dataframe.query(query).to_dict("records")
def get_all_relations_for_item(self, item_id: str) -> List[RelationItem]:
return self.relations_dataframe[
self.relations_dataframe["first_id"] == item_id
].to_dict("records")
def get_item_by_id(self, item_id) -> UUIDable:
return self.items_dict[item_id]
def save_item(self, item: UUIDable):
self.items_dict[item.get("id")] = item
def save_relation(self, item1: UUIDable, item2: UUIDable):
relation_1_2 = {
"first_id": item1["id"],
"first_type": item1["type"],
"second_id": item2["id"],
"second_type": item2["type"],
}
relation_2_1 = {
"first_id": item2["id"],
"first_type": item2["type"],
"second_id": item1["id"],
"second_type": item1["type"],
}
self.relations_dataframe = self.relations_dataframe.append(
relation_1_2, ignore_index=True
)
self.relations_dataframe = self.relations_dataframe.append(
relation_2_1, ignore_index=True
)
def commit(self):
self.relations_dataframe.to_csv(self.relations_csv_path, index=False)
items_json = json.dumps(self.items_dict)
with open(self.items_json_path, "w") as f:
f.write(items_json)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment