Skip to content
Snippets Groups Projects
pipeline_process_flair_upos.py 1.87 KiB
Newer Older
Marcin Wątroba's avatar
Marcin Wątroba committed
import argparse

from experiment.const_pipeline_names import GOLD_TRANSCRIPT
from experiment.experiment_dependency_provider import get_record_provider, get_repository
from experiment.sentence_wer_processor.flair_upos_multi_transformers_wer_processor_base import \
    FlairUposMultiTransformersWerProcessorBase
from experiment.sentence_wer_processor.wikineural_multilingual_ner_transformers_wer_processor_base import \
    WikineuralMultilingualNerTransformersWerProcessorBase
from sziszapangma.core.transformer.web_embedding_transformer import WebEmbeddingTransformer
from sziszapangma.integration.experiment_manager import ExperimentManager
from sziszapangma.integration.task.classic_wer_metric_task import ClassicWerMetricTask
from sziszapangma.integration.task.embedding_wer_metrics_task import EmbeddingWerMetricsTask


def run_word_wer_pipeline(dataset_name: str, asr_name: str):
    record_provider = get_record_provider(dataset_name)
    experiment_processor = ExperimentManager(
        record_id_iterator=record_provider,
        processing_tasks=[
            FlairUposMultiTransformersWerProcessorBase(
                gold_transcript_property_name=GOLD_TRANSCRIPT,
                asr_property_name=f'{asr_name}__result',
                alignment_property_name=f'{asr_name}__flair_upos_alignment',
                wer_property_name=f'{asr_name}__flair_upos_metrics',
                task_name=f'FlairUposMultiTransformersWerProcessorBase___{dataset_name}___{asr_name}',
                require_update=False
            )
        ],
        experiment_repository=get_repository(dataset_name),
        relation_manager_provider=record_provider
    )
    experiment_processor.process()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset")
    parser.add_argument("--asr")
    args = parser.parse_args()
    run_word_wer_pipeline(args.dataset, args.asr)