torch-rb 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,44 @@
1
+ module Torch
2
+ module NN
3
+ class Functional
4
+ class << self
5
+ def relu(input)
6
+ Torch.relu(input)
7
+ end
8
+
9
+ def conv2d(input, weight, bias)
10
+ Torch.conv2d(input, weight, bias)
11
+ end
12
+
13
+ def max_pool2d(input, kernel_size)
14
+ kernel_size = [kernel_size, kernel_size] if kernel_size.is_a?(Integer)
15
+ Torch.max_pool2d(input, kernel_size)
16
+ end
17
+
18
+ def linear(input, weight, bias)
19
+ Torch.linear(input, weight, bias)
20
+ end
21
+
22
+ def mse_loss(input, target, reduction: "mean")
23
+ Torch.mse_loss(input, target, reduction)
24
+ end
25
+
26
+ def cross_entropy(input, target)
27
+ nll_loss(log_softmax(input, 1), target)
28
+ end
29
+
30
+ def nll_loss(input, target)
31
+ # TODO fix for non-1d
32
+ Torch.nll_loss(input, target)
33
+ end
34
+
35
+ def log_softmax(input, dim)
36
+ input.log_softmax(dim)
37
+ end
38
+ end
39
+ end
40
+
41
+ # shortcut
42
+ F = Functional
43
+ end
44
+ end
@@ -0,0 +1,30 @@
1
+ module Torch
2
+ module NN
3
+ module Init
4
+ class << self
5
+ def calculate_fan_in_and_fan_out(tensor)
6
+ dimensions = tensor.dim
7
+ if dimensions < 2
8
+ raise Error, "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"
9
+ end
10
+
11
+ if dimensions == 2
12
+ fan_in = tensor.size(1)
13
+ fan_out = tensor.size(0)
14
+ else
15
+ num_input_fmaps = tensor.size(1)
16
+ num_output_fmaps = tensor.size(0)
17
+ receptive_field_size = 1
18
+ if tensor.dim > 2
19
+ receptive_field_size = tensor[0][0].numel
20
+ end
21
+ fan_in = num_input_fmaps * receptive_field_size
22
+ fan_out = num_output_fmaps * receptive_field_size
23
+ end
24
+
25
+ [fan_in, fan_out]
26
+ end
27
+ end
28
+ end
29
+ end
30
+ end
@@ -0,0 +1,36 @@
1
+ module Torch
2
+ module NN
3
+ class Linear < Module
4
+ attr_reader :bias, :weight
5
+
6
+ def initialize(in_features, out_features, bias: true)
7
+ @in_features = in_features
8
+ @out_features = out_features
9
+
10
+ @weight = Parameter.new(Tensor.new(out_features, in_features))
11
+ if bias
12
+ @bias = Parameter.new(Tensor.new(out_features))
13
+ end
14
+
15
+ reset_parameters
16
+ end
17
+
18
+ def call(input)
19
+ F.linear(input, @weight, @bias)
20
+ end
21
+
22
+ def reset_parameters
23
+ Init.kaiming_uniform_(@weight, Math.sqrt(5))
24
+ if @bias
25
+ fan_in, _ = Init.calculate_fan_in_and_fan_out(@weight)
26
+ bound = 1 / Math.sqrt(fan_in)
27
+ Init.uniform_(@bias, -bound, bound)
28
+ end
29
+ end
30
+
31
+ def inspect
32
+ "Linear(in_features: #{@in_features.inspect}, out_features: #{@out_features.inspect}, bias: #{(!@bias.nil?).inspect})"
33
+ end
34
+ end
35
+ end
36
+ end
@@ -0,0 +1,56 @@
1
+ module Torch
2
+ module NN
3
+ class Module
4
+ def inspect
5
+ str = String.new
6
+ str << "#{self.class.name}(\n"
7
+ modules.each do |name, mod|
8
+ str << " (#{name}): #{mod.inspect}\n"
9
+ end
10
+ str << ")"
11
+ end
12
+
13
+ def call(*input)
14
+ forward(*input)
15
+ end
16
+
17
+ def parameters
18
+ params = []
19
+ instance_variables.each do |name|
20
+ param = instance_variable_get(name)
21
+ params << param if param.is_a?(Parameter)
22
+ end
23
+ params + modules.flat_map { |_, mod| mod.parameters }
24
+ end
25
+
26
+ def zero_grad
27
+ parameters.each do |param|
28
+ if param.grad
29
+ raise Error, "Not supported yet"
30
+ param.grad.detach!
31
+ param.grad.zero!
32
+ end
33
+ end
34
+ end
35
+
36
+ def method_missing(method, *args, &block)
37
+ modules[method.to_s] || super
38
+ end
39
+
40
+ def respond_to?(method, include_private = false)
41
+ modules.key?(method.to_s) || super
42
+ end
43
+
44
+ private
45
+
46
+ def modules
47
+ modules = {}
48
+ instance_variables.each do |name|
49
+ mod = instance_variable_get(name)
50
+ modules[name[1..-1]] = mod if mod.is_a?(Module)
51
+ end
52
+ modules
53
+ end
54
+ end
55
+ end
56
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class MSELoss < Module
4
+ def initialize(reduction: "mean")
5
+ @reduction = reduction
6
+ end
7
+
8
+ def forward(input, target)
9
+ F.mse_loss(input, target, reduction: @reduction)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class Parameter < Tensor
4
+ def self.new(data = nil, requires_grad: true)
5
+ data = Tensor.new unless data
6
+ Tensor._make_subclass(data, requires_grad)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class ReLU < Module
4
+ def initialize #(inplace: false)
5
+ # @inplace = inplace
6
+ end
7
+
8
+ def forward(input)
9
+ F.relu(input) #, inplace: @inplace)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,29 @@
1
+ module Torch
2
+ module NN
3
+ class Sequential < Module
4
+ def initialize(*args)
5
+ @modules = {}
6
+ # TODO support hash arg (named modules)
7
+ args.each_with_index do |mod, idx|
8
+ add_module(idx.to_s, mod)
9
+ end
10
+ end
11
+
12
+ def add_module(name, mod)
13
+ # TODO add checks
14
+ @modules[name] = mod
15
+ end
16
+
17
+ def forward(input)
18
+ @modules.values.each do |mod|
19
+ input = mod.call(input)
20
+ end
21
+ input
22
+ end
23
+
24
+ def parameters
25
+ @modules.flat_map { |_, mod| mod.parameters }
26
+ end
27
+ end
28
+ end
29
+ end
@@ -0,0 +1,143 @@
1
+ module Torch
2
+ class Tensor
3
+ include Comparable
4
+ include Inspector
5
+
6
+ alias_method :requires_grad?, :requires_grad
7
+
8
+ def self.new(*size)
9
+ if size.first.is_a?(Tensor)
10
+ size.first
11
+ else
12
+ Torch.rand(*size)
13
+ end
14
+ end
15
+
16
+ def dtype
17
+ dtype = ENUM_TO_DTYPE[_dtype]
18
+ raise Error, "Unknown type: #{_dtype}" unless dtype
19
+ dtype
20
+ end
21
+
22
+ def layout
23
+ _layout.downcase.to_sym
24
+ end
25
+
26
+ def to_s
27
+ inspect
28
+ end
29
+
30
+ def to_a
31
+ reshape(_data, shape)
32
+ end
33
+
34
+ def size(dim = nil)
35
+ if dim
36
+ _size(dim)
37
+ else
38
+ shape
39
+ end
40
+ end
41
+
42
+ def shape
43
+ dim.times.map { |i| size(i) }
44
+ end
45
+
46
+ def view(*size)
47
+ _view(size)
48
+ end
49
+
50
+ def item
51
+ if numel != 1
52
+ raise Error, "only one element tensors can be converted to Ruby scalars"
53
+ end
54
+ _data.first
55
+ end
56
+
57
+ def data
58
+ Torch.tensor(to_a)
59
+ end
60
+
61
+ # TODO read directly from memory
62
+ def numo
63
+ raise Error, "Numo not found" unless defined?(Numo::NArray)
64
+ cls = Torch._dtype_to_numo[dtype]
65
+ raise Error, "Cannot convert #{dtype} to Numo" unless cls
66
+ cls.cast(_data).reshape(*shape)
67
+ end
68
+
69
+ def new_ones(*size, **options)
70
+ Torch.ones_like(Torch.empty(*size), **options)
71
+ end
72
+
73
+ def requires_grad!(requires_grad = true)
74
+ _requires_grad!(requires_grad)
75
+ end
76
+
77
+ # operations
78
+ %w(add sub mul div remainder pow neg sum mean num norm min max dot matmul exp log unsqueeze).each do |op|
79
+ define_method(op) do |*args, **options, &block|
80
+ if options.any?
81
+ Torch.send(op, self, *args, **options, &block)
82
+ else
83
+ Torch.send(op, self, *args, &block)
84
+ end
85
+ end
86
+ end
87
+
88
+ def +(other)
89
+ add(other)
90
+ end
91
+
92
+ def -(other)
93
+ sub(other)
94
+ end
95
+
96
+ def *(other)
97
+ mul(other)
98
+ end
99
+
100
+ def /(other)
101
+ div(other)
102
+ end
103
+
104
+ def %(other)
105
+ remainder(other)
106
+ end
107
+
108
+ def **(other)
109
+ pow(other)
110
+ end
111
+
112
+ def -@
113
+ neg
114
+ end
115
+
116
+ def <=>(other)
117
+ item <=> other
118
+ end
119
+
120
+ # TODO use accessor C++ method
121
+ def [](index, *args)
122
+ v = _access(index)
123
+ args.each do |i|
124
+ v = v._access(i)
125
+ end
126
+ v
127
+ end
128
+
129
+ private
130
+
131
+ def reshape(arr, dims)
132
+ if dims.empty?
133
+ arr
134
+ else
135
+ arr = arr.flatten
136
+ dims[1..-1].reverse.each do |dim|
137
+ arr = arr.each_slice(dim)
138
+ end
139
+ arr.to_a
140
+ end
141
+ end
142
+ end
143
+ end
@@ -0,0 +1,12 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ class DataLoader
5
+ def initialize(dataset, batch_size: 1)
6
+ @dataset = dataset
7
+ @batch_size = batch_size
8
+ end
9
+ end
10
+ end
11
+ end
12
+ end
@@ -0,0 +1,15 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ class TensorDataset
5
+ def initialize(*tensors)
6
+ @tensors = tensors
7
+ end
8
+
9
+ def [](index)
10
+ tensors.map { |t| t[index] }
11
+ end
12
+ end
13
+ end
14
+ end
15
+ end
@@ -0,0 +1,3 @@
1
+ module Torch
2
+ VERSION = "0.1.0"
3
+ end
metadata ADDED
@@ -0,0 +1,149 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: torch-rb
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.1.0
5
+ platform: ruby
6
+ authors:
7
+ - Andrew Kane
8
+ autorequire:
9
+ bindir: bin
10
+ cert_chain: []
11
+ date: 2019-11-26 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: rice
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - ">="
18
+ - !ruby/object:Gem::Version
19
+ version: '0'
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - ">="
25
+ - !ruby/object:Gem::Version
26
+ version: '0'
27
+ - !ruby/object:Gem::Dependency
28
+ name: bundler
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - ">="
32
+ - !ruby/object:Gem::Version
33
+ version: '0'
34
+ type: :development
35
+ prerelease: false
36
+ version_requirements: !ruby/object:Gem::Requirement
37
+ requirements:
38
+ - - ">="
39
+ - !ruby/object:Gem::Version
40
+ version: '0'
41
+ - !ruby/object:Gem::Dependency
42
+ name: rake
43
+ requirement: !ruby/object:Gem::Requirement
44
+ requirements:
45
+ - - ">="
46
+ - !ruby/object:Gem::Version
47
+ version: '0'
48
+ type: :development
49
+ prerelease: false
50
+ version_requirements: !ruby/object:Gem::Requirement
51
+ requirements:
52
+ - - ">="
53
+ - !ruby/object:Gem::Version
54
+ version: '0'
55
+ - !ruby/object:Gem::Dependency
56
+ name: rake-compiler
57
+ requirement: !ruby/object:Gem::Requirement
58
+ requirements:
59
+ - - ">="
60
+ - !ruby/object:Gem::Version
61
+ version: '0'
62
+ type: :development
63
+ prerelease: false
64
+ version_requirements: !ruby/object:Gem::Requirement
65
+ requirements:
66
+ - - ">="
67
+ - !ruby/object:Gem::Version
68
+ version: '0'
69
+ - !ruby/object:Gem::Dependency
70
+ name: minitest
71
+ requirement: !ruby/object:Gem::Requirement
72
+ requirements:
73
+ - - ">="
74
+ - !ruby/object:Gem::Version
75
+ version: '5'
76
+ type: :development
77
+ prerelease: false
78
+ version_requirements: !ruby/object:Gem::Requirement
79
+ requirements:
80
+ - - ">="
81
+ - !ruby/object:Gem::Version
82
+ version: '5'
83
+ - !ruby/object:Gem::Dependency
84
+ name: numo-narray
85
+ requirement: !ruby/object:Gem::Requirement
86
+ requirements:
87
+ - - ">="
88
+ - !ruby/object:Gem::Version
89
+ version: '0'
90
+ type: :development
91
+ prerelease: false
92
+ version_requirements: !ruby/object:Gem::Requirement
93
+ requirements:
94
+ - - ">="
95
+ - !ruby/object:Gem::Version
96
+ version: '0'
97
+ description:
98
+ email: andrew@chartkick.com
99
+ executables: []
100
+ extensions:
101
+ - ext/torch/extconf.rb
102
+ extra_rdoc_files: []
103
+ files:
104
+ - CHANGELOG.md
105
+ - LICENSE.txt
106
+ - README.md
107
+ - ext/torch/ext.cpp
108
+ - ext/torch/extconf.rb
109
+ - lib/torch-rb.rb
110
+ - lib/torch.rb
111
+ - lib/torch/ext.bundle
112
+ - lib/torch/inspector.rb
113
+ - lib/torch/nn/conv2d.rb
114
+ - lib/torch/nn/functional.rb
115
+ - lib/torch/nn/init.rb
116
+ - lib/torch/nn/linear.rb
117
+ - lib/torch/nn/module.rb
118
+ - lib/torch/nn/mse_loss.rb
119
+ - lib/torch/nn/parameter.rb
120
+ - lib/torch/nn/relu.rb
121
+ - lib/torch/nn/sequential.rb
122
+ - lib/torch/tensor.rb
123
+ - lib/torch/utils/data/data_loader.rb
124
+ - lib/torch/utils/data/tensor_dataset.rb
125
+ - lib/torch/version.rb
126
+ homepage: https://github.com/ankane/torch-rb
127
+ licenses:
128
+ - MIT
129
+ metadata: {}
130
+ post_install_message:
131
+ rdoc_options: []
132
+ require_paths:
133
+ - lib
134
+ required_ruby_version: !ruby/object:Gem::Requirement
135
+ requirements:
136
+ - - ">="
137
+ - !ruby/object:Gem::Version
138
+ version: '2.4'
139
+ required_rubygems_version: !ruby/object:Gem::Requirement
140
+ requirements:
141
+ - - ">="
142
+ - !ruby/object:Gem::Version
143
+ version: '0'
144
+ requirements: []
145
+ rubygems_version: 3.0.3
146
+ signing_key:
147
+ specification_version: 4
148
+ summary: Deep learning for Ruby, powered by LibTorch
149
+ test_files: []