torch-rb 0.3.2 → 0.3.7

Sign up to get free protection for your applications and to get access to all the features.
@@ -6,14 +6,24 @@ module Torch
6
6
  @name = @functions.first.ruby_name
7
7
  @min_args = @functions.map { |f| f.args.count { |a| a[:pos] && !a[:has_default] } }.min
8
8
  @max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
9
+ @int_array_first = @functions.all? { |c| c.args.first && c.args.first[:type] == "int[]" }
9
10
  end
10
11
 
12
+ # TODO improve performance
13
+ # possibly move to C++ (see python_arg_parser.cpp)
11
14
  def parse(args, options)
12
15
  candidates = @functions.dup
13
16
 
14
- # remove nil
15
- while args.any? && args.last.nil?
16
- args.pop
17
+ # TODO check candidates individually to see if they match
18
+ if @int_array_first
19
+ int_args = []
20
+ while args.first.is_a?(Integer)
21
+ int_args << args.shift
22
+ end
23
+ if int_args.any?
24
+ raise ArgumentError, "argument '#{candidates.first.args.first[:name]}' must be array of ints, but found element of type #{args.first.class.name} at pos #{int_args.size + 1}" if args.any?
25
+ args.unshift(int_args)
26
+ end
17
27
  end
18
28
 
19
29
  # TODO account for args passed as options here
@@ -25,89 +35,60 @@ module Torch
25
35
 
26
36
  candidates.reject! { |f| args.size > f.args.size }
27
37
 
28
- # exclude functions missing required options
29
- candidates.reject! do |func|
30
- # TODO make more generic
31
- func.out? && !options[:out]
32
- end
33
-
34
38
  # handle out with multiple
35
39
  # there should only be one match, so safe to modify all
36
- out_func = candidates.find { |f| f.out? }
37
- if out_func && out_func.out_size > 1 && options[:out]
38
- out_args = out_func.args.last(2).map { |a| a[:name] }
39
- out_args.zip(options.delete(:out)).each do |k, v|
40
- options[k.to_sym] = v
41
- end
42
- candidates = [out_func]
43
- end
44
-
45
- # exclude functions where options don't match
46
- options.each do |k, v|
47
- candidates.select! do |func|
48
- func.args.any? { |a| a[:name] == k.to_s }
40
+ if options[:out]
41
+ if (out_func = candidates.find { |f| f.out? }) && out_func.out_size > 1
42
+ out_args = out_func.args.last(2).map { |a| a[:name] }
43
+ out_args.zip(options.delete(:out)).each do |k, v|
44
+ options[k] = v
45
+ end
46
+ candidates = [out_func]
49
47
  end
50
- # TODO show all bad keywords at once like Ruby?
51
- return {error: "unknown keyword: #{k}"} if candidates.empty?
48
+ else
49
+ # exclude functions missing required options
50
+ candidates.reject!(&:out?)
52
51
  end
53
52
 
54
- final_values = {}
53
+ final_values = nil
55
54
 
56
55
  # check args
57
- candidates.select! do |func|
56
+ while (func = candidates.shift)
58
57
  good = true
59
58
 
60
- values = args.zip(func.args).map { |a, fa| [fa[:name], a] }.to_h
61
- values.merge!(options.map { |k, v| [k.to_s, v] }.to_h)
62
- func.args.each do |fa|
63
- values[fa[:name]] = fa[:default] if values[fa[:name]].nil?
59
+ # set values
60
+ # TODO use array instead of hash?
61
+ values = {}
62
+ args.each_with_index do |a, i|
63
+ values[func.arg_names[i]] = a
64
+ end
65
+ options.each do |k, v|
66
+ values[k] = v
67
+ end
68
+ func.arg_defaults.each do |k, v|
69
+ values[k] = v unless values.key?(k)
70
+ end
71
+ func.int_array_lengths.each do |k, len|
72
+ values[k] = [values[k]] * len if values[k].is_a?(Integer)
64
73
  end
65
74
 
66
- arg_types = func.args.map { |a| [a[:name], a[:type]] }.to_h
75
+ arg_checkers = func.arg_checkers
67
76
 
68
77
  values.each_key do |k|
69
- v = values[k]
70
- t = arg_types[k].split("(").first
71
-
72
- good =
73
- case t
74
- when "Tensor"
75
- v.is_a?(Tensor)
76
- when "Tensor?"
77
- v.nil? || v.is_a?(Tensor)
78
- when "Tensor[]", "Tensor?[]"
79
- v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
80
- when "int"
81
- if k == "reduction"
82
- v.is_a?(String)
83
- else
84
- v.is_a?(Integer)
85
- end
86
- when "float"
87
- v.is_a?(Numeric)
88
- when /int\[.*\]/
89
- if v.is_a?(Integer)
90
- size = t[4..-2]
91
- raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
92
- v = [v] * size.to_i
93
- values[k] = v
94
- end
95
- v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) }
96
- when "Scalar"
97
- v.is_a?(Numeric)
98
- when "ScalarType?"
99
- v.nil?
100
- when "bool"
101
- v == true || v == false
102
- when "str"
103
- v.is_a?(String)
104
- else
105
- raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}."
78
+ unless arg_checkers.key?(k)
79
+ good = false
80
+ if candidates.empty?
81
+ # TODO show all bad keywords at once like Ruby?
82
+ return {error: "unknown keyword: #{k}"}
106
83
  end
84
+ break
85
+ end
107
86
 
108
- if !good
109
- if candidates.size == 1
110
- k = "input" if k == "self"
87
+ unless arg_checkers[k].call(values[k])
88
+ good = false
89
+ if candidates.empty?
90
+ t = func.arg_types[k]
91
+ k = :input if k == :self
111
92
  return {error: "#{@name}(): argument '#{k}' must be #{t}"}
112
93
  end
113
94
  break
@@ -116,19 +97,19 @@ module Torch
116
97
 
117
98
  if good
118
99
  final_values = values
100
+ break
119
101
  end
120
-
121
- good
122
102
  end
123
103
 
124
- if candidates.size != 1
104
+ unless final_values
125
105
  raise Error, "This should never happen. Please report a bug with #{@name}."
126
106
  end
127
107
 
128
- func = candidates.first
108
+ args = func.arg_names.map { |k| final_values[k] }
109
+ args << TensorOptions.new.dtype(6) if func.tensor_options
129
110
  {
130
111
  name: func.cpp_name,
131
- args: func.args.map { |a| final_values[a[:name]] }
112
+ args: args
132
113
  }
133
114
  end
134
115
  end
@@ -178,8 +178,12 @@ module Torch
178
178
  Torch.hardshrink(input, lambd)
179
179
  end
180
180
 
181
- def leaky_relu(input, negative_slope = 0.01)
182
- NN.leaky_relu(input, negative_slope)
181
+ def leaky_relu(input, negative_slope = 0.01, inplace: false)
182
+ if inplace
183
+ NN.leaky_relu!(input, negative_slope)
184
+ else
185
+ NN.leaky_relu(input, negative_slope)
186
+ end
183
187
  end
184
188
 
185
189
  def log_sigmoid(input)
@@ -465,6 +469,77 @@ module Torch
465
469
  Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
466
470
  end
467
471
 
472
+ # vision
473
+
474
+ def interpolate(input, size: nil, scale_factor: nil, mode: "nearest", align_corners: nil, recompute_scale_factor: nil)
475
+ if ["nearest", "area"].include?(mode)
476
+ unless align_corners.nil?
477
+ raise ArgumentError, "align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear"
478
+ end
479
+ else
480
+ if align_corners.nil?
481
+ align_corners = false
482
+ end
483
+ end
484
+
485
+ scale_factor_len = input.dim - 2
486
+ scale_factor_list = [nil] * scale_factor_len
487
+ # default value of recompute_scale_factor is False
488
+ if !scale_factor.nil? && (recompute_scale_factor == false || recompute_scale_factor.nil?)
489
+ if scale_factor.is_a?(Array)
490
+ _scale_factor_repeated = scale_factor
491
+ else
492
+ _scale_factor_repeated = [scale_factor] * scale_factor_len
493
+ end
494
+ scale_factor_list = _scale_factor_repeated
495
+ end
496
+
497
+ # Give this variable a short name because it has to be repeated multiple times below.
498
+ sfl = scale_factor_list
499
+
500
+ closed_over_args = [input, size, scale_factor, recompute_scale_factor]
501
+ output_size = _interp_output_size(closed_over_args)
502
+ if input.dim == 3 && mode == "nearest"
503
+ NN.upsample_nearest1d(input, output_size, sfl[0])
504
+ elsif input.dim == 4 && mode == "nearest"
505
+ NN.upsample_nearest2d(input, output_size, sfl[0], sfl[1])
506
+ elsif input.dim == 5 && mode == "nearest"
507
+ NN.upsample_nearest3d(input, output_size, sfl[0], sfl[1], sfl[2])
508
+ elsif input.dim == 3 && mode == "area"
509
+ adaptive_avg_pool1d(input, output_size)
510
+ elsif input.dim == 4 && mode == "area"
511
+ adaptive_avg_pool2d(input, output_size)
512
+ elsif input.dim == 5 && mode == "area"
513
+ adaptive_avg_pool3d(input, output_size)
514
+ elsif input.dim == 3 && mode == "linear"
515
+ # assert align_corners is not None
516
+ NN.upsample_linear1d(input, output_size, align_corners, sfl[0])
517
+ elsif input.dim == 3 && mode == "bilinear"
518
+ raise ArgumentError, "Got 3D input, but bilinear mode needs 4D input"
519
+ elsif input.dim == 3 && mode == "trilinear"
520
+ raise ArgumentError, "Got 3D input, but trilinear mode needs 5D input"
521
+ elsif input.dim == 4 && mode == "linear"
522
+ raise ArgumentError, "Got 4D input, but linear mode needs 3D input"
523
+ elsif input.dim == 4 && mode == "bilinear"
524
+ # assert align_corners is not None
525
+ NN.upsample_bilinear2d(input, output_size, align_corners, sfl[0], sfl[1])
526
+ elsif input.dim == 4 && mode == "trilinear"
527
+ raise ArgumentError, "Got 4D input, but trilinear mode needs 5D input"
528
+ elsif input.dim == 5 && mode == "linear"
529
+ raise ArgumentError, "Got 5D input, but linear mode needs 3D input"
530
+ elsif input.dim == 5 && mode == "bilinear"
531
+ raise ArgumentError, "Got 5D input, but bilinear mode needs 4D input"
532
+ elsif input.dim == 5 && mode == "trilinear"
533
+ # assert align_corners is not None
534
+ NN.upsample_trilinear3d(input, output_size, align_corners, sfl[0], sfl[1], sfl[2])
535
+ elsif input.dim == 4 && mode == "bicubic"
536
+ # assert align_corners is not None
537
+ NN.upsample_bicubic2d(input, output_size, align_corners, sfl[0], sfl[1])
538
+ else
539
+ raise ArgumentError, "Input Error: Only 3D, 4D and 5D input Tensors supported (got #{input.dim}D) for the modes: nearest | linear | bilinear | bicubic | trilinear (got #{mode})"
540
+ end
541
+ end
542
+
468
543
  private
469
544
 
470
545
  def softmax_dim(ndim)
@@ -480,6 +555,41 @@ module Torch
480
555
  out_size.zip(defaults.last(out_size.length)).map { |v, d| v || d }
481
556
  end
482
557
  end
558
+
559
+ def _interp_output_size(closed_over_args)
560
+ input, size, scale_factor, recompute_scale_factor = closed_over_args
561
+ dim = input.dim - 2
562
+ if size.nil? && scale_factor.nil?
563
+ raise ArgumentError, "either size or scale_factor should be defined"
564
+ end
565
+ if !size.nil? && !scale_factor.nil?
566
+ raise ArgumentError, "only one of size or scale_factor should be defined"
567
+ end
568
+ if !scale_factor.nil?
569
+ if scale_factor.is_a?(Array)
570
+ if scale_factor.length != dim
571
+ raise ArgumentError, "scale_factor shape must match input shape. Input is #{dim}D, scale_factor size is #{scale_factor.length}"
572
+ end
573
+ end
574
+ end
575
+
576
+ if !size.nil?
577
+ if size.is_a?(Array)
578
+ return size
579
+ else
580
+ return [size] * dim
581
+ end
582
+ end
583
+
584
+ raise "Failed assertion" if scale_factor.nil?
585
+ if scale_factor.is_a?(Array)
586
+ scale_factors = scale_factor
587
+ else
588
+ scale_factors = [scale_factor] * dim
589
+ end
590
+
591
+ dim.times.map { |i| (input.size(i + 2) * scale_factors[i]).floor }
592
+ end
483
593
  end
484
594
  end
485
595
 
@@ -1,14 +1,14 @@
1
1
  module Torch
2
2
  module NN
3
3
  class LeakyReLU < Module
4
- def initialize(negative_slope: 1e-2) #, inplace: false)
4
+ def initialize(negative_slope: 1e-2, inplace: false)
5
5
  super()
6
6
  @negative_slope = negative_slope
7
- # @inplace = inplace
7
+ @inplace = inplace
8
8
  end
9
9
 
10
10
  def forward(input)
11
- F.leaky_relu(input, @negative_slope) #, inplace: @inplace)
11
+ F.leaky_relu(input, @negative_slope, inplace: @inplace)
12
12
  end
13
13
 
14
14
  def extra_inspect
@@ -55,7 +55,15 @@ module Torch
55
55
  end
56
56
  end
57
57
  end
58
- # TODO apply to more objects
58
+
59
+ @buffers.each_key do |k|
60
+ buf = @buffers[k]
61
+ unless buf.nil?
62
+ @buffers[k] = fn.call(buf)
63
+ instance_variable_set("@#{k}", @buffers[k])
64
+ end
65
+ end
66
+
59
67
  self
60
68
  end
61
69
 
@@ -0,0 +1,31 @@
1
+ module Torch
2
+ module NN
3
+ class Upsample < Module
4
+ def initialize(size: nil, scale_factor: nil, mode: "nearest", align_corners: nil)
5
+ super()
6
+ @size = size
7
+ if scale_factor.is_a?(Array)
8
+ @scale_factor = scale_factor.map(&:to_f)
9
+ else
10
+ @scale_factor = scale_factor ? scale_factor.to_f : nil
11
+ end
12
+ @mode = mode
13
+ @align_corners = align_corners
14
+ end
15
+
16
+ def forward(input)
17
+ F.interpolate(input, size: @size, scale_factor: @scale_factor, mode: @mode, align_corners: @align_corners)
18
+ end
19
+
20
+ def extra_inspect
21
+ if !@scale_factor.nil?
22
+ info = "scale_factor: #{@scale_factor.inspect}"
23
+ else
24
+ info = "size: #{@size.inspect}"
25
+ end
26
+ info += ", mode: #{@mode.inspect}"
27
+ info
28
+ end
29
+ end
30
+ end
31
+ end
@@ -48,6 +48,11 @@ module Torch
48
48
  end
49
49
 
50
50
  def to(device = nil, dtype: nil, non_blocking: false, copy: false)
51
+ if device.is_a?(Symbol) && !dtype
52
+ dtype = device
53
+ device = nil
54
+ end
55
+
51
56
  device ||= self.device
52
57
  device = Device.new(device) if device.is_a?(String)
53
58
 
@@ -74,10 +79,6 @@ module Torch
74
79
  end
75
80
  end
76
81
 
77
- def shape
78
- dim.times.map { |i| size(i) }
79
- end
80
-
81
82
  # mirror Python len()
82
83
  def length
83
84
  size(0)
@@ -103,11 +104,6 @@ module Torch
103
104
  Torch.empty(0, dtype: dtype)
104
105
  end
105
106
 
106
- def backward(gradient = nil, retain_graph: nil, create_graph: false)
107
- retain_graph = create_graph if retain_graph.nil?
108
- _backward(gradient, retain_graph, create_graph)
109
- end
110
-
111
107
  # TODO read directly from memory
112
108
  def numo
113
109
  cls = Torch._dtype_to_numo[dtype]
@@ -124,9 +120,14 @@ module Torch
124
120
  end
125
121
 
126
122
  def type(dtype)
127
- enum = DTYPE_TO_ENUM[dtype]
128
- raise Error, "Unknown type: #{dtype}" unless enum
129
- _type(enum)
123
+ if dtype.is_a?(Class)
124
+ raise Error, "Invalid type: #{dtype}" unless TENSOR_TYPE_CLASSES.include?(dtype)
125
+ dtype.new(self)
126
+ else
127
+ enum = DTYPE_TO_ENUM[dtype]
128
+ raise Error, "Invalid type: #{dtype}" unless enum
129
+ _type(enum)
130
+ end
130
131
  end
131
132
 
132
133
  def reshape(*size)
@@ -188,49 +189,15 @@ module Torch
188
189
  # based on python_variable_indexing.cpp and
189
190
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
190
191
  def [](*indexes)
191
- result = self
192
- dim = 0
193
- indexes.each do |index|
194
- if index.is_a?(Numeric)
195
- result = result._select_int(dim, index)
196
- elsif index.is_a?(Range)
197
- finish = index.end
198
- finish += 1 unless index.exclude_end?
199
- result = result._slice_tensor(dim, index.begin, finish, 1)
200
- dim += 1
201
- elsif index.is_a?(Tensor)
202
- result = result.index([index])
203
- elsif index.nil?
204
- result = result.unsqueeze(dim)
205
- dim += 1
206
- elsif index == true
207
- result = result.unsqueeze(dim)
208
- # TODO handle false
209
- else
210
- raise Error, "Unsupported index type: #{index.class.name}"
211
- end
212
- end
213
- result
192
+ _index(tensor_indexes(indexes))
214
193
  end
215
194
 
216
195
  # based on python_variable_indexing.cpp and
217
196
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
218
- def []=(index, value)
197
+ def []=(*indexes, value)
219
198
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
220
-
221
199
  value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
222
-
223
- if index.is_a?(Numeric)
224
- index_put!([Torch.tensor(index)], value)
225
- elsif index.is_a?(Range)
226
- finish = index.end
227
- finish += 1 unless index.exclude_end?
228
- _slice_tensor(0, index.begin, finish, 1).copy!(value)
229
- elsif index.is_a?(Tensor)
230
- index_put!([index], value)
231
- else
232
- raise Error, "Unsupported index type: #{index.class.name}"
233
- end
200
+ _index_put_custom(tensor_indexes(indexes), value)
234
201
  end
235
202
 
236
203
  # native functions that need manually defined
@@ -244,13 +211,13 @@ module Torch
244
211
  end
245
212
  end
246
213
 
247
- # native functions overlap, so need to handle manually
214
+ # parser can't handle overlap, so need to handle manually
248
215
  def random!(*args)
249
216
  case args.size
250
217
  when 1
251
218
  _random__to(*args)
252
219
  when 2
253
- _random__from_to(*args)
220
+ _random__from(*args)
254
221
  else
255
222
  _random_(*args)
256
223
  end
@@ -260,5 +227,32 @@ module Torch
260
227
  _clamp_min_(min)
261
228
  _clamp_max_(max)
262
229
  end
230
+
231
+ private
232
+
233
+ def tensor_indexes(indexes)
234
+ indexes.map do |index|
235
+ case index
236
+ when Integer
237
+ TensorIndex.integer(index)
238
+ when Range
239
+ finish = index.end || -1
240
+ if finish == -1 && !index.exclude_end?
241
+ finish = nil
242
+ else
243
+ finish += 1 unless index.exclude_end?
244
+ end
245
+ TensorIndex.slice(index.begin, finish)
246
+ when Tensor
247
+ TensorIndex.tensor(index)
248
+ when nil
249
+ TensorIndex.none
250
+ when true, false
251
+ TensorIndex.boolean(index)
252
+ else
253
+ raise Error, "Unsupported index type: #{index.class.name}"
254
+ end
255
+ end
256
+ end
263
257
  end
264
258
  end