ruby-dnn 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +8 -0
  3. data/.travis.yml +5 -0
  4. data/CODE_OF_CONDUCT.md +74 -0
  5. data/Gemfile +6 -0
  6. data/LICENSE.txt +21 -0
  7. data/README.md +42 -0
  8. data/Rakefile +10 -0
  9. data/bin/console +14 -0
  10. data/bin/setup +8 -0
  11. data/lib/dnn.rb +14 -0
  12. data/lib/dnn/core/activations.rb +116 -0
  13. data/lib/dnn/core/error.rb +13 -0
  14. data/lib/dnn/core/initializers.rb +46 -0
  15. data/lib/dnn/core/layers.rb +366 -0
  16. data/lib/dnn/core/model.rb +158 -0
  17. data/lib/dnn/core/optimizers.rb +113 -0
  18. data/lib/dnn/core/util.rb +24 -0
  19. data/lib/dnn/core/version.rb +3 -0
  20. data/lib/dnn/ext/cifar10/Makefile +263 -0
  21. data/lib/dnn/ext/cifar10/cifar10_ext.c +52 -0
  22. data/lib/dnn/ext/cifar10/cifar10_ext.o +0 -0
  23. data/lib/dnn/ext/cifar10/cifar10_ext.so +0 -0
  24. data/lib/dnn/ext/cifar10/extconf.rb +3 -0
  25. data/lib/dnn/ext/cifar10/numo/compat.h +23 -0
  26. data/lib/dnn/ext/cifar10/numo/extconf.h +13 -0
  27. data/lib/dnn/ext/cifar10/numo/intern.h +117 -0
  28. data/lib/dnn/ext/cifar10/numo/narray.h +430 -0
  29. data/lib/dnn/ext/cifar10/numo/ndloop.h +94 -0
  30. data/lib/dnn/ext/cifar10/numo/template.h +149 -0
  31. data/lib/dnn/ext/cifar10/numo/types/bit.h +33 -0
  32. data/lib/dnn/ext/cifar10/numo/types/complex.h +409 -0
  33. data/lib/dnn/ext/cifar10/numo/types/complex_macro.h +377 -0
  34. data/lib/dnn/ext/cifar10/numo/types/dcomplex.h +44 -0
  35. data/lib/dnn/ext/cifar10/numo/types/dfloat.h +42 -0
  36. data/lib/dnn/ext/cifar10/numo/types/float_def.h +34 -0
  37. data/lib/dnn/ext/cifar10/numo/types/float_macro.h +186 -0
  38. data/lib/dnn/ext/cifar10/numo/types/int16.h +24 -0
  39. data/lib/dnn/ext/cifar10/numo/types/int32.h +24 -0
  40. data/lib/dnn/ext/cifar10/numo/types/int64.h +24 -0
  41. data/lib/dnn/ext/cifar10/numo/types/int8.h +24 -0
  42. data/lib/dnn/ext/cifar10/numo/types/int_macro.h +41 -0
  43. data/lib/dnn/ext/cifar10/numo/types/real_accum.h +486 -0
  44. data/lib/dnn/ext/cifar10/numo/types/robj_macro.h +75 -0
  45. data/lib/dnn/ext/cifar10/numo/types/robject.h +27 -0
  46. data/lib/dnn/ext/cifar10/numo/types/scomplex.h +44 -0
  47. data/lib/dnn/ext/cifar10/numo/types/sfloat.h +43 -0
  48. data/lib/dnn/ext/cifar10/numo/types/uint16.h +21 -0
  49. data/lib/dnn/ext/cifar10/numo/types/uint32.h +21 -0
  50. data/lib/dnn/ext/cifar10/numo/types/uint64.h +21 -0
  51. data/lib/dnn/ext/cifar10/numo/types/uint8.h +21 -0
  52. data/lib/dnn/ext/cifar10/numo/types/uint_macro.h +32 -0
  53. data/lib/dnn/ext/cifar10/numo/types/xint_macro.h +189 -0
  54. data/lib/dnn/ext/image_io/Makefile +263 -0
  55. data/lib/dnn/ext/image_io/extconf.rb +3 -0
  56. data/lib/dnn/ext/image_io/image_io_ext.c +89 -0
  57. data/lib/dnn/ext/image_io/image_io_ext.so +0 -0
  58. data/lib/dnn/ext/image_io/numo/compat.h +23 -0
  59. data/lib/dnn/ext/image_io/numo/extconf.h +13 -0
  60. data/lib/dnn/ext/image_io/numo/intern.h +117 -0
  61. data/lib/dnn/ext/image_io/numo/narray.h +430 -0
  62. data/lib/dnn/ext/image_io/numo/ndloop.h +94 -0
  63. data/lib/dnn/ext/image_io/numo/template.h +149 -0
  64. data/lib/dnn/ext/image_io/numo/types/bit.h +33 -0
  65. data/lib/dnn/ext/image_io/numo/types/complex.h +409 -0
  66. data/lib/dnn/ext/image_io/numo/types/complex_macro.h +377 -0
  67. data/lib/dnn/ext/image_io/numo/types/dcomplex.h +44 -0
  68. data/lib/dnn/ext/image_io/numo/types/dfloat.h +42 -0
  69. data/lib/dnn/ext/image_io/numo/types/float_def.h +34 -0
  70. data/lib/dnn/ext/image_io/numo/types/float_macro.h +186 -0
  71. data/lib/dnn/ext/image_io/numo/types/int16.h +24 -0
  72. data/lib/dnn/ext/image_io/numo/types/int32.h +24 -0
  73. data/lib/dnn/ext/image_io/numo/types/int64.h +24 -0
  74. data/lib/dnn/ext/image_io/numo/types/int8.h +24 -0
  75. data/lib/dnn/ext/image_io/numo/types/int_macro.h +41 -0
  76. data/lib/dnn/ext/image_io/numo/types/real_accum.h +486 -0
  77. data/lib/dnn/ext/image_io/numo/types/robj_macro.h +75 -0
  78. data/lib/dnn/ext/image_io/numo/types/robject.h +27 -0
  79. data/lib/dnn/ext/image_io/numo/types/scomplex.h +44 -0
  80. data/lib/dnn/ext/image_io/numo/types/sfloat.h +43 -0
  81. data/lib/dnn/ext/image_io/numo/types/uint16.h +21 -0
  82. data/lib/dnn/ext/image_io/numo/types/uint32.h +21 -0
  83. data/lib/dnn/ext/image_io/numo/types/uint64.h +21 -0
  84. data/lib/dnn/ext/image_io/numo/types/uint8.h +21 -0
  85. data/lib/dnn/ext/image_io/numo/types/uint_macro.h +32 -0
  86. data/lib/dnn/ext/image_io/numo/types/xint_macro.h +189 -0
  87. data/lib/dnn/ext/image_io/stb_image.h +7462 -0
  88. data/lib/dnn/ext/image_io/stb_image_write.h +1568 -0
  89. data/lib/dnn/ext/mnist/Makefile +263 -0
  90. data/lib/dnn/ext/mnist/extconf.rb +3 -0
  91. data/lib/dnn/ext/mnist/mnist_ext.c +49 -0
  92. data/lib/dnn/ext/mnist/mnist_ext.o +0 -0
  93. data/lib/dnn/ext/mnist/mnist_ext.so +0 -0
  94. data/lib/dnn/ext/mnist/numo/compat.h +23 -0
  95. data/lib/dnn/ext/mnist/numo/extconf.h +13 -0
  96. data/lib/dnn/ext/mnist/numo/intern.h +117 -0
  97. data/lib/dnn/ext/mnist/numo/narray.h +430 -0
  98. data/lib/dnn/ext/mnist/numo/ndloop.h +94 -0
  99. data/lib/dnn/ext/mnist/numo/template.h +149 -0
  100. data/lib/dnn/ext/mnist/numo/types/bit.h +33 -0
  101. data/lib/dnn/ext/mnist/numo/types/complex.h +409 -0
  102. data/lib/dnn/ext/mnist/numo/types/complex_macro.h +377 -0
  103. data/lib/dnn/ext/mnist/numo/types/dcomplex.h +44 -0
  104. data/lib/dnn/ext/mnist/numo/types/dfloat.h +42 -0
  105. data/lib/dnn/ext/mnist/numo/types/float_def.h +34 -0
  106. data/lib/dnn/ext/mnist/numo/types/float_macro.h +186 -0
  107. data/lib/dnn/ext/mnist/numo/types/int16.h +24 -0
  108. data/lib/dnn/ext/mnist/numo/types/int32.h +24 -0
  109. data/lib/dnn/ext/mnist/numo/types/int64.h +24 -0
  110. data/lib/dnn/ext/mnist/numo/types/int8.h +24 -0
  111. data/lib/dnn/ext/mnist/numo/types/int_macro.h +41 -0
  112. data/lib/dnn/ext/mnist/numo/types/real_accum.h +486 -0
  113. data/lib/dnn/ext/mnist/numo/types/robj_macro.h +75 -0
  114. data/lib/dnn/ext/mnist/numo/types/robject.h +27 -0
  115. data/lib/dnn/ext/mnist/numo/types/scomplex.h +44 -0
  116. data/lib/dnn/ext/mnist/numo/types/sfloat.h +43 -0
  117. data/lib/dnn/ext/mnist/numo/types/uint16.h +21 -0
  118. data/lib/dnn/ext/mnist/numo/types/uint32.h +21 -0
  119. data/lib/dnn/ext/mnist/numo/types/uint64.h +21 -0
  120. data/lib/dnn/ext/mnist/numo/types/uint8.h +21 -0
  121. data/lib/dnn/ext/mnist/numo/types/uint_macro.h +32 -0
  122. data/lib/dnn/ext/mnist/numo/types/xint_macro.h +189 -0
  123. data/lib/dnn/lib/cifar10.rb +26 -0
  124. data/lib/dnn/lib/image_io.rb +33 -0
  125. data/lib/dnn/lib/mnist.rb +61 -0
  126. data/ruby-dnn.gemspec +41 -0
  127. metadata +225 -0
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: f4d8dd255a75d99969d877a06faa150456210f131858a6f08b558dc070f5af9f
4
+ data.tar.gz: 72fe90225b3deb1a2957af33529ef4f1845749cd53ce0246d75ffd8ed36fded5
5
+ SHA512:
6
+ metadata.gz: 96194040842f66d4ea3499fab0eefb7e5b47212ec8ba2981640879446ac72103c531bab05e55aab64ecc40d2e490fc8a5a96b20068167e017d58250f060230cf
7
+ data.tar.gz: 718a6f11e8a8e647dc0fa3867b72c3f289efa3d8a8efd689425d92ab5003e221f729a326011a46caa3c0d53be10cfe569a7e959197c0b818f44e948842a1b832
data/.gitignore ADDED
@@ -0,0 +1,8 @@
1
+ /.bundle/
2
+ /.yardoc
3
+ /_yardoc/
4
+ /coverage/
5
+ /doc/
6
+ /pkg/
7
+ /spec/reports/
8
+ /tmp/
data/.travis.yml ADDED
@@ -0,0 +1,5 @@
1
+ sudo: false
2
+ language: ruby
3
+ rvm:
4
+ - 2.5.1
5
+ before_install: gem install bundler -v 1.16.2
@@ -0,0 +1,74 @@
1
+ # Contributor Covenant Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to making participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, gender identity and expression, level of experience,
9
+ nationality, personal appearance, race, religion, or sexual identity and
10
+ orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies both within project spaces and in public spaces
49
+ when an individual is representing the project or its community. Examples of
50
+ representing a project or community include using an official project e-mail
51
+ address, posting via an official social media account, or acting as an appointed
52
+ representative at an online or offline event. Representation of a project may be
53
+ further defined and clarified by project maintainers.
54
+
55
+ ## Enforcement
56
+
57
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
58
+ reported by contacting the project team at ootoro838861@outlook.jp. All
59
+ complaints will be reviewed and investigated and will result in a response that
60
+ is deemed necessary and appropriate to the circumstances. The project team is
61
+ obligated to maintain confidentiality with regard to the reporter of an incident.
62
+ Further details of specific enforcement policies may be posted separately.
63
+
64
+ Project maintainers who do not follow or enforce the Code of Conduct in good
65
+ faith may face temporary or permanent repercussions as determined by other
66
+ members of the project's leadership.
67
+
68
+ ## Attribution
69
+
70
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
71
+ available at [http://contributor-covenant.org/version/1/4][version]
72
+
73
+ [homepage]: http://contributor-covenant.org
74
+ [version]: http://contributor-covenant.org/version/1/4/
data/Gemfile ADDED
@@ -0,0 +1,6 @@
1
+ source "https://rubygems.org"
2
+
3
+ git_source(:github) {|repo_name| "https://github.com/#{repo_name}" }
4
+
5
+ # Specify your gem's dependencies in ruby-dnn.gemspec
6
+ gemspec
data/LICENSE.txt ADDED
@@ -0,0 +1,21 @@
1
+ The MIT License (MIT)
2
+
3
+ Copyright (c) 2018 unagiootoro
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in
13
+ all copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21
+ THE SOFTWARE.
data/README.md ADDED
@@ -0,0 +1,42 @@
1
+ # ruby-dnn
2
+
3
+ ruby-dnn is a ruby deep learning library. This library supports full connected neural network and convolution neural network.
4
+ Currently, you can get 99% accuracy with MNIST and 60% with CIFAR 10.
5
+
6
+ ## Installation
7
+
8
+ Add this line to your application's Gemfile:
9
+
10
+ ```ruby
11
+ gem 'dnn'
12
+ ```
13
+
14
+ And then execute:
15
+
16
+ $ bundle
17
+
18
+ Or install it yourself as:
19
+
20
+ $ gem install dnn
21
+
22
+ ## Usage
23
+
24
+ TODO: Write usage instructions here
25
+
26
+ ## Development
27
+
28
+ After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake "spec"` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
29
+
30
+ To install this gem onto your local machine, run `bundle exec rake install`. To release a new version, update the version number in `version.rb`, and then run `bundle exec rake release`, which will create a git tag for the version, push git commits and tags, and push the `.gem` file to [rubygems.org](https://rubygems.org).
31
+
32
+ ## Contributing
33
+
34
+ Bug reports and pull requests are welcome on GitHub at https://github.com/[USERNAME]/dnn. This project is intended to be a safe, welcoming space for collaboration, and contributors are expected to adhere to the [Contributor Covenant](http://contributor-covenant.org) code of conduct.
35
+
36
+ ## License
37
+
38
+ The gem is available as open source under the terms of the [MIT License](https://opensource.org/licenses/MIT).
39
+
40
+ ## Code of Conduct
41
+
42
+ Everyone interacting in the Dnn project’s codebases, issue trackers, chat rooms and mailing lists is expected to follow the [code of conduct](https://github.com/[USERNAME]/dnn/blob/master/CODE_OF_CONDUCT.md).
data/Rakefile ADDED
@@ -0,0 +1,10 @@
1
+ require "bundler/gem_tasks"
2
+ require "rake/testtask"
3
+
4
+ Rake::TestTask.new(:test) do |t|
5
+ t.libs << "test"
6
+ t.libs << "lib"
7
+ t.test_files = FileList["test/**/*_test.rb"]
8
+ end
9
+
10
+ task :default => :test
data/bin/console ADDED
@@ -0,0 +1,14 @@
1
+ #!/usr/bin/env ruby
2
+
3
+ require "bundler/setup"
4
+ require "ruby/dnn"
5
+
6
+ # You can add fixtures and/or initialization code here to make experimenting
7
+ # with your gem easier. You can also use a different console, if you like.
8
+
9
+ # (If you use this, don't forget to add pry to your Gemfile!)
10
+ # require "pry"
11
+ # Pry.start
12
+
13
+ require "irb"
14
+ IRB.start(__FILE__)
data/bin/setup ADDED
@@ -0,0 +1,8 @@
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+ IFS=$'\n\t'
4
+ set -vx
5
+
6
+ bundle install
7
+
8
+ # Do any other automated setup that you need to do here
data/lib/dnn.rb ADDED
@@ -0,0 +1,14 @@
1
+ require "numo/narray"
2
+
3
+ Numo::SFloat.srand(rand(2**64))
4
+
5
+ module DNN; end
6
+
7
+ require "dnn/core/version"
8
+ require "dnn/core/error"
9
+ require "dnn/core/model"
10
+ require "dnn/core/initializers"
11
+ require "dnn/core/layers"
12
+ require "dnn/core/activations"
13
+ require "dnn/core/optimizers"
14
+ require "dnn/core/util"
@@ -0,0 +1,116 @@
1
+ module DNN
2
+ module Activations
3
+ Layer = Layers::Layer
4
+ OutputLayer = Layers::OutputLayer
5
+
6
+
7
+ module SigmoidFunction
8
+ def forward(x)
9
+ @out = 1.0 / (1 + NMath.exp(-x))
10
+ end
11
+ end
12
+
13
+
14
+ class Sigmoid < Layer
15
+ include SigmoidFunction
16
+
17
+ def backward(dout)
18
+ dout * (1.0 - @out) * @out
19
+ end
20
+ end
21
+
22
+
23
+ class Tanh < Layer
24
+ include Numo
25
+
26
+ def forward(x)
27
+ @x = x
28
+ NMath.tanh(x)
29
+ end
30
+
31
+ def backward(dout)
32
+ dout * (1.0 / NMath.cosh(@x)**2)
33
+ end
34
+ end
35
+
36
+
37
+ class ReLU < Layer
38
+ def forward(x)
39
+ @x = x.clone
40
+ x[x < 0] = 0
41
+ x
42
+ end
43
+
44
+ def backward(dout)
45
+ @x[@x > 0] = 1
46
+ @x[@x <= 0] = 0
47
+ dout * @x
48
+ end
49
+ end
50
+
51
+
52
+ class LeakyReLU < Layer
53
+ def initialize(alpha = 0.3)
54
+ @alpha = alpha
55
+ end
56
+
57
+ def forward(x)
58
+ @x = x.clone
59
+ a = Numo::SFloat.ones(x.shape)
60
+ a[x <= 0] = @alpha
61
+ x * a
62
+ end
63
+
64
+ def backward(dout)
65
+ @x[@x > 0] = 1
66
+ @x[@x <= 0] = @alpha
67
+ dout * @x
68
+ end
69
+ end
70
+
71
+
72
+ class IdentityWithLoss < OutputLayer
73
+ def forward(x)
74
+ @out = x
75
+ end
76
+
77
+ def backward(y)
78
+ @out - y
79
+ end
80
+
81
+ def loss(y)
82
+ 0.5 * ((@out - y) ** 2).sum / @model.batch_size + ridge
83
+ end
84
+ end
85
+
86
+
87
+ class SoftmaxWithLoss < OutputLayer
88
+ def forward(x)
89
+ @out = NMath.exp(x) / NMath.exp(x).sum(1).reshape(x.shape[0], 1)
90
+ end
91
+
92
+ def backward(y)
93
+ @out - y
94
+ end
95
+
96
+ def loss(y)
97
+ -(y * NMath.log(@out + 1e-7)).sum / @model.batch_size + ridge
98
+ end
99
+ end
100
+
101
+
102
+ class SigmoidWithLoss < OutputLayer
103
+ include Numo
104
+ include SigmoidFunction
105
+
106
+ def backward(y)
107
+ @out - y
108
+ end
109
+
110
+ def loss(y)
111
+ -(y * NMath.log(@out + 1e-7) + (1 - y) * NMath.log(1 - @out + 1e-7)).sum / @model.batch_size + ridge
112
+ end
113
+ end
114
+
115
+ end
116
+ end
@@ -0,0 +1,13 @@
1
+ module DNN
2
+ class DNN_Error < StandardError; end
3
+
4
+ class DNN_TypeError < DNN_Error; end
5
+
6
+ class DNN_SharpError < DNN_Error; end
7
+
8
+ class DNN_GradUnfairError < DNN_Error
9
+ def initialize(grad, n_grad)
10
+ super("gradient is #{grad}, but numerical gradient is #{n_grad}")
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,46 @@
1
+ module DNN
2
+ module Initializers
3
+
4
+ class Initializer
5
+ def init_param(layer, param_key, param)
6
+ layer.params[param_key] = param
7
+ end
8
+ end
9
+
10
+
11
+ class Zeros < Initializer
12
+ def init_param(layer, param_key)
13
+ super(layer, param_key, layer.params[param_key].fill(0))
14
+ end
15
+ end
16
+
17
+
18
+ class RandomNormal < Initializer
19
+ def initialize(mean = 0, std = 0.05)
20
+ @mean = mean
21
+ @std = std
22
+ end
23
+
24
+ def init_param(layer, param_key)
25
+ super(layer, param_key, layer.params[param_key].rand_norm(@mean, @std))
26
+ end
27
+ end
28
+
29
+
30
+ class Xavier < Initializer
31
+ def init_param(layer, param_key)
32
+ prev_dim = layer.prev_layer.shape.reduce(:*)
33
+ super(layer, param_key, layer.params[param_key].rand_norm / Math.sqrt(prev_dim))
34
+ end
35
+ end
36
+
37
+
38
+ class He < Initializer
39
+ def init_param(layer, param_key)
40
+ prev_dim = layer.prev_layer.shape.reduce(:*)
41
+ super(layer, param_key, layer.params[param_key].rand_norm / Math.sqrt(prev_dim) * Math.sqrt(2))
42
+ end
43
+ end
44
+
45
+ end
46
+ end
@@ -0,0 +1,366 @@
1
+ module DNN
2
+ module Layers
3
+
4
+ #Super class of all optimizer classes.
5
+ class Layer
6
+ include Numo
7
+
8
+ #Initialize layer when model is compiled.
9
+ def init(model)
10
+ @model = model
11
+ end
12
+
13
+ #Forward propagation.
14
+ def forward() end
15
+
16
+ #Backward propagation.
17
+ def backward() end
18
+
19
+ #Get the shape of the layer.
20
+ def shape
21
+ prev_layer.shape
22
+ end
23
+
24
+ #Get the previous layer.
25
+ def prev_layer
26
+ @model.layers[@model.layers.index(self) - 1]
27
+ end
28
+ end
29
+
30
+
31
+ class HasParamLayer < Layer
32
+ attr_reader :params #The parameters of the layer.
33
+ attr_reader :grads #Differential value of parameter of layer.
34
+
35
+ def initialize
36
+ @params = {}
37
+ @grads = {}
38
+ end
39
+
40
+ def init(model)
41
+ super
42
+ init_params
43
+ end
44
+
45
+ #Update the parameters.
46
+ def update
47
+ @model.optimizer.update(self)
48
+ end
49
+
50
+ private
51
+
52
+ #Initialize of the parameters.
53
+ def init_params() end
54
+ end
55
+
56
+
57
+ class InputLayer < Layer
58
+ attr_reader :shape
59
+
60
+ def initialize(dim_or_shape)
61
+ @shape = dim_or_shape.is_a?(Array) ? dim_or_shape : [dim_or_shape]
62
+ end
63
+
64
+ def forward(x)
65
+ x
66
+ end
67
+
68
+ def backward(dout)
69
+ dout
70
+ end
71
+ end
72
+
73
+
74
+ class Dense < HasParamLayer
75
+ include Initializers
76
+
77
+ attr_reader :num_nodes
78
+ attr_reader :weight_decay
79
+
80
+ def initialize(num_nodes,
81
+ weight_initializer: nil,
82
+ bias_initializer: nil,
83
+ weight_decay: 0)
84
+ super()
85
+ @num_nodes = num_nodes
86
+ @weight_initializer = (weight_initializer || RandomNormal.new)
87
+ @bias_initializer = (bias_initializer || Zeros.new)
88
+ @weight_decay = weight_decay
89
+ end
90
+
91
+ def forward(x)
92
+ @x = x
93
+ @x.dot(@params[:weight]) + @params[:bias]
94
+ end
95
+
96
+ def backward(dout)
97
+ @grads[:weight] = @x.transpose.dot(dout)
98
+ if @weight_decay > 0
99
+ dridge = @weight_decay * @params[:weight]
100
+ @grads[:weight] += dridge
101
+ end
102
+ @grads[:bias] = dout.sum(0)
103
+ dout.dot(@params[:weight].transpose)
104
+ end
105
+
106
+ def shape
107
+ [@num_nodes]
108
+ end
109
+
110
+ private
111
+
112
+ def init_params
113
+ num_prev_nodes = prev_layer.shape[0]
114
+ @params[:weight] = SFloat.new(num_prev_nodes, @num_nodes)
115
+ @params[:bias] = SFloat.new(@num_nodes)
116
+ @weight_initializer.init_param(self, :weight)
117
+ @bias_initializer.init_param(self, :bias)
118
+ end
119
+ end
120
+
121
+
122
+ module Convert
123
+ def im2col(img, out_h, out_w, fh, fw, strides)
124
+ bs, fn = img.shape[0..1]
125
+ col = SFloat.zeros(bs, fn, fh, fw, out_h, out_w)
126
+ (0...fh).each do |i|
127
+ i_range = (i...(i + strides[0] * out_h)).step(strides[0]).to_a
128
+ (0...fw).each do |j|
129
+ j_range = (j...(j + strides[1] * out_w)).step(strides[1]).to_a
130
+ col[true, true, i, j, true, true] = img[true, true, i_range, j_range]
131
+ end
132
+ end
133
+ col.transpose(0, 4, 5, 1, 2, 3).reshape(bs * out_h * out_w, fn * fh * fw)
134
+ end
135
+
136
+ def col2im(col, img_shape, out_h, out_w, fh, fw, strides)
137
+ bs, fn, ih, iw = img_shape
138
+ col = col.reshape(bs, out_h, out_w, fn, fh, fw).transpose(0, 3, 4, 5, 1, 2)
139
+ img = SFloat.zeros(bs, fn, ih, iw)
140
+ (0...fh).each do |i|
141
+ i_range = (i...(i + strides[0] * out_h)).step(strides[0]).to_a
142
+ (0...fw).each do |j|
143
+ j_range = (j...(j + strides[1] * out_w)).step(strides[1]).to_a
144
+ img[true, true, i_range, j_range] += col[true, true, i, j, true, true]
145
+ end
146
+ end
147
+ img
148
+ end
149
+
150
+ def padding(img, pad)
151
+ bs, c, ih, iw = img.shape
152
+ ih2 = ih + pad * 2
153
+ iw2 = iw + pad * 2
154
+ img2 = SFloat.zeros(bs, c, ih2, iw2)
155
+ img2[true, true, pad...(ih + pad), pad...(iw + pad)] = img
156
+ img2
157
+ end
158
+ end
159
+
160
+
161
+ class Conv2D < HasParamLayer
162
+ include Initializers
163
+ include Convert
164
+
165
+ def initialize(num_filters, filter_height, filter_width,
166
+ weight_initializer: nil,
167
+ bias_initializer: nil,
168
+ strides: [1, 1],
169
+ padding: 0,
170
+ weight_decay: 0)
171
+ super()
172
+ @num_filters = num_filters
173
+ @filter_height = filter_height
174
+ @filter_width = filter_width
175
+ @weight_initializer = (weight_initializer || RandomNormal.new)
176
+ @bias_initializer = (bias_initializer || Zeros.new)
177
+ @strides = strides
178
+ @weight_decay = weight_decay
179
+ @padding = padding
180
+ end
181
+
182
+ def init(model)
183
+ super
184
+ prev_height, prev_width = prev_layer.shape[1], prev_layer.shape[2]
185
+ @out_height = (prev_height + @padding * 2 - @filter_height) / @strides[0] + 1
186
+ @out_width = (prev_width + @padding * 2 - @filter_width) / @strides[1] + 1
187
+ end
188
+
189
+ def forward(x)
190
+ x = padding(x, 2) if @padding > 0
191
+ @x_shape = x.shape
192
+ @col = im2col(x, @out_height, @out_width, @filter_height, @filter_width, @strides)
193
+ out = @col.dot(@params[:weight])
194
+ out.reshape(@model.batch_size, @out_height, @out_width, out.shape[3]).transpose(0, 3, 1, 2)
195
+ end
196
+
197
+ def backward(dout)
198
+ dout = dout.transpose(0, 2, 3, 1)
199
+ dout = dout.reshape(dout.shape[0..2].reduce(:*), dout.shape[3])
200
+ @grads[:weight] = @col.transpose.dot(dout)
201
+ if @weight_decay > 0
202
+ dridge = @weight_decay * @params[:weight]
203
+ @grads[:weight] += dridge
204
+ end
205
+ @grads[:bias] = dout.sum(0)
206
+ dcol = dout.dot(@params[:weight].transpose)
207
+ col2im(dcol, @x_shape, @out_height, @out_width, @filter_height, @filter_width, @strides)
208
+ end
209
+
210
+ def shape
211
+ [@num_filters, @out_height, @out_width]
212
+ end
213
+
214
+ private
215
+
216
+ def init_params
217
+ num_prev_filter = prev_layer.shape[0]
218
+ @params[:weight] = SFloat.new(num_prev_filter * @filter_height * @filter_height, @num_filters)
219
+ @params[:bias] = SFloat.new(@num_filters)
220
+ @weight_initializer.init_param(self, :weight)
221
+ @bias_initializer.init_param(self, :bias)
222
+ end
223
+ end
224
+
225
+
226
+ class MaxPool2D < Layer
227
+ include Convert
228
+
229
+ def initialize(pool_height, pool_width, strides: nil, padding: 0)
230
+ @pool_height = pool_height
231
+ @pool_width = pool_width
232
+ @strides = strides ? strides : [@pool_height, @pool_width]
233
+ @padding = padding
234
+ end
235
+
236
+ def init(model)
237
+ super
238
+ prev_height, prev_width = prev_layer.shape[1], prev_layer.shape[2]
239
+ @num_channel = prev_layer.shape[0]
240
+ @out_height = (prev_height - @pool_height) / @strides[0] + 1
241
+ @out_width = (prev_width - @pool_width) / @strides[1] + 1
242
+ end
243
+
244
+ def forward(x)
245
+ @x_shape = x.shape
246
+ col = im2col(x, @out_height, @out_width, @pool_height, @pool_width, @strides)
247
+ col = col.reshape(x.shape[0] * @out_height * @out_width * x.shape[1], @pool_height * @pool_width)
248
+ @max_index = col.max_index(1)
249
+ col.max(1).reshape(x.shape[0], @out_height, @out_width, x.shape[1]).transpose(0, 3, 1, 2)
250
+ end
251
+
252
+ def backward(dout)
253
+ dout = dout.transpose(0, 2, 3, 1)
254
+ pool_size = @pool_height * @pool_width
255
+ dmax = SFloat.zeros(dout.size * pool_size)
256
+ dmax[@max_index] = dout.flatten
257
+ dcol = dmax.reshape(dout.shape[0..2].reduce(:*), dout.shape[3] * pool_size)
258
+ col2im(dcol, @x_shape, @out_height, @out_width, @pool_height, @pool_width, @strides)
259
+ end
260
+
261
+ def shape
262
+ [@num_channel, @out_height, @out_width]
263
+ end
264
+ end
265
+
266
+
267
+ class Flatten < Layer
268
+ def forward(x)
269
+ @shape = x.shape
270
+ x.reshape(x.shape[0], x.shape[1..-1].reduce(:*))
271
+ end
272
+
273
+ def backward(dout)
274
+ dout.reshape(*@shape)
275
+ end
276
+
277
+ def shape
278
+ [prev_layer.shape.reduce(:*)]
279
+ end
280
+ end
281
+
282
+
283
+ class Reshape < Layer
284
+ attr_reader :shape
285
+
286
+ def initialize(shape)
287
+ @shape = shape
288
+ @x_shape = nil
289
+ end
290
+
291
+ def forward(x)
292
+ @x_shape = x.shape
293
+ x.reshape(*@shape)
294
+ end
295
+
296
+ def backward(dout)
297
+ dout.reshape(@x_shape)
298
+ end
299
+ end
300
+
301
+
302
+ class OutputLayer < Layer
303
+ private
304
+
305
+ def ridge
306
+ @model.layers.select { |layer| layer.is_a?(Dense) }
307
+ .reduce(0) { |sum, layer| layer.weight_decay * (layer.params[:weight]**2).sum }
308
+ end
309
+ end
310
+
311
+
312
+ class Dropout < Layer
313
+ def initialize(dropout_ratio)
314
+ @dropout_ratio = dropout_ratio
315
+ @mask = nil
316
+ end
317
+
318
+ def forward(x)
319
+ if @model.training
320
+ @mask = SFloat.ones(*x.shape).rand < @dropout_ratio
321
+ x[@mask] = 0
322
+ else
323
+ x *= (1 - @dropout_ratio)
324
+ end
325
+ x
326
+ end
327
+
328
+ def backward(dout)
329
+ dout[@mask] = 0 if @model.training
330
+ dout
331
+ end
332
+ end
333
+
334
+
335
+ class BatchNormalization < HasParamLayer
336
+ def forward(x)
337
+ @mean = x.mean(0)
338
+ @xc = x - @mean
339
+ @var = (@xc**2).mean(0)
340
+ @std = NMath.sqrt(@var + 1e-7)
341
+ @xn = @xc / @std
342
+ @params[:gamma] * @xn + @params[:beta]
343
+ end
344
+
345
+ def backward(dout)
346
+ @grads[:beta] = dout.sum(0)
347
+ @grads[:gamma] = (@xn * dout).sum(0)
348
+ dxn = @params[:gamma] * dout
349
+ dxc = dxn / @std
350
+ dstd = -((dxn * @xc) / (@std**2)).sum(0)
351
+ dvar = 0.5 * dstd / @std
352
+ dxc += (2.0 / @model.batch_size) * @xc * dvar
353
+ dmean = dxc.sum(0)
354
+ dxc - dmean / @model.batch_size
355
+ end
356
+
357
+ private
358
+
359
+ def init_params
360
+ @params[:gamma] = 1
361
+ @params[:beta] = 0
362
+ end
363
+ end
364
+ end
365
+
366
+ end