rbbt-dm 1.1.63 → 1.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 7939251873b8bf86dccb5260dc9f4b64abb7c24b5001f436a1b457d4ad2333af
4
- data.tar.gz: 476b2c7175b557bc287f928225d7da020d8e29a1e59b534d595ea6f75597f23f
3
+ metadata.gz: ab775c0224960820e5c62e294e6a183be49201da15710b66544762e1aaf97ebf
4
+ data.tar.gz: 8fffb47ba226f06d1f41a8893d085bdc12c33c021cf2f0152f4cc741db36e420
5
5
  SHA512:
6
- metadata.gz: bf503009cf5bc8d1ac239f1c8fe07288102560c3bc5324368187690b523f04f92ff94b74bf97c512cb0d4378985f9269177d7dda0897694a0eb2df62f369decc
7
- data.tar.gz: 13c8cd26daff6205de91f5c733563fe3d28af056869584882a9eb08d0d1ddcf5bde88bf38645dc82d9f04864631b3dccc93bf1a56945372c3b2a4bad618f3144
6
+ metadata.gz: 8be084156063cd93c7fe905bc4b6248dd376bbfcff8e650cbb03a4cc5c28f29dbcdaa2801895a3663067d7660e8bc2cf96682829519ebd6511a7a74cec021da0
7
+ data.tar.gz: 9c8570722319caf5afe60c90778d0b8517e70064030e0081968d919a314cfe1af90f63d13a7b8ce8dd56da6644db89916a729a4005f1e20745c8d4b45d50394c
@@ -0,0 +1,140 @@
1
+ require 'rbbt/vector/model'
2
+ require 'rbbt/util/python'
3
+
4
+ RbbtPython.init_rbbt
5
+
6
+ class HuggingfaceModel < VectorModel
7
+
8
+ attr_accessor :checkpoint, :task, :locate_tokens, :class_labels
9
+
10
+ def tsv_dataset(tsv_dataset_file, elements, labels = nil)
11
+
12
+ if labels
13
+ Open.write(tsv_dataset_file) do |ffile|
14
+ ffile.puts ["label", "text"].flatten * "\t"
15
+ elements.zip(labels).each do |element,label|
16
+ ffile.puts [label, element].flatten * "\t"
17
+ end
18
+ end
19
+ else
20
+ Open.write(tsv_dataset_file) do |ffile|
21
+ ffile.puts ["text"].flatten * "\t"
22
+ elements.each{|element| ffile.puts element }
23
+ end
24
+ end
25
+
26
+ tsv_dataset_file
27
+ end
28
+
29
+ def call_method(name, *args)
30
+ RbbtPython.import_method("rbbt_dm.huggingface", name).call(*args)
31
+ end
32
+
33
+ def input_tsv_file
34
+ File.join(@directory, 'dataset.tsv') if @directory
35
+ end
36
+
37
+ def checkpoint_dir
38
+ File.join(@directory, 'checkpoints') if @directory
39
+ end
40
+
41
+ def run_model(elements, labels = nil)
42
+ TmpFile.with_file do |tmpfile|
43
+ tsv_file = input_tsv_file || File.join(tmpfile, 'dataset.tsv')
44
+ output_dir = checkpoint_dir || File.join(tmpfile, 'checkpoints')
45
+
46
+ Open.mkdir File.dirname(output_dir)
47
+ Open.mkdir File.dirname(tsv_file)
48
+
49
+ if labels
50
+ training_args = call_method(:training_args, output_dir)
51
+ call_method(:train_model, @model, @tokenizer, training_args, tsv_dataset(tsv_file, elements, labels))
52
+ else
53
+ if Array === elements
54
+ training_args = call_method(:training_args, output_dir)
55
+ call_method(:predict_model, @model, @tokenizer, training_args, tsv_dataset(tsv_file, elements), @locate_tokens)
56
+ else
57
+ call_method(:eval_model, @model, @tokenizer, [elements], @locate_tokens)
58
+ end
59
+ end
60
+ end
61
+ end
62
+
63
+ def initialize(task, initial_checkpoint = nil, *args)
64
+ super(*args)
65
+ @task = task
66
+
67
+ @checkpoint = model_file && File.exists?(model_file)? model_file : initial_checkpoint
68
+
69
+ @model, @tokenizer = call_method(:load_model_and_tokenizer, @task, @checkpoint)
70
+
71
+ @locate_tokens = @tokenizer.special_tokens_map["mask_token"] if @task == "MaskedLM"
72
+
73
+ train_model do |file,elements,labels|
74
+ run_model(elements, labels)
75
+
76
+ @model.save_pretrained(file) if file
77
+ @tokenizer.save_pretrained(file) if file
78
+ end
79
+
80
+ eval_model do |file,elements|
81
+ run_model(elements)
82
+ end
83
+
84
+ post_process do |result|
85
+ if result.respond_to?(:predictions)
86
+ single = false
87
+ predictions = result.predictions
88
+ elsif result["token_positions"]
89
+ predictions = result["result"].predictions
90
+ token_positions = result["token_positions"]
91
+ else
92
+ single = true
93
+ predictions = result["logits"]
94
+ end
95
+
96
+ result = case @task
97
+ when "SequenceClassification"
98
+ RbbtPython.collect(predictions) do |logits|
99
+ logits = RbbtPython.numpy2ruby logits
100
+ best_class = logits.index logits.max
101
+ best_class = @class_labels[best_class] if @class_labels
102
+ best_class
103
+ end
104
+ when "MaskedLM"
105
+ all_token_positions = token_positions.to_a
106
+
107
+ i = 0
108
+ RbbtPython.collect(predictions) do |item_logits|
109
+ item_token_positions = all_token_positions[i]
110
+ i += 1
111
+
112
+ item_logits = RbbtPython.numpy2ruby(item_logits)
113
+ item_masks = item_token_positions.collect do |token_positions|
114
+
115
+ best = item_logits.values_at(*token_positions).collect do |logits|
116
+ best_token, best_score = nil
117
+ logits.each_with_index do |v,i|
118
+ if best_score.nil? || v > best_score
119
+ best_token, best_score = i, v
120
+ end
121
+ end
122
+ best_token
123
+ end
124
+
125
+ best.collect{|b| @tokenizer.decode(b) } * "|"
126
+ end
127
+ Array === @locate_tokens ? item_masks : item_masks.first
128
+ end
129
+ else
130
+ logits
131
+ end
132
+
133
+ single ? result.first : result
134
+ end
135
+ end
136
+ end
137
+
138
+ if __FILE__ == $0
139
+
140
+ end
@@ -2,9 +2,30 @@ require 'rbbt/util/R'
2
2
  require 'rbbt/vector/model/util'
