rbbt-dm 1.2.4 → 1.2.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: abaea1fff82b5e14a84dc9afc966fc8dde6482d50769d196854c1d619adebaf3
4
- data.tar.gz: 561b8864fc2c0ba271a2a658da0d3492c7481a2368b40c3b91fe6edb4ebca4cd
3
+ metadata.gz: 1c55843bf543c88167239f6e182495963e0683c5a7fdd7c3a7ab9bd501a78bc8
4
+ data.tar.gz: d01aaf45331766eac6d868749b8df72c49d1a6888f44f7a1d4f8cbfefe258c87
5
5
  SHA512:
6
- metadata.gz: f26f6b27f1beb2554fa78369d1d618cc13175e0c9bb0e789b9490dcae0f7f6df4449a3c72d183ae22c96324d4e2f1ab0352bde8068c1c18871d52c5f5b53c235
7
- data.tar.gz: bb33d93cbe24ea974beedb0530f9af317dec06c7e76f32c37d724322ba05f241c6b79a706a88f1bbe703ac4bc78c53c220f28c3f38cf7939477274b8747c436e
6
+ metadata.gz: 7b6a225ce0403759ab45f26d371d491c19fc76f6560771868a58b9de921fd3aa03750bd7aec95c34029f61f53e71e382958f2779ca790fde30958cfbd1169a0b
7
+ data.tar.gz: ae1b6d44072398fbde96a0cb31f9586076dee1a5c7e2ac32726c65ecaaa3d08b59ea627c7a0f9f4a8e87547d5a403452ea5bee1d0736d610bf73b6456cb99be9
@@ -0,0 +1,160 @@
1
+ require 'rbbt/vector/model'
2
+ require 'rbbt/util/python'
3
+
4
+ RbbtPython.add_path Rbbt.python.find(:lib)
5
+ RbbtPython.init_rbbt
6
+
7
+ class HuggingfaceModel < VectorModel
8
+
9
+ attr_accessor :checkpoint, :task, :locate_tokens, :class_labels, :class_weights, :training_args
10
+
11
+ def self.tsv_dataset(tsv_dataset_file, elements, labels = nil)
12
+
13
+ if labels
14
+ Open.write(tsv_dataset_file) do |ffile|
15
+ ffile.puts ["label", "text"].flatten * "\t"
16
+ elements.zip(labels).each do |element,label|
17
+ ffile.puts [label, element].flatten * "\t"
18
+ end
19
+ end
20
+ else
21
+ Open.write(tsv_dataset_file) do |ffile|
22
+ ffile.puts ["text"].flatten * "\t"
23
+ elements.each{|element| ffile.puts element }
24
+ end
25
+ end
26
+
27
+ tsv_dataset_file
28
+ end
29
+
30
+ def self.call_method(name, *args)
31
+ RbbtPython.import_method("rbbt_dm.huggingface", name).call(*args)
32
+ end
33
+
34
+ def call_method(name, *args)
35
+ HuggingfaceModel.call_method(name, *args)
36
+ end
37
+
38
+ #def input_tsv_file
39
+ # File.join(@directory, 'dataset.tsv') if @directory
40
+ #end
41
+
42
+ #def checkpoint_dir
43
+ # File.join(@directory, 'checkpoints') if @directory
44
+ #end
45
+
46
+ def self.run_model(model, tokenizer, elements, labels = nil, training_args = {}, class_weights = nil)
47
+ TmpFile.with_file do |tmpfile|
48
+ tsv_file = File.join(tmpfile, 'dataset.tsv')
49
+
50
+ if training_args
51
+ training_args = training_args.dup
52
+ checkpoint_dir = training_args.delete(:checkpoint_dir)
53
+ end
54
+
55
+ checkpoint_dir = File.join(tmpfile, 'checkpoints')
56
+
57
+ Open.mkdir File.dirname(tsv_file)
58
+ Open.mkdir File.dirname(checkpoint_dir)
59
+
60
+ if labels
61
+ training_args_obj = call_method(:training_args, checkpoint_dir, **training_args)
62
+ call_method(:train_model, model, tokenizer, training_args_obj, tsv_dataset(tsv_file, elements, labels), class_weights)
63
+ else
64
+ locate_tokens, training_args = training_args, {}
65
+ if Array === elements
66
+ training_args_obj = call_method(:training_args, checkpoint_dir)
67
+ call_method(:predict_model, model, tokenizer, training_args_obj, tsv_dataset(tsv_file, elements), locate_tokens)
68
+ else
69
+ call_method(:eval_model, model, tokenizer, [elements], locate_tokens)
70
+ end
71
+ end
72
+ end
73
+ end
74
+
75
+ def init_model
76
+ @model, @tokenizer = call_method(:load_model_and_tokenizer, @task, @checkpoint)
77
+ end
78
+
79
+ def reset_model
80
+ init_model
81
+ end
82
+
83
+ def initialize(task, initial_checkpoint = nil, *args)
84
+ super(*args)
85
+ @task = task
86
+
87
+ @checkpoint = model_file && File.exists?(model_file)? model_file : initial_checkpoint
88
+
89
+ init_model
90
+
91
+ @locate_tokens = @tokenizer.special_tokens_map["mask_token"] if @task == "MaskedLM"
92
+
93
+ @training_args = {}
94
+
95
+ train_model do |file,elements,labels|
96
+ HuggingfaceModel.run_model(@model, @tokenizer, elements, labels, @training_args, @class_weights)
97
+
98
+ @model.save_pretrained(file) if file
99
+ @tokenizer.save_pretrained(file) if file
100
+ end
101
+
102
+ eval_model do |file,elements|
103
+ @model, @tokenizer = HuggingfaceModel.call_method(:load_model_and_tokenizer, @task, @checkpoint)
104
+ HuggingfaceModel.run_model(@model, @tokenizer, elements, nil, @locate_tokens)
105
+ end
106
+
107
+ post_process do |result|
108
+ if result.respond_to?(:predictions)
109
+ single = false
110
+ predictions = result.predictions
111
+ elsif result["token_positions"]
112
+ predictions = result["result"].predictions
113
+ token_positions = result["token_positions"]
114
+ else
115
+ single = true
116
+ predictions = result["logits"]
117
+ end
118
+
119
+ result = case @task
120
+ when "SequenceClassification"
121
+ RbbtPython.collect(predictions) do |logits|
122
+ logits = RbbtPython.numpy2ruby logits
123
+ best_class = logits.index logits.max
124
+ best_class = @class_labels[best_class] if @class_labels
125
+ best_class
126
+ end
127
+ when "MaskedLM"
128
+ all_token_positions = token_positions.to_a
129
+
130
+ i = 0
131
+ RbbtPython.collect(predictions) do |item_logits|
132
+ item_token_positions = all_token_positions[i]
133
+ i += 1
134
+
135
+ item_logits = RbbtPython.numpy2ruby(item_logits)
136
+ item_masks = item_token_positions.collect do |token_positions|
137
+
138
+ best = item_logits.values_at(*token_positions).collect do |logits|
139
+ best_token, best_score = nil
140
+ logits.each_with_index do |v,i|
141
+ if best_score.nil? || v > best_score
142
+ best_token, best_score = i, v
143
+ end
144
+ end
145
+ best_token
146
+ end
147
+
148
+ best.collect{|b| @tokenizer.decode(b) } * "|"
149
+ end
150
+ Array === @locate_tokens ? item_masks : item_masks.first
151
+ end
152
+ else
153
+ logits
154
+ end
155
+
156
+ single ? result.first : result
157
+ end
158
+ end
159
+ end
160
+
@@ -6,9 +6,7 @@ RbbtPython.init_rbbt
6
6
 
7
7
  class HuggingfaceModel < VectorModel
8
8
 
9
- attr_accessor :checkpoint, :task, :locate_tokens, :class_labels, :class_weights
10
-
11
- def tsv_dataset(tsv_dataset_file, elements, labels = nil)
9
+ def self.tsv_dataset(tsv_dataset_file, elements, labels = nil)
12
10
 
13
11
  if labels
14
12
  Open.write(tsv_dataset_file) do |ffile|
@@ -27,59 +25,74 @@ class HuggingfaceModel < VectorModel
27
25
  tsv_dataset_file
28
26
  end
29
27
 
30
- def call_method(name, *args)
31
- RbbtPython.import_method("rbbt_dm.huggingface", name).call(*args)
32
- end
33
-
34
- def input_tsv_file
35
- File.join(@directory, 'dataset.tsv') if @directory
36
- end
28
+ def initialize(task, checkpoint, *args)
29
+ options = args.pop if Hash === args.last
30
+ options = Misc.add_defaults options, :task => task, :checkpoint => checkpoint
31
+ super(*args)
32
+ @model_options ||= {}
33
+ @model_options.merge!(options)
37
34
 
38
- def checkpoint_dir
39
- File.join(@directory, 'checkpoints') if @directory
40
- end
35
+ eval_model do |directory,texts|
36
+ checkpoint = directory && File.directory?(directory) ? directory : @model_options[:checkpoint]
41
37
 
42
- def run_model(elements, labels = nil)
43
- TmpFile.with_file do |tmpfile|
44
- tsv_file = input_tsv_file || File.join(tmpfile, 'dataset.tsv')
45
- output_dir = checkpoint_dir || File.join(tmpfile, 'checkpoints')
38
+ if @model.nil?
39
+ @model, @tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_model_and_tokenizer, @model_options[:task], checkpoint)
40
+ end
41
+
42
+ if Array === texts
46
43
 
47
- Open.mkdir File.dirname(output_dir)
48
- Open.mkdir File.dirname(tsv_file)
44
+ if @model_options.include?(:locate_tokens)
45
+ locate_tokens = @model_options[:locate_tokens]
46
+ elsif @model_options[:task] == "MaskedLM"
47
+ @model_options[:locate_tokens] = locate_tokens = @tokenizer.special_tokens_map["mask_token"]
48
+ end
49
49
 
50
- if labels
51
- training_args = call_method(:training_args, output_dir)
52
- call_method(:train_model, @model, @tokenizer, training_args, tsv_dataset(tsv_file, elements, labels), @class_weights)
53
- else
54
- if Array === elements
55
- training_args = call_method(:training_args, output_dir)
56
- call_method(:predict_model, @model, @tokenizer, training_args, tsv_dataset(tsv_file, elements), @locate_tokens)
50
+ if @directory
51
+ tsv_file = File.join(@directory, 'dataset.tsv')
52
+ checkpoint_dir = File.join(@directory, 'checkpoints')
57
53
  else
58
- call_method(:eval_model, @model, @tokenizer, [elements], @locate_tokens)
54
+ tmpdir = TmpFile.tmp_file
55
+ Open.mkdir tmpdir
56
+ tsv_file = File.join(tmpdir, 'dataset.tsv')
57
+ checkpoint_dir = File.join(tmpdir, 'checkpoints')
58
+ end
59
+
60
+ dataset_file = HuggingfaceModel.tsv_dataset(tsv_file, texts)
61
+ training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, checkpoint_dir, @model_options[:training_args])
62
+
63
+ begin
64
+ RbbtPython.call_method("rbbt_dm.huggingface", :predict_model, @model, @tokenizer, training_args_obj, dataset_file, locate_tokens)
65
+ ensure
66
+ Open.rm_rf tmpdir if tmpdir
59
67
  end
68
+ else
69
+ RbbtPython.call_method("rbbt_dm.huggingface", :eval_model, @model, @tokenizer, [texts], locate_tokens)
60
70
  end
61
71
  end
62
- end
63
-
64
- def initialize(task, initial_checkpoint = nil, *args)
65
- super(*args)
66
- @task = task
67
72
 
68
- @checkpoint = model_file && File.exists?(model_file)? model_file : initial_checkpoint
73
+ train_model do |directory,texts,labels|
74
+ checkpoint = directory && File.directory?(directory) ? directory : @model_options[:checkpoint]
75
+ @model, @tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_model_and_tokenizer, @model_options[:task], checkpoint)
69
76
 
70
- @model, @tokenizer = call_method(:load_model_and_tokenizer, @task, @checkpoint)
77
+ if @directory
78
+ tsv_file = File.join(@directory, 'dataset.tsv')
79
+ checkpoint_dir = File.join(@directory, 'checkpoints')
80
+ else
81
+ tmpdir = TmpFile.tmp_file
82
+ Open.mkdir tmpdir
83
+ tsv_file = File.join(tmpdir, 'dataset.tsv')
84
+ checkpoint_dir = File.join(tmpdir, 'checkpoints')
85
+ end
71
86
 
