torch-rb 0.3.2 → 0.3.7

Sign up to get free protection for your applications and to get access to all the features.
@@ -174,6 +174,9 @@ require "torch/nn/smooth_l1_loss"
174
174
  require "torch/nn/soft_margin_loss"
175
175
  require "torch/nn/triplet_margin_loss"
176
176
 
177
+ # nn vision
178
+ require "torch/nn/upsample"
179
+
177
180
  # nn other
178
181
  require "torch/nn/functional"
179
182
  require "torch/nn/init"
@@ -196,6 +199,32 @@ module Torch
196
199
  end
197
200
  end
198
201
 
202
+ # legacy
203
+ # but may make it easier to port tutorials
204
+ module Autograd
205
+ class Variable
206
+ def self.new(x)
207
+ raise ArgumentError, "Variable data has to be a tensor, but got #{x.class.name}" unless x.is_a?(Tensor)
208
+ warn "[torch] The Variable API is deprecated. Use tensors with requires_grad: true instead."
209
+ x
210
+ end
211
+ end
212
+ end
213
+
214
+ # TODO move to C++
215
+ class ByteStorage
216
+ # private
217
+ attr_reader :bytes
218
+
219
+ def initialize(bytes)
220
+ @bytes = bytes
221
+ end
222
+
223
+ def self.from_buffer(bytes)
224
+ new(bytes)
225
+ end
226
+ end
227
+
199
228
  # keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
200
229
  # values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
201
230
  DTYPE_TO_ENUM = {
@@ -224,40 +253,43 @@ module Torch
224
253
  }
225
254
  ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
226
255
 
256
+ TENSOR_TYPE_CLASSES = []
257
+
227
258
  def self._make_tensor_class(dtype, cuda = false)
228
259
  cls = Class.new
229
260
  device = cuda ? "cuda" : "cpu"
230
261
  cls.define_singleton_method("new") do |*args|
231
262
  if args.size == 1 && args.first.is_a?(Tensor)
232
263
  args.first.send(dtype).to(device)
264
+ elsif args.size == 1 && args.first.is_a?(ByteStorage) && dtype == :uint8
265
+ bytes = args.first.bytes
266
+ Torch._from_blob(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
233
267
  elsif args.size == 1 && args.first.is_a?(Array)
234
268
  Torch.tensor(args.first, dtype: dtype, device: device)
235
269
  else
236
270
  Torch.empty(*args, dtype: dtype, device: device)
237
271
  end
238
272
  end
273
+ TENSOR_TYPE_CLASSES << cls
239
274
  cls
240
275
  end
241
276
 
242
- FloatTensor = _make_tensor_class(:float32)
243
- DoubleTensor = _make_tensor_class(:float64)
244
- HalfTensor = _make_tensor_class(:float16)
245
- ByteTensor = _make_tensor_class(:uint8)
246
- CharTensor = _make_tensor_class(:int8)
247
- ShortTensor = _make_tensor_class(:int16)
248
- IntTensor = _make_tensor_class(:int32)
249
- LongTensor = _make_tensor_class(:int64)
250
- BoolTensor = _make_tensor_class(:bool)
251
-
252
- CUDA::FloatTensor = _make_tensor_class(:float32, true)
253
- CUDA::DoubleTensor = _make_tensor_class(:float64, true)
254
- CUDA::HalfTensor = _make_tensor_class(:float16, true)
255
- CUDA::ByteTensor = _make_tensor_class(:uint8, true)
256
- CUDA::CharTensor = _make_tensor_class(:int8, true)
257
- CUDA::ShortTensor = _make_tensor_class(:int16, true)
258
- CUDA::IntTensor = _make_tensor_class(:int32, true)
259
- CUDA::LongTensor = _make_tensor_class(:int64, true)
260
- CUDA::BoolTensor = _make_tensor_class(:bool, true)
277
+ DTYPE_TO_CLASS = {
278
+ float32: "FloatTensor",
279
+ float64: "DoubleTensor",
280
+ float16: "HalfTensor",
281
+ uint8: "ByteTensor",
282
+ int8: "CharTensor",
283
+ int16: "ShortTensor",
284
+ int32: "IntTensor",
285
+ int64: "LongTensor",
286
+ bool: "BoolTensor"
287
+ }
288
+
289
+ DTYPE_TO_CLASS.each do |dtype, class_name|
290
+ const_set(class_name, _make_tensor_class(dtype))
291
+ CUDA.const_set(class_name, _make_tensor_class(dtype, true))
292
+ end
261
293
 
262
294
  class << self
263
295
  # Torch.float, Torch.long, etc
@@ -388,6 +420,10 @@ module Torch
388
420
  end
389
421
 
390
422
  def randperm(n, **options)
423
+ # dtype hack in Python
424
+ # https://github.com/pytorch/pytorch/blob/v1.6.0/tools/autograd/gen_python_functions.py#L1307-L1311
425
+ options[:dtype] ||= :int64
426
+
391
427
  _randperm(n, tensor_options(**options))
392
428
  end
393
429
 
@@ -460,6 +496,22 @@ module Torch
460
496
  zeros(input.size, **like_options(input, options))
461
497
  end
462
498
 
499
+ def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true)
500
+ if center
501
+ signal_dim = input.dim
502
+ extended_shape = [1] * (3 - signal_dim) + input.size
503
+ pad = n_fft.div(2).to_i
504
+ input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
505
+ input = input.view(input.shape[-signal_dim..-1])
506
+ end
507
+ _stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
508
+ end
509
+
510
+ def clamp(tensor, min, max)
511
+ tensor = _clamp_min(tensor, min)
512
+ _clamp_max(tensor, max)
513
+ end
514
+
463
515
  private
464
516
 
465
517
  def to_ivalue(obj)
@@ -22,21 +22,43 @@ module Torch
22
22
  end
23
23
 
24
24
  def bind_functions(context, def_method, functions)
25
+ instance_method = def_method == :define_method
25
26
  functions.group_by(&:ruby_name).sort_by { |g, _| g }.each do |name, funcs|
26
- if def_method == :define_method
27
+ if instance_method
27
28
  funcs.map! { |f| Function.new(f.function) }
28
- funcs.each { |f| f.args.reject! { |a| a[:name] == "self" } }
29
+ funcs.each { |f| f.args.reject! { |a| a[:name] == :self } }
29
30
  end
30
31
 
31
- defined = def_method == :define_method ? context.method_defined?(name) : context.respond_to?(name)
32
+ defined = instance_method ? context.method_defined?(name) : context.respond_to?(name)
32
33
  next if defined && name != "clone"
33
34
 
34
- parser = Parser.new(funcs)
35
+ # skip parser when possible for performance
36
+ if funcs.size == 1 && funcs.first.args.size == 0
37
+ # functions with no arguments
38
+ if instance_method
39
+ context.send(:alias_method, name, funcs.first.cpp_name)
40
+ else
41
+ context.singleton_class.send(:alias_method, name, funcs.first.cpp_name)
42
+ end
43
+ elsif funcs.size == 2 && funcs.map { |f| f.arg_types.values }.sort == [["Scalar"], ["Tensor"]]
44
+ # functions that take a tensor or scalar
45
+ scalar_name, tensor_name = funcs.sort_by { |f| f.arg_types.values }.map(&:cpp_name)
46
+ context.send(def_method, name) do |other|
47
+ case other
48
+ when Tensor
49
+ send(tensor_name, other)
50
+ else
51
+ send(scalar_name, other)
52
+ end
53
+ end
54
+ else
55
+ parser = Parser.new(funcs)
35
56
 
36
- context.send(def_method, name) do |*args, **options|
37
- result = parser.parse(args, options)
38
- raise ArgumentError, result[:error] if result[:error]
39
- send(result[:name], *result[:args])
57
+ context.send(def_method, name) do |*args, **options|
58
+ result = parser.parse(args, options)
59
+ raise ArgumentError, result[:error] if result[:error]
60
+ send(result[:name], *result[:args])
61
+ end
40
62
  end
41
63
  end
42
64
  end
@@ -1,10 +1,15 @@
1
1
  module Torch
2
2
  module Native
3
3
  class Function
4
- attr_reader :function
4
+ attr_reader :function, :tensor_options
5
5
 
6
6
  def initialize(function)
7
7
  @function = function
8
+
9
+ # note: don't modify function in-place
10
+ @tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
11
+ @tensor_options = @function["func"].include?(@tensor_options_str)
12
+ @out = out_size > 0 && base_name[-1] != "_"
8
13
  end
9
14
 
10
15
  def func
@@ -27,7 +32,7 @@ module Torch
27
32
  @args ||= begin
28
33
  args = []
29
34
  pos = true
30
- args_str = func.split("(", 2).last.split(") ->").first
35
+ args_str = func.sub(@tensor_options_str, ")").split("(", 2).last.split(") ->").first
31
36
  args_str.split(", ").each do |a|
32
37
  if a == "*"
33
38
  pos = false
@@ -68,12 +73,88 @@ module Torch
68
73
  next if t == "Generator?"
69
74
  next if t == "MemoryFormat"
70
75
  next if t == "MemoryFormat?"
71
- args << {name: k, type: t, default: d, pos: pos, has_default: has_default}
76
+ args << {name: k.to_sym, type: t, default: d, pos: pos, has_default: has_default}
72
77
  end
73
78
  args
74
79
  end
75
80
  end
76
81
 
82
+ def arg_checkers
83
+ @arg_checkers ||= begin
84
+ checkers = {}
85
+ arg_types.each do |k, t|
86
+ checker =
87
+ case t
88
+ when "Tensor"
89
+ ->(v) { v.is_a?(Tensor) }
90
+ when "Tensor?"
91
+ ->(v) { v.nil? || v.is_a?(Tensor) }
92
+ when "Tensor[]", "Tensor?[]"
93
+ ->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) } }
94
+ when "int"
95
+ if k == :reduction
96
+ ->(v) { v.is_a?(String) }
97
+ else
98
+ ->(v) { v.is_a?(Integer) }
99
+ end
100
+ when "int?"
101
+ ->(v) { v.is_a?(Integer) || v.nil? }
102
+ when "float?"
103
+ ->(v) { v.is_a?(Numeric) || v.nil? }
104
+ when "bool?"
105
+ ->(v) { v == true || v == false || v.nil? }
106
+ when "float"
107
+ ->(v) { v.is_a?(Numeric) }
108
+ when /int\[.*\]/
109
+ ->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) } }
110
+ when "Scalar"
111
+ ->(v) { v.is_a?(Numeric) }
112
+ when "Scalar?"
113
+ ->(v) { v.is_a?(Numeric) || v.nil? }
114
+ when "ScalarType"
115
+ ->(v) { false } # not supported yet
116
+ when "ScalarType?"
117
+ ->(v) { v.nil? }
118
+ when "bool"
119
+ ->(v) { v == true || v == false }
120
+ when "str"
121
+ ->(v) { v.is_a?(String) }
122
+ else
123
+ raise Error, "Unknown argument type: #{t}. Please report a bug with #{@name}."
124
+ end
125
+ checkers[k] = checker
126
+ end
127
+ checkers
128
+ end
129
+ end
130
+
131
+ def int_array_lengths
132
+ @int_array_lengths ||= begin
133
+ ret = {}
134
+ arg_types.each do |k, t|
135
+ if t.match?(/\Aint\[.+\]\z/)
136
+ size = t[4..-2]
137
+ raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
138
+ ret[k] = size.to_i
139
+ end
140
+ end
141
+ ret
142
+ end
143
+ end
144
+
145
+ def arg_names
146
+ @arg_names ||= args.map { |a| a[:name] }
147
+ end
148
+
149
+ def arg_types
150
+ @arg_types ||= args.map { |a| [a[:name], a[:type].split("(").first] }.to_h
151
+ end
152
+
153
+ def arg_defaults
154
+ # TODO find out why can't use select here
155
+ @arg_defaults ||= args.map { |a| [a[:name], a[:default]] }.to_h
156
+ end
157
+
77
158
  def out_size
78
159
  @out_size ||= func.split("->").last.count("!")
79
160
  end
@@ -82,8 +163,16 @@ module Torch
82
163
  @ret_size ||= func.split("->").last.split(", ").size
83
164
  end
84
165
 
166
+ def ret_array?
167
+ @ret_array ||= func.split("->").last.include?('[]')
168
+ end
169
+
170
+ def ret_void?
171
+ func.split("->").last.strip == "()"
172
+ end
173
+
85
174
  def out?
86
- out_size > 0 && base_name[-1] != "_"
175
+ @out
87
176
  end
88
177
 
89
178
  def ruby_name
@@ -18,12 +18,12 @@ module Torch
18
18
  functions = functions()
19
19
 
20
20
  # skip functions
21
- skip_args = ["bool[3]", "Dimname", "Layout", "Storage", "ConstQuantizerPtr"]
21
+ skip_args = ["Layout", "Storage", "ConstQuantizerPtr"]
22
22
 
23
23
  # remove functions
24
24
  functions.reject! do |f|
25
25
  f.ruby_name.start_with?("_") ||
26
- f.ruby_name.end_with?("_backward") ||
26
+ f.ruby_name.include?("_backward") ||
27
27
  f.args.any? { |a| a[:type].include?("Dimname") }
28
28
  end
29
29
 
@@ -31,32 +31,15 @@ module Torch
31
31
  todo_functions, functions =
32
32
  functions.partition do |f|
33
33
  f.args.any? do |a|
34
- a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
35
34
  skip_args.any? { |sa| a[:type].include?(sa) } ||
35
+ # call to 'range' is ambiguous
36
+ f.cpp_name == "_range" ||
36
37
  # native_functions.yaml is missing size argument for normal
37
38
  # https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
38
39
  (f.base_name == "normal" && !f.out?)
39
40
  end
40
41
  end
41
42
 
42
- # generate additional functions for optional arguments
43
- # there may be a better way to do this
44
- optional_functions, functions = functions.partition { |f| f.args.any? { |a| a[:type] == "int?" } }
45
- optional_functions.each do |f|
46
- next if f.ruby_name == "cross"
47
- next if f.ruby_name.start_with?("avg_pool") && f.out?
48
-
49
- opt_args = f.args.select { |a| a[:type] == "int?" }
50
- if opt_args.size == 1
51
- sep = f.name.include?(".") ? "_" : "."
52
- f1 = Function.new(f.function.merge("func" => f.func.sub("(", "#{sep}#{opt_args.first[:name]}(").gsub("int?", "int")))
53
- # TODO only remove some arguments
54
- f2 = Function.new(f.function.merge("func" => f.func.sub(/, int\?.+\) ->/, ") ->")))
55
- functions << f1
56
- functions << f2
57
- end
58
- end
59
-
60
43
  # todo_functions.each do |f|
61
44
  # puts f.func
62
45
  # puts
@@ -89,15 +72,18 @@ void add_%{type}_functions(Module m);
89
72
  #include <rice/Module.hpp>
90
73
  #include "templates.hpp"
91
74
 
75
+ %{functions}
76
+
92
77
  void add_%{type}_functions(Module m) {
93
- m
94
- %{functions};
78
+ %{add_functions}
95
79
  }
96
80
  TEMPLATE
97
81
 
98
82
  cpp_defs = []
83
+ add_defs = []
99
84
  functions.sort_by(&:cpp_name).each do |func|
100
- fargs = func.args #.select { |a| a[:type] != "Generator?" }
85
+ fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
86
+ fargs << {name: :options, type: "TensorOptions"} if func.tensor_options
101
87
 
102
88
  cpp_args = []
103
89
  fargs.each do |a|
@@ -109,52 +95,70 @@ void add_%{type}_functions(Module m) {
109
95
  # TODO better signature
110
96
  "OptionalTensor"
111
97
  when "ScalarType?"
112
- "OptionalScalarType"
113
- when "Tensor[]"
114
- "TensorList"
115
- when "Tensor?[]"
98
+ "torch::optional<ScalarType>"
99
+ when "Tensor[]", "Tensor?[]"
116
100
  # TODO make optional
117
- "TensorList"
101
+ "std::vector<Tensor>"
118
102
  when "int"
119
103
  "int64_t"
104
+ when "int?"
105
+ "torch::optional<int64_t>"
106
+ when "float?"
107
+ "torch::optional<double>"
108
+ when "bool?"
109
+ "torch::optional<bool>"
110
+ when "Scalar?"
111
+ "torch::optional<torch::Scalar>"
120
112
  when "float"
121
113
  "double"
122
114
  when /\Aint\[/
123
- "IntArrayRef"
115
+ "std::vector<int64_t>"
124
116
  when /Tensor\(\S!?\)/
125
117
  "Tensor &"
126
118
  when "str"
127
119
  "std::string"
128
- else
120
+ when "TensorOptions"
121
+ "const torch::TensorOptions &"
122
+ when "Layout?"
123
+ "torch::optional<Layout>"
124
+ when "Device?"
125
+ "torch::optional<Device>"
126
+ when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage"
129
127
  a[:type]
128
+ else
129
+ raise "Unknown type: #{a[:type]}"
130
130
  end
131
131
 
132
- t = "MyReduction" if a[:name] == "reduction" && t == "int64_t"
132
+ t = "MyReduction" if a[:name] == :reduction && t == "int64_t"
133
133
  cpp_args << [t, a[:name]].join(" ").sub("& ", "&")
134
134
  end
135
135
 
136
136
  dispatch = func.out? ? "#{func.base_name}_out" : func.base_name
137
137
  args = fargs.map { |a| a[:name] }
138
138
  args.unshift(*args.pop(func.out_size)) if func.out?
139
- args.delete("self") if def_method == :define_method
139
+ args.delete(:self) if def_method == :define_method
140
140
 
141
141
  prefix = def_method == :define_method ? "self." : "torch::"
142
142
 
143
143
  body = "#{prefix}#{dispatch}(#{args.join(", ")})"
144
- # TODO check type as well
145
- if func.ret_size > 1
144
+
145
+ if func.cpp_name == "_fill_diagonal_"
146
+ body = "to_ruby<torch::Tensor>(#{body})"
147
+ elsif !func.ret_void?
146
148
  body = "wrap(#{body})"
147
149
  end
148
150
 
149
- cpp_defs << ".#{def_method}(
150
- \"#{func.cpp_name}\",
151
- *[](#{cpp_args.join(", ")}) {
152
- return #{body};
153
- })"
151
+ cpp_defs << "// #{func.func}
152
+ static #{func.ret_void? ? "void" : "Object"} #{type}#{func.cpp_name}(#{cpp_args.join(", ")})
153
+ {
154
+ return #{body};
155
+ }"
156
+
157
+ add_defs << "m.#{def_method}(\"#{func.cpp_name}\", #{type}#{func.cpp_name});"
154
158
  end
155
159
 
156
160
  hpp_contents = hpp_template % {type: type}
157
- cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n ")}
161
+ cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n\n"), add_functions: add_defs.join("\n ")}
158
162
 
159
163
  path = File.expand_path("../../../ext/torch", __dir__)
160
164
  File.write("#{path}/#{type}_functions.hpp", hpp_contents)