rubyzero 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.DS_Store +0 -0
- data/.gitignore +9 -0
- data/Gemfile +11 -0
- data/LICENSE +21 -0
- data/README.md +35 -0
- data/Rakefile +4 -0
- data/bin/console +15 -0
- data/bin/setup +8 -0
- data/changelog.md +0 -0
- data/lib/rubyzero/core/cast.rb +25 -0
- data/lib/rubyzero/core/core.rb +15 -0
- data/lib/rubyzero/core/device.rb +32 -0
- data/lib/rubyzero/core/dtypes.rb +115 -0
- data/lib/rubyzero/core/exceptions.rb +7 -0
- data/lib/rubyzero/core/functions/activations.rb +20 -0
- data/lib/rubyzero/core/functions/elementary_functions.rb +29 -0
- data/lib/rubyzero/core/functions/function.rb +29 -0
- data/lib/rubyzero/core/functions/functions.rb +10 -0
- data/lib/rubyzero/core/functions/operators.rb +112 -0
- data/lib/rubyzero/core/functions/tensor_functions.rb +123 -0
- data/lib/rubyzero/core/tensor.rb +56 -0
- data/lib/rubyzero/core/tensor_backward.rb +17 -0
- data/lib/rubyzero/core/tensor_initialize_methods.rb +78 -0
- data/lib/rubyzero/core/tensor_operators.rb +27 -0
- data/lib/rubyzero/data/data.rb +9 -0
- data/lib/rubyzero/data/dataloader.rb +34 -0
- data/lib/rubyzero/data/dataset.rb +19 -0
- data/lib/rubyzero/data/presets/presets.rb +7 -0
- data/lib/rubyzero/data/presets/xor.rb +24 -0
- data/lib/rubyzero/nn/functional.rb +21 -0
- data/lib/rubyzero/nn/layers/affine.rb +21 -0
- data/lib/rubyzero/nn/layers/embedding.rb +7 -0
- data/lib/rubyzero/nn/layers/layer.rb +5 -0
- data/lib/rubyzero/nn/layers/layers.rb +44 -0
- data/lib/rubyzero/nn/layers/modellist.rb +40 -0
- data/lib/rubyzero/nn/layers/modelstack.rb +20 -0
- data/lib/rubyzero/nn/layers/multi_layer_perceptron.rb +26 -0
- data/lib/rubyzero/nn/layers/relu.rb +10 -0
- data/lib/rubyzero/nn/load.rb +5 -0
- data/lib/rubyzero/nn/losses/loss.rb +5 -0
- data/lib/rubyzero/nn/losses/losses.rb +8 -0
- data/lib/rubyzero/nn/losses/mse.rb +13 -0
- data/lib/rubyzero/nn/model.rb +75 -0
- data/lib/rubyzero/nn/nn.rb +11 -0
- data/lib/rubyzero/nn/optimizers/momentum.rb +22 -0
- data/lib/rubyzero/nn/optimizers/optimizer.rb +16 -0
- data/lib/rubyzero/nn/optimizers/optimizers.rb +9 -0
- data/lib/rubyzero/nn/optimizers/sgd.rb +14 -0
- data/lib/rubyzero/nn/parameters.rb +36 -0
- data/lib/rubyzero/utils/hyper_parameter_optimizer.rb +0 -0
- data/lib/rubyzero/utils/trainer.rb +49 -0
- data/lib/rubyzero/utils/utils.rb +8 -0
- data/lib/rubyzero/version.rb +5 -0
- data/lib/rubyzero.rb +7 -0
- data/note.txt +29 -0
- data/rubyzero.gemspec +36 -0
- metadata +101 -0
checksums.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
1
|
+
---
|
2
|
+
SHA256:
|
3
|
+
metadata.gz: 998d9df9776d84a808d54a8e7472d4a9ddea9ee6d0c6732705d8204153016b4e
|
4
|
+
data.tar.gz: 77abc1c5ed65a055339150c4443a72622afd5cb8266c6feb17b707b73678e225
|
5
|
+
SHA512:
|
6
|
+
metadata.gz: 1ecf6f49bb0d4a182f204f89fc89d456f51871a7707e4b15bd2d232580b9c62673b1e95ae6159ad3a08535a6d4440d916b6916e33065300ea3674026d75470d3
|
7
|
+
data.tar.gz: 5f5837d4f2ca18403ecafba1309ec8c2fe00d317da7dee746c2d5822c933a6d7817db739df0be988c66f1044113f50b2e1f09356210a9f8a1baf9be46bc7546d
|
data/.DS_Store
ADDED
Binary file
|
data/.gitignore
ADDED
data/Gemfile
ADDED
data/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) 2021 UThree
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|
data/README.md
ADDED
@@ -0,0 +1,35 @@
|
|
1
|
+
# Rubyzero
|
2
|
+
|
3
|
+
Welcome to your new gem! In this directory, you'll find the files you need to be able to package up your Ruby library into a gem. Put your Ruby code in the file `lib/rubyzero`. To experiment with that code, run `bin/console` for an interactive prompt.
|
4
|
+
|
5
|
+
TODO: Delete this and the text above, and describe your gem
|
6
|
+
|
7
|
+
## Installation
|
8
|
+
|
9
|
+
Add this line to your application's Gemfile:
|
10
|
+
|
11
|
+
```ruby
|
12
|
+
gem 'rubyzero'
|
13
|
+
```
|
14
|
+
|
15
|
+
And then execute:
|
16
|
+
|
17
|
+
$ bundle install
|
18
|
+
|
19
|
+
Or install it yourself as:
|
20
|
+
|
21
|
+
$ gem install rubyzero
|
22
|
+
|
23
|
+
## Usage
|
24
|
+
|
25
|
+
TODO: Write usage instructions here
|
26
|
+
|
27
|
+
## Development
|
28
|
+
|
29
|
+
After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake ` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
|
30
|
+
|
31
|
+
To install this gem onto your local machine, run `bundle exec rake install`. To release a new version, update the version number in `version.rb`, and then run `bundle exec rake release`, which will create a git tag for the version, push git commits and the created tag, and push the `.gem` file to [rubygems.org](https://rubygems.org).
|
32
|
+
|
33
|
+
## Contributing
|
34
|
+
|
35
|
+
Bug reports and pull requests are welcome on GitHub at https://github.com/[USERNAME]/rubyzero.
|
data/Rakefile
ADDED
data/bin/console
ADDED
@@ -0,0 +1,15 @@
|
|
1
|
+
#!/usr/bin/env ruby
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
require "bundler/setup"
|
5
|
+
require "rubyzero"
|
6
|
+
|
7
|
+
# You can add fixtures and/or initialization code here to make experimenting
|
8
|
+
# with your gem easier. You can also use a different console, if you like.
|
9
|
+
|
10
|
+
# (If you use this, don't forget to add pry to your Gemfile!)
|
11
|
+
# require "pry"
|
12
|
+
# Pry.start
|
13
|
+
|
14
|
+
require "irb"
|
15
|
+
IRB.start(__FILE__)
|
data/bin/setup
ADDED
data/changelog.md
ADDED
File without changes
|
@@ -0,0 +1,25 @@
|
|
1
|
+
module RubyZero::Core
|
2
|
+
# Tensor class
|
3
|
+
class Tensor
|
4
|
+
CAST_PRIORITY = [
|
5
|
+
RubyZero::Core::DataTypes::Boolean,
|
6
|
+
RubyZero::Core::DataTypes::UInt8,
|
7
|
+
RubyZero::Core::DataTypes::Int8,
|
8
|
+
RubyZero::Core::DataTypes::UInt16,
|
9
|
+
RubyZero::Core::DataTypes::Int16,
|
10
|
+
RubyZero::Core::DataTypes::UInt32,
|
11
|
+
RubyZero::Core::DataTypes::Int32,
|
12
|
+
RubyZero::Core::DataTypes::UInt64,
|
13
|
+
RubyZero::Core::DataTypes::Int64,
|
14
|
+
RubyZero::Core::DataTypes::Float32,
|
15
|
+
RubyZero::Core::DataTypes::Float64,
|
16
|
+
RubyZero::Core::DataTypes::Complex64,
|
17
|
+
RubyZero::Core::DataTypes::Complex128,
|
18
|
+
]
|
19
|
+
def cast_to(dtype)
|
20
|
+
@dtype = dtype
|
21
|
+
@data = dtype.get_type_on_calculator(@device).cast(@data)
|
22
|
+
return self
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
@@ -0,0 +1,15 @@
|
|
1
|
+
module RubyZero
|
2
|
+
module Core
|
3
|
+
|
4
|
+
end
|
5
|
+
end
|
6
|
+
|
7
|
+
require_relative './tensor.rb'
|
8
|
+
require_relative './device.rb'
|
9
|
+
require_relative './dtypes.rb'
|
10
|
+
require_relative './tensor_operators.rb'
|
11
|
+
require_relative './tensor_backward.rb'
|
12
|
+
require_relative './tensor_initialize_methods.rb'
|
13
|
+
require_relative './functions/functions.rb'
|
14
|
+
require_relative './exceptions.rb'
|
15
|
+
require_relative './cast.rb'
|
@@ -0,0 +1,32 @@
|
|
1
|
+
module RubyZero::Core
|
2
|
+
class Device
|
3
|
+
attr_reader :caluculator, :identifier
|
4
|
+
def initialize(identifier)
|
5
|
+
sym = identifier.to_sym
|
6
|
+
if sym == :cpu
|
7
|
+
@caluculator = Numo
|
8
|
+
@identifier = sym
|
9
|
+
else
|
10
|
+
raise RubyZero::Core::Exceptions::DeviceNotSupported, "Device #{identifier} is not supported."
|
11
|
+
end
|
12
|
+
end
|
13
|
+
def xmo
|
14
|
+
if @identifier == :cpu
|
15
|
+
return Numo
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
20
|
+
|
21
|
+
module RubyZero
|
22
|
+
@@ident_device = {}
|
23
|
+
def self.device(identifier)
|
24
|
+
if @@ident_device[identifier]
|
25
|
+
return @@ident_device[identifier]
|
26
|
+
else
|
27
|
+
d = Core::Device.new(identifier)
|
28
|
+
@@ident_device[identifier] = d
|
29
|
+
return d
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
@@ -0,0 +1,115 @@
|
|
1
|
+
module RubyZero::Core
|
2
|
+
module DataTypes
|
3
|
+
# convert Rubyzero datatypes from Numo/Cumo classes.
|
4
|
+
def self.from_xmo_dtype(klass)
|
5
|
+
case klass.name
|
6
|
+
when "Numo::NArray"
|
7
|
+
return Float64
|
8
|
+
when "Numo::Bit"
|
9
|
+
return Boolean
|
10
|
+
when "Numo::Int8"
|
11
|
+
return Int8
|
12
|
+
when "Numo::Int16"
|
13
|
+
return Int16
|
14
|
+
when "Numo::Int32"
|
15
|
+
return Int32
|
16
|
+
when "Numo::Int64"
|
17
|
+
return Int64
|
18
|
+
when "Numo::UInt8"
|
19
|
+
return UInt8
|
20
|
+
when "Numo::UInt16"
|
21
|
+
return UInt16
|
22
|
+
when "Numo::UInt32"
|
23
|
+
return UInt32
|
24
|
+
when "Numo::UInt64"
|
25
|
+
return UInt64
|
26
|
+
when "Numo::SFloat"
|
27
|
+
return Float32
|
28
|
+
when "Numo::DFloat"
|
29
|
+
return Float64
|
30
|
+
when "Numo::SComplex"
|
31
|
+
return Complex64
|
32
|
+
when "Numo::DComplex"
|
33
|
+
return Complex128
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
37
|
+
class DataType
|
38
|
+
def self.get_type_on_calculator(device)
|
39
|
+
device.caluculator
|
40
|
+
end
|
41
|
+
def self.[](*data)
|
42
|
+
return Tensor.new(data, dtype:self, device: RubyZero.device(:cpu))
|
43
|
+
end
|
44
|
+
end
|
45
|
+
class Boolean < DataType
|
46
|
+
def self.get_type_on_calculator(device)
|
47
|
+
device.caluculator::Bit
|
48
|
+
end
|
49
|
+
end
|
50
|
+
class Int8 < DataType
|
51
|
+
def self.get_type_on_calculator(device)
|
52
|
+
device.caluculator::Int8
|
53
|
+
end
|
54
|
+
end
|
55
|
+
class Int16 < DataType
|
56
|
+
def self.get_type_on_calculator(device)
|
57
|
+
device.caluculator::Int16
|
58
|
+
end
|
59
|
+
end
|
60
|
+
class Int32 < DataType
|
61
|
+
def self.get_type_on_calculator(device)
|
62
|
+
device.caluculator::Int32
|
63
|
+
end
|
64
|
+
end
|
65
|
+
class Int64 < DataType
|
66
|
+
def self.get_type_on_calculator
|
67
|
+
device.caluculator::Int64
|
68
|
+
end
|
69
|
+
end
|
70
|
+
class UInt8 < DataType
|
71
|
+
def self.get_type_on_calculator(device)
|
72
|
+
device.caluculator::UInt8
|
73
|
+
end
|
74
|
+
end
|
75
|
+
class UInt16 < DataType
|
76
|
+
def self.get_type_on_calculator(device)
|
77
|
+
device.caluculator::UInt16
|
78
|
+
end
|
79
|
+
end
|
80
|
+
class UInt32 < DataType
|
81
|
+
def self.get_type_on_calculator(device)
|
82
|
+
device.caluculator::UInt32
|
83
|
+
end
|
84
|
+
end
|
85
|
+
class UInt64 < DataType
|
86
|
+
def self.get_type_on_calculator(device)
|
87
|
+
device.caluculator::UInt64
|
88
|
+
end
|
89
|
+
end
|
90
|
+
class Float32 < DataType
|
91
|
+
def self.get_type_on_calculator(device)
|
92
|
+
device.caluculator::SFloat
|
93
|
+
end
|
94
|
+
end
|
95
|
+
class Float64 < DataType
|
96
|
+
def self.get_type_on_calculator(device)
|
97
|
+
device.caluculator::DFloat
|
98
|
+
end
|
99
|
+
end
|
100
|
+
class Complex64 < DataType
|
101
|
+
def self.get_type_on_calculator(device)
|
102
|
+
device.caluculator::SComplex
|
103
|
+
end
|
104
|
+
end
|
105
|
+
class Complex128 < DataType
|
106
|
+
def self.get_type_on_calculator(device)
|
107
|
+
device.caluculator::DComplex
|
108
|
+
end
|
109
|
+
end
|
110
|
+
end
|
111
|
+
end
|
112
|
+
|
113
|
+
module RubyZero
|
114
|
+
include Core::DataTypes
|
115
|
+
end
|
@@ -0,0 +1,20 @@
|
|
1
|
+
module RubyZero::Core::Functions
|
2
|
+
class ReLU < Function
|
3
|
+
def forward(x)
|
4
|
+
@path_through = RubyZero::Core::Tensor.new(x.data < 0)
|
5
|
+
@path_through.cast_to(RubyZero::FloatTensor)
|
6
|
+
return x * @path_through
|
7
|
+
end
|
8
|
+
def backward(dy)
|
9
|
+
return [ @path_through * dy ]
|
10
|
+
end
|
11
|
+
end
|
12
|
+
|
13
|
+
class Sigmoid < Function
|
14
|
+
def forward(x)
|
15
|
+
nmath = x.device.caluculator::NMath
|
16
|
+
data = 1.0 / (1.0 + nmath.exp(-x.data))
|
17
|
+
return RubyZero::Core::Tensor.new(data)
|
18
|
+
end
|
19
|
+
end
|
20
|
+
end
|
@@ -0,0 +1,29 @@
|
|
1
|
+
module RubyZero::Core::Functions
|
2
|
+
class Log < Function
|
3
|
+
def forward(x1)
|
4
|
+
nmath = x1.device.xmo::NMath
|
5
|
+
new_arr = nmath.log(x1.data)
|
6
|
+
new_t = RubyZero::Core::Tensor.new(new_arr, device: x1.device)
|
7
|
+
return new_t
|
8
|
+
end
|
9
|
+
|
10
|
+
def backward(dy)
|
11
|
+
x1 = @inputs[0]
|
12
|
+
return [dy / x1]
|
13
|
+
end
|
14
|
+
end
|
15
|
+
|
16
|
+
class Exp < Function
|
17
|
+
def forward(x1)
|
18
|
+
nmath = x1.device.xmo::NMath
|
19
|
+
new_arr = nmath.exp(x1.data)
|
20
|
+
new_t = RubyZero::Core::Tensor.new(new_arr, device: x1.device)
|
21
|
+
return new_t
|
22
|
+
end
|
23
|
+
|
24
|
+
def backward(dy)
|
25
|
+
x1 = @inputs[0]
|
26
|
+
return [dy * self.new.call(x1)]
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
@@ -0,0 +1,29 @@
|
|
1
|
+
require_relative '../exceptions.rb'
|
2
|
+
|
3
|
+
module RubyZero::Core::Functions
|
4
|
+
# Function class
|
5
|
+
class Function
|
6
|
+
attr_reader :inputs, :output
|
7
|
+
def initialize(*args, **kwargs, &block)
|
8
|
+
|
9
|
+
end
|
10
|
+
def forward(*args, **kwargs, &block)
|
11
|
+
raise Execptions::NotImplementedError, "#{self.class}#forward() not implemented"
|
12
|
+
end
|
13
|
+
def backward(*args, **kwargs, &block)
|
14
|
+
raise Execptions::NotImplementedError, "#{self.class}#backward() not implemented"
|
15
|
+
end
|
16
|
+
def call(*args)
|
17
|
+
@inputs = args
|
18
|
+
@output = forward(*args)
|
19
|
+
if @inputs.any?{|t| t.requires_grad?}
|
20
|
+
@output.grad_fn = self
|
21
|
+
@output.requires_grad = true
|
22
|
+
end
|
23
|
+
return @output
|
24
|
+
end
|
25
|
+
def inspect
|
26
|
+
return "#<#{self.class}>"
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
@@ -0,0 +1,112 @@
|
|
1
|
+
module RubyZero::Core::Functions
|
2
|
+
class Neg < Function
|
3
|
+
def forward(x1)
|
4
|
+
new_arr = -(x1.data)
|
5
|
+
new_t = RubyZero::Core::Tensor.new(new_arr)
|
6
|
+
return new_t
|
7
|
+
end
|
8
|
+
def backward(dy)
|
9
|
+
return [-dy]
|
10
|
+
end
|
11
|
+
end
|
12
|
+
|
13
|
+
class Add < Function
|
14
|
+
def forward(x1, x2)
|
15
|
+
new_arr = x1.data + x2.data
|
16
|
+
new_t = RubyZero::Core::Tensor.new(new_arr, device: x1.device)
|
17
|
+
return new_t
|
18
|
+
end
|
19
|
+
def backward(dy)
|
20
|
+
return dy, dy
|
21
|
+
end
|
22
|
+
end
|
23
|
+
|
24
|
+
class Sub < Function
|
25
|
+
def forward(x1, x2)
|
26
|
+
new_arr = x1.data - x2.data
|
27
|
+
new_t = RubyZero::Core::Tensor.new(new_arr, device: x1.device)
|
28
|
+
return new_t
|
29
|
+
end
|
30
|
+
def backward(dy)
|
31
|
+
return dy, -dy
|
32
|
+
end
|
33
|
+
end
|
34
|
+
|
35
|
+
class Mul < Function
|
36
|
+
def forward(x1, x2)
|
37
|
+
new_arr = x1.data * x2.data
|
38
|
+
new_t = RubyZero::Core::Tensor.new(new_arr, device: x1.device)
|
39
|
+
return new_t
|
40
|
+
end
|
41
|
+
def backward(dy)
|
42
|
+
x1, x2 = @inputs
|
43
|
+
return dy * x2, dy * x1
|
44
|
+
end
|
45
|
+
end
|
46
|
+
|
47
|
+
class Div < Function
|
48
|
+
def forward(x1, x2)
|
49
|
+
new_arr = x1.data / x2.data
|
50
|
+
new_t = RubyZero::Core::Tensor.new(new_arr, device: x1.device)
|
51
|
+
return new_t
|
52
|
+
end
|
53
|
+
def backward(dy)
|
54
|
+
x1, x2 = @inputs
|
55
|
+
return dy / x2, -dy * x1 / x2 ** 2
|
56
|
+
end
|
57
|
+
end
|
58
|
+
|
59
|
+
class Pow < Function
|
60
|
+
def forward(x1, x2)
|
61
|
+
new_arr = x1.data ** x2.data
|
62
|
+
new_t = RubyZero::Core::Tensor.new(new_arr, device: x1.device)
|
63
|
+
return new_t
|
64
|
+
end
|
65
|
+
def backward(dy)
|
66
|
+
x1, x2 = @inputs
|
67
|
+
return dy * x2 * x1 ** (x2 - 1), dy * x1 ** x2 * Log.new().call(x1)
|
68
|
+
end
|
69
|
+
end
|
70
|
+
|
71
|
+
class Log < Function
|
72
|
+
def forward(x1)
|
73
|
+
nmath = x1.device.xmo::NMath
|
74
|
+
new_arr = nmath.log(x1.data)
|
75
|
+
new_t = RubyZero::Core::Tensor.new(new_arr, device: x1.device)
|
76
|
+
return new_t
|
77
|
+
end
|
78
|
+
|
79
|
+
def backward(dy)
|
80
|
+
x1 = @inputs[0]
|
81
|
+
return [dy / x1]
|
82
|
+
end
|
83
|
+
end
|
84
|
+
|
85
|
+
class MulScalar < Function
|
86
|
+
def initialize(scalar)
|
87
|
+
@scalar = scalar
|
88
|
+
end
|
89
|
+
def forward(x1)
|
90
|
+
new_arr = x1.data * @scalar
|
91
|
+
new_t = RubyZero::Core::Tensor.new(new_arr, device: x1.device)
|
92
|
+
return new_t
|
93
|
+
end
|
94
|
+
def backward(dy)
|
95
|
+
return [dy * @scalar]
|
96
|
+
end
|
97
|
+
end
|
98
|
+
|
99
|
+
class DivScalar < Function
|
100
|
+
def initialize(scalar)
|
101
|
+
@scalar = scalar
|
102
|
+
end
|
103
|
+
def forward(x1)
|
104
|
+
new_arr = x1.data / @scalar
|
105
|
+
new_t = RubyZero::Core::Tensor.new(new_arr, device: x1.device)
|
106
|
+
return new_t
|
107
|
+
end
|
108
|
+
def backward(dy)
|
109
|
+
return [dy / @scalar]
|
110
|
+
end
|
111
|
+
end
|
112
|
+
end
|
@@ -0,0 +1,123 @@
|
|
1
|
+
module RubyZero::Core::Functions
|
2
|
+
class Reshape < Function
|
3
|
+
def initialize(shape)
|
4
|
+
@dist_shape = shape
|
5
|
+
end
|
6
|
+
def forward(x1)
|
7
|
+
new_arr = x1.data.reshape(@dist_shape)
|
8
|
+
@prev_shape = x1.shape
|
9
|
+
new_t = RubyZero::Core::Tensor.new(new_arr, device: x1.device)
|
10
|
+
return new_t
|
11
|
+
end
|
12
|
+
def backward(dy)
|
13
|
+
return [dy.reshape(@prev_shape)]
|
14
|
+
end
|
15
|
+
end
|
16
|
+
class SwapAxes < Function
|
17
|
+
def initialize(axis1, axis2)
|
18
|
+
@axis1 = axis1
|
19
|
+
@axis2 = axis2
|
20
|
+
end
|
21
|
+
def forward(x1)
|
22
|
+
new_arr = x1.data.swapaxes(@axis1, @axis2)
|
23
|
+
@prev_shape = x1.shape
|
24
|
+
new_t = RubyZero::Core::Tensor.new(new_arr, device: x1.device)
|
25
|
+
return new_t
|
26
|
+
end
|
27
|
+
def backward(dy)
|
28
|
+
return [dy.swapaxes(@axis1, @axis2)]
|
29
|
+
end
|
30
|
+
end
|
31
|
+
class Repeat < Function
|
32
|
+
def initialize(axis, repeats)
|
33
|
+
@axis = axis
|
34
|
+
@repeats = repeats
|
35
|
+
end
|
36
|
+
def forward(x1)
|
37
|
+
arr = x1.data
|
38
|
+
arr = arr.reshape(*([1] + arr.shape))
|
39
|
+
arr = arr.repeat(@repeats, axis:0)
|
40
|
+
arr = arr.swapaxes(0, @axis)
|
41
|
+
new_t = RubyZero::Core::Tensor.new(arr, device: x1.device)
|
42
|
+
return new_t
|
43
|
+
end
|
44
|
+
def backward(dy)
|
45
|
+
return [ dy.sum(axis: @axis) ]
|
46
|
+
end
|
47
|
+
end
|
48
|
+
|
49
|
+
class Sum < Function
|
50
|
+
def initialize(axis)
|
51
|
+
@axis = axis
|
52
|
+
end
|
53
|
+
def forward(x1)
|
54
|
+
@repeats = x1.shape[@axis]
|
55
|
+
arr = x1.data
|
56
|
+
arr = arr.sum(axis: @axis)
|
57
|
+
new_t = RubyZero::Core::Tensor.new(arr, device: x1.device)
|
58
|
+
return new_t
|
59
|
+
end
|
60
|
+
def backward(dy)
|
61
|
+
return [dy.repeat(@repeats, axis: @axis) / @repeats]
|
62
|
+
end
|
63
|
+
end
|
64
|
+
|
65
|
+
class Mean < Function
|
66
|
+
def initialize(axis)
|
67
|
+
@axis = axis
|
68
|
+
end
|
69
|
+
def forward(x1)
|
70
|
+
@repeats = x1.shape[@axis]
|
71
|
+
arr = x1.data
|
72
|
+
arr = arr.mean(axis: @axis)
|
73
|
+
new_t = RubyZero::Core::Tensor.new(arr, device: x1.device)
|
74
|
+
return new_t
|
75
|
+
end
|
76
|
+
def backward(dy)
|
77
|
+
return [dy.repeat(@repeats, axis: @axis)]
|
78
|
+
end
|
79
|
+
end
|
80
|
+
|
81
|
+
class DotProduct < Function
|
82
|
+
def initialize()
|
83
|
+
end
|
84
|
+
def forward(x1, x2)
|
85
|
+
arr = x1.data.dot(x2.data)
|
86
|
+
new_t = RubyZero::Core::Tensor.new(arr, device: x1.device)
|
87
|
+
return new_t
|
88
|
+
end
|
89
|
+
def backward(dy)
|
90
|
+
x1, x2 = @inputs
|
91
|
+
dx, dy = [dy.dot(x2.swapaxes(0,1)), x1.swapaxes(0,1).dot(dy)]
|
92
|
+
return dx, dy
|
93
|
+
end
|
94
|
+
end
|
95
|
+
end
|
96
|
+
|
97
|
+
# apply Tensor class
|
98
|
+
module RubyZero::Core
|
99
|
+
class Tensor
|
100
|
+
def reshape(shape)
|
101
|
+
return RubyZero::Core::Functions::Reshape.new(shape).call(self)
|
102
|
+
end
|
103
|
+
def swapaxes(axis1, axis2)
|
104
|
+
return RubyZero::Core::Functions::SwapAxes.new(axis1, axis2).call(self)
|
105
|
+
end
|
106
|
+
def repeat(repeats, axis:0)
|
107
|
+
return RubyZero::Core::Functions::Repeat.new(axis, repeats).call(self)
|
108
|
+
end
|
109
|
+
def sum(axis: 0)
|
110
|
+
return RubyZero::Core::Functions::Sum.new(axis).call(self)
|
111
|
+
end
|
112
|
+
def mean(axis: 0)
|
113
|
+
return RubyZero::Core::Functions::Mean.new(axis).call(self)
|
114
|
+
end
|
115
|
+
def dot(other)
|
116
|
+
if other.is_a?(RubyZero::Core::Tensor)
|
117
|
+
return RubyZero::Core::Functions::DotProduct.new().call(self, other)
|
118
|
+
else
|
119
|
+
return self*other
|
120
|
+
end
|
121
|
+
end
|
122
|
+
end
|
123
|
+
end
|