72
- @locate_tokens = @tokenizer.special_tokens_map["mask_token"] if @task == "MaskedLM"
87
+ training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, checkpoint_dir, @model_options[:training_args])
88
+ dataset_file = HuggingfaceModel.tsv_dataset(tsv_file, texts, labels)
73
89
 
74
- train_model do |file,elements,labels|
75
- run_model(elements, labels)
90
+ RbbtPython.call_method("rbbt_dm.huggingface", :train_model, @model, @tokenizer, training_args_obj, dataset_file, @model_options[:class_weights])
76
91
 
77
- @model.save_pretrained(file) if file
78
- @tokenizer.save_pretrained(file) if file
79
- end
92
+ Open.rm_rf tmpdir if tmpdir
80
93
 
81
- eval_model do |file,elements|
82
- run_model(elements)
94
+ @model.save_pretrained(directory) if directory
95
+ @tokenizer.save_pretrained(directory) if directory
83
96
  end
84
97
 
85
98
  post_process do |result|
@@ -94,12 +107,13 @@ class HuggingfaceModel < VectorModel
94
107
  predictions = result["logits"]
95
108
  end
96
109
 
97
- result = case @task
110
+ task, class_labels, locate_tokens = @model_options.values_at :task, :class_labels, :locate_tokens
111
+ result = case task
98
112
  when "SequenceClassification"
99
113
  RbbtPython.collect(predictions) do |logits|
100
114
  logits = RbbtPython.numpy2ruby logits
101
115
  best_class = logits.index logits.max
102
- best_class = @class_labels[best_class] if @class_labels
116
+ best_class = class_labels[best_class] if class_labels
103
117
  best_class
104
118
  end
105
119
  when "MaskedLM"
@@ -125,7 +139,7 @@ class HuggingfaceModel < VectorModel
125
139
 
126
140
  best.collect{|b| @tokenizer.decode(b) } * "|"
127
141
  end
128
- Array === @locate_tokens ? item_masks : item_masks.first
142
+ Array === locate_tokens ? item_masks : item_masks.first
129
143
  end
130
144
  else
131
145
  logits
@@ -133,6 +147,15 @@ class HuggingfaceModel < VectorModel
133
147
 
134
148
  single ? result.first : result
135
149
  end
150
+
151
+
152
+ save_models if @directory
136
153
  end
154
+
155
+ def reset_model
156
+ @model, @tokenizer = nil
157
+ Open.rm @model_file
158
+ end
159
+
137
160
  end
138
161
 
@@ -75,14 +75,6 @@ class SpaCyModel < VectorModel
75
75
  d.cats.sort_by{|l,v| v.to_f || 0 }.last.first
76
76
  end
77
77
  end
78
- #nlp.(docs).cats.collect{|cats| cats.sort_by{|l,v| v.to_f }.last.first }
79
- #Log::ProgressBar.with_bar texts.length, :desc => "Evaluating documents" do |bar|
80
- # texts.collect do |text|
81
- # cats = nlp.(text).cats
82
- # bar.tick
83
- # cats.sort_by{|l,v| v.to_f }.last.first
84
- # end
85
- #end
86
78
  end
87
79
  end
88
80
  end
@@ -9,4 +9,22 @@ class VectorModel
9
9
  @bar.init
10
10
  @bar
11
11
  end
12
+
13
+ def balance_labels
14
+ counts = Misc.counts(@labels)
15
+ min = counts.values.min
16
+
17
+ used = {}
18
+ new_labels = []
19
+ new_features = []
20
+ @labels.zip(@features).shuffle.each do |label, features|
21
+ used[label] ||= 0
22
+ next if used[label] > min
23
+ used[label] += 1
24
+ new_labels << label
25
+ new_features << features
26
+ end
27
+ @labels = new_labels
28
+ @features = new_features
29
+ end
12
30
  end
@@ -2,8 +2,9 @@ require 'rbbt/util/R'
2
2
  require 'rbbt/vector/model/util'
3
3
 
4
4
  class VectorModel
5
- attr_accessor :directory, :model_file, :extract_features, :train_model, :eval_model, :post_process
5
+ attr_accessor :directory, :model_file, :extract_features, :train_model, :eval_model, :post_process, :balance
6
6
  attr_accessor :features, :names, :labels, :factor_levels
7
+ attr_accessor :model_options
7
8
 
8
9
  def extract_features(&block)
9
10
  @extract_features = block if block_given?
@@ -126,7 +127,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
126
127
  instance_eval code, file
127
128
  end
128
129
 
129
- def initialize(directory = nil, extract_features = nil, train_model = nil, eval_model = nil, names = nil, factor_levels = nil)
130
+ def initialize(directory = nil, extract_features = nil, train_model = nil, eval_model = nil, post_process = nil, names = nil, factor_levels = nil)
130
131
  @directory = directory
131
132
  if @directory
132
133
  FileUtils.mkdir_p @directory unless File.exists?(@directory)
@@ -135,10 +136,18 @@ cat(paste(label, sep="\\n", collapse="\\n"));
135
136
  @extract_features_file = File.join(@directory, "features")
136
137
  @train_model_file = File.join(@directory, "train_model")
137
138
  @eval_model_file = File.join(@directory, "eval_model")
139
+ @post_process_file = File.join(@directory, "post_process")
138
140
  @train_model_file_R = File.join(@directory, "train_model.R")
139
141
  @eval_model_file_R = File.join(@directory, "eval_model.R")
142
+ @post_process_file_R = File.join(@directory, "post_process.R")
140
143
  @names_file = File.join(@directory, "feature_names")
141
144
  @levels_file = File.join(@directory, "levels")
145
+ @options_file = File.join(@directory, "options.json")
146
+
147
+ if File.exists?(@options_file)
148
+ @model_options = JSON.parse(Open.read(@options_file))
149
+ IndiferentHash.setup(@model_options)
150
+ end
142
151
  end
143
152
 
144
153
  if extract_features.nil?
@@ -169,6 +178,17 @@ cat(paste(label, sep="\\n", collapse="\\n"));
169
178
  @eval_model = eval_model
170
179
  end
171
180
 
181
+ if post_process.nil?
182
+ if @post_process_file && File.exists?(@post_process_file)
183
+ @post_process = __load_method @post_process_file
184
+ elsif @post_process_file_R && File.exists?(@post_process_file_R)
185
+ @post_process = Open.read(@post_process_file_R)
186
+ end
187
+ else
188
+ @post_process = post_process
189
+ end
190
+
191
+
172
192
  if names.nil?
173
193
  if @names_file && File.exists?(@names_file)
174
194
  @names = Open.read(@names_file).split("\n")
@@ -240,18 +260,43 @@ cat(paste(label, sep="\\n", collapse="\\n"));
240
260
  Open.write(@eval_model_file_R, eval_model)
241
261
  end
242
262
 
263
+ case
264
+ when Proc === post_process
265
+ begin
266
+ Open.write(@post_process_file, post_process.source)
267
+ rescue
268
+ end
269
+ when String === post_process
270
+ Open.write(@post_process_file_R, post_process)
271
+ end
272
+
243
273
  Open.write(@levels_file, @factor_levels.to_yaml) if @factor_levels
244
274
  Open.write(@names_file, @names * "\n" + "\n") if @names
275
+ Open.write(@options_file, @model_options.to_json) if @model_options
245
276
  end
246
277
 
247
278
  def train
248
- case
249
- when Proc === @train_model
250
- self.instance_exec(@model_file, @features, @labels, @names, @factor_levels, &@train_model)
251
- when String === @train_model
252
- VectorModel.R_train(@model_file, @features, @labels, train_model, @names, @factor_levels)
279
+ begin
280
+ if @balance
281
+ @original_features = @features
282
+ @original_labels = @labels
283
+ self.balance_labels
284
+ end
285
+
286
+ case
287
+ when Proc === @train_model
288
+ self.instance_exec(@model_file, @features, @labels, @names, @factor_levels, &@train_model)
289
+ when String === @train_model
290
+ VectorModel.R_train(@model_file, @features, @labels, train_model, @names, @factor_levels)
291
+ end
292
+ ensure
293
+ if @balance
294
+ @features = @original_features
295
+ @labels = @original_labels
296
+ end
253
297
  end
