Skip to content
Snippets Groups Projects
postprocessing.py 18.48 KiB
# TODO lemma remove punctuation - ukrainian
# TODO lemma remove punctuation - russian
# TODO consider handling multiple 'case'
import sys

import conllu

from re import *

rus = compile(u'^из-за$')
expand = compile('^\d+\.\d+$')

'''
A script correcting automatically predicted enhanced dependency graphs.
Running the script: python postprocessing.py cs

You have to modified the paths to the input CoNLL-U file and the output file.

The last argument (e.g. cs) corresponds to the language symbol.
All language symbols:
ar (Arabic), bg (Bulgarian), cs (Czech), nl (Dutch), en (English), et (Estonian), fi (Finnish)
fr (French), it (Italian), lv (Latvian), lt (Lithuanian), pl (Polish), ru (Russian)
sk (Slovak), sv (Swedish), ta (Tamil), uk (Ukrainian)

There are two main rules:
1) the first one add case information to the following labels: nmod, obl, acl, advcl. 
The case information comes from case/mark dependent of the current token and from the morphological feature Case.
Depending on the language, not all information is added.
In some languages ('en', 'it', 'nl', 'sv') the lemma of coordinating conjunction (cc) is appendend to the conjunct label (conj). 
Functions: fix_mod_deps, fix_obj_deps, fix_acl_deps, fix_advcl_deps and fix_conj_deps

2) the second rule correct enhanced edges comming into function words labelled ref, mark, punct, root, case, det, cc, cop, aux
They should not be assinged other functions. For example, if a token, e.g. "and" is labelled cc (coordinating conjunction), 
it cannot be simultaneously a subject (nsubj) and if this wrong enhanced edge exists, it should be removed from the graph.

There is one additional rule for Estonian: 
if the label is nsubj:cop or csubj:cop, the cop sublabel is removed and we have nsubj and csubj, respectively. 
'''


def fix_nmod_deps(dep, token, sentence, relation):
    """
    This function modifies enhanced edges labelled 'nmod'
    """
    label: str
    label, head = dep

    # All labels starting with 'relation' are checked
    if not label.startswith(relation):
        return dep

    # case_lemma is a (complex) preposition labelled 'case' e.g. 'po' in nmod:po:loc
    # or a (complex) subordinating conjunction labelled 'mark'
    case_lemma = None
    case_tokens = []
    for t in sentence:
        if t["deprel"] in ["case", "mark"] and t["head"] == token["id"]:
            case_tokens.append(t)
            break

    if case_tokens:
        fixed_tokens = []
        for t in sentence:
            for c in case_tokens:
                if t["deprel"] == "fixed" and t["head"] == c["id"]:
                    fixed_tokens.append(t)

        if fixed_tokens:
            case_lemma = "_".join(rus.sub('изза', f["lemma"]) for f in quicksort(case_tokens + fixed_tokens))
        else:
            case_lemma = "_".join(rus.sub('изза', f["lemma"]) for f in quicksort(case_tokens))

    # case_val is a value of Case, e.g. 'gen' in nmod:gen and 'loc' in nmod:po:loc
    case_val = None
    if token['feats'] is not None:
        if 'Case' in token["feats"]:
            case_val = token["feats"]['Case'].lower()

    #TODO: check for other languages
    if language in ['fi'] and label not in ['nmod', 'nmod:poss']:
        return dep
    elif language not in ['fi'] and label not in ['nmod']:
        return dep
    else:
        label_lst = [label]
        if case_lemma:
            label_lst.append(case_lemma)
        if case_val:
            #TODO: check for other languages
            if language not in ['bg', 'en', 'nl', 'sv']:
                label_lst.append(case_val)
        label = ":".join(label_lst)

    # print(label, sentence.metadata["sent_id"])
    return label, head


def fix_obl_deps(dep, token, sentence, relation):
    """
    This function modifies enhanced edges labelled 'obl', 'obl:arg', 'obl:rel'
    """
    label: str
    label, head = dep

    if not label.startswith(relation):
        return dep

    # case_lemma is a (complex) preposition labelled 'case' e.g. 'pod' in obl:pod:loc
    # or a (complex) subordinating conjunction labelled 'mark'
    case_lemma = None
    case_tokens = []
    for t in sentence:
        if t["deprel"] in ["case", "mark"] and t["head"] == token["id"]:
            case_tokens.append(t)
            break

    if case_tokens:
        # fixed_token is the lemma of a complex preposition, e.g. 'przypadek' in obl:w_przypadku:gen
        fixed_tokens = []
        for t in sentence:
            for c in case_tokens:
                if t["deprel"] == "fixed" and t["head"] == c["id"]:
                    fixed_tokens.append(t)

        if fixed_tokens:
            case_lemma = "_".join(rus.sub('изза', f["lemma"]) for f in quicksort(case_tokens + fixed_tokens))
        else:
            case_lemma = "_".join(rus.sub('изза', f["lemma"]) for f in quicksort(case_tokens))

    # case_val is a value of Case feature, e.g. 'loc' in obl:pod:loc
    case_val = None
    if token['feats'] is not None:
        if 'Case' in token["feats"]:
            case_val = token["feats"]['Case'].lower()

    if label not in ['obl', 'obl:arg', 'obl:agent']:
        return dep
    else:
        label_lst = [label]
        if case_lemma:
            label_lst.append(case_lemma)
            if case_val:
                # TODO: check for other languages
                if language not in ['bg', 'en', 'lv', 'nl', 'sv']:
                    label_lst.append(case_val)
        # TODO: check it for other languages
        if language not in ['pl', 'sv']:
            if case_val and not case_lemma:
                if label == token['deprel']:
                    label_lst.append(case_val)
        label = ":".join(label_lst)

    # print(label, sentence.metadata["sent_id"])
    return label, head


def fix_acl_deps(dep, acl_token, sentence, acl, lang):
    """
    This function modifies enhanced edges labelled 'acl'
    """
    label: str
    label, head = dep

    if not label.startswith(acl):
        return dep

    if label.startswith("acl:relcl"):
        if lang not in ['uk']:
            return dep

    case_lemma = None
    case_tokens = []
    for token in sentence:
        if token["deprel"] == "mark" and token["head"] == acl_token["id"]:
            case_tokens.append(token)
            break

    if case_tokens:
        fixed_tokens = []
        for token in sentence:
            if token["deprel"] == "fixed" and token["head"] == quicksort(case_tokens)[0]["id"]:
                fixed_tokens.append(token)

        if fixed_tokens:
            case_lemma = "_".join([t["lemma"] for t in quicksort(case_tokens + fixed_tokens)])
        else:
            case_lemma = quicksort(case_tokens)[0]["lemma"]

    if lang in ['uk']:
        if label not in ['acl', 'acl:relcl']:
            return dep
        else:
            label_lst = [label]
            if case_lemma:
                label_lst.append(case_lemma)
            label = ":".join(label_lst)
    else:
        if label not in ['acl']:
            return dep
        else:
            label_lst = [label]
            if case_lemma:
                label_lst.append(case_lemma)
            label = ":".join(label_lst)

    # print(label, sentence.metadata["sent_id"])
    return label, head

def fix_advcl_deps(dep, advcl_token, sentence, advcl):
    """
    This function modifies enhanced edges labelled 'advcl'
    """
    label: str
    label, head = dep

    if not label.startswith(advcl):
        return dep

    case_lemma = None
    case_tokens = []
    # TODO: check for other languages
    if language in ['bg', 'lt']:
        for token in sentence:
            if token["deprel"] in ["mark", "case"] and token["head"] == advcl_token["id"]:
                case_tokens.append(token)
    else:
        for token in sentence:
            if token["deprel"] == "mark" and token["head"] == advcl_token["id"]:
                case_tokens.append(token)

    if case_tokens:
        fixed_tokens = []
        # TODO: check for other languages
        if language not in ['bg', 'nl']:
            for token in sentence:
                for case in quicksort(case_tokens):
                    if token["deprel"] == "fixed" and token["head"] == case["id"]:
                        fixed_tokens.append(token)

        if fixed_tokens:
            case_lemma = "_".join([t["lemma"] for t in quicksort(case_tokens + fixed_tokens)])
        else:
            case_lemma = "_".join([t["lemma"] for t in quicksort(case_tokens)])

    if label not in ['advcl']:
        return dep
    else:
        label_lst = [label]
        if case_lemma:
            label_lst.append(case_lemma)
        label = ":".join(label_lst)

    # print(label, sentence.metadata["sent_id"])
    return label, head


def fix_conj_deps(dep, conj_token, sentence, conj):
    """
    This function modifies enhanced edges labelled 'conj' which should be assined the lemma of cc as sublabel
    """
    label: str
    label, head = dep

    if not label.startswith(conj):
        return dep

    case_lemma = None
    case_tokens = []
    for token in sentence:
        if token["deprel"] == "cc" and token["head"] == conj_token["id"]:
            case_tokens.append(token)

    if case_tokens:
        fixed_tokens = []
        for token in sentence:
            for case in quicksort(case_tokens):
                if token["deprel"] == "fixed" and token["head"] == case["id"]:
                    fixed_tokens.append(token)

        if fixed_tokens:
            case_lemma = "_".join([t["lemma"] for t in quicksort(case_tokens + fixed_tokens)])
        else:
            case_lemma = "_".join([t["lemma"] for t in quicksort(case_tokens)])

    if label not in ['conj']:
        return dep
    else:
        label_lst = [label]
        if case_lemma:
            label_lst.append(case_lemma)
        label = ":".join(label_lst)

    # print(label, sentence.metadata["sent_id"])
    return label, head



def quicksort(tokens):
    if len(tokens) <= 1:
        return tokens
    else:
        return quicksort([x for x in tokens[1:] if int(x["id"]) < int(tokens[0]["id"])]) \
               + [tokens[0]] \
               + quicksort([y for y in tokens[1:] if int(y["id"]) >= int(tokens[0]["id"])])


language = sys.argv[1]
errors = 0

input_file = f"./token_test/{language}_pred.fixed.conllu"
output_file = f"./token_test/{language}.nofixed.conllu"
with open(input_file) as fh:
    with open(output_file, "w") as oh:
        for sentence in conllu.parse_incr(fh):
            for token in sentence:
                deps = token["deps"]
                if deps:
                    if language not in ['fr']:
                        for idx, dep in enumerate(deps):
                            assert len(dep) == 2, dep
                            new_dep = fix_obl_deps(dep, token, sentence, "obl")
                            token["deps"][idx] = new_dep
                            if new_dep[0] != dep[0]:
                                errors += 1
                    if language not in ['fr']:
                        for idx, dep in enumerate(deps):
                            assert len(dep) == 2, dep
                            new_dep = fix_nmod_deps(dep, token, sentence, "nmod")
                            token["deps"][idx] = new_dep
                            if new_dep[0] != dep[0]:
                                errors += 1
                    # TODO: check for other languages
                    if language not in ['fr', 'lv']:
                        for idx, dep in enumerate(deps):
                            assert len(dep) == 2, dep
                            new_dep = fix_acl_deps(dep, token, sentence, "acl", language)
                            token["deps"][idx] = new_dep
                            if new_dep[0] != dep[0]:
                                errors += 1

                    # TODO: check for other languages
                    if language not in ['fr', 'lv']:
                        for idx, dep in enumerate(deps):
                            assert len(dep) == 2, dep
                            new_dep = fix_advcl_deps(dep, token, sentence, "advcl")
                            token["deps"][idx] = new_dep
                            if new_dep[0] != dep[0]:
                                errors += 1
                    # TODO: check for other languages
                    if language in ['en', 'it', 'nl', 'sv']:
                        for idx, dep in enumerate(deps):
                            assert len(dep) == 2, dep
                            new_dep = fix_conj_deps(dep, token, sentence, "conj")
                            token["deps"][idx] = new_dep
                            if new_dep[0] != dep[0]:
                                errors += 1
                    # TODO: check for other languages
                    if language in ['et']:
                        for idx, dep in enumerate(deps):
                            assert len(dep) == 2, dep
                            if token['deprel'] == 'nsubj:cop' and dep[0] == 'nsubj:cop':
                                new_dep = ('nsubj', dep[1])
                                token["deps"][idx] = new_dep
                                if new_dep[0] != dep[0]:
                                    errors += 1
                            if token['deprel'] == 'csubj:cop' and dep[0] == 'csubj:cop':
                                new_dep = ('csubj', dep[1])
                                token["deps"][idx] = new_dep
                                if new_dep[0] != dep[0]:
                                    errors += 1
                    # BELOW ARE THE RULES FOR CORRECTION OF THE FUNCTION WORDS
                    # labelled ref, mark, punct, root, case, det, cc, cop, aux
                    # They should not be assinged other functions
                    #TODO: to check for other languages
                    if language in ['ar', 'bg', 'cs', 'en', 'et', 'fi', 'it', 'lt', 'lv', 'nl', 'pl', 'sk', 'sv', 'ru']:
                        refs = [s for s in deps if s[0] == 'ref']
                        if refs:
                            token["deps"] = refs
                    #TODO: to check for other languages
                    if language in ['ar', 'bg', 'en', 'et', 'fi', 'it', 'lt', 'nl', 'pl', 'sk', 'sv', 'ta', 'uk', 'fr']:
                        marks = [s for s in deps if s[0] == 'mark']
                        if marks and token['deprel'] == 'mark':
                            token["deps"] = marks
                    #TODO: to check for other languages
                    if language in ['ar', 'bg', 'cs', 'en', 'et', 'fi', 'lv', 'nl', 'pl', 'sk', 'sv', 'ta', 'uk', 'fr', 'ru']:
                        puncts = [s for s in deps if s[0] == 'punct' and s[1] == token['head']]
                        if puncts and token['deprel'] == 'punct':
                            token["deps"] = puncts
                    #TODO: to check for other languages
                    if language in ['ar', 'lt', 'pl']:
                        roots = [s for s in deps if s[0] == 'root']
                        if roots and token['deprel'] == 'root':
                            token["deps"] = roots
                    #TODO: to check for other languages
                    if language in ['en', 'ar', 'bg', 'et', 'fi', 'it', 'lt', 'lv', 'nl', 'pl', 'sk', 'sv', 'ta', 'uk', 'fr']:
                        cases = [s for s in deps if s[0] == 'case']
                        if cases and token['deprel'] == 'case':
                            token["deps"] = cases
                    #TODO: to check for other languages
                    if language in ['en', 'ar', 'et', 'fi', 'it', 'lt', 'lv', 'nl', 'pl', 'sk', 'sv', 'ta', 'uk', 'fr', 'ru']:
                        dets = [s for s in deps if s[0] == 'det']
                        if dets and token['deprel'] == 'det':
                            token["deps"] = dets
                    #TODO: to check for other languages
                    if language in ['et', 'fi', 'it', 'lv', 'nl', 'pl', 'sk', 'sv', 'uk', 'fr', 'ar', 'ru', 'ta']:
                        ccs = [s for s in deps if s[0] == 'cc']
                        if ccs and token['deprel'] == 'cc':
                            token["deps"] = ccs
                    #TODO: to check for other languages
                    if language in ['bg', 'fi','et', 'it', 'sk', 'sv', 'uk', 'nl', 'fr', 'ru']:
                        cops = [s for s in deps if s[0] == 'cop']
                        if cops and token['deprel'] == 'cop':
                            token["deps"] = cops
                    #TODO: to check for other languages
                    if language in ['bg', 'et', 'fi', 'it', 'lv', 'pl', 'sv']:
                        auxs = [s for s in deps if s[0] == 'aux']
                        if auxs and token['deprel'] == 'aux':
                            token["deps"] = auxs

                    #TODO: to check for other languages
                    if language in ['ar', 'bg', 'cs', 'et', 'fi', 'fr', 'lt', 'lv', 'pl', 'sk', 'sv', 'uk', 'ru', 'ta']:
                        conjs = [s for s in deps if s[0] == 'conj' and s[1] == token['head']]
                        other = [s for s in deps if s[0] != 'conj']
                        if conjs and token['deprel'] == 'conj':
                            token["deps"] = conjs+other

                    #TODO: to check for other languages
                    # EXTRA rule 1
                    if language in ['cs', 'et', 'fi', 'lv', 'pl', 'uk']: #ar nl ru
                        # not use for: lt, bg, fr, sk, ta, sv, en
                        deprel = [s for s in deps if s[0] == token['deprel'] and s[1] == token['head']]
                        other_exp = [s for s in deps if type(s[1]) == tuple]
                        other_noexp = [s for s in deps if s[1] != token['head'] and type(s[1]) != tuple]
                        if other_exp:
                            token["deps"] = other_exp+other_noexp

                    # EXTRA rule 2
                    if language in ['cs', 'lt', 'pl', 'sk', 'uk']: #ar nl ru
                        conjs = [s for s in deps if s[0] == 'conj' and s[1] == token['head']]
                        if conjs and len(deps) == 1 and len(conjs) == 1:
                            for t in sentence:
                                if t['id'] == conjs[0][1] and t['deprel'] == 'root':
                                    conjs.append((t['deprel'], t['head']))
                            token["deps"] = conjs

                    if language in ['ta']:
                        if token['deprel'] != 'conj':
                            conjs = [s for s in deps if s[0] == 'conj']
                            if conjs:
                                new_dep = [s for s in deps if s[1] == token['head']]
                                token["deps"] = new_dep

            oh.write(sentence.serialize())
print(errors)