Skip to content
Snippets Groups Projects
Commit ac107081 authored by piotrmp's avatar piotrmp
Browse files

Added partial execution to pretraining.

parent 98383a84
1 merge request!1Migration to UD 2.11
...@@ -15,13 +15,13 @@ from lambo.learning.preprocessing_pretraining import encode_pretraining, prepare ...@@ -15,13 +15,13 @@ from lambo.learning.preprocessing_pretraining import encode_pretraining, prepare
from lambo.learning.train import pretrain from lambo.learning.train import pretrain
from lambo.utils.oscar import read_jsonl_to_documents, download_archive1_from_oscar from lambo.utils.oscar import read_jsonl_to_documents, download_archive1_from_oscar
if __name__=='__main__': if __name__ == '__main__':
outpath = Path(sys.argv[1]) #Path.home() / 'PATH-TO/models/pretrained/' outpath = Path(sys.argv[1]) # Path.home() / 'PATH-TO/models/pretrained/'
tmppath = Path(sys.argv[2]) #Path.home() / 'PATH-TO/tmp/tmp.jsonl.gz' tmppath = Path(sys.argv[2]) # Path.home() / 'PATH-TO/tmp/tmp.jsonl.gz'
# These need to be filled ine before running. OSCAR is avaialable on request. # These need to be filled ine before running. OSCAR is avaialable on request.
OSCAR_LOGIN = sys.argv[3] OSCAR_LOGIN = sys.argv[3]
OSCAR_PASSWORD = sys.argv[4] OSCAR_PASSWORD = sys.argv[4]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
languages_file_str = resources.read_text('lambo.resources', 'languages.txt', encoding='utf-8', errors='strict') languages_file_str = resources.read_text('lambo.resources', 'languages.txt', encoding='utf-8', errors='strict')
...@@ -32,7 +32,9 @@ if __name__=='__main__': ...@@ -32,7 +32,9 @@ if __name__=='__main__':
MAX_DOCUMENTS = 100 MAX_DOCUMENTS = 100
CONTEXT_LEN = 1024 CONTEXT_LEN = 1024
for language in languages: for l, language in enumerate(languages):
if l % 5 != int(sys.argv[5]):
continue
if (outpath / ('oscar_' + language + '.pth')).exists(): if (outpath / ('oscar_' + language + '.pth')).exists():
continue continue
print("Language: " + language) print("Language: " + language)
...@@ -52,7 +54,8 @@ if __name__=='__main__': ...@@ -52,7 +54,8 @@ if __name__=='__main__':
print(str(i + 1) + '/' + str(min(len(train_documents), MAX_DOCUMENTS))) print(str(i + 1) + '/' + str(min(len(train_documents), MAX_DOCUMENTS)))
Xchars, Xutfs, Xmasks, Yvecs = encode_pretraining([document_train], dict, CONTEXT_LEN) Xchars, Xutfs, Xmasks, Yvecs = encode_pretraining([document_train], dict, CONTEXT_LEN)
_, train_dataloader, test_dataloader = prepare_dataloaders_pretraining([document_train], _, train_dataloader, test_dataloader = prepare_dataloaders_pretraining([document_train],
[document_test], CONTEXT_LEN, 32, dict) [document_test], CONTEXT_LEN, 32,
dict)
pretrain(model, train_dataloader, test_dataloader, 1, device) pretrain(model, train_dataloader, test_dataloader, 1, device)
torch.save(model, outpath / ('oscar_' + language + '.pth')) torch.save(model, outpath / ('oscar_' + language + '.pth'))
with open(outpath / ('oscar_' + language + '.dict'), "w") as file1: with open(outpath / ('oscar_' + language + '.dict'), "w") as file1:
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment