rbbt-dm 1.2.6 → 1.2.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/lib/rbbt/matrix/barcode.rb +2 -2
- data/lib/rbbt/matrix/differential.rb +3 -3
- data/lib/rbbt/matrix/knowledge_base.rb +1 -1
- data/lib/rbbt/plots/bar.rb +1 -1
- data/lib/rbbt/stan.rb +1 -1
- data/lib/rbbt/statistics/hypergeometric.rb +2 -1
- data/lib/rbbt/vector/model/huggingface/masked_lm.rb +50 -0
- data/lib/rbbt/vector/model/huggingface.rb +57 -38
- data/lib/rbbt/vector/model/pytorch_lightning.rb +35 -0
- data/lib/rbbt/vector/model/random_forest.rb +1 -1
- data/lib/rbbt/vector/model/spaCy.rb +8 -14
- data/lib/rbbt/vector/model/tensorflow.rb +6 -5
- data/lib/rbbt/vector/model/torch.rb +37 -0
- data/lib/rbbt/vector/model/util.rb +18 -0
- data/lib/rbbt/vector/model.rb +100 -56
- data/python/rbbt_dm/__init__.py +48 -1
- data/python/rbbt_dm/atcold/__init__.py +0 -0
- data/python/rbbt_dm/atcold/plot_lib.py +141 -0
- data/python/rbbt_dm/atcold/spiral.py +27 -0
- data/python/rbbt_dm/huggingface.py +57 -26
- data/python/rbbt_dm/language_model.py +70 -0
- data/python/rbbt_dm/util.py +30 -0
- data/share/spaCy/gpu/textcat_accuracy.conf +2 -1
- data/test/rbbt/vector/model/huggingface/test_masked_lm.rb +41 -0
- data/test/rbbt/vector/model/test_huggingface.rb +258 -27
- data/test/rbbt/vector/model/test_pytorch_lightning.rb +83 -0
- data/test/rbbt/vector/model/test_spaCy.rb +1 -1
- data/test/rbbt/vector/model/test_tensorflow.rb +3 -0
- data/test/rbbt/vector/test_model.rb +25 -26
- data/test/test_helper.rb +13 -0
- metadata +26 -16
- data/lib/rbbt/tensorflow.rb +0 -43
- data/lib/rbbt/vector/model/huggingface.old.rb +0 -160
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: db2cbab94e21fd2ca67f7306fa9941b59cbfb2865382e5439edf6313f50309e7
|
4
|
+
data.tar.gz: f4acf3651daa90ef23bc454c62df68e208976a977d51e2e85d02558d48897187
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 7786759636450821aabca306cd210ab3d201e094b81bb70052d57d7bfb6e4de73a198576fe4b002487baf7997138f9c53b91644632cd12cc79b40ff62141a70a
|
7
|
+
data.tar.gz: 9870745068a897909170f3a6187e520e6530b121f8ad4ab40224c3369a16a8bb1c1e55bb8ca8fd892943ec1fcead9f4661e49c3d97d07d3740236a1ec4f69a34
|
data/lib/rbbt/matrix/barcode.rb
CHANGED
@@ -3,7 +3,7 @@ require 'rbbt/util/R'
|
|
3
3
|
class RbbtMatrix
|
4
4
|
def barcode(outfile, factor = 2)
|
5
5
|
|
6
|
-
FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.
|
6
|
+
FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exist? File.dirname(outfile)
|
7
7
|
cmd =<<-EOF
|
8
8
|
source('#{Rbbt.share.R['barcode.R'].find}')
|
9
9
|
rbbt.GE.barcode.mode(#{ R.ruby2R self.data_file }, #{ R.ruby2R outfile }, #{ R.ruby2R factor })
|
@@ -49,7 +49,7 @@ rbbt.GE.barcode.mode(#{ R.ruby2R self.data_file }, #{ R.ruby2R outfile }, #{ R.r
|
|
49
49
|
|
50
50
|
clusters = Array === clusters ? clusters : (2..clusters).to_a
|
51
51
|
|
52
|
-
FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.
|
52
|
+
FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exist? File.dirname(outfile)
|
53
53
|
cmd =<<-EOF
|
54
54
|
source('#{Rbbt.share.R['barcode.R'].find}')
|
55
55
|
rbbt.GE.activity_cluster(#{ R.ruby2R self.data_file }, #{ R.ruby2R outfile }, #{R.ruby2R key_field}, #{R.ruby2R clusters})
|
@@ -45,7 +45,7 @@ class RbbtMatrix
|
|
45
45
|
end
|
46
46
|
|
47
47
|
file = file.find if Path === file
|
48
|
-
FileUtils.mkdir_p File.dirname(file) unless file.nil? or File.
|
48
|
+
FileUtils.mkdir_p File.dirname(file) unless file.nil? or File.exist? File.dirname(file)
|
49
49
|
|
50
50
|
cmd = <<-EOS
|
51
51
|
|
@@ -71,10 +71,10 @@ end
|
|
71
71
|
|
72
72
|
|
73
73
|
#def self.analyze(datafile, main, contrast = nil, log2 = false, outfile = nil, key_field = nil, two_channel = nil)
|
74
|
-
# FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.
|
74
|
+
# FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exist? File.dirname(outfile)
|
75
75
|
# GE.run_R("rbbt.GE.process(#{ R.ruby2R datafile }, main = #{R.ruby2R(main, :strings => true)}, contrast = #{R.ruby2R(contrast, :strings => true)}, log2=#{ R.ruby2R log2 }, outfile = #{R.ruby2R outfile}, key.field = #{R.ruby2R key_field}, two.channel = #{R.ruby2R two_channel})")
|
76
76
|
#end
|
77
77
|
#def self.barcode(datafile, outfile, factor = 2)
|
78
|
-
# FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.
|
78
|
+
# FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exist? File.dirname(outfile)
|
79
79
|
# GE.run_R("rbbt.GE.barcode(#{ R.ruby2R datafile }, #{ R.ruby2R outfile }, #{ R.ruby2R factor })")
|
80
80
|
#end
|
@@ -13,7 +13,7 @@ class KnowledgeBase
|
|
13
13
|
|
14
14
|
return matrix if RbbtMatrix === matrix
|
15
15
|
|
16
|
-
Path.setup(matrix) if not Path === matrix and File.
|
16
|
+
Path.setup(matrix) if not Path === matrix and File.exist? matrix
|
17
17
|
|
18
18
|
raise "Registered matrix is strange: #{Misc.fingerprint matrix}" unless Path === matrix
|
19
19
|
|
data/lib/rbbt/plots/bar.rb
CHANGED
@@ -119,7 +119,7 @@ module BarPlot
|
|
119
119
|
options = Misc.add_defaults options, :width => [options[:total], MAX_WIDTH].min, :height => 20, :background => PNG::Color::White
|
120
120
|
width, height, background, canvas = Misc.process_options options, :width, :height, :background, :canvas
|
121
121
|
|
122
|
-
canvas ||= if options[:update] and options[:filename] and File.
|
122
|
+
canvas ||= if options[:update] and options[:filename] and File.exist? options[:filename]
|
123
123
|
PNG.load_file options[:filename]
|
124
124
|
else
|
125
125
|
PNG::Canvas.new width, height, get_color(background)
|
data/lib/rbbt/stan.rb
CHANGED
@@ -130,7 +130,7 @@ print(fit)
|
|
130
130
|
erase = true
|
131
131
|
end
|
132
132
|
|
133
|
-
FileUtils.mkdir_p directory unless File.
|
133
|
+
FileUtils.mkdir_p directory unless File.exist? directory
|
134
134
|
input_directory = File.join(directory, 'inputs')
|
135
135
|
parameter_chains = File.join(directory, 'chains') unless erase
|
136
136
|
summary = File.join(directory, 'summary') unless erase
|
@@ -118,7 +118,7 @@ module TSV
|
|
118
118
|
|
119
119
|
field_pos = fields.collect{|f| self.fields.index f}.compact
|
120
120
|
persistence_path = self.respond_to?(:persistence_path)? self.persistence_path : nil
|
121
|
-
Persist.persist(filename, :yaml, :fields => fields, :persist => persistence, :prefix => "Hyp.Geo.Counts", :other => {:background => background, :rename => rename, :persistence_path => persistence_path}) do
|
121
|
+
Persist.persist(filename, :yaml, :fields => fields, :persist => persistence, :prefix => "Hyp.Geo.Counts", :other => {:fields => fields, :background => background, :rename => rename, :persistence_path => persistence_path}) do
|
122
122
|
data ||= {}
|
123
123
|
|
124
124
|
with_unnamed do
|
@@ -152,6 +152,7 @@ module TSV
|
|
152
152
|
|
153
153
|
end
|
154
154
|
|
155
|
+
|
155
156
|
if rename
|
156
157
|
Log.debug("Using renames during annotation counts")
|
157
158
|
Hash[*data.keys.zip(data.values.collect{|l| l.collect{|e| rename.include?(e)? rename[e] : e }.uniq.length }).flatten]
|
@@ -0,0 +1,50 @@
|
|
1
|
+
require 'rbbt/vector/model/huggingface'
|
2
|
+
class MaskedLMModel < HuggingfaceModel
|
3
|
+
|
4
|
+
def initialize(checkpoint, dir = nil, model_options = {})
|
5
|
+
|
6
|
+
model_options = Misc.add_defaults model_options, :max_length => 128
|
7
|
+
super("MaskedLM", checkpoint, dir, model_options)
|
8
|
+
|
9
|
+
train_model do |texts,labels|
|
10
|
+
model, tokenizer = self.init
|
11
|
+
max_length = @model_options[:max_length]
|
12
|
+
mask_id = tokenizer.mask_token_id
|
13
|
+
|
14
|
+
dataset = []
|
15
|
+
texts.zip(labels).each do |text,label_values|
|
16
|
+
fixed_text = text.gsub("[MASK]", "[PENDINGMASK]")
|
17
|
+
label_tokens = label_values.collect{|label| tokenizer.convert_tokens_to_ids(label) }
|
18
|
+
label_tokens.each do |ids|
|
19
|
+
ids = [ids] unless Array === ids
|
20
|
+
fixed_text.sub!("[PENDINGMASK]", "[MASK]" * ids.length)
|
21
|
+
end
|
22
|
+
|
23
|
+
tokenized_text = tokenizer.call(fixed_text, truncation: true, padding: "max_length")
|
24
|
+
input_ids = tokenized_text["input_ids"].to_a
|
25
|
+
attention_mask = tokenized_text["attention_mask"].to_a
|
26
|
+
|
27
|
+
all_label_tokens = label_tokens.flatten
|
28
|
+
label_ids = input_ids.collect do |id|
|
29
|
+
if id == mask_id
|
30
|
+
all_label_tokens.shift
|
31
|
+
else
|
32
|
+
-100
|
33
|
+
end
|
34
|
+
end
|
35
|
+
dataset << {input_ids: input_ids, labels: label_ids, attention_mask: attention_mask}
|
36
|
+
end
|
37
|
+
|
38
|
+
dataset_file = File.join(@directory, 'dataset.json')
|
39
|
+
Open.write(dataset_file, dataset.collect{|e| e.to_json} * "\n")
|
40
|
+
|
41
|
+
training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, @model_path, @model_options[:training_args])
|
42
|
+
data_collator = RbbtPython.class_new_obj("transformers", "DefaultDataCollator", {})
|
43
|
+
RbbtPython.call_method("rbbt_dm.huggingface", :train_model, model, tokenizer, training_args_obj, dataset_file, @model_options[:class_weights], data_collator: data_collator)
|
44
|
+
|
45
|
+
model.save_pretrained(@model_path) if @model_path
|
46
|
+
tokenizer.save_pretrained(@model_path) if @model_path
|
47
|
+
end
|
48
|
+
|
49
|
+
end
|
50
|
+
end
|
@@ -1,50 +1,68 @@
|
|
1
|
-
require 'rbbt/vector/model'
|
2
|
-
require 'rbbt/util/python'
|
1
|
+
require 'rbbt/vector/model/torch'
|
3
2
|
|
4
|
-
|
5
|
-
RbbtPython.init_rbbt
|
3
|
+
class HuggingfaceModel < TorchModel
|
6
4
|
|
7
|
-
|
8
|
-
|
9
|
-
def self.tsv_dataset(tsv_dataset_file, elements, labels = nil)
|
5
|
+
def self.tsv_dataset(tsv_dataset_file, elements, labels = nil, class_labels = nil)
|
10
6
|
|
11
7
|
if labels
|
8
|
+
labels = case class_labels
|
9
|
+
when Array
|
10
|
+
labels.collect{|l| class_labels.index l}
|
11
|
+
when Hash
|
12
|
+
inverse_class_labels = {}
|
13
|
+
class_labels.each{|c,l| inverse_class_labels[l] = c }
|
14
|
+
labels.collect{|l| inverse_class_labels[l]}
|
15
|
+
else
|
16
|
+
labels
|
17
|
+
end
|
18
|
+
|
12
19
|
Open.write(tsv_dataset_file) do |ffile|
|
13
20
|
ffile.puts ["label", "text"].flatten * "\t"
|
14
21
|
elements.zip(labels).each do |element,label|
|
22
|
+
element = element.gsub("\n", " ")
|
15
23
|
ffile.puts [label, element].flatten * "\t"
|
16
24
|
end
|
25
|
+
ffile.sync
|
17
26
|
end
|
18
27
|
else
|
19
28
|
Open.write(tsv_dataset_file) do |ffile|
|
20
29
|
ffile.puts ["text"].flatten * "\t"
|
21
|
-
elements.each
|
30
|
+
elements.each do |element|
|
31
|
+
element = element.gsub("\n", " ")
|
32
|
+
ffile.puts element
|
33
|
+
end
|
34
|
+
ffile.sync
|
22
35
|
end
|
23
36
|
end
|
24
37
|
|
25
38
|
tsv_dataset_file
|
26
39
|
end
|
27
40
|
|
28
|
-
def initialize(task, checkpoint,
|
29
|
-
|
30
|
-
options = Misc.add_defaults options, :task => task, :checkpoint => checkpoint
|
31
|
-
super(*args)
|
32
|
-
@model_options ||= {}
|
33
|
-
@model_options.merge!(options)
|
41
|
+
def initialize(task, checkpoint, dir = nil, model_options = {})
|
42
|
+
super(dir, model_options)
|
34
43
|
|
35
|
-
|
36
|
-
checkpoint = directory && File.directory?(directory) ? directory : @model_options[:checkpoint]
|
44
|
+
@model_options = Misc.add_defaults @model_options, :task => task, :checkpoint => checkpoint
|
37
45
|
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
46
|
+
init_model do
|
47
|
+
checkpoint = @model_path && File.directory?(@model_path) ? @model_path : @model_options[:checkpoint]
|
48
|
+
model = RbbtPython.call_method("rbbt_dm.huggingface", :load_model,
|
49
|
+
@model_options[:task], checkpoint, **(IndiferentHash.setup(model_options[:model_args]) || {}))
|
50
|
+
tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_tokenizer,
|
51
|
+
@model_options[:task], checkpoint, **(IndiferentHash.setup(model_options[:tokenizer_args]) || {}))
|
52
|
+
|
53
|
+
[model, tokenizer]
|
54
|
+
end
|
55
|
+
|
56
|
+
eval_model do |texts,is_list|
|
57
|
+
model, tokenizer = self.init
|
58
|
+
|
59
|
+
if is_list || @model_options[:task] == "MaskedLM"
|
60
|
+
texts = [texts] if ! is_list
|
43
61
|
|
44
62
|
if @model_options.include?(:locate_tokens)
|
45
63
|
locate_tokens = @model_options[:locate_tokens]
|
46
64
|
elsif @model_options[:task] == "MaskedLM"
|
47
|
-
@model_options[:locate_tokens] = locate_tokens =
|
65
|
+
@model_options[:locate_tokens] = locate_tokens = tokenizer.special_tokens_map["mask_token"]
|
48
66
|
end
|
49
67
|
|
50
68
|
if @directory
|
@@ -61,18 +79,17 @@ class HuggingfaceModel < VectorModel
|
|
61
79
|
training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, checkpoint_dir, @model_options[:training_args])
|
62
80
|
|
63
81
|
begin
|
64
|
-
RbbtPython.call_method("rbbt_dm.huggingface", :predict_model,
|
82
|
+
RbbtPython.call_method("rbbt_dm.huggingface", :predict_model, model, tokenizer, training_args_obj, dataset_file, locate_tokens)
|
65
83
|
ensure
|
66
84
|
Open.rm_rf tmpdir if tmpdir
|
67
85
|
end
|
68
86
|
else
|
69
|
-
RbbtPython.call_method("rbbt_dm.huggingface", :eval_model,
|
87
|
+
RbbtPython.call_method("rbbt_dm.huggingface", :eval_model, model, tokenizer, [texts], locate_tokens)
|
70
88
|
end
|
71
89
|
end
|
72
90
|
|
73
|
-
train_model do |
|
74
|
-
|
75
|
-
@model, @tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_model_and_tokenizer, @model_options[:task], checkpoint)
|
91
|
+
train_model do |texts,labels|
|
92
|
+
model, tokenizer = self.init
|
76
93
|
|
77
94
|
if @directory
|
78
95
|
tsv_file = File.join(@directory, 'dataset.tsv')
|
@@ -85,17 +102,19 @@ class HuggingfaceModel < VectorModel
|
|
85
102
|
end
|
86
103
|
|
87
104
|
training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, checkpoint_dir, @model_options[:training_args])
|
88
|
-
dataset_file = HuggingfaceModel.tsv_dataset(tsv_file, texts, labels)
|
105
|
+
dataset_file = HuggingfaceModel.tsv_dataset(tsv_file, texts, labels, @model_options[:class_labels])
|
89
106
|
|
90
|
-
RbbtPython.call_method("rbbt_dm.huggingface", :train_model,
|
107
|
+
RbbtPython.call_method("rbbt_dm.huggingface", :train_model, model, tokenizer, training_args_obj, dataset_file, @model_options[:class_weights])
|
91
108
|
|
92
109
|
Open.rm_rf tmpdir if tmpdir
|
93
110
|
|
94
|
-
|
95
|
-
|
111
|
+
model.save_pretrained(@model_path) if @model_path
|
112
|
+
tokenizer.save_pretrained(@model_path) if @model_path
|
96
113
|
end
|
97
114
|
|
98
|
-
post_process do |result|
|
115
|
+
post_process do |result,is_list|
|
116
|
+
model, tokenizer = self.init
|
117
|
+
|
99
118
|
if result.respond_to?(:predictions)
|
100
119
|
single = false
|
101
120
|
predictions = result.predictions
|
@@ -137,25 +156,25 @@ class HuggingfaceModel < VectorModel
|
|
137
156
|
best_token
|
138
157
|
end
|
139
158
|
|
140
|
-
best.collect{|b|
|
159
|
+
best.collect{|b| tokenizer.decode(b) } * "|"
|
141
160
|
end
|
142
161
|
Array === locate_tokens ? item_masks : item_masks.first
|
143
162
|
end
|
144
163
|
else
|
145
|
-
|
164
|
+
predictions
|
146
165
|
end
|
147
166
|
|
148
|
-
single ? result.first : result
|
167
|
+
(! is_list || single) && Array === result ? result.first : result
|
149
168
|
end
|
150
169
|
|
151
170
|
|
152
|
-
save_models if @
|
171
|
+
save_models if @model_path
|
153
172
|
end
|
154
173
|
|
155
174
|
def reset_model
|
156
175
|
@model, @tokenizer = nil
|
157
|
-
Open.
|
176
|
+
Open.rm_rf @model_path
|
177
|
+
init
|
158
178
|
end
|
159
|
-
|
160
179
|
end
|
161
180
|
|
@@ -0,0 +1,35 @@
|
|
1
|
+
require 'rbbt/vector/model/torch'
|
2
|
+
|
3
|
+
class PytorchLightningModel < TorchModel
|
4
|
+
attr_accessor :loader, :val_loader, :trainer
|
5
|
+
def initialize(module_name, class_name, dir = nil, model_options = {})
|
6
|
+
super(dir, model_options)
|
7
|
+
@module_name = module_name
|
8
|
+
@class_name = class_name
|
9
|
+
|
10
|
+
init_model do
|
11
|
+
RbbtPython.pyimport @module_name
|
12
|
+
RbbtPython.class_new_obj(@module_name, @class_name, @model_options[:model_args] || {})
|
13
|
+
end
|
14
|
+
|
15
|
+
train_model do |features,labels|
|
16
|
+
model = init
|
17
|
+
raise "Use the loader" if @loader.nil?
|
18
|
+
raise "Use the trainer" if @trainer.nil?
|
19
|
+
|
20
|
+
trainer.fit(model, @loader, @val_loader)
|
21
|
+
end
|
22
|
+
|
23
|
+
eval_model do |features,list|
|
24
|
+
if list
|
25
|
+
model.call(RbbtPython.call_method(:torch, :tensor, features))
|
26
|
+
else
|
27
|
+
model.call(RbbtPython.call_method(:torch, :tensor, [features]))
|
28
|
+
end
|
29
|
+
end
|
30
|
+
|
31
|
+
end
|
32
|
+
end
|
33
|
+
|
34
|
+
if __FILE__ == $0
|
35
|
+
end
|
@@ -28,12 +28,12 @@ class SpaCyModel < VectorModel
|
|
28
28
|
|
29
29
|
super(dir)
|
30
30
|
|
31
|
-
@train_model = Proc.new do |
|
31
|
+
@train_model = Proc.new do |features, labels|
|
32
32
|
texts = features
|
33
33
|
docs = []
|
34
34
|
unique_labels = labels.uniq
|
35
|
-
tmpconfig = File.join(
|
36
|
-
tmptrain = File.join(
|
35
|
+
tmpconfig = File.join(@model_path, 'config')
|
36
|
+
tmptrain = File.join(@model_path, 'train.spacy')
|
37
37
|
SpaCy.config(@config, tmpconfig)
|
38
38
|
|
39
39
|
bar = bar(features.length, "Training documents into spacy format")
|
@@ -54,20 +54,22 @@ class SpaCyModel < VectorModel
|
|
54
54
|
end
|
55
55
|
|
56
56
|
gpu = Rbbt::Config.get('gpu_id', :spacy, :spacy_train, :default => 0)
|
57
|
-
CMD.cmd_log(:spacy, "train #{tmpconfig} --output #{
|
57
|
+
CMD.cmd_log(:spacy, "train #{tmpconfig} --output #{@model_path} --paths.train #{tmptrain} --paths.dev #{tmptrain}", "--gpu-id" => gpu)
|
58
58
|
end
|
59
59
|
|
60
|
-
@eval_model = Proc.new do |
|
60
|
+
@eval_model = Proc.new do |features,list|
|
61
61
|
texts = features
|
62
62
|
texts = [texts] unless list
|
63
63
|
|
64
|
+
model_path = @model_path
|
65
|
+
|
64
66
|
docs = []
|
65
67
|
bar = bar(features.length, "Evaluating model")
|
66
68
|
SpaCyModel.spacy do
|
67
69
|
gpu = Rbbt::Config.get('gpu_id', :spacy, :spacy_train, :default => 0)
|
68
70
|
gpu = gpu.to_i if gpu && gpu != ""
|
69
71
|
spacy.require_gpu(gpu) if gpu
|
70
|
-
nlp = spacy.load("#{
|
72
|
+
nlp = spacy.load("#{model_path}/model-best")
|
71
73
|
|
72
74
|
docs = nlp.pipe(texts)
|
73
75
|
RbbtPython.collect docs, :bar => bar do |d|
|
@@ -75,14 +77,6 @@ class SpaCyModel < VectorModel
|
|
75
77
|
d.cats.sort_by{|l,v| v.to_f || 0 }.last.first
|
76
78
|
end
|
77
79
|
end
|
78
|
-
#nlp.(docs).cats.collect{|cats| cats.sort_by{|l,v| v.to_f }.last.first }
|
79
|
-
#Log::ProgressBar.with_bar texts.length, :desc => "Evaluating documents" do |bar|
|
80
|
-
# texts.collect do |text|
|
81
|
-
# cats = nlp.(text).cats
|
82
|
-
# bar.tick
|
83
|
-
# cats.sort_by{|l,v| v.to_f }.last.first
|
84
|
-
# end
|
85
|
-
#end
|
86
80
|
end
|
87
81
|
end
|
88
82
|
end
|
@@ -25,7 +25,7 @@ class TensorFlowModel < VectorModel
|
|
25
25
|
|
26
26
|
super(dir)
|
27
27
|
|
28
|
-
@train_model = Proc.new do |
|
28
|
+
@train_model = Proc.new do |features, labels|
|
29
29
|
tensorflow do
|
30
30
|
features = tensorflow.convert_to_tensor(features)
|
31
31
|
labels = tensorflow.convert_to_tensor(labels)
|
@@ -33,16 +33,17 @@ class TensorFlowModel < VectorModel
|
|
33
33
|
@graph ||= keras_graph
|
34
34
|
@graph.compile(**@compile_options)
|
35
35
|
@graph.fit(features, labels, :epochs => @epochs, :verbose => true)
|
36
|
-
@graph.save(
|
36
|
+
@graph.save(@model_path)
|
37
37
|
end
|
38
38
|
|
39
|
-
@eval_model = Proc.new do |
|
39
|
+
@eval_model = Proc.new do |features|
|
40
40
|
tensorflow do
|
41
41
|
features = tensorflow.convert_to_tensor(features)
|
42
42
|
end
|
43
|
+
model_path = @model_path
|
44
|
+
graph = @graph ||= keras.models.load_model(model_path)
|
43
45
|
keras do
|
44
|
-
|
45
|
-
indices = @graph.predict(features, :verbose => false).tolist()
|
46
|
+
indices = graph.predict(features, :verbose => false).tolist()
|
46
47
|
labels = indices.collect{|p| p.length > 1 ? p.index(p.max): p.first }
|
47
48
|
labels
|
48
49
|
end
|
@@ -0,0 +1,37 @@
|
|
1
|
+
require 'rbbt/vector/model'
|
2
|
+
require 'rbbt/util/python'
|
3
|
+
|
4
|
+
RbbtPython.add_path Rbbt.python.find(:lib)
|
5
|
+
RbbtPython.init_rbbt
|
6
|
+
|
7
|
+
class TorchModel < VectorModel
|
8
|
+
|
9
|
+
attr_accessor :model
|
10
|
+
|
11
|
+
def self.get_layer(model, layer)
|
12
|
+
layer.split(".").inject(model){|acc,l| PyCall.getattr(acc, l.to_sym) }
|
13
|
+
end
|
14
|
+
|
15
|
+
def self.get_weights(model, layer)
|
16
|
+
PyCall.getattr(get_layer(model, layer), :weight)
|
17
|
+
end
|
18
|
+
|
19
|
+
def self.freeze(layer)
|
20
|
+
begin
|
21
|
+
PyCall.getattr(layer, :weight).requires_grad = false
|
22
|
+
rescue
|
23
|
+
end
|
24
|
+
RbbtPython.iterate(layer.children) do |layer|
|
25
|
+
freeze(layer)
|
26
|
+
end
|
27
|
+
end
|
28
|
+
|
29
|
+
def self.freeze_layer(model, layer)
|
30
|
+
layer = get_layer(model, layer)
|
31
|
+
freeze(layer)
|
32
|
+
end
|
33
|
+
|
34
|
+
def initialize(dir, model_options = {})
|
35
|
+
super(dir, model_options)
|
36
|
+
end
|
37
|
+
end
|
@@ -9,4 +9,22 @@ class VectorModel
|
|
9
9
|
@bar.init
|
10
10
|
@bar
|
11
11
|
end
|
12
|
+
|
13
|
+
def balance_labels
|
14
|
+
counts = Misc.counts(@labels)
|
15
|
+
min = counts.values.min
|
16
|
+
|
17
|
+
used = {}
|
18
|
+
new_labels = []
|
19
|
+
new_features = []
|
20
|
+
@labels.zip(@features).shuffle.each do |label, features|
|
21
|
+
used[label] ||= 0
|
22
|
+
next if used[label] > min
|
23
|
+
used[label] += 1
|
24
|
+
new_labels << label
|
25
|
+
new_features << features
|
26
|
+
end
|
27
|
+
@labels = new_labels
|
28
|
+
@features = new_features
|
29
|
+
end
|
12
30
|
end
|