Skip to content
Snippets Groups Projects
Select Git revision
  • master
  • main default protected
  • ud_training_script
  • fix_seed
  • merged-with-ner
  • multiword_fix_transformer
  • transformer_encoder
  • combo3
  • save_deprel_matrix_to_npz
  • combo-lambo
  • lambo-sent-attributes
  • adding_lambo
  • develop
  • update_allenlp2
  • develop_tmp
  • tokens_truncation
  • LR_test
  • eud_iwpt
  • iob
  • eud_iwpt_shared_task_bert_finetuning
  • 3.3.1
  • list
  • 3.2.1
  • 3.0.3
  • 3.0.1
  • 3.0.0
  • v1.0.6
  • v1.0.5
  • v1.0.4
  • v1.0.3
  • v1.0.2
  • v1.0.1
  • v1.0.0
33 results

postprocessing.py

Blame
  • postprocessing.py 18.48 KiB
    # TODO lemma remove punctuation - ukrainian
    # TODO lemma remove punctuation - russian
    # TODO consider handling multiple 'case'
    import sys
    
    import conllu
    
    from re import *
    
    rus = compile(u'^из-за$')
    expand = compile('^\d+\.\d+$')
    
    '''
    A script correcting automatically predicted enhanced dependency graphs.
    Running the script: python postprocessing.py cs
    
    You have to modified the paths to the input CoNLL-U file and the output file.
    
    The last argument (e.g. cs) corresponds to the language symbol.
    All language symbols:
    ar (Arabic), bg (Bulgarian), cs (Czech), nl (Dutch), en (English), et (Estonian), fi (Finnish)
    fr (French), it (Italian), lv (Latvian), lt (Lithuanian), pl (Polish), ru (Russian)
    sk (Slovak), sv (Swedish), ta (Tamil), uk (Ukrainian)
    
    There are two main rules:
    1) the first one add case information to the following labels: nmod, obl, acl, advcl. 
    The case information comes from case/mark dependent of the current token and from the morphological feature Case.
    Depending on the language, not all information is added.
    In some languages ('en', 'it', 'nl', 'sv') the lemma of coordinating conjunction (cc) is appendend to the conjunct label (conj). 
    Functions: fix_mod_deps, fix_obj_deps, fix_acl_deps, fix_advcl_deps and fix_conj_deps
    
    2) the second rule correct enhanced edges comming into function words labelled ref, mark, punct, root, case, det, cc, cop, aux
    They should not be assinged other functions. For example, if a token, e.g. "and" is labelled cc (coordinating conjunction), 
    it cannot be simultaneously a subject (nsubj) and if this wrong enhanced edge exists, it should be removed from the graph.
    
    There is one additional rule for Estonian: 
    if the label is nsubj:cop or csubj:cop, the cop sublabel is removed and we have nsubj and csubj, respectively. 
    '''
    
    
    def fix_nmod_deps(dep, token, sentence, relation):
        """
        This function modifies enhanced edges labelled 'nmod'
        """
        label: str
        label, head = dep
    
        # All labels starting with 'relation' are checked
        if not label.startswith(relation):
            return dep
    
        # case_lemma is a (complex) preposition labelled 'case' e.g. 'po' in nmod:po:loc
        # or a (complex) subordinating conjunction labelled 'mark'
        case_lemma = None
        case_tokens = []
        for t in sentence:
            if t["deprel"] in ["case", "mark"] and t["head"] == token["id"]:
                case_tokens.append(t)
                break
    
        if case_tokens:
            fixed_tokens = []
            for t in sentence:
                for c in case_tokens:
                    if t["deprel"] == "fixed" and t["head"] == c["id"]:
                        fixed_tokens.append(t)
    
            if fixed_tokens:
                case_lemma = "_".join(rus.sub('изза', f["lemma"]) for f in quicksort(case_tokens + fixed_tokens))
            else:
                case_lemma = "_".join(rus.sub('изза', f["lemma"]) for f in quicksort(case_tokens))
    
        # case_val is a value of Case, e.g. 'gen' in nmod:gen and 'loc' in nmod:po:loc
        case_val = None
        if token['feats'] is not None:
            if 'Case' in token["feats"]:
                case_val = token["feats"]['Case'].lower()
    
        #TODO: check for other languages
        if language in ['fi'] and label not in ['nmod', 'nmod:poss']:
            return dep
        elif language not in ['fi'] and label not in ['nmod']:
            return dep
        else:
            label_lst = [label]
            if case_lemma:
                label_lst.append(case_lemma)
            if case_val:
                #TODO: check for other languages
                if language not in ['bg', 'en', 'nl', 'sv']:
                    label_lst.append(case_val)
            label = ":".join(label_lst)
    
        # print(label, sentence.metadata["sent_id"])
        return label, head
    
    
    def fix_obl_deps(dep, token, sentence, relation):
        """
        This function modifies enhanced edges labelled 'obl', 'obl:arg', 'obl:rel'
        """
        label: str
        label, head = dep
    
        if not label.startswith(relation):
            return dep
    
        # case_lemma is a (complex) preposition labelled 'case' e.g. 'pod' in obl:pod:loc
        # or a (complex) subordinating conjunction labelled 'mark'
        case_lemma = None
        case_tokens = []
        for t in sentence:
            if t["deprel"] in ["case", "mark"] and t["head"] == token["id"]:
                case_tokens.append(t)
                break
    
        if case_tokens:
            # fixed_token is the lemma of a complex preposition, e.g. 'przypadek' in obl:w_przypadku:gen
            fixed_tokens = []
            for t in sentence:
                for c in case_tokens:
                    if t["deprel"] == "fixed" and t["head"] == c["id"]:
                        fixed_tokens.append(t)
    
            if fixed_tokens:
                case_lemma = "_".join(rus.sub('изза', f["lemma"]) for f in quicksort(case_tokens + fixed_tokens))
            else:
                case_lemma = "_".join(rus.sub('изза', f["lemma"]) for f in quicksort(case_tokens))
    
        # case_val is a value of Case feature, e.g. 'loc' in obl:pod:loc
        case_val = None
        if token['feats'] is not None:
            if 'Case' in token["feats"]:
                case_val = token["feats"]['Case'].lower()
    
        if label not in ['obl', 'obl:arg', 'obl:agent']:
            return dep
        else:
            label_lst = [label]
            if case_lemma:
                label_lst.append(case_lemma)
                if case_val:
                    # TODO: check for other languages
                    if language not in ['bg', 'en', 'lv', 'nl', 'sv']:
                        label_lst.append(case_val)
            # TODO: check it for other languages
            if language not in ['pl', 'sv']:
                if case_val and not case_lemma:
                    if label == token['deprel']:
                        label_lst.append(case_val)
            label = ":".join(label_lst)
    
        # print(label, sentence.metadata["sent_id"])
        return label, head
    
    
    def fix_acl_deps(dep, acl_token, sentence, acl, lang):
        """
        This function modifies enhanced edges labelled 'acl'
        """
        label: str
        label, head = dep
    
        if not label.startswith(acl):
            return dep
    
        if label.startswith("acl:relcl"):
            if lang not in ['uk']:
                return dep
    
        case_lemma = None
        case_tokens = []
        for token in sentence:
            if token["deprel"] == "mark" and token["head"] == acl_token["id"]:
                case_tokens.append(token)
                break
    
        if case_tokens:
            fixed_tokens = []
            for token in sentence:
                if token["deprel"] == "fixed" and token["head"] == quicksort(case_tokens)[0]["id"]:
                    fixed_tokens.append(token)
    
            if fixed_tokens:
                case_lemma = "_".join([t["lemma"] for t in quicksort(case_tokens + fixed_tokens)])
            else:
                case_lemma = quicksort(case_tokens)[0]["lemma"]
    
        if lang in ['uk']:
            if label not in ['acl', 'acl:relcl']:
                return dep
            else:
                label_lst = [label]
                if case_lemma:
                    label_lst.append(case_lemma)
                label = ":".join(label_lst)
        else:
            if label not in ['acl']:
                return dep
            else:
                label_lst = [label]
                if case_lemma:
                    label_lst.append(case_lemma)
                label = ":".join(label_lst)
    
        # print(label, sentence.metadata["sent_id"])
        return label, head
    
    def fix_advcl_deps(dep, advcl_token, sentence, advcl):
        """
        This function modifies enhanced edges labelled 'advcl'
        """
        label: str
        label, head = dep
    
        if not label.startswith(advcl):
            return dep
    
        case_lemma = None
        case_tokens = []
        # TODO: check for other languages
        if language in ['bg', 'lt']:
            for token in sentence:
                if token["deprel"] in ["mark", "case"] and token["head"] == advcl_token["id"]:
                    case_tokens.append(token)
        else:
            for token in sentence:
                if token["deprel"] == "mark" and token["head"] == advcl_token["id"]:
                    case_tokens.append(token)
    
        if case_tokens:
            fixed_tokens = []
            # TODO: check for other languages
            if language not in ['bg', 'nl']:
                for token in sentence:
                    for case in quicksort(case_tokens):
                        if token["deprel"] == "fixed" and token["head"] == case["id"]:
                            fixed_tokens.append(token)
    
            if fixed_tokens:
                case_lemma = "_".join([t["lemma"] for t in quicksort(case_tokens + fixed_tokens)])
            else:
                case_lemma = "_".join([t["lemma"] for t in quicksort(case_tokens)])
    
        if label not in ['advcl']:
            return dep
        else:
            label_lst = [label]
            if case_lemma:
                label_lst.append(case_lemma)
            label = ":".join(label_lst)
    
        # print(label, sentence.metadata["sent_id"])
        return label, head
    
    
    def fix_conj_deps(dep, conj_token, sentence, conj):
        """
        This function modifies enhanced edges labelled 'conj' which should be assined the lemma of cc as sublabel
        """
        label: str
        label, head = dep
    
        if not label.startswith(conj):
            return dep
    
        case_lemma = None
        case_tokens = []
        for token in sentence:
            if token["deprel"] == "cc" and token["head"] == conj_token["id"]:
                case_tokens.append(token)
    
        if case_tokens:
            fixed_tokens = []
            for token in sentence:
                for case in quicksort(case_tokens):
                    if token["deprel"] == "fixed" and token["head"] == case["id"]:
                        fixed_tokens.append(token)
    
            if fixed_tokens:
                case_lemma = "_".join([t["lemma"] for t in quicksort(case_tokens + fixed_tokens)])
            else:
                case_lemma = "_".join([t["lemma"] for t in quicksort(case_tokens)])
    
        if label not in ['conj']:
            return dep
        else:
            label_lst = [label]
            if case_lemma:
                label_lst.append(case_lemma)
            label = ":".join(label_lst)
    
        # print(label, sentence.metadata["sent_id"])
        return label, head
    
    
    
    def quicksort(tokens):
        if len(tokens) <= 1:
            return tokens
        else:
            return quicksort([x for x in tokens[1:] if int(x["id"]) < int(tokens[0]["id"])]) \
                   + [tokens[0]] \
                   + quicksort([y for y in tokens[1:] if int(y["id"]) >= int(tokens[0]["id"])])
    
    
    language = sys.argv[1]
    errors = 0
    
    input_file = f"./token_test/{language}_pred.fixed.conllu"
    output_file = f"./token_test/{language}.nofixed.conllu"
    with open(input_file) as fh:
        with open(output_file, "w") as oh:
            for sentence in conllu.parse_incr(fh):
                for token in sentence:
                    deps = token["deps"]
                    if deps:
                        if language not in ['fr']:
                            for idx, dep in enumerate(deps):
                                assert len(dep) == 2, dep
                                new_dep = fix_obl_deps(dep, token, sentence, "obl")
                                token["deps"][idx] = new_dep
                                if new_dep[0] != dep[0]:
                                    errors += 1
                        if language not in ['fr']:
                            for idx, dep in enumerate(deps):
                                assert len(dep) == 2, dep
                                new_dep = fix_nmod_deps(dep, token, sentence, "nmod")
                                token["deps"][idx] = new_dep
                                if new_dep[0] != dep[0]:
                                    errors += 1
                        # TODO: check for other languages
                        if language not in ['fr', 'lv']:
                            for idx, dep in enumerate(deps):
                                assert len(dep) == 2, dep
                                new_dep = fix_acl_deps(dep, token, sentence, "acl", language)
                                token["deps"][idx] = new_dep
                                if new_dep[0] != dep[0]:
                                    errors += 1
    
                        # TODO: check for other languages
                        if language not in ['fr', 'lv']:
                            for idx, dep in enumerate(deps):
                                assert len(dep) == 2, dep
                                new_dep = fix_advcl_deps(dep, token, sentence, "advcl")
                                token["deps"][idx] = new_dep
                                if new_dep[0] != dep[0]:
                                    errors += 1
                        # TODO: check for other languages
                        if language in ['en', 'it', 'nl', 'sv']:
                            for idx, dep in enumerate(deps):
                                assert len(dep) == 2, dep
                                new_dep = fix_conj_deps(dep, token, sentence, "conj")
                                token["deps"][idx] = new_dep
                                if new_dep[0] != dep[0]:
                                    errors += 1
                        # TODO: check for other languages
                        if language in ['et']:
                            for idx, dep in enumerate(deps):
                                assert len(dep) == 2, dep
                                if token['deprel'] == 'nsubj:cop' and dep[0] == 'nsubj:cop':
                                    new_dep = ('nsubj', dep[1])
                                    token["deps"][idx] = new_dep
                                    if new_dep[0] != dep[0]:
                                        errors += 1
                                if token['deprel'] == 'csubj:cop' and dep[0] == 'csubj:cop':
                                    new_dep = ('csubj', dep[1])
                                    token["deps"][idx] = new_dep
                                    if new_dep[0] != dep[0]:
                                        errors += 1
                        # BELOW ARE THE RULES FOR CORRECTION OF THE FUNCTION WORDS
                        # labelled ref, mark, punct, root, case, det, cc, cop, aux
                        # They should not be assinged other functions
                        #TODO: to check for other languages
                        if language in ['ar', 'bg', 'cs', 'en', 'et', 'fi', 'it', 'lt', 'lv', 'nl', 'pl', 'sk', 'sv', 'ru']:
                            refs = [s for s in deps if s[0] == 'ref']
                            if refs:
                                token["deps"] = refs
                        #TODO: to check for other languages
                        if language in ['ar', 'bg', 'en', 'et', 'fi', 'it', 'lt', 'nl', 'pl', 'sk', 'sv', 'ta', 'uk', 'fr']:
                            marks = [s for s in deps if s[0] == 'mark']
                            if marks and token['deprel'] == 'mark':
                                token["deps"] = marks
                        #TODO: to check for other languages
                        if language in ['ar', 'bg', 'cs', 'en', 'et', 'fi', 'lv', 'nl', 'pl', 'sk', 'sv', 'ta', 'uk', 'fr', 'ru']:
                            puncts = [s for s in deps if s[0] == 'punct' and s[1] == token['head']]
                            if puncts and token['deprel'] == 'punct':
                                token["deps"] = puncts
                        #TODO: to check for other languages
                        if language in ['ar', 'lt', 'pl']:
                            roots = [s for s in deps if s[0] == 'root']
                            if roots and token['deprel'] == 'root':
                                token["deps"] = roots
                        #TODO: to check for other languages
                        if language in ['en', 'ar', 'bg', 'et', 'fi', 'it', 'lt', 'lv', 'nl', 'pl', 'sk', 'sv', 'ta', 'uk', 'fr']:
                            cases = [s for s in deps if s[0] == 'case']
                            if cases and token['deprel'] == 'case':
                                token["deps"] = cases
                        #TODO: to check for other languages
                        if language in ['en', 'ar', 'et', 'fi', 'it', 'lt', 'lv', 'nl', 'pl', 'sk', 'sv', 'ta', 'uk', 'fr', 'ru']:
                            dets = [s for s in deps if s[0] == 'det']
                            if dets and token['deprel'] == 'det':
                                token["deps"] = dets
                        #TODO: to check for other languages
                        if language in ['et', 'fi', 'it', 'lv', 'nl', 'pl', 'sk', 'sv', 'uk', 'fr', 'ar', 'ru', 'ta']:
                            ccs = [s for s in deps if s[0] == 'cc']
                            if ccs and token['deprel'] == 'cc':
                                token["deps"] = ccs
                        #TODO: to check for other languages
                        if language in ['bg', 'fi','et', 'it', 'sk', 'sv', 'uk', 'nl', 'fr', 'ru']:
                            cops = [s for s in deps if s[0] == 'cop']
                            if cops and token['deprel'] == 'cop':
                                token["deps"] = cops
                        #TODO: to check for other languages
                        if language in ['bg', 'et', 'fi', 'it', 'lv', 'pl', 'sv']:
                            auxs = [s for s in deps if s[0] == 'aux']
                            if auxs and token['deprel'] == 'aux':
                                token["deps"] = auxs
    
                        #TODO: to check for other languages
                        if language in ['ar', 'bg', 'cs', 'et', 'fi', 'fr', 'lt', 'lv', 'pl', 'sk', 'sv', 'uk', 'ru', 'ta']:
                            conjs = [s for s in deps if s[0] == 'conj' and s[1] == token['head']]
                            other = [s for s in deps if s[0] != 'conj']
                            if conjs and token['deprel'] == 'conj':
                                token["deps"] = conjs+other
    
                        #TODO: to check for other languages
                        # EXTRA rule 1
                        if language in ['cs', 'et', 'fi', 'lv', 'pl', 'uk']: #ar nl ru
                            # not use for: lt, bg, fr, sk, ta, sv, en
                            deprel = [s for s in deps if s[0] == token['deprel'] and s[1] == token['head']]
                            other_exp = [s for s in deps if type(s[1]) == tuple]
                            other_noexp = [s for s in deps if s[1] != token['head'] and type(s[1]) != tuple]
                            if other_exp:
                                token["deps"] = other_exp+other_noexp
    
                        # EXTRA rule 2
                        if language in ['cs', 'lt', 'pl', 'sk', 'uk']: #ar nl ru
                            conjs = [s for s in deps if s[0] == 'conj' and s[1] == token['head']]
                            if conjs and len(deps) == 1 and len(conjs) == 1:
                                for t in sentence:
                                    if t['id'] == conjs[0][1] and t['deprel'] == 'root':
                                        conjs.append((t['deprel'], t['head']))
                                token["deps"] = conjs
    
                        if language in ['ta']:
                            if token['deprel'] != 'conj':
                                conjs = [s for s in deps if s[0] == 'conj']
                                if conjs:
                                    new_dep = [s for s in deps if s[1] == token['head']]
                                    token["deps"] = new_dep
    
                oh.write(sentence.serialize())
    print(errors)