ruby-dnn 1.2.0 → 1.2.1

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: d152e645981ddb8f244abc88719c4c68a154679fa18abf6b02fa8852510e0c27
4
- data.tar.gz: 2987a4ce59a8f7d60a160ac192125ab97a4e14735ae29e151d23412a22f3c13b
3
+ metadata.gz: 8475052a51f8b81f176c8b3a77c953d116fb1098b7f24fe8a2ad9135b2f2a33b
4
+ data.tar.gz: f291c99be16b9a35fee59a09cefda3b9aece1bbaef5d6b482fad6d238043e68b
5
5
  SHA512:
6
- metadata.gz: 13e9cb2995d1d2851850931a0aedb01ca4fb71b6b03fa81a0ac9e78176b8ed4f520ac6284a39e9c71aaf2629c55475eac33987e6f49535bfe27153729f37624c
7
- data.tar.gz: a492f778d4094fce114149841a54e2b0d8a4eff77778b8074198c8df009916598cb96bfd67494c88520e10fd99618591862854773a1f20baf00247fe2bb1f1ea
6
+ metadata.gz: '08f061510f73297b12b53cd3da5748dd47e67605d1b40c35388779a94e023e37450d2c9ebd9ec1bf0970c1580c7481f46c6cef05d54af5748fe175522b198c5e'
7
+ data.tar.gz: fdfe99ac9b5a63e4914aaab4f71a37e56b097c6b00560245a82438c6fd70ab658683de95c1ddef0283f4cc2f0b9e261dda798158cd374cf19cc17966ed338b16
@@ -12,7 +12,7 @@ end
12
12
  def mnist_predict(img, width, height)
13
13
  load_model
14
14
  img = DNN::Image.from_binary(img, height, width, DNN::Image::RGBA)
15
- img = img[true, true, 0...DNN::Image::RGB]
15
+ img = DNN::Image.to_rgb(img)
16
16
  img = DNN::Image.to_gray_scale(img)
17
17
  x = Numo::SFloat.cast(img) / 255
18
18
  out = $model.predict1(x)
@@ -70,77 +70,6 @@ class Generator < Model
70
70
  end
71
71
  end
72
72
 
73
- class Discriminator < Model
74
- def initialize(gen_input_shape, gen_output_shape)
75
- super()
76
- @gen_input_shape = gen_input_shape
77
- @gen_output_shape = gen_output_shape
78
- @l1_1 = Conv2D.new(32, 4, padding: true)
79
- @l1_2 = Conv2D.new(32, 4, padding: true)
80
- @l2 = Conv2D.new(32, 4, strides: 2, padding: true)
81
- @l3 = Conv2D.new(32, 4, padding: true)
82
- @l4 = Conv2D.new(64, 4, strides: 2, padding: true)
83
- @l5 = Conv2D.new(64, 4, padding: true)
84
- @l6 = Dense.new(1024)
85
- @l7 = Dense.new(1)
86
- @bn1 = BatchNormalization.new
87
- @bn2 = BatchNormalization.new
88
- @bn3 = BatchNormalization.new
89
- @bn4 = BatchNormalization.new
90
- @bn5 = BatchNormalization.new
91
- @bn6 = BatchNormalization.new
92
- end
93
-
94
- def forward(inputs)
95
- input, images = *inputs
96
- x = InputLayer.new(@gen_input_shape).(input)
97
- x = @l1_1.(x)
98
- x = @bn1.(x)
99
- x1 = LeakyReLU.(x, 0.2)
100
-
101
- x = InputLayer.new(@gen_output_shape).(images)
102
- x = @l1_2.(x)
103
- x = @bn2.(x)
104
- x2 = LeakyReLU.(x, 0.2)
105
-
106
- x = Concatenate.(x1, x2)
107
- x = @l2.(x)
108
- x = @bn3.(x)
109
- x = LeakyReLU.(x, 0.2)
110
-
111
- x = @l3.(x)
112
- x = @bn4.(x)
113
- x = LeakyReLU.(x, 0.2)
114
-
115
- x = @l4.(x)
116
- x = @bn5.(x)
117
- x = LeakyReLU.(x, 0.2)
118
-
119
- x = @l5.(x)
120
- x = @bn6.(x)
121
- x = LeakyReLU.(x, 0.2)
122
-
123
- x = Flatten.(x)
124
- x = @l6.(x)
125
- x = LeakyReLU.(x, 0.2)
126
-
127
- x = @l7.(x)
128
- x
129
- end
130
-
131
- def enable_training
132
- trainable_layers.each do |layer|
133
- layer.trainable = true
134
- end
135
- end
136
-
137
- def disable_training
138
- trainable_layers.each do |layer|
139
- layer.trainable = false
140
- end
141
- end
142
- end
143
-
144
73
  class Discriminator < Model
145
74
  def initialize(gen_input_shape, gen_output_shape, base_num_filters)
146
75
  super()
@@ -436,8 +436,8 @@ module DNN
436
436
  def forward_node(x)
437
437
  if DNN.learning_phase
438
438
  Xumo::SFloat.srand(@rnd.rand(1 << 31))
439
- @mask = Xumo::SFloat.new(*x.shape).rand < @dropout_ratio
440
- x[@mask] = 0
439
+ @mask = Xumo::SFloat.cast(Xumo::SFloat.new(*x.shape).rand >= @dropout_ratio)
440
+ x = x * @mask
441
441
  elsif @use_scale
442
442
  x *= (1 - @dropout_ratio)
443
443
  end
@@ -445,8 +445,7 @@ module DNN
445
445
  end
446
446
 
447
447
  def backward_node(dy)
448
- dy[@mask] = 0
449
- dy
448
+ dy * @mask
450
449
  end
451
450
 
452
451
  def to_hash
data/lib/dnn/image.rb CHANGED
@@ -114,6 +114,41 @@ module DNN
114
114
  Numo::UInt8.cast(x)
115
115
  end
116
116
 
117
+ # Image convert image channel to RGB.
118
+ # @param [Numo::UInt8] img Image to RGB.
119
+ def self.to_rgb(img)
120
+ img_check(img)
121
+ case img.shape[2]
122
+ when 1
123
+ return img.concatenate(img, axis: 2).concatenate(img, axis: 2)
124
+ when 2
125
+ img = img[true, true, 0...1]
126
+ return img.concatenate(img, axis: 2).concatenate(img, axis: 2)
127
+ when 4
128
+ return img[true, true, 0...3].clone
129
+ end
130
+ img
131
+ end
132
+
133
+ # Image convert image channel to RGBA.
134
+ # @param [Numo::UInt8] img Image to RGBA.
135
+ def self.to_rgba(img)
136
+ img_check(img)
137
+ case img.shape[2]
138
+ when 1
139
+ alpha = Numo::UInt8.new(*img.shape[0..1], 1).fill(255)
140
+ return img.concatenate(img, axis: 2).concatenate(img, axis: 2).concatenate(alpha, axis: 2)
141
+ when 2
142
+ alpha = img[true, true, 1...2]
143
+ img = img[true, true, 0...1]
144
+ return img.concatenate(img, axis: 2).concatenate(img, axis: 2).concatenate(alpha, axis: 2)
145
+ when 3
146
+ alpha = Numo::UInt8.new(*img.shape[0..1], 1).fill(255)
147
+ return img.concatenate(alpha, axis: 2)
148
+ end
149
+ img
150
+ end
151
+
117
152
  private_class_method def self.img_check(img)
118
153
  raise TypeError, "img: #{img.class} is not an instance of the Numo::UInt8 class." unless img.is_a?(Numo::UInt8)
119
154
  if img.shape.length != 3
data/lib/dnn/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module DNN
2
- VERSION = "1.2.0"
2
+ VERSION = "1.2.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: ruby-dnn
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.2.0
4
+ version: 1.2.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2020-05-16 00:00:00.000000000 Z
11
+ date: 2020-05-24 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray