torch-rb 0.1.4 → 0.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +6 -0
  3. data/README.md +5 -3
  4. data/ext/torch/ext.cpp +22 -548
  5. data/ext/torch/extconf.rb +6 -0
  6. data/ext/torch/nn_functions.cpp +595 -0
  7. data/ext/torch/nn_functions.hpp +6 -0
  8. data/ext/torch/templates.hpp +250 -0
  9. data/ext/torch/tensor_functions.cpp +1860 -0
  10. data/ext/torch/tensor_functions.hpp +6 -0
  11. data/ext/torch/torch_functions.cpp +2875 -0
  12. data/ext/torch/torch_functions.hpp +6 -0
  13. data/lib/torch.rb +68 -129
  14. data/lib/torch/ext.bundle +0 -0
  15. data/lib/torch/native/dispatcher.rb +48 -0
  16. data/lib/torch/native/function.rb +78 -0
  17. data/lib/torch/native/generator.rb +149 -0
  18. data/lib/torch/native/native_functions.yaml +6837 -0
  19. data/lib/torch/native/parser.rb +97 -0
  20. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  21. data/lib/torch/nn/conv2d.rb +0 -2
  22. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  23. data/lib/torch/nn/functional.rb +55 -16
  24. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  25. data/lib/torch/nn/identity.rb +1 -0
  26. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  27. data/lib/torch/nn/module.rb +59 -12
  28. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  29. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  30. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  31. data/lib/torch/nn/parameter.rb +4 -0
  32. data/lib/torch/nn/rnn.rb +22 -0
  33. data/lib/torch/nn/rnn_base.rb +154 -0
  34. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  35. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  36. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  37. data/lib/torch/tensor.rb +19 -19
  38. data/lib/torch/version.rb +1 -1
  39. metadata +26 -2
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_torch_functions(Module m);
data/lib/torch.rb CHANGED
@@ -1,6 +1,11 @@
1
1
  # ext
2
2
  require "torch/ext"
3
3
 
4
+ # native functions
5
+ require "torch/native/generator"
6
+ require "torch/native/parser"
7
+ require "torch/native/dispatcher"
8
+
4
9
  # modules
5
10
  require "torch/inspector"
6
11
  require "torch/tensor"
@@ -39,6 +44,10 @@ require "torch/nn/max_pool2d"
39
44
  require "torch/nn/avg_poolnd"
40
45
  require "torch/nn/avg_pool2d"
41
46
 
47
+ # nn recurrent layers
48
+ require "torch/nn/rnn_base"
49
+ require "torch/nn/rnn"
50
+
42
51
  # nn linear layers
43
52
  require "torch/nn/bilinear"
44
53
  require "torch/nn/identity"
@@ -77,23 +86,23 @@ require "torch/nn/pairwise_distance"
77
86
  require "torch/nn/loss"
78
87
  require "torch/nn/weighted_loss"
79
88
  require "torch/nn/bce_loss"
80
- # require "torch/nn/bce_with_logits_loss"
81
- # require "torch/nn/cosine_embedding_loss"
89
+ require "torch/nn/bce_with_logits_loss"
90
+ require "torch/nn/cosine_embedding_loss"
82
91
  require "torch/nn/cross_entropy_loss"
83
92
  require "torch/nn/ctc_loss"
84
- # require "torch/nn/hinge_embedding_loss"
93
+ require "torch/nn/hinge_embedding_loss"
85
94
  require "torch/nn/kl_div_loss"
86
95
  require "torch/nn/l1_loss"
87
- # require "torch/nn/margin_ranking_loss"
96
+ require "torch/nn/margin_ranking_loss"
88
97
  require "torch/nn/mse_loss"
89
- # require "torch/nn/multi_label_margin_loss"
90
- # require "torch/nn/multi_label_soft_margin_loss"
91
- # require "torch/nn/multi_margin_loss"
98
+ require "torch/nn/multi_label_margin_loss"
99
+ require "torch/nn/multi_label_soft_margin_loss"
100
+ require "torch/nn/multi_margin_loss"
92
101
  require "torch/nn/nll_loss"
93
102
  require "torch/nn/poisson_nll_loss"
