rbbt-dm 1.2.7 → 1.2.10

Sign up to get free protection for your applications and to get access to all the features.
Files changed (40) hide show
  1. checksums.yaml +4 -4
  2. data/lib/rbbt/matrix/barcode.rb +2 -2
  3. data/lib/rbbt/matrix/differential.rb +3 -3
  4. data/lib/rbbt/matrix/knowledge_base.rb +1 -1
  5. data/lib/rbbt/plots/bar.rb +1 -1
  6. data/lib/rbbt/stan.rb +1 -1
  7. data/lib/rbbt/statistics/hypergeometric.rb +2 -1
  8. data/lib/rbbt/vector/model/huggingface/masked_lm.rb +50 -0
  9. data/lib/rbbt/vector/model/huggingface.rb +39 -52
  10. data/lib/rbbt/vector/model/python.rb +33 -0
  11. data/lib/rbbt/vector/model/pytorch_lightning.rb +31 -0
  12. data/lib/rbbt/vector/model/random_forest.rb +1 -1
  13. data/lib/rbbt/vector/model/spaCy.rb +8 -6
  14. data/lib/rbbt/vector/model/tensorflow.rb +6 -5
  15. data/lib/rbbt/vector/model/torch/dataloader.rb +58 -0
  16. data/lib/rbbt/vector/model/torch/helpers.rb +52 -0
  17. data/lib/rbbt/vector/model/torch/introspection.rb +31 -0
  18. data/lib/rbbt/vector/model/torch/load_and_save.rb +30 -0
  19. data/lib/rbbt/vector/model/torch.rb +71 -0
  20. data/lib/rbbt/vector/model.rb +84 -54
  21. data/python/rbbt_dm/__init__.py +31 -1
  22. data/python/rbbt_dm/atcold/__init__.py +0 -0
  23. data/python/rbbt_dm/atcold/plot_lib.py +141 -0
  24. data/python/rbbt_dm/atcold/spiral.py +27 -0
  25. data/python/rbbt_dm/huggingface.py +64 -28
  26. data/python/rbbt_dm/language_model.py +70 -0
  27. data/python/rbbt_dm/util.py +32 -0
  28. data/share/spaCy/gpu/textcat_accuracy.conf +2 -1
  29. data/test/rbbt/vector/model/huggingface/test_masked_lm.rb +41 -0
  30. data/test/rbbt/vector/model/test_huggingface.rb +258 -27
  31. data/test/rbbt/vector/model/test_python.rb +31 -0
  32. data/test/rbbt/vector/model/test_pytorch_lightning.rb +97 -0
  33. data/test/rbbt/vector/model/test_spaCy.rb +1 -1
  34. data/test/rbbt/vector/model/test_tensorflow.rb +3 -0
  35. data/test/rbbt/vector/model/test_torch.rb +61 -0
  36. data/test/rbbt/vector/test_model.rb +25 -26
  37. data/test/test_helper.rb +13 -0
  38. metadata +35 -16
  39. data/lib/rbbt/tensorflow.rb +0 -43
  40. data/lib/rbbt/vector/model/huggingface.old.rb +0 -160
@@ -3,7 +3,7 @@ require 'rbbt/vector/model/huggingface'
3
3
 
4
4
  class TestHuggingface < Test::Unit::TestCase
5
5
 
6
- def test_options
6
+ def _test_options
7
7
  TmpFile.with_file do |dir|
8
8
  checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
9
9
  task = "SequenceClassification"
@@ -11,20 +11,20 @@ class TestHuggingface < Test::Unit::TestCase
11
11
  model = HuggingfaceModel.new task, checkpoint, dir, :class_labels => %w(bad good)
12
12
  iii model.eval "This is dog"
13
13
  iii model.eval "This is cat"
14
- iii model.eval(["This is dog", "This is cat"])
14
+ iii model.eval_list(["This is dog", "This is cat"])
15
15
 
16
16
  model = VectorModel.new dir
17
- iii model.eval(["This is dog", "This is cat"])
17
+ iii model.eval_list(["This is dog", "This is cat"])
18
18
  end
19
19
  end
20
20
 
21
- def test_pipeline
21
+ def _test_pipeline
22
22
  require 'rbbt/util/python'
23
23
  model = VectorModel.new
24
24
  model.post_process do |elements|
25
25
  elements.collect{|e| e['label'] }
26
26
  end
27
- model.eval_model do |file, elements|
27
+ model.eval_model do |elements|
28
28
  RbbtPython.run :transformers do
29
29
  classifier ||= transformers.pipeline("sentiment-analysis")
30
30
  classifier.call(elements)
@@ -33,21 +33,53 @@ class TestHuggingface < Test::Unit::TestCase
33
33
 
34
34
  assert_equal ["POSITIVE"], model.eval("I've been waiting for a HuggingFace course my whole life.")
35
35
  end
36
+
37
+ def _test_tokenizer_size
38
+ checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
39
+ tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_tokenizer,
40
+ "MaskedLM", checkpoint, :max_length => 5, :model_max_length => 5)
41
+ assert_equal 5, tokenizer.call("This is a sentence that has several words", truncation: true, max_length: 5)["input_ids"].__len__
42
+ assert_equal 5, tokenizer.call("This is a sentence that has several words", truncation: true)["input_ids"].__len__
43
+ end
36
44
 
37
45
  def test_sst_eval
38
46
  TmpFile.with_file do |dir|
39
47
  checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
40
48
 
41
- model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
49
+ model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir, :tokenizer_args => {:max_length => 16}
42
50
 
43
51
  model.model_options[:class_labels] = ["Bad", "Good"]
44
52
 
45
- assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
53
+ assert_equal "Bad", model.eval("This is dog")
54
+ assert_equal ["Bad", "Good"], model.eval_list(["This is dog", "This is cat"])
46
55
  end
47
56
  end
48
57
 
49
58
 
50
- def test_sst_train
59
+ def _test_sst_train
60
+ TmpFile.with_file do |dir|
61
+ checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
62
+
63
+ model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir, max_length: 128
64
+
65
+ model.model_options[:class_labels] = %w(Bad Good)
66
+
67
+ assert_equal ["Bad", "Good"], model.eval_list(["This is dog", "This is cat"])
68
+
69
+ 100.times do
70
+ model.add "Dog is good", "Good"
71
+ end
72
+
73
+ model.train
74
+
75
+ assert_equal ["Good", "Good"], model.eval_list(["This is dog", "This is cat"])
76
+
77
+ model = VectorModel.new dir
78
+ assert_equal ["Good", "Good"], model.eval_list(["This is dog", "This is cat"])
79
+ end
80
+ end
81
+
82
+ def _test_sst_train_with_labels
51
83
  TmpFile.with_file do |dir|
52
84
  checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
53
85
 
@@ -55,28 +87,29 @@ class TestHuggingface < Test::Unit::TestCase
55
87
 
56
88
  model.model_options[:class_labels] = %w(Bad Good)
57
89
 
58
- assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
90
+ assert_equal ["Bad", "Good"], model.eval_list(["This is dog", "This is cat"])
59
91
 
60
92
  100.times do
61
- model.add "Dog is good", 1
93
+ model.add "Dog is good", "Good"
62
94
  end
63
95
 
64
96
  model.train
65
97
 
66
- assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
98
+ assert_equal ["Good", "Good"], model.eval_list(["This is dog", "This is cat"])
67
99
 
68
100
  model = VectorModel.new dir
69
- assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
101
+ assert_equal ["Good", "Good"], model.eval_list(["This is dog", "This is cat"])
70
102
  end
71
103
  end
72
104
 
73
- def test_sst_train_no_save
105
+
106
+ def _test_sst_train_no_save
74
107
  checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
75
108
 
76
109
  model = HuggingfaceModel.new "SequenceClassification", checkpoint
77
110
  model.model_options[:class_labels] = ["Bad", "Good"]
78
111
 
79
- assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
112
+ assert_equal ["Bad", "Good"], model.eval_list(["This is dog", "This is cat"])
80
113
 
81
114
  100.times do
82
115
  model.add "Dog is good", 1
@@ -84,48 +117,50 @@ class TestHuggingface < Test::Unit::TestCase
84
117
 
85
118
  model.train
86
119
 
87
- assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
120
+ assert_equal ["Good", "Good"], model.eval_list(["This is dog", "This is cat"])
88
121
  end
89
122
 
90
- def test_sst_train_save_and_load
123
+ def _test_sst_train_save_and_load
91
124
  TmpFile.with_file do |dir|
92
125
  checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
93
126
 
94
127
  model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
95
128
  model.model_options[:class_labels] = ["Bad", "Good"]
96
129
 
97
- assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
130
+ assert_equal ["Bad", "Good"], model.eval_list(["This is dog", "This is cat"])
98
131
 
99
132
  100.times do
100
- model.add "Dog is good", 1
133
+ model.add "Dog is good", "Good"
101
134
  end
102
135
 
103
136
  model.train
104
137
 
105
138
  model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
106
139
 
107
- assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
140
+ assert_equal ["Good", "Good"], model.eval_list(["This is dog", "This is cat"])
108
141
 
109
- model_file = model.model_file
142
+ model_path = model.model_path
110
143
 
111
- model = HuggingfaceModel.new "SequenceClassification", model_file
144
+ model = HuggingfaceModel.new "SequenceClassification", model_path
112
145
  model.model_options[:class_labels] = ["Bad", "Good"]
113
146
 
114
- assert_equal ["Good", "Good"], model.eval(["This is dog", "This is cat"])
147
+ assert_equal ["Good", "Good"], model.eval_list(["This is dog", "This is cat"])
115
148
 
116
149
  model = VectorModel.new dir
117
150
 
118
- assert_equal "Good", model.eval("This is dog")
151
+ assert_equal "Good", model.eval_list("This is dog")
119
152
 
120
153
  end
121
154
  end
122
155
 
123
- def test_sst_stress_test
156
+ def _test_sst_stress_test
124
157
  TmpFile.with_file do |dir|
125
158
  checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
126
159
 
127
160
  model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
128
161
 
162
+ assert_equal 0, model.eval("This is dog")
163
+
129
164
  100.times do
130
165
  model.add "Dog is good", 1
131
166
  model.add "Cat is bad", 0
@@ -136,18 +171,214 @@ class TestHuggingface < Test::Unit::TestCase
136
171
  end
137
172
 
138
173
  Misc.benchmark 1000 do
139
- model.eval(["This is good", "This is terrible", "This is dog", "This is cat", "Very different stuff", "Dog is bad", "Cat is good"])
174
+ model.eval_list(["This is good", "This is terrible", "This is dog", "This is cat", "Very different stuff", "Dog is bad", "Cat is good"])
140
175
  end
141
176
  end
142
177
  end
143
178
 
144
- def test_mask_eval
179
+ def _test_mask_eval
145
180
  checkpoint = "bert-base-uncased"
146
181
 
147
182
  model = HuggingfaceModel.new "MaskedLM", checkpoint
148
- assert_equal 3, model.eval(["Paris is the [MASK] of the France.", "The [MASK] worked very hard all the time.", "The [MASK] arrested the dangerous [MASK]."]).
183
+ assert_equal 3, model.eval_list(["Paris is the [MASK] of the France.", "The [MASK] worked very hard all the time.", "The [MASK] arrested the dangerous [MASK]."]).
149
184
  reject{|v| v.empty?}.length
150
185
  end
151
186
 
