rbbt-dm 1.2.7 → 1.2.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/lib/rbbt/matrix/barcode.rb +2 -2
- data/lib/rbbt/matrix/differential.rb +3 -3
- data/lib/rbbt/matrix/knowledge_base.rb +1 -1
- data/lib/rbbt/plots/bar.rb +1 -1
- data/lib/rbbt/stan.rb +1 -1
- data/lib/rbbt/statistics/hypergeometric.rb +2 -1
- data/lib/rbbt/vector/model/huggingface/masked_lm.rb +50 -0
- data/lib/rbbt/vector/model/huggingface.rb +57 -38
- data/lib/rbbt/vector/model/pytorch_lightning.rb +35 -0
- data/lib/rbbt/vector/model/random_forest.rb +1 -1
- data/lib/rbbt/vector/model/spaCy.rb +8 -6
- data/lib/rbbt/vector/model/tensorflow.rb +6 -5
- data/lib/rbbt/vector/model/torch.rb +37 -0
- data/lib/rbbt/vector/model.rb +82 -52
- data/python/rbbt_dm/__init__.py +48 -1
- data/python/rbbt_dm/atcold/__init__.py +0 -0
- data/python/rbbt_dm/atcold/plot_lib.py +141 -0
- data/python/rbbt_dm/atcold/spiral.py +27 -0
- data/python/rbbt_dm/huggingface.py +57 -26
- data/python/rbbt_dm/language_model.py +70 -0
- data/python/rbbt_dm/util.py +30 -0
- data/share/spaCy/gpu/textcat_accuracy.conf +2 -1
- data/test/rbbt/vector/model/huggingface/test_masked_lm.rb +41 -0
- data/test/rbbt/vector/model/test_huggingface.rb +258 -27
- data/test/rbbt/vector/model/test_pytorch_lightning.rb +83 -0
- data/test/rbbt/vector/model/test_spaCy.rb +1 -1
- data/test/rbbt/vector/model/test_tensorflow.rb +3 -0
- data/test/rbbt/vector/test_model.rb +25 -26
- data/test/test_helper.rb +13 -0
- metadata +26 -16
- data/lib/rbbt/tensorflow.rb +0 -43
- data/lib/rbbt/vector/model/huggingface.old.rb +0 -160
@@ -26,7 +26,7 @@ class TestVectorModel < Test::Unit::TestCase
|
|
26
26
|
element.split(";")
|
27
27
|
}
|
28
28
|
|
29
|
-
model.train_model = Proc.new{|
|
29
|
+
model.train_model = Proc.new{|features,labels|
|
30
30
|
TmpFile.with_file do |feature_file|
|
31
31
|
Open.write(feature_file, features.collect{|feats| feats * "\t"} * "\n")
|
32
32
|
Open.write(feature_file + '.class', labels * "\n")
|
@@ -36,23 +36,23 @@ labels = scan("#{ feature_file }.class", what=numeric());
|
|
36
36
|
features = cbind(features, class = labels);
|
37
37
|
rbbt.require('e1071')
|
38
38
|
model = svm(class ~ ., data = features)
|
39
|
-
save(model, file="#{
|
39
|
+
save(model, file="#{ @model_path }");
|
40
40
|
EOF
|
41
41
|
end
|
42
42
|
}
|
43
43
|
|
44
|
-
model.eval_model = Proc.new{|
|
44
|
+
model.eval_model = Proc.new{|features|
|
45
45
|
TmpFile.with_file do |feature_file|
|
46
46
|
TmpFile.with_file do |results|
|
47
47
|
Open.write(feature_file, features * "\t")
|
48
|
-
|
48
|
+
R.run <<-EOF
|
49
49
|
features = read.table("#{ feature_file }", sep ="\\t", stringsAsFactors=FALSE);
|
50
50
|
library(e1071)
|
51
|
-
load(file="#{
|
51
|
+
load(file="#{ @model_path }")
|
52
52
|
label = predict(model, features);
|
53
53
|
cat(label, file="#{results}");
|
54
54
|
EOF
|
55
|
-
|
55
|
+
|
56
56
|
Open.read(results)
|
57
57
|
end
|
58
58
|
end
|
@@ -96,7 +96,7 @@ cat(label, file="#{results}");
|
|
96
96
|
end
|
97
97
|
}
|
98
98
|
|
99
|
-
model.train_model = Proc.new{|
|
99
|
+
model.train_model = Proc.new{|features,labels|
|
100
100
|
TmpFile.with_file do |feature_file|
|
101
101
|
Open.write(feature_file, features.collect{|feats| feats * "\t"} * "\n")
|
102
102
|
Open.write(feature_file + '.class', labels * "\n")
|
@@ -106,23 +106,23 @@ labels = scan("#{ feature_file }.class", what=numeric());
|
|
106
106
|
features = cbind(features, class = labels);
|
107
107
|
rbbt.require('e1071')
|
108
108
|
model = svm(class ~ ., data = features)
|
109
|
-
save(model, file="#{
|
109
|
+
save(model, file="#{ @model_path }");
|
110
110
|
EOF
|
111
111
|
end
|
112
112
|
}
|
113
113
|
|
114
|
-
model.eval_model = Proc.new{|
|
114
|
+
model.eval_model = Proc.new{|features|
|
115
115
|
TmpFile.with_file do |feature_file|
|
116
116
|
TmpFile.with_file do |results|
|
117
117
|
Open.write(feature_file, features * "\t")
|
118
|
-
|
118
|
+
R.run <<-EOF
|
119
119
|
features = read.table("#{ feature_file }", sep ="\\t", stringsAsFactors=FALSE);
|
120
120
|
library(e1071)
|
121
|
-
load(file="#{
|
121
|
+
load(file="#{ @model_path }")
|
122
122
|
label = predict(model, features);
|
123
123
|
cat(label, file="#{results}");
|
124
124
|
EOF
|
125
|
-
|
125
|
+
|
126
126
|
Open.read(results)
|
127
127
|
end
|
128
128
|
end
|
@@ -164,7 +164,7 @@ cat(label, file="#{results}");
|
|
164
164
|
element.split(";")
|
165
165
|
}
|
166
166
|
|
167
|
-
model.train_model = Proc.new{|
|
167
|
+
model.train_model = Proc.new{|features,labels|
|
168
168
|
TmpFile.with_file do |feature_file|
|
169
169
|
Open.write(feature_file, features.collect{|feats| feats * "\t"} * "\n")
|
170
170
|
Open.write(feature_file + '.class', labels * "\n")
|
@@ -174,23 +174,23 @@ labels = scan("#{ feature_file }.class", what=numeric());
|
|
174
174
|
features = cbind(features, class = labels);
|
175
175
|
rbbt.require('e1071')
|
176
176
|
model = svm(class ~ ., data = features)
|
177
|
-
save(model, file="#{
|
177
|
+
save(model, file="#{ @model_path }");
|
178
178
|
EOF
|
179
179
|
end
|
180
180
|
}
|
181
181
|
|
182
|
-
model.eval_model = Proc.new{|
|
182
|
+
model.eval_model = Proc.new{|features|
|
183
183
|
TmpFile.with_file do |feature_file|
|
184
184
|
TmpFile.with_file do |results|
|
185
185
|
Open.write(feature_file, features * "\t")
|
186
|
-
|
186
|
+
R.run <<-EOF
|
187
187
|
features = read.table("#{ feature_file }", sep ="\\t", stringsAsFactors=FALSE);
|
188
188
|
library(e1071)
|
189
|
-
load(file="#{
|
189
|
+
load(file="#{ @model_path }")
|
190
190
|
label = predict(model, features);
|
191
191
|
cat(label, file="#{results}");
|
192
192
|
EOF
|
193
|
-
|
193
|
+
|
194
194
|
Open.read(results)
|
195
195
|
end
|
196
196
|
end
|
@@ -236,7 +236,7 @@ cat(label, file="#{results}");
|
|
236
236
|
end
|
237
237
|
}
|
238
238
|
|
239
|
-
model.train_model = Proc.new{|
|
239
|
+
model.train_model = Proc.new{|features,labels|
|
240
240
|
TmpFile.with_file do |feature_file|
|
241
241
|
Open.write(feature_file, features.collect{|feats| feats * "\t"} * "\n")
|
242
242
|
Open.write(feature_file + '.class', labels * "\n")
|
@@ -246,23 +246,23 @@ labels = scan("#{ feature_file }.class", what=numeric());
|
|
246
246
|
features = cbind(features, label = labels);
|
247
247
|
rbbt.require('e1071')
|
248
248
|
model = svm(label ~ ., data = features)
|
249
|
-
save(model, file="#{
|
249
|
+
save(model, file="#{ @model_path }");
|
250
250
|
EOF
|
251
251
|
end
|
252
252
|
}
|
253
253
|
|
254
|
-
model.eval_model = Proc.new{|
|
254
|
+
model.eval_model = Proc.new{|features|
|
255
255
|
TmpFile.with_file do |feature_file|
|
256
256
|
TmpFile.with_file do |results|
|
257
257
|
Open.write(feature_file, features * "\t")
|
258
|
-
|
258
|
+
R.run <<-EOF
|
259
259
|
features = read.table("#{ feature_file }", sep ="\\t", stringsAsFactors=FALSE);
|
260
260
|
library(e1071)
|
261
|
-
load(file="#{
|
261
|
+
load(file="#{ @model_path }")
|
262
262
|
label = predict(model, features);
|
263
263
|
cat(label, file="#{results}");
|
264
264
|
EOF
|
265
|
-
|
265
|
+
|
266
266
|
Open.read(results)
|
267
267
|
end
|
268
268
|
end
|
@@ -282,7 +282,6 @@ cat(label, file="#{results}");
|
|
282
282
|
model.add features, label
|
283
283
|
end
|
284
284
|
|
285
|
-
iii model.eval("1;1;1")
|
286
285
|
assert model.eval("1;1;1").to_f > 0.5
|
287
286
|
assert model.eval("0;0;0").to_f < 0.5
|
288
287
|
end
|
@@ -515,7 +514,7 @@ label = predict(model, features);
|
|
515
514
|
TmpFile.with_file do |dir|
|
516
515
|
model = VectorModel.new dir
|
517
516
|
|
518
|
-
model.eval_model do |
|
517
|
+
model.eval_model do |elements|
|
519
518
|
elements = [elements] unless Array === elements
|
520
519
|
RbbtPython.binding_run do
|
521
520
|
pyimport :torch
|
data/test/test_helper.rb
CHANGED
@@ -19,4 +19,17 @@ class Test::Unit::TestCase
|
|
19
19
|
def datafile_test(file)
|
20
20
|
Test::Unit::TestCase.datafile_test(file)
|
21
21
|
end
|
22
|
+
|
23
|
+
def with_python(code, &block)
|
24
|
+
TmpFile.with_file do |dir|
|
25
|
+
pkg = "pkg#{rand(100)}"
|
26
|
+
Open.write File.join(dir, "#{pkg}/__init__.py"), code
|
27
|
+
|
28
|
+
RbbtPython.add_path dir
|
29
|
+
|
30
|
+
Misc.in_dir dir do
|
31
|
+
yield pkg
|
32
|
+
end
|
33
|
+
end
|
34
|
+
end
|
22
35
|
end
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rbbt-dm
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 1.2.
|
4
|
+
version: 1.2.9
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Miguel Vazquez
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-
|
11
|
+
date: 2023-08-30 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rbbt-util
|
@@ -105,17 +105,23 @@ files:
|
|
105
105
|
- lib/rbbt/statistics/hypergeometric.rb
|
106
106
|
- lib/rbbt/statistics/random_walk.rb
|
107
107
|
- lib/rbbt/statistics/rank_product.rb
|
108
|
-
- lib/rbbt/tensorflow.rb
|
109
108
|
- lib/rbbt/vector/model.rb
|
110
|
-
- lib/rbbt/vector/model/huggingface.old.rb
|
111
109
|
- lib/rbbt/vector/model/huggingface.rb
|
110
|
+
- lib/rbbt/vector/model/huggingface/masked_lm.rb
|
111
|
+
- lib/rbbt/vector/model/pytorch_lightning.rb
|
112
112
|
- lib/rbbt/vector/model/random_forest.rb
|
113
113
|
- lib/rbbt/vector/model/spaCy.rb
|
114
114
|
- lib/rbbt/vector/model/svm.rb
|
115
115
|
- lib/rbbt/vector/model/tensorflow.rb
|
116
|
+
- lib/rbbt/vector/model/torch.rb
|
116
117
|
- lib/rbbt/vector/model/util.rb
|
117
118
|
- python/rbbt_dm/__init__.py
|
119
|
+
- python/rbbt_dm/atcold/__init__.py
|
120
|
+
- python/rbbt_dm/atcold/plot_lib.py
|
121
|
+
- python/rbbt_dm/atcold/spiral.py
|
118
122
|
- python/rbbt_dm/huggingface.py
|
123
|
+
- python/rbbt_dm/language_model.py
|
124
|
+
- python/rbbt_dm/util.py
|
119
125
|
- share/R/MA.R
|
120
126
|
- share/R/barcode.R
|
121
127
|
- share/R/heatmap.3.R
|
@@ -135,7 +141,9 @@ files:
|
|
135
141
|
- test/rbbt/statistics/test_random_walk.rb
|
136
142
|
- test/rbbt/test_ml_task.rb
|
137
143
|
- test/rbbt/test_stan.rb
|
144
|
+
- test/rbbt/vector/model/huggingface/test_masked_lm.rb
|
138
145
|
- test/rbbt/vector/model/test_huggingface.rb
|
146
|
+
- test/rbbt/vector/model/test_pytorch_lightning.rb
|
139
147
|
- test/rbbt/vector/model/test_spaCy.rb
|
140
148
|
- test/rbbt/vector/model/test_svm.rb
|
141
149
|
- test/rbbt/vector/model/test_tensorflow.rb
|
@@ -159,22 +167,24 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
159
167
|
- !ruby/object:Gem::Version
|
160
168
|
version: '0'
|
161
169
|
requirements: []
|
162
|
-
rubygems_version: 3.
|
170
|
+
rubygems_version: 3.4.19
|
163
171
|
signing_key:
|
164
172
|
specification_version: 4
|
165
173
|
summary: Data-mining and statistics
|
166
174
|
test_files:
|
167
|
-
- test/
|
168
|
-
- test/rbbt/
|
169
|
-
- test/rbbt/vector/model/test_huggingface.rb
|
170
|
-
- test/rbbt/vector/model/test_tensorflow.rb
|
171
|
-
- test/rbbt/vector/model/test_spaCy.rb
|
172
|
-
- test/rbbt/vector/model/test_svm.rb
|
173
|
-
- test/rbbt/statistics/test_random_walk.rb
|
174
|
-
- test/rbbt/statistics/test_fisher.rb
|
175
|
+
- test/rbbt/matrix/test_barcode.rb
|
176
|
+
- test/rbbt/network/test_paths.rb
|
175
177
|
- test/rbbt/statistics/test_fdr.rb
|
178
|
+
- test/rbbt/statistics/test_fisher.rb
|
176
179
|
- test/rbbt/statistics/test_hypergeometric.rb
|
177
|
-
- test/rbbt/
|
178
|
-
- test/rbbt/matrix/test_barcode.rb
|
180
|
+
- test/rbbt/statistics/test_random_walk.rb
|
179
181
|
- test/rbbt/test_ml_task.rb
|
180
|
-
- test/rbbt/
|
182
|
+
- test/rbbt/test_stan.rb
|
183
|
+
- test/rbbt/vector/model/huggingface/test_masked_lm.rb
|
184
|
+
- test/rbbt/vector/model/test_huggingface.rb
|
185
|
+
- test/rbbt/vector/model/test_pytorch_lightning.rb
|
186
|
+
- test/rbbt/vector/model/test_spaCy.rb
|
187
|
+
- test/rbbt/vector/model/test_svm.rb
|
188
|
+
- test/rbbt/vector/model/test_tensorflow.rb
|
189
|
+
- test/rbbt/vector/test_model.rb
|
190
|
+
- test/test_helper.rb
|
data/lib/rbbt/tensorflow.rb
DELETED
@@ -1,43 +0,0 @@
|
|
1
|
-
require 'rbbt/util/python'
|
2
|
-
|
3
|
-
module RbbtTensorflow
|
4
|
-
|
5
|
-
def self.init
|
6
|
-
RbbtPython.run do
|
7
|
-
pyimport "tensorflow", as: "tf"
|
8
|
-
end
|
9
|
-
end
|
10
|
-
|
11
|
-
def self.test
|
12
|
-
|
13
|
-
mod = x_test = y_test = nil
|
14
|
-
RbbtPython.run do
|
15
|
-
|
16
|
-
mnist_db = tf.keras.datasets.mnist
|
17
|
-
|
18
|
-
(x_train, y_train), (x_test, y_test) = mnist_db.load_data()
|
19
|
-
x_train, x_test = x_train / 255.0, x_test / 255.0
|
20
|
-
|
21
|
-
mod = tf.keras.models.Sequential.new([
|
22
|
-
tf.keras.layers.Flatten.new(input_shape: [28, 28]),
|
23
|
-
tf.keras.layers.Dense.new(128, activation:'relu'),
|
24
|
-
tf.keras.layers.Dropout.new(0.2),
|
25
|
-
tf.keras.layers.Dense.new(10, activation:'softmax')
|
26
|
-
])
|
27
|
-
mod.compile(optimizer='adam',
|
28
|
-
loss='sparse_categorical_crossentropy',
|
29
|
-
metrics=['accuracy'])
|
30
|
-
mod.fit(x_train, y_train, epochs:3)
|
31
|
-
mod
|
32
|
-
end
|
33
|
-
|
34
|
-
RbbtPython.run do
|
35
|
-
mod.evaluate(x_test, y_test, verbose:2)
|
36
|
-
end
|
37
|
-
end
|
38
|
-
end
|
39
|
-
|
40
|
-
if __FILE__ == $0
|
41
|
-
RbbtTensorflow.init
|
42
|
-
RbbtTensorflow.test
|
43
|
-
end
|
@@ -1,160 +0,0 @@
|
|
1
|
-
require 'rbbt/vector/model'
|
2
|
-
require 'rbbt/util/python'
|
3
|
-
|
4
|
-
RbbtPython.add_path Rbbt.python.find(:lib)
|
5
|
-
RbbtPython.init_rbbt
|
6
|
-
|
7
|
-
class HuggingfaceModel < VectorModel
|
8
|
-
|
9
|
-
attr_accessor :checkpoint, :task, :locate_tokens, :class_labels, :class_weights, :training_args
|
10
|
-
|
11
|
-
def self.tsv_dataset(tsv_dataset_file, elements, labels = nil)
|
12
|
-
|
13
|
-
if labels
|
14
|
-
Open.write(tsv_dataset_file) do |ffile|
|
15
|
-
ffile.puts ["label", "text"].flatten * "\t"
|
16
|
-
elements.zip(labels).each do |element,label|
|
17
|
-
ffile.puts [label, element].flatten * "\t"
|
18
|
-
end
|
19
|
-
end
|
20
|
-
else
|
21
|
-
Open.write(tsv_dataset_file) do |ffile|
|
22
|
-
ffile.puts ["text"].flatten * "\t"
|
23
|
-
elements.each{|element| ffile.puts element }
|
24
|
-
end
|
25
|
-
end
|
26
|
-
|
27
|
-
tsv_dataset_file
|
28
|
-
end
|
29
|
-
|
30
|
-
def self.call_method(name, *args)
|
31
|
-
RbbtPython.import_method("rbbt_dm.huggingface", name).call(*args)
|
32
|
-
end
|
33
|
-
|
34
|
-
def call_method(name, *args)
|
35
|
-
HuggingfaceModel.call_method(name, *args)
|
36
|
-
end
|
37
|
-
|
38
|
-
#def input_tsv_file
|
39
|
-
# File.join(@directory, 'dataset.tsv') if @directory
|
40
|
-
#end
|
41
|
-
|
42
|
-
#def checkpoint_dir
|
43
|
-
# File.join(@directory, 'checkpoints') if @directory
|
44
|
-
#end
|
45
|
-
|
46
|
-
def self.run_model(model, tokenizer, elements, labels = nil, training_args = {}, class_weights = nil)
|
47
|
-
TmpFile.with_file do |tmpfile|
|
48
|
-
tsv_file = File.join(tmpfile, 'dataset.tsv')
|
49
|
-
|
50
|
-
if training_args
|
51
|
-
training_args = training_args.dup
|
52
|
-
checkpoint_dir = training_args.delete(:checkpoint_dir)
|
53
|
-
end
|
54
|
-
|
55
|
-
checkpoint_dir = File.join(tmpfile, 'checkpoints')
|
56
|
-
|
57
|
-
Open.mkdir File.dirname(tsv_file)
|
58
|
-
Open.mkdir File.dirname(checkpoint_dir)
|
59
|
-
|
60
|
-
if labels
|
61
|
-
training_args_obj = call_method(:training_args, checkpoint_dir, **training_args)
|
62
|
-
call_method(:train_model, model, tokenizer, training_args_obj, tsv_dataset(tsv_file, elements, labels), class_weights)
|
63
|
-
else
|
64
|
-
locate_tokens, training_args = training_args, {}
|
65
|
-
if Array === elements
|
66
|
-
training_args_obj = call_method(:training_args, checkpoint_dir)
|
67
|
-
call_method(:predict_model, model, tokenizer, training_args_obj, tsv_dataset(tsv_file, elements), locate_tokens)
|
68
|
-
else
|
69
|
-
call_method(:eval_model, model, tokenizer, [elements], locate_tokens)
|
70
|
-
end
|
71
|
-
end
|
72
|
-
end
|
73
|
-
end
|
74
|
-
|
75
|
-
def init_model
|
76
|
-
@model, @tokenizer = call_method(:load_model_and_tokenizer, @task, @checkpoint)
|
77
|
-
end
|
78
|
-
|
79
|
-
def reset_model
|
80
|
-
init_model
|
81
|
-
end
|
82
|
-
|
83
|
-
def initialize(task, initial_checkpoint = nil, *args)
|
84
|
-
super(*args)
|
85
|
-
@task = task
|
86
|
-
|
87
|
-
@checkpoint = model_file && File.exists?(model_file)? model_file : initial_checkpoint
|
88
|
-
|
89
|
-
init_model
|
90
|
-
|
91
|
-
@locate_tokens = @tokenizer.special_tokens_map["mask_token"] if @task == "MaskedLM"
|
92
|
-
|
93
|
-
@training_args = {}
|
94
|
-
|
95
|
-
train_model do |file,elements,labels|
|
96
|
-
HuggingfaceModel.run_model(@model, @tokenizer, elements, labels, @training_args, @class_weights)
|
97
|
-
|
98
|
-
@model.save_pretrained(file) if file
|
99
|
-
@tokenizer.save_pretrained(file) if file
|
100
|
-
end
|
101
|
-
|
102
|
-
eval_model do |file,elements|
|
103
|
-
@model, @tokenizer = HuggingfaceModel.call_method(:load_model_and_tokenizer, @task, @checkpoint)
|
104
|
-
HuggingfaceModel.run_model(@model, @tokenizer, elements, nil, @locate_tokens)
|
105
|
-
end
|
106
|
-
|
107
|
-
post_process do |result|
|
108
|
-
if result.respond_to?(:predictions)
|
109
|
-
single = false
|
110
|
-
predictions = result.predictions
|
111
|
-
elsif result["token_positions"]
|
112
|
-
predictions = result["result"].predictions
|
113
|
-
token_positions = result["token_positions"]
|
114
|
-
else
|
115
|
-
single = true
|
116
|
-
predictions = result["logits"]
|
117
|
-
end
|
118
|
-
|
119
|
-
result = case @task
|
120
|
-
when "SequenceClassification"
|
121
|
-
RbbtPython.collect(predictions) do |logits|
|
122
|
-
logits = RbbtPython.numpy2ruby logits
|
123
|
-
best_class = logits.index logits.max
|
124
|
-
best_class = @class_labels[best_class] if @class_labels
|
125
|
-
best_class
|
126
|
-
end
|
127
|
-
when "MaskedLM"
|
128
|
-
all_token_positions = token_positions.to_a
|
129
|
-
|
130
|
-
i = 0
|
131
|
-
RbbtPython.collect(predictions) do |item_logits|
|
132
|
-
item_token_positions = all_token_positions[i]
|
133
|
-
i += 1
|
134
|
-
|
135
|
-
item_logits = RbbtPython.numpy2ruby(item_logits)
|
136
|
-
item_masks = item_token_positions.collect do |token_positions|
|
137
|
-
|
138
|
-
best = item_logits.values_at(*token_positions).collect do |logits|
|
139
|
-
best_token, best_score = nil
|
140
|
-
logits.each_with_index do |v,i|
|
141
|
-
if best_score.nil? || v > best_score
|
142
|
-
best_token, best_score = i, v
|
143
|
-
end
|
144
|
-
end
|
145
|
-
best_token
|
146
|
-
end
|
147
|
-
|
148
|
-
best.collect{|b| @tokenizer.decode(b) } * "|"
|
149
|
-
end
|
150
|
-
Array === @locate_tokens ? item_masks : item_masks.first
|
151
|
-
end
|
152
|
-
else
|
153
|
-
logits
|
154
|
-
end
|
155
|
-
|
156
|
-
single ? result.first : result
|
157
|
-
end
|
158
|
-
end
|
159
|
-
end
|
160
|
-
|