rbbt-dm 1.2.3 → 1.2.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 2ff72107967b0f7c654697f3a7b3c0ef10f7a5264d775117f12f74084a2819b2
4
- data.tar.gz: 6b9a58b5a2723c095332f79a37d9c1c7f4bc1431410f23a55beeed1c3b52f7ad
3
+ metadata.gz: 9744ab9faeaf4f9cc04947eb11103dbf0694dda624f805a5c6be27bb22af81ce
4
+ data.tar.gz: d3a3903aa276a69e20cbd71213286449db396ecf5f6a4b4d80a64ab299041fbb
5
5
  SHA512:
6
- metadata.gz: a0fb4198cb0be3aa5253df0f655ee230621dd26a31956a774fffe95eac35f4c8b558a41c0e340c25c6eef463760ff6230b967f09eb09671b2078a50066067384
7
- data.tar.gz: 0bd6c3667a8ec26ed092c54e78176671807ed0634e136e497303e34a17b2740e5e023041bc06389ac187de39b54942b9b1c5cd77abbc067c89250424654b6974
6
+ metadata.gz: 263fb609b37522874426bcd79374760399b4a9aaab443ae6d74c727f2d148474dd71ee0b2cfda7a50131dafbc314f66352f4285562a75f62144d2e05ccd214c7
7
+ data.tar.gz: 1e0426429a38028a19b3f8c955e975138199c791dad8691de7fb760a5cbec3304f19341906a4457563a90de965ee2f12a5a639b928866b46258af2507eeb39fa
@@ -0,0 +1,160 @@
1
+ require 'rbbt/vector/model'
2
+ require 'rbbt/util/python'
3
+
4
+ RbbtPython.add_path Rbbt.python.find(:lib)
5
+ RbbtPython.init_rbbt
6
+
7
+ class HuggingfaceModel < VectorModel
8
+
9
+ attr_accessor :checkpoint, :task, :locate_tokens, :class_labels, :class_weights, :training_args
10
+
11
+ def self.tsv_dataset(tsv_dataset_file, elements, labels = nil)
12
+
13
+ if labels
14
+ Open.write(tsv_dataset_file) do |ffile|
15
+ ffile.puts ["label", "text"].flatten * "\t"
16
+ elements.zip(labels).each do |element,label|
17
+ ffile.puts [label, element].flatten * "\t"
18
+ end
19
+ end
20
+ else
21
+ Open.write(tsv_dataset_file) do |ffile|
22
+ ffile.puts ["text"].flatten * "\t"
23
+ elements.each{|element| ffile.puts element }
24
+ end
25
+ end
26
+
27
+ tsv_dataset_file
28
+ end
29
+
30
+ def self.call_method(name, *args)
31
+ RbbtPython.import_method("rbbt_dm.huggingface", name).call(*args)
32
+ end
33
+
34
+ def call_method(name, *args)
35
+ HuggingfaceModel.call_method(name, *args)
36
+ end
37
+
38
+ #def input_tsv_file
39
+ # File.join(@directory, 'dataset.tsv') if @directory
40
+ #end
41
+
42
+ #def checkpoint_dir
43
+ # File.join(@directory, 'checkpoints') if @directory
44
+ #end
45
+
46
+ def self.run_model(model, tokenizer, elements, labels = nil, training_args = {}, class_weights = nil)
47
+ TmpFile.with_file do |tmpfile|
48
+ tsv_file = File.join(tmpfile, 'dataset.tsv')
49
+
50
+ if training_args
51
+ training_args = training_args.dup
52
+ checkpoint_dir = training_args.delete(:checkpoint_dir)
53
+ end
54
+
55
+ checkpoint_dir = File.join(tmpfile, 'checkpoints')
56
+
57
+ Open.mkdir File.dirname(tsv_file)
58
+ Open.mkdir File.dirname(checkpoint_dir)
59
+
60
+ if labels
61
+ training_args_obj = call_method(:training_args, checkpoint_dir, **training_args)
62
+ call_method(:train_model, model, tokenizer, training_args_obj, tsv_dataset(tsv_file, elements, labels), class_weights)
63
+ else
64
+ locate_tokens, training_args = training_args, {}
65
+ if Array === elements
66
+ training_args_obj = call_method(:training_args, checkpoint_dir)
67
+ call_method(:predict_model, model, tokenizer, training_args_obj, tsv_dataset(tsv_file, elements), locate_tokens)
68
+ else
69
+ call_method(:eval_model, model, tokenizer, [elements], locate_tokens)
70
+ end
71
+ end
72
+ end
73
+ end
74
+
75
+ def init_model
76
+ @model, @tokenizer = call_method(:load_model_and_tokenizer, @task, @checkpoint)
77
+ end
78
+
79
+ def reset_model
80
+ init_model
81
+ end
82
+
83
+ def initialize(task, initial_checkpoint = nil, *args)
84
+ super(*args)
85
+ @task = task
86
+
87
+ @checkpoint = model_file && File.exists?(model_file)? model_file : initial_checkpoint
88
+
89
+ init_model
90
+
91
+ @locate_tokens = @tokenizer.special_tokens_map["mask_token"] if @task == "MaskedLM"
92
+
93
+ @training_args = {}
94
+
95
+ train_model do |file,elements,labels|
96
+ HuggingfaceModel.run_model(@model, @tokenizer, elements, labels, @training_args, @class_weights)
97
+
98
+ @model.save_pretrained(file) if file
99
+ @tokenizer.save_pretrained(file) if file
100
+ end
101
+
102
+ eval_model do |file,elements|
103
+ @model, @tokenizer = HuggingfaceModel.call_method(:load_model_and_tokenizer, @task, @checkpoint)
104
+ HuggingfaceModel.run_model(@model, @tokenizer, elements, nil, @locate_tokens)
105
+ end
106
+
107
+ post_process do |result|
108
+ if result.respond_to?(:predictions)
109
+ single = false
110
+ predictions = result.predictions
111
+ elsif result["token_positions"]
112
+ predictions = result["result"].predictions
113
+ token_positions = result["token_positions"]
114
+ else
115
+ single = true
116
+ predictions = result["logits"]
117
+ end
118
+
119
+ result = case @task
120
+ when "SequenceClassification"
121
+ RbbtPython.collect(predictions) do |logits|
122
+ logits = RbbtPython.numpy2ruby logits
123
+ best_class = logits.index logits.max
124
+ best_class = @class_labels[best_class] if @class_labels
125
+ best_class
126
+ end
127
+ when "MaskedLM"
128
+ all_token_positions = token_positions.to_a
129
+
130
+ i = 0
131
+ RbbtPython.collect(predictions) do |item_logits|
132
+ item_token_positions = all_token_positions[i]
133
+ i += 1
134
+
135
+ item_logits = RbbtPython.numpy2ruby(item_logits)
136
+ item_masks = item_token_positions.collect do |token_positions|
137
+
138
+ best = item_logits.values_at(*token_positions).collect do |logits|
139
+ best_token, best_score = nil
140
+ logits.each_with_index do |v,i|
141
+ if best_score.nil? || v > best_score
142
+ best_token, best_score = i, v
143
+ end
144
+ end
145
+ best_token
146
+ end
147
+
148
+ best.collect{|b| @tokenizer.decode(b) } * "|"
149
+ end
150
+ Array === @locate_tokens ? item_masks : item_masks.first
151
+ end
152
+ else
153
+ logits
154
+ end
155
+
156
+ single ? result.first : result
157
+ end
158
+ end
159
+ end
160
+
@@ -1,13 +1,12 @@
1
1
  require 'rbbt/vector/model'
2
2
  require 'rbbt/util/python'
3
3
 
4
+ RbbtPython.add_path Rbbt.python.find(:lib)
4
5
  RbbtPython.init_rbbt
5
6
 
6
7
  class HuggingfaceModel < VectorModel
7
8
 
8
- attr_accessor :checkpoint, :task, :locate_tokens, :class_labels
9
-
10
- def tsv_dataset(tsv_dataset_file, elements, labels = nil)
9
+ def self.tsv_dataset(tsv_dataset_file, elements, labels = nil)
11
10
 
12
11
  if labels
13
12
  Open.write(tsv_dataset_file) do |ffile|
@@ -26,59 +25,74 @@ class HuggingfaceModel < VectorModel
26
25
  tsv_dataset_file
27
26
  end
28
27
 
29
- def call_method(name, *args)
30
- RbbtPython.import_method("rbbt_dm.huggingface", name).call(*args)
31
- end
32
-
33
- def input_tsv_file
34
- File.join(@directory, 'dataset.tsv') if @directory
35
- end
28
+ def initialize(task, checkpoint, *args)
29
+ options = args.pop if Hash === args.last
30
+ options = Misc.add_defaults options, :task => task, :checkpoint => checkpoint
31
+ super(*args)
32
+ @model_options ||= {}
33
+ @model_options.merge!(options)
36
34
 
37
- def checkpoint_dir
38
- File.join(@directory, 'checkpoints') if @directory
39
- end
35
+ eval_model do |directory,texts|
36
+ checkpoint = directory && File.directory?(directory) ? directory : @model_options[:checkpoint]
40
37
 
41
- def run_model(elements, labels = nil)
42
- TmpFile.with_file do |tmpfile|
43
- tsv_file = input_tsv_file || File.join(tmpfile, 'dataset.tsv')
44
- output_dir = checkpoint_dir || File.join(tmpfile, 'checkpoints')
38
+ if @model.nil?
39
+ @model, @tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_model_and_tokenizer, @model_options[:task], checkpoint)
40
+ end
41
+
42
+ if Array === texts
45
43
 
46
- Open.mkdir File.dirname(output_dir)
47
- Open.mkdir File.dirname(tsv_file)
44
+ if @model_options.include?(:locate_tokens)
45
+ locate_tokens = @model_options[:locate_tokens]
46
+ elsif @model_options[:task] == "MaskedLM"
47
+ @model_options[:locate_tokens] = locate_tokens = @tokenizer.special_tokens_map["mask_token"]
48
+ end
48
49
 
49
- if labels
50
- training_args = call_method(:training_args, output_dir)
51
- call_method(:train_model, @model, @tokenizer, training_args, tsv_dataset(tsv_file, elements, labels))
52
- else
53
- if Array === elements
54
- training_args = call_method(:training_args, output_dir)
55
- call_method(:predict_model, @model, @tokenizer, training_args, tsv_dataset(tsv_file, elements), @locate_tokens)
50
+ if @directory
51
+ tsv_file = File.join(@directory, 'dataset.tsv')
52
+ checkpoint_dir = File.join(@directory, 'checkpoints')
56
53
  else
57
- call_method(:eval_model, @model, @tokenizer, [elements], @locate_tokens)
54
+ tmpdir = TmpFile.tmp_file
55
+ Open.mkdir tmpdir
56
+ tsv_file = File.join(tmpdir, 'dataset.tsv')
57
+ checkpoint_dir = File.join(tmpdir, 'checkpoints')
58
+ end
59
+
60
+ dataset_file = HuggingfaceModel.tsv_dataset(tsv_file, texts)
61
+ training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, checkpoint_dir, @model_options[:training_args])
62
+
63
+ begin
64
+ RbbtPython.call_method("rbbt_dm.huggingface", :predict_model, @model, @tokenizer, training_args_obj, dataset_file, locate_tokens)
65
+ ensure
66
+ Open.rm_rf tmpdir if tmpdir
58
67
  end
68
+ else
69
+ RbbtPython.call_method("rbbt_dm.huggingface", :eval_model, @model, @tokenizer, [texts], locate_tokens)
59
70
  end
60
71
  end
61
- end
62
-
63
- def initialize(task, initial_checkpoint = nil, *args)
64
- super(*args)
65
- @task = task
66
72
 
67
- @checkpoint = model_file && File.exists?(model_file)? model_file : initial_checkpoint
73
+ train_model do |directory,texts,labels|
74
+ checkpoint = directory && File.directory?(directory) ? directory : @model_options[:checkpoint]
75
+ @model, @tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_model_and_tokenizer, @model_options[:task], checkpoint)
68
76
 
69
- @model, @tokenizer = call_method(:load_model_and_tokenizer, @task, @checkpoint)
77
+ if @directory
78
+ tsv_file = File.join(@directory, 'dataset.tsv')
79
+ checkpoint_dir = File.join(@directory, 'checkpoints')
80
+ else
81
+ tmpdir = TmpFile.tmp_file
82
+ Open.mkdir tmpdir
83
+ tsv_file = File.join(tmpdir, 'dataset.tsv')
84
+ checkpoint_dir = File.join(tmpdir, 'checkpoints')
85
+ end
70
86
 
