Skip to content
Snippets Groups Projects
Commit 34e11dde authored by pwalkow's avatar pwalkow
Browse files

Add orths

parent 9ee74cfd
No related branches found
No related tags found
No related merge requests found
...@@ -3,7 +3,6 @@ import click ...@@ -3,7 +3,6 @@ import click
import pandas as pd import pandas as pd
import os import os
from tqdm import tqdm from tqdm import tqdm
from multiprocessing import cpu_count, Pool
from text_attacks.utils import get_classify_function from text_attacks.utils import get_classify_function
from textfooler import Attack, TextFooler, BaseLine, process from textfooler import Attack, TextFooler, BaseLine, process
...@@ -11,6 +10,7 @@ from textfooler import Attack, TextFooler, BaseLine, process ...@@ -11,6 +10,7 @@ from textfooler import Attack, TextFooler, BaseLine, process
TEXT = "text" TEXT = "text"
LEMMAS = "lemmas" LEMMAS = "lemmas"
TAGS = "tags" TAGS = "tags"
ORTHS = "orths"
ATTACK_SUMMARY = "attacks_summary" ATTACK_SUMMARY = "attacks_summary"
ATTACK_SUCCEEDED = "attacks_succeeded" ATTACK_SUCCEEDED = "attacks_succeeded"
...@@ -35,12 +35,6 @@ DEFAULT_RES = { ...@@ -35,12 +35,6 @@ DEFAULT_RES = {
} }
def spoil_sentence(sentence, lemmas, tags, lang, similarity, max_sub):
attack = TextFooler(lang)
# attack = BaseLine(lang, 0.5, 0.4, 0.3)
return attack.spoil(sentence, [], lemmas, tags, similarity, max_sub)
@click.command() @click.command()
@click.option( @click.option(
"--dataset_name", "--dataset_name",
...@@ -61,62 +55,17 @@ def main(dataset_name: str): ...@@ -61,62 +55,17 @@ def main(dataset_name: str):
output_path = os.path.join(output_dir, "test.jsonl") output_path = os.path.join(output_dir, "test.jsonl")
classify = get_classify_function(dataset_name=dataset_name) classify = get_classify_function(dataset_name=dataset_name)
dataset_df = pd.read_json(input_file, lines=True) dataset_df = pd.read_json(input_file, lines=True)
# dataset_df = dataset_df[:10]
spoiled, results = [], [] spoiled, results = [], []
similarity, max_sub = 0.95, 1 similarity, max_sub = 0.95, 1
cpus = cpu_count()
classes = classify(dataset_df[TEXT].tolist()) classes = classify(dataset_df[TEXT].tolist())
# used_id = 0 attack = TextFooler(lang)
# sent_nbr = len(dataset_df[TEXT])
# with Pool(processes=cpus) as pool:
# for idx in range(0, min(cpus, sent_nbr)):
# sentence, lemmas, tags = dataset_df[TEXT][idx], \
# dataset_df[LEMMAS][idx], \
# dataset_df[TAGS][idx]
#
lang = "en" if dataset_name == "enron_spam" else "pl"
# results.append(pool.apply_async(spoil_sentence, args=[sentence,
# lemmas,
# tags,
# lang,
# similarity,
# max_sub]))
# used_id = idx
# count = len(results)
# while count and used_id < sent_nbr:
# ready = 0
# to_rm = []
# for r in results:
# if r.ready():
# ready += 1
# changed_sent = r.get()
# if changed_sent:
# spoiled.append(process(changed_sent, classes[i], classify))
# to_rm.append(r)
# count = len(results) - ready
# results = [res for res in results if res not in to_rm]
# h_bound = min(used_id + cpus - len(results), sent_nbr)
# for i in range(used_id + 1, h_bound):
# used_id += 1
# sentence, lemmas, tags = dataset_df[TEXT][idx], \
# dataset_df[LEMMAS][idx], \
# dataset_df[TAGS][idx]
#
# results.append(pool.apply_async(spoil_sentence, args=[sentence,
# lemmas,
# tags,
# lang,
# similarity,
# max_sub]))
for i, cols in tqdm( for i, cols in tqdm(
dataset_df[[TEXT, LEMMAS, TAGS]].iterrows(), total=len(dataset_df) dataset_df[[TEXT, LEMMAS, TAGS, ORTHS]].iterrows(), total=len(dataset_df)
): ):
sentence, lemmas, tags = cols[0], cols[1], cols[2] sentence, lemmas, tags, orths = cols[0], cols[1], cols[2], cols[3]
changed_sent = spoil_sentence( changed_sent = attack.spoil(sentence, [], lemmas, tags, orths, similarity, max_sub)
sentence, lemmas, tags, lang, similarity, max_sub
)
if changed_sent: if changed_sent:
spoiled.append(process(changed_sent, classes[i], classify)) spoiled.append(process(changed_sent, classes[i], classify))
......
...@@ -15,6 +15,7 @@ MSTAG = "mstag" ...@@ -15,6 +15,7 @@ MSTAG = "mstag"
TEXT = "text" TEXT = "text"
LEMMAS = "lemmas" LEMMAS = "lemmas"
TAGS = "tags" TAGS = "tags"
ORTHS = "orths"
def tag_sentence(sentence: str, lang: str): def tag_sentence(sentence: str, lang: str):
...@@ -41,17 +42,18 @@ def tag_sentence(sentence: str, lang: str): ...@@ -41,17 +42,18 @@ def tag_sentence(sentence: str, lang: str):
for line in lines: for line in lines:
tokens.extend(line[TOKENS]) tokens.extend(line[TOKENS])
os.remove(downloaded) os.remove(downloaded)
lemmas, tags = [], [] lemmas, tags, orths = [], [], []
for token in tokens: for token in tokens:
lexeme = token["lexemes"][0] lexeme = token[LEXEMES][0]
lemmas.append(lexeme["lemma"]) lemmas.append(lexeme[LEMMA])
tags.append(lexeme["mstag"]) tags.append(lexeme[MSTAG])
return lemmas, tags orths.append(token[ORTH])
return lemmas, tags, orths
def process_file(dataset_df, lang, output_path): def process_file(dataset_df, lang, output_path):
test_with_tags = pd.DataFrame(dataset_df) test_with_tags = pd.DataFrame(dataset_df)
lemmas_col, tags_col = [], [] lemmas_col, tags_col, orth_col = [], [], []
cpus = 8 cpus = 8
with Pool(processes=cpus) as pool: with Pool(processes=cpus) as pool:
results = [] results = []
...@@ -62,12 +64,14 @@ def process_file(dataset_df, lang, output_path): ...@@ -62,12 +64,14 @@ def process_file(dataset_df, lang, output_path):
pool.apply_async(tag_sentence, args=[sentence, lang]) pool.apply_async(tag_sentence, args=[sentence, lang])
) )
for res in results: for res in results:
lemmas, tags = res.get() lemmas, tags, orths = res.get()
lemmas_col.append(lemmas) lemmas_col.append(lemmas)
tags_col.append(tags) tags_col.append(tags)
orth_col.append(orths)
results = [] results = []
test_with_tags[LEMMAS] = lemmas_col test_with_tags[LEMMAS] = lemmas_col
test_with_tags[TAGS] = tags_col test_with_tags[TAGS] = tags_col
test_with_tags[ORTHS] = orth_col
with open(output_path, mode="wt") as fd: with open(output_path, mode="wt") as fd:
fd.write(test_with_tags.to_json(orient="records", lines=True)) fd.write(test_with_tags.to_json(orient="records", lines=True))
...@@ -103,10 +107,10 @@ def main(dataset_name: str): ...@@ -103,10 +107,10 @@ def main(dataset_name: str):
test_with_tags = pd.DataFrame( test_with_tags = pd.DataFrame(
pd.read_json(os.path.join(input_dir, file), lines=True) pd.read_json(os.path.join(input_dir, file), lines=True)
) )
test_with_tags[LEMMAS] = [ empty_list = [[] for _ in range(len(test_with_tags)]]
"" for _ in range(len(test_with_tags)) test_with_tags[LEMMAS] = empty_list
] test_with_tags[TAGS] = empty_list
test_with_tags[TAGS] = ["" for _ in range(len(test_with_tags))] test_with_tags[ORTHS] = empty_list
with open(os.path.join(output_dir, file), mode="wt") as fd: with open(os.path.join(output_dir, file), mode="wt") as fd:
fd.write( fd.write(
test_with_tags.to_json(orient="records", lines=True) test_with_tags.to_json(orient="records", lines=True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment