Skip to content
Snippets Groups Projects
Commit a3009ec2 authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Fixes to dataset readers

parent 8d1d7a09
No related branches found
No related tags found
1 merge request!46Merge COMBO 3.0 into master
......@@ -84,7 +84,7 @@ def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.Toke
def tokens2conllu(tokens: List[str]) -> conllu.TokenList:
return _TokenList(
[collections.OrderedDict({"idx": idx, "token": token}) for
[collections.OrderedDict({"id": idx, "token": token}) for
idx, token
in enumerate(tokens, start=1)],
metadata=collections.OrderedDict()
......
......@@ -116,19 +116,12 @@ class UniversalDependenciesDatasetReader(DatasetReader, ABC):
assert conllu_file and file.exists(), f"File with path '{conllu_file}' does not exists!"
with file.open("r", encoding="utf-8") as f:
for annotation in conllu.parse_incr(f, fields=self.fields, field_parsers=self.field_parsers):
yield self.text_to_instance([Token.from_conllu_token(t) for t in annotation])
yield self.text_to_instance(annotation)
def text_to_instance(self, tree: conllu.TokenList) -> Instance:
fields_: Dict[str, Field] = {}
tokens = [Token.from_conllu_token(t) for t in tree if isinstance(t["idx"], int)]
# tokens = [Token(text=t["token"],
# upostag=t.get("upostag"),
# xpostag=t.get("xpostag"),ś
# lemma=t.get("lemma"),
# feats=t.get("feats"))
# for t in tree_tokens]
tokens = [Token.from_conllu_token(t) for t in tree if isinstance(t["id"], int)]
# features
text_field = TextField(tokens, self.token_indexers)
......
This diff is collapsed.
......@@ -12,14 +12,14 @@ class ConllDatasetReaderTest(unittest.TestCase):
def test_tokenize_correct_tokens(self):
reader = ConllDatasetReader(coding_scheme='IOB2')
token = next(iter(reader(os.path.join(os.path.dirname(__file__), 'conll_test_file.txt'))))
token = next(iter(reader.read(os.path.join(os.path.dirname(__file__), 'conll_test_file.txt'))))
self.assertListEqual([str(t) for t in token['tokens'].tokens],
['SOCCER', '-', 'JAPAN', 'GET', 'LUCKY', 'WIN', ',',
'CHINA', 'IN', 'SURPRISE', 'DEFEAT', '.'])
def test_tokenize_correct_tags(self):
reader = ConllDatasetReader(coding_scheme='IOB2')
token = next(iter(reader(os.path.join(os.path.dirname(__file__), 'conll_test_file.txt'))))
token = next(iter(reader.read(os.path.join(os.path.dirname(__file__), 'conll_test_file.txt'))))
self.assertListEqual(token['tags'].labels,
['O', 'O', 'B-LOC', 'O', 'O', 'O', 'O',
'B-PER', 'O', 'O', 'O', 'O'])
......@@ -9,12 +9,12 @@ from combo.data.tokenizers import SpacySentenceSplitter
class TextClassificationJSONReaderTest(unittest.TestCase):
def test_read_two_tokens(self):
reader = TextClassificationJSONReader()
tokens = [token for token in reader(os.path.join(os.path.dirname(__file__), 'text_classification_json_reader.json'))]
tokens = [token for token in reader.read(os.path.join(os.path.dirname(__file__), 'text_classification_json_reader.json'))]
self.assertEqual(len(tokens), 2)
def test_read_two_examples_fields_without_sentence_splitting(self):
reader = TextClassificationJSONReader()
tokens = [token for token in reader(os.path.join(os.path.dirname(__file__), 'text_classification_json_reader.json'))]
tokens = [token for token in reader.read(os.path.join(os.path.dirname(__file__), 'text_classification_json_reader.json'))]
self.assertEqual(len(tokens[0].fields.items()), 2)
self.assertIsInstance(tokens[0].fields.get('label'), LabelField)
self.assertEqual(tokens[0].fields.get('label').label, 'label1')
......@@ -24,7 +24,7 @@ class TextClassificationJSONReaderTest(unittest.TestCase):
def test_read_two_examples_tokens_without_sentence_splitting(self):
reader = TextClassificationJSONReader()
tokens = [token for token in reader(os.path.join(os.path.dirname(__file__), 'text_classification_json_reader.json'))]
tokens = [token for token in reader.read(os.path.join(os.path.dirname(__file__), 'text_classification_json_reader.json'))]
self.assertEqual(len(tokens[0].fields.items()), 2)
self.assertIsInstance(tokens[0].fields.get('tokens'), TextField)
self.assertEqual(len(tokens[0].fields.get('tokens').tokens), 13)
......@@ -34,7 +34,7 @@ class TextClassificationJSONReaderTest(unittest.TestCase):
def test_read_two_examples_tokens_with_sentence_splitting(self):
reader = TextClassificationJSONReader(sentence_segmenter=SpacySentenceSplitter())
tokens = [token for token in reader(os.path.join(os.path.dirname(__file__), 'text_classification_json_reader.json'))]
tokens = [token for token in reader.read(os.path.join(os.path.dirname(__file__), 'text_classification_json_reader.json'))]
self.assertEqual(len(tokens[0].fields.items()), 2)
self.assertIsInstance(tokens[0].fields.get('tokens'), ListField)
self.assertEqual(len(tokens[0].fields.get('tokens').field_list), 2)
......
......@@ -7,24 +7,24 @@ from combo.data import UniversalDependenciesDatasetReader
class UniversalDependenciesDatasetReaderTest(unittest.TestCase):
def test_read_all_tokens(self):
t = UniversalDependenciesDatasetReader()
tokens = [token for token in t(os.path.join(os.path.dirname(__file__), 'tl_trg-ud-test.conllu'))]
tokens = [token for token in t.read(os.path.join(os.path.dirname(__file__), 'tl_trg-ud-test.conllu'))]
self.assertEqual(len(tokens), 128)
def test_read_text(self):
t = UniversalDependenciesDatasetReader()
token = next(iter(t(os.path.join(os.path.dirname(__file__), 'tl_trg-ud-test.conllu'))))
self.assertListEqual([str(t) for t in token['sentence'].tokens],
token = next(iter(t.read(os.path.join(os.path.dirname(__file__), 'tl_trg-ud-test.conllu'))))
self.assertListEqual([t.text for t in token['sentence'].tokens],
['Gumising', 'ang', 'bata', '.'])
def test_read_deprel(self):
t = UniversalDependenciesDatasetReader()
token = next(iter(t(os.path.join(os.path.dirname(__file__), 'tl_trg-ud-test.conllu'))))
token = next(iter(t.read(os.path.join(os.path.dirname(__file__), 'tl_trg-ud-test.conllu'))))
self.assertListEqual(token['deprel'].labels,
['root', 'case', 'nsubj', 'punct'])
def test_read_upostag(self):
t = UniversalDependenciesDatasetReader()
token = next(iter(t(os.path.join(os.path.dirname(__file__), 'tl_trg-ud-test.conllu'))))
token = next(iter(t.read(os.path.join(os.path.dirname(__file__), 'tl_trg-ud-test.conllu'))))
self.assertListEqual(token['upostag'].labels,
['VERB', 'ADP', 'NOUN', 'PUNCT'])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment