Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
C
combo
Manage
Activity
Members
Labels
Plan
Issues
20
Issue boards
Milestones
Wiki
Redmine
Code
Merge requests
2
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container Registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Syntactic Tools
combo
Commits
e1bfa0c7
Commit
e1bfa0c7
authored
1 year ago
by
Maja Jablonska
Browse files
Options
Downloads
Patches
Plain Diff
Add testing
parent
d80a60c1
1 merge request
!46
Merge COMBO 3.0 into master
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
combo/data/api.py
+7
-7
7 additions, 7 deletions
combo/data/api.py
combo/data/tokenizers/token.py
+15
-1
15 additions, 1 deletion
combo/data/tokenizers/token.py
combo/main.py
+25
-3
25 additions, 3 deletions
combo/main.py
with
47 additions
and
11 deletions
combo/data/api.py
+
7
−
7
View file @
e1bfa0c7
...
@@ -54,21 +54,21 @@ class Sentence:
...
@@ -54,21 +54,21 @@ class Sentence:
return
len
(
self
.
tokens
)
return
len
(
self
.
tokens
)
class
_TokenList
(
conllu
.
models
.
TokenList
):
class
_TokenList
(
conllu
.
TokenList
):
@overrides
@overrides
def
__repr__
(
self
):
def
__repr__
(
self
):
return
'
TokenList<
'
+
'
,
'
.
join
(
token
[
'
t
oken
'
]
for
token
in
self
)
+
'
>
'
return
'
TokenList<
'
+
'
,
'
.
join
(
token
[
'
t
ext
'
]
for
token
in
self
)
+
'
>
'
def
sentence2conllu
(
sentence
:
Sentence
,
keep_semrel
:
bool
=
True
)
->
conllu
.
models
.
TokenList
:
def
sentence2conllu
(
sentence
:
Sentence
,
keep_semrel
:
bool
=
True
)
->
conllu
.
TokenList
:
tokens
=
[]
tokens
=
[]
for
token
in
sentence
.
tokens
:
for
token
in
sentence
.
tokens
:
token_dict
=
collections
.
OrderedDict
(
dataclasses
.
asdict
(
token
))
token_dict
=
collections
.
OrderedDict
(
token
.
as
_
dict
(
keep_semrel
))
# Remove semrel to have default conllu format.
# Remove semrel to have default conllu format.
if
not
keep_semrel
:
#
if not keep_semrel:
del
token_dict
[
"
semrel
"
]
#
del token_dict["semrel"]
del
token_dict
[
"
embeddings
"
]
#
del token_dict["embeddings"]
tokens
.
append
(
token_dict
)
tokens
.
append
(
token_dict
)
# Range tokens must be tuple not list, this is conllu library requirement
# Range tokens must be tuple not list, this is conllu library requirement
for
t
in
tokens
:
for
t
in
tokens
:
...
...
This diff is collapsed.
Click to expand it.
combo/data/tokenizers/token.py
+
15
−
1
View file @
e1bfa0c7
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
Adapted from AllenNLP
Adapted from AllenNLP
https://github.com/allenai/allennlp/blob/main/allennlp/data/tokenizers/token_class.py
https://github.com/allenai/allennlp/blob/main/allennlp/data/tokenizers/token_class.py
"""
"""
from
collections
import
defaultdict
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
logging
import
logging
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
...
@@ -95,6 +95,20 @@ class Token:
...
@@ -95,6 +95,20 @@ class Token:
self
.
text_id
=
text_id
self
.
text_id
=
text_id
self
.
type_id
=
type_id
self
.
type_id
=
type_id
def
as_dict
(
self
,
semrel
:
bool
=
True
)
->
Dict
[
str
,
Any
]:
repr
=
{}
repr_keys
=
[
'
text
'
,
'
idx
'
,
'
lemma
'
,
'
upostag
'
,
'
xpostag
'
,
'
entity_type
'
,
'
feats
'
,
'
head
'
,
'
deprel
'
,
'
deps
'
,
'
misc
'
]
for
rk
in
repr_keys
:
repr
[
rk
]
=
self
.
__getattribute__
(
rk
)
if
semrel
:
repr
[
'
semrel
'
]
=
self
.
semrel
return
repr
def
__str__
(
self
):
def
__str__
(
self
):
return
self
.
text
return
self
.
text
...
...
This diff is collapsed.
Click to expand it.
combo/main.py
+
25
−
3
View file @
e1bfa0c7
import
json
import
json
import
logging
import
logging
import
os
import
pathlib
import
pathlib
import
tempfile
import
tempfile
from
typing
import
Dict
from
typing
import
Dict
...
@@ -7,6 +8,7 @@ from typing import Dict
...
@@ -7,6 +8,7 @@ from typing import Dict
import
torch
import
torch
from
absl
import
app
,
flags
from
absl
import
app
,
flags
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
from
tqdm
import
tqdm
from
combo.training.trainable_combo
import
TrainableCombo
from
combo.training.trainable_combo
import
TrainableCombo
from
combo.utils
import
checks
,
ComboLogger
from
combo.utils
import
checks
,
ComboLogger
...
@@ -15,6 +17,7 @@ from combo.config import resolve
...
@@ -15,6 +17,7 @@ from combo.config import resolve
from
combo.default_model
import
default_ud_dataset_reader
,
default_data_loader
from
combo.default_model
import
default_ud_dataset_reader
,
default_data_loader
from
combo.modules.archival
import
load_archive
,
archive
from
combo.modules.archival
import
load_archive
,
archive
from
combo.predict
import
COMBO
from
combo.predict
import
COMBO
from
combo.data
import
api
logging
.
setLoggerClass
(
ComboLogger
)
logging
.
setLoggerClass
(
ComboLogger
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -65,8 +68,9 @@ flags.DEFINE_string(name="finetuning_validation_data_path", default="",
...
@@ -65,8 +68,9 @@ flags.DEFINE_string(name="finetuning_validation_data_path", default="",
help
=
"
Validation data path(s)
"
)
help
=
"
Validation data path(s)
"
)
# Test after training flags
# Test after training flags
flags
.
DEFINE_string
(
name
=
"
test_path
"
,
default
=
None
,
flags
.
DEFINE_string
(
name
=
"
test_
data_
path
"
,
default
=
None
,
help
=
"
Test path file.
"
)
help
=
"
Test path file.
"
)
flags
.
DEFINE_alias
(
name
=
"
test_data
"
,
original_name
=
"
test_data_path
"
)
# Experimental
# Experimental
flags
.
DEFINE_boolean
(
name
=
"
use_pure_config
"
,
default
=
False
,
flags
.
DEFINE_boolean
(
name
=
"
use_pure_config
"
,
default
=
False
,
...
@@ -111,12 +115,16 @@ def run(_):
...
@@ -111,12 +115,16 @@ def run(_):
serialization_dir
=
tempfile
.
mkdtemp
(
prefix
=
'
combo
'
,
dir
=
FLAGS
.
serialization_dir
)
serialization_dir
=
tempfile
.
mkdtemp
(
prefix
=
'
combo
'
,
dir
=
FLAGS
.
serialization_dir
)
params
[
'
vocabulary
'
][
'
parameters
'
][
'
directory
'
]
=
os
.
path
.
join
(
'
/
'
.
join
(
FLAGS
.
config_path
.
split
(
'
/
'
)[:
-
1
]),
params
[
'
vocabulary
'
][
'
parameters
'
][
'
directory
'
])
try
:
try
:
vocabulary
=
resolve
(
params
[
'
vocabulary
'
])
vocabulary
=
resolve
(
params
[
'
vocabulary
'
])
except
KeyError
:
except
KeyError
:
logger
.
error
(
'
No vocabulary in config.json!
'
)
logger
.
error
(
'
No vocabulary in config.json!
'
)
return
return
model
=
resolve
(
params
[
'
model
'
],
pass_down_parameters
=
{
'
vocabulary
'
:
vocabulary
})
model
=
resolve
(
params
[
'
model
'
],
pass_down_parameters
=
{
'
vocabulary
'
:
vocabulary
})
dataset_reader
=
None
dataset_reader
=
None
...
@@ -184,9 +192,23 @@ def run(_):
...
@@ -184,9 +192,23 @@ def run(_):
gradient_clip_val
=
5
)
gradient_clip_val
=
5
)
trainer
.
fit
(
model
=
nlp
,
train_dataloaders
=
train_data_loader
,
val_dataloaders
=
validation_data_loader
)
trainer
.
fit
(
model
=
nlp
,
train_dataloaders
=
train_data_loader
,
val_dataloaders
=
validation_data_loader
)
logger
.
info
(
f
'
Archiving the
fine-tuned
model in
{
serialization_dir
}
'
,
prefix
=
prefix
)
logger
.
info
(
f
'
Archiving the model in
{
serialization_dir
}
'
,
prefix
=
prefix
)
archive
(
model
,
serialization_dir
,
train_data_loader
,
validation_data_loader
,
dataset_reader
)
archive
(
model
,
serialization_dir
,
train_data_loader
,
validation_data_loader
,
dataset_reader
)
logger
.
info
(
f
"
Training model stored in:
{
serialization_dir
}
"
,
prefix
=
prefix
)
logger
.
info
(
f
"
Model stored in:
{
serialization_dir
}
"
,
prefix
=
prefix
)
if
FLAGS
.
test_data_path
and
FLAGS
.
output_file
:
checks
.
file_exists
(
FLAGS
.
test_data_path
)
if
not
dataset_reader
:
logger
.
info
(
"
No dataset reader in the configuration or archive file - using a default UD dataset reader
"
,
prefix
=
prefix
)
dataset_reader
=
default_ud_dataset_reader
()
logger
.
info
(
"
Predicting test examples
"
,
prefix
=
prefix
)
test_trees
=
dataset_reader
.
read
(
FLAGS
.
test_data_path
)
predictor
=
COMBO
(
model
,
dataset_reader
)
with
open
(
FLAGS
.
output_file
,
"
w
"
)
as
file
:
for
tree
in
tqdm
(
test_trees
):
file
.
writelines
(
api
.
sentence2conllu
(
predictor
.
predict_instance
(
tree
),
keep_semrel
=
dataset_reader
.
use_sem
).
serialize
())
elif
FLAGS
.
mode
==
'
predict
'
:
elif
FLAGS
.
mode
==
'
predict
'
:
predictor
=
get_predictor
()
predictor
=
get_predictor
()
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment