torch-rb 0.3.2 → 0.3.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +28 -0
- data/README.md +7 -2
- data/ext/torch/ext.cpp +60 -20
- data/ext/torch/extconf.rb +3 -0
- data/ext/torch/templates.cpp +36 -0
- data/ext/torch/templates.hpp +81 -87
- data/lib/torch.rb +71 -19
- data/lib/torch/native/dispatcher.rb +30 -8
- data/lib/torch/native/function.rb +93 -4
- data/lib/torch/native/generator.rb +45 -41
- data/lib/torch/native/parser.rb +57 -76
- data/lib/torch/nn/functional.rb +112 -2
- data/lib/torch/nn/leaky_relu.rb +3 -3
- data/lib/torch/nn/module.rb +9 -1
- data/lib/torch/nn/upsample.rb +31 -0
- data/lib/torch/tensor.rb +45 -51
- data/lib/torch/utils/data/data_loader.rb +2 -0
- data/lib/torch/utils/data/tensor_dataset.rb +2 -0
- data/lib/torch/version.rb +1 -1
- metadata +3 -2
data/lib/torch/native/parser.rb
CHANGED
@@ -6,14 +6,24 @@ module Torch
|
|
6
6
|
@name = @functions.first.ruby_name
|
7
7
|
@min_args = @functions.map { |f| f.args.count { |a| a[:pos] && !a[:has_default] } }.min
|
8
8
|
@max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
|
9
|
+
@int_array_first = @functions.all? { |c| c.args.first && c.args.first[:type] == "int[]" }
|
9
10
|
end
|
10
11
|
|
12
|
+
# TODO improve performance
|
13
|
+
# possibly move to C++ (see python_arg_parser.cpp)
|
11
14
|
def parse(args, options)
|
12
15
|
candidates = @functions.dup
|
13
16
|
|
14
|
-
#
|
15
|
-
|
16
|
-
|
17
|
+
# TODO check candidates individually to see if they match
|
18
|
+
if @int_array_first
|
19
|
+
int_args = []
|
20
|
+
while args.first.is_a?(Integer)
|
21
|
+
int_args << args.shift
|
22
|
+
end
|
23
|
+
if int_args.any?
|
24
|
+
raise ArgumentError, "argument '#{candidates.first.args.first[:name]}' must be array of ints, but found element of type #{args.first.class.name} at pos #{int_args.size + 1}" if args.any?
|
25
|
+
args.unshift(int_args)
|
26
|
+
end
|
17
27
|
end
|
18
28
|
|
19
29
|
# TODO account for args passed as options here
|
@@ -25,89 +35,60 @@ module Torch
|
|
25
35
|
|
26
36
|
candidates.reject! { |f| args.size > f.args.size }
|
27
37
|
|
28
|
-
# exclude functions missing required options
|
29
|
-
candidates.reject! do |func|
|
30
|
-
# TODO make more generic
|
31
|
-
func.out? && !options[:out]
|
32
|
-
end
|
33
|
-
|
34
38
|
# handle out with multiple
|
35
39
|
# there should only be one match, so safe to modify all
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
end
|
44
|
-
|
45
|
-
# exclude functions where options don't match
|
46
|
-
options.each do |k, v|
|
47
|
-
candidates.select! do |func|
|
48
|
-
func.args.any? { |a| a[:name] == k.to_s }
|
40
|
+
if options[:out]
|
41
|
+
if (out_func = candidates.find { |f| f.out? }) && out_func.out_size > 1
|
42
|
+
out_args = out_func.args.last(2).map { |a| a[:name] }
|
43
|
+
out_args.zip(options.delete(:out)).each do |k, v|
|
44
|
+
options[k] = v
|
45
|
+
end
|
46
|
+
candidates = [out_func]
|
49
47
|
end
|
50
|
-
|
51
|
-
|
48
|
+
else
|
49
|
+
# exclude functions missing required options
|
50
|
+
candidates.reject!(&:out?)
|
52
51
|
end
|
53
52
|
|
54
|
-
final_values =
|
53
|
+
final_values = nil
|
55
54
|
|
56
55
|
# check args
|
57
|
-
candidates.
|
56
|
+
while (func = candidates.shift)
|
58
57
|
good = true
|
59
58
|
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
59
|
+
# set values
|
60
|
+
# TODO use array instead of hash?
|
61
|
+
values = {}
|
62
|
+
args.each_with_index do |a, i|
|
63
|
+
values[func.arg_names[i]] = a
|
64
|
+
end
|
65
|
+
options.each do |k, v|
|
66
|
+
values[k] = v
|
67
|
+
end
|
68
|
+
func.arg_defaults.each do |k, v|
|
69
|
+
values[k] = v unless values.key?(k)
|
70
|
+
end
|
71
|
+
func.int_array_lengths.each do |k, len|
|
72
|
+
values[k] = [values[k]] * len if values[k].is_a?(Integer)
|
64
73
|
end
|
65
74
|
|
66
|
-
|
75
|
+
arg_checkers = func.arg_checkers
|
67
76
|
|
68
77
|
values.each_key do |k|
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
when "Tensor"
|
75
|
-
v.is_a?(Tensor)
|
76
|
-
when "Tensor?"
|
77
|
-
v.nil? || v.is_a?(Tensor)
|
78
|
-
when "Tensor[]", "Tensor?[]"
|
79
|
-
v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
|
80
|
-
when "int"
|
81
|
-
if k == "reduction"
|
82
|
-
v.is_a?(String)
|
83
|
-
else
|
84
|
-
v.is_a?(Integer)
|
85
|
-
end
|
86
|
-
when "float"
|
87
|
-
v.is_a?(Numeric)
|
88
|
-
when /int\[.*\]/
|
89
|
-
if v.is_a?(Integer)
|
90
|
-
size = t[4..-2]
|
91
|
-
raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
|
92
|
-
v = [v] * size.to_i
|
93
|
-
values[k] = v
|
94
|
-
end
|
95
|
-
v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) }
|
96
|
-
when "Scalar"
|
97
|
-
v.is_a?(Numeric)
|
98
|
-
when "ScalarType?"
|
99
|
-
v.nil?
|
100
|
-
when "bool"
|
101
|
-
v == true || v == false
|
102
|
-
when "str"
|
103
|
-
v.is_a?(String)
|
104
|
-
else
|
105
|
-
raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}."
|
78
|
+
unless arg_checkers.key?(k)
|
79
|
+
good = false
|
80
|
+
if candidates.empty?
|
81
|
+
# TODO show all bad keywords at once like Ruby?
|
82
|
+
return {error: "unknown keyword: #{k}"}
|
106
83
|
end
|
84
|
+
break
|
85
|
+
end
|
107
86
|
|
108
|
-
|
109
|
-
|
110
|
-
|
87
|
+
unless arg_checkers[k].call(values[k])
|
88
|
+
good = false
|
89
|
+
if candidates.empty?
|
90
|
+
t = func.arg_types[k]
|
91
|
+
k = :input if k == :self
|
111
92
|
return {error: "#{@name}(): argument '#{k}' must be #{t}"}
|
112
93
|
end
|
113
94
|
break
|
@@ -116,19 +97,19 @@ module Torch
|
|
116
97
|
|
117
98
|
if good
|
118
99
|
final_values = values
|
100
|
+
break
|
119
101
|
end
|
120
|
-
|
121
|
-
good
|
122
102
|
end
|
123
103
|
|
124
|
-
|
104
|
+
unless final_values
|
125
105
|
raise Error, "This should never happen. Please report a bug with #{@name}."
|
126
106
|
end
|
127
107
|
|
128
|
-
|
108
|
+
args = func.arg_names.map { |k| final_values[k] }
|
109
|
+
args << TensorOptions.new.dtype(6) if func.tensor_options
|
129
110
|
{
|
130
111
|
name: func.cpp_name,
|
131
|
-
args:
|
112
|
+
args: args
|
132
113
|
}
|
133
114
|
end
|
134
115
|
end
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -178,8 +178,12 @@ module Torch
|
|
178
178
|
Torch.hardshrink(input, lambd)
|
179
179
|
end
|
180
180
|
|
181
|
-
def leaky_relu(input, negative_slope = 0.01)
|
182
|
-
|
181
|
+
def leaky_relu(input, negative_slope = 0.01, inplace: false)
|
182
|
+
if inplace
|
183
|
+
NN.leaky_relu!(input, negative_slope)
|
184
|
+
else
|
185
|
+
NN.leaky_relu(input, negative_slope)
|
186
|
+
end
|
183
187
|
end
|
184
188
|
|
185
189
|
def log_sigmoid(input)
|
@@ -465,6 +469,77 @@ module Torch
|
|
465
469
|
Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
|
466
470
|
end
|
467
471
|
|
472
|
+
# vision
|
473
|
+
|
474
|
+
def interpolate(input, size: nil, scale_factor: nil, mode: "nearest", align_corners: nil, recompute_scale_factor: nil)
|
475
|
+
if ["nearest", "area"].include?(mode)
|
476
|
+
unless align_corners.nil?
|
477
|
+
raise ArgumentError, "align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear"
|
478
|
+
end
|
479
|
+
else
|
480
|
+
if align_corners.nil?
|
481
|
+
align_corners = false
|
482
|
+
end
|
483
|
+
end
|
484
|
+
|
485
|
+
scale_factor_len = input.dim - 2
|
486
|
+
scale_factor_list = [nil] * scale_factor_len
|
487
|
+
# default value of recompute_scale_factor is False
|
488
|
+
if !scale_factor.nil? && (recompute_scale_factor == false || recompute_scale_factor.nil?)
|
489
|
+
if scale_factor.is_a?(Array)
|
490
|
+
_scale_factor_repeated = scale_factor
|
491
|
+
else
|
492
|
+
_scale_factor_repeated = [scale_factor] * scale_factor_len
|
493
|
+
end
|
494
|
+
scale_factor_list = _scale_factor_repeated
|
495
|
+
end
|
496
|
+
|
497
|
+
# Give this variable a short name because it has to be repeated multiple times below.
|
498
|
+
sfl = scale_factor_list
|
499
|
+
|
500
|
+
closed_over_args = [input, size, scale_factor, recompute_scale_factor]
|
501
|
+
output_size = _interp_output_size(closed_over_args)
|
502
|
+
if input.dim == 3 && mode == "nearest"
|
503
|
+
NN.upsample_nearest1d(input, output_size, sfl[0])
|
504
|
+
elsif input.dim == 4 && mode == "nearest"
|
505
|
+
NN.upsample_nearest2d(input, output_size, sfl[0], sfl[1])
|
506
|
+
elsif input.dim == 5 && mode == "nearest"
|
507
|
+
NN.upsample_nearest3d(input, output_size, sfl[0], sfl[1], sfl[2])
|
508
|
+
elsif input.dim == 3 && mode == "area"
|
509
|
+
adaptive_avg_pool1d(input, output_size)
|
510
|
+
elsif input.dim == 4 && mode == "area"
|
511
|
+
adaptive_avg_pool2d(input, output_size)
|
512
|
+
elsif input.dim == 5 && mode == "area"
|
513
|
+
adaptive_avg_pool3d(input, output_size)
|
514
|
+
elsif input.dim == 3 && mode == "linear"
|
515
|
+
# assert align_corners is not None
|
516
|
+
NN.upsample_linear1d(input, output_size, align_corners, sfl[0])
|
517
|
+
elsif input.dim == 3 && mode == "bilinear"
|
518
|
+
raise ArgumentError, "Got 3D input, but bilinear mode needs 4D input"
|
519
|
+
elsif input.dim == 3 && mode == "trilinear"
|
520
|
+
raise ArgumentError, "Got 3D input, but trilinear mode needs 5D input"
|
521
|
+
elsif input.dim == 4 && mode == "linear"
|
522
|
+
raise ArgumentError, "Got 4D input, but linear mode needs 3D input"
|
523
|
+
elsif input.dim == 4 && mode == "bilinear"
|
524
|
+
# assert align_corners is not None
|
525
|
+
NN.upsample_bilinear2d(input, output_size, align_corners, sfl[0], sfl[1])
|
526
|
+
elsif input.dim == 4 && mode == "trilinear"
|
527
|
+
raise ArgumentError, "Got 4D input, but trilinear mode needs 5D input"
|
528
|
+
elsif input.dim == 5 && mode == "linear"
|
529
|
+
raise ArgumentError, "Got 5D input, but linear mode needs 3D input"
|
530
|
+
elsif input.dim == 5 && mode == "bilinear"
|
531
|
+
raise ArgumentError, "Got 5D input, but bilinear mode needs 4D input"
|
532
|
+
elsif input.dim == 5 && mode == "trilinear"
|
533
|
+
# assert align_corners is not None
|
534
|
+
NN.upsample_trilinear3d(input, output_size, align_corners, sfl[0], sfl[1], sfl[2])
|
535
|
+
elsif input.dim == 4 && mode == "bicubic"
|
536
|
+
# assert align_corners is not None
|
537
|
+
NN.upsample_bicubic2d(input, output_size, align_corners, sfl[0], sfl[1])
|
538
|
+
else
|
539
|
+
raise ArgumentError, "Input Error: Only 3D, 4D and 5D input Tensors supported (got #{input.dim}D) for the modes: nearest | linear | bilinear | bicubic | trilinear (got #{mode})"
|
540
|
+
end
|
541
|
+
end
|
542
|
+
|
468
543
|
private
|
469
544
|
|
470
545
|
def softmax_dim(ndim)
|
@@ -480,6 +555,41 @@ module Torch
|
|
480
555
|
out_size.zip(defaults.last(out_size.length)).map { |v, d| v || d }
|
481
556
|
end
|
482
557
|
end
|
558
|
+
|
559
|
+
def _interp_output_size(closed_over_args)
|
560
|
+
input, size, scale_factor, recompute_scale_factor = closed_over_args
|
561
|
+
dim = input.dim - 2
|
562
|
+
if size.nil? && scale_factor.nil?
|
563
|
+
raise ArgumentError, "either size or scale_factor should be defined"
|
564
|
+
end
|
565
|
+
if !size.nil? && !scale_factor.nil?
|
566
|
+
raise ArgumentError, "only one of size or scale_factor should be defined"
|
567
|
+
end
|
568
|
+
if !scale_factor.nil?
|
569
|
+
if scale_factor.is_a?(Array)
|
570
|
+
if scale_factor.length != dim
|
571
|
+
raise ArgumentError, "scale_factor shape must match input shape. Input is #{dim}D, scale_factor size is #{scale_factor.length}"
|
572
|
+
end
|
573
|
+
end
|
574
|
+
end
|
575
|
+
|
576
|
+
if !size.nil?
|
577
|
+
if size.is_a?(Array)
|
578
|
+
return size
|
579
|
+
else
|
580
|
+
return [size] * dim
|
581
|
+
end
|
582
|
+
end
|
583
|
+
|
584
|
+
raise "Failed assertion" if scale_factor.nil?
|
585
|
+
if scale_factor.is_a?(Array)
|
586
|
+
scale_factors = scale_factor
|
587
|
+
else
|
588
|
+
scale_factors = [scale_factor] * dim
|
589
|
+
end
|
590
|
+
|
591
|
+
dim.times.map { |i| (input.size(i + 2) * scale_factors[i]).floor }
|
592
|
+
end
|
483
593
|
end
|
484
594
|
end
|
485
595
|
|
data/lib/torch/nn/leaky_relu.rb
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class LeakyReLU < Module
|
4
|
-
def initialize(negative_slope: 1e-2
|
4
|
+
def initialize(negative_slope: 1e-2, inplace: false)
|
5
5
|
super()
|
6
6
|
@negative_slope = negative_slope
|
7
|
-
|
7
|
+
@inplace = inplace
|
8
8
|
end
|
9
9
|
|
10
10
|
def forward(input)
|
11
|
-
F.leaky_relu(input, @negative_slope
|
11
|
+
F.leaky_relu(input, @negative_slope, inplace: @inplace)
|
12
12
|
end
|
13
13
|
|
14
14
|
def extra_inspect
|
data/lib/torch/nn/module.rb
CHANGED
@@ -55,7 +55,15 @@ module Torch
|
|
55
55
|
end
|
56
56
|
end
|
57
57
|
end
|
58
|
-
|
58
|
+
|
59
|
+
@buffers.each_key do |k|
|
60
|
+
buf = @buffers[k]
|
61
|
+
unless buf.nil?
|
62
|
+
@buffers[k] = fn.call(buf)
|
63
|
+
instance_variable_set("@#{k}", @buffers[k])
|
64
|
+
end
|
65
|
+
end
|
66
|
+
|
59
67
|
self
|
60
68
|
end
|
61
69
|
|
@@ -0,0 +1,31 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Upsample < Module
|
4
|
+
def initialize(size: nil, scale_factor: nil, mode: "nearest", align_corners: nil)
|
5
|
+
super()
|
6
|
+
@size = size
|
7
|
+
if scale_factor.is_a?(Array)
|
8
|
+
@scale_factor = scale_factor.map(&:to_f)
|
9
|
+
else
|
10
|
+
@scale_factor = scale_factor ? scale_factor.to_f : nil
|
11
|
+
end
|
12
|
+
@mode = mode
|
13
|
+
@align_corners = align_corners
|
14
|
+
end
|
15
|
+
|
16
|
+
def forward(input)
|
17
|
+
F.interpolate(input, size: @size, scale_factor: @scale_factor, mode: @mode, align_corners: @align_corners)
|
18
|
+
end
|
19
|
+
|
20
|
+
def extra_inspect
|
21
|
+
if !@scale_factor.nil?
|
22
|
+
info = "scale_factor: #{@scale_factor.inspect}"
|
23
|
+
else
|
24
|
+
info = "size: #{@size.inspect}"
|
25
|
+
end
|
26
|
+
info += ", mode: #{@mode.inspect}"
|
27
|
+
info
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
data/lib/torch/tensor.rb
CHANGED
@@ -48,6 +48,11 @@ module Torch
|
|
48
48
|
end
|
49
49
|
|
50
50
|
def to(device = nil, dtype: nil, non_blocking: false, copy: false)
|
51
|
+
if device.is_a?(Symbol) && !dtype
|
52
|
+
dtype = device
|
53
|
+
device = nil
|
54
|
+
end
|
55
|
+
|
51
56
|
device ||= self.device
|
52
57
|
device = Device.new(device) if device.is_a?(String)
|
53
58
|
|
@@ -74,10 +79,6 @@ module Torch
|
|
74
79
|
end
|
75
80
|
end
|
76
81
|
|
77
|
-
def shape
|
78
|
-
dim.times.map { |i| size(i) }
|
79
|
-
end
|
80
|
-
|
81
82
|
# mirror Python len()
|
82
83
|
def length
|
83
84
|
size(0)
|
@@ -103,11 +104,6 @@ module Torch
|
|
103
104
|
Torch.empty(0, dtype: dtype)
|
104
105
|
end
|
105
106
|
|
106
|
-
def backward(gradient = nil, retain_graph: nil, create_graph: false)
|
107
|
-
retain_graph = create_graph if retain_graph.nil?
|
108
|
-
_backward(gradient, retain_graph, create_graph)
|
109
|
-
end
|
110
|
-
|
111
107
|
# TODO read directly from memory
|
112
108
|
def numo
|
113
109
|
cls = Torch._dtype_to_numo[dtype]
|
@@ -124,9 +120,14 @@ module Torch
|
|
124
120
|
end
|
125
121
|
|
126
122
|
def type(dtype)
|
127
|
-
|
128
|
-
|
129
|
-
|
123
|
+
if dtype.is_a?(Class)
|
124
|
+
raise Error, "Invalid type: #{dtype}" unless TENSOR_TYPE_CLASSES.include?(dtype)
|
125
|
+
dtype.new(self)
|
126
|
+
else
|
127
|
+
enum = DTYPE_TO_ENUM[dtype]
|
128
|
+
raise Error, "Invalid type: #{dtype}" unless enum
|
129
|
+
_type(enum)
|
130
|
+
end
|
130
131
|
end
|
131
132
|
|
132
133
|
def reshape(*size)
|
@@ -188,49 +189,15 @@ module Torch
|
|
188
189
|
# based on python_variable_indexing.cpp and
|
189
190
|
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
190
191
|
def [](*indexes)
|
191
|
-
|
192
|
-
dim = 0
|
193
|
-
indexes.each do |index|
|
194
|
-
if index.is_a?(Numeric)
|
195
|
-
result = result._select_int(dim, index)
|
196
|
-
elsif index.is_a?(Range)
|
197
|
-
finish = index.end
|
198
|
-
finish += 1 unless index.exclude_end?
|
199
|
-
result = result._slice_tensor(dim, index.begin, finish, 1)
|
200
|
-
dim += 1
|
201
|
-
elsif index.is_a?(Tensor)
|
202
|
-
result = result.index([index])
|
203
|
-
elsif index.nil?
|
204
|
-
result = result.unsqueeze(dim)
|
205
|
-
dim += 1
|
206
|
-
elsif index == true
|
207
|
-
result = result.unsqueeze(dim)
|
208
|
-
# TODO handle false
|
209
|
-
else
|
210
|
-
raise Error, "Unsupported index type: #{index.class.name}"
|
211
|
-
end
|
212
|
-
end
|
213
|
-
result
|
192
|
+
_index(tensor_indexes(indexes))
|
214
193
|
end
|
215
194
|
|
216
195
|
# based on python_variable_indexing.cpp and
|
217
196
|
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
218
|
-
def []=(
|
197
|
+
def []=(*indexes, value)
|
219
198
|
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
|
220
|
-
|
221
199
|
value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
|
222
|
-
|
223
|
-
if index.is_a?(Numeric)
|
224
|
-
index_put!([Torch.tensor(index)], value)
|
225
|
-
elsif index.is_a?(Range)
|
226
|
-
finish = index.end
|
227
|
-
finish += 1 unless index.exclude_end?
|
228
|
-
_slice_tensor(0, index.begin, finish, 1).copy!(value)
|
229
|
-
elsif index.is_a?(Tensor)
|
230
|
-
index_put!([index], value)
|
231
|
-
else
|
232
|
-
raise Error, "Unsupported index type: #{index.class.name}"
|
233
|
-
end
|
200
|
+
_index_put_custom(tensor_indexes(indexes), value)
|
234
201
|
end
|
235
202
|
|
236
203
|
# native functions that need manually defined
|
@@ -244,13 +211,13 @@ module Torch
|
|
244
211
|
end
|
245
212
|
end
|
246
213
|
|
247
|
-
#
|
214
|
+
# parser can't handle overlap, so need to handle manually
|
248
215
|
def random!(*args)
|
249
216
|
case args.size
|
250
217
|
when 1
|
251
218
|
_random__to(*args)
|
252
219
|
when 2
|
253
|
-
|
220
|
+
_random__from(*args)
|
254
221
|
else
|
255
222
|
_random_(*args)
|
256
223
|
end
|
@@ -260,5 +227,32 @@ module Torch
|
|
260
227
|
_clamp_min_(min)
|
261
228
|
_clamp_max_(max)
|
262
229
|
end
|
230
|
+
|
231
|
+
private
|
232
|
+
|
233
|
+
def tensor_indexes(indexes)
|
234
|
+
indexes.map do |index|
|
235
|
+
case index
|
236
|
+
when Integer
|
237
|
+
TensorIndex.integer(index)
|
238
|
+
when Range
|
239
|
+
finish = index.end || -1
|
240
|
+
if finish == -1 && !index.exclude_end?
|
241
|
+
finish = nil
|
242
|
+
else
|
243
|
+
finish += 1 unless index.exclude_end?
|
244
|
+
end
|
245
|
+
TensorIndex.slice(index.begin, finish)
|
246
|
+
when Tensor
|
247
|
+
TensorIndex.tensor(index)
|
248
|
+
when nil
|
249
|
+
TensorIndex.none
|
250
|
+
when true, false
|
251
|
+
TensorIndex.boolean(index)
|
252
|
+
else
|
253
|
+
raise Error, "Unsupported index type: #{index.class.name}"
|
254
|
+
end
|
255
|
+
end
|
256
|
+
end
|
263
257
|
end
|
264
258
|
end
|