rbbt-dm 1.2.4 → 1.2.6

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: abaea1fff82b5e14a84dc9afc966fc8dde6482d50769d196854c1d619adebaf3
4
- data.tar.gz: 561b8864fc2c0ba271a2a658da0d3492c7481a2368b40c3b91fe6edb4ebca4cd
3
+ metadata.gz: 9744ab9faeaf4f9cc04947eb11103dbf0694dda624f805a5c6be27bb22af81ce
4
+ data.tar.gz: d3a3903aa276a69e20cbd71213286449db396ecf5f6a4b4d80a64ab299041fbb
5
5
  SHA512:
6
- metadata.gz: f26f6b27f1beb2554fa78369d1d618cc13175e0c9bb0e789b9490dcae0f7f6df4449a3c72d183ae22c96324d4e2f1ab0352bde8068c1c18871d52c5f5b53c235
7
- data.tar.gz: bb33d93cbe24ea974beedb0530f9af317dec06c7e76f32c37d724322ba05f241c6b79a706a88f1bbe703ac4bc78c53c220f28c3f38cf7939477274b8747c436e
6
+ metadata.gz: 263fb609b37522874426bcd79374760399b4a9aaab443ae6d74c727f2d148474dd71ee0b2cfda7a50131dafbc314f66352f4285562a75f62144d2e05ccd214c7
7
+ data.tar.gz: 1e0426429a38028a19b3f8c955e975138199c791dad8691de7fb760a5cbec3304f19341906a4457563a90de965ee2f12a5a639b928866b46258af2507eeb39fa
@@ -0,0 +1,160 @@
1
+ require 'rbbt/vector/model'
2
+ require 'rbbt/util/python'
3
+
4
+ RbbtPython.add_path Rbbt.python.find(:lib)
5
+ RbbtPython.init_rbbt
6
+
7
+ class HuggingfaceModel < VectorModel
8
+
9
+ attr_accessor :checkpoint, :task, :locate_tokens, :class_labels, :class_weights, :training_args
10
+
11
+ def self.tsv_dataset(tsv_dataset_file, elements, labels = nil)
12
+
13
+ if labels
14
+ Open.write(tsv_dataset_file) do |ffile|
15
+ ffile.puts ["label", "text"].flatten * "\t"
16
+ elements.zip(labels).each do |element,label|
17
+ ffile.puts [label, element].flatten * "\t"
18
+ end
19
+ end
20
+ else
21
+ Open.write(tsv_dataset_file) do |ffile|
22
+ ffile.puts ["text"].flatten * "\t"
23
+ elements.each{|element| ffile.puts element }
24
+ end
25
+ end
26
+
27
+ tsv_dataset_file
28
+ end
29
+
30
+ def self.call_method(name, *args)
31
+ RbbtPython.import_method("rbbt_dm.huggingface", name).call(*args)
32
+ end
33
+
34
+ def call_method(name, *args)
35
+ HuggingfaceModel.call_method(name, *args)
36
+ end
37
+
38
+ #def input_tsv_file
39
+ # File.join(@directory, 'dataset.tsv') if @directory
40
+ #end
41
+
42
+ #def checkpoint_dir
43
+ # File.join(@directory, 'checkpoints') if @directory
44
+ #end
45
+
46
+ def self.run_model(model, tokenizer, elements, labels = nil, training_args = {}, class_weights = nil)
47
+ TmpFile.with_file do |tmpfile|
48
+ tsv_file = File.join(tmpfile, 'dataset.tsv')
49
+
50
+ if training_args
51
+ training_args = training_args.dup
52
+ checkpoint_dir = training_args.delete(:checkpoint_dir)
53
+ end
54
+
55
+ checkpoint_dir = File.join(tmpfile, 'checkpoints')
56
+
57
+ Open.mkdir File.dirname(tsv_file)
58
+ Open.mkdir File.dirname(checkpoint_dir)
59
+
60
+ if labels
61
+ training_args_obj = call_method(:training_args, checkpoint_dir, **training_args)
62
+ call_method(:train_model, model, tokenizer, training_args_obj, tsv_dataset(tsv_file, elements, labels), class_weights)
63
+ else
64
+ locate_tokens, training_args = training_args, {}
65
+ if Array === elements
66
+ training_args_obj = call_method(:training_args, checkpoint_dir)
67
+ call_method(:predict_model, model, tokenizer, training_args_obj, tsv_dataset(tsv_file, elements), locate_tokens)
68
+ else
69
+ call_method(:eval_model, model, tokenizer, [elements], locate_tokens)
70
+ end
71
+ end
72
+ end
73
+ end
74
+
75
+ def init_model
76
+ @model, @tokenizer = call_method(:load_model_and_tokenizer, @task, @checkpoint)
77
+ end
78
+
79
+ def reset_model
80
+ init_model
81
+ end
82
+
83
+ def initialize(task, initial_checkpoint = nil, *args)
84
+ super(*args)
85
+ @task = task
86
+
87
+ @checkpoint = model_file && File.exists?(model_file)? model_file : initial_checkpoint
88
+
89
+ init_model
90
+
91
+ @locate_tokens = @tokenizer.special_tokens_map["mask_token"] if @task == "MaskedLM"
92
+
93
+ @training_args = {}
94
+
95
+ train_model do |file,elements,labels|
96
+ HuggingfaceModel.run_model(@model, @tokenizer, elements, labels, @training_args, @class_weights)
97
+
98
+ @model.save_pretrained(file) if file
99
+ @tokenizer.save_pretrained(file) if file
100
+ end
101
+
102
+ eval_model do |file,elements|
103
+ @model, @tokenizer = HuggingfaceModel.call_method(:load_model_and_tokenizer, @task, @checkpoint)
104
+ HuggingfaceModel.run_model(@model, @tokenizer, elements, nil, @locate_tokens)
105
+ end
106
+
107
+ post_process do |result|
108
+ if result.respond_to?(:predictions)
109
+ single = false
110
+ predictions = result.predictions
111
+ elsif result["token_positions"]
112
+ predictions = result["result"].predictions
113
+ token_positions = result["token_positions"]
114
+ else
115
+ single = true
116
+ predictions = result["logits"]
117
+ end
118
+
119
+ result = case @task
120
+ when "SequenceClassification"
121
+ RbbtPython.collect(predictions) do |logits|
122
+ logits = RbbtPython.numpy2ruby logits
123
+ best_class = logits.index logits.max
124
+ best_class = @class_labels[best_class] if @class_labels
125
+ best_class
126
+ end
127
+ when "MaskedLM"
128
+ all_token_positions = token_positions.to_a
129
+
130
+ i = 0
131
+ RbbtPython.collect(predictions) do |item_logits|
132
+ item_token_positions = all_token_positions[i]
133
+ i += 1
134
+
135
+ item_logits = RbbtPython.numpy2ruby(item_logits)
136
+ item_masks = item_token_positions.collect do |token_positions|
137
+
138
+ best = item_logits.values_at(*token_positions).collect do |logits|
139
+ best_token, best_score = nil
140
+ logits.each_with_index do |v,i|
141
+ if best_score.nil? || v > best_score
142
+ best_token, best_score = i, v
143
+ end
144
+ end
145
+ best_token
146
+ end
147
+
148
+ best.collect{|b| @tokenizer.decode(b) } * "|"
149
+ end
150
+ Array === @locate_tokens ? item_masks : item_masks.first
151
+ end
152
+ else
153
+ logits
154
+ end
155
+
156
+ single ? result.first : result
157
+ end
158
+ end
159
+ end
160
+
@@ -6,9 +6,7 @@ RbbtPython.init_rbbt
6
6
 
