rbbt-dm 1.2.7 → 1.2.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. checksums.yaml +4 -4
  2. data/lib/rbbt/matrix/barcode.rb +2 -2
  3. data/lib/rbbt/matrix/differential.rb +3 -3
  4. data/lib/rbbt/matrix/knowledge_base.rb +1 -1
  5. data/lib/rbbt/plots/bar.rb +1 -1
  6. data/lib/rbbt/stan.rb +1 -1
  7. data/lib/rbbt/statistics/hypergeometric.rb +2 -1
  8. data/lib/rbbt/vector/model/huggingface/masked_lm.rb +50 -0
  9. data/lib/rbbt/vector/model/huggingface.rb +39 -52
  10. data/lib/rbbt/vector/model/python.rb +33 -0
  11. data/lib/rbbt/vector/model/pytorch_lightning.rb +31 -0
  12. data/lib/rbbt/vector/model/random_forest.rb +1 -1
  13. data/lib/rbbt/vector/model/spaCy.rb +8 -6
  14. data/lib/rbbt/vector/model/tensorflow.rb +6 -5
  15. data/lib/rbbt/vector/model/torch/dataloader.rb +58 -0
  16. data/lib/rbbt/vector/model/torch/helpers.rb +52 -0
  17. data/lib/rbbt/vector/model/torch/introspection.rb +31 -0
  18. data/lib/rbbt/vector/model/torch/load_and_save.rb +30 -0
  19. data/lib/rbbt/vector/model/torch.rb +71 -0
  20. data/lib/rbbt/vector/model.rb +84 -54
  21. data/python/rbbt_dm/__init__.py +31 -1
  22. data/python/rbbt_dm/atcold/__init__.py +0 -0
  23. data/python/rbbt_dm/atcold/plot_lib.py +141 -0
  24. data/python/rbbt_dm/atcold/spiral.py +27 -0
  25. data/python/rbbt_dm/huggingface.py +64 -28
  26. data/python/rbbt_dm/language_model.py +70 -0
  27. data/python/rbbt_dm/util.py +32 -0
  28. data/share/spaCy/gpu/textcat_accuracy.conf +2 -1
  29. data/test/rbbt/vector/model/huggingface/test_masked_lm.rb +41 -0
  30. data/test/rbbt/vector/model/test_huggingface.rb +258 -27
  31. data/test/rbbt/vector/model/test_python.rb +31 -0
  32. data/test/rbbt/vector/model/test_pytorch_lightning.rb +97 -0
  33. data/test/rbbt/vector/model/test_spaCy.rb +1 -1
  34. data/test/rbbt/vector/model/test_tensorflow.rb +3 -0
  35. data/test/rbbt/vector/model/test_torch.rb +61 -0
  36. data/test/rbbt/vector/test_model.rb +25 -26
  37. data/test/test_helper.rb +13 -0
  38. metadata +35 -16
  39. data/lib/rbbt/tensorflow.rb +0 -43
  40. data/lib/rbbt/vector/model/huggingface.old.rb +0 -160
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 1c55843bf543c88167239f6e182495963e0683c5a7fdd7c3a7ab9bd501a78bc8
4
- data.tar.gz: d01aaf45331766eac6d868749b8df72c49d1a6888f44f7a1d4f8cbfefe258c87
3
+ metadata.gz: 9d53609453e1c3bd589c95071569583bff5f11224a200850c0d9e85775e5a2ce
4
+ data.tar.gz: 501f436caff07c990c09adec4caa60d32fe577b35209a65e8f96418ab2acc422
5
5
  SHA512:
6
- metadata.gz: 7b6a225ce0403759ab45f26d371d491c19fc76f6560771868a58b9de921fd3aa03750bd7aec95c34029f61f53e71e382958f2779ca790fde30958cfbd1169a0b
7
- data.tar.gz: ae1b6d44072398fbde96a0cb31f9586076dee1a5c7e2ac32726c65ecaaa3d08b59ea627c7a0f9f4a8e87547d5a403452ea5bee1d0736d610bf73b6456cb99be9
6
+ metadata.gz: 4a53ccc8a5eac633e344beffd36073e1828d848d9743552f12e75a54cbda337050152b72fd95d7d97d4aeed50f811dcc8475c97bc292ba9bcdd57e9d32f0e91b
7
+ data.tar.gz: 30fef4ee2c023e141c0ef991ae35d1a0913d4444138fda778760b821ccc48d85d67dc44d4c251d7f84e3b1bbe6ecb6f3b64aee13fbdc443cb33bbed2b34a586a
@@ -3,7 +3,7 @@ require 'rbbt/util/R'
3
3
  class RbbtMatrix
4
4
  def barcode(outfile, factor = 2)
5
5
 
6
- FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exists? File.dirname(outfile)
6
+ FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exist? File.dirname(outfile)
7
7
  cmd =<<-EOF
8
8
  source('#{Rbbt.share.R['barcode.R'].find}')
9
9
  rbbt.GE.barcode.mode(#{ R.ruby2R self.data_file }, #{ R.ruby2R outfile }, #{ R.ruby2R factor })
@@ -49,7 +49,7 @@ rbbt.GE.barcode.mode(#{ R.ruby2R self.data_file }, #{ R.ruby2R outfile }, #{ R.r
49
49
 
50
50
  clusters = Array === clusters ? clusters : (2..clusters).to_a
51
51
 
52
- FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exists? File.dirname(outfile)
52
+ FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exist? File.dirname(outfile)
53
53
  cmd =<<-EOF
54
54
  source('#{Rbbt.share.R['barcode.R'].find}')
55
55
  rbbt.GE.activity_cluster(#{ R.ruby2R self.data_file }, #{ R.ruby2R outfile }, #{R.ruby2R key_field}, #{R.ruby2R clusters})
@@ -45,7 +45,7 @@ class RbbtMatrix
45
45
  end
46
46
 
47
47
  file = file.find if Path === file
48
- FileUtils.mkdir_p File.dirname(file) unless file.nil? or File.exists? File.dirname(file)
48
+ FileUtils.mkdir_p File.dirname(file) unless file.nil? or File.exist? File.dirname(file)
49
49
 
50
50
  cmd = <<-EOS
51
51
 
@@ -71,10 +71,10 @@ end
71
71
 
72
72
 
73
73
  #def self.analyze(datafile, main, contrast = nil, log2 = false, outfile = nil, key_field = nil, two_channel = nil)
74
- # FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exists? File.dirname(outfile)
74
+ # FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exist? File.dirname(outfile)
75
75
  # GE.run_R("rbbt.GE.process(#{ R.ruby2R datafile }, main = #{R.ruby2R(main, :strings => true)}, contrast = #{R.ruby2R(contrast, :strings => true)}, log2=#{ R.ruby2R log2 }, outfile = #{R.ruby2R outfile}, key.field = #{R.ruby2R key_field}, two.channel = #{R.ruby2R two_channel})")
76
76
  #end
77
77
  #def self.barcode(datafile, outfile, factor = 2)
78
- # FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exists? File.dirname(outfile)
78
+ # FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exist? File.dirname(outfile)
79
79
  # GE.run_R("rbbt.GE.barcode(#{ R.ruby2R datafile }, #{ R.ruby2R outfile }, #{ R.ruby2R factor })")
80
80
  #end
@@ -13,7 +13,7 @@ class KnowledgeBase
13
13
 
14
14
  return matrix if RbbtMatrix === matrix
15
15
 
16
- Path.setup(matrix) if not Path === matrix and File.exists? matrix
16
+ Path.setup(matrix) if not Path === matrix and File.exist? matrix
17
17
 
18
18
  raise "Registered matrix is strange: #{Misc.fingerprint matrix}" unless Path === matrix
19
19
 
@@ -119,7 +119,7 @@ module BarPlot
119
119
  options = Misc.add_defaults options, :width => [options[:total], MAX_WIDTH].min, :height => 20, :background => PNG::Color::White
120
120
  width, height, background, canvas = Misc.process_options options, :width, :height, :background, :canvas
121
121
 
122
- canvas ||= if options[:update] and options[:filename] and File.exists? options[:filename]
122
+ canvas ||= if options[:update] and options[:filename] and File.exist? options[:filename]
123
123
  PNG.load_file options[:filename]
124
124
  else
125
125
  PNG::Canvas.new width, height, get_color(background)
data/lib/rbbt/stan.rb CHANGED
@@ -130,7 +130,7 @@ print(fit)
130
130
  erase = true
131
131
  end
132
132
 
133
- FileUtils.mkdir_p directory unless File.exists? directory
133
+ FileUtils.mkdir_p directory unless File.exist? directory
134
134
  input_directory = File.join(directory, 'inputs')
135
135
  parameter_chains = File.join(directory, 'chains') unless erase
136
136
  summary = File.join(directory, 'summary') unless erase
@@ -118,7 +118,7 @@ module TSV
118
118
 
119
119
  field_pos = fields.collect{|f| self.fields.index f}.compact
120
120
  persistence_path = self.respond_to?(:persistence_path)? self.persistence_path : nil
121
- Persist.persist(filename, :yaml, :fields => fields, :persist => persistence, :prefix => "Hyp.Geo.Counts", :other => {:background => background, :rename => rename, :persistence_path => persistence_path}) do
121
+ Persist.persist(filename, :yaml, :fields => fields, :persist => persistence, :prefix => "Hyp.Geo.Counts", :other => {:fields => fields, :background => background, :rename => rename, :persistence_path => persistence_path}) do
122
122
  data ||= {}
123
123
 
124
124
  with_unnamed do
@@ -152,6 +152,7 @@ module TSV
152
152
 
153
153
  end
154
154
 
155
+
155
156
  if rename
156
157
  Log.debug("Using renames during annotation counts")
157
158
  Hash[*data.keys.zip(data.values.collect{|l| l.collect{|e| rename.include?(e)? rename[e] : e }.uniq.length }).flatten]
@@ -0,0 +1,50 @@
1
+ require 'rbbt/vector/model/huggingface'
2
+ class MaskedLMModel < HuggingfaceModel
3
+
4
+ def initialize(checkpoint, dir = nil, model_options = {})
5
+
6
+ model_options = Misc.add_defaults model_options, :max_length => 128
7
+ super("MaskedLM", checkpoint, dir, model_options)
8
+
9
+ train_model do |texts,labels|
10
+ model, tokenizer = self.init
11
+ max_length = @model_options[:max_length]
12
+ mask_id = tokenizer.mask_token_id
13
+
14
+ dataset = []
15
+ texts.zip(labels).each do |text,label_values|
16
+ fixed_text = text.gsub("[MASK]", "[PENDINGMASK]")
17
+ label_tokens = label_values.collect{|label| tokenizer.convert_tokens_to_ids(label) }
18
+ label_tokens.each do |ids|
19
+ ids = [ids] unless Array === ids
20
+ fixed_text.sub!("[PENDINGMASK]", "[MASK]" * ids.length)
21
+ end
22
+
23
+ tokenized_text = tokenizer.call(fixed_text, truncation: true, padding: "max_length")
24
+ input_ids = tokenized_text["input_ids"].to_a
25
+ attention_mask = tokenized_text["attention_mask"].to_a
26
+
27
+ all_label_tokens = label_tokens.flatten
28
+ label_ids = input_ids.collect do |id|
29
+ if id == mask_id
30
+ all_label_tokens.shift
31
+ else
32
+ -100
33
+ end
34
+ end
35
+ dataset << {input_ids: input_ids, labels: label_ids, attention_mask: attention_mask}
36
+ end
37
+
38
+ dataset_file = File.join(@directory, 'dataset.json')
39
+ Open.write(dataset_file, dataset.collect{|e| e.to_json} * "\n")
40
+
41
+ training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, @model_path, @model_options[:training_args])
42
+ data_collator = RbbtPython.class_new_obj("transformers", "DefaultDataCollator", {})
43
+ RbbtPython.call_method("rbbt_dm.huggingface", :train_model, model, tokenizer, training_args_obj, dataset_file, @model_options[:class_weights], data_collator: data_collator)
44
+
45
+ model.save_pretrained(@model_path) if @model_path
46
+ tokenizer.save_pretrained(@model_path) if @model_path
47
+ end
48
+
49
+ end
50
+ end
@@ -1,50 +1,36 @@
1
- require 'rbbt/vector/model'
2
- require 'rbbt/util/python'
1
+ require 'rbbt/vector/model/torch'
3
2
 
4
- RbbtPython.add_path Rbbt.python.find(:lib)
5
- RbbtPython.init_rbbt
3
+ class HuggingfaceModel < TorchModel
6
4
 
7
- class HuggingfaceModel < VectorModel
5
+ def initialize(task, checkpoint, dir = nil, model_options = {})
6
+ super(dir, nil, model_options)
8
7
 
9
- def self.tsv_dataset(tsv_dataset_file, elements, labels = nil)
8
+ @model_options = Misc.add_defaults @model_options, :task => task, :checkpoint => checkpoint
10
9
 
11
- if labels
12
- Open.write(tsv_dataset_file) do |ffile|
13
- ffile.puts ["label", "text"].flatten * "\t"
14
- elements.zip(labels).each do |element,label|
15
- ffile.puts [label, element].flatten * "\t"
16
- end
17
- end
18
- else
19
- Open.write(tsv_dataset_file) do |ffile|
20
- ffile.puts ["text"].flatten * "\t"
21
- elements.each{|element| ffile.puts element }
22
- end
23
- end
10
+ init_model do
11
+ checkpoint = @model_path && File.directory?(@model_path) ? @model_path : @model_options[:checkpoint]
24
12
 
25
- tsv_dataset_file
26
- end
13
+ model = RbbtPython.call_method("rbbt_dm.huggingface", :load_model,
14
+ @model_options[:task], checkpoint, **(IndiferentHash.setup(model_options[:model_args]) || {}))
27
15
 
28
- def initialize(task, checkpoint, *args)
29
- options = args.pop if Hash === args.last
30
- options = Misc.add_defaults options, :task => task, :checkpoint => checkpoint
31
- super(*args)
32
- @model_options ||= {}
33
- @model_options.merge!(options)
16
+ tokenizer_checkpoint = @model_options[:tokenizer_checkpoint] || checkpoint
34
17
 
35
- eval_model do |directory,texts|
36
- checkpoint = directory && File.directory?(directory) ? directory : @model_options[:checkpoint]
18
+ tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_tokenizer,
19
+ @model_options[:task], tokenizer_checkpoint, **(IndiferentHash.setup(model_options[:tokenizer_args]) || {}))
37
20
 
38
- if @model.nil?
39
- @model, @tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_model_and_tokenizer, @model_options[:task], checkpoint)
40
- end
41
-
42
- if Array === texts
21
+ [model, tokenizer]
22
+ end
23
+
24
+ eval_model do |texts,is_list|
25
+ model, tokenizer = self.init
26
+
27
+ if is_list || @model_options[:task] == "MaskedLM"
28
+ texts = [texts] if ! is_list
43
29
 
44
30
  if @model_options.include?(:locate_tokens)
45
31
  locate_tokens = @model_options[:locate_tokens]
46
32
  elsif @model_options[:task] == "MaskedLM"
47
- @model_options[:locate_tokens] = locate_tokens = @tokenizer.special_tokens_map["mask_token"]
33
+ @model_options[:locate_tokens] = locate_tokens = tokenizer.special_tokens_map["mask_token"]
48
34
  end
49
35
 
50
36
  if @directory
@@ -57,22 +43,21 @@ class HuggingfaceModel < VectorModel
57
43
  checkpoint_dir = File.join(tmpdir, 'checkpoints')
58
44
  end
59
45
 
60
- dataset_file = HuggingfaceModel.tsv_dataset(tsv_file, texts)
46
+ dataset_file = TorchModel.text_dataset(tsv_file, texts)
61
47
  training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, checkpoint_dir, @model_options[:training_args])
62
48
 
63
49
  begin
64
- RbbtPython.call_method("rbbt_dm.huggingface", :predict_model, @model, @tokenizer, training_args_obj, dataset_file, locate_tokens)
50
+ RbbtPython.call_method("rbbt_dm.huggingface", :predict_model, model, tokenizer, training_args_obj, dataset_file, locate_tokens)
65
51
  ensure
66
52
  Open.rm_rf tmpdir if tmpdir
67
53
  end
68
54
  else
69
- RbbtPython.call_method("rbbt_dm.huggingface", :eval_model, @model, @tokenizer, [texts], locate_tokens)
55
+ RbbtPython.call_method("rbbt_dm.huggingface", :eval_model, model, tokenizer, [texts], locate_tokens)
70
56
  end
71
57
  end
72
58
 
73
- train_model do |directory,texts,labels|
74
- checkpoint = directory && File.directory?(directory) ? directory : @model_options[:checkpoint]
75
- @model, @tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_model_and_tokenizer, @model_options[:task], checkpoint)
59
+ train_model do |texts,labels|
60
+ model, tokenizer = self.init
76
61
 
77
62
  if @directory
78
63
  tsv_file = File.join(@directory, 'dataset.tsv')
@@ -85,17 +70,19 @@ class HuggingfaceModel < VectorModel
85
70
  end
86
71
 
87
72
  training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, checkpoint_dir, @model_options[:training_args])
88
- dataset_file = HuggingfaceModel.tsv_dataset(tsv_file, texts, labels)
73
+ dataset_file = HuggingfaceModel.text_dataset(tsv_file, texts, labels, @model_options[:class_labels])
89
74
 
90
- RbbtPython.call_method("rbbt_dm.huggingface", :train_model, @model, @tokenizer, training_args_obj, dataset_file, @model_options[:class_weights])
75
+ RbbtPython.call_method("rbbt_dm.huggingface", :train_model, model, tokenizer, training_args_obj, dataset_file, @model_options[:class_weights])
91
76
 
92
77
  Open.rm_rf tmpdir if tmpdir
93
78
 
94
- @model.save_pretrained(directory) if directory
95
- @tokenizer.save_pretrained(directory) if directory
79
+ model.save_pretrained(@model_path) if @model_path
80
+ tokenizer.save_pretrained(@model_path) if @model_path
96
81
  end
97
82
 
98
- post_process do |result|
83
+ post_process do |result,is_list|
84
+ model, tokenizer = self.init
85
+
99
86
  if result.respond_to?(:predictions)
100
87
  single = false
101
88
  predictions = result.predictions
@@ -137,25 +124,25 @@ class HuggingfaceModel < VectorModel
137
124
  best_token
138
125
  end
139
126
 
140
- best.collect{|b| @tokenizer.decode(b) } * "|"
127
+ best.collect{|b| tokenizer.decode(b) } * "|"
141
128
  end
142
129
  Array === locate_tokens ? item_masks : item_masks.first
143
130
  end
144
131
  else
145
- logits
132
+ predictions
146
133
  end
147
134
 
148
- single ? result.first : result
135
+ (! is_list || single) && Array === result ? result.first : result
149
136
  end
150
137
 
151
138
 
152
- save_models if @directory
139
+ save_models if @model_path
153
140
  end
154
141
 
155
142
  def reset_model
156
143
  @model, @tokenizer = nil
157
- Open.rm @model_file
144
+ Open.rm_rf @model_path
145
+ init
158
146
  end
159
-
160
147
  end
161
148
 
@@ -0,0 +1,33 @@
1
+ require 'rbbt/vector/model'
2
+ require 'rbbt/util/python'
3
+
4
+ RbbtPython.add_path Rbbt.python.find(:lib)
5
+ RbbtPython.init_rbbt
6
+
7
+ class PythonModel < VectorModel
8
+ attr_accessor :python_class, :python_module
9
+ def initialize(dir, python_class = nil, python_module = nil, model_options = nil)
10
+ python_module = :model if python_module.nil?
11
+ model_options, python_module = python_module, :model if model_options.nil? && Hash === python_module
12
+ model_options = {} if model_options.nil?
13
+
14
+ super(dir, model_options)
15
+
16
+ @python_class = python_class
17
+ @python_module = python_module
18
+
19
+ init_model do
20
+ RbbtPython.add_path @directory
21
+ RbbtPython.class_new_obj(@python_module, @python_class, **model_options)
22
+ end if python_class
23
+
24
+ eval_model do |features,list=false|
25
+ init
26
+ if list
27
+ model.eval(features)
28
+ else
29
+ model.eval([features])[0]
30
+ end
31
+ end
32
+ end
33
+ end
@@ -0,0 +1,31 @@
1
+ require 'rbbt/vector/model/torch'
2
+
3
+ class PytorchLightningModel < TorchModel
4
+ attr_accessor :loader, :val_loader, :trainer
5
+ def initialize(...)
6
+ super(...)
7
+
8
+ train_model do |features,labels|
9
+ model = init
10
+ loader = self.loader
11
+ val_loader = self.val_loader
12
+ if (features && features.any?) && loader.nil?
13
+ TmpFile.with_file do |tsv_dataset_file|
14
+ TorchModel.feature_dataset(tsv_dataset_file, features, labels)
15
+ RbbtPython.pyimport :rbbt_dm
16
+ loader = RbbtPython.rbbt_dm.tsv(tsv_dataset_file)
17
+ end
18
+ end
19
+ trainer.fit(model, loader, val_loader)
20
+ TorchModel.save_architecture(model, model_path) if @directory
21
+ TorchModel.save_state(model, model_path) if @directory
22
+ end
23
+ end
24
+
25
+ def trainer
26
+ @trainer ||= begin
27
+ options = @model_options[:training_args] || @model_options[:trainer_args]
28
+ RbbtPython.class_new_obj("pytorch_lightning", "Trainer", options || {})
29
+ end
30
+ end
31
+ end
@@ -27,7 +27,7 @@ label = predict(model, features);
27
27
  def importance
28
28
  TmpFile.with_file do |tmp|
29
29
  tsv = R.run <<-EOF
30
- load(file="#{model_file}");
30
+ load(file="#{@model_path}");
31
31
  rbbt.tsv.write('#{tmp}', model$importance)
32
32
  EOF
33
33
  TSV.open(tmp)
@@ -28,12 +28,12 @@ class SpaCyModel < VectorModel
28
28
 
29
29
  super(dir)
30
30
 
31
- @train_model = Proc.new do |file, features, labels|
31
+ @train_model = Proc.new do |features, labels|
32
32
  texts = features
33
33
  docs = []
34
34
  unique_labels = labels.uniq
35
- tmpconfig = File.join(file, 'config')
36
- tmptrain = File.join(file, 'train.spacy')
35
+ tmpconfig = File.join(@model_path, 'config')
36
+ tmptrain = File.join(@model_path, 'train.spacy')
37
37
  SpaCy.config(@config, tmpconfig)
38
38
 
39
39
  bar = bar(features.length, "Training documents into spacy format")
@@ -54,20 +54,22 @@ class SpaCyModel < VectorModel
54
54
  end
55
55
 
56
56
  gpu = Rbbt::Config.get('gpu_id', :spacy, :spacy_train, :default => 0)
57
- CMD.cmd_log(:spacy, "train #{tmpconfig} --output #{file} --paths.train #{tmptrain} --paths.dev #{tmptrain}", "--gpu-id" => gpu)
57
+ CMD.cmd_log(:spacy, "train #{tmpconfig} --output #{@model_path} --paths.train #{tmptrain} --paths.dev #{tmptrain}", "--gpu-id" => gpu)
58
58
  end
59
59
 
60
- @eval_model = Proc.new do |file, features,list|
60
+ @eval_model = Proc.new do |features,list|
61
61
  texts = features
62
62
  texts = [texts] unless list
63
63
 
64
+ model_path = @model_path
65
+
64
66
  docs = []
65
67
  bar = bar(features.length, "Evaluating model")
66
68
  SpaCyModel.spacy do
67
69
  gpu = Rbbt::Config.get('gpu_id', :spacy, :spacy_train, :default => 0)
68
70
  gpu = gpu.to_i if gpu && gpu != ""
69
71
  spacy.require_gpu(gpu) if gpu
70
- nlp = spacy.load("#{file}/model-best")
72
+ nlp = spacy.load("#{model_path}/model-best")
71
73
 
72
74
  docs = nlp.pipe(texts)
73
75
  RbbtPython.collect docs, :bar => bar do |d|
@@ -25,7 +25,7 @@ class TensorFlowModel < VectorModel
25
25
 
26
26
  super(dir)
27
27
 
28
- @train_model = Proc.new do |file, features, labels|
28
+ @train_model = Proc.new do |features, labels|
29
29
  tensorflow do
30
30
  features = tensorflow.convert_to_tensor(features)
31
31
  labels = tensorflow.convert_to_tensor(labels)
@@ -33,16 +33,17 @@ class TensorFlowModel < VectorModel
33
33
  @graph ||= keras_graph
34
34
  @graph.compile(**@compile_options)
35
35
  @graph.fit(features, labels, :epochs => @epochs, :verbose => true)
36
- @graph.save(file)
36
+ @graph.save(@model_path)
37
37
  end
38
38
 
39
- @eval_model = Proc.new do |file, features|
39
+ @eval_model = Proc.new do |features|
40
40
  tensorflow do
41
41
  features = tensorflow.convert_to_tensor(features)
42
42
  end
43
+ model_path = @model_path
44
+ graph = @graph ||= keras.models.load_model(model_path)
43
45
  keras do
44
- @graph ||= keras.models.load_model(file)
45
- indices = @graph.predict(features, :verbose => false).tolist()
46
+ indices = graph.predict(features, :verbose => false).tolist()
46
47
  labels = indices.collect{|p| p.length > 1 ? p.index(p.max): p.first }
47
48
  labels
48
49
  end
@@ -0,0 +1,58 @@
1
+ class TorchModel
2
+ def self.feature_tsv(elements, labels = nil, class_labels = nil)
3
+ tsv = TSV.setup({}, :key_field => "ID", :fields => ["features"], :type => :flat)
4
+ if labels
5
+ tsv.fields = tsv.fields + ["label"]
6
+ labels = case class_labels
7
+ when Array
8
+ labels.collect{|l| class_labels.index l}
9
+ when Hash
10
+ inverse_class_labels = {}
11
+ class_labels.each{|c,l| inverse_class_labels[l] = c }
12
+ labels.collect{|l| inverse_class_labels[l]}
13
+ else
14
+ labels
15
+ end
16
+ elements.zip(labels).each_with_index do |p,i|
17
+ features, label = p
18
+ id = i
19
+ if Array === features
20
+ tsv[id] = features + [label]
21
+ else
22
+ tsv[id] = [features, label]
23
+ end
24
+ end
25
+ else
26
+ elements.each_with_index do |features,i|
27
+ id = i
28
+ if Array === features
29
+ tsv[id] = features
30
+ else
31
+ tsv[id] = [features]
32
+ end
33
+ end
34
+ end
35
+ tsv
36
+ end
37
+
38
+ def self.feature_dataset(tsv_dataset_file, elements, labels = nil, class_labels = nil)
39
+ tsv = feature_tsv(elements, labels, class_labels)
40
+ Open.write(tsv_dataset_file, tsv.to_s)
41
+ tsv_dataset_file
42
+ end
43
+
44
+ def self.text_dataset(tsv_dataset_file, elements, labels = nil, class_labels = nil)
45
+ elements = elements.collect{|e| e.gsub("\n", ' ') }
46
+ tsv = feature_tsv(elements, labels, class_labels)
47
+ if labels.nil?
48
+ tsv.fields[0] = "text"
49
+ tsv.type = :single
50
+ else
51
+ tsv.fields[0] = "text"
52
+ tsv.type = :list
53
+ end
54
+ Open.write(tsv_dataset_file, tsv.to_s)
55
+ tsv_dataset_file
56
+ end
57
+
58
+ end
@@ -0,0 +1,52 @@
1
+ class TorchModel
2
+ module Tensor
3
+ def to_ruby
4
+ RbbtPython.numpy2ruby(self)
5
+ end
6
+ def self.setup(obj)
7
+ obj.extend Tensor
8
+ end
9
+ end
10
+
11
+ def self.init_python
12
+ RbbtPython.pyimport :torch
13
+ RbbtPython.pyimport :rbbt
14
+ RbbtPython.pyimport :rbbt_dm
15
+ RbbtPython.pyfrom :rbbt_dm, import: :util
16
+ RbbtPython.pyfrom :torch, import: :nn
17
+ end
18
+
19
+ def self.optimizer(model, training_args)
20
+ begin
21
+ learning_rate = training_args[:learning_rate] || 0.01
22
+ RbbtPython.torch.optim.SGD.new(model.parameters(), lr: learning_rate)
23
+ end
24
+ end
25
+
26
+ def self.device(model_options)
27
+ case model_options[:device]
28
+ when String, Symbol
29
+ RbbtPython.torch.device(model_options[:device].to_s)
30
+ when nil
31
+ RbbtPython.rbbt_dm.util.device()
32
+ else
33
+ model_options[:device]
34
+ end
35
+ end
36
+
37
+ def self.dtype(model_options)
38
+ case model_options[:dtype]
39
+ when String, Symbol
40
+ RbbtPython.torch.call(model_options[:dtype])
41
+ when nil
42
+ nil
43
+ else
44
+ model_options[:dtype]
45
+ end
46
+ end
47
+
48
+ def self.tensor(obj, device, dtype)
49
+ RbbtPython.torch.tensor(obj, dtype: dtype, device: device)
50
+ end
51
+
52
+ end
@@ -0,0 +1,31 @@
1
+ class TorchModel
2
+ def self.get_layer(model, layer = nil)
3
+ if layer.nil?
4
+ model
5
+ else
6
+ layer.split(".").inject(model){|acc,l| PyCall.getattr(acc, l.to_sym) }
7
+ end
8
+ end
9
+ def get_layer(...); TorchModel.get_layer(model, ...); end
10
+
11
+ def self.get_weights(model, layer = nil)
12
+ Tensor.setup PyCall.getattr(get_layer(model, layer), :weight)
13
+ end
14
+ def get_weights(...); TorchModel.get_weights(model, ...); end
15
+
16
+ def self.freeze(layer)
17
+ begin
18
+ PyCall.getattr(layer, :weight).requires_grad = false
19
+ rescue
20
+ end
21
+ RbbtPython.iterate(layer.children) do |layer|
22
+ freeze(layer)
23
+ end
24
+ end
25
+ def self.freeze_layer(model, layer)
26
+ layer = get_layer(model, layer)
27
+ freeze(layer)
28
+ end
29
+ def freeze_layer(...); TorchModel.freeze_layer(model, ...); end
30
+
31
+ end
@@ -0,0 +1,30 @@
1
+ class TorchModel
2
+ def self.model_architecture(model_path)
3
+ model_path + '.architecture'
4
+ end
5
+
6
+ def self.save_state(model, model_path)
7
+ Log.debug "Saving model state into #{model_path}"
8
+ RbbtPython.torch.save(model.state_dict(), model_path)
9
+ end
10
+
11
+ def self.load_state(model, model_path)
12
+ return model unless Open.exists?(model_path)
13
+ Log.debug "Loading model state from #{model_path}"
14
+ model.load_state_dict(RbbtPython.torch.load(model_path))
15
+ model
16
+ end
17
+
18
+ def self.save_architecture(model, model_path)
19
+ model_architecture = model_architecture(model_path)
20
+ Log.debug "Saving model architecture into #{model_architecture}"
21
+ RbbtPython.torch.save(model, model_architecture)
22
+ end
23
+
24
+ def self.load_architecture(model_path)
25
+ model_architecture = model_architecture(model_path)
26
+ return unless Open.exists?(model_architecture)
27
+ Log.debug "Loading model architecture from #{model_architecture}"
28
+ RbbtPython.torch.load(model_architecture)
29
+ end
30
+ end