rbbt-dm 1.2.7 → 1.2.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/lib/rbbt/matrix/barcode.rb +2 -2
- data/lib/rbbt/matrix/differential.rb +3 -3
- data/lib/rbbt/matrix/knowledge_base.rb +1 -1
- data/lib/rbbt/plots/bar.rb +1 -1
- data/lib/rbbt/stan.rb +1 -1
- data/lib/rbbt/statistics/hypergeometric.rb +2 -1
- data/lib/rbbt/vector/model/huggingface/masked_lm.rb +50 -0
- data/lib/rbbt/vector/model/huggingface.rb +39 -52
- data/lib/rbbt/vector/model/python.rb +33 -0
- data/lib/rbbt/vector/model/pytorch_lightning.rb +31 -0
- data/lib/rbbt/vector/model/random_forest.rb +1 -1
- data/lib/rbbt/vector/model/spaCy.rb +8 -6
- data/lib/rbbt/vector/model/tensorflow.rb +6 -5
- data/lib/rbbt/vector/model/torch/dataloader.rb +58 -0
- data/lib/rbbt/vector/model/torch/helpers.rb +52 -0
- data/lib/rbbt/vector/model/torch/introspection.rb +31 -0
- data/lib/rbbt/vector/model/torch/load_and_save.rb +30 -0
- data/lib/rbbt/vector/model/torch.rb +71 -0
- data/lib/rbbt/vector/model.rb +84 -54
- data/python/rbbt_dm/__init__.py +31 -1
- data/python/rbbt_dm/atcold/__init__.py +0 -0
- data/python/rbbt_dm/atcold/plot_lib.py +141 -0
- data/python/rbbt_dm/atcold/spiral.py +27 -0
- data/python/rbbt_dm/huggingface.py +64 -28
- data/python/rbbt_dm/language_model.py +70 -0
- data/python/rbbt_dm/util.py +32 -0
- data/share/spaCy/gpu/textcat_accuracy.conf +2 -1
- data/test/rbbt/vector/model/huggingface/test_masked_lm.rb +41 -0
- data/test/rbbt/vector/model/test_huggingface.rb +258 -27
- data/test/rbbt/vector/model/test_python.rb +31 -0
- data/test/rbbt/vector/model/test_pytorch_lightning.rb +97 -0
- data/test/rbbt/vector/model/test_spaCy.rb +1 -1
- data/test/rbbt/vector/model/test_tensorflow.rb +3 -0
- data/test/rbbt/vector/model/test_torch.rb +61 -0
- data/test/rbbt/vector/test_model.rb +25 -26
- data/test/test_helper.rb +13 -0
- metadata +35 -16
- data/lib/rbbt/tensorflow.rb +0 -43
- data/lib/rbbt/vector/model/huggingface.old.rb +0 -160
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 9d53609453e1c3bd589c95071569583bff5f11224a200850c0d9e85775e5a2ce
|
4
|
+
data.tar.gz: 501f436caff07c990c09adec4caa60d32fe577b35209a65e8f96418ab2acc422
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 4a53ccc8a5eac633e344beffd36073e1828d848d9743552f12e75a54cbda337050152b72fd95d7d97d4aeed50f811dcc8475c97bc292ba9bcdd57e9d32f0e91b
|
7
|
+
data.tar.gz: 30fef4ee2c023e141c0ef991ae35d1a0913d4444138fda778760b821ccc48d85d67dc44d4c251d7f84e3b1bbe6ecb6f3b64aee13fbdc443cb33bbed2b34a586a
|
data/lib/rbbt/matrix/barcode.rb
CHANGED
@@ -3,7 +3,7 @@ require 'rbbt/util/R'
|
|
3
3
|
class RbbtMatrix
|
4
4
|
def barcode(outfile, factor = 2)
|
5
5
|
|
6
|
-
FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.
|
6
|
+
FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exist? File.dirname(outfile)
|
7
7
|
cmd =<<-EOF
|
8
8
|
source('#{Rbbt.share.R['barcode.R'].find}')
|
9
9
|
rbbt.GE.barcode.mode(#{ R.ruby2R self.data_file }, #{ R.ruby2R outfile }, #{ R.ruby2R factor })
|
@@ -49,7 +49,7 @@ rbbt.GE.barcode.mode(#{ R.ruby2R self.data_file }, #{ R.ruby2R outfile }, #{ R.r
|
|
49
49
|
|
50
50
|
clusters = Array === clusters ? clusters : (2..clusters).to_a
|
51
51
|
|
52
|
-
FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.
|
52
|
+
FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exist? File.dirname(outfile)
|
53
53
|
cmd =<<-EOF
|
54
54
|
source('#{Rbbt.share.R['barcode.R'].find}')
|
55
55
|
rbbt.GE.activity_cluster(#{ R.ruby2R self.data_file }, #{ R.ruby2R outfile }, #{R.ruby2R key_field}, #{R.ruby2R clusters})
|
@@ -45,7 +45,7 @@ class RbbtMatrix
|
|
45
45
|
end
|
46
46
|
|
47
47
|
file = file.find if Path === file
|
48
|
-
FileUtils.mkdir_p File.dirname(file) unless file.nil? or File.
|
48
|
+
FileUtils.mkdir_p File.dirname(file) unless file.nil? or File.exist? File.dirname(file)
|
49
49
|
|
50
50
|
cmd = <<-EOS
|
51
51
|
|
@@ -71,10 +71,10 @@ end
|
|
71
71
|
|
72
72
|
|
73
73
|
#def self.analyze(datafile, main, contrast = nil, log2 = false, outfile = nil, key_field = nil, two_channel = nil)
|
74
|
-
# FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.
|
74
|
+
# FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exist? File.dirname(outfile)
|
75
75
|
# GE.run_R("rbbt.GE.process(#{ R.ruby2R datafile }, main = #{R.ruby2R(main, :strings => true)}, contrast = #{R.ruby2R(contrast, :strings => true)}, log2=#{ R.ruby2R log2 }, outfile = #{R.ruby2R outfile}, key.field = #{R.ruby2R key_field}, two.channel = #{R.ruby2R two_channel})")
|
76
76
|
#end
|
77
77
|
#def self.barcode(datafile, outfile, factor = 2)
|
78
|
-
# FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.
|
78
|
+
# FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exist? File.dirname(outfile)
|
79
79
|
# GE.run_R("rbbt.GE.barcode(#{ R.ruby2R datafile }, #{ R.ruby2R outfile }, #{ R.ruby2R factor })")
|
80
80
|
#end
|
@@ -13,7 +13,7 @@ class KnowledgeBase
|
|
13
13
|
|
14
14
|
return matrix if RbbtMatrix === matrix
|
15
15
|
|
16
|
-
Path.setup(matrix) if not Path === matrix and File.
|
16
|
+
Path.setup(matrix) if not Path === matrix and File.exist? matrix
|
17
17
|
|
18
18
|
raise "Registered matrix is strange: #{Misc.fingerprint matrix}" unless Path === matrix
|
19
19
|
|
data/lib/rbbt/plots/bar.rb
CHANGED
@@ -119,7 +119,7 @@ module BarPlot
|
|
119
119
|
options = Misc.add_defaults options, :width => [options[:total], MAX_WIDTH].min, :height => 20, :background => PNG::Color::White
|
120
120
|
width, height, background, canvas = Misc.process_options options, :width, :height, :background, :canvas
|
121
121
|
|
122
|
-
canvas ||= if options[:update] and options[:filename] and File.
|
122
|
+
canvas ||= if options[:update] and options[:filename] and File.exist? options[:filename]
|
123
123
|
PNG.load_file options[:filename]
|
124
124
|
else
|
125
125
|
PNG::Canvas.new width, height, get_color(background)
|
data/lib/rbbt/stan.rb
CHANGED
@@ -130,7 +130,7 @@ print(fit)
|
|
130
130
|
erase = true
|
131
131
|
end
|
132
132
|
|
133
|
-
FileUtils.mkdir_p directory unless File.
|
133
|
+
FileUtils.mkdir_p directory unless File.exist? directory
|
134
134
|
input_directory = File.join(directory, 'inputs')
|
135
135
|
parameter_chains = File.join(directory, 'chains') unless erase
|
136
136
|
summary = File.join(directory, 'summary') unless erase
|
@@ -118,7 +118,7 @@ module TSV
|
|
118
118
|
|
119
119
|
field_pos = fields.collect{|f| self.fields.index f}.compact
|
120
120
|
persistence_path = self.respond_to?(:persistence_path)? self.persistence_path : nil
|
121
|
-
Persist.persist(filename, :yaml, :fields => fields, :persist => persistence, :prefix => "Hyp.Geo.Counts", :other => {:background => background, :rename => rename, :persistence_path => persistence_path}) do
|
121
|
+
Persist.persist(filename, :yaml, :fields => fields, :persist => persistence, :prefix => "Hyp.Geo.Counts", :other => {:fields => fields, :background => background, :rename => rename, :persistence_path => persistence_path}) do
|
122
122
|
data ||= {}
|
123
123
|
|
124
124
|
with_unnamed do
|
@@ -152,6 +152,7 @@ module TSV
|
|
152
152
|
|
153
153
|
end
|
154
154
|
|
155
|
+
|
155
156
|
if rename
|
156
157
|
Log.debug("Using renames during annotation counts")
|
157
158
|
Hash[*data.keys.zip(data.values.collect{|l| l.collect{|e| rename.include?(e)? rename[e] : e }.uniq.length }).flatten]
|
@@ -0,0 +1,50 @@
|
|
1
|
+
require 'rbbt/vector/model/huggingface'
|
2
|
+
class MaskedLMModel < HuggingfaceModel
|
3
|
+
|
4
|
+
def initialize(checkpoint, dir = nil, model_options = {})
|
5
|
+
|
6
|
+
model_options = Misc.add_defaults model_options, :max_length => 128
|
7
|
+
super("MaskedLM", checkpoint, dir, model_options)
|
8
|
+
|
9
|
+
train_model do |texts,labels|
|
10
|
+
model, tokenizer = self.init
|
11
|
+
max_length = @model_options[:max_length]
|
12
|
+
mask_id = tokenizer.mask_token_id
|
13
|
+
|
14
|
+
dataset = []
|
15
|
+
texts.zip(labels).each do |text,label_values|
|
16
|
+
fixed_text = text.gsub("[MASK]", "[PENDINGMASK]")
|
17
|
+
label_tokens = label_values.collect{|label| tokenizer.convert_tokens_to_ids(label) }
|
18
|
+
label_tokens.each do |ids|
|
19
|
+
ids = [ids] unless Array === ids
|
20
|
+
fixed_text.sub!("[PENDINGMASK]", "[MASK]" * ids.length)
|
21
|
+
end
|
22
|
+
|
23
|
+
tokenized_text = tokenizer.call(fixed_text, truncation: true, padding: "max_length")
|
24
|
+
input_ids = tokenized_text["input_ids"].to_a
|
25
|
+
attention_mask = tokenized_text["attention_mask"].to_a
|
26
|
+
|
27
|
+
all_label_tokens = label_tokens.flatten
|
28
|
+
label_ids = input_ids.collect do |id|
|
29
|
+
if id == mask_id
|
30
|
+
all_label_tokens.shift
|
31
|
+
else
|
32
|
+
-100
|
33
|
+
end
|
34
|
+
end
|
35
|
+
dataset << {input_ids: input_ids, labels: label_ids, attention_mask: attention_mask}
|
36
|
+
end
|
37
|
+
|
38
|
+
dataset_file = File.join(@directory, 'dataset.json')
|
39
|
+
Open.write(dataset_file, dataset.collect{|e| e.to_json} * "\n")
|
40
|
+
|
41
|
+
training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, @model_path, @model_options[:training_args])
|
42
|
+
data_collator = RbbtPython.class_new_obj("transformers", "DefaultDataCollator", {})
|
43
|
+
RbbtPython.call_method("rbbt_dm.huggingface", :train_model, model, tokenizer, training_args_obj, dataset_file, @model_options[:class_weights], data_collator: data_collator)
|
44
|
+
|
45
|
+
model.save_pretrained(@model_path) if @model_path
|
46
|
+
tokenizer.save_pretrained(@model_path) if @model_path
|
47
|
+
end
|
48
|
+
|
49
|
+
end
|
50
|
+
end
|
@@ -1,50 +1,36 @@
|
|
1
|
-
require 'rbbt/vector/model'
|
2
|
-
require 'rbbt/util/python'
|
1
|
+
require 'rbbt/vector/model/torch'
|
3
2
|
|
4
|
-
|
5
|
-
RbbtPython.init_rbbt
|
3
|
+
class HuggingfaceModel < TorchModel
|
6
4
|
|
7
|
-
|
5
|
+
def initialize(task, checkpoint, dir = nil, model_options = {})
|
6
|
+
super(dir, nil, model_options)
|
8
7
|
|
9
|
-
|
8
|
+
@model_options = Misc.add_defaults @model_options, :task => task, :checkpoint => checkpoint
|
10
9
|
|
11
|
-
|
12
|
-
|
13
|
-
ffile.puts ["label", "text"].flatten * "\t"
|
14
|
-
elements.zip(labels).each do |element,label|
|
15
|
-
ffile.puts [label, element].flatten * "\t"
|
16
|
-
end
|
17
|
-
end
|
18
|
-
else
|
19
|
-
Open.write(tsv_dataset_file) do |ffile|
|
20
|
-
ffile.puts ["text"].flatten * "\t"
|
21
|
-
elements.each{|element| ffile.puts element }
|
22
|
-
end
|
23
|
-
end
|
10
|
+
init_model do
|
11
|
+
checkpoint = @model_path && File.directory?(@model_path) ? @model_path : @model_options[:checkpoint]
|
24
12
|
|
25
|
-
|
26
|
-
|
13
|
+
model = RbbtPython.call_method("rbbt_dm.huggingface", :load_model,
|
14
|
+
@model_options[:task], checkpoint, **(IndiferentHash.setup(model_options[:model_args]) || {}))
|
27
15
|
|
28
|
-
|
29
|
-
options = args.pop if Hash === args.last
|
30
|
-
options = Misc.add_defaults options, :task => task, :checkpoint => checkpoint
|
31
|
-
super(*args)
|
32
|
-
@model_options ||= {}
|
33
|
-
@model_options.merge!(options)
|
16
|
+
tokenizer_checkpoint = @model_options[:tokenizer_checkpoint] || checkpoint
|
34
17
|
|
35
|
-
|
36
|
-
|
18
|
+
tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_tokenizer,
|
19
|
+
@model_options[:task], tokenizer_checkpoint, **(IndiferentHash.setup(model_options[:tokenizer_args]) || {}))
|
37
20
|
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
21
|
+
[model, tokenizer]
|
22
|
+
end
|
23
|
+
|
24
|
+
eval_model do |texts,is_list|
|
25
|
+
model, tokenizer = self.init
|
26
|
+
|
27
|
+
if is_list || @model_options[:task] == "MaskedLM"
|
28
|
+
texts = [texts] if ! is_list
|
43
29
|
|
44
30
|
if @model_options.include?(:locate_tokens)
|
45
31
|
locate_tokens = @model_options[:locate_tokens]
|
46
32
|
elsif @model_options[:task] == "MaskedLM"
|
47
|
-
@model_options[:locate_tokens] = locate_tokens =
|
33
|
+
@model_options[:locate_tokens] = locate_tokens = tokenizer.special_tokens_map["mask_token"]
|
48
34
|
end
|
49
35
|
|
50
36
|
if @directory
|
@@ -57,22 +43,21 @@ class HuggingfaceModel < VectorModel
|
|
57
43
|
checkpoint_dir = File.join(tmpdir, 'checkpoints')
|
58
44
|
end
|
59
45
|
|
60
|
-
dataset_file =
|
46
|
+
dataset_file = TorchModel.text_dataset(tsv_file, texts)
|
61
47
|
training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, checkpoint_dir, @model_options[:training_args])
|
62
48
|
|
63
49
|
begin
|
64
|
-
RbbtPython.call_method("rbbt_dm.huggingface", :predict_model,
|
50
|
+
RbbtPython.call_method("rbbt_dm.huggingface", :predict_model, model, tokenizer, training_args_obj, dataset_file, locate_tokens)
|
65
51
|
ensure
|
66
52
|
Open.rm_rf tmpdir if tmpdir
|
67
53
|
end
|
68
54
|
else
|
69
|
-
RbbtPython.call_method("rbbt_dm.huggingface", :eval_model,
|
55
|
+
RbbtPython.call_method("rbbt_dm.huggingface", :eval_model, model, tokenizer, [texts], locate_tokens)
|
70
56
|
end
|
71
57
|
end
|
72
58
|
|
73
|
-
train_model do |
|
74
|
-
|
75
|
-
@model, @tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_model_and_tokenizer, @model_options[:task], checkpoint)
|
59
|
+
train_model do |texts,labels|
|
60
|
+
model, tokenizer = self.init
|
76
61
|
|
77
62
|
if @directory
|
78
63
|
tsv_file = File.join(@directory, 'dataset.tsv')
|
@@ -85,17 +70,19 @@ class HuggingfaceModel < VectorModel
|
|
85
70
|
end
|
86
71
|
|
87
72
|
training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, checkpoint_dir, @model_options[:training_args])
|
88
|
-
dataset_file = HuggingfaceModel.
|
73
|
+
dataset_file = HuggingfaceModel.text_dataset(tsv_file, texts, labels, @model_options[:class_labels])
|
89
74
|
|
90
|
-
RbbtPython.call_method("rbbt_dm.huggingface", :train_model,
|
75
|
+
RbbtPython.call_method("rbbt_dm.huggingface", :train_model, model, tokenizer, training_args_obj, dataset_file, @model_options[:class_weights])
|
91
76
|
|
92
77
|
Open.rm_rf tmpdir if tmpdir
|
93
78
|
|
94
|
-
|
95
|
-
|
79
|
+
model.save_pretrained(@model_path) if @model_path
|
80
|
+
tokenizer.save_pretrained(@model_path) if @model_path
|
96
81
|
end
|
97
82
|
|
98
|
-
post_process do |result|
|
83
|
+
post_process do |result,is_list|
|
84
|
+
model, tokenizer = self.init
|
85
|
+
|
99
86
|
if result.respond_to?(:predictions)
|
100
87
|
single = false
|
101
88
|
predictions = result.predictions
|
@@ -137,25 +124,25 @@ class HuggingfaceModel < VectorModel
|
|
137
124
|
best_token
|
138
125
|
end
|
139
126
|
|
140
|
-
best.collect{|b|
|
127
|
+
best.collect{|b| tokenizer.decode(b) } * "|"
|
141
128
|
end
|
142
129
|
Array === locate_tokens ? item_masks : item_masks.first
|
143
130
|
end
|
144
131
|
else
|
145
|
-
|
132
|
+
predictions
|
146
133
|
end
|
147
134
|
|
148
|
-
single ? result.first : result
|
135
|
+
(! is_list || single) && Array === result ? result.first : result
|
149
136
|
end
|
150
137
|
|
151
138
|
|
152
|
-
save_models if @
|
139
|
+
save_models if @model_path
|
153
140
|
end
|
154
141
|
|
155
142
|
def reset_model
|
156
143
|
@model, @tokenizer = nil
|
157
|
-
Open.
|
144
|
+
Open.rm_rf @model_path
|
145
|
+
init
|
158
146
|
end
|
159
|
-
|
160
147
|
end
|
161
148
|
|
@@ -0,0 +1,33 @@
|
|
1
|
+
require 'rbbt/vector/model'
|
2
|
+
require 'rbbt/util/python'
|
3
|
+
|
4
|
+
RbbtPython.add_path Rbbt.python.find(:lib)
|
5
|
+
RbbtPython.init_rbbt
|
6
|
+
|
7
|
+
class PythonModel < VectorModel
|
8
|
+
attr_accessor :python_class, :python_module
|
9
|
+
def initialize(dir, python_class = nil, python_module = nil, model_options = nil)
|
10
|
+
python_module = :model if python_module.nil?
|
11
|
+
model_options, python_module = python_module, :model if model_options.nil? && Hash === python_module
|
12
|
+
model_options = {} if model_options.nil?
|
13
|
+
|
14
|
+
super(dir, model_options)
|
15
|
+
|
16
|
+
@python_class = python_class
|
17
|
+
@python_module = python_module
|
18
|
+
|
19
|
+
init_model do
|
20
|
+
RbbtPython.add_path @directory
|
21
|
+
RbbtPython.class_new_obj(@python_module, @python_class, **model_options)
|
22
|
+
end if python_class
|
23
|
+
|
24
|
+
eval_model do |features,list=false|
|
25
|
+
init
|
26
|
+
if list
|
27
|
+
model.eval(features)
|
28
|
+
else
|
29
|
+
model.eval([features])[0]
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
33
|
+
end
|
@@ -0,0 +1,31 @@
|
|
1
|
+
require 'rbbt/vector/model/torch'
|
2
|
+
|
3
|
+
class PytorchLightningModel < TorchModel
|
4
|
+
attr_accessor :loader, :val_loader, :trainer
|
5
|
+
def initialize(...)
|
6
|
+
super(...)
|
7
|
+
|
8
|
+
train_model do |features,labels|
|
9
|
+
model = init
|
10
|
+
loader = self.loader
|
11
|
+
val_loader = self.val_loader
|
12
|
+
if (features && features.any?) && loader.nil?
|
13
|
+
TmpFile.with_file do |tsv_dataset_file|
|
14
|
+
TorchModel.feature_dataset(tsv_dataset_file, features, labels)
|
15
|
+
RbbtPython.pyimport :rbbt_dm
|
16
|
+
loader = RbbtPython.rbbt_dm.tsv(tsv_dataset_file)
|
17
|
+
end
|
18
|
+
end
|
19
|
+
trainer.fit(model, loader, val_loader)
|
20
|
+
TorchModel.save_architecture(model, model_path) if @directory
|
21
|
+
TorchModel.save_state(model, model_path) if @directory
|
22
|
+
end
|
23
|
+
end
|
24
|
+
|
25
|
+
def trainer
|
26
|
+
@trainer ||= begin
|
27
|
+
options = @model_options[:training_args] || @model_options[:trainer_args]
|
28
|
+
RbbtPython.class_new_obj("pytorch_lightning", "Trainer", options || {})
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
@@ -28,12 +28,12 @@ class SpaCyModel < VectorModel
|
|
28
28
|
|
29
29
|
super(dir)
|
30
30
|
|
31
|
-
@train_model = Proc.new do |
|
31
|
+
@train_model = Proc.new do |features, labels|
|
32
32
|
texts = features
|
33
33
|
docs = []
|
34
34
|
unique_labels = labels.uniq
|
35
|
-
tmpconfig = File.join(
|
36
|
-
tmptrain = File.join(
|
35
|
+
tmpconfig = File.join(@model_path, 'config')
|
36
|
+
tmptrain = File.join(@model_path, 'train.spacy')
|
37
37
|
SpaCy.config(@config, tmpconfig)
|
38
38
|
|
39
39
|
bar = bar(features.length, "Training documents into spacy format")
|
@@ -54,20 +54,22 @@ class SpaCyModel < VectorModel
|
|
54
54
|
end
|
55
55
|
|
56
56
|
gpu = Rbbt::Config.get('gpu_id', :spacy, :spacy_train, :default => 0)
|
57
|
-
CMD.cmd_log(:spacy, "train #{tmpconfig} --output #{
|
57
|
+
CMD.cmd_log(:spacy, "train #{tmpconfig} --output #{@model_path} --paths.train #{tmptrain} --paths.dev #{tmptrain}", "--gpu-id" => gpu)
|
58
58
|
end
|
59
59
|
|
60
|
-
@eval_model = Proc.new do |
|
60
|
+
@eval_model = Proc.new do |features,list|
|
61
61
|
texts = features
|
62
62
|
texts = [texts] unless list
|
63
63
|
|
64
|
+
model_path = @model_path
|
65
|
+
|
64
66
|
docs = []
|
65
67
|
bar = bar(features.length, "Evaluating model")
|
66
68
|
SpaCyModel.spacy do
|
67
69
|
gpu = Rbbt::Config.get('gpu_id', :spacy, :spacy_train, :default => 0)
|
68
70
|
gpu = gpu.to_i if gpu && gpu != ""
|
69
71
|
spacy.require_gpu(gpu) if gpu
|
70
|
-
nlp = spacy.load("#{
|
72
|
+
nlp = spacy.load("#{model_path}/model-best")
|
71
73
|
|
72
74
|
docs = nlp.pipe(texts)
|
73
75
|
RbbtPython.collect docs, :bar => bar do |d|
|
@@ -25,7 +25,7 @@ class TensorFlowModel < VectorModel
|
|
25
25
|
|
26
26
|
super(dir)
|
27
27
|
|
28
|
-
@train_model = Proc.new do |
|
28
|
+
@train_model = Proc.new do |features, labels|
|
29
29
|
tensorflow do
|
30
30
|
features = tensorflow.convert_to_tensor(features)
|
31
31
|
labels = tensorflow.convert_to_tensor(labels)
|
@@ -33,16 +33,17 @@ class TensorFlowModel < VectorModel
|
|
33
33
|
@graph ||= keras_graph
|
34
34
|
@graph.compile(**@compile_options)
|
35
35
|
@graph.fit(features, labels, :epochs => @epochs, :verbose => true)
|
36
|
-
@graph.save(
|
36
|
+
@graph.save(@model_path)
|
37
37
|
end
|
38
38
|
|
39
|
-
@eval_model = Proc.new do |
|
39
|
+
@eval_model = Proc.new do |features|
|
40
40
|
tensorflow do
|
41
41
|
features = tensorflow.convert_to_tensor(features)
|
42
42
|
end
|
43
|
+
model_path = @model_path
|
44
|
+
graph = @graph ||= keras.models.load_model(model_path)
|
43
45
|
keras do
|
44
|
-
|
45
|
-
indices = @graph.predict(features, :verbose => false).tolist()
|
46
|
+
indices = graph.predict(features, :verbose => false).tolist()
|
46
47
|
labels = indices.collect{|p| p.length > 1 ? p.index(p.max): p.first }
|
47
48
|
labels
|
48
49
|
end
|
@@ -0,0 +1,58 @@
|
|
1
|
+
class TorchModel
|
2
|
+
def self.feature_tsv(elements, labels = nil, class_labels = nil)
|
3
|
+
tsv = TSV.setup({}, :key_field => "ID", :fields => ["features"], :type => :flat)
|
4
|
+
if labels
|
5
|
+
tsv.fields = tsv.fields + ["label"]
|
6
|
+
labels = case class_labels
|
7
|
+
when Array
|
8
|
+
labels.collect{|l| class_labels.index l}
|
9
|
+
when Hash
|
10
|
+
inverse_class_labels = {}
|
11
|
+
class_labels.each{|c,l| inverse_class_labels[l] = c }
|
12
|
+
labels.collect{|l| inverse_class_labels[l]}
|
13
|
+
else
|
14
|
+
labels
|
15
|
+
end
|
16
|
+
elements.zip(labels).each_with_index do |p,i|
|
17
|
+
features, label = p
|
18
|
+
id = i
|
19
|
+
if Array === features
|
20
|
+
tsv[id] = features + [label]
|
21
|
+
else
|
22
|
+
tsv[id] = [features, label]
|
23
|
+
end
|
24
|
+
end
|
25
|
+
else
|
26
|
+
elements.each_with_index do |features,i|
|
27
|
+
id = i
|
28
|
+
if Array === features
|
29
|
+
tsv[id] = features
|
30
|
+
else
|
31
|
+
tsv[id] = [features]
|
32
|
+
end
|
33
|
+
end
|
34
|
+
end
|
35
|
+
tsv
|
36
|
+
end
|
37
|
+
|
38
|
+
def self.feature_dataset(tsv_dataset_file, elements, labels = nil, class_labels = nil)
|
39
|
+
tsv = feature_tsv(elements, labels, class_labels)
|
40
|
+
Open.write(tsv_dataset_file, tsv.to_s)
|
41
|
+
tsv_dataset_file
|
42
|
+
end
|
43
|
+
|
44
|
+
def self.text_dataset(tsv_dataset_file, elements, labels = nil, class_labels = nil)
|
45
|
+
elements = elements.collect{|e| e.gsub("\n", ' ') }
|
46
|
+
tsv = feature_tsv(elements, labels, class_labels)
|
47
|
+
if labels.nil?
|
48
|
+
tsv.fields[0] = "text"
|
49
|
+
tsv.type = :single
|
50
|
+
else
|
51
|
+
tsv.fields[0] = "text"
|
52
|
+
tsv.type = :list
|
53
|
+
end
|
54
|
+
Open.write(tsv_dataset_file, tsv.to_s)
|
55
|
+
tsv_dataset_file
|
56
|
+
end
|
57
|
+
|
58
|
+
end
|
@@ -0,0 +1,52 @@
|
|
1
|
+
class TorchModel
|
2
|
+
module Tensor
|
3
|
+
def to_ruby
|
4
|
+
RbbtPython.numpy2ruby(self)
|
5
|
+
end
|
6
|
+
def self.setup(obj)
|
7
|
+
obj.extend Tensor
|
8
|
+
end
|
9
|
+
end
|
10
|
+
|
11
|
+
def self.init_python
|
12
|
+
RbbtPython.pyimport :torch
|
13
|
+
RbbtPython.pyimport :rbbt
|
14
|
+
RbbtPython.pyimport :rbbt_dm
|
15
|
+
RbbtPython.pyfrom :rbbt_dm, import: :util
|
16
|
+
RbbtPython.pyfrom :torch, import: :nn
|
17
|
+
end
|
18
|
+
|
19
|
+
def self.optimizer(model, training_args)
|
20
|
+
begin
|
21
|
+
learning_rate = training_args[:learning_rate] || 0.01
|
22
|
+
RbbtPython.torch.optim.SGD.new(model.parameters(), lr: learning_rate)
|
23
|
+
end
|
24
|
+
end
|
25
|
+
|
26
|
+
def self.device(model_options)
|
27
|
+
case model_options[:device]
|
28
|
+
when String, Symbol
|
29
|
+
RbbtPython.torch.device(model_options[:device].to_s)
|
30
|
+
when nil
|
31
|
+
RbbtPython.rbbt_dm.util.device()
|
32
|
+
else
|
33
|
+
model_options[:device]
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
37
|
+
def self.dtype(model_options)
|
38
|
+
case model_options[:dtype]
|
39
|
+
when String, Symbol
|
40
|
+
RbbtPython.torch.call(model_options[:dtype])
|
41
|
+
when nil
|
42
|
+
nil
|
43
|
+
else
|
44
|
+
model_options[:dtype]
|
45
|
+
end
|
46
|
+
end
|
47
|
+
|
48
|
+
def self.tensor(obj, device, dtype)
|
49
|
+
RbbtPython.torch.tensor(obj, dtype: dtype, device: device)
|
50
|
+
end
|
51
|
+
|
52
|
+
end
|
@@ -0,0 +1,31 @@
|
|
1
|
+
class TorchModel
|
2
|
+
def self.get_layer(model, layer = nil)
|
3
|
+
if layer.nil?
|
4
|
+
model
|
5
|
+
else
|
6
|
+
layer.split(".").inject(model){|acc,l| PyCall.getattr(acc, l.to_sym) }
|
7
|
+
end
|
8
|
+
end
|
9
|
+
def get_layer(...); TorchModel.get_layer(model, ...); end
|
10
|
+
|
11
|
+
def self.get_weights(model, layer = nil)
|
12
|
+
Tensor.setup PyCall.getattr(get_layer(model, layer), :weight)
|
13
|
+
end
|
14
|
+
def get_weights(...); TorchModel.get_weights(model, ...); end
|
15
|
+
|
16
|
+
def self.freeze(layer)
|
17
|
+
begin
|
18
|
+
PyCall.getattr(layer, :weight).requires_grad = false
|
19
|
+
rescue
|
20
|
+
end
|
21
|
+
RbbtPython.iterate(layer.children) do |layer|
|
22
|
+
freeze(layer)
|
23
|
+
end
|
24
|
+
end
|
25
|
+
def self.freeze_layer(model, layer)
|
26
|
+
layer = get_layer(model, layer)
|
27
|
+
freeze(layer)
|
28
|
+
end
|
29
|
+
def freeze_layer(...); TorchModel.freeze_layer(model, ...); end
|
30
|
+
|
31
|
+
end
|
@@ -0,0 +1,30 @@
|
|
1
|
+
class TorchModel
|
2
|
+
def self.model_architecture(model_path)
|
3
|
+
model_path + '.architecture'
|
4
|
+
end
|
5
|
+
|
6
|
+
def self.save_state(model, model_path)
|
7
|
+
Log.debug "Saving model state into #{model_path}"
|
8
|
+
RbbtPython.torch.save(model.state_dict(), model_path)
|
9
|
+
end
|
10
|
+
|
11
|
+
def self.load_state(model, model_path)
|
12
|
+
return model unless Open.exists?(model_path)
|
13
|
+
Log.debug "Loading model state from #{model_path}"
|
14
|
+
model.load_state_dict(RbbtPython.torch.load(model_path))
|
15
|
+
model
|
16
|
+
end
|
17
|
+
|
18
|
+
def self.save_architecture(model, model_path)
|
19
|
+
model_architecture = model_architecture(model_path)
|
20
|
+
Log.debug "Saving model architecture into #{model_architecture}"
|
21
|
+
RbbtPython.torch.save(model, model_architecture)
|
22
|
+
end
|
23
|
+
|
24
|
+
def self.load_architecture(model_path)
|
25
|
+
model_architecture = model_architecture(model_path)
|
26
|
+
return unless Open.exists?(model_architecture)
|
27
|
+
Log.debug "Loading model architecture from #{model_architecture}"
|
28
|
+
RbbtPython.torch.load(model_architecture)
|
29
|
+
end
|
30
|
+
end
|