Skip to content
Snippets Groups Projects
Unverified Commit 51ae81aa authored by Marcin Wątroba's avatar Marcin Wątroba
Browse files

Fix integration

parent 02e76f7f
2 merge requests!4Feature/add poetry,!3Add ids to words
from typing import Dict
from sziszapangma.core.alignment.alignment_step import AlignmentStep
from sziszapangma.integration.mapper.step_words_mapper import StepWordsMapper
class AlignmentStepMapper:
@staticmethod
def to_json_dict(alignment_step: AlignmentStep) -> Dict[str, any]:
return {
'step_type': alignment_step.step_type.name,
'step_words': StepWordsMapper.to_json_dict(
alignment_step.step_words),
'step_cost': alignment_step.step_cost
}
from typing import Dict
from sziszapangma.core.alignment.step_words import StepWords
from sziszapangma.integration.mapper.word_mapper import WordMapper
class StepWordsMapper:
@staticmethod
def to_json_dict(step_words: StepWords) -> Dict[str, any]:
to_return = dict()
if step_words.hypothesis_word is not None:
to_return['hypothesis_word'] = WordMapper.to_json_dict(
step_words.hypothesis_word)
if step_words.reference_word is not None:
to_return['reference_word'] = WordMapper.to_json_dict(
step_words.reference_word)
return to_return
@staticmethod
def from_json_dict(input_json_dict: Dict[str, any]) -> StepWords:
return StepWords(
None if 'reference_word' not in input_json_dict
else WordMapper.from_json_dict(input_json_dict['reference_word']),
None if 'hypothesis_word' not in input_json_dict
else WordMapper.from_json_dict(input_json_dict['hypothesis_word']),
)
from typing import Dict
from sziszapangma.core.alignment.alignment_step import AlignmentStep
from sziszapangma.core.alignment.step_words import StepWords
from sziszapangma.core.alignment.word import Word
_ID = 'id'
_VALUE = 'value'
class WordMapper:
@staticmethod
def to_json_dict(word: Word) -> Dict[str, str]:
return {
_ID: word.id,
_VALUE: word.value
}
@staticmethod
def from_json_dict(input_json_dict: Dict[str, str]) -> Word:
return Word(input_json_dict[_ID], input_json_dict[_VALUE])
......@@ -2,17 +2,22 @@ from typing import List, Dict
from sziszapangma.core.alignment.alignment_classic_calculator import \
AlignmentClassicCalculator
from sziszapangma.core.alignment.alignment_step import AlignmentStep
from sziszapangma.core.alignment.word import Word
from sziszapangma.core.wer.wer_calculator import WerCalculator
from sziszapangma.integration.mapper.alignment_step_mapper import \
AlignmentStepMapper
from sziszapangma.integration.mapper.word_mapper import WordMapper
from sziszapangma.integration.repository.experiment_repository import \
ExperimentRepository
from sziszapangma.integration.task.processing_task import ProcessingTask
_CLASSIC_WER = 'classic_wer'
_WORD = 'word'
class ClassicWerMetricTask(ProcessingTask):
_metrics_property_name: str
_alignment_property_name: str
_gold_transcript_property_name: str
_alignment_classic_calculator: AlignmentClassicCalculator
_wer_calculator: WerCalculator
......@@ -23,17 +28,22 @@ class ClassicWerMetricTask(ProcessingTask):
gold_transcript_property_name: str,
asr_property_name: str,
metrics_property_name: str,
alignment_property_name: str,
require_update: bool
):
super().__init__(task_name, require_update)
self._gold_transcript_property_name = gold_transcript_property_name
self._asr_property_name = asr_property_name
self._alignment_property_name = alignment_property_name
self._metrics_property_name = metrics_property_name
self._alignment_classic_calculator = AlignmentClassicCalculator()
self._wer_calculator = WerCalculator()
def skip_for_record(self, record_id: str,
experiment_repository: ExperimentRepository) -> bool:
def skip_for_record(
self,
record_id: str,
experiment_repository: ExperimentRepository
) -> bool:
return experiment_repository \
.get_property_for_key(record_id, self._metrics_property_name)
......@@ -45,34 +55,44 @@ class ClassicWerMetricTask(ProcessingTask):
asr_result = experiment_repository \
.get_property_for_key(record_id, self._asr_property_name)
if 'transcription' in asr_result:
alignment_steps = self._get_alignment(
gold_transcript, asr_result['transcription']
)
experiment_repository.update_property_for_key(
record_id,
self._alignment_property_name,
[AlignmentStepMapper.to_json_dict(it)
for it in alignment_steps]
)
experiment_repository.update_property_for_key(
record_id,
self._metrics_property_name,
self.calculate_metrics(
gold_transcript=gold_transcript,
asr_result=asr_result['transcription']
)
self.calculate_metrics(alignment_steps)
)
def _run_wer_calculations(
def _get_alignment(
self,
gold_transcript: List[Dict[str, any]],
asr_result: List[str]
) -> float:
return self._wer_calculator.calculate_wer(
self._alignment_classic_calculator.calculate_alignment(
reference=[it[_WORD] for it in gold_transcript],
hypothesis=[it for it in asr_result],
)
)
asr_result: List[Dict[str, any]]
) -> List[AlignmentStep]:
gold_transcript_words = [
WordMapper.from_json_dict(word_dict)
for word_dict in gold_transcript
]
asr_words = [
WordMapper.from_json_dict(word_dict)
for word_dict in asr_result
]
return self._alignment_classic_calculator \
.calculate_alignment(reference=gold_transcript_words,
hypothesis=asr_words)
def calculate_metrics(
self,
gold_transcript: List[Dict[str, any]],
asr_result: List[str]
alignment_steps: List[AlignmentStep]
) -> Dict[str, any]:
"""Calculate all metrics for data sample."""
metrics = dict()
metrics[_CLASSIC_WER] = self._run_wer_calculations(
gold_transcript, asr_result)
metrics[_CLASSIC_WER] = self._wer_calculator.calculate_wer(
alignment_steps)
return metrics
......@@ -4,11 +4,15 @@ from sziszapangma.core.alignment.alignment_embedding_calculator import \
AlignmentEmbeddingCalculator
from sziszapangma.core.alignment.alignment_soft_calculator import \
AlignmentSoftCalculator
from sziszapangma.core.alignment.word import Word
from sziszapangma.core.transformer.cached_embedding_transformer import \
CachedEmbeddingTransformer
from sziszapangma.core.transformer.embedding_transformer import \
EmbeddingTransformer
from sziszapangma.core.wer.wer_calculator import WerCalculator
from sziszapangma.integration.mapper.alignment_step_mapper import \
AlignmentStepMapper
from sziszapangma.integration.mapper.word_mapper import WordMapper
from sziszapangma.integration.repository.experiment_repository import \
ExperimentRepository
from sziszapangma.integration.task.processing_task import ProcessingTask
......@@ -20,6 +24,7 @@ _WORD = 'word'
class EmbeddingWerMetricsTask(ProcessingTask):
_metrics_property_name: str
_alignment_property_name: str
_gold_transcript_property_name: str
_embedding_transformer: CachedEmbeddingTransformer
_alignment_embedding_calculator: AlignmentEmbeddingCalculator
......@@ -32,6 +37,7 @@ class EmbeddingWerMetricsTask(ProcessingTask):
gold_transcript_property_name: str,
asr_property_name: str,
metrics_property_name: str,
alignment_property_name: str,
require_update: bool,
embedding_transformer: EmbeddingTransformer
):
......@@ -46,6 +52,7 @@ class EmbeddingWerMetricsTask(ProcessingTask):
self._alignment_soft_calculator = \
AlignmentSoftCalculator(self._embedding_transformer)
self._wer_calculator = WerCalculator()
self._alignment_property_name = alignment_property_name
def skip_for_record(self, record_id: str,
experiment_repository: ExperimentRepository) -> bool:
......@@ -60,26 +67,39 @@ class EmbeddingWerMetricsTask(ProcessingTask):
asr_result = experiment_repository \
.get_property_for_key(record_id, self._asr_property_name)
if 'transcription' in asr_result:
gold_transcript_words = self._map_words_to_domain(gold_transcript)
asr_words = self._map_words_to_domain(asr_result['transcription'])
soft_alignment = self._alignment_soft_calculator \
.calculate_alignment(gold_transcript_words, asr_words)
embedding_alignment = self._alignment_embedding_calculator \
.calculate_alignment(gold_transcript_words, asr_words)
soft_wer = self._wer_calculator.calculate_wer(soft_alignment)
embedding_wer = self._wer_calculator \
.calculate_wer(embedding_alignment)
alignment_results = {
'soft_alignment': [AlignmentStepMapper.to_json_dict(it)
for it in soft_alignment],
'embedding_alignment': [AlignmentStepMapper.to_json_dict(it)
for it in embedding_alignment],
}
wer_results = {'soft_wer': soft_wer,
'embedding_wer': embedding_wer}
experiment_repository.update_property_for_key(
record_id,
self._metrics_property_name,
self.calculate_metrics(
gold_transcript=gold_transcript,
asr_result=asr_result['transcription']
)
)
record_id, self._alignment_property_name, alignment_results)
experiment_repository.update_property_for_key(
record_id, self._metrics_property_name, wer_results)
self._embedding_transformer.clear()
def calculate_metrics(
self,
gold_transcript: List[Dict[str, any]],
asr_result: List[str]
) -> Dict[str, any]:
"""Calculate all metrics for data sample."""
metrics = dict()
reference = [it[_WORD] for it in gold_transcript]
metrics[_SOFT_WER] = self._alignment_soft_calculator\
.calculate_alignment(reference, asr_result)[0]
metrics[_EMBEDDING_WER] = self._alignment_embedding_calculator\
.calculate_wer(reference, asr_result)[0]
return metrics
@staticmethod
def _map_words_to_domain(
input_json_dicts: List[Dict[str, str]]
) -> List[Word]:
return [
WordMapper.from_json_dict(word_dict)
for word_dict in input_json_dicts
]
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment