torch-rb 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,44 @@
1
+ module Torch
2
+ module NN
3
+ class Functional
4
+ class << self
5
+ def relu(input)
6
+ Torch.relu(input)
7
+ end
8
+
9
+ def conv2d(input, weight, bias)
10
+ Torch.conv2d(input, weight, bias)
11
+ end
12
+
13
+ def max_pool2d(input, kernel_size)
14
+ kernel_size = [kernel_size, kernel_size] if kernel_size.is_a?(Integer)
15
+ Torch.max_pool2d(input, kernel_size)
16
+ end
17
+
18
+ def linear(input, weight, bias)
19
+ Torch.linear(input, weight, bias)
20
+ end
21
+
22
+ def mse_loss(input, target, reduction: "mean")
23
+ Torch.mse_loss(input, target, reduction)
24
+ end
25
+
26
+ def cross_entropy(input, target)
27
+ nll_loss(log_softmax(input, 1), target)
28
+ end
29
+
30
+ def nll_loss(input, target)
31
+ # TODO fix for non-1d
32
+ Torch.nll_loss(input, target)
33
+ end
34
+
35
+ def log_softmax(input, dim)
36
+ input.log_softmax(dim)
37
+ end
38
+ end
39
+ end
40
+
41
+ # shortcut
42
+ F = Functional
43
+ end
44
+ end
@@ -0,0 +1,30 @@
1
+ module Torch
2
+ module NN
3
+ module Init
4
+ class << self
5
+ def calculate_fan_in_and_fan_out(tensor)
6
+ dimensions = tensor.dim
7
+ if dimensions < 2
8
+ raise Error, "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"
9
+ end
10
+
11
+ if dimensions == 2
12
+ fan_in = tensor.size(1)
13
+ fan_out = tensor.size(0)
14
+ else
15
+ num_input_fmaps = tensor.size(1)
16
+ num_output_fmaps = tensor.size(0)
17
+ receptive_field_size = 1
18
+ if tensor.dim > 2
19
+ receptive_field_size = tensor[0][0].numel
20
+ end
21
+ fan_in = num_input_fmaps * receptive_field_size
22
+ fan_out = num_output_fmaps * receptive_field_size
23
+ end
24
+
25
+ [fan_in, fan_out]
26
+ end
27
+ end
28
+ end
29
+ end
30
+ end
@@ -0,0 +1,36 @@
1
+ module Torch
2
+ module NN
3
+ class Linear < Module
4
+ attr_reader :bias, :weight
5
+
6
+ def initialize(in_features, out_features, bias: true)
7
+ @in_features = in_features
8
+ @out_features = out_features
9
+
10
+ @weight = Parameter.new(Tensor.new(out_features, in_features))
11
+ if bias
12
+ @bias = Parameter.new(Tensor.new(out_features))
13
+ end
14
+
15
+ reset_parameters
16
+ end
17
+
18
+ def call(input)
19
+ F.linear(input, @weight, @bias)
20
+ end
21
+
22
+ def reset_parameters
23
+ Init.kaiming_uniform_(@weight, Math.sqrt(5))
24
+ if @bias
25
+ fan_in, _ = Init.calculate_fan_in_and_fan_out(@weight)
26
+ bound = 1 / Math.sqrt(fan_in)
27
+ Init.uniform_(@bias, -bound, bound)
28
+ end
29
+ end
30
+
31
+ def inspect
32
+ "Linear(in_features: #{@in_features.inspect}, out_features: #{@out_features.inspect}, bias: #{(!@bias.nil?).inspect})"
33
+ end
34
+ end
35
+ end
36
+ end
@@ -0,0 +1,56 @@
1
+ module Torch
2
+ module NN
3
+ class Module
4
+ def inspect
5
+ str = String.new
6
+ str << "#{self.class.name}(\n"
7
+ modules.each do |name, mod|
8
+ str << " (#{name}): #{mod.inspect}\n"
9
+ end
10
+ str << ")"
11
+ end
12
+
13
+ def call(*input)
14
+ forward(*input)
15
+ end
16
+
17
+ def parameters
18
+ params = []
19
+ instance_variables.each do |name|
20
+ param = instance_variable_get(name)
21
+ params << param if param.is_a?(Parameter)
22
+ end
23
+ params + modules.flat_map { |_, mod| mod.parameters }
24
+ end
25
+
26
+ def zero_grad
27
+ parameters.each do |param|
28
+ if param.grad
29
+ raise Error, "Not supported yet"
30
+ param.grad.detach!
31
+ param.grad.zero!
32
+ end
33
+ end
34
+ end
35
+
36
+ def method_missing(method, *args, &block)
37
+ modules[method.to_s] || super
38
+ end
39
+
40
+ def respond_to?(method, include_private = false)
41
+ modules.key?(method.to_s) || super
42
+ end
43
+
44
+ private
45
+
46
+ def modules
47
+ modules = {}
48
+ instance_variables.each do |name|
49
+ mod = instance_variable_get(name)
50
+ modules[name[1..-1]] = mod if mod.is_a?(Module)
51
+ end
52
+ modules
53
+ end
54
+ end
55
+ end
56
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class MSELoss < Module
4
+ def initialize(reduction: "mean")
5
+ @reduction = reduction
6
+ end
7
+
8
+ def forward(input, target)
9
+ F.mse_loss(input, target, reduction: @reduction)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class Parameter < Tensor
4
+ def self.new(data = nil, requires_grad: true)
5
+ data = Tensor.new unless data
6
+ Tensor._make_subclass(data, requires_grad)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class ReLU < Module
4
+ def initialize #(inplace: false)
5
+ # @inplace = inplace
6
+ end
7
+
8
+ def forward(input)
9
+ F.relu(input) #, inplace: @inplace)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,29 @@
1
+ module Torch
2
+ module NN
3
+ class Sequential < Module
4
+ def initialize(*args)
5
+ @modules = {}
6
+ # TODO support hash arg (named modules)
7
+ args.each_with_index do |mod, idx|
8
+ add_module(idx.to_s, mod)
9
+ end
10
+ end
11
+
12
+ def add_module(name, mod)
13
+ # TODO add checks
14
+ @modules[name] = mod
15
+ end
16
+
17
+ def forward(input)
18
+ @modules.values.each do |mod|
19
+ input = mod.call(input)
20
+ end
21
+ input
22
+ end
23
+
24
+ def parameters
25
+ @modules.flat_map { |_, mod| mod.parameters }
26
+ end
27
+ end
28
+ end
29
+ end
@@ -0,0 +1,143 @@
1
+ module Torch
2
+ class Tensor
3
+ include Comparable
4
+ include Inspector
5
+
6
+ alias_method :requires_grad?, :requires_grad
7
+
8
+ def self.new(*size)
9
+ if size.first.is_a?(Tensor)
10
+ size.first
11
+ else
12
+ Torch.rand(*size)
13
+ end
14
+ end
15
+
16
+ def dtype
17
+ dtype = ENUM_TO_DTYPE[_dtype]
18
+ raise Error, "Unknown type: #{_dtype}" unless dtype
19
+ dtype
20
+ end
21
+
22
+ def layout
23
+ _layout.downcase.to_sym
24
+ end
25
+
26
+ def to_s
27
+ inspect
28
+ end
29
+
30
+ def to_a
31
+ reshape(_data, shape)
32
+ end
33
+
34
+ def size(dim = nil)
35
+ if dim
36
+ _size(dim)
37
+ else
38
+ shape
39
+ end
40
+ end
41
+
42
+ def shape
43
+ dim.times.map { |i| size(i) }
44
+ end
45
+
46
+ def view(*size)
47
+ _view(size)
48
+ end
49
+
50
+ def item
51
+ if numel != 1
52
+ raise Error, "only one element tensors can be converted to Ruby scalars"
53
+ end
54
+ _data.first
55
+ end
56
+
57
+ def data
58
+ Torch.tensor(to_a)
59
+ end
60
+
61
+ # TODO read directly from memory
62
+ def numo
63
+ raise Error, "Numo not found" unless defined?(Numo::NArray)
64
+ cls = Torch._dtype_to_numo[dtype]
65
+ raise Error, "Cannot convert #{dtype} to Numo" unless cls
66
+ cls.cast(_data).reshape(*shape)
67
+ end
68
+
69
+ def new_ones(*size, **options)
70
+ Torch.ones_like(Torch.empty(*size), **options)
71
+ end
72
+
73
+ def requires_grad!(requires_grad = true)
74
+ _requires_grad!(requires_grad)
75
+ end
76
+
77
+ # operations
78
+ %w(add sub mul div remainder pow neg sum mean num norm min max dot matmul exp log unsqueeze).each do |op|
79
+ define_method(op) do |*args, **options, &block|
80
+ if options.any?
81
+ Torch.send(op, self, *args, **options, &block)
82
+ else
83
+ Torch.send(op, self, *args, &block)
84
+ end
85
+ end
86
+ end
87
+
88
+ def +(other)
89
+ add(other)
90
+ end
91
+
92
+ def -(other)
93
+ sub(other)
94
+ end
95
+
96
+ def *(other)
97
+ mul(other)
98
+ end
99
+
100
+ def /(other)
101
+ div(other)
102
+ end
103
+
104
+ def %(other)
105
+ remainder(other)
106
+ end
107
+
108
+ def **(other)
109
+ pow(other)
110
+ end
111
+
112
+ def -@
113
+ neg
114
+ end
115
+
116
+ def <=>(other)
117
+ item <=> other
118
+ end
119
+
120
+ # TODO use accessor C++ method
121
+ def [](index, *args)
122
+ v = _access(index)
123
+ args.each do |i|
124
+ v = v._access(i)
125
+ end
126
+ v
127
+ end
128
+
129
+ private
130
+
131
+ def reshape(arr, dims)
132
+ if dims.empty?
133
+ arr
134
+ else
135
+ arr = arr.flatten
136
+ dims[1..-1].reverse.each do |dim|
137
+ arr = arr.each_slice(dim)
138
+ end
139
+ arr.to_a
140
+ end
141
+ end
142
+ end
143
+ end
@@ -0,0 +1,12 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ class DataLoader
5
+ def initialize(dataset, batch_size: 1)
6
+ @dataset = dataset
7
+ @batch_size = batch_size
8
+ end
9
+ end
10
+ end
11
+ end
12
+ end
@@ -0,0 +1,15 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ class TensorDataset
5
+ def initialize(*tensors)
6
+ @tensors = tensors
7
+ end
8
+
9
+ def [](index)
10
+ tensors.map { |t| t[index] }
11
+ end
12
+ end
13
+ end
14
+ end
15
+ end
@@ -0,0 +1,3 @@
1
+ module Torch
2
+ VERSION = "0.1.0"
3
+ end
metadata ADDED
@@ -0,0 +1,149 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: torch-rb
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.1.0
5
+ platform: ruby
6
+ authors:
7
+ - Andrew Kane
8
+ autorequire:
9
+ bindir: bin
10
+ cert_chain: []
11
+ date: 2019-11-26 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: rice
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - ">="
18
+ - !ruby/object:Gem::Version
19
+ version: '0'
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - ">="
25
+ - !ruby/object:Gem::Version
26
+ version: '0'
27
+ - !ruby/object:Gem::Dependency
28
+ name: bundler
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - ">="
32
+ - !ruby/object:Gem::Version
33
+ version: '0'
34
+ type: :development
35
+ prerelease: false
36
+ version_requirements: !ruby/object:Gem::Requirement
37
+ requirements:
38
+ - - ">="
39
+ - !ruby/object:Gem::Version
40
+ version: '0'
41
+ - !ruby/object:Gem::Dependency
42
+ name: rake
43
+ requirement: !ruby/object:Gem::Requirement
44
+ requirements:
45
+ - - ">="
46
+ - !ruby/object:Gem::Version
47
+ version: '0'
48
+ type: :development
49
+ prerelease: false
50
+ version_requirements: !ruby/object:Gem::Requirement
51
+ requirements:
52
+ - - ">="
53
+ - !ruby/object:Gem::Version
54
+ version: '0'
55
+ - !ruby/object:Gem::Dependency
56
+ name: rake-compiler
57
+ requirement: !ruby/object:Gem::Requirement
58
+ requirements:
59
+ - - ">="
60
+ - !ruby/object:Gem::Version
61
+ version: '0'
62
+ type: :development
63
+ prerelease: false
64
+ version_requirements: !ruby/object:Gem::Requirement
65
+ requirements:
66
+ - - ">="
67
+ - !ruby/object:Gem::Version
68
+ version: '0'
69
+ - !ruby/object:Gem::Dependency
70
+ name: minitest
71
+ requirement: !ruby/object:Gem::Requirement
72
+ requirements:
73
+ - - ">="
74
+ - !ruby/object:Gem::Version
75
+ version: '5'
76
+ type: :development
77
+ prerelease: false
78
+ version_requirements: !ruby/object:Gem::Requirement
79
+ requirements:
80
+ - - ">="
81
+ - !ruby/object:Gem::Version
82
+ version: '5'
83
+ - !ruby/object:Gem::Dependency
84
+ name: numo-narray
85
+ requirement: !ruby/object:Gem::Requirement
86
+ requirements:
87
+ - - ">="
88
+ - !ruby/object:Gem::Version
89
+ version: '0'
90
+ type: :development
91
+ prerelease: false
92
+ version_requirements: !ruby/object:Gem::Requirement
93
+ requirements:
94
+ - - ">="
95
+ - !ruby/object:Gem::Version
96
+ version: '0'
97
+ description:
98
+ email: andrew@chartkick.com
99
+ executables: []
100
+ extensions:
101
+ - ext/torch/extconf.rb
102
+ extra_rdoc_files: []
103
+ files:
104
+ - CHANGELOG.md
105
+ - LICENSE.txt
106
+ - README.md
107
+ - ext/torch/ext.cpp
108
+ - ext/torch/extconf.rb
109
+ - lib/torch-rb.rb
110
+ - lib/torch.rb
111
+ - lib/torch/ext.bundle
112
+ - lib/torch/inspector.rb
113
+ - lib/torch/nn/conv2d.rb
114
+ - lib/torch/nn/functional.rb
115
+ - lib/torch/nn/init.rb
116
+ - lib/torch/nn/linear.rb
117
+ - lib/torch/nn/module.rb
118
+ - lib/torch/nn/mse_loss.rb
119
+ - lib/torch/nn/parameter.rb
120
+ - lib/torch/nn/relu.rb
121
+ - lib/torch/nn/sequential.rb
122
+ - lib/torch/tensor.rb
123
+ - lib/torch/utils/data/data_loader.rb
124
+ - lib/torch/utils/data/tensor_dataset.rb
125
+ - lib/torch/version.rb
126
+ homepage: https://github.com/ankane/torch-rb
127
+ licenses:
128
+ - MIT
129
+ metadata: {}
130
+ post_install_message:
131
+ rdoc_options: []
132
+ require_paths:
133
+ - lib
134
+ required_ruby_version: !ruby/object:Gem::Requirement
135
+ requirements:
136
+ - - ">="
137
+ - !ruby/object:Gem::Version
138
+ version: '2.4'
139
+ required_rubygems_version: !ruby/object:Gem::Requirement
140
+ requirements:
141
+ - - ">="
142
+ - !ruby/object:Gem::Version
143
+ version: '0'
144
+ requirements: []
145
+ rubygems_version: 3.0.3
146
+ signing_key:
147
+ specification_version: 4
148
+ summary: Deep learning for Ruby, powered by LibTorch
149
+ test_files: []