rbbt-dm 1.2.6 → 1.2.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. checksums.yaml +4 -4
  2. data/lib/rbbt/matrix/barcode.rb +2 -2
  3. data/lib/rbbt/matrix/differential.rb +3 -3
  4. data/lib/rbbt/matrix/knowledge_base.rb +1 -1
  5. data/lib/rbbt/plots/bar.rb +1 -1
  6. data/lib/rbbt/stan.rb +1 -1
  7. data/lib/rbbt/statistics/hypergeometric.rb +2 -1
  8. data/lib/rbbt/vector/model/huggingface/masked_lm.rb +50 -0
  9. data/lib/rbbt/vector/model/huggingface.rb +57 -38
  10. data/lib/rbbt/vector/model/pytorch_lightning.rb +35 -0
  11. data/lib/rbbt/vector/model/random_forest.rb +1 -1
  12. data/lib/rbbt/vector/model/spaCy.rb +8 -14
  13. data/lib/rbbt/vector/model/tensorflow.rb +6 -5
  14. data/lib/rbbt/vector/model/torch.rb +37 -0
  15. data/lib/rbbt/vector/model/util.rb +18 -0
  16. data/lib/rbbt/vector/model.rb +100 -56
  17. data/python/rbbt_dm/__init__.py +48 -1
  18. data/python/rbbt_dm/atcold/__init__.py +0 -0
  19. data/python/rbbt_dm/atcold/plot_lib.py +141 -0
  20. data/python/rbbt_dm/atcold/spiral.py +27 -0
  21. data/python/rbbt_dm/huggingface.py +57 -26
  22. data/python/rbbt_dm/language_model.py +70 -0
  23. data/python/rbbt_dm/util.py +30 -0
  24. data/share/spaCy/gpu/textcat_accuracy.conf +2 -1
  25. data/test/rbbt/vector/model/huggingface/test_masked_lm.rb +41 -0
  26. data/test/rbbt/vector/model/test_huggingface.rb +258 -27
  27. data/test/rbbt/vector/model/test_pytorch_lightning.rb +83 -0
  28. data/test/rbbt/vector/model/test_spaCy.rb +1 -1
  29. data/test/rbbt/vector/model/test_tensorflow.rb +3 -0
  30. data/test/rbbt/vector/test_model.rb +25 -26
  31. data/test/test_helper.rb +13 -0
  32. metadata +26 -16
  33. data/lib/rbbt/tensorflow.rb +0 -43
  34. data/lib/rbbt/vector/model/huggingface.old.rb +0 -160
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 9744ab9faeaf4f9cc04947eb11103dbf0694dda624f805a5c6be27bb22af81ce
4
- data.tar.gz: d3a3903aa276a69e20cbd71213286449db396ecf5f6a4b4d80a64ab299041fbb
3
+ metadata.gz: db2cbab94e21fd2ca67f7306fa9941b59cbfb2865382e5439edf6313f50309e7
4
+ data.tar.gz: f4acf3651daa90ef23bc454c62df68e208976a977d51e2e85d02558d48897187
5
5
  SHA512:
6
- metadata.gz: 263fb609b37522874426bcd79374760399b4a9aaab443ae6d74c727f2d148474dd71ee0b2cfda7a50131dafbc314f66352f4285562a75f62144d2e05ccd214c7
7
- data.tar.gz: 1e0426429a38028a19b3f8c955e975138199c791dad8691de7fb760a5cbec3304f19341906a4457563a90de965ee2f12a5a639b928866b46258af2507eeb39fa
6
+ metadata.gz: 7786759636450821aabca306cd210ab3d201e094b81bb70052d57d7bfb6e4de73a198576fe4b002487baf7997138f9c53b91644632cd12cc79b40ff62141a70a
7
+ data.tar.gz: 9870745068a897909170f3a6187e520e6530b121f8ad4ab40224c3369a16a8bb1c1e55bb8ca8fd892943ec1fcead9f4661e49c3d97d07d3740236a1ec4f69a34
@@ -3,7 +3,7 @@ require 'rbbt/util/R'
3
3
  class RbbtMatrix