71
- @locate_tokens = @tokenizer.special_tokens_map["mask_token"] if @task == "MaskedLM"
87
+ training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, checkpoint_dir, @model_options[:training_args])
88
+ dataset_file = HuggingfaceModel.tsv_dataset(tsv_file, texts, labels)
72
89
 
73
- train_model do |file,elements,labels|
74
- run_model(elements, labels)
90
+ RbbtPython.call_method("rbbt_dm.huggingface", :train_model, @model, @tokenizer, training_args_obj, dataset_file, @model_options[:class_weights])
75
91
 
76
- @model.save_pretrained(file) if file
77
- @tokenizer.save_pretrained(file) if file
78
- end
92
+ Open.rm_rf tmpdir if tmpdir
79
93
 
80
- eval_model do |file,elements|
81
- run_model(elements)
94
+ @model.save_pretrained(directory) if directory
95
+ @tokenizer.save_pretrained(directory) if directory
82
96
  end
83
97
 
84
98
  post_process do |result|
@@ -93,12 +107,13 @@ class HuggingfaceModel < VectorModel
93
107
  predictions = result["logits"]
94
108
  end
95
109
 
96
- result = case @task
110
+ task, class_labels, locate_tokens = @model_options.values_at :task, :class_labels, :locate_tokens
111
+ result = case task
97
112
  when "SequenceClassification"
98
113
  RbbtPython.collect(predictions) do |logits|
99
114
  logits = RbbtPython.numpy2ruby logits
100
115
  best_class = logits.index logits.max
101
- best_class = @class_labels[best_class] if @class_labels
116
+ best_class = class_labels[best_class] if class_labels
102
117
  best_class
103
118
  end
104
119
  when "MaskedLM"
@@ -124,7 +139,7 @@ class HuggingfaceModel < VectorModel
124
139
 
125
140
  best.collect{|b| @tokenizer.decode(b) } * "|"
126
141
  end
127
- Array === @locate_tokens ? item_masks : item_masks.first
142
+ Array === locate_tokens ? item_masks : item_masks.first
128
143
  end
129
144
  else
130
145
  logits
@@ -132,9 +147,15 @@ class HuggingfaceModel < VectorModel
132
147
 
133
148
  single ? result.first : result
134
149
  end
150
+
151
+
152
+ save_models if @directory
135
153
  end
136
- end
137
154
 
138
- if __FILE__ == $0
155
+ def reset_model
156
+ @model, @tokenizer = nil
157
+ Open.rm @model_file
158
+ end
139
159
 
140
160
  end
161
+
@@ -4,6 +4,7 @@ require 'rbbt/vector/model/util'
4
4
  class VectorModel
5
5
  attr_accessor :directory, :model_file, :extract_features, :train_model, :eval_model, :post_process
6
6
  attr_accessor :features, :names, :labels, :factor_levels
7
+ attr_accessor :model_options
7
8
 
8
9
  def extract_features(&block)
9
10
  @extract_features = block if block_given?
@@ -126,7 +127,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
126
127
  instance_eval code, file
127
128
  end
128
129
 
129
- def initialize(directory = nil, extract_features = nil, train_model = nil, eval_model = nil, names = nil, factor_levels = nil)
130
+ def initialize(directory = nil, extract_features = nil, train_model = nil, eval_model = nil, post_process = nil, names = nil, factor_levels = nil)
130
131
  @directory = directory
131
132
  if @directory
132
133
  FileUtils.mkdir_p @directory unless File.exists?(@directory)
@@ -135,10 +136,18 @@ cat(paste(label, sep="\\n", collapse="\\n"));
135
136
  @extract_features_file = File.join(@directory, "features")
136
137
  @train_model_file = File.join(@directory, "train_model")
137
138
  @eval_model_file = File.join(@directory, "eval_model")
139
+ @post_process_file = File.join(@directory, "post_process")
138
140
  @train_model_file_R = File.join(@directory, "train_model.R")
139
141
  @eval_model_file_R = File.join(@directory, "eval_model.R")
142
+ @post_process_file_R = File.join(@directory, "post_process.R")
140
143
  @names_file = File.join(@directory, "feature_names")
141
144
  @levels_file = File.join(@directory, "levels")
145
+ @options_file = File.join(@directory, "options.json")
146
+
147
+ if File.exists?(@options_file)
148
+ @model_options = JSON.parse(Open.read(@options_file))
149
+ IndiferentHash.setup(@model_options)
150
+ end
142
151
  end
143
152
 
144
153
  if extract_features.nil?
@@ -169,6 +178,17 @@ cat(paste(label, sep="\\n", collapse="\\n"));
169
178
  @eval_model = eval_model
170
179
  end
171
180
 
181
+ if post_process.nil?
182
+ if @post_process_file && File.exists?(@post_process_file)
183
+ @post_process = __load_method @post_process_file
184
+ elsif @post_process_file_R && File.exists?(@post_process_file_R)
185
+ @post_process = Open.read(@post_process_file_R)
186
+ end
187
+ else
188
+ @post_process = post_process
189
+ end
190
+
191
+
172
192
  if names.nil?
173
193
  if @names_file && File.exists?(@names_file)
174
194
  @names = Open.read(@names_file).split("\n")
@@ -240,8 +260,20 @@ cat(paste(label, sep="\\n", collapse="\\n"));
240
260
  Open.write(@eval_model_file_R, eval_model)
241
261
  end
242
262
 
263
+ case
264
+ when Proc === post_process
265
+ begin
266
+ Open.write(@post_process_file, post_process.source)
267
+ rescue
268
+ end
269
+ when String === post_process
270
+ Open.write(@post_process_file_R, post_process)
271
+ end
272
+
273
+
243
274
  Open.write(@levels_file, @factor_levels.to_yaml) if @factor_levels
244
275
  Open.write(@names_file, @names * "\n" + "\n") if @names
276
+ Open.write(@options_file, @model_options.to_json) if @model_options
245
277
  end
246
278
 
247
279
  def train
@@ -251,7 +283,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
251
283
  when String === @train_model
252
284
  VectorModel.R_train(@model_file, @features, @labels, train_model, @names, @factor_levels)
253
285
  end
254
- save_models
286
+ save_models if @directory
255
287
  end
256
288
 
257
289
  def run(code)
@@ -299,38 +331,6 @@ cat(paste(label, sep="\\n", collapse="\\n"));
299
331
  result
300
332
  end
301
333
 
302
- #def cross_validation(folds = 10)
303
- # saved_features = @features
304
- # saved_labels = @labels
305
- # seq = (0..features.length - 1).to_a
306
-
307
- # chunk_size = features.length / folds
308
-
309
- # acc = []
310
- # folds.times do
311
- # seq = seq.shuffle
312
- # eval_chunk = seq[0..chunk_size]
313
- # train_chunk = seq[chunk_size.. -1]
314
-
315
- # eval_features = @features.values_at *eval_chunk
316
- # eval_labels = @labels.values_at *eval_chunk
317
-
318
- # @features = @features.values_at *train_chunk
319
- # @labels = @labels.values_at *train_chunk
320
-
321
- # train
322
- # predictions = eval_list eval_features, false
323
-
324
- # acc << predictions.zip(eval_labels).collect{|pred,lab| pred - lab < 0.5 ? 1 : 0}.inject(0){|acc,e| acc +=e} / chunk_size
325
-
326
- # @features = saved_features
327
- # @labels = saved_labels
328
- # end
329
-
330
- # acc
331
- #end
332
- #
333
-
334
334
  def self.f1_metrics(test, predicted, good_label = nil)
335
335
  tp, tn, fp, fn, pr, re, f1 = [0, 0, 0, 0, nil, nil, nil]
336
336
 
@@ -413,6 +413,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
413
413
  @features = train_set
414
414
  @labels = train_labels
415
415
 
416
+ self.reset_model if self.respond_to? :reset_model
416
417
  self.train
417
418
  predictions = self.eval_list test_set, false
418
419
 
@@ -437,6 +438,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
437
438
  @features = orig_features
438
439
  @labels = orig_labels
439
440
  end unless folds == -1
441
+ self.reset_model if self.respond_to? :reset_model
440
442
  self.train unless folds == 1
441
443
  res
442
444
  end
