rbbt-dm 1.1.63 → 1.2.1

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 7939251873b8bf86dccb5260dc9f4b64abb7c24b5001f436a1b457d4ad2333af
4
- data.tar.gz: 476b2c7175b557bc287f928225d7da020d8e29a1e59b534d595ea6f75597f23f
3
+ metadata.gz: ab775c0224960820e5c62e294e6a183be49201da15710b66544762e1aaf97ebf
4
+ data.tar.gz: 8fffb47ba226f06d1f41a8893d085bdc12c33c021cf2f0152f4cc741db36e420
5
5
  SHA512:
6
- metadata.gz: bf503009cf5bc8d1ac239f1c8fe07288102560c3bc5324368187690b523f04f92ff94b74bf97c512cb0d4378985f9269177d7dda0897694a0eb2df62f369decc
7
- data.tar.gz: 13c8cd26daff6205de91f5c733563fe3d28af056869584882a9eb08d0d1ddcf5bde88bf38645dc82d9f04864631b3dccc93bf1a56945372c3b2a4bad618f3144
6
+ metadata.gz: 8be084156063cd93c7fe905bc4b6248dd376bbfcff8e650cbb03a4cc5c28f29dbcdaa2801895a3663067d7660e8bc2cf96682829519ebd6511a7a74cec021da0
7
+ data.tar.gz: 9c8570722319caf5afe60c90778d0b8517e70064030e0081968d919a314cfe1af90f63d13a7b8ce8dd56da6644db89916a729a4005f1e20745c8d4b45d50394c
@@ -0,0 +1,140 @@
1
+ require 'rbbt/vector/model'
2
+ require 'rbbt/util/python'
3
+
4
+ RbbtPython.init_rbbt
5
+
6
+ class HuggingfaceModel < VectorModel
7
+
8
+ attr_accessor :checkpoint, :task, :locate_tokens, :class_labels
9
+
10
+ def tsv_dataset(tsv_dataset_file, elements, labels = nil)
11
+
12
+ if labels
13
+ Open.write(tsv_dataset_file) do |ffile|
14
+ ffile.puts ["label", "text"].flatten * "\t"
15
+ elements.zip(labels).each do |element,label|
16
+ ffile.puts [label, element].flatten * "\t"
17
+ end
18
+ end
19
+ else
20
+ Open.write(tsv_dataset_file) do |ffile|
21
+ ffile.puts ["text"].flatten * "\t"
22
+ elements.each{|element| ffile.puts element }
23
+ end
24
+ end
25
+
26
+ tsv_dataset_file
27
+ end
28
+
29
+ def call_method(name, *args)
30
+ RbbtPython.import_method("rbbt_dm.huggingface", name).call(*args)
31
+ end
32
+
33
+ def input_tsv_file
34
+ File.join(@directory, 'dataset.tsv') if @directory
35
+ end
36
+
37
+ def checkpoint_dir
38
+ File.join(@directory, 'checkpoints') if @directory
39
+ end
40
+
41
+ def run_model(elements, labels = nil)
42
+ TmpFile.with_file do |tmpfile|
43
+ tsv_file = input_tsv_file || File.join(tmpfile, 'dataset.tsv')
44
+ output_dir = checkpoint_dir || File.join(tmpfile, 'checkpoints')
45
+
46
+ Open.mkdir File.dirname(output_dir)
47
+ Open.mkdir File.dirname(tsv_file)
48
+
49
+ if labels
50
+ training_args = call_method(:training_args, output_dir)
51
+ call_method(:train_model, @model, @tokenizer, training_args, tsv_dataset(tsv_file, elements, labels))
52
+ else
53
+ if Array === elements
54
+ training_args = call_method(:training_args, output_dir)
55
+ call_method(:predict_model, @model, @tokenizer, training_args, tsv_dataset(tsv_file, elements), @locate_tokens)
56
+ else
57
+ call_method(:eval_model, @model, @tokenizer, [elements], @locate_tokens)
58
+ end
59
+ end
60
+ end
61
+ end
62
+
63
+ def initialize(task, initial_checkpoint = nil, *args)
64
+ super(*args)
65
+ @task = task
66
+
67
+ @checkpoint = model_file && File.exists?(model_file)? model_file : initial_checkpoint
68
+
69
+ @model, @tokenizer = call_method(:load_model_and_tokenizer, @task, @checkpoint)
70
+
71
+ @locate_tokens = @tokenizer.special_tokens_map["mask_token"] if @task == "MaskedLM"
72
+
73
+ train_model do |file,elements,labels|
74
+ run_model(elements, labels)
75
+
76
+ @model.save_pretrained(file) if file
77
+ @tokenizer.save_pretrained(file) if file
78
+ end
79
+
80
+ eval_model do |file,elements|
81
+ run_model(elements)
82
+ end
83
+
84
+ post_process do |result|
85
+ if result.respond_to?(:predictions)
86
+ single = false
87
+ predictions = result.predictions
88
+ elsif result["token_positions"]
89
+ predictions = result["result"].predictions
90
+ token_positions = result["token_positions"]
91
+ else
92
+ single = true
93
+ predictions = result["logits"]
94
+ end
95
+
96
+ result = case @task
97
+ when "SequenceClassification"
98
+ RbbtPython.collect(predictions) do |logits|
99
+ logits = RbbtPython.numpy2ruby logits
100
+ best_class = logits.index logits.max
101
+ best_class = @class_labels[best_class] if @class_labels
102
+ best_class
103
+ end
104
+ when "MaskedLM"
105
+ all_token_positions = token_positions.to_a
106
+
107
+ i = 0
108
+ RbbtPython.collect(predictions) do |item_logits|
109
+ item_token_positions = all_token_positions[i]
110
+ i += 1
111
+
112
+ item_logits = RbbtPython.numpy2ruby(item_logits)
113
+ item_masks = item_token_positions.collect do |token_positions|
114
+
115
+ best = item_logits.values_at(*token_positions).collect do |logits|
116
+ best_token, best_score = nil
117
+ logits.each_with_index do |v,i|
118
+ if best_score.nil? || v > best_score
119
+ best_token, best_score = i, v
120
+ end
121
+ end
122
+ best_token
123
+ end
124
+
125
+ best.collect{|b| @tokenizer.decode(b) } * "|"
126
+ end
127
+ Array === @locate_tokens ? item_masks : item_masks.first
128
+ end
129
+ else
130
+ logits
131
+ end
132
+
133
+ single ? result.first : result
134
+ end
135
+ end
136
+ end
137
+
138
+ if __FILE__ == $0
139
+
140
+ end
@@ -2,9 +2,30 @@ require 'rbbt/util/R'
2
2
  require 'rbbt/vector/model/util'
