Skip to content
Snippets Groups Projects
Commit 34e11dde authored by pwalkow's avatar pwalkow
Browse files

Add orths

parent 9ee74cfd
No related branches found
No related tags found
No related merge requests found
......@@ -3,7 +3,6 @@ import click
import pandas as pd
import os
from tqdm import tqdm
from multiprocessing import cpu_count, Pool
from text_attacks.utils import get_classify_function
from textfooler import Attack, TextFooler, BaseLine, process
......@@ -11,6 +10,7 @@ from textfooler import Attack, TextFooler, BaseLine, process
TEXT = "text"
LEMMAS = "lemmas"
TAGS = "tags"
ORTHS = "orths"
ATTACK_SUMMARY = "attacks_summary"
ATTACK_SUCCEEDED = "attacks_succeeded"
......@@ -35,12 +35,6 @@ DEFAULT_RES = {
}
def spoil_sentence(sentence, lemmas, tags, lang, similarity, max_sub):
attack = TextFooler(lang)
# attack = BaseLine(lang, 0.5, 0.4, 0.3)
return attack.spoil(sentence, [], lemmas, tags, similarity, max_sub)
@click.command()
@click.option(
"--dataset_name",
......@@ -61,62 +55,17 @@ def main(dataset_name: str):
output_path = os.path.join(output_dir, "test.jsonl")
classify = get_classify_function(dataset_name=dataset_name)
dataset_df = pd.read_json(input_file, lines=True)
# dataset_df = dataset_df[:10]
spoiled, results = [], []
similarity, max_sub = 0.95, 1
cpus = cpu_count()
classes = classify(dataset_df[TEXT].tolist())
# used_id = 0
# sent_nbr = len(dataset_df[TEXT])
# with Pool(processes=cpus) as pool:
# for idx in range(0, min(cpus, sent_nbr)):
# sentence, lemmas, tags = dataset_df[TEXT][idx], \
# dataset_df[LEMMAS][idx], \
# dataset_df[TAGS][idx]
#
lang = "en" if dataset_name == "enron_spam" else "pl"
# results.append(pool.apply_async(spoil_sentence, args=[sentence,
# lemmas,
# tags,
# lang,
# similarity,
# max_sub]))
# used_id = idx
# count = len(results)
# while count and used_id < sent_nbr:
# ready = 0
# to_rm = []
# for r in results:
# if r.ready():
# ready += 1
# changed_sent = r.get()
# if changed_sent:
# spoiled.append(process(changed_sent, classes[i], classify))
# to_rm.append(r)
# count = len(results) - ready
# results = [res for res in results if res not in to_rm]
# h_bound = min(used_id + cpus - len(results), sent_nbr)
# for i in range(used_id + 1, h_bound):
# used_id += 1
# sentence, lemmas, tags = dataset_df[TEXT][idx], \
# dataset_df[LEMMAS][idx], \
# dataset_df[TAGS][idx]
#
# results.append(pool.apply_async(spoil_sentence, args=[sentence,
# lemmas,
# tags,
# lang,
# similarity,
# max_sub]))
attack = TextFooler(lang)
for i, cols in tqdm(
dataset_df[[TEXT, LEMMAS, TAGS]].iterrows(), total=len(dataset_df)
dataset_df[[TEXT, LEMMAS, TAGS, ORTHS]].iterrows(), total=len(dataset_df)
):
sentence, lemmas, tags = cols[0], cols[1], cols[2]
changed_sent = spoil_sentence(
sentence, lemmas, tags, lang, similarity, max_sub
)
sentence, lemmas, tags, orths = cols[0], cols[1], cols[2], cols[3]
changed_sent = attack.spoil(sentence, [], lemmas, tags, orths, similarity, max_sub)
if changed_sent:
spoiled.append(process(changed_sent, classes[i], classify))
......
......@@ -15,6 +15,7 @@ MSTAG = "mstag"
TEXT = "text"
LEMMAS = "lemmas"
TAGS = "tags"
ORTHS = "orths"
def tag_sentence(sentence: str, lang: str):
......@@ -41,17 +42,18 @@ def tag_sentence(sentence: str, lang: str):
for line in lines:
tokens.extend(line[TOKENS])
os.remove(downloaded)
lemmas, tags = [], []
lemmas, tags, orths = [], [], []
for token in tokens:
lexeme = token["lexemes"][0]
lemmas.append(lexeme["lemma"])
tags.append(lexeme["mstag"])
return lemmas, tags
lexeme = token[LEXEMES][0]
lemmas.append(lexeme[LEMMA])
tags.append(lexeme[MSTAG])
orths.append(token[ORTH])
return lemmas, tags, orths
def process_file(dataset_df, lang, output_path):
test_with_tags = pd.DataFrame(dataset_df)
lemmas_col, tags_col = [], []
lemmas_col, tags_col, orth_col = [], [], []
cpus = 8
with Pool(processes=cpus) as pool:
results = []
......@@ -62,12 +64,14 @@ def process_file(dataset_df, lang, output_path):
pool.apply_async(tag_sentence, args=[sentence, lang])
)
for res in results:
lemmas, tags = res.get()
lemmas, tags, orths = res.get()
lemmas_col.append(lemmas)
tags_col.append(tags)
orth_col.append(orths)
results = []
test_with_tags[LEMMAS] = lemmas_col
test_with_tags[TAGS] = tags_col
test_with_tags[ORTHS] = orth_col
with open(output_path, mode="wt") as fd:
fd.write(test_with_tags.to_json(orient="records", lines=True))
......@@ -103,10 +107,10 @@ def main(dataset_name: str):
test_with_tags = pd.DataFrame(
pd.read_json(os.path.join(input_dir, file), lines=True)
)
test_with_tags[LEMMAS] = [
"" for _ in range(len(test_with_tags))
]
test_with_tags[TAGS] = ["" for _ in range(len(test_with_tags))]
empty_list = [[] for _ in range(len(test_with_tags)]]
test_with_tags[LEMMAS] = empty_list
test_with_tags[TAGS] = empty_list
test_with_tags[ORTHS] = empty_list
with open(os.path.join(output_dir, file), mode="wt") as fd:
fd.write(
test_with_tags.to_json(orient="records", lines=True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment