rbbt-dm 1.2.3 → 1.2.6

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 2ff72107967b0f7c654697f3a7b3c0ef10f7a5264d775117f12f74084a2819b2
4
- data.tar.gz: 6b9a58b5a2723c095332f79a37d9c1c7f4bc1431410f23a55beeed1c3b52f7ad
3
+ metadata.gz: 9744ab9faeaf4f9cc04947eb11103dbf0694dda624f805a5c6be27bb22af81ce
4
+ data.tar.gz: d3a3903aa276a69e20cbd71213286449db396ecf5f6a4b4d80a64ab299041fbb
5
5
  SHA512:
6
- metadata.gz: a0fb4198cb0be3aa5253df0f655ee230621dd26a31956a774fffe95eac35f4c8b558a41c0e340c25c6eef463760ff6230b967f09eb09671b2078a50066067384
7
- data.tar.gz: 0bd6c3667a8ec26ed092c54e78176671807ed0634e136e497303e34a17b2740e5e023041bc06389ac187de39b54942b9b1c5cd77abbc067c89250424654b6974
6
+ metadata.gz: 263fb609b37522874426bcd79374760399b4a9aaab443ae6d74c727f2d148474dd71ee0b2cfda7a50131dafbc314f66352f4285562a75f62144d2e05ccd214c7
7
+ data.tar.gz: 1e0426429a38028a19b3f8c955e975138199c791dad8691de7fb760a5cbec3304f19341906a4457563a90de965ee2f12a5a639b928866b46258af2507eeb39fa
@@ -0,0 +1,160 @@
1
+ require 'rbbt/vector/model'
2
+ require 'rbbt/util/python'
3
+
4
+ RbbtPython.add_path Rbbt.python.find(:lib)
5
+ RbbtPython.init_rbbt
6
+
7
+ class HuggingfaceModel < VectorModel
8
+
9
+ attr_accessor :checkpoint, :task, :locate_tokens, :class_labels, :class_weights, :training_args
10
+
11
+ def self.tsv_dataset(tsv_dataset_file, elements, labels = nil)
12
+
13
+ if labels
14
+ Open.write(tsv_dataset_file) do |ffile|
15
+ ffile.puts ["label", "text"].flatten * "\t"
16
+ elements.zip(labels).each do |element,label|
17
+ ffile.puts [label, element].flatten * "\t"
18
+ end
19
+ end
20
+ else
21
+ Open.write(tsv_dataset_file) do |ffile|
22
+ ffile.puts ["text"].flatten * "\t"
23
+ elements.each{|element| ffile.puts element }
24
+ end
25
+ end
26
+
27
+ tsv_dataset_file
28
+ end
29
+
30
+ def self.call_method(name, *args)
31
+ RbbtPython.import_method("rbbt_dm.huggingface", name).call(*args)
32
+ end
33
+
34
+ def call_method(name, *args)
35
+ HuggingfaceModel.call_method(name, *args)
36
+ end
37
+
38
+ #def input_tsv_file
39
+ # File.join(@directory, 'dataset.tsv') if @directory
40
+ #end
41
+
42
+ #def checkpoint_dir
43
+ # File.join(@directory, 'checkpoints') if @directory
44
+ #end
45
+
46
+ def self.run_model(model, tokenizer, elements, labels = nil, training_args = {}, class_weights = nil)
47
+ TmpFile.with_file do |tmpfile|
48
+ tsv_file = File.join(tmpfile, 'dataset.tsv')
49
+
50
+ if training_args
51
+ training_args = training_args.dup
52
+ checkpoint_dir = training_args.delete(:checkpoint_dir)
53
+ end
54
+
55
+ checkpoint_dir = File.join(tmpfile, 'checkpoints')
56
+
57
+ Open.mkdir File.dirname(tsv_file)
58
+ Open.mkdir File.dirname(checkpoint_dir)
59
+
60
+ if labels
61
+ training_args_obj = call_method(:training_args, checkpoint_dir, **training_args)
62
+ call_method(:train_model, model, tokenizer, training_args_obj, tsv_dataset(tsv_file, elements, labels), class_weights)
63
+ else
64
+ locate_tokens, training_args = training_args, {}
65
+ if Array === elements
66
+ training_args_obj = call_method(:training_args, checkpoint_dir)
67
+ call_method(:predict_model, model, tokenizer, training_args_obj, tsv_dataset(tsv_file, elements), locate_tokens)
68
+ else
69
+ call_method(:eval_model, model, tokenizer, [elements], locate_tokens)
70
+ end
71
+ end
72
+ end
73
+ end
74
+
75
+ def init_model
76
+ @model, @tokenizer = call_method(:load_model_and_tokenizer, @task, @checkpoint)
77
+ end
78
+
79
+ def reset_model
80
+ init_model
81
+ end
82
+
83
+ def initialize(task, initial_checkpoint = nil, *args)
84
+ super(*args)
85
+ @task = task
86
+
87
+ @checkpoint = model_file && File.exists?(model_file)? model_file : initial_checkpoint
88
+
89
+ init_model
90
+
91
+ @locate_tokens = @tokenizer.special_tokens_map["mask_token"] if @task == "MaskedLM"
92
+
93
+ @training_args = {}
94
+
95
+ train_model do |file,elements,labels|
96
+ HuggingfaceModel.run_model(@model, @tokenizer, elements, labels, @training_args, @class_weights)
97
+
98
+ @model.save_pretrained(file) if file
99
+ @tokenizer.save_pretrained(file) if file
100
+ end
101
+
102
+ eval_model do |file,elements|
103
+ @model, @tokenizer = HuggingfaceModel.call_method(:load_model_and_tokenizer, @task, @checkpoint)
104
+ HuggingfaceModel.run_model(@model, @tokenizer, elements, nil, @locate_tokens)
105
+ end
106
+
107
+ post_process do |result|
108
+ if result.respond_to?(:predictions)
109
+ single = false
110
+ predictions = result.predictions
111
+ elsif result["token_positions"]
112
+ predictions = result["result"].predictions
113
+ token_positions = result["token_positions"]
114
+ else
115
+ single = true
116
+ predictions = result["logits"]
117
+ end
118
+
119
+ result = case @task
120
+ when "SequenceClassification"
121
+ RbbtPython.collect(predictions) do |logits|
122
+ logits = RbbtPython.numpy2ruby logits
123
+ best_class = logits.index logits.max
124
+ best_class = @class_labels[best_class] if @class_labels
125
+ best_class
126
+ end
127
+ when "MaskedLM"
128
+ all_token_positions = token_positions.to_a
129
+
130
+ i = 0
131
+ RbbtPython.collect(predictions) do |item_logits|
132
+ item_token_positions = all_token_positions[i]
133
+ i += 1
134
+
135
+ item_logits = RbbtPython.numpy2ruby(item_logits)
136
+ item_masks = item_token_positions.collect do |token_positions|
137
+
138
+ best = item_logits.values_at(*token_positions).collect do |logits|
139
+ best_token, best_score = nil
140
+ logits.each_with_index do |v,i|
141
+ if best_score.nil? || v > best_score
142
+ best_token, best_score = i, v
143
+ end
144
+ end
145
+ best_token
146
+ end
147
+
148
+ best.collect{|b| @tokenizer.decode(b) } * "|"
149
+ end
150
+ Array === @locate_tokens ? item_masks : item_masks.first
151
+ end
152
+ else
153
+ logits
154
+ end
155
+
156
+ single ? result.first : result
157
+ end
158
+ end
159
+ end
160
+
@@ -1,13 +1,12 @@
1
1
  require 'rbbt/vector/model'
2
2
  require 'rbbt/util/python'
3
3
 
4
+ RbbtPython.add_path Rbbt.python.find(:lib)
4
5
  RbbtPython.init_rbbt
5
6
 
6
7
  class HuggingfaceModel < VectorModel
7
8
 
8
- attr_accessor :checkpoint, :task, :locate_tokens, :class_labels
9
-
10
- def tsv_dataset(tsv_dataset_file, elements, labels = nil)
9
+ def self.tsv_dataset(tsv_dataset_file, elements, labels = nil)
11
10
 
12
11
  if labels
13
12
  Open.write(tsv_dataset_file) do |ffile|
@@ -26,59 +25,74 @@ class HuggingfaceModel < VectorModel
26
25
  tsv_dataset_file
27
26
  end
28
27
 
29
- def call_method(name, *args)
30
- RbbtPython.import_method("rbbt_dm.huggingface", name).call(*args)
31
- end
32
-
33
- def input_tsv_file
34
- File.join(@directory, 'dataset.tsv') if @directory
35
- end
28
+ def initialize(task, checkpoint, *args)
29
+ options = args.pop if Hash === args.last
30
+ options = Misc.add_defaults options, :task => task, :checkpoint => checkpoint
31
+ super(*args)
32
+ @model_options ||= {}
33
+ @model_options.merge!(options)
36
34
 
37
- def checkpoint_dir
38
- File.join(@directory, 'checkpoints') if @directory
39
- end
35
+ eval_model do |directory,texts|
36
+ checkpoint = directory && File.directory?(directory) ? directory : @model_options[:checkpoint]
40
37
 
41
- def run_model(elements, labels = nil)
42
- TmpFile.with_file do |tmpfile|
43
- tsv_file = input_tsv_file || File.join(tmpfile, 'dataset.tsv')
44
- output_dir = checkpoint_dir || File.join(tmpfile, 'checkpoints')
38
+ if @model.nil?
39
+ @model, @tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_model_and_tokenizer, @model_options[:task], checkpoint)
40
+ end
41
+
42
+ if Array === texts
45
43
 
46
- Open.mkdir File.dirname(output_dir)
47
- Open.mkdir File.dirname(tsv_file)
44
+ if @model_options.include?(:locate_tokens)
45
+ locate_tokens = @model_options[:locate_tokens]
46
+ elsif @model_options[:task] == "MaskedLM"
47
+ @model_options[:locate_tokens] = locate_tokens = @tokenizer.special_tokens_map["mask_token"]
48
+ end
48
49
 
49
- if labels
50
- training_args = call_method(:training_args, output_dir)
51
- call_method(:train_model, @model, @tokenizer, training_args, tsv_dataset(tsv_file, elements, labels))
52
- else
53
- if Array === elements
54
- training_args = call_method(:training_args, output_dir)
55
- call_method(:predict_model, @model, @tokenizer, training_args, tsv_dataset(tsv_file, elements), @locate_tokens)
50
+ if @directory
51
+ tsv_file = File.join(@directory, 'dataset.tsv')
52
+ checkpoint_dir = File.join(@directory, 'checkpoints')
56
53
  else
57
- call_method(:eval_model, @model, @tokenizer, [elements], @locate_tokens)
54
+ tmpdir = TmpFile.tmp_file
55
+ Open.mkdir tmpdir
56
+ tsv_file = File.join(tmpdir, 'dataset.tsv')
57
+ checkpoint_dir = File.join(tmpdir, 'checkpoints')
58
+ end
59
+
60
+ dataset_file = HuggingfaceModel.tsv_dataset(tsv_file, texts)
61
+ training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, checkpoint_dir, @model_options[:training_args])
62
+
63
+ begin
64
+ RbbtPython.call_method("rbbt_dm.huggingface", :predict_model, @model, @tokenizer, training_args_obj, dataset_file, locate_tokens)
65
+ ensure
66
+ Open.rm_rf tmpdir if tmpdir
58
67
  end
68
+ else
69
+ RbbtPython.call_method("rbbt_dm.huggingface", :eval_model, @model, @tokenizer, [texts], locate_tokens)
59
70
  end
60
71
  end
61
- end
62
-
63
- def initialize(task, initial_checkpoint = nil, *args)
64
- super(*args)
65
- @task = task
66
72
 
67
- @checkpoint = model_file && File.exists?(model_file)? model_file : initial_checkpoint
73
+ train_model do |directory,texts,labels|
74
+ checkpoint = directory && File.directory?(directory) ? directory : @model_options[:checkpoint]
75
+ @model, @tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_model_and_tokenizer, @model_options[:task], checkpoint)
68
76
 
69
- @model, @tokenizer = call_method(:load_model_and_tokenizer, @task, @checkpoint)
77
+ if @directory
78
+ tsv_file = File.join(@directory, 'dataset.tsv')
79
+ checkpoint_dir = File.join(@directory, 'checkpoints')
80
+ else
81
+ tmpdir = TmpFile.tmp_file
82
+ Open.mkdir tmpdir
83
+ tsv_file = File.join(tmpdir, 'dataset.tsv')
84
+ checkpoint_dir = File.join(tmpdir, 'checkpoints')
85
+ end
70
86
 
71
- @locate_tokens = @tokenizer.special_tokens_map["mask_token"] if @task == "MaskedLM"
87
+ training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, checkpoint_dir, @model_options[:training_args])
88
+ dataset_file = HuggingfaceModel.tsv_dataset(tsv_file, texts, labels)
72
89
 
73
- train_model do |file,elements,labels|
74
- run_model(elements, labels)
90
+ RbbtPython.call_method("rbbt_dm.huggingface", :train_model, @model, @tokenizer, training_args_obj, dataset_file, @model_options[:class_weights])
75
91
 
76
- @model.save_pretrained(file) if file
77
- @tokenizer.save_pretrained(file) if file
78
- end
92
+ Open.rm_rf tmpdir if tmpdir
79
93
 
80
- eval_model do |file,elements|
81
- run_model(elements)
94
+ @model.save_pretrained(directory) if directory
95
+ @tokenizer.save_pretrained(directory) if directory
82
96
  end
83
97
 
84
98
  post_process do |result|
@@ -93,12 +107,13 @@ class HuggingfaceModel < VectorModel
93
107
  predictions = result["logits"]
94
108
  end
95
109
 
96
- result = case @task
110
+ task, class_labels, locate_tokens = @model_options.values_at :task, :class_labels, :locate_tokens
111
+ result = case task
97
112
  when "SequenceClassification"
98
113
  RbbtPython.collect(predictions) do |logits|
99
114
  logits = RbbtPython.numpy2ruby logits
100
115
  best_class = logits.index logits.max
101
- best_class = @class_labels[best_class] if @class_labels
116
+ best_class = class_labels[best_class] if class_labels
102
117
  best_class
103
118
  end
104
119
  when "MaskedLM"
@@ -124,7 +139,7 @@ class HuggingfaceModel < VectorModel
124
139
 
125
140
  best.collect{|b| @tokenizer.decode(b) } * "|"
126
141
  end
127
- Array === @locate_tokens ? item_masks : item_masks.first
142
+ Array === locate_tokens ? item_masks : item_masks.first
128
143
  end
129
144
  else
130
145
  logits
@@ -132,9 +147,15 @@ class HuggingfaceModel < VectorModel
132
147
 
133
148
  single ? result.first : result
134
149
  end
150
+
151
+
152
+ save_models if @directory
135
153
  end
136
- end
137
154
 
138
- if __FILE__ == $0
155
+ def reset_model
156
+ @model, @tokenizer = nil
157
+ Open.rm @model_file
158
+ end
139
159
 
140
160
  end
161
+
@@ -4,6 +4,7 @@ require 'rbbt/vector/model/util'
4
4
  class VectorModel
5
5
  attr_accessor :directory, :model_file, :extract_features, :train_model, :eval_model, :post_process
6
6
  attr_accessor :features, :names, :labels, :factor_levels
7
+ attr_accessor :model_options
7
8
 
8
9
  def extract_features(&block)
9
10
  @extract_features = block if block_given?
@@ -126,7 +127,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
126
127
  instance_eval code, file
127
128
  end
128
129
 
129
- def initialize(directory = nil, extract_features = nil, train_model = nil, eval_model = nil, names = nil, factor_levels = nil)
130
+ def initialize(directory = nil, extract_features = nil, train_model = nil, eval_model = nil, post_process = nil, names = nil, factor_levels = nil)
130
131
  @directory = directory
131
132
  if @directory
132
133
  FileUtils.mkdir_p @directory unless File.exists?(@directory)
@@ -135,10 +136,18 @@ cat(paste(label, sep="\\n", collapse="\\n"));
135
136
  @extract_features_file = File.join(@directory, "features")
136
137
  @train_model_file = File.join(@directory, "train_model")
137
138
  @eval_model_file = File.join(@directory, "eval_model")
139
+ @post_process_file = File.join(@directory, "post_process")
138
140
  @train_model_file_R = File.join(@directory, "train_model.R")
139
141
  @eval_model_file_R = File.join(@directory, "eval_model.R")
142
+ @post_process_file_R = File.join(@directory, "post_process.R")
140
143
  @names_file = File.join(@directory, "feature_names")
141
144
  @levels_file = File.join(@directory, "levels")
145
+ @options_file = File.join(@directory, "options.json")
146
+
147
+ if File.exists?(@options_file)
148
+ @model_options = JSON.parse(Open.read(@options_file))
149
+ IndiferentHash.setup(@model_options)
150
+ end
142
151
  end
143
152
 
144
153
  if extract_features.nil?
@@ -169,6 +178,17 @@ cat(paste(label, sep="\\n", collapse="\\n"));
169
178
  @eval_model = eval_model
170
179
  end
171
180
 
181
+ if post_process.nil?
182
+ if @post_process_file && File.exists?(@post_process_file)
183
+ @post_process = __load_method @post_process_file
184
+ elsif @post_process_file_R && File.exists?(@post_process_file_R)
185
+ @post_process = Open.read(@post_process_file_R)
186
+ end
187
+ else
188
+ @post_process = post_process
189
+ end
190
+
191
+
172
192
  if names.nil?
173
193
  if @names_file && File.exists?(@names_file)
174
194
  @names = Open.read(@names_file).split("\n")
@@ -240,8 +260,20 @@ cat(paste(label, sep="\\n", collapse="\\n"));
240
260
  Open.write(@eval_model_file_R, eval_model)
241
261
  end
242
262
 
263
+ case
264
+ when Proc === post_process
265
+ begin
266
+ Open.write(@post_process_file, post_process.source)
267
+ rescue
268
+ end
269
+ when String === post_process
270
+ Open.write(@post_process_file_R, post_process)
271
+ end
272
+
273
+
243
274
  Open.write(@levels_file, @factor_levels.to_yaml) if @factor_levels
244
275
  Open.write(@names_file, @names * "\n" + "\n") if @names
276
+ Open.write(@options_file, @model_options.to_json) if @model_options
245
277
  end
246
278
 
247
279
  def train
@@ -251,7 +283,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
251
283
  when String === @train_model
252
284
  VectorModel.R_train(@model_file, @features, @labels, train_model, @names, @factor_levels)
253
285
  end
254
- save_models
286
+ save_models if @directory
255
287
  end
256
288
 
257
289
  def run(code)
@@ -299,38 +331,6 @@ cat(paste(label, sep="\\n", collapse="\\n"));
299
331
  result
300
332
  end
301
333
 
302
- #def cross_validation(folds = 10)
303
- # saved_features = @features
304
- # saved_labels = @labels
305
- # seq = (0..features.length - 1).to_a
306
-
307
- # chunk_size = features.length / folds
308
-
309
- # acc = []
310
- # folds.times do
311
- # seq = seq.shuffle
312
- # eval_chunk = seq[0..chunk_size]
313
- # train_chunk = seq[chunk_size.. -1]
314
-
315
- # eval_features = @features.values_at *eval_chunk
316
- # eval_labels = @labels.values_at *eval_chunk
317
-
318
- # @features = @features.values_at *train_chunk
319
- # @labels = @labels.values_at *train_chunk
320
-
321
- # train
322
- # predictions = eval_list eval_features, false
323
-
324
- # acc << predictions.zip(eval_labels).collect{|pred,lab| pred - lab < 0.5 ? 1 : 0}.inject(0){|acc,e| acc +=e} / chunk_size
325
-
326
- # @features = saved_features
327
- # @labels = saved_labels
328
- # end
329
-
330
- # acc
331
- #end
332
- #
333
-
334
334
  def self.f1_metrics(test, predicted, good_label = nil)
335
335
  tp, tn, fp, fn, pr, re, f1 = [0, 0, 0, 0, nil, nil, nil]
336
336
 
@@ -413,6 +413,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
413
413
  @features = train_set
414
414
  @labels = train_labels
415
415
 
416
+ self.reset_model if self.respond_to? :reset_model
416
417
  self.train
417
418
  predictions = self.eval_list test_set, false
418
419
 
@@ -437,6 +438,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
437
438
  @features = orig_features
438
439
  @labels = orig_labels
439
440
  end unless folds == -1
441
+ self.reset_model if self.respond_to? :reset_model
440
442
  self.train unless folds == 1
441
443
  res
442
444
  end