3
3
 
4
4
  class VectorModel
5
- attr_accessor :directory, :model_file, :extract_features, :train_model, :eval_model
5
+ attr_accessor :directory, :model_file, :extract_features, :train_model, :eval_model, :post_process
6
6
  attr_accessor :features, :names, :labels, :factor_levels
7
7
 
8
+ def extract_features(&block)
9
+ @extract_features = block if block_given?
10
+ @extract_features
11
+ end
12
+
13
+ def train_model(&block)
14
+ @train_model = block if block_given?
15
+ @train_model
16
+ end
17
+
18
+ def eval_model(&block)
19
+ @eval_model = block if block_given?
20
+ @eval_model
21
+ end
22
+
23
+ def post_process(&block)
24
+ @post_process = block if block_given?
25
+ @post_process
26
+ end
27
+
28
+
8
29
  def self.R_run(model_file, features, labels, code, names = nil, factor_levels = nil)
9
30
  TmpFile.with_file do |feature_file|
10
31
  Open.write(feature_file, features.collect{|feats| feats * "\t"} * "\n")
@@ -101,25 +122,27 @@ cat(paste(label, sep="\\n", collapse="\\n"));
101
122
 
102
123
  def __load_method(file)
103
124
  code = Open.read(file)
104
- code.sub!(/.*Proc\.new/, "Proc.new")
125
+ code.sub!(/.*(\sdo\b|{)/, 'Proc.new\1')
105
126
  instance_eval code, file
106
127
  end
107
128
 
108
- def initialize(directory, extract_features = nil, train_model = nil, eval_model = nil, names = nil, factor_levels = nil)
129
+ def initialize(directory = nil, extract_features = nil, train_model = nil, eval_model = nil, names = nil, factor_levels = nil)
109
130
  @directory = directory
110
- FileUtils.mkdir_p @directory unless File.exists? @directory
111
-
112
- @model_file = File.join(@directory, "model")
113
- @extract_features_file = File.join(@directory, "features")
114
- @train_model_file = File.join(@directory, "train_model")
115
- @eval_model_file = File.join(@directory, "eval_model")
116
- @train_model_file_R = File.join(@directory, "train_model.R")
117
- @eval_model_file_R = File.join(@directory, "eval_model.R")
118
- @names_file = File.join(@directory, "feature_names")
119
- @levels_file = File.join(@directory, "levels")
131
+ if @directory
132
+ FileUtils.mkdir_p @directory unless File.exists?(@directory)
133
+
134
+ @model_file = File.join(@directory, "model")
135
+ @extract_features_file = File.join(@directory, "features")
136
+ @train_model_file = File.join(@directory, "train_model")
137
+ @eval_model_file = File.join(@directory, "eval_model")
138
+ @train_model_file_R = File.join(@directory, "train_model.R")
139
+ @eval_model_file_R = File.join(@directory, "eval_model.R")
140
+ @names_file = File.join(@directory, "feature_names")
141
+ @levels_file = File.join(@directory, "levels")
142
+ end
120
143
 
121
144
  if extract_features.nil?
122
- if File.exists?(@extract_features_file)
145
+ if @extract_features_file && File.exists?(@extract_features_file)
123
146
  @extract_features = __load_method @extract_features_file
124
147
  end
125
148
  else
@@ -127,9 +150,9 @@ cat(paste(label, sep="\\n", collapse="\\n"));
127
150
  end
128
151
 
129
152
  if train_model.nil?
130
- if File.exists?(@train_model_file)
153
+ if @train_model_file && File.exists?(@train_model_file)
131
154
  @train_model = __load_method @train_model_file
132
- elsif File.exists?(@train_model_file_R)
155
+ elsif @train_model_file_R && File.exists?(@train_model_file_R)
133
156
  @train_model = Open.read(@train_model_file_R)
134
157
  end
135
158
  else
@@ -137,9 +160,9 @@ cat(paste(label, sep="\\n", collapse="\\n"));
137
160
  end
138
161
 
139
162
  if eval_model.nil?
140
- if File.exists?(@eval_model_file)
163
+ if @eval_model_file && File.exists?(@eval_model_file)
141
164
  @eval_model = __load_method @eval_model_file
142
- elsif File.exists?(@eval_model_file_R)
165
+ elsif @eval_model_file_R && File.exists?(@eval_model_file_R)
143
166
  @eval_model = Open.read(@eval_model_file_R)
144
167
  end
145
168
  else
@@ -147,7 +170,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
147
170
  end
148
171
 
149
172
  if names.nil?
150
- if File.exists?(@names_file)
173
+ if @names_file && File.exists?(@names_file)
151
174
  @names = Open.read(@names_file).split("\n")
152
175
  end
153
176
  else
@@ -155,10 +178,10 @@ cat(paste(label, sep="\\n", collapse="\\n"));
155
178
  end
156
179
 
157
180
  if factor_levels.nil?
158
- if File.exists?(@levels_file)
181
+ if @levels_file && File.exists?(@levels_file)
159
182
  @factor_levels = YAML.load(Open.read(@levels_file))
160
183
  end
161
- if File.exists?(@model_file + '.factor_levels')
184
+ if @model_file && File.exists?(@model_file + '.factor_levels')
162
185
  @factor_levels = TSV.open(@model_file + '.factor_levels')
163
186
  end
164
187
  else
@@ -175,7 +198,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
175
198
  end
176
199
 
177
200
  def add(element, label = nil)
178
- features = @extract_features ? extract_features.call(element) : element
201
+ features = @extract_features ? self.instance_exec(element, &@extract_features) : element
179
202
  @features << features
180
203
  @labels << label
181
204
  end
@@ -186,7 +209,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
186
209
  add(elem, label)
187
210
  end
188
211
  else
189
- features = @extract_features.call(nil, elements)
212
+ features = self.instance_exec(nil, elements, &@extract_features)
190
213
  @features.concat features
191
214
  @labels.concat labels if labels
192
215
  end
@@ -223,9 +246,9 @@ cat(paste(label, sep="\\n", collapse="\\n"));
223
246
 
224
247
  def train
225
248
  case
226
- when Proc === train_model
227
- train_model.call(@model_file, @features, @labels, @names, @factor_levels)
228
- when String === train_model
249
+ when Proc === @train_model
250
+ self.instance_exec(@model_file, @features, @labels, @names, @factor_levels, &@train_model)
251
+ when String === @train_model
229
252
  VectorModel.R_train(@model_file, @features, @labels, train_model, @names, @factor_levels)
230
253
  end
231
254
  save_models
@@ -236,32 +259,44 @@ cat(paste(label, sep="\\n", collapse="\\n"));
236
259
  end
237
260
 
238
261
  def eval(element)
239
- case
240
- when Proc === @eval_model
241
- @eval_model.call(@model_file, @extract_features.call(element), false, nil, @names, @factor_levels)
242
- when String === @eval_model
243
- VectorModel.R_eval(@model_file, @extract_features.call(element), false, eval_model, @names, @factor_levels)
244
- end
262
+ features = @extract_features.nil? ? element : self.instance_exec(element, &@extract_features)
263
+
264
+ result = case
265
+ when Proc === @eval_model
266
+ self.instance_exec(@model_file, features, false, nil, @names, @factor_levels, &@eval_model)
267
+ when String === @eval_model
268
+ VectorModel.R_eval(@model_file, features, false, eval_model, @names, @factor_levels)
269
+ else
270
+ raise "No @eval_model function or R script"
271
+ end
272
+
273
+ result = self.instance_exec(result, &@post_process) if Proc === @post_process
274
+
275
+ result
245
276
  end
246
277
 
247
278
  def eval_list(elements, extract = true)
248
279
 
249
280
  if extract && ! @extract_features.nil?
250
281
  features = if @extract_features.arity == 1
251
- elements.collect{|element| @extract_features.call(element) }
282
+ elements.collect{|element| self.instance_exec(element, &@extract_features) }
252
283
  else
253
- @extract_features.call(nil, elements)
284
+ self.instance_exec(nil, elements, &@extract_features)
254
285
  end
255
286
  else
256
287
  features = elements
257
288
  end
258
289
 
259
- case
260
- when Proc === eval_model
261
- eval_model.call(@model_file, features, true, nil, @names, @factor_levels)
262
- when String === eval_model
263
- VectorModel.R_eval(@model_file, features, true, eval_model, @names, @factor_levels)
264
- end
290
+ result = case
291
+ when Proc === eval_model
292
+ self.instance_exec(@model_file, features, true, nil, @names, @factor_levels, &@eval_model)
293
+ when String === eval_model
294
+ VectorModel.R_eval(@model_file, features, true, eval_model, @names, @factor_levels)
295
+ end
296
+
297
+ result = self.instance_exec(result, &@post_process) if Proc === @post_process
298
+
299
+ result
265
300
  end
266
301
 
267
302
  #def cross_validation(folds = 10)
@@ -0,0 +1,116 @@
1
+ require File.join(File.expand_path(File.dirname(__FILE__)),'../../..', 'test_helper.rb')
2
+ require 'rbbt/vector/model/huggingface'
3
+
4
+ class TestHuggingface < Test::Unit::TestCase
5
+
6
+ def test_sst_eval
7
+ TmpFile.with_file do |dir|
8
+ checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
9
+
10
+ model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
11
+
12
+ model.class_labels = ["Bad", "Good"]
13
+
14
+ assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
15
+
16
+ end
17
+ end
18
+
19
+
20
+ def test_sst_train
21
+ TmpFile.with_file do |dir|
22
+ checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
23
+
24
+ model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
25
+ model.class_labels = ["Bad", "Good"]
26
+
27
+ assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
28
+
29
+ 100.times do
30
+ model.add "Dog is good", 1
31
+ end
32
+
33
+ model.train
34
+
35
+ assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
36
+ end
37
+ end
38
+
39
+ def test_sst_train_no_save
40
+ checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
41
+
42
+ model = HuggingfaceModel.new "SequenceClassification", checkpoint
43
+ model.class_labels = ["Bad", "Good"]
44
+
45
+ assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
46
+
47
+ 100.times do
48
+ model.add "Dog is good", 1
49
+ end
50
+
51
+ model.train
52
+
53
+ assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
54
+ end
55
+
56
+ def test_sst_train_save_and_load
57
+ TmpFile.with_file do |dir|
58
+ checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
59
+
60
+ model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
61
+ model.class_labels = ["Bad", "Good"]
62
+
63
+ assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
64
+
65
+ 100.times do
66
+ model.add "Dog is good", 1
67
+ end
68
+
69
+ model.train
70
+
71
+ model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
72
+ model.class_labels = ["Bad", "Good"]
73
+
74
+ assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
75
+
76
+ model = HuggingfaceModel.new "SequenceClassification", model.model_file
77
+ model.class_labels = ["Bad", "Good"]
78
+
79
+ assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
80
+
81
+ end
82
+ end
83
+
84
+ def test_sst_stress_test
85
+ TmpFile.with_file do |dir|
86
+ checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
87
+
88
+ model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
89
+
90
+ 100.times do
91
+ model.add "Dog is good", 1
92
+ model.add "Cat is bad", 0
93
+ end
94
+
95
+ Misc.benchmark(10) do
96
+ model.train
97
+ end
98
+
99
+ Misc.benchmark 1000 do
100
+ model.eval(["This is good", "This is terrible", "This is dog", "This is cat", "Very different stuff", "Dog is bad", "Cat is good"])
101
+ end
102
+ end
103
+ end
104
+
105
+ def test_mask_eval
106
+ checkpoint = "bert-base-uncased"
107
+
108
+ model = HuggingfaceModel.new "MaskedLM", checkpoint
109
+ assert_equal 3, model.eval(["Paris is the [MASK] of the France.", "The [MASK] worked very hard all the time.", "The [MASK] arrested the dangerous [MASK]."]).
110
+ reject{|v| v.empty?}.length
111
+
112
+ end
113
+
114
+
115
+ end
116
+
@@ -282,6 +282,7 @@ cat(label, file="#{results}");
282
282
  model.add features, label
283
283
  end
284
284
 
285
+ iii model.eval("1;1;1")
285
286
  assert model.eval("1;1;1").to_f > 0.5
286
287
  assert model.eval("0;0;0").to_f < 0.5
287
288
  end
@@ -509,5 +510,28 @@ label = predict(model, features);
509
510
  end
510
511
  end
511
512
 
513
+ def test_python
514
+ require 'rbbt/util/python'
515
+ TmpFile.with_file do |dir|
516
+ model = VectorModel.new dir
517
+
518
+ model.eval_model do |file, elements|
519
+ elements = [elements] unless Array === elements
520
+ RbbtPython.binding_run do
521
+ pyimport :torch
522
+ rand = torch.rand(1).numpy[0].to_f
523
+ elements.collect{|e| e >= rand ? 1 : 0 }
524
+ end
525
+ end
526
+ p1, p2 = model.eval [0.9, 0.1]
527
+ assert p2 <= p1
528
+
529
+ model = VectorModel.new dir
530
+ assert p2 <= p1
531
+
532
+ end
533
+
534
+ end
535
+
512
536
 
513
537
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rbbt-dm
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.1.63
4
+ version: 1.2.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Miguel Vazquez
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2022-12-15 00:00:00.000000000 Z
11
+ date: 2023-02-04 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rbbt-util
@@ -107,6 +107,7 @@ files:
107
107
  - lib/rbbt/statistics/rank_product.rb
108
108
  - lib/rbbt/tensorflow.rb
109
109
  - lib/rbbt/vector/model.rb
110
+ - lib/rbbt/vector/model/huggingface.rb
110
111
  - lib/rbbt/vector/model/random_forest.rb
111
112
  - lib/rbbt/vector/model/spaCy.rb
112
113
  - lib/rbbt/vector/model/svm.rb
@@ -131,6 +132,7 @@ files:
131
132
  - test/rbbt/statistics/test_random_walk.rb
132
133
  - test/rbbt/test_ml_task.rb
133
134
  - test/rbbt/test_stan.rb
135
+ - test/rbbt/vector/model/test_huggingface.rb
134
136
  - test/rbbt/vector/model/test_spaCy.rb
135
137
  - test/rbbt/vector/model/test_svm.rb
136
138
  - test/rbbt/vector/model/test_tensorflow.rb
@@ -154,21 +156,22 @@ required_rubygems_version: !ruby/object:Gem::Requirement
154
156
  - !ruby/object:Gem::Version
155
157
  version: '0'
156
158
  requirements: []
157
- rubygems_version: 3.1.4
159
+ rubygems_version: 3.1.2
158
160
  signing_key:
159
161
  specification_version: 4
160
162
  summary: Data-mining and statistics
161
163
  test_files:
162
- - test/rbbt/network/test_paths.rb
163
- - test/rbbt/matrix/test_barcode.rb
164
+ - test/test_helper.rb
165
+ - test/rbbt/vector/test_model.rb
166
+ - test/rbbt/vector/model/test_huggingface.rb
167
+ - test/rbbt/vector/model/test_tensorflow.rb
168
+ - test/rbbt/vector/model/test_spaCy.rb
169
+ - test/rbbt/vector/model/test_svm.rb
164
170
  - test/rbbt/statistics/test_random_walk.rb
165
171
  - test/rbbt/statistics/test_fisher.rb
166
172
  - test/rbbt/statistics/test_fdr.rb
167
173
  - test/rbbt/statistics/test_hypergeometric.rb
168
- - test/rbbt/test_ml_task.rb
169
- - test/rbbt/vector/test_model.rb
170
- - test/rbbt/vector/model/test_spaCy.rb
171
- - test/rbbt/vector/model/test_tensorflow.rb
172
- - test/rbbt/vector/model/test_svm.rb
173
174
  - test/rbbt/test_stan.rb
174
- - test/test_helper.rb
175
+ - test/rbbt/matrix/test_barcode.rb
176
+ - test/rbbt/test_ml_task.rb
177
+ - test/rbbt/network/test_paths.rb