ruby-dnn 1.2.2 → 1.3.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (86) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +0 -0
  3. data/.travis.yml +0 -0
  4. data/CODE_OF_CONDUCT.md +0 -0
  5. data/Gemfile +0 -0
  6. data/LICENSE.txt +0 -0
  7. data/README.md +0 -0
  8. data/Rakefile +5 -0
  9. data/examples/api-examples/early_stopping_example.rb +0 -0
  10. data/examples/api-examples/initializer_example.rb +0 -0
  11. data/examples/api-examples/regularizer_example.rb +0 -0
  12. data/examples/api-examples/save_example.rb +0 -0
  13. data/examples/cifar100_example.rb +0 -0
  14. data/examples/cifar10_example.rb +0 -0
  15. data/examples/dcgan/dcgan.rb +1 -1
  16. data/examples/dcgan/imgen.rb +0 -0
  17. data/examples/dcgan/train.rb +0 -0
  18. data/examples/iris_example.rb +17 -41
  19. data/examples/iris_example_unused_model.rb +57 -0
  20. data/examples/judge-number/README.md +0 -0
  21. data/examples/judge-number/capture.PNG +0 -0
  22. data/examples/judge-number/convnet8.rb +0 -0
  23. data/examples/judge-number/make_weights.rb +0 -0
  24. data/examples/judge-number/mnist_predict.rb +0 -0
  25. data/examples/judge-number/mnist_train.rb +0 -0
  26. data/examples/judge-number/public/httpRequest.js +0 -0
  27. data/examples/judge-number/public/judgeNumber.js +0 -0
  28. data/examples/judge-number/server.rb +0 -0
  29. data/examples/judge-number/trained_mnist_params.marshal +0 -0
  30. data/examples/judge-number/views/index.erb +0 -0
  31. data/examples/mnist_conv2d_example.rb +0 -0
  32. data/examples/mnist_define_by_run.rb +0 -0
  33. data/examples/mnist_example.rb +0 -0
  34. data/examples/mnist_gpu.rb +0 -0
  35. data/examples/mnist_lstm_example.rb +0 -0
  36. data/examples/pix2pix/dcgan.rb +0 -0
  37. data/examples/pix2pix/imgen.rb +0 -0
  38. data/examples/pix2pix/train.rb +0 -0
  39. data/examples/vae.rb +1 -1
  40. data/examples/xor_example.rb +0 -0
  41. data/ext/rb_stb_image/extconf.rb +0 -0
  42. data/ext/rb_stb_image/rb_stb_image.c +0 -0
  43. data/img/cart-pole.gif +0 -0
  44. data/img/cycle-gan.PNG +0 -0
  45. data/img/facade-pix2pix.png +0 -0
  46. data/lib/dnn/core/callbacks.rb +18 -8
  47. data/lib/dnn/core/error.rb +0 -0
  48. data/lib/dnn/core/global.rb +0 -0
  49. data/lib/dnn/core/initializers.rb +0 -0
  50. data/lib/dnn/core/iterator.rb +20 -4
  51. data/lib/dnn/core/layers/activations.rb +0 -0
  52. data/lib/dnn/core/layers/basic_layers.rb +2 -2
  53. data/lib/dnn/core/layers/cnn_layers.rb +0 -0
  54. data/lib/dnn/core/layers/embedding.rb +0 -0
  55. data/lib/dnn/core/layers/math_layers.rb +0 -0
  56. data/lib/dnn/core/layers/merge_layers.rb +2 -2
  57. data/lib/dnn/core/layers/normalizations.rb +0 -0
  58. data/lib/dnn/core/layers/rnn_layers.rb +20 -24
  59. data/lib/dnn/core/layers/split_layers.rb +0 -0
  60. data/lib/dnn/core/link.rb +0 -0
  61. data/lib/dnn/core/losses.rb +2 -2
  62. data/lib/dnn/core/models.rb +474 -149
  63. data/lib/dnn/core/monkey_patch.rb +0 -0
  64. data/lib/dnn/core/optimizers.rb +0 -0
  65. data/lib/dnn/core/param.rb +0 -0
  66. data/lib/dnn/core/regularizers.rb +0 -0
  67. data/lib/dnn/core/savers.rb +4 -12
  68. data/lib/dnn/core/tensor.rb +0 -0
  69. data/lib/dnn/core/utils.rb +14 -0
  70. data/lib/dnn/datasets/cifar10.rb +0 -0
  71. data/lib/dnn/datasets/cifar100.rb +0 -0
  72. data/lib/dnn/datasets/downloader.rb +12 -3
  73. data/lib/dnn/datasets/fashion-mnist.rb +0 -0
  74. data/lib/dnn/datasets/iris.rb +5 -1
  75. data/lib/dnn/datasets/mnist.rb +0 -0
  76. data/lib/dnn/datasets/stl-10.rb +0 -0
  77. data/lib/dnn/image.rb +1 -1
  78. data/lib/dnn/keras-model-convertor.rb +0 -0
  79. data/lib/dnn/numo2numpy.rb +0 -0
  80. data/lib/dnn/version.rb +1 -1
  81. data/lib/dnn.rb +32 -26
  82. data/ruby-dnn.gemspec +1 -0
  83. data/third_party/stb_image.h +0 -0
  84. data/third_party/stb_image_resize.h +0 -0
  85. data/third_party/stb_image_write.h +0 -0
  86. metadata +21 -6
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 2aa12b717ef532b8afe44de7cb388c7d87cb271bf38a25adaad2c335c8817d4b
4
- data.tar.gz: e1322f86b06c11ac3728e948e18469ccc6e454eeefb180108a0a9bfc2dbd8143
3
+ metadata.gz: df187618941592ff119bab49757f039544d806a6dab5b8c9962040b714a7873e
4
+ data.tar.gz: 263c5a9d1d366ad4782c0fa30b130a081edae1d6b744daf0b2f359560c83e0fd
5
5
  SHA512:
6
- metadata.gz: 6c9c53ca73a5ab7fc53935f37e0804d761e36735aa06972f63e151ce76a44f01878a2c01dfe0622bf39baa5c98b16ee8d06db212356fcc329a9dd40ae2d78f1c
7
- data.tar.gz: 5a195c0afd677127afad2433df2fda3a09c711464c46cb174713d00d695026d2b86faa64f82cc96f2b417b468fc372c04efee22761800c900e2f6b0ca05ac57d
6
+ metadata.gz: f2b53307f3a90d6fa3caaa01a9f57e7ea3d76a378fa90ecaba4c4e46f052fc1380718a89f066eba9df28a822b49236f10e03c65bb113004b523017412e5cbbf2
7
+ data.tar.gz: fc6d85c0f8de928f97fad6debf480fa683cad780d70855bedbe8ed989ad31efc92cb60d55a6e456107f76093d461ddfc328c806b51f69095ec08111632438191
data/.gitignore CHANGED
File without changes
data/.travis.yml CHANGED
File without changes
data/CODE_OF_CONDUCT.md CHANGED
File without changes
data/Gemfile CHANGED
File without changes
data/LICENSE.txt CHANGED
File without changes
data/README.md CHANGED
File without changes
data/Rakefile CHANGED
@@ -1,5 +1,6 @@
1
1
  require "bundler/gem_tasks"
2
2
  require "rake/testtask"
3
+ require "rake/extensiontask"
3
4
  require "yard"
4
5
  require "yard/rake/yardoc_task"
5
6
 
@@ -10,6 +11,10 @@ Rake::TestTask.new(:test) do |t|
10
11
  t.test_files = FileList["test/*_test.rb", "test/layers_test/*_test.rb"]
11
12
  end
12
13
 
14
+ Rake::ExtensionTask.new "rb_stb_image" do |ext|
15
+ ext.lib_dir = "lib/rb_stb_image"
16
+ end
17
+
13
18
  task :build_rb_stb_image do
14
19
  sh "cd ext/rb_stb_image; ruby extconf.rb; make"
15
20
  end
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
@@ -121,7 +121,7 @@ class DCGAN < Model
121
121
  x
122
122
  end
123
123
 
124
- def train_step(x_batch, y_batch)
124
+ def train_step(x_batch, y_batch, need_accuracy: false)
125
125
  batch_size = x_batch.shape[0]
126
126
  noise = Numo::SFloat.new(batch_size, 20).rand(-1, 1)
127
127
  images = @gen.predict(noise)
File without changes
File without changes
@@ -3,6 +3,7 @@ require "dnn/datasets/iris"
3
3
  # If you use numo/linalg then please uncomment out.
4
4
  # require "numo/linalg/autoloader"
5
5
 
6
+ include DNN::Models
6
7
  include DNN::Layers
7
8
  include DNN::Optimizers
8
9
  include DNN::Losses
@@ -14,44 +15,19 @@ x_test, y_test = x[100...150, true], y[100...150]
14
15
  y_train = DNN::Utils.to_categorical(y_train, 3, Numo::SFloat)
15
16
  y_test = DNN::Utils.to_categorical(y_test, 3, Numo::SFloat)
16
17
 
17
- epochs = 1000
18
- batch_size = 32
19
-
20
- opt = Adam.new
21
- lf = SoftmaxCrossEntropy.new
22
-
23
- train_iter = DNN::Iterator.new(x_train, y_train)
24
- test_iter = DNN::Iterator.new(x_test, y_test, random: false)
25
-
26
- w1 = DNN::Param.new(Numo::SFloat.new(4, 16).rand_norm)
27
- b1 = DNN::Param.new(Numo::SFloat.zeros(16))
28
- w2 = DNN::Param.new(Numo::SFloat.new(16, 3).rand_norm)
29
- b2 = DNN::Param.new(Numo::SFloat.zeros(3))
30
-
31
- net = -> x, y do
32
- h = Dot.(x, w1) + b1
33
- h = Sigmoid.(h)
34
- out = Dot.(h, w2) + b2
35
- out
36
- end
37
-
38
- (1..epochs).each do |epoch|
39
- train_iter.foreach(batch_size) do |x_batch, y_batch, step|
40
- x = DNN::Tensor.convert(x_batch)
41
- y = DNN::Tensor.convert(y_batch)
42
- out = net.(x, y)
43
- loss = lf.(out, y)
44
- loss.link.backward
45
- puts "epoch: #{epoch}, step: #{step}, loss = #{loss.data.to_f}"
46
- opt.update([w1, b1, w2, b2])
47
- end
48
- end
49
-
50
- correct = 0
51
- test_iter.foreach(batch_size) do |x_batch, y_batch, step|
52
- x = DNN::Tensor.convert(x_batch)
53
- y = DNN::Tensor.convert(y_batch)
54
- out = net.(x, y)
55
- correct += out.data.max_index(axis: 1).eq(y_batch.max_index(axis: 1)).count
56
- end
57
- puts "correct = #{correct}"
18
+ model = Sequential.new
19
+
20
+ model << InputLayer.new(4)
21
+
22
+ model << Dense.new(16)
23
+ model << Sigmoid.new
24
+
25
+ model << Dense.new(3)
26
+
27
+ model.setup(Adam.new, SoftmaxCrossEntropy.new)
28
+
29
+ model.train(x_train, y_train, 1000, batch_size: 32, test: [x_test, y_test])
30
+
31
+ accuracy, loss = model.evaluate(x_test, y_test)
32
+ puts "accuracy: #{accuracy}"
33
+ puts "loss: #{loss}"
@@ -0,0 +1,57 @@
1
+ require "dnn"
2
+ require "dnn/datasets/iris"
3
+ # If you use numo/linalg then please uncomment out.
4
+ # require "numo/linalg/autoloader"
5
+
6
+ include DNN::Layers
7
+ include DNN::Optimizers
8
+ include DNN::Losses
9
+
10
+ x, y = DNN::Iris.load(true)
11
+ x_train, y_train = x[0...100, true], y[0...100]
12
+ x_test, y_test = x[100...150, true], y[100...150]
13
+
14
+ y_train = DNN::Utils.to_categorical(y_train, 3, Numo::SFloat)
15
+ y_test = DNN::Utils.to_categorical(y_test, 3, Numo::SFloat)
16
+
17
+ epochs = 1000
18
+ batch_size = 32
19
+
20
+ opt = Adam.new
21
+ lf = SoftmaxCrossEntropy.new
22
+
23
+ train_iter = DNN::Iterator.new(x_train, y_train)
24
+ test_iter = DNN::Iterator.new(x_test, y_test, random: false)
25
+
26
+ w1 = DNN::Param.new(Numo::SFloat.new(4, 16).rand_norm)
27
+ b1 = DNN::Param.new(Numo::SFloat.zeros(16))
28
+ w2 = DNN::Param.new(Numo::SFloat.new(16, 3).rand_norm)
29
+ b2 = DNN::Param.new(Numo::SFloat.zeros(3))
30
+
31
+ net = -> x, y do
32
+ h = Dot.(x, w1) + b1
33
+ h = Sigmoid.(h)
34
+ out = Dot.(h, w2) + b2
35
+ out
36
+ end
37
+
38
+ (1..epochs).each do |epoch|
39
+ train_iter.foreach(batch_size) do |x_batch, y_batch, step|
40
+ x = DNN::Tensor.convert(x_batch)
41
+ y = DNN::Tensor.convert(y_batch)
42
+ out = net.(x, y)
43
+ loss = lf.(out, y)
44
+ loss.link.backward
45
+ puts "epoch: #{epoch}, step: #{step}, loss = #{loss.data.to_f}"
46
+ opt.update([w1, b1, w2, b2])
47
+ end
48
+ end
49
+
50
+ correct = 0
51
+ test_iter.foreach(batch_size) do |x_batch, y_batch, step|
52
+ x = DNN::Tensor.convert(x_batch)
53
+ y = DNN::Tensor.convert(y_batch)
54
+ out = net.(x, y)
55
+ correct += out.data.max_index(axis: 1).eq(y_batch.max_index(axis: 1)).count
56
+ end
57
+ puts "correct = #{correct}"
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
data/examples/vae.rb CHANGED
@@ -97,7 +97,7 @@ model = VAE.new
97
97
  dec = model.dec
