Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
C
combo
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Redmine
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Syntactic Tools
combo
Commits
e1bfa0c7
Commit
e1bfa0c7
authored
Nov 12, 2023
by
Maja Jablonska
Browse files
Options
Downloads
Patches
Plain Diff
Add testing
parent
d80a60c1
No related branches found
No related tags found
1 merge request
!46
Merge COMBO 3.0 into master
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
combo/data/api.py
+7
-7
7 additions, 7 deletions
combo/data/api.py
combo/data/tokenizers/token.py
+15
-1
15 additions, 1 deletion
combo/data/tokenizers/token.py
combo/main.py
+25
-3
25 additions, 3 deletions
combo/main.py
with
47 additions
and
11 deletions
combo/data/api.py
+
7
−
7
View file @
e1bfa0c7
...
@@ -54,21 +54,21 @@ class Sentence:
...
@@ -54,21 +54,21 @@ class Sentence:
return
len
(
self
.
tokens
)
return
len
(
self
.
tokens
)
class
_TokenList
(
conllu
.
models
.
TokenList
):
class
_TokenList
(
conllu
.
TokenList
):
@overrides
@overrides
def
__repr__
(
self
):
def
__repr__
(
self
):
return
'
TokenList<
'
+
'
,
'
.
join
(
token
[
'
t
oken
'
]
for
token
in
self
)
+
'
>
'
return
'
TokenList<
'
+
'
,
'
.
join
(
token
[
'
t
ext
'
]
for
token
in
self
)
+
'
>
'
def
sentence2conllu
(
sentence
:
Sentence
,
keep_semrel
:
bool
=
True
)
->
conllu
.
models
.
TokenList
:
def
sentence2conllu
(
sentence
:
Sentence
,
keep_semrel
:
bool
=
True
)
->
conllu
.
TokenList
:
tokens
=
[]
tokens
=
[]
for
token
in
sentence
.
tokens
:
for
token
in
sentence
.
tokens
:
token_dict
=
collections
.
OrderedDict
(
dataclasses
.
asdict
(
token
))
token_dict
=
collections
.
OrderedDict
(
token
.
as
_
dict
(
keep_semrel
))
# Remove semrel to have default conllu format.
# Remove semrel to have default conllu format.
if
not
keep_semrel
:
#
if not keep_semrel:
del
token_dict
[
"
semrel
"
]
#
del token_dict["semrel"]
del
token_dict
[
"
embeddings
"
]
#
del token_dict["embeddings"]
tokens
.
append
(
token_dict
)
tokens
.
append
(
token_dict
)
# Range tokens must be tuple not list, this is conllu library requirement
# Range tokens must be tuple not list, this is conllu library requirement
for
t
in
tokens
:
for
t
in
tokens
:
...
...
This diff is collapsed.
Click to expand it.
combo/data/tokenizers/token.py
+
15
−
1
View file @
e1bfa0c7
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
Adapted from AllenNLP
Adapted from AllenNLP
https://github.com/allenai/allennlp/blob/main/allennlp/data/tokenizers/token_class.py
https://github.com/allenai/allennlp/blob/main/allennlp/data/tokenizers/token_class.py
"""
"""
from
collections
import
defaultdict
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
logging
import
logging
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
...
@@ -95,6 +95,20 @@ class Token:
...
@@ -95,6 +95,20 @@ class Token:
self
.
text_id
=
text_id
self
.
text_id
=
text_id
self
.
type_id
=
type_id
self
.
type_id
=
type_id
def
as_dict
(
self
,
semrel
:
bool
=
True
)
->
Dict
[
str
,
Any
]:
repr
=
{}
repr_keys
=
[
'
text
'
,
'
idx
'
,
'
lemma
'
,
'
upostag
'
,
'
xpostag
'
,
'
entity_type
'
,
'
feats
'
,
'
head
'
,
'
deprel
'
,
'
deps
'
,
'
misc
'
]
for
rk
in
repr_keys
:
repr
[
rk
]
=
self
.
__getattribute__
(
rk
)
if
semrel
:
repr
[
'
semrel
'
]
=
self
.
semrel
return
repr
def
__str__
(
self
):
def
__str__
(
self
):
return
self
.
text
return
self
.
text
...
...
This diff is collapsed.
Click to expand it.
combo/main.py
+
25
−
3
View file @
e1bfa0c7
import
json
import
json
import
logging
import
logging
import
os
import
pathlib
import
pathlib
import
tempfile
import
tempfile
from
typing
import
Dict
from
typing
import
Dict
...
@@ -7,6 +8,7 @@ from typing import Dict
...
@@ -7,6 +8,7 @@ from typing import Dict
import
torch
import
torch
from
absl
import
app
,
flags
from
absl
import
app
,
flags
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
from
tqdm
import
tqdm
from
combo.training.trainable_combo
import
TrainableCombo
from
combo.training.trainable_combo
import
TrainableCombo
from
combo.utils
import
checks
,
ComboLogger
from
combo.utils
import
checks
,
ComboLogger
...
@@ -15,6 +17,7 @@ from combo.config import resolve
...
@@ -15,6 +17,7 @@ from combo.config import resolve
from
combo.default_model
import
default_ud_dataset_reader
,
default_data_loader
from
combo.default_model
import
default_ud_dataset_reader
,
default_data_loader
from
combo.modules.archival
import
load_archive
,
archive
from
combo.modules.archival
import
load_archive
,
archive
from
combo.predict
import
COMBO
from
combo.predict
import
COMBO
from
combo.data
import
api
logging
.
setLoggerClass
(
ComboLogger
)
logging
.
setLoggerClass
(
ComboLogger
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -65,8 +68,9 @@ flags.DEFINE_string(name="finetuning_validation_data_path", default="",
...
@@ -65,8 +68,9 @@ flags.DEFINE_string(name="finetuning_validation_data_path", default="",
help
=
"
Validation data path(s)
"
)
help
=
"
Validation data path(s)
"
)
# Test after training flags
# Test after training flags
flags
.
DEFINE_string
(
name
=
"
test_path
"
,
default
=
None
,
flags
.
DEFINE_string
(
name
=
"
test_
data_
path
"
,
default
=
None
,
help
=
"
Test path file.
"
)
help
=
"
Test path file.
"
)
flags
.
DEFINE_alias
(
name
=
"
test_data
"
,
original_name
=
"
test_data_path
"
)
# Experimental
# Experimental
flags
.
DEFINE_boolean
(
name
=
"
use_pure_config
"
,
default
=
False
,
flags
.
DEFINE_boolean
(
name
=
"
use_pure_config
"
,
default
=
False
,
...
@@ -111,12 +115,16 @@ def run(_):
...
@@ -111,12 +115,16 @@ def run(_):
serialization_dir
=
tempfile
.
mkdtemp
(
prefix
=
'
combo
'
,
dir
=
FLAGS
.
serialization_dir
)
serialization_dir
=
tempfile
.
mkdtemp
(
prefix
=
'
combo
'
,
dir
=
FLAGS
.
serialization_dir
)
params
[
'
vocabulary
'
][
'
parameters
'
][
'
directory
'
]
=
os
.
path
.
join
(
'
/
'
.
join
(
FLAGS
.
config_path
.
split
(
'
/
'
)[:
-
1
]),
params
[
'
vocabulary
'
][
'
parameters
'
][
'
directory
'
])
try
:
try
:
vocabulary
=
resolve
(
params
[
'
vocabulary
'
])
vocabulary
=
resolve
(
params
[
'
vocabulary
'
])
except
KeyError
:
except
KeyError
:
logger
.
error
(
'
No vocabulary in config.json!
'
)
logger
.
error
(
'
No vocabulary in config.json!
'
)
return
return
model
=
resolve
(
params
[
'
model
'
],
pass_down_parameters
=
{
'
vocabulary
'
:
vocabulary
})
model
=
resolve
(
params
[
'
model
'
],
pass_down_parameters
=
{
'
vocabulary
'
:
vocabulary
})
dataset_reader
=
None
dataset_reader
=
None
...
@@ -184,9 +192,23 @@ def run(_):
...
@@ -184,9 +192,23 @@ def run(_):
gradient_clip_val
=
5
)
gradient_clip_val
=
5
)
trainer
.
fit
(
model
=
nlp
,
train_dataloaders
=
train_data_loader
,
val_dataloaders
=
validation_data_loader
)
trainer
.
fit
(
model
=
nlp
,
train_dataloaders
=
train_data_loader
,
val_dataloaders
=
validation_data_loader
)
logger
.
info
(
f
'
Archiving the
fine-tuned
model in
{
serialization_dir
}
'
,
prefix
=
prefix
)
logger
.
info
(
f
'
Archiving the model in
{
serialization_dir
}
'
,
prefix
=
prefix
)
archive
(
model
,
serialization_dir
,
train_data_loader
,
validation_data_loader
,
dataset_reader
)
archive
(
model
,
serialization_dir
,
train_data_loader
,
validation_data_loader
,
dataset_reader
)
logger
.
info
(
f
"
Training model stored in:
{
serialization_dir
}
"
,
prefix
=
prefix
)
logger
.
info
(
f
"
Model stored in:
{
serialization_dir
}
"
,
prefix
=
prefix
)
if
FLAGS
.
test_data_path
and
FLAGS
.
output_file
:
checks
.
file_exists
(
FLAGS
.
test_data_path
)
if
not
dataset_reader
:
logger
.
info
(
"
No dataset reader in the configuration or archive file - using a default UD dataset reader
"
,
prefix
=
prefix
)
dataset_reader
=
default_ud_dataset_reader
()
logger
.
info
(
"
Predicting test examples
"
,
prefix
=
prefix
)
test_trees
=
dataset_reader
.
read
(
FLAGS
.
test_data_path
)
predictor
=
COMBO
(
model
,
dataset_reader
)
with
open
(
FLAGS
.
output_file
,
"
w
"
)
as
file
:
for
tree
in
tqdm
(
test_trees
):
file
.
writelines
(
api
.
sentence2conllu
(
predictor
.
predict_instance
(
tree
),
keep_semrel
=
dataset_reader
.
use_sem
).
serialize
())
elif
FLAGS
.
mode
==
'
predict
'
:
elif
FLAGS
.
mode
==
'
predict
'
:
predictor
=
get_predictor
()
predictor
=
get_predictor
()
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment