rbbt-dm 1.2.1 → 1.2.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/lib/rbbt/vector/model.rb +1 -1
- data/python/rbbt_dm/huggingface.py +112 -0
- metadata +2 -1
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 2ff72107967b0f7c654697f3a7b3c0ef10f7a5264d775117f12f74084a2819b2
|
4
|
+
data.tar.gz: 6b9a58b5a2723c095332f79a37d9c1c7f4bc1431410f23a55beeed1c3b52f7ad
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: a0fb4198cb0be3aa5253df0f655ee230621dd26a31956a774fffe95eac35f4c8b558a41c0e340c25c6eef463760ff6230b967f09eb09671b2078a50066067384
|
7
|
+
data.tar.gz: 0bd6c3667a8ec26ed092c54e78176671807ed0634e136e497303e34a17b2740e5e023041bc06389ac187de39b54942b9b1c5cd77abbc067c89250424654b6974
|
data/lib/rbbt/vector/model.rb
CHANGED
@@ -0,0 +1,112 @@
|
|
1
|
+
#{{{ LOAD MODEL
|
2
|
+
|
3
|
+
def import_module_class(module, class_name):
|
4
|
+
exec(f"from {module} import {class_name}")
|
5
|
+
return eval(class_name)
|
6
|
+
|
7
|
+
def load_model(task, checkpoint):
|
8
|
+
class_name = 'AutoModelFor' + task
|
9
|
+
return import_module_class('transformers', class_name).from_pretrained(checkpoint)
|
10
|
+
|
11
|
+
def load_tokenizer(task, checkpoint):
|
12
|
+
class_name = 'AutoTokenizer'
|
13
|
+
return import_module_class('transformers', class_name).from_pretrained(checkpoint)
|
14
|
+
|
15
|
+
def load_model_and_tokenizer(task, checkpoint):
|
16
|
+
model = load_model(task, checkpoint)
|
17
|
+
tokenizer = load_tokenizer(task, checkpoint)
|
18
|
+
return model, tokenizer
|
19
|
+
|
20
|
+
#{{{ SIMPLE EVALUATE
|
21
|
+
|
22
|
+
def forward(model, features):
|
23
|
+
return model(**features)
|
24
|
+
|
25
|
+
def logits(predictions):
|
26
|
+
logits = predictions["logits"]
|
27
|
+
return [v.detach().cpu().numpy() for v in logits]
|
28
|
+
|
29
|
+
def eval_model(model, tokenizer, texts, return_logits = True):
|
30
|
+
features = tokenizer(texts, return_tensors='pt', truncation=True).to(model.device)
|
31
|
+
predictions = forward(model, features)
|
32
|
+
if (return_logits):
|
33
|
+
return logits(predictions)
|
34
|
+
else:
|
35
|
+
return predictions
|
36
|
+
|
37
|
+
#{{{ TRAIN AND PREDICT
|
38
|
+
|
39
|
+
def load_tsv(tsv_file):
|
40
|
+
from datasets import load_dataset
|
41
|
+
return load_dataset('csv', data_files=[tsv_file], sep="\t")
|
42
|
+
|
43
|
+
def tsv_dataset(tokenizer, tsv_file):
|
44
|
+
dataset = load_tsv(tsv_file)
|
45
|
+
tokenized_dataset = dataset.map(lambda example: tokenizer(example["text"], truncation=True) , batched=True)
|
46
|
+
return tokenized_dataset
|
47
|
+
|
48
|
+
def training_args(*args, **kwargs):
|
49
|
+
from transformers import TrainingArguments
|
50
|
+
training_args = TrainingArguments(*args, **kwargs)
|
51
|
+
return training_args
|
52
|
+
|
53
|
+
|
54
|
+
def train_model(model, tokenizer, training_args, tsv_file):
|
55
|
+
from transformers import Trainer
|
56
|
+
|
57
|
+
tokenized_dataset = tsv_dataset(tokenizer, tsv_file)
|
58
|
+
|
59
|
+
trainer = Trainer(
|
60
|
+
model,
|
61
|
+
training_args,
|
62
|
+
train_dataset = tokenized_dataset["train"],
|
63
|
+
tokenizer = tokenizer
|
64
|
+
)
|
65
|
+
|
66
|
+
trainer.train()
|
67
|
+
|
68
|
+
def find_tokens_in_input(dataset, token_ids):
|
69
|
+
position_rows = []
|
70
|
+
|
71
|
+
for row in dataset:
|
72
|
+
input_ids = row["input_ids"]
|
73
|
+
|
74
|
+
if (not hasattr(token_ids, "__len__")):
|
75
|
+
token_ids = [token_ids]
|
76
|
+
|
77
|
+
positions = []
|
78
|
+
for token_id in token_ids:
|
79
|
+
|
80
|
+
item_positions = []
|
81
|
+
for i in range(len(input_ids)):
|
82
|
+
if input_ids[i] == token_id:
|
83
|
+
item_positions.append(i)
|
84
|
+
|
85
|
+
positions.append(item_positions)
|
86
|
+
|
87
|
+
|
88
|
+
position_rows.append(positions)
|
89
|
+
|
90
|
+
return position_rows
|
91
|
+
|
92
|
+
|
93
|
+
|
94
|
+
def predict_model(model, tokenizer, training_args, tsv_file, locate_tokens = None):
|
95
|
+
from transformers import Trainer
|
96
|
+
|
97
|
+
tokenized_dataset = tsv_dataset(tokenizer, tsv_file)
|
98
|
+
|
99
|
+
trainer = Trainer(
|
100
|
+
model,
|
101
|
+
training_args,
|
102
|
+
tokenizer = tokenizer
|
103
|
+
)
|
104
|
+
|
105
|
+
result = trainer.predict(test_dataset = tokenized_dataset["train"])
|
106
|
+
if (locate_tokens != None):
|
107
|
+
token_ids = tokenizer.convert_tokens_to_ids(locate_tokens)
|
108
|
+
token_positions = find_tokens_in_input(tokenized_dataset["train"], token_ids)
|
109
|
+
return dict(result=result, token_positions=token_positions)
|
110
|
+
else:
|
111
|
+
return result
|
112
|
+
|
metadata
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rbbt-dm
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 1.2.
|
4
|
+
version: 1.2.3
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Miguel Vazquez
|
@@ -113,6 +113,7 @@ files:
|
|
113
113
|
- lib/rbbt/vector/model/svm.rb
|
114
114
|
- lib/rbbt/vector/model/tensorflow.rb
|
115
115
|
- lib/rbbt/vector/model/util.rb
|
116
|
+
- python/rbbt_dm/huggingface.py
|
116
117
|
- share/R/MA.R
|
117
118
|
- share/R/barcode.R
|
118
119
|
- share/R/heatmap.3.R
|