rbbt-dm 1.2.6 → 1.2.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. checksums.yaml +4 -4
  2. data/lib/rbbt/matrix/barcode.rb +2 -2
  3. data/lib/rbbt/matrix/differential.rb +3 -3
  4. data/lib/rbbt/matrix/knowledge_base.rb +1 -1
  5. data/lib/rbbt/plots/bar.rb +1 -1
  6. data/lib/rbbt/stan.rb +1 -1
  7. data/lib/rbbt/statistics/hypergeometric.rb +2 -1
  8. data/lib/rbbt/vector/model/huggingface/masked_lm.rb +50 -0
  9. data/lib/rbbt/vector/model/huggingface.rb +57 -38
  10. data/lib/rbbt/vector/model/pytorch_lightning.rb +35 -0
  11. data/lib/rbbt/vector/model/random_forest.rb +1 -1
  12. data/lib/rbbt/vector/model/spaCy.rb +8 -14
  13. data/lib/rbbt/vector/model/tensorflow.rb +6 -5
  14. data/lib/rbbt/vector/model/torch.rb +37 -0
  15. data/lib/rbbt/vector/model/util.rb +18 -0
  16. data/lib/rbbt/vector/model.rb +100 -56
  17. data/python/rbbt_dm/__init__.py +48 -1
  18. data/python/rbbt_dm/atcold/__init__.py +0 -0
  19. data/python/rbbt_dm/atcold/plot_lib.py +141 -0
  20. data/python/rbbt_dm/atcold/spiral.py +27 -0
  21. data/python/rbbt_dm/huggingface.py +57 -26
  22. data/python/rbbt_dm/language_model.py +70 -0
  23. data/python/rbbt_dm/util.py +30 -0
  24. data/share/spaCy/gpu/textcat_accuracy.conf +2 -1
  25. data/test/rbbt/vector/model/huggingface/test_masked_lm.rb +41 -0
  26. data/test/rbbt/vector/model/test_huggingface.rb +258 -27
  27. data/test/rbbt/vector/model/test_pytorch_lightning.rb +83 -0
  28. data/test/rbbt/vector/model/test_spaCy.rb +1 -1
  29. data/test/rbbt/vector/model/test_tensorflow.rb +3 -0
  30. data/test/rbbt/vector/test_model.rb +25 -26
  31. data/test/test_helper.rb +13 -0
  32. metadata +26 -16
  33. data/lib/rbbt/tensorflow.rb +0 -43
  34. data/lib/rbbt/vector/model/huggingface.old.rb +0 -160
@@ -3,7 +3,7 @@ require 'rbbt/vector/model/huggingface'
3
3
 
4
4
  class TestHuggingface < Test::Unit::TestCase
5
5
 
6
- def test_options
6
+ def _test_options
7
7
  TmpFile.with_file do |dir|
8
8
  checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
9
9
  task = "SequenceClassification"
@@ -11,20 +11,20 @@ class TestHuggingface < Test::Unit::TestCase
11
11
  model = HuggingfaceModel.new task, checkpoint, dir, :class_labels => %w(bad good)
12
12
  iii model.eval "This is dog"
13
13
  iii model.eval "This is cat"
14
- iii model.eval(["This is dog", "This is cat"])
14
+ iii model.eval_list(["This is dog", "This is cat"])
15
15
 
16
16
  model = VectorModel.new dir
17
- iii model.eval(["This is dog", "This is cat"])
17
+ iii model.eval_list(["This is dog", "This is cat"])
18
18
  end
19
19
  end
20
20
 
21
- def test_pipeline
21
+ def _test_pipeline
22
22
  require 'rbbt/util/python'
23
23
  model = VectorModel.new
24
24
  model.post_process do |elements|
25
25
  elements.collect{|e| e['label'] }
26
26
  end
27
- model.eval_model do |file, elements|
27
+ model.eval_model do |elements|
28
28
  RbbtPython.run :transformers do
29
29
  classifier ||= transformers.pipeline("sentiment-analysis")
30
30
  classifier.call(elements)
@@ -33,21 +33,53 @@ class TestHuggingface < Test::Unit::TestCase
33
33
 
34
34
  assert_equal ["POSITIVE"], model.eval("I've been waiting for a HuggingFace course my whole life.")
35
35
  end
36
+
37
+ def _test_tokenizer_size
38
+ checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
39
+ tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_tokenizer,
40
+ "MaskedLM", checkpoint, :max_length => 5, :model_max_length => 5)
41
+ assert_equal 5, tokenizer.call("This is a sentence that has several words", truncation: true, max_length: 5)["input_ids"].__len__
42
+ assert_equal 5, tokenizer.call("This is a sentence that has several words", truncation: true)["input_ids"].__len__
43
+ end
36
44
 
37
- def test_sst_eval
45
+ def _test_sst_eval
38
46
  TmpFile.with_file do |dir|
39
47
  checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
40
48
 
41
- model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
49
+ model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir, :tokenizer_args => {:max_length => 16}
42
50
 
43
51
  model.model_options[:class_labels] = ["Bad", "Good"]
44
52
 
45
- assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
53
+ assert_equal "Bad", model.eval("This is dog")
54
+ assert_equal ["Bad", "Good"], model.eval_list(["This is dog", "This is cat"])
46
55
  end
47
56
  end
48
57
 
49
58
 
50
59
  def test_sst_train
60
+ TmpFile.with_file do |dir|
61
+ checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
62
+
63
+ model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir, max_length: 128
64
+
65
+ model.model_options[:class_labels] = %w(Bad Good)
66
+
67
+ assert_equal ["Bad", "Good"], model.eval_list(["This is dog", "This is cat"])
68
+
69
+ 100.times do
70
+ model.add "Dog is good", "Good"
71
+ end
72
+
73
+ model.train
74
+
75
+ assert_equal ["Good", "Good"], model.eval_list(["This is dog", "This is cat"])
76
+
77
+ model = VectorModel.new dir
78
+ assert_equal ["Good", "Good"], model.eval_list(["This is dog", "This is cat"])
79
+ end
80
+ end
81
+
82
+ def _test_sst_train_with_labels
51
83
  TmpFile.with_file do |dir|
52
84
  checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
53
85
 
@@ -55,28 +87,29 @@ class TestHuggingface < Test::Unit::TestCase
55
87
 
56
88
  model.model_options[:class_labels] = %w(Bad Good)
57
89
 
58
- assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
90
+ assert_equal ["Bad", "Good"], model.eval_list(["This is dog", "This is cat"])
59
91
 
60
92
  100.times do
61
- model.add "Dog is good", 1
93
+ model.add "Dog is good", "Good"
62
94
  end
63
95
 
64
96
  model.train
65
97
 
66
- assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
98
+ assert_equal ["Good", "Good"], model.eval_list(["This is dog", "This is cat"])
67
99
 
68
100
  model = VectorModel.new dir
69
- assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
101
+ assert_equal ["Good", "Good"], model.eval_list(["This is dog", "This is cat"])
70
102
  end
71
103
  end
72
104
 
73
- def test_sst_train_no_save
105
+
106
+ def _test_sst_train_no_save
74
107
  checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
75
108
 
76
109
  model = HuggingfaceModel.new "SequenceClassification", checkpoint
77
110
  model.model_options[:class_labels] = ["Bad", "Good"]
78
111
 
79
- assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
112
+ assert_equal ["Bad", "Good"], model.eval_list(["This is dog", "This is cat"])
80
113
 
81
114
  100.times do
82
115
  model.add "Dog is good", 1
@@ -84,48 +117,50 @@ class TestHuggingface < Test::Unit::TestCase
84
117
 
85
118
  model.train
86
119
 
87
- assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
120
+ assert_equal ["Good", "Good"], model.eval_list(["This is dog", "This is cat"])
88
121
  end
89
122
 
90
- def test_sst_train_save_and_load
123
+ def _test_sst_train_save_and_load
91
124
  TmpFile.with_file do |dir|
92
125
  checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
93
126
 
94
127
  model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
95
128
  model.model_options[:class_labels] = ["Bad", "Good"]
96
129
 
97
- assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
130
+ assert_equal ["Bad", "Good"], model.eval_list(["This is dog", "This is cat"])
98
131
 
99
132
  100.times do
100
- model.add "Dog is good", 1
133
+ model.add "Dog is good", "Good"
101
134
  end
102
135
 
103
136
  model.train
104
137
 
105
138
  model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
106
139
 
107
- assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
140
+ assert_equal ["Good", "Good"], model.eval_list(["This is dog", "This is cat"])
108
141
 
109
- model_file = model.model_file
142
+ model_path = model.model_path
110
143
 
111
- model = HuggingfaceModel.new "SequenceClassification", model_file
144
+ model = HuggingfaceModel.new "SequenceClassification", model_path
112
145
  model.model_options[:class_labels] = ["Bad", "Good"]
113
146
 
114
- assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
147
+ assert_equal ["Good", "Good"], model.eval_list(["This is dog", "This is cat"])
115
148
 
116
149
  model = VectorModel.new dir
117
150
 
118
- assert_equal "Good", model.eval("This is dog")
151
+ assert_equal "Good", model.eval_list("This is dog")
119
152
 
120
153
  end
121
154
  end
122
155
 
123
- def test_sst_stress_test
156
+ def _test_sst_stress_test
124
157
  TmpFile.with_file do |dir|
125
158
  checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
126
159
 
127
160
  model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
128
161
 
162
+ assert_equal 0, model.eval("This is dog")
163
+
129
164
  100.times do
130
165
  model.add "Dog is good", 1
131
166
  model.add "Cat is bad", 0
@@ -136,18 +171,214 @@ class TestHuggingface < Test::Unit::TestCase
136
171
  end
137
172
 
138
173
  Misc.benchmark 1000 do
139
- model.eval(["This is good", "This is terrible", "This is dog", "This is cat", "Very different stuff", "Dog is bad", "Cat is good"])
174
+ model.eval_list(["This is good", "This is terrible", "This is dog", "This is cat", "Very different stuff", "Dog is bad", "Cat is good"])
140
175
  end
141
176
  end
142
177
  end
143
178
 
144
- def test_mask_eval
179
+ def _test_mask_eval
145
180
  checkpoint = "bert-base-uncased"
146
181
 
147
182
  model = HuggingfaceModel.new "MaskedLM", checkpoint
148
- assert_equal 3, model.eval(["Paris is the [MASK] of the France.", "The [MASK] worked very hard all the time.", "The [MASK] arrested the dangerous [MASK]."]).
183
+ assert_equal 3, model.eval_list(["Paris is the [MASK] of the France.", "The [MASK] worked very hard all the time.", "The [MASK] arrested the dangerous [MASK]."]).
149
184
  reject{|v| v.empty?}.length
150
185
  end
151
186
 
187
+ def _test_mask_eval_tokenizer
188
+ checkpoint = "bert-base-uncased"
189
+
190
+ model = HuggingfaceModel.new "MaskedLM", checkpoint
191
+
192
+ mod, tokenizer = model.init
193
+
194
+ orig = tokenizer.call("Hi [GENE]")["input_ids"]
195
+ tokenizer.add_tokens(["[GENE]"])
196
+ mod.resize_token_embeddings(tokenizer.__len__)
197
+ new = tokenizer.call("Hi [GENE]")["input_ids"]
198
+
199
+ assert orig.length > new.length
200
+ end
201
+
202
+
203
+ def _test_custom_class
204
+ TmpFile.with_file do |dir|
205
+ Open.write File.join(dir, "mypkg/__init__.py"), ""
206
+
207
+ Open.write File.join(dir, "mypkg/mymodel.py"), <<~EOF
208
+
209
+ # Esta clase es igual que la de RobertaForTokenClassification
210
+ # Importamos los métodos necesarios
211
+ import torch.nn as nn
212
+ from transformers import RobertaConfig
213
+ from transformers.modeling_outputs import TokenClassifierOutput
214
+ from transformers.models.roberta.modeling_roberta import RobertaModel, RobertaPreTrainedModel
215
+
216
+ # Creamos una clase que herede de RobertaPreTrainedModel
217
+ class RobertaForTokenClassification_NER(RobertaPreTrainedModel):
218
+ config_class = RobertaConfig
219
+
220
+ def __init__(self, config):
221
+ # Se usa para inicializar el modelo Roberta
222
+ super().__init__(config)
223
+ # Numero de etiquetas que se van a clasificar (sería el número de etiquetas del corpus*2)
224
+ # Una correspondiente a la etiqueta I y otra a la B.
225
+ self.num_labels = config.num_labels
226
+ # No incorporamos pooling layer para devolver los hidden states de cada token (no sólo el CLS)
227
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
228
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
229
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
230
+ self.init_weights()
231
+
232
+ def forward(self, input_ids = None, attention_mask = None, token_type_ids = None, labels = None,
233
+ **kwargs):
234
+ # Obtenemos una codificación del input (los hidden states)
235
+ outputs = self.roberta(input_ids, attention_mask = attention_mask,
236
+ token_type_ids = token_type_ids, **kwargs)
237
+
238
+ # A la salida de los hidden states le aplicamos la capa de dropout
239
+ sequence_output = self.dropout(outputs[0])
240
+ # Y posteriormente la capa de clasificación.
241
+ logits = self.classifier(sequence_output)
242
+ # Si labels tiene algún valor (lo que se hará durante el proceso de entrenamiento), se calculan las Loss
243
+ # para justar los pesos en el backprop.
244
+ loss = None
245
+ if labels is not None:
246
+ loss_fct = nn.CrossEntropyLoss()
247
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
248
+
249
+ return TokenClassifierOutput(loss=loss, logits=logits,
250
+ hidden_states=outputs.hidden_states,
251
+ attentions=outputs.attentions)
252
+ EOF
253
+
254
+ RbbtPython.add_path dir
255
+
256
+ biomedical_roberta = "PlanTL-GOB-ES/bsc-bio-ehr-es-cantemist"
257
+ model = HuggingfaceModel.new "mypkg.mymodel:RobertaForTokenClassification_NER", biomedical_roberta
258
+
259
+ model.post_process do |result,is_list|
260
+ if is_list
261
+ RbbtPython.numpy2ruby result.predictions
262
+ else
263
+ result["logits"][0]
264
+ end
265
+ end
266
+
267
+ texto = "El paciente tiene un cáncer del pulmon"
268
+ assert model.eval(texto)[5][1] > 0
269
+ end
270
+ end
271
+
272
+ def _test_sst_train_word_embeddings
273
+ TmpFile.with_file do |dir|
274
+ checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
275
+
276
+ model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
277
+ model.model_options[:class_labels] = %w(Bad Good)
278
+
279
+ mod, tokenizer = model.init
280
+
281
+ orig = HuggingfaceModel.get_weights(mod, 'distilbert.embeddings.word_embeddings')
282
+ orig = RbbtPython.numpy2ruby(orig.cpu.detach.numpy)
283
+
284
+ 100.times do
285
+ model.add "Dog is good", "Good"
286
+ end
287
+
288
+ model.train
289
+
290
+ new = HuggingfaceModel.get_weights(mod, 'distilbert.embeddings.word_embeddings')
291
+ new = RbbtPython.numpy2ruby(new.cpu.detach.numpy)
292
+
293
+ diff = []
294
+ new.each_with_index do |row,i|
295
+ diff << i if row != orig[i]
296
+ end
297
+
298
+ assert diff.length > 0
299
+ end
300
+ end
301
+
302
+ def _test_sst_freeze_word_embeddings
303
+ TmpFile.with_file do |dir|
304
+ checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
305
+
306
+ model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
307
+ model.model_options[:class_labels] = %w(Bad Good)
308
+
309
+ mod, tokenizer = model.init
310
+
311
+ layer = HuggingfaceModel.freeze_layer(mod, 'distilbert')
312
+
313
+ orig = HuggingfaceModel.get_weights(mod, 'distilbert.embeddings.word_embeddings')
314
+ orig = RbbtPython.numpy2ruby(orig.cpu.detach.numpy)
315
+
316
+ 100.times do
317
+ model.add "Dog is good", "Good"
318
+ end
319
+
320
+ model.train
321
+
322
+ new = HuggingfaceModel.get_weights(mod, 'distilbert.embeddings.word_embeddings')
323
+ new = RbbtPython.numpy2ruby(new.cpu.detach.numpy)
324
+
325
+ diff = []
326
+ new.each_with_index do |row,i|
327
+ diff << i if row != orig[i]
328
+ end
329
+
330
+ assert diff.length == 0
331
+ end
332
+ end
333
+
334
+ def _test_sst_save_word_embeddings
335
+ TmpFile.with_file do |dir|
336
+ checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
337
+
338
+ model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
339
+ model.model_options[:class_labels] = %w(Bad Good)
340
+
341
+ mod, tokenizer = model.init
342
+
343
+ 100.times do
344
+ model.add "Dog is good", "Good"
345
+ end
346
+
347
+ model.train
348
+
349
+ orig = RbbtPython.numpy2ruby(
350
+ HuggingfaceModel.get_weights(mod, 'distilbert.embeddings.word_embeddings').cpu.detach.numpy)
351
+
352
+ model = HuggingfaceModel.new "MaskedLM", checkpoint, dir
353
+
354
+ mod, tokenizer = model.init
355
+
356
+ new = RbbtPython.numpy2ruby(
357
+ HuggingfaceModel.get_weights(mod, 'distilbert.embeddings.word_embeddings').cpu.detach.numpy)
358
+
359
+
360
+ diff = []
361
+ new.each_with_index do |row,i|
362
+ diff << i if row != orig[i]
363
+ end
364
+
365
+ assert diff.length == 0
366
+
367
+ model = HuggingfaceModel.new "MaskedLM", checkpoint
368
+
369
+ mod, tokenizer = model.init
370
+
371
+ new = RbbtPython.numpy2ruby(
372
+ HuggingfaceModel.get_weights(mod, 'distilbert.embeddings.word_embeddings').cpu.detach.numpy)
373
+
374
+
375
+ diff = []
376
+ new.each_with_index do |row,i|
377
+ diff << i if row != orig[i]
378
+ end
379
+
380
+ assert diff.length > 0
381
+ end
382
+ end
152
383
  end
153
384
 
@@ -0,0 +1,83 @@
1
+ require File.join(File.expand_path(File.dirname(__FILE__)), '../../..', 'test_helper.rb')
2
+ require 'rbbt/vector/model/pytorch_lightning'
3
+
4
+ class TestPytorchLightning < Test::Unit::TestCase
5
+ def test_clustering
6
+ nsamples = 10
7
+ ngenes = 10000
8
+ samples = nsamples.times.collect{|i| "Sample-#{i}" }
9
+ data = TSV.setup({}, :key_field => "Gene", :fields => samples + ["cluster"], :type => :list, :cast => :to_f)
10
+
11
+ profiles = []
12
+ p0 = 3
13
+ p1 = 7
14
+ profiles[0] = nsamples.times.collect{ rand() + p0 }
15
+ profiles[1] = nsamples.times.collect{ rand() + p1 }
16
+
17
+ ngenes.times do |genen|
18
+ gene = "Gene-#{genen}"
19
+ cluster = genen % 2
20
+ values = profiles[cluster].collect do |m|
21
+ rand() + m
22
+ end
23
+ data[gene] = values + [cluster]
24
+ end
25
+
26
+ python = <<~EOF
27
+ import torch
28
+ from torch import nn
29
+ from torch.nn import functional as F
30
+ from torch.utils.data import DataLoader
31
+ from torch.utils.data import random_split
32
+ from torchvision.datasets import MNIST
33
+ from torchvision import transforms
34
+ import pytorch_lightning as pl
35
+
36
+ class TestPytorchLightningModel(pl.LightningModule):
37
+ def __init__(self, input_size=10, internal_dim=1):
38
+ super().__init__()
39
+ self.model = nn.Tanh()
40
+
41
+ def configure_optimizers(self):
42
+ optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
43
+ return optimizer
44
+
45
+ @torch.cuda.amp.autocast(True)
46
+ def forward(self, x):
47
+ x = x.to(self.dtype)
48
+ return self.model(x).squeeze()
49
+
50
+ @torch.cuda.amp.autocast(True)
51
+ def training_step(self, train_batch, batch_idx):
52
+ x, y = train_batch
53
+ x = x.to(self.dtype)
54
+ y = y.to(self.dtype)
55
+ y_hat = self.model(x).squeeze()
56
+ loss = F.mse_loss(y, y_hat)
57
+ self.log('train_loss', loss)
58
+ return loss
59
+
60
+ @torch.cuda.amp.custom_fwd(cast_inputs=torch.float64)
61
+ def validation_step(self, val_batch, batch_idx):
62
+ x, y = train_batch
63
+ y_hat = self.model(x)
64
+ loss = F.mse_loss(y, y_hat)
65
+ self.log('val_loss', loss)
66
+
67
+ EOF
68
+
69
+ with_python(python) do |pkg|
70
+ model = PytorchLightningModel.new pkg , "TestPytorchLightningModel", nil, model_args: {internal_dim: 1}
71
+ TmpFile.with_file(data.to_s) do |data_file|
72
+ ds = RbbtPython.call_method "rbbt_dm", :tsv, filename: data_file
73
+ model.loader = RbbtPython.class_new_obj("torch.utils.data", :DataLoader, dataset: ds, batch_size: 64)
74
+ model.trainer = RbbtPython.class_new_obj("pytorch_lightning", "Trainer", gpus: 1, max_epochs: 5, precision: 16)
75
+ end
76
+ model.train
77
+ encoding = model.eval_list(data.values.collect{|v| v[0..-2] }).detach().cpu().numpy()
78
+ iii encoding[0..10]
79
+ end
80
+ end
81
+
82
+ end
83
+
@@ -100,7 +100,7 @@ class TestSpaCyModel < Test::Unit::TestCase
100
100
  )
101
101
 
102
102
 
103
- Rbbt::Config.set 'gpu_id', nil, :spacy
103
+ Rbbt::Config.set 'gpu_id', 0, :spacy
104
104
  require 'rbbt/tsv/csv'
105
105
  url = "https://raw.githubusercontent.com/hanzhang0420/Women-Clothing-E-commerce/master/Womens%20Clothing%20E-Commerce%20Reviews.csv"
106
106
  tsv = TSV.csv(Open.open(url))
@@ -1,5 +1,6 @@
1
1
  require File.join(File.expand_path(File.dirname(__FILE__)), '../../..', 'test_helper.rb')
2
2
  require 'rbbt/vector/model/tensorflow'
3
+ require 'rbbt/util/python'
3
4
 
4
5
  class TestTensorflowModel < Test::Unit::TestCase
5
6
 
@@ -10,6 +11,7 @@ class TestTensorflowModel < Test::Unit::TestCase
10
11
 
11
12
  model = TensorFlowModel.new(
12
13
  dir,
14
+ jit_compile: true,
13
15
  optimizer: 'adam',
14
16
  loss: 'sparse_categorical_crossentropy',
15
17
  metrics: ['accuracy']
@@ -53,5 +55,6 @@ class TestTensorflowModel < Test::Unit::TestCase
53
55
  assert sum.to_f / predictions.length > 0.7
54
56
  end
55
57
  end
58
+
56
59
  end
57
60
 
@@ -26,7 +26,7 @@ class TestVectorModel < Test::Unit::TestCase
26
26
  element.split(";")
27
27
  }
28
28
 
29
- model.train_model = Proc.new{|model_file,features,labels|
29
+ model.train_model = Proc.new{|features,labels|
30
30
  TmpFile.with_file do |feature_file|
31
31
  Open.write(feature_file, features.collect{|feats| feats * "\t"} * "\n")
32
32
  Open.write(feature_file + '.class', labels * "\n")
@@ -36,23 +36,23 @@ labels = scan("#{ feature_file }.class", what=numeric());
36
36
  features = cbind(features, class = labels);
37
37
  rbbt.require('e1071')
38
38
  model = svm(class ~ ., data = features)
39
- save(model, file="#{ model_file }");
39
+ save(model, file="#{ @model_path }");
40
40
  EOF
41
41
  end
42
42
  }
43
43
 
44
- model.eval_model = Proc.new{|model_file,features|
44
+ model.eval_model = Proc.new{|features|
45
45
  TmpFile.with_file do |feature_file|
46
46
  TmpFile.with_file do |results|
47
47
  Open.write(feature_file, features * "\t")
48
- puts R.run(<<-EOF
48
+ R.run <<-EOF
49
49
  features = read.table("#{ feature_file }", sep ="\\t", stringsAsFactors=FALSE);
50
50
  library(e1071)
51
- load(file="#{ model_file }")
51
+ load(file="#{ @model_path }")
52
52
  label = predict(model, features);
53
53
  cat(label, file="#{results}");
54
54
  EOF
55
- ).read
55
+
56
56
  Open.read(results)
57
57
  end
58
58
  end
@@ -96,7 +96,7 @@ cat(label, file="#{results}");
96
96
  end
97
97
  }
98
98
 
99
- model.train_model = Proc.new{|model_file,features,labels|
99
+ model.train_model = Proc.new{|features,labels|
100
100
  TmpFile.with_file do |feature_file|
101
101
  Open.write(feature_file, features.collect{|feats| feats * "\t"} * "\n")
102
102
  Open.write(feature_file + '.class', labels * "\n")
@@ -106,23 +106,23 @@ labels = scan("#{ feature_file }.class", what=numeric());
106
106
  features = cbind(features, class = labels);
107
107
  rbbt.require('e1071')
108
108
  model = svm(class ~ ., data = features)
109
- save(model, file="#{ model_file }");
109
+ save(model, file="#{ @model_path }");
110
110
  EOF
111
111
  end
112
112
  }
113
113
 
114
- model.eval_model = Proc.new{|model_file,features|
114
+ model.eval_model = Proc.new{|features|
115
115
  TmpFile.with_file do |feature_file|
116
116
  TmpFile.with_file do |results|
117
117
  Open.write(feature_file, features * "\t")
118
- puts R.run(<<-EOF
118
+ R.run <<-EOF
119
119
  features = read.table("#{ feature_file }", sep ="\\t", stringsAsFactors=FALSE);
120
120
  library(e1071)
121
- load(file="#{ model_file }")
121
+ load(file="#{ @model_path }")
122
122
  label = predict(model, features);
123
123
  cat(label, file="#{results}");
124
124
  EOF
125
- ).read
125
+
126
126
  Open.read(results)
127
127
  end
128
128
  end
@@ -164,7 +164,7 @@ cat(label, file="#{results}");
164
164
  element.split(";")
165
165
  }
166
166
 
167
- model.train_model = Proc.new{|model_file,features,labels|
167
+ model.train_model = Proc.new{|features,labels|
168
168
  TmpFile.with_file do |feature_file|
169
169
  Open.write(feature_file, features.collect{|feats| feats * "\t"} * "\n")
170
170
  Open.write(feature_file + '.class', labels * "\n")
@@ -174,23 +174,23 @@ labels = scan("#{ feature_file }.class", what=numeric());
174
174
  features = cbind(features, class = labels);
175
175
  rbbt.require('e1071')
176
176
  model = svm(class ~ ., data = features)
177
- save(model, file="#{ model_file }");
177
+ save(model, file="#{ @model_path }");
178
178
  EOF
179
179
  end
180
180
  }
181
181
 
182
- model.eval_model = Proc.new{|model_file,features|
182
+ model.eval_model = Proc.new{|features|
183
183
  TmpFile.with_file do |feature_file|
184
184
  TmpFile.with_file do |results|
185
185
  Open.write(feature_file, features * "\t")
186
- puts R.run(<<-EOF
186
+ R.run <<-EOF
187
187
  features = read.table("#{ feature_file }", sep ="\\t", stringsAsFactors=FALSE);
188
188
  library(e1071)
189
- load(file="#{ model_file }")
189
+ load(file="#{ @model_path }")
190
190
  label = predict(model, features);
191
191
  cat(label, file="#{results}");
192
192
  EOF
193
- ).read
193
+
194
194
  Open.read(results)
195
195
  end
196
196
  end
@@ -236,7 +236,7 @@ cat(label, file="#{results}");
236
236
  end
237
237
  }
238
238
 
239
- model.train_model = Proc.new{|model_file,features,labels|
239
+ model.train_model = Proc.new{|features,labels|
240
240
  TmpFile.with_file do |feature_file|
241
241
  Open.write(feature_file, features.collect{|feats| feats * "\t"} * "\n")
242
242
  Open.write(feature_file + '.class', labels * "\n")
@@ -246,23 +246,23 @@ labels = scan("#{ feature_file }.class", what=numeric());
246
246
  features = cbind(features, label = labels);
247
247
  rbbt.require('e1071')
248
248
  model = svm(label ~ ., data = features)
249
- save(model, file="#{ model_file }");
249
+ save(model, file="#{ @model_path }");
250
250
  EOF
251
251
  end
252
252
  }
253
253
 
254
- model.eval_model = Proc.new{|model_file,features|
254
+ model.eval_model = Proc.new{|features|
255
255
  TmpFile.with_file do |feature_file|
256
256
  TmpFile.with_file do |results|
257
257
  Open.write(feature_file, features * "\t")
258
- puts R.run(<<-EOF
258
+ R.run <<-EOF
259
259
  features = read.table("#{ feature_file }", sep ="\\t", stringsAsFactors=FALSE);
260
260
  library(e1071)
261
- load(file="#{ model_file }")
261
+ load(file="#{ @model_path }")
262
262
  label = predict(model, features);
263
263
  cat(label, file="#{results}");
264
264
  EOF
265
- ).read
265
+
266
266
  Open.read(results)
267
267
  end
268
268
  end
@@ -282,7 +282,6 @@ cat(label, file="#{results}");
282
282
  model.add features, label
283
283
  end
284
284
 
285
- iii model.eval("1;1;1")
286
285
  assert model.eval("1;1;1").to_f > 0.5
287
286
  assert model.eval("0;0;0").to_f < 0.5
288
287
  end
@@ -515,7 +514,7 @@ label = predict(model, features);
515
514
  TmpFile.with_file do |dir|
516
515
  model = VectorModel.new dir
517
516
 
518
- model.eval_model do |file, elements|
517
+ model.eval_model do |elements|
519
518
  elements = [elements] unless Array === elements
520
519
  RbbtPython.binding_run do
521
520
  pyimport :torch