torch-rb 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +22 -0
- data/README.md +363 -0
- data/ext/torch/ext.cpp +546 -0
- data/ext/torch/extconf.rb +22 -0
- data/lib/torch-rb.rb +1 -0
- data/lib/torch.rb +327 -0
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +62 -0
- data/lib/torch/nn/conv2d.rb +50 -0
- data/lib/torch/nn/functional.rb +44 -0
- data/lib/torch/nn/init.rb +30 -0
- data/lib/torch/nn/linear.rb +36 -0
- data/lib/torch/nn/module.rb +56 -0
- data/lib/torch/nn/mse_loss.rb +13 -0
- data/lib/torch/nn/parameter.rb +10 -0
- data/lib/torch/nn/relu.rb +13 -0
- data/lib/torch/nn/sequential.rb +29 -0
- data/lib/torch/tensor.rb +143 -0
- data/lib/torch/utils/data/data_loader.rb +12 -0
- data/lib/torch/utils/data/tensor_dataset.rb +15 -0
- data/lib/torch/version.rb +3 -0
- metadata +149 -0
@@ -0,0 +1,44 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Functional
|
4
|
+
class << self
|
5
|
+
def relu(input)
|
6
|
+
Torch.relu(input)
|
7
|
+
end
|
8
|
+
|
9
|
+
def conv2d(input, weight, bias)
|
10
|
+
Torch.conv2d(input, weight, bias)
|
11
|
+
end
|
12
|
+
|
13
|
+
def max_pool2d(input, kernel_size)
|
14
|
+
kernel_size = [kernel_size, kernel_size] if kernel_size.is_a?(Integer)
|
15
|
+
Torch.max_pool2d(input, kernel_size)
|
16
|
+
end
|
17
|
+
|
18
|
+
def linear(input, weight, bias)
|
19
|
+
Torch.linear(input, weight, bias)
|
20
|
+
end
|
21
|
+
|
22
|
+
def mse_loss(input, target, reduction: "mean")
|
23
|
+
Torch.mse_loss(input, target, reduction)
|
24
|
+
end
|
25
|
+
|
26
|
+
def cross_entropy(input, target)
|
27
|
+
nll_loss(log_softmax(input, 1), target)
|
28
|
+
end
|
29
|
+
|
30
|
+
def nll_loss(input, target)
|
31
|
+
# TODO fix for non-1d
|
32
|
+
Torch.nll_loss(input, target)
|
33
|
+
end
|
34
|
+
|
35
|
+
def log_softmax(input, dim)
|
36
|
+
input.log_softmax(dim)
|
37
|
+
end
|
38
|
+
end
|
39
|
+
end
|
40
|
+
|
41
|
+
# shortcut
|
42
|
+
F = Functional
|
43
|
+
end
|
44
|
+
end
|
@@ -0,0 +1,30 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
module Init
|
4
|
+
class << self
|
5
|
+
def calculate_fan_in_and_fan_out(tensor)
|
6
|
+
dimensions = tensor.dim
|
7
|
+
if dimensions < 2
|
8
|
+
raise Error, "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"
|
9
|
+
end
|
10
|
+
|
11
|
+
if dimensions == 2
|
12
|
+
fan_in = tensor.size(1)
|
13
|
+
fan_out = tensor.size(0)
|
14
|
+
else
|
15
|
+
num_input_fmaps = tensor.size(1)
|
16
|
+
num_output_fmaps = tensor.size(0)
|
17
|
+
receptive_field_size = 1
|
18
|
+
if tensor.dim > 2
|
19
|
+
receptive_field_size = tensor[0][0].numel
|
20
|
+
end
|
21
|
+
fan_in = num_input_fmaps * receptive_field_size
|
22
|
+
fan_out = num_output_fmaps * receptive_field_size
|
23
|
+
end
|
24
|
+
|
25
|
+
[fan_in, fan_out]
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
@@ -0,0 +1,36 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Linear < Module
|
4
|
+
attr_reader :bias, :weight
|
5
|
+
|
6
|
+
def initialize(in_features, out_features, bias: true)
|
7
|
+
@in_features = in_features
|
8
|
+
@out_features = out_features
|
9
|
+
|
10
|
+
@weight = Parameter.new(Tensor.new(out_features, in_features))
|
11
|
+
if bias
|
12
|
+
@bias = Parameter.new(Tensor.new(out_features))
|
13
|
+
end
|
14
|
+
|
15
|
+
reset_parameters
|
16
|
+
end
|
17
|
+
|
18
|
+
def call(input)
|
19
|
+
F.linear(input, @weight, @bias)
|
20
|
+
end
|
21
|
+
|
22
|
+
def reset_parameters
|
23
|
+
Init.kaiming_uniform_(@weight, Math.sqrt(5))
|
24
|
+
if @bias
|
25
|
+
fan_in, _ = Init.calculate_fan_in_and_fan_out(@weight)
|
26
|
+
bound = 1 / Math.sqrt(fan_in)
|
27
|
+
Init.uniform_(@bias, -bound, bound)
|
28
|
+
end
|
29
|
+
end
|
30
|
+
|
31
|
+
def inspect
|
32
|
+
"Linear(in_features: #{@in_features.inspect}, out_features: #{@out_features.inspect}, bias: #{(!@bias.nil?).inspect})"
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
@@ -0,0 +1,56 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Module
|
4
|
+
def inspect
|
5
|
+
str = String.new
|
6
|
+
str << "#{self.class.name}(\n"
|
7
|
+
modules.each do |name, mod|
|
8
|
+
str << " (#{name}): #{mod.inspect}\n"
|
9
|
+
end
|
10
|
+
str << ")"
|
11
|
+
end
|
12
|
+
|
13
|
+
def call(*input)
|
14
|
+
forward(*input)
|
15
|
+
end
|
16
|
+
|
17
|
+
def parameters
|
18
|
+
params = []
|
19
|
+
instance_variables.each do |name|
|
20
|
+
param = instance_variable_get(name)
|
21
|
+
params << param if param.is_a?(Parameter)
|
22
|
+
end
|
23
|
+
params + modules.flat_map { |_, mod| mod.parameters }
|
24
|
+
end
|
25
|
+
|
26
|
+
def zero_grad
|
27
|
+
parameters.each do |param|
|
28
|
+
if param.grad
|
29
|
+
raise Error, "Not supported yet"
|
30
|
+
param.grad.detach!
|
31
|
+
param.grad.zero!
|
32
|
+
end
|
33
|
+
end
|
34
|
+
end
|
35
|
+
|
36
|
+
def method_missing(method, *args, &block)
|
37
|
+
modules[method.to_s] || super
|
38
|
+
end
|
39
|
+
|
40
|
+
def respond_to?(method, include_private = false)
|
41
|
+
modules.key?(method.to_s) || super
|
42
|
+
end
|
43
|
+
|
44
|
+
private
|
45
|
+
|
46
|
+
def modules
|
47
|
+
modules = {}
|
48
|
+
instance_variables.each do |name|
|
49
|
+
mod = instance_variable_get(name)
|
50
|
+
modules[name[1..-1]] = mod if mod.is_a?(Module)
|
51
|
+
end
|
52
|
+
modules
|
53
|
+
end
|
54
|
+
end
|
55
|
+
end
|
56
|
+
end
|
@@ -0,0 +1,29 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Sequential < Module
|
4
|
+
def initialize(*args)
|
5
|
+
@modules = {}
|
6
|
+
# TODO support hash arg (named modules)
|
7
|
+
args.each_with_index do |mod, idx|
|
8
|
+
add_module(idx.to_s, mod)
|
9
|
+
end
|
10
|
+
end
|
11
|
+
|
12
|
+
def add_module(name, mod)
|
13
|
+
# TODO add checks
|
14
|
+
@modules[name] = mod
|
15
|
+
end
|
16
|
+
|
17
|
+
def forward(input)
|
18
|
+
@modules.values.each do |mod|
|
19
|
+
input = mod.call(input)
|
20
|
+
end
|
21
|
+
input
|
22
|
+
end
|
23
|
+
|
24
|
+
def parameters
|
25
|
+
@modules.flat_map { |_, mod| mod.parameters }
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
data/lib/torch/tensor.rb
ADDED
@@ -0,0 +1,143 @@
|
|
1
|
+
module Torch
|
2
|
+
class Tensor
|
3
|
+
include Comparable
|
4
|
+
include Inspector
|
5
|
+
|
6
|
+
alias_method :requires_grad?, :requires_grad
|
7
|
+
|
8
|
+
def self.new(*size)
|
9
|
+
if size.first.is_a?(Tensor)
|
10
|
+
size.first
|
11
|
+
else
|
12
|
+
Torch.rand(*size)
|
13
|
+
end
|
14
|
+
end
|
15
|
+
|
16
|
+
def dtype
|
17
|
+
dtype = ENUM_TO_DTYPE[_dtype]
|
18
|
+
raise Error, "Unknown type: #{_dtype}" unless dtype
|
19
|
+
dtype
|
20
|
+
end
|
21
|
+
|
22
|
+
def layout
|
23
|
+
_layout.downcase.to_sym
|
24
|
+
end
|
25
|
+
|
26
|
+
def to_s
|
27
|
+
inspect
|
28
|
+
end
|
29
|
+
|
30
|
+
def to_a
|
31
|
+
reshape(_data, shape)
|
32
|
+
end
|
33
|
+
|
34
|
+
def size(dim = nil)
|
35
|
+
if dim
|
36
|
+
_size(dim)
|
37
|
+
else
|
38
|
+
shape
|
39
|
+
end
|
40
|
+
end
|
41
|
+
|
42
|
+
def shape
|
43
|
+
dim.times.map { |i| size(i) }
|
44
|
+
end
|
45
|
+
|
46
|
+
def view(*size)
|
47
|
+
_view(size)
|
48
|
+
end
|
49
|
+
|
50
|
+
def item
|
51
|
+
if numel != 1
|
52
|
+
raise Error, "only one element tensors can be converted to Ruby scalars"
|
53
|
+
end
|
54
|
+
_data.first
|
55
|
+
end
|
56
|
+
|
57
|
+
def data
|
58
|
+
Torch.tensor(to_a)
|
59
|
+
end
|
60
|
+
|
61
|
+
# TODO read directly from memory
|
62
|
+
def numo
|
63
|
+
raise Error, "Numo not found" unless defined?(Numo::NArray)
|
64
|
+
cls = Torch._dtype_to_numo[dtype]
|
65
|
+
raise Error, "Cannot convert #{dtype} to Numo" unless cls
|
66
|
+
cls.cast(_data).reshape(*shape)
|
67
|
+
end
|
68
|
+
|
69
|
+
def new_ones(*size, **options)
|
70
|
+
Torch.ones_like(Torch.empty(*size), **options)
|
71
|
+
end
|
72
|
+
|
73
|
+
def requires_grad!(requires_grad = true)
|
74
|
+
_requires_grad!(requires_grad)
|
75
|
+
end
|
76
|
+
|
77
|
+
# operations
|
78
|
+
%w(add sub mul div remainder pow neg sum mean num norm min max dot matmul exp log unsqueeze).each do |op|
|
79
|
+
define_method(op) do |*args, **options, &block|
|
80
|
+
if options.any?
|
81
|
+
Torch.send(op, self, *args, **options, &block)
|
82
|
+
else
|
83
|
+
Torch.send(op, self, *args, &block)
|
84
|
+
end
|
85
|
+
end
|
86
|
+
end
|
87
|
+
|
88
|
+
def +(other)
|
89
|
+
add(other)
|
90
|
+
end
|
91
|
+
|
92
|
+
def -(other)
|
93
|
+
sub(other)
|
94
|
+
end
|
95
|
+
|
96
|
+
def *(other)
|
97
|
+
mul(other)
|
98
|
+
end
|
99
|
+
|
100
|
+
def /(other)
|
101
|
+
div(other)
|
102
|
+
end
|
103
|
+
|
104
|
+
def %(other)
|
105
|
+
remainder(other)
|
106
|
+
end
|
107
|
+
|
108
|
+
def **(other)
|
109
|
+
pow(other)
|
110
|
+
end
|
111
|
+
|
112
|
+
def -@
|
113
|
+
neg
|
114
|
+
end
|
115
|
+
|
116
|
+
def <=>(other)
|
117
|
+
item <=> other
|
118
|
+
end
|
119
|
+
|
120
|
+
# TODO use accessor C++ method
|
121
|
+
def [](index, *args)
|
122
|
+
v = _access(index)
|
123
|
+
args.each do |i|
|
124
|
+
v = v._access(i)
|
125
|
+
end
|
126
|
+
v
|
127
|
+
end
|
128
|
+
|
129
|
+
private
|
130
|
+
|
131
|
+
def reshape(arr, dims)
|
132
|
+
if dims.empty?
|
133
|
+
arr
|
134
|
+
else
|
135
|
+
arr = arr.flatten
|
136
|
+
dims[1..-1].reverse.each do |dim|
|
137
|
+
arr = arr.each_slice(dim)
|
138
|
+
end
|
139
|
+
arr.to_a
|
140
|
+
end
|
141
|
+
end
|
142
|
+
end
|
143
|
+
end
|
metadata
ADDED
@@ -0,0 +1,149 @@
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
2
|
+
name: torch-rb
|
3
|
+
version: !ruby/object:Gem::Version
|
4
|
+
version: 0.1.0
|
5
|
+
platform: ruby
|
6
|
+
authors:
|
7
|
+
- Andrew Kane
|
8
|
+
autorequire:
|
9
|
+
bindir: bin
|
10
|
+
cert_chain: []
|
11
|
+
date: 2019-11-26 00:00:00.000000000 Z
|
12
|
+
dependencies:
|
13
|
+
- !ruby/object:Gem::Dependency
|
14
|
+
name: rice
|
15
|
+
requirement: !ruby/object:Gem::Requirement
|
16
|
+
requirements:
|
17
|
+
- - ">="
|
18
|
+
- !ruby/object:Gem::Version
|
19
|
+
version: '0'
|
20
|
+
type: :runtime
|
21
|
+
prerelease: false
|
22
|
+
version_requirements: !ruby/object:Gem::Requirement
|
23
|
+
requirements:
|
24
|
+
- - ">="
|
25
|
+
- !ruby/object:Gem::Version
|
26
|
+
version: '0'
|
27
|
+
- !ruby/object:Gem::Dependency
|
28
|
+
name: bundler
|
29
|
+
requirement: !ruby/object:Gem::Requirement
|
30
|
+
requirements:
|
31
|
+
- - ">="
|
32
|
+
- !ruby/object:Gem::Version
|
33
|
+
version: '0'
|
34
|
+
type: :development
|
35
|
+
prerelease: false
|
36
|
+
version_requirements: !ruby/object:Gem::Requirement
|
37
|
+
requirements:
|
38
|
+
- - ">="
|
39
|
+
- !ruby/object:Gem::Version
|
40
|
+
version: '0'
|
41
|
+
- !ruby/object:Gem::Dependency
|
42
|
+
name: rake
|
43
|
+
requirement: !ruby/object:Gem::Requirement
|
44
|
+
requirements:
|
45
|
+
- - ">="
|
46
|
+
- !ruby/object:Gem::Version
|
47
|
+
version: '0'
|
48
|
+
type: :development
|
49
|
+
prerelease: false
|
50
|
+
version_requirements: !ruby/object:Gem::Requirement
|
51
|
+
requirements:
|
52
|
+
- - ">="
|
53
|
+
- !ruby/object:Gem::Version
|
54
|
+
version: '0'
|
55
|
+
- !ruby/object:Gem::Dependency
|
56
|
+
name: rake-compiler
|
57
|
+
requirement: !ruby/object:Gem::Requirement
|
58
|
+
requirements:
|
59
|
+
- - ">="
|
60
|
+
- !ruby/object:Gem::Version
|
61
|
+
version: '0'
|
62
|
+
type: :development
|
63
|
+
prerelease: false
|
64
|
+
version_requirements: !ruby/object:Gem::Requirement
|
65
|
+
requirements:
|
66
|
+
- - ">="
|
67
|
+
- !ruby/object:Gem::Version
|
68
|
+
version: '0'
|
69
|
+
- !ruby/object:Gem::Dependency
|
70
|
+
name: minitest
|
71
|
+
requirement: !ruby/object:Gem::Requirement
|
72
|
+
requirements:
|
73
|
+
- - ">="
|
74
|
+
- !ruby/object:Gem::Version
|
75
|
+
version: '5'
|
76
|
+
type: :development
|
77
|
+
prerelease: false
|
78
|
+
version_requirements: !ruby/object:Gem::Requirement
|
79
|
+
requirements:
|
80
|
+
- - ">="
|
81
|
+
- !ruby/object:Gem::Version
|
82
|
+
version: '5'
|
83
|
+
- !ruby/object:Gem::Dependency
|
84
|
+
name: numo-narray
|
85
|
+
requirement: !ruby/object:Gem::Requirement
|
86
|
+
requirements:
|
87
|
+
- - ">="
|
88
|
+
- !ruby/object:Gem::Version
|
89
|
+
version: '0'
|
90
|
+
type: :development
|
91
|
+
prerelease: false
|
92
|
+
version_requirements: !ruby/object:Gem::Requirement
|
93
|
+
requirements:
|
94
|
+
- - ">="
|
95
|
+
- !ruby/object:Gem::Version
|
96
|
+
version: '0'
|
97
|
+
description:
|
98
|
+
email: andrew@chartkick.com
|
99
|
+
executables: []
|
100
|
+
extensions:
|
101
|
+
- ext/torch/extconf.rb
|
102
|
+
extra_rdoc_files: []
|
103
|
+
files:
|
104
|
+
- CHANGELOG.md
|
105
|
+
- LICENSE.txt
|
106
|
+
- README.md
|
107
|
+
- ext/torch/ext.cpp
|
108
|
+
- ext/torch/extconf.rb
|
109
|
+
- lib/torch-rb.rb
|
110
|
+
- lib/torch.rb
|
111
|
+
- lib/torch/ext.bundle
|
112
|
+
- lib/torch/inspector.rb
|
113
|
+
- lib/torch/nn/conv2d.rb
|
114
|
+
- lib/torch/nn/functional.rb
|
115
|
+
- lib/torch/nn/init.rb
|
116
|
+
- lib/torch/nn/linear.rb
|
117
|
+
- lib/torch/nn/module.rb
|
118
|
+
- lib/torch/nn/mse_loss.rb
|
119
|
+
- lib/torch/nn/parameter.rb
|
120
|
+
- lib/torch/nn/relu.rb
|
121
|
+
- lib/torch/nn/sequential.rb
|
122
|
+
- lib/torch/tensor.rb
|
123
|
+
- lib/torch/utils/data/data_loader.rb
|
124
|
+
- lib/torch/utils/data/tensor_dataset.rb
|
125
|
+
- lib/torch/version.rb
|
126
|
+
homepage: https://github.com/ankane/torch-rb
|
127
|
+
licenses:
|
128
|
+
- MIT
|
129
|
+
metadata: {}
|
130
|
+
post_install_message:
|
131
|
+
rdoc_options: []
|
132
|
+
require_paths:
|
133
|
+
- lib
|
134
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
135
|
+
requirements:
|
136
|
+
- - ">="
|
137
|
+
- !ruby/object:Gem::Version
|
138
|
+
version: '2.4'
|
139
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
140
|
+
requirements:
|
141
|
+
- - ">="
|
142
|
+
- !ruby/object:Gem::Version
|
143
|
+
version: '0'
|
144
|
+
requirements: []
|
145
|
+
rubygems_version: 3.0.3
|
146
|
+
signing_key:
|
147
|
+
specification_version: 4
|
148
|
+
summary: Deep learning for Ruby, powered by LibTorch
|
149
|
+
test_files: []
|