torch-rb 0.3.2 → 0.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,14 +6,24 @@ module Torch
6
6
  @name = @functions.first.ruby_name
7
7
  @min_args = @functions.map { |f| f.args.count { |a| a[:pos] && !a[:has_default] } }.min
8
8
  @max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
9
+ @int_array_first = @functions.all? { |c| c.args.first && c.args.first[:type] == "int[]" }
9
10
  end
10
11
 
12
+ # TODO improve performance
13
+ # possibly move to C++ (see python_arg_parser.cpp)
11
14
  def parse(args, options)
12
15
  candidates = @functions.dup
13
16
 
14
- # remove nil
15
- while args.any? && args.last.nil?
16
- args.pop
17
+ # TODO check candidates individually to see if they match
18
+ if @int_array_first
19
+ int_args = []
20
+ while args.first.is_a?(Integer)
21
+ int_args << args.shift
22
+ end
23
+ if int_args.any?
24
+ raise ArgumentError, "argument '#{candidates.first.args.first[:name]}' must be array of ints, but found element of type #{args.first.class.name} at pos #{int_args.size + 1}" if args.any?
25
+ args.unshift(int_args)
26
+ end
17
27
  end
18
28
 
19
29
  # TODO account for args passed as options here
@@ -25,89 +35,60 @@ module Torch
25
35
 
26
36
  candidates.reject! { |f| args.size > f.args.size }
27
37
 
28
- # exclude functions missing required options
29
- candidates.reject! do |func|
30
- # TODO make more generic
31
- func.out? && !options[:out]
32
- end
33
-
34
38
  # handle out with multiple
35
39
  # there should only be one match, so safe to modify all
36
- out_func = candidates.find { |f| f.out? }
37
- if out_func && out_func.out_size > 1 && options[:out]
38
- out_args = out_func.args.last(2).map { |a| a[:name] }
39
- out_args.zip(options.delete(:out)).each do |k, v|
40
- options[k.to_sym] = v
41
- end
42
- candidates = [out_func]
43
- end
44
-
45
- # exclude functions where options don't match
46
- options.each do |k, v|
47
- candidates.select! do |func|
48
- func.args.any? { |a| a[:name] == k.to_s }
40
+ if options[:out]
41
+ if (out_func = candidates.find { |f| f.out? }) && out_func.out_size > 1
42
+ out_args = out_func.args.last(2).map { |a| a[:name] }
43
+ out_args.zip(options.delete(:out)).each do |k, v|
44
+ options[k] = v
45
+ end
46
+ candidates = [out_func]
49
47
  end
50
- # TODO show all bad keywords at once like Ruby?
51
- return {error: "unknown keyword: #{k}"} if candidates.empty?
48
+ else
49
+ # exclude functions missing required options
50
+ candidates.reject!(&:out?)
52
51
  end
53
52
 
54
- final_values = {}
53
+ final_values = nil
55
54
 
56
55
  # check args
57
- candidates.select! do |func|
56
+ while (func = candidates.shift)
58
57
  good = true
59
58
 
60
- values = args.zip(func.args).map { |a, fa| [fa[:name], a] }.to_h
61
- values.merge!(options.map { |k, v| [k.to_s, v] }.to_h)
62
- func.args.each do |fa|
63
- values[fa[:name]] = fa[:default] if values[fa[:name]].nil?
59
+ # set values
60
+ # TODO use array instead of hash?
61
+ values = {}
62
+ args.each_with_index do |a, i|
63
+ values[func.arg_names[i]] = a
64
+ end
65
+ options.each do |k, v|
66
+ values[k] = v
67
+ end
68
+ func.arg_defaults.each do |k, v|
69
+ values[k] = v unless values.key?(k)
70
+ end
71
+ func.int_array_lengths.each do |k, len|
72
+ values[k] = [values[k]] * len if values[k].is_a?(Integer)
64
73
  end
65
74
 
66
- arg_types = func.args.map { |a| [a[:name], a[:type]] }.to_h
75
+ arg_checkers = func.arg_checkers
67
76
 
68
77
  values.each_key do |k|
69
- v = values[k]
70
- t = arg_types[k].split("(").first
71
-
72
- good =
73
- case t
74
- when "Tensor"
75
- v.is_a?(Tensor)
76
- when "Tensor?"
77
- v.nil? || v.is_a?(Tensor)
78
- when "Tensor[]", "Tensor?[]"
79
- v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
80
- when "int"
81
- if k == "reduction"
82
- v.is_a?(String)
83
- else
84
- v.is_a?(Integer)
85
- end
86
- when "float"
87
- v.is_a?(Numeric)
88
- when /int\[.*\]/
89
- if v.is_a?(Integer)
90
- size = t[4..-2]
91
- raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
92
- v = [v] * size.to_i
93
- values[k] = v
94
- end
95
- v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) }
96
- when "Scalar"
97
- v.is_a?(Numeric)
98
- when "ScalarType?"
99
- v.nil?
100
- when "bool"
101
- v == true || v == false
102
- when "str"
103
- v.is_a?(String)
104
- else
105
- raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}."
78
+ unless arg_checkers.key?(k)
79
+ good = false
80
+ if candidates.empty?
81
+ # TODO show all bad keywords at once like Ruby?
82
+ return {error: "unknown keyword: #{k}"}
106
83
  end
84
+ break
85
+ end
107
86
 
108
- if !good
109
- if candidates.size == 1
110
- k = "input" if k == "self"
87
+ unless arg_checkers[k].call(values[k])
88
+ good = false
89
+ if candidates.empty?
90
+ t = func.arg_types[k]
91
+ k = :input if k == :self
111
92
  return {error: "#{@name}(): argument '#{k}' must be #{t}"}
112
93
  end
113
94
  break
@@ -116,19 +97,19 @@ module Torch
116
97
 
117
98
  if good
118
99
  final_values = values
100
+ break
119
101
  end
120
-
121
- good
122
102
  end
123
103
 
124
- if candidates.size != 1
104
+ unless final_values
125
105
  raise Error, "This should never happen. Please report a bug with #{@name}."
126
106
  end
127
107
 
128
- func = candidates.first
108
+ args = func.arg_names.map { |k| final_values[k] }
109
+ args << TensorOptions.new.dtype(6) if func.tensor_options
129
110
  {
130
111
  name: func.cpp_name,
131
- args: func.args.map { |a| final_values[a[:name]] }
112
+ args: args
132
113
  }
133
114
  end
134
115
  end
@@ -178,8 +178,12 @@ module Torch
178
178
  Torch.hardshrink(input, lambd)
179
179
  end
180
180
 
181
- def leaky_relu(input, negative_slope = 0.01)
182
- NN.leaky_relu(input, negative_slope)
181
+ def leaky_relu(input, negative_slope = 0.01, inplace: false)
182
+ if inplace
183
+ NN.leaky_relu!(input, negative_slope)
184
+ else
185
+ NN.leaky_relu(input, negative_slope)
186
+ end
183
187
  end
184
188
 
185
189
  def log_sigmoid(input)
@@ -465,6 +469,77 @@ module Torch
465
469
  Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
466
470
  end
467
471
 
