rbbt-dm 1.2.3 → 1.2.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/lib/rbbt/vector/model/huggingface.rb +3 -5
- data/python/rbbt_dm/__init__.py +1 -0
- data/python/rbbt_dm/huggingface.py +31 -8
- data/test/rbbt/vector/model/test_huggingface.rb +16 -1
- metadata +16 -15
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: abaea1fff82b5e14a84dc9afc966fc8dde6482d50769d196854c1d619adebaf3
|
4
|
+
data.tar.gz: 561b8864fc2c0ba271a2a658da0d3492c7481a2368b40c3b91fe6edb4ebca4cd
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: f26f6b27f1beb2554fa78369d1d618cc13175e0c9bb0e789b9490dcae0f7f6df4449a3c72d183ae22c96324d4e2f1ab0352bde8068c1c18871d52c5f5b53c235
|
7
|
+
data.tar.gz: bb33d93cbe24ea974beedb0530f9af317dec06c7e76f32c37d724322ba05f241c6b79a706a88f1bbe703ac4bc78c53c220f28c3f38cf7939477274b8747c436e
|
@@ -1,11 +1,12 @@
|
|
1
1
|
require 'rbbt/vector/model'
|
2
2
|
require 'rbbt/util/python'
|
3
3
|
|
4
|
+
RbbtPython.add_path Rbbt.python.find(:lib)
|
4
5
|
RbbtPython.init_rbbt
|
5
6
|
|
6
7
|
class HuggingfaceModel < VectorModel
|
7
8
|
|
8
|
-
attr_accessor :checkpoint, :task, :locate_tokens, :class_labels
|
9
|
+
attr_accessor :checkpoint, :task, :locate_tokens, :class_labels, :class_weights
|
9
10
|
|
10
11
|
def tsv_dataset(tsv_dataset_file, elements, labels = nil)
|
11
12
|
|
@@ -48,7 +49,7 @@ class HuggingfaceModel < VectorModel
|
|
48
49
|
|
49
50
|
if labels
|
50
51
|
training_args = call_method(:training_args, output_dir)
|
51
|
-
call_method(:train_model, @model, @tokenizer, training_args, tsv_dataset(tsv_file, elements, labels))
|
52
|
+
call_method(:train_model, @model, @tokenizer, training_args, tsv_dataset(tsv_file, elements, labels), @class_weights)
|
52
53
|
else
|
53
54
|
if Array === elements
|
54
55
|
training_args = call_method(:training_args, output_dir)
|
@@ -135,6 +136,3 @@ class HuggingfaceModel < VectorModel
|
|
135
136
|
end
|
136
137
|
end
|
137
138
|
|
138
|
-
if __FILE__ == $0
|
139
|
-
|
140
|
-
end
|
@@ -0,0 +1 @@
|
|
1
|
+
# Keep
|
@@ -51,17 +51,40 @@ def training_args(*args, **kwargs):
|
|
51
51
|
return training_args
|
52
52
|
|
53
53
|
|
54
|
-
def train_model(model, tokenizer, training_args, tsv_file):
|
54
|
+
def train_model(model, tokenizer, training_args, tsv_file, class_weights=None):
|
55
55
|
from transformers import Trainer
|
56
56
|
|
57
57
|
tokenized_dataset = tsv_dataset(tokenizer, tsv_file)
|
58
58
|
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
59
|
+
if (not class_weights == None):
|
60
|
+
import torch
|
61
|
+
from torch import nn
|
62
|
+
|
63
|
+
class WeightTrainer(Trainer):
|
64
|
+
def compute_loss(self, model, inputs, return_outputs=False):
|
65
|
+
labels = inputs.get("labels")
|
66
|
+
# forward pass
|
67
|
+
outputs = model(**inputs)
|
68
|
+
logits = outputs.get('logits')
|
69
|
+
# compute custom loss
|
70
|
+
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor(class_weights).to(model.device))
|
71
|
+
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
|
72
|
+
return (loss, outputs) if return_outputs else loss
|
73
|
+
|
74
|
+
trainer = WeightTrainer(
|
75
|
+
model,
|
76
|
+
training_args,
|
77
|
+
train_dataset = tokenized_dataset["train"],
|
78
|
+
tokenizer = tokenizer
|
79
|
+
)
|
80
|
+
else:
|
81
|
+
|
82
|
+
trainer = Trainer(
|
83
|
+
model,
|
84
|
+
training_args,
|
85
|
+
train_dataset = tokenized_dataset["train"],
|
86
|
+
tokenizer = tokenizer
|
87
|
+
)
|
65
88
|
|
66
89
|
trainer.train()
|
67
90
|
|
@@ -90,7 +113,6 @@ def find_tokens_in_input(dataset, token_ids):
|
|
90
113
|
return position_rows
|
91
114
|
|
92
115
|
|
93
|
-
|
94
116
|
def predict_model(model, tokenizer, training_args, tsv_file, locate_tokens = None):
|
95
117
|
from transformers import Trainer
|
96
118
|
|
@@ -110,3 +132,4 @@ def predict_model(model, tokenizer, training_args, tsv_file, locate_tokens = Non
|
|
110
132
|
else:
|
111
133
|
return result
|
112
134
|
|
135
|
+
|
@@ -3,6 +3,22 @@ require 'rbbt/vector/model/huggingface'
|
|
3
3
|
|
4
4
|
class TestHuggingface < Test::Unit::TestCase
|
5
5
|
|
6
|
+
def test_pipeline
|
7
|
+
require 'rbbt/util/python'
|
8
|
+
model = VectorModel.new
|
9
|
+
model.post_process do |elements|
|
10
|
+
elements.collect{|e| e['label'] }
|
11
|
+
end
|
12
|
+
model.eval_model do |file, elements|
|
13
|
+
RbbtPython.run :transformers do
|
14
|
+
classifier ||= transformers.pipeline("sentiment-analysis")
|
15
|
+
classifier.call(elements)
|
16
|
+
end
|
17
|
+
end
|
18
|
+
|
19
|
+
assert_equal ["POSITIVE"], model.eval("I've been waiting for a HuggingFace course my whole life.")
|
20
|
+
end
|
21
|
+
|
6
22
|
def test_sst_eval
|
7
23
|
TmpFile.with_file do |dir|
|
8
24
|
checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
|
@@ -12,7 +28,6 @@ class TestHuggingface < Test::Unit::TestCase
|
|
12
28
|
model.class_labels = ["Bad", "Good"]
|
13
29
|
|
14
30
|
assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
|
15
|
-
|
16
31
|
end
|
17
32
|
end
|
18
33
|
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rbbt-dm
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 1.2.
|
4
|
+
version: 1.2.4
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Miguel Vazquez
|
8
|
-
autorequire:
|
8
|
+
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-02-
|
11
|
+
date: 2023-02-07 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rbbt-util
|
@@ -113,6 +113,7 @@ files:
|
|
113
113
|
- lib/rbbt/vector/model/svm.rb
|
114
114
|
- lib/rbbt/vector/model/tensorflow.rb
|
115
115
|
- lib/rbbt/vector/model/util.rb
|
116
|
+
- python/rbbt_dm/__init__.py
|
116
117
|
- python/rbbt_dm/huggingface.py
|
117
118
|
- share/R/MA.R
|
118
119
|
- share/R/barcode.R
|
@@ -142,7 +143,7 @@ files:
|
|
142
143
|
homepage: http://github.com/mikisvaz/rbbt-phgx
|
143
144
|
licenses: []
|
144
145
|
metadata: {}
|
145
|
-
post_install_message:
|
146
|
+
post_install_message:
|
146
147
|
rdoc_options: []
|
147
148
|
require_paths:
|
148
149
|
- lib
|
@@ -157,22 +158,22 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
157
158
|
- !ruby/object:Gem::Version
|
158
159
|
version: '0'
|
159
160
|
requirements: []
|
160
|
-
rubygems_version: 3.1.
|
161
|
-
signing_key:
|
161
|
+
rubygems_version: 3.1.6
|
162
|
+
signing_key:
|
162
163
|
specification_version: 4
|
163
164
|
summary: Data-mining and statistics
|
164
165
|
test_files:
|
165
|
-
- test/
|
166
|
+
- test/rbbt/statistics/test_hypergeometric.rb
|
167
|
+
- test/rbbt/statistics/test_fisher.rb
|
168
|
+
- test/rbbt/statistics/test_fdr.rb
|
169
|
+
- test/rbbt/statistics/test_random_walk.rb
|
170
|
+
- test/rbbt/test_ml_task.rb
|
166
171
|
- test/rbbt/vector/test_model.rb
|
167
|
-
- test/rbbt/vector/model/test_huggingface.rb
|
168
172
|
- test/rbbt/vector/model/test_tensorflow.rb
|
169
173
|
- test/rbbt/vector/model/test_spaCy.rb
|
174
|
+
- test/rbbt/vector/model/test_huggingface.rb
|
170
175
|
- test/rbbt/vector/model/test_svm.rb
|
171
|
-
- test/rbbt/statistics/test_random_walk.rb
|
172
|
-
- test/rbbt/statistics/test_fisher.rb
|
173
|
-
- test/rbbt/statistics/test_fdr.rb
|
174
|
-
- test/rbbt/statistics/test_hypergeometric.rb
|
175
|
-
- test/rbbt/test_stan.rb
|
176
|
-
- test/rbbt/matrix/test_barcode.rb
|
177
|
-
- test/rbbt/test_ml_task.rb
|
178
176
|
- test/rbbt/network/test_paths.rb
|
177
|
+
- test/rbbt/matrix/test_barcode.rb
|
178
|
+
- test/rbbt/test_stan.rb
|
179
|
+
- test/test_helper.rb
|