254
- save_models
298
+
299
+ save_models if @directory
255
300
  end
256
301
 
257
302
  def run(code)
@@ -299,38 +344,6 @@ cat(paste(label, sep="\\n", collapse="\\n"));
299
344
  result
300
345
  end
301
346
 
302
- #def cross_validation(folds = 10)
303
- # saved_features = @features
304
- # saved_labels = @labels
305
- # seq = (0..features.length - 1).to_a
306
-
307
- # chunk_size = features.length / folds
308
-
309
- # acc = []
310
- # folds.times do
311
- # seq = seq.shuffle
312
- # eval_chunk = seq[0..chunk_size]
313
- # train_chunk = seq[chunk_size.. -1]
314
-
315
- # eval_features = @features.values_at *eval_chunk
316
- # eval_labels = @labels.values_at *eval_chunk
317
-
318
- # @features = @features.values_at *train_chunk
319
- # @labels = @labels.values_at *train_chunk
320
-
321
- # train
322
- # predictions = eval_list eval_features, false
323
-
324
- # acc << predictions.zip(eval_labels).collect{|pred,lab| pred - lab < 0.5 ? 1 : 0}.inject(0){|acc,e| acc +=e} / chunk_size
325
-
326
- # @features = saved_features
327
- # @labels = saved_labels
328
- # end
329
-
330
- # acc
331
- #end
332
- #
333
-
334
347
  def self.f1_metrics(test, predicted, good_label = nil)
335
348
  tp, tn, fp, fn, pr, re, f1 = [0, 0, 0, 0, nil, nil, nil]
336
349
 
@@ -413,6 +426,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
413
426
  @features = train_set
414
427
  @labels = train_labels
415
428
 
429
+ self.reset_model if self.respond_to? :reset_model
416
430
  self.train
417
431
  predictions = self.eval_list test_set, false
418
432
 
@@ -437,6 +451,8 @@ cat(paste(label, sep="\\n", collapse="\\n"));
437
451
  @features = orig_features
438
452
  @labels = orig_labels
439
453
  end unless folds == -1
454
+
455
+ self.reset_model if self.respond_to? :reset_model
440
456
  self.train unless folds == 1
441
457
  res
442
458
  end
@@ -17,6 +17,17 @@ def load_model_and_tokenizer(task, checkpoint):
17
17
  tokenizer = load_tokenizer(task, checkpoint)
18
18
  return model, tokenizer
19
19
 
20
+ def load_model_and_tokenizer_from_directory(directory):
21
+ import os
22
+ import json
23
+ options_file = os.path.join(directory, 'options.json')
24
+ f = open(options_file, "r")
25
+ options = json.load(f.read())
26
+ f.close()
27
+ task = options["task"]
28
+ checkpoint = options["checkpoint"]
29
+ return load_model_and_tokenizer(task, checkpoint)
30
+
20
31
  #{{{ SIMPLE EVALUATE
21
32
 
22
33
  def forward(model, features):
@@ -42,7 +53,7 @@ def load_tsv(tsv_file):
42
53
 
43
54
  def tsv_dataset(tokenizer, tsv_file):
44
55
  dataset = load_tsv(tsv_file)
45
- tokenized_dataset = dataset.map(lambda example: tokenizer(example["text"], truncation=True) , batched=True)
56
+ tokenized_dataset = dataset.map(lambda example: tokenizer(example["text"], truncation=True, max_length=512) , batched=True)
46
57
  return tokenized_dataset
47
58
 
48
59
  def training_args(*args, **kwargs):
@@ -57,34 +68,34 @@ def train_model(model, tokenizer, training_args, tsv_file, class_weights=None):
57
68
  tokenized_dataset = tsv_dataset(tokenizer, tsv_file)
58
69
 
59
70
  if (not class_weights == None):