@@ -0,0 +1 @@
1
+ # Keep
@@ -17,6 +17,17 @@ def load_model_and_tokenizer(task, checkpoint):
17
17
  tokenizer = load_tokenizer(task, checkpoint)
18
18
  return model, tokenizer
19
19
 
20
+ def load_model_and_tokenizer_from_directory(directory):
21
+ import os
22
+ import json
23
+ options_file = os.path.join(directory, 'options.json')
24
+ f = open(options_file, "r")
25
+ options = json.load(f.read())
26
+ f.close()
27
+ task = options["task"]
28
+ checkpoint = options["checkpoint"]
29
+ return load_model_and_tokenizer(task, checkpoint)
30
+
20
31
  #{{{ SIMPLE EVALUATE
21
32
 
22
33
  def forward(model, features):
@@ -51,17 +62,40 @@ def training_args(*args, **kwargs):
51
62
  return training_args
52
63
 
53
64
 
54
- def train_model(model, tokenizer, training_args, tsv_file):
65
+ def train_model(model, tokenizer, training_args, tsv_file, class_weights=None):
55
66
  from transformers import Trainer
56
67
 
57
68
  tokenized_dataset = tsv_dataset(tokenizer, tsv_file)
58
69
 
59
- trainer = Trainer(
60
- model,
61
- training_args,
62
- train_dataset = tokenized_dataset["train"],
63
- tokenizer = tokenizer
64
- )
70
+ if (not class_weights == None):
71
+ import torch
72
+ from torch import nn
73
+
74
+ class WeightTrainer(Trainer):
75
+ def compute_loss(self, model, inputs, return_outputs=False):
76
+ labels = inputs.get("labels")
77
+ # forward pass
78
+ outputs = model(**inputs)
79
+ logits = outputs.get('logits')
80
+ # compute custom loss
81
+ loss_fct = nn.CrossEntropyLoss(weight=torch.tensor(class_weights).to(model.device))
82
+ loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
83
+ return (loss, outputs) if return_outputs else loss
84
+
85
+ trainer = WeightTrainer(
86
+ model,
87
+ training_args,
88
+ train_dataset = tokenized_dataset["train"],
89
+ tokenizer = tokenizer
90
+ )
91
+ else:
92
+
93
+ trainer = Trainer(
94
+ model,
95
+ training_args,
96
+ train_dataset = tokenized_dataset["train"],
97
+ tokenizer = tokenizer
98
+ )
65
99
 
66
100
  trainer.train()
67
101
 
@@ -90,7 +124,6 @@ def find_tokens_in_input(dataset, token_ids):
90
124
  return position_rows
91
125
 
92
126
 
93
-
94
127
  def predict_model(model, tokenizer, training_args, tsv_file, locate_tokens = None):
95
128
  from transformers import Trainer
96
129
 
