rbbt-dm 1.2.6 → 1.2.9
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/lib/rbbt/matrix/barcode.rb +2 -2
- data/lib/rbbt/matrix/differential.rb +3 -3
- data/lib/rbbt/matrix/knowledge_base.rb +1 -1
- data/lib/rbbt/plots/bar.rb +1 -1
- data/lib/rbbt/stan.rb +1 -1
- data/lib/rbbt/statistics/hypergeometric.rb +2 -1
- data/lib/rbbt/vector/model/huggingface/masked_lm.rb +50 -0
- data/lib/rbbt/vector/model/huggingface.rb +57 -38
- data/lib/rbbt/vector/model/pytorch_lightning.rb +35 -0
- data/lib/rbbt/vector/model/random_forest.rb +1 -1
- data/lib/rbbt/vector/model/spaCy.rb +8 -14
- data/lib/rbbt/vector/model/tensorflow.rb +6 -5
- data/lib/rbbt/vector/model/torch.rb +37 -0
- data/lib/rbbt/vector/model/util.rb +18 -0
- data/lib/rbbt/vector/model.rb +100 -56
- data/python/rbbt_dm/__init__.py +48 -1
- data/python/rbbt_dm/atcold/__init__.py +0 -0
- data/python/rbbt_dm/atcold/plot_lib.py +141 -0
- data/python/rbbt_dm/atcold/spiral.py +27 -0
- data/python/rbbt_dm/huggingface.py +57 -26
- data/python/rbbt_dm/language_model.py +70 -0
- data/python/rbbt_dm/util.py +30 -0
- data/share/spaCy/gpu/textcat_accuracy.conf +2 -1
- data/test/rbbt/vector/model/huggingface/test_masked_lm.rb +41 -0
- data/test/rbbt/vector/model/test_huggingface.rb +258 -27
- data/test/rbbt/vector/model/test_pytorch_lightning.rb +83 -0
- data/test/rbbt/vector/model/test_spaCy.rb +1 -1
- data/test/rbbt/vector/model/test_tensorflow.rb +3 -0
- data/test/rbbt/vector/test_model.rb +25 -26
- data/test/test_helper.rb +13 -0
- metadata +26 -16
- data/lib/rbbt/tensorflow.rb +0 -43
- data/lib/rbbt/vector/model/huggingface.old.rb +0 -160
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: db2cbab94e21fd2ca67f7306fa9941b59cbfb2865382e5439edf6313f50309e7
|
4
|
+
data.tar.gz: f4acf3651daa90ef23bc454c62df68e208976a977d51e2e85d02558d48897187
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 7786759636450821aabca306cd210ab3d201e094b81bb70052d57d7bfb6e4de73a198576fe4b002487baf7997138f9c53b91644632cd12cc79b40ff62141a70a
|
7
|
+
data.tar.gz: 9870745068a897909170f3a6187e520e6530b121f8ad4ab40224c3369a16a8bb1c1e55bb8ca8fd892943ec1fcead9f4661e49c3d97d07d3740236a1ec4f69a34
|
data/lib/rbbt/matrix/barcode.rb
CHANGED
@@ -3,7 +3,7 @@ require 'rbbt/util/R'
|
|
3
3
|
class RbbtMatrix
|
4
4
|
def barcode(outfile, factor = 2)
|
5
5
|
|
6
|
-
FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.
|
6
|
+
FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exist? File.dirname(outfile)
|
7
7
|
cmd =<<-EOF
|
8
8
|
source('#{Rbbt.share.R['barcode.R'].find}')
|
9
9
|
rbbt.GE.barcode.mode(#{ R.ruby2R self.data_file }, #{ R.ruby2R outfile }, #{ R.ruby2R factor })
|
@@ -49,7 +49,7 @@ rbbt.GE.barcode.mode(#{ R.ruby2R self.data_file }, #{ R.ruby2R outfile }, #{ R.r
|
|
49
49
|
|
50
50
|
clusters = Array === clusters ? clusters : (2..clusters).to_a
|
51
51
|
|
52
|
-
FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.
|
52
|
+
FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exist? File.dirname(outfile)
|
53
53
|
cmd =<<-EOF
|
54
54
|
source('#{Rbbt.share.R['barcode.R'].find}')
|
55
55
|
rbbt.GE.activity_cluster(#{ R.ruby2R self.data_file }, #{ R.ruby2R outfile }, #{R.ruby2R key_field}, #{R.ruby2R clusters})
|
@@ -45,7 +45,7 @@ class RbbtMatrix
|
|
45
45
|
end
|
46
46
|
|
47
47
|
file = file.find if Path === file
|
48
|
-
FileUtils.mkdir_p File.dirname(file) unless file.nil? or File.
|
48
|
+
FileUtils.mkdir_p File.dirname(file) unless file.nil? or File.exist? File.dirname(file)
|
49
49
|
|
50
50
|
cmd = <<-EOS
|
51
51
|
|
@@ -71,10 +71,10 @@ end
|
|
71
71
|
|
72
72
|
|
73
73
|
#def self.analyze(datafile, main, contrast = nil, log2 = false, outfile = nil, key_field = nil, two_channel = nil)
|
74
|
-
# FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.
|
74
|
+
# FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exist? File.dirname(outfile)
|
75
75
|
# GE.run_R("rbbt.GE.process(#{ R.ruby2R datafile }, main = #{R.ruby2R(main, :strings => true)}, contrast = #{R.ruby2R(contrast, :strings => true)}, log2=#{ R.ruby2R log2 }, outfile = #{R.ruby2R outfile}, key.field = #{R.ruby2R key_field}, two.channel = #{R.ruby2R two_channel})")
|
76
76
|
#end
|
77
77
|
#def self.barcode(datafile, outfile, factor = 2)
|
78
|
-
# FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.
|
78
|
+
# FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exist? File.dirname(outfile)
|
79
79
|
# GE.run_R("rbbt.GE.barcode(#{ R.ruby2R datafile }, #{ R.ruby2R outfile }, #{ R.ruby2R factor })")
|
80
80
|
#end
|
@@ -13,7 +13,7 @@ class KnowledgeBase
|
|
13
13
|
|
14
14
|
return matrix if RbbtMatrix === matrix
|
15
15
|
|
16
|
-
Path.setup(matrix) if not Path === matrix and File.
|
16
|
+
Path.setup(matrix) if not Path === matrix and File.exist? matrix
|
17
17
|
|
18
18
|
raise "Registered matrix is strange: #{Misc.fingerprint matrix}" unless Path === matrix
|
19
19
|
|
data/lib/rbbt/plots/bar.rb
CHANGED
@@ -119,7 +119,7 @@ module BarPlot
|
|
119
119
|
options = Misc.add_defaults options, :width => [options[:total], MAX_WIDTH].min, :height => 20, :background => PNG::Color::White
|
120
120
|
width, height, background, canvas = Misc.process_options options, :width, :height, :background, :canvas
|
121
121
|
|
122
|
-
canvas ||= if options[:update] and options[:filename] and File.
|
122
|
+
canvas ||= if options[:update] and options[:filename] and File.exist? options[:filename]
|
123
123
|
PNG.load_file options[:filename]
|
124
124
|
else
|
125
125
|
PNG::Canvas.new width, height, get_color(background)
|
data/lib/rbbt/stan.rb
CHANGED
@@ -130,7 +130,7 @@ print(fit)
|
|
130
130
|
erase = true
|
131
131
|
end
|
132
132
|
|
133
|
-
FileUtils.mkdir_p directory unless File.
|
133
|
+
FileUtils.mkdir_p directory unless File.exist? directory
|
134
134
|
input_directory = File.join(directory, 'inputs')
|
135
135
|
parameter_chains = File.join(directory, 'chains') unless erase
|
136
136
|
summary = File.join(directory, 'summary') unless erase
|
@@ -118,7 +118,7 @@ module TSV
|
|
118
118
|
|
119
119
|
field_pos = fields.collect{|f| self.fields.index f}.compact
|
120
120
|
persistence_path = self.respond_to?(:persistence_path)? self.persistence_path : nil
|
121
|
-
Persist.persist(filename, :yaml, :fields => fields, :persist => persistence, :prefix => "Hyp.Geo.Counts", :other => {:background => background, :rename => rename, :persistence_path => persistence_path}) do
|
121
|
+
Persist.persist(filename, :yaml, :fields => fields, :persist => persistence, :prefix => "Hyp.Geo.Counts", :other => {:fields => fields, :background => background, :rename => rename, :persistence_path => persistence_path}) do
|
122
122
|
data ||= {}
|
123
123
|
|
124
124
|
with_unnamed do
|
@@ -152,6 +152,7 @@ module TSV
|
|
152
152
|
|
153
153
|
end
|
154
154
|
|
155
|
+
|
155
156
|
if rename
|
156
157
|
Log.debug("Using renames during annotation counts")
|
157
158
|
Hash[*data.keys.zip(data.values.collect{|l| l.collect{|e| rename.include?(e)? rename[e] : e }.uniq.length }).flatten]
|
@@ -0,0 +1,50 @@
|
|
1
|
+
require 'rbbt/vector/model/huggingface'
|
2
|
+
class MaskedLMModel < HuggingfaceModel
|
3
|
+
|
4
|
+
def initialize(checkpoint, dir = nil, model_options = {})
|
5
|
+
|
6
|
+
model_options = Misc.add_defaults model_options, :max_length => 128
|
7
|
+
super("MaskedLM", checkpoint, dir, model_options)
|
8
|
+
|
9
|
+
train_model do |texts,labels|
|
10
|
+
model, tokenizer = self.init
|
11
|
+
max_length = @model_options[:max_length]
|
12
|
+
mask_id = tokenizer.mask_token_id
|
13
|
+
|
14
|
+
dataset = []
|
15
|
+
texts.zip(labels).each do |text,label_values|
|
16
|
+
fixed_text = text.gsub("[MASK]", "[PENDINGMASK]")
|
17
|
+
label_tokens = label_values.collect{|label| tokenizer.convert_tokens_to_ids(label) }
|
18
|
+
label_tokens.each do |ids|
|
19
|
+
ids = [ids] unless Array === ids
|
20
|
+
fixed_text.sub!("[PENDINGMASK]", "[MASK]" * ids.length)
|
21
|
+
end
|
22
|
+
|
23
|
+
tokenized_text = tokenizer.call(fixed_text, truncation: true, padding: "max_length")
|
24
|
+
input_ids = tokenized_text["input_ids"].to_a
|
25
|
+
attention_mask = tokenized_text["attention_mask"].to_a
|
26
|
+
|
27
|
+
all_label_tokens = label_tokens.flatten
|
28
|
+
label_ids = input_ids.collect do |id|
|
29
|
+
if id == mask_id
|
30
|
+
all_label_tokens.shift
|
31
|
+
else
|
32
|
+
-100
|
33
|
+
end
|
34
|
+
end
|
35
|
+
dataset << {input_ids: input_ids, labels: label_ids, attention_mask: attention_mask}
|
36
|
+
end
|
37
|
+
|
38
|
+
dataset_file = File.join(@directory, 'dataset.json')
|
39
|
+
Open.write(dataset_file, dataset.collect{|e| e.to_json} * "\n")
|
40
|
+
|
41
|
+
training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, @model_path, @model_options[:training_args])
|
42
|
+
data_collator = RbbtPython.class_new_obj("transformers", "DefaultDataCollator", {})
|
43
|
+
RbbtPython.call_method("rbbt_dm.huggingface", :train_model, model, tokenizer, training_args_obj, dataset_file, @model_options[:class_weights], data_collator: data_collator)
|
44
|
+
|
45
|
+
model.save_pretrained(@model_path) if @model_path
|
46
|
+
tokenizer.save_pretrained(@model_path) if @model_path
|
47
|
+
end
|
48
|
+
|
49
|
+
end
|
50
|
+
end
|
@@ -1,50 +1,68 @@
|
|
1
|
-
require 'rbbt/vector/model'
|
2
|
-
require 'rbbt/util/python'
|
1
|
+
require 'rbbt/vector/model/torch'
|
3
2
|
|
4
|
-
|
5
|
-
RbbtPython.init_rbbt
|
3
|
+
class HuggingfaceModel < TorchModel
|
6
4
|
|
7
|
-
|
8
|
-
|
9
|
-
def self.tsv_dataset(tsv_dataset_file, elements, labels = nil)
|
5
|
+
def self.tsv_dataset(tsv_dataset_file, elements, labels = nil, class_labels = nil)
|
10
6
|
|
11
7
|
if labels
|
8
|
+
labels = case class_labels
|
9
|
+
when Array
|
10
|
+
labels.collect{|l| class_labels.index l}
|
11
|
+
when Hash
|
12
|
+
inverse_class_labels = {}
|
13
|
+
class_labels.each{|c,l| inverse_class_labels[l] = c }
|
14
|
+
labels.collect{|l| inverse_class_labels[l]}
|
15
|
+
else
|
16
|
+
labels
|
17
|
+
end
|
18
|
+
|
12
19
|
Open.write(tsv_dataset_file) do |ffile|
|
13
20
|
ffile.puts ["label", "text"].flatten * "\t"
|
14
21
|
elements.zip(labels).each do |element,label|
|
22
|
+
element = element.gsub("\n", " ")
|
15
23
|
ffile.puts [label, element].flatten * "\t"
|
16
24
|
end
|
25
|
+
ffile.sync
|
17
26
|
end
|
18
27
|
else
|
19
28
|
Open.write(tsv_dataset_file) do |ffile|
|
20
29
|
ffile.puts ["text"].flatten * "\t"
|
21
|
-
elements.each
|
30
|
+
elements.each do |element|
|
31
|
+
element = element.gsub("\n", " ")
|
32
|
+
ffile.puts element
|
33
|
+
end
|
34
|
+
ffile.sync
|
22
35
|
end
|
23
36
|
end
|
24
37
|
|
25
38
|
tsv_dataset_file
|
26
39
|
end
|
27
40
|
|
28
|
-
def initialize(task, checkpoint,
|
29
|
-
|
30
|
-
options = Misc.add_defaults options, :task => task, :checkpoint => checkpoint
|
31
|
-
super(*args)
|
32
|
-
@model_options ||= {}
|
33
|
-
@model_options.merge!(options)
|
41
|
+
def initialize(task, checkpoint, dir = nil, model_options = {})
|
42
|
+
super(dir, model_options)
|
34
43
|
|
35
|
-
|
36
|
-
checkpoint = directory && File.directory?(directory) ? directory : @model_options[:checkpoint]
|
44
|
+
@model_options = Misc.add_defaults @model_options, :task => task, :checkpoint => checkpoint
|
37
45
|
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
46
|
+
init_model do
|
47
|
+
checkpoint = @model_path && File.directory?(@model_path) ? @model_path : @model_options[:checkpoint]
|
48
|
+
model = RbbtPython.call_method("rbbt_dm.huggingface", :load_model,
|
49
|
+
@model_options[:task], checkpoint, **(IndiferentHash.setup(model_options[:model_args]) || {}))
|
50
|
+
tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_tokenizer,
|
51
|
+
@model_options[:task], checkpoint, **(IndiferentHash.setup(model_options[:tokenizer_args]) || {}))
|
52
|
+
|
53
|
+
[model, tokenizer]
|
54
|
+
end
|
55
|
+
|
56
|
+
eval_model do |texts,is_list|
|
57
|
+
model, tokenizer = self.init
|
58
|
+
|
59
|
+
if is_list || @model_options[:task] == "MaskedLM"
|
60
|
+
texts = [texts] if ! is_list
|
43
61
|
|
44
62
|
if @model_options.include?(:locate_tokens)
|
45
63
|
locate_tokens = @model_options[:locate_tokens]
|
46
64
|
elsif @model_options[:task] == "MaskedLM"
|
47
|
-
@model_options[:locate_tokens] = locate_tokens =
|
65
|
+
@model_options[:locate_tokens] = locate_tokens = tokenizer.special_tokens_map["mask_token"]
|
48
66
|
end
|
49
67
|
|
50
68
|
if @directory
|
@@ -61,18 +79,17 @@ class HuggingfaceModel < VectorModel
|
|
61
79
|
training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, checkpoint_dir, @model_options[:training_args])
|
62
80
|
|
63
81
|
begin
|
64
|
-
RbbtPython.call_method("rbbt_dm.huggingface", :predict_model,
|
82
|
+
RbbtPython.call_method("rbbt_dm.huggingface", :predict_model, model, tokenizer, training_args_obj, dataset_file, locate_tokens)
|
65
83
|
ensure
|
66
84
|
Open.rm_rf tmpdir if tmpdir
|
67
85
|
end
|
68
86
|
else
|
69
|
-
RbbtPython.call_method("rbbt_dm.huggingface", :eval_model,
|
87
|
+
RbbtPython.call_method("rbbt_dm.huggingface", :eval_model, model, tokenizer, [texts], locate_tokens)
|
70
88
|
end
|
71
89
|
end
|
72
90
|
|
73
|
-
train_model do |
|
74
|
-
|
75
|
-
@model, @tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_model_and_tokenizer, @model_options[:task], checkpoint)
|
91
|
+
train_model do |texts,labels|
|
92
|
+
model, tokenizer = self.init
|
76
93
|
|
77
94
|
if @directory
|
78
95
|
tsv_file = File.join(@directory, 'dataset.tsv')
|
@@ -85,17 +102,19 @@ class HuggingfaceModel < VectorModel
|
|
85
102
|
end
|
86
103
|
|
87
104
|
training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, checkpoint_dir, @model_options[:training_args])
|
88
|
-
dataset_file = HuggingfaceModel.tsv_dataset(tsv_file, texts, labels)
|
105
|
+
dataset_file = HuggingfaceModel.tsv_dataset(tsv_file, texts, labels, @model_options[:class_labels])
|
89
106
|
|
90
|
-
RbbtPython.call_method("rbbt_dm.huggingface", :train_model,
|
107
|
+
RbbtPython.call_method("rbbt_dm.huggingface", :train_model, model, tokenizer, training_args_obj, dataset_file, @model_options[:class_weights])
|
91
108
|
|
92
109
|
Open.rm_rf tmpdir if tmpdir
|
93
110
|
|
94
|
-
|
95
|
-
|
111
|
+
model.save_pretrained(@model_path) if @model_path
|
112
|
+
tokenizer.save_pretrained(@model_path) if @model_path
|
96
113
|
end
|
97
114
|
|
98
|
-
post_process do |result|
|
115
|
+
post_process do |result,is_list|
|
116
|
+
model, tokenizer = self.init
|
117
|
+
|
99
118
|
if result.respond_to?(:predictions)
|
100
119
|
single = false
|
101
120
|
predictions = result.predictions
|
@@ -137,25 +156,25 @@ class HuggingfaceModel < VectorModel
|
|
137
156
|
best_token
|
138
157
|
end
|
139
158
|
|
140
|
-
best.collect{|b|
|
159
|
+
best.collect{|b| tokenizer.decode(b) } * "|"
|
141
160
|
end
|
142
161
|
Array === locate_tokens ? item_masks : item_masks.first
|
143
162
|
end
|
144
163
|
else
|
145
|
-
|
164
|
+
predictions
|
146
165
|
end
|
147
166
|
|
148
|
-
single ? result.first : result
|
167
|
+
(! is_list || single) && Array === result ? result.first : result
|
149
168
|
end
|
150
169
|
|
151
170
|
|
152
|
-
save_models if @
|
171
|
+
save_models if @model_path
|
153
172
|
end
|
154
173
|
|
155
174
|
def reset_model
|
156
175
|
@model, @tokenizer = nil
|
157
|
-
Open.
|
176
|
+
Open.rm_rf @model_path
|
177
|
+
init
|
158
178
|
end
|
159
|
-
|
160
179
|
end
|
161
180
|
|
@@ -0,0 +1,35 @@
|
|
1
|
+
require 'rbbt/vector/model/torch'
|
2
|
+
|
3
|
+
class PytorchLightningModel < TorchModel
|
4
|
+
attr_accessor :loader, :val_loader, :trainer
|
5
|
+
def initialize(module_name, class_name, dir = nil, model_options = {})
|
6
|
+
super(dir, model_options)
|
7
|
+
@module_name = module_name
|
8
|
+
@class_name = class_name
|
9
|
+
|
10
|
+
init_model do
|
11
|
+
RbbtPython.pyimport @module_name
|
12
|
+
RbbtPython.class_new_obj(@module_name, @class_name, @model_options[:model_args] || {})
|
13
|
+
end
|
14
|
+
|
15
|
+
train_model do |features,labels|
|
16
|
+
model = init
|
17
|
+
raise "Use the loader" if @loader.nil?
|
18
|
+
raise "Use the trainer" if @trainer.nil?
|
19
|
+
|
20
|
+
trainer.fit(model, @loader, @val_loader)
|
21
|
+
end
|
22
|
+
|
23
|
+
eval_model do |features,list|
|
24
|
+
if list
|
25
|
+
model.call(RbbtPython.call_method(:torch, :tensor, features))
|
26
|
+
else
|
27
|
+
model.call(RbbtPython.call_method(:torch, :tensor, [features]))
|
28
|
+
end
|
29
|
+
end
|
30
|
+
|
31
|
+
end
|
32
|
+
end
|
33
|
+
|
34
|
+
if __FILE__ == $0
|
35
|
+
end
|
@@ -28,12 +28,12 @@ class SpaCyModel < VectorModel
|
|
28
28
|
|
29
29
|
super(dir)
|
30
30
|
|
31
|
-
@train_model = Proc.new do |
|
31
|
+
@train_model = Proc.new do |features, labels|
|
32
32
|
texts = features
|
33
33
|
docs = []
|
34
34
|
unique_labels = labels.uniq
|
35
|
-
tmpconfig = File.join(
|
36
|
-
tmptrain = File.join(
|
35
|
+
tmpconfig = File.join(@model_path, 'config')
|
36
|
+
tmptrain = File.join(@model_path, 'train.spacy')
|
37
37
|
SpaCy.config(@config, tmpconfig)
|
38
38
|
|
39
39
|
bar = bar(features.length, "Training documents into spacy format")
|
@@ -54,20 +54,22 @@ class SpaCyModel < VectorModel
|
|
54
54
|
end
|
55
55
|
|
56
56
|
gpu = Rbbt::Config.get('gpu_id', :spacy, :spacy_train, :default => 0)
|
57
|
-
CMD.cmd_log(:spacy, "train #{tmpconfig} --output #{
|
57
|
+
CMD.cmd_log(:spacy, "train #{tmpconfig} --output #{@model_path} --paths.train #{tmptrain} --paths.dev #{tmptrain}", "--gpu-id" => gpu)
|
58
58
|
end
|
59
59
|
|
60
|
-
@eval_model = Proc.new do |
|
60
|
+
@eval_model = Proc.new do |features,list|
|
61
61
|
texts = features
|
62
62
|
texts = [texts] unless list
|
63
63
|
|
64
|
+
model_path = @model_path
|
65
|
+
|
64
66
|
docs = []
|
65
67
|
bar = bar(features.length, "Evaluating model")
|
66
68
|
SpaCyModel.spacy do
|
67
69
|
gpu = Rbbt::Config.get('gpu_id', :spacy, :spacy_train, :default => 0)
|
68
70
|
gpu = gpu.to_i if gpu && gpu != ""
|
69
71
|
spacy.require_gpu(gpu) if gpu
|
70
|
-
nlp = spacy.load("#{
|
72
|
+
nlp = spacy.load("#{model_path}/model-best")
|
71
73
|
|
72
74
|
docs = nlp.pipe(texts)
|
73
75
|
RbbtPython.collect docs, :bar => bar do |d|
|
@@ -75,14 +77,6 @@ class SpaCyModel < VectorModel
|
|
75
77
|
d.cats.sort_by{|l,v| v.to_f || 0 }.last.first
|
76
78
|
end
|
77
79
|
end
|
78
|
-
#nlp.(docs).cats.collect{|cats| cats.sort_by{|l,v| v.to_f }.last.first }
|
79
|
-
#Log::ProgressBar.with_bar texts.length, :desc => "Evaluating documents" do |bar|
|
80
|
-
# texts.collect do |text|
|
81
|
-
# cats = nlp.(text).cats
|
82
|
-
# bar.tick
|
83
|
-
# cats.sort_by{|l,v| v.to_f }.last.first
|
84
|
-
# end
|
85
|
-
#end
|
86
80
|
end
|
87
81
|
end
|
88
82
|
end
|
@@ -25,7 +25,7 @@ class TensorFlowModel < VectorModel
|
|
25
25
|
|
26
26
|
super(dir)
|
27
27
|
|
28
|
-
@train_model = Proc.new do |
|
28
|
+
@train_model = Proc.new do |features, labels|
|
29
29
|
tensorflow do
|
30
30
|
features = tensorflow.convert_to_tensor(features)
|
31
31
|
labels = tensorflow.convert_to_tensor(labels)
|
@@ -33,16 +33,17 @@ class TensorFlowModel < VectorModel
|
|
33
33
|
@graph ||= keras_graph
|
34
34
|
@graph.compile(**@compile_options)
|
35
35
|
@graph.fit(features, labels, :epochs => @epochs, :verbose => true)
|
36
|
-
@graph.save(
|
36
|
+
@graph.save(@model_path)
|
37
37
|
end
|
38
38
|
|
39
|
-
@eval_model = Proc.new do |
|
39
|
+
@eval_model = Proc.new do |features|
|
40
40
|
tensorflow do
|
41
41
|
features = tensorflow.convert_to_tensor(features)
|
42
42
|
end
|
43
|
+
model_path = @model_path
|
44
|
+
graph = @graph ||= keras.models.load_model(model_path)
|
43
45
|
keras do
|
44
|
-
|
45
|
-
indices = @graph.predict(features, :verbose => false).tolist()
|
46
|
+
indices = graph.predict(features, :verbose => false).tolist()
|
46
47
|
labels = indices.collect{|p| p.length > 1 ? p.index(p.max): p.first }
|
47
48
|
labels
|
48
49
|
end
|
@@ -0,0 +1,37 @@
|
|
1
|
+
require 'rbbt/vector/model'
|
2
|
+
require 'rbbt/util/python'
|
3
|
+
|
4
|
+
RbbtPython.add_path Rbbt.python.find(:lib)
|
5
|
+
RbbtPython.init_rbbt
|
6
|
+
|
7
|
+
class TorchModel < VectorModel
|
8
|
+
|
9
|
+
attr_accessor :model
|
10
|
+
|
11
|
+
def self.get_layer(model, layer)
|
12
|
+
layer.split(".").inject(model){|acc,l| PyCall.getattr(acc, l.to_sym) }
|
13
|
+
end
|
14
|
+
|
15
|
+
def self.get_weights(model, layer)
|
16
|
+
PyCall.getattr(get_layer(model, layer), :weight)
|
17
|
+
end
|
18
|
+
|
19
|
+
def self.freeze(layer)
|
20
|
+
begin
|
21
|
+
PyCall.getattr(layer, :weight).requires_grad = false
|
22
|
+
rescue
|
23
|
+
end
|
24
|
+
RbbtPython.iterate(layer.children) do |layer|
|
25
|
+
freeze(layer)
|
26
|
+
end
|
27
|
+
end
|
28
|
+
|
29
|
+
def self.freeze_layer(model, layer)
|
30
|
+
layer = get_layer(model, layer)
|
31
|
+
freeze(layer)
|
32
|
+
end
|
33
|
+
|
34
|
+
def initialize(dir, model_options = {})
|
35
|
+
super(dir, model_options)
|
36
|
+
end
|
37
|
+
end
|
@@ -9,4 +9,22 @@ class VectorModel
|
|
9
9
|
@bar.init
|
10
10
|
@bar
|
11
11
|
end
|
12
|
+
|
13
|
+
def balance_labels
|
14
|
+
counts = Misc.counts(@labels)
|
15
|
+
min = counts.values.min
|
16
|
+
|
17
|
+
used = {}
|
18
|
+
new_labels = []
|
19
|
+
new_features = []
|
20
|
+
@labels.zip(@features).shuffle.each do |label, features|
|
21
|
+
used[label] ||= 0
|
22
|
+
next if used[label] > min
|
23
|
+
used[label] += 1
|
24
|
+
new_labels << label
|
25
|
+
new_features << features
|
26
|
+
end
|
27
|
+
@labels = new_labels
|
28
|
+
@features = new_features
|
29
|
+
end
|
12
30
|
end
|