187
+ def _test_mask_eval_tokenizer
188
+ checkpoint = "bert-base-uncased"
189
+
190
+ model = HuggingfaceModel.new "MaskedLM", checkpoint
191
+
192
+ mod, tokenizer = model.init
193
+
194
+ orig = tokenizer.call("Hi [GENE]")["input_ids"]
195
+ tokenizer.add_tokens(["[GENE]"])
196
+ mod.resize_token_embeddings(tokenizer.__len__)
197
+ new = tokenizer.call("Hi [GENE]")["input_ids"]
198
+
199
+ assert orig.length > new.length
200
+ end
201
+
202
+
203
+ def _test_custom_class
204
+ TmpFile.with_file do |dir|
205
+ Open.write File.join(dir, "mypkg/__init__.py"), ""
206
+
207
+ Open.write File.join(dir, "mypkg/mymodel.py"), <<~EOF
208
+
209
+ # Esta clase es igual que la de RobertaForTokenClassification
210
+ # Importamos los métodos necesarios
211
+ import torch.nn as nn
212
+ from transformers import RobertaConfig
213
+ from transformers.modeling_outputs import TokenClassifierOutput
214
+ from transformers.models.roberta.modeling_roberta import RobertaModel, RobertaPreTrainedModel
215
+
216
+ # Creamos una clase que herede de RobertaPreTrainedModel
217
+ class RobertaForTokenClassification_NER(RobertaPreTrainedModel):
218
+ config_class = RobertaConfig
219
+
220
+ def __init__(self, config):
221
+ # Se usa para inicializar el modelo Roberta
222
+ super().__init__(config)
223
+ # Numero de etiquetas que se van a clasificar (sería el número de etiquetas del corpus*2)
224
+ # Una correspondiente a la etiqueta I y otra a la B.
225
+ self.num_labels = config.num_labels
226
+ # No incorporamos pooling layer para devolver los hidden states de cada token (no sólo el CLS)
227
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
228
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
229
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
230
+ self.init_weights()
231
+
232
+ def forward(self, input_ids = None, attention_mask = None, token_type_ids = None, labels = None,
233
+ **kwargs):
234
+ # Obtenemos una codificación del input (los hidden states)
235
+ outputs = self.roberta(input_ids, attention_mask = attention_mask,
236
+ token_type_ids = token_type_ids, **kwargs)
237
+
238
+ # A la salida de los hidden states le aplicamos la capa de dropout
239
+ sequence_output = self.dropout(outputs[0])
240
+ # Y posteriormente la capa de clasificación.
241
+ logits = self.classifier(sequence_output)
242
+ # Si labels tiene algún valor (lo que se hará durante el proceso de entrenamiento), se calculan las Loss
243
+ # para justar los pesos en el backprop.
244
+ loss = None
245
+ if labels is not None:
246
+ loss_fct = nn.CrossEntropyLoss()
247
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
248
+
249
+ return TokenClassifierOutput(loss=loss, logits=logits,
250
+ hidden_states=outputs.hidden_states,
251
+ attentions=outputs.attentions)
252
+ EOF
253
+
254
+ RbbtPython.add_path dir
255
+
256
+ biomedical_roberta = "PlanTL-GOB-ES/bsc-bio-ehr-es-cantemist"
257
+ model = HuggingfaceModel.new "mypkg.mymodel:RobertaForTokenClassification_NER", biomedical_roberta
258
+
259
+ model.post_process do |result,is_list|
260
+ if is_list
261
+ RbbtPython.numpy2ruby result.predictions
262
+ else
263
+ result["logits"][0]
264
+ end
265
+ end
266
+
267
+ texto = "El paciente tiene un cáncer del pulmon"
268
+ assert model.eval(texto)[5][1] > 0
269
+ end
270
+ end
271
+
272
+ def _test_sst_train_word_embeddings
273
+ TmpFile.with_file do |dir|
274
+ checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
275
+
276
+ model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
277
+ model.model_options[:class_labels] = %w(Bad Good)
278
+
279
+ mod, tokenizer = model.init
280
+
281
+ orig = HuggingfaceModel.get_weights(mod, 'distilbert.embeddings.word_embeddings')
282
+ orig = RbbtPython.numpy2ruby(orig.cpu.detach.numpy)
283
+
284
+ 100.times do
285
+ model.add "Dog is good", "Good"
286
+ end
287
+
288
+ model.train
289
+
290
+ new = HuggingfaceModel.get_weights(mod, 'distilbert.embeddings.word_embeddings')
291
+ new = RbbtPython.numpy2ruby(new.cpu.detach.numpy)
292
+
293
+ diff = []
294
+ new.each_with_index do |row,i|
295
+ diff << i if row != orig[i]
296
+ end
297
+
298
+ assert diff.length > 0
299
+ end
300
+ end
301
+
302
+ def _test_sst_freeze_word_embeddings
303
+ TmpFile.with_file do |dir|
304
+ checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
305
+
306
+ model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
307
+ model.model_options[:class_labels] = %w(Bad Good)
308
+
309
+ mod, tokenizer = model.init
310
+
311
+ layer = HuggingfaceModel.freeze_layer(mod, 'distilbert')
312
+
313
+ orig = HuggingfaceModel.get_weights(mod, 'distilbert.embeddings.word_embeddings')
314
+ orig = RbbtPython.numpy2ruby(orig.cpu.detach.numpy)
315
+
316
+ 100.times do
317
+ model.add "Dog is good", "Good"
318
+ end
319
+
320
+ model.train
321
+
322
+ new = HuggingfaceModel.get_weights(mod, 'distilbert.embeddings.word_embeddings')
323
+ new = RbbtPython.numpy2ruby(new.cpu.detach.numpy)
324
+
325
+ diff = []
326
+ new.each_with_index do |row,i|
327
+ diff << i if row != orig[i]
328
+ end
329
+
330
+ assert diff.length == 0
331
+ end
332
+ end
333
+
334
+ def _test_sst_save_word_embeddings
335
+ TmpFile.with_file do |dir|
336
+ checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
337
+
338
+ model = HuggingfaceModel.new "SequenceClassification", checkpoint, dir
339
+ model.model_options[:class_labels] = %w(Bad Good)
340
+
341
+ mod, tokenizer = model.init
342
+
343
+ 100.times do
344
+ model.add "Dog is good", "Good"
345
+ end
346
+
347
+ model.train
348
+
349
+ orig = RbbtPython.numpy2ruby(
350
+ HuggingfaceModel.get_weights(mod, 'distilbert.embeddings.word_embeddings').cpu.detach.numpy)
351
+
352
+ model = HuggingfaceModel.new "MaskedLM", checkpoint, dir
353
+
354
+ mod, tokenizer = model.init
355
+
356
+ new = RbbtPython.numpy2ruby(
357
+ HuggingfaceModel.get_weights(mod, 'distilbert.embeddings.word_embeddings').cpu.detach.numpy)
358
+
359
+
360
+ diff = []
361
+ new.each_with_index do |row,i|
362
+ diff << i if row != orig[i]
363
+ end
364
+
365
+ assert diff.length == 0
366
+
367
+ model = HuggingfaceModel.new "MaskedLM", checkpoint
368
+
369
+ mod, tokenizer = model.init
370
+
371
+ new = RbbtPython.numpy2ruby(
372
+ HuggingfaceModel.get_weights(mod, 'distilbert.embeddings.word_embeddings').cpu.detach.numpy)
373
+
374
+
375
+ diff = []
376
+ new.each_with_index do |row,i|
377
+ diff << i if row != orig[i]
378
+ end
379
+
380
+ assert diff.length > 0
381
+ end
382
+ end
152
383
  end
153
384
 
@@ -0,0 +1,31 @@
1
+ require File.expand_path(__FILE__).sub(%r(/test/.*), '/test/test_helper.rb')
2
+ require File.expand_path(__FILE__).sub(%r(.*/test/), '').sub(/test_(.*)\.rb/,'\1')
3
+
4
+ class TestPythonModel < Test::Unit::TestCase
5
+ def test_linear
6
+ model = nil
7
+
8
+ TmpFile.with_dir do |dir|
9
+
10
+ Misc.in_dir dir do
11
+ Open.write 'model.py', <<-EOF
12
+ class TestModel:
13
+ def __init__(self, delta):
14
+ self.delta = delta
15
+
16
+ def eval(self, x):
17
+ return [e + self.delta for e in x]
18
+ EOF
19
+ model = PythonModel.new dir, 'TestModel', :model, delta: 1
20
+
21
+ assert_equal 2, model.eval(1)
22
+ assert_equal [4, 6], model.eval_list([3, 5])
23
+
24
+ model = PythonModel.new dir, 'TestModel', :model, delta: 2
25
+
26
+ assert_equal 3, model.eval(1)
27
+ end
28
+ end
29
+ end
30
+ end
31
+
@@ -0,0 +1,97 @@
1
+ require File.join(File.expand_path(File.dirname(__FILE__)), '../../..', 'test_helper.rb')
2
+ require 'rbbt/vector/model/pytorch_lightning'
3
+
4
+ class TestPytorchLightning < Test::Unit::TestCase
5
+ def test_regresion
6
+ points = 10
7
+ a = 1
8
+ b = 1
9
+
10
+ x = (0..points - 1)
11
+ y = points.times.collect{|p| p }
12
+
13
+ python = <<~EOF
14
+ import pytorch_lightning as pl
15
+ import numpy as np
16
+ import torch
17
+ from torch.nn import MSELoss
18
+ from torch.optim import Adam
19
+ from torch.utils.data import DataLoader, Dataset
20
+ import torch.nn as nn
21
+
22
+
23
+ class SimpleDataset(Dataset):
24
+ def __init__(self):
25
+ X = np.arange(10000)
26
+ y = X * 2
27
+ X = [[_] for _ in X]
28
+ y = [[_] for _ in y]
29
+ self.X = torch.Tensor(X)
30
+ self.y = torch.Tensor(y)
31
+
32
+ def __len__(self):
33
+ return len(self.y)
34
+
35
+ def __getitem__(self, idx):
36
+ return {"X": self.X[idx], "y": self.y[idx]}
37
+
38
+
39
+ class TestPytorchLightningModel(pl.LightningModule):
40
+ def __init__(self):
41
+ super().__init__()
42
+ self.fc = nn.Linear(1, 1)
43
+ self.criterion = MSELoss()
44
+
45
+ def forward(self, inputs, labels=None):
46
+ outputs = self.fc(inputs)
47
+ loss = 0
48
+ if labels is not None:
49
+ loss = self.criterion(outputs, labels)
50
+ return loss, outputs
51
+
52
+ def train_dataloader(self):
53
+ dataset = SimpleDataset()
54
+ return DataLoader(dataset, batch_size=1000)
55
+
56
+ def training_step(self, batch, batch_idx):
57
+ input_ids = batch["X"]
58
+ labels = batch["y"]
59
+ loss, outputs = self(input_ids, labels)
60
+ return {"loss": loss}
61
+
62
+ def configure_optimizers(self):
63
+ optimizer = Adam(self.parameters(), lr=0.1)
64
+ return optimizer
65
+ EOF
66
+
67
+ TmpFile.with_dir do |dir|
68
+ Open.write(File.join(dir, 'model.py'), python)
69
+ model = PytorchLightningModel.new dir, "TestPytorchLightningModel"
70
+ model.init
71
+
72
+ model.trainer = RbbtPython.class_new_obj("pytorch_lightning", "Trainer", max_epochs: 10, precision: 16)
73
+ model.init
74
+
75
+ model.train
76
+
77
+ w = model.get_weights('fc').to_ruby.first.first
78
+
79
+ assert w > 1.8
80
+ assert w < 2.2
81
+
82
+ res = model.eval(10.0)
83
+ assert_equal res, (10 * w)
84
+ assert res > 1.8 * 10.0
85
+ assert res < 2.2 * 10.0
86
+
87
+ res = model.eval([10.0])
88
+ res = model.eval_list([[10.0], [11.2], [14.3]])
89
+ assert_equal 3, RbbtPython.numpy2ruby(res).length
90
+
91
+ model = VectorModel.new dir
92
+ model.init
93
+
94
+ end
95
+ end
96
+ end
97
+
@@ -100,7 +100,7 @@ class TestSpaCyModel < Test::Unit::TestCase
100
100
  )
