ruby-dnn 1.2.0 → 1.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: d152e645981ddb8f244abc88719c4c68a154679fa18abf6b02fa8852510e0c27
4
- data.tar.gz: 2987a4ce59a8f7d60a160ac192125ab97a4e14735ae29e151d23412a22f3c13b
3
+ metadata.gz: 8475052a51f8b81f176c8b3a77c953d116fb1098b7f24fe8a2ad9135b2f2a33b
4
+ data.tar.gz: f291c99be16b9a35fee59a09cefda3b9aece1bbaef5d6b482fad6d238043e68b
5
5
  SHA512:
6
- metadata.gz: 13e9cb2995d1d2851850931a0aedb01ca4fb71b6b03fa81a0ac9e78176b8ed4f520ac6284a39e9c71aaf2629c55475eac33987e6f49535bfe27153729f37624c
7
- data.tar.gz: a492f778d4094fce114149841a54e2b0d8a4eff77778b8074198c8df009916598cb96bfd67494c88520e10fd99618591862854773a1f20baf00247fe2bb1f1ea
6
+ metadata.gz: '08f061510f73297b12b53cd3da5748dd47e67605d1b40c35388779a94e023e37450d2c9ebd9ec1bf0970c1580c7481f46c6cef05d54af5748fe175522b198c5e'
7
+ data.tar.gz: fdfe99ac9b5a63e4914aaab4f71a37e56b097c6b00560245a82438c6fd70ab658683de95c1ddef0283f4cc2f0b9e261dda798158cd374cf19cc17966ed338b16
@@ -12,7 +12,7 @@ end
12
12
  def mnist_predict(img, width, height)
13
13
  load_model
14
14
  img = DNN::Image.from_binary(img, height, width, DNN::Image::RGBA)
15
- img = img[true, true, 0...DNN::Image::RGB]
15
+ img = DNN::Image.to_rgb(img)
16
16
  img = DNN::Image.to_gray_scale(img)
17
17
  x = Numo::SFloat.cast(img) / 255
18
18
  out = $model.predict1(x)
@@ -70,77 +70,6 @@ class Generator < Model
70
70
  end
71
71
  end
72
72
 
73
- class Discriminator < Model
74
- def initialize(gen_input_shape, gen_output_shape)
75
- super()
76
- @gen_input_shape = gen_input_shape
77
- @gen_output_shape = gen_output_shape
78
- @l1_1 = Conv2D.new(32, 4, padding: true)
79
- @l1_2 = Conv2D.new(32, 4, padding: true)
80
- @l2 = Conv2D.new(32, 4, strides: 2, padding: true)
81
- @l3 = Conv2D.new(32, 4, padding: true)
82
- @l4 = Conv2D.new(64, 4, strides: 2, padding: true)
83
- @l5 = Conv2D.new(64, 4, padding: true)
84
- @l6 = Dense.new(1024)
85
- @l7 = Dense.new(1)
86
- @bn1 = BatchNormalization.new
87
- @bn2 = BatchNormalization.new
88
- @bn3 = BatchNormalization.new
89
- @bn4 = BatchNormalization.new
90
- @bn5 = BatchNormalization.new
91
- @bn6 = BatchNormalization.new
92
- end
93
-
94
- def forward(inputs)
95
- input, images = *inputs
96
- x = InputLayer.new(@gen_input_shape).(input)
97
- x = @l1_1.(x)
98
- x = @bn1.(x)
99
- x1 = LeakyReLU.(x, 0.2)
100
-
101
- x = InputLayer.new(@gen_output_shape).(images)
102
- x = @l1_2.(x)
103
- x = @bn2.(x)
104
- x2 = LeakyReLU.(x, 0.2)
105
-
106
- x = Concatenate.(x1, x2)
107
- x = @l2.(x)
108
- x = @bn3.(x)
109
- x = LeakyReLU.(x, 0.2)
110
-
111
- x = @l3.(x)
112
- x = @bn4.(x)
113
- x = LeakyReLU.(x, 0.2)
114
-
115
- x = @l4.(x)
116
- x = @bn5.(x)
117
- x = LeakyReLU.(x, 0.2)
118
-
119
- x = @l5.(x)
120
- x = @bn6.(x)
121
- x = LeakyReLU.(x, 0.2)
122
-
123
- x = Flatten.(x)
124
- x = @l6.(x)
125
- x = LeakyReLU.(x, 0.2)
126
-
127
- x = @l7.(x)
128
- x
129
- end
130
-
131
- def enable_training
132
- trainable_layers.each do |layer|
133
- layer.trainable = true
134
- end
135
- end
136
-
137
- def disable_training
138
- trainable_layers.each do |layer|
139
- layer.trainable = false
140
- end
141
- end
142
- end
143
-
144
73
  class Discriminator < Model
145
74
  def initialize(gen_input_shape, gen_output_shape, base_num_filters)
146
75
  super()
@@ -436,8 +436,8 @@ module DNN
436
436
  def forward_node(x)
437
437
  if DNN.learning_phase
438
438
  Xumo::SFloat.srand(@rnd.rand(1 << 31))
439
- @mask = Xumo::SFloat.new(*x.shape).rand < @dropout_ratio
440
- x[@mask] = 0
439
+ @mask = Xumo::SFloat.cast(Xumo::SFloat.new(*x.shape).rand >= @dropout_ratio)
440
+ x = x * @mask
441
441
  elsif @use_scale
442
442
  x *= (1 - @dropout_ratio)
443
443
  end
@@ -445,8 +445,7 @@ module DNN
445
445
  end
446
446
 
447
447
  def backward_node(dy)
448
- dy[@mask] = 0
449
- dy
448
+ dy * @mask
450
449
  end
451
450
 
452
451
  def to_hash
data/lib/dnn/image.rb CHANGED
@@ -114,6 +114,41 @@ module DNN
114
114
  Numo::UInt8.cast(x)
115
115
  end
116
116
 
117
+ # Image convert image channel to RGB.
118
+ # @param [Numo::UInt8] img Image to RGB.
119
+ def self.to_rgb(img)
120
+ img_check(img)
121
+ case img.shape[2]
122
+ when 1
123
+ return img.concatenate(img, axis: 2).concatenate(img, axis: 2)
124
+ when 2
125
+ img = img[true, true, 0...1]
126
+ return img.concatenate(img, axis: 2).concatenate(img, axis: 2)
127
+ when 4
128
+ return img[true, true, 0...3].clone
129
+ end
130
+ img
131
+ end
132
+
133
+ # Image convert image channel to RGBA.
134
+ # @param [Numo::UInt8] img Image to RGBA.
135
+ def self.to_rgba(img)
136
+ img_check(img)
137
+ case img.shape[2]
138
+ when 1
139
+ alpha = Numo::UInt8.new(*img.shape[0..1], 1).fill(255)
140
+ return img.concatenate(img, axis: 2).concatenate(img, axis: 2).concatenate(alpha, axis: 2)
141
+ when 2
142
+ alpha = img[true, true, 1...2]
143
+ img = img[true, true, 0...1]
144
+ return img.concatenate(img, axis: 2).concatenate(img, axis: 2).concatenate(alpha, axis: 2)
145
+ when 3
146
+ alpha = Numo::UInt8.new(*img.shape[0..1], 1).fill(255)
147
+ return img.concatenate(alpha, axis: 2)
148
+ end
149
+ img
150
+ end
151
+
117
152
  private_class_method def self.img_check(img)
118
153
  raise TypeError, "img: #{img.class} is not an instance of the Numo::UInt8 class." unless img.is_a?(Numo::UInt8)
119
154
  if img.shape.length != 3
data/lib/dnn/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module DNN
2
- VERSION = "1.2.0"
2
+ VERSION = "1.2.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: ruby-dnn
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.2.0
4
+ version: 1.2.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2020-05-16 00:00:00.000000000 Z
11
+ date: 2020-05-24 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray