From 7ddc4b0ee23a42601b27c97435cf51c5ac5068d0 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Thu, 5 Aug 2021 10:43:50 +0200
Subject: [PATCH] Add postprocessing EUD script.

---
 scripts/postprocessing.py | 454 ++++++++++++++++++++++++++++++++++++++
 1 file changed, 454 insertions(+)
 create mode 100644 scripts/postprocessing.py

diff --git a/scripts/postprocessing.py b/scripts/postprocessing.py
new file mode 100644
index 0000000..2f4da16
--- /dev/null
+++ b/scripts/postprocessing.py
@@ -0,0 +1,454 @@
+# TODO lemma remove punctuation - ukrainian
+# TODO lemma remove punctuation - russian
+# TODO consider handling multiple 'case'
+import sys
+
+import conllu
+
+from re import *
+
+rus = compile(u'^из-за$')
+expand = compile('^\d+\.\d+$')
+
+'''
+A script correcting automatically predicted enhanced dependency graphs.
+Running the script: python postprocessing.py cs
+
+You have to modified the paths to the input CoNLL-U file and the output file.
+
+The last argument (e.g. cs) corresponds to the language symbol.
+All language symbols:
+ar (Arabic), bg (Bulgarian), cs (Czech), nl (Dutch), en (English), et (Estonian), fi (Finnish)
+fr (French), it (Italian), lv (Latvian), lt (Lithuanian), pl (Polish), ru (Russian)
+sk (Slovak), sv (Swedish), ta (Tamil), uk (Ukrainian)
+
+There are two main rules:
+1) the first one add case information to the following labels: nmod, obl, acl, advcl. 
+The case information comes from case/mark dependent of the current token and from the morphological feature Case.
+Depending on the language, not all information is added.
+In some languages ('en', 'it', 'nl', 'sv') the lemma of coordinating conjunction (cc) is appendend to the conjunct label (conj). 
+Functions: fix_mod_deps, fix_obj_deps, fix_acl_deps, fix_advcl_deps and fix_conj_deps
+
+2) the second rule correct enhanced edges comming into function words labelled ref, mark, punct, root, case, det, cc, cop, aux
+They should not be assinged other functions. For example, if a token, e.g. "and" is labelled cc (coordinating conjunction), 
+it cannot be simultaneously a subject (nsubj) and if this wrong enhanced edge exists, it should be removed from the graph.
+
+There is one additional rule for Estonian: 
+if the label is nsubj:cop or csubj:cop, the cop sublabel is removed and we have nsubj and csubj, respectively. 
+'''
+
+
+def fix_nmod_deps(dep, token, sentence, relation):
+    """
+    This function modifies enhanced edges labelled 'nmod'
+    """
+    label: str
+    label, head = dep
+
+    # All labels starting with 'relation' are checked
+    if not label.startswith(relation):
+        return dep
+
+    # case_lemma is a (complex) preposition labelled 'case' e.g. 'po' in nmod:po:loc
+    # or a (complex) subordinating conjunction labelled 'mark'
+    case_lemma = None
+    case_tokens = []
+    for t in sentence:
+        if t["deprel"] in ["case", "mark"] and t["head"] == token["id"]:
+            case_tokens.append(t)
+            break
+
+    if case_tokens:
+        fixed_tokens = []
+        for t in sentence:
+            for c in case_tokens:
+                if t["deprel"] == "fixed" and t["head"] == c["id"]:
+                    fixed_tokens.append(t)
+
+        if fixed_tokens:
+            case_lemma = "_".join(rus.sub('изза', f["lemma"]) for f in quicksort(case_tokens + fixed_tokens))
+        else:
+            case_lemma = "_".join(rus.sub('изза', f["lemma"]) for f in quicksort(case_tokens))
+
+    # case_val is a value of Case, e.g. 'gen' in nmod:gen and 'loc' in nmod:po:loc
+    case_val = None
+    if token['feats'] is not None:
+        if 'Case' in token["feats"]:
+            case_val = token["feats"]['Case'].lower()
+
+    #TODO: check for other languages
+    if language in ['fi'] and label not in ['nmod', 'nmod:poss']:
+        return dep
+    elif language not in ['fi'] and label not in ['nmod']:
+        return dep
+    else:
+        label_lst = [label]
+        if case_lemma:
+            label_lst.append(case_lemma)
+        if case_val:
+            #TODO: check for other languages
+            if language not in ['bg', 'en', 'nl', 'sv']:
+                label_lst.append(case_val)
+        label = ":".join(label_lst)
+
+    # print(label, sentence.metadata["sent_id"])
+    return label, head
+
+
+def fix_obl_deps(dep, token, sentence, relation):
+    """
+    This function modifies enhanced edges labelled 'obl', 'obl:arg', 'obl:rel'
+    """
+    label: str
+    label, head = dep
+
+    if not label.startswith(relation):
+        return dep
+
+    # case_lemma is a (complex) preposition labelled 'case' e.g. 'pod' in obl:pod:loc
+    # or a (complex) subordinating conjunction labelled 'mark'
+    case_lemma = None
+    case_tokens = []
+    for t in sentence:
+        if t["deprel"] in ["case", "mark"] and t["head"] == token["id"]:
+            case_tokens.append(t)
+            break
+
+    if case_tokens:
+        # fixed_token is the lemma of a complex preposition, e.g. 'przypadek' in obl:w_przypadku:gen
+        fixed_tokens = []
+        for t in sentence:
+            for c in case_tokens:
+                if t["deprel"] == "fixed" and t["head"] == c["id"]:
+                    fixed_tokens.append(t)
+
+        if fixed_tokens:
+            case_lemma = "_".join(rus.sub('изза', f["lemma"]) for f in quicksort(case_tokens + fixed_tokens))
+        else:
+            case_lemma = "_".join(rus.sub('изза', f["lemma"]) for f in quicksort(case_tokens))
+
+    # case_val is a value of Case feature, e.g. 'loc' in obl:pod:loc
+    case_val = None
+    if token['feats'] is not None:
+        if 'Case' in token["feats"]:
+            case_val = token["feats"]['Case'].lower()
+
+    if label not in ['obl', 'obl:arg', 'obl:agent']:
+        return dep
+    else:
+        label_lst = [label]
+        if case_lemma:
+            label_lst.append(case_lemma)
+            if case_val:
+                # TODO: check for other languages
+                if language not in ['bg', 'en', 'lv', 'nl', 'sv']:
+                    label_lst.append(case_val)
+        # TODO: check it for other languages
+        if language not in ['pl', 'sv']:
+            if case_val and not case_lemma:
+                if label == token['deprel']:
+                    label_lst.append(case_val)
+        label = ":".join(label_lst)
+
+    # print(label, sentence.metadata["sent_id"])
+    return label, head
+
+
+def fix_acl_deps(dep, acl_token, sentence, acl, lang):
+    """
+    This function modifies enhanced edges labelled 'acl'
+    """
+    label: str
+    label, head = dep
+
+    if not label.startswith(acl):
+        return dep
+
+    if label.startswith("acl:relcl"):
+        if lang not in ['uk']:
+            return dep
+
+    case_lemma = None
+    case_tokens = []
+    for token in sentence:
+        if token["deprel"] == "mark" and token["head"] == acl_token["id"]:
+            case_tokens.append(token)
+            break
+
+    if case_tokens:
+        fixed_tokens = []
+        for token in sentence:
+            if token["deprel"] == "fixed" and token["head"] == quicksort(case_tokens)[0]["id"]:
+                fixed_tokens.append(token)
+
+        if fixed_tokens:
+            case_lemma = "_".join([t["lemma"] for t in quicksort(case_tokens + fixed_tokens)])
+        else:
+            case_lemma = quicksort(case_tokens)[0]["lemma"]
+
+    if lang in ['uk']:
+        if label not in ['acl', 'acl:relcl']:
+            return dep
+        else:
+            label_lst = [label]
+            if case_lemma:
+                label_lst.append(case_lemma)
+            label = ":".join(label_lst)
+    else:
+        if label not in ['acl']:
+            return dep
+        else:
+            label_lst = [label]
+            if case_lemma:
+                label_lst.append(case_lemma)
+            label = ":".join(label_lst)
+
+    # print(label, sentence.metadata["sent_id"])
+    return label, head
+
+def fix_advcl_deps(dep, advcl_token, sentence, advcl):
+    """
+    This function modifies enhanced edges labelled 'advcl'
+    """
+    label: str
+    label, head = dep
+
+    if not label.startswith(advcl):
+        return dep
+
+    case_lemma = None
+    case_tokens = []
+    # TODO: check for other languages
+    if language in ['bg', 'lt']:
+        for token in sentence:
+            if token["deprel"] in ["mark", "case"] and token["head"] == advcl_token["id"]:
+                case_tokens.append(token)
+    else:
+        for token in sentence:
+            if token["deprel"] == "mark" and token["head"] == advcl_token["id"]:
+                case_tokens.append(token)
+
+    if case_tokens:
+        fixed_tokens = []
+        # TODO: check for other languages
+        if language not in ['bg', 'nl']:
+            for token in sentence:
+                for case in quicksort(case_tokens):
+                    if token["deprel"] == "fixed" and token["head"] == case["id"]:
+                        fixed_tokens.append(token)
+
+        if fixed_tokens:
+            case_lemma = "_".join([t["lemma"] for t in quicksort(case_tokens + fixed_tokens)])
+        else:
+            case_lemma = "_".join([t["lemma"] for t in quicksort(case_tokens)])
+
+    if label not in ['advcl']:
+        return dep
+    else:
+        label_lst = [label]
+        if case_lemma:
+            label_lst.append(case_lemma)
+        label = ":".join(label_lst)
+
+    # print(label, sentence.metadata["sent_id"])
+    return label, head
+
+
+def fix_conj_deps(dep, conj_token, sentence, conj):
+    """
+    This function modifies enhanced edges labelled 'conj' which should be assined the lemma of cc as sublabel
+    """
+    label: str
+    label, head = dep
+
+    if not label.startswith(conj):
+        return dep
+
+    case_lemma = None
+    case_tokens = []
+    for token in sentence:
+        if token["deprel"] == "cc" and token["head"] == conj_token["id"]:
+            case_tokens.append(token)
+
+    if case_tokens:
+        fixed_tokens = []
+        for token in sentence:
+            for case in quicksort(case_tokens):
+                if token["deprel"] == "fixed" and token["head"] == case["id"]:
+                    fixed_tokens.append(token)
+
+        if fixed_tokens:
+            case_lemma = "_".join([t["lemma"] for t in quicksort(case_tokens + fixed_tokens)])
+        else:
+            case_lemma = "_".join([t["lemma"] for t in quicksort(case_tokens)])
+
+    if label not in ['conj']:
+        return dep
+    else:
+        label_lst = [label]
+        if case_lemma:
+            label_lst.append(case_lemma)
+        label = ":".join(label_lst)
+
+    # print(label, sentence.metadata["sent_id"])
+    return label, head
+
+
+
+def quicksort(tokens):
+    if len(tokens) <= 1:
+        return tokens
+    else:
+        return quicksort([x for x in tokens[1:] if int(x["id"]) < int(tokens[0]["id"])]) \
+               + [tokens[0]] \
+               + quicksort([y for y in tokens[1:] if int(y["id"]) >= int(tokens[0]["id"])])
+
+
+language = sys.argv[1]
+errors = 0
+
+input_file = f"./token_test/{language}_pred.fixed.conllu"
+output_file = f"./token_test/{language}.nofixed.conllu"
+with open(input_file) as fh:
+    with open(output_file, "w") as oh:
+        for sentence in conllu.parse_incr(fh):
+            for token in sentence:
+                deps = token["deps"]
+                if deps:
+                    if language not in ['fr']:
+                        for idx, dep in enumerate(deps):
+                            assert len(dep) == 2, dep
+                            new_dep = fix_obl_deps(dep, token, sentence, "obl")
+                            token["deps"][idx] = new_dep
+                            if new_dep[0] != dep[0]:
+                                errors += 1
+                    if language not in ['fr']:
+                        for idx, dep in enumerate(deps):
+                            assert len(dep) == 2, dep
+                            new_dep = fix_nmod_deps(dep, token, sentence, "nmod")
+                            token["deps"][idx] = new_dep
+                            if new_dep[0] != dep[0]:
+                                errors += 1
+                    # TODO: check for other languages
+                    if language not in ['fr', 'lv']:
+                        for idx, dep in enumerate(deps):
+                            assert len(dep) == 2, dep
+                            new_dep = fix_acl_deps(dep, token, sentence, "acl", language)
+                            token["deps"][idx] = new_dep
+                            if new_dep[0] != dep[0]:
+                                errors += 1
+
+                    # TODO: check for other languages
+                    if language not in ['fr', 'lv']:
+                        for idx, dep in enumerate(deps):
+                            assert len(dep) == 2, dep
+                            new_dep = fix_advcl_deps(dep, token, sentence, "advcl")
+                            token["deps"][idx] = new_dep
+                            if new_dep[0] != dep[0]:
+                                errors += 1
+                    # TODO: check for other languages
+                    if language in ['en', 'it', 'nl', 'sv']:
+                        for idx, dep in enumerate(deps):
+                            assert len(dep) == 2, dep
+                            new_dep = fix_conj_deps(dep, token, sentence, "conj")
+                            token["deps"][idx] = new_dep
+                            if new_dep[0] != dep[0]:
+                                errors += 1
+                    # TODO: check for other languages
+                    if language in ['et']:
+                        for idx, dep in enumerate(deps):
+                            assert len(dep) == 2, dep
+                            if token['deprel'] == 'nsubj:cop' and dep[0] == 'nsubj:cop':
+                                new_dep = ('nsubj', dep[1])
+                                token["deps"][idx] = new_dep
+                                if new_dep[0] != dep[0]:
+                                    errors += 1
+                            if token['deprel'] == 'csubj:cop' and dep[0] == 'csubj:cop':
+                                new_dep = ('csubj', dep[1])
+                                token["deps"][idx] = new_dep
+                                if new_dep[0] != dep[0]:
+                                    errors += 1
+                    # BELOW ARE THE RULES FOR CORRECTION OF THE FUNCTION WORDS
+                    # labelled ref, mark, punct, root, case, det, cc, cop, aux
+                    # They should not be assinged other functions
+                    #TODO: to check for other languages
+                    if language in ['ar', 'bg', 'cs', 'en', 'et', 'fi', 'it', 'lt', 'lv', 'nl', 'pl', 'sk', 'sv', 'ru']:
+                        refs = [s for s in deps if s[0] == 'ref']
+                        if refs:
+                            token["deps"] = refs
+                    #TODO: to check for other languages
+                    if language in ['ar', 'bg', 'en', 'et', 'fi', 'it', 'lt', 'nl', 'pl', 'sk', 'sv', 'ta', 'uk', 'fr']:
+                        marks = [s for s in deps if s[0] == 'mark']
+                        if marks and token['deprel'] == 'mark':
+                            token["deps"] = marks
+                    #TODO: to check for other languages
+                    if language in ['ar', 'bg', 'cs', 'en', 'et', 'fi', 'lv', 'nl', 'pl', 'sk', 'sv', 'ta', 'uk', 'fr', 'ru']:
+                        puncts = [s for s in deps if s[0] == 'punct' and s[1] == token['head']]
+                        if puncts and token['deprel'] == 'punct':
+                            token["deps"] = puncts
+                    #TODO: to check for other languages
+                    if language in ['ar', 'lt', 'pl']:
+                        roots = [s for s in deps if s[0] == 'root']
+                        if roots and token['deprel'] == 'root':
+                            token["deps"] = roots
+                    #TODO: to check for other languages
+                    if language in ['en', 'ar', 'bg', 'et', 'fi', 'it', 'lt', 'lv', 'nl', 'pl', 'sk', 'sv', 'ta', 'uk', 'fr']:
+                        cases = [s for s in deps if s[0] == 'case']
+                        if cases and token['deprel'] == 'case':
+                            token["deps"] = cases
+                    #TODO: to check for other languages
+                    if language in ['en', 'ar', 'et', 'fi', 'it', 'lt', 'lv', 'nl', 'pl', 'sk', 'sv', 'ta', 'uk', 'fr', 'ru']:
+                        dets = [s for s in deps if s[0] == 'det']
+                        if dets and token['deprel'] == 'det':
+                            token["deps"] = dets
+                    #TODO: to check for other languages
+                    if language in ['et', 'fi', 'it', 'lv', 'nl', 'pl', 'sk', 'sv', 'uk', 'fr', 'ar', 'ru', 'ta']:
+                        ccs = [s for s in deps if s[0] == 'cc']
+                        if ccs and token['deprel'] == 'cc':
+                            token["deps"] = ccs
+                    #TODO: to check for other languages
+                    if language in ['bg', 'fi','et', 'it', 'sk', 'sv', 'uk', 'nl', 'fr', 'ru']:
+                        cops = [s for s in deps if s[0] == 'cop']
+                        if cops and token['deprel'] == 'cop':
+                            token["deps"] = cops
+                    #TODO: to check for other languages
+                    if language in ['bg', 'et', 'fi', 'it', 'lv', 'pl', 'sv']:
+                        auxs = [s for s in deps if s[0] == 'aux']
+                        if auxs and token['deprel'] == 'aux':
+                            token["deps"] = auxs
+
+                    #TODO: to check for other languages
+                    if language in ['ar', 'bg', 'cs', 'et', 'fi', 'fr', 'lt', 'lv', 'pl', 'sk', 'sv', 'uk', 'ru', 'ta']:
+                        conjs = [s for s in deps if s[0] == 'conj' and s[1] == token['head']]
+                        other = [s for s in deps if s[0] != 'conj']
+                        if conjs and token['deprel'] == 'conj':
+                            token["deps"] = conjs+other
+
+                    #TODO: to check for other languages
+                    # EXTRA rule 1
+                    if language in ['cs', 'et', 'fi', 'lv', 'pl', 'uk']: #ar nl ru
+                        # not use for: lt, bg, fr, sk, ta, sv, en
+                        deprel = [s for s in deps if s[0] == token['deprel'] and s[1] == token['head']]
+                        other_exp = [s for s in deps if type(s[1]) == tuple]
+                        other_noexp = [s for s in deps if s[1] != token['head'] and type(s[1]) != tuple]
+                        if other_exp:
+                            token["deps"] = other_exp+other_noexp
+
+                    # EXTRA rule 2
+                    if language in ['cs', 'lt', 'pl', 'sk', 'uk']: #ar nl ru
+                        conjs = [s for s in deps if s[0] == 'conj' and s[1] == token['head']]
+                        if conjs and len(deps) == 1 and len(conjs) == 1:
+                            for t in sentence:
+                                if t['id'] == conjs[0][1] and t['deprel'] == 'root':
+                                    conjs.append((t['deprel'], t['head']))
+                            token["deps"] = conjs
+
+                    if language in ['ta']:
+                        if token['deprel'] != 'conj':
+                            conjs = [s for s in deps if s[0] == 'conj']
+                            if conjs:
+                                new_dep = [s for s in deps if s[1] == token['head']]
+                                token["deps"] = new_dep
+
+            oh.write(sentence.serialize())
+print(errors)
-- 
GitLab