101
101
 
102
102
 
103
- Rbbt::Config.set 'gpu_id', nil, :spacy
103
+ Rbbt::Config.set 'gpu_id', 0, :spacy
104
104
  require 'rbbt/tsv/csv'
105
105
  url = "https://raw.githubusercontent.com/hanzhang0420/Women-Clothing-E-commerce/master/Womens%20Clothing%20E-Commerce%20Reviews.csv"
106
106
  tsv = TSV.csv(Open.open(url))
@@ -1,5 +1,6 @@
1
1
  require File.join(File.expand_path(File.dirname(__FILE__)), '../../..', 'test_helper.rb')
2
2
  require 'rbbt/vector/model/tensorflow'
3
+ require 'rbbt/util/python'
3
4
 
4
5
  class TestTensorflowModel < Test::Unit::TestCase
5
6
 
@@ -10,6 +11,7 @@ class TestTensorflowModel < Test::Unit::TestCase
10
11
 
11
12
  model = TensorFlowModel.new(
12
13
  dir,
14
+ jit_compile: true,
13
15
  optimizer: 'adam',
14
16
  loss: 'sparse_categorical_crossentropy',
15
17
  metrics: ['accuracy']
@@ -53,5 +55,6 @@ class TestTensorflowModel < Test::Unit::TestCase
53
55
  assert sum.to_f / predictions.length > 0.7
54
56
  end
55
57
  end
58
+
56
59
  end
57
60
 
@@ -0,0 +1,61 @@
1
+ require File.expand_path(__FILE__).sub(%r(/test/.*), '/test/test_helper.rb')
2
+ require File.expand_path(__FILE__).sub(%r(.*/test/), '').sub(/test_(.*)\.rb/,'\1')
3
+
4
+ class TestTorch < Test::Unit::TestCase
5
+ def test_linear
6
+ model = nil
7
+
8
+ TmpFile.with_dir do |dir|
9
+
10
+ # Create model
11
+
12
+ model = TorchModel.new dir
13
+ model.model = RbbtPython.torch.nn.Linear.new(1, 1)
14
+ model.criterion = RbbtPython.torch.nn.MSELoss.new()
15
+
16
+ model.extract_features do |f|
17
+ [f]
18
+ end
19
+
20
+ model.post_process do |v,list|
21
+ list ? v.to_ruby.collect{|vv| vv.first } : v.to_ruby.first
22
+ end
23
+
24
+ # Train model
25
+
26
+ model.add 5.0, [10.0]
27
+ model.add 10.0, [20.0]
28
+
29
+ model.training_args[:epochs] = 1000
30
+ model.train
31
+
32
+ w = model.get_weights.to_ruby.first.first
33
+
34
+ assert w > 1.8
35
+ assert w < 2.2
36
+
37
+ # Load the model again
38
+
39
+ model = VectorModel.new dir
40
+
41
+ # Test model
42
+
43
+ y = model.eval(100.0)
44
+
45
+ assert(y > 150.0)
46
+ assert(y < 250.0)
47
+
48
+ test = [1.0, 5.0, 10.0, 20.0]
49
+ input_sum = Misc.sum(test)
50
+ sum = Misc.sum(model.eval_list(test))
51
+ assert sum > 0.8 * input_sum * 2
52
+ assert sum < 1.2 * input_sum * 2
53
+
54
+ w = TorchModel.get_weights(model.model).to_ruby.first.first
55
+
56
+ assert w > 1.8
57
+ assert w < 2.2
58
+ end
59
+ end
60
+ end
61
+