torch-rb 0.1.4 → 0.1.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (39) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +6 -0
  3. data/README.md +5 -3
  4. data/ext/torch/ext.cpp +22 -548
  5. data/ext/torch/extconf.rb +6 -0
  6. data/ext/torch/nn_functions.cpp +595 -0
  7. data/ext/torch/nn_functions.hpp +6 -0
  8. data/ext/torch/templates.hpp +250 -0
  9. data/ext/torch/tensor_functions.cpp +1860 -0
  10. data/ext/torch/tensor_functions.hpp +6 -0
  11. data/ext/torch/torch_functions.cpp +2875 -0
  12. data/ext/torch/torch_functions.hpp +6 -0
  13. data/lib/torch.rb +68 -129
  14. data/lib/torch/ext.bundle +0 -0
  15. data/lib/torch/native/dispatcher.rb +48 -0
  16. data/lib/torch/native/function.rb +78 -0
  17. data/lib/torch/native/generator.rb +149 -0
  18. data/lib/torch/native/native_functions.yaml +6837 -0
  19. data/lib/torch/native/parser.rb +97 -0
  20. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  21. data/lib/torch/nn/conv2d.rb +0 -2
  22. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  23. data/lib/torch/nn/functional.rb +55 -16
  24. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  25. data/lib/torch/nn/identity.rb +1 -0
  26. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  27. data/lib/torch/nn/module.rb +59 -12
  28. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  29. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  30. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  31. data/lib/torch/nn/parameter.rb +4 -0
  32. data/lib/torch/nn/rnn.rb +22 -0
  33. data/lib/torch/nn/rnn_base.rb +154 -0
  34. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  35. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  36. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  37. data/lib/torch/tensor.rb +19 -19
  38. data/lib/torch/version.rb +1 -1
  39. metadata +26 -2
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_torch_functions(Module m);
data/lib/torch.rb CHANGED
@@ -1,6 +1,11 @@
1
1
  # ext
2
2
  require "torch/ext"
3
3
 
4
+ # native functions
5
+ require "torch/native/generator"
6
+ require "torch/native/parser"
7
+ require "torch/native/dispatcher"
8
+
4
9
  # modules
5
10
  require "torch/inspector"
6
11
  require "torch/tensor"
@@ -39,6 +44,10 @@ require "torch/nn/max_pool2d"
39
44
  require "torch/nn/avg_poolnd"
40
45
  require "torch/nn/avg_pool2d"
41
46
 
47
+ # nn recurrent layers
48
+ require "torch/nn/rnn_base"
49
+ require "torch/nn/rnn"
50
+
42
51
  # nn linear layers
43
52
  require "torch/nn/bilinear"
44
53
  require "torch/nn/identity"
@@ -77,23 +86,23 @@ require "torch/nn/pairwise_distance"
77
86
  require "torch/nn/loss"
78
87
  require "torch/nn/weighted_loss"
79
88
  require "torch/nn/bce_loss"
80
- # require "torch/nn/bce_with_logits_loss"
81
- # require "torch/nn/cosine_embedding_loss"
89
+ require "torch/nn/bce_with_logits_loss"
90
+ require "torch/nn/cosine_embedding_loss"
82
91
  require "torch/nn/cross_entropy_loss"
83
92
  require "torch/nn/ctc_loss"
84
- # require "torch/nn/hinge_embedding_loss"
93
+ require "torch/nn/hinge_embedding_loss"
85
94
  require "torch/nn/kl_div_loss"
86
95
  require "torch/nn/l1_loss"
87
- # require "torch/nn/margin_ranking_loss"
96
+ require "torch/nn/margin_ranking_loss"
88
97
  require "torch/nn/mse_loss"
89
- # require "torch/nn/multi_label_margin_loss"
90
- # require "torch/nn/multi_label_soft_margin_loss"
91
- # require "torch/nn/multi_margin_loss"
98
+ require "torch/nn/multi_label_margin_loss"
99
+ require "torch/nn/multi_label_soft_margin_loss"
100
+ require "torch/nn/multi_margin_loss"
92
101
  require "torch/nn/nll_loss"
93
102
  require "torch/nn/poisson_nll_loss"
94
- # require "torch/nn/smooth_l1_loss"
95
- # require "torch/nn/soft_margin_loss"
96
- # require "torch/nn/triplet_margin_loss"
103
+ require "torch/nn/smooth_l1_loss"
104
+ require "torch/nn/soft_margin_loss"
105
+ require "torch/nn/triplet_margin_loss"
97
106
 
98
107
  # nn other
99
108
  require "torch/nn/functional"
@@ -142,6 +151,41 @@ module Torch
142
151
  }
143
152
  ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
144
153
 
154
+ def self._make_tensor_class(dtype, cuda = false)
155
+ cls = Class.new
156
+ device = cuda ? "cuda" : "cpu"
157
+ cls.define_singleton_method("new") do |*args|
158
+ if args.size == 1 && args.first.is_a?(Tensor)
159
+ args.first.send(dtype).to(device)
160
+ elsif args.size == 1 && args.first.is_a?(Array)
161
+ Torch.tensor(args.first, dtype: dtype, device: device)
162
+ else
163
+ Torch.empty(*args, dtype: dtype, device: device)
164
+ end
165
+ end
166
+ cls
167
+ end
168
+
169
+ FloatTensor = _make_tensor_class(:float32)
170
+ DoubleTensor = _make_tensor_class(:float64)
171
+ HalfTensor = _make_tensor_class(:float16)
172
+ ByteTensor = _make_tensor_class(:uint8)
173
+ CharTensor = _make_tensor_class(:int8)
174
+ ShortTensor = _make_tensor_class(:int16)
175
+ IntTensor = _make_tensor_class(:int32)
176
+ LongTensor = _make_tensor_class(:int64)
177
+ BoolTensor = _make_tensor_class(:bool)
178
+
179
+ CUDA::FloatTensor = _make_tensor_class(:float32, true)
180
+ CUDA::DoubleTensor = _make_tensor_class(:float64, true)
181
+ CUDA::HalfTensor = _make_tensor_class(:float16, true)
182
+ CUDA::ByteTensor = _make_tensor_class(:uint8, true)
183
+ CUDA::CharTensor = _make_tensor_class(:int8, true)
184
+ CUDA::ShortTensor = _make_tensor_class(:int16, true)
185
+ CUDA::IntTensor = _make_tensor_class(:int32, true)
186
+ CUDA::LongTensor = _make_tensor_class(:int64, true)
187
+ CUDA::BoolTensor = _make_tensor_class(:bool, true)
188
+
145
189
  class << self
146
190
  # Torch.float, Torch.long, etc
147
191
  DTYPE_TO_ENUM.each_key do |dtype|
@@ -191,6 +235,20 @@ module Torch
191
235
  }
192
236
  end
193
237
 
238
+ def no_grad
239
+ previous_value = grad_enabled?
240
+ begin
241
+ _set_grad_enabled(false)
242
+ yield
243
+ ensure
244
+ _set_grad_enabled(previous_value)
245
+ end
246
+ end
247
+
248
+ def device(str)
249
+ Device.new(str)
250
+ end
251
+
194
252
  # --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
195
253
 
196
254
  def arange(start, finish = nil, step = 1, **options)
@@ -308,26 +366,6 @@ module Torch
308
366
 
309
367
  # --- begin operations ---
310
368
 
311
- %w(add sub mul div remainder).each do |op|
312
- define_method(op) do |input, other, **options|
313
- execute_op(op, input, other, **options)
314
- end
315
- end
316
-
317
- def neg(input)
318
- _neg(input)
319
- end
320
-
321
- def no_grad
322
- previous_value = grad_enabled?
323
- begin
324
- _set_grad_enabled(false)
325
- yield
326
- ensure
327
- _set_grad_enabled(previous_value)
328
- end
329
- end
330
-
331
369
  # TODO support out
332
370
  def mean(input, dim = nil, keepdim: false)
333
371
  if dim
@@ -346,34 +384,10 @@ module Torch
346
384
  end
347
385
  end
348
386
 
349
- def argmax(input, dim = nil, keepdim: false)
350
- if dim
351
- _argmax_dim(input, dim, keepdim)
352
- else
353
- _argmax(input)
354
- end
355
- end
356
-
357
- def eq(input, other)
358
- _eq(input, other)
359
- end
360
-
361
- def norm(input)
362
- _norm(input)
363
- end
364
-
365
- def pow(input, exponent)
366
- _pow(input, exponent)
367
- end
368
-
369
387
  def topk(input, k)
370
388
  _topk(input, k)
371
389
  end
372
390
 
373
- def min(input)
374
- _min(input)
375
- end
376
-
377
391
  def max(input, dim = nil, keepdim: false, out: nil)
378
392
  if dim
379
393
  raise NotImplementedYet unless out
@@ -383,58 +397,6 @@ module Torch
383
397
  end
384
398
  end
385
399
 
386
- def exp(input)
387
- _exp(input)
388
- end
389
-
390
- def log(input)
391
- _log(input)
392
- end
393
-
394
- def sign(input)
395
- _sign(input)
396
- end
397
-
398
- def sigmoid(input)
399
- _sigmoid(input)
400
- end
401
-
402
- def gt(input, other)
403
- _gt(input, other)
404
- end
405
-
406
- def lt(input, other)
407
- _lt(input, other)
408
- end
409
-
410
- def unsqueeze(input, dim)
411
- _unsqueeze(input, dim)
412
- end
413
-
414
- def dot(input, tensor)
415
- _dot(input, tensor)
416
- end
417
-
418
- def cat(tensors, dim = 0)
419
- _cat(tensors, dim)
420
- end
421
-
422
- def matmul(input, other)
423
- _matmul(input, other)
424
- end
425
-
426
- def reshape(input, shape)
427
- _reshape(input, shape)
428
- end
429
-
430
- def flatten(input, start_dim: 0, end_dim: -1)
431
- _flatten(input, start_dim, end_dim)
432
- end
433
-
434
- def sqrt(input)
435
- _sqrt(input)
436
- end
437
-
438
400
  # TODO make dim keyword argument
439
401
  def log_softmax(input, dim)
440
402
  _log_softmax(input, dim)
@@ -444,31 +406,8 @@ module Torch
444
406
  _softmax(input, dim)
445
407
  end
446
408
 
447
- def abs(input)
448
- _abs(input)
449
- end
450
-
451
- def device(str)
452
- Device.new(str)
453
- end
454
-
455
409
  private
456
410
 
457
- def execute_op(op, input, other, out: nil)
458
- scalar = other.is_a?(Numeric)
459
- if out
460
- # TODO make work with scalars
461
- raise Error, "out not supported with scalar yet" if scalar
462
- send("_#{op}_out", out, input, other)
463
- else
464
- if scalar
465
- send("_#{op}_scalar", input, other)
466
- else
467
- send("_#{op}", input, other)
468
- end
469
- end
470
- end
471
-
472
411
  def tensor_size(size)
473
412
  size.flatten
474
413
  end
