Skip to content
Snippets Groups Projects
Commit 34e11dde authored by pwalkow's avatar pwalkow
Browse files

Add orths

parent 9ee74cfd
Branches
No related merge requests found
......@@ -3,7 +3,6 @@ import click
import pandas as pd
import os
from tqdm import tqdm
from multiprocessing import cpu_count, Pool
from text_attacks.utils import get_classify_function
from textfooler import Attack, TextFooler, BaseLine, process
......@@ -11,6 +10,7 @@ from textfooler import Attack, TextFooler, BaseLine, process
TEXT = "text"
LEMMAS = "lemmas"
TAGS = "tags"
ORTHS = "orths"
ATTACK_SUMMARY = "attacks_summary"
ATTACK_SUCCEEDED = "attacks_succeeded"
......@@ -35,12 +35,6 @@ DEFAULT_RES = {
}
def spoil_sentence(sentence, lemmas, tags, lang, similarity, max_sub):
attack = TextFooler(lang)
# attack = BaseLine(lang, 0.5, 0.4, 0.3)
return attack.spoil(sentence, [], lemmas, tags, similarity, max_sub)
@click.command()
@click.option(
"--dataset_name",
......@@ -61,62 +55,17 @@ def main(dataset_name: str):
output_path = os.path.join(output_dir, "test.jsonl")
classify = get_classify_function(dataset_name=dataset_name)
dataset_df = pd.read_json(input_file, lines=True)
# dataset_df = dataset_df[:10]
spoiled, results = [], []
similarity, max_sub = 0.95, 1
cpus = cpu_count()
classes = classify(dataset_df[TEXT].tolist())
# used_id = 0
# sent_nbr = len(dataset_df[TEXT])
# with Pool(processes=cpus) as pool:
# for idx in range(0, min(cpus, sent_nbr)):
# sentence, lemmas, tags = dataset_df[TEXT][idx], \
# dataset_df[LEMMAS][idx], \
# dataset_df[TAGS][idx]
#
lang = "en" if dataset_name == "enron_spam" else "pl"
# results.append(pool.apply_async(spoil_sentence, args=[sentence,
# lemmas,
# tags,
# lang,
# similarity,
# max_sub]))
# used_id = idx
# count = len(results)
# while count and used_id < sent_nbr:
# ready = 0
# to_rm = []
# for r in results:
# if r.ready():
# ready += 1
# changed_sent = r.get()
# if changed_sent:
# spoiled.append(process(changed_sent, classes[i], classify))
# to_rm.append(r)
# count = len(results) - ready
# results = [res for res in results if res not in to_rm]
# h_bound = min(used_id + cpus - len(results), sent_nbr)
# for i in range(used_id + 1, h_bound):
# used_id += 1
# sentence, lemmas, tags = dataset_df[TEXT][idx], \
# dataset_df[LEMMAS][idx], \
# dataset_df[TAGS][idx]
#
# results.append(pool.apply_async(spoil_sentence, args=[sentence,
# lemmas,
# tags,
# lang,
# similarity,
# max_sub]))
attack = TextFooler(lang)
for i, cols in tqdm(
dataset_df[[TEXT, LEMMAS, TAGS]].iterrows(), total=len(dataset_df)
dataset_df[[TEXT, LEMMAS, TAGS, ORTHS]].iterrows(), total=len(dataset_df)
):
sentence, lemmas, tags = cols[0], cols[1], cols[2]
changed_sent = spoil_sentence(
sentence, lemmas, tags, lang, similarity, max_sub
)
sentence, lemmas, tags, orths = cols[0], cols[1], cols[2], cols[3]
changed_sent = attack.spoil(sentence, [], lemmas, tags, orths, similarity, max_sub)
if changed_sent:
spoiled.append(process(changed_sent, classes[i], classify))
......
......@@ -15,6 +15,7 @@ MSTAG = "mstag"
TEXT = "text"
LEMMAS = "lemmas"
TAGS = "tags"
ORTHS = "orths"
def tag_sentence(sentence: str, lang: str):
......@@ -41,17 +42,18 @@ def tag_sentence(sentence: str, lang: str):
for line in lines:
tokens.extend(line[TOKENS])
os.remove(downloaded)
lemmas, tags = [], []
lemmas, tags, orths = [], [], []
for token in tokens:
lexeme = token["lexemes"][0]
lemmas.append(lexeme["lemma"])
tags.append(lexeme["mstag"])
return lemmas, tags
lexeme = token[LEXEMES][0]
lemmas.append(lexeme[LEMMA])
tags.append(lexeme[MSTAG])
orths.append(token[ORTH])
return lemmas, tags, orths
def process_file(dataset_df, lang, output_path):
test_with_tags = pd.DataFrame(dataset_df)
lemmas_col, tags_col = [], []
lemmas_col, tags_col, orth_col = [], [], []
cpus = 8
with Pool(processes=cpus) as pool:
results = []
......@@ -62,12 +64,14 @@ def process_file(dataset_df, lang, output_path):
pool.apply_async(tag_sentence, args=[sentence, lang])
)
for res in results:
lemmas, tags = res.get()
lemmas, tags, orths = res.get()
lemmas_col.append(lemmas)
tags_col.append(tags)
orth_col.append(orths)
results = []
test_with_tags[LEMMAS] = lemmas_col
test_with_tags[TAGS] = tags_col
test_with_tags[ORTHS] = orth_col
with open(output_path, mode="wt") as fd:
fd.write(test_with_tags.to_json(orient="records", lines=True))
......@@ -103,10 +107,10 @@ def main(dataset_name: str):
test_with_tags = pd.DataFrame(
pd.read_json(os.path.join(input_dir, file), lines=True)
)
test_with_tags[LEMMAS] = [
"" for _ in range(len(test_with_tags))
]
test_with_tags[TAGS] = ["" for _ in range(len(test_with_tags))]
empty_list = [[] for _ in range(len(test_with_tags)]]
test_with_tags[LEMMAS] = empty_list
test_with_tags[TAGS] = empty_list
test_with_tags[ORTHS] = empty_list
with open(os.path.join(output_dir, file), mode="wt") as fd:
fd.write(
test_with_tags.to_json(orient="records", lines=True)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment