torch-rb 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +22 -0
- data/README.md +363 -0
- data/ext/torch/ext.cpp +546 -0
- data/ext/torch/extconf.rb +22 -0
- data/lib/torch-rb.rb +1 -0
- data/lib/torch.rb +327 -0
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +62 -0
- data/lib/torch/nn/conv2d.rb +50 -0
- data/lib/torch/nn/functional.rb +44 -0
- data/lib/torch/nn/init.rb +30 -0
- data/lib/torch/nn/linear.rb +36 -0
- data/lib/torch/nn/module.rb +56 -0
- data/lib/torch/nn/mse_loss.rb +13 -0
- data/lib/torch/nn/parameter.rb +10 -0
- data/lib/torch/nn/relu.rb +13 -0
- data/lib/torch/nn/sequential.rb +29 -0
- data/lib/torch/tensor.rb +143 -0
- data/lib/torch/utils/data/data_loader.rb +12 -0
- data/lib/torch/utils/data/tensor_dataset.rb +15 -0
- data/lib/torch/version.rb +3 -0
- metadata +149 -0
@@ -0,0 +1,44 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Functional
|
4
|
+
class << self
|
5
|
+
def relu(input)
|
6
|
+
Torch.relu(input)
|
7
|
+
end
|
8
|
+
|
9
|
+
def conv2d(input, weight, bias)
|
10
|
+
Torch.conv2d(input, weight, bias)
|
11
|
+
end
|
12
|
+
|
13
|
+
def max_pool2d(input, kernel_size)
|
14
|
+
kernel_size = [kernel_size, kernel_size] if kernel_size.is_a?(Integer)
|
15
|
+
Torch.max_pool2d(input, kernel_size)
|
16
|
+
end
|
17
|
+
|
18
|
+
def linear(input, weight, bias)
|
19
|
+
Torch.linear(input, weight, bias)
|
20
|
+
end
|
21
|
+
|
22
|
+
def mse_loss(input, target, reduction: "mean")
|
23
|
+
Torch.mse_loss(input, target, reduction)
|
24
|
+
end
|
25
|
+
|
26
|
+
def cross_entropy(input, target)
|
27
|
+
nll_loss(log_softmax(input, 1), target)
|
28
|
+
end
|
29
|
+
|
30
|
+
def nll_loss(input, target)
|
31
|
+
# TODO fix for non-1d
|
32
|
+
Torch.nll_loss(input, target)
|
33
|
+
end
|
34
|
+
|
35
|
+
def log_softmax(input, dim)
|
36
|
+
input.log_softmax(dim)
|
37
|
+
end
|
38
|
+
end
|
39
|
+
end
|
40
|
+
|
41
|
+
# shortcut
|
42
|
+
F = Functional
|
43
|
+
end
|
44
|
+
end
|
@@ -0,0 +1,30 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
module Init
|
4
|
+
class << self
|
5
|
+
def calculate_fan_in_and_fan_out(tensor)
|
6
|
+
dimensions = tensor.dim
|
7
|
+
if dimensions < 2
|
8
|
+
raise Error, "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"
|
9
|
+
end
|
10
|
+
|
11
|
+
if dimensions == 2
|
12
|
+
fan_in = tensor.size(1)
|
13
|
+
fan_out = tensor.size(0)
|
14
|
+
else
|
15
|
+
num_input_fmaps = tensor.size(1)
|
16
|
+
num_output_fmaps = tensor.size(0)
|
17
|
+
receptive_field_size = 1
|
18
|
+
if tensor.dim > 2
|
19
|
+
receptive_field_size = tensor[0][0].numel
|
20
|
+
end
|
21
|
+
fan_in = num_input_fmaps * receptive_field_size
|
22
|
+
fan_out = num_output_fmaps * receptive_field_size
|
23
|
+
end
|
24
|
+
|
25
|
+
[fan_in, fan_out]
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
@@ -0,0 +1,36 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Linear < Module
|
4
|
+
attr_reader :bias, :weight
|
5
|
+
|
6
|
+
def initialize(in_features, out_features, bias: true)
|
7
|
+
@in_features = in_features
|
8
|
+
@out_features = out_features
|
9
|
+
|
10
|
+
@weight = Parameter.new(Tensor.new(out_features, in_features))
|
11
|
+
if bias
|
12
|
+
@bias = Parameter.new(Tensor.new(out_features))
|
13
|
+
end
|
14
|
+
|
15
|
+
reset_parameters
|
16
|
+
end
|
17
|
+
|
18
|
+
def call(input)
|
19
|
+
F.linear(input, @weight, @bias)
|
20
|
+
end
|
21
|
+
|
22
|
+
def reset_parameters
|
23
|
+
Init.kaiming_uniform_(@weight, Math.sqrt(5))
|
24
|
+
if @bias
|
25
|
+
fan_in, _ = Init.calculate_fan_in_and_fan_out(@weight)
|
26
|
+
bound = 1 / Math.sqrt(fan_in)
|
27
|
+
Init.uniform_(@bias, -bound, bound)
|
28
|
+
end
|
29
|
+
end
|
30
|
+
|
31
|
+
def inspect
|
32
|
+
"Linear(in_features: #{@in_features.inspect}, out_features: #{@out_features.inspect}, bias: #{(!@bias.nil?).inspect})"
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
@@ -0,0 +1,56 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Module
|
4
|
+
def inspect
|
5
|
+
str = String.new
|
6
|
+
str << "#{self.class.name}(\n"
|
7
|
+
modules.each do |name, mod|
|
8
|
+
str << " (#{name}): #{mod.inspect}\n"
|
9
|
+
end
|
10
|
+
str << ")"
|
11
|
+
end
|
12
|
+
|
13
|
+
def call(*input)
|
14
|
+
forward(*input)
|
15
|
+
end
|
16
|
+
|
17
|
+
def parameters
|
18
|
+
params = []
|
19
|
+
instance_variables.each do |name|
|
20
|
+
param = instance_variable_get(name)
|
21
|
+
params << param if param.is_a?(Parameter)
|
22
|
+
end
|
23
|
+
params + modules.flat_map { |_, mod| mod.parameters }
|
24
|
+
end
|
25
|
+
|
26
|
+
def zero_grad
|
27
|
+
parameters.each do |param|
|
28
|
+
if param.grad
|
29
|
+
raise Error, "Not supported yet"
|
30
|
+
param.grad.detach!
|
31
|
+
param.grad.zero!
|
32
|
+
end
|
33
|
+
end
|
34
|
+
end
|
35
|
+
|
36
|
+
def method_missing(method, *args, &block)
|
37
|
+
modules[method.to_s] || super
|
38
|
+
end
|
39
|
+
|
40
|
+
def respond_to?(method, include_private = false)
|
41
|
+
modules.key?(method.to_s) || super
|
42
|
+
end
|
43
|
+
|
44
|
+
private
|
45
|
+
|
46
|
+
def modules
|
47
|
+
modules = {}
|
48
|
+
instance_variables.each do |name|
|
49
|
+
mod = instance_variable_get(name)
|
50
|
+
modules[name[1..-1]] = mod if mod.is_a?(Module)
|
51
|
+
end
|
52
|
+
modules
|
53
|
+
end
|
54
|
+
end
|
55
|
+
end
|
56
|
+
end
|
@@ -0,0 +1,29 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Sequential < Module
|
4
|
+
def initialize(*args)
|
5
|
+
@modules = {}
|
6
|
+
# TODO support hash arg (named modules)
|
7
|
+
args.each_with_index do |mod, idx|
|
8
|
+
add_module(idx.to_s, mod)
|
9
|
+
end
|
10
|
+
end
|
11
|
+
|
12
|
+
def add_module(name, mod)
|
13
|
+
# TODO add checks
|
14
|
+
@modules[name] = mod
|
15
|
+
end
|
16
|
+
|
17
|
+
def forward(input)
|
18
|
+
@modules.values.each do |mod|
|
19
|
+
input = mod.call(input)
|
20
|
+
end
|
21
|
+
input
|
22
|
+
end
|
23
|
+
|
24
|
+
def parameters
|
25
|
+
@modules.flat_map { |_, mod| mod.parameters }
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
data/lib/torch/tensor.rb
ADDED
@@ -0,0 +1,143 @@
|
|
1
|
+
module Torch
|
2
|
+
class Tensor
|
3
|
+
include Comparable
|
4
|
+
include Inspector
|
5
|
+
|
6
|
+
alias_method :requires_grad?, :requires_grad
|
7
|
+
|
8
|
+
def self.new(*size)
|
9
|
+
if size.first.is_a?(Tensor)
|
10
|
+
size.first
|
11
|
+
else
|
12
|
+
Torch.rand(*size)
|
13
|
+
end
|
14
|
+
end
|
15
|
+
|
16
|
+
def dtype
|
17
|
+
dtype = ENUM_TO_DTYPE[_dtype]
|
18
|
+
raise Error, "Unknown type: #{_dtype}" unless dtype
|
19
|
+
dtype
|
20
|
+
end
|
21
|
+
|
22
|
+
def layout
|
23
|
+
_layout.downcase.to_sym
|
24
|
+
end
|
25
|
+
|
26
|
+
def to_s
|
27
|
+
inspect
|
28
|
+
end
|
29
|
+
|
30
|
+
def to_a
|
31
|
+
reshape(_data, shape)
|
32
|
+
end
|
33
|
+
|
34
|
+
def size(dim = nil)
|
35
|
+
if dim
|
36
|
+
_size(dim)
|
37
|
+
else
|
38
|
+
shape
|
39
|
+
end
|
40
|
+
end
|
41
|
+
|
42
|
+
def shape
|
43
|
+
dim.times.map { |i| size(i) }
|
44
|
+
end
|
45
|
+
|
46
|
+
def view(*size)
|
47
|
+
_view(size)
|
48
|
+
end
|
49
|
+
|
50
|
+
def item
|
51
|
+
if numel != 1
|
52
|
+
raise Error, "only one element tensors can be converted to Ruby scalars"
|
53
|
+
end
|
54
|
+
_data.first
|
55
|
+
end
|
56
|
+
|
57
|
+
def data
|
58
|
+
Torch.tensor(to_a)
|
59
|
+
end
|
60
|
+
|
61
|
+
# TODO read directly from memory
|
62
|
+
def numo
|
63
|
+
raise Error, "Numo not found" unless defined?(Numo::NArray)
|
64
|
+
cls = Torch._dtype_to_numo[dtype]
|
65
|
+
raise Error, "Cannot convert #{dtype} to Numo" unless cls
|
66
|
+
cls.cast(_data).reshape(*shape)
|
67
|
+
end
|
68
|
+
|
69
|
+
def new_ones(*size, **options)
|
70
|
+
Torch.ones_like(Torch.empty(*size), **options)
|
71
|
+
end
|
72
|
+
|
73
|
+
def requires_grad!(requires_grad = true)
|
74
|
+
_requires_grad!(requires_grad)
|
75
|
+
end
|
76
|
+
|
77
|
+
# operations
|
78
|
+
%w(add sub mul div remainder pow neg sum mean num norm min max dot matmul exp log unsqueeze).each do |op|
|
79
|
+
define_method(op) do |*args, **options, &block|
|
80
|
+
if options.any?
|
81
|
+
Torch.send(op, self, *args, **options, &block)
|
82
|
+
else
|
83
|
+
Torch.send(op, self, *args, &block)
|
84
|
+
end
|
85
|
+
end
|
86
|
+
end
|
87
|
+
|
88
|
+
def +(other)
|
89
|
+
add(other)
|
90
|
+
end
|
91
|
+
|
92
|
+
def -(other)
|
93
|
+
sub(other)
|
94
|
+
end
|
95
|
+
|
96
|
+
def *(other)
|
97
|
+
mul(other)
|
98
|
+
end
|
99
|
+
|
100
|
+
def /(other)
|
101
|
+
div(other)
|
102
|
+
end
|
103
|
+
|
104
|
+
def %(other)
|
105
|
+
remainder(other)
|
106
|
+
end
|
107
|
+
|
108
|
+
def **(other)
|
109
|
+
pow(other)
|
110
|
+
end
|
111
|
+
|
112
|
+
def -@
|
113
|
+
neg
|
114
|
+
end
|
115
|
+
|
116
|
+
def <=>(other)
|
117
|
+
item <=> other
|
118
|
+
end
|
119
|
+
|
120
|
+
# TODO use accessor C++ method
|
121
|
+
def [](index, *args)
|
122
|
+
v = _access(index)
|
123
|
+
args.each do |i|
|
124
|
+
v = v._access(i)
|
125
|
+
end
|
126
|
+
v
|
127
|
+
end
|
128
|
+
|
129
|
+
private
|
130
|
+
|
131
|
+
def reshape(arr, dims)
|
132
|
+
if dims.empty?
|
133
|
+
arr
|
134
|
+
else
|
135
|
+
arr = arr.flatten
|
136
|
+
dims[1..-1].reverse.each do |dim|
|
137
|
+
arr = arr.each_slice(dim)
|
138
|
+
end
|
139
|
+
arr.to_a
|
140
|
+
end
|
141
|
+
end
|
142
|
+
end
|
143
|
+
end
|
metadata
ADDED
@@ -0,0 +1,149 @@
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
2
|
+
name: torch-rb
|
3
|
+
version: !ruby/object:Gem::Version
|
4
|
+
version: 0.1.0
|
5
|
+
platform: ruby
|
6
|
+
authors:
|
7
|
+
- Andrew Kane
|
8
|
+
autorequire:
|
9
|
+
bindir: bin
|
10
|
+
cert_chain: []
|
11
|
+
date: 2019-11-26 00:00:00.000000000 Z
|
12
|
+
dependencies:
|
13
|
+
- !ruby/object:Gem::Dependency
|
14
|
+
name: rice
|
15
|
+
requirement: !ruby/object:Gem::Requirement
|
16
|
+
requirements:
|
17
|
+
- - ">="
|
18
|
+
- !ruby/object:Gem::Version
|
19
|
+
version: '0'
|
20
|
+
type: :runtime
|
21
|
+
prerelease: false
|
22
|
+
version_requirements: !ruby/object:Gem::Requirement
|
23
|
+
requirements:
|
24
|
+
- - ">="
|
25
|
+
- !ruby/object:Gem::Version
|
26
|
+
version: '0'
|
27
|
+
- !ruby/object:Gem::Dependency
|
28
|
+
name: bundler
|
29
|
+
requirement: !ruby/object:Gem::Requirement
|
30
|
+
requirements:
|
31
|
+
- - ">="
|
32
|
+
- !ruby/object:Gem::Version
|
33
|
+
version: '0'
|
34
|
+
type: :development
|
35
|
+
prerelease: false
|
36
|
+
version_requirements: !ruby/object:Gem::Requirement
|
37
|
+
requirements:
|
38
|
+
- - ">="
|
39
|
+
- !ruby/object:Gem::Version
|
40
|
+
version: '0'
|
41
|
+
- !ruby/object:Gem::Dependency
|
42
|
+
name: rake
|
43
|
+
requirement: !ruby/object:Gem::Requirement
|
44
|
+
requirements:
|
45
|
+
- - ">="
|
46
|
+
- !ruby/object:Gem::Version
|
47
|
+
version: '0'
|
48
|
+
type: :development
|
49
|
+
prerelease: false
|
50
|
+
version_requirements: !ruby/object:Gem::Requirement
|
51
|
+
requirements:
|
52
|
+
- - ">="
|
53
|
+
- !ruby/object:Gem::Version
|
54
|
+
version: '0'
|
55
|
+
- !ruby/object:Gem::Dependency
|
56
|
+
name: rake-compiler
|
57
|
+
requirement: !ruby/object:Gem::Requirement
|
58
|
+
requirements:
|
59
|
+
- - ">="
|
60
|
+
- !ruby/object:Gem::Version
|
61
|
+
version: '0'
|
62
|
+
type: :development
|
63
|
+
prerelease: false
|
64
|
+
version_requirements: !ruby/object:Gem::Requirement
|
65
|
+
requirements:
|
66
|
+
- - ">="
|
67
|
+
- !ruby/object:Gem::Version
|
68
|
+
version: '0'
|
69
|
+
- !ruby/object:Gem::Dependency
|
70
|
+
name: minitest
|
71
|
+
requirement: !ruby/object:Gem::Requirement
|
72
|
+
requirements:
|
73
|
+
- - ">="
|
74
|
+
- !ruby/object:Gem::Version
|
75
|
+
version: '5'
|
76
|
+
type: :development
|
77
|
+
prerelease: false
|
78
|
+
version_requirements: !ruby/object:Gem::Requirement
|
79
|
+
requirements:
|
80
|
+
- - ">="
|
81
|
+
- !ruby/object:Gem::Version
|
82
|
+
version: '5'
|
83
|
+
- !ruby/object:Gem::Dependency
|
84
|
+
name: numo-narray
|
85
|
+
requirement: !ruby/object:Gem::Requirement
|
86
|
+
requirements:
|
87
|
+
- - ">="
|
88
|
+
- !ruby/object:Gem::Version
|
89
|
+
version: '0'
|
90
|
+
type: :development
|
91
|
+
prerelease: false
|
92
|
+
version_requirements: !ruby/object:Gem::Requirement
|
93
|
+
requirements:
|
94
|
+
- - ">="
|
95
|
+
- !ruby/object:Gem::Version
|
96
|
+
version: '0'
|
97
|
+
description:
|
98
|
+
email: andrew@chartkick.com
|
99
|
+
executables: []
|
100
|
+
extensions:
|
101
|
+
- ext/torch/extconf.rb
|
102
|
+
extra_rdoc_files: []
|
103
|
+
files:
|
104
|
+
- CHANGELOG.md
|
105
|
+
- LICENSE.txt
|
106
|
+
- README.md
|
107
|
+
- ext/torch/ext.cpp
|
108
|
+
- ext/torch/extconf.rb
|
109
|
+
- lib/torch-rb.rb
|
110
|
+
- lib/torch.rb
|
111
|
+
- lib/torch/ext.bundle
|
112
|
+
- lib/torch/inspector.rb
|
113
|
+
- lib/torch/nn/conv2d.rb
|
114
|
+
- lib/torch/nn/functional.rb
|
115
|
+
- lib/torch/nn/init.rb
|
116
|
+
- lib/torch/nn/linear.rb
|
117
|
+
- lib/torch/nn/module.rb
|
118
|
+
- lib/torch/nn/mse_loss.rb
|
119
|
+
- lib/torch/nn/parameter.rb
|
120
|
+
- lib/torch/nn/relu.rb
|
121
|
+
- lib/torch/nn/sequential.rb
|
122
|
+
- lib/torch/tensor.rb
|
123
|
+
- lib/torch/utils/data/data_loader.rb
|
124
|
+
- lib/torch/utils/data/tensor_dataset.rb
|
125
|
+
- lib/torch/version.rb
|
126
|
+
homepage: https://github.com/ankane/torch-rb
|
127
|
+
licenses:
|
128
|
+
- MIT
|
129
|
+
metadata: {}
|
130
|
+
post_install_message:
|
131
|
+
rdoc_options: []
|
132
|
+
require_paths:
|
133
|
+
- lib
|
134
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
135
|
+
requirements:
|
136
|
+
- - ">="
|
137
|
+
- !ruby/object:Gem::Version
|
138
|
+
version: '2.4'
|
139
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
140
|
+
requirements:
|
141
|
+
- - ">="
|
142
|
+
- !ruby/object:Gem::Version
|
143
|
+
version: '0'
|
144
|
+
requirements: []
|
145
|
+
rubygems_version: 3.0.3
|
146
|
+
signing_key:
|
147
|
+
specification_version: 4
|
148
|
+
summary: Deep learning for Ruby, powered by LibTorch
|
149
|
+
test_files: []
|