7
7
  class HuggingfaceModel < VectorModel
8
8
 
9
- attr_accessor :checkpoint, :task, :locate_tokens, :class_labels, :class_weights
10
-
11
- def tsv_dataset(tsv_dataset_file, elements, labels = nil)
9
+ def self.tsv_dataset(tsv_dataset_file, elements, labels = nil)
12
10
 
13
11
  if labels
14
12
  Open.write(tsv_dataset_file) do |ffile|
@@ -27,59 +25,74 @@ class HuggingfaceModel < VectorModel
27
25
  tsv_dataset_file
28
26
  end
29
27
 
30
- def call_method(name, *args)
31
- RbbtPython.import_method("rbbt_dm.huggingface", name).call(*args)
32
- end
33
-
34
- def input_tsv_file
35
- File.join(@directory, 'dataset.tsv') if @directory
36
- end
28
+ def initialize(task, checkpoint, *args)
29
+ options = args.pop if Hash === args.last
30
+ options = Misc.add_defaults options, :task => task, :checkpoint => checkpoint
31
+ super(*args)
32
+ @model_options ||= {}
33
+ @model_options.merge!(options)
37
34
 
38
- def checkpoint_dir
39
- File.join(@directory, 'checkpoints') if @directory
40
- end
35
+ eval_model do |directory,texts|
36
+ checkpoint = directory && File.directory?(directory) ? directory : @model_options[:checkpoint]
41
37
 
42
- def run_model(elements, labels = nil)
43
- TmpFile.with_file do |tmpfile|
44
- tsv_file = input_tsv_file || File.join(tmpfile, 'dataset.tsv')
45
- output_dir = checkpoint_dir || File.join(tmpfile, 'checkpoints')
38
+ if @model.nil?
39
+ @model, @tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_model_and_tokenizer, @model_options[:task], checkpoint)
40
+ end
41
+
42
+ if Array === texts
46
43
 
47
- Open.mkdir File.dirname(output_dir)
48
- Open.mkdir File.dirname(tsv_file)
44
+ if @model_options.include?(:locate_tokens)
45
+ locate_tokens = @model_options[:locate_tokens]
46
+ elsif @model_options[:task] == "MaskedLM"
47
+ @model_options[:locate_tokens] = locate_tokens = @tokenizer.special_tokens_map["mask_token"]
48
+ end
49
49
 
50
- if labels
51
- training_args = call_method(:training_args, output_dir)
52
- call_method(:train_model, @model, @tokenizer, training_args, tsv_dataset(tsv_file, elements, labels), @class_weights)
53
- else
54
- if Array === elements
55
- training_args = call_method(:training_args, output_dir)
56
- call_method(:predict_model, @model, @tokenizer, training_args, tsv_dataset(tsv_file, elements), @locate_tokens)
50
+ if @directory
51
+ tsv_file = File.join(@directory, 'dataset.tsv')
52
+ checkpoint_dir = File.join(@directory, 'checkpoints')
57
53
  else
58
- call_method(:eval_model, @model, @tokenizer, [elements], @locate_tokens)
54
+ tmpdir = TmpFile.tmp_file
55
+ Open.mkdir tmpdir
56
+ tsv_file = File.join(tmpdir, 'dataset.tsv')
57
+ checkpoint_dir = File.join(tmpdir, 'checkpoints')
58
+ end
59
+
60
+ dataset_file = HuggingfaceModel.tsv_dataset(tsv_file, texts)
61
+ training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, checkpoint_dir, @model_options[:training_args])
62
+
63
+ begin
64
+ RbbtPython.call_method("rbbt_dm.huggingface", :predict_model, @model, @tokenizer, training_args_obj, dataset_file, locate_tokens)
65
+ ensure
66
+ Open.rm_rf tmpdir if tmpdir
59
67
  end
68
+ else
69
+ RbbtPython.call_method("rbbt_dm.huggingface", :eval_model, @model, @tokenizer, [texts], locate_tokens)
60
70
  end
61
71
  end
62
- end
63
-
64
- def initialize(task, initial_checkpoint = nil, *args)
65
- super(*args)
66
- @task = task
67
72
 
68
- @checkpoint = model_file && File.exists?(model_file)? model_file : initial_checkpoint
73
+ train_model do |directory,texts,labels|
74
+ checkpoint = directory && File.directory?(directory) ? directory : @model_options[:checkpoint]
75
+ @model, @tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_model_and_tokenizer, @model_options[:task], checkpoint)
69
76
 
70
- @model, @tokenizer = call_method(:load_model_and_tokenizer, @task, @checkpoint)
77
+ if @directory
78
+ tsv_file = File.join(@directory, 'dataset.tsv')
79
+ checkpoint_dir = File.join(@directory, 'checkpoints')
80
+ else
81
+ tmpdir = TmpFile.tmp_file
82
+ Open.mkdir tmpdir
83
+ tsv_file = File.join(tmpdir, 'dataset.tsv')
84
+ checkpoint_dir = File.join(tmpdir, 'checkpoints')
85
+ end
71
86
 
72
- @locate_tokens = @tokenizer.special_tokens_map["mask_token"] if @task == "MaskedLM"
87
+ training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, checkpoint_dir, @model_options[:training_args])
88
+ dataset_file = HuggingfaceModel.tsv_dataset(tsv_file, texts, labels)
73
89
 
