rubyzero 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. checksums.yaml +7 -0
  2. data/.DS_Store +0 -0
  3. data/.gitignore +9 -0
  4. data/Gemfile +11 -0
  5. data/LICENSE +21 -0
  6. data/README.md +35 -0
  7. data/Rakefile +4 -0
  8. data/bin/console +15 -0
  9. data/bin/setup +8 -0
  10. data/changelog.md +0 -0
  11. data/lib/rubyzero/core/cast.rb +25 -0
  12. data/lib/rubyzero/core/core.rb +15 -0
  13. data/lib/rubyzero/core/device.rb +32 -0
  14. data/lib/rubyzero/core/dtypes.rb +115 -0
  15. data/lib/rubyzero/core/exceptions.rb +7 -0
  16. data/lib/rubyzero/core/functions/activations.rb +20 -0
  17. data/lib/rubyzero/core/functions/elementary_functions.rb +29 -0
  18. data/lib/rubyzero/core/functions/function.rb +29 -0
  19. data/lib/rubyzero/core/functions/functions.rb +10 -0
  20. data/lib/rubyzero/core/functions/operators.rb +112 -0
  21. data/lib/rubyzero/core/functions/tensor_functions.rb +123 -0
  22. data/lib/rubyzero/core/tensor.rb +56 -0
  23. data/lib/rubyzero/core/tensor_backward.rb +17 -0
  24. data/lib/rubyzero/core/tensor_initialize_methods.rb +78 -0
  25. data/lib/rubyzero/core/tensor_operators.rb +27 -0
  26. data/lib/rubyzero/data/data.rb +9 -0
  27. data/lib/rubyzero/data/dataloader.rb +34 -0
  28. data/lib/rubyzero/data/dataset.rb +19 -0
  29. data/lib/rubyzero/data/presets/presets.rb +7 -0
  30. data/lib/rubyzero/data/presets/xor.rb +24 -0
  31. data/lib/rubyzero/nn/functional.rb +21 -0
  32. data/lib/rubyzero/nn/layers/affine.rb +21 -0
  33. data/lib/rubyzero/nn/layers/embedding.rb +7 -0
  34. data/lib/rubyzero/nn/layers/layer.rb +5 -0
  35. data/lib/rubyzero/nn/layers/layers.rb +44 -0
  36. data/lib/rubyzero/nn/layers/modellist.rb +40 -0
  37. data/lib/rubyzero/nn/layers/modelstack.rb +20 -0
  38. data/lib/rubyzero/nn/layers/multi_layer_perceptron.rb +26 -0
  39. data/lib/rubyzero/nn/layers/relu.rb +10 -0
  40. data/lib/rubyzero/nn/load.rb +5 -0
  41. data/lib/rubyzero/nn/losses/loss.rb +5 -0
  42. data/lib/rubyzero/nn/losses/losses.rb +8 -0
  43. data/lib/rubyzero/nn/losses/mse.rb +13 -0
  44. data/lib/rubyzero/nn/model.rb +75 -0
  45. data/lib/rubyzero/nn/nn.rb +11 -0
  46. data/lib/rubyzero/nn/optimizers/momentum.rb +22 -0
  47. data/lib/rubyzero/nn/optimizers/optimizer.rb +16 -0
  48. data/lib/rubyzero/nn/optimizers/optimizers.rb +9 -0
  49. data/lib/rubyzero/nn/optimizers/sgd.rb +14 -0
  50. data/lib/rubyzero/nn/parameters.rb +36 -0
  51. data/lib/rubyzero/utils/hyper_parameter_optimizer.rb +0 -0
  52. data/lib/rubyzero/utils/trainer.rb +49 -0
  53. data/lib/rubyzero/utils/utils.rb +8 -0
  54. data/lib/rubyzero/version.rb +5 -0
  55. data/lib/rubyzero.rb +7 -0
  56. data/note.txt +29 -0
  57. data/rubyzero.gemspec +36 -0
  58. metadata +101 -0
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: 998d9df9776d84a808d54a8e7472d4a9ddea9ee6d0c6732705d8204153016b4e
4
+ data.tar.gz: 77abc1c5ed65a055339150c4443a72622afd5cb8266c6feb17b707b73678e225
5
+ SHA512:
6
+ metadata.gz: 1ecf6f49bb0d4a182f204f89fc89d456f51871a7707e4b15bd2d232580b9c62673b1e95ae6159ad3a08535a6d4440d916b6916e33065300ea3674026d75470d3
7
+ data.tar.gz: 5f5837d4f2ca18403ecafba1309ec8c2fe00d317da7dee746c2d5822c933a6d7817db739df0be988c66f1044113f50b2e1f09356210a9f8a1baf9be46bc7546d
data/.DS_Store ADDED
Binary file
data/.gitignore ADDED
@@ -0,0 +1,9 @@
1
+ /.bundle/
2
+ /.yardoc
3
+ /_yardoc/
4
+ /coverage/
5
+ /doc/
6
+ /pkg/
7
+ /spec/reports/
8
+ /tmp/
9
+ Gemfile.lock
data/Gemfile ADDED
@@ -0,0 +1,11 @@
1
+ # frozen_string_literal: true
2
+
3
+ source "https://rubygems.org"
4
+
5
+ # Specify your gem's dependencies in rubyzero.gemspec
6
+ gemspec
7
+
8
+ gem "rake", "~> 13.0"
9
+
10
+ gem "unicode_plot"
11
+ gem "numo-narray"
data/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2021 UThree
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
data/README.md ADDED
@@ -0,0 +1,35 @@
1
+ # Rubyzero
2
+
3
+ Welcome to your new gem! In this directory, you'll find the files you need to be able to package up your Ruby library into a gem. Put your Ruby code in the file `lib/rubyzero`. To experiment with that code, run `bin/console` for an interactive prompt.
4
+
5
+ TODO: Delete this and the text above, and describe your gem
6
+
7
+ ## Installation
8
+
9
+ Add this line to your application's Gemfile:
10
+
11
+ ```ruby
12
+ gem 'rubyzero'
13
+ ```
14
+
15
+ And then execute:
16
+
17
+ $ bundle install
18
+
19
+ Or install it yourself as:
20
+
21
+ $ gem install rubyzero
22
+
23
+ ## Usage
24
+
25
+ TODO: Write usage instructions here
26
+
27
+ ## Development
28
+
29
+ After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake ` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
30
+
31
+ To install this gem onto your local machine, run `bundle exec rake install`. To release a new version, update the version number in `version.rb`, and then run `bundle exec rake release`, which will create a git tag for the version, push git commits and the created tag, and push the `.gem` file to [rubygems.org](https://rubygems.org).
32
+
33
+ ## Contributing
34
+
35
+ Bug reports and pull requests are welcome on GitHub at https://github.com/[USERNAME]/rubyzero.
data/Rakefile ADDED
@@ -0,0 +1,4 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "bundler/gem_tasks"
4
+ task default: %i[]
data/bin/console ADDED
@@ -0,0 +1,15 @@
1
+ #!/usr/bin/env ruby
2
+ # frozen_string_literal: true
3
+
4
+ require "bundler/setup"
5
+ require "rubyzero"
6
+
7
+ # You can add fixtures and/or initialization code here to make experimenting
8
+ # with your gem easier. You can also use a different console, if you like.
9
+
10
+ # (If you use this, don't forget to add pry to your Gemfile!)
11
+ # require "pry"
12
+ # Pry.start
13
+
14
+ require "irb"
15
+ IRB.start(__FILE__)
data/bin/setup ADDED
@@ -0,0 +1,8 @@
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+ IFS=$'\n\t'
4
+ set -vx
5
+
6
+ bundle install
7
+
8
+ # Do any other automated setup that you need to do here
data/changelog.md ADDED
File without changes
@@ -0,0 +1,25 @@
1
+ module RubyZero::Core
2
+ # Tensor class
3
+ class Tensor
4
+ CAST_PRIORITY = [
5
+ RubyZero::Core::DataTypes::Boolean,
6
+ RubyZero::Core::DataTypes::UInt8,
7
+ RubyZero::Core::DataTypes::Int8,
8
+ RubyZero::Core::DataTypes::UInt16,
9
+ RubyZero::Core::DataTypes::Int16,
10
+ RubyZero::Core::DataTypes::UInt32,
11
+ RubyZero::Core::DataTypes::Int32,
12
+ RubyZero::Core::DataTypes::UInt64,
13
+ RubyZero::Core::DataTypes::Int64,
14
+ RubyZero::Core::DataTypes::Float32,
15
+ RubyZero::Core::DataTypes::Float64,
16
+ RubyZero::Core::DataTypes::Complex64,
17
+ RubyZero::Core::DataTypes::Complex128,
18
+ ]
19
+ def cast_to(dtype)
20
+ @dtype = dtype
21
+ @data = dtype.get_type_on_calculator(@device).cast(@data)
22
+ return self
23
+ end
24
+ end
25
+ end
@@ -0,0 +1,15 @@
1
+ module RubyZero
2
+ module Core
3
+
4
+ end
5
+ end
6
+
7
+ require_relative './tensor.rb'
8
+ require_relative './device.rb'
9
+ require_relative './dtypes.rb'
10
+ require_relative './tensor_operators.rb'
11
+ require_relative './tensor_backward.rb'
12
+ require_relative './tensor_initialize_methods.rb'
13
+ require_relative './functions/functions.rb'
14
+ require_relative './exceptions.rb'
15
+ require_relative './cast.rb'
@@ -0,0 +1,32 @@
1
+ module RubyZero::Core
2
+ class Device
3
+ attr_reader :caluculator, :identifier
4
+ def initialize(identifier)
5
+ sym = identifier.to_sym
6
+ if sym == :cpu
7
+ @caluculator = Numo
8
+ @identifier = sym
9
+ else
10
+ raise RubyZero::Core::Exceptions::DeviceNotSupported, "Device #{identifier} is not supported."
11
+ end
12
+ end
13
+ def xmo
14
+ if @identifier == :cpu
15
+ return Numo
16
+ end
17
+ end
18
+ end
19
+ end
20
+
21
+ module RubyZero
22
+ @@ident_device = {}
23
+ def self.device(identifier)
24
+ if @@ident_device[identifier]
25
+ return @@ident_device[identifier]
26
+ else
27
+ d = Core::Device.new(identifier)
28
+ @@ident_device[identifier] = d
29
+ return d
30
+ end
31
+ end
32
+ end
@@ -0,0 +1,115 @@
1
+ module RubyZero::Core
2
+ module DataTypes
3
+ # convert Rubyzero datatypes from Numo/Cumo classes.
4
+ def self.from_xmo_dtype(klass)
5
+ case klass.name
6
+ when "Numo::NArray"
7
+ return Float64
8
+ when "Numo::Bit"
9
+ return Boolean
10
+ when "Numo::Int8"
11
+ return Int8
12
+ when "Numo::Int16"
13
+ return Int16
14
+ when "Numo::Int32"
15
+ return Int32
16
+ when "Numo::Int64"
17
+ return Int64
18
+ when "Numo::UInt8"
19
+ return UInt8
20
+ when "Numo::UInt16"
21
+ return UInt16
22
+ when "Numo::UInt32"
23
+ return UInt32
24
+ when "Numo::UInt64"
25
+ return UInt64
26
+ when "Numo::SFloat"
27
+ return Float32
28
+ when "Numo::DFloat"
29
+ return Float64
30
+ when "Numo::SComplex"
31
+ return Complex64
32
+ when "Numo::DComplex"
33
+ return Complex128
34
+ end
35
+ end
36
+
37
+ class DataType
38
+ def self.get_type_on_calculator(device)
39
+ device.caluculator
40
+ end
41
+ def self.[](*data)
42
+ return Tensor.new(data, dtype:self, device: RubyZero.device(:cpu))
43
+ end
44
+ end
45
+ class Boolean < DataType
46
+ def self.get_type_on_calculator(device)
47
+ device.caluculator::Bit
48
+ end
49
+ end
50
+ class Int8 < DataType
51
+ def self.get_type_on_calculator(device)
52
+ device.caluculator::Int8
53
+ end
54
+ end
55
+ class Int16 < DataType
56
+ def self.get_type_on_calculator(device)
57
+ device.caluculator::Int16
58
+ end
59
+ end
60
+ class Int32 < DataType
61
+ def self.get_type_on_calculator(device)
62
+ device.caluculator::Int32
63
+ end
64
+ end
65
+ class Int64 < DataType
66
+ def self.get_type_on_calculator
67
+ device.caluculator::Int64
68
+ end
69
+ end
70
+ class UInt8 < DataType
71
+ def self.get_type_on_calculator(device)
72
+ device.caluculator::UInt8
73
+ end
74
+ end
75
+ class UInt16 < DataType
76
+ def self.get_type_on_calculator(device)
77
+ device.caluculator::UInt16
78
+ end
79
+ end
80
+ class UInt32 < DataType
81
+ def self.get_type_on_calculator(device)
82
+ device.caluculator::UInt32
83
+ end
84
+ end
85
+ class UInt64 < DataType
86
+ def self.get_type_on_calculator(device)
87
+ device.caluculator::UInt64
88
+ end
89
+ end
90
+ class Float32 < DataType
91
+ def self.get_type_on_calculator(device)
92
+ device.caluculator::SFloat
93
+ end
94
+ end
95
+ class Float64 < DataType
96
+ def self.get_type_on_calculator(device)
97
+ device.caluculator::DFloat
98
+ end
99
+ end
100
+ class Complex64 < DataType
101
+ def self.get_type_on_calculator(device)
102
+ device.caluculator::SComplex
103
+ end
104
+ end
105
+ class Complex128 < DataType
106
+ def self.get_type_on_calculator(device)
107
+ device.caluculator::DComplex
108
+ end
109
+ end
110
+ end
111
+ end
112
+
113
+ module RubyZero
114
+ include Core::DataTypes
115
+ end
@@ -0,0 +1,7 @@
1
+ module RubyZero::Core
2
+ module Exceptions
3
+ class NoInplementError < StandardError ; end
4
+ class DeviceNotSupported < StandardError ; end
5
+ class TypeNotSupported < StandardError ; end
6
+ end
7
+ end
@@ -0,0 +1,20 @@
1
+ module RubyZero::Core::Functions
2
+ class ReLU < Function
3
+ def forward(x)
4
+ @path_through = RubyZero::Core::Tensor.new(x.data < 0)
5
+ @path_through.cast_to(RubyZero::FloatTensor)
6
+ return x * @path_through
7
+ end
8
+ def backward(dy)
9
+ return [ @path_through * dy ]
10
+ end
11
+ end
12
+
13
+ class Sigmoid < Function
14
+ def forward(x)
15
+ nmath = x.device.caluculator::NMath
16
+ data = 1.0 / (1.0 + nmath.exp(-x.data))
17
+ return RubyZero::Core::Tensor.new(data)
18
+ end
19
+ end
20
+ end
@@ -0,0 +1,29 @@
1
+ module RubyZero::Core::Functions
2
+ class Log < Function
3
+ def forward(x1)
4
+ nmath = x1.device.xmo::NMath
5
+ new_arr = nmath.log(x1.data)
6
+ new_t = RubyZero::Core::Tensor.new(new_arr, device: x1.device)
7
+ return new_t
8
+ end
9
+
10
+ def backward(dy)
11
+ x1 = @inputs[0]
12
+ return [dy / x1]
13
+ end
14
+ end
15
+
16
+ class Exp < Function
17
+ def forward(x1)
18
+ nmath = x1.device.xmo::NMath
19
+ new_arr = nmath.exp(x1.data)
20
+ new_t = RubyZero::Core::Tensor.new(new_arr, device: x1.device)
21
+ return new_t
22
+ end
23
+
24
+ def backward(dy)
25
+ x1 = @inputs[0]
26
+ return [dy * self.new.call(x1)]
27
+ end
28
+ end
29
+ end
@@ -0,0 +1,29 @@
1
+ require_relative '../exceptions.rb'
2
+
3
+ module RubyZero::Core::Functions
4
+ # Function class
5
+ class Function
6
+ attr_reader :inputs, :output
7
+ def initialize(*args, **kwargs, &block)
8
+
9
+ end
10
+ def forward(*args, **kwargs, &block)
11
+ raise Execptions::NotImplementedError, "#{self.class}#forward() not implemented"
12
+ end
13
+ def backward(*args, **kwargs, &block)
14
+ raise Execptions::NotImplementedError, "#{self.class}#backward() not implemented"
15
+ end
16
+ def call(*args)
17
+ @inputs = args
18
+ @output = forward(*args)
19
+ if @inputs.any?{|t| t.requires_grad?}
20
+ @output.grad_fn = self
21
+ @output.requires_grad = true
22
+ end
23
+ return @output
24
+ end
25
+ def inspect
26
+ return "#<#{self.class}>"
27
+ end
28
+ end
29
+ end
@@ -0,0 +1,10 @@
1
+ module RubyZero::Core
2
+ module Functions
3
+
4
+ end
5
+ end
6
+
7
+ require_relative './function.rb'
8
+ require_relative './operators.rb'
9
+ require_relative './tensor_functions.rb'
10
+ require_relative './activations.rb'
@@ -0,0 +1,112 @@
1
+ module RubyZero::Core::Functions
2
+ class Neg < Function
3
+ def forward(x1)
4
+ new_arr = -(x1.data)
5
+ new_t = RubyZero::Core::Tensor.new(new_arr)
6
+ return new_t
7
+ end
8
+ def backward(dy)
9
+ return [-dy]
10
+ end
11
+ end
12
+
13
+ class Add < Function
14
+ def forward(x1, x2)
15
+ new_arr = x1.data + x2.data
16
+ new_t = RubyZero::Core::Tensor.new(new_arr, device: x1.device)
17
+ return new_t
18
+ end
19
+ def backward(dy)
20
+ return dy, dy
21
+ end
22
+ end
23
+
24
+ class Sub < Function
25
+ def forward(x1, x2)
26
+ new_arr = x1.data - x2.data
27
+ new_t = RubyZero::Core::Tensor.new(new_arr, device: x1.device)
28
+ return new_t
29
+ end
30
+ def backward(dy)
31
+ return dy, -dy
32
+ end
33
+ end
34
+
35
+ class Mul < Function
36
+ def forward(x1, x2)
37
+ new_arr = x1.data * x2.data
38
+ new_t = RubyZero::Core::Tensor.new(new_arr, device: x1.device)
39
+ return new_t
40
+ end
41
+ def backward(dy)
42
+ x1, x2 = @inputs
43
+ return dy * x2, dy * x1
44
+ end
45
+ end
46
+
47
+ class Div < Function
48
+ def forward(x1, x2)
49
+ new_arr = x1.data / x2.data
50
+ new_t = RubyZero::Core::Tensor.new(new_arr, device: x1.device)
51
+ return new_t
52
+ end
53
+ def backward(dy)
54
+ x1, x2 = @inputs
55
+ return dy / x2, -dy * x1 / x2 ** 2
56
+ end
57
+ end
58
+
59
+ class Pow < Function
60
+ def forward(x1, x2)
61
+ new_arr = x1.data ** x2.data
62
+ new_t = RubyZero::Core::Tensor.new(new_arr, device: x1.device)
63
+ return new_t
64
+ end
65
+ def backward(dy)
66
+ x1, x2 = @inputs
67
+ return dy * x2 * x1 ** (x2 - 1), dy * x1 ** x2 * Log.new().call(x1)
68
+ end
69
+ end
70
+
71
+ class Log < Function
72
+ def forward(x1)
73
+ nmath = x1.device.xmo::NMath
74
+ new_arr = nmath.log(x1.data)
75
+ new_t = RubyZero::Core::Tensor.new(new_arr, device: x1.device)
76
+ return new_t
77
+ end
78
+
79
+ def backward(dy)
80
+ x1 = @inputs[0]
81
+ return [dy / x1]
82
+ end
83
+ end
84
+
85
+ class MulScalar < Function
86
+ def initialize(scalar)
87
+ @scalar = scalar
88
+ end
89
+ def forward(x1)
90
+ new_arr = x1.data * @scalar
91
+ new_t = RubyZero::Core::Tensor.new(new_arr, device: x1.device)
92
+ return new_t
93
+ end
94
+ def backward(dy)
95
+ return [dy * @scalar]
96
+ end
97
+ end
98
+
99
+ class DivScalar < Function
100
+ def initialize(scalar)
101
+ @scalar = scalar
102
+ end
103
+ def forward(x1)
104
+ new_arr = x1.data / @scalar
105
+ new_t = RubyZero::Core::Tensor.new(new_arr, device: x1.device)
106
+ return new_t
107
+ end
108
+ def backward(dy)
109
+ return [dy / @scalar]
110
+ end
111
+ end
112
+ end
@@ -0,0 +1,123 @@
1
+ module RubyZero::Core::Functions
2
+ class Reshape < Function
3
+ def initialize(shape)
4
+ @dist_shape = shape
5
+ end
6
+ def forward(x1)
7
+ new_arr = x1.data.reshape(@dist_shape)
8
+ @prev_shape = x1.shape
9
+ new_t = RubyZero::Core::Tensor.new(new_arr, device: x1.device)
10
+ return new_t
11
+ end
12
+ def backward(dy)
13
+ return [dy.reshape(@prev_shape)]
14
+ end
15
+ end
16
+ class SwapAxes < Function
17
+ def initialize(axis1, axis2)
18
+ @axis1 = axis1
19
+ @axis2 = axis2
20
+ end
21
+ def forward(x1)
22
+ new_arr = x1.data.swapaxes(@axis1, @axis2)
23
+ @prev_shape = x1.shape
24
+ new_t = RubyZero::Core::Tensor.new(new_arr, device: x1.device)
25
+ return new_t
26
+ end
27
+ def backward(dy)
28
+ return [dy.swapaxes(@axis1, @axis2)]
29
+ end
30
+ end
31
+ class Repeat < Function
32
+ def initialize(axis, repeats)
33
+ @axis = axis
34
+ @repeats = repeats
35
+ end
36
+ def forward(x1)
37
+ arr = x1.data
38
+ arr = arr.reshape(*([1] + arr.shape))
39
+ arr = arr.repeat(@repeats, axis:0)
40
+ arr = arr.swapaxes(0, @axis)
41
+ new_t = RubyZero::Core::Tensor.new(arr, device: x1.device)
42
+ return new_t
43
+ end
44
+ def backward(dy)
45
+ return [ dy.sum(axis: @axis) ]
46
+ end
47
+ end
48
+
49
+ class Sum < Function
50
+ def initialize(axis)
51
+ @axis = axis
52
+ end
53
+ def forward(x1)
54
+ @repeats = x1.shape[@axis]
55
+ arr = x1.data
56
+ arr = arr.sum(axis: @axis)
57
+ new_t = RubyZero::Core::Tensor.new(arr, device: x1.device)
58
+ return new_t
59
+ end
60
+ def backward(dy)
61
+ return [dy.repeat(@repeats, axis: @axis) / @repeats]
62
+ end
63
+ end
64
+
65
+ class Mean < Function
66
+ def initialize(axis)
67
+ @axis = axis
68
+ end
69
+ def forward(x1)
70
+ @repeats = x1.shape[@axis]
71
+ arr = x1.data
72
+ arr = arr.mean(axis: @axis)
73
+ new_t = RubyZero::Core::Tensor.new(arr, device: x1.device)
74
+ return new_t
75
+ end
76
+ def backward(dy)
77
+ return [dy.repeat(@repeats, axis: @axis)]
78
+ end
79
+ end
80
+
81
+ class DotProduct < Function
82
+ def initialize()
83
+ end
84
+ def forward(x1, x2)
85
+ arr = x1.data.dot(x2.data)
86
+ new_t = RubyZero::Core::Tensor.new(arr, device: x1.device)
87
+ return new_t
88
+ end
89
+ def backward(dy)
90
+ x1, x2 = @inputs
91
+ dx, dy = [dy.dot(x2.swapaxes(0,1)), x1.swapaxes(0,1).dot(dy)]
92
+ return dx, dy
93
+ end
94
+ end
95
+ end
96
+
97
+ # apply Tensor class
98
+ module RubyZero::Core
99
+ class Tensor
100
+ def reshape(shape)
101
+ return RubyZero::Core::Functions::Reshape.new(shape).call(self)
102
+ end
103
+ def swapaxes(axis1, axis2)
104
+ return RubyZero::Core::Functions::SwapAxes.new(axis1, axis2).call(self)
105
+ end
106
+ def repeat(repeats, axis:0)
107
+ return RubyZero::Core::Functions::Repeat.new(axis, repeats).call(self)
108
+ end
109
+ def sum(axis: 0)
110
+ return RubyZero::Core::Functions::Sum.new(axis).call(self)
111
+ end
112
+ def mean(axis: 0)
113
+ return RubyZero::Core::Functions::Mean.new(axis).call(self)
114
+ end
115
+ def dot(other)
116
+ if other.is_a?(RubyZero::Core::Tensor)
117
+ return RubyZero::Core::Functions::DotProduct.new().call(self, other)
118
+ else
119
+ return self*other
120
+ end
121
+ end
122
+ end
123
+ end