torch-rb 0.2.0 → 0.2.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +31 -0
- data/README.md +36 -6
- data/ext/torch/ext.cpp +197 -24
- data/ext/torch/extconf.rb +34 -21
- data/lib/torch.rb +102 -6
- data/lib/torch/hub.rb +52 -0
- data/lib/torch/inspector.rb +3 -3
- data/lib/torch/nn/batch_norm.rb +5 -0
- data/lib/torch/nn/conv2d.rb +8 -1
- data/lib/torch/nn/convnd.rb +1 -1
- data/lib/torch/nn/max_poolnd.rb +2 -1
- data/lib/torch/nn/module.rb +45 -8
- data/lib/torch/tensor.rb +48 -26
- data/lib/torch/utils/data/data_loader.rb +32 -4
- data/lib/torch/utils/data/dataset.rb +8 -0
- data/lib/torch/utils/data/tensor_dataset.rb +1 -1
- data/lib/torch/version.rb +1 -1
- metadata +6 -13
- data/ext/torch/nn_functions.cpp +0 -560
- data/ext/torch/nn_functions.hpp +0 -6
- data/ext/torch/tensor_functions.cpp +0 -2085
- data/ext/torch/tensor_functions.hpp +0 -6
- data/ext/torch/torch_functions.cpp +0 -3175
- data/ext/torch/torch_functions.hpp +0 -6
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/random.rb +0 -10
data/ext/torch/extconf.rb
CHANGED
@@ -2,28 +2,33 @@ require "mkmf-rice"
|
|
2
2
|
|
3
3
|
abort "Missing stdc++" unless have_library("stdc++")
|
4
4
|
|
5
|
-
$CXXFLAGS
|
5
|
+
$CXXFLAGS += " -std=c++14"
|
6
6
|
|
7
7
|
# change to 0 for Linux pre-cxx11 ABI version
|
8
|
-
$CXXFLAGS
|
8
|
+
$CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
|
9
9
|
|
10
10
|
# TODO check compiler name
|
11
11
|
clang = RbConfig::CONFIG["host_os"] =~ /darwin/i
|
12
12
|
|
13
|
+
# check omp first
|
13
14
|
if have_library("omp") || have_library("gomp")
|
14
|
-
$CXXFLAGS
|
15
|
-
$CXXFLAGS
|
16
|
-
$CXXFLAGS
|
15
|
+
$CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
|
16
|
+
$CXXFLAGS += " -Xclang" if clang
|
17
|
+
$CXXFLAGS += " -fopenmp"
|
17
18
|
end
|
18
19
|
|
19
|
-
# silence ruby/intern.h warning
|
20
|
-
$CXXFLAGS << " -Wno-deprecated-register"
|
21
|
-
|
22
|
-
# silence torch warnings
|
23
20
|
if clang
|
24
|
-
|
21
|
+
# silence ruby/intern.h warning
|
22
|
+
$CXXFLAGS += " -Wno-deprecated-register"
|
23
|
+
|
24
|
+
# silence torch warnings
|
25
|
+
$CXXFLAGS += " -Wno-shorten-64-to-32 -Wno-missing-noreturn"
|
25
26
|
else
|
26
|
-
|
27
|
+
# silence rice warnings
|
28
|
+
$CXXFLAGS += " -Wno-noexcept-type"
|
29
|
+
|
30
|
+
# silence torch warnings
|
31
|
+
$CXXFLAGS += " -Wno-duplicated-cond -Wno-suggest-attribute=noreturn"
|
27
32
|
end
|
28
33
|
|
29
34
|
inc, lib = dir_config("torch")
|
@@ -34,22 +39,30 @@ cuda_inc, cuda_lib = dir_config("cuda")
|
|
34
39
|
cuda_inc ||= "/usr/local/cuda/include"
|
35
40
|
cuda_lib ||= "/usr/local/cuda/lib64"
|
36
41
|
|
37
|
-
|
42
|
+
$LDFLAGS += " -L#{lib}" if Dir.exist?(lib)
|
43
|
+
abort "LibTorch not found" unless have_library("torch")
|
44
|
+
|
45
|
+
have_library("mkldnn")
|
46
|
+
have_library("nnpack")
|
47
|
+
|
48
|
+
with_cuda = false
|
49
|
+
if Dir["#{lib}/*torch_cuda*"].any?
|
50
|
+
$LDFLAGS += " -L#{cuda_lib}" if Dir.exist?(cuda_lib)
|
51
|
+
with_cuda = have_library("cuda") && have_library("cudnn")
|
52
|
+
end
|
38
53
|
|
39
|
-
$INCFLAGS
|
40
|
-
$INCFLAGS
|
54
|
+
$INCFLAGS += " -I#{inc}"
|
55
|
+
$INCFLAGS += " -I#{inc}/torch/csrc/api/include"
|
41
56
|
|
42
|
-
$LDFLAGS
|
43
|
-
$LDFLAGS
|
44
|
-
$LDFLAGS << " -L#{lib}"
|
45
|
-
$LDFLAGS << " -L#{cuda_lib}" if with_cuda
|
57
|
+
$LDFLAGS += " -Wl,-rpath,#{lib}"
|
58
|
+
$LDFLAGS += ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
|
46
59
|
|
47
60
|
# https://github.com/pytorch/pytorch/blob/v1.5.0/torch/utils/cpp_extension.py#L1232-L1238
|
48
|
-
$LDFLAGS
|
61
|
+
$LDFLAGS += " -lc10 -ltorch_cpu -ltorch"
|
49
62
|
if with_cuda
|
50
|
-
$LDFLAGS
|
63
|
+
$LDFLAGS += " -lcuda -lnvrtc -lnvToolsExt -lcudart -lc10_cuda -ltorch_cuda -lcufft -lcurand -lcublas -lcudnn"
|
51
64
|
# TODO figure out why this is needed
|
52
|
-
$LDFLAGS
|
65
|
+
$LDFLAGS += " -Wl,--no-as-needed,#{lib}/libtorch.so"
|
53
66
|
end
|
54
67
|
|
55
68
|
# generate C++ functions
|
data/lib/torch.rb
CHANGED
@@ -1,6 +1,11 @@
|
|
1
1
|
# ext
|
2
2
|
require "torch/ext"
|
3
3
|
|
4
|
+
# stdlib
|
5
|
+
require "fileutils"
|
6
|
+
require "net/http"
|
7
|
+
require "tmpdir"
|
8
|
+
|
4
9
|
# native functions
|
5
10
|
require "torch/native/generator"
|
6
11
|
require "torch/native/parser"
|
@@ -174,11 +179,9 @@ require "torch/nn/init"
|
|
174
179
|
|
175
180
|
# utils
|
176
181
|
require "torch/utils/data/data_loader"
|
182
|
+
require "torch/utils/data/dataset"
|
177
183
|
require "torch/utils/data/tensor_dataset"
|
178
184
|
|
179
|
-
# random
|
180
|
-
require "torch/random"
|
181
|
-
|
182
185
|
# hub
|
183
186
|
require "torch/hub"
|
184
187
|
|
@@ -317,12 +320,11 @@ module Torch
|
|
317
320
|
end
|
318
321
|
|
319
322
|
def save(obj, f)
|
320
|
-
|
321
|
-
File.binwrite(f, _save(obj))
|
323
|
+
File.binwrite(f, _save(to_ivalue(obj)))
|
322
324
|
end
|
323
325
|
|
324
326
|
def load(f)
|
325
|
-
|
327
|
+
to_ruby(_load(File.binread(f)))
|
326
328
|
end
|
327
329
|
|
328
330
|
# --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
|
@@ -447,6 +449,100 @@ module Torch
|
|
447
449
|
|
448
450
|
private
|
449
451
|
|
452
|
+
def to_ivalue(obj)
|
453
|
+
case obj
|
454
|
+
when String
|
455
|
+
IValue.from_string(obj)
|
456
|
+
when Integer
|
457
|
+
IValue.from_int(obj)
|
458
|
+
when Tensor
|
459
|
+
IValue.from_tensor(obj)
|
460
|
+
when Float
|
461
|
+
IValue.from_double(obj)
|
462
|
+
when Hash
|
463
|
+
dict = {}
|
464
|
+
obj.each do |k, v|
|
465
|
+
dict[to_ivalue(k)] = to_ivalue(v)
|
466
|
+
end
|
467
|
+
IValue.from_dict(dict)
|
468
|
+
when true, false
|
469
|
+
IValue.from_bool(obj)
|
470
|
+
when nil
|
471
|
+
IValue.new
|
472
|
+
when Array
|
473
|
+
if obj.all? { |v| v.is_a?(Tensor) }
|
474
|
+
IValue.from_list(obj.map { |v| IValue.from_tensor(v) })
|
475
|
+
else
|
476
|
+
raise Error, "Unknown list type"
|
477
|
+
end
|
478
|
+
else
|
479
|
+
raise Error, "Unknown type: #{obj.class.name}"
|
480
|
+
end
|
481
|
+
end
|
482
|
+
|
483
|
+
def to_ruby(ivalue)
|
484
|
+
if ivalue.bool?
|
485
|
+
ivalue.to_bool
|
486
|
+
elsif ivalue.double?
|
487
|
+
ivalue.to_double
|
488
|
+
elsif ivalue.int?
|
489
|
+
ivalue.to_int
|
490
|
+
elsif ivalue.none?
|
491
|
+
nil
|
492
|
+
elsif ivalue.string?
|
493
|
+
ivalue.to_string_ref
|
494
|
+
elsif ivalue.tensor?
|
495
|
+
ivalue.to_tensor
|
496
|
+
elsif ivalue.generic_dict?
|
497
|
+
dict = {}
|
498
|
+
ivalue.to_generic_dict.each do |k, v|
|
499
|
+
dict[to_ruby(k)] = to_ruby(v)
|
500
|
+
end
|
501
|
+
dict
|
502
|
+
elsif ivalue.list?
|
503
|
+
ivalue.to_list.map { |v| to_ruby(v) }
|
504
|
+
else
|
505
|
+
type =
|
506
|
+
if ivalue.capsule?
|
507
|
+
"Capsule"
|
508
|
+
elsif ivalue.custom_class?
|
509
|
+
"CustomClass"
|
510
|
+
elsif ivalue.tuple?
|
511
|
+
"Tuple"
|
512
|
+
elsif ivalue.future?
|
513
|
+
"Future"
|
514
|
+
elsif ivalue.r_ref?
|
515
|
+
"RRef"
|
516
|
+
elsif ivalue.int_list?
|
517
|
+
"IntList"
|
518
|
+
elsif ivalue.double_list?
|
519
|
+
"DoubleList"
|
520
|
+
elsif ivalue.bool_list?
|
521
|
+
"BoolList"
|
522
|
+
elsif ivalue.tensor_list?
|
523
|
+
"TensorList"
|
524
|
+
elsif ivalue.object?
|
525
|
+
"Object"
|
526
|
+
elsif ivalue.module?
|
527
|
+
"Module"
|
528
|
+
elsif ivalue.py_object?
|
529
|
+
"PyObject"
|
530
|
+
elsif ivalue.scalar?
|
531
|
+
"Scalar"
|
532
|
+
elsif ivalue.device?
|
533
|
+
"Device"
|
534
|
+
# elsif ivalue.generator?
|
535
|
+
# "Generator"
|
536
|
+
elsif ivalue.ptr_type?
|
537
|
+
"PtrType"
|
538
|
+
else
|
539
|
+
"Unknown"
|
540
|
+
end
|
541
|
+
|
542
|
+
raise Error, "Unsupported type: #{type}"
|
543
|
+
end
|
544
|
+
end
|
545
|
+
|
450
546
|
def tensor_size(size)
|
451
547
|
size.flatten
|
452
548
|
end
|
data/lib/torch/hub.rb
CHANGED
@@ -4,6 +4,58 @@ module Torch
|
|
4
4
|
def list(github, force_reload: false)
|
5
5
|
raise NotImplementedYet
|
6
6
|
end
|
7
|
+
|
8
|
+
def download_url_to_file(url, dst)
|
9
|
+
uri = URI(url)
|
10
|
+
tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
|
11
|
+
location = nil
|
12
|
+
|
13
|
+
Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
|
14
|
+
request = Net::HTTP::Get.new(uri)
|
15
|
+
|
16
|
+
puts "Downloading #{url}..."
|
17
|
+
File.open(tmp, "wb") do |f|
|
18
|
+
http.request(request) do |response|
|
19
|
+
case response
|
20
|
+
when Net::HTTPRedirection
|
21
|
+
location = response["location"]
|
22
|
+
when Net::HTTPSuccess
|
23
|
+
response.read_body do |chunk|
|
24
|
+
f.write(chunk)
|
25
|
+
end
|
26
|
+
else
|
27
|
+
raise Error, "Bad response"
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
32
|
+
|
33
|
+
if location
|
34
|
+
download_url_to_file(location, dst)
|
35
|
+
else
|
36
|
+
FileUtils.mv(tmp, dst)
|
37
|
+
nil
|
38
|
+
end
|
39
|
+
end
|
40
|
+
|
41
|
+
def load_state_dict_from_url(url, model_dir: nil)
|
42
|
+
unless model_dir
|
43
|
+
torch_home = ENV["TORCH_HOME"] || "#{ENV["XDG_CACHE_HOME"] || "#{ENV["HOME"]}/.cache"}/torch"
|
44
|
+
model_dir = File.join(torch_home, "checkpoints")
|
45
|
+
end
|
46
|
+
|
47
|
+
FileUtils.mkdir_p(model_dir)
|
48
|
+
|
49
|
+
parts = URI(url)
|
50
|
+
filename = File.basename(parts.path)
|
51
|
+
cached_file = File.join(model_dir, filename)
|
52
|
+
unless File.exist?(cached_file)
|
53
|
+
# TODO support hash_prefix
|
54
|
+
download_url_to_file(url, cached_file)
|
55
|
+
end
|
56
|
+
|
57
|
+
Torch.load(cached_file)
|
58
|
+
end
|
7
59
|
end
|
8
60
|
end
|
9
61
|
end
|
data/lib/torch/inspector.rb
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
module Torch
|
2
2
|
module Inspector
|
3
|
-
# TODO make more
|
3
|
+
# TODO make more performant, especially when summarizing
|
4
4
|
# how? only read data that will be displayed
|
5
5
|
def inspect
|
6
6
|
data =
|
@@ -14,7 +14,7 @@ module Torch
|
|
14
14
|
if dtype == :bool
|
15
15
|
fmt = "%s"
|
16
16
|
else
|
17
|
-
values =
|
17
|
+
values = _flat_data
|
18
18
|
abs = values.select { |v| v != 0 }.map(&:abs)
|
19
19
|
max = abs.max || 1
|
20
20
|
min = abs.min || 1
|
@@ -25,7 +25,7 @@ module Torch
|
|
25
25
|
end
|
26
26
|
|
27
27
|
if floating_point?
|
28
|
-
sci = max
|
28
|
+
sci = max > 1e8 || max < 1e-4
|
29
29
|
|
30
30
|
all_int = values.all? { |v| v.finite? && v == v.to_i }
|
31
31
|
decimal = all_int ? 1 : 4
|
data/lib/torch/nn/batch_norm.rb
CHANGED
@@ -70,6 +70,11 @@ module Torch
|
|
70
70
|
momentum: exponential_average_factor, eps: @eps
|
71
71
|
)
|
72
72
|
end
|
73
|
+
|
74
|
+
def extra_inspect
|
75
|
+
s = "%{num_features}, eps: %{eps}, momentum: %{momentum}, affine: %{affine}, track_running_stats: %{track_running_stats}"
|
76
|
+
format(s, **dict)
|
77
|
+
end
|
73
78
|
end
|
74
79
|
end
|
75
80
|
end
|
data/lib/torch/nn/conv2d.rb
CHANGED
@@ -20,7 +20,14 @@ module Torch
|
|
20
20
|
|
21
21
|
# TODO add more parameters
|
22
22
|
def extra_inspect
|
23
|
-
|
23
|
+
s = String.new("%{in_channels}, %{out_channels}, kernel_size: %{kernel_size}, stride: %{stride}")
|
24
|
+
s += ", padding: %{padding}" if @padding != [0] * @padding.size
|
25
|
+
s += ", dilation: %{dilation}" if @dilation != [1] * @dilation.size
|
26
|
+
s += ", output_padding: %{output_padding}" if @output_padding != [0] * @output_padding.size
|
27
|
+
s += ", groups: %{groups}" if @groups != 1
|
28
|
+
s += ", bias: false" unless @bias
|
29
|
+
s += ", padding_mode: %{padding_mode}" if @padding_mode != "zeros"
|
30
|
+
format(s, **dict)
|
24
31
|
end
|
25
32
|
end
|
26
33
|
end
|
data/lib/torch/nn/convnd.rb
CHANGED
data/lib/torch/nn/max_poolnd.rb
CHANGED
data/lib/torch/nn/module.rb
CHANGED
@@ -67,8 +67,9 @@ module Torch
|
|
67
67
|
self
|
68
68
|
end
|
69
69
|
|
70
|
-
|
71
|
-
|
70
|
+
# TODO add device
|
71
|
+
def cuda
|
72
|
+
_apply ->(t) { t.cuda }
|
72
73
|
end
|
73
74
|
|
74
75
|
def cpu
|
@@ -112,8 +113,28 @@ module Torch
|
|
112
113
|
destination
|
113
114
|
end
|
114
115
|
|
116
|
+
# TODO add strict option
|
117
|
+
# TODO match PyTorch behavior
|
115
118
|
def load_state_dict(state_dict)
|
116
|
-
|
119
|
+
state_dict.each do |k, input_param|
|
120
|
+
k1, k2 = k.split(".", 2)
|
121
|
+
mod = named_modules[k1]
|
122
|
+
if mod.is_a?(Module)
|
123
|
+
param = mod.named_parameters[k2]
|
124
|
+
if param.is_a?(Parameter)
|
125
|
+
Torch.no_grad do
|
126
|
+
param.copy!(input_param)
|
127
|
+
end
|
128
|
+
else
|
129
|
+
raise Error, "Unknown parameter: #{k1}"
|
130
|
+
end
|
131
|
+
else
|
132
|
+
raise Error, "Unknown module: #{k1}"
|
133
|
+
end
|
134
|
+
end
|
135
|
+
|
136
|
+
# TODO return missing keys and unexpected keys
|
137
|
+
nil
|
117
138
|
end
|
118
139
|
|
119
140
|
def parameters
|
@@ -165,8 +186,22 @@ module Torch
|
|
165
186
|
named_modules.values
|
166
187
|
end
|
167
188
|
|
168
|
-
|
169
|
-
|
189
|
+
# TODO return enumerator?
|
190
|
+
def named_modules(memo: nil, prefix: "")
|
191
|
+
ret = {}
|
192
|
+
memo ||= Set.new
|
193
|
+
unless memo.include?(self)
|
194
|
+
memo << self
|
195
|
+
ret[prefix] = self
|
196
|
+
named_children.each do |name, mod|
|
197
|
+
next unless mod.is_a?(Module)
|
198
|
+
submodule_prefix = prefix + (!prefix.empty? ? "." : "") + name
|
199
|
+
mod.named_modules(memo: memo, prefix: submodule_prefix).each do |m|
|
200
|
+
ret[m[0]] = m[1]
|
201
|
+
end
|
202
|
+
end
|
203
|
+
end
|
204
|
+
ret
|
170
205
|
end
|
171
206
|
|
172
207
|
def train(mode = true)
|
@@ -203,13 +238,15 @@ module Torch
|
|
203
238
|
|
204
239
|
def inspect
|
205
240
|
name = self.class.name.split("::").last
|
206
|
-
if
|
241
|
+
if named_children.empty?
|
207
242
|
"#{name}(#{extra_inspect})"
|
208
243
|
else
|
209
244
|
str = String.new
|
210
245
|
str << "#{name}(\n"
|
211
|
-
|
212
|
-
|
246
|
+
named_children.each do |name, mod|
|
247
|
+
mod_str = mod.inspect
|
248
|
+
mod_str = mod_str.lines.join(" ")
|
249
|
+
str << " (#{name}): #{mod_str}\n"
|
213
250
|
end
|
214
251
|
str << ")"
|
215
252
|
end
|
data/lib/torch/tensor.rb
CHANGED
@@ -4,6 +4,8 @@ module Torch
|
|
4
4
|
include Inspector
|
5
5
|
|
6
6
|
alias_method :requires_grad?, :requires_grad
|
7
|
+
alias_method :ndim, :dim
|
8
|
+
alias_method :ndimension, :dim
|
7
9
|
|
8
10
|
def self.new(*args)
|
9
11
|
FloatTensor.new(*args)
|
@@ -23,8 +25,17 @@ module Torch
|
|
23
25
|
inspect
|
24
26
|
end
|
25
27
|
|
28
|
+
# TODO make more performant
|
26
29
|
def to_a
|
27
|
-
|
30
|
+
arr = _flat_data
|
31
|
+
if shape.empty?
|
32
|
+
arr
|
33
|
+
else
|
34
|
+
shape[1..-1].reverse.each do |dim|
|
35
|
+
arr = arr.each_slice(dim)
|
36
|
+
end
|
37
|
+
arr.to_a
|
38
|
+
end
|
28
39
|
end
|
29
40
|
|
30
41
|
# TODO support dtype
|
@@ -37,6 +48,10 @@ module Torch
|
|
37
48
|
to("cpu")
|
38
49
|
end
|
39
50
|
|
51
|
+
def cuda
|
52
|
+
to("cuda")
|
53
|
+
end
|
54
|
+
|
40
55
|
def size(dim = nil)
|
41
56
|
if dim
|
42
57
|
_size_int(dim)
|
@@ -58,7 +73,15 @@ module Torch
|
|
58
73
|
if numel != 1
|
59
74
|
raise Error, "only one element tensors can be converted to Ruby scalars"
|
60
75
|
end
|
61
|
-
|
76
|
+
to_a.first
|
77
|
+
end
|
78
|
+
|
79
|
+
def to_i
|
80
|
+
item.to_i
|
81
|
+
end
|
82
|
+
|
83
|
+
def to_f
|
84
|
+
item.to_f
|
62
85
|
end
|
63
86
|
|
64
87
|
# unsure if this is correct
|
@@ -74,7 +97,7 @@ module Torch
|
|
74
97
|
def numo
|
75
98
|
cls = Torch._dtype_to_numo[dtype]
|
76
99
|
raise Error, "Cannot convert #{dtype} to Numo" unless cls
|
77
|
-
cls.
|
100
|
+
cls.from_string(_data_str).reshape(*shape)
|
78
101
|
end
|
79
102
|
|
80
103
|
def new_ones(*size, **options)
|
@@ -102,15 +125,6 @@ module Torch
|
|
102
125
|
_view(size)
|
103
126
|
end
|
104
127
|
|
105
|
-
# value and other are swapped for some methods
|
106
|
-
def add!(value = 1, other)
|
107
|
-
if other.is_a?(Numeric)
|
108
|
-
_add__scalar(other, value)
|
109
|
-
else
|
110
|
-
_add__tensor(other, value)
|
111
|
-
end
|
112
|
-
end
|
113
|
-
|
114
128
|
def +(other)
|
115
129
|
add(other)
|
116
130
|
end
|
@@ -139,6 +153,7 @@ module Torch
|
|
139
153
|
neg
|
140
154
|
end
|
141
155
|
|
156
|
+
# TODO better compare?
|
142
157
|
def <=>(other)
|
143
158
|
item <=> other
|
144
159
|
end
|
@@ -186,8 +201,27 @@ module Torch
|
|
186
201
|
end
|
187
202
|
end
|
188
203
|
|
189
|
-
|
190
|
-
|
204
|
+
# native functions that need manually defined
|
205
|
+
|
206
|
+
# value and other are swapped for some methods
|
207
|
+
def add!(value = 1, other)
|
208
|
+
if other.is_a?(Numeric)
|
209
|
+
_add__scalar(other, value)
|
210
|
+
else
|
211
|
+
_add__tensor(other, value)
|
212
|
+
end
|
213
|
+
end
|
214
|
+
|
215
|
+
# native functions overlap, so need to handle manually
|
216
|
+
def random!(*args)
|
217
|
+
case args.size
|
218
|
+
when 1
|
219
|
+
_random__to(*args)
|
220
|
+
when 2
|
221
|
+
_random__from_to(*args)
|
222
|
+
else
|
223
|
+
_random_(*args)
|
224
|
+
end
|
191
225
|
end
|
192
226
|
|
193
227
|
private
|
@@ -195,17 +229,5 @@ module Torch
|
|
195
229
|
def copy_to(dst, src)
|
196
230
|
dst.copy!(src)
|
197
231
|
end
|
198
|
-
|
199
|
-
def reshape_arr(arr, dims)
|
200
|
-
if dims.empty?
|
201
|
-
arr
|
202
|
-
else
|
203
|
-
arr = arr.flatten
|
204
|
-
dims[1..-1].reverse.each do |dim|
|
205
|
-
arr = arr.each_slice(dim)
|
206
|
-
end
|
207
|
-
arr.to_a
|
208
|
-
end
|
209
|
-
end
|
210
232
|
end
|
211
233
|
end
|