74
- train_model do |file,elements,labels|
75
- run_model(elements, labels)
90
+ RbbtPython.call_method("rbbt_dm.huggingface", :train_model, @model, @tokenizer, training_args_obj, dataset_file, @model_options[:class_weights])
76
91
 
77
- @model.save_pretrained(file) if file
78
- @tokenizer.save_pretrained(file) if file
79
- end
92
+ Open.rm_rf tmpdir if tmpdir
80
93
 
81
- eval_model do |file,elements|
82
- run_model(elements)
94
+ @model.save_pretrained(directory) if directory
95
+ @tokenizer.save_pretrained(directory) if directory
83
96
  end
84
97
 
85
98
  post_process do |result|
@@ -94,12 +107,13 @@ class HuggingfaceModel < VectorModel
94
107
  predictions = result["logits"]
95
108
  end
96
109
 
97
- result = case @task
110
+ task, class_labels, locate_tokens = @model_options.values_at :task, :class_labels, :locate_tokens
111
+ result = case task
98
112
  when "SequenceClassification"
99
113
  RbbtPython.collect(predictions) do |logits|
100
114
  logits = RbbtPython.numpy2ruby logits
101
115
  best_class = logits.index logits.max
102
- best_class = @class_labels[best_class] if @class_labels
116
+ best_class = class_labels[best_class] if class_labels
103
117
  best_class
104
118
  end
105
119
  when "MaskedLM"
@@ -125,7 +139,7 @@ class HuggingfaceModel < VectorModel
125
139
 
126
140
  best.collect{|b| @tokenizer.decode(b) } * "|"
127
141
  end
128
- Array === @locate_tokens ? item_masks : item_masks.first
142
+ Array === locate_tokens ? item_masks : item_masks.first
129
143
  end
130
144
  else
131
145
  logits
@@ -133,6 +147,15 @@ class HuggingfaceModel < VectorModel
133
147
 
134
148
  single ? result.first : result
135
149
  end
150
+
151
+
152
+ save_models if @directory
136
153
  end
154
+
155
+ def reset_model
156
+ @model, @tokenizer = nil
157
+ Open.rm @model_file
158
+ end
159
+
137
160
  end
138
161
 
@@ -4,6 +4,7 @@ require 'rbbt/vector/model/util'
4
4
  class VectorModel
5
5
  attr_accessor :directory, :model_file, :extract_features, :train_model, :eval_model, :post_process
6
6
  attr_accessor :features, :names, :labels, :factor_levels
7
+ attr_accessor :model_options
7
8
 
8
9
  def extract_features(&block)
9
10
  @extract_features = block if block_given?
@@ -126,7 +127,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
126
127
  instance_eval code, file
127
128
  end
128
129
 
129
- def initialize(directory = nil, extract_features = nil, train_model = nil, eval_model = nil, names = nil, factor_levels = nil)
130
+ def initialize(directory = nil, extract_features = nil, train_model = nil, eval_model = nil, post_process = nil, names = nil, factor_levels = nil)
130
131
  @directory = directory
131
132
  if @directory
132
133
  FileUtils.mkdir_p @directory unless File.exists?(@directory)
@@ -135,10 +136,18 @@ cat(paste(label, sep="\\n", collapse="\\n"));
135
136
  @extract_features_file = File.join(@directory, "features")
136
137
  @train_model_file = File.join(@directory, "train_model")
137
138
  @eval_model_file = File.join(@directory, "eval_model")
139
+ @post_process_file = File.join(@directory, "post_process")
138
140
  @train_model_file_R = File.join(@directory, "train_model.R")
139
141
  @eval_model_file_R = File.join(@directory, "eval_model.R")
142
+ @post_process_file_R = File.join(@directory, "post_process.R")
140
143
  @names_file = File.join(@directory, "feature_names")
141
144
  @levels_file = File.join(@directory, "levels")
145
+ @options_file = File.join(@directory, "options.json")
146
+
147
+ if File.exists?(@options_file)
148
+ @model_options = JSON.parse(Open.read(@options_file))
149
+ IndiferentHash.setup(@model_options)
150
+ end
142
151
  end
143
152
 
144
153
  if extract_features.nil?
@@ -169,6 +178,17 @@ cat(paste(label, sep="\\n", collapse="\\n"));
169
178
  @eval_model = eval_model
170
179
  end
171
180
 
181
+ if post_process.nil?
182
+ if @post_process_file && File.exists?(@post_process_file)
183
+ @post_process = __load_method @post_process_file
184
+ elsif @post_process_file_R && File.exists?(@post_process_file_R)
185
+ @post_process = Open.read(@post_process_file_R)
186
+ end
187
+ else
188
+ @post_process = post_process
189
+ end
190
+
191
+
172
192
  if names.nil?
173
193
  if @names_file && File.exists?(@names_file)
174
194
  @names = Open.read(@names_file).split("\n")
@@ -240,8 +260,20 @@ cat(paste(label, sep="\\n", collapse="\\n"));
240
260
  Open.write(@eval_model_file_R, eval_model)
241
261
  end
242
262
 
263
+ case
264
+ when Proc === post_process
265
+ begin
266
+ Open.write(@post_process_file, post_process.source)
267
+ rescue
268
+ end
269
+ when String === post_process
270
+ Open.write(@post_process_file_R, post_process)
271
+ end
272
+
273
+
243
274
  Open.write(@levels_file, @factor_levels.to_yaml) if @factor_levels
244
275
  Open.write(@names_file, @names * "\n" + "\n") if @names
276
+ Open.write(@options_file, @model_options.to_json) if @model_options
245
277
  end
246
278
 
247
279
  def train
@@ -251,7 +283,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
251
283
  when String === @train_model
252
284
  VectorModel.R_train(@model_file, @features, @labels, train_model, @names, @factor_levels)
253
285
  end
254
- save_models
286
+ save_models if @directory
255
287
  end
256
288
 
257
289
  def run(code)
@@ -299,38 +331,6 @@ cat(paste(label, sep="\\n", collapse="\\n"));
299
331
  result
300
332
  end
301
333
 
302
- #def cross_validation(folds = 10)
303
- # saved_features = @features
304
- # saved_labels = @labels
305
- # seq = (0..features.length - 1).to_a
306
-
307
- # chunk_size = features.length / folds
308
-
309
- # acc = []
310
- # folds.times do
311
- # seq = seq.shuffle
312
- # eval_chunk = seq[0..chunk_size]
313
- # train_chunk = seq[chunk_size.. -1]
314
-
315
- # eval_features = @features.values_at *eval_chunk
316
- # eval_labels = @labels.values_at *eval_chunk
317
-
318
- # @features = @features.values_at *train_chunk
319
- # @labels = @labels.values_at *train_chunk
320
-
321
- # train
322
- # predictions = eval_list eval_features, false
323
-
324
- # acc << predictions.zip(eval_labels).collect{|pred,lab| pred - lab < 0.5 ? 1 : 0}.inject(0){|acc,e| acc +=e} / chunk_size
325
-
326
- # @features = saved_features
327
- # @labels = saved_labels
328
- # end
329
-
330
- # acc
331
- #end
332
- #
333
-
334
334
  def self.f1_metrics(test, predicted, good_label = nil)
335
335
  tp, tn, fp, fn, pr, re, f1 = [0, 0, 0, 0, nil, nil, nil]
336
336
 
@@ -413,6 +413,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
413
413
  @features = train_set
414
414
  @labels = train_labels
415
415
 
416
+ self.reset_model if self.respond_to? :reset_model
416
417
  self.train
417
418
  predictions = self.eval_list test_set, false
418
419
 
@@ -437,6 +438,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
437
438
  @features = orig_features
438
439
  @labels = orig_labels
439
440
  end unless folds == -1
441
+ self.reset_model if self.respond_to? :reset_model
440
442
  self.train unless folds == 1
441
443
  res
442
444
  end
@@ -17,6 +17,17 @@ def load_model_and_tokenizer(task, checkpoint):
17
17
  tokenizer = load_tokenizer(task, checkpoint)
18
18
  return model, tokenizer
19
19
 
20
+ def load_model_and_tokenizer_from_directory(directory):
21
+ import os
22
+ import json
23
+ options_file = os.path.join(directory, 'options.json')
24
+ f = open(options_file, "r")
25
+ options = json.load(f.read())
26
+ f.close()
27
+ task = options["task"]
28
+ checkpoint = options["checkpoint"]
29
+ return load_model_and_tokenizer(task, checkpoint)
30
+
20
31
  #{{{ SIMPLE EVALUATE
21
32
 
22
33
  def forward(model, features):
@@ -57,34 +68,34 @@ def train_model(model, tokenizer, training_args, tsv_file, class_weights=None):
57
68
  tokenized_dataset = tsv_dataset(tokenizer, tsv_file)