94
- # require "torch/nn/smooth_l1_loss"
95
- # require "torch/nn/soft_margin_loss"
96
- # require "torch/nn/triplet_margin_loss"
103
+ require "torch/nn/smooth_l1_loss"
104
+ require "torch/nn/soft_margin_loss"
105
+ require "torch/nn/triplet_margin_loss"
97
106
 
98
107
  # nn other
99
108
  require "torch/nn/functional"
@@ -142,6 +151,41 @@ module Torch
142
151
  }
143
152
  ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
144
153
 
154
+ def self._make_tensor_class(dtype, cuda = false)
155
+ cls = Class.new
156
+ device = cuda ? "cuda" : "cpu"
157
+ cls.define_singleton_method("new") do |*args|
158
+ if args.size == 1 && args.first.is_a?(Tensor)
159
+ args.first.send(dtype).to(device)
160
+ elsif args.size == 1 && args.first.is_a?(Array)
161
+ Torch.tensor(args.first, dtype: dtype, device: device)
162
+ else
163
+ Torch.empty(*args, dtype: dtype, device: device)
164
+ end
165
+ end
166
+ cls
167
+ end
168
+
169
+ FloatTensor = _make_tensor_class(:float32)
170
+ DoubleTensor = _make_tensor_class(:float64)
171
+ HalfTensor = _make_tensor_class(:float16)
172
+ ByteTensor = _make_tensor_class(:uint8)
173
+ CharTensor = _make_tensor_class(:int8)
174
+ ShortTensor = _make_tensor_class(:int16)
175
+ IntTensor = _make_tensor_class(:int32)
176
+ LongTensor = _make_tensor_class(:int64)
177
+ BoolTensor = _make_tensor_class(:bool)
178
+
179
+ CUDA::FloatTensor = _make_tensor_class(:float32, true)
180
+ CUDA::DoubleTensor = _make_tensor_class(:float64, true)
181
+ CUDA::HalfTensor = _make_tensor_class(:float16, true)
182
+ CUDA::ByteTensor = _make_tensor_class(:uint8, true)
183
+ CUDA::CharTensor = _make_tensor_class(:int8, true)
184
+ CUDA::ShortTensor = _make_tensor_class(:int16, true)
185
+ CUDA::IntTensor = _make_tensor_class(:int32, true)
186
+ CUDA::LongTensor = _make_tensor_class(:int64, true)
187
+ CUDA::BoolTensor = _make_tensor_class(:bool, true)
188
+
145
189
  class << self
146
190
  # Torch.float, Torch.long, etc
147
191
  DTYPE_TO_ENUM.each_key do |dtype|
@@ -191,6 +235,20 @@ module Torch
191
235
  }
192
236
  end
193
237
 
238
+ def no_grad
239
+ previous_value = grad_enabled?
240
+ begin
241
+ _set_grad_enabled(false)
242
+ yield
243
+ ensure
244
+ _set_grad_enabled(previous_value)
245
+ end
246
+ end
247
+
248
+ def device(str)
249
+ Device.new(str)
250
+ end
251
+
194
252
  # --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
195
253
 
196
254
  def arange(start, finish = nil, step = 1, **options)
@@ -308,26 +366,6 @@ module Torch
308
366
 
309
367
  # --- begin operations ---
310
368
 
311
- %w(add sub mul div remainder).each do |op|
312
- define_method(op) do |input, other, **options|
313
- execute_op(op, input, other, **options)
314
- end
315
- end
316
-
317
- def neg(input)
318
- _neg(input)
319
- end
320
-
321
- def no_grad
322
- previous_value = grad_enabled?
323
- begin
324
- _set_grad_enabled(false)
325
- yield
326
- ensure
327
- _set_grad_enabled(previous_value)
328
- end
329
- end
330
-
331
369
  # TODO support out
332
370
  def mean(input, dim = nil, keepdim: false)
333
371
  if dim
@@ -346,34 +384,10 @@ module Torch
346
384
  end
347
385
  end
348
386
 
349
- def argmax(input, dim = nil, keepdim: false)
350
- if dim
351
- _argmax_dim(input, dim, keepdim)
352
- else
353
- _argmax(input)
354
- end
355
- end
356
-
357
- def eq(input, other)
358
- _eq(input, other)
359
- end
360
-
361
- def norm(input)
362
- _norm(input)
363
- end
364
-
365
- def pow(input, exponent)
366
- _pow(input, exponent)
367
- end
368
-
369
387
  def topk(input, k)
370
388
  _topk(input, k)
371
389
  end
372
390
 
373
- def min(input)
374
- _min(input)
375
- end
376
-
377
391
  def max(input, dim = nil, keepdim: false, out: nil)
378
392
  if dim
379
393
  raise NotImplementedYet unless out
@@ -383,58 +397,6 @@ module Torch
383
397
  end
384
398
  end
385
399
 
386
- def exp(input)
387
- _exp(input)
388
- end
389
-
390
- def log(input)
391
- _log(input)
392
- end
393
-
394
- def sign(input)
395
- _sign(input)
396
- end
397
-
398
- def sigmoid(input)
399
- _sigmoid(input)
400
- end
401
-
402
- def gt(input, other)
403
- _gt(input, other)
404
- end
405
-
406
- def lt(input, other)
407
- _lt(input, other)
408
- end
409
-
410
- def unsqueeze(input, dim)
411
- _unsqueeze(input, dim)
412
- end
413
-
414
- def dot(input, tensor)
415
- _dot(input, tensor)
416
- end
417
-
418
- def cat(tensors, dim = 0)
419
- _cat(tensors, dim)
420
- end
421
-
422
- def matmul(input, other)
423
- _matmul(input, other)
424
- end
425
-
426
- def reshape(input, shape)
427
- _reshape(input, shape)
428
- end
429
-
430
- def flatten(input, start_dim: 0, end_dim: -1)
431
- _flatten(input, start_dim, end_dim)
432
- end
433
-
434
- def sqrt(input)
435
- _sqrt(input)
436
- end
437
-
438
400
  # TODO make dim keyword argument
439
401
  def log_softmax(input, dim)
440
402
  _log_softmax(input, dim)
@@ -444,31 +406,8 @@ module Torch
444
406
  _softmax(input, dim)
445
407
  end
446
408
 
447
- def abs(input)
448
- _abs(input)
449
- end
450
-
451
- def device(str)
452
- Device.new(str)
453
- end
454
-
455
409
  private
456
410
 
457
- def execute_op(op, input, other, out: nil)
458
- scalar = other.is_a?(Numeric)
459
- if out
460
- # TODO make work with scalars
461
- raise Error, "out not supported with scalar yet" if scalar
462
- send("_#{op}_out", out, input, other)
463
- else
464
- if scalar
465
- send("_#{op}_scalar", input, other)
466
- else
467
- send("_#{op}", input, other)
468
- end
469
- end
470
- end
471
-
472
411
  def tensor_size(size)
473
412
  size.flatten
474
413
  end
data/lib/torch/ext.bundle CHANGED
Binary file
@@ -0,0 +1,48 @@
1
+ # We use a generic interface for methods (*args, **options)
2
+ # and this class to determine the C++ method to call
3
+ #
4
+ # This is needed since LibTorch uses function overloading,
5
+ # which isn't available in Ruby or Python
6
+ #
7
+ # PyTorch uses this approach, but the parser/dispatcher is written in C++
8
+ #
9
+ # We could generate Ruby methods directly, but an advantage of this approach is
10
+ # arguments and keyword arguments can be used interchangably like in Python,
11
+ # making it easier to port code
12
+
13
+ module Torch
14
+ module Native
15
+ module Dispatcher
16
+ class << self
17
+ def bind
18
+ functions = Generator.grouped_functions
19
+ bind_functions(::Torch, :define_singleton_method, functions[:torch])
20
+ bind_functions(::Torch::Tensor, :define_method, functions[:tensor])
21
+ # NN functions are internal, so no need to bind
22
+ end
23
+
24
+ def bind_functions(context, def_method, functions)
25
+ functions.group_by(&:ruby_name).sort_by { |g, _| g }.each do |name, funcs|
26
+ if def_method == :define_method
27
+ funcs.map! { |f| Function.new(f.function) }
28
+ funcs.each { |f| f.args.reject! { |a| a[:name] == "self" } }
29
+ end
30
+
31
+ defined = def_method == :define_method ? context.method_defined?(name) : context.respond_to?(name)
32
+ next if defined && name != "clone"
33
+
34
+ parser = Parser.new(funcs)
35
+
36
+ context.send(def_method, name) do |*args, **options|
37
+ result = parser.parse(args, options)
38
+ raise ArgumentError, result[:error] if result[:error]
39
+ send(result[:name], *result[:args])
40
+ end
41
+ end
42
+ end
43
+ end
44
+ end
45
+ end
46
+ end
47
+
48
+ Torch::Native::Dispatcher.bind
@@ -0,0 +1,78 @@
1
+ module Torch
2
+ module Native
3
+ class Function
4
+ attr_reader :function
5
+
6
+ def initialize(function)
7
+ @function = function
8
+ end
9
+
10
+ def func
11
+ @func ||= @function["func"]
12
+ end
13
+
14
+ def name
15
+ @name ||= func.split("(", 2).first
16
+ end
17
+
18
+ def python_module
19
+ @python_module ||= @function["python_module"]
20
+ end
21
+
22
+ def variants
23
+ @variants ||= (@function["variants"] || "function").split(", ")
24
+ end
25
+
26
+ def args
27
+ @args ||= begin
28
+ args = []
29
+ pos = true
30
+ args_str = func.split("(", 2).last.split(") ->").first
31
+ args_str.split(", ").each do |a|
32
+ if a == "*"
33
+ pos = false
34
+ next
35
+ end
36
+ t, _, k = a.rpartition(" ")
37
+ k, d = k.split("=")
38
+ d = d.to_i if d.to_i.to_s == d
39
+ d = true if d == "True"
40
+ d = false if d == "False"
41
+ d = nil if d == "None"
42
+ args << {name: k, type: t, default: d, pos: pos}
43
+ end
44
+ args
45
+ end
46
+ end
47
+
48
+ def out_size
49
+ @out_size ||= func.split("->").last.count("!")
50
+ end
51
+
52
+ def out?
53
+ out_size > 0 && base_name[-1] != "_"
54
+ end
55
+
56
+ def ruby_name
57
+ @ruby_name ||= begin
58
+ name = base_name
59
+ if name.end_with?("_")
60
+ "#{name[0..-2]}!"
61
+ elsif name.start_with?("is_")
62
+ "#{name[3..-1]}?"
63
+ else
64
+ name
65
+ end
66
+ end
67
+ end
68
+
69
+ def cpp_name
70
+ @cpp_name ||= "_" + name.downcase.sub(".", "_")
71
+ end
72
+
73
+ def base_name
74
+ @base_name ||= name.split(".").first
75
+ end
76
+ end
77
+ end
78
+ end
@@ -0,0 +1,149 @@
1
+ require "yaml"
2
+ # use require_relative for
3
+ # rake generate:function (without bundle)
4
+ require_relative "function"
5
+
6
+ module Torch
7
+ module Native
8
+ module Generator
9
+ class << self
10
+ def generate_cpp_functions
11
+ functions = grouped_functions
12
+ generate_cpp_file("torch", :define_singleton_method, functions[:torch])
13
+ generate_cpp_file("tensor", :define_method, functions[:tensor])
14
+ generate_cpp_file("nn", :define_singleton_method, functions[:nn])
15
+ end
16
+
17
+ def grouped_functions
18
+ functions = functions()
19
+
20
+ # remove functions
21
+ skip_binding = ["unique_dim_consecutive", "einsum", "normal"]
22
+ skip_args = ["bool[3]", "Dimname", "ScalarType", "MemoryFormat", "Storage", "ConstQuantizerPtr"]
23
+ functions.reject! { |f| f.ruby_name.start_with?("_") || f.ruby_name.end_with?("_backward") || skip_binding.include?(f.ruby_name) }
24
+ todo_functions, functions =
25
+ functions.partition do |f|
26
+ f.args.any? do |a|
27
+ a[:type].include?("?") && !["Tensor?", "Generator?", "int?"].include?(a[:type]) ||
28
+ skip_args.any? { |sa| a[:type].include?(sa) }
29
+ end
30
+ end
31
+
32
+ # generate additional functions for optional arguments
33
+ # there may be a better way to do this
34
+ optional_functions, functions = functions.partition { |f| f.args.any? { |a| a[:type] == "int?" } }
35
+ optional_functions.each do |f|
36
+ next if f.ruby_name.start_with?("avg_pool") || f.ruby_name == "cross"
37
+ opt_args = f.args.select { |a| a[:type] == "int?" }
38
+ if opt_args.size == 1
39
+ sep = f.name.include?(".") ? "_" : "."
40
+ f1 = Function.new(f.function.merge("func" => f.func.sub("(", "#{sep}#{opt_args.first[:name]}(").gsub("int?", "int")))
41
+ # TODO only remove some arguments
42
+ f2 = Function.new(f.function.merge("func" => f.func.sub(/, int\?.+\) ->/, ") ->")))
43
+ functions << f1
44
+ functions << f2
45
+ end
46
+ end
47
+
48
+ # todo_functions.each do |f|
49
+ # puts f.func
50
+ # puts
51
+ # end
52
+
53
+ nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" }
54
+ torch_functions = other_functions.select { |f| f.variants.include?("function") }
55
+ tensor_functions = other_functions.select { |f| f.variants.include?("method") }
56
+
57
+ {torch: torch_functions, tensor: tensor_functions, nn: nn_functions}
58
+ end
59
+
60
+ private
61
+
62
+ def generate_cpp_file(type, def_method, functions)
63
+ hpp_template = <<-TEMPLATE
64
+ // generated by rake generate:functions
65
+ // do not edit by hand
66
+
67
+ #pragma once
68
+
69
+ void add_%{type}_functions(Module m);
70
+ TEMPLATE
71
+
72
+ cpp_template = <<-TEMPLATE
73
+ // generated by rake generate:functions
74
+ // do not edit by hand
75
+
76
+ #include <torch/torch.h>
77
+ #include <rice/Module.hpp>
78
+ #include "templates.hpp"
79
+
80
+ void add_%{type}_functions(Module m) {
81
+ m
82
+ %{functions};
83
+ }
84
+ TEMPLATE
85
+
86
+ cpp_defs = []
87
+ functions.sort_by(&:cpp_name).each do |func|
88
+ fargs = func.args.select { |a| a[:type] != "Generator?" }
89
+
90
+ cpp_args = []
91
+ fargs.each do |a|
92
+ t =
93
+ case a[:type]
94
+ when "Tensor"
95
+ "const Tensor &"
96
+ when "Tensor?"
97
+ # TODO better signature
98
+ "OptionalTensor"
99
+ when "Tensor[]"
100
+ "TensorList"
101
+ when "int"
102
+ "int64_t"
103
+ when "float"
104
+ "double"
105
+ when /\Aint\[/
106
+ "IntArrayRef"
107
+ when /Tensor\(\S!?\)/
108
+ "Tensor &"
109
+ else
110
+ a[:type]
111
+ end
112
+
113
+ t = "MyReduction" if a[:name] == "reduction" && t == "int64_t"
114
+ cpp_args << [t, a[:name]].join(" ").sub("& ", "&")
115
+ end
116
+
117
+ dispatch = func.out? ? "#{func.base_name}_out" : func.base_name
118
+ args = fargs.map { |a| a[:name] }
119
+ args.unshift(*args.pop(func.out_size)) if func.out?
120
+ args.delete("self") if def_method == :define_method
121
+
122
+ prefix = def_method == :define_method ? "self." : "torch::"
123
+
124
+ cpp_defs << ".#{def_method}(
125
+ \"#{func.cpp_name}\",
126
+ *[](#{cpp_args.join(", ")}) {
127
+ return #{prefix}#{dispatch}(#{args.join(", ")});
128
+ })"
129
+ end
130
+
131
+ hpp_contents = hpp_template % {type: type}
132
+ cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n ")}
133
+
134
+ path = File.expand_path("../../../ext/torch", __dir__)
135
+ File.write("#{path}/#{type}_functions.hpp", hpp_contents)
136
+ File.write("#{path}/#{type}_functions.cpp", cpp_contents)
137
+ end
138
+
139
+ def functions
140
+ @native_functions ||= YAML.load_file(path).map { |f| Function.new(f) }
141
+ end
142
+
143
+ def path
144
+ File.expand_path("native_functions.yaml", __dir__)
145
+ end
146
+ end
147
+ end
148
+ end
149
+ end