torch-rb 0.3.2 → 0.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -174,6 +174,9 @@ require "torch/nn/smooth_l1_loss"
174
174
  require "torch/nn/soft_margin_loss"
175
175
  require "torch/nn/triplet_margin_loss"
176
176
 
177
+ # nn vision
178
+ require "torch/nn/upsample"
179
+
177
180
  # nn other
178
181
  require "torch/nn/functional"
179
182
  require "torch/nn/init"
@@ -196,6 +199,32 @@ module Torch
196
199
  end
197
200
  end
198
201
 
202
+ # legacy
203
+ # but may make it easier to port tutorials
204
+ module Autograd
205
+ class Variable
206
+ def self.new(x)
207
+ raise ArgumentError, "Variable data has to be a tensor, but got #{x.class.name}" unless x.is_a?(Tensor)
208
+ warn "[torch] The Variable API is deprecated. Use tensors with requires_grad: true instead."
209
+ x
210
+ end
211
+ end
212
+ end
213
+
214
+ # TODO move to C++
215
+ class ByteStorage
216
+ # private
217
+ attr_reader :bytes
218
+
219
+ def initialize(bytes)
220
+ @bytes = bytes
221
+ end
222
+
223
+ def self.from_buffer(bytes)
224
+ new(bytes)
225
+ end
226
+ end
227
+
199
228
  # keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
200
229
  # values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
201
230
  DTYPE_TO_ENUM = {
@@ -224,40 +253,43 @@ module Torch
224
253
  }
225
254
  ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
226
255
 
256
+ TENSOR_TYPE_CLASSES = []
257
+
227
258
  def self._make_tensor_class(dtype, cuda = false)
228
259
  cls = Class.new
229
260
  device = cuda ? "cuda" : "cpu"
230
261
  cls.define_singleton_method("new") do |*args|
231
262
  if args.size == 1 && args.first.is_a?(Tensor)
232
263
  args.first.send(dtype).to(device)
264
+ elsif args.size == 1 && args.first.is_a?(ByteStorage) && dtype == :uint8
265
+ bytes = args.first.bytes
266
+ Torch._from_blob(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
233
267
  elsif args.size == 1 && args.first.is_a?(Array)
234
268
  Torch.tensor(args.first, dtype: dtype, device: device)
235
269
  else
236
270
  Torch.empty(*args, dtype: dtype, device: device)
237
271
  end
238
272
  end
273
+ TENSOR_TYPE_CLASSES << cls
239
274
  cls
240
275
  end
241
276
 
242
- FloatTensor = _make_tensor_class(:float32)
243
- DoubleTensor = _make_tensor_class(:float64)
244
- HalfTensor = _make_tensor_class(:float16)
245
- ByteTensor = _make_tensor_class(:uint8)
246
- CharTensor = _make_tensor_class(:int8)
247
- ShortTensor = _make_tensor_class(:int16)
248
- IntTensor = _make_tensor_class(:int32)
249
- LongTensor = _make_tensor_class(:int64)
250
- BoolTensor = _make_tensor_class(:bool)
251
-
252
- CUDA::FloatTensor = _make_tensor_class(:float32, true)
253
- CUDA::DoubleTensor = _make_tensor_class(:float64, true)
254
- CUDA::HalfTensor = _make_tensor_class(:float16, true)
255
- CUDA::ByteTensor = _make_tensor_class(:uint8, true)
256
- CUDA::CharTensor = _make_tensor_class(:int8, true)
257
- CUDA::ShortTensor = _make_tensor_class(:int16, true)
258
- CUDA::IntTensor = _make_tensor_class(:int32, true)
259
- CUDA::LongTensor = _make_tensor_class(:int64, true)
260
- CUDA::BoolTensor = _make_tensor_class(:bool, true)
277
+ DTYPE_TO_CLASS = {
278
+ float32: "FloatTensor",
279
+ float64: "DoubleTensor",
280
+ float16: "HalfTensor",
281
+ uint8: "ByteTensor",
282
+ int8: "CharTensor",
283
+ int16: "ShortTensor",
284
+ int32: "IntTensor",
285
+ int64: "LongTensor",
286
+ bool: "BoolTensor"
287
+ }
288
+
289
+ DTYPE_TO_CLASS.each do |dtype, class_name|
290
+ const_set(class_name, _make_tensor_class(dtype))
291
+ CUDA.const_set(class_name, _make_tensor_class(dtype, true))
292
+ end
261
293
 
262
294
  class << self
263
295
  # Torch.float, Torch.long, etc
@@ -388,6 +420,10 @@ module Torch
388
420
  end
389
421
 
390
422
  def randperm(n, **options)
423
+ # dtype hack in Python
424
+ # https://github.com/pytorch/pytorch/blob/v1.6.0/tools/autograd/gen_python_functions.py#L1307-L1311
425
+ options[:dtype] ||= :int64
426
+
391
427
  _randperm(n, tensor_options(**options))
392
428
  end
393
429
 
@@ -460,6 +496,22 @@ module Torch
460
496
  zeros(input.size, **like_options(input, options))
461
497
  end
462
498
 
499
+ def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true)
500
+ if center
501
+ signal_dim = input.dim
502
+ extended_shape = [1] * (3 - signal_dim) + input.size
503
+ pad = n_fft.div(2).to_i
504
+ input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
505
+ input = input.view(input.shape[-signal_dim..-1])
506
+ end
507
+ _stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
508
+ end
509
+
510
+ def clamp(tensor, min, max)
511
+ tensor = _clamp_min(tensor, min)
512
+ _clamp_max(tensor, max)
513
+ end
514
+
463
515
  private
464
516
 
465
517
  def to_ivalue(obj)
@@ -22,21 +22,43 @@ module Torch
22
22
  end
23
23
 
24
24
  def bind_functions(context, def_method, functions)
25
+ instance_method = def_method == :define_method
25
26
  functions.group_by(&:ruby_name).sort_by { |g, _| g }.each do |name, funcs|
26
- if def_method == :define_method
27
+ if instance_method
27
28
  funcs.map! { |f| Function.new(f.function) }
28
- funcs.each { |f| f.args.reject! { |a| a[:name] == "self" } }
29
+ funcs.each { |f| f.args.reject! { |a| a[:name] == :self } }
29
30
  end
30
31
 
31
- defined = def_method == :define_method ? context.method_defined?(name) : context.respond_to?(name)
32
+ defined = instance_method ? context.method_defined?(name) : context.respond_to?(name)
32
33
  next if defined && name != "clone"
33
34
 
34
- parser = Parser.new(funcs)
35
+ # skip parser when possible for performance
36
+ if funcs.size == 1 && funcs.first.args.size == 0
37
+ # functions with no arguments
38
+ if instance_method
39
+ context.send(:alias_method, name, funcs.first.cpp_name)
40
+ else
41
+ context.singleton_class.send(:alias_method, name, funcs.first.cpp_name)
42
+ end
43
+ elsif funcs.size == 2 && funcs.map { |f| f.arg_types.values }.sort == [["Scalar"], ["Tensor"]]
44
+ # functions that take a tensor or scalar
45
+ scalar_name, tensor_name = funcs.sort_by { |f| f.arg_types.values }.map(&:cpp_name)
46
+ context.send(def_method, name) do |other|
47
+ case other
48
+ when Tensor
49
+ send(tensor_name, other)
50
+ else
51
+ send(scalar_name, other)
52
+ end
53
+ end
54
+ else
55
+ parser = Parser.new(funcs)
35
56
 
36
- context.send(def_method, name) do |*args, **options|
37
- result = parser.parse(args, options)
38
- raise ArgumentError, result[:error] if result[:error]
39
- send(result[:name], *result[:args])
57
+ context.send(def_method, name) do |*args, **options|
58
+ result = parser.parse(args, options)
59
+ raise ArgumentError, result[:error] if result[:error]
60
+ send(result[:name], *result[:args])
61
+ end
40
62
  end
41
63
  end
42
64
  end
@@ -1,10 +1,15 @@
1
1
  module Torch
2
2
  module Native
3
3
  class Function
4
- attr_reader :function
4
+ attr_reader :function, :tensor_options
5
5
 
6
6
  def initialize(function)
7
7
  @function = function
8
+
9
+ # note: don't modify function in-place
10
+ @tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
11
+ @tensor_options = @function["func"].include?(@tensor_options_str)
12
+ @out = out_size > 0 && base_name[-1] != "_"
8
13
  end
9
14
 
10
15
  def func
@@ -27,7 +32,7 @@ module Torch
27
32
  @args ||= begin
28
33
  args = []
29
34
  pos = true
30
- args_str = func.split("(", 2).last.split(") ->").first
35
+ args_str = func.sub(@tensor_options_str, ")").split("(", 2).last.split(") ->").first
31
36
  args_str.split(", ").each do |a|
32
37
  if a == "*"
33
38
  pos = false
@@ -68,12 +73,88 @@ module Torch
68
73
  next if t == "Generator?"
69
74
  next if t == "MemoryFormat"
70
75
  next if t == "MemoryFormat?"
71
- args << {name: k, type: t, default: d, pos: pos, has_default: has_default}
76
+ args << {name: k.to_sym, type: t, default: d, pos: pos, has_default: has_default}
72
77
  end
73
78
  args
74
79
  end
75
80
  end
76
81
 
82
+ def arg_checkers
83
+ @arg_checkers ||= begin
84
+ checkers = {}
85
+ arg_types.each do |k, t|
86
+ checker =
87
+ case t
88
+ when "Tensor"
89
+ ->(v) { v.is_a?(Tensor) }
90
+ when "Tensor?"
91
+ ->(v) { v.nil? || v.is_a?(Tensor) }
92
+ when "Tensor[]", "Tensor?[]"
93
+ ->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) } }
94
+ when "int"
95
+ if k == :reduction
96
+ ->(v) { v.is_a?(String) }
97
+ else
98
+ ->(v) { v.is_a?(Integer) }
99
+ end
100
+ when "int?"
101
+ ->(v) { v.is_a?(Integer) || v.nil? }
102
+ when "float?"
103
+ ->(v) { v.is_a?(Numeric) || v.nil? }
104
+ when "bool?"
105
+ ->(v) { v == true || v == false || v.nil? }
106
+ when "float"
107
+ ->(v) { v.is_a?(Numeric) }
108
+ when /int\[.*\]/
109
+ ->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) } }
110
+ when "Scalar"
111
+ ->(v) { v.is_a?(Numeric) }
112
+ when "Scalar?"
113
+ ->(v) { v.is_a?(Numeric) || v.nil? }
114
+ when "ScalarType"
115
+ ->(v) { false } # not supported yet
116
+ when "ScalarType?"
117
+ ->(v) { v.nil? }
118
+ when "bool"
119
+ ->(v) { v == true || v == false }
120
+ when "str"
121
+ ->(v) { v.is_a?(String) }
122
+ else
123
+ raise Error, "Unknown argument type: #{t}. Please report a bug with #{@name}."
124
+ end
125
+ checkers[k] = checker
126
+ end
127
+ checkers
128
+ end
129
+ end
130
+
131
+ def int_array_lengths
132
+ @int_array_lengths ||= begin
133
+ ret = {}
134
+ arg_types.each do |k, t|
135
+ if t.match?(/\Aint\[.+\]\z/)
136
+ size = t[4..-2]
137
+ raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
138
+ ret[k] = size.to_i
139
+ end
140
+ end
141
+ ret
142
+ end
143
+ end
144
+
145
+ def arg_names
146
+ @arg_names ||= args.map { |a| a[:name] }
147
+ end
148
+
149
+ def arg_types
150
+ @arg_types ||= args.map { |a| [a[:name], a[:type].split("(").first] }.to_h
151
+ end
152
+
153
+ def arg_defaults
154
+ # TODO find out why can't use select here
155
+ @arg_defaults ||= args.map { |a| [a[:name], a[:default]] }.to_h
156
+ end
157
+
77
158
  def out_size
78
159
  @out_size ||= func.split("->").last.count("!")
79
160
  end
@@ -82,8 +163,16 @@ module Torch
82
163
  @ret_size ||= func.split("->").last.split(", ").size
83
164
  end
84
165
 
166
+ def ret_array?
167
+ @ret_array ||= func.split("->").last.include?('[]')
168
+ end
169
+
170
+ def ret_void?
171
+ func.split("->").last.strip == "()"
172
+ end
173
+
85
174
  def out?
86
- out_size > 0 && base_name[-1] != "_"
175
+ @out
87
176
  end
88
177
 
89
178
  def ruby_name
@@ -18,12 +18,12 @@ module Torch
18
18
  functions = functions()
19
19
 
20
20
  # skip functions
21
- skip_args = ["bool[3]", "Dimname", "Layout", "Storage", "ConstQuantizerPtr"]
21
+ skip_args = ["Layout", "Storage", "ConstQuantizerPtr"]
22
22
 
23
23
  # remove functions
24
24
  functions.reject! do |f|
25
25
  f.ruby_name.start_with?("_") ||
26
- f.ruby_name.end_with?("_backward") ||
26
+ f.ruby_name.include?("_backward") ||
27
27
  f.args.any? { |a| a[:type].include?("Dimname") }
28
28
  end
29
29
 
@@ -31,32 +31,15 @@ module Torch
31
31
  todo_functions, functions =
32
32
  functions.partition do |f|
33
33
  f.args.any? do |a|
34
- a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
35
34
  skip_args.any? { |sa| a[:type].include?(sa) } ||
35
+ # call to 'range' is ambiguous
36
+ f.cpp_name == "_range" ||
36
37
  # native_functions.yaml is missing size argument for normal
37
38
  # https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
38
39
  (f.base_name == "normal" && !f.out?)
39
40
  end
40
41
  end
41
42
 
42
- # generate additional functions for optional arguments
43
- # there may be a better way to do this
44
- optional_functions, functions = functions.partition { |f| f.args.any? { |a| a[:type] == "int?" } }
45
- optional_functions.each do |f|
46
- next if f.ruby_name == "cross"
47
- next if f.ruby_name.start_with?("avg_pool") && f.out?
48
-
49
- opt_args = f.args.select { |a| a[:type] == "int?" }
50
- if opt_args.size == 1
51
- sep = f.name.include?(".") ? "_" : "."
52
- f1 = Function.new(f.function.merge("func" => f.func.sub("(", "#{sep}#{opt_args.first[:name]}(").gsub("int?", "int")))
53
- # TODO only remove some arguments
54
- f2 = Function.new(f.function.merge("func" => f.func.sub(/, int\?.+\) ->/, ") ->")))
55
- functions << f1
56
- functions << f2
57
- end
58
- end
59
-
60
43
  # todo_functions.each do |f|
61
44
  # puts f.func
62
45
  # puts
@@ -89,15 +72,18 @@ void add_%{type}_functions(Module m);
89
72
  #include <rice/Module.hpp>
90
73
  #include "templates.hpp"
91
74
 
75
+ %{functions}
76
+
92
77
  void add_%{type}_functions(Module m) {
93
- m
94
- %{functions};
78
+ %{add_functions}
95
79
  }
96
80
  TEMPLATE
97
81
 
98
82
  cpp_defs = []
83
+ add_defs = []
99
84
  functions.sort_by(&:cpp_name).each do |func|
100
- fargs = func.args #.select { |a| a[:type] != "Generator?" }
85
+ fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
86
+ fargs << {name: :options, type: "TensorOptions"} if func.tensor_options
101
87
 
102
88
  cpp_args = []
103
89
  fargs.each do |a|
@@ -109,52 +95,70 @@ void add_%{type}_functions(Module m) {
109
95
  # TODO better signature
110
96
  "OptionalTensor"
111
97
  when "ScalarType?"
112
- "OptionalScalarType"
113
- when "Tensor[]"
114
- "TensorList"
115
- when "Tensor?[]"
98
+ "torch::optional<ScalarType>"
99
+ when "Tensor[]", "Tensor?[]"
116
100
  # TODO make optional
117
- "TensorList"
101
+ "std::vector<Tensor>"
118
102
  when "int"
119
103
  "int64_t"
104
+ when "int?"
105
+ "torch::optional<int64_t>"
106
+ when "float?"
107
+ "torch::optional<double>"
108
+ when "bool?"
109
+ "torch::optional<bool>"
110
+ when "Scalar?"
111
+ "torch::optional<torch::Scalar>"
120
112
  when "float"
121
113
  "double"
122
114
  when /\Aint\[/
123
- "IntArrayRef"
115
+ "std::vector<int64_t>"
124
116
  when /Tensor\(\S!?\)/
125
117
  "Tensor &"
126
118
  when "str"
127
119
  "std::string"
128
- else
120
+ when "TensorOptions"
121
+ "const torch::TensorOptions &"
122
+ when "Layout?"
123
+ "torch::optional<Layout>"
124
+ when "Device?"
125
+ "torch::optional<Device>"
126
+ when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage"
129
127
  a[:type]
128
+ else
129
+ raise "Unknown type: #{a[:type]}"
130
130
  end
131
131
 
132
- t = "MyReduction" if a[:name] == "reduction" && t == "int64_t"
132
+ t = "MyReduction" if a[:name] == :reduction && t == "int64_t"
133
133
  cpp_args << [t, a[:name]].join(" ").sub("& ", "&")
134
134
  end
135
135
 
136
136
  dispatch = func.out? ? "#{func.base_name}_out" : func.base_name
137
137
  args = fargs.map { |a| a[:name] }
138
138
  args.unshift(*args.pop(func.out_size)) if func.out?
139
- args.delete("self") if def_method == :define_method
139
+ args.delete(:self) if def_method == :define_method
140
140
 
141
141
  prefix = def_method == :define_method ? "self." : "torch::"
142
142
 
143
143
  body = "#{prefix}#{dispatch}(#{args.join(", ")})"
144
- # TODO check type as well
145
- if func.ret_size > 1
144
+
145
+ if func.cpp_name == "_fill_diagonal_"
146
+ body = "to_ruby<torch::Tensor>(#{body})"
147
+ elsif !func.ret_void?
146
148
  body = "wrap(#{body})"
147
149
  end
148
150
 
149
- cpp_defs << ".#{def_method}(
150
- \"#{func.cpp_name}\",
151
- *[](#{cpp_args.join(", ")}) {
152
- return #{body};
153
- })"
151
+ cpp_defs << "// #{func.func}
152
+ static #{func.ret_void? ? "void" : "Object"} #{type}#{func.cpp_name}(#{cpp_args.join(", ")})
153
+ {
154
+ return #{body};
155
+ }"
156
+
157
+ add_defs << "m.#{def_method}(\"#{func.cpp_name}\", #{type}#{func.cpp_name});"
154
158
  end
155
159
 
156
160
  hpp_contents = hpp_template % {type: type}
157
- cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n ")}
161
+ cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n\n"), add_functions: add_defs.join("\n ")}
158
162
 
159
163
  path = File.expand_path("../../../ext/torch", __dir__)
160
164
  File.write("#{path}/#{type}_functions.hpp", hpp_contents)