Skip to content
Snippets Groups Projects
Commit a3009ec2 authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Fixes to dataset readers

parent 8d1d7a09
Branches
Tags
1 merge request!46Merge COMBO 3.0 into master
......@@ -84,7 +84,7 @@ def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.Toke
def tokens2conllu(tokens: List[str]) -> conllu.TokenList:
return _TokenList(
[collections.OrderedDict({"idx": idx, "token": token}) for
[collections.OrderedDict({"id": idx, "token": token}) for
idx, token
in enumerate(tokens, start=1)],
metadata=collections.OrderedDict()
......
......@@ -116,19 +116,12 @@ class UniversalDependenciesDatasetReader(DatasetReader, ABC):
assert conllu_file and file.exists(), f"File with path '{conllu_file}' does not exists!"
with file.open("r", encoding="utf-8") as f:
for annotation in conllu.parse_incr(f, fields=self.fields, field_parsers=self.field_parsers):
yield self.text_to_instance([Token.from_conllu_token(t) for t in annotation])
yield self.text_to_instance(annotation)
def text_to_instance(self, tree: conllu.TokenList) -> Instance:
fields_: Dict[str, Field] = {}
tokens = [Token.from_conllu_token(t) for t in tree if isinstance(t["idx"], int)]
# tokens = [Token(text=t["token"],
# upostag=t.get("upostag"),
# xpostag=t.get("xpostag"),ś
# lemma=t.get("lemma"),
# feats=t.get("feats"))
# for t in tree_tokens]
tokens = [Token.from_conllu_token(t) for t in tree if isinstance(t["id"], int)]
# features
text_field = TextField(tokens, self.token_indexers)
......
This diff is collapsed.
......@@ -12,14 +12,14 @@ class ConllDatasetReaderTest(unittest.TestCase):
def test_tokenize_correct_tokens(self):
reader = ConllDatasetReader(coding_scheme='IOB2')
token = next(iter(reader(os.path.join(os.path.dirname(__file__), 'conll_test_file.txt'))))
token = next(iter(reader.read(os.path.join(os.path.dirname(__file__), 'conll_test_file.txt'))))
self.assertListEqual([str(t) for t in token['tokens'].tokens],
['SOCCER', '-', 'JAPAN', 'GET', 'LUCKY', 'WIN', ',',
'CHINA', 'IN', 'SURPRISE', 'DEFEAT', '.'])
def test_tokenize_correct_tags(self):
reader = ConllDatasetReader(coding_scheme='IOB2')
token = next(iter(reader(os.path.join(os.path.dirname(__file__), 'conll_test_file.txt'))))
token = next(iter(reader.read(os.path.join(os.path.dirname(__file__), 'conll_test_file.txt'))))
self.assertListEqual(token['tags'].labels,
['O', 'O', 'B-LOC', 'O', 'O', 'O', 'O',
'B-PER', 'O', 'O', 'O', 'O'])
......@@ -9,12 +9,12 @@ from combo.data.tokenizers import SpacySentenceSplitter
class TextClassificationJSONReaderTest(unittest.TestCase):
def test_read_two_tokens(self):
reader = TextClassificationJSONReader()
tokens = [token for token in reader(os.path.join(os.path.dirname(__file__), 'text_classification_json_reader.json'))]
tokens = [token for token in reader.read(os.path.join(os.path.dirname(__file__), 'text_classification_json_reader.json'))]
self.assertEqual(len(tokens), 2)
def test_read_two_examples_fields_without_sentence_splitting(self):
reader = TextClassificationJSONReader()
tokens = [token for token in reader(os.path.join(os.path.dirname(__file__), 'text_classification_json_reader.json'))]
tokens = [token for token in reader.read(os.path.join(os.path.dirname(__file__), 'text_classification_json_reader.json'))]
self.assertEqual(len(tokens[0].fields.items()), 2)
self.assertIsInstance(tokens[0].fields.get('label'), LabelField)
self.assertEqual(tokens[0].fields.get('label').label, 'label1')
......@@ -24,7 +24,7 @@ class TextClassificationJSONReaderTest(unittest.TestCase):
def test_read_two_examples_tokens_without_sentence_splitting(self):
reader = TextClassificationJSONReader()
tokens = [token for token in reader(os.path.join(os.path.dirname(__file__), 'text_classification_json_reader.json'))]
tokens = [token for token in reader.read(os.path.join(os.path.dirname(__file__), 'text_classification_json_reader.json'))]
self.assertEqual(len(tokens[0].fields.items()), 2)
self.assertIsInstance(tokens[0].fields.get('tokens'), TextField)
self.assertEqual(len(tokens[0].fields.get('tokens').tokens), 13)
......@@ -34,7 +34,7 @@ class TextClassificationJSONReaderTest(unittest.TestCase):
def test_read_two_examples_tokens_with_sentence_splitting(self):
reader = TextClassificationJSONReader(sentence_segmenter=SpacySentenceSplitter())
tokens = [token for token in reader(os.path.join(os.path.dirname(__file__), 'text_classification_json_reader.json'))]
tokens = [token for token in reader.read(os.path.join(os.path.dirname(__file__), 'text_classification_json_reader.json'))]
self.assertEqual(len(tokens[0].fields.items()), 2)
self.assertIsInstance(tokens[0].fields.get('tokens'), ListField)
self.assertEqual(len(tokens[0].fields.get('tokens').field_list), 2)
......
......@@ -7,24 +7,24 @@ from combo.data import UniversalDependenciesDatasetReader
class UniversalDependenciesDatasetReaderTest(unittest.TestCase):
def test_read_all_tokens(self):
t = UniversalDependenciesDatasetReader()
tokens = [token for token in t(os.path.join(os.path.dirname(__file__), 'tl_trg-ud-test.conllu'))]
tokens = [token for token in t.read(os.path.join(os.path.dirname(__file__), 'tl_trg-ud-test.conllu'))]
self.assertEqual(len(tokens), 128)
def test_read_text(self):
t = UniversalDependenciesDatasetReader()
token = next(iter(t(os.path.join(os.path.dirname(__file__), 'tl_trg-ud-test.conllu'))))
self.assertListEqual([str(t) for t in token['sentence'].tokens],
token = next(iter(t.read(os.path.join(os.path.dirname(__file__), 'tl_trg-ud-test.conllu'))))
self.assertListEqual([t.text for t in token['sentence'].tokens],
['Gumising', 'ang', 'bata', '.'])
def test_read_deprel(self):
t = UniversalDependenciesDatasetReader()
token = next(iter(t(os.path.join(os.path.dirname(__file__), 'tl_trg-ud-test.conllu'))))
token = next(iter(t.read(os.path.join(os.path.dirname(__file__), 'tl_trg-ud-test.conllu'))))
self.assertListEqual(token['deprel'].labels,
['root', 'case', 'nsubj', 'punct'])
def test_read_upostag(self):
t = UniversalDependenciesDatasetReader()
token = next(iter(t(os.path.join(os.path.dirname(__file__), 'tl_trg-ud-test.conllu'))))
token = next(iter(t.read(os.path.join(os.path.dirname(__file__), 'tl_trg-ud-test.conllu'))))
self.assertListEqual(token['upostag'].labels,
['VERB', 'ADP', 'NOUN', 'PUNCT'])
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment