rbbt-dm 1.2.3 → 1.2.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/lib/rbbt/vector/model/huggingface.rb +3 -5
- data/python/rbbt_dm/__init__.py +1 -0
- data/python/rbbt_dm/huggingface.py +31 -8
- data/test/rbbt/vector/model/test_huggingface.rb +16 -1
- metadata +16 -15
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: abaea1fff82b5e14a84dc9afc966fc8dde6482d50769d196854c1d619adebaf3
|
4
|
+
data.tar.gz: 561b8864fc2c0ba271a2a658da0d3492c7481a2368b40c3b91fe6edb4ebca4cd
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: f26f6b27f1beb2554fa78369d1d618cc13175e0c9bb0e789b9490dcae0f7f6df4449a3c72d183ae22c96324d4e2f1ab0352bde8068c1c18871d52c5f5b53c235
|
7
|
+
data.tar.gz: bb33d93cbe24ea974beedb0530f9af317dec06c7e76f32c37d724322ba05f241c6b79a706a88f1bbe703ac4bc78c53c220f28c3f38cf7939477274b8747c436e
|
@@ -1,11 +1,12 @@
|
|
1
1
|
require 'rbbt/vector/model'
|
2
2
|
require 'rbbt/util/python'
|
3
3
|
|
4
|
+
RbbtPython.add_path Rbbt.python.find(:lib)
|
4
5
|
RbbtPython.init_rbbt
|
5
6
|
|
6
7
|
class HuggingfaceModel < VectorModel
|
7
8
|
|
8
|
-
attr_accessor :checkpoint, :task, :locate_tokens, :class_labels
|
9
|
+
attr_accessor :checkpoint, :task, :locate_tokens, :class_labels, :class_weights
|
9
10
|
|
10
11
|
def tsv_dataset(tsv_dataset_file, elements, labels = nil)
|
11
12
|
|
@@ -48,7 +49,7 @@ class HuggingfaceModel < VectorModel
|
|
48
49
|
|
49
50
|
if labels
|
50
51
|
training_args = call_method(:training_args, output_dir)
|
51
|
-
call_method(:train_model, @model, @tokenizer, training_args, tsv_dataset(tsv_file, elements, labels))
|
52
|
+
call_method(:train_model, @model, @tokenizer, training_args, tsv_dataset(tsv_file, elements, labels), @class_weights)
|
52
53
|
else
|
53
54
|
if Array === elements
|
54
55
|
training_args = call_method(:training_args, output_dir)
|
@@ -135,6 +136,3 @@ class HuggingfaceModel < VectorModel
|
|
135
136
|
end
|
136
137
|
end
|
137
138
|
|
138
|
-
if __FILE__ == $0
|
139
|
-
|
140
|
-
end
|
@@ -0,0 +1 @@
|
|
1
|
+
# Keep
|
@@ -51,17 +51,40 @@ def training_args(*args, **kwargs):
|
|
51
51
|
return training_args
|
52
52
|
|
53
53
|
|
54
|
-
def train_model(model, tokenizer, training_args, tsv_file):
|
54
|
+
def train_model(model, tokenizer, training_args, tsv_file, class_weights=None):
|
55
55
|
from transformers import Trainer
|
56
56
|
|
57
57
|
tokenized_dataset = tsv_dataset(tokenizer, tsv_file)
|
58
58
|
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
59
|
+
if (not class_weights == None):
|
60
|
+
import torch
|
61
|
+
from torch import nn
|
62
|
+
|
63
|
+
class WeightTrainer(Trainer):
|
64
|
+
def compute_loss(self, model, inputs, return_outputs=False):
|
65
|
+
labels = inputs.get("labels")
|
66
|
+
# forward pass
|
67
|
+
outputs = model(**inputs)
|
68
|
+
logits = outputs.get('logits')
|
69
|
+
# compute custom loss
|
70
|
+
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor(class_weights).to(model.device))
|
71
|
+
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
|
72
|
+
return (loss, outputs) if return_outputs else loss
|
73
|
+
|
74
|
+
trainer = WeightTrainer(
|
75
|
+
model,
|
76
|
+
training_args,
|
77
|
+
train_dataset = tokenized_dataset["train"],
|
78
|
+
tokenizer = tokenizer
|
79
|
+
)
|
80
|
+
else:
|
81
|
+
|
82
|
+
trainer = Trainer(
|
83
|
+
model,
|
84
|
+
training_args,
|
85
|
+
train_dataset = tokenized_dataset["train"],
|
86
|
+
tokenizer = tokenizer
|
87
|
+
)
|
65
88
|
|
66
89
|
trainer.train()
|
67
90
|
|
@@ -90,7 +113,6 @@ def find_tokens_in_input(dataset, token_ids):
|
|
90
113
|
return position_rows
|
91
114
|
|
92
115
|
|
93
|
-
|
94
116
|
def predict_model(model, tokenizer, training_args, tsv_file, locate_tokens = None):
|
95
117
|
from transformers import Trainer
|
96
118
|
|
@@ -110,3 +132,4 @@ def predict_model(model, tokenizer, training_args, tsv_file, locate_tokens = Non
|
|
110
132
|
else:
|
111
133
|
return result
|
112
134
|
|
135
|
+
|
@@ -3,6 +3,22 @@ require 'rbbt/vector/model/huggingface'
|
|
3
3
|
|
4
4
|
class TestHuggingface < Test::Unit::TestCase
|
5
5
|
|
6
|
+
def test_pipeline
|
7
|
+
require 'rbbt/util/python'
|
8
|
+
model = VectorModel.new
|
9
|
+
model.post_process do |elements|
|
10
|
+
elements.collect{|e| e['label'] }
|
11
|
+
end
|
12
|
+
model.eval_model do |file, elements|
|
13
|
+
RbbtPython.run :transformers do
|
14
|
+
classifier ||= transformers.pipeline("sentiment-analysis")
|
15
|
+
classifier.call(elements)
|
16
|
+
end
|
17
|
+
end
|
18
|
+
|
19
|
+
assert_equal ["POSITIVE"], model.eval("I've been waiting for a HuggingFace course my whole life.")
|
20
|
+
end
|
21
|
+
|
6
22
|
def test_sst_eval
|
7
23
|
TmpFile.with_file do |dir|
|
8
24
|
checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
|
@@ -12,7 +28,6 @@ class TestHuggingface < Test::Unit::TestCase
|
|
12
28
|
model.class_labels = ["Bad", "Good"]
|
13
29
|
|
14
30
|
assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
|
15
|
-
|
16
31
|
end
|
17
32
|
end
|
18
33
|
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rbbt-dm
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 1.2.
|
4
|
+
version: 1.2.4
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Miguel Vazquez
|
8
|
-
autorequire:
|
8
|
+
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-02-
|
11
|
+
date: 2023-02-07 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rbbt-util
|
@@ -113,6 +113,7 @@ files:
|
|
113
113
|
- lib/rbbt/vector/model/svm.rb
|
114
114
|
- lib/rbbt/vector/model/tensorflow.rb
|
115
115
|
- lib/rbbt/vector/model/util.rb
|
116
|
+
- python/rbbt_dm/__init__.py
|
116
117
|
- python/rbbt_dm/huggingface.py
|
117
118
|
- share/R/MA.R
|
118
119
|
- share/R/barcode.R
|
@@ -142,7 +143,7 @@ files:
|
|
142
143
|
homepage: http://github.com/mikisvaz/rbbt-phgx
|
143
144
|
licenses: []
|
144
145
|
metadata: {}
|
145
|
-
post_install_message:
|
146
|
+
post_install_message:
|
146
147
|
rdoc_options: []
|
147
148
|
require_paths:
|
148
149
|
- lib
|
@@ -157,22 +158,22 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
157
158
|
- !ruby/object:Gem::Version
|
158
159
|
version: '0'
|
159
160
|
requirements: []
|
160
|
-
rubygems_version: 3.1.
|
161
|
-
signing_key:
|
161
|
+
rubygems_version: 3.1.6
|
162
|
+
signing_key:
|
162
163
|
specification_version: 4
|
163
164
|
summary: Data-mining and statistics
|
164
165
|
test_files:
|
165
|
-
- test/
|
166
|
+
- test/rbbt/statistics/test_hypergeometric.rb
|
167
|
+
- test/rbbt/statistics/test_fisher.rb
|
168
|
+
- test/rbbt/statistics/test_fdr.rb
|
169
|
+
- test/rbbt/statistics/test_random_walk.rb
|
170
|
+
- test/rbbt/test_ml_task.rb
|
166
171
|
- test/rbbt/vector/test_model.rb
|
167
|
-
- test/rbbt/vector/model/test_huggingface.rb
|
168
172
|
- test/rbbt/vector/model/test_tensorflow.rb
|
169
173
|
- test/rbbt/vector/model/test_spaCy.rb
|
174
|
+
- test/rbbt/vector/model/test_huggingface.rb
|
170
175
|
- test/rbbt/vector/model/test_svm.rb
|
171
|
-
- test/rbbt/statistics/test_random_walk.rb
|
172
|
-
- test/rbbt/statistics/test_fisher.rb
|
173
|
-
- test/rbbt/statistics/test_fdr.rb
|
174
|
-
- test/rbbt/statistics/test_hypergeometric.rb
|
175
|
-
- test/rbbt/test_stan.rb
|
176
|
-
- test/rbbt/matrix/test_barcode.rb
|
177
|
-
- test/rbbt/test_ml_task.rb
|
178
176
|
- test/rbbt/network/test_paths.rb
|
177
|
+
- test/rbbt/matrix/test_barcode.rb
|
178
|
+
- test/rbbt/test_stan.rb
|
179
|
+
- test/test_helper.rb
|