rbbt-dm 1.2.6 → 1.2.9

Sign up to get free protection for your applications and to get access to all the features.
Files changed (34) hide show
  1. checksums.yaml +4 -4
  2. data/lib/rbbt/matrix/barcode.rb +2 -2
  3. data/lib/rbbt/matrix/differential.rb +3 -3
  4. data/lib/rbbt/matrix/knowledge_base.rb +1 -1
  5. data/lib/rbbt/plots/bar.rb +1 -1
  6. data/lib/rbbt/stan.rb +1 -1
  7. data/lib/rbbt/statistics/hypergeometric.rb +2 -1
  8. data/lib/rbbt/vector/model/huggingface/masked_lm.rb +50 -0
  9. data/lib/rbbt/vector/model/huggingface.rb +57 -38
  10. data/lib/rbbt/vector/model/pytorch_lightning.rb +35 -0
  11. data/lib/rbbt/vector/model/random_forest.rb +1 -1
  12. data/lib/rbbt/vector/model/spaCy.rb +8 -14
  13. data/lib/rbbt/vector/model/tensorflow.rb +6 -5
  14. data/lib/rbbt/vector/model/torch.rb +37 -0
  15. data/lib/rbbt/vector/model/util.rb +18 -0
  16. data/lib/rbbt/vector/model.rb +100 -56
  17. data/python/rbbt_dm/__init__.py +48 -1
  18. data/python/rbbt_dm/atcold/__init__.py +0 -0
  19. data/python/rbbt_dm/atcold/plot_lib.py +141 -0
  20. data/python/rbbt_dm/atcold/spiral.py +27 -0
  21. data/python/rbbt_dm/huggingface.py +57 -26
  22. data/python/rbbt_dm/language_model.py +70 -0
  23. data/python/rbbt_dm/util.py +30 -0
  24. data/share/spaCy/gpu/textcat_accuracy.conf +2 -1
  25. data/test/rbbt/vector/model/huggingface/test_masked_lm.rb +41 -0
  26. data/test/rbbt/vector/model/test_huggingface.rb +258 -27
  27. data/test/rbbt/vector/model/test_pytorch_lightning.rb +83 -0
  28. data/test/rbbt/vector/model/test_spaCy.rb +1 -1
  29. data/test/rbbt/vector/model/test_tensorflow.rb +3 -0
  30. data/test/rbbt/vector/test_model.rb +25 -26
  31. data/test/test_helper.rb +13 -0
  32. metadata +26 -16
  33. data/lib/rbbt/tensorflow.rb +0 -43
  34. data/lib/rbbt/vector/model/huggingface.old.rb +0 -160
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 9744ab9faeaf4f9cc04947eb11103dbf0694dda624f805a5c6be27bb22af81ce
4
- data.tar.gz: d3a3903aa276a69e20cbd71213286449db396ecf5f6a4b4d80a64ab299041fbb
3
+ metadata.gz: db2cbab94e21fd2ca67f7306fa9941b59cbfb2865382e5439edf6313f50309e7
4
+ data.tar.gz: f4acf3651daa90ef23bc454c62df68e208976a977d51e2e85d02558d48897187
5
5
  SHA512:
6
- metadata.gz: 263fb609b37522874426bcd79374760399b4a9aaab443ae6d74c727f2d148474dd71ee0b2cfda7a50131dafbc314f66352f4285562a75f62144d2e05ccd214c7
7
- data.tar.gz: 1e0426429a38028a19b3f8c955e975138199c791dad8691de7fb760a5cbec3304f19341906a4457563a90de965ee2f12a5a639b928866b46258af2507eeb39fa
6
+ metadata.gz: 7786759636450821aabca306cd210ab3d201e094b81bb70052d57d7bfb6e4de73a198576fe4b002487baf7997138f9c53b91644632cd12cc79b40ff62141a70a
7
+ data.tar.gz: 9870745068a897909170f3a6187e520e6530b121f8ad4ab40224c3369a16a8bb1c1e55bb8ca8fd892943ec1fcead9f4661e49c3d97d07d3740236a1ec4f69a34
@@ -3,7 +3,7 @@ require 'rbbt/util/R'
3
3
  class RbbtMatrix
4
4
  def barcode(outfile, factor = 2)
5
5
 
6
- FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exists? File.dirname(outfile)
6
+ FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exist? File.dirname(outfile)
7
7
  cmd =<<-EOF
8
8
  source('#{Rbbt.share.R['barcode.R'].find}')
9
9
  rbbt.GE.barcode.mode(#{ R.ruby2R self.data_file }, #{ R.ruby2R outfile }, #{ R.ruby2R factor })
@@ -49,7 +49,7 @@ rbbt.GE.barcode.mode(#{ R.ruby2R self.data_file }, #{ R.ruby2R outfile }, #{ R.r
49
49
 
50
50
  clusters = Array === clusters ? clusters : (2..clusters).to_a
51
51
 
52
- FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exists? File.dirname(outfile)
52
+ FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exist? File.dirname(outfile)
53
53
  cmd =<<-EOF
54
54
  source('#{Rbbt.share.R['barcode.R'].find}')
55
55
  rbbt.GE.activity_cluster(#{ R.ruby2R self.data_file }, #{ R.ruby2R outfile }, #{R.ruby2R key_field}, #{R.ruby2R clusters})
@@ -45,7 +45,7 @@ class RbbtMatrix
45
45
  end
46
46
 
47
47
  file = file.find if Path === file
48
- FileUtils.mkdir_p File.dirname(file) unless file.nil? or File.exists? File.dirname(file)
48
+ FileUtils.mkdir_p File.dirname(file) unless file.nil? or File.exist? File.dirname(file)
49
49
 
50
50
  cmd = <<-EOS
51
51
 
@@ -71,10 +71,10 @@ end
71
71
 
72
72
 
73
73
  #def self.analyze(datafile, main, contrast = nil, log2 = false, outfile = nil, key_field = nil, two_channel = nil)
74
- # FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exists? File.dirname(outfile)
74
+ # FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exist? File.dirname(outfile)
75
75
  # GE.run_R("rbbt.GE.process(#{ R.ruby2R datafile }, main = #{R.ruby2R(main, :strings => true)}, contrast = #{R.ruby2R(contrast, :strings => true)}, log2=#{ R.ruby2R log2 }, outfile = #{R.ruby2R outfile}, key.field = #{R.ruby2R key_field}, two.channel = #{R.ruby2R two_channel})")
76
76
  #end
77
77
  #def self.barcode(datafile, outfile, factor = 2)
78
- # FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exists? File.dirname(outfile)
78
+ # FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exist? File.dirname(outfile)
79
79
  # GE.run_R("rbbt.GE.barcode(#{ R.ruby2R datafile }, #{ R.ruby2R outfile }, #{ R.ruby2R factor })")
80
80
  #end
@@ -13,7 +13,7 @@ class KnowledgeBase
13
13
 
14
14
  return matrix if RbbtMatrix === matrix
15
15
 
16
- Path.setup(matrix) if not Path === matrix and File.exists? matrix
16
+ Path.setup(matrix) if not Path === matrix and File.exist? matrix
17
17
 
18
18
  raise "Registered matrix is strange: #{Misc.fingerprint matrix}" unless Path === matrix
19
19
 
@@ -119,7 +119,7 @@ module BarPlot
119
119
  options = Misc.add_defaults options, :width => [options[:total], MAX_WIDTH].min, :height => 20, :background => PNG::Color::White
120
120
  width, height, background, canvas = Misc.process_options options, :width, :height, :background, :canvas
121
121
 
122
- canvas ||= if options[:update] and options[:filename] and File.exists? options[:filename]
122
+ canvas ||= if options[:update] and options[:filename] and File.exist? options[:filename]
123
123
  PNG.load_file options[:filename]
124
124
  else
125
125
  PNG::Canvas.new width, height, get_color(background)
data/lib/rbbt/stan.rb CHANGED
@@ -130,7 +130,7 @@ print(fit)
130
130
  erase = true
131
131
  end
132
132
 
133
- FileUtils.mkdir_p directory unless File.exists? directory
133
+ FileUtils.mkdir_p directory unless File.exist? directory
134
134
  input_directory = File.join(directory, 'inputs')
135
135
  parameter_chains = File.join(directory, 'chains') unless erase
136
136
  summary = File.join(directory, 'summary') unless erase
@@ -118,7 +118,7 @@ module TSV
118
118
 
119
119
  field_pos = fields.collect{|f| self.fields.index f}.compact
120
120
  persistence_path = self.respond_to?(:persistence_path)? self.persistence_path : nil
121
- Persist.persist(filename, :yaml, :fields => fields, :persist => persistence, :prefix => "Hyp.Geo.Counts", :other => {:background => background, :rename => rename, :persistence_path => persistence_path}) do
121
+ Persist.persist(filename, :yaml, :fields => fields, :persist => persistence, :prefix => "Hyp.Geo.Counts", :other => {:fields => fields, :background => background, :rename => rename, :persistence_path => persistence_path}) do
122
122
  data ||= {}
123
123
 
124
124
  with_unnamed do
@@ -152,6 +152,7 @@ module TSV
152
152
 
153
153
  end
154
154
 
155
+
155
156
  if rename
156
157
  Log.debug("Using renames during annotation counts")
157
158
  Hash[*data.keys.zip(data.values.collect{|l| l.collect{|e| rename.include?(e)? rename[e] : e }.uniq.length }).flatten]
@@ -0,0 +1,50 @@
1
+ require 'rbbt/vector/model/huggingface'
2
+ class MaskedLMModel < HuggingfaceModel
3
+
4
+ def initialize(checkpoint, dir = nil, model_options = {})
5
+
6
+ model_options = Misc.add_defaults model_options, :max_length => 128
7
+ super("MaskedLM", checkpoint, dir, model_options)
8
+
9
+ train_model do |texts,labels|
10
+ model, tokenizer = self.init
11
+ max_length = @model_options[:max_length]
12
+ mask_id = tokenizer.mask_token_id
13
+
14
+ dataset = []
15
+ texts.zip(labels).each do |text,label_values|
16
+ fixed_text = text.gsub("[MASK]", "[PENDINGMASK]")
17
+ label_tokens = label_values.collect{|label| tokenizer.convert_tokens_to_ids(label) }
18
+ label_tokens.each do |ids|
19
+ ids = [ids] unless Array === ids
20
+ fixed_text.sub!("[PENDINGMASK]", "[MASK]" * ids.length)
21
+ end
22
+
23
+ tokenized_text = tokenizer.call(fixed_text, truncation: true, padding: "max_length")
24
+ input_ids = tokenized_text["input_ids"].to_a
25
+ attention_mask = tokenized_text["attention_mask"].to_a
26
+
27
+ all_label_tokens = label_tokens.flatten
28
+ label_ids = input_ids.collect do |id|
29
+ if id == mask_id
30
+ all_label_tokens.shift
31
+ else
32
+ -100
33
+ end
34
+ end
35
+ dataset << {input_ids: input_ids, labels: label_ids, attention_mask: attention_mask}
36
+ end
37
+
38
+ dataset_file = File.join(@directory, 'dataset.json')
39
+ Open.write(dataset_file, dataset.collect{|e| e.to_json} * "\n")
40
+
41
+ training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, @model_path, @model_options[:training_args])
42
+ data_collator = RbbtPython.class_new_obj("transformers", "DefaultDataCollator", {})
43
+ RbbtPython.call_method("rbbt_dm.huggingface", :train_model, model, tokenizer, training_args_obj, dataset_file, @model_options[:class_weights], data_collator: data_collator)
44
+
45
+ model.save_pretrained(@model_path) if @model_path
46
+ tokenizer.save_pretrained(@model_path) if @model_path
47
+ end
48
+
49
+ end
50
+ end
@@ -1,50 +1,68 @@
1
- require 'rbbt/vector/model'
2
- require 'rbbt/util/python'
1
+ require 'rbbt/vector/model/torch'
3
2
 
4
- RbbtPython.add_path Rbbt.python.find(:lib)
5
- RbbtPython.init_rbbt
3
+ class HuggingfaceModel < TorchModel
6
4
 
7
- class HuggingfaceModel < VectorModel
8
-
9
- def self.tsv_dataset(tsv_dataset_file, elements, labels = nil)
5
+ def self.tsv_dataset(tsv_dataset_file, elements, labels = nil, class_labels = nil)
10
6
 
11
7
  if labels
8
+ labels = case class_labels
9
+ when Array
10
+ labels.collect{|l| class_labels.index l}
11
+ when Hash
12
+ inverse_class_labels = {}
13
+ class_labels.each{|c,l| inverse_class_labels[l] = c }
14
+ labels.collect{|l| inverse_class_labels[l]}
15
+ else
16
+ labels
17
+ end
18
+
12
19
  Open.write(tsv_dataset_file) do |ffile|
13
20
  ffile.puts ["label", "text"].flatten * "\t"
14
21
  elements.zip(labels).each do |element,label|
22
+ element = element.gsub("\n", " ")
15
23
  ffile.puts [label, element].flatten * "\t"
16
24
  end
25
+ ffile.sync
17
26
  end
18
27
  else
19
28
  Open.write(tsv_dataset_file) do |ffile|
20
29
  ffile.puts ["text"].flatten * "\t"
21
- elements.each{|element| ffile.puts element }
30
+ elements.each do |element|
31
+ element = element.gsub("\n", " ")
32
+ ffile.puts element
33
+ end
34
+ ffile.sync
22
35
  end
23
36
  end
24
37
 
25
38
  tsv_dataset_file
26
39
  end
27
40
 
28
- def initialize(task, checkpoint, *args)
29
- options = args.pop if Hash === args.last
30
- options = Misc.add_defaults options, :task => task, :checkpoint => checkpoint
31
- super(*args)
32
- @model_options ||= {}
33
- @model_options.merge!(options)
41
+ def initialize(task, checkpoint, dir = nil, model_options = {})
42
+ super(dir, model_options)
34
43
 
35
- eval_model do |directory,texts|
36
- checkpoint = directory && File.directory?(directory) ? directory : @model_options[:checkpoint]
44
+ @model_options = Misc.add_defaults @model_options, :task => task, :checkpoint => checkpoint
37
45
 
38
- if @model.nil?
39
- @model, @tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_model_and_tokenizer, @model_options[:task], checkpoint)
40
- end
41
-
42
- if Array === texts
46
+ init_model do
47
+ checkpoint = @model_path && File.directory?(@model_path) ? @model_path : @model_options[:checkpoint]
48
+ model = RbbtPython.call_method("rbbt_dm.huggingface", :load_model,
49
+ @model_options[:task], checkpoint, **(IndiferentHash.setup(model_options[:model_args]) || {}))
50
+ tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_tokenizer,
51
+ @model_options[:task], checkpoint, **(IndiferentHash.setup(model_options[:tokenizer_args]) || {}))
52
+
53
+ [model, tokenizer]
54
+ end
55
+
56
+ eval_model do |texts,is_list|
57
+ model, tokenizer = self.init
58
+
59
+ if is_list || @model_options[:task] == "MaskedLM"
60
+ texts = [texts] if ! is_list
43
61
 
44
62
  if @model_options.include?(:locate_tokens)
45
63
  locate_tokens = @model_options[:locate_tokens]
46
64
  elsif @model_options[:task] == "MaskedLM"
47
- @model_options[:locate_tokens] = locate_tokens = @tokenizer.special_tokens_map["mask_token"]
65
+ @model_options[:locate_tokens] = locate_tokens = tokenizer.special_tokens_map["mask_token"]
48
66
  end
49
67
 
50
68
  if @directory
@@ -61,18 +79,17 @@ class HuggingfaceModel < VectorModel
61
79
  training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, checkpoint_dir, @model_options[:training_args])
62
80
 
63
81
  begin
64
- RbbtPython.call_method("rbbt_dm.huggingface", :predict_model, @model, @tokenizer, training_args_obj, dataset_file, locate_tokens)
82
+ RbbtPython.call_method("rbbt_dm.huggingface", :predict_model, model, tokenizer, training_args_obj, dataset_file, locate_tokens)
65
83
  ensure
66
84
  Open.rm_rf tmpdir if tmpdir
67
85
  end
68
86
  else
69
- RbbtPython.call_method("rbbt_dm.huggingface", :eval_model, @model, @tokenizer, [texts], locate_tokens)
87
+ RbbtPython.call_method("rbbt_dm.huggingface", :eval_model, model, tokenizer, [texts], locate_tokens)
70
88
  end
71
89
  end
72
90
 
73
- train_model do |directory,texts,labels|
74
- checkpoint = directory && File.directory?(directory) ? directory : @model_options[:checkpoint]
75
- @model, @tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_model_and_tokenizer, @model_options[:task], checkpoint)
91
+ train_model do |texts,labels|
92
+ model, tokenizer = self.init
76
93
 
77
94
  if @directory
78
95
  tsv_file = File.join(@directory, 'dataset.tsv')
@@ -85,17 +102,19 @@ class HuggingfaceModel < VectorModel
85
102
  end
86
103
 
87
104
  training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, checkpoint_dir, @model_options[:training_args])
88
- dataset_file = HuggingfaceModel.tsv_dataset(tsv_file, texts, labels)
105
+ dataset_file = HuggingfaceModel.tsv_dataset(tsv_file, texts, labels, @model_options[:class_labels])
89
106
 
90
- RbbtPython.call_method("rbbt_dm.huggingface", :train_model, @model, @tokenizer, training_args_obj, dataset_file, @model_options[:class_weights])
107
+ RbbtPython.call_method("rbbt_dm.huggingface", :train_model, model, tokenizer, training_args_obj, dataset_file, @model_options[:class_weights])
91
108
 
92
109
  Open.rm_rf tmpdir if tmpdir
93
110
 
94
- @model.save_pretrained(directory) if directory
95
- @tokenizer.save_pretrained(directory) if directory
111
+ model.save_pretrained(@model_path) if @model_path
112
+ tokenizer.save_pretrained(@model_path) if @model_path
96
113
  end
97
114
 
98
- post_process do |result|
115
+ post_process do |result,is_list|
116
+ model, tokenizer = self.init
117
+
99
118
  if result.respond_to?(:predictions)
100
119
  single = false
101
120
  predictions = result.predictions
@@ -137,25 +156,25 @@ class HuggingfaceModel < VectorModel
137
156
  best_token
138
157
  end
139
158
 
140
- best.collect{|b| @tokenizer.decode(b) } * "|"
159
+ best.collect{|b| tokenizer.decode(b) } * "|"
141
160
  end
142
161
  Array === locate_tokens ? item_masks : item_masks.first
143
162
  end
144
163
  else
145
- logits
164
+ predictions
146
165
  end
147
166
 
148
- single ? result.first : result
167
+ (! is_list || single) && Array === result ? result.first : result
149
168
  end
150
169
 
151
170
 
152
- save_models if @directory
171
+ save_models if @model_path
153
172
  end
154
173
 
155
174
  def reset_model
156
175
  @model, @tokenizer = nil
157
- Open.rm @model_file
176
+ Open.rm_rf @model_path
177
+ init
158
178
  end
159
-
160
179
  end
161
180
 
@@ -0,0 +1,35 @@
1
+ require 'rbbt/vector/model/torch'
2
+
3
+ class PytorchLightningModel < TorchModel
4
+ attr_accessor :loader, :val_loader, :trainer
5
+ def initialize(module_name, class_name, dir = nil, model_options = {})
6
+ super(dir, model_options)
7
+ @module_name = module_name
8
+ @class_name = class_name
9
+
10
+ init_model do
11
+ RbbtPython.pyimport @module_name
12
+ RbbtPython.class_new_obj(@module_name, @class_name, @model_options[:model_args] || {})
13
+ end
14
+
15
+ train_model do |features,labels|
16
+ model = init
17
+ raise "Use the loader" if @loader.nil?
18
+ raise "Use the trainer" if @trainer.nil?
19
+
20
+ trainer.fit(model, @loader, @val_loader)
21
+ end
22
+
23
+ eval_model do |features,list|
24
+ if list
25
+ model.call(RbbtPython.call_method(:torch, :tensor, features))
26
+ else
27
+ model.call(RbbtPython.call_method(:torch, :tensor, [features]))
28
+ end
29
+ end
30
+
31
+ end
32
+ end
33
+
34
+ if __FILE__ == $0
35
+ end
@@ -27,7 +27,7 @@ label = predict(model, features);
27
27
  def importance
28
28
  TmpFile.with_file do |tmp|
29
29
  tsv = R.run <<-EOF
30
- load(file="#{model_file}");
30
+ load(file="#{@model_path}");
31
31
  rbbt.tsv.write('#{tmp}', model$importance)
32
32
  EOF
33
33
  TSV.open(tmp)
@@ -28,12 +28,12 @@ class SpaCyModel < VectorModel
28
28
 
29
29
  super(dir)
30
30
 
31
- @train_model = Proc.new do |file, features, labels|
31
+ @train_model = Proc.new do |features, labels|
32
32
  texts = features
33
33
  docs = []
34
34
  unique_labels = labels.uniq
35
- tmpconfig = File.join(file, 'config')
36
- tmptrain = File.join(file, 'train.spacy')
35
+ tmpconfig = File.join(@model_path, 'config')
36
+ tmptrain = File.join(@model_path, 'train.spacy')
37
37
  SpaCy.config(@config, tmpconfig)
38
38
 
39
39
  bar = bar(features.length, "Training documents into spacy format")
@@ -54,20 +54,22 @@ class SpaCyModel < VectorModel
54
54
  end
55
55
 
56
56
  gpu = Rbbt::Config.get('gpu_id', :spacy, :spacy_train, :default => 0)
57
- CMD.cmd_log(:spacy, "train #{tmpconfig} --output #{file} --paths.train #{tmptrain} --paths.dev #{tmptrain}", "--gpu-id" => gpu)
57
+ CMD.cmd_log(:spacy, "train #{tmpconfig} --output #{@model_path} --paths.train #{tmptrain} --paths.dev #{tmptrain}", "--gpu-id" => gpu)
58
58
  end
59
59
 
60
- @eval_model = Proc.new do |file, features,list|
60
+ @eval_model = Proc.new do |features,list|
61
61
  texts = features
62
62
  texts = [texts] unless list
63
63
 
64
+ model_path = @model_path
65
+
64
66
  docs = []
65
67
  bar = bar(features.length, "Evaluating model")
66
68
  SpaCyModel.spacy do
67
69
  gpu = Rbbt::Config.get('gpu_id', :spacy, :spacy_train, :default => 0)
68
70
  gpu = gpu.to_i if gpu && gpu != ""
69
71
  spacy.require_gpu(gpu) if gpu
70
- nlp = spacy.load("#{file}/model-best")
72
+ nlp = spacy.load("#{model_path}/model-best")
71
73
 
72
74
  docs = nlp.pipe(texts)
73
75
  RbbtPython.collect docs, :bar => bar do |d|
