rbbt-dm 1.2.7 → 1.2.9
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/lib/rbbt/matrix/barcode.rb +2 -2
- data/lib/rbbt/matrix/differential.rb +3 -3
- data/lib/rbbt/matrix/knowledge_base.rb +1 -1
- data/lib/rbbt/plots/bar.rb +1 -1
- data/lib/rbbt/stan.rb +1 -1
- data/lib/rbbt/statistics/hypergeometric.rb +2 -1
- data/lib/rbbt/vector/model/huggingface/masked_lm.rb +50 -0
- data/lib/rbbt/vector/model/huggingface.rb +57 -38
- data/lib/rbbt/vector/model/pytorch_lightning.rb +35 -0
- data/lib/rbbt/vector/model/random_forest.rb +1 -1
- data/lib/rbbt/vector/model/spaCy.rb +8 -6
- data/lib/rbbt/vector/model/tensorflow.rb +6 -5
- data/lib/rbbt/vector/model/torch.rb +37 -0
- data/lib/rbbt/vector/model.rb +82 -52
- data/python/rbbt_dm/__init__.py +48 -1
- data/python/rbbt_dm/atcold/__init__.py +0 -0
- data/python/rbbt_dm/atcold/plot_lib.py +141 -0
- data/python/rbbt_dm/atcold/spiral.py +27 -0
- data/python/rbbt_dm/huggingface.py +57 -26
- data/python/rbbt_dm/language_model.py +70 -0
- data/python/rbbt_dm/util.py +30 -0
- data/share/spaCy/gpu/textcat_accuracy.conf +2 -1
- data/test/rbbt/vector/model/huggingface/test_masked_lm.rb +41 -0
- data/test/rbbt/vector/model/test_huggingface.rb +258 -27
- data/test/rbbt/vector/model/test_pytorch_lightning.rb +83 -0
- data/test/rbbt/vector/model/test_spaCy.rb +1 -1
- data/test/rbbt/vector/model/test_tensorflow.rb +3 -0
- data/test/rbbt/vector/test_model.rb +25 -26
- data/test/test_helper.rb +13 -0
- metadata +26 -16
- data/lib/rbbt/tensorflow.rb +0 -43
- data/lib/rbbt/vector/model/huggingface.old.rb +0 -160
@@ -26,7 +26,7 @@ class TestVectorModel < Test::Unit::TestCase
|
|
26
26
|
element.split(";")
|
27
27
|
}
|
28
28
|
|
29
|
-
model.train_model = Proc.new{|
|
29
|
+
model.train_model = Proc.new{|features,labels|
|
30
30
|
TmpFile.with_file do |feature_file|
|
31
31
|
Open.write(feature_file, features.collect{|feats| feats * "\t"} * "\n")
|
32
32
|
Open.write(feature_file + '.class', labels * "\n")
|
@@ -36,23 +36,23 @@ labels = scan("#{ feature_file }.class", what=numeric());
|
|
36
36
|
features = cbind(features, class = labels);
|
37
37
|
rbbt.require('e1071')
|
38
38
|
model = svm(class ~ ., data = features)
|
39
|
-
save(model, file="#{
|
39
|
+
save(model, file="#{ @model_path }");
|
40
40
|
EOF
|
41
41
|
end
|
42
42
|
}
|
43
43
|
|
44
|
-
model.eval_model = Proc.new{|
|
44
|
+
model.eval_model = Proc.new{|features|
|
45
45
|
TmpFile.with_file do |feature_file|
|
46
46
|
TmpFile.with_file do |results|
|
47
47
|
Open.write(feature_file, features * "\t")
|
48
|
-
|
48
|
+
R.run <<-EOF
|
49
49
|
features = read.table("#{ feature_file }", sep ="\\t", stringsAsFactors=FALSE);
|
50
50
|
library(e1071)
|
51
|
-
load(file="#{
|
51
|
+
load(file="#{ @model_path }")
|
52
52
|
label = predict(model, features);
|
53
53
|
cat(label, file="#{results}");
|
54
54
|
EOF
|
55
|
-
|
55
|
+
|
56
56
|
Open.read(results)
|
57
57
|
end
|
58
58
|
end
|
@@ -96,7 +96,7 @@ cat(label, file="#{results}");
|
|
96
96
|
end
|
97
97
|
}
|
98
98
|
|
99
|
-
model.train_model = Proc.new{|
|
99
|
+
model.train_model = Proc.new{|features,labels|
|
100
100
|
TmpFile.with_file do |feature_file|
|
101
101
|
Open.write(feature_file, features.collect{|feats| feats * "\t"} * "\n")
|
102
102
|
Open.write(feature_file + '.class', labels * "\n")
|
@@ -106,23 +106,23 @@ labels = scan("#{ feature_file }.class", what=numeric());
|
|
106
106
|
features = cbind(features, class = labels);
|
107
107
|
rbbt.require('e1071')
|
108
108
|
model = svm(class ~ ., data = features)
|
109
|
-
save(model, file="#{
|
109
|
+
save(model, file="#{ @model_path }");
|
110
110
|
EOF
|
111
111
|
end
|
112
112
|
}
|
113
113
|
|
114
|
-
model.eval_model = Proc.new{|
|
114
|
+
model.eval_model = Proc.new{|features|
|
115
115
|
TmpFile.with_file do |feature_file|
|
116
116
|
TmpFile.with_file do |results|
|
117
117
|
Open.write(feature_file, features * "\t")
|
118
|
-
|
118
|
+
R.run <<-EOF
|
119
119
|
features = read.table("#{ feature_file }", sep ="\\t", stringsAsFactors=FALSE);
|
120
120
|
library(e1071)
|
121
|
-
load(file="#{
|
121
|
+
load(file="#{ @model_path }")
|
122
122
|
label = predict(model, features);
|
123
123
|
cat(label, file="#{results}");
|
124
124
|
EOF
|
125
|
-
|
125
|
+
|
126
126
|
Open.read(results)
|
127
127
|
end
|
128
128
|
end
|
@@ -164,7 +164,7 @@ cat(label, file="#{results}");
|
|
164
164
|
element.split(";")
|
165
165
|
}
|
166
166
|
|
167
|
-
model.train_model = Proc.new{|
|
167
|
+
model.train_model = Proc.new{|features,labels|
|
168
168
|
TmpFile.with_file do |feature_file|
|
169
169
|
Open.write(feature_file, features.collect{|feats| feats * "\t"} * "\n")
|
170
170
|
Open.write(feature_file + '.class', labels * "\n")
|
@@ -174,23 +174,23 @@ labels = scan("#{ feature_file }.class", what=numeric());
|
|
174
174
|
features = cbind(features, class = labels);
|
175
175
|
rbbt.require('e1071')
|
176
176
|
model = svm(class ~ ., data = features)
|
177
|
-
save(model, file="#{
|
177
|
+
save(model, file="#{ @model_path }");
|
178
178
|
EOF
|
179
179
|
end
|
180
180
|
}
|
181
181
|
|
182
|
-
model.eval_model = Proc.new{|
|
182
|
+
model.eval_model = Proc.new{|features|
|
183
183
|
TmpFile.with_file do |feature_file|
|
184
184
|
TmpFile.with_file do |results|
|
185
185
|
Open.write(feature_file, features * "\t")
|
186
|
-
|
186
|
+
R.run <<-EOF
|
187
187
|
features = read.table("#{ feature_file }", sep ="\\t", stringsAsFactors=FALSE);
|
188
188
|
library(e1071)
|
189
|
-
load(file="#{
|
189
|
+
load(file="#{ @model_path }")
|
190
190
|
label = predict(model, features);
|
191
191
|
cat(label, file="#{results}");
|
192
192
|
EOF
|
193
|
-
|
193
|
+
|
194
194
|
Open.read(results)
|
195
195
|
end
|
196
196
|
end
|
@@ -236,7 +236,7 @@ cat(label, file="#{results}");
|
|
236
236
|
end
|
237
237
|
}
|
238
238
|
|
239
|
-
model.train_model = Proc.new{|
|
239
|
+
model.train_model = Proc.new{|features,labels|
|
240
240
|
TmpFile.with_file do |feature_file|
|
241
241
|
Open.write(feature_file, features.collect{|feats| feats * "\t"} * "\n")
|
242
242
|
Open.write(feature_file + '.class', labels * "\n")
|
@@ -246,23 +246,23 @@ labels = scan("#{ feature_file }.class", what=numeric());
|
|
246
246
|
features = cbind(features, label = labels);
|
247
247
|
rbbt.require('e1071')
|
248
248
|
model = svm(label ~ ., data = features)
|
249
|
-
save(model, file="#{
|
249
|
+
save(model, file="#{ @model_path }");
|
250
250
|
EOF
|
251
251
|
end
|
252
252
|
}
|
253
253
|
|
254
|
-
model.eval_model = Proc.new{|
|
254
|
+
model.eval_model = Proc.new{|features|
|
255
255
|
TmpFile.with_file do |feature_file|
|
256
256
|
TmpFile.with_file do |results|
|
257
257
|
Open.write(feature_file, features * "\t")
|
258
|
-
|
258
|
+
R.run <<-EOF
|
259
259
|
features = read.table("#{ feature_file }", sep ="\\t", stringsAsFactors=FALSE);
|
260
260
|
library(e1071)
|
261
|
-
load(file="#{
|
261
|
+
load(file="#{ @model_path }")
|
262
262
|
label = predict(model, features);
|
263
263
|
cat(label, file="#{results}");
|
264
264
|
EOF
|
265
|
-
|
265
|
+
|
266
266
|
Open.read(results)
|
267
267
|
end
|
268
268
|
end
|
@@ -282,7 +282,6 @@ cat(label, file="#{results}");
|
|
282
282
|
model.add features, label
|
283
283
|
end
|
284
284
|
|
285
|
-
iii model.eval("1;1;1")
|
286
285
|
assert model.eval("1;1;1").to_f > 0.5
|
287
286
|
assert model.eval("0;0;0").to_f < 0.5
|
288
287
|
end
|
@@ -515,7 +514,7 @@ label = predict(model, features);
|
|
515
514
|
TmpFile.with_file do |dir|
|
516
515
|
model = VectorModel.new dir
|
517
516
|
|
518
|
-
model.eval_model do |
|
517
|
+
model.eval_model do |elements|
|
519
518
|
elements = [elements] unless Array === elements
|
520
519
|
RbbtPython.binding_run do
|
521
520
|
pyimport :torch
|
data/test/test_helper.rb
CHANGED
@@ -19,4 +19,17 @@ class Test::Unit::TestCase
|
|
19
19
|
def datafile_test(file)
|
20
20
|
Test::Unit::TestCase.datafile_test(file)
|
21
21
|
end
|
22
|
+
|
23
|
+
def with_python(code, &block)
|
24
|
+
TmpFile.with_file do |dir|
|
25
|
+
pkg = "pkg#{rand(100)}"
|
26
|
+
Open.write File.join(dir, "#{pkg}/__init__.py"), code
|
27
|
+
|
28
|
+
RbbtPython.add_path dir
|
29
|
+
|
30
|
+
Misc.in_dir dir do
|
31
|
+
yield pkg
|
32
|
+
end
|
33
|
+
end
|
34
|
+
end
|
22
35
|
end
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rbbt-dm
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 1.2.
|
4
|
+
version: 1.2.9
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Miguel Vazquez
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-
|
11
|
+
date: 2023-08-30 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rbbt-util
|
@@ -105,17 +105,23 @@ files:
|
|
105
105
|
- lib/rbbt/statistics/hypergeometric.rb
|
106
106
|
- lib/rbbt/statistics/random_walk.rb
|
107
107
|
- lib/rbbt/statistics/rank_product.rb
|
108
|
-
- lib/rbbt/tensorflow.rb
|
109
108
|
- lib/rbbt/vector/model.rb
|
110
|
-
- lib/rbbt/vector/model/huggingface.old.rb
|
111
109
|
- lib/rbbt/vector/model/huggingface.rb
|
110
|
+
- lib/rbbt/vector/model/huggingface/masked_lm.rb
|
111
|
+
- lib/rbbt/vector/model/pytorch_lightning.rb
|
112
112
|
- lib/rbbt/vector/model/random_forest.rb
|
113
113
|
- lib/rbbt/vector/model/spaCy.rb
|
114
114
|
- lib/rbbt/vector/model/svm.rb
|
115
115
|
- lib/rbbt/vector/model/tensorflow.rb
|
116
|
+
- lib/rbbt/vector/model/torch.rb
|
116
117
|
- lib/rbbt/vector/model/util.rb
|
117
118
|
- python/rbbt_dm/__init__.py
|
119
|
+
- python/rbbt_dm/atcold/__init__.py
|
120
|
+
- python/rbbt_dm/atcold/plot_lib.py
|
121
|
+
- python/rbbt_dm/atcold/spiral.py
|
118
122
|
- python/rbbt_dm/huggingface.py
|
123
|
+
- python/rbbt_dm/language_model.py
|
124
|
+
- python/rbbt_dm/util.py
|
119
125
|
- share/R/MA.R
|
120
126
|
- share/R/barcode.R
|
121
127
|
- share/R/heatmap.3.R
|
@@ -135,7 +141,9 @@ files:
|
|
135
141
|
- test/rbbt/statistics/test_random_walk.rb
|
136
142
|
- test/rbbt/test_ml_task.rb
|
137
143
|
- test/rbbt/test_stan.rb
|
144
|
+
- test/rbbt/vector/model/huggingface/test_masked_lm.rb
|
138
145
|
- test/rbbt/vector/model/test_huggingface.rb
|
146
|
+
- test/rbbt/vector/model/test_pytorch_lightning.rb
|
139
147
|
- test/rbbt/vector/model/test_spaCy.rb
|
140
148
|
- test/rbbt/vector/model/test_svm.rb
|
141
149
|
- test/rbbt/vector/model/test_tensorflow.rb
|
@@ -159,22 +167,24 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
159
167
|
- !ruby/object:Gem::Version
|
160
168
|
version: '0'
|
161
169
|
requirements: []
|
162
|
-
rubygems_version: 3.
|
170
|
+
rubygems_version: 3.4.19
|
163
171
|
signing_key:
|
164
172
|
specification_version: 4
|
165
173
|
summary: Data-mining and statistics
|
166
174
|
test_files:
|
167
|
-
- test/
|
168
|
-
- test/rbbt/
|
169
|
-
- test/rbbt/vector/model/test_huggingface.rb
|
170
|
-
- test/rbbt/vector/model/test_tensorflow.rb
|
171
|
-
- test/rbbt/vector/model/test_spaCy.rb
|
172
|
-
- test/rbbt/vector/model/test_svm.rb
|
173
|
-
- test/rbbt/statistics/test_random_walk.rb
|
174
|
-
- test/rbbt/statistics/test_fisher.rb
|
175
|
+
- test/rbbt/matrix/test_barcode.rb
|
176
|
+
- test/rbbt/network/test_paths.rb
|
175
177
|
- test/rbbt/statistics/test_fdr.rb
|
178
|
+
- test/rbbt/statistics/test_fisher.rb
|
176
179
|
- test/rbbt/statistics/test_hypergeometric.rb
|
177
|
-
- test/rbbt/
|
178
|
-
- test/rbbt/matrix/test_barcode.rb
|
180
|
+
- test/rbbt/statistics/test_random_walk.rb
|
179
181
|
- test/rbbt/test_ml_task.rb
|
180
|
-
- test/rbbt/
|
182
|
+
- test/rbbt/test_stan.rb
|
183
|
+
- test/rbbt/vector/model/huggingface/test_masked_lm.rb
|
184
|
+
- test/rbbt/vector/model/test_huggingface.rb
|
185
|
+
- test/rbbt/vector/model/test_pytorch_lightning.rb
|
186
|
+
- test/rbbt/vector/model/test_spaCy.rb
|
187
|
+
- test/rbbt/vector/model/test_svm.rb
|
188
|
+
- test/rbbt/vector/model/test_tensorflow.rb
|
189
|
+
- test/rbbt/vector/test_model.rb
|
190
|
+
- test/test_helper.rb
|
data/lib/rbbt/tensorflow.rb
DELETED
@@ -1,43 +0,0 @@
|
|
1
|
-
require 'rbbt/util/python'
|
2
|
-
|
3
|
-
module RbbtTensorflow
|
4
|
-
|
5
|
-
def self.init
|
6
|
-
RbbtPython.run do
|
7
|
-
pyimport "tensorflow", as: "tf"
|
8
|
-
end
|
9
|
-
end
|
10
|
-
|
11
|
-
def self.test
|
12
|
-
|
13
|
-
mod = x_test = y_test = nil
|
14
|
-
RbbtPython.run do
|
15
|
-
|
16
|
-
mnist_db = tf.keras.datasets.mnist
|
17
|
-
|
18
|
-
(x_train, y_train), (x_test, y_test) = mnist_db.load_data()
|
19
|
-
x_train, x_test = x_train / 255.0, x_test / 255.0
|
20
|
-
|
21
|
-
mod = tf.keras.models.Sequential.new([
|
22
|
-
tf.keras.layers.Flatten.new(input_shape: [28, 28]),
|
23
|
-
tf.keras.layers.Dense.new(128, activation:'relu'),
|
24
|
-
tf.keras.layers.Dropout.new(0.2),
|
25
|
-
tf.keras.layers.Dense.new(10, activation:'softmax')
|
26
|
-
])
|
27
|
-
mod.compile(optimizer='adam',
|
28
|
-
loss='sparse_categorical_crossentropy',
|
29
|
-
metrics=['accuracy'])
|
30
|
-
mod.fit(x_train, y_train, epochs:3)
|
31
|
-
mod
|
32
|
-
end
|
33
|
-
|
34
|
-
RbbtPython.run do
|
35
|
-
mod.evaluate(x_test, y_test, verbose:2)
|
36
|
-
end
|
37
|
-
end
|
38
|
-
end
|
39
|
-
|
40
|
-
if __FILE__ == $0
|
41
|
-
RbbtTensorflow.init
|
42
|
-
RbbtTensorflow.test
|
43
|
-
end
|
@@ -1,160 +0,0 @@
|
|
1
|
-
require 'rbbt/vector/model'
|
2
|
-
require 'rbbt/util/python'
|
3
|
-
|
4
|
-
RbbtPython.add_path Rbbt.python.find(:lib)
|
5
|
-
RbbtPython.init_rbbt
|
6
|
-
|
7
|
-
class HuggingfaceModel < VectorModel
|
8
|
-
|
9
|
-
attr_accessor :checkpoint, :task, :locate_tokens, :class_labels, :class_weights, :training_args
|
10
|
-
|
11
|
-
def self.tsv_dataset(tsv_dataset_file, elements, labels = nil)
|
12
|
-
|
13
|
-
if labels
|
14
|
-
Open.write(tsv_dataset_file) do |ffile|
|
15
|
-
ffile.puts ["label", "text"].flatten * "\t"
|
16
|
-
elements.zip(labels).each do |element,label|
|
17
|
-
ffile.puts [label, element].flatten * "\t"
|
18
|
-
end
|
19
|
-
end
|
20
|
-
else
|
21
|
-
Open.write(tsv_dataset_file) do |ffile|
|
22
|
-
ffile.puts ["text"].flatten * "\t"
|
23
|
-
elements.each{|element| ffile.puts element }
|
24
|
-
end
|
25
|
-
end
|
26
|
-
|
27
|
-
tsv_dataset_file
|
28
|
-
end
|
29
|
-
|
30
|
-
def self.call_method(name, *args)
|
31
|
-
RbbtPython.import_method("rbbt_dm.huggingface", name).call(*args)
|
32
|
-
end
|
33
|
-
|
34
|
-
def call_method(name, *args)
|
35
|
-
HuggingfaceModel.call_method(name, *args)
|
36
|
-
end
|
37
|
-
|
38
|
-
#def input_tsv_file
|
39
|
-
# File.join(@directory, 'dataset.tsv') if @directory
|
40
|
-
#end
|
41
|
-
|
42
|
-
#def checkpoint_dir
|
43
|
-
# File.join(@directory, 'checkpoints') if @directory
|
44
|
-
#end
|
45
|
-
|
46
|
-
def self.run_model(model, tokenizer, elements, labels = nil, training_args = {}, class_weights = nil)
|
47
|
-
TmpFile.with_file do |tmpfile|
|
48
|
-
tsv_file = File.join(tmpfile, 'dataset.tsv')
|
49
|
-
|
50
|
-
if training_args
|
51
|
-
training_args = training_args.dup
|
52
|
-
checkpoint_dir = training_args.delete(:checkpoint_dir)
|
53
|
-
end
|
54
|
-
|
55
|
-
checkpoint_dir = File.join(tmpfile, 'checkpoints')
|
56
|
-
|
57
|
-
Open.mkdir File.dirname(tsv_file)
|
58
|
-
Open.mkdir File.dirname(checkpoint_dir)
|
59
|
-
|
60
|
-
if labels
|
61
|
-
training_args_obj = call_method(:training_args, checkpoint_dir, **training_args)
|
62
|
-
call_method(:train_model, model, tokenizer, training_args_obj, tsv_dataset(tsv_file, elements, labels), class_weights)
|
63
|
-
else
|
64
|
-
locate_tokens, training_args = training_args, {}
|
65
|
-
if Array === elements
|
66
|
-
training_args_obj = call_method(:training_args, checkpoint_dir)
|
67
|
-
call_method(:predict_model, model, tokenizer, training_args_obj, tsv_dataset(tsv_file, elements), locate_tokens)
|
68
|
-
else
|
69
|
-
call_method(:eval_model, model, tokenizer, [elements], locate_tokens)
|
70
|
-
end
|
71
|
-
end
|
72
|
-
end
|
73
|
-
end
|
74
|
-
|
75
|
-
def init_model
|
76
|
-
@model, @tokenizer = call_method(:load_model_and_tokenizer, @task, @checkpoint)
|
77
|
-
end
|
78
|
-
|
79
|
-
def reset_model
|
80
|
-
init_model
|
81
|
-
end
|
82
|
-
|
83
|
-
def initialize(task, initial_checkpoint = nil, *args)
|
84
|
-
super(*args)
|
85
|
-
@task = task
|
86
|
-
|
87
|
-
@checkpoint = model_file && File.exists?(model_file)? model_file : initial_checkpoint
|
88
|
-
|
89
|
-
init_model
|
90
|
-
|
91
|
-
@locate_tokens = @tokenizer.special_tokens_map["mask_token"] if @task == "MaskedLM"
|
92
|
-
|
93
|
-
@training_args = {}
|
94
|
-
|
95
|
-
train_model do |file,elements,labels|
|
96
|
-
HuggingfaceModel.run_model(@model, @tokenizer, elements, labels, @training_args, @class_weights)
|
97
|
-
|
98
|
-
@model.save_pretrained(file) if file
|
99
|
-
@tokenizer.save_pretrained(file) if file
|
100
|
-
end
|
101
|
-
|
102
|
-
eval_model do |file,elements|
|
103
|
-
@model, @tokenizer = HuggingfaceModel.call_method(:load_model_and_tokenizer, @task, @checkpoint)
|
104
|
-
HuggingfaceModel.run_model(@model, @tokenizer, elements, nil, @locate_tokens)
|
105
|
-
end
|
106
|
-
|
107
|
-
post_process do |result|
|
108
|
-
if result.respond_to?(:predictions)
|
109
|
-
single = false
|
110
|
-
predictions = result.predictions
|
111
|
-
elsif result["token_positions"]
|
112
|
-
predictions = result["result"].predictions
|
113
|
-
token_positions = result["token_positions"]
|
114
|
-
else
|
115
|
-
single = true
|
116
|
-
predictions = result["logits"]
|
117
|
-
end
|
118
|
-
|
119
|
-
result = case @task
|
120
|
-
when "SequenceClassification"
|
121
|
-
RbbtPython.collect(predictions) do |logits|
|
122
|
-
logits = RbbtPython.numpy2ruby logits
|
123
|
-
best_class = logits.index logits.max
|
124
|
-
best_class = @class_labels[best_class] if @class_labels
|
125
|
-
best_class
|
126
|
-
end
|
127
|
-
when "MaskedLM"
|
128
|
-
all_token_positions = token_positions.to_a
|
129
|
-
|
130
|
-
i = 0
|
131
|
-
RbbtPython.collect(predictions) do |item_logits|
|
132
|
-
item_token_positions = all_token_positions[i]
|
133
|
-
i += 1
|
134
|
-
|
135
|
-
item_logits = RbbtPython.numpy2ruby(item_logits)
|
136
|
-
item_masks = item_token_positions.collect do |token_positions|
|
137
|
-
|
138
|
-
best = item_logits.values_at(*token_positions).collect do |logits|
|
139
|
-
best_token, best_score = nil
|
140
|
-
logits.each_with_index do |v,i|
|
141
|
-
if best_score.nil? || v > best_score
|
142
|
-
best_token, best_score = i, v
|
143
|
-
end
|
144
|
-
end
|
145
|
-
best_token
|
146
|
-
end
|
147
|
-
|
148
|
-
best.collect{|b| @tokenizer.decode(b) } * "|"
|
149
|
-
end
|
150
|
-
Array === @locate_tokens ? item_masks : item_masks.first
|
151
|
-
end
|
152
|
-
else
|
153
|
-
logits
|
154
|
-
end
|
155
|
-
|
156
|
-
single ? result.first : result
|
157
|
-
end
|
158
|
-
end
|
159
|
-
end
|
160
|
-
|