@@ -110,3 +143,4 @@ def predict_model(model, tokenizer, training_args, tsv_file, locate_tokens = Non
110
143
  else:
111
144
  return result
112
145
 
146
+
@@ -3,16 +3,46 @@ require 'rbbt/vector/model/huggingface'
3
3
 
4
4
  class TestHuggingface < Test::Unit::TestCase
5
5
 
6
+ def test_options
7
+ TmpFile.with_file do |dir|
8
+ checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
9
+ task = "SequenceClassification"
10
+
11
+ model = HuggingfaceModel.new task, checkpoint, dir, :class_labels => %w(bad good)
12
+ iii model.eval "This is dog"
13
+ iii model.eval "This is cat"
14
+ iii model.eval(["This is dog", "This is cat"])
15
+
16
+ model = VectorModel.new dir
17
+ iii model.eval(["This is dog", "This is cat"])
18
+ end
19
+ end
20
+
21
+ def test_pipeline
22
+ require 'rbbt/util/python'
23
+ model = VectorModel.new
24
+ model.post_process do |elements|
25
+ elements.collect{|e| e['label'] }
26
+ end
27
+ model.eval_model do |file, elements|
28
+ RbbtPython.run :transformers do
29
+ classifier ||= transformers.pipeline("sentiment-analysis")
30
+ classifier.call(elements)
31
+ end
32
+ end
33
+
34
+ assert_equal ["POSITIVE"], model.eval("I've been waiting for a HuggingFace course my whole life.")
35
+ end
36
+
6
37
  def test_sst_eval
7
38
  TmpFile.with_file do |dir|
8
39
  checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
9
40
 
10
41
  model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
11
42
 
12
- model.class_labels = ["Bad", "Good"]
43
+ model.model_options[:class_labels] = ["Bad", "Good"]
13
44
 
14
45
  assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
15
-
16
46
  end
17
47
  end
18
48
 
@@ -22,7 +52,8 @@ class TestHuggingface < Test::Unit::TestCase
22
52
  checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
23
53
 
24
54
  model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
25
- model.class_labels = ["Bad", "Good"]
55
+
56
+ model.model_options[:class_labels] = %w(Bad Good)
26
57
 
27
58
  assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
28
59
 
@@ -33,6 +64,9 @@ class TestHuggingface < Test::Unit::TestCase
33
64
  model.train
34
65
 
35
66
  assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
67
+
68
+ model = VectorModel.new dir
69
+ assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
36
70
  end
37
71
  end
38
72
 
@@ -40,7 +74,7 @@ class TestHuggingface < Test::Unit::TestCase
40
74
  checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
41
75
 
42
76
  model = HuggingfaceModel.new "SequenceClassification", checkpoint
43
- model.class_labels = ["Bad", "Good"]
77
+ model.model_options[:class_labels] = ["Bad", "Good"]
44
78
 
45
79
  assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
46
80
 
@@ -58,7 +92,7 @@ class TestHuggingface < Test::Unit::TestCase
58
92
  checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
59
93
 
60
94
  model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
61
- model.class_labels = ["Bad", "Good"]
95
+ model.model_options[:class_labels] = ["Bad", "Good"]
62
96
 
63
97
  assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
64
98
 
@@ -69,15 +103,20 @@ class TestHuggingface < Test::Unit::TestCase
69
103
  model.train
70
104
 
71
105
  model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
72
- model.class_labels = ["Bad", "Good"]
73
106
 
74
107
  assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
75
108
 
76
- model = HuggingfaceModel.new "SequenceClassification", model.model_file
77
- model.class_labels = ["Bad", "Good"]
109
+ model_file = model.model_file
110
+
111
+ model = HuggingfaceModel.new "SequenceClassification", model_file
112
+ model.model_options[:class_labels] = ["Bad", "Good"]
78
113
 
79
114
  assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
80
115
 
116
+ model = VectorModel.new dir
117
+
118
+ assert_equal "Good", model.eval("This is dog")
119
+
81
120
  end
82
121
  end
83
122
 
@@ -108,9 +147,7 @@ class TestHuggingface < Test::Unit::TestCase
108
147
  model = HuggingfaceModel.new "MaskedLM", checkpoint
109
148
  assert_equal 3, model.eval(["Paris is the [MASK] of the France.", "The [MASK] worked very hard all the time.", "The [MASK] arrested the dangerous [MASK]."]).
110
149
  reject{|v| v.empty?}.length
111
-
112
150
  end
113
151
 
114
-
115
152
  end
116
153
 
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rbbt-dm
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.2.3
4
+ version: 1.2.6
5
5
  platform: ruby
6
6
  authors:
7
7
  - Miguel Vazquez
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2023-02-04 00:00:00.000000000 Z
11
+ date: 2023-02-08 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rbbt-util
@@ -107,12 +107,14 @@ files:
107
107
  - lib/rbbt/statistics/rank_product.rb
108
108
  - lib/rbbt/tensorflow.rb
109
109
  - lib/rbbt/vector/model.rb
110
+ - lib/rbbt/vector/model/huggingface.old.rb
110
111
  - lib/rbbt/vector/model/huggingface.rb
111
112
  - lib/rbbt/vector/model/random_forest.rb
112
113
  - lib/rbbt/vector/model/spaCy.rb
113
114
  - lib/rbbt/vector/model/svm.rb
114
115
  - lib/rbbt/vector/model/tensorflow.rb
115
116
  - lib/rbbt/vector/model/util.rb
117
+ - python/rbbt_dm/__init__.py
116
118
  - python/rbbt_dm/huggingface.py
117
119
  - share/R/MA.R
118
120
  - share/R/barcode.R