torch-rb 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +22 -0
- data/README.md +363 -0
- data/ext/torch/ext.cpp +546 -0
- data/ext/torch/extconf.rb +22 -0
- data/lib/torch-rb.rb +1 -0
- data/lib/torch.rb +327 -0
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +62 -0
- data/lib/torch/nn/conv2d.rb +50 -0
- data/lib/torch/nn/functional.rb +44 -0
- data/lib/torch/nn/init.rb +30 -0
- data/lib/torch/nn/linear.rb +36 -0
- data/lib/torch/nn/module.rb +56 -0
- data/lib/torch/nn/mse_loss.rb +13 -0
- data/lib/torch/nn/parameter.rb +10 -0
- data/lib/torch/nn/relu.rb +13 -0
- data/lib/torch/nn/sequential.rb +29 -0
- data/lib/torch/tensor.rb +143 -0
- data/lib/torch/utils/data/data_loader.rb +12 -0
- data/lib/torch/utils/data/tensor_dataset.rb +15 -0
- data/lib/torch/version.rb +3 -0
- metadata +149 -0
@@ -0,0 +1,22 @@
|
|
1
|
+
require "mkmf-rice"
|
2
|
+
|
3
|
+
abort "Missing stdc++" unless have_library("stdc++")
|
4
|
+
|
5
|
+
$CXXFLAGS << " -std=c++11"
|
6
|
+
|
7
|
+
# silence ruby/intern.h warning
|
8
|
+
$CXXFLAGS << " -Wno-deprecated-register"
|
9
|
+
|
10
|
+
inc, lib = dir_config("torch")
|
11
|
+
|
12
|
+
inc ||= "/usr/local/include"
|
13
|
+
lib ||= "/usr/local/lib"
|
14
|
+
|
15
|
+
$INCFLAGS << " -I#{inc}"
|
16
|
+
$INCFLAGS << " -I#{inc}/torch/csrc/api/include"
|
17
|
+
|
18
|
+
$LDFLAGS << " -Wl,-rpath,#{lib}"
|
19
|
+
$LDFLAGS << " -L#{lib}"
|
20
|
+
$LDFLAGS << " -ltorch -lc10"
|
21
|
+
|
22
|
+
create_makefile("torch/ext")
|
data/lib/torch-rb.rb
ADDED
@@ -0,0 +1 @@
|
|
1
|
+
require "torch"
|
data/lib/torch.rb
ADDED
@@ -0,0 +1,327 @@
|
|
1
|
+
# ext
|
2
|
+
require "torch/ext"
|
3
|
+
|
4
|
+
# modules
|
5
|
+
require "torch/inspector"
|
6
|
+
require "torch/tensor"
|
7
|
+
require "torch/version"
|
8
|
+
|
9
|
+
# nn
|
10
|
+
require "torch/nn/module"
|
11
|
+
require "torch/nn/init"
|
12
|
+
require "torch/nn/conv2d"
|
13
|
+
require "torch/nn/functional"
|
14
|
+
require "torch/nn/linear"
|
15
|
+
require "torch/nn/parameter"
|
16
|
+
require "torch/nn/sequential"
|
17
|
+
require "torch/nn/relu"
|
18
|
+
require "torch/nn/mse_loss"
|
19
|
+
|
20
|
+
# utils
|
21
|
+
require "torch/utils/data/data_loader"
|
22
|
+
require "torch/utils/data/tensor_dataset"
|
23
|
+
|
24
|
+
module Torch
|
25
|
+
class Error < StandardError; end
|
26
|
+
|
27
|
+
# keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
|
28
|
+
# values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
|
29
|
+
# complex and quantized types not supported by PyTorch yet
|
30
|
+
DTYPE_TO_ENUM = {
|
31
|
+
uint8: 0,
|
32
|
+
int8: 1,
|
33
|
+
short: 2,
|
34
|
+
int16: 2,
|
35
|
+
int: 3,
|
36
|
+
int32: 3,
|
37
|
+
long: 4,
|
38
|
+
int64: 4,
|
39
|
+
half: 5,
|
40
|
+
float16: 5,
|
41
|
+
float: 6,
|
42
|
+
float32: 6,
|
43
|
+
double: 7,
|
44
|
+
float64: 7,
|
45
|
+
# complex_half: 8,
|
46
|
+
# complex_float: 9,
|
47
|
+
# complex_double: 10,
|
48
|
+
bool: 11,
|
49
|
+
# qint8: 12,
|
50
|
+
# quint8: 13,
|
51
|
+
# qint32: 14,
|
52
|
+
# bfloat16: 15
|
53
|
+
}
|
54
|
+
ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
|
55
|
+
|
56
|
+
class << self
|
57
|
+
# Torch.float, Torch.long, etc
|
58
|
+
DTYPE_TO_ENUM.each_key do |type|
|
59
|
+
define_method(type) do
|
60
|
+
type
|
61
|
+
end
|
62
|
+
end
|
63
|
+
|
64
|
+
# https://pytorch.org/docs/stable/torch.html
|
65
|
+
|
66
|
+
def tensor?(obj)
|
67
|
+
obj.is_a?(Tensor)
|
68
|
+
end
|
69
|
+
|
70
|
+
# TODO don't copy
|
71
|
+
def from_numo(ndarray)
|
72
|
+
dtype = _dtype_to_numo.find { |k, v| ndarray.is_a?(v) }
|
73
|
+
raise Error, "Cannot convert #{ndarray.class.name} to tensor" unless dtype
|
74
|
+
tensor(ndarray.to_a, dtype: dtype[0])
|
75
|
+
end
|
76
|
+
|
77
|
+
# private
|
78
|
+
# use method for cases when Numo not available
|
79
|
+
# or available after Torch loaded
|
80
|
+
def _dtype_to_numo
|
81
|
+
{
|
82
|
+
uint8: Numo::UInt8,
|
83
|
+
int8: Numo::Int8,
|
84
|
+
int16: Numo::Int16,
|
85
|
+
int32: Numo::Int32,
|
86
|
+
int64: Numo::Int64,
|
87
|
+
float32: Numo::SFloat,
|
88
|
+
float64: Numo::DFloat
|
89
|
+
}
|
90
|
+
end
|
91
|
+
|
92
|
+
# --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
|
93
|
+
|
94
|
+
def arange(start, finish = nil, step = 1, **options)
|
95
|
+
# ruby doesn't support start = 0, finish, step = 1, ...
|
96
|
+
if finish.nil?
|
97
|
+
finish = start
|
98
|
+
start = 0
|
99
|
+
end
|
100
|
+
_arange(start, finish, step, tensor_options(**options))
|
101
|
+
end
|
102
|
+
|
103
|
+
def empty(*size, **options)
|
104
|
+
_empty(tensor_size(size), tensor_options(**options))
|
105
|
+
end
|
106
|
+
|
107
|
+
def eye(n, m = nil, **options)
|
108
|
+
_eye(n, m || n, tensor_options(**options))
|
109
|
+
end
|
110
|
+
|
111
|
+
def full(size, fill_value, **options)
|
112
|
+
_full(size, fill_value, tensor_options(**options))
|
113
|
+
end
|
114
|
+
|
115
|
+
def linspace(start, finish, steps = 100, **options)
|
116
|
+
_linspace(start, finish, steps, tensor_options(**options))
|
117
|
+
end
|
118
|
+
|
119
|
+
def logspace(start, finish, steps = 100, base = 10.0, **options)
|
120
|
+
_logspace(start, finish, steps, base, tensor_options(**options))
|
121
|
+
end
|
122
|
+
|
123
|
+
def ones(*size, **options)
|
124
|
+
_ones(tensor_size(size), tensor_options(**options))
|
125
|
+
end
|
126
|
+
|
127
|
+
def rand(*size, **options)
|
128
|
+
_rand(tensor_size(size), tensor_options(**options))
|
129
|
+
end
|
130
|
+
|
131
|
+
def randint(low = 0, high, size, **options)
|
132
|
+
_randint(low, high, size, tensor_options(**options))
|
133
|
+
end
|
134
|
+
|
135
|
+
def randn(*size, **options)
|
136
|
+
_randn(tensor_size(size), tensor_options(**options))
|
137
|
+
end
|
138
|
+
|
139
|
+
def randperm(n, **options)
|
140
|
+
_randperm(n, tensor_options(**options))
|
141
|
+
end
|
142
|
+
|
143
|
+
def zeros(*size, **options)
|
144
|
+
_zeros(tensor_size(size), tensor_options(**options))
|
145
|
+
end
|
146
|
+
|
147
|
+
def tensor(data, **options)
|
148
|
+
size = []
|
149
|
+
if data.respond_to?(:to_a)
|
150
|
+
data = data.to_a
|
151
|
+
d = data
|
152
|
+
while d.is_a?(Array)
|
153
|
+
size << d.size
|
154
|
+
d = d.first
|
155
|
+
end
|
156
|
+
data = data.flatten
|
157
|
+
else
|
158
|
+
data = [data].compact
|
159
|
+
end
|
160
|
+
|
161
|
+
if options[:dtype].nil? && data.all? { |v| v.is_a?(Integer) }
|
162
|
+
options[:dtype] = :int64
|
163
|
+
end
|
164
|
+
|
165
|
+
_tensor(data, size, tensor_options(**options))
|
166
|
+
end
|
167
|
+
|
168
|
+
# --- begin like ---
|
169
|
+
|
170
|
+
def ones_like(input, **options)
|
171
|
+
ones(input.size, like_options(input, options))
|
172
|
+
end
|
173
|
+
|
174
|
+
def empty_like(input, **options)
|
175
|
+
empty(input.size, like_options(input, options))
|
176
|
+
end
|
177
|
+
|
178
|
+
def full_like(input, fill_value, **options)
|
179
|
+
full(input.size, fill_value, like_options(input, options))
|
180
|
+
end
|
181
|
+
|
182
|
+
def rand_like(input, **options)
|
183
|
+
rand(input.size, like_options(input, options))
|
184
|
+
end
|
185
|
+
|
186
|
+
def randint_like(input, low, high = nil, **options)
|
187
|
+
# ruby doesn't support input, low = 0, high, ...
|
188
|
+
if high.nil?
|
189
|
+
high = low
|
190
|
+
low = 0
|
191
|
+
end
|
192
|
+
rand(input.size, like_options(input, options))
|
193
|
+
end
|
194
|
+
|
195
|
+
def randn_like(input, **options)
|
196
|
+
randn(input.size, like_options(input, options))
|
197
|
+
end
|
198
|
+
|
199
|
+
def zeros_like(input, **options)
|
200
|
+
zeros(input.size, like_options(input, options))
|
201
|
+
end
|
202
|
+
|
203
|
+
# --- begin operations ---
|
204
|
+
|
205
|
+
%w(add sub mul div remainder).each do |op|
|
206
|
+
define_method(op) do |input, other, **options|
|
207
|
+
execute_op(op, input, other, **options)
|
208
|
+
end
|
209
|
+
end
|
210
|
+
|
211
|
+
def neg(input)
|
212
|
+
_neg(input)
|
213
|
+
end
|
214
|
+
|
215
|
+
def no_grad
|
216
|
+
previous_value = grad_enabled?
|
217
|
+
begin
|
218
|
+
_set_grad_enabled(false)
|
219
|
+
yield
|
220
|
+
ensure
|
221
|
+
_set_grad_enabled(previous_value)
|
222
|
+
end
|
223
|
+
end
|
224
|
+
|
225
|
+
# TODO support out
|
226
|
+
def mean(input, dim = nil, keepdim: false)
|
227
|
+
if dim
|
228
|
+
_mean_dim(input, dim, keepdim)
|
229
|
+
else
|
230
|
+
_mean(input)
|
231
|
+
end
|
232
|
+
end
|
233
|
+
|
234
|
+
# TODO support dtype
|
235
|
+
def sum(input, dim = nil, keepdim: false)
|
236
|
+
if dim
|
237
|
+
_sum_dim(input, dim, keepdim)
|
238
|
+
else
|
239
|
+
_sum(input)
|
240
|
+
end
|
241
|
+
end
|
242
|
+
|
243
|
+
def norm(input)
|
244
|
+
_norm(input)
|
245
|
+
end
|
246
|
+
|
247
|
+
def pow(input, exponent)
|
248
|
+
_pow(input, exponent)
|
249
|
+
end
|
250
|
+
|
251
|
+
def min(input)
|
252
|
+
_min(input)
|
253
|
+
end
|
254
|
+
|
255
|
+
def max(input)
|
256
|
+
_max(input)
|
257
|
+
end
|
258
|
+
|
259
|
+
def exp(input)
|
260
|
+
_exp(input)
|
261
|
+
end
|
262
|
+
|
263
|
+
def log(input)
|
264
|
+
_log(input)
|
265
|
+
end
|
266
|
+
|
267
|
+
def unsqueeze(input, dim)
|
268
|
+
_unsqueeze(input, dim)
|
269
|
+
end
|
270
|
+
|
271
|
+
def dot(input, tensor)
|
272
|
+
_dot(input, tensor)
|
273
|
+
end
|
274
|
+
|
275
|
+
def matmul(input, other)
|
276
|
+
_matmul(input, other)
|
277
|
+
end
|
278
|
+
|
279
|
+
private
|
280
|
+
|
281
|
+
def execute_op(op, input, other, out: nil)
|
282
|
+
scalar = other.is_a?(Numeric)
|
283
|
+
if out
|
284
|
+
# TODO make work with scalars
|
285
|
+
raise Error, "out not supported with scalar yet" if scalar
|
286
|
+
send("_#{op}_out", out, input, other)
|
287
|
+
else
|
288
|
+
if scalar
|
289
|
+
send("_#{op}_scalar", input, other)
|
290
|
+
else
|
291
|
+
send("_#{op}", input, other)
|
292
|
+
end
|
293
|
+
end
|
294
|
+
end
|
295
|
+
|
296
|
+
def tensor_size(size)
|
297
|
+
size.flatten
|
298
|
+
end
|
299
|
+
|
300
|
+
def tensor_options(dtype: nil, layout: nil, device: nil, requires_grad: nil)
|
301
|
+
options = TensorOptions.new
|
302
|
+
unless dtype.nil?
|
303
|
+
type = DTYPE_TO_ENUM[dtype]
|
304
|
+
raise Error, "Unknown dtype: #{dtype.inspect}" unless type
|
305
|
+
options = options.dtype(type)
|
306
|
+
end
|
307
|
+
unless device.nil?
|
308
|
+
options = options.device(device.to_s)
|
309
|
+
end
|
310
|
+
unless layout.nil?
|
311
|
+
options = options.layout(layout.to_s)
|
312
|
+
end
|
313
|
+
unless requires_grad.nil?
|
314
|
+
options = options.requires_grad(requires_grad)
|
315
|
+
end
|
316
|
+
options
|
317
|
+
end
|
318
|
+
|
319
|
+
def like_options(input, options)
|
320
|
+
options = options.dup
|
321
|
+
options[:dtype] ||= input.dtype
|
322
|
+
options[:layout] ||= input.layout
|
323
|
+
options[:device] ||= input.device
|
324
|
+
options
|
325
|
+
end
|
326
|
+
end
|
327
|
+
end
|
Binary file
|
@@ -0,0 +1,62 @@
|
|
1
|
+
module Torch
|
2
|
+
module Inspector
|
3
|
+
def inspect
|
4
|
+
data =
|
5
|
+
if numel == 0
|
6
|
+
"[]"
|
7
|
+
elsif dim == 0
|
8
|
+
to_a.first
|
9
|
+
else
|
10
|
+
values = to_a.flatten
|
11
|
+
abs = values.select { |v| v != 0 }.map(&:abs)
|
12
|
+
max = abs.max || 1
|
13
|
+
min = abs.min || 1
|
14
|
+
|
15
|
+
total = 0
|
16
|
+
if values.any? { |v| v < 0 }
|
17
|
+
total += 1
|
18
|
+
end
|
19
|
+
|
20
|
+
if floating_point?
|
21
|
+
sci = max / min.to_f > 1000 || max > 1e8 || min < 1e-4
|
22
|
+
|
23
|
+
all_int = values.all? { |v| v == v.to_i }
|
24
|
+
decimal = all_int ? 1 : 4
|
25
|
+
|
26
|
+
total += sci ? 10 : decimal + 1 + max.to_i.to_s.size
|
27
|
+
|
28
|
+
if sci
|
29
|
+
fmt = "%#{total}.4e"
|
30
|
+
else
|
31
|
+
fmt = "%#{total}.#{decimal}f"
|
32
|
+
end
|
33
|
+
else
|
34
|
+
total += max.to_s.size
|
35
|
+
fmt = "%#{total}d"
|
36
|
+
end
|
37
|
+
|
38
|
+
inspect_level(to_a, fmt, dim - 1)
|
39
|
+
end
|
40
|
+
|
41
|
+
attributes = []
|
42
|
+
if requires_grad
|
43
|
+
attributes << "requires_grad: true"
|
44
|
+
end
|
45
|
+
if ![:float32, :int64, :bool].include?(dtype)
|
46
|
+
attributes << "dtype: #{dtype.inspect}"
|
47
|
+
end
|
48
|
+
|
49
|
+
"tensor(#{data}#{attributes.map { |a| ", #{a}" }.join("")})"
|
50
|
+
end
|
51
|
+
|
52
|
+
private
|
53
|
+
|
54
|
+
def inspect_level(arr, fmt, total, level = 0)
|
55
|
+
if level == total
|
56
|
+
"[#{arr.map { |v| fmt % v }.join(", ")}]"
|
57
|
+
else
|
58
|
+
"[#{arr.map { |row| inspect_level(row, fmt, total, level + 1) }.join(",#{"\n" * (total - level)}#{" " * (level + 8)}")}]"
|
59
|
+
end
|
60
|
+
end
|
61
|
+
end
|
62
|
+
end
|
@@ -0,0 +1,50 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Conv2d < Module
|
4
|
+
attr_reader :bias, :weight
|
5
|
+
|
6
|
+
def initialize(in_channels, out_channels, kernel_size) #, stride: 1, padding: 0, dilation: 1, groups: 1)
|
7
|
+
@in_channels = in_channels
|
8
|
+
@out_channels = out_channels
|
9
|
+
@kernel_size = pair(kernel_size)
|
10
|
+
@stride = pair(1)
|
11
|
+
# @stride = pair(stride)
|
12
|
+
# @padding = pair(padding)
|
13
|
+
# @dilation = pair(dilation)
|
14
|
+
|
15
|
+
# TODO divide by groups
|
16
|
+
@weight = Parameter.new(Tensor.new(out_channels, in_channels, *@kernel_size))
|
17
|
+
@bias = Parameter.new(Tensor.new(out_channels))
|
18
|
+
|
19
|
+
reset_parameters
|
20
|
+
end
|
21
|
+
|
22
|
+
def reset_parameters
|
23
|
+
Init.kaiming_uniform_(@weight, Math.sqrt(5))
|
24
|
+
if @bias
|
25
|
+
fan_in, _ = Init.calculate_fan_in_and_fan_out(@weight)
|
26
|
+
bound = 1 / Math.sqrt(fan_in)
|
27
|
+
Init.uniform_(@bias, -bound, bound)
|
28
|
+
end
|
29
|
+
end
|
30
|
+
|
31
|
+
def call(input)
|
32
|
+
F.conv2d(input, @weight, @bias) # @stride, self.padding, self.dilation, self.groups)
|
33
|
+
end
|
34
|
+
|
35
|
+
def inspect
|
36
|
+
"Conv2d(#{@in_channels}, #{@out_channels}, kernel_size: #{@kernel_size.inspect}, stride: #{@stride.inspect})"
|
37
|
+
end
|
38
|
+
|
39
|
+
private
|
40
|
+
|
41
|
+
def pair(value)
|
42
|
+
if value.is_a?(Array)
|
43
|
+
value
|
44
|
+
else
|
45
|
+
[value] * 2
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|
50
|
+
end
|