rbbt-dm 1.2.7 → 1.2.10
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/lib/rbbt/matrix/barcode.rb +2 -2
- data/lib/rbbt/matrix/differential.rb +3 -3
- data/lib/rbbt/matrix/knowledge_base.rb +1 -1
- data/lib/rbbt/plots/bar.rb +1 -1
- data/lib/rbbt/stan.rb +1 -1
- data/lib/rbbt/statistics/hypergeometric.rb +2 -1
- data/lib/rbbt/vector/model/huggingface/masked_lm.rb +50 -0
- data/lib/rbbt/vector/model/huggingface.rb +39 -52
- data/lib/rbbt/vector/model/python.rb +33 -0
- data/lib/rbbt/vector/model/pytorch_lightning.rb +31 -0
- data/lib/rbbt/vector/model/random_forest.rb +1 -1
- data/lib/rbbt/vector/model/spaCy.rb +8 -6
- data/lib/rbbt/vector/model/tensorflow.rb +6 -5
- data/lib/rbbt/vector/model/torch/dataloader.rb +58 -0
- data/lib/rbbt/vector/model/torch/helpers.rb +52 -0
- data/lib/rbbt/vector/model/torch/introspection.rb +31 -0
- data/lib/rbbt/vector/model/torch/load_and_save.rb +30 -0
- data/lib/rbbt/vector/model/torch.rb +71 -0
- data/lib/rbbt/vector/model.rb +84 -54
- data/python/rbbt_dm/__init__.py +31 -1
- data/python/rbbt_dm/atcold/__init__.py +0 -0
- data/python/rbbt_dm/atcold/plot_lib.py +141 -0
- data/python/rbbt_dm/atcold/spiral.py +27 -0
- data/python/rbbt_dm/huggingface.py +64 -28
- data/python/rbbt_dm/language_model.py +70 -0
- data/python/rbbt_dm/util.py +32 -0
- data/share/spaCy/gpu/textcat_accuracy.conf +2 -1
- data/test/rbbt/vector/model/huggingface/test_masked_lm.rb +41 -0
- data/test/rbbt/vector/model/test_huggingface.rb +258 -27
- data/test/rbbt/vector/model/test_python.rb +31 -0
- data/test/rbbt/vector/model/test_pytorch_lightning.rb +97 -0
- data/test/rbbt/vector/model/test_spaCy.rb +1 -1
- data/test/rbbt/vector/model/test_tensorflow.rb +3 -0
- data/test/rbbt/vector/model/test_torch.rb +61 -0
- data/test/rbbt/vector/test_model.rb +25 -26
- data/test/test_helper.rb +13 -0
- metadata +35 -16
- data/lib/rbbt/tensorflow.rb +0 -43
- data/lib/rbbt/vector/model/huggingface.old.rb +0 -160
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 9d53609453e1c3bd589c95071569583bff5f11224a200850c0d9e85775e5a2ce
|
4
|
+
data.tar.gz: 501f436caff07c990c09adec4caa60d32fe577b35209a65e8f96418ab2acc422
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 4a53ccc8a5eac633e344beffd36073e1828d848d9743552f12e75a54cbda337050152b72fd95d7d97d4aeed50f811dcc8475c97bc292ba9bcdd57e9d32f0e91b
|
7
|
+
data.tar.gz: 30fef4ee2c023e141c0ef991ae35d1a0913d4444138fda778760b821ccc48d85d67dc44d4c251d7f84e3b1bbe6ecb6f3b64aee13fbdc443cb33bbed2b34a586a
|
data/lib/rbbt/matrix/barcode.rb
CHANGED
@@ -3,7 +3,7 @@ require 'rbbt/util/R'
|
|
3
3
|
class RbbtMatrix
|
4
4
|
def barcode(outfile, factor = 2)
|
5
5
|
|
6
|
-
FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.
|
6
|
+
FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exist? File.dirname(outfile)
|
7
7
|
cmd =<<-EOF
|
8
8
|
source('#{Rbbt.share.R['barcode.R'].find}')
|
9
9
|
rbbt.GE.barcode.mode(#{ R.ruby2R self.data_file }, #{ R.ruby2R outfile }, #{ R.ruby2R factor })
|
@@ -49,7 +49,7 @@ rbbt.GE.barcode.mode(#{ R.ruby2R self.data_file }, #{ R.ruby2R outfile }, #{ R.r
|
|
49
49
|
|
50
50
|
clusters = Array === clusters ? clusters : (2..clusters).to_a
|
51
51
|
|
52
|
-
FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.
|
52
|
+
FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exist? File.dirname(outfile)
|
53
53
|
cmd =<<-EOF
|
54
54
|
source('#{Rbbt.share.R['barcode.R'].find}')
|
55
55
|
rbbt.GE.activity_cluster(#{ R.ruby2R self.data_file }, #{ R.ruby2R outfile }, #{R.ruby2R key_field}, #{R.ruby2R clusters})
|
@@ -45,7 +45,7 @@ class RbbtMatrix
|
|
45
45
|
end
|
46
46
|
|
47
47
|
file = file.find if Path === file
|
48
|
-
FileUtils.mkdir_p File.dirname(file) unless file.nil? or File.
|
48
|
+
FileUtils.mkdir_p File.dirname(file) unless file.nil? or File.exist? File.dirname(file)
|
49
49
|
|
50
50
|
cmd = <<-EOS
|
51
51
|
|
@@ -71,10 +71,10 @@ end
|
|
71
71
|
|
72
72
|
|
73
73
|
#def self.analyze(datafile, main, contrast = nil, log2 = false, outfile = nil, key_field = nil, two_channel = nil)
|
74
|
-
# FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.
|
74
|
+
# FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exist? File.dirname(outfile)
|
75
75
|
# GE.run_R("rbbt.GE.process(#{ R.ruby2R datafile }, main = #{R.ruby2R(main, :strings => true)}, contrast = #{R.ruby2R(contrast, :strings => true)}, log2=#{ R.ruby2R log2 }, outfile = #{R.ruby2R outfile}, key.field = #{R.ruby2R key_field}, two.channel = #{R.ruby2R two_channel})")
|
76
76
|
#end
|
77
77
|
#def self.barcode(datafile, outfile, factor = 2)
|
78
|
-
# FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.
|
78
|
+
# FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exist? File.dirname(outfile)
|
79
79
|
# GE.run_R("rbbt.GE.barcode(#{ R.ruby2R datafile }, #{ R.ruby2R outfile }, #{ R.ruby2R factor })")
|
80
80
|
#end
|
@@ -13,7 +13,7 @@ class KnowledgeBase
|
|
13
13
|
|
14
14
|
return matrix if RbbtMatrix === matrix
|
15
15
|
|
16
|
-
Path.setup(matrix) if not Path === matrix and File.
|
16
|
+
Path.setup(matrix) if not Path === matrix and File.exist? matrix
|
17
17
|
|
18
18
|
raise "Registered matrix is strange: #{Misc.fingerprint matrix}" unless Path === matrix
|
19
19
|
|
data/lib/rbbt/plots/bar.rb
CHANGED
@@ -119,7 +119,7 @@ module BarPlot
|
|
119
119
|
options = Misc.add_defaults options, :width => [options[:total], MAX_WIDTH].min, :height => 20, :background => PNG::Color::White
|
120
120
|
width, height, background, canvas = Misc.process_options options, :width, :height, :background, :canvas
|
121
121
|
|
122
|
-
canvas ||= if options[:update] and options[:filename] and File.
|
122
|
+
canvas ||= if options[:update] and options[:filename] and File.exist? options[:filename]
|
123
123
|
PNG.load_file options[:filename]
|
124
124
|
else
|
125
125
|
PNG::Canvas.new width, height, get_color(background)
|
data/lib/rbbt/stan.rb
CHANGED
@@ -130,7 +130,7 @@ print(fit)
|
|
130
130
|
erase = true
|
131
131
|
end
|
132
132
|
|
133
|
-
FileUtils.mkdir_p directory unless File.
|
133
|
+
FileUtils.mkdir_p directory unless File.exist? directory
|
134
134
|
input_directory = File.join(directory, 'inputs')
|
135
135
|
parameter_chains = File.join(directory, 'chains') unless erase
|
136
136
|
summary = File.join(directory, 'summary') unless erase
|
@@ -118,7 +118,7 @@ module TSV
|
|
118
118
|
|
119
119
|
field_pos = fields.collect{|f| self.fields.index f}.compact
|
120
120
|
persistence_path = self.respond_to?(:persistence_path)? self.persistence_path : nil
|
121
|
-
Persist.persist(filename, :yaml, :fields => fields, :persist => persistence, :prefix => "Hyp.Geo.Counts", :other => {:background => background, :rename => rename, :persistence_path => persistence_path}) do
|
121
|
+
Persist.persist(filename, :yaml, :fields => fields, :persist => persistence, :prefix => "Hyp.Geo.Counts", :other => {:fields => fields, :background => background, :rename => rename, :persistence_path => persistence_path}) do
|
122
122
|
data ||= {}
|
123
123
|
|
124
124
|
with_unnamed do
|
@@ -152,6 +152,7 @@ module TSV
|
|
152
152
|
|
153
153
|
end
|
154
154
|
|
155
|
+
|
155
156
|
if rename
|
156
157
|
Log.debug("Using renames during annotation counts")
|
157
158
|
Hash[*data.keys.zip(data.values.collect{|l| l.collect{|e| rename.include?(e)? rename[e] : e }.uniq.length }).flatten]
|
@@ -0,0 +1,50 @@
|
|
1
|
+
require 'rbbt/vector/model/huggingface'
|
2
|
+
class MaskedLMModel < HuggingfaceModel
|
3
|
+
|
4
|
+
def initialize(checkpoint, dir = nil, model_options = {})
|
5
|
+
|
6
|
+
model_options = Misc.add_defaults model_options, :max_length => 128
|
7
|
+
super("MaskedLM", checkpoint, dir, model_options)
|
8
|
+
|
9
|
+
train_model do |texts,labels|
|
10
|
+
model, tokenizer = self.init
|
11
|
+
max_length = @model_options[:max_length]
|
12
|
+
mask_id = tokenizer.mask_token_id
|
13
|
+
|
14
|
+
dataset = []
|
15
|
+
texts.zip(labels).each do |text,label_values|
|
16
|
+
fixed_text = text.gsub("[MASK]", "[PENDINGMASK]")
|
17
|
+
label_tokens = label_values.collect{|label| tokenizer.convert_tokens_to_ids(label) }
|
18
|
+
label_tokens.each do |ids|
|
19
|
+
ids = [ids] unless Array === ids
|
20
|
+
fixed_text.sub!("[PENDINGMASK]", "[MASK]" * ids.length)
|
21
|
+
end
|
22
|
+
|
23
|
+
tokenized_text = tokenizer.call(fixed_text, truncation: true, padding: "max_length")
|
24
|
+
input_ids = tokenized_text["input_ids"].to_a
|
25
|
+
attention_mask = tokenized_text["attention_mask"].to_a
|
26
|
+
|
27
|
+
all_label_tokens = label_tokens.flatten
|
28
|
+
label_ids = input_ids.collect do |id|
|
29
|
+
if id == mask_id
|
30
|
+
all_label_tokens.shift
|
31
|
+
else
|
32
|
+
-100
|
33
|
+
end
|
34
|
+
end
|
35
|
+
dataset << {input_ids: input_ids, labels: label_ids, attention_mask: attention_mask}
|
36
|
+
end
|
37
|
+
|
38
|
+
dataset_file = File.join(@directory, 'dataset.json')
|
39
|
+
Open.write(dataset_file, dataset.collect{|e| e.to_json} * "\n")
|
40
|
+
|
41
|
+
training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, @model_path, @model_options[:training_args])
|
42
|
+
data_collator = RbbtPython.class_new_obj("transformers", "DefaultDataCollator", {})
|
43
|
+
RbbtPython.call_method("rbbt_dm.huggingface", :train_model, model, tokenizer, training_args_obj, dataset_file, @model_options[:class_weights], data_collator: data_collator)
|
44
|
+
|
45
|
+
model.save_pretrained(@model_path) if @model_path
|
46
|
+
tokenizer.save_pretrained(@model_path) if @model_path
|
47
|
+
end
|
48
|
+
|
49
|
+
end
|
50
|
+
end
|
@@ -1,50 +1,36 @@
|
|
1
|
-
require 'rbbt/vector/model'
|
2
|
-
require 'rbbt/util/python'
|
1
|
+
require 'rbbt/vector/model/torch'
|
3
2
|
|
4
|
-
|
5
|
-
RbbtPython.init_rbbt
|
3
|
+
class HuggingfaceModel < TorchModel
|
6
4
|
|
7
|
-
|
5
|
+
def initialize(task, checkpoint, dir = nil, model_options = {})
|
6
|
+
super(dir, nil, model_options)
|
8
7
|
|
9
|
-
|
8
|
+
@model_options = Misc.add_defaults @model_options, :task => task, :checkpoint => checkpoint
|
10
9
|
|
11
|
-
|
12
|
-
|
13
|
-
ffile.puts ["label", "text"].flatten * "\t"
|
14
|
-
elements.zip(labels).each do |element,label|
|
15
|
-
ffile.puts [label, element].flatten * "\t"
|
16
|
-
end
|
17
|
-
end
|
18
|
-
else
|
19
|
-
Open.write(tsv_dataset_file) do |ffile|
|
20
|
-
ffile.puts ["text"].flatten * "\t"
|
21
|
-
elements.each{|element| ffile.puts element }
|
22
|
-
end
|
23
|
-
end
|
10
|
+
init_model do
|
11
|
+
checkpoint = @model_path && File.directory?(@model_path) ? @model_path : @model_options[:checkpoint]
|
24
12
|
|
25
|
-
|
26
|
-
|
13
|
+
model = RbbtPython.call_method("rbbt_dm.huggingface", :load_model,
|
14
|
+
@model_options[:task], checkpoint, **(IndiferentHash.setup(model_options[:model_args]) || {}))
|
27
15
|
|
28
|
-
|
29
|
-
options = args.pop if Hash === args.last
|
30
|
-
options = Misc.add_defaults options, :task => task, :checkpoint => checkpoint
|
31
|
-
super(*args)
|
32
|
-
@model_options ||= {}
|
33
|
-
@model_options.merge!(options)
|
16
|
+
tokenizer_checkpoint = @model_options[:tokenizer_checkpoint] || checkpoint
|
34
17
|
|
35
|
-
|
36
|
-
|
18
|
+
tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_tokenizer,
|
19
|
+
@model_options[:task], tokenizer_checkpoint, **(IndiferentHash.setup(model_options[:tokenizer_args]) || {}))
|
37
20
|
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
21
|
+
[model, tokenizer]
|
22
|
+
end
|
23
|
+
|
24
|
+
eval_model do |texts,is_list|
|
25
|
+
model, tokenizer = self.init
|
26
|
+
|
27
|
+
if is_list || @model_options[:task] == "MaskedLM"
|
28
|
+
texts = [texts] if ! is_list
|
43
29
|
|
44
30
|
if @model_options.include?(:locate_tokens)
|
45
31
|
locate_tokens = @model_options[:locate_tokens]
|
46
32
|
elsif @model_options[:task] == "MaskedLM"
|
47
|
-
@model_options[:locate_tokens] = locate_tokens =
|
33
|
+
@model_options[:locate_tokens] = locate_tokens = tokenizer.special_tokens_map["mask_token"]
|
48
34
|
end
|
49
35
|
|
50
36
|
if @directory
|
@@ -57,22 +43,21 @@ class HuggingfaceModel < VectorModel
|
|
57
43
|
checkpoint_dir = File.join(tmpdir, 'checkpoints')
|
58
44
|
end
|
59
45
|
|
60
|
-
dataset_file =
|
46
|
+
dataset_file = TorchModel.text_dataset(tsv_file, texts)
|
61
47
|
training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, checkpoint_dir, @model_options[:training_args])
|
62
48
|
|
63
49
|
begin
|
64
|
-
RbbtPython.call_method("rbbt_dm.huggingface", :predict_model,
|
50
|
+
RbbtPython.call_method("rbbt_dm.huggingface", :predict_model, model, tokenizer, training_args_obj, dataset_file, locate_tokens)
|
65
51
|
ensure
|
66
52
|
Open.rm_rf tmpdir if tmpdir
|
67
53
|
end
|
68
54
|
else
|
69
|
-
RbbtPython.call_method("rbbt_dm.huggingface", :eval_model,
|
55
|
+
RbbtPython.call_method("rbbt_dm.huggingface", :eval_model, model, tokenizer, [texts], locate_tokens)
|
70
56
|
end
|
71
57
|
end
|
72
58
|
|
73
|
-
train_model do |
|
74
|
-
|
75
|
-
@model, @tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_model_and_tokenizer, @model_options[:task], checkpoint)
|
59
|
+
train_model do |texts,labels|
|
60
|
+
model, tokenizer = self.init
|
76
61
|
|
77
62
|
if @directory
|
78
63
|
tsv_file = File.join(@directory, 'dataset.tsv')
|
@@ -85,17 +70,19 @@ class HuggingfaceModel < VectorModel
|
|
85
70
|
end
|
86
71
|
|
87
72
|
training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, checkpoint_dir, @model_options[:training_args])
|
88
|
-
dataset_file = HuggingfaceModel.
|
73
|
+
dataset_file = HuggingfaceModel.text_dataset(tsv_file, texts, labels, @model_options[:class_labels])
|
89
74
|
|
90
|
-
RbbtPython.call_method("rbbt_dm.huggingface", :train_model,
|
75
|
+
RbbtPython.call_method("rbbt_dm.huggingface", :train_model, model, tokenizer, training_args_obj, dataset_file, @model_options[:class_weights])
|
91
76
|
|
92
77
|
Open.rm_rf tmpdir if tmpdir
|
93
78
|
|
94
|
-
|
95
|
-
|
79
|
+
model.save_pretrained(@model_path) if @model_path
|
80
|
+
tokenizer.save_pretrained(@model_path) if @model_path
|
96
81
|
end
|
97
82
|
|
98
|
-
post_process do |result|
|
83
|
+
post_process do |result,is_list|
|
84
|
+
model, tokenizer = self.init
|
85
|
+
|
99
86
|
if result.respond_to?(:predictions)
|
100
87
|
single = false
|
101
88
|
predictions = result.predictions
|
@@ -137,25 +124,25 @@ class HuggingfaceModel < VectorModel
|
|
137
124
|
best_token
|
138
125
|
end
|
139
126
|
|
140
|
-
best.collect{|b|
|
127
|
+
best.collect{|b| tokenizer.decode(b) } * "|"
|
141
128
|
end
|
142
129
|
Array === locate_tokens ? item_masks : item_masks.first
|
143
130
|
end
|
144
131
|
else
|
145
|
-
|
132
|
+
predictions
|
146
133
|
end
|
147
134
|
|
148
|
-
single ? result.first : result
|
135
|
+
(! is_list || single) && Array === result ? result.first : result
|
149
136
|
end
|
150
137
|
|
151
138
|
|
152
|
-
save_models if @
|
139
|
+
save_models if @model_path
|
153
140
|
end
|
154
141
|
|
155
142
|
def reset_model
|
156
143
|
@model, @tokenizer = nil
|
157
|
-
Open.
|
144
|
+
Open.rm_rf @model_path
|
145
|
+
init
|
158
146
|
end
|
159
|
-
|
160
147
|
end
|
161
148
|
|
@@ -0,0 +1,33 @@
|
|
1
|
+
require 'rbbt/vector/model'
|
2
|
+
require 'rbbt/util/python'
|
3
|
+
|
4
|
+
RbbtPython.add_path Rbbt.python.find(:lib)
|
5
|
+
RbbtPython.init_rbbt
|
6
|
+
|
7
|
+
class PythonModel < VectorModel
|
8
|
+
attr_accessor :python_class, :python_module
|
9
|
+
def initialize(dir, python_class = nil, python_module = nil, model_options = nil)
|
10
|
+
python_module = :model if python_module.nil?
|
11
|
+
model_options, python_module = python_module, :model if model_options.nil? && Hash === python_module
|
12
|
+
model_options = {} if model_options.nil?
|
13
|
+
|
14
|
+
super(dir, model_options)
|
15
|
+
|
16
|
+
@python_class = python_class
|
17
|
+
@python_module = python_module
|
18
|
+
|
19
|
+
init_model do
|
20
|
+
RbbtPython.add_path @directory
|
21
|
+
RbbtPython.class_new_obj(@python_module, @python_class, **model_options)
|
22
|
+
end if python_class
|
23
|
+
|
24
|
+
eval_model do |features,list=false|
|
25
|
+
init
|
26
|
+
if list
|
27
|
+
model.eval(features)
|
28
|
+
else
|
29
|
+
model.eval([features])[0]
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
33
|
+
end
|
@@ -0,0 +1,31 @@
|
|
1
|
+
require 'rbbt/vector/model/torch'
|
2
|
+
|
3
|
+
class PytorchLightningModel < TorchModel
|
4
|
+
attr_accessor :loader, :val_loader, :trainer
|
5
|
+
def initialize(...)
|
6
|
+
super(...)
|
7
|
+
|
8
|
+
train_model do |features,labels|
|
9
|
+
model = init
|
10
|
+
loader = self.loader
|
11
|
+
val_loader = self.val_loader
|
12
|
+
if (features && features.any?) && loader.nil?
|
13
|
+
TmpFile.with_file do |tsv_dataset_file|
|
14
|
+
TorchModel.feature_dataset(tsv_dataset_file, features, labels)
|
15
|
+
RbbtPython.pyimport :rbbt_dm
|
16
|
+
loader = RbbtPython.rbbt_dm.tsv(tsv_dataset_file)
|
17
|
+
end
|
18
|
+
end
|
19
|
+
trainer.fit(model, loader, val_loader)
|
20
|
+
TorchModel.save_architecture(model, model_path) if @directory
|
21
|
+
TorchModel.save_state(model, model_path) if @directory
|
22
|
+
end
|
23
|
+
end
|
24
|
+
|
25
|
+
def trainer
|
26
|
+
@trainer ||= begin
|
27
|
+
options = @model_options[:training_args] || @model_options[:trainer_args]
|
28
|
+
RbbtPython.class_new_obj("pytorch_lightning", "Trainer", options || {})
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
@@ -28,12 +28,12 @@ class SpaCyModel < VectorModel
|
|
28
28
|
|
29
29
|
super(dir)
|
30
30
|
|
31
|
-
@train_model = Proc.new do |
|
31
|
+
@train_model = Proc.new do |features, labels|
|
32
32
|
texts = features
|
33
33
|
docs = []
|
34
34
|
unique_labels = labels.uniq
|
35
|
-
tmpconfig = File.join(
|
36
|
-
tmptrain = File.join(
|
35
|
+
tmpconfig = File.join(@model_path, 'config')
|
36
|
+
tmptrain = File.join(@model_path, 'train.spacy')
|
37
37
|
SpaCy.config(@config, tmpconfig)
|
38
38
|
|
39
39
|
bar = bar(features.length, "Training documents into spacy format")
|
@@ -54,20 +54,22 @@ class SpaCyModel < VectorModel
|
|
54
54
|
end
|
55
55
|
|
56
56
|
gpu = Rbbt::Config.get('gpu_id', :spacy, :spacy_train, :default => 0)
|
57
|
-
CMD.cmd_log(:spacy, "train #{tmpconfig} --output #{
|
57
|
+
CMD.cmd_log(:spacy, "train #{tmpconfig} --output #{@model_path} --paths.train #{tmptrain} --paths.dev #{tmptrain}", "--gpu-id" => gpu)
|
58
58
|
end
|
59
59
|
|
60
|
-
@eval_model = Proc.new do |
|
60
|
+
@eval_model = Proc.new do |features,list|
|
61
61
|
texts = features
|
62
62
|
texts = [texts] unless list
|
63
63
|
|
64
|
+
model_path = @model_path
|
65
|
+
|
64
66
|
docs = []
|
65
67
|
bar = bar(features.length, "Evaluating model")
|
66
68
|
SpaCyModel.spacy do
|
67
69
|
gpu = Rbbt::Config.get('gpu_id', :spacy, :spacy_train, :default => 0)
|
68
70
|
gpu = gpu.to_i if gpu && gpu != ""
|
69
71
|
spacy.require_gpu(gpu) if gpu
|
70
|
-
nlp = spacy.load("#{
|
72
|
+
nlp = spacy.load("#{model_path}/model-best")
|
71
73
|
|
72
74
|
docs = nlp.pipe(texts)
|
73
75
|
RbbtPython.collect docs, :bar => bar do |d|
|
@@ -25,7 +25,7 @@ class TensorFlowModel < VectorModel
|
|
25
25
|
|
26
26
|
super(dir)
|
27
27
|
|
28
|
-
@train_model = Proc.new do |
|
28
|
+
@train_model = Proc.new do |features, labels|
|
29
29
|
tensorflow do
|
30
30
|
features = tensorflow.convert_to_tensor(features)
|
31
31
|
labels = tensorflow.convert_to_tensor(labels)
|
@@ -33,16 +33,17 @@ class TensorFlowModel < VectorModel
|
|
33
33
|
@graph ||= keras_graph
|
34
34
|
@graph.compile(**@compile_options)
|
35
35
|
@graph.fit(features, labels, :epochs => @epochs, :verbose => true)
|
36
|
-
@graph.save(
|
36
|
+
@graph.save(@model_path)
|
37
37
|
end
|
38
38
|
|
39
|
-
@eval_model = Proc.new do |
|
39
|
+
@eval_model = Proc.new do |features|
|
40
40
|
tensorflow do
|
41
41
|
features = tensorflow.convert_to_tensor(features)
|
42
42
|
end
|
43
|
+
model_path = @model_path
|
44
|
+
graph = @graph ||= keras.models.load_model(model_path)
|
43
45
|
keras do
|
44
|
-
|
45
|
-
indices = @graph.predict(features, :verbose => false).tolist()
|
46
|
+
indices = graph.predict(features, :verbose => false).tolist()
|
46
47
|
labels = indices.collect{|p| p.length > 1 ? p.index(p.max): p.first }
|
47
48
|
labels
|
48
49
|
end
|
@@ -0,0 +1,58 @@
|
|
1
|
+
class TorchModel
|
2
|
+
def self.feature_tsv(elements, labels = nil, class_labels = nil)
|
3
|
+
tsv = TSV.setup({}, :key_field => "ID", :fields => ["features"], :type => :flat)
|
4
|
+
if labels
|
5
|
+
tsv.fields = tsv.fields + ["label"]
|
6
|
+
labels = case class_labels
|
7
|
+
when Array
|
8
|
+
labels.collect{|l| class_labels.index l}
|
9
|
+
when Hash
|
10
|
+
inverse_class_labels = {}
|
11
|
+
class_labels.each{|c,l| inverse_class_labels[l] = c }
|
12
|
+
labels.collect{|l| inverse_class_labels[l]}
|
13
|
+
else
|
14
|
+
labels
|
15
|
+
end
|
16
|
+
elements.zip(labels).each_with_index do |p,i|
|
17
|
+
features, label = p
|
18
|
+
id = i
|
19
|
+
if Array === features
|
20
|
+
tsv[id] = features + [label]
|
21
|
+
else
|
22
|
+
tsv[id] = [features, label]
|
23
|
+
end
|
24
|
+
end
|
25
|
+
else
|
26
|
+
elements.each_with_index do |features,i|
|
27
|
+
id = i
|
28
|
+
if Array === features
|
29
|
+
tsv[id] = features
|
30
|
+
else
|
31
|
+
tsv[id] = [features]
|
32
|
+
end
|
33
|
+
end
|
34
|
+
end
|
35
|
+
tsv
|
36
|
+
end
|
37
|
+
|
38
|
+
def self.feature_dataset(tsv_dataset_file, elements, labels = nil, class_labels = nil)
|
39
|
+
tsv = feature_tsv(elements, labels, class_labels)
|
40
|
+
Open.write(tsv_dataset_file, tsv.to_s)
|
41
|
+
tsv_dataset_file
|
42
|
+
end
|
43
|
+
|
44
|
+
def self.text_dataset(tsv_dataset_file, elements, labels = nil, class_labels = nil)
|
45
|
+
elements = elements.collect{|e| e.gsub("\n", ' ') }
|
46
|
+
tsv = feature_tsv(elements, labels, class_labels)
|
47
|
+
if labels.nil?
|
48
|
+
tsv.fields[0] = "text"
|
49
|
+
tsv.type = :single
|
50
|
+
else
|
51
|
+
tsv.fields[0] = "text"
|
52
|
+
tsv.type = :list
|
53
|
+
end
|
54
|
+
Open.write(tsv_dataset_file, tsv.to_s)
|
55
|
+
tsv_dataset_file
|
56
|
+
end
|
57
|
+
|
58
|
+
end
|
@@ -0,0 +1,52 @@
|
|
1
|
+
class TorchModel
|
2
|
+
module Tensor
|
3
|
+
def to_ruby
|
4
|
+
RbbtPython.numpy2ruby(self)
|
5
|
+
end
|
6
|
+
def self.setup(obj)
|
7
|
+
obj.extend Tensor
|
8
|
+
end
|
9
|
+
end
|
10
|
+
|
11
|
+
def self.init_python
|
12
|
+
RbbtPython.pyimport :torch
|
13
|
+
RbbtPython.pyimport :rbbt
|
14
|
+
RbbtPython.pyimport :rbbt_dm
|
15
|
+
RbbtPython.pyfrom :rbbt_dm, import: :util
|
16
|
+
RbbtPython.pyfrom :torch, import: :nn
|
17
|
+
end
|
18
|
+
|
19
|
+
def self.optimizer(model, training_args)
|
20
|
+
begin
|
21
|
+
learning_rate = training_args[:learning_rate] || 0.01
|
22
|
+
RbbtPython.torch.optim.SGD.new(model.parameters(), lr: learning_rate)
|
23
|
+
end
|
24
|
+
end
|
25
|
+
|
26
|
+
def self.device(model_options)
|
27
|
+
case model_options[:device]
|
28
|
+
when String, Symbol
|
29
|
+
RbbtPython.torch.device(model_options[:device].to_s)
|
30
|
+
when nil
|
31
|
+
RbbtPython.rbbt_dm.util.device()
|
32
|
+
else
|
33
|
+
model_options[:device]
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
37
|
+
def self.dtype(model_options)
|
38
|
+
case model_options[:dtype]
|
39
|
+
when String, Symbol
|
40
|
+
RbbtPython.torch.call(model_options[:dtype])
|
41
|
+
when nil
|
42
|
+
nil
|
43
|
+
else
|
44
|
+
model_options[:dtype]
|
45
|
+
end
|
46
|
+
end
|
47
|
+
|
48
|
+
def self.tensor(obj, device, dtype)
|
49
|
+
RbbtPython.torch.tensor(obj, dtype: dtype, device: device)
|
50
|
+
end
|
51
|
+
|
52
|
+
end
|
@@ -0,0 +1,31 @@
|
|
1
|
+
class TorchModel
|
2
|
+
def self.get_layer(model, layer = nil)
|
3
|
+
if layer.nil?
|
4
|
+
model
|
5
|
+
else
|
6
|
+
layer.split(".").inject(model){|acc,l| PyCall.getattr(acc, l.to_sym) }
|
7
|
+
end
|
8
|
+
end
|
9
|
+
def get_layer(...); TorchModel.get_layer(model, ...); end
|
10
|
+
|
11
|
+
def self.get_weights(model, layer = nil)
|
12
|
+
Tensor.setup PyCall.getattr(get_layer(model, layer), :weight)
|
13
|
+
end
|
14
|
+
def get_weights(...); TorchModel.get_weights(model, ...); end
|
15
|
+
|
16
|
+
def self.freeze(layer)
|
17
|
+
begin
|
18
|
+
PyCall.getattr(layer, :weight).requires_grad = false
|
19
|
+
rescue
|
20
|
+
end
|
21
|
+
RbbtPython.iterate(layer.children) do |layer|
|
22
|
+
freeze(layer)
|
23
|
+
end
|
24
|
+
end
|
25
|
+
def self.freeze_layer(model, layer)
|
26
|
+
layer = get_layer(model, layer)
|
27
|
+
freeze(layer)
|
28
|
+
end
|
29
|
+
def freeze_layer(...); TorchModel.freeze_layer(model, ...); end
|
30
|
+
|
31
|
+
end
|
@@ -0,0 +1,30 @@
|
|
1
|
+
class TorchModel
|
2
|
+
def self.model_architecture(model_path)
|
3
|
+
model_path + '.architecture'
|
4
|
+
end
|
5
|
+
|
6
|
+
def self.save_state(model, model_path)
|
7
|
+
Log.debug "Saving model state into #{model_path}"
|
8
|
+
RbbtPython.torch.save(model.state_dict(), model_path)
|
9
|
+
end
|
10
|
+
|
11
|
+
def self.load_state(model, model_path)
|
12
|
+
return model unless Open.exists?(model_path)
|
13
|
+
Log.debug "Loading model state from #{model_path}"
|
14
|
+
model.load_state_dict(RbbtPython.torch.load(model_path))
|
15
|
+
model
|
16
|
+
end
|
17
|
+
|
18
|
+
def self.save_architecture(model, model_path)
|
19
|
+
model_architecture = model_architecture(model_path)
|
20
|
+
Log.debug "Saving model architecture into #{model_architecture}"
|
21
|
+
RbbtPython.torch.save(model, model_architecture)
|
22
|
+
end
|
23
|
+
|
24
|
+
def self.load_architecture(model_path)
|
25
|
+
model_architecture = model_architecture(model_path)
|
26
|
+
return unless Open.exists?(model_architecture)
|
27
|
+
Log.debug "Loading model architecture from #{model_architecture}"
|
28
|
+
RbbtPython.torch.load(model_architecture)
|
29
|
+
end
|
30
|
+
end
|