3
3
 
4
4
  class VectorModel
5
- attr_accessor :directory, :model_file, :extract_features, :train_model, :eval_model
5
+ attr_accessor :directory, :model_file, :extract_features, :train_model, :eval_model, :post_process
6
6
  attr_accessor :features, :names, :labels, :factor_levels
7
7
 
8
+ def extract_features(&block)
9
+ @extract_features = block if block_given?
10
+ @extract_features
11
+ end
12
+
13
+ def train_model(&block)
14
+ @train_model = block if block_given?
15
+ @train_model
16
+ end
17
+
18
+ def eval_model(&block)
19
+ @eval_model = block if block_given?
20
+ @eval_model
21
+ end
22
+
23
+ def post_process(&block)
24
+ @post_process = block if block_given?
25
+ @post_process
26
+ end
27
+
28
+
8
29
  def self.R_run(model_file, features, labels, code, names = nil, factor_levels = nil)
9
30
  TmpFile.with_file do |feature_file|
10
31
  Open.write(feature_file, features.collect{|feats| feats * "\t"} * "\n")
@@ -101,25 +122,27 @@ cat(paste(label, sep="\\n", collapse="\\n"));
101
122
 
102
123
  def __load_method(file)
103
124
  code = Open.read(file)
104
- code.sub!(/.*Proc\.new/, "Proc.new")
125
+ code.sub!(/.*(\sdo\b|{)/, 'Proc.new\1')
105
126
  instance_eval code, file
106
127
  end
107
128
 
108
- def initialize(directory, extract_features = nil, train_model = nil, eval_model = nil, names = nil, factor_levels = nil)
129
+ def initialize(directory = nil, extract_features = nil, train_model = nil, eval_model = nil, names = nil, factor_levels = nil)
109
130
  @directory = directory
110
- FileUtils.mkdir_p @directory unless File.exists? @directory
111
-
112
- @model_file = File.join(@directory, "model")
113
- @extract_features_file = File.join(@directory, "features")
114
- @train_model_file = File.join(@directory, "train_model")
115
- @eval_model_file = File.join(@directory, "eval_model")
116
- @train_model_file_R = File.join(@directory, "train_model.R")
117
- @eval_model_file_R = File.join(@directory, "eval_model.R")
118
- @names_file = File.join(@directory, "feature_names")
119
- @levels_file = File.join(@directory, "levels")
131
+ if @directory
132
+ FileUtils.mkdir_p @directory unless File.exists?(@directory)
133
+
134
+ @model_file = File.join(@directory, "model")
135
+ @extract_features_file = File.join(@directory, "features")
136
+ @train_model_file = File.join(@directory, "train_model")
137
+ @eval_model_file = File.join(@directory, "eval_model")
138
+ @train_model_file_R = File.join(@directory, "train_model.R")
139
+ @eval_model_file_R = File.join(@directory, "eval_model.R")
140
+ @names_file = File.join(@directory, "feature_names")
141
+ @levels_file = File.join(@directory, "levels")
142
+ end
120
143
 
121
144
  if extract_features.nil?
122
- if File.exists?(@extract_features_file)
145
+ if @extract_features_file && File.exists?(@extract_features_file)
123
146
  @extract_features = __load_method @extract_features_file
124
147
  end
125
148
  else
@@ -127,9 +150,9 @@ cat(paste(label, sep="\\n", collapse="\\n"));
127
150
  end
128
151
 
129
152
  if train_model.nil?
130
- if File.exists?(@train_model_file)
153
+ if @train_model_file && File.exists?(@train_model_file)
131
154
  @train_model = __load_method @train_model_file
132
- elsif File.exists?(@train_model_file_R)
155
+ elsif @train_model_file_R && File.exists?(@train_model_file_R)
133
156
  @train_model = Open.read(@train_model_file_R)
134
157
  end
135
158
  else
@@ -137,9 +160,9 @@ cat(paste(label, sep="\\n", collapse="\\n"));
137
160
  end
138
161
 
139
162
  if eval_model.nil?
140
- if File.exists?(@eval_model_file)
163
+ if @eval_model_file && File.exists?(@eval_model_file)
141
164
  @eval_model = __load_method @eval_model_file
142
- elsif File.exists?(@eval_model_file_R)
165
+ elsif @eval_model_file_R && File.exists?(@eval_model_file_R)
143
166
  @eval_model = Open.read(@eval_model_file_R)
144
167
  end
145
168
  else
@@ -147,7 +170,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
147
170
  end
148
171
 
149
172
  if names.nil?
150
- if File.exists?(@names_file)
173
+ if @names_file && File.exists?(@names_file)
151
174
  @names = Open.read(@names_file).split("\n")
152
175
  end
153
176
  else
@@ -155,10 +178,10 @@ cat(paste(label, sep="\\n", collapse="\\n"));
155
178
  end
156
179
 
157
180
  if factor_levels.nil?
158
- if File.exists?(@levels_file)
181
+ if @levels_file && File.exists?(@levels_file)
159
182
  @factor_levels = YAML.load(Open.read(@levels_file))
160
183
  end
161
- if File.exists?(@model_file + '.factor_levels')
184
+ if @model_file && File.exists?(@model_file + '.factor_levels')
162
185
  @factor_levels = TSV.open(@model_file + '.factor_levels')
163
186
  end
164
187
  else
@@ -175,7 +198,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
175
198
  end
176
199
 
177
200
  def add(element, label = nil)
178
- features = @extract_features ? extract_features.call(element) : element
201
+ features = @extract_features ? self.instance_exec(element, &@extract_features) : element
179
202
  @features << features
180
203
  @labels << label
181
204
  end
@@ -186,7 +209,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
186
209
  add(elem, label)
187
210
  end
188
211
  else
189
- features = @extract_features.call(nil, elements)
212
+ features = self.instance_exec(nil, elements, &@extract_features)
190
213
  @features.concat features
191
214
  @labels.concat labels if labels
192
215
  end
@@ -223,9 +246,9 @@ cat(paste(label, sep="\\n", collapse="\\n"));
223
246
 
224
247
  def train
225
248
  case
226
- when Proc === train_model
227
- train_model.call(@model_file, @features, @labels, @names, @factor_levels)
228
- when String === train_model
249
+ when Proc === @train_model
250
+ self.instance_exec(@model_file, @features, @labels, @names, @factor_levels, &@train_model)
251
+ when String === @train_model
229
252
  VectorModel.R_train(@model_file, @features, @labels, train_model, @names, @factor_levels)
230
253
  end
231
254
  save_models
@@ -236,32 +259,44 @@ cat(paste(label, sep="\\n", collapse="\\n"));
236
259
  end
237
260
 
238
261
  def eval(element)
239
- case
240
- when Proc === @eval_model
241
- @eval_model.call(@model_file, @extract_features.call(element), false, nil, @names, @factor_levels)
242
- when String === @eval_model
243
- VectorModel.R_eval(@model_file, @extract_features.call(element), false, eval_model, @names, @factor_levels)
244
- end
262
+ features = @extract_features.nil? ? element : self.instance_exec(element, &@extract_features)
263
+
264
+ result = case
265
+ when Proc === @eval_model
266
+ self.instance_exec(@model_file, features, false, nil, @names, @factor_levels, &@eval_model)
267
+ when String === @eval_model
268
+ VectorModel.R_eval(@model_file, features, false, eval_model, @names, @factor_levels)
269
+ else
270
+ raise "No @eval_model function or R script"
271
+ end
272
+
273
+ result = self.instance_exec(result, &@post_process) if Proc === @post_process
274
+
275
+ result
245
276
  end
246
277
 
247
278
  def eval_list(elements, extract = true)
248
279
 
249
280
  if extract && ! @extract_features.nil?
250
281
  features = if @extract_features.arity == 1
251
- elements.collect{|element| @extract_features.call(element) }
282
+ elements.collect{|element| self.instance_exec(element, &@extract_features) }
252
283
  else
253
- @extract_features.call(nil, elements)
284
+ self.instance_exec(nil, elements, &@extract_features)
254
285
  end
255
286
  else
256
287
  features = elements
257
288
  end
258
289
 
259
- case
260
- when Proc === eval_model
261
- eval_model.call(@model_file, features, true, nil, @names, @factor_levels)
262
- when String === eval_model
263
- VectorModel.R_eval(@model_file, features, true, eval_model, @names, @factor_levels)
264
- end
290
+ result = case
291
+ when Proc === eval_model
292
+ self.instance_exec(@model_file, features, true, nil, @names, @factor_levels, &@eval_model)
293
+ when String === eval_model
294
+ VectorModel.R_eval(@model_file, features, true, eval_model, @names, @factor_levels)
295
+ end
296
+
297
+ result = self.instance_exec(result, &@post_process) if Proc === @post_process
298
+
299
+ result
265
300
  end
266
301
 
267
302
  #def cross_validation(folds = 10)
@@ -0,0 +1,116 @@
1
+ require File.join(File.expand_path(File.dirname(__FILE__)),'../../..', 'test_helper.rb')
2
+ require 'rbbt/vector/model/huggingface'
3
+
4
+ class TestHuggingface < Test::Unit::TestCase
5
+
6
+ def test_sst_eval
7
+ TmpFile.with_file do |dir|
8
+ checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
9
+
10
+ model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
11
+
12
+ model.class_labels = ["Bad", "Good"]
13
+
14
+ assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
15
+
16
+ end
17
+ end
18
+
19
+
20
+ def test_sst_train
21
+ TmpFile.with_file do |dir|
22
+ checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
23
+
24
+ model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
25
+ model.class_labels = ["Bad", "Good"]
26
+
27
+ assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
28
+
29
+ 100.times do
30
+ model.add "Dog is good", 1
31
+ end
32
+
33
+ model.train
34
+
35
+ assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
36
+ end
37
+ end
38
+
39
+ def test_sst_train_no_save
40
+ checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
41
+
42
+ model = HuggingfaceModel.new "SequenceClassification", checkpoint
43
+ model.class_labels = ["Bad", "Good"]
44
+
45
+ assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
46
+
47
+ 100.times do
48
+ model.add "Dog is good", 1
49
+ end
50
+
51
+ model.train
52
+
53
+ assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
54
+ end
55
+
56
+ def test_sst_train_save_and_load
57
+ TmpFile.with_file do |dir|
58
+ checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
59
+
60
+ model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
61
+ model.class_labels = ["Bad", "Good"]
62
+
63
+ assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
64
+
65
+ 100.times do
66
+ model.add "Dog is good", 1
67
+ end
68
+
69
+ model.train
70
+
71
+ model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
72
+ model.class_labels = ["Bad", "Good"]
73
+
74
+ assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
75
+
76
+ model = HuggingfaceModel.new "SequenceClassification", model.model_file
77
+ model.class_labels = ["Bad", "Good"]
78
+
79
+ assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
80
+
81
+ end
82
+ end
83
+
84
+ def test_sst_stress_test
85
+ TmpFile.with_file do |dir|
86
+ checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
87
+
88
+ model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
89
+
90
+ 100.times do
91
+ model.add "Dog is good", 1
92
+ model.add "Cat is bad", 0
93
+ end
94
+
95
+ Misc.benchmark(10) do
96
+ model.train
97
+ end
98
+
99
+ Misc.benchmark 1000 do
100
+ model.eval(["This is good", "This is terrible", "This is dog", "This is cat", "Very different stuff", "Dog is bad", "Cat is good"])
101
+ end
102
+ end
103
+ end
104
+
105
+ def test_mask_eval
106
+ checkpoint = "bert-base-uncased"
107
+
108
+ model = HuggingfaceModel.new "MaskedLM", checkpoint
109
+ assert_equal 3, model.eval(["Paris is the [MASK] of the France.", "The [MASK] worked very hard all the time.", "The [MASK] arrested the dangerous [MASK]."]).
110
+ reject{|v| v.empty?}.length
111
+
112
+ end
113
+
114
+
115
+ end
116
+
@@ -282,6 +282,7 @@ cat(label, file="#{results}");
282
282
  model.add features, label
283
283
  end
284
284
 
285
+ iii model.eval("1;1;1")
285
286
  assert model.eval("1;1;1").to_f > 0.5
286
287
  assert model.eval("0;0;0").to_f < 0.5
287
288
  end
@@ -509,5 +510,28 @@ label = predict(model, features);
509
510
  end
510
511
  end
511
512
 
513
+ def test_python
514
+ require 'rbbt/util/python'
515
+ TmpFile.with_file do |dir|
516
+ model = VectorModel.new dir
517
+
518
+ model.eval_model do |file, elements|
519
+ elements = [elements] unless Array === elements
520
+ RbbtPython.binding_run do
521
+ pyimport :torch
522
+ rand = torch.rand(1).numpy[0].to_f
523
+ elements.collect{|e| e >= rand ? 1 : 0 }
524
+ end
525
+ end
526
+ p1, p2 = model.eval [0.9, 0.1]
527
+ assert p2 <= p1
528
+
529
+ model = VectorModel.new dir
530
+ assert p2 <= p1
531
+
532
+ end
533
+
534
+ end
535
+
512
536
 
513
537
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rbbt-dm
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.1.63
4
+ version: 1.2.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Miguel Vazquez
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2022-12-15 00:00:00.000000000 Z
11
+ date: 2023-02-04 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rbbt-util
@@ -107,6 +107,7 @@ files:
107
107
  - lib/rbbt/statistics/rank_product.rb
108
108
  - lib/rbbt/tensorflow.rb
109
109
  - lib/rbbt/vector/model.rb
110
+ - lib/rbbt/vector/model/huggingface.rb
110
111
  - lib/rbbt/vector/model/random_forest.rb
111
112
  - lib/rbbt/vector/model/spaCy.rb
112
113
  - lib/rbbt/vector/model/svm.rb
@@ -131,6 +132,7 @@ files:
131
132
  - test/rbbt/statistics/test_random_walk.rb
132
133
  - test/rbbt/test_ml_task.rb
133
134
  - test/rbbt/test_stan.rb
135
+ - test/rbbt/vector/model/test_huggingface.rb
134
136
  - test/rbbt/vector/model/test_spaCy.rb
135
137
  - test/rbbt/vector/model/test_svm.rb
136
138
  - test/rbbt/vector/model/test_tensorflow.rb
@@ -154,21 +156,22 @@ required_rubygems_version: !ruby/object:Gem::Requirement
154
156
  - !ruby/object:Gem::Version
155
157
  version: '0'
156
158
  requirements: []
157
- rubygems_version: 3.1.4
159
+ rubygems_version: 3.1.2
158
160
  signing_key:
159
161
  specification_version: 4
160
162
  summary: Data-mining and statistics
161
163
  test_files:
162
- - test/rbbt/network/test_paths.rb
163
- - test/rbbt/matrix/test_barcode.rb
164
+ - test/test_helper.rb
165
+ - test/rbbt/vector/test_model.rb
166
+ - test/rbbt/vector/model/test_huggingface.rb
167
+ - test/rbbt/vector/model/test_tensorflow.rb
168
+ - test/rbbt/vector/model/test_spaCy.rb
169
+ - test/rbbt/vector/model/test_svm.rb
164
170
  - test/rbbt/statistics/test_random_walk.rb
165
171
  - test/rbbt/statistics/test_fisher.rb
166
172
  - test/rbbt/statistics/test_fdr.rb
167
173
  - test/rbbt/statistics/test_hypergeometric.rb
168
- - test/rbbt/test_ml_task.rb
169
- - test/rbbt/vector/test_model.rb
170
- - test/rbbt/vector/model/test_spaCy.rb
171
- - test/rbbt/vector/model/test_tensorflow.rb
172
- - test/rbbt/vector/model/test_svm.rb
173
174
  - test/rbbt/test_stan.rb
174
- - test/test_helper.rb
175
+ - test/rbbt/matrix/test_barcode.rb
176
+ - test/rbbt/test_ml_task.rb
177
+ - test/rbbt/network/test_paths.rb