rbbt-dm 1.2.6 → 1.2.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/lib/rbbt/vector/model/spaCy.rb +0 -8
- data/lib/rbbt/vector/model/util.rb +18 -0
- data/lib/rbbt/vector/model.rb +21 -7
- data/python/rbbt_dm/huggingface.py +1 -1
- metadata +1 -1
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 1c55843bf543c88167239f6e182495963e0683c5a7fdd7c3a7ab9bd501a78bc8
|
4
|
+
data.tar.gz: d01aaf45331766eac6d868749b8df72c49d1a6888f44f7a1d4f8cbfefe258c87
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 7b6a225ce0403759ab45f26d371d491c19fc76f6560771868a58b9de921fd3aa03750bd7aec95c34029f61f53e71e382958f2779ca790fde30958cfbd1169a0b
|
7
|
+
data.tar.gz: ae1b6d44072398fbde96a0cb31f9586076dee1a5c7e2ac32726c65ecaaa3d08b59ea627c7a0f9f4a8e87547d5a403452ea5bee1d0736d610bf73b6456cb99be9
|
@@ -75,14 +75,6 @@ class SpaCyModel < VectorModel
|
|
75
75
|
d.cats.sort_by{|l,v| v.to_f || 0 }.last.first
|
76
76
|
end
|
77
77
|
end
|
78
|
-
#nlp.(docs).cats.collect{|cats| cats.sort_by{|l,v| v.to_f }.last.first }
|
79
|
-
#Log::ProgressBar.with_bar texts.length, :desc => "Evaluating documents" do |bar|
|
80
|
-
# texts.collect do |text|
|
81
|
-
# cats = nlp.(text).cats
|
82
|
-
# bar.tick
|
83
|
-
# cats.sort_by{|l,v| v.to_f }.last.first
|
84
|
-
# end
|
85
|
-
#end
|
86
78
|
end
|
87
79
|
end
|
88
80
|
end
|
@@ -9,4 +9,22 @@ class VectorModel
|
|
9
9
|
@bar.init
|
10
10
|
@bar
|
11
11
|
end
|
12
|
+
|
13
|
+
def balance_labels
|
14
|
+
counts = Misc.counts(@labels)
|
15
|
+
min = counts.values.min
|
16
|
+
|
17
|
+
used = {}
|
18
|
+
new_labels = []
|
19
|
+
new_features = []
|
20
|
+
@labels.zip(@features).shuffle.each do |label, features|
|
21
|
+
used[label] ||= 0
|
22
|
+
next if used[label] > min
|
23
|
+
used[label] += 1
|
24
|
+
new_labels << label
|
25
|
+
new_features << features
|
26
|
+
end
|
27
|
+
@labels = new_labels
|
28
|
+
@features = new_features
|
29
|
+
end
|
12
30
|
end
|
data/lib/rbbt/vector/model.rb
CHANGED
@@ -2,7 +2,7 @@ require 'rbbt/util/R'
|
|
2
2
|
require 'rbbt/vector/model/util'
|
3
3
|
|
4
4
|
class VectorModel
|
5
|
-
attr_accessor :directory, :model_file, :extract_features, :train_model, :eval_model, :post_process
|
5
|
+
attr_accessor :directory, :model_file, :extract_features, :train_model, :eval_model, :post_process, :balance
|
6
6
|
attr_accessor :features, :names, :labels, :factor_levels
|
7
7
|
attr_accessor :model_options
|
8
8
|
|
@@ -270,19 +270,32 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
270
270
|
Open.write(@post_process_file_R, post_process)
|
271
271
|
end
|
272
272
|
|
273
|
-
|
274
273
|
Open.write(@levels_file, @factor_levels.to_yaml) if @factor_levels
|
275
274
|
Open.write(@names_file, @names * "\n" + "\n") if @names
|
276
275
|
Open.write(@options_file, @model_options.to_json) if @model_options
|
277
276
|
end
|
278
277
|
|
279
278
|
def train
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
279
|
+
begin
|
280
|
+
if @balance
|
281
|
+
@original_features = @features
|
282
|
+
@original_labels = @labels
|
283
|
+
self.balance_labels
|
284
|
+
end
|
285
|
+
|
286
|
+
case
|
287
|
+
when Proc === @train_model
|
288
|
+
self.instance_exec(@model_file, @features, @labels, @names, @factor_levels, &@train_model)
|
289
|
+
when String === @train_model
|
290
|
+
VectorModel.R_train(@model_file, @features, @labels, train_model, @names, @factor_levels)
|
291
|
+
end
|
292
|
+
ensure
|
293
|
+
if @balance
|
294
|
+
@features = @original_features
|
295
|
+
@labels = @original_labels
|
296
|
+
end
|
285
297
|
end
|
298
|
+
|
286
299
|
save_models if @directory
|
287
300
|
end
|
288
301
|
|
@@ -438,6 +451,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
438
451
|
@features = orig_features
|
439
452
|
@labels = orig_labels
|
440
453
|
end unless folds == -1
|
454
|
+
|
441
455
|
self.reset_model if self.respond_to? :reset_model
|
442
456
|
self.train unless folds == 1
|
443
457
|
res
|
@@ -53,7 +53,7 @@ def load_tsv(tsv_file):
|
|
53
53
|
|
54
54
|
def tsv_dataset(tokenizer, tsv_file):
|
55
55
|
dataset = load_tsv(tsv_file)
|
56
|
-
tokenized_dataset = dataset.map(lambda example: tokenizer(example["text"], truncation=True) , batched=True)
|
56
|
+
tokenized_dataset = dataset.map(lambda example: tokenizer(example["text"], truncation=True, max_length=512) , batched=True)
|
57
57
|
return tokenized_dataset
|
58
58
|
|
59
59
|
def training_args(*args, **kwargs):
|