4
4
  def barcode(outfile, factor = 2)
5
5
 
6
- FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exists? File.dirname(outfile)
6
+ FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exist? File.dirname(outfile)
7
7
  cmd =<<-EOF
8
8
  source('#{Rbbt.share.R['barcode.R'].find}')
9
9
  rbbt.GE.barcode.mode(#{ R.ruby2R self.data_file }, #{ R.ruby2R outfile }, #{ R.ruby2R factor })
@@ -49,7 +49,7 @@ rbbt.GE.barcode.mode(#{ R.ruby2R self.data_file }, #{ R.ruby2R outfile }, #{ R.r
49
49
 
50
50
  clusters = Array === clusters ? clusters : (2..clusters).to_a
51
51
 
52
- FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exists? File.dirname(outfile)
52
+ FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exist? File.dirname(outfile)
53
53
  cmd =<<-EOF
54
54
  source('#{Rbbt.share.R['barcode.R'].find}')
55
55
  rbbt.GE.activity_cluster(#{ R.ruby2R self.data_file }, #{ R.ruby2R outfile }, #{R.ruby2R key_field}, #{R.ruby2R clusters})
@@ -45,7 +45,7 @@ class RbbtMatrix
45
45
  end
46
46
 
47
47
  file = file.find if Path === file
48
- FileUtils.mkdir_p File.dirname(file) unless file.nil? or File.exists? File.dirname(file)
48
+ FileUtils.mkdir_p File.dirname(file) unless file.nil? or File.exist? File.dirname(file)
49
49
 
50
50
  cmd = <<-EOS
51
51
 
@@ -71,10 +71,10 @@ end
71
71
 
72
72
 
73
73
  #def self.analyze(datafile, main, contrast = nil, log2 = false, outfile = nil, key_field = nil, two_channel = nil)
74
- # FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exists? File.dirname(outfile)
74
+ # FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exist? File.dirname(outfile)
75
75
  # GE.run_R("rbbt.GE.process(#{ R.ruby2R datafile }, main = #{R.ruby2R(main, :strings => true)}, contrast = #{R.ruby2R(contrast, :strings => true)}, log2=#{ R.ruby2R log2 }, outfile = #{R.ruby2R outfile}, key.field = #{R.ruby2R key_field}, two.channel = #{R.ruby2R two_channel})")
76
76
  #end
77
77
  #def self.barcode(datafile, outfile, factor = 2)
78
- # FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exists? File.dirname(outfile)
78
+ # FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exist? File.dirname(outfile)
79
79
  # GE.run_R("rbbt.GE.barcode(#{ R.ruby2R datafile }, #{ R.ruby2R outfile }, #{ R.ruby2R factor })")
80
80
  #end
@@ -13,7 +13,7 @@ class KnowledgeBase
13
13
 
14
14
  return matrix if RbbtMatrix === matrix
15
15
 
16
- Path.setup(matrix) if not Path === matrix and File.exists? matrix
16
+ Path.setup(matrix) if not Path === matrix and File.exist? matrix
17
17
 
18
18
  raise "Registered matrix is strange: #{Misc.fingerprint matrix}" unless Path === matrix
19
19
 
@@ -119,7 +119,7 @@ module BarPlot
119
119
  options = Misc.add_defaults options, :width => [options[:total], MAX_WIDTH].min, :height => 20, :background => PNG::Color::White
120
120
  width, height, background, canvas = Misc.process_options options, :width, :height, :background, :canvas
121
121
 
122
- canvas ||= if options[:update] and options[:filename] and File.exists? options[:filename]
122
+ canvas ||= if options[:update] and options[:filename] and File.exist? options[:filename]
123
123
  PNG.load_file options[:filename]
124
124
  else
125
125
  PNG::Canvas.new width, height, get_color(background)
data/lib/rbbt/stan.rb CHANGED
@@ -130,7 +130,7 @@ print(fit)
130
130
  erase = true
131
131
  end
132
132
 
133
- FileUtils.mkdir_p directory unless File.exists? directory
133
+ FileUtils.mkdir_p directory unless File.exist? directory
134
134
  input_directory = File.join(directory, 'inputs')
135
135
  parameter_chains = File.join(directory, 'chains') unless erase
136
136
  summary = File.join(directory, 'summary') unless erase
@@ -118,7 +118,7 @@ module TSV
118
118
 
119
119
  field_pos = fields.collect{|f| self.fields.index f}.compact
120
120
  persistence_path = self.respond_to?(:persistence_path)? self.persistence_path : nil
121
- Persist.persist(filename, :yaml, :fields => fields, :persist => persistence, :prefix => "Hyp.Geo.Counts", :other => {:background => background, :rename => rename, :persistence_path => persistence_path}) do
121
+ Persist.persist(filename, :yaml, :fields => fields, :persist => persistence, :prefix => "Hyp.Geo.Counts", :other => {:fields => fields, :background => background, :rename => rename, :persistence_path => persistence_path}) do
122
122
  data ||= {}
123
123
 
124
124
  with_unnamed do
@@ -152,6 +152,7 @@ module TSV
152
152
 
153
153
  end
154
154
 
155
+
155
156
  if rename
156
157
  Log.debug("Using renames during annotation counts")
157
158
  Hash[*data.keys.zip(data.values.collect{|l| l.collect{|e| rename.include?(e)? rename[e] : e }.uniq.length }).flatten]
@@ -0,0 +1,50 @@
1
+ require 'rbbt/vector/model/huggingface'
2
+ class MaskedLMModel < HuggingfaceModel
3
+
4
+ def initialize(checkpoint, dir = nil, model_options = {})
5
+
6
+ model_options = Misc.add_defaults model_options, :max_length => 128
7
+ super("MaskedLM", checkpoint, dir, model_options)
8
+
9
+ train_model do |texts,labels|
10
+ model, tokenizer = self.init
11
+ max_length = @model_options[:max_length]
12
+ mask_id = tokenizer.mask_token_id
13
+
14
+ dataset = []
15
+ texts.zip(labels).each do |text,label_values|
16
+ fixed_text = text.gsub("[MASK]", "[PENDINGMASK]")
17
+ label_tokens = label_values.collect{|label| tokenizer.convert_tokens_to_ids(label) }
18
+ label_tokens.each do |ids|
19
+ ids = [ids] unless Array === ids
20
+ fixed_text.sub!("[PENDINGMASK]", "[MASK]" * ids.length)
21
+ end
22
+
23
+ tokenized_text = tokenizer.call(fixed_text, truncation: true, padding: "max_length")
24
+ input_ids = tokenized_text["input_ids"].to_a
25
+ attention_mask = tokenized_text["attention_mask"].to_a
26
+
27
+ all_label_tokens = label_tokens.flatten
28
+ label_ids = input_ids.collect do |id|
29
+ if id == mask_id
30
+ all_label_tokens.shift
31
+ else
32
+ -100
33
+ end
34
+ end
35
+ dataset << {input_ids: input_ids, labels: label_ids, attention_mask: attention_mask}
36
+ end
37
+
38
+ dataset_file = File.join(@directory, 'dataset.json')
39
+ Open.write(dataset_file, dataset.collect{|e| e.to_json} * "\n")
40
+
41
+ training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, @model_path, @model_options[:training_args])
42
+ data_collator = RbbtPython.class_new_obj("transformers", "DefaultDataCollator", {})
43
+ RbbtPython.call_method("rbbt_dm.huggingface", :train_model, model, tokenizer, training_args_obj, dataset_file, @model_options[:class_weights], data_collator: data_collator)
44
+
45
+ model.save_pretrained(@model_path) if @model_path
46
+ tokenizer.save_pretrained(@model_path) if @model_path
47
+ end
48
+
49
+ end
50
+ end
@@ -1,50 +1,68 @@
1
- require 'rbbt/vector/model'
2
- require 'rbbt/util/python'
1
+ require 'rbbt/vector/model/torch'
3
2
 
4
- RbbtPython.add_path Rbbt.python.find(:lib)
5
- RbbtPython.init_rbbt
3
+ class HuggingfaceModel < TorchModel
6
4
 
7
- class HuggingfaceModel < VectorModel
8
-
9
- def self.tsv_dataset(tsv_dataset_file, elements, labels = nil)
5
+ def self.tsv_dataset(tsv_dataset_file, elements, labels = nil, class_labels = nil)
10
6
 
11
7
  if labels
8
+ labels = case class_labels
9
+ when Array
10
+ labels.collect{|l| class_labels.index l}
11
+ when Hash
12
+ inverse_class_labels = {}
13
+ class_labels.each{|c,l| inverse_class_labels[l] = c }
14
+ labels.collect{|l| inverse_class_labels[l]}
15
+ else
16
+ labels
17
+ end
18
+
12
19
  Open.write(tsv_dataset_file) do |ffile|
13
20
  ffile.puts ["label", "text"].flatten * "\t"
14
21
  elements.zip(labels).each do |element,label|
22
+ element = element.gsub("\n", " ")
15
23
  ffile.puts [label, element].flatten * "\t"
16
24
  end
25
+ ffile.sync
17
26
  end
18
27
  else
19
28
  Open.write(tsv_dataset_file) do |ffile|
20
29
  ffile.puts ["text"].flatten * "\t"
21
- elements.each{|element| ffile.puts element }
30
+ elements.each do |element|
31
+ element = element.gsub("\n", " ")
32
+ ffile.puts element
33
+ end
34
+ ffile.sync
22
35
  end
23
36
  end
24
37
 
25
38
  tsv_dataset_file
26
39
  end
27
40
 
28
- def initialize(task, checkpoint, *args)
29
- options = args.pop if Hash === args.last
30
- options = Misc.add_defaults options, :task => task, :checkpoint => checkpoint
31
- super(*args)
32
- @model_options ||= {}
33
- @model_options.merge!(options)
41
+ def initialize(task, checkpoint, dir = nil, model_options = {})
42
+ super(dir, model_options)
34
43
 
35
- eval_model do |directory,texts|
36
- checkpoint = directory && File.directory?(directory) ? directory : @model_options[:checkpoint]
44
+ @model_options = Misc.add_defaults @model_options, :task => task, :checkpoint => checkpoint
37
45
 
38
- if @model.nil?
39
- @model, @tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_model_and_tokenizer, @model_options[:task], checkpoint)
40
- end
41
-
42
- if Array === texts
46
+ init_model do
47
+ checkpoint = @model_path && File.directory?(@model_path) ? @model_path : @model_options[:checkpoint]
48
+ model = RbbtPython.call_method("rbbt_dm.huggingface", :load_model,
49
+ @model_options[:task], checkpoint, **(IndiferentHash.setup(model_options[:model_args]) || {}))
50
+ tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_tokenizer,
51
+ @model_options[:task], checkpoint, **(IndiferentHash.setup(model_options[:tokenizer_args]) || {}))
52
+
53
+ [model, tokenizer]
54
+ end
55
+
56
+ eval_model do |texts,is_list|
57
+ model, tokenizer = self.init
58
+
59
+ if is_list || @model_options[:task] == "MaskedLM"
60
+ texts = [texts] if ! is_list
43
61
 
44
62
  if @model_options.include?(:locate_tokens)
45
63
  locate_tokens = @model_options[:locate_tokens]
46
64
  elsif @model_options[:task] == "MaskedLM"
47
- @model_options[:locate_tokens] = locate_tokens = @tokenizer.special_tokens_map["mask_token"]
65
+ @model_options[:locate_tokens] = locate_tokens = tokenizer.special_tokens_map["mask_token"]
48
66
  end
49
67
 
50
68
  if @directory
@@ -61,18 +79,17 @@ class HuggingfaceModel < VectorModel
61
79
  training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, checkpoint_dir, @model_options[:training_args])
62
80
 
63
81
  begin
64
- RbbtPython.call_method("rbbt_dm.huggingface", :predict_model, @model, @tokenizer, training_args_obj, dataset_file, locate_tokens)
82
+ RbbtPython.call_method("rbbt_dm.huggingface", :predict_model, model, tokenizer, training_args_obj, dataset_file, locate_tokens)
65
83
  ensure
66
84
  Open.rm_rf tmpdir if tmpdir
67
85
  end
68
86
  else
69
- RbbtPython.call_method("rbbt_dm.huggingface", :eval_model, @model, @tokenizer, [texts], locate_tokens)
87
+ RbbtPython.call_method("rbbt_dm.huggingface", :eval_model, model, tokenizer, [texts], locate_tokens)
70
88
  end
71
89
  end
72
90
 
73
- train_model do |directory,texts,labels|
74
- checkpoint = directory && File.directory?(directory) ? directory : @model_options[:checkpoint]
75
- @model, @tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_model_and_tokenizer, @model_options[:task], checkpoint)
91
+ train_model do |texts,labels|
92
+ model, tokenizer = self.init
76
93
 
77
94
  if @directory
78
95
  tsv_file = File.join(@directory, 'dataset.tsv')
@@ -85,17 +102,19 @@ class HuggingfaceModel < VectorModel
85
102
  end
86
103
 
87
104
  training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, checkpoint_dir, @model_options[:training_args])
88
- dataset_file = HuggingfaceModel.tsv_dataset(tsv_file, texts, labels)
105
+ dataset_file = HuggingfaceModel.tsv_dataset(tsv_file, texts, labels, @model_options[:class_labels])
89
106
 
90
- RbbtPython.call_method("rbbt_dm.huggingface", :train_model, @model, @tokenizer, training_args_obj, dataset_file, @model_options[:class_weights])
107
+ RbbtPython.call_method("rbbt_dm.huggingface", :train_model, model, tokenizer, training_args_obj, dataset_file, @model_options[:class_weights])
91
108
 
92
109
  Open.rm_rf tmpdir if tmpdir
93
110
 
94
- @model.save_pretrained(directory) if directory
95
- @tokenizer.save_pretrained(directory) if directory
111
+ model.save_pretrained(@model_path) if @model_path
112
+ tokenizer.save_pretrained(@model_path) if @model_path
96
113
  end
97
114
 
98
- post_process do |result|
115
+ post_process do |result,is_list|
116
+ model, tokenizer = self.init
117
+
99
118
  if result.respond_to?(:predictions)
100
119
  single = false
101
120
  predictions = result.predictions
@@ -137,25 +156,25 @@ class HuggingfaceModel < VectorModel
137
156
  best_token
138
157
  end
139
158
 
140
- best.collect{|b| @tokenizer.decode(b) } * "|"
159
+ best.collect{|b| tokenizer.decode(b) } * "|"
141
160
  end
142
161
  Array === locate_tokens ? item_masks : item_masks.first
143
162
  end
144
163
  else
145
- logits
164
+ predictions
146
165
  end
147
166
 
148
- single ? result.first : result
167
+ (! is_list || single) && Array === result ? result.first : result
149
168
  end
150
169
 
151
170
 
152
- save_models if @directory
171
+ save_models if @model_path
153
172
  end
154
173
 
155
174
  def reset_model
156
175
  @model, @tokenizer = nil
157
- Open.rm @model_file
176
+ Open.rm_rf @model_path
177
+ init
158
178
  end
159
-
160
179
  end
161
180
 
@@ -0,0 +1,35 @@
1
+ require 'rbbt/vector/model/torch'
2
+
3
+ class PytorchLightningModel < TorchModel
4
+ attr_accessor :loader, :val_loader, :trainer
5
+ def initialize(module_name, class_name, dir = nil, model_options = {})
6
+ super(dir, model_options)
7
+ @module_name = module_name
8
+ @class_name = class_name
9
+
10
+ init_model do
11
+ RbbtPython.pyimport @module_name
12
+ RbbtPython.class_new_obj(@module_name, @class_name, @model_options[:model_args] || {})
13
+ end
14
+
15
+ train_model do |features,labels|
16
+ model = init
17
+ raise "Use the loader" if @loader.nil?
18
+ raise "Use the trainer" if @trainer.nil?
19
+
20
+ trainer.fit(model, @loader, @val_loader)
21
+ end
22
+
23
+ eval_model do |features,list|
24
+ if list
25
+ model.call(RbbtPython.call_method(:torch, :tensor, features))
26
+ else
27
+ model.call(RbbtPython.call_method(:torch, :tensor, [features]))
28
+ end
29
+ end
30
+
31
+ end
32
+ end
33
+
34
+ if __FILE__ == $0
35
+ end
@@ -27,7 +27,7 @@ label = predict(model, features);
27
27
  def importance
28
28
  TmpFile.with_file do |tmp|
29
29
  tsv = R.run <<-EOF
30
- load(file="#{model_file}");
30
+ load(file="#{@model_path}");
31
31
  rbbt.tsv.write('#{tmp}', model$importance)
32
32
  EOF
33
33
  TSV.open(tmp)
@@ -28,12 +28,12 @@ class SpaCyModel < VectorModel
28
28
 
29
29
  super(dir)
30
30
 
31
- @train_model = Proc.new do |file, features, labels|
31
+ @train_model = Proc.new do |features, labels|
32
32
  texts = features
33
33
  docs = []
34
34
  unique_labels = labels.uniq
35
- tmpconfig = File.join(file, 'config')
36
- tmptrain = File.join(file, 'train.spacy')
35
+ tmpconfig = File.join(@model_path, 'config')
36
+ tmptrain = File.join(@model_path, 'train.spacy')
37
37
  SpaCy.config(@config, tmpconfig)
38
38
 
39
39
  bar = bar(features.length, "Training documents into spacy format")
@@ -54,20 +54,22 @@ class SpaCyModel < VectorModel
54
54
  end
55
55
 
56
56
  gpu = Rbbt::Config.get('gpu_id', :spacy, :spacy_train, :default => 0)
57
- CMD.cmd_log(:spacy, "train #{tmpconfig} --output #{file} --paths.train #{tmptrain} --paths.dev #{tmptrain}", "--gpu-id" => gpu)
57
+ CMD.cmd_log(:spacy, "train #{tmpconfig} --output #{@model_path} --paths.train #{tmptrain} --paths.dev #{tmptrain}", "--gpu-id" => gpu)
58
58
  end
59
59
 
60
- @eval_model = Proc.new do |file, features,list|
60
+ @eval_model = Proc.new do |features,list|
61
61
  texts = features
62
62
  texts = [texts] unless list
63
63
 
64
+ model_path = @model_path
65
+
64
66
  docs = []
65
67
  bar = bar(features.length, "Evaluating model")
66
68
  SpaCyModel.spacy do
67
69
  gpu = Rbbt::Config.get('gpu_id', :spacy, :spacy_train, :default => 0)
68
70
  gpu = gpu.to_i if gpu && gpu != ""
69
71
  spacy.require_gpu(gpu) if gpu
70
- nlp = spacy.load("#{file}/model-best")
72
+ nlp = spacy.load("#{model_path}/model-best")
71
73
 
72
74
  docs = nlp.pipe(texts)
73
75
  RbbtPython.collect docs, :bar => bar do |d|
@@ -75,14 +77,6 @@ class SpaCyModel < VectorModel
75
77
  d.cats.sort_by{|l,v| v.to_f || 0 }.last.first
76
78
  end
77
79
  end
78
- #nlp.(docs).cats.collect{|cats| cats.sort_by{|l,v| v.to_f }.last.first }
79
- #Log::ProgressBar.with_bar texts.length, :desc => "Evaluating documents" do |bar|
80
- # texts.collect do |text|
81
- # cats = nlp.(text).cats
82
- # bar.tick
83
- # cats.sort_by{|l,v| v.to_f }.last.first
84
- # end
85
- #end
86
80
  end
87
81
  end
88
82
  end
@@ -25,7 +25,7 @@ class TensorFlowModel < VectorModel
25
25
 
26
26
  super(dir)
27
27
 
28
- @train_model = Proc.new do |file, features, labels|
28
+ @train_model = Proc.new do |features, labels|
29
29
  tensorflow do
30
30
  features = tensorflow.convert_to_tensor(features)
31
31
  labels = tensorflow.convert_to_tensor(labels)
@@ -33,16 +33,17 @@ class TensorFlowModel < VectorModel
33
33
  @graph ||= keras_graph
34
34
  @graph.compile(**@compile_options)
35
35
  @graph.fit(features, labels, :epochs => @epochs, :verbose => true)
36
- @graph.save(file)
36
+ @graph.save(@model_path)
37
37
  end
38
38
 
39
- @eval_model = Proc.new do |file, features|
39
+ @eval_model = Proc.new do |features|
40
40
  tensorflow do
41
41
  features = tensorflow.convert_to_tensor(features)
42
42
  end
43
+ model_path = @model_path
44
+ graph = @graph ||= keras.models.load_model(model_path)
43
45
  keras do
44
- @graph ||= keras.models.load_model(file)
45
- indices = @graph.predict(features, :verbose => false).tolist()
46
+ indices = graph.predict(features, :verbose => false).tolist()
46
47
  labels = indices.collect{|p| p.length > 1 ? p.index(p.max): p.first }
47
48
  labels
48
49
  end
@@ -0,0 +1,37 @@
1
+ require 'rbbt/vector/model'
2
+ require 'rbbt/util/python'
3
+
4
+ RbbtPython.add_path Rbbt.python.find(:lib)
5
+ RbbtPython.init_rbbt
6
+
7
+ class TorchModel < VectorModel
8
+
9
+ attr_accessor :model
10
+
11
+ def self.get_layer(model, layer)
12
+ layer.split(".").inject(model){|acc,l| PyCall.getattr(acc, l.to_sym) }
13
+ end
14
+
15
+ def self.get_weights(model, layer)
16
+ PyCall.getattr(get_layer(model, layer), :weight)
17
+ end
18
+
19
+ def self.freeze(layer)
20
+ begin
21
+ PyCall.getattr(layer, :weight).requires_grad = false
22
+ rescue
23
+ end
24
+ RbbtPython.iterate(layer.children) do |layer|
25
+ freeze(layer)
26
+ end
27
+ end
28
+
29
+ def self.freeze_layer(model, layer)
30
+ layer = get_layer(model, layer)
31
+ freeze(layer)
32
+ end
33
+
34
+ def initialize(dir, model_options = {})
35
+ super(dir, model_options)
36
+ end
37
+ end
@@ -9,4 +9,22 @@ class VectorModel
9
9
  @bar.init
10
10
  @bar
11
11
  end
12
+
13
+ def balance_labels
14
+ counts = Misc.counts(@labels)
15
+ min = counts.values.min
16
+
17
+ used = {}
18
+ new_labels = []
19
+ new_features = []
20
+ @labels.zip(@features).shuffle.each do |label, features|
21
+ used[label] ||= 0
22
+ next if used[label] > min
23
+ used[label] += 1
24
+ new_labels << label
25
+ new_features << features
26
+ end
27
+ @labels = new_labels
28
+ @features = new_features
29
+ end
12
30
  end