rbbt-dm 1.2.1 → 1.2.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/lib/rbbt/vector/model/huggingface.rb +3 -5
- data/lib/rbbt/vector/model.rb +1 -1
- data/python/rbbt_dm/__init__.py +1 -0
- data/python/rbbt_dm/huggingface.py +135 -0
- data/test/rbbt/vector/model/test_huggingface.rb +16 -1
- metadata +17 -15
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: abaea1fff82b5e14a84dc9afc966fc8dde6482d50769d196854c1d619adebaf3
|
4
|
+
data.tar.gz: 561b8864fc2c0ba271a2a658da0d3492c7481a2368b40c3b91fe6edb4ebca4cd
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: f26f6b27f1beb2554fa78369d1d618cc13175e0c9bb0e789b9490dcae0f7f6df4449a3c72d183ae22c96324d4e2f1ab0352bde8068c1c18871d52c5f5b53c235
|
7
|
+
data.tar.gz: bb33d93cbe24ea974beedb0530f9af317dec06c7e76f32c37d724322ba05f241c6b79a706a88f1bbe703ac4bc78c53c220f28c3f38cf7939477274b8747c436e
|
@@ -1,11 +1,12 @@
|
|
1
1
|
require 'rbbt/vector/model'
|
2
2
|
require 'rbbt/util/python'
|
3
3
|
|
4
|
+
RbbtPython.add_path Rbbt.python.find(:lib)
|
4
5
|
RbbtPython.init_rbbt
|
5
6
|
|
6
7
|
class HuggingfaceModel < VectorModel
|
7
8
|
|
8
|
-
attr_accessor :checkpoint, :task, :locate_tokens, :class_labels
|
9
|
+
attr_accessor :checkpoint, :task, :locate_tokens, :class_labels, :class_weights
|
9
10
|
|
10
11
|
def tsv_dataset(tsv_dataset_file, elements, labels = nil)
|
11
12
|
|
@@ -48,7 +49,7 @@ class HuggingfaceModel < VectorModel
|
|
48
49
|
|
49
50
|
if labels
|
50
51
|
training_args = call_method(:training_args, output_dir)
|
51
|
-
call_method(:train_model, @model, @tokenizer, training_args, tsv_dataset(tsv_file, elements, labels))
|
52
|
+
call_method(:train_model, @model, @tokenizer, training_args, tsv_dataset(tsv_file, elements, labels), @class_weights)
|
52
53
|
else
|
53
54
|
if Array === elements
|
54
55
|
training_args = call_method(:training_args, output_dir)
|
@@ -135,6 +136,3 @@ class HuggingfaceModel < VectorModel
|
|
135
136
|
end
|
136
137
|
end
|
137
138
|
|
138
|
-
if __FILE__ == $0
|
139
|
-
|
140
|
-
end
|
data/lib/rbbt/vector/model.rb
CHANGED
@@ -0,0 +1 @@
|
|
1
|
+
# Keep
|
@@ -0,0 +1,135 @@
|
|
1
|
+
#{{{ LOAD MODEL
|
2
|
+
|
3
|
+
def import_module_class(module, class_name):
|
4
|
+
exec(f"from {module} import {class_name}")
|
5
|
+
return eval(class_name)
|
6
|
+
|
7
|
+
def load_model(task, checkpoint):
|
8
|
+
class_name = 'AutoModelFor' + task
|
9
|
+
return import_module_class('transformers', class_name).from_pretrained(checkpoint)
|
10
|
+
|
11
|
+
def load_tokenizer(task, checkpoint):
|
12
|
+
class_name = 'AutoTokenizer'
|
13
|
+
return import_module_class('transformers', class_name).from_pretrained(checkpoint)
|
14
|
+
|
15
|
+
def load_model_and_tokenizer(task, checkpoint):
|
16
|
+
model = load_model(task, checkpoint)
|
17
|
+
tokenizer = load_tokenizer(task, checkpoint)
|
18
|
+
return model, tokenizer
|
19
|
+
|
20
|
+
#{{{ SIMPLE EVALUATE
|
21
|
+
|
22
|
+
def forward(model, features):
|
23
|
+
return model(**features)
|
24
|
+
|
25
|
+
def logits(predictions):
|
26
|
+
logits = predictions["logits"]
|
27
|
+
return [v.detach().cpu().numpy() for v in logits]
|
28
|
+
|
29
|
+
def eval_model(model, tokenizer, texts, return_logits = True):
|
30
|
+
features = tokenizer(texts, return_tensors='pt', truncation=True).to(model.device)
|
31
|
+
predictions = forward(model, features)
|
32
|
+
if (return_logits):
|
33
|
+
return logits(predictions)
|
34
|
+
else:
|
35
|
+
return predictions
|
36
|
+
|
37
|
+
#{{{ TRAIN AND PREDICT
|
38
|
+
|
39
|
+
def load_tsv(tsv_file):
|
40
|
+
from datasets import load_dataset
|
41
|
+
return load_dataset('csv', data_files=[tsv_file], sep="\t")
|
42
|
+
|
43
|
+
def tsv_dataset(tokenizer, tsv_file):
|
44
|
+
dataset = load_tsv(tsv_file)
|
45
|
+
tokenized_dataset = dataset.map(lambda example: tokenizer(example["text"], truncation=True) , batched=True)
|
46
|
+
return tokenized_dataset
|
47
|
+
|
48
|
+
def training_args(*args, **kwargs):
|
49
|
+
from transformers import TrainingArguments
|
50
|
+
training_args = TrainingArguments(*args, **kwargs)
|
51
|
+
return training_args
|
52
|
+
|
53
|
+
|
54
|
+
def train_model(model, tokenizer, training_args, tsv_file, class_weights=None):
|
55
|
+
from transformers import Trainer
|
56
|
+
|
57
|
+
tokenized_dataset = tsv_dataset(tokenizer, tsv_file)
|
58
|
+
|
59
|
+
if (not class_weights == None):
|
60
|
+
import torch
|
61
|
+
from torch import nn
|
62
|
+
|
63
|
+
class WeightTrainer(Trainer):
|
64
|
+
def compute_loss(self, model, inputs, return_outputs=False):
|
65
|
+
labels = inputs.get("labels")
|
66
|
+
# forward pass
|
67
|
+
outputs = model(**inputs)
|
68
|
+
logits = outputs.get('logits')
|
69
|
+
# compute custom loss
|
70
|
+
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor(class_weights).to(model.device))
|
71
|
+
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
|
72
|
+
return (loss, outputs) if return_outputs else loss
|
73
|
+
|
74
|
+
trainer = WeightTrainer(
|
75
|
+
model,
|
76
|
+
training_args,
|
77
|
+
train_dataset = tokenized_dataset["train"],
|
78
|
+
tokenizer = tokenizer
|
79
|
+
)
|
80
|
+
else:
|
81
|
+
|
82
|
+
trainer = Trainer(
|
83
|
+
model,
|
84
|
+
training_args,
|
85
|
+
train_dataset = tokenized_dataset["train"],
|
86
|
+
tokenizer = tokenizer
|
87
|
+
)
|
88
|
+
|
89
|
+
trainer.train()
|
90
|
+
|
91
|
+
def find_tokens_in_input(dataset, token_ids):
|
92
|
+
position_rows = []
|
93
|
+
|
94
|
+
for row in dataset:
|
95
|
+
input_ids = row["input_ids"]
|
96
|
+
|
97
|
+
if (not hasattr(token_ids, "__len__")):
|
98
|
+
token_ids = [token_ids]
|
99
|
+
|
100
|
+
positions = []
|
101
|
+
for token_id in token_ids:
|
102
|
+
|
103
|
+
item_positions = []
|
104
|
+
for i in range(len(input_ids)):
|
105
|
+
if input_ids[i] == token_id:
|
106
|
+
item_positions.append(i)
|
107
|
+
|
108
|
+
positions.append(item_positions)
|
109
|
+
|
110
|
+
|
111
|
+
position_rows.append(positions)
|
112
|
+
|
113
|
+
return position_rows
|
114
|
+
|
115
|
+
|
116
|
+
def predict_model(model, tokenizer, training_args, tsv_file, locate_tokens = None):
|
117
|
+
from transformers import Trainer
|
118
|
+
|
119
|
+
tokenized_dataset = tsv_dataset(tokenizer, tsv_file)
|
120
|
+
|
121
|
+
trainer = Trainer(
|
122
|
+
model,
|
123
|
+
training_args,
|
124
|
+
tokenizer = tokenizer
|
125
|
+
)
|
126
|
+
|
127
|
+
result = trainer.predict(test_dataset = tokenized_dataset["train"])
|
128
|
+
if (locate_tokens != None):
|
129
|
+
token_ids = tokenizer.convert_tokens_to_ids(locate_tokens)
|
130
|
+
token_positions = find_tokens_in_input(tokenized_dataset["train"], token_ids)
|
131
|
+
return dict(result=result, token_positions=token_positions)
|
132
|
+
else:
|
133
|
+
return result
|
134
|
+
|
135
|
+
|
@@ -3,6 +3,22 @@ require 'rbbt/vector/model/huggingface'
|
|
3
3
|
|
4
4
|
class TestHuggingface < Test::Unit::TestCase
|
5
5
|
|
6
|
+
def test_pipeline
|
7
|
+
require 'rbbt/util/python'
|
8
|
+
model = VectorModel.new
|
9
|
+
model.post_process do |elements|
|
10
|
+
elements.collect{|e| e['label'] }
|
11
|
+
end
|
12
|
+
model.eval_model do |file, elements|
|
13
|
+
RbbtPython.run :transformers do
|
14
|
+
classifier ||= transformers.pipeline("sentiment-analysis")
|
15
|
+
classifier.call(elements)
|
16
|
+
end
|
17
|
+
end
|
18
|
+
|
19
|
+
assert_equal ["POSITIVE"], model.eval("I've been waiting for a HuggingFace course my whole life.")
|
20
|
+
end
|
21
|
+
|
6
22
|
def test_sst_eval
|
7
23
|
TmpFile.with_file do |dir|
|
8
24
|
checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
|
@@ -12,7 +28,6 @@ class TestHuggingface < Test::Unit::TestCase
|
|
12
28
|
model.class_labels = ["Bad", "Good"]
|
13
29
|
|
14
30
|
assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
|
15
|
-
|
16
31
|
end
|
17
32
|
end
|
18
33
|
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rbbt-dm
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 1.2.
|
4
|
+
version: 1.2.4
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Miguel Vazquez
|
8
|
-
autorequire:
|
8
|
+
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-02-
|
11
|
+
date: 2023-02-07 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rbbt-util
|
@@ -113,6 +113,8 @@ files:
|
|
113
113
|
- lib/rbbt/vector/model/svm.rb
|
114
114
|
- lib/rbbt/vector/model/tensorflow.rb
|
115
115
|
- lib/rbbt/vector/model/util.rb
|
116
|
+
- python/rbbt_dm/__init__.py
|
117
|
+
- python/rbbt_dm/huggingface.py
|
116
118
|
- share/R/MA.R
|
117
119
|
- share/R/barcode.R
|
118
120
|
- share/R/heatmap.3.R
|
@@ -141,7 +143,7 @@ files:
|
|
141
143
|
homepage: http://github.com/mikisvaz/rbbt-phgx
|
142
144
|
licenses: []
|
143
145
|
metadata: {}
|
144
|
-
post_install_message:
|
146
|
+
post_install_message:
|
145
147
|
rdoc_options: []
|
146
148
|
require_paths:
|
147
149
|
- lib
|
@@ -156,22 +158,22 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
156
158
|
- !ruby/object:Gem::Version
|
157
159
|
version: '0'
|
158
160
|
requirements: []
|
159
|
-
rubygems_version: 3.1.
|
160
|
-
signing_key:
|
161
|
+
rubygems_version: 3.1.6
|
162
|
+
signing_key:
|
161
163
|
specification_version: 4
|
162
164
|
summary: Data-mining and statistics
|
163
165
|
test_files:
|
164
|
-
- test/
|
166
|
+
- test/rbbt/statistics/test_hypergeometric.rb
|
167
|
+
- test/rbbt/statistics/test_fisher.rb
|
168
|
+
- test/rbbt/statistics/test_fdr.rb
|
169
|
+
- test/rbbt/statistics/test_random_walk.rb
|
170
|
+
- test/rbbt/test_ml_task.rb
|
165
171
|
- test/rbbt/vector/test_model.rb
|
166
|
-
- test/rbbt/vector/model/test_huggingface.rb
|
167
172
|
- test/rbbt/vector/model/test_tensorflow.rb
|
168
173
|
- test/rbbt/vector/model/test_spaCy.rb
|
174
|
+
- test/rbbt/vector/model/test_huggingface.rb
|
169
175
|
- test/rbbt/vector/model/test_svm.rb
|
170
|
-
- test/rbbt/statistics/test_random_walk.rb
|
171
|
-
- test/rbbt/statistics/test_fisher.rb
|
172
|
-
- test/rbbt/statistics/test_fdr.rb
|
173
|
-
- test/rbbt/statistics/test_hypergeometric.rb
|
174
|
-
- test/rbbt/test_stan.rb
|
175
|
-
- test/rbbt/matrix/test_barcode.rb
|
176
|
-
- test/rbbt/test_ml_task.rb
|
177
176
|
- test/rbbt/network/test_paths.rb
|
177
|
+
- test/rbbt/matrix/test_barcode.rb
|
178
|
+
- test/rbbt/test_stan.rb
|
179
|
+
- test/test_helper.rb
|