rbbt-dm 1.2.6 → 1.2.7

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 9744ab9faeaf4f9cc04947eb11103dbf0694dda624f805a5c6be27bb22af81ce
4
- data.tar.gz: d3a3903aa276a69e20cbd71213286449db396ecf5f6a4b4d80a64ab299041fbb
3
+ metadata.gz: 1c55843bf543c88167239f6e182495963e0683c5a7fdd7c3a7ab9bd501a78bc8
4
+ data.tar.gz: d01aaf45331766eac6d868749b8df72c49d1a6888f44f7a1d4f8cbfefe258c87
5
5
  SHA512:
6
- metadata.gz: 263fb609b37522874426bcd79374760399b4a9aaab443ae6d74c727f2d148474dd71ee0b2cfda7a50131dafbc314f66352f4285562a75f62144d2e05ccd214c7
7
- data.tar.gz: 1e0426429a38028a19b3f8c955e975138199c791dad8691de7fb760a5cbec3304f19341906a4457563a90de965ee2f12a5a639b928866b46258af2507eeb39fa
6
+ metadata.gz: 7b6a225ce0403759ab45f26d371d491c19fc76f6560771868a58b9de921fd3aa03750bd7aec95c34029f61f53e71e382958f2779ca790fde30958cfbd1169a0b
7
+ data.tar.gz: ae1b6d44072398fbde96a0cb31f9586076dee1a5c7e2ac32726c65ecaaa3d08b59ea627c7a0f9f4a8e87547d5a403452ea5bee1d0736d610bf73b6456cb99be9
@@ -75,14 +75,6 @@ class SpaCyModel < VectorModel
75
75
  d.cats.sort_by{|l,v| v.to_f || 0 }.last.first
76
76
  end
77
77
  end
78
- #nlp.(docs).cats.collect{|cats| cats.sort_by{|l,v| v.to_f }.last.first }
79
- #Log::ProgressBar.with_bar texts.length, :desc => "Evaluating documents" do |bar|
80
- # texts.collect do |text|
81
- # cats = nlp.(text).cats
82
- # bar.tick
83
- # cats.sort_by{|l,v| v.to_f }.last.first
84
- # end
85
- #end
86
78
  end
87
79
  end
88
80
  end
@@ -9,4 +9,22 @@ class VectorModel
9
9
  @bar.init
10
10
  @bar
11
11
  end
12
+
13
+ def balance_labels
14
+ counts = Misc.counts(@labels)
15
+ min = counts.values.min
16
+
17
+ used = {}
18
+ new_labels = []
19
+ new_features = []
20
+ @labels.zip(@features).shuffle.each do |label, features|
21
+ used[label] ||= 0
22
+ next if used[label] > min
23
+ used[label] += 1
24
+ new_labels << label
25
+ new_features << features
26
+ end
27
+ @labels = new_labels
28
+ @features = new_features
29
+ end
12
30
  end
@@ -2,7 +2,7 @@ require 'rbbt/util/R'
2
2
  require 'rbbt/vector/model/util'
3
3
 
4
4
  class VectorModel
5
- attr_accessor :directory, :model_file, :extract_features, :train_model, :eval_model, :post_process
5
+ attr_accessor :directory, :model_file, :extract_features, :train_model, :eval_model, :post_process, :balance
6
6
  attr_accessor :features, :names, :labels, :factor_levels
7
7
  attr_accessor :model_options
8
8
 
@@ -270,19 +270,32 @@ cat(paste(label, sep="\\n", collapse="\\n"));
270
270
  Open.write(@post_process_file_R, post_process)
271
271
  end
272
272
 
273
-
274
273
  Open.write(@levels_file, @factor_levels.to_yaml) if @factor_levels
275
274
  Open.write(@names_file, @names * "\n" + "\n") if @names
276
275
  Open.write(@options_file, @model_options.to_json) if @model_options
277
276
  end
278
277
 
279
278
  def train
280
- case
281
- when Proc === @train_model
282
- self.instance_exec(@model_file, @features, @labels, @names, @factor_levels, &@train_model)
283
- when String === @train_model
284
- VectorModel.R_train(@model_file, @features, @labels, train_model, @names, @factor_levels)
279
+ begin
280
+ if @balance
281
+ @original_features = @features
282
+ @original_labels = @labels
283
+ self.balance_labels
284
+ end
285
+
286
+ case
287
+ when Proc === @train_model
288
+ self.instance_exec(@model_file, @features, @labels, @names, @factor_levels, &@train_model)
289
+ when String === @train_model
290
+ VectorModel.R_train(@model_file, @features, @labels, train_model, @names, @factor_levels)
291
+ end
292
+ ensure
293
+ if @balance
294
+ @features = @original_features
295
+ @labels = @original_labels
296
+ end
285
297
  end
298
+
286
299
  save_models if @directory
287
300
  end
288
301
 
@@ -438,6 +451,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
438
451
  @features = orig_features
439
452
  @labels = orig_labels
440
453
  end unless folds == -1
454
+
441
455
  self.reset_model if self.respond_to? :reset_model
442
456
  self.train unless folds == 1
443
457
  res
@@ -53,7 +53,7 @@ def load_tsv(tsv_file):
53
53
 
54
54
  def tsv_dataset(tokenizer, tsv_file):
55
55
  dataset = load_tsv(tsv_file)
56
- tokenized_dataset = dataset.map(lambda example: tokenizer(example["text"], truncation=True) , batched=True)
56
+ tokenized_dataset = dataset.map(lambda example: tokenizer(example["text"], truncation=True, max_length=512) , batched=True)
57
57
  return tokenized_dataset
58
58
 
59
59
  def training_args(*args, **kwargs):
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rbbt-dm
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.2.6
4
+ version: 1.2.7
5
5
  platform: ruby
6
6
  authors:
7
7
  - Miguel Vazquez