data/lib/torch/ext.bundle CHANGED
Binary file
@@ -0,0 +1,48 @@
1
+ # We use a generic interface for methods (*args, **options)
2
+ # and this class to determine the C++ method to call
3
+ #
4
+ # This is needed since LibTorch uses function overloading,
5
+ # which isn't available in Ruby or Python
6
+ #
7
+ # PyTorch uses this approach, but the parser/dispatcher is written in C++
8
+ #
9
+ # We could generate Ruby methods directly, but an advantage of this approach is
10
+ # arguments and keyword arguments can be used interchangably like in Python,
11
+ # making it easier to port code
12
+
13
+ module Torch
14
+ module Native
15
+ module Dispatcher
16
+ class << self
17
+ def bind
18
+ functions = Generator.grouped_functions
19
+ bind_functions(::Torch, :define_singleton_method, functions[:torch])
20
+ bind_functions(::Torch::Tensor, :define_method, functions[:tensor])
21
+ # NN functions are internal, so no need to bind
22
+ end
23
+
24
+ def bind_functions(context, def_method, functions)
25
+ functions.group_by(&:ruby_name).sort_by { |g, _| g }.each do |name, funcs|
26
+ if def_method == :define_method
27
+ funcs.map! { |f| Function.new(f.function) }
28
+ funcs.each { |f| f.args.reject! { |a| a[:name] == "self" } }
29
+ end
30
+
31
+ defined = def_method == :define_method ? context.method_defined?(name) : context.respond_to?(name)
32
+ next if defined && name != "clone"
33
+
34
+ parser = Parser.new(funcs)
35
+
36
+ context.send(def_method, name) do |*args, **options|
37
+ result = parser.parse(args, options)
38
+ raise ArgumentError, result[:error] if result[:error]
39
+ send(result[:name], *result[:args])
40
+ end
41
+ end
42
+ end
43
+ end
44
+ end
45
+ end
46
+ end
47
+
48
+ Torch::Native::Dispatcher.bind
@@ -0,0 +1,78 @@
1
+ module Torch
2
+ module Native
3
+ class Function
4
+ attr_reader :function
5
+
6
+ def initialize(function)
7
+ @function = function
8
+ end
9
+
10
+ def func
11
+ @func ||= @function["func"]
12
+ end
13
+
14
+ def name
15
+ @name ||= func.split("(", 2).first
16
+ end
17
+
18
+ def python_module
19
+ @python_module ||= @function["python_module"]
20
+ end
21
+
22
+ def variants
23
+ @variants ||= (@function["variants"] || "function").split(", ")
24
+ end
25
+
26
+ def args
27
+ @args ||= begin
28
+ args = []
29
+ pos = true
30
+ args_str = func.split("(", 2).last.split(") ->").first
31
+ args_str.split(", ").each do |a|
32
+ if a == "*"
33
+ pos = false
34
+ next
35
+ end
36
+ t, _, k = a.rpartition(" ")
37
+ k, d = k.split("=")
38
+ d = d.to_i if d.to_i.to_s == d
39
+ d = true if d == "True"
40
+ d = false if d == "False"
41
+ d = nil if d == "None"
42
+ args << {name: k, type: t, default: d, pos: pos}
43
+ end
44
+ args
45
+ end
46
+ end
47
+
48
+ def out_size
49
+ @out_size ||= func.split("->").last.count("!")
50
+ end
51
+
52
+ def out?
53
+ out_size > 0 && base_name[-1] != "_"
54
+ end
55
+
56
+ def ruby_name
57
+ @ruby_name ||= begin
58
+ name = base_name
59
+ if name.end_with?("_")
60
+ "#{name[0..-2]}!"
61
+ elsif name.start_with?("is_")
62
+ "#{name[3..-1]}?"
63
+ else
64
+ name
65
+ end
66
+ end
67
+ end
68
+
69
+ def cpp_name
70
+ @cpp_name ||= "_" + name.downcase.sub(".", "_")
71
+ end
72
+
73
+ def base_name
74
+ @base_name ||= name.split(".").first
75
+ end
76
+ end
77
+ end
78
+ end
@@ -0,0 +1,149 @@
1
+ require "yaml"
2
+ # use require_relative for
3
+ # rake generate:function (without bundle)
4
+ require_relative "function"
5
+
6
+ module Torch
7
+ module Native
8
+ module Generator
9
+ class << self
10
+ def generate_cpp_functions
11
+ functions = grouped_functions
12
+ generate_cpp_file("torch", :define_singleton_method, functions[:torch])
13
+ generate_cpp_file("tensor", :define_method, functions[:tensor])
14
+ generate_cpp_file("nn", :define_singleton_method, functions[:nn])
15
+ end
16
+
17
+ def grouped_functions
18
+ functions = functions()
19
+
20
+ # remove functions
21
+ skip_binding = ["unique_dim_consecutive", "einsum", "normal"]
22
+ skip_args = ["bool[3]", "Dimname", "ScalarType", "MemoryFormat", "Storage", "ConstQuantizerPtr"]
23
+ functions.reject! { |f| f.ruby_name.start_with?("_") || f.ruby_name.end_with?("_backward") || skip_binding.include?(f.ruby_name) }
24
+ todo_functions, functions =
25
+ functions.partition do |f|
26
+ f.args.any? do |a|
27
+ a[:type].include?("?") && !["Tensor?", "Generator?", "int?"].include?(a[:type]) ||
28
+ skip_args.any? { |sa| a[:type].include?(sa) }
29
+ end
30
+ end
31
+
32
+ # generate additional functions for optional arguments
33
+ # there may be a better way to do this
34
+ optional_functions, functions = functions.partition { |f| f.args.any? { |a| a[:type] == "int?" } }
35
+ optional_functions.each do |f|
36
+ next if f.ruby_name.start_with?("avg_pool") || f.ruby_name == "cross"
37
+ opt_args = f.args.select { |a| a[:type] == "int?" }
38
+ if opt_args.size == 1
39
+ sep = f.name.include?(".") ? "_" : "."
40
+ f1 = Function.new(f.function.merge("func" => f.func.sub("(", "#{sep}#{opt_args.first[:name]}(").gsub("int?", "int")))
41
+ # TODO only remove some arguments
42
+ f2 = Function.new(f.function.merge("func" => f.func.sub(/, int\?.+\) ->/, ") ->")))
43
+ functions << f1
44
+ functions << f2
45
+ end
46
+ end
47
+
48
+ # todo_functions.each do |f|
49
+ # puts f.func
50
+ # puts
51
+ # end
52
+
53
+ nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" }
54
+ torch_functions = other_functions.select { |f| f.variants.include?("function") }
55
+ tensor_functions = other_functions.select { |f| f.variants.include?("method") }
56
+
57
+ {torch: torch_functions, tensor: tensor_functions, nn: nn_functions}
58
+ end
59
+
60
+ private
61
+
62
+ def generate_cpp_file(type, def_method, functions)
63
+ hpp_template = <<-TEMPLATE
64
+ // generated by rake generate:functions
65
+ // do not edit by hand
66
+
67
+ #pragma once
68
+
69
+ void add_%{type}_functions(Module m);
70
+ TEMPLATE
71
+
72
+ cpp_template = <<-TEMPLATE
73
+ // generated by rake generate:functions
74
+ // do not edit by hand
75
+
76
+ #include <torch/torch.h>
77
+ #include <rice/Module.hpp>
78
+ #include "templates.hpp"
79
+
80
+ void add_%{type}_functions(Module m) {
81
+ m
82
+ %{functions};
83
+ }
84
+ TEMPLATE
85
+
86
+ cpp_defs = []
87
+ functions.sort_by(&:cpp_name).each do |func|
88
+ fargs = func.args.select { |a| a[:type] != "Generator?" }
89
+
90
+ cpp_args = []
91
+ fargs.each do |a|
92
+ t =
93
+ case a[:type]
94
+ when "Tensor"
95
+ "const Tensor &"
96
+ when "Tensor?"
97
+ # TODO better signature
98
+ "OptionalTensor"
99
+ when "Tensor[]"
100
+ "TensorList"
101
+ when "int"
102
+ "int64_t"
103
+ when "float"
104
+ "double"
105
+ when /\Aint\[/
106
+ "IntArrayRef"
107
+ when /Tensor\(\S!?\)/
108
+ "Tensor &"
109
+ else
110
+ a[:type]
111
+ end
112
+
113
+ t = "MyReduction" if a[:name] == "reduction" && t == "int64_t"
114
+ cpp_args << [t, a[:name]].join(" ").sub("& ", "&")
115
+ end
116
+
117
+ dispatch = func.out? ? "#{func.base_name}_out" : func.base_name
118
+ args = fargs.map { |a| a[:name] }
119
+ args.unshift(*args.pop(func.out_size)) if func.out?
120
+ args.delete("self") if def_method == :define_method
121
+
122
+ prefix = def_method == :define_method ? "self." : "torch::"
123
+
124
+ cpp_defs << ".#{def_method}(
125
+ \"#{func.cpp_name}\",
126
+ *[](#{cpp_args.join(", ")}) {
127
+ return #{prefix}#{dispatch}(#{args.join(", ")});
128
+ })"
129
+ end
130
+
131
+ hpp_contents = hpp_template % {type: type}
132
+ cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n ")}
133
+
134
+ path = File.expand_path("../../../ext/torch", __dir__)
135
+ File.write("#{path}/#{type}_functions.hpp", hpp_contents)
136
+ File.write("#{path}/#{type}_functions.cpp", cpp_contents)
137
+ end
138
+
139
+ def functions
140
+ @native_functions ||= YAML.load_file(path).map { |f| Function.new(f) }
141
+ end
142
+
143
+ def path
144
+ File.expand_path("native_functions.yaml", __dir__)
145
+ end
146
+ end
147
+ end
148
+ end
149
+ end