60
- import torch
61
- from torch import nn
62
-
63
- class WeightTrainer(Trainer):
64
- def compute_loss(self, model, inputs, return_outputs=False):
65
- labels = inputs.get("labels")
66
- # forward pass
67
- outputs = model(**inputs)
68
- logits = outputs.get('logits')
69
- # compute custom loss
70
- loss_fct = nn.CrossEntropyLoss(weight=torch.tensor(class_weights).to(model.device))
71
- loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
72
- return (loss, outputs) if return_outputs else loss
73
-
74
- trainer = WeightTrainer(
75
- model,
76
- training_args,
77
- train_dataset = tokenized_dataset["train"],
78
- tokenizer = tokenizer
79
- )
71
+ import torch
72
+ from torch import nn
73
+
74
+ class WeightTrainer(Trainer):
75
+ def compute_loss(self, model, inputs, return_outputs=False):
76
+ labels = inputs.get("labels")
77
+ # forward pass
78
+ outputs = model(**inputs)
79
+ logits = outputs.get('logits')
80
+ # compute custom loss
81
+ loss_fct = nn.CrossEntropyLoss(weight=torch.tensor(class_weights).to(model.device))
82
+ loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
83
+ return (loss, outputs) if return_outputs else loss
84
+
85
+ trainer = WeightTrainer(
86
+ model,
87
+ training_args,
88
+ train_dataset = tokenized_dataset["train"],
89
+ tokenizer = tokenizer
90
+ )
80
91
  else:
81
92
 
82
- trainer = Trainer(
83
- model,
84
- training_args,
85
- train_dataset = tokenized_dataset["train"],
86
- tokenizer = tokenizer
87
- )
93
+ trainer = Trainer(
94
+ model,
95
+ training_args,
96
+ train_dataset = tokenized_dataset["train"],
97
+ tokenizer = tokenizer
98
+ )
88
99
 
89
100
  trainer.train()
90
101
 
@@ -3,6 +3,21 @@ require 'rbbt/vector/model/huggingface'
3
3
 
4
4
  class TestHuggingface < Test::Unit::TestCase
5
5
 
6
+ def test_options
7
+ TmpFile.with_file do |dir|
8
+ checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
9
+ task = "SequenceClassification"
10
+
11
+ model = HuggingfaceModel.new task, checkpoint, dir, :class_labels => %w(bad good)
12
+ iii model.eval "This is dog"
13
+ iii model.eval "This is cat"
14
+ iii model.eval(["This is dog", "This is cat"])
15
+
16
+ model = VectorModel.new dir
17
+ iii model.eval(["This is dog", "This is cat"])
18
+ end
19
+ end
20
+
6
21
  def test_pipeline
7
22
  require 'rbbt/util/python'
8
23
  model = VectorModel.new
@@ -25,7 +40,7 @@ class TestHuggingface < Test::Unit::TestCase
25
40
 
26
41
  model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
27
42
 
28
- model.class_labels = ["Bad", "Good"]
43
+ model.model_options[:class_labels] = ["Bad", "Good"]
29
44
 
30
45
  assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
31
46
  end
@@ -37,7 +52,8 @@ class TestHuggingface < Test::Unit::TestCase
37
52
  checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
38
53
 
39
54
  model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
40
- model.class_labels = ["Bad", "Good"]
55
+
56
+ model.model_options[:class_labels] = %w(Bad Good)
41
57
 
42
58
  assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
43
59
 
@@ -48,6 +64,9 @@ class TestHuggingface < Test::Unit::TestCase
48
64
  model.train
49
65
 
50
66
  assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
67
+
68
+ model = VectorModel.new dir
69
+ assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
51
70
  end
52
71
  end
53
72
 
@@ -55,7 +74,7 @@ class TestHuggingface < Test::Unit::TestCase
55
74
  checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
56
75
 
57
76
  model = HuggingfaceModel.new "SequenceClassification", checkpoint
58
- model.class_labels = ["Bad", "Good"]
77
+ model.model_options[:class_labels] = ["Bad", "Good"]
59
78
 
60
79
  assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
61
80
 
@@ -73,7 +92,7 @@ class TestHuggingface < Test::Unit::TestCase
73
92
  checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
74
93
 
75
94
  model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
76
- model.class_labels = ["Bad", "Good"]
95
+ model.model_options[:class_labels] = ["Bad", "Good"]
77
96
 
78
97
  assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
79
98
 
@@ -84,15 +103,20 @@ class TestHuggingface < Test::Unit::TestCase
84
103
  model.train
85
104
 
86
105
  model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
87
- model.class_labels = ["Bad", "Good"]
88
106
 
89
107
  assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
90
108
 
91
- model = HuggingfaceModel.new "SequenceClassification", model.model_file
92
- model.class_labels = ["Bad", "Good"]
109
+ model_file = model.model_file
110
+
111
+ model = HuggingfaceModel.new "SequenceClassification", model_file
112
+ model.model_options[:class_labels] = ["Bad", "Good"]
93
113
 
94
114
  assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
95
115
 
116
+ model = VectorModel.new dir
117
+
118
+ assert_equal "Good", model.eval("This is dog")
119
+
96
120
  end
97
121
  end
98
122
 
@@ -123,9 +147,7 @@ class TestHuggingface < Test::Unit::TestCase
123
147
  model = HuggingfaceModel.new "MaskedLM", checkpoint
124
148
  assert_equal 3, model.eval(["Paris is the [MASK] of the France.", "The [MASK] worked very hard all the time.", "The [MASK] arrested the dangerous [MASK]."]).
125
149
  reject{|v| v.empty?}.length
126
-
127
150
  end
128
151
 
129
-
130
152
  end
131
153
 
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rbbt-dm
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.2.4
4
+ version: 1.2.7
5
5
  platform: ruby
6
6
  authors:
7
7
  - Miguel Vazquez
8
- autorequire:
8
+ autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2023-02-07 00:00:00.000000000 Z
11
+ date: 2023-02-08 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rbbt-util
@@ -107,6 +107,7 @@ files:
107
107
  - lib/rbbt/statistics/rank_product.rb
108
108
  - lib/rbbt/tensorflow.rb
109
109
  - lib/rbbt/vector/model.rb
110
+ - lib/rbbt/vector/model/huggingface.old.rb
110
111
  - lib/rbbt/vector/model/huggingface.rb
111
112
  - lib/rbbt/vector/model/random_forest.rb
112
113
  - lib/rbbt/vector/model/spaCy.rb
@@ -143,7 +144,7 @@ files:
143
144
  homepage: http://github.com/mikisvaz/rbbt-phgx
144
145
  licenses: []
145
146
  metadata: {}
146
- post_install_message:
147
+ post_install_message:
147
148
  rdoc_options: []
148
149
  require_paths:
149
150
  - lib
@@ -158,22 +159,22 @@ required_rubygems_version: !ruby/object:Gem::Requirement
158
159
  - !ruby/object:Gem::Version
159
160
  version: '0'
160
161
  requirements: []
161
- rubygems_version: 3.1.6
162
- signing_key:
162
+ rubygems_version: 3.1.2
163
+ signing_key:
163
164
  specification_version: 4
164
165
  summary: Data-mining and statistics
165
166
  test_files:
166
- - test/rbbt/statistics/test_hypergeometric.rb
167
- - test/rbbt/statistics/test_fisher.rb
168
- - test/rbbt/statistics/test_fdr.rb
169
- - test/rbbt/statistics/test_random_walk.rb
170
- - test/rbbt/test_ml_task.rb
167
+ - test/test_helper.rb
171
168
  - test/rbbt/vector/test_model.rb
169
+ - test/rbbt/vector/model/test_huggingface.rb
172
170
  - test/rbbt/vector/model/test_tensorflow.rb
173
171
  - test/rbbt/vector/model/test_spaCy.rb
174
- - test/rbbt/vector/model/test_huggingface.rb
175
172
  - test/rbbt/vector/model/test_svm.rb
176
- - test/rbbt/network/test_paths.rb
177
- - test/rbbt/matrix/test_barcode.rb
173
+ - test/rbbt/statistics/test_random_walk.rb
174
+ - test/rbbt/statistics/test_fisher.rb
175
+ - test/rbbt/statistics/test_fdr.rb
176
+ - test/rbbt/statistics/test_hypergeometric.rb
178
177
  - test/rbbt/test_stan.rb
179
- - test/test_helper.rb
178
+ - test/rbbt/matrix/test_barcode.rb
179
+ - test/rbbt/test_ml_task.rb
180
+ - test/rbbt/network/test_paths.rb