58
69
 
59
70
  if (not class_weights == None):
60
- import torch
61
- from torch import nn
62
-
63
- class WeightTrainer(Trainer):
64
- def compute_loss(self, model, inputs, return_outputs=False):
65
- labels = inputs.get("labels")
66
- # forward pass
67
- outputs = model(**inputs)
68
- logits = outputs.get('logits')
69
- # compute custom loss
70
- loss_fct = nn.CrossEntropyLoss(weight=torch.tensor(class_weights).to(model.device))
71
- loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
72
- return (loss, outputs) if return_outputs else loss
73
-
74
- trainer = WeightTrainer(
75
- model,
76
- training_args,
77
- train_dataset = tokenized_dataset["train"],
78
- tokenizer = tokenizer
79
- )
71
+ import torch
72
+ from torch import nn
73
+
74
+ class WeightTrainer(Trainer):
75
+ def compute_loss(self, model, inputs, return_outputs=False):
76
+ labels = inputs.get("labels")
77
+ # forward pass
78
+ outputs = model(**inputs)
79
+ logits = outputs.get('logits')
80
+ # compute custom loss
81
+ loss_fct = nn.CrossEntropyLoss(weight=torch.tensor(class_weights).to(model.device))
82
+ loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
83
+ return (loss, outputs) if return_outputs else loss
84
+
85
+ trainer = WeightTrainer(
86
+ model,
87
+ training_args,
88
+ train_dataset = tokenized_dataset["train"],
89
+ tokenizer = tokenizer
90
+ )
80
91
  else:
81
92
 
82
- trainer = Trainer(
83
- model,
84
- training_args,
85
- train_dataset = tokenized_dataset["train"],
86
- tokenizer = tokenizer
87
- )
93
+ trainer = Trainer(
94
+ model,
95
+ training_args,
96
+ train_dataset = tokenized_dataset["train"],
97
+ tokenizer = tokenizer
98
+ )
88
99
 
89
100
  trainer.train()
90
101
 
@@ -3,6 +3,21 @@ require 'rbbt/vector/model/huggingface'
3
3
 
4
4
  class TestHuggingface < Test::Unit::TestCase
5
5
 
6
+ def test_options
7
+ TmpFile.with_file do |dir|
8
+ checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
9
+ task = "SequenceClassification"
10
+
11
+ model = HuggingfaceModel.new task, checkpoint, dir, :class_labels => %w(bad good)
12
+ iii model.eval "This is dog"
13
+ iii model.eval "This is cat"
14
+ iii model.eval(["This is dog", "This is cat"])
15
+
16
+ model = VectorModel.new dir
17
+ iii model.eval(["This is dog", "This is cat"])
18
+ end
19
+ end
20
+
6
21
  def test_pipeline
7
22
  require 'rbbt/util/python'
8
23
  model = VectorModel.new
@@ -25,7 +40,7 @@ class TestHuggingface < Test::Unit::TestCase
25
40
 
26
41
  model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
27
42
 
28
- model.class_labels = ["Bad", "Good"]
43
+ model.model_options[:class_labels] = ["Bad", "Good"]
29
44
 
30
45
  assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
31
46
  end
@@ -37,7 +52,8 @@ class TestHuggingface < Test::Unit::TestCase
37
52
  checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
38
53
 
39
54
  model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
40
- model.class_labels = ["Bad", "Good"]
55
+
56
+ model.model_options[:class_labels] = %w(Bad Good)
41
57
 
42
58
  assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
43
59
 
@@ -48,6 +64,9 @@ class TestHuggingface < Test::Unit::TestCase
48
64
  model.train
49
65
 
50
66
  assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
67
+
68
+ model = VectorModel.new dir
69
+ assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
51
70
  end
52
71
  end
53
72
 
@@ -55,7 +74,7 @@ class TestHuggingface < Test::Unit::TestCase
55
74
  checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
56
75
 
57
76
  model = HuggingfaceModel.new "SequenceClassification", checkpoint
58
- model.class_labels = ["Bad", "Good"]
77
+ model.model_options[:class_labels] = ["Bad", "Good"]
59
78
 
60
79
  assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
61
80
 
@@ -73,7 +92,7 @@ class TestHuggingface < Test::Unit::TestCase
73
92
  checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
74
93
 
75
94
  model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
