rbbt-dm 1.2.4 → 1.2.7
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/lib/rbbt/vector/model/huggingface.old.rb +160 -0
- data/lib/rbbt/vector/model/huggingface.rb +68 -45
- data/lib/rbbt/vector/model/spaCy.rb +0 -8
- data/lib/rbbt/vector/model/util.rb +18 -0
- data/lib/rbbt/vector/model.rb +56 -40
- data/python/rbbt_dm/huggingface.py +38 -27
- data/test/rbbt/vector/model/test_huggingface.rb +31 -9
- metadata +16 -15
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 1c55843bf543c88167239f6e182495963e0683c5a7fdd7c3a7ab9bd501a78bc8
|
4
|
+
data.tar.gz: d01aaf45331766eac6d868749b8df72c49d1a6888f44f7a1d4f8cbfefe258c87
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 7b6a225ce0403759ab45f26d371d491c19fc76f6560771868a58b9de921fd3aa03750bd7aec95c34029f61f53e71e382958f2779ca790fde30958cfbd1169a0b
|
7
|
+
data.tar.gz: ae1b6d44072398fbde96a0cb31f9586076dee1a5c7e2ac32726c65ecaaa3d08b59ea627c7a0f9f4a8e87547d5a403452ea5bee1d0736d610bf73b6456cb99be9
|
@@ -0,0 +1,160 @@
|
|
1
|
+
require 'rbbt/vector/model'
|
2
|
+
require 'rbbt/util/python'
|
3
|
+
|
4
|
+
RbbtPython.add_path Rbbt.python.find(:lib)
|
5
|
+
RbbtPython.init_rbbt
|
6
|
+
|
7
|
+
class HuggingfaceModel < VectorModel
|
8
|
+
|
9
|
+
attr_accessor :checkpoint, :task, :locate_tokens, :class_labels, :class_weights, :training_args
|
10
|
+
|
11
|
+
def self.tsv_dataset(tsv_dataset_file, elements, labels = nil)
|
12
|
+
|
13
|
+
if labels
|
14
|
+
Open.write(tsv_dataset_file) do |ffile|
|
15
|
+
ffile.puts ["label", "text"].flatten * "\t"
|
16
|
+
elements.zip(labels).each do |element,label|
|
17
|
+
ffile.puts [label, element].flatten * "\t"
|
18
|
+
end
|
19
|
+
end
|
20
|
+
else
|
21
|
+
Open.write(tsv_dataset_file) do |ffile|
|
22
|
+
ffile.puts ["text"].flatten * "\t"
|
23
|
+
elements.each{|element| ffile.puts element }
|
24
|
+
end
|
25
|
+
end
|
26
|
+
|
27
|
+
tsv_dataset_file
|
28
|
+
end
|
29
|
+
|
30
|
+
def self.call_method(name, *args)
|
31
|
+
RbbtPython.import_method("rbbt_dm.huggingface", name).call(*args)
|
32
|
+
end
|
33
|
+
|
34
|
+
def call_method(name, *args)
|
35
|
+
HuggingfaceModel.call_method(name, *args)
|
36
|
+
end
|
37
|
+
|
38
|
+
#def input_tsv_file
|
39
|
+
# File.join(@directory, 'dataset.tsv') if @directory
|
40
|
+
#end
|
41
|
+
|
42
|
+
#def checkpoint_dir
|
43
|
+
# File.join(@directory, 'checkpoints') if @directory
|
44
|
+
#end
|
45
|
+
|
46
|
+
def self.run_model(model, tokenizer, elements, labels = nil, training_args = {}, class_weights = nil)
|
47
|
+
TmpFile.with_file do |tmpfile|
|
48
|
+
tsv_file = File.join(tmpfile, 'dataset.tsv')
|
49
|
+
|
50
|
+
if training_args
|
51
|
+
training_args = training_args.dup
|
52
|
+
checkpoint_dir = training_args.delete(:checkpoint_dir)
|
53
|
+
end
|
54
|
+
|
55
|
+
checkpoint_dir = File.join(tmpfile, 'checkpoints')
|
56
|
+
|
57
|
+
Open.mkdir File.dirname(tsv_file)
|
58
|
+
Open.mkdir File.dirname(checkpoint_dir)
|
59
|
+
|
60
|
+
if labels
|
61
|
+
training_args_obj = call_method(:training_args, checkpoint_dir, **training_args)
|
62
|
+
call_method(:train_model, model, tokenizer, training_args_obj, tsv_dataset(tsv_file, elements, labels), class_weights)
|
63
|
+
else
|
64
|
+
locate_tokens, training_args = training_args, {}
|
65
|
+
if Array === elements
|
66
|
+
training_args_obj = call_method(:training_args, checkpoint_dir)
|
67
|
+
call_method(:predict_model, model, tokenizer, training_args_obj, tsv_dataset(tsv_file, elements), locate_tokens)
|
68
|
+
else
|
69
|
+
call_method(:eval_model, model, tokenizer, [elements], locate_tokens)
|
70
|
+
end
|
71
|
+
end
|
72
|
+
end
|
73
|
+
end
|
74
|
+
|
75
|
+
def init_model
|
76
|
+
@model, @tokenizer = call_method(:load_model_and_tokenizer, @task, @checkpoint)
|
77
|
+
end
|
78
|
+
|
79
|
+
def reset_model
|
80
|
+
init_model
|
81
|
+
end
|
82
|
+
|
83
|
+
def initialize(task, initial_checkpoint = nil, *args)
|
84
|
+
super(*args)
|
85
|
+
@task = task
|
86
|
+
|
87
|
+
@checkpoint = model_file && File.exists?(model_file)? model_file : initial_checkpoint
|
88
|
+
|
89
|
+
init_model
|
90
|
+
|
91
|
+
@locate_tokens = @tokenizer.special_tokens_map["mask_token"] if @task == "MaskedLM"
|
92
|
+
|
93
|
+
@training_args = {}
|
94
|
+
|
95
|
+
train_model do |file,elements,labels|
|
96
|
+
HuggingfaceModel.run_model(@model, @tokenizer, elements, labels, @training_args, @class_weights)
|
97
|
+
|
98
|
+
@model.save_pretrained(file) if file
|
99
|
+
@tokenizer.save_pretrained(file) if file
|
100
|
+
end
|
101
|
+
|
102
|
+
eval_model do |file,elements|
|
103
|
+
@model, @tokenizer = HuggingfaceModel.call_method(:load_model_and_tokenizer, @task, @checkpoint)
|
104
|
+
HuggingfaceModel.run_model(@model, @tokenizer, elements, nil, @locate_tokens)
|
105
|
+
end
|
106
|
+
|
107
|
+
post_process do |result|
|
108
|
+
if result.respond_to?(:predictions)
|
109
|
+
single = false
|
110
|
+
predictions = result.predictions
|
111
|
+
elsif result["token_positions"]
|
112
|
+
predictions = result["result"].predictions
|
113
|
+
token_positions = result["token_positions"]
|
114
|
+
else
|
115
|
+
single = true
|
116
|
+
predictions = result["logits"]
|
117
|
+
end
|
118
|
+
|
119
|
+
result = case @task
|
120
|
+
when "SequenceClassification"
|
121
|
+
RbbtPython.collect(predictions) do |logits|
|
122
|
+
logits = RbbtPython.numpy2ruby logits
|
123
|
+
best_class = logits.index logits.max
|
124
|
+
best_class = @class_labels[best_class] if @class_labels
|
125
|
+
best_class
|
126
|
+
end
|
127
|
+
when "MaskedLM"
|
128
|
+
all_token_positions = token_positions.to_a
|
129
|
+
|
130
|
+
i = 0
|
131
|
+
RbbtPython.collect(predictions) do |item_logits|
|
132
|
+
item_token_positions = all_token_positions[i]
|
133
|
+
i += 1
|
134
|
+
|
135
|
+
item_logits = RbbtPython.numpy2ruby(item_logits)
|
136
|
+
item_masks = item_token_positions.collect do |token_positions|
|
137
|
+
|
138
|
+
best = item_logits.values_at(*token_positions).collect do |logits|
|
139
|
+
best_token, best_score = nil
|
140
|
+
logits.each_with_index do |v,i|
|
141
|
+
if best_score.nil? || v > best_score
|
142
|
+
best_token, best_score = i, v
|
143
|
+
end
|
144
|
+
end
|
145
|
+
best_token
|
146
|
+
end
|
147
|
+
|
148
|
+
best.collect{|b| @tokenizer.decode(b) } * "|"
|
149
|
+
end
|
150
|
+
Array === @locate_tokens ? item_masks : item_masks.first
|
151
|
+
end
|
152
|
+
else
|
153
|
+
logits
|
154
|
+
end
|
155
|
+
|
156
|
+
single ? result.first : result
|
157
|
+
end
|
158
|
+
end
|
159
|
+
end
|
160
|
+
|
@@ -6,9 +6,7 @@ RbbtPython.init_rbbt
|
|
6
6
|
|
7
7
|
class HuggingfaceModel < VectorModel
|
8
8
|
|
9
|
-
|
10
|
-
|
11
|
-
def tsv_dataset(tsv_dataset_file, elements, labels = nil)
|
9
|
+
def self.tsv_dataset(tsv_dataset_file, elements, labels = nil)
|
12
10
|
|
13
11
|
if labels
|
14
12
|
Open.write(tsv_dataset_file) do |ffile|
|
@@ -27,59 +25,74 @@ class HuggingfaceModel < VectorModel
|
|
27
25
|
tsv_dataset_file
|
28
26
|
end
|
29
27
|
|
30
|
-
def
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
end
|
28
|
+
def initialize(task, checkpoint, *args)
|
29
|
+
options = args.pop if Hash === args.last
|
30
|
+
options = Misc.add_defaults options, :task => task, :checkpoint => checkpoint
|
31
|
+
super(*args)
|
32
|
+
@model_options ||= {}
|
33
|
+
@model_options.merge!(options)
|
37
34
|
|
38
|
-
|
39
|
-
|
40
|
-
end
|
35
|
+
eval_model do |directory,texts|
|
36
|
+
checkpoint = directory && File.directory?(directory) ? directory : @model_options[:checkpoint]
|
41
37
|
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
38
|
+
if @model.nil?
|
39
|
+
@model, @tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_model_and_tokenizer, @model_options[:task], checkpoint)
|
40
|
+
end
|
41
|
+
|
42
|
+
if Array === texts
|
46
43
|
|
47
|
-
|
48
|
-
|
44
|
+
if @model_options.include?(:locate_tokens)
|
45
|
+
locate_tokens = @model_options[:locate_tokens]
|
46
|
+
elsif @model_options[:task] == "MaskedLM"
|
47
|
+
@model_options[:locate_tokens] = locate_tokens = @tokenizer.special_tokens_map["mask_token"]
|
48
|
+
end
|
49
49
|
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
else
|
54
|
-
if Array === elements
|
55
|
-
training_args = call_method(:training_args, output_dir)
|
56
|
-
call_method(:predict_model, @model, @tokenizer, training_args, tsv_dataset(tsv_file, elements), @locate_tokens)
|
50
|
+
if @directory
|
51
|
+
tsv_file = File.join(@directory, 'dataset.tsv')
|
52
|
+
checkpoint_dir = File.join(@directory, 'checkpoints')
|
57
53
|
else
|
58
|
-
|
54
|
+
tmpdir = TmpFile.tmp_file
|
55
|
+
Open.mkdir tmpdir
|
56
|
+
tsv_file = File.join(tmpdir, 'dataset.tsv')
|
57
|
+
checkpoint_dir = File.join(tmpdir, 'checkpoints')
|
58
|
+
end
|
59
|
+
|
60
|
+
dataset_file = HuggingfaceModel.tsv_dataset(tsv_file, texts)
|
61
|
+
training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, checkpoint_dir, @model_options[:training_args])
|
62
|
+
|
63
|
+
begin
|
64
|
+
RbbtPython.call_method("rbbt_dm.huggingface", :predict_model, @model, @tokenizer, training_args_obj, dataset_file, locate_tokens)
|
65
|
+
ensure
|
66
|
+
Open.rm_rf tmpdir if tmpdir
|
59
67
|
end
|
68
|
+
else
|
69
|
+
RbbtPython.call_method("rbbt_dm.huggingface", :eval_model, @model, @tokenizer, [texts], locate_tokens)
|
60
70
|
end
|
61
71
|
end
|
62
|
-
end
|
63
|
-
|
64
|
-
def initialize(task, initial_checkpoint = nil, *args)
|
65
|
-
super(*args)
|
66
|
-
@task = task
|
67
72
|
|
68
|
-
|
73
|
+
train_model do |directory,texts,labels|
|
74
|
+
checkpoint = directory && File.directory?(directory) ? directory : @model_options[:checkpoint]
|
75
|
+
@model, @tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_model_and_tokenizer, @model_options[:task], checkpoint)
|
69
76
|
|
70
|
-
|
77
|
+
if @directory
|
78
|
+
tsv_file = File.join(@directory, 'dataset.tsv')
|
79
|
+
checkpoint_dir = File.join(@directory, 'checkpoints')
|
80
|
+
else
|
81
|
+
tmpdir = TmpFile.tmp_file
|
82
|
+
Open.mkdir tmpdir
|
83
|
+
tsv_file = File.join(tmpdir, 'dataset.tsv')
|
84
|
+
checkpoint_dir = File.join(tmpdir, 'checkpoints')
|
85
|
+
end
|
71
86
|
|
72
|
-
|
87
|
+
training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, checkpoint_dir, @model_options[:training_args])
|
88
|
+
dataset_file = HuggingfaceModel.tsv_dataset(tsv_file, texts, labels)
|
73
89
|
|
74
|
-
|
75
|
-
run_model(elements, labels)
|
90
|
+
RbbtPython.call_method("rbbt_dm.huggingface", :train_model, @model, @tokenizer, training_args_obj, dataset_file, @model_options[:class_weights])
|
76
91
|
|
77
|
-
|
78
|
-
@tokenizer.save_pretrained(file) if file
|
79
|
-
end
|
92
|
+
Open.rm_rf tmpdir if tmpdir
|
80
93
|
|
81
|
-
|
82
|
-
|
94
|
+
@model.save_pretrained(directory) if directory
|
95
|
+
@tokenizer.save_pretrained(directory) if directory
|
83
96
|
end
|
84
97
|
|
85
98
|
post_process do |result|
|
@@ -94,12 +107,13 @@ class HuggingfaceModel < VectorModel
|
|
94
107
|
predictions = result["logits"]
|
95
108
|
end
|
96
109
|
|
97
|
-
|
110
|
+
task, class_labels, locate_tokens = @model_options.values_at :task, :class_labels, :locate_tokens
|
111
|
+
result = case task
|
98
112
|
when "SequenceClassification"
|
99
113
|
RbbtPython.collect(predictions) do |logits|
|
100
114
|
logits = RbbtPython.numpy2ruby logits
|
101
115
|
best_class = logits.index logits.max
|
102
|
-
best_class =
|
116
|
+
best_class = class_labels[best_class] if class_labels
|
103
117
|
best_class
|
104
118
|
end
|
105
119
|
when "MaskedLM"
|
@@ -125,7 +139,7 @@ class HuggingfaceModel < VectorModel
|
|
125
139
|
|
126
140
|
best.collect{|b| @tokenizer.decode(b) } * "|"
|
127
141
|
end
|
128
|
-
Array ===
|
142
|
+
Array === locate_tokens ? item_masks : item_masks.first
|
129
143
|
end
|
130
144
|
else
|
131
145
|
logits
|
@@ -133,6 +147,15 @@ class HuggingfaceModel < VectorModel
|
|
133
147
|
|
134
148
|
single ? result.first : result
|
135
149
|
end
|
150
|
+
|
151
|
+
|
152
|
+
save_models if @directory
|
136
153
|
end
|
154
|
+
|
155
|
+
def reset_model
|
156
|
+
@model, @tokenizer = nil
|
157
|
+
Open.rm @model_file
|
158
|
+
end
|
159
|
+
|
137
160
|
end
|
138
161
|
|
@@ -75,14 +75,6 @@ class SpaCyModel < VectorModel
|
|
75
75
|
d.cats.sort_by{|l,v| v.to_f || 0 }.last.first
|
76
76
|
end
|
77
77
|
end
|
78
|
-
#nlp.(docs).cats.collect{|cats| cats.sort_by{|l,v| v.to_f }.last.first }
|
79
|
-
#Log::ProgressBar.with_bar texts.length, :desc => "Evaluating documents" do |bar|
|
80
|
-
# texts.collect do |text|
|
81
|
-
# cats = nlp.(text).cats
|
82
|
-
# bar.tick
|
83
|
-
# cats.sort_by{|l,v| v.to_f }.last.first
|
84
|
-
# end
|
85
|
-
#end
|
86
78
|
end
|
87
79
|
end
|
88
80
|
end
|
@@ -9,4 +9,22 @@ class VectorModel
|
|
9
9
|
@bar.init
|
10
10
|
@bar
|
11
11
|
end
|
12
|
+
|
13
|
+
def balance_labels
|
14
|
+
counts = Misc.counts(@labels)
|
15
|
+
min = counts.values.min
|
16
|
+
|
17
|
+
used = {}
|
18
|
+
new_labels = []
|
19
|
+
new_features = []
|
20
|
+
@labels.zip(@features).shuffle.each do |label, features|
|
21
|
+
used[label] ||= 0
|
22
|
+
next if used[label] > min
|
23
|
+
used[label] += 1
|
24
|
+
new_labels << label
|
25
|
+
new_features << features
|
26
|
+
end
|
27
|
+
@labels = new_labels
|
28
|
+
@features = new_features
|
29
|
+
end
|
12
30
|
end
|
data/lib/rbbt/vector/model.rb
CHANGED
@@ -2,8 +2,9 @@ require 'rbbt/util/R'
|
|
2
2
|
require 'rbbt/vector/model/util'
|
3
3
|
|
4
4
|
class VectorModel
|
5
|
-
attr_accessor :directory, :model_file, :extract_features, :train_model, :eval_model, :post_process
|
5
|
+
attr_accessor :directory, :model_file, :extract_features, :train_model, :eval_model, :post_process, :balance
|
6
6
|
attr_accessor :features, :names, :labels, :factor_levels
|
7
|
+
attr_accessor :model_options
|
7
8
|
|
8
9
|
def extract_features(&block)
|
9
10
|
@extract_features = block if block_given?
|
@@ -126,7 +127,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
126
127
|
instance_eval code, file
|
127
128
|
end
|
128
129
|
|
129
|
-
def initialize(directory = nil, extract_features = nil, train_model = nil, eval_model = nil, names = nil, factor_levels = nil)
|
130
|
+
def initialize(directory = nil, extract_features = nil, train_model = nil, eval_model = nil, post_process = nil, names = nil, factor_levels = nil)
|
130
131
|
@directory = directory
|
131
132
|
if @directory
|
132
133
|
FileUtils.mkdir_p @directory unless File.exists?(@directory)
|
@@ -135,10 +136,18 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
135
136
|
@extract_features_file = File.join(@directory, "features")
|
136
137
|
@train_model_file = File.join(@directory, "train_model")
|
137
138
|
@eval_model_file = File.join(@directory, "eval_model")
|
139
|
+
@post_process_file = File.join(@directory, "post_process")
|
138
140
|
@train_model_file_R = File.join(@directory, "train_model.R")
|
139
141
|
@eval_model_file_R = File.join(@directory, "eval_model.R")
|
142
|
+
@post_process_file_R = File.join(@directory, "post_process.R")
|
140
143
|
@names_file = File.join(@directory, "feature_names")
|
141
144
|
@levels_file = File.join(@directory, "levels")
|
145
|
+
@options_file = File.join(@directory, "options.json")
|
146
|
+
|
147
|
+
if File.exists?(@options_file)
|
148
|
+
@model_options = JSON.parse(Open.read(@options_file))
|
149
|
+
IndiferentHash.setup(@model_options)
|
150
|
+
end
|
142
151
|
end
|
143
152
|
|
144
153
|
if extract_features.nil?
|
@@ -169,6 +178,17 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
169
178
|
@eval_model = eval_model
|
170
179
|
end
|
171
180
|
|
181
|
+
if post_process.nil?
|
182
|
+
if @post_process_file && File.exists?(@post_process_file)
|
183
|
+
@post_process = __load_method @post_process_file
|
184
|
+
elsif @post_process_file_R && File.exists?(@post_process_file_R)
|
185
|
+
@post_process = Open.read(@post_process_file_R)
|
186
|
+
end
|
187
|
+
else
|
188
|
+
@post_process = post_process
|
189
|
+
end
|
190
|
+
|
191
|
+
|
172
192
|
if names.nil?
|
173
193
|
if @names_file && File.exists?(@names_file)
|
174
194
|
@names = Open.read(@names_file).split("\n")
|
@@ -240,18 +260,43 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
240
260
|
Open.write(@eval_model_file_R, eval_model)
|
241
261
|
end
|
242
262
|
|
263
|
+
case
|
264
|
+
when Proc === post_process
|
265
|
+
begin
|
266
|
+
Open.write(@post_process_file, post_process.source)
|
267
|
+
rescue
|
268
|
+
end
|
269
|
+
when String === post_process
|
270
|
+
Open.write(@post_process_file_R, post_process)
|
271
|
+
end
|
272
|
+
|
243
273
|
Open.write(@levels_file, @factor_levels.to_yaml) if @factor_levels
|
244
274
|
Open.write(@names_file, @names * "\n" + "\n") if @names
|
275
|
+
Open.write(@options_file, @model_options.to_json) if @model_options
|
245
276
|
end
|
246
277
|
|
247
278
|
def train
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
279
|
+
begin
|
280
|
+
if @balance
|
281
|
+
@original_features = @features
|
282
|
+
@original_labels = @labels
|
283
|
+
self.balance_labels
|
284
|
+
end
|
285
|
+
|
286
|
+
case
|
287
|
+
when Proc === @train_model
|
288
|
+
self.instance_exec(@model_file, @features, @labels, @names, @factor_levels, &@train_model)
|
289
|
+
when String === @train_model
|
290
|
+
VectorModel.R_train(@model_file, @features, @labels, train_model, @names, @factor_levels)
|
291
|
+
end
|
292
|
+
ensure
|
293
|
+
if @balance
|
294
|
+
@features = @original_features
|
295
|
+
@labels = @original_labels
|
296
|
+
end
|
253
297
|
end
|
254
|
-
|
298
|
+
|
299
|
+
save_models if @directory
|
255
300
|
end
|
256
301
|
|
257
302
|
def run(code)
|
@@ -299,38 +344,6 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
299
344
|
result
|
300
345
|
end
|
301
346
|
|
302
|
-
#def cross_validation(folds = 10)
|
303
|
-
# saved_features = @features
|
304
|
-
# saved_labels = @labels
|
305
|
-
# seq = (0..features.length - 1).to_a
|
306
|
-
|
307
|
-
# chunk_size = features.length / folds
|
308
|
-
|
309
|
-
# acc = []
|
310
|
-
# folds.times do
|
311
|
-
# seq = seq.shuffle
|
312
|
-
# eval_chunk = seq[0..chunk_size]
|
313
|
-
# train_chunk = seq[chunk_size.. -1]
|
314
|
-
|
315
|
-
# eval_features = @features.values_at *eval_chunk
|
316
|
-
# eval_labels = @labels.values_at *eval_chunk
|
317
|
-
|
318
|
-
# @features = @features.values_at *train_chunk
|
319
|
-
# @labels = @labels.values_at *train_chunk
|
320
|
-
|
321
|
-
# train
|
322
|
-
# predictions = eval_list eval_features, false
|
323
|
-
|
324
|
-
# acc << predictions.zip(eval_labels).collect{|pred,lab| pred - lab < 0.5 ? 1 : 0}.inject(0){|acc,e| acc +=e} / chunk_size
|
325
|
-
|
326
|
-
# @features = saved_features
|
327
|
-
# @labels = saved_labels
|
328
|
-
# end
|
329
|
-
|
330
|
-
# acc
|
331
|
-
#end
|
332
|
-
#
|
333
|
-
|
334
347
|
def self.f1_metrics(test, predicted, good_label = nil)
|
335
348
|
tp, tn, fp, fn, pr, re, f1 = [0, 0, 0, 0, nil, nil, nil]
|
336
349
|
|
@@ -413,6 +426,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
413
426
|
@features = train_set
|
414
427
|
@labels = train_labels
|
415
428
|
|
429
|
+
self.reset_model if self.respond_to? :reset_model
|
416
430
|
self.train
|
417
431
|
predictions = self.eval_list test_set, false
|
418
432
|
|
@@ -437,6 +451,8 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
437
451
|
@features = orig_features
|
438
452
|
@labels = orig_labels
|
439
453
|
end unless folds == -1
|
454
|
+
|
455
|
+
self.reset_model if self.respond_to? :reset_model
|
440
456
|
self.train unless folds == 1
|
441
457
|
res
|
442
458
|
end
|
@@ -17,6 +17,17 @@ def load_model_and_tokenizer(task, checkpoint):
|
|
17
17
|
tokenizer = load_tokenizer(task, checkpoint)
|
18
18
|
return model, tokenizer
|
19
19
|
|
20
|
+
def load_model_and_tokenizer_from_directory(directory):
|
21
|
+
import os
|
22
|
+
import json
|
23
|
+
options_file = os.path.join(directory, 'options.json')
|
24
|
+
f = open(options_file, "r")
|
25
|
+
options = json.load(f.read())
|
26
|
+
f.close()
|
27
|
+
task = options["task"]
|
28
|
+
checkpoint = options["checkpoint"]
|
29
|
+
return load_model_and_tokenizer(task, checkpoint)
|
30
|
+
|
20
31
|
#{{{ SIMPLE EVALUATE
|
21
32
|
|
22
33
|
def forward(model, features):
|
@@ -42,7 +53,7 @@ def load_tsv(tsv_file):
|
|
42
53
|
|
43
54
|
def tsv_dataset(tokenizer, tsv_file):
|
44
55
|
dataset = load_tsv(tsv_file)
|
45
|
-
tokenized_dataset = dataset.map(lambda example: tokenizer(example["text"], truncation=True) , batched=True)
|
56
|
+
tokenized_dataset = dataset.map(lambda example: tokenizer(example["text"], truncation=True, max_length=512) , batched=True)
|
46
57
|
return tokenized_dataset
|
47
58
|
|
48
59
|
def training_args(*args, **kwargs):
|
@@ -57,34 +68,34 @@ def train_model(model, tokenizer, training_args, tsv_file, class_weights=None):
|
|
57
68
|
tokenized_dataset = tsv_dataset(tokenizer, tsv_file)
|
58
69
|
|
59
70
|
if (not class_weights == None):
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
71
|
+
import torch
|
72
|
+
from torch import nn
|
73
|
+
|
74
|
+
class WeightTrainer(Trainer):
|
75
|
+
def compute_loss(self, model, inputs, return_outputs=False):
|
76
|
+
labels = inputs.get("labels")
|
77
|
+
# forward pass
|
78
|
+
outputs = model(**inputs)
|
79
|
+
logits = outputs.get('logits')
|
80
|
+
# compute custom loss
|
81
|
+
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor(class_weights).to(model.device))
|
82
|
+
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
|
83
|
+
return (loss, outputs) if return_outputs else loss
|
84
|
+
|
85
|
+
trainer = WeightTrainer(
|
86
|
+
model,
|
87
|
+
training_args,
|
88
|
+
train_dataset = tokenized_dataset["train"],
|
89
|
+
tokenizer = tokenizer
|
90
|
+
)
|
80
91
|
else:
|
81
92
|
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
93
|
+
trainer = Trainer(
|
94
|
+
model,
|
95
|
+
training_args,
|
96
|
+
train_dataset = tokenized_dataset["train"],
|
97
|
+
tokenizer = tokenizer
|
98
|
+
)
|
88
99
|
|
89
100
|
trainer.train()
|
90
101
|
|
@@ -3,6 +3,21 @@ require 'rbbt/vector/model/huggingface'
|
|
3
3
|
|
4
4
|
class TestHuggingface < Test::Unit::TestCase
|
5
5
|
|
6
|
+
def test_options
|
7
|
+
TmpFile.with_file do |dir|
|
8
|
+
checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
|
9
|
+
task = "SequenceClassification"
|
10
|
+
|
11
|
+
model = HuggingfaceModel.new task, checkpoint, dir, :class_labels => %w(bad good)
|
12
|
+
iii model.eval "This is dog"
|
13
|
+
iii model.eval "This is cat"
|
14
|
+
iii model.eval(["This is dog", "This is cat"])
|
15
|
+
|
16
|
+
model = VectorModel.new dir
|
17
|
+
iii model.eval(["This is dog", "This is cat"])
|
18
|
+
end
|
19
|
+
end
|
20
|
+
|
6
21
|
def test_pipeline
|
7
22
|
require 'rbbt/util/python'
|
8
23
|
model = VectorModel.new
|
@@ -25,7 +40,7 @@ class TestHuggingface < Test::Unit::TestCase
|
|
25
40
|
|
26
41
|
model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
|
27
42
|
|
28
|
-
model.class_labels = ["Bad", "Good"]
|
43
|
+
model.model_options[:class_labels] = ["Bad", "Good"]
|
29
44
|
|
30
45
|
assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
|
31
46
|
end
|
@@ -37,7 +52,8 @@ class TestHuggingface < Test::Unit::TestCase
|
|
37
52
|
checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
|
38
53
|
|
39
54
|
model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
|
40
|
-
|
55
|
+
|
56
|
+
model.model_options[:class_labels] = %w(Bad Good)
|
41
57
|
|
42
58
|
assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
|
43
59
|
|
@@ -48,6 +64,9 @@ class TestHuggingface < Test::Unit::TestCase
|
|
48
64
|
model.train
|
49
65
|
|
50
66
|
assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
|
67
|
+
|
68
|
+
model = VectorModel.new dir
|
69
|
+
assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
|
51
70
|
end
|
52
71
|
end
|
53
72
|
|
@@ -55,7 +74,7 @@ class TestHuggingface < Test::Unit::TestCase
|
|
55
74
|
checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
|
56
75
|
|
57
76
|
model = HuggingfaceModel.new "SequenceClassification", checkpoint
|
58
|
-
model.class_labels = ["Bad", "Good"]
|
77
|
+
model.model_options[:class_labels] = ["Bad", "Good"]
|
59
78
|
|
60
79
|
assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
|
61
80
|
|
@@ -73,7 +92,7 @@ class TestHuggingface < Test::Unit::TestCase
|
|
73
92
|
checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
|
74
93
|
|
75
94
|
model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
|
76
|
-
model.class_labels = ["Bad", "Good"]
|
95
|
+
model.model_options[:class_labels] = ["Bad", "Good"]
|
77
96
|
|
78
97
|
assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
|
79
98
|
|
@@ -84,15 +103,20 @@ class TestHuggingface < Test::Unit::TestCase
|
|
84
103
|
model.train
|
85
104
|
|
86
105
|
model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
|
87
|
-
model.class_labels = ["Bad", "Good"]
|
88
106
|
|
89
107
|
assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
|
90
108
|
|
91
|
-
|
92
|
-
|
109
|
+
model_file = model.model_file
|
110
|
+
|
111
|
+
model = HuggingfaceModel.new "SequenceClassification", model_file
|
112
|
+
model.model_options[:class_labels] = ["Bad", "Good"]
|
93
113
|
|
94
114
|
assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
|
95
115
|
|
116
|
+
model = VectorModel.new dir
|
117
|
+
|
118
|
+
assert_equal "Good", model.eval("This is dog")
|
119
|
+
|
96
120
|
end
|
97
121
|
end
|
98
122
|
|
@@ -123,9 +147,7 @@ class TestHuggingface < Test::Unit::TestCase
|
|
123
147
|
model = HuggingfaceModel.new "MaskedLM", checkpoint
|
124
148
|
assert_equal 3, model.eval(["Paris is the [MASK] of the France.", "The [MASK] worked very hard all the time.", "The [MASK] arrested the dangerous [MASK]."]).
|
125
149
|
reject{|v| v.empty?}.length
|
126
|
-
|
127
150
|
end
|
128
151
|
|
129
|
-
|
130
152
|
end
|
131
153
|
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rbbt-dm
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 1.2.
|
4
|
+
version: 1.2.7
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Miguel Vazquez
|
8
|
-
autorequire:
|
8
|
+
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-02-
|
11
|
+
date: 2023-02-08 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rbbt-util
|
@@ -107,6 +107,7 @@ files:
|
|
107
107
|
- lib/rbbt/statistics/rank_product.rb
|
108
108
|
- lib/rbbt/tensorflow.rb
|
109
109
|
- lib/rbbt/vector/model.rb
|
110
|
+
- lib/rbbt/vector/model/huggingface.old.rb
|
110
111
|
- lib/rbbt/vector/model/huggingface.rb
|
111
112
|
- lib/rbbt/vector/model/random_forest.rb
|
112
113
|
- lib/rbbt/vector/model/spaCy.rb
|
@@ -143,7 +144,7 @@ files:
|
|
143
144
|
homepage: http://github.com/mikisvaz/rbbt-phgx
|
144
145
|
licenses: []
|
145
146
|
metadata: {}
|
146
|
-
post_install_message:
|
147
|
+
post_install_message:
|
147
148
|
rdoc_options: []
|
148
149
|
require_paths:
|
149
150
|
- lib
|
@@ -158,22 +159,22 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
158
159
|
- !ruby/object:Gem::Version
|
159
160
|
version: '0'
|
160
161
|
requirements: []
|
161
|
-
rubygems_version: 3.1.
|
162
|
-
signing_key:
|
162
|
+
rubygems_version: 3.1.2
|
163
|
+
signing_key:
|
163
164
|
specification_version: 4
|
164
165
|
summary: Data-mining and statistics
|
165
166
|
test_files:
|
166
|
-
- test/
|
167
|
-
- test/rbbt/statistics/test_fisher.rb
|
168
|
-
- test/rbbt/statistics/test_fdr.rb
|
169
|
-
- test/rbbt/statistics/test_random_walk.rb
|
170
|
-
- test/rbbt/test_ml_task.rb
|
167
|
+
- test/test_helper.rb
|
171
168
|
- test/rbbt/vector/test_model.rb
|
169
|
+
- test/rbbt/vector/model/test_huggingface.rb
|
172
170
|
- test/rbbt/vector/model/test_tensorflow.rb
|
173
171
|
- test/rbbt/vector/model/test_spaCy.rb
|
174
|
-
- test/rbbt/vector/model/test_huggingface.rb
|
175
172
|
- test/rbbt/vector/model/test_svm.rb
|
176
|
-
- test/rbbt/
|
177
|
-
- test/rbbt/
|
173
|
+
- test/rbbt/statistics/test_random_walk.rb
|
174
|
+
- test/rbbt/statistics/test_fisher.rb
|
175
|
+
- test/rbbt/statistics/test_fdr.rb
|
176
|
+
- test/rbbt/statistics/test_hypergeometric.rb
|
178
177
|
- test/rbbt/test_stan.rb
|
179
|
-
- test/
|
178
|
+
- test/rbbt/matrix/test_barcode.rb
|
179
|
+
- test/rbbt/test_ml_task.rb
|
180
|
+
- test/rbbt/network/test_paths.rb
|