rbbt-dm 1.2.3 → 1.2.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/lib/rbbt/vector/model/huggingface.old.rb +160 -0
- data/lib/rbbt/vector/model/huggingface.rb +68 -47
- data/lib/rbbt/vector/model.rb +36 -34
- data/python/rbbt_dm/__init__.py +1 -0
- data/python/rbbt_dm/huggingface.py +42 -8
- data/test/rbbt/vector/model/test_huggingface.rb +47 -10
- metadata +4 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 9744ab9faeaf4f9cc04947eb11103dbf0694dda624f805a5c6be27bb22af81ce
|
4
|
+
data.tar.gz: d3a3903aa276a69e20cbd71213286449db396ecf5f6a4b4d80a64ab299041fbb
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 263fb609b37522874426bcd79374760399b4a9aaab443ae6d74c727f2d148474dd71ee0b2cfda7a50131dafbc314f66352f4285562a75f62144d2e05ccd214c7
|
7
|
+
data.tar.gz: 1e0426429a38028a19b3f8c955e975138199c791dad8691de7fb760a5cbec3304f19341906a4457563a90de965ee2f12a5a639b928866b46258af2507eeb39fa
|
@@ -0,0 +1,160 @@
|
|
1
|
+
require 'rbbt/vector/model'
|
2
|
+
require 'rbbt/util/python'
|
3
|
+
|
4
|
+
RbbtPython.add_path Rbbt.python.find(:lib)
|
5
|
+
RbbtPython.init_rbbt
|
6
|
+
|
7
|
+
class HuggingfaceModel < VectorModel
|
8
|
+
|
9
|
+
attr_accessor :checkpoint, :task, :locate_tokens, :class_labels, :class_weights, :training_args
|
10
|
+
|
11
|
+
def self.tsv_dataset(tsv_dataset_file, elements, labels = nil)
|
12
|
+
|
13
|
+
if labels
|
14
|
+
Open.write(tsv_dataset_file) do |ffile|
|
15
|
+
ffile.puts ["label", "text"].flatten * "\t"
|
16
|
+
elements.zip(labels).each do |element,label|
|
17
|
+
ffile.puts [label, element].flatten * "\t"
|
18
|
+
end
|
19
|
+
end
|
20
|
+
else
|
21
|
+
Open.write(tsv_dataset_file) do |ffile|
|
22
|
+
ffile.puts ["text"].flatten * "\t"
|
23
|
+
elements.each{|element| ffile.puts element }
|
24
|
+
end
|
25
|
+
end
|
26
|
+
|
27
|
+
tsv_dataset_file
|
28
|
+
end
|
29
|
+
|
30
|
+
def self.call_method(name, *args)
|
31
|
+
RbbtPython.import_method("rbbt_dm.huggingface", name).call(*args)
|
32
|
+
end
|
33
|
+
|
34
|
+
def call_method(name, *args)
|
35
|
+
HuggingfaceModel.call_method(name, *args)
|
36
|
+
end
|
37
|
+
|
38
|
+
#def input_tsv_file
|
39
|
+
# File.join(@directory, 'dataset.tsv') if @directory
|
40
|
+
#end
|
41
|
+
|
42
|
+
#def checkpoint_dir
|
43
|
+
# File.join(@directory, 'checkpoints') if @directory
|
44
|
+
#end
|
45
|
+
|
46
|
+
def self.run_model(model, tokenizer, elements, labels = nil, training_args = {}, class_weights = nil)
|
47
|
+
TmpFile.with_file do |tmpfile|
|
48
|
+
tsv_file = File.join(tmpfile, 'dataset.tsv')
|
49
|
+
|
50
|
+
if training_args
|
51
|
+
training_args = training_args.dup
|
52
|
+
checkpoint_dir = training_args.delete(:checkpoint_dir)
|
53
|
+
end
|
54
|
+
|
55
|
+
checkpoint_dir = File.join(tmpfile, 'checkpoints')
|
56
|
+
|
57
|
+
Open.mkdir File.dirname(tsv_file)
|
58
|
+
Open.mkdir File.dirname(checkpoint_dir)
|
59
|
+
|
60
|
+
if labels
|
61
|
+
training_args_obj = call_method(:training_args, checkpoint_dir, **training_args)
|
62
|
+
call_method(:train_model, model, tokenizer, training_args_obj, tsv_dataset(tsv_file, elements, labels), class_weights)
|
63
|
+
else
|
64
|
+
locate_tokens, training_args = training_args, {}
|
65
|
+
if Array === elements
|
66
|
+
training_args_obj = call_method(:training_args, checkpoint_dir)
|
67
|
+
call_method(:predict_model, model, tokenizer, training_args_obj, tsv_dataset(tsv_file, elements), locate_tokens)
|
68
|
+
else
|
69
|
+
call_method(:eval_model, model, tokenizer, [elements], locate_tokens)
|
70
|
+
end
|
71
|
+
end
|
72
|
+
end
|
73
|
+
end
|
74
|
+
|
75
|
+
def init_model
|
76
|
+
@model, @tokenizer = call_method(:load_model_and_tokenizer, @task, @checkpoint)
|
77
|
+
end
|
78
|
+
|
79
|
+
def reset_model
|
80
|
+
init_model
|
81
|
+
end
|
82
|
+
|
83
|
+
def initialize(task, initial_checkpoint = nil, *args)
|
84
|
+
super(*args)
|
85
|
+
@task = task
|
86
|
+
|
87
|
+
@checkpoint = model_file && File.exists?(model_file)? model_file : initial_checkpoint
|
88
|
+
|
89
|
+
init_model
|
90
|
+
|
91
|
+
@locate_tokens = @tokenizer.special_tokens_map["mask_token"] if @task == "MaskedLM"
|
92
|
+
|
93
|
+
@training_args = {}
|
94
|
+
|
95
|
+
train_model do |file,elements,labels|
|
96
|
+
HuggingfaceModel.run_model(@model, @tokenizer, elements, labels, @training_args, @class_weights)
|
97
|
+
|
98
|
+
@model.save_pretrained(file) if file
|
99
|
+
@tokenizer.save_pretrained(file) if file
|
100
|
+
end
|
101
|
+
|
102
|
+
eval_model do |file,elements|
|
103
|
+
@model, @tokenizer = HuggingfaceModel.call_method(:load_model_and_tokenizer, @task, @checkpoint)
|
104
|
+
HuggingfaceModel.run_model(@model, @tokenizer, elements, nil, @locate_tokens)
|
105
|
+
end
|
106
|
+
|
107
|
+
post_process do |result|
|
108
|
+
if result.respond_to?(:predictions)
|
109
|
+
single = false
|
110
|
+
predictions = result.predictions
|
111
|
+
elsif result["token_positions"]
|
112
|
+
predictions = result["result"].predictions
|
113
|
+
token_positions = result["token_positions"]
|
114
|
+
else
|
115
|
+
single = true
|
116
|
+
predictions = result["logits"]
|
117
|
+
end
|
118
|
+
|
119
|
+
result = case @task
|
120
|
+
when "SequenceClassification"
|
121
|
+
RbbtPython.collect(predictions) do |logits|
|
122
|
+
logits = RbbtPython.numpy2ruby logits
|
123
|
+
best_class = logits.index logits.max
|
124
|
+
best_class = @class_labels[best_class] if @class_labels
|
125
|
+
best_class
|
126
|
+
end
|
127
|
+
when "MaskedLM"
|
128
|
+
all_token_positions = token_positions.to_a
|
129
|
+
|
130
|
+
i = 0
|
131
|
+
RbbtPython.collect(predictions) do |item_logits|
|
132
|
+
item_token_positions = all_token_positions[i]
|
133
|
+
i += 1
|
134
|
+
|
135
|
+
item_logits = RbbtPython.numpy2ruby(item_logits)
|
136
|
+
item_masks = item_token_positions.collect do |token_positions|
|
137
|
+
|
138
|
+
best = item_logits.values_at(*token_positions).collect do |logits|
|
139
|
+
best_token, best_score = nil
|
140
|
+
logits.each_with_index do |v,i|
|
141
|
+
if best_score.nil? || v > best_score
|
142
|
+
best_token, best_score = i, v
|
143
|
+
end
|
144
|
+
end
|
145
|
+
best_token
|
146
|
+
end
|
147
|
+
|
148
|
+
best.collect{|b| @tokenizer.decode(b) } * "|"
|
149
|
+
end
|
150
|
+
Array === @locate_tokens ? item_masks : item_masks.first
|
151
|
+
end
|
152
|
+
else
|
153
|
+
logits
|
154
|
+
end
|
155
|
+
|
156
|
+
single ? result.first : result
|
157
|
+
end
|
158
|
+
end
|
159
|
+
end
|
160
|
+
|
@@ -1,13 +1,12 @@
|
|
1
1
|
require 'rbbt/vector/model'
|
2
2
|
require 'rbbt/util/python'
|
3
3
|
|
4
|
+
RbbtPython.add_path Rbbt.python.find(:lib)
|
4
5
|
RbbtPython.init_rbbt
|
5
6
|
|
6
7
|
class HuggingfaceModel < VectorModel
|
7
8
|
|
8
|
-
|
9
|
-
|
10
|
-
def tsv_dataset(tsv_dataset_file, elements, labels = nil)
|
9
|
+
def self.tsv_dataset(tsv_dataset_file, elements, labels = nil)
|
11
10
|
|
12
11
|
if labels
|
13
12
|
Open.write(tsv_dataset_file) do |ffile|
|
@@ -26,59 +25,74 @@ class HuggingfaceModel < VectorModel
|
|
26
25
|
tsv_dataset_file
|
27
26
|
end
|
28
27
|
|
29
|
-
def
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
end
|
28
|
+
def initialize(task, checkpoint, *args)
|
29
|
+
options = args.pop if Hash === args.last
|
30
|
+
options = Misc.add_defaults options, :task => task, :checkpoint => checkpoint
|
31
|
+
super(*args)
|
32
|
+
@model_options ||= {}
|
33
|
+
@model_options.merge!(options)
|
36
34
|
|
37
|
-
|
38
|
-
|
39
|
-
end
|
35
|
+
eval_model do |directory,texts|
|
36
|
+
checkpoint = directory && File.directory?(directory) ? directory : @model_options[:checkpoint]
|
40
37
|
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
38
|
+
if @model.nil?
|
39
|
+
@model, @tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_model_and_tokenizer, @model_options[:task], checkpoint)
|
40
|
+
end
|
41
|
+
|
42
|
+
if Array === texts
|
45
43
|
|
46
|
-
|
47
|
-
|
44
|
+
if @model_options.include?(:locate_tokens)
|
45
|
+
locate_tokens = @model_options[:locate_tokens]
|
46
|
+
elsif @model_options[:task] == "MaskedLM"
|
47
|
+
@model_options[:locate_tokens] = locate_tokens = @tokenizer.special_tokens_map["mask_token"]
|
48
|
+
end
|
48
49
|
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
else
|
53
|
-
if Array === elements
|
54
|
-
training_args = call_method(:training_args, output_dir)
|
55
|
-
call_method(:predict_model, @model, @tokenizer, training_args, tsv_dataset(tsv_file, elements), @locate_tokens)
|
50
|
+
if @directory
|
51
|
+
tsv_file = File.join(@directory, 'dataset.tsv')
|
52
|
+
checkpoint_dir = File.join(@directory, 'checkpoints')
|
56
53
|
else
|
57
|
-
|
54
|
+
tmpdir = TmpFile.tmp_file
|
55
|
+
Open.mkdir tmpdir
|
56
|
+
tsv_file = File.join(tmpdir, 'dataset.tsv')
|
57
|
+
checkpoint_dir = File.join(tmpdir, 'checkpoints')
|
58
|
+
end
|
59
|
+
|
60
|
+
dataset_file = HuggingfaceModel.tsv_dataset(tsv_file, texts)
|
61
|
+
training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, checkpoint_dir, @model_options[:training_args])
|
62
|
+
|
63
|
+
begin
|
64
|
+
RbbtPython.call_method("rbbt_dm.huggingface", :predict_model, @model, @tokenizer, training_args_obj, dataset_file, locate_tokens)
|
65
|
+
ensure
|
66
|
+
Open.rm_rf tmpdir if tmpdir
|
58
67
|
end
|
68
|
+
else
|
69
|
+
RbbtPython.call_method("rbbt_dm.huggingface", :eval_model, @model, @tokenizer, [texts], locate_tokens)
|
59
70
|
end
|
60
71
|
end
|
61
|
-
end
|
62
|
-
|
63
|
-
def initialize(task, initial_checkpoint = nil, *args)
|
64
|
-
super(*args)
|
65
|
-
@task = task
|
66
72
|
|
67
|
-
|
73
|
+
train_model do |directory,texts,labels|
|
74
|
+
checkpoint = directory && File.directory?(directory) ? directory : @model_options[:checkpoint]
|
75
|
+
@model, @tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_model_and_tokenizer, @model_options[:task], checkpoint)
|
68
76
|
|
69
|
-
|
77
|
+
if @directory
|
78
|
+
tsv_file = File.join(@directory, 'dataset.tsv')
|
79
|
+
checkpoint_dir = File.join(@directory, 'checkpoints')
|
80
|
+
else
|
81
|
+
tmpdir = TmpFile.tmp_file
|
82
|
+
Open.mkdir tmpdir
|
83
|
+
tsv_file = File.join(tmpdir, 'dataset.tsv')
|
84
|
+
checkpoint_dir = File.join(tmpdir, 'checkpoints')
|
85
|
+
end
|
70
86
|
|
71
|
-
|
87
|
+
training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, checkpoint_dir, @model_options[:training_args])
|
88
|
+
dataset_file = HuggingfaceModel.tsv_dataset(tsv_file, texts, labels)
|
72
89
|
|
73
|
-
|
74
|
-
run_model(elements, labels)
|
90
|
+
RbbtPython.call_method("rbbt_dm.huggingface", :train_model, @model, @tokenizer, training_args_obj, dataset_file, @model_options[:class_weights])
|
75
91
|
|
76
|
-
|
77
|
-
@tokenizer.save_pretrained(file) if file
|
78
|
-
end
|
92
|
+
Open.rm_rf tmpdir if tmpdir
|
79
93
|
|
80
|
-
|
81
|
-
|
94
|
+
@model.save_pretrained(directory) if directory
|
95
|
+
@tokenizer.save_pretrained(directory) if directory
|
82
96
|
end
|
83
97
|
|
84
98
|
post_process do |result|
|
@@ -93,12 +107,13 @@ class HuggingfaceModel < VectorModel
|
|
93
107
|
predictions = result["logits"]
|
94
108
|
end
|
95
109
|
|
96
|
-
|
110
|
+
task, class_labels, locate_tokens = @model_options.values_at :task, :class_labels, :locate_tokens
|
111
|
+
result = case task
|
97
112
|
when "SequenceClassification"
|
98
113
|
RbbtPython.collect(predictions) do |logits|
|
99
114
|
logits = RbbtPython.numpy2ruby logits
|
100
115
|
best_class = logits.index logits.max
|
101
|
-
best_class =
|
116
|
+
best_class = class_labels[best_class] if class_labels
|
102
117
|
best_class
|
103
118
|
end
|
104
119
|
when "MaskedLM"
|
@@ -124,7 +139,7 @@ class HuggingfaceModel < VectorModel
|
|
124
139
|
|
125
140
|
best.collect{|b| @tokenizer.decode(b) } * "|"
|
126
141
|
end
|
127
|
-
Array ===
|
142
|
+
Array === locate_tokens ? item_masks : item_masks.first
|
128
143
|
end
|
129
144
|
else
|
130
145
|
logits
|
@@ -132,9 +147,15 @@ class HuggingfaceModel < VectorModel
|
|
132
147
|
|
133
148
|
single ? result.first : result
|
134
149
|
end
|
150
|
+
|
151
|
+
|
152
|
+
save_models if @directory
|
135
153
|
end
|
136
|
-
end
|
137
154
|
|
138
|
-
|
155
|
+
def reset_model
|
156
|
+
@model, @tokenizer = nil
|
157
|
+
Open.rm @model_file
|
158
|
+
end
|
139
159
|
|
140
160
|
end
|
161
|
+
|
data/lib/rbbt/vector/model.rb
CHANGED
@@ -4,6 +4,7 @@ require 'rbbt/vector/model/util'
|
|
4
4
|
class VectorModel
|
5
5
|
attr_accessor :directory, :model_file, :extract_features, :train_model, :eval_model, :post_process
|
6
6
|
attr_accessor :features, :names, :labels, :factor_levels
|
7
|
+
attr_accessor :model_options
|
7
8
|
|
8
9
|
def extract_features(&block)
|
9
10
|
@extract_features = block if block_given?
|
@@ -126,7 +127,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
126
127
|
instance_eval code, file
|
127
128
|
end
|
128
129
|
|
129
|
-
def initialize(directory = nil, extract_features = nil, train_model = nil, eval_model = nil, names = nil, factor_levels = nil)
|
130
|
+
def initialize(directory = nil, extract_features = nil, train_model = nil, eval_model = nil, post_process = nil, names = nil, factor_levels = nil)
|
130
131
|
@directory = directory
|
131
132
|
if @directory
|
132
133
|
FileUtils.mkdir_p @directory unless File.exists?(@directory)
|
@@ -135,10 +136,18 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
135
136
|
@extract_features_file = File.join(@directory, "features")
|
136
137
|
@train_model_file = File.join(@directory, "train_model")
|
137
138
|
@eval_model_file = File.join(@directory, "eval_model")
|
139
|
+
@post_process_file = File.join(@directory, "post_process")
|
138
140
|
@train_model_file_R = File.join(@directory, "train_model.R")
|
139
141
|
@eval_model_file_R = File.join(@directory, "eval_model.R")
|
142
|
+
@post_process_file_R = File.join(@directory, "post_process.R")
|
140
143
|
@names_file = File.join(@directory, "feature_names")
|
141
144
|
@levels_file = File.join(@directory, "levels")
|
145
|
+
@options_file = File.join(@directory, "options.json")
|
146
|
+
|
147
|
+
if File.exists?(@options_file)
|
148
|
+
@model_options = JSON.parse(Open.read(@options_file))
|
149
|
+
IndiferentHash.setup(@model_options)
|
150
|
+
end
|
142
151
|
end
|
143
152
|
|
144
153
|
if extract_features.nil?
|
@@ -169,6 +178,17 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
169
178
|
@eval_model = eval_model
|
170
179
|
end
|
171
180
|
|
181
|
+
if post_process.nil?
|
182
|
+
if @post_process_file && File.exists?(@post_process_file)
|
183
|
+
@post_process = __load_method @post_process_file
|
184
|
+
elsif @post_process_file_R && File.exists?(@post_process_file_R)
|
185
|
+
@post_process = Open.read(@post_process_file_R)
|
186
|
+
end
|
187
|
+
else
|
188
|
+
@post_process = post_process
|
189
|
+
end
|
190
|
+
|
191
|
+
|
172
192
|
if names.nil?
|
173
193
|
if @names_file && File.exists?(@names_file)
|
174
194
|
@names = Open.read(@names_file).split("\n")
|
@@ -240,8 +260,20 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
240
260
|
Open.write(@eval_model_file_R, eval_model)
|
241
261
|
end
|
242
262
|
|
263
|
+
case
|
264
|
+
when Proc === post_process
|
265
|
+
begin
|
266
|
+
Open.write(@post_process_file, post_process.source)
|
267
|
+
rescue
|
268
|
+
end
|
269
|
+
when String === post_process
|
270
|
+
Open.write(@post_process_file_R, post_process)
|
271
|
+
end
|
272
|
+
|
273
|
+
|
243
274
|
Open.write(@levels_file, @factor_levels.to_yaml) if @factor_levels
|
244
275
|
Open.write(@names_file, @names * "\n" + "\n") if @names
|
276
|
+
Open.write(@options_file, @model_options.to_json) if @model_options
|
245
277
|
end
|
246
278
|
|
247
279
|
def train
|
@@ -251,7 +283,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
251
283
|
when String === @train_model
|
252
284
|
VectorModel.R_train(@model_file, @features, @labels, train_model, @names, @factor_levels)
|
253
285
|
end
|
254
|
-
save_models
|
286
|
+
save_models if @directory
|
255
287
|
end
|
256
288
|
|
257
289
|
def run(code)
|
@@ -299,38 +331,6 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
299
331
|
result
|
300
332
|
end
|
301
333
|
|
302
|
-
#def cross_validation(folds = 10)
|
303
|
-
# saved_features = @features
|
304
|
-
# saved_labels = @labels
|
305
|
-
# seq = (0..features.length - 1).to_a
|
306
|
-
|
307
|
-
# chunk_size = features.length / folds
|
308
|
-
|
309
|
-
# acc = []
|
310
|
-
# folds.times do
|
311
|
-
# seq = seq.shuffle
|
312
|
-
# eval_chunk = seq[0..chunk_size]
|
313
|
-
# train_chunk = seq[chunk_size.. -1]
|
314
|
-
|
315
|
-
# eval_features = @features.values_at *eval_chunk
|
316
|
-
# eval_labels = @labels.values_at *eval_chunk
|
317
|
-
|
318
|
-
# @features = @features.values_at *train_chunk
|
319
|
-
# @labels = @labels.values_at *train_chunk
|
320
|
-
|
321
|
-
# train
|
322
|
-
# predictions = eval_list eval_features, false
|
323
|
-
|
324
|
-
# acc << predictions.zip(eval_labels).collect{|pred,lab| pred - lab < 0.5 ? 1 : 0}.inject(0){|acc,e| acc +=e} / chunk_size
|
325
|
-
|
326
|
-
# @features = saved_features
|
327
|
-
# @labels = saved_labels
|
328
|
-
# end
|
329
|
-
|
330
|
-
# acc
|
331
|
-
#end
|
332
|
-
#
|
333
|
-
|
334
334
|
def self.f1_metrics(test, predicted, good_label = nil)
|
335
335
|
tp, tn, fp, fn, pr, re, f1 = [0, 0, 0, 0, nil, nil, nil]
|
336
336
|
|
@@ -413,6 +413,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
413
413
|
@features = train_set
|
414
414
|
@labels = train_labels
|
415
415
|
|
416
|
+
self.reset_model if self.respond_to? :reset_model
|
416
417
|
self.train
|
417
418
|
predictions = self.eval_list test_set, false
|
418
419
|
|
@@ -437,6 +438,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
437
438
|
@features = orig_features
|
438
439
|
@labels = orig_labels
|
439
440
|
end unless folds == -1
|
441
|
+
self.reset_model if self.respond_to? :reset_model
|
440
442
|
self.train unless folds == 1
|
441
443
|
res
|
442
444
|
end
|
@@ -0,0 +1 @@
|
|
1
|
+
# Keep
|
@@ -17,6 +17,17 @@ def load_model_and_tokenizer(task, checkpoint):
|
|
17
17
|
tokenizer = load_tokenizer(task, checkpoint)
|
18
18
|
return model, tokenizer
|
19
19
|
|
20
|
+
def load_model_and_tokenizer_from_directory(directory):
|
21
|
+
import os
|
22
|
+
import json
|
23
|
+
options_file = os.path.join(directory, 'options.json')
|
24
|
+
f = open(options_file, "r")
|
25
|
+
options = json.load(f.read())
|
26
|
+
f.close()
|
27
|
+
task = options["task"]
|
28
|
+
checkpoint = options["checkpoint"]
|
29
|
+
return load_model_and_tokenizer(task, checkpoint)
|
30
|
+
|
20
31
|
#{{{ SIMPLE EVALUATE
|
21
32
|
|
22
33
|
def forward(model, features):
|
@@ -51,17 +62,40 @@ def training_args(*args, **kwargs):
|
|
51
62
|
return training_args
|
52
63
|
|
53
64
|
|
54
|
-
def train_model(model, tokenizer, training_args, tsv_file):
|
65
|
+
def train_model(model, tokenizer, training_args, tsv_file, class_weights=None):
|
55
66
|
from transformers import Trainer
|
56
67
|
|
57
68
|
tokenized_dataset = tsv_dataset(tokenizer, tsv_file)
|
58
69
|
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
)
|
70
|
+
if (not class_weights == None):
|
71
|
+
import torch
|
72
|
+
from torch import nn
|
73
|
+
|
74
|
+
class WeightTrainer(Trainer):
|
75
|
+
def compute_loss(self, model, inputs, return_outputs=False):
|
76
|
+
labels = inputs.get("labels")
|
77
|
+
# forward pass
|
78
|
+
outputs = model(**inputs)
|
79
|
+
logits = outputs.get('logits')
|
80
|
+
# compute custom loss
|
81
|
+
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor(class_weights).to(model.device))
|
82
|
+
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
|
83
|
+
return (loss, outputs) if return_outputs else loss
|
84
|
+
|
85
|
+
trainer = WeightTrainer(
|
86
|
+
model,
|
87
|
+
training_args,
|
88
|
+
train_dataset = tokenized_dataset["train"],
|
89
|
+
tokenizer = tokenizer
|
90
|
+
)
|
91
|
+
else:
|
92
|
+
|
93
|
+
trainer = Trainer(
|
94
|
+
model,
|
95
|
+
training_args,
|
96
|
+
train_dataset = tokenized_dataset["train"],
|
97
|
+
tokenizer = tokenizer
|
98
|
+
)
|
65
99
|
|
66
100
|
trainer.train()
|
67
101
|
|
@@ -90,7 +124,6 @@ def find_tokens_in_input(dataset, token_ids):
|
|
90
124
|
return position_rows
|
91
125
|
|
92
126
|
|
93
|
-
|
94
127
|
def predict_model(model, tokenizer, training_args, tsv_file, locate_tokens = None):
|
95
128
|
from transformers import Trainer
|
96
129
|
|
@@ -110,3 +143,4 @@ def predict_model(model, tokenizer, training_args, tsv_file, locate_tokens = Non
|
|
110
143
|
else:
|
111
144
|
return result
|
112
145
|
|
146
|
+
|
@@ -3,16 +3,46 @@ require 'rbbt/vector/model/huggingface'
|
|
3
3
|
|
4
4
|
class TestHuggingface < Test::Unit::TestCase
|
5
5
|
|
6
|
+
def test_options
|
7
|
+
TmpFile.with_file do |dir|
|
8
|
+
checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
|
9
|
+
task = "SequenceClassification"
|
10
|
+
|
11
|
+
model = HuggingfaceModel.new task, checkpoint, dir, :class_labels => %w(bad good)
|
12
|
+
iii model.eval "This is dog"
|
13
|
+
iii model.eval "This is cat"
|
14
|
+
iii model.eval(["This is dog", "This is cat"])
|
15
|
+
|
16
|
+
model = VectorModel.new dir
|
17
|
+
iii model.eval(["This is dog", "This is cat"])
|
18
|
+
end
|
19
|
+
end
|
20
|
+
|
21
|
+
def test_pipeline
|
22
|
+
require 'rbbt/util/python'
|
23
|
+
model = VectorModel.new
|
24
|
+
model.post_process do |elements|
|
25
|
+
elements.collect{|e| e['label'] }
|
26
|
+
end
|
27
|
+
model.eval_model do |file, elements|
|
28
|
+
RbbtPython.run :transformers do
|
29
|
+
classifier ||= transformers.pipeline("sentiment-analysis")
|
30
|
+
classifier.call(elements)
|
31
|
+
end
|
32
|
+
end
|
33
|
+
|
34
|
+
assert_equal ["POSITIVE"], model.eval("I've been waiting for a HuggingFace course my whole life.")
|
35
|
+
end
|
36
|
+
|
6
37
|
def test_sst_eval
|
7
38
|
TmpFile.with_file do |dir|
|
8
39
|
checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
|
9
40
|
|
10
41
|
model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
|
11
42
|
|
12
|
-
model.class_labels = ["Bad", "Good"]
|
43
|
+
model.model_options[:class_labels] = ["Bad", "Good"]
|
13
44
|
|
14
45
|
assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
|
15
|
-
|
16
46
|
end
|
17
47
|
end
|
18
48
|
|
@@ -22,7 +52,8 @@ class TestHuggingface < Test::Unit::TestCase
|
|
22
52
|
checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
|
23
53
|
|
24
54
|
model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
|
25
|
-
|
55
|
+
|
56
|
+
model.model_options[:class_labels] = %w(Bad Good)
|
26
57
|
|
27
58
|
assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
|
28
59
|
|
@@ -33,6 +64,9 @@ class TestHuggingface < Test::Unit::TestCase
|
|
33
64
|
model.train
|
34
65
|
|
35
66
|
assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
|
67
|
+
|
68
|
+
model = VectorModel.new dir
|
69
|
+
assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
|
36
70
|
end
|
37
71
|
end
|
38
72
|
|
@@ -40,7 +74,7 @@ class TestHuggingface < Test::Unit::TestCase
|
|
40
74
|
checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
|
41
75
|
|
42
76
|
model = HuggingfaceModel.new "SequenceClassification", checkpoint
|
43
|
-
model.class_labels = ["Bad", "Good"]
|
77
|
+
model.model_options[:class_labels] = ["Bad", "Good"]
|
44
78
|
|
45
79
|
assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
|
46
80
|
|
@@ -58,7 +92,7 @@ class TestHuggingface < Test::Unit::TestCase
|
|
58
92
|
checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
|
59
93
|
|
60
94
|
model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
|
61
|
-
model.class_labels = ["Bad", "Good"]
|
95
|
+
model.model_options[:class_labels] = ["Bad", "Good"]
|
62
96
|
|
63
97
|
assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
|
64
98
|
|
@@ -69,15 +103,20 @@ class TestHuggingface < Test::Unit::TestCase
|
|
69
103
|
model.train
|
70
104
|
|
71
105
|
model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
|
72
|
-
model.class_labels = ["Bad", "Good"]
|
73
106
|
|
74
107
|
assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
|
75
108
|
|
76
|
-
|
77
|
-
|
109
|
+
model_file = model.model_file
|
110
|
+
|
111
|
+
model = HuggingfaceModel.new "SequenceClassification", model_file
|
112
|
+
model.model_options[:class_labels] = ["Bad", "Good"]
|
78
113
|
|
79
114
|
assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
|
80
115
|
|
116
|
+
model = VectorModel.new dir
|
117
|
+
|
118
|
+
assert_equal "Good", model.eval("This is dog")
|
119
|
+
|
81
120
|
end
|
82
121
|
end
|
83
122
|
|
@@ -108,9 +147,7 @@ class TestHuggingface < Test::Unit::TestCase
|
|
108
147
|
model = HuggingfaceModel.new "MaskedLM", checkpoint
|
109
148
|
assert_equal 3, model.eval(["Paris is the [MASK] of the France.", "The [MASK] worked very hard all the time.", "The [MASK] arrested the dangerous [MASK]."]).
|
110
149
|
reject{|v| v.empty?}.length
|
111
|
-
|
112
150
|
end
|
113
151
|
|
114
|
-
|
115
152
|
end
|
116
153
|
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rbbt-dm
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 1.2.
|
4
|
+
version: 1.2.6
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Miguel Vazquez
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-02-
|
11
|
+
date: 2023-02-08 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rbbt-util
|
@@ -107,12 +107,14 @@ files:
|
|
107
107
|
- lib/rbbt/statistics/rank_product.rb
|
108
108
|
- lib/rbbt/tensorflow.rb
|
109
109
|
- lib/rbbt/vector/model.rb
|
110
|
+
- lib/rbbt/vector/model/huggingface.old.rb
|
110
111
|
- lib/rbbt/vector/model/huggingface.rb
|
111
112
|
- lib/rbbt/vector/model/random_forest.rb
|
112
113
|
- lib/rbbt/vector/model/spaCy.rb
|
113
114
|
- lib/rbbt/vector/model/svm.rb
|
114
115
|
- lib/rbbt/vector/model/tensorflow.rb
|
115
116
|
- lib/rbbt/vector/model/util.rb
|
117
|
+
- python/rbbt_dm/__init__.py
|
116
118
|
- python/rbbt_dm/huggingface.py
|
117
119
|
- share/R/MA.R
|
118
120
|
- share/R/barcode.R
|