472
+ # vision
473
+
474
+ def interpolate(input, size: nil, scale_factor: nil, mode: "nearest", align_corners: nil, recompute_scale_factor: nil)
475
+ if ["nearest", "area"].include?(mode)
476
+ unless align_corners.nil?
477
+ raise ArgumentError, "align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear"
478
+ end
479
+ else
480
+ if align_corners.nil?
481
+ align_corners = false
482
+ end
483
+ end
484
+
485
+ scale_factor_len = input.dim - 2
486
+ scale_factor_list = [nil] * scale_factor_len
487
+ # default value of recompute_scale_factor is False
488
+ if !scale_factor.nil? && (recompute_scale_factor == false || recompute_scale_factor.nil?)
489
+ if scale_factor.is_a?(Array)
490
+ _scale_factor_repeated = scale_factor
491
+ else
492
+ _scale_factor_repeated = [scale_factor] * scale_factor_len
493
+ end
494
+ scale_factor_list = _scale_factor_repeated
495
+ end
496
+
497
+ # Give this variable a short name because it has to be repeated multiple times below.
498
+ sfl = scale_factor_list
499
+
500
+ closed_over_args = [input, size, scale_factor, recompute_scale_factor]
501
+ output_size = _interp_output_size(closed_over_args)
502
+ if input.dim == 3 && mode == "nearest"
503
+ NN.upsample_nearest1d(input, output_size, sfl[0])
504
+ elsif input.dim == 4 && mode == "nearest"
505
+ NN.upsample_nearest2d(input, output_size, sfl[0], sfl[1])
506
+ elsif input.dim == 5 && mode == "nearest"
507
+ NN.upsample_nearest3d(input, output_size, sfl[0], sfl[1], sfl[2])
508
+ elsif input.dim == 3 && mode == "area"
509
+ adaptive_avg_pool1d(input, output_size)
510
+ elsif input.dim == 4 && mode == "area"
511
+ adaptive_avg_pool2d(input, output_size)
512
+ elsif input.dim == 5 && mode == "area"
513
+ adaptive_avg_pool3d(input, output_size)
514
+ elsif input.dim == 3 && mode == "linear"
515
+ # assert align_corners is not None
516
+ NN.upsample_linear1d(input, output_size, align_corners, sfl[0])
517
+ elsif input.dim == 3 && mode == "bilinear"
518
+ raise ArgumentError, "Got 3D input, but bilinear mode needs 4D input"
519
+ elsif input.dim == 3 && mode == "trilinear"
520
+ raise ArgumentError, "Got 3D input, but trilinear mode needs 5D input"
521
+ elsif input.dim == 4 && mode == "linear"
522
+ raise ArgumentError, "Got 4D input, but linear mode needs 3D input"
523
+ elsif input.dim == 4 && mode == "bilinear"
524
+ # assert align_corners is not None
525
+ NN.upsample_bilinear2d(input, output_size, align_corners, sfl[0], sfl[1])
526
+ elsif input.dim == 4 && mode == "trilinear"
527
+ raise ArgumentError, "Got 4D input, but trilinear mode needs 5D input"
528
+ elsif input.dim == 5 && mode == "linear"
529
+ raise ArgumentError, "Got 5D input, but linear mode needs 3D input"
530
+ elsif input.dim == 5 && mode == "bilinear"
531
+ raise ArgumentError, "Got 5D input, but bilinear mode needs 4D input"
532
+ elsif input.dim == 5 && mode == "trilinear"
533
+ # assert align_corners is not None
534
+ NN.upsample_trilinear3d(input, output_size, align_corners, sfl[0], sfl[1], sfl[2])
535
+ elsif input.dim == 4 && mode == "bicubic"
536
+ # assert align_corners is not None
537
+ NN.upsample_bicubic2d(input, output_size, align_corners, sfl[0], sfl[1])
538
+ else
539
+ raise ArgumentError, "Input Error: Only 3D, 4D and 5D input Tensors supported (got #{input.dim}D) for the modes: nearest | linear | bilinear | bicubic | trilinear (got #{mode})"
540
+ end
541
+ end
542
+
468
543
  private
469
544
 
470
545
  def softmax_dim(ndim)
@@ -480,6 +555,41 @@ module Torch
480
555
  out_size.zip(defaults.last(out_size.length)).map { |v, d| v || d }
481
556
  end
482
557
  end
558
+
559
+ def _interp_output_size(closed_over_args)
560
+ input, size, scale_factor, recompute_scale_factor = closed_over_args
561
+ dim = input.dim - 2
562
+ if size.nil? && scale_factor.nil?
563
+ raise ArgumentError, "either size or scale_factor should be defined"
564
+ end
565
+ if !size.nil? && !scale_factor.nil?
566
+ raise ArgumentError, "only one of size or scale_factor should be defined"
567
+ end
568
+ if !scale_factor.nil?
569
+ if scale_factor.is_a?(Array)
570
+ if scale_factor.length != dim
571
+ raise ArgumentError, "scale_factor shape must match input shape. Input is #{dim}D, scale_factor size is #{scale_factor.length}"
572
+ end
573
+ end
574
+ end
575
+
576
+ if !size.nil?
577
+ if size.is_a?(Array)
578
+ return size
579
+ else
580
+ return [size] * dim
581
+ end
582
+ end
583
+
584
+ raise "Failed assertion" if scale_factor.nil?
585
+ if scale_factor.is_a?(Array)
586
+ scale_factors = scale_factor
587
+ else
588
+ scale_factors = [scale_factor] * dim
589
+ end
590
+
591
+ dim.times.map { |i| (input.size(i + 2) * scale_factors[i]).floor }
592
+ end
483
593
  end
