rbbt-dm 1.2.6 → 1.2.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 9744ab9faeaf4f9cc04947eb11103dbf0694dda624f805a5c6be27bb22af81ce
4
- data.tar.gz: d3a3903aa276a69e20cbd71213286449db396ecf5f6a4b4d80a64ab299041fbb
3
+ metadata.gz: 1c55843bf543c88167239f6e182495963e0683c5a7fdd7c3a7ab9bd501a78bc8
4
+ data.tar.gz: d01aaf45331766eac6d868749b8df72c49d1a6888f44f7a1d4f8cbfefe258c87
5
5
  SHA512:
6
- metadata.gz: 263fb609b37522874426bcd79374760399b4a9aaab443ae6d74c727f2d148474dd71ee0b2cfda7a50131dafbc314f66352f4285562a75f62144d2e05ccd214c7
7
- data.tar.gz: 1e0426429a38028a19b3f8c955e975138199c791dad8691de7fb760a5cbec3304f19341906a4457563a90de965ee2f12a5a639b928866b46258af2507eeb39fa
6
+ metadata.gz: 7b6a225ce0403759ab45f26d371d491c19fc76f6560771868a58b9de921fd3aa03750bd7aec95c34029f61f53e71e382958f2779ca790fde30958cfbd1169a0b
7
+ data.tar.gz: ae1b6d44072398fbde96a0cb31f9586076dee1a5c7e2ac32726c65ecaaa3d08b59ea627c7a0f9f4a8e87547d5a403452ea5bee1d0736d610bf73b6456cb99be9
@@ -75,14 +75,6 @@ class SpaCyModel < VectorModel
75
75
  d.cats.sort_by{|l,v| v.to_f || 0 }.last.first
76
76
  end
77
77
  end
78
- #nlp.(docs).cats.collect{|cats| cats.sort_by{|l,v| v.to_f }.last.first }
79
- #Log::ProgressBar.with_bar texts.length, :desc => "Evaluating documents" do |bar|
80
- # texts.collect do |text|
81
- # cats = nlp.(text).cats
82
- # bar.tick
83
- # cats.sort_by{|l,v| v.to_f }.last.first
84
- # end
85
- #end
86
78
  end
87
79
  end
88
80
  end
@@ -9,4 +9,22 @@ class VectorModel
9
9
  @bar.init
10
10
  @bar
11
11
  end
12
+
13
+ def balance_labels
14
+ counts = Misc.counts(@labels)
15
+ min = counts.values.min
16
+
17
+ used = {}
18
+ new_labels = []
19
+ new_features = []
20
+ @labels.zip(@features).shuffle.each do |label, features|
21
+ used[label] ||= 0
22
+ next if used[label] > min
23
+ used[label] += 1
24
+ new_labels << label
25
+ new_features << features
26
+ end
27
+ @labels = new_labels
28
+ @features = new_features
29
+ end
12
30
  end
@@ -2,7 +2,7 @@ require 'rbbt/util/R'
2
2
  require 'rbbt/vector/model/util'
3
3
 
4
4
  class VectorModel
5
- attr_accessor :directory, :model_file, :extract_features, :train_model, :eval_model, :post_process
5
+ attr_accessor :directory, :model_file, :extract_features, :train_model, :eval_model, :post_process, :balance
6
6
  attr_accessor :features, :names, :labels, :factor_levels
7
7
  attr_accessor :model_options
8
8
 
@@ -270,19 +270,32 @@ cat(paste(label, sep="\\n", collapse="\\n"));
270
270
  Open.write(@post_process_file_R, post_process)
271
271
  end
272
272
 
273
-
274
273
  Open.write(@levels_file, @factor_levels.to_yaml) if @factor_levels
275
274
  Open.write(@names_file, @names * "\n" + "\n") if @names
276
275
  Open.write(@options_file, @model_options.to_json) if @model_options
277
276
  end
278
277
 
279
278
  def train
280
- case
281
- when Proc === @train_model
282
- self.instance_exec(@model_file, @features, @labels, @names, @factor_levels, &@train_model)
283
- when String === @train_model
284
- VectorModel.R_train(@model_file, @features, @labels, train_model, @names, @factor_levels)
279
+ begin
280
+ if @balance
281
+ @original_features = @features
282
+ @original_labels = @labels
283
+ self.balance_labels
284
+ end
285
+
286
+ case
287
+ when Proc === @train_model
288
+ self.instance_exec(@model_file, @features, @labels, @names, @factor_levels, &@train_model)
289
+ when String === @train_model
290
+ VectorModel.R_train(@model_file, @features, @labels, train_model, @names, @factor_levels)
291
+ end
292
+ ensure
293
+ if @balance
294
+ @features = @original_features
295
+ @labels = @original_labels
296
+ end
285
297
  end
298
+
286
299
  save_models if @directory
287
300
  end
288
301
 
@@ -438,6 +451,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
438
451
  @features = orig_features
439
452
  @labels = orig_labels
440
453
  end unless folds == -1
454
+
441
455
  self.reset_model if self.respond_to? :reset_model
442
456
  self.train unless folds == 1
443
457
  res
@@ -53,7 +53,7 @@ def load_tsv(tsv_file):
53
53
 
54
54
  def tsv_dataset(tokenizer, tsv_file):
55
55
  dataset = load_tsv(tsv_file)
56
- tokenized_dataset = dataset.map(lambda example: tokenizer(example["text"], truncation=True) , batched=True)
56
+ tokenized_dataset = dataset.map(lambda example: tokenizer(example["text"], truncation=True, max_length=512) , batched=True)
57
57
  return tokenized_dataset
58
58
 
59
59
  def training_args(*args, **kwargs):
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rbbt-dm
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.2.6
4
+ version: 1.2.7
5
5
  platform: ruby
6
6
  authors:
7
7
  - Miguel Vazquez