rbbt-dm 1.2.6 → 1.2.7
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/lib/rbbt/vector/model/spaCy.rb +0 -8
- data/lib/rbbt/vector/model/util.rb +18 -0
- data/lib/rbbt/vector/model.rb +21 -7
- data/python/rbbt_dm/huggingface.py +1 -1
- metadata +1 -1
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 1c55843bf543c88167239f6e182495963e0683c5a7fdd7c3a7ab9bd501a78bc8
|
4
|
+
data.tar.gz: d01aaf45331766eac6d868749b8df72c49d1a6888f44f7a1d4f8cbfefe258c87
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 7b6a225ce0403759ab45f26d371d491c19fc76f6560771868a58b9de921fd3aa03750bd7aec95c34029f61f53e71e382958f2779ca790fde30958cfbd1169a0b
|
7
|
+
data.tar.gz: ae1b6d44072398fbde96a0cb31f9586076dee1a5c7e2ac32726c65ecaaa3d08b59ea627c7a0f9f4a8e87547d5a403452ea5bee1d0736d610bf73b6456cb99be9
|
@@ -75,14 +75,6 @@ class SpaCyModel < VectorModel
|
|
75
75
|
d.cats.sort_by{|l,v| v.to_f || 0 }.last.first
|
76
76
|
end
|
77
77
|
end
|
78
|
-
#nlp.(docs).cats.collect{|cats| cats.sort_by{|l,v| v.to_f }.last.first }
|
79
|
-
#Log::ProgressBar.with_bar texts.length, :desc => "Evaluating documents" do |bar|
|
80
|
-
# texts.collect do |text|
|
81
|
-
# cats = nlp.(text).cats
|
82
|
-
# bar.tick
|
83
|
-
# cats.sort_by{|l,v| v.to_f }.last.first
|
84
|
-
# end
|
85
|
-
#end
|
86
78
|
end
|
87
79
|
end
|
88
80
|
end
|
@@ -9,4 +9,22 @@ class VectorModel
|
|
9
9
|
@bar.init
|
10
10
|
@bar
|
11
11
|
end
|
12
|
+
|
13
|
+
def balance_labels
|
14
|
+
counts = Misc.counts(@labels)
|
15
|
+
min = counts.values.min
|
16
|
+
|
17
|
+
used = {}
|
18
|
+
new_labels = []
|
19
|
+
new_features = []
|
20
|
+
@labels.zip(@features).shuffle.each do |label, features|
|
21
|
+
used[label] ||= 0
|
22
|
+
next if used[label] > min
|
23
|
+
used[label] += 1
|
24
|
+
new_labels << label
|
25
|
+
new_features << features
|
26
|
+
end
|
27
|
+
@labels = new_labels
|
28
|
+
@features = new_features
|
29
|
+
end
|
12
30
|
end
|
data/lib/rbbt/vector/model.rb
CHANGED
@@ -2,7 +2,7 @@ require 'rbbt/util/R'
|
|
2
2
|
require 'rbbt/vector/model/util'
|
3
3
|
|
4
4
|
class VectorModel
|
5
|
-
attr_accessor :directory, :model_file, :extract_features, :train_model, :eval_model, :post_process
|
5
|
+
attr_accessor :directory, :model_file, :extract_features, :train_model, :eval_model, :post_process, :balance
|
6
6
|
attr_accessor :features, :names, :labels, :factor_levels
|
7
7
|
attr_accessor :model_options
|
8
8
|
|
@@ -270,19 +270,32 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
270
270
|
Open.write(@post_process_file_R, post_process)
|
271
271
|
end
|
272
272
|
|
273
|
-
|
274
273
|
Open.write(@levels_file, @factor_levels.to_yaml) if @factor_levels
|
275
274
|
Open.write(@names_file, @names * "\n" + "\n") if @names
|
276
275
|
Open.write(@options_file, @model_options.to_json) if @model_options
|
277
276
|
end
|
278
277
|
|
279
278
|
def train
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
279
|
+
begin
|
280
|
+
if @balance
|
281
|
+
@original_features = @features
|
282
|
+
@original_labels = @labels
|
283
|
+
self.balance_labels
|
284
|
+
end
|
285
|
+
|
286
|
+
case
|
287
|
+
when Proc === @train_model
|
288
|
+
self.instance_exec(@model_file, @features, @labels, @names, @factor_levels, &@train_model)
|
289
|
+
when String === @train_model
|
290
|
+
VectorModel.R_train(@model_file, @features, @labels, train_model, @names, @factor_levels)
|
291
|
+
end
|
292
|
+
ensure
|
293
|
+
if @balance
|
294
|
+
@features = @original_features
|
295
|
+
@labels = @original_labels
|
296
|
+
end
|
285
297
|
end
|
298
|
+
|
286
299
|
save_models if @directory
|
287
300
|
end
|
288
301
|
|
@@ -438,6 +451,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
438
451
|
@features = orig_features
|
439
452
|
@labels = orig_labels
|
440
453
|
end unless folds == -1
|
454
|
+
|
441
455
|
self.reset_model if self.respond_to? :reset_model
|
442
456
|
self.train unless folds == 1
|
443
457
|
res
|
@@ -53,7 +53,7 @@ def load_tsv(tsv_file):
|
|
53
53
|
|
54
54
|
def tsv_dataset(tokenizer, tsv_file):
|
55
55
|
dataset = load_tsv(tsv_file)
|
56
|
-
tokenized_dataset = dataset.map(lambda example: tokenizer(example["text"], truncation=True) , batched=True)
|
56
|
+
tokenized_dataset = dataset.map(lambda example: tokenizer(example["text"], truncation=True, max_length=512) , batched=True)
|
57
57
|
return tokenized_dataset
|
58
58
|
|
59
59
|
def training_args(*args, **kwargs):
|