@@ -75,14 +77,6 @@ class SpaCyModel < VectorModel
75
77
  d.cats.sort_by{|l,v| v.to_f || 0 }.last.first
76
78
  end
77
79
  end
78
- #nlp.(docs).cats.collect{|cats| cats.sort_by{|l,v| v.to_f }.last.first }
79
- #Log::ProgressBar.with_bar texts.length, :desc => "Evaluating documents" do |bar|
80
- # texts.collect do |text|
81
- # cats = nlp.(text).cats
82
- # bar.tick
83
- # cats.sort_by{|l,v| v.to_f }.last.first
84
- # end
85
- #end
86
80
  end
87
81
  end
88
82
  end
@@ -25,7 +25,7 @@ class TensorFlowModel < VectorModel
25
25
 
26
26
  super(dir)
27
27
 
28
- @train_model = Proc.new do |file, features, labels|
28
+ @train_model = Proc.new do |features, labels|
29
29
  tensorflow do
30
30
  features = tensorflow.convert_to_tensor(features)
31
31
  labels = tensorflow.convert_to_tensor(labels)
@@ -33,16 +33,17 @@ class TensorFlowModel < VectorModel
33
33
  @graph ||= keras_graph
34
34
  @graph.compile(**@compile_options)
35
35
  @graph.fit(features, labels, :epochs => @epochs, :verbose => true)
36
- @graph.save(file)
36
+ @graph.save(@model_path)
37
37
  end
38
38
 
39
- @eval_model = Proc.new do |file, features|
39
+ @eval_model = Proc.new do |features|
40
40
  tensorflow do
41
41
  features = tensorflow.convert_to_tensor(features)
42
42
  end
43
+ model_path = @model_path
44
+ graph = @graph ||= keras.models.load_model(model_path)
43
45
  keras do
44
- @graph ||= keras.models.load_model(file)
45
- indices = @graph.predict(features, :verbose => false).tolist()
46
+ indices = graph.predict(features, :verbose => false).tolist()
46
47
  labels = indices.collect{|p| p.length > 1 ? p.index(p.max): p.first }
47
48
  labels
48
49
  end
@@ -0,0 +1,37 @@
1
+ require 'rbbt/vector/model'
2
+ require 'rbbt/util/python'
3
+
4
+ RbbtPython.add_path Rbbt.python.find(:lib)
5
+ RbbtPython.init_rbbt
6
+
7
+ class TorchModel < VectorModel
8
+
9
+ attr_accessor :model
10
+
11
+ def self.get_layer(model, layer)
12
+ layer.split(".").inject(model){|acc,l| PyCall.getattr(acc, l.to_sym) }
13
+ end
14
+
15
+ def self.get_weights(model, layer)
16
+ PyCall.getattr(get_layer(model, layer), :weight)
17
+ end
18
+
19
+ def self.freeze(layer)
20
+ begin
21
+ PyCall.getattr(layer, :weight).requires_grad = false
22
+ rescue
23
+ end
24
+ RbbtPython.iterate(layer.children) do |layer|
25
+ freeze(layer)
26
+ end
27
+ end
28
+
29
+ def self.freeze_layer(model, layer)
30
+ layer = get_layer(model, layer)
31
+ freeze(layer)
32
+ end
33
+
34
+ def initialize(dir, model_options = {})
35
+ super(dir, model_options)
36
+ end
37
+ end
@@ -9,4 +9,22 @@ class VectorModel
9
9
  @bar.init
10
10
  @bar
11
11
  end
12
+
13
+ def balance_labels
14
+ counts = Misc.counts(@labels)
15
+ min = counts.values.min
16
+
17
+ used = {}
18
+ new_labels = []
19
+ new_features = []
20
+ @labels.zip(@features).shuffle.each do |label, features|
21
+ used[label] ||= 0
22
+ next if used[label] > min
23
+ used[label] += 1
24
+ new_labels << label
25
+ new_features << features
26
+ end
27
+ @labels = new_labels
28
+ @features = new_features
29
+ end
12
30
  end