@@ -0,0 +1 @@
1
+ # Keep
@@ -17,6 +17,17 @@ def load_model_and_tokenizer(task, checkpoint):
17
17
  tokenizer = load_tokenizer(task, checkpoint)
18
18
  return model, tokenizer
19
19
 
20
+ def load_model_and_tokenizer_from_directory(directory):
21
+ import os
22
+ import json
23
+ options_file = os.path.join(directory, 'options.json')
24
+ f = open(options_file, "r")
25
+ options = json.load(f.read())
26
+ f.close()
27
+ task = options["task"]
28
+ checkpoint = options["checkpoint"]
29
+ return load_model_and_tokenizer(task, checkpoint)
30
+
20
31
  #{{{ SIMPLE EVALUATE
21
32
 
22
33
  def forward(model, features):
@@ -51,17 +62,40 @@ def training_args(*args, **kwargs):
51
62
  return training_args
52
63
 
53
64
 
54
- def train_model(model, tokenizer, training_args, tsv_file):
65
+ def train_model(model, tokenizer, training_args, tsv_file, class_weights=None):
55
66
  from transformers import Trainer
56
67
 
57
68
  tokenized_dataset = tsv_dataset(tokenizer, tsv_file)
58
69
 
59
- trainer = Trainer(
60
- model,
61
- training_args,
62
- train_dataset = tokenized_dataset["train"],
63
- tokenizer = tokenizer
64
- )
70
+ if (not class_weights == None):
71
+ import torch
72
+ from torch import nn
73
+
74
+ class WeightTrainer(Trainer):
75
+ def compute_loss(self, model, inputs, return_outputs=False):
76
+ labels = inputs.get("labels")
77
+ # forward pass
78
+ outputs = model(**inputs)
79
+ logits = outputs.get('logits')
80
+ # compute custom loss
81
+ loss_fct = nn.CrossEntropyLoss(weight=torch.tensor(class_weights).to(model.device))
82
+ loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
83
+ return (loss, outputs) if return_outputs else loss
84
+
85
+ trainer = WeightTrainer(
86
+ model,
87
+ training_args,
88
+ train_dataset = tokenized_dataset["train"],
89
+ tokenizer = tokenizer
90
+ )
91
+ else:
92
+
93
+ trainer = Trainer(
94
+ model,
95
+ training_args,
96
+ train_dataset = tokenized_dataset["train"],
97
+ tokenizer = tokenizer
98
+ )
65
99
 
66
100
  trainer.train()
67
101
 
@@ -90,7 +124,6 @@ def find_tokens_in_input(dataset, token_ids):
90
124
  return position_rows
91
125
 
92
126
 
93
-
94
127
  def predict_model(model, tokenizer, training_args, tsv_file, locate_tokens = None):
95
128
  from transformers import Trainer
96
129
 
@@ -110,3 +143,4 @@ def predict_model(model, tokenizer, training_args, tsv_file, locate_tokens = Non
110
143
  else:
111
144
  return result
112
145
 
146
+
@@ -3,16 +3,46 @@ require 'rbbt/vector/model/huggingface'
3
3
 
4
4
  class TestHuggingface < Test::Unit::TestCase
5
5
 
6
+ def test_options
7
+ TmpFile.with_file do |dir|
8
+ checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
9
+ task = "SequenceClassification"
10
+
11
+ model = HuggingfaceModel.new task, checkpoint, dir, :class_labels => %w(bad good)
12
+ iii model.eval "This is dog"
13
+ iii model.eval "This is cat"
14
+ iii model.eval(["This is dog", "This is cat"])
15
+
16
+ model = VectorModel.new dir
17
+ iii model.eval(["This is dog", "This is cat"])
18
+ end
19
+ end
20
+
21
+ def test_pipeline
22
+ require 'rbbt/util/python'
23
+ model = VectorModel.new
24
+ model.post_process do |elements|
25
+ elements.collect{|e| e['label'] }
26
+ end
27
+ model.eval_model do |file, elements|
28
+ RbbtPython.run :transformers do
29
+ classifier ||= transformers.pipeline("sentiment-analysis")
30
+ classifier.call(elements)
31
+ end
32
+ end
33
+
34
+ assert_equal ["POSITIVE"], model.eval("I've been waiting for a HuggingFace course my whole life.")
35
+ end
36
+
6
37
  def test_sst_eval
7
38
  TmpFile.with_file do |dir|
8
39
  checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
9
40
 
10
41
  model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
11
42
 
12
- model.class_labels = ["Bad", "Good"]
43
+ model.model_options[:class_labels] = ["Bad", "Good"]
13
44
 
14
45
  assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
15
-
16
46
  end
17
47
  end
18
48
 
@@ -22,7 +52,8 @@ class TestHuggingface < Test::Unit::TestCase
22
52
  checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
23
53
 
24
54
  model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
25
- model.class_labels = ["Bad", "Good"]
55
+
56
+ model.model_options[:class_labels] = %w(Bad Good)
26
57
 
27
58
  assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
28
59
 
@@ -33,6 +64,9 @@ class TestHuggingface < Test::Unit::TestCase
33
64
  model.train
34
65
 
35
66
  assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
67
+
68
+ model = VectorModel.new dir
69
+ assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
36
70
  end
37
71
  end
38
72
 
@@ -40,7 +74,7 @@ class TestHuggingface < Test::Unit::TestCase
40
74
  checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
41
75
 
42
76
  model = HuggingfaceModel.new "SequenceClassification", checkpoint
43
- model.class_labels = ["Bad", "Good"]
77
+ model.model_options[:class_labels] = ["Bad", "Good"]
44
78
 
45
79
  assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
46
80
 
@@ -58,7 +92,7 @@ class TestHuggingface < Test::Unit::TestCase
58
92
  checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
59
93
 
60
94
  model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
61
- model.class_labels = ["Bad", "Good"]
95
+ model.model_options[:class_labels] = ["Bad", "Good"]
62
96
 
63
97
  assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
64
98
 
@@ -69,15 +103,20 @@ class TestHuggingface < Test::Unit::TestCase
69
103
  model.train
70
104
 
71
105
  model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
72
- model.class_labels = ["Bad", "Good"]
73
106
 
74
107
  assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
75
108
 
76
- model = HuggingfaceModel.new "SequenceClassification", model.model_file
77
- model.class_labels = ["Bad", "Good"]
109
+ model_file = model.model_file
110
+
111
+ model = HuggingfaceModel.new "SequenceClassification", model_file
112
+ model.model_options[:class_labels] = ["Bad", "Good"]
78
113
 
79
114
  assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
80
115
 
116
+ model = VectorModel.new dir
117
+
118
+ assert_equal "Good", model.eval("This is dog")
119
+
81
120
  end
82
121
  end
83
122
 
@@ -108,9 +147,7 @@ class TestHuggingface < Test::Unit::TestCase
108
147
  model = HuggingfaceModel.new "MaskedLM", checkpoint
109
148
  assert_equal 3, model.eval(["Paris is the [MASK] of the France.", "The [MASK] worked very hard all the time.", "The [MASK] arrested the dangerous [MASK]."]).
110
149
  reject{|v| v.empty?}.length
111
-
112
150
  end
113
151
 
114
-
115
152
  end
116
153
 
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rbbt-dm
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.2.3
4
+ version: 1.2.6
5
5
  platform: ruby
6
6
  authors:
7
7
  - Miguel Vazquez
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2023-02-04 00:00:00.000000000 Z
11
+ date: 2023-02-08 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rbbt-util
@@ -107,12 +107,14 @@ files:
107
107
  - lib/rbbt/statistics/rank_product.rb
108
108
  - lib/rbbt/tensorflow.rb
109
109
  - lib/rbbt/vector/model.rb
110
+ - lib/rbbt/vector/model/huggingface.old.rb
110
111
  - lib/rbbt/vector/model/huggingface.rb
111
112
  - lib/rbbt/vector/model/random_forest.rb
112
113
  - lib/rbbt/vector/model/spaCy.rb
113
114
  - lib/rbbt/vector/model/svm.rb
114
115
  - lib/rbbt/vector/model/tensorflow.rb
115
116
  - lib/rbbt/vector/model/util.rb
117
+ - python/rbbt_dm/__init__.py
116
118
  - python/rbbt_dm/huggingface.py
117
119
  - share/R/MA.R
118
120
  - share/R/barcode.R