torch-rb 0.1.4 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +5 -3
- data/ext/torch/ext.cpp +22 -548
- data/ext/torch/extconf.rb +6 -0
- data/ext/torch/nn_functions.cpp +595 -0
- data/ext/torch/nn_functions.hpp +6 -0
- data/ext/torch/templates.hpp +250 -0
- data/ext/torch/tensor_functions.cpp +1860 -0
- data/ext/torch/tensor_functions.hpp +6 -0
- data/ext/torch/torch_functions.cpp +2875 -0
- data/ext/torch/torch_functions.hpp +6 -0
- data/lib/torch.rb +68 -129
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/native/dispatcher.rb +48 -0
- data/lib/torch/native/function.rb +78 -0
- data/lib/torch/native/generator.rb +149 -0
- data/lib/torch/native/native_functions.yaml +6837 -0
- data/lib/torch/native/parser.rb +97 -0
- data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
- data/lib/torch/nn/conv2d.rb +0 -2
- data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
- data/lib/torch/nn/functional.rb +55 -16
- data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
- data/lib/torch/nn/identity.rb +1 -0
- data/lib/torch/nn/margin_ranking_loss.rb +14 -0
- data/lib/torch/nn/module.rb +59 -12
- data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_margin_loss.rb +17 -0
- data/lib/torch/nn/parameter.rb +4 -0
- data/lib/torch/nn/rnn.rb +22 -0
- data/lib/torch/nn/rnn_base.rb +154 -0
- data/lib/torch/nn/smooth_l1_loss.rb +13 -0
- data/lib/torch/nn/soft_margin_loss.rb +13 -0
- data/lib/torch/nn/triplet_margin_loss.rb +18 -0
- data/lib/torch/tensor.rb +19 -19
- data/lib/torch/version.rb +1 -1
- metadata +26 -2
data/lib/torch.rb
CHANGED
@@ -1,6 +1,11 @@
|
|
1
1
|
# ext
|
2
2
|
require "torch/ext"
|
3
3
|
|
4
|
+
# native functions
|
5
|
+
require "torch/native/generator"
|
6
|
+
require "torch/native/parser"
|
7
|
+
require "torch/native/dispatcher"
|
8
|
+
|
4
9
|
# modules
|
5
10
|
require "torch/inspector"
|
6
11
|
require "torch/tensor"
|
@@ -39,6 +44,10 @@ require "torch/nn/max_pool2d"
|
|
39
44
|
require "torch/nn/avg_poolnd"
|
40
45
|
require "torch/nn/avg_pool2d"
|
41
46
|
|
47
|
+
# nn recurrent layers
|
48
|
+
require "torch/nn/rnn_base"
|
49
|
+
require "torch/nn/rnn"
|
50
|
+
|
42
51
|
# nn linear layers
|
43
52
|
require "torch/nn/bilinear"
|
44
53
|
require "torch/nn/identity"
|
@@ -77,23 +86,23 @@ require "torch/nn/pairwise_distance"
|
|
77
86
|
require "torch/nn/loss"
|
78
87
|
require "torch/nn/weighted_loss"
|
79
88
|
require "torch/nn/bce_loss"
|
80
|
-
|
81
|
-
|
89
|
+
require "torch/nn/bce_with_logits_loss"
|
90
|
+
require "torch/nn/cosine_embedding_loss"
|
82
91
|
require "torch/nn/cross_entropy_loss"
|
83
92
|
require "torch/nn/ctc_loss"
|
84
|
-
|
93
|
+
require "torch/nn/hinge_embedding_loss"
|
85
94
|
require "torch/nn/kl_div_loss"
|
86
95
|
require "torch/nn/l1_loss"
|
87
|
-
|
96
|
+
require "torch/nn/margin_ranking_loss"
|
88
97
|
require "torch/nn/mse_loss"
|
89
|
-
|
90
|
-
|
91
|
-
|
98
|
+
require "torch/nn/multi_label_margin_loss"
|
99
|
+
require "torch/nn/multi_label_soft_margin_loss"
|
100
|
+
require "torch/nn/multi_margin_loss"
|
92
101
|
require "torch/nn/nll_loss"
|
93
102
|
require "torch/nn/poisson_nll_loss"
|
94
|
-
|
95
|
-
|
96
|
-
|
103
|
+
require "torch/nn/smooth_l1_loss"
|
104
|
+
require "torch/nn/soft_margin_loss"
|
105
|
+
require "torch/nn/triplet_margin_loss"
|
97
106
|
|
98
107
|
# nn other
|
99
108
|
require "torch/nn/functional"
|
@@ -142,6 +151,41 @@ module Torch
|
|
142
151
|
}
|
143
152
|
ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
|
144
153
|
|
154
|
+
def self._make_tensor_class(dtype, cuda = false)
|
155
|
+
cls = Class.new
|
156
|
+
device = cuda ? "cuda" : "cpu"
|
157
|
+
cls.define_singleton_method("new") do |*args|
|
158
|
+
if args.size == 1 && args.first.is_a?(Tensor)
|
159
|
+
args.first.send(dtype).to(device)
|
160
|
+
elsif args.size == 1 && args.first.is_a?(Array)
|
161
|
+
Torch.tensor(args.first, dtype: dtype, device: device)
|
162
|
+
else
|
163
|
+
Torch.empty(*args, dtype: dtype, device: device)
|
164
|
+
end
|
165
|
+
end
|
166
|
+
cls
|
167
|
+
end
|
168
|
+
|
169
|
+
FloatTensor = _make_tensor_class(:float32)
|
170
|
+
DoubleTensor = _make_tensor_class(:float64)
|
171
|
+
HalfTensor = _make_tensor_class(:float16)
|
172
|
+
ByteTensor = _make_tensor_class(:uint8)
|
173
|
+
CharTensor = _make_tensor_class(:int8)
|
174
|
+
ShortTensor = _make_tensor_class(:int16)
|
175
|
+
IntTensor = _make_tensor_class(:int32)
|
176
|
+
LongTensor = _make_tensor_class(:int64)
|
177
|
+
BoolTensor = _make_tensor_class(:bool)
|
178
|
+
|
179
|
+
CUDA::FloatTensor = _make_tensor_class(:float32, true)
|
180
|
+
CUDA::DoubleTensor = _make_tensor_class(:float64, true)
|
181
|
+
CUDA::HalfTensor = _make_tensor_class(:float16, true)
|
182
|
+
CUDA::ByteTensor = _make_tensor_class(:uint8, true)
|
183
|
+
CUDA::CharTensor = _make_tensor_class(:int8, true)
|
184
|
+
CUDA::ShortTensor = _make_tensor_class(:int16, true)
|
185
|
+
CUDA::IntTensor = _make_tensor_class(:int32, true)
|
186
|
+
CUDA::LongTensor = _make_tensor_class(:int64, true)
|
187
|
+
CUDA::BoolTensor = _make_tensor_class(:bool, true)
|
188
|
+
|
145
189
|
class << self
|
146
190
|
# Torch.float, Torch.long, etc
|
147
191
|
DTYPE_TO_ENUM.each_key do |dtype|
|
@@ -191,6 +235,20 @@ module Torch
|
|
191
235
|
}
|
192
236
|
end
|
193
237
|
|
238
|
+
def no_grad
|
239
|
+
previous_value = grad_enabled?
|
240
|
+
begin
|
241
|
+
_set_grad_enabled(false)
|
242
|
+
yield
|
243
|
+
ensure
|
244
|
+
_set_grad_enabled(previous_value)
|
245
|
+
end
|
246
|
+
end
|
247
|
+
|
248
|
+
def device(str)
|
249
|
+
Device.new(str)
|
250
|
+
end
|
251
|
+
|
194
252
|
# --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
|
195
253
|
|
196
254
|
def arange(start, finish = nil, step = 1, **options)
|
@@ -308,26 +366,6 @@ module Torch
|
|
308
366
|
|
309
367
|
# --- begin operations ---
|
310
368
|
|
311
|
-
%w(add sub mul div remainder).each do |op|
|
312
|
-
define_method(op) do |input, other, **options|
|
313
|
-
execute_op(op, input, other, **options)
|
314
|
-
end
|
315
|
-
end
|
316
|
-
|
317
|
-
def neg(input)
|
318
|
-
_neg(input)
|
319
|
-
end
|
320
|
-
|
321
|
-
def no_grad
|
322
|
-
previous_value = grad_enabled?
|
323
|
-
begin
|
324
|
-
_set_grad_enabled(false)
|
325
|
-
yield
|
326
|
-
ensure
|
327
|
-
_set_grad_enabled(previous_value)
|
328
|
-
end
|
329
|
-
end
|
330
|
-
|
331
369
|
# TODO support out
|
332
370
|
def mean(input, dim = nil, keepdim: false)
|
333
371
|
if dim
|
@@ -346,34 +384,10 @@ module Torch
|
|
346
384
|
end
|
347
385
|
end
|
348
386
|
|
349
|
-
def argmax(input, dim = nil, keepdim: false)
|
350
|
-
if dim
|
351
|
-
_argmax_dim(input, dim, keepdim)
|
352
|
-
else
|
353
|
-
_argmax(input)
|
354
|
-
end
|
355
|
-
end
|
356
|
-
|
357
|
-
def eq(input, other)
|
358
|
-
_eq(input, other)
|
359
|
-
end
|
360
|
-
|
361
|
-
def norm(input)
|
362
|
-
_norm(input)
|
363
|
-
end
|
364
|
-
|
365
|
-
def pow(input, exponent)
|
366
|
-
_pow(input, exponent)
|
367
|
-
end
|
368
|
-
|
369
387
|
def topk(input, k)
|
370
388
|
_topk(input, k)
|
371
389
|
end
|
372
390
|
|
373
|
-
def min(input)
|
374
|
-
_min(input)
|
375
|
-
end
|
376
|
-
|
377
391
|
def max(input, dim = nil, keepdim: false, out: nil)
|
378
392
|
if dim
|
379
393
|
raise NotImplementedYet unless out
|
@@ -383,58 +397,6 @@ module Torch
|
|
383
397
|
end
|
384
398
|
end
|
385
399
|
|
386
|
-
def exp(input)
|
387
|
-
_exp(input)
|
388
|
-
end
|
389
|
-
|
390
|
-
def log(input)
|
391
|
-
_log(input)
|
392
|
-
end
|
393
|
-
|
394
|
-
def sign(input)
|
395
|
-
_sign(input)
|
396
|
-
end
|
397
|
-
|
398
|
-
def sigmoid(input)
|
399
|
-
_sigmoid(input)
|
400
|
-
end
|
401
|
-
|
402
|
-
def gt(input, other)
|
403
|
-
_gt(input, other)
|
404
|
-
end
|
405
|
-
|
406
|
-
def lt(input, other)
|
407
|
-
_lt(input, other)
|
408
|
-
end
|
409
|
-
|
410
|
-
def unsqueeze(input, dim)
|
411
|
-
_unsqueeze(input, dim)
|
412
|
-
end
|
413
|
-
|
414
|
-
def dot(input, tensor)
|
415
|
-
_dot(input, tensor)
|
416
|
-
end
|
417
|
-
|
418
|
-
def cat(tensors, dim = 0)
|
419
|
-
_cat(tensors, dim)
|
420
|
-
end
|
421
|
-
|
422
|
-
def matmul(input, other)
|
423
|
-
_matmul(input, other)
|
424
|
-
end
|
425
|
-
|
426
|
-
def reshape(input, shape)
|
427
|
-
_reshape(input, shape)
|
428
|
-
end
|
429
|
-
|
430
|
-
def flatten(input, start_dim: 0, end_dim: -1)
|
431
|
-
_flatten(input, start_dim, end_dim)
|
432
|
-
end
|
433
|
-
|
434
|
-
def sqrt(input)
|
435
|
-
_sqrt(input)
|
436
|
-
end
|
437
|
-
|
438
400
|
# TODO make dim keyword argument
|
439
401
|
def log_softmax(input, dim)
|
440
402
|
_log_softmax(input, dim)
|
@@ -444,31 +406,8 @@ module Torch
|
|
444
406
|
_softmax(input, dim)
|
445
407
|
end
|
446
408
|
|
447
|
-
def abs(input)
|
448
|
-
_abs(input)
|
449
|
-
end
|
450
|
-
|
451
|
-
def device(str)
|
452
|
-
Device.new(str)
|
453
|
-
end
|
454
|
-
|
455
409
|
private
|
456
410
|
|
457
|
-
def execute_op(op, input, other, out: nil)
|
458
|
-
scalar = other.is_a?(Numeric)
|
459
|
-
if out
|
460
|
-
# TODO make work with scalars
|
461
|
-
raise Error, "out not supported with scalar yet" if scalar
|
462
|
-
send("_#{op}_out", out, input, other)
|
463
|
-
else
|
464
|
-
if scalar
|
465
|
-
send("_#{op}_scalar", input, other)
|
466
|
-
else
|
467
|
-
send("_#{op}", input, other)
|
468
|
-
end
|
469
|
-
end
|
470
|
-
end
|
471
|
-
|
472
411
|
def tensor_size(size)
|
473
412
|
size.flatten
|
474
413
|
end
|
data/lib/torch/ext.bundle
CHANGED
Binary file
|
@@ -0,0 +1,48 @@
|
|
1
|
+
# We use a generic interface for methods (*args, **options)
|
2
|
+
# and this class to determine the C++ method to call
|
3
|
+
#
|
4
|
+
# This is needed since LibTorch uses function overloading,
|
5
|
+
# which isn't available in Ruby or Python
|
6
|
+
#
|
7
|
+
# PyTorch uses this approach, but the parser/dispatcher is written in C++
|
8
|
+
#
|
9
|
+
# We could generate Ruby methods directly, but an advantage of this approach is
|
10
|
+
# arguments and keyword arguments can be used interchangably like in Python,
|
11
|
+
# making it easier to port code
|
12
|
+
|
13
|
+
module Torch
|
14
|
+
module Native
|
15
|
+
module Dispatcher
|
16
|
+
class << self
|
17
|
+
def bind
|
18
|
+
functions = Generator.grouped_functions
|
19
|
+
bind_functions(::Torch, :define_singleton_method, functions[:torch])
|
20
|
+
bind_functions(::Torch::Tensor, :define_method, functions[:tensor])
|
21
|
+
# NN functions are internal, so no need to bind
|
22
|
+
end
|
23
|
+
|
24
|
+
def bind_functions(context, def_method, functions)
|
25
|
+
functions.group_by(&:ruby_name).sort_by { |g, _| g }.each do |name, funcs|
|
26
|
+
if def_method == :define_method
|
27
|
+
funcs.map! { |f| Function.new(f.function) }
|
28
|
+
funcs.each { |f| f.args.reject! { |a| a[:name] == "self" } }
|
29
|
+
end
|
30
|
+
|
31
|
+
defined = def_method == :define_method ? context.method_defined?(name) : context.respond_to?(name)
|
32
|
+
next if defined && name != "clone"
|
33
|
+
|
34
|
+
parser = Parser.new(funcs)
|
35
|
+
|
36
|
+
context.send(def_method, name) do |*args, **options|
|
37
|
+
result = parser.parse(args, options)
|
38
|
+
raise ArgumentError, result[:error] if result[:error]
|
39
|
+
send(result[:name], *result[:args])
|
40
|
+
end
|
41
|
+
end
|
42
|
+
end
|
43
|
+
end
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
47
|
+
|
48
|
+
Torch::Native::Dispatcher.bind
|
@@ -0,0 +1,78 @@
|
|
1
|
+
module Torch
|
2
|
+
module Native
|
3
|
+
class Function
|
4
|
+
attr_reader :function
|
5
|
+
|
6
|
+
def initialize(function)
|
7
|
+
@function = function
|
8
|
+
end
|
9
|
+
|
10
|
+
def func
|
11
|
+
@func ||= @function["func"]
|
12
|
+
end
|
13
|
+
|
14
|
+
def name
|
15
|
+
@name ||= func.split("(", 2).first
|
16
|
+
end
|
17
|
+
|
18
|
+
def python_module
|
19
|
+
@python_module ||= @function["python_module"]
|
20
|
+
end
|
21
|
+
|
22
|
+
def variants
|
23
|
+
@variants ||= (@function["variants"] || "function").split(", ")
|
24
|
+
end
|
25
|
+
|
26
|
+
def args
|
27
|
+
@args ||= begin
|
28
|
+
args = []
|
29
|
+
pos = true
|
30
|
+
args_str = func.split("(", 2).last.split(") ->").first
|
31
|
+
args_str.split(", ").each do |a|
|
32
|
+
if a == "*"
|
33
|
+
pos = false
|
34
|
+
next
|
35
|
+
end
|
36
|
+
t, _, k = a.rpartition(" ")
|
37
|
+
k, d = k.split("=")
|
38
|
+
d = d.to_i if d.to_i.to_s == d
|
39
|
+
d = true if d == "True"
|
40
|
+
d = false if d == "False"
|
41
|
+
d = nil if d == "None"
|
42
|
+
args << {name: k, type: t, default: d, pos: pos}
|
43
|
+
end
|
44
|
+
args
|
45
|
+
end
|
46
|
+
end
|
47
|
+
|
48
|
+
def out_size
|
49
|
+
@out_size ||= func.split("->").last.count("!")
|
50
|
+
end
|
51
|
+
|
52
|
+
def out?
|
53
|
+
out_size > 0 && base_name[-1] != "_"
|
54
|
+
end
|
55
|
+
|
56
|
+
def ruby_name
|
57
|
+
@ruby_name ||= begin
|
58
|
+
name = base_name
|
59
|
+
if name.end_with?("_")
|
60
|
+
"#{name[0..-2]}!"
|
61
|
+
elsif name.start_with?("is_")
|
62
|
+
"#{name[3..-1]}?"
|
63
|
+
else
|
64
|
+
name
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|
68
|
+
|
69
|
+
def cpp_name
|
70
|
+
@cpp_name ||= "_" + name.downcase.sub(".", "_")
|
71
|
+
end
|
72
|
+
|
73
|
+
def base_name
|
74
|
+
@base_name ||= name.split(".").first
|
75
|
+
end
|
76
|
+
end
|
77
|
+
end
|
78
|
+
end
|
@@ -0,0 +1,149 @@
|
|
1
|
+
require "yaml"
|
2
|
+
# use require_relative for
|
3
|
+
# rake generate:function (without bundle)
|
4
|
+
require_relative "function"
|
5
|
+
|
6
|
+
module Torch
|
7
|
+
module Native
|
8
|
+
module Generator
|
9
|
+
class << self
|
10
|
+
def generate_cpp_functions
|
11
|
+
functions = grouped_functions
|
12
|
+
generate_cpp_file("torch", :define_singleton_method, functions[:torch])
|
13
|
+
generate_cpp_file("tensor", :define_method, functions[:tensor])
|
14
|
+
generate_cpp_file("nn", :define_singleton_method, functions[:nn])
|
15
|
+
end
|
16
|
+
|
17
|
+
def grouped_functions
|
18
|
+
functions = functions()
|
19
|
+
|
20
|
+
# remove functions
|
21
|
+
skip_binding = ["unique_dim_consecutive", "einsum", "normal"]
|
22
|
+
skip_args = ["bool[3]", "Dimname", "ScalarType", "MemoryFormat", "Storage", "ConstQuantizerPtr"]
|
23
|
+
functions.reject! { |f| f.ruby_name.start_with?("_") || f.ruby_name.end_with?("_backward") || skip_binding.include?(f.ruby_name) }
|
24
|
+
todo_functions, functions =
|
25
|
+
functions.partition do |f|
|
26
|
+
f.args.any? do |a|
|
27
|
+
a[:type].include?("?") && !["Tensor?", "Generator?", "int?"].include?(a[:type]) ||
|
28
|
+
skip_args.any? { |sa| a[:type].include?(sa) }
|
29
|
+
end
|
30
|
+
end
|
31
|
+
|
32
|
+
# generate additional functions for optional arguments
|
33
|
+
# there may be a better way to do this
|
34
|
+
optional_functions, functions = functions.partition { |f| f.args.any? { |a| a[:type] == "int?" } }
|
35
|
+
optional_functions.each do |f|
|
36
|
+
next if f.ruby_name.start_with?("avg_pool") || f.ruby_name == "cross"
|
37
|
+
opt_args = f.args.select { |a| a[:type] == "int?" }
|
38
|
+
if opt_args.size == 1
|
39
|
+
sep = f.name.include?(".") ? "_" : "."
|
40
|
+
f1 = Function.new(f.function.merge("func" => f.func.sub("(", "#{sep}#{opt_args.first[:name]}(").gsub("int?", "int")))
|
41
|
+
# TODO only remove some arguments
|
42
|
+
f2 = Function.new(f.function.merge("func" => f.func.sub(/, int\?.+\) ->/, ") ->")))
|
43
|
+
functions << f1
|
44
|
+
functions << f2
|
45
|
+
end
|
46
|
+
end
|
47
|
+
|
48
|
+
# todo_functions.each do |f|
|
49
|
+
# puts f.func
|
50
|
+
# puts
|
51
|
+
# end
|
52
|
+
|
53
|
+
nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" }
|
54
|
+
torch_functions = other_functions.select { |f| f.variants.include?("function") }
|
55
|
+
tensor_functions = other_functions.select { |f| f.variants.include?("method") }
|
56
|
+
|
57
|
+
{torch: torch_functions, tensor: tensor_functions, nn: nn_functions}
|
58
|
+
end
|
59
|
+
|
60
|
+
private
|
61
|
+
|
62
|
+
def generate_cpp_file(type, def_method, functions)
|
63
|
+
hpp_template = <<-TEMPLATE
|
64
|
+
// generated by rake generate:functions
|
65
|
+
// do not edit by hand
|
66
|
+
|
67
|
+
#pragma once
|
68
|
+
|
69
|
+
void add_%{type}_functions(Module m);
|
70
|
+
TEMPLATE
|
71
|
+
|
72
|
+
cpp_template = <<-TEMPLATE
|
73
|
+
// generated by rake generate:functions
|
74
|
+
// do not edit by hand
|
75
|
+
|
76
|
+
#include <torch/torch.h>
|
77
|
+
#include <rice/Module.hpp>
|
78
|
+
#include "templates.hpp"
|
79
|
+
|
80
|
+
void add_%{type}_functions(Module m) {
|
81
|
+
m
|
82
|
+
%{functions};
|
83
|
+
}
|
84
|
+
TEMPLATE
|
85
|
+
|
86
|
+
cpp_defs = []
|
87
|
+
functions.sort_by(&:cpp_name).each do |func|
|
88
|
+
fargs = func.args.select { |a| a[:type] != "Generator?" }
|
89
|
+
|
90
|
+
cpp_args = []
|
91
|
+
fargs.each do |a|
|
92
|
+
t =
|
93
|
+
case a[:type]
|
94
|
+
when "Tensor"
|
95
|
+
"const Tensor &"
|
96
|
+
when "Tensor?"
|
97
|
+
# TODO better signature
|
98
|
+
"OptionalTensor"
|
99
|
+
when "Tensor[]"
|
100
|
+
"TensorList"
|
101
|
+
when "int"
|
102
|
+
"int64_t"
|
103
|
+
when "float"
|
104
|
+
"double"
|
105
|
+
when /\Aint\[/
|
106
|
+
"IntArrayRef"
|
107
|
+
when /Tensor\(\S!?\)/
|
108
|
+
"Tensor &"
|
109
|
+
else
|
110
|
+
a[:type]
|
111
|
+
end
|
112
|
+
|
113
|
+
t = "MyReduction" if a[:name] == "reduction" && t == "int64_t"
|
114
|
+
cpp_args << [t, a[:name]].join(" ").sub("& ", "&")
|
115
|
+
end
|
116
|
+
|
117
|
+
dispatch = func.out? ? "#{func.base_name}_out" : func.base_name
|
118
|
+
args = fargs.map { |a| a[:name] }
|
119
|
+
args.unshift(*args.pop(func.out_size)) if func.out?
|
120
|
+
args.delete("self") if def_method == :define_method
|
121
|
+
|
122
|
+
prefix = def_method == :define_method ? "self." : "torch::"
|
123
|
+
|
124
|
+
cpp_defs << ".#{def_method}(
|
125
|
+
\"#{func.cpp_name}\",
|
126
|
+
*[](#{cpp_args.join(", ")}) {
|
127
|
+
return #{prefix}#{dispatch}(#{args.join(", ")});
|
128
|
+
})"
|
129
|
+
end
|
130
|
+
|
131
|
+
hpp_contents = hpp_template % {type: type}
|
132
|
+
cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n ")}
|
133
|
+
|
134
|
+
path = File.expand_path("../../../ext/torch", __dir__)
|
135
|
+
File.write("#{path}/#{type}_functions.hpp", hpp_contents)
|
136
|
+
File.write("#{path}/#{type}_functions.cpp", cpp_contents)
|
137
|
+
end
|
138
|
+
|
139
|
+
def functions
|
140
|
+
@native_functions ||= YAML.load_file(path).map { |f| Function.new(f) }
|
141
|
+
end
|
142
|
+
|
143
|
+
def path
|
144
|
+
File.expand_path("native_functions.yaml", __dir__)
|
145
|
+
end
|
146
|
+
end
|
147
|
+
end
|
148
|
+
end
|
149
|
+
end
|