torch-rb 0.3.2 → 0.3.7
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +28 -0
- data/README.md +7 -2
- data/ext/torch/ext.cpp +60 -20
- data/ext/torch/extconf.rb +3 -0
- data/ext/torch/templates.cpp +36 -0
- data/ext/torch/templates.hpp +81 -87
- data/lib/torch.rb +71 -19
- data/lib/torch/native/dispatcher.rb +30 -8
- data/lib/torch/native/function.rb +93 -4
- data/lib/torch/native/generator.rb +45 -41
- data/lib/torch/native/parser.rb +57 -76
- data/lib/torch/nn/functional.rb +112 -2
- data/lib/torch/nn/leaky_relu.rb +3 -3
- data/lib/torch/nn/module.rb +9 -1
- data/lib/torch/nn/upsample.rb +31 -0
- data/lib/torch/tensor.rb +45 -51
- data/lib/torch/utils/data/data_loader.rb +2 -0
- data/lib/torch/utils/data/tensor_dataset.rb +2 -0
- data/lib/torch/version.rb +1 -1
- metadata +3 -2
data/lib/torch/native/parser.rb
CHANGED
@@ -6,14 +6,24 @@ module Torch
|
|
6
6
|
@name = @functions.first.ruby_name
|
7
7
|
@min_args = @functions.map { |f| f.args.count { |a| a[:pos] && !a[:has_default] } }.min
|
8
8
|
@max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
|
9
|
+
@int_array_first = @functions.all? { |c| c.args.first && c.args.first[:type] == "int[]" }
|
9
10
|
end
|
10
11
|
|
12
|
+
# TODO improve performance
|
13
|
+
# possibly move to C++ (see python_arg_parser.cpp)
|
11
14
|
def parse(args, options)
|
12
15
|
candidates = @functions.dup
|
13
16
|
|
14
|
-
#
|
15
|
-
|
16
|
-
|
17
|
+
# TODO check candidates individually to see if they match
|
18
|
+
if @int_array_first
|
19
|
+
int_args = []
|
20
|
+
while args.first.is_a?(Integer)
|
21
|
+
int_args << args.shift
|
22
|
+
end
|
23
|
+
if int_args.any?
|
24
|
+
raise ArgumentError, "argument '#{candidates.first.args.first[:name]}' must be array of ints, but found element of type #{args.first.class.name} at pos #{int_args.size + 1}" if args.any?
|
25
|
+
args.unshift(int_args)
|
26
|
+
end
|
17
27
|
end
|
18
28
|
|
19
29
|
# TODO account for args passed as options here
|
@@ -25,89 +35,60 @@ module Torch
|
|
25
35
|
|
26
36
|
candidates.reject! { |f| args.size > f.args.size }
|
27
37
|
|
28
|
-
# exclude functions missing required options
|
29
|
-
candidates.reject! do |func|
|
30
|
-
# TODO make more generic
|
31
|
-
func.out? && !options[:out]
|
32
|
-
end
|
33
|
-
|
34
38
|
# handle out with multiple
|
35
39
|
# there should only be one match, so safe to modify all
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
end
|
44
|
-
|
45
|
-
# exclude functions where options don't match
|
46
|
-
options.each do |k, v|
|
47
|
-
candidates.select! do |func|
|
48
|
-
func.args.any? { |a| a[:name] == k.to_s }
|
40
|
+
if options[:out]
|
41
|
+
if (out_func = candidates.find { |f| f.out? }) && out_func.out_size > 1
|
42
|
+
out_args = out_func.args.last(2).map { |a| a[:name] }
|
43
|
+
out_args.zip(options.delete(:out)).each do |k, v|
|
44
|
+
options[k] = v
|
45
|
+
end
|
46
|
+
candidates = [out_func]
|
49
47
|
end
|
50
|
-
|
51
|
-
|
48
|
+
else
|
49
|
+
# exclude functions missing required options
|
50
|
+
candidates.reject!(&:out?)
|
52
51
|
end
|
53
52
|
|
54
|
-
final_values =
|
53
|
+
final_values = nil
|
55
54
|
|
56
55
|
# check args
|
57
|
-
candidates.
|
56
|
+
while (func = candidates.shift)
|
58
57
|
good = true
|
59
58
|
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
59
|
+
# set values
|
60
|
+
# TODO use array instead of hash?
|
61
|
+
values = {}
|
62
|
+
args.each_with_index do |a, i|
|
63
|
+
values[func.arg_names[i]] = a
|
64
|
+
end
|
65
|
+
options.each do |k, v|
|
66
|
+
values[k] = v
|
67
|
+
end
|
68
|
+
func.arg_defaults.each do |k, v|
|
69
|
+
values[k] = v unless values.key?(k)
|
70
|
+
end
|
71
|
+
func.int_array_lengths.each do |k, len|
|
72
|
+
values[k] = [values[k]] * len if values[k].is_a?(Integer)
|
64
73
|
end
|
65
74
|
|
66
|
-
|
75
|
+
arg_checkers = func.arg_checkers
|
67
76
|
|
68
77
|
values.each_key do |k|
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
when "Tensor"
|
75
|
-
v.is_a?(Tensor)
|
76
|
-
when "Tensor?"
|
77
|
-
v.nil? || v.is_a?(Tensor)
|
78
|
-
when "Tensor[]", "Tensor?[]"
|
79
|
-
v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
|
80
|
-
when "int"
|
81
|
-
if k == "reduction"
|
82
|
-
v.is_a?(String)
|
83
|
-
else
|
84
|
-
v.is_a?(Integer)
|
85
|
-
end
|
86
|
-
when "float"
|
87
|
-
v.is_a?(Numeric)
|
88
|
-
when /int\[.*\]/
|
89
|
-
if v.is_a?(Integer)
|
90
|
-
size = t[4..-2]
|
91
|
-
raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
|
92
|
-
v = [v] * size.to_i
|
93
|
-
values[k] = v
|
94
|
-
end
|
95
|
-
v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) }
|
96
|
-
when "Scalar"
|
97
|
-
v.is_a?(Numeric)
|
98
|
-
when "ScalarType?"
|
99
|
-
v.nil?
|
100
|
-
when "bool"
|
101
|
-
v == true || v == false
|
102
|
-
when "str"
|
103
|
-
v.is_a?(String)
|
104
|
-
else
|
105
|
-
raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}."
|
78
|
+
unless arg_checkers.key?(k)
|
79
|
+
good = false
|
80
|
+
if candidates.empty?
|
81
|
+
# TODO show all bad keywords at once like Ruby?
|
82
|
+
return {error: "unknown keyword: #{k}"}
|
106
83
|
end
|
84
|
+
break
|
85
|
+
end
|
107
86
|
|
108
|
-
|
109
|
-
|
110
|
-
|
87
|
+
unless arg_checkers[k].call(values[k])
|
88
|
+
good = false
|
89
|
+
if candidates.empty?
|
90
|
+
t = func.arg_types[k]
|
91
|
+
k = :input if k == :self
|
111
92
|
return {error: "#{@name}(): argument '#{k}' must be #{t}"}
|
112
93
|
end
|
113
94
|
break
|
@@ -116,19 +97,19 @@ module Torch
|
|
116
97
|
|
117
98
|
if good
|
118
99
|
final_values = values
|
100
|
+
break
|
119
101
|
end
|
120
|
-
|
121
|
-
good
|
122
102
|
end
|
123
103
|
|
124
|
-
|
104
|
+
unless final_values
|
125
105
|
raise Error, "This should never happen. Please report a bug with #{@name}."
|
126
106
|
end
|
127
107
|
|
128
|
-
|
108
|
+
args = func.arg_names.map { |k| final_values[k] }
|
109
|
+
args << TensorOptions.new.dtype(6) if func.tensor_options
|
129
110
|
{
|
130
111
|
name: func.cpp_name,
|
131
|
-
args:
|
112
|
+
args: args
|
132
113
|
}
|
133
114
|
end
|
134
115
|
end
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -178,8 +178,12 @@ module Torch
|
|
178
178
|
Torch.hardshrink(input, lambd)
|
179
179
|
end
|
180
180
|
|
181
|
-
def leaky_relu(input, negative_slope = 0.01)
|
182
|
-
|
181
|
+
def leaky_relu(input, negative_slope = 0.01, inplace: false)
|
182
|
+
if inplace
|
183
|
+
NN.leaky_relu!(input, negative_slope)
|
184
|
+
else
|
185
|
+
NN.leaky_relu(input, negative_slope)
|
186
|
+
end
|
183
187
|
end
|
184
188
|
|
185
189
|
def log_sigmoid(input)
|
@@ -465,6 +469,77 @@ module Torch
|
|
465
469
|
Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
|
466
470
|
end
|
467
471
|
|
472
|
+
# vision
|
473
|
+
|
474
|
+
def interpolate(input, size: nil, scale_factor: nil, mode: "nearest", align_corners: nil, recompute_scale_factor: nil)
|
475
|
+
if ["nearest", "area"].include?(mode)
|
476
|
+
unless align_corners.nil?
|
477
|
+
raise ArgumentError, "align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear"
|
478
|
+
end
|
479
|
+
else
|
480
|
+
if align_corners.nil?
|
481
|
+
align_corners = false
|
482
|
+
end
|
483
|
+
end
|
484
|
+
|
485
|
+
scale_factor_len = input.dim - 2
|
486
|
+
scale_factor_list = [nil] * scale_factor_len
|
487
|
+
# default value of recompute_scale_factor is False
|
488
|
+
if !scale_factor.nil? && (recompute_scale_factor == false || recompute_scale_factor.nil?)
|
489
|
+
if scale_factor.is_a?(Array)
|
490
|
+
_scale_factor_repeated = scale_factor
|
491
|
+
else
|
492
|
+
_scale_factor_repeated = [scale_factor] * scale_factor_len
|
493
|
+
end
|
494
|
+
scale_factor_list = _scale_factor_repeated
|
495
|
+
end
|
496
|
+
|
497
|
+
# Give this variable a short name because it has to be repeated multiple times below.
|
498
|
+
sfl = scale_factor_list
|
499
|
+
|
500
|
+
closed_over_args = [input, size, scale_factor, recompute_scale_factor]
|
501
|
+
output_size = _interp_output_size(closed_over_args)
|
502
|
+
if input.dim == 3 && mode == "nearest"
|
503
|
+
NN.upsample_nearest1d(input, output_size, sfl[0])
|
504
|
+
elsif input.dim == 4 && mode == "nearest"
|
505
|
+
NN.upsample_nearest2d(input, output_size, sfl[0], sfl[1])
|
506
|
+
elsif input.dim == 5 && mode == "nearest"
|
507
|
+
NN.upsample_nearest3d(input, output_size, sfl[0], sfl[1], sfl[2])
|
508
|
+
elsif input.dim == 3 && mode == "area"
|
509
|
+
adaptive_avg_pool1d(input, output_size)
|
510
|
+
elsif input.dim == 4 && mode == "area"
|
511
|
+
adaptive_avg_pool2d(input, output_size)
|
512
|
+
elsif input.dim == 5 && mode == "area"
|
513
|
+
adaptive_avg_pool3d(input, output_size)
|
514
|
+
elsif input.dim == 3 && mode == "linear"
|
515
|
+
# assert align_corners is not None
|
516
|
+
NN.upsample_linear1d(input, output_size, align_corners, sfl[0])
|
517
|
+
elsif input.dim == 3 && mode == "bilinear"
|
518
|
+
raise ArgumentError, "Got 3D input, but bilinear mode needs 4D input"
|
519
|
+
elsif input.dim == 3 && mode == "trilinear"
|
520
|
+
raise ArgumentError, "Got 3D input, but trilinear mode needs 5D input"
|
521
|
+
elsif input.dim == 4 && mode == "linear"
|
522
|
+
raise ArgumentError, "Got 4D input, but linear mode needs 3D input"
|
523
|
+
elsif input.dim == 4 && mode == "bilinear"
|
524
|
+
# assert align_corners is not None
|
525
|
+
NN.upsample_bilinear2d(input, output_size, align_corners, sfl[0], sfl[1])
|
526
|
+
elsif input.dim == 4 && mode == "trilinear"
|
527
|
+
raise ArgumentError, "Got 4D input, but trilinear mode needs 5D input"
|
528
|
+
elsif input.dim == 5 && mode == "linear"
|
529
|
+
raise ArgumentError, "Got 5D input, but linear mode needs 3D input"
|
530
|
+
elsif input.dim == 5 && mode == "bilinear"
|
531
|
+
raise ArgumentError, "Got 5D input, but bilinear mode needs 4D input"
|
532
|
+
elsif input.dim == 5 && mode == "trilinear"
|
533
|
+
# assert align_corners is not None
|
534
|
+
NN.upsample_trilinear3d(input, output_size, align_corners, sfl[0], sfl[1], sfl[2])
|
535
|
+
elsif input.dim == 4 && mode == "bicubic"
|
536
|
+
# assert align_corners is not None
|
537
|
+
NN.upsample_bicubic2d(input, output_size, align_corners, sfl[0], sfl[1])
|
538
|
+
else
|
539
|
+
raise ArgumentError, "Input Error: Only 3D, 4D and 5D input Tensors supported (got #{input.dim}D) for the modes: nearest | linear | bilinear | bicubic | trilinear (got #{mode})"
|
540
|
+
end
|
541
|
+
end
|
542
|
+
|
468
543
|
private
|
469
544
|
|
470
545
|
def softmax_dim(ndim)
|
@@ -480,6 +555,41 @@ module Torch
|
|
480
555
|
out_size.zip(defaults.last(out_size.length)).map { |v, d| v || d }
|
481
556
|
end
|
482
557
|
end
|
558
|
+
|
559
|
+
def _interp_output_size(closed_over_args)
|
560
|
+
input, size, scale_factor, recompute_scale_factor = closed_over_args
|
561
|
+
dim = input.dim - 2
|
562
|
+
if size.nil? && scale_factor.nil?
|
563
|
+
raise ArgumentError, "either size or scale_factor should be defined"
|
564
|
+
end
|
565
|
+
if !size.nil? && !scale_factor.nil?
|
566
|
+
raise ArgumentError, "only one of size or scale_factor should be defined"
|
567
|
+
end
|
568
|
+
if !scale_factor.nil?
|
569
|
+
if scale_factor.is_a?(Array)
|
570
|
+
if scale_factor.length != dim
|
571
|
+
raise ArgumentError, "scale_factor shape must match input shape. Input is #{dim}D, scale_factor size is #{scale_factor.length}"
|
572
|
+
end
|
573
|
+
end
|
574
|
+
end
|
575
|
+
|
576
|
+
if !size.nil?
|
577
|
+
if size.is_a?(Array)
|
578
|
+
return size
|
579
|
+
else
|
580
|
+
return [size] * dim
|
581
|
+
end
|
582
|
+
end
|
583
|
+
|
584
|
+
raise "Failed assertion" if scale_factor.nil?
|
585
|
+
if scale_factor.is_a?(Array)
|
586
|
+
scale_factors = scale_factor
|
587
|
+
else
|
588
|
+
scale_factors = [scale_factor] * dim
|
589
|
+
end
|
590
|
+
|
591
|
+
dim.times.map { |i| (input.size(i + 2) * scale_factors[i]).floor }
|
592
|
+
end
|
483
593
|
end
|
484
594
|
end
|
485
595
|
|
data/lib/torch/nn/leaky_relu.rb
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class LeakyReLU < Module
|
4
|
-
def initialize(negative_slope: 1e-2
|
4
|
+
def initialize(negative_slope: 1e-2, inplace: false)
|
5
5
|
super()
|
6
6
|
@negative_slope = negative_slope
|
7
|
-
|
7
|
+
@inplace = inplace
|
8
8
|
end
|
9
9
|
|
10
10
|
def forward(input)
|
11
|
-
F.leaky_relu(input, @negative_slope
|
11
|
+
F.leaky_relu(input, @negative_slope, inplace: @inplace)
|
12
12
|
end
|
13
13
|
|
14
14
|
def extra_inspect
|
data/lib/torch/nn/module.rb
CHANGED
@@ -55,7 +55,15 @@ module Torch
|
|
55
55
|
end
|
56
56
|
end
|
57
57
|
end
|
58
|
-
|
58
|
+
|
59
|
+
@buffers.each_key do |k|
|
60
|
+
buf = @buffers[k]
|
61
|
+
unless buf.nil?
|
62
|
+
@buffers[k] = fn.call(buf)
|
63
|
+
instance_variable_set("@#{k}", @buffers[k])
|
64
|
+
end
|
65
|
+
end
|
66
|
+
|
59
67
|
self
|
60
68
|
end
|
61
69
|
|
@@ -0,0 +1,31 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Upsample < Module
|
4
|
+
def initialize(size: nil, scale_factor: nil, mode: "nearest", align_corners: nil)
|
5
|
+
super()
|
6
|
+
@size = size
|
7
|
+
if scale_factor.is_a?(Array)
|
8
|
+
@scale_factor = scale_factor.map(&:to_f)
|
9
|
+
else
|
10
|
+
@scale_factor = scale_factor ? scale_factor.to_f : nil
|
11
|
+
end
|
12
|
+
@mode = mode
|
13
|
+
@align_corners = align_corners
|
14
|
+
end
|
15
|
+
|
16
|
+
def forward(input)
|
17
|
+
F.interpolate(input, size: @size, scale_factor: @scale_factor, mode: @mode, align_corners: @align_corners)
|
18
|
+
end
|
19
|
+
|
20
|
+
def extra_inspect
|
21
|
+
if !@scale_factor.nil?
|
22
|
+
info = "scale_factor: #{@scale_factor.inspect}"
|
23
|
+
else
|
24
|
+
info = "size: #{@size.inspect}"
|
25
|
+
end
|
26
|
+
info += ", mode: #{@mode.inspect}"
|
27
|
+
info
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
data/lib/torch/tensor.rb
CHANGED
@@ -48,6 +48,11 @@ module Torch
|
|
48
48
|
end
|
49
49
|
|
50
50
|
def to(device = nil, dtype: nil, non_blocking: false, copy: false)
|
51
|
+
if device.is_a?(Symbol) && !dtype
|
52
|
+
dtype = device
|
53
|
+
device = nil
|
54
|
+
end
|
55
|
+
|
51
56
|
device ||= self.device
|
52
57
|
device = Device.new(device) if device.is_a?(String)
|
53
58
|
|
@@ -74,10 +79,6 @@ module Torch
|
|
74
79
|
end
|
75
80
|
end
|
76
81
|
|
77
|
-
def shape
|
78
|
-
dim.times.map { |i| size(i) }
|
79
|
-
end
|
80
|
-
|
81
82
|
# mirror Python len()
|
82
83
|
def length
|
83
84
|
size(0)
|
@@ -103,11 +104,6 @@ module Torch
|
|
103
104
|
Torch.empty(0, dtype: dtype)
|
104
105
|
end
|
105
106
|
|
106
|
-
def backward(gradient = nil, retain_graph: nil, create_graph: false)
|
107
|
-
retain_graph = create_graph if retain_graph.nil?
|
108
|
-
_backward(gradient, retain_graph, create_graph)
|
109
|
-
end
|
110
|
-
|
111
107
|
# TODO read directly from memory
|
112
108
|
def numo
|
113
109
|
cls = Torch._dtype_to_numo[dtype]
|
@@ -124,9 +120,14 @@ module Torch
|
|
124
120
|
end
|
125
121
|
|
126
122
|
def type(dtype)
|
127
|
-
|
128
|
-
|
129
|
-
|
123
|
+
if dtype.is_a?(Class)
|
124
|
+
raise Error, "Invalid type: #{dtype}" unless TENSOR_TYPE_CLASSES.include?(dtype)
|
125
|
+
dtype.new(self)
|
126
|
+
else
|
127
|
+
enum = DTYPE_TO_ENUM[dtype]
|
128
|
+
raise Error, "Invalid type: #{dtype}" unless enum
|
129
|
+
_type(enum)
|
130
|
+
end
|
130
131
|
end
|
131
132
|
|
132
133
|
def reshape(*size)
|
@@ -188,49 +189,15 @@ module Torch
|
|
188
189
|
# based on python_variable_indexing.cpp and
|
189
190
|
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
190
191
|
def [](*indexes)
|
191
|
-
|
192
|
-
dim = 0
|
193
|
-
indexes.each do |index|
|
194
|
-
if index.is_a?(Numeric)
|
195
|
-
result = result._select_int(dim, index)
|
196
|
-
elsif index.is_a?(Range)
|
197
|
-
finish = index.end
|
198
|
-
finish += 1 unless index.exclude_end?
|
199
|
-
result = result._slice_tensor(dim, index.begin, finish, 1)
|
200
|
-
dim += 1
|
201
|
-
elsif index.is_a?(Tensor)
|
202
|
-
result = result.index([index])
|
203
|
-
elsif index.nil?
|
204
|
-
result = result.unsqueeze(dim)
|
205
|
-
dim += 1
|
206
|
-
elsif index == true
|
207
|
-
result = result.unsqueeze(dim)
|
208
|
-
# TODO handle false
|
209
|
-
else
|
210
|
-
raise Error, "Unsupported index type: #{index.class.name}"
|
211
|
-
end
|
212
|
-
end
|
213
|
-
result
|
192
|
+
_index(tensor_indexes(indexes))
|
214
193
|
end
|
215
194
|
|
216
195
|
# based on python_variable_indexing.cpp and
|
217
196
|
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
218
|
-
def []=(
|
197
|
+
def []=(*indexes, value)
|
219
198
|
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
|
220
|
-
|
221
199
|
value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
|
222
|
-
|
223
|
-
if index.is_a?(Numeric)
|
224
|
-
index_put!([Torch.tensor(index)], value)
|
225
|
-
elsif index.is_a?(Range)
|
226
|
-
finish = index.end
|
227
|
-
finish += 1 unless index.exclude_end?
|
228
|
-
_slice_tensor(0, index.begin, finish, 1).copy!(value)
|
229
|
-
elsif index.is_a?(Tensor)
|
230
|
-
index_put!([index], value)
|
231
|
-
else
|
232
|
-
raise Error, "Unsupported index type: #{index.class.name}"
|
233
|
-
end
|
200
|
+
_index_put_custom(tensor_indexes(indexes), value)
|
234
201
|
end
|
235
202
|
|
236
203
|
# native functions that need manually defined
|
@@ -244,13 +211,13 @@ module Torch
|
|
244
211
|
end
|
245
212
|
end
|
246
213
|
|
247
|
-
#
|
214
|
+
# parser can't handle overlap, so need to handle manually
|
248
215
|
def random!(*args)
|
249
216
|
case args.size
|
250
217
|
when 1
|
251
218
|
_random__to(*args)
|
252
219
|
when 2
|
253
|
-
|
220
|
+
_random__from(*args)
|
254
221
|
else
|
255
222
|
_random_(*args)
|
256
223
|
end
|
@@ -260,5 +227,32 @@ module Torch
|
|
260
227
|
_clamp_min_(min)
|
261
228
|
_clamp_max_(max)
|
262
229
|
end
|
230
|
+
|
231
|
+
private
|
232
|
+
|
233
|
+
def tensor_indexes(indexes)
|
234
|
+
indexes.map do |index|
|
235
|
+
case index
|
236
|
+
when Integer
|
237
|
+
TensorIndex.integer(index)
|
238
|
+
when Range
|
239
|
+
finish = index.end || -1
|
240
|
+
if finish == -1 && !index.exclude_end?
|
241
|
+
finish = nil
|
242
|
+
else
|
243
|
+
finish += 1 unless index.exclude_end?
|
244
|
+
end
|
245
|
+
TensorIndex.slice(index.begin, finish)
|
246
|
+
when Tensor
|
247
|
+
TensorIndex.tensor(index)
|
248
|
+
when nil
|
249
|
+
TensorIndex.none
|
250
|
+
when true, false
|
251
|
+
TensorIndex.boolean(index)
|
252
|
+
else
|
253
|
+
raise Error, "Unsupported index type: #{index.class.name}"
|
254
|
+
end
|
255
|
+
end
|
256
|
+
end
|
263
257
|
end
|
264
258
|
end
|