rbbt-dm 1.2.4 → 1.2.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/lib/rbbt/vector/model/huggingface.old.rb +160 -0
- data/lib/rbbt/vector/model/huggingface.rb +68 -45
- data/lib/rbbt/vector/model/spaCy.rb +0 -8
- data/lib/rbbt/vector/model/util.rb +18 -0
- data/lib/rbbt/vector/model.rb +56 -40
- data/python/rbbt_dm/huggingface.py +38 -27
- data/test/rbbt/vector/model/test_huggingface.rb +31 -9
- metadata +16 -15
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 1c55843bf543c88167239f6e182495963e0683c5a7fdd7c3a7ab9bd501a78bc8
|
4
|
+
data.tar.gz: d01aaf45331766eac6d868749b8df72c49d1a6888f44f7a1d4f8cbfefe258c87
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 7b6a225ce0403759ab45f26d371d491c19fc76f6560771868a58b9de921fd3aa03750bd7aec95c34029f61f53e71e382958f2779ca790fde30958cfbd1169a0b
|
7
|
+
data.tar.gz: ae1b6d44072398fbde96a0cb31f9586076dee1a5c7e2ac32726c65ecaaa3d08b59ea627c7a0f9f4a8e87547d5a403452ea5bee1d0736d610bf73b6456cb99be9
|
@@ -0,0 +1,160 @@
|
|
1
|
+
require 'rbbt/vector/model'
|
2
|
+
require 'rbbt/util/python'
|
3
|
+
|
4
|
+
RbbtPython.add_path Rbbt.python.find(:lib)
|
5
|
+
RbbtPython.init_rbbt
|
6
|
+
|
7
|
+
class HuggingfaceModel < VectorModel
|
8
|
+
|
9
|
+
attr_accessor :checkpoint, :task, :locate_tokens, :class_labels, :class_weights, :training_args
|
10
|
+
|
11
|
+
def self.tsv_dataset(tsv_dataset_file, elements, labels = nil)
|
12
|
+
|
13
|
+
if labels
|
14
|
+
Open.write(tsv_dataset_file) do |ffile|
|
15
|
+
ffile.puts ["label", "text"].flatten * "\t"
|
16
|
+
elements.zip(labels).each do |element,label|
|
17
|
+
ffile.puts [label, element].flatten * "\t"
|
18
|
+
end
|
19
|
+
end
|
20
|
+
else
|
21
|
+
Open.write(tsv_dataset_file) do |ffile|
|
22
|
+
ffile.puts ["text"].flatten * "\t"
|
23
|
+
elements.each{|element| ffile.puts element }
|
24
|
+
end
|
25
|
+
end
|
26
|
+
|
27
|
+
tsv_dataset_file
|
28
|
+
end
|
29
|
+
|
30
|
+
def self.call_method(name, *args)
|
31
|
+
RbbtPython.import_method("rbbt_dm.huggingface", name).call(*args)
|
32
|
+
end
|
33
|
+
|
34
|
+
def call_method(name, *args)
|
35
|
+
HuggingfaceModel.call_method(name, *args)
|
36
|
+
end
|
37
|
+
|
38
|
+
#def input_tsv_file
|
39
|
+
# File.join(@directory, 'dataset.tsv') if @directory
|
40
|
+
#end
|
41
|
+
|
42
|
+
#def checkpoint_dir
|
43
|
+
# File.join(@directory, 'checkpoints') if @directory
|
44
|
+
#end
|
45
|
+
|
46
|
+
def self.run_model(model, tokenizer, elements, labels = nil, training_args = {}, class_weights = nil)
|
47
|
+
TmpFile.with_file do |tmpfile|
|
48
|
+
tsv_file = File.join(tmpfile, 'dataset.tsv')
|
49
|
+
|
50
|
+
if training_args
|
51
|
+
training_args = training_args.dup
|
52
|
+
checkpoint_dir = training_args.delete(:checkpoint_dir)
|
53
|
+
end
|
54
|
+
|
55
|
+
checkpoint_dir = File.join(tmpfile, 'checkpoints')
|
56
|
+
|
57
|
+
Open.mkdir File.dirname(tsv_file)
|
58
|
+
Open.mkdir File.dirname(checkpoint_dir)
|
59
|
+
|
60
|
+
if labels
|
61
|
+
training_args_obj = call_method(:training_args, checkpoint_dir, **training_args)
|
62
|
+
call_method(:train_model, model, tokenizer, training_args_obj, tsv_dataset(tsv_file, elements, labels), class_weights)
|
63
|
+
else
|
64
|
+
locate_tokens, training_args = training_args, {}
|
65
|
+
if Array === elements
|
66
|
+
training_args_obj = call_method(:training_args, checkpoint_dir)
|
67
|
+
call_method(:predict_model, model, tokenizer, training_args_obj, tsv_dataset(tsv_file, elements), locate_tokens)
|
68
|
+
else
|
69
|
+
call_method(:eval_model, model, tokenizer, [elements], locate_tokens)
|
70
|
+
end
|
71
|
+
end
|
72
|
+
end
|
73
|
+
end
|
74
|
+
|
75
|
+
def init_model
|
76
|
+
@model, @tokenizer = call_method(:load_model_and_tokenizer, @task, @checkpoint)
|
77
|
+
end
|
78
|
+
|
79
|
+
def reset_model
|
80
|
+
init_model
|
81
|
+
end
|
82
|
+
|
83
|
+
def initialize(task, initial_checkpoint = nil, *args)
|
84
|
+
super(*args)
|
85
|
+
@task = task
|
86
|
+
|
87
|
+
@checkpoint = model_file && File.exists?(model_file)? model_file : initial_checkpoint
|
88
|
+
|
89
|
+
init_model
|
90
|
+
|
91
|
+
@locate_tokens = @tokenizer.special_tokens_map["mask_token"] if @task == "MaskedLM"
|
92
|
+
|
93
|
+
@training_args = {}
|
94
|
+
|
95
|
+
train_model do |file,elements,labels|
|
96
|
+
HuggingfaceModel.run_model(@model, @tokenizer, elements, labels, @training_args, @class_weights)
|
97
|
+
|
98
|
+
@model.save_pretrained(file) if file
|
99
|
+
@tokenizer.save_pretrained(file) if file
|
100
|
+
end
|
101
|
+
|
102
|
+
eval_model do |file,elements|
|
103
|
+
@model, @tokenizer = HuggingfaceModel.call_method(:load_model_and_tokenizer, @task, @checkpoint)
|
104
|
+
HuggingfaceModel.run_model(@model, @tokenizer, elements, nil, @locate_tokens)
|
105
|
+
end
|
106
|
+
|
107
|
+
post_process do |result|
|
108
|
+
if result.respond_to?(:predictions)
|
109
|
+
single = false
|
110
|
+
predictions = result.predictions
|
111
|
+
elsif result["token_positions"]
|
112
|
+
predictions = result["result"].predictions
|
113
|
+
token_positions = result["token_positions"]
|
114
|
+
else
|
115
|
+
single = true
|
116
|
+
predictions = result["logits"]
|
117
|
+
end
|
118
|
+
|
119
|
+
result = case @task
|
120
|
+
when "SequenceClassification"
|
121
|
+
RbbtPython.collect(predictions) do |logits|
|
122
|
+
logits = RbbtPython.numpy2ruby logits
|
123
|
+
best_class = logits.index logits.max
|
124
|
+
best_class = @class_labels[best_class] if @class_labels
|
125
|
+
best_class
|
126
|
+
end
|
127
|
+
when "MaskedLM"
|
128
|
+
all_token_positions = token_positions.to_a
|
129
|
+
|
130
|
+
i = 0
|
131
|
+
RbbtPython.collect(predictions) do |item_logits|
|
132
|
+
item_token_positions = all_token_positions[i]
|
133
|
+
i += 1
|
134
|
+
|
135
|
+
item_logits = RbbtPython.numpy2ruby(item_logits)
|
136
|
+
item_masks = item_token_positions.collect do |token_positions|
|
137
|
+
|
138
|
+
best = item_logits.values_at(*token_positions).collect do |logits|
|
139
|
+
best_token, best_score = nil
|
140
|
+
logits.each_with_index do |v,i|
|
141
|
+
if best_score.nil? || v > best_score
|
142
|
+
best_token, best_score = i, v
|
143
|
+
end
|
144
|
+
end
|
145
|
+
best_token
|
146
|
+
end
|
147
|
+
|
148
|
+
best.collect{|b| @tokenizer.decode(b) } * "|"
|
149
|
+
end
|
150
|
+
Array === @locate_tokens ? item_masks : item_masks.first
|
151
|
+
end
|
152
|
+
else
|
153
|
+
logits
|
154
|
+
end
|
155
|
+
|
156
|
+
single ? result.first : result
|
157
|
+
end
|
158
|
+
end
|
159
|
+
end
|
160
|
+
|
@@ -6,9 +6,7 @@ RbbtPython.init_rbbt
|
|
6
6
|
|
7
7
|
class HuggingfaceModel < VectorModel
|
8
8
|
|
9
|
-
|
10
|
-
|
11
|
-
def tsv_dataset(tsv_dataset_file, elements, labels = nil)
|
9
|
+
def self.tsv_dataset(tsv_dataset_file, elements, labels = nil)
|
12
10
|
|
13
11
|
if labels
|
14
12
|
Open.write(tsv_dataset_file) do |ffile|
|
@@ -27,59 +25,74 @@ class HuggingfaceModel < VectorModel
|
|
27
25
|
tsv_dataset_file
|
28
26
|
end
|
29
27
|
|
30
|
-
def
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
end
|
28
|
+
def initialize(task, checkpoint, *args)
|
29
|
+
options = args.pop if Hash === args.last
|
30
|
+
options = Misc.add_defaults options, :task => task, :checkpoint => checkpoint
|
31
|
+
super(*args)
|
32
|
+
@model_options ||= {}
|
33
|
+
@model_options.merge!(options)
|
37
34
|
|
38
|
-
|
39
|
-
|
40
|
-
end
|
35
|
+
eval_model do |directory,texts|
|
36
|
+
checkpoint = directory && File.directory?(directory) ? directory : @model_options[:checkpoint]
|
41
37
|
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
38
|
+
if @model.nil?
|
39
|
+
@model, @tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_model_and_tokenizer, @model_options[:task], checkpoint)
|
40
|
+
end
|
41
|
+
|
42
|
+
if Array === texts
|
46
43
|
|
47
|
-
|
48
|
-
|
44
|
+
if @model_options.include?(:locate_tokens)
|
45
|
+
locate_tokens = @model_options[:locate_tokens]
|
46
|
+
elsif @model_options[:task] == "MaskedLM"
|
47
|
+
@model_options[:locate_tokens] = locate_tokens = @tokenizer.special_tokens_map["mask_token"]
|
48
|
+
end
|
49
49
|
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
else
|
54
|
-
if Array === elements
|
55
|
-
training_args = call_method(:training_args, output_dir)
|
56
|
-
call_method(:predict_model, @model, @tokenizer, training_args, tsv_dataset(tsv_file, elements), @locate_tokens)
|
50
|
+
if @directory
|
51
|
+
tsv_file = File.join(@directory, 'dataset.tsv')
|
52
|
+
checkpoint_dir = File.join(@directory, 'checkpoints')
|
57
53
|
else
|
58
|
-
|
54
|
+
tmpdir = TmpFile.tmp_file
|
55
|
+
Open.mkdir tmpdir
|
56
|
+
tsv_file = File.join(tmpdir, 'dataset.tsv')
|
57
|
+
checkpoint_dir = File.join(tmpdir, 'checkpoints')
|
58
|
+
end
|
59
|
+
|
60
|
+
dataset_file = HuggingfaceModel.tsv_dataset(tsv_file, texts)
|
61
|
+
training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, checkpoint_dir, @model_options[:training_args])
|
62
|
+
|
63
|
+
begin
|
64
|
+
RbbtPython.call_method("rbbt_dm.huggingface", :predict_model, @model, @tokenizer, training_args_obj, dataset_file, locate_tokens)
|
65
|
+
ensure
|
66
|
+
Open.rm_rf tmpdir if tmpdir
|
59
67
|
end
|
68
|
+
else
|
69
|
+
RbbtPython.call_method("rbbt_dm.huggingface", :eval_model, @model, @tokenizer, [texts], locate_tokens)
|
60
70
|
end
|
61
71
|
end
|
62
|
-
end
|
63
|
-
|
64
|
-
def initialize(task, initial_checkpoint = nil, *args)
|
65
|
-
super(*args)
|
66
|
-
@task = task
|
67
72
|
|
68
|
-
|
73
|
+
train_model do |directory,texts,labels|
|
74
|
+
checkpoint = directory && File.directory?(directory) ? directory : @model_options[:checkpoint]
|
75
|
+
@model, @tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_model_and_tokenizer, @model_options[:task], checkpoint)
|
69
76
|
|
70
|
-
|
77
|
+
if @directory
|
78
|
+
tsv_file = File.join(@directory, 'dataset.tsv')
|
79
|
+
checkpoint_dir = File.join(@directory, 'checkpoints')
|
80
|
+
else
|
81
|
+
tmpdir = TmpFile.tmp_file
|
82
|
+
Open.mkdir tmpdir
|
83
|
+
tsv_file = File.join(tmpdir, 'dataset.tsv')
|
84
|
+
checkpoint_dir = File.join(tmpdir, 'checkpoints')
|
85
|
+
end
|
71
86
|
|
72
|
-
|
87
|
+
training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, checkpoint_dir, @model_options[:training_args])
|
88
|
+
dataset_file = HuggingfaceModel.tsv_dataset(tsv_file, texts, labels)
|
73
89
|
|
74
|
-
|
75
|
-
run_model(elements, labels)
|
90
|
+
RbbtPython.call_method("rbbt_dm.huggingface", :train_model, @model, @tokenizer, training_args_obj, dataset_file, @model_options[:class_weights])
|
76
91
|
|
77
|
-
|
78
|
-
@tokenizer.save_pretrained(file) if file
|
79
|
-
end
|
92
|
+
Open.rm_rf tmpdir if tmpdir
|
80
93
|
|
81
|
-
|
82
|
-
|
94
|
+
@model.save_pretrained(directory) if directory
|
95
|
+
@tokenizer.save_pretrained(directory) if directory
|
83
96
|
end
|
84
97
|
|
85
98
|
post_process do |result|
|
@@ -94,12 +107,13 @@ class HuggingfaceModel < VectorModel
|
|
94
107
|
predictions = result["logits"]
|
95
108
|
end
|
96
109
|
|
97
|
-
|
110
|
+
task, class_labels, locate_tokens = @model_options.values_at :task, :class_labels, :locate_tokens
|
111
|
+
result = case task
|
98
112
|
when "SequenceClassification"
|
99
113
|
RbbtPython.collect(predictions) do |logits|
|
100
114
|
logits = RbbtPython.numpy2ruby logits
|
101
115
|
best_class = logits.index logits.max
|
102
|
-
best_class =
|
116
|
+
best_class = class_labels[best_class] if class_labels
|
103
117
|
best_class
|
104
118
|
end
|
105
119
|
when "MaskedLM"
|
@@ -125,7 +139,7 @@ class HuggingfaceModel < VectorModel
|
|
125
139
|
|
126
140
|
best.collect{|b| @tokenizer.decode(b) } * "|"
|
127
141
|
end
|
128
|
-
Array ===
|
142
|
+
Array === locate_tokens ? item_masks : item_masks.first
|
129
143
|
end
|
130
144
|
else
|
131
145
|
logits
|
@@ -133,6 +147,15 @@ class HuggingfaceModel < VectorModel
|
|
133
147
|
|
134
148
|
single ? result.first : result
|
135
149
|
end
|
150
|
+
|
151
|
+
|
152
|
+
save_models if @directory
|
136
153
|
end
|
154
|
+
|
155
|
+
def reset_model
|
156
|
+
@model, @tokenizer = nil
|
157
|
+
Open.rm @model_file
|
158
|
+
end
|
159
|
+
|
137
160
|
end
|
138
161
|
|
@@ -75,14 +75,6 @@ class SpaCyModel < VectorModel
|
|
75
75
|
d.cats.sort_by{|l,v| v.to_f || 0 }.last.first
|
76
76
|
end
|
77
77
|
end
|
78
|
-
#nlp.(docs).cats.collect{|cats| cats.sort_by{|l,v| v.to_f }.last.first }
|
79
|
-
#Log::ProgressBar.with_bar texts.length, :desc => "Evaluating documents" do |bar|
|
80
|
-
# texts.collect do |text|
|
81
|
-
# cats = nlp.(text).cats
|
82
|
-
# bar.tick
|
83
|
-
# cats.sort_by{|l,v| v.to_f }.last.first
|
84
|
-
# end
|
85
|
-
#end
|
86
78
|
end
|
87
79
|
end
|
88
80
|
end
|
@@ -9,4 +9,22 @@ class VectorModel
|
|
9
9
|
@bar.init
|
10
10
|
@bar
|
11
11
|
end
|
12
|
+
|
13
|
+
def balance_labels
|
14
|
+
counts = Misc.counts(@labels)
|
15
|
+
min = counts.values.min
|
16
|
+
|
17
|
+
used = {}
|
18
|
+
new_labels = []
|
19
|
+
new_features = []
|
20
|
+
@labels.zip(@features).shuffle.each do |label, features|
|
21
|
+
used[label] ||= 0
|
22
|
+
next if used[label] > min
|
23
|
+
used[label] += 1
|
24
|
+
new_labels << label
|
25
|
+
new_features << features
|
26
|
+
end
|
27
|
+
@labels = new_labels
|
28
|
+
@features = new_features
|
29
|
+
end
|
12
30
|
end
|
data/lib/rbbt/vector/model.rb
CHANGED
@@ -2,8 +2,9 @@ require 'rbbt/util/R'
|
|
2
2
|
require 'rbbt/vector/model/util'
|
3
3
|
|
4
4
|
class VectorModel
|
5
|
-
attr_accessor :directory, :model_file, :extract_features, :train_model, :eval_model, :post_process
|
5
|
+
attr_accessor :directory, :model_file, :extract_features, :train_model, :eval_model, :post_process, :balance
|
6
6
|
attr_accessor :features, :names, :labels, :factor_levels
|
7
|
+
attr_accessor :model_options
|
7
8
|
|
8
9
|
def extract_features(&block)
|
9
10
|
@extract_features = block if block_given?
|
@@ -126,7 +127,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
126
127
|
instance_eval code, file
|
127
128
|
end
|
128
129
|
|
129
|
-
def initialize(directory = nil, extract_features = nil, train_model = nil, eval_model = nil, names = nil, factor_levels = nil)
|
130
|
+
def initialize(directory = nil, extract_features = nil, train_model = nil, eval_model = nil, post_process = nil, names = nil, factor_levels = nil)
|
130
131
|
@directory = directory
|
131
132
|
if @directory
|
132
133
|
FileUtils.mkdir_p @directory unless File.exists?(@directory)
|
@@ -135,10 +136,18 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
135
136
|
@extract_features_file = File.join(@directory, "features")
|
136
137
|
@train_model_file = File.join(@directory, "train_model")
|
137
138
|
@eval_model_file = File.join(@directory, "eval_model")
|
139
|
+
@post_process_file = File.join(@directory, "post_process")
|
138
140
|
@train_model_file_R = File.join(@directory, "train_model.R")
|
139
141
|
@eval_model_file_R = File.join(@directory, "eval_model.R")
|
142
|
+
@post_process_file_R = File.join(@directory, "post_process.R")
|
140
143
|
@names_file = File.join(@directory, "feature_names")
|
141
144
|
@levels_file = File.join(@directory, "levels")
|
145
|
+
@options_file = File.join(@directory, "options.json")
|
146
|
+
|
147
|
+
if File.exists?(@options_file)
|
148
|
+
@model_options = JSON.parse(Open.read(@options_file))
|
149
|
+
IndiferentHash.setup(@model_options)
|
150
|
+
end
|
142
151
|
end
|
143
152
|
|
144
153
|
if extract_features.nil?
|
@@ -169,6 +178,17 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
169
178
|
@eval_model = eval_model
|
170
179
|
end
|
171
180
|
|
181
|
+
if post_process.nil?
|
182
|
+
if @post_process_file && File.exists?(@post_process_file)
|
183
|
+
@post_process = __load_method @post_process_file
|
184
|
+
elsif @post_process_file_R && File.exists?(@post_process_file_R)
|
185
|
+
@post_process = Open.read(@post_process_file_R)
|
186
|
+
end
|
187
|
+
else
|
188
|
+
@post_process = post_process
|
189
|
+
end
|
190
|
+
|
191
|
+
|
172
192
|
if names.nil?
|
173
193
|
if @names_file && File.exists?(@names_file)
|
174
194
|
@names = Open.read(@names_file).split("\n")
|
@@ -240,18 +260,43 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
240
260
|
Open.write(@eval_model_file_R, eval_model)
|
241
261
|
end
|
242
262
|
|
263
|
+
case
|
264
|
+
when Proc === post_process
|
265
|
+
begin
|
266
|
+
Open.write(@post_process_file, post_process.source)
|
267
|
+
rescue
|
268
|
+
end
|
269
|
+
when String === post_process
|
270
|
+
Open.write(@post_process_file_R, post_process)
|
271
|
+
end
|
272
|
+
|
243
273
|
Open.write(@levels_file, @factor_levels.to_yaml) if @factor_levels
|
244
274
|
Open.write(@names_file, @names * "\n" + "\n") if @names
|
275
|
+
Open.write(@options_file, @model_options.to_json) if @model_options
|
245
276
|
end
|
246
277
|
|
247
278
|
def train
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
279
|
+
begin
|
280
|
+
if @balance
|
281
|
+
@original_features = @features
|
282
|
+
@original_labels = @labels
|
283
|
+
self.balance_labels
|
284
|
+
end
|
285
|
+
|
286
|
+
case
|
287
|
+
when Proc === @train_model
|
288
|
+
self.instance_exec(@model_file, @features, @labels, @names, @factor_levels, &@train_model)
|
289
|
+
when String === @train_model
|
290
|
+
VectorModel.R_train(@model_file, @features, @labels, train_model, @names, @factor_levels)
|
291
|
+
end
|
292
|
+
ensure
|
293
|
+
if @balance
|
294
|
+
@features = @original_features
|
295
|
+
@labels = @original_labels
|
296
|
+
end
|
253
297
|
end
|
254
|
-
|
298
|
+
|
299
|
+
save_models if @directory
|
255
300
|
end
|
256
301
|
|
257
302
|
def run(code)
|
@@ -299,38 +344,6 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
299
344
|
result
|
300
345
|
end
|
301
346
|
|
302
|
-
#def cross_validation(folds = 10)
|
303
|
-
# saved_features = @features
|
304
|
-
# saved_labels = @labels
|
305
|
-
# seq = (0..features.length - 1).to_a
|
306
|
-
|
307
|
-
# chunk_size = features.length / folds
|
308
|
-
|
309
|
-
# acc = []
|
310
|
-
# folds.times do
|
311
|
-
# seq = seq.shuffle
|
312
|
-
# eval_chunk = seq[0..chunk_size]
|
313
|
-
# train_chunk = seq[chunk_size.. -1]
|
314
|
-
|
315
|
-
# eval_features = @features.values_at *eval_chunk
|
316
|
-
# eval_labels = @labels.values_at *eval_chunk
|
317
|
-
|
318
|
-
# @features = @features.values_at *train_chunk
|
319
|
-
# @labels = @labels.values_at *train_chunk
|
320
|
-
|
321
|
-
# train
|
322
|
-
# predictions = eval_list eval_features, false
|
323
|
-
|
324
|
-
# acc << predictions.zip(eval_labels).collect{|pred,lab| pred - lab < 0.5 ? 1 : 0}.inject(0){|acc,e| acc +=e} / chunk_size
|
325
|
-
|
326
|
-
# @features = saved_features
|
327
|
-
# @labels = saved_labels
|
328
|
-
# end
|
329
|
-
|
330
|
-
# acc
|
331
|
-
#end
|
332
|
-
#
|
333
|
-
|
334
347
|
def self.f1_metrics(test, predicted, good_label = nil)
|
335
348
|
tp, tn, fp, fn, pr, re, f1 = [0, 0, 0, 0, nil, nil, nil]
|
336
349
|
|
@@ -413,6 +426,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
413
426
|
@features = train_set
|
414
427
|
@labels = train_labels
|
415
428
|
|
429
|
+
self.reset_model if self.respond_to? :reset_model
|
416
430
|
self.train
|
417
431
|
predictions = self.eval_list test_set, false
|
418
432
|
|
@@ -437,6 +451,8 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
437
451
|
@features = orig_features
|
438
452
|
@labels = orig_labels
|
439
453
|
end unless folds == -1
|
454
|
+
|
455
|
+
self.reset_model if self.respond_to? :reset_model
|
440
456
|
self.train unless folds == 1
|
441
457
|
res
|
442
458
|
end
|
@@ -17,6 +17,17 @@ def load_model_and_tokenizer(task, checkpoint):
|
|
17
17
|
tokenizer = load_tokenizer(task, checkpoint)
|
18
18
|
return model, tokenizer
|
19
19
|
|
20
|
+
def load_model_and_tokenizer_from_directory(directory):
|
21
|
+
import os
|
22
|
+
import json
|
23
|
+
options_file = os.path.join(directory, 'options.json')
|
24
|
+
f = open(options_file, "r")
|
25
|
+
options = json.load(f.read())
|
26
|
+
f.close()
|
27
|
+
task = options["task"]
|
28
|
+
checkpoint = options["checkpoint"]
|
29
|
+
return load_model_and_tokenizer(task, checkpoint)
|
30
|
+
|
20
31
|
#{{{ SIMPLE EVALUATE
|
21
32
|
|
22
33
|
def forward(model, features):
|
@@ -42,7 +53,7 @@ def load_tsv(tsv_file):
|
|
42
53
|
|
43
54
|
def tsv_dataset(tokenizer, tsv_file):
|
44
55
|
dataset = load_tsv(tsv_file)
|
45
|
-
tokenized_dataset = dataset.map(lambda example: tokenizer(example["text"], truncation=True) , batched=True)
|
56
|
+
tokenized_dataset = dataset.map(lambda example: tokenizer(example["text"], truncation=True, max_length=512) , batched=True)
|
46
57
|
return tokenized_dataset
|
47
58
|
|
48
59
|
def training_args(*args, **kwargs):
|
@@ -57,34 +68,34 @@ def train_model(model, tokenizer, training_args, tsv_file, class_weights=None):
|
|
57
68
|
tokenized_dataset = tsv_dataset(tokenizer, tsv_file)
|
58
69
|
|
59
70
|
if (not class_weights == None):
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
71
|
+
import torch
|
72
|
+
from torch import nn
|
73
|
+
|
74
|
+
class WeightTrainer(Trainer):
|
75
|
+
def compute_loss(self, model, inputs, return_outputs=False):
|
76
|
+
labels = inputs.get("labels")
|
77
|
+
# forward pass
|
78
|
+
outputs = model(**inputs)
|
79
|
+
logits = outputs.get('logits')
|
80
|
+
# compute custom loss
|
81
|
+
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor(class_weights).to(model.device))
|
82
|
+
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
|
83
|
+
return (loss, outputs) if return_outputs else loss
|
84
|
+
|
85
|
+
trainer = WeightTrainer(
|
86
|
+
model,
|
87
|
+
training_args,
|
88
|
+
train_dataset = tokenized_dataset["train"],
|
89
|
+
tokenizer = tokenizer
|
90
|
+
)
|
80
91
|
else:
|
81
92
|
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
93
|
+
trainer = Trainer(
|
94
|
+
model,
|
95
|
+
training_args,
|
96
|
+
train_dataset = tokenized_dataset["train"],
|
97
|
+
tokenizer = tokenizer
|
98
|
+
)
|
88
99
|
|
89
100
|
trainer.train()
|
90
101
|
|
@@ -3,6 +3,21 @@ require 'rbbt/vector/model/huggingface'
|
|
3
3
|
|
4
4
|
class TestHuggingface < Test::Unit::TestCase
|
5
5
|
|
6
|
+
def test_options
|
7
|
+
TmpFile.with_file do |dir|
|
8
|
+
checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
|
9
|
+
task = "SequenceClassification"
|
10
|
+
|
11
|
+
model = HuggingfaceModel.new task, checkpoint, dir, :class_labels => %w(bad good)
|
12
|
+
iii model.eval "This is dog"
|
13
|
+
iii model.eval "This is cat"
|
14
|
+
iii model.eval(["This is dog", "This is cat"])
|
15
|
+
|
16
|
+
model = VectorModel.new dir
|
17
|
+
iii model.eval(["This is dog", "This is cat"])
|
18
|
+
end
|
19
|
+
end
|
20
|
+
|
6
21
|
def test_pipeline
|
7
22
|
require 'rbbt/util/python'
|
8
23
|
model = VectorModel.new
|
@@ -25,7 +40,7 @@ class TestHuggingface < Test::Unit::TestCase
|
|
25
40
|
|
26
41
|
model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
|
27
42
|
|
28
|
-
model.class_labels = ["Bad", "Good"]
|
43
|
+
model.model_options[:class_labels] = ["Bad", "Good"]
|
29
44
|
|
30
45
|
assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
|
31
46
|
end
|
@@ -37,7 +52,8 @@ class TestHuggingface < Test::Unit::TestCase
|
|
37
52
|
checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
|
38
53
|
|
39
54
|
model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
|
40
|
-
|
55
|
+
|
56
|
+
model.model_options[:class_labels] = %w(Bad Good)
|
41
57
|
|
42
58
|
assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
|
43
59
|
|
@@ -48,6 +64,9 @@ class TestHuggingface < Test::Unit::TestCase
|
|
48
64
|
model.train
|
49
65
|
|
50
66
|
assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
|
67
|
+
|
68
|
+
model = VectorModel.new dir
|
69
|
+
assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
|
51
70
|
end
|
52
71
|
end
|
53
72
|
|
@@ -55,7 +74,7 @@ class TestHuggingface < Test::Unit::TestCase
|
|
55
74
|
checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
|
56
75
|
|
57
76
|
model = HuggingfaceModel.new "SequenceClassification", checkpoint
|
58
|
-
model.class_labels = ["Bad", "Good"]
|
77
|
+
model.model_options[:class_labels] = ["Bad", "Good"]
|
59
78
|
|
60
79
|
assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
|
61
80
|
|
@@ -73,7 +92,7 @@ class TestHuggingface < Test::Unit::TestCase
|
|
73
92
|
checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
|
74
93
|
|
75
94
|
model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
|
76
|
-
model.class_labels = ["Bad", "Good"]
|
95
|
+
model.model_options[:class_labels] = ["Bad", "Good"]
|
77
96
|
|
78
97
|
assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
|
79
98
|
|
@@ -84,15 +103,20 @@ class TestHuggingface < Test::Unit::TestCase
|
|
84
103
|
model.train
|
85
104
|
|
86
105
|
model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
|
87
|
-
model.class_labels = ["Bad", "Good"]
|
88
106
|
|
89
107
|
assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
|
90
108
|
|
91
|
-
|
92
|
-
|
109
|
+
model_file = model.model_file
|
110
|
+
|
111
|
+
model = HuggingfaceModel.new "SequenceClassification", model_file
|
112
|
+
model.model_options[:class_labels] = ["Bad", "Good"]
|
93
113
|
|
94
114
|
assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
|
95
115
|
|
116
|
+
model = VectorModel.new dir
|
117
|
+
|
118
|
+
assert_equal "Good", model.eval("This is dog")
|
119
|
+
|
96
120
|
end
|
97
121
|
end
|
98
122
|
|
@@ -123,9 +147,7 @@ class TestHuggingface < Test::Unit::TestCase
|
|
123
147
|
model = HuggingfaceModel.new "MaskedLM", checkpoint
|
124
148
|
assert_equal 3, model.eval(["Paris is the [MASK] of the France.", "The [MASK] worked very hard all the time.", "The [MASK] arrested the dangerous [MASK]."]).
|
125
149
|
reject{|v| v.empty?}.length
|
126
|
-
|
127
150
|
end
|
128
151
|
|
129
|
-
|
130
152
|
end
|
131
153
|
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rbbt-dm
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 1.2.
|
4
|
+
version: 1.2.7
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Miguel Vazquez
|
8
|
-
autorequire:
|
8
|
+
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-02-
|
11
|
+
date: 2023-02-08 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rbbt-util
|
@@ -107,6 +107,7 @@ files:
|
|
107
107
|
- lib/rbbt/statistics/rank_product.rb
|
108
108
|
- lib/rbbt/tensorflow.rb
|
109
109
|
- lib/rbbt/vector/model.rb
|
110
|
+
- lib/rbbt/vector/model/huggingface.old.rb
|
110
111
|
- lib/rbbt/vector/model/huggingface.rb
|
111
112
|
- lib/rbbt/vector/model/random_forest.rb
|
112
113
|
- lib/rbbt/vector/model/spaCy.rb
|
@@ -143,7 +144,7 @@ files:
|
|
143
144
|
homepage: http://github.com/mikisvaz/rbbt-phgx
|
144
145
|
licenses: []
|
145
146
|
metadata: {}
|
146
|
-
post_install_message:
|
147
|
+
post_install_message:
|
147
148
|
rdoc_options: []
|
148
149
|
require_paths:
|
149
150
|
- lib
|
@@ -158,22 +159,22 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
158
159
|
- !ruby/object:Gem::Version
|
159
160
|
version: '0'
|
160
161
|
requirements: []
|
161
|
-
rubygems_version: 3.1.
|
162
|
-
signing_key:
|
162
|
+
rubygems_version: 3.1.2
|
163
|
+
signing_key:
|
163
164
|
specification_version: 4
|
164
165
|
summary: Data-mining and statistics
|
165
166
|
test_files:
|
166
|
-
- test/
|
167
|
-
- test/rbbt/statistics/test_fisher.rb
|
168
|
-
- test/rbbt/statistics/test_fdr.rb
|
169
|
-
- test/rbbt/statistics/test_random_walk.rb
|
170
|
-
- test/rbbt/test_ml_task.rb
|
167
|
+
- test/test_helper.rb
|
171
168
|
- test/rbbt/vector/test_model.rb
|
169
|
+
- test/rbbt/vector/model/test_huggingface.rb
|
172
170
|
- test/rbbt/vector/model/test_tensorflow.rb
|
173
171
|
- test/rbbt/vector/model/test_spaCy.rb
|
174
|
-
- test/rbbt/vector/model/test_huggingface.rb
|
175
172
|
- test/rbbt/vector/model/test_svm.rb
|
176
|
-
- test/rbbt/
|
177
|
-
- test/rbbt/
|
173
|
+
- test/rbbt/statistics/test_random_walk.rb
|
174
|
+
- test/rbbt/statistics/test_fisher.rb
|
175
|
+
- test/rbbt/statistics/test_fdr.rb
|
176
|
+
- test/rbbt/statistics/test_hypergeometric.rb
|
178
177
|
- test/rbbt/test_stan.rb
|
179
|
-
- test/
|
178
|
+
- test/rbbt/matrix/test_barcode.rb
|
179
|
+
- test/rbbt/test_ml_task.rb
|
180
|
+
- test/rbbt/network/test_paths.rb
|