torch-rb 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,22 @@
1
+ require "mkmf-rice"
2
+
3
+ abort "Missing stdc++" unless have_library("stdc++")
4
+
5
+ $CXXFLAGS << " -std=c++11"
6
+
7
+ # silence ruby/intern.h warning
8
+ $CXXFLAGS << " -Wno-deprecated-register"
9
+
10
+ inc, lib = dir_config("torch")
11
+
12
+ inc ||= "/usr/local/include"
13
+ lib ||= "/usr/local/lib"
14
+
15
+ $INCFLAGS << " -I#{inc}"
16
+ $INCFLAGS << " -I#{inc}/torch/csrc/api/include"
17
+
18
+ $LDFLAGS << " -Wl,-rpath,#{lib}"
19
+ $LDFLAGS << " -L#{lib}"
20
+ $LDFLAGS << " -ltorch -lc10"
21
+
22
+ create_makefile("torch/ext")
data/lib/torch-rb.rb ADDED
@@ -0,0 +1 @@
1
+ require "torch"
data/lib/torch.rb ADDED
@@ -0,0 +1,327 @@
1
+ # ext
2
+ require "torch/ext"
3
+
4
+ # modules
5
+ require "torch/inspector"
6
+ require "torch/tensor"
7
+ require "torch/version"
8
+
9
+ # nn
10
+ require "torch/nn/module"
11
+ require "torch/nn/init"
12
+ require "torch/nn/conv2d"
13
+ require "torch/nn/functional"
14
+ require "torch/nn/linear"
15
+ require "torch/nn/parameter"
16
+ require "torch/nn/sequential"
17
+ require "torch/nn/relu"
18
+ require "torch/nn/mse_loss"
19
+
20
+ # utils
21
+ require "torch/utils/data/data_loader"
22
+ require "torch/utils/data/tensor_dataset"
23
+
24
+ module Torch
25
+ class Error < StandardError; end
26
+
27
+ # keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
28
+ # values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
29
+ # complex and quantized types not supported by PyTorch yet
30
+ DTYPE_TO_ENUM = {
31
+ uint8: 0,
32
+ int8: 1,
33
+ short: 2,
34
+ int16: 2,
35
+ int: 3,
36
+ int32: 3,
37
+ long: 4,
38
+ int64: 4,
39
+ half: 5,
40
+ float16: 5,
41
+ float: 6,
42
+ float32: 6,
43
+ double: 7,
44
+ float64: 7,
45
+ # complex_half: 8,
46
+ # complex_float: 9,
47
+ # complex_double: 10,
48
+ bool: 11,
49
+ # qint8: 12,
50
+ # quint8: 13,
51
+ # qint32: 14,
52
+ # bfloat16: 15
53
+ }
54
+ ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
55
+
56
+ class << self
57
+ # Torch.float, Torch.long, etc
58
+ DTYPE_TO_ENUM.each_key do |type|
59
+ define_method(type) do
60
+ type
61
+ end
62
+ end
63
+
64
+ # https://pytorch.org/docs/stable/torch.html
65
+
66
+ def tensor?(obj)
67
+ obj.is_a?(Tensor)
68
+ end
69
+
70
+ # TODO don't copy
71
+ def from_numo(ndarray)
72
+ dtype = _dtype_to_numo.find { |k, v| ndarray.is_a?(v) }
73
+ raise Error, "Cannot convert #{ndarray.class.name} to tensor" unless dtype
74
+ tensor(ndarray.to_a, dtype: dtype[0])
75
+ end
76
+
77
+ # private
78
+ # use method for cases when Numo not available
79
+ # or available after Torch loaded
80
+ def _dtype_to_numo
81
+ {
82
+ uint8: Numo::UInt8,
83
+ int8: Numo::Int8,
84
+ int16: Numo::Int16,
85
+ int32: Numo::Int32,
86
+ int64: Numo::Int64,
87
+ float32: Numo::SFloat,
88
+ float64: Numo::DFloat
89
+ }
90
+ end
91
+
92
+ # --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
93
+
94
+ def arange(start, finish = nil, step = 1, **options)
95
+ # ruby doesn't support start = 0, finish, step = 1, ...
96
+ if finish.nil?
97
+ finish = start
98
+ start = 0
99
+ end
100
+ _arange(start, finish, step, tensor_options(**options))
101
+ end
102
+
103
+ def empty(*size, **options)
104
+ _empty(tensor_size(size), tensor_options(**options))
105
+ end
106
+
107
+ def eye(n, m = nil, **options)
108
+ _eye(n, m || n, tensor_options(**options))
109
+ end
110
+
111
+ def full(size, fill_value, **options)
112
+ _full(size, fill_value, tensor_options(**options))
113
+ end
114
+
115
+ def linspace(start, finish, steps = 100, **options)
116
+ _linspace(start, finish, steps, tensor_options(**options))
117
+ end
118
+
119
+ def logspace(start, finish, steps = 100, base = 10.0, **options)
120
+ _logspace(start, finish, steps, base, tensor_options(**options))
121
+ end
122
+
123
+ def ones(*size, **options)
124
+ _ones(tensor_size(size), tensor_options(**options))
125
+ end
126
+
127
+ def rand(*size, **options)
128
+ _rand(tensor_size(size), tensor_options(**options))
129
+ end
130
+
131
+ def randint(low = 0, high, size, **options)
132
+ _randint(low, high, size, tensor_options(**options))
133
+ end
134
+
135
+ def randn(*size, **options)
136
+ _randn(tensor_size(size), tensor_options(**options))
137
+ end
138
+
139
+ def randperm(n, **options)
140
+ _randperm(n, tensor_options(**options))
141
+ end
142
+
143
+ def zeros(*size, **options)
144
+ _zeros(tensor_size(size), tensor_options(**options))
145
+ end
146
+
147
+ def tensor(data, **options)
148
+ size = []
149
+ if data.respond_to?(:to_a)
150
+ data = data.to_a
151
+ d = data
152
+ while d.is_a?(Array)
153
+ size << d.size
154
+ d = d.first
155
+ end
156
+ data = data.flatten
157
+ else
158
+ data = [data].compact
159
+ end
160
+
161
+ if options[:dtype].nil? && data.all? { |v| v.is_a?(Integer) }
162
+ options[:dtype] = :int64
163
+ end
164
+
165
+ _tensor(data, size, tensor_options(**options))
166
+ end
167
+
168
+ # --- begin like ---
169
+
170
+ def ones_like(input, **options)
171
+ ones(input.size, like_options(input, options))
172
+ end
173
+
174
+ def empty_like(input, **options)
175
+ empty(input.size, like_options(input, options))
176
+ end
177
+
178
+ def full_like(input, fill_value, **options)
179
+ full(input.size, fill_value, like_options(input, options))
180
+ end
181
+
182
+ def rand_like(input, **options)
183
+ rand(input.size, like_options(input, options))
184
+ end
185
+
186
+ def randint_like(input, low, high = nil, **options)
187
+ # ruby doesn't support input, low = 0, high, ...
188
+ if high.nil?
189
+ high = low
190
+ low = 0
191
+ end
192
+ rand(input.size, like_options(input, options))
193
+ end
194
+
195
+ def randn_like(input, **options)
196
+ randn(input.size, like_options(input, options))
197
+ end
198
+
199
+ def zeros_like(input, **options)
200
+ zeros(input.size, like_options(input, options))
201
+ end
202
+
203
+ # --- begin operations ---
204
+
205
+ %w(add sub mul div remainder).each do |op|
206
+ define_method(op) do |input, other, **options|
207
+ execute_op(op, input, other, **options)
208
+ end
209
+ end
210
+
211
+ def neg(input)
212
+ _neg(input)
213
+ end
214
+
215
+ def no_grad
216
+ previous_value = grad_enabled?
217
+ begin
218
+ _set_grad_enabled(false)
219
+ yield
220
+ ensure
221
+ _set_grad_enabled(previous_value)
222
+ end
223
+ end
224
+
225
+ # TODO support out
226
+ def mean(input, dim = nil, keepdim: false)
227
+ if dim
228
+ _mean_dim(input, dim, keepdim)
229
+ else
230
+ _mean(input)
231
+ end
232
+ end
233
+
234
+ # TODO support dtype
235
+ def sum(input, dim = nil, keepdim: false)
236
+ if dim
237
+ _sum_dim(input, dim, keepdim)
238
+ else
239
+ _sum(input)
240
+ end
241
+ end
242
+
243
+ def norm(input)
244
+ _norm(input)
245
+ end
246
+
247
+ def pow(input, exponent)
248
+ _pow(input, exponent)
249
+ end
250
+
251
+ def min(input)
252
+ _min(input)
253
+ end
254
+
255
+ def max(input)
256
+ _max(input)
257
+ end
258
+
259
+ def exp(input)
260
+ _exp(input)
261
+ end
262
+
263
+ def log(input)
264
+ _log(input)
265
+ end
266
+
267
+ def unsqueeze(input, dim)
268
+ _unsqueeze(input, dim)
269
+ end
270
+
271
+ def dot(input, tensor)
272
+ _dot(input, tensor)
273
+ end
274
+
275
+ def matmul(input, other)
276
+ _matmul(input, other)
277
+ end
278
+
279
+ private
280
+
281
+ def execute_op(op, input, other, out: nil)
282
+ scalar = other.is_a?(Numeric)
283
+ if out
284
+ # TODO make work with scalars
285
+ raise Error, "out not supported with scalar yet" if scalar
286
+ send("_#{op}_out", out, input, other)
287
+ else
288
+ if scalar
289
+ send("_#{op}_scalar", input, other)
290
+ else
291
+ send("_#{op}", input, other)
292
+ end
293
+ end
294
+ end
295
+
296
+ def tensor_size(size)
297
+ size.flatten
298
+ end
299
+
300
+ def tensor_options(dtype: nil, layout: nil, device: nil, requires_grad: nil)
301
+ options = TensorOptions.new
302
+ unless dtype.nil?
303
+ type = DTYPE_TO_ENUM[dtype]
304
+ raise Error, "Unknown dtype: #{dtype.inspect}" unless type
305
+ options = options.dtype(type)
306
+ end
307
+ unless device.nil?
308
+ options = options.device(device.to_s)
309
+ end
310
+ unless layout.nil?
311
+ options = options.layout(layout.to_s)
312
+ end
313
+ unless requires_grad.nil?
314
+ options = options.requires_grad(requires_grad)
315
+ end
316
+ options
317
+ end
318
+
319
+ def like_options(input, options)
320
+ options = options.dup
321
+ options[:dtype] ||= input.dtype
322
+ options[:layout] ||= input.layout
323
+ options[:device] ||= input.device
324
+ options
325
+ end
326
+ end
327
+ end
Binary file
@@ -0,0 +1,62 @@
1
+ module Torch
2
+ module Inspector
3
+ def inspect
4
+ data =
5
+ if numel == 0
6
+ "[]"
7
+ elsif dim == 0
8
+ to_a.first
9
+ else
10
+ values = to_a.flatten
11
+ abs = values.select { |v| v != 0 }.map(&:abs)
12
+ max = abs.max || 1
13
+ min = abs.min || 1
14
+
15
+ total = 0
16
+ if values.any? { |v| v < 0 }
17
+ total += 1
18
+ end
19
+
20
+ if floating_point?
21
+ sci = max / min.to_f > 1000 || max > 1e8 || min < 1e-4
22
+
23
+ all_int = values.all? { |v| v == v.to_i }
24
+ decimal = all_int ? 1 : 4
25
+
26
+ total += sci ? 10 : decimal + 1 + max.to_i.to_s.size
27
+
28
+ if sci
29
+ fmt = "%#{total}.4e"
30
+ else
31
+ fmt = "%#{total}.#{decimal}f"
32
+ end
33
+ else
34
+ total += max.to_s.size
35
+ fmt = "%#{total}d"
36
+ end
37
+
38
+ inspect_level(to_a, fmt, dim - 1)
39
+ end
40
+
41
+ attributes = []
42
+ if requires_grad
43
+ attributes << "requires_grad: true"
44
+ end
45
+ if ![:float32, :int64, :bool].include?(dtype)
46
+ attributes << "dtype: #{dtype.inspect}"
47
+ end
48
+
49
+ "tensor(#{data}#{attributes.map { |a| ", #{a}" }.join("")})"
50
+ end
51
+
52
+ private
53
+
54
+ def inspect_level(arr, fmt, total, level = 0)
55
+ if level == total
56
+ "[#{arr.map { |v| fmt % v }.join(", ")}]"
57
+ else
58
+ "[#{arr.map { |row| inspect_level(row, fmt, total, level + 1) }.join(",#{"\n" * (total - level)}#{" " * (level + 8)}")}]"
59
+ end
60
+ end
61
+ end
62
+ end
@@ -0,0 +1,50 @@
1
+ module Torch
2
+ module NN
3
+ class Conv2d < Module
4
+ attr_reader :bias, :weight
5
+
6
+ def initialize(in_channels, out_channels, kernel_size) #, stride: 1, padding: 0, dilation: 1, groups: 1)
7
+ @in_channels = in_channels
8
+ @out_channels = out_channels
9
+ @kernel_size = pair(kernel_size)
10
+ @stride = pair(1)
11
+ # @stride = pair(stride)
12
+ # @padding = pair(padding)
13
+ # @dilation = pair(dilation)
14
+
15
+ # TODO divide by groups
16
+ @weight = Parameter.new(Tensor.new(out_channels, in_channels, *@kernel_size))
17
+ @bias = Parameter.new(Tensor.new(out_channels))
18
+
19
+ reset_parameters
20
+ end
21
+
22
+ def reset_parameters
23
+ Init.kaiming_uniform_(@weight, Math.sqrt(5))
24
+ if @bias
25
+ fan_in, _ = Init.calculate_fan_in_and_fan_out(@weight)
26
+ bound = 1 / Math.sqrt(fan_in)
27
+ Init.uniform_(@bias, -bound, bound)
28
+ end
29
+ end
30
+
31
+ def call(input)
32
+ F.conv2d(input, @weight, @bias) # @stride, self.padding, self.dilation, self.groups)
33
+ end
34
+
35
+ def inspect
36
+ "Conv2d(#{@in_channels}, #{@out_channels}, kernel_size: #{@kernel_size.inspect}, stride: #{@stride.inspect})"
37
+ end
38
+
39
+ private
40
+
41
+ def pair(value)
42
+ if value.is_a?(Array)
43
+ value
44
+ else
45
+ [value] * 2
46
+ end
47
+ end
48
+ end
49
+ end
50
+ end