484
594
  end
485
595
 
@@ -1,14 +1,14 @@
1
1
  module Torch
2
2
  module NN
3
3
  class LeakyReLU < Module
4
- def initialize(negative_slope: 1e-2) #, inplace: false)
4
+ def initialize(negative_slope: 1e-2, inplace: false)
5
5
  super()
6
6
  @negative_slope = negative_slope
7
- # @inplace = inplace
7
+ @inplace = inplace
8
8
  end
9
9
 
10
10
  def forward(input)
11
- F.leaky_relu(input, @negative_slope) #, inplace: @inplace)
11
+ F.leaky_relu(input, @negative_slope, inplace: @inplace)
12
12
  end
13
13
 
14
14
  def extra_inspect
@@ -55,7 +55,15 @@ module Torch
55
55
  end
56
56
  end
57
57
  end
58
- # TODO apply to more objects
58
+
59
+ @buffers.each_key do |k|
60
+ buf = @buffers[k]
61
+ unless buf.nil?
62
+ @buffers[k] = fn.call(buf)
63
+ instance_variable_set("@#{k}", @buffers[k])
64
+ end
65
+ end
66
+
59
67
  self
60
68
  end
61
69
 
@@ -0,0 +1,31 @@
1
+ module Torch
2
+ module NN
3
+ class Upsample < Module
4
+ def initialize(size: nil, scale_factor: nil, mode: "nearest", align_corners: nil)
5
+ super()
6
+ @size = size
7
+ if scale_factor.is_a?(Array)
8
+ @scale_factor = scale_factor.map(&:to_f)
9
+ else
10
+ @scale_factor = scale_factor ? scale_factor.to_f : nil
11
+ end
12
+ @mode = mode
13
+ @align_corners = align_corners
14
+ end
15
+
16
+ def forward(input)
17
+ F.interpolate(input, size: @size, scale_factor: @scale_factor, mode: @mode, align_corners: @align_corners)
18
+ end
19
+
20
+ def extra_inspect
21
+ if !@scale_factor.nil?
22
+ info = "scale_factor: #{@scale_factor.inspect}"
23
+ else
24
+ info = "size: #{@size.inspect}"
25
+ end
26
+ info += ", mode: #{@mode.inspect}"
27
+ info
28
+ end
29
+ end
30
+ end
31
+ end
@@ -48,6 +48,11 @@ module Torch
48
48
  end
49
49
 
50
50
  def to(device = nil, dtype: nil, non_blocking: false, copy: false)
51
+ if device.is_a?(Symbol) && !dtype
52
+ dtype = device
53
+ device = nil
54
+ end
55
+
51
56
  device ||= self.device
52
57
  device = Device.new(device) if device.is_a?(String)
53
58
 
@@ -74,10 +79,6 @@ module Torch
74
79
  end
75
80
  end
76
81
 
77
- def shape
78
- dim.times.map { |i| size(i) }
79
- end
80
-
81
82
  # mirror Python len()
82
83
  def length
83
84
  size(0)
@@ -103,11 +104,6 @@ module Torch
103
104
  Torch.empty(0, dtype: dtype)
104
105
  end
105
106
 
106
- def backward(gradient = nil, retain_graph: nil, create_graph: false)
107
- retain_graph = create_graph if retain_graph.nil?
108
- _backward(gradient, retain_graph, create_graph)
109
- end
110
-
111
107
  # TODO read directly from memory
112
108
  def numo
113
109
  cls = Torch._dtype_to_numo[dtype]
@@ -124,9 +120,14 @@ module Torch
124
120
  end
125
121
 
126
122
  def type(dtype)
127
- enum = DTYPE_TO_ENUM[dtype]
128
- raise Error, "Unknown type: #{dtype}" unless enum
129
- _type(enum)
123
+ if dtype.is_a?(Class)
124
+ raise Error, "Invalid type: #{dtype}" unless TENSOR_TYPE_CLASSES.include?(dtype)
125
+ dtype.new(self)
126
+ else
127
+ enum = DTYPE_TO_ENUM[dtype]
128
+ raise Error, "Invalid type: #{dtype}" unless enum
129
+ _type(enum)
130
+ end
130
131
  end
131
132
 
132
133
  def reshape(*size)
@@ -188,49 +189,15 @@ module Torch
188
189
  # based on python_variable_indexing.cpp and
189
190
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
190
191
  def [](*indexes)
191
- result = self
192
- dim = 0
193
- indexes.each do |index|
194
- if index.is_a?(Numeric)
195
- result = result._select_int(dim, index)
196
- elsif index.is_a?(Range)
197
- finish = index.end
198
- finish += 1 unless index.exclude_end?
199
- result = result._slice_tensor(dim, index.begin, finish, 1)
200
- dim += 1
201
- elsif index.is_a?(Tensor)
202
- result = result.index([index])
203
- elsif index.nil?
204
- result = result.unsqueeze(dim)
205
- dim += 1
206
- elsif index == true
207
- result = result.unsqueeze(dim)
208
- # TODO handle false
209
- else
210
- raise Error, "Unsupported index type: #{index.class.name}"
211
- end
212
- end
213
- result
192
+ _index(tensor_indexes(indexes))
214
193
  end
215
194
 
216
195
  # based on python_variable_indexing.cpp and
217
196
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
218
- def []=(index, value)
197
+ def []=(*indexes, value)
219
198
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
220
-
221
199
  value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
222
-
223
- if index.is_a?(Numeric)
224
- index_put!([Torch.tensor(index)], value)
225
- elsif index.is_a?(Range)
226
- finish = index.end
227
- finish += 1 unless index.exclude_end?
228
- _slice_tensor(0, index.begin, finish, 1).copy!(value)
229
- elsif index.is_a?(Tensor)
230
- index_put!([index], value)
231
- else
232
- raise Error, "Unsupported index type: #{index.class.name}"
233
- end
200
+ _index_put_custom(tensor_indexes(indexes), value)
234
201
  end
235
202
 
236
203
  # native functions that need manually defined
@@ -244,13 +211,13 @@ module Torch
244
211
  end
245
212
  end
246
213
 
247
- # native functions overlap, so need to handle manually
214
+ # parser can't handle overlap, so need to handle manually
248
215
  def random!(*args)
249
216
  case args.size
250
217
  when 1
251
218
  _random__to(*args)
252
219
  when 2
253
- _random__from_to(*args)
220
+ _random__from(*args)
254
221
  else
255
222
  _random_(*args)
256
223
  end
@@ -260,5 +227,32 @@ module Torch
260
227
  _clamp_min_(min)
261
228
  _clamp_max_(max)
262
229
  end
230
+
231
+ private
232
+
233
+ def tensor_indexes(indexes)
234
+ indexes.map do |index|
235
+ case index
236
+ when Integer
237
+ TensorIndex.integer(index)
238
+ when Range
239
+ finish = index.end || -1
240
+ if finish == -1 && !index.exclude_end?
241
+ finish = nil
242
+ else
243
+ finish += 1 unless index.exclude_end?
244
+ end
245
+ TensorIndex.slice(index.begin, finish)
246
+ when Tensor
247
+ TensorIndex.tensor(index)
248
+ when nil
249
+ TensorIndex.none
250
+ when true, false
251
+ TensorIndex.boolean(index)
252
+ else
253
+ raise Error, "Unsupported index type: #{index.class.name}"
254
+ end
255
+ end
256
+ end
263
257
  end
264
258
  end