Skip to content
Snippets Groups Projects
Commit 0c324297 authored by pwalkow's avatar pwalkow
Browse files

Change script

parent 786bc9f2
No related merge requests found
......@@ -40,28 +40,18 @@ def tag_sentence(connection: Connection, sentence: str, lang: str):
return lemmas, tags
@click.command()
@click.option(
"--dataset_name",
help="Dataset name",
type=str,
)
def main(dataset_name: str):
"""Downloads the dataset to the output directory."""
lang = 'en' if dataset_name == 'enron_spam' else 'pl'
test = pd.read_json(f"data/datasets/{dataset_name}/test.jsonl", lines=True)
test_with_tags = pd.DataFrame(test)
conn = Connection(config_file="experiments/configs/config.yml")
def process_file(dataset_df, connection, lang, output_path):
test_with_tags = pd.DataFrame(dataset_df)
lemmas_col, tags_col = [], []
cpus = cpu_count()
with Pool(processes=cpus) as pool:
results = []
for idx in tqdm(range(0, len(test), cpus)):
end = min(idx+cpus, len(test) + 1)
for sentence in test[TEXT][idx:end]:
results.append(pool.apply_async(tag_sentence, args=[conn,
for idx in tqdm(range(0, len(dataset_df), cpus)):
end = min(idx+cpus, len(dataset_df) + 1)
for sentence in dataset_df[TEXT][idx:end]:
results.append(pool.apply_async(tag_sentence, args=(connection,
sentence,
lang]))
lang,)))
for res in results:
lemmas, tags = res.get()
lemmas_col.append(lemmas)
......@@ -70,10 +60,28 @@ def main(dataset_name: str):
test_with_tags[LEMMAS] = lemmas_col
test_with_tags[TAGS] = tags_col
with open(output_path, mode="wt") as fd:
fd.write(test_with_tags.to_json(orient='records', lines=True))
@click.command()
@click.option(
"--dataset_name",
help="Dataset name",
type=str,
)
def main(dataset_name: str):
"""Downloads the dataset to the output directory."""
lang = 'en' if dataset_name == 'enron_spam' else 'pl'
conn = Connection(config_file="experiments/configs/config.yml")
output_dir = f"data/preprocessed/{dataset_name}"
os.makedirs(output_dir, exist_ok=True)
with open(f"{output_dir}/test.jsonl", mode="wt") as fd:
fd.write(test_with_tags.to_json(orient='records', lines=True))
input_dir = f"data/datasets/{dataset_name}"
for file in os.listdir(input_dir):
if os.path.isfile(os.path.join(input_dir, file)):
process_file(pd.read_json(os.path.join(input_dir, file), lines=True),
conn, lang, os.path.join(output_dir, file))
if __name__ == "__main__":
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment