rbbt-dm 1.2.7 → 1.2.10

Sign up to get free protection for your applications and to get access to all the features.
Files changed (40) hide show
  1. checksums.yaml +4 -4
  2. data/lib/rbbt/matrix/barcode.rb +2 -2
  3. data/lib/rbbt/matrix/differential.rb +3 -3
  4. data/lib/rbbt/matrix/knowledge_base.rb +1 -1
  5. data/lib/rbbt/plots/bar.rb +1 -1
  6. data/lib/rbbt/stan.rb +1 -1
  7. data/lib/rbbt/statistics/hypergeometric.rb +2 -1
  8. data/lib/rbbt/vector/model/huggingface/masked_lm.rb +50 -0
  9. data/lib/rbbt/vector/model/huggingface.rb +39 -52
  10. data/lib/rbbt/vector/model/python.rb +33 -0
  11. data/lib/rbbt/vector/model/pytorch_lightning.rb +31 -0
  12. data/lib/rbbt/vector/model/random_forest.rb +1 -1
  13. data/lib/rbbt/vector/model/spaCy.rb +8 -6
  14. data/lib/rbbt/vector/model/tensorflow.rb +6 -5
  15. data/lib/rbbt/vector/model/torch/dataloader.rb +58 -0
  16. data/lib/rbbt/vector/model/torch/helpers.rb +52 -0
  17. data/lib/rbbt/vector/model/torch/introspection.rb +31 -0
  18. data/lib/rbbt/vector/model/torch/load_and_save.rb +30 -0
  19. data/lib/rbbt/vector/model/torch.rb +71 -0
  20. data/lib/rbbt/vector/model.rb +84 -54
  21. data/python/rbbt_dm/__init__.py +31 -1
  22. data/python/rbbt_dm/atcold/__init__.py +0 -0
  23. data/python/rbbt_dm/atcold/plot_lib.py +141 -0
  24. data/python/rbbt_dm/atcold/spiral.py +27 -0
  25. data/python/rbbt_dm/huggingface.py +64 -28
  26. data/python/rbbt_dm/language_model.py +70 -0
  27. data/python/rbbt_dm/util.py +32 -0
  28. data/share/spaCy/gpu/textcat_accuracy.conf +2 -1
  29. data/test/rbbt/vector/model/huggingface/test_masked_lm.rb +41 -0
  30. data/test/rbbt/vector/model/test_huggingface.rb +258 -27
  31. data/test/rbbt/vector/model/test_python.rb +31 -0
  32. data/test/rbbt/vector/model/test_pytorch_lightning.rb +97 -0
  33. data/test/rbbt/vector/model/test_spaCy.rb +1 -1
  34. data/test/rbbt/vector/model/test_tensorflow.rb +3 -0
  35. data/test/rbbt/vector/model/test_torch.rb +61 -0
  36. data/test/rbbt/vector/test_model.rb +25 -26
  37. data/test/test_helper.rb +13 -0
  38. metadata +35 -16
  39. data/lib/rbbt/tensorflow.rb +0 -43
  40. data/lib/rbbt/vector/model/huggingface.old.rb +0 -160
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 1c55843bf543c88167239f6e182495963e0683c5a7fdd7c3a7ab9bd501a78bc8
4
- data.tar.gz: d01aaf45331766eac6d868749b8df72c49d1a6888f44f7a1d4f8cbfefe258c87
3
+ metadata.gz: 9d53609453e1c3bd589c95071569583bff5f11224a200850c0d9e85775e5a2ce
4
+ data.tar.gz: 501f436caff07c990c09adec4caa60d32fe577b35209a65e8f96418ab2acc422
5
5
  SHA512:
6
- metadata.gz: 7b6a225ce0403759ab45f26d371d491c19fc76f6560771868a58b9de921fd3aa03750bd7aec95c34029f61f53e71e382958f2779ca790fde30958cfbd1169a0b
7
- data.tar.gz: ae1b6d44072398fbde96a0cb31f9586076dee1a5c7e2ac32726c65ecaaa3d08b59ea627c7a0f9f4a8e87547d5a403452ea5bee1d0736d610bf73b6456cb99be9
6
+ metadata.gz: 4a53ccc8a5eac633e344beffd36073e1828d848d9743552f12e75a54cbda337050152b72fd95d7d97d4aeed50f811dcc8475c97bc292ba9bcdd57e9d32f0e91b
7
+ data.tar.gz: 30fef4ee2c023e141c0ef991ae35d1a0913d4444138fda778760b821ccc48d85d67dc44d4c251d7f84e3b1bbe6ecb6f3b64aee13fbdc443cb33bbed2b34a586a
@@ -3,7 +3,7 @@ require 'rbbt/util/R'
3
3
  class RbbtMatrix
4
4
  def barcode(outfile, factor = 2)
5
5
 
6
- FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exists? File.dirname(outfile)
6
+ FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exist? File.dirname(outfile)
7
7
  cmd =<<-EOF
8
8
  source('#{Rbbt.share.R['barcode.R'].find}')
9
9
  rbbt.GE.barcode.mode(#{ R.ruby2R self.data_file }, #{ R.ruby2R outfile }, #{ R.ruby2R factor })
@@ -49,7 +49,7 @@ rbbt.GE.barcode.mode(#{ R.ruby2R self.data_file }, #{ R.ruby2R outfile }, #{ R.r
49
49
 
50
50
  clusters = Array === clusters ? clusters : (2..clusters).to_a
51
51
 
52
- FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exists? File.dirname(outfile)
52
+ FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exist? File.dirname(outfile)
53
53
  cmd =<<-EOF
54
54
  source('#{Rbbt.share.R['barcode.R'].find}')
55
55
  rbbt.GE.activity_cluster(#{ R.ruby2R self.data_file }, #{ R.ruby2R outfile }, #{R.ruby2R key_field}, #{R.ruby2R clusters})
@@ -45,7 +45,7 @@ class RbbtMatrix
45
45
  end
46
46
 
47
47
  file = file.find if Path === file
48
- FileUtils.mkdir_p File.dirname(file) unless file.nil? or File.exists? File.dirname(file)
48
+ FileUtils.mkdir_p File.dirname(file) unless file.nil? or File.exist? File.dirname(file)
49
49
 
50
50
  cmd = <<-EOS
51
51
 
@@ -71,10 +71,10 @@ end
71
71
 
72
72
 
73
73
  #def self.analyze(datafile, main, contrast = nil, log2 = false, outfile = nil, key_field = nil, two_channel = nil)
74
- # FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exists? File.dirname(outfile)
74
+ # FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exist? File.dirname(outfile)
75
75
  # GE.run_R("rbbt.GE.process(#{ R.ruby2R datafile }, main = #{R.ruby2R(main, :strings => true)}, contrast = #{R.ruby2R(contrast, :strings => true)}, log2=#{ R.ruby2R log2 }, outfile = #{R.ruby2R outfile}, key.field = #{R.ruby2R key_field}, two.channel = #{R.ruby2R two_channel})")
76
76
  #end
77
77
  #def self.barcode(datafile, outfile, factor = 2)
78
- # FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exists? File.dirname(outfile)
78
+ # FileUtils.mkdir_p File.dirname(outfile) unless outfile.nil? or File.exist? File.dirname(outfile)
79
79
  # GE.run_R("rbbt.GE.barcode(#{ R.ruby2R datafile }, #{ R.ruby2R outfile }, #{ R.ruby2R factor })")
80
80
  #end
@@ -13,7 +13,7 @@ class KnowledgeBase
13
13
 
14
14
  return matrix if RbbtMatrix === matrix
15
15
 
16
- Path.setup(matrix) if not Path === matrix and File.exists? matrix
16
+ Path.setup(matrix) if not Path === matrix and File.exist? matrix
17
17
 
18
18
  raise "Registered matrix is strange: #{Misc.fingerprint matrix}" unless Path === matrix
19
19
 
@@ -119,7 +119,7 @@ module BarPlot
119
119
  options = Misc.add_defaults options, :width => [options[:total], MAX_WIDTH].min, :height => 20, :background => PNG::Color::White
120
120
  width, height, background, canvas = Misc.process_options options, :width, :height, :background, :canvas
121
121
 
122
- canvas ||= if options[:update] and options[:filename] and File.exists? options[:filename]
122
+ canvas ||= if options[:update] and options[:filename] and File.exist? options[:filename]
123
123
  PNG.load_file options[:filename]
124
124
  else
125
125
  PNG::Canvas.new width, height, get_color(background)
data/lib/rbbt/stan.rb CHANGED
@@ -130,7 +130,7 @@ print(fit)
130
130
  erase = true
131
131
  end
132
132
 
133
- FileUtils.mkdir_p directory unless File.exists? directory
133
+ FileUtils.mkdir_p directory unless File.exist? directory
134
134
  input_directory = File.join(directory, 'inputs')
135
135
  parameter_chains = File.join(directory, 'chains') unless erase
136
136
  summary = File.join(directory, 'summary') unless erase
@@ -118,7 +118,7 @@ module TSV
118
118
 
119
119
  field_pos = fields.collect{|f| self.fields.index f}.compact
120
120
  persistence_path = self.respond_to?(:persistence_path)? self.persistence_path : nil
121
- Persist.persist(filename, :yaml, :fields => fields, :persist => persistence, :prefix => "Hyp.Geo.Counts", :other => {:background => background, :rename => rename, :persistence_path => persistence_path}) do
121
+ Persist.persist(filename, :yaml, :fields => fields, :persist => persistence, :prefix => "Hyp.Geo.Counts", :other => {:fields => fields, :background => background, :rename => rename, :persistence_path => persistence_path}) do
122
122
  data ||= {}
123
123
 
124
124
  with_unnamed do
@@ -152,6 +152,7 @@ module TSV
152
152
 
153
153
  end
154
154
 
155
+
155
156
  if rename
156
157
  Log.debug("Using renames during annotation counts")
157
158
  Hash[*data.keys.zip(data.values.collect{|l| l.collect{|e| rename.include?(e)? rename[e] : e }.uniq.length }).flatten]
@@ -0,0 +1,50 @@
1
+ require 'rbbt/vector/model/huggingface'
2
+ class MaskedLMModel < HuggingfaceModel
3
+
4
+ def initialize(checkpoint, dir = nil, model_options = {})
5
+
6
+ model_options = Misc.add_defaults model_options, :max_length => 128
7
+ super("MaskedLM", checkpoint, dir, model_options)
8
+
9
+ train_model do |texts,labels|
10
+ model, tokenizer = self.init
11
+ max_length = @model_options[:max_length]
12
+ mask_id = tokenizer.mask_token_id
13
+
14
+ dataset = []
15
+ texts.zip(labels).each do |text,label_values|
16
+ fixed_text = text.gsub("[MASK]", "[PENDINGMASK]")
17
+ label_tokens = label_values.collect{|label| tokenizer.convert_tokens_to_ids(label) }
18
+ label_tokens.each do |ids|
19
+ ids = [ids] unless Array === ids
20
+ fixed_text.sub!("[PENDINGMASK]", "[MASK]" * ids.length)
21
+ end
22
+
23
+ tokenized_text = tokenizer.call(fixed_text, truncation: true, padding: "max_length")
24
+ input_ids = tokenized_text["input_ids"].to_a
25
+ attention_mask = tokenized_text["attention_mask"].to_a
26
+
27
+ all_label_tokens = label_tokens.flatten
28
+ label_ids = input_ids.collect do |id|
29
+ if id == mask_id
30
+ all_label_tokens.shift
31
+ else
32
+ -100
33
+ end
34
+ end
35
+ dataset << {input_ids: input_ids, labels: label_ids, attention_mask: attention_mask}
36
+ end
37
+
38
+ dataset_file = File.join(@directory, 'dataset.json')
39
+ Open.write(dataset_file, dataset.collect{|e| e.to_json} * "\n")
40
+
41
+ training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, @model_path, @model_options[:training_args])
42
+ data_collator = RbbtPython.class_new_obj("transformers", "DefaultDataCollator", {})
43
+ RbbtPython.call_method("rbbt_dm.huggingface", :train_model, model, tokenizer, training_args_obj, dataset_file, @model_options[:class_weights], data_collator: data_collator)
44
+
45
+ model.save_pretrained(@model_path) if @model_path
46
+ tokenizer.save_pretrained(@model_path) if @model_path
47
+ end
48
+
49
+ end
50
+ end
@@ -1,50 +1,36 @@
1
- require 'rbbt/vector/model'
2
- require 'rbbt/util/python'
1
+ require 'rbbt/vector/model/torch'
3
2
 
4
- RbbtPython.add_path Rbbt.python.find(:lib)
5
- RbbtPython.init_rbbt
3
+ class HuggingfaceModel < TorchModel
6
4
 
7
- class HuggingfaceModel < VectorModel
5
+ def initialize(task, checkpoint, dir = nil, model_options = {})
6
+ super(dir, nil, model_options)
8
7
 
9
- def self.tsv_dataset(tsv_dataset_file, elements, labels = nil)
8
+ @model_options = Misc.add_defaults @model_options, :task => task, :checkpoint => checkpoint
10
9
 
11
- if labels
12
- Open.write(tsv_dataset_file) do |ffile|
13
- ffile.puts ["label", "text"].flatten * "\t"
14
- elements.zip(labels).each do |element,label|
15
- ffile.puts [label, element].flatten * "\t"
16
- end
17
- end
18
- else
19
- Open.write(tsv_dataset_file) do |ffile|
20
- ffile.puts ["text"].flatten * "\t"
21
- elements.each{|element| ffile.puts element }
22
- end
23
- end
10
+ init_model do
11
+ checkpoint = @model_path && File.directory?(@model_path) ? @model_path : @model_options[:checkpoint]
24
12
 
25
- tsv_dataset_file
26
- end
13
+ model = RbbtPython.call_method("rbbt_dm.huggingface", :load_model,
14
+ @model_options[:task], checkpoint, **(IndiferentHash.setup(model_options[:model_args]) || {}))
27
15
 
28
- def initialize(task, checkpoint, *args)
29
- options = args.pop if Hash === args.last
30
- options = Misc.add_defaults options, :task => task, :checkpoint => checkpoint
31
- super(*args)
32
- @model_options ||= {}
33
- @model_options.merge!(options)
16
+ tokenizer_checkpoint = @model_options[:tokenizer_checkpoint] || checkpoint
34
17
 
35
- eval_model do |directory,texts|
36
- checkpoint = directory && File.directory?(directory) ? directory : @model_options[:checkpoint]
18
+ tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_tokenizer,
19
+ @model_options[:task], tokenizer_checkpoint, **(IndiferentHash.setup(model_options[:tokenizer_args]) || {}))
37
20
 
38
- if @model.nil?
39
- @model, @tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_model_and_tokenizer, @model_options[:task], checkpoint)
40
- end
41
-
42
- if Array === texts
21
+ [model, tokenizer]
22
+ end
23
+
24
+ eval_model do |texts,is_list|
25
+ model, tokenizer = self.init
26
+
27
+ if is_list || @model_options[:task] == "MaskedLM"
28
+ texts = [texts] if ! is_list
43
29
 
44
30
  if @model_options.include?(:locate_tokens)
45
31
  locate_tokens = @model_options[:locate_tokens]
46
32
  elsif @model_options[:task] == "MaskedLM"
47
- @model_options[:locate_tokens] = locate_tokens = @tokenizer.special_tokens_map["mask_token"]
33
+ @model_options[:locate_tokens] = locate_tokens = tokenizer.special_tokens_map["mask_token"]
48
34
  end
49
35
 
50
36
  if @directory
@@ -57,22 +43,21 @@ class HuggingfaceModel < VectorModel
57
43
  checkpoint_dir = File.join(tmpdir, 'checkpoints')
58
44
  end
59
45
 
60
- dataset_file = HuggingfaceModel.tsv_dataset(tsv_file, texts)
46
+ dataset_file = TorchModel.text_dataset(tsv_file, texts)
61
47
  training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, checkpoint_dir, @model_options[:training_args])
62
48
 
63
49
  begin
64
- RbbtPython.call_method("rbbt_dm.huggingface", :predict_model, @model, @tokenizer, training_args_obj, dataset_file, locate_tokens)
50
+ RbbtPython.call_method("rbbt_dm.huggingface", :predict_model, model, tokenizer, training_args_obj, dataset_file, locate_tokens)
65
51
  ensure
66
52
  Open.rm_rf tmpdir if tmpdir
67
53
  end
68
54
  else
69
- RbbtPython.call_method("rbbt_dm.huggingface", :eval_model, @model, @tokenizer, [texts], locate_tokens)
55
+ RbbtPython.call_method("rbbt_dm.huggingface", :eval_model, model, tokenizer, [texts], locate_tokens)
70
56
  end
71
57
  end
72
58
 
73
- train_model do |directory,texts,labels|
74
- checkpoint = directory && File.directory?(directory) ? directory : @model_options[:checkpoint]
75
- @model, @tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_model_and_tokenizer, @model_options[:task], checkpoint)
59
+ train_model do |texts,labels|
60
+ model, tokenizer = self.init
76
61
 
77
62
  if @directory
78
63
  tsv_file = File.join(@directory, 'dataset.tsv')
@@ -85,17 +70,19 @@ class HuggingfaceModel < VectorModel
85
70
  end
86
71
 
87
72
  training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, checkpoint_dir, @model_options[:training_args])
88
- dataset_file = HuggingfaceModel.tsv_dataset(tsv_file, texts, labels)
73
+ dataset_file = HuggingfaceModel.text_dataset(tsv_file, texts, labels, @model_options[:class_labels])
89
74
 
90
- RbbtPython.call_method("rbbt_dm.huggingface", :train_model, @model, @tokenizer, training_args_obj, dataset_file, @model_options[:class_weights])
75
+ RbbtPython.call_method("rbbt_dm.huggingface", :train_model, model, tokenizer, training_args_obj, dataset_file, @model_options[:class_weights])
91
76
 
92
77
  Open.rm_rf tmpdir if tmpdir
93
78
 
94
- @model.save_pretrained(directory) if directory
95
- @tokenizer.save_pretrained(directory) if directory
79
+ model.save_pretrained(@model_path) if @model_path
80
+ tokenizer.save_pretrained(@model_path) if @model_path
96
81
  end
97
82
 
98
- post_process do |result|
83
+ post_process do |result,is_list|
84
+ model, tokenizer = self.init
85
+
99
86
  if result.respond_to?(:predictions)
100
87
  single = false
101
88
  predictions = result.predictions
@@ -137,25 +124,25 @@ class HuggingfaceModel < VectorModel
137
124
  best_token
138
125
  end
139
126
 
140
- best.collect{|b| @tokenizer.decode(b) } * "|"
127
+ best.collect{|b| tokenizer.decode(b) } * "|"
141
128
  end
142
129
  Array === locate_tokens ? item_masks : item_masks.first
143
130
  end
144
131
  else
145
- logits
132
+ predictions
146
133
  end
147
134
 
148
- single ? result.first : result
135
+ (! is_list || single) && Array === result ? result.first : result
149
136
  end
150
137
 
151
138
 
152
- save_models if @directory
139
+ save_models if @model_path
153
140
  end
154
141
 
155
142
  def reset_model
156
143
  @model, @tokenizer = nil
157
- Open.rm @model_file
144
+ Open.rm_rf @model_path
145
+ init
158
146
  end
159
-
160
147
  end
161
148
 
@@ -0,0 +1,33 @@
1
+ require 'rbbt/vector/model'
2
+ require 'rbbt/util/python'
3
+
4
+ RbbtPython.add_path Rbbt.python.find(:lib)
5
+ RbbtPython.init_rbbt
6
+
7
+ class PythonModel < VectorModel
8
+ attr_accessor :python_class, :python_module
9
+ def initialize(dir, python_class = nil, python_module = nil, model_options = nil)
10
+ python_module = :model if python_module.nil?
11
+ model_options, python_module = python_module, :model if model_options.nil? && Hash === python_module
12
+ model_options = {} if model_options.nil?
13
+
14
+ super(dir, model_options)
15
+
16
+ @python_class = python_class
17
+ @python_module = python_module
18
+
19
+ init_model do
20
+ RbbtPython.add_path @directory
21
+ RbbtPython.class_new_obj(@python_module, @python_class, **model_options)
22
+ end if python_class
23
+
24
+ eval_model do |features,list=false|
25
+ init
26
+ if list
27
+ model.eval(features)
28
+ else
29
+ model.eval([features])[0]
30
+ end
31
+ end
32
+ end
33
+ end
@@ -0,0 +1,31 @@
1
+ require 'rbbt/vector/model/torch'
2
+
3
+ class PytorchLightningModel < TorchModel
4
+ attr_accessor :loader, :val_loader, :trainer
5
+ def initialize(...)
6
+ super(...)
7
+
8
+ train_model do |features,labels|
9
+ model = init
10
+ loader = self.loader
11
+ val_loader = self.val_loader
12
+ if (features && features.any?) && loader.nil?
13
+ TmpFile.with_file do |tsv_dataset_file|
14
+ TorchModel.feature_dataset(tsv_dataset_file, features, labels)
15
+ RbbtPython.pyimport :rbbt_dm
16
+ loader = RbbtPython.rbbt_dm.tsv(tsv_dataset_file)
17
+ end
18
+ end
19
+ trainer.fit(model, loader, val_loader)
20
+ TorchModel.save_architecture(model, model_path) if @directory
21
+ TorchModel.save_state(model, model_path) if @directory
22
+ end
23
+ end
24
+
25
+ def trainer
26
+ @trainer ||= begin
27
+ options = @model_options[:training_args] || @model_options[:trainer_args]
28
+ RbbtPython.class_new_obj("pytorch_lightning", "Trainer", options || {})
29
+ end
30
+ end
31
+ end
@@ -27,7 +27,7 @@ label = predict(model, features);
27
27
  def importance
28
28
  TmpFile.with_file do |tmp|
29
29
  tsv = R.run <<-EOF
30
- load(file="#{model_file}");
30
+ load(file="#{@model_path}");
31
31
  rbbt.tsv.write('#{tmp}', model$importance)
32
32
  EOF
33
33
  TSV.open(tmp)
@@ -28,12 +28,12 @@ class SpaCyModel < VectorModel
28
28
 
29
29
  super(dir)
30
30
 
31
- @train_model = Proc.new do |file, features, labels|
31
+ @train_model = Proc.new do |features, labels|
32
32
  texts = features
33
33
  docs = []
34
34
  unique_labels = labels.uniq
35
- tmpconfig = File.join(file, 'config')
36
- tmptrain = File.join(file, 'train.spacy')
35
+ tmpconfig = File.join(@model_path, 'config')
36
+ tmptrain = File.join(@model_path, 'train.spacy')
37
37
  SpaCy.config(@config, tmpconfig)
38
38
 
39
39
  bar = bar(features.length, "Training documents into spacy format")
@@ -54,20 +54,22 @@ class SpaCyModel < VectorModel
54
54
  end
55
55
 
56
56
  gpu = Rbbt::Config.get('gpu_id', :spacy, :spacy_train, :default => 0)
57
- CMD.cmd_log(:spacy, "train #{tmpconfig} --output #{file} --paths.train #{tmptrain} --paths.dev #{tmptrain}", "--gpu-id" => gpu)
57
+ CMD.cmd_log(:spacy, "train #{tmpconfig} --output #{@model_path} --paths.train #{tmptrain} --paths.dev #{tmptrain}", "--gpu-id" => gpu)
58
58
  end
59
59
 
60
- @eval_model = Proc.new do |file, features,list|
60
+ @eval_model = Proc.new do |features,list|
61
61
  texts = features
62
62
  texts = [texts] unless list
63
63
 
64
+ model_path = @model_path
65
+
64
66
  docs = []
65
67
  bar = bar(features.length, "Evaluating model")
66
68
  SpaCyModel.spacy do
67
69
  gpu = Rbbt::Config.get('gpu_id', :spacy, :spacy_train, :default => 0)
68
70
  gpu = gpu.to_i if gpu && gpu != ""
69
71
  spacy.require_gpu(gpu) if gpu
70
- nlp = spacy.load("#{file}/model-best")
72
+ nlp = spacy.load("#{model_path}/model-best")
71
73
 
72
74
  docs = nlp.pipe(texts)
73
75
  RbbtPython.collect docs, :bar => bar do |d|
@@ -25,7 +25,7 @@ class TensorFlowModel < VectorModel
25
25
 
26
26
  super(dir)
27
27
 
28
- @train_model = Proc.new do |file, features, labels|
28
+ @train_model = Proc.new do |features, labels|
29
29
  tensorflow do
30
30
  features = tensorflow.convert_to_tensor(features)
31
31
  labels = tensorflow.convert_to_tensor(labels)
@@ -33,16 +33,17 @@ class TensorFlowModel < VectorModel
33
33
  @graph ||= keras_graph
34
34
  @graph.compile(**@compile_options)
35
35
  @graph.fit(features, labels, :epochs => @epochs, :verbose => true)
36
- @graph.save(file)
36
+ @graph.save(@model_path)
37
37
  end
38
38
 
39
- @eval_model = Proc.new do |file, features|
39
+ @eval_model = Proc.new do |features|
40
40
  tensorflow do
41
41
  features = tensorflow.convert_to_tensor(features)
42
42
  end
43
+ model_path = @model_path
44
+ graph = @graph ||= keras.models.load_model(model_path)
43
45
  keras do
44
- @graph ||= keras.models.load_model(file)
45
- indices = @graph.predict(features, :verbose => false).tolist()
46
+ indices = graph.predict(features, :verbose => false).tolist()
46
47
  labels = indices.collect{|p| p.length > 1 ? p.index(p.max): p.first }
47
48
  labels
48
49
  end
@@ -0,0 +1,58 @@
1
+ class TorchModel
2
+ def self.feature_tsv(elements, labels = nil, class_labels = nil)
3
+ tsv = TSV.setup({}, :key_field => "ID", :fields => ["features"], :type => :flat)
4
+ if labels
5
+ tsv.fields = tsv.fields + ["label"]
6
+ labels = case class_labels
7
+ when Array
8
+ labels.collect{|l| class_labels.index l}
9
+ when Hash
10
+ inverse_class_labels = {}
11
+ class_labels.each{|c,l| inverse_class_labels[l] = c }
12
+ labels.collect{|l| inverse_class_labels[l]}
13
+ else
14
+ labels
15
+ end
16
+ elements.zip(labels).each_with_index do |p,i|
17
+ features, label = p
18
+ id = i
19
+ if Array === features
20
+ tsv[id] = features + [label]
21
+ else
22
+ tsv[id] = [features, label]
23
+ end
24
+ end
25
+ else
26
+ elements.each_with_index do |features,i|
27
+ id = i
28
+ if Array === features
29
+ tsv[id] = features
30
+ else
31
+ tsv[id] = [features]
32
+ end
33
+ end
34
+ end
35
+ tsv
36
+ end
37
+
38
+ def self.feature_dataset(tsv_dataset_file, elements, labels = nil, class_labels = nil)
39
+ tsv = feature_tsv(elements, labels, class_labels)
40
+ Open.write(tsv_dataset_file, tsv.to_s)
41
+ tsv_dataset_file
42
+ end
43
+
44
+ def self.text_dataset(tsv_dataset_file, elements, labels = nil, class_labels = nil)
45
+ elements = elements.collect{|e| e.gsub("\n", ' ') }
46
+ tsv = feature_tsv(elements, labels, class_labels)
47
+ if labels.nil?
48
+ tsv.fields[0] = "text"
49
+ tsv.type = :single
50
+ else
51
+ tsv.fields[0] = "text"
52
+ tsv.type = :list
53
+ end
54
+ Open.write(tsv_dataset_file, tsv.to_s)
55
+ tsv_dataset_file
56
+ end
57
+
58
+ end
@@ -0,0 +1,52 @@
1
+ class TorchModel
2
+ module Tensor
3
+ def to_ruby
4
+ RbbtPython.numpy2ruby(self)
5
+ end
6
+ def self.setup(obj)
7
+ obj.extend Tensor
8
+ end
9
+ end
10
+
11
+ def self.init_python
12
+ RbbtPython.pyimport :torch
13
+ RbbtPython.pyimport :rbbt
14
+ RbbtPython.pyimport :rbbt_dm
15
+ RbbtPython.pyfrom :rbbt_dm, import: :util
16
+ RbbtPython.pyfrom :torch, import: :nn
17
+ end
18
+
19
+ def self.optimizer(model, training_args)
20
+ begin
21
+ learning_rate = training_args[:learning_rate] || 0.01
22
+ RbbtPython.torch.optim.SGD.new(model.parameters(), lr: learning_rate)
23
+ end
24
+ end
25
+
26
+ def self.device(model_options)
27
+ case model_options[:device]
28
+ when String, Symbol
29
+ RbbtPython.torch.device(model_options[:device].to_s)
30
+ when nil
31
+ RbbtPython.rbbt_dm.util.device()
32
+ else
33
+ model_options[:device]
34
+ end
35
+ end
36
+
37
+ def self.dtype(model_options)
38
+ case model_options[:dtype]
39
+ when String, Symbol
40
+ RbbtPython.torch.call(model_options[:dtype])
41
+ when nil
42
+ nil
43
+ else
44
+ model_options[:dtype]
45
+ end
46
+ end
47
+
48
+ def self.tensor(obj, device, dtype)
49
+ RbbtPython.torch.tensor(obj, dtype: dtype, device: device)
50
+ end
51
+
52
+ end
@@ -0,0 +1,31 @@
1
+ class TorchModel
2
+ def self.get_layer(model, layer = nil)
3
+ if layer.nil?
4
+ model
5
+ else
6
+ layer.split(".").inject(model){|acc,l| PyCall.getattr(acc, l.to_sym) }
7
+ end
8
+ end
9
+ def get_layer(...); TorchModel.get_layer(model, ...); end
10
+
11
+ def self.get_weights(model, layer = nil)
12
+ Tensor.setup PyCall.getattr(get_layer(model, layer), :weight)
13
+ end
14
+ def get_weights(...); TorchModel.get_weights(model, ...); end
15
+
16
+ def self.freeze(layer)
17
+ begin
18
+ PyCall.getattr(layer, :weight).requires_grad = false
19
+ rescue
20
+ end
21
+ RbbtPython.iterate(layer.children) do |layer|
22
+ freeze(layer)
23
+ end
24
+ end
25
+ def self.freeze_layer(model, layer)
26
+ layer = get_layer(model, layer)
27
+ freeze(layer)
28
+ end
29
+ def freeze_layer(...); TorchModel.freeze_layer(model, ...); end
30
+
31
+ end
@@ -0,0 +1,30 @@
1
+ class TorchModel
2
+ def self.model_architecture(model_path)
3
+ model_path + '.architecture'
4
+ end
5
+
6
+ def self.save_state(model, model_path)
7
+ Log.debug "Saving model state into #{model_path}"
8
+ RbbtPython.torch.save(model.state_dict(), model_path)
9
+ end
10
+
11
+ def self.load_state(model, model_path)
12
+ return model unless Open.exists?(model_path)
13
+ Log.debug "Loading model state from #{model_path}"
14
+ model.load_state_dict(RbbtPython.torch.load(model_path))
15
+ model
16
+ end
17
+
18
+ def self.save_architecture(model, model_path)
19
+ model_architecture = model_architecture(model_path)
20
+ Log.debug "Saving model architecture into #{model_architecture}"
21
+ RbbtPython.torch.save(model, model_architecture)
22
+ end
23
+
24
+ def self.load_architecture(model_path)
25
+ model_architecture = model_architecture(model_path)
26
+ return unless Open.exists?(model_architecture)
27
+ Log.debug "Loading model architecture from #{model_architecture}"
28
+ RbbtPython.torch.load(model_architecture)
29
+ end
30
+ end