98
98
  model.setup(Adam.new, VAELoss.new)
99
99
 
100
- model.train(x_train, x_train, 10, batch_size: 128)
100
+ model.train(x_train, x_train, 10, batch_size: 128, need_accuracy: false)
101
101
 
102
102
  images = []
103
103
  10.times do |i|
File without changes
File without changes
File without changes
data/img/cart-pole.gif CHANGED
File without changes
data/img/cycle-gan.PNG CHANGED
File without changes
File without changes
@@ -6,6 +6,12 @@ module DNN
6
6
 
7
7
  # Please implement the method used for callback event.
8
8
 
9
+ # Process performed before all training.
10
+ # def before_train; end
11
+
12
+ # Process performed after all training.
13
+ # def after_train; end
14
+
9
15
  # Process performed before one training.
10
16
  # def before_epoch; end
11
17
 
@@ -57,7 +63,7 @@ module DNN
57
63
 
58
64
  # A callback to stop training the model early after test on batch.
59
65
  # @param [Symbol] trigger A log that triggers early stopping.
60
- # Specify one of train_loss, test_loss, test_accuracy.
66
+ # Specify one of :loss, :test_loss, :test_accuracy
61
67
  # @param [Float] tolerance Tolerance value for early stopping.
62
68
  class EarlyStopping < Callback
63
69
  def initialize(trigger, tolerance)
@@ -66,19 +72,21 @@ module DNN
66
72
  end
67
73
 
68
74
  def after_train_on_batch
69
- throw :stop, "Early stopped." if judge_early_stopping_train
75
+ @model.request_early_stop if judge_early_stopping_train
70
76
  end
71
77
 
72
78
  def after_epoch
73
- throw :stop, "Early stopped." if judge_early_stopping_test
79
+ @model.request_early_stop if judge_early_stopping_test
74
80
  end
75
81
 
76
82
  private
77
83
 
78
84
  def judge_early_stopping_train
79
85
  case @trigger
80
- when :train_loss
86
+ when :loss
81
87
  return true if model.last_log[@trigger] <= @tolerance
88
+ when :accuracy
89
+ return true if model.last_log[@trigger] >= @tolerance
82
90
  end
83
91
  false
84
92
  end
@@ -97,7 +105,7 @@ module DNN
97
105
  # A callback to stop training the model if loss is NaN by after train on batch.
98
106
  class NaNStopping < Callback
99
107
  def after_train_on_batch
100
- throw :stop, "loss is NaN." if model.last_log[:train_loss].nan?
108
+ throw :stop, "loss is NaN." if model.last_log[:loss].nan?
101
109
  end
102
110
  end
103
111
 
@@ -105,7 +113,8 @@ module DNN
105
113
  # The following logs will be recorded.
106
114
  # epoch: Current epoch.
107
115
  # step: Current step in epoch.
108
- # train_loss: Batch training loss.
116
+ # loss: Batch training loss.
117
+ # accuracy: Batch training accuracy.
109
118
  # test_loss: Mean test loss.
110
119
  # test_accuracy: Test accuracy.
111
120
  class Logger < Callback
@@ -113,7 +122,8 @@ module DNN
113
122
  @log = {
114
123
  epoch: [],
115
124
  step: [],
116
- train_loss: [],
125
+ loss: [],
126
+ accuracy: [],
117
127
  test_loss: [],
118
128
  test_accuracy: [],
119
129
  }
@@ -124,7 +134,7 @@ module DNN
124
134
  end
125
135
 
126
136
  def after_train_on_batch
127
- logging(:train_loss, :step)
137
+ logging(:loss, :step)
128
138
  end
129
139
 
130
140
  # Get a log.
File without changes
File without changes
File without changes
@@ -4,11 +4,13 @@ module DNN
4
4
  attr_reader :num_datas
5
5
  attr_reader :last_round_down
6
6
 
7
- # @param [Numo::SFloat | Array] x_datas input datas.
8
- # @param [Numo::SFloat | Array] y_datas output datas.
7
+ # @param [Numo::NArray | Array] x_datas input datas.
8
+ # @param [Numo::NArray | Array] y_datas output datas.
9
9
  # @param [Boolean] random Set true to return batches randomly. Setting false returns batches in order of index.
10
10
  # @param [Boolean] last_round_down Set true to round down for last batch data when call foreach.
11
11
  def initialize(x_datas, y_datas, random: true, last_round_down: false)
12
+ Utils.check_input_data_type("x_datas", x_datas, Xumo::NArray)
13
+ Utils.check_input_data_type("y_datas", y_datas, Xumo::NArray)
12
14
  @x_datas = x_datas
13
15
  @y_datas = y_datas
14
16
  @random = random
@@ -64,12 +66,26 @@ module DNN
64
66
  # @param [Integer] batch_size Batch size.
65
67
  # @yield Executes block by receiving the specified arguments (x_batch, y_batch).
66
68
  def foreach(batch_size, &block)
67
- steps = @last_round_down ? @num_datas / batch_size : (@num_datas.to_f / batch_size).ceil
68
- steps.times do |step|
69
+ max_steps(batch_size).times do |step|
69
70
  x_batch, y_batch = next_batch(batch_size)
70
71
  block.call(x_batch, y_batch, step)
71
72
  end
72
73
  reset
73
74
  end
75
+
76
+ # Return the number of available data considering last_round_down.
77
+ def num_usable_datas(batch_size)
78
+ if @last_round_down
79
+ max_steps(batch_size) * batch_size
80
+ else
81
+ @num_datas
82
+ end
83
+ end
84
+
85
+ # Get max steps for iteration.
86
+ # @param [Integer] batch_size Batch size.
87
+ def max_steps(batch_size)
88
+ @last_round_down ? @num_datas / batch_size : (@num_datas.to_f / batch_size).ceil
89
+ end
74
90
  end
75
91
  end
File without changes
@@ -26,8 +26,8 @@ module DNN
26
26
  attr_reader :input_shape
27
27
  attr_reader :output_shape
28
28
 
29
- def self.call(x, *args)
30
- new(*args).(x)
29
+ def self.call(x, *args, **kwargs)
30
+ new(*args, **kwargs).(x)
31
31
  end
32
32
 
33
33
  def self.from_hash(hash)
File without changes
File without changes
File without changes
@@ -2,8 +2,8 @@ module DNN
2
2
  module Layers
3
3
 
4
4
  class MergeLayer < Layer
5
- def self.call(x1, x2, *args)
6
- new(*args).call(x1, x2)
5
+ def self.call(x1, x2, *args, **kwargs)
6
+ new(*args, **kwargs).call(x1, x2)
7
7
  end
8
8
 
9
9
  def call(input1, input2)
File without changes
@@ -1,5 +1,16 @@
1
1
  module DNN
2
2
  module Layers
3
+ # Super class of all RNN cells.
4
+ class RNNCell
5
+ attr_accessor :trainable
6
+
7
+ def initialize(weight, recurrent_weight, bias)
8
+ @weight = weight
9
+ @recurrent_weight = recurrent_weight
10
+ @bias = bias
11
+ @trainable = true
12
+ end
13
+ end
3
14
 
4
15
  # Super class of all RNN classes.
5
16
  class RNN < Connection
@@ -136,15 +147,10 @@ module DNN
136
147
  end
137
148
  end
138
149
 
139
- class SimpleRNNDense
140
- attr_accessor :trainable
141
-
150
+ class SimpleRNNCell < RNNCell
142
151
  def initialize(weight, recurrent_weight, bias, activation)
143
- @weight = weight
144
- @recurrent_weight = recurrent_weight
145
- @bias = bias
152
+ super(weight, recurrent_weight, bias)
146
153
  @activation = activation.clone
147
- @trainable = true
148
154
  end
149
155
 
150
156
  def forward(x, h)
@@ -206,7 +212,7 @@ module DNN
206
212
  end
207
213
 
208
214
  def create_hidden_layer
209
- @hidden_layers = Array.new(@time_length) { SimpleRNNDense.new(@weight, @recurrent_weight, @bias, @activation) }
215
+ @hidden_layers = Array.new(@time_length) { SimpleRNNCell.new(@weight, @recurrent_weight, @bias, @activation) }
210
216
  end
211
217
 
212
218
  def to_hash
@@ -228,19 +234,14 @@ module DNN
228
234
  end
229
235
  end
230
236
 
231
- class LSTMDense
232
- attr_accessor :trainable
233
-
237
+ class LSTMCell < RNNCell
234
238
  def initialize(weight, recurrent_weight, bias)
235
- @weight = weight
236
- @recurrent_weight = recurrent_weight
237
- @bias = bias
239
+ super(weight, recurrent_weight, bias)
238
240
  @tanh = Layers::Tanh.new
239
241
  @g_tanh = Layers::Tanh.new
240
242
  @forget_sigmoid = Layers::Sigmoid.new
241
243
  @in_sigmoid = Layers::Sigmoid.new
242
244
  @out_sigmoid = Layers::Sigmoid.new
243
- @trainable = true
244
245
  end
245
246
 
246
247
  def forward(x, h, c)
@@ -312,7 +313,7 @@ module DNN
312
313
  end
313
314
 
314
315
  def create_hidden_layer
315
- @hidden_layers = Array.new(@time_length) { LSTMDense.new(@weight, @recurrent_weight, @bias) }
316
+ @hidden_layers = Array.new(@time_length) { LSTMCell.new(@weight, @recurrent_weight, @bias) }
316
317
  end
317
318
 
318
319
  def forward_node(xs)
@@ -365,17 +366,12 @@ module DNN
365
366
  end
366
367
  end
367
368
 
368
- class GRUDense < Layer
369
- attr_accessor :trainable
370
-
369
+ class GRUCell < RNNCell
371
370
  def initialize(weight, recurrent_weight, bias)
372
- @weight = weight
373
- @recurrent_weight = recurrent_weight
374
- @bias = bias
371
+ super(weight, recurrent_weight, bias)
375
372
  @update_sigmoid = Layers::Sigmoid.new
376
373
  @reset_sigmoid = Layers::Sigmoid.new
377
374
  @tanh = Layers::Tanh.new
378
- @trainable = true
379
375
  end
380
376
 
381
377
  def forward(x, h)
@@ -457,7 +453,7 @@ module DNN
457
453
  end
458
454
 
459
455
  def create_hidden_layer
460
- @hidden_layers = Array.new(@time_length) { GRUDense.new(@weight, @recurrent_weight, @bias) }
456
+ @hidden_layers = Array.new(@time_length) { GRUCell.new(@weight, @recurrent_weight, @bias) }
461
457
  end
462
458
  end
463
459
 
File without changes
data/lib/dnn/core/link.rb CHANGED
File without changes
@@ -2,8 +2,8 @@ module DNN
2
2
  module Losses
3
3
 
4
4
  class Loss
5
- def self.call(y, t, *args)
6
- new(*args).(y, t)
5
+ def self.call(y, t, *args, **kwargs)
6
+ new(*args, **kwargs).(y, t)
7
7
  end
8
8
 
9
9
  def self.from_hash(hash)