76
- model.class_labels = ["Bad", "Good"]
95
+ model.model_options[:class_labels] = ["Bad", "Good"]
77
96
 
78
97
  assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
79
98
 
@@ -84,15 +103,20 @@ class TestHuggingface < Test::Unit::TestCase
84
103
  model.train
85
104
 
86
105
  model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
87
- model.class_labels = ["Bad", "Good"]
88
106
 
89
107
  assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
90
108
 
91
- model = HuggingfaceModel.new "SequenceClassification", model.model_file
92
- model.class_labels = ["Bad", "Good"]
109
+ model_file = model.model_file
110
+
111
+ model = HuggingfaceModel.new "SequenceClassification", model_file
112
+ model.model_options[:class_labels] = ["Bad", "Good"]
93
113
 
94
114
  assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
95
115
 
116
+ model = VectorModel.new dir
117
+
118
+ assert_equal "Good", model.eval("This is dog")
119
+
96
120
  end
97
121
  end
98
122
 
@@ -123,9 +147,7 @@ class TestHuggingface < Test::Unit::TestCase
123
147
  model = HuggingfaceModel.new "MaskedLM", checkpoint
124
148
  assert_equal 3, model.eval(["Paris is the [MASK] of the France.", "The [MASK] worked very hard all the time.", "The [MASK] arrested the dangerous [MASK]."]).
125
149
  reject{|v| v.empty?}.length
126
-
127
150
  end
128
151
 
129
-
130
152
  end
131
153
 
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rbbt-dm
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.2.4
4
+ version: 1.2.6
5
5
  platform: ruby
6
6
  authors:
7
7
  - Miguel Vazquez
8
- autorequire:
8
+ autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2023-02-07 00:00:00.000000000 Z
11
+ date: 2023-02-08 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rbbt-util
@@ -107,6 +107,7 @@ files:
107
107
  - lib/rbbt/statistics/rank_product.rb
108
108
  - lib/rbbt/tensorflow.rb
109
109
  - lib/rbbt/vector/model.rb
110
+ - lib/rbbt/vector/model/huggingface.old.rb
110
111
  - lib/rbbt/vector/model/huggingface.rb
111
112
  - lib/rbbt/vector/model/random_forest.rb
112
113
  - lib/rbbt/vector/model/spaCy.rb
@@ -143,7 +144,7 @@ files:
143
144
  homepage: http://github.com/mikisvaz/rbbt-phgx
144
145
  licenses: []
145
146
  metadata: {}
146
- post_install_message:
147
+ post_install_message:
147
148
  rdoc_options: []
148
149
  require_paths:
149
150
  - lib
@@ -158,22 +159,22 @@ required_rubygems_version: !ruby/object:Gem::Requirement
158
159
  - !ruby/object:Gem::Version
159
160
  version: '0'
160
161
  requirements: []
161
- rubygems_version: 3.1.6
162
- signing_key:
162
+ rubygems_version: 3.1.2
163
+ signing_key:
163
164
  specification_version: 4
164
165
  summary: Data-mining and statistics
165
166
  test_files:
166
- - test/rbbt/statistics/test_hypergeometric.rb
167
- - test/rbbt/statistics/test_fisher.rb
168
- - test/rbbt/statistics/test_fdr.rb
169
- - test/rbbt/statistics/test_random_walk.rb
170
- - test/rbbt/test_ml_task.rb
167
+ - test/test_helper.rb
171
168
  - test/rbbt/vector/test_model.rb
169
+ - test/rbbt/vector/model/test_huggingface.rb
172
170
  - test/rbbt/vector/model/test_tensorflow.rb
173
171
  - test/rbbt/vector/model/test_spaCy.rb
174
- - test/rbbt/vector/model/test_huggingface.rb
175
172
  - test/rbbt/vector/model/test_svm.rb
176
- - test/rbbt/network/test_paths.rb
177
- - test/rbbt/matrix/test_barcode.rb
173
+ - test/rbbt/statistics/test_random_walk.rb
174
+ - test/rbbt/statistics/test_fisher.rb
175
+ - test/rbbt/statistics/test_fdr.rb
176
+ - test/rbbt/statistics/test_hypergeometric.rb
178
177
  - test/rbbt/test_stan.rb
179
- - test/test_helper.rb
178
+ - test/rbbt/matrix/test_barcode.rb
179
+ - test/rbbt/test_ml_task.rb
180
+ - test/rbbt/network/test_paths.rb