torch-rb 0.1.4 → 0.1.5
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +5 -3
- data/ext/torch/ext.cpp +22 -548
- data/ext/torch/extconf.rb +6 -0
- data/ext/torch/nn_functions.cpp +595 -0
- data/ext/torch/nn_functions.hpp +6 -0
- data/ext/torch/templates.hpp +250 -0
- data/ext/torch/tensor_functions.cpp +1860 -0
- data/ext/torch/tensor_functions.hpp +6 -0
- data/ext/torch/torch_functions.cpp +2875 -0
- data/ext/torch/torch_functions.hpp +6 -0
- data/lib/torch.rb +68 -129
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/native/dispatcher.rb +48 -0
- data/lib/torch/native/function.rb +78 -0
- data/lib/torch/native/generator.rb +149 -0
- data/lib/torch/native/native_functions.yaml +6837 -0
- data/lib/torch/native/parser.rb +97 -0
- data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
- data/lib/torch/nn/conv2d.rb +0 -2
- data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
- data/lib/torch/nn/functional.rb +55 -16
- data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
- data/lib/torch/nn/identity.rb +1 -0
- data/lib/torch/nn/margin_ranking_loss.rb +14 -0
- data/lib/torch/nn/module.rb +59 -12
- data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_margin_loss.rb +17 -0
- data/lib/torch/nn/parameter.rb +4 -0
- data/lib/torch/nn/rnn.rb +22 -0
- data/lib/torch/nn/rnn_base.rb +154 -0
- data/lib/torch/nn/smooth_l1_loss.rb +13 -0
- data/lib/torch/nn/soft_margin_loss.rb +13 -0
- data/lib/torch/nn/triplet_margin_loss.rb +18 -0
- data/lib/torch/tensor.rb +19 -19
- data/lib/torch/version.rb +1 -1
- metadata +26 -2
data/lib/torch.rb
CHANGED
@@ -1,6 +1,11 @@
|
|
1
1
|
# ext
|
2
2
|
require "torch/ext"
|
3
3
|
|
4
|
+
# native functions
|
5
|
+
require "torch/native/generator"
|
6
|
+
require "torch/native/parser"
|
7
|
+
require "torch/native/dispatcher"
|
8
|
+
|
4
9
|
# modules
|
5
10
|
require "torch/inspector"
|
6
11
|
require "torch/tensor"
|
@@ -39,6 +44,10 @@ require "torch/nn/max_pool2d"
|
|
39
44
|
require "torch/nn/avg_poolnd"
|
40
45
|
require "torch/nn/avg_pool2d"
|
41
46
|
|
47
|
+
# nn recurrent layers
|
48
|
+
require "torch/nn/rnn_base"
|
49
|
+
require "torch/nn/rnn"
|
50
|
+
|
42
51
|
# nn linear layers
|
43
52
|
require "torch/nn/bilinear"
|
44
53
|
require "torch/nn/identity"
|
@@ -77,23 +86,23 @@ require "torch/nn/pairwise_distance"
|
|
77
86
|
require "torch/nn/loss"
|
78
87
|
require "torch/nn/weighted_loss"
|
79
88
|
require "torch/nn/bce_loss"
|
80
|
-
|
81
|
-
|
89
|
+
require "torch/nn/bce_with_logits_loss"
|
90
|
+
require "torch/nn/cosine_embedding_loss"
|
82
91
|
require "torch/nn/cross_entropy_loss"
|
83
92
|
require "torch/nn/ctc_loss"
|
84
|
-
|
93
|
+
require "torch/nn/hinge_embedding_loss"
|
85
94
|
require "torch/nn/kl_div_loss"
|
86
95
|
require "torch/nn/l1_loss"
|
87
|
-
|
96
|
+
require "torch/nn/margin_ranking_loss"
|
88
97
|
require "torch/nn/mse_loss"
|
89
|
-
|
90
|
-
|
91
|
-
|
98
|
+
require "torch/nn/multi_label_margin_loss"
|
99
|
+
require "torch/nn/multi_label_soft_margin_loss"
|
100
|
+
require "torch/nn/multi_margin_loss"
|
92
101
|
require "torch/nn/nll_loss"
|
93
102
|
require "torch/nn/poisson_nll_loss"
|
94
|
-
|
95
|
-
|
96
|
-
|
103
|
+
require "torch/nn/smooth_l1_loss"
|
104
|
+
require "torch/nn/soft_margin_loss"
|
105
|
+
require "torch/nn/triplet_margin_loss"
|
97
106
|
|
98
107
|
# nn other
|
99
108
|
require "torch/nn/functional"
|
@@ -142,6 +151,41 @@ module Torch
|
|
142
151
|
}
|
143
152
|
ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
|
144
153
|
|
154
|
+
def self._make_tensor_class(dtype, cuda = false)
|
155
|
+
cls = Class.new
|
156
|
+
device = cuda ? "cuda" : "cpu"
|
157
|
+
cls.define_singleton_method("new") do |*args|
|
158
|
+
if args.size == 1 && args.first.is_a?(Tensor)
|
159
|
+
args.first.send(dtype).to(device)
|
160
|
+
elsif args.size == 1 && args.first.is_a?(Array)
|
161
|
+
Torch.tensor(args.first, dtype: dtype, device: device)
|
162
|
+
else
|
163
|
+
Torch.empty(*args, dtype: dtype, device: device)
|
164
|
+
end
|
165
|
+
end
|
166
|
+
cls
|
167
|
+
end
|
168
|
+
|
169
|
+
FloatTensor = _make_tensor_class(:float32)
|
170
|
+
DoubleTensor = _make_tensor_class(:float64)
|
171
|
+
HalfTensor = _make_tensor_class(:float16)
|
172
|
+
ByteTensor = _make_tensor_class(:uint8)
|
173
|
+
CharTensor = _make_tensor_class(:int8)
|
174
|
+
ShortTensor = _make_tensor_class(:int16)
|
175
|
+
IntTensor = _make_tensor_class(:int32)
|
176
|
+
LongTensor = _make_tensor_class(:int64)
|
177
|
+
BoolTensor = _make_tensor_class(:bool)
|
178
|
+
|
179
|
+
CUDA::FloatTensor = _make_tensor_class(:float32, true)
|
180
|
+
CUDA::DoubleTensor = _make_tensor_class(:float64, true)
|
181
|
+
CUDA::HalfTensor = _make_tensor_class(:float16, true)
|
182
|
+
CUDA::ByteTensor = _make_tensor_class(:uint8, true)
|
183
|
+
CUDA::CharTensor = _make_tensor_class(:int8, true)
|
184
|
+
CUDA::ShortTensor = _make_tensor_class(:int16, true)
|
185
|
+
CUDA::IntTensor = _make_tensor_class(:int32, true)
|
186
|
+
CUDA::LongTensor = _make_tensor_class(:int64, true)
|
187
|
+
CUDA::BoolTensor = _make_tensor_class(:bool, true)
|
188
|
+
|
145
189
|
class << self
|
146
190
|
# Torch.float, Torch.long, etc
|
147
191
|
DTYPE_TO_ENUM.each_key do |dtype|
|
@@ -191,6 +235,20 @@ module Torch
|
|
191
235
|
}
|
192
236
|
end
|
193
237
|
|
238
|
+
def no_grad
|
239
|
+
previous_value = grad_enabled?
|
240
|
+
begin
|
241
|
+
_set_grad_enabled(false)
|
242
|
+
yield
|
243
|
+
ensure
|
244
|
+
_set_grad_enabled(previous_value)
|
245
|
+
end
|
246
|
+
end
|
247
|
+
|
248
|
+
def device(str)
|
249
|
+
Device.new(str)
|
250
|
+
end
|
251
|
+
|
194
252
|
# --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
|
195
253
|
|
196
254
|
def arange(start, finish = nil, step = 1, **options)
|
@@ -308,26 +366,6 @@ module Torch
|
|
308
366
|
|
309
367
|
# --- begin operations ---
|
310
368
|
|
311
|
-
%w(add sub mul div remainder).each do |op|
|
312
|
-
define_method(op) do |input, other, **options|
|
313
|
-
execute_op(op, input, other, **options)
|
314
|
-
end
|
315
|
-
end
|
316
|
-
|
317
|
-
def neg(input)
|
318
|
-
_neg(input)
|
319
|
-
end
|
320
|
-
|
321
|
-
def no_grad
|
322
|
-
previous_value = grad_enabled?
|
323
|
-
begin
|
324
|
-
_set_grad_enabled(false)
|
325
|
-
yield
|
326
|
-
ensure
|
327
|
-
_set_grad_enabled(previous_value)
|
328
|
-
end
|
329
|
-
end
|
330
|
-
|
331
369
|
# TODO support out
|
332
370
|
def mean(input, dim = nil, keepdim: false)
|
333
371
|
if dim
|
@@ -346,34 +384,10 @@ module Torch
|
|
346
384
|
end
|
347
385
|
end
|
348
386
|
|
349
|
-
def argmax(input, dim = nil, keepdim: false)
|
350
|
-
if dim
|
351
|
-
_argmax_dim(input, dim, keepdim)
|
352
|
-
else
|
353
|
-
_argmax(input)
|
354
|
-
end
|
355
|
-
end
|
356
|
-
|
357
|
-
def eq(input, other)
|
358
|
-
_eq(input, other)
|
359
|
-
end
|
360
|
-
|
361
|
-
def norm(input)
|
362
|
-
_norm(input)
|
363
|
-
end
|
364
|
-
|
365
|
-
def pow(input, exponent)
|
366
|
-
_pow(input, exponent)
|
367
|
-
end
|
368
|
-
|
369
387
|
def topk(input, k)
|
370
388
|
_topk(input, k)
|
371
389
|
end
|
372
390
|
|
373
|
-
def min(input)
|
374
|
-
_min(input)
|
375
|
-
end
|
376
|
-
|
377
391
|
def max(input, dim = nil, keepdim: false, out: nil)
|
378
392
|
if dim
|
379
393
|
raise NotImplementedYet unless out
|
@@ -383,58 +397,6 @@ module Torch
|
|
383
397
|
end
|
384
398
|
end
|
385
399
|
|
386
|
-
def exp(input)
|
387
|
-
_exp(input)
|
388
|
-
end
|
389
|
-
|
390
|
-
def log(input)
|
391
|
-
_log(input)
|
392
|
-
end
|
393
|
-
|
394
|
-
def sign(input)
|
395
|
-
_sign(input)
|
396
|
-
end
|
397
|
-
|
398
|
-
def sigmoid(input)
|
399
|
-
_sigmoid(input)
|
400
|
-
end
|
401
|
-
|
402
|
-
def gt(input, other)
|
403
|
-
_gt(input, other)
|
404
|
-
end
|
405
|
-
|
406
|
-
def lt(input, other)
|
407
|
-
_lt(input, other)
|
408
|
-
end
|
409
|
-
|
410
|
-
def unsqueeze(input, dim)
|
411
|
-
_unsqueeze(input, dim)
|
412
|
-
end
|
413
|
-
|
414
|
-
def dot(input, tensor)
|
415
|
-
_dot(input, tensor)
|
416
|
-
end
|
417
|
-
|
418
|
-
def cat(tensors, dim = 0)
|
419
|
-
_cat(tensors, dim)
|
420
|
-
end
|
421
|
-
|
422
|
-
def matmul(input, other)
|
423
|
-
_matmul(input, other)
|
424
|
-
end
|
425
|
-
|
426
|
-
def reshape(input, shape)
|
427
|
-
_reshape(input, shape)
|
428
|
-
end
|
429
|
-
|
430
|
-
def flatten(input, start_dim: 0, end_dim: -1)
|
431
|
-
_flatten(input, start_dim, end_dim)
|
432
|
-
end
|
433
|
-
|
434
|
-
def sqrt(input)
|
435
|
-
_sqrt(input)
|
436
|
-
end
|
437
|
-
|
438
400
|
# TODO make dim keyword argument
|
439
401
|
def log_softmax(input, dim)
|
440
402
|
_log_softmax(input, dim)
|
@@ -444,31 +406,8 @@ module Torch
|
|
444
406
|
_softmax(input, dim)
|
445
407
|
end
|
446
408
|
|
447
|
-
def abs(input)
|
448
|
-
_abs(input)
|
449
|
-
end
|
450
|
-
|
451
|
-
def device(str)
|
452
|
-
Device.new(str)
|
453
|
-
end
|
454
|
-
|
455
409
|
private
|
456
410
|
|
457
|
-
def execute_op(op, input, other, out: nil)
|
458
|
-
scalar = other.is_a?(Numeric)
|
459
|
-
if out
|
460
|
-
# TODO make work with scalars
|
461
|
-
raise Error, "out not supported with scalar yet" if scalar
|
462
|
-
send("_#{op}_out", out, input, other)
|
463
|
-
else
|
464
|
-
if scalar
|
465
|
-
send("_#{op}_scalar", input, other)
|
466
|
-
else
|
467
|
-
send("_#{op}", input, other)
|
468
|
-
end
|
469
|
-
end
|
470
|
-
end
|
471
|
-
|
472
411
|
def tensor_size(size)
|
473
412
|
size.flatten
|
474
413
|
end
|
data/lib/torch/ext.bundle
CHANGED
Binary file
|
@@ -0,0 +1,48 @@
|
|
1
|
+
# We use a generic interface for methods (*args, **options)
|
2
|
+
# and this class to determine the C++ method to call
|
3
|
+
#
|
4
|
+
# This is needed since LibTorch uses function overloading,
|
5
|
+
# which isn't available in Ruby or Python
|
6
|
+
#
|
7
|
+
# PyTorch uses this approach, but the parser/dispatcher is written in C++
|
8
|
+
#
|
9
|
+
# We could generate Ruby methods directly, but an advantage of this approach is
|
10
|
+
# arguments and keyword arguments can be used interchangably like in Python,
|
11
|
+
# making it easier to port code
|
12
|
+
|
13
|
+
module Torch
|
14
|
+
module Native
|
15
|
+
module Dispatcher
|
16
|
+
class << self
|
17
|
+
def bind
|
18
|
+
functions = Generator.grouped_functions
|
19
|
+
bind_functions(::Torch, :define_singleton_method, functions[:torch])
|
20
|
+
bind_functions(::Torch::Tensor, :define_method, functions[:tensor])
|
21
|
+
# NN functions are internal, so no need to bind
|
22
|
+
end
|
23
|
+
|
24
|
+
def bind_functions(context, def_method, functions)
|
25
|
+
functions.group_by(&:ruby_name).sort_by { |g, _| g }.each do |name, funcs|
|
26
|
+
if def_method == :define_method
|
27
|
+
funcs.map! { |f| Function.new(f.function) }
|
28
|
+
funcs.each { |f| f.args.reject! { |a| a[:name] == "self" } }
|
29
|
+
end
|
30
|
+
|
31
|
+
defined = def_method == :define_method ? context.method_defined?(name) : context.respond_to?(name)
|
32
|
+
next if defined && name != "clone"
|
33
|
+
|
34
|
+
parser = Parser.new(funcs)
|
35
|
+
|
36
|
+
context.send(def_method, name) do |*args, **options|
|
37
|
+
result = parser.parse(args, options)
|
38
|
+
raise ArgumentError, result[:error] if result[:error]
|
39
|
+
send(result[:name], *result[:args])
|
40
|
+
end
|
41
|
+
end
|
42
|
+
end
|
43
|
+
end
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
47
|
+
|
48
|
+
Torch::Native::Dispatcher.bind
|
@@ -0,0 +1,78 @@
|
|
1
|
+
module Torch
|
2
|
+
module Native
|
3
|
+
class Function
|
4
|
+
attr_reader :function
|
5
|
+
|
6
|
+
def initialize(function)
|
7
|
+
@function = function
|
8
|
+
end
|
9
|
+
|
10
|
+
def func
|
11
|
+
@func ||= @function["func"]
|
12
|
+
end
|
13
|
+
|
14
|
+
def name
|
15
|
+
@name ||= func.split("(", 2).first
|
16
|
+
end
|
17
|
+
|
18
|
+
def python_module
|
19
|
+
@python_module ||= @function["python_module"]
|
20
|
+
end
|
21
|
+
|
22
|
+
def variants
|
23
|
+
@variants ||= (@function["variants"] || "function").split(", ")
|
24
|
+
end
|
25
|
+
|
26
|
+
def args
|
27
|
+
@args ||= begin
|
28
|
+
args = []
|
29
|
+
pos = true
|
30
|
+
args_str = func.split("(", 2).last.split(") ->").first
|
31
|
+
args_str.split(", ").each do |a|
|
32
|
+
if a == "*"
|
33
|
+
pos = false
|
34
|
+
next
|
35
|
+
end
|
36
|
+
t, _, k = a.rpartition(" ")
|
37
|
+
k, d = k.split("=")
|
38
|
+
d = d.to_i if d.to_i.to_s == d
|
39
|
+
d = true if d == "True"
|
40
|
+
d = false if d == "False"
|
41
|
+
d = nil if d == "None"
|
42
|
+
args << {name: k, type: t, default: d, pos: pos}
|
43
|
+
end
|
44
|
+
args
|
45
|
+
end
|
46
|
+
end
|
47
|
+
|
48
|
+
def out_size
|
49
|
+
@out_size ||= func.split("->").last.count("!")
|
50
|
+
end
|
51
|
+
|
52
|
+
def out?
|
53
|
+
out_size > 0 && base_name[-1] != "_"
|
54
|
+
end
|
55
|
+
|
56
|
+
def ruby_name
|
57
|
+
@ruby_name ||= begin
|
58
|
+
name = base_name
|
59
|
+
if name.end_with?("_")
|
60
|
+
"#{name[0..-2]}!"
|
61
|
+
elsif name.start_with?("is_")
|
62
|
+
"#{name[3..-1]}?"
|
63
|
+
else
|
64
|
+
name
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|
68
|
+
|
69
|
+
def cpp_name
|
70
|
+
@cpp_name ||= "_" + name.downcase.sub(".", "_")
|
71
|
+
end
|
72
|
+
|
73
|
+
def base_name
|
74
|
+
@base_name ||= name.split(".").first
|
75
|
+
end
|
76
|
+
end
|
77
|
+
end
|
78
|
+
end
|
@@ -0,0 +1,149 @@
|
|
1
|
+
require "yaml"
|
2
|
+
# use require_relative for
|
3
|
+
# rake generate:function (without bundle)
|
4
|
+
require_relative "function"
|
5
|
+
|
6
|
+
module Torch
|
7
|
+
module Native
|
8
|
+
module Generator
|
9
|
+
class << self
|
10
|
+
def generate_cpp_functions
|
11
|
+
functions = grouped_functions
|
12
|
+
generate_cpp_file("torch", :define_singleton_method, functions[:torch])
|
13
|
+
generate_cpp_file("tensor", :define_method, functions[:tensor])
|
14
|
+
generate_cpp_file("nn", :define_singleton_method, functions[:nn])
|
15
|
+
end
|
16
|
+
|
17
|
+
def grouped_functions
|
18
|
+
functions = functions()
|
19
|
+
|
20
|
+
# remove functions
|
21
|
+
skip_binding = ["unique_dim_consecutive", "einsum", "normal"]
|
22
|
+
skip_args = ["bool[3]", "Dimname", "ScalarType", "MemoryFormat", "Storage", "ConstQuantizerPtr"]
|
23
|
+
functions.reject! { |f| f.ruby_name.start_with?("_") || f.ruby_name.end_with?("_backward") || skip_binding.include?(f.ruby_name) }
|
24
|
+
todo_functions, functions =
|
25
|
+
functions.partition do |f|
|
26
|
+
f.args.any? do |a|
|
27
|
+
a[:type].include?("?") && !["Tensor?", "Generator?", "int?"].include?(a[:type]) ||
|
28
|
+
skip_args.any? { |sa| a[:type].include?(sa) }
|
29
|
+
end
|
30
|
+
end
|
31
|
+
|
32
|
+
# generate additional functions for optional arguments
|
33
|
+
# there may be a better way to do this
|
34
|
+
optional_functions, functions = functions.partition { |f| f.args.any? { |a| a[:type] == "int?" } }
|
35
|
+
optional_functions.each do |f|
|
36
|
+
next if f.ruby_name.start_with?("avg_pool") || f.ruby_name == "cross"
|
37
|
+
opt_args = f.args.select { |a| a[:type] == "int?" }
|
38
|
+
if opt_args.size == 1
|
39
|
+
sep = f.name.include?(".") ? "_" : "."
|
40
|
+
f1 = Function.new(f.function.merge("func" => f.func.sub("(", "#{sep}#{opt_args.first[:name]}(").gsub("int?", "int")))
|
41
|
+
# TODO only remove some arguments
|
42
|
+
f2 = Function.new(f.function.merge("func" => f.func.sub(/, int\?.+\) ->/, ") ->")))
|
43
|
+
functions << f1
|
44
|
+
functions << f2
|
45
|
+
end
|
46
|
+
end
|
47
|
+
|
48
|
+
# todo_functions.each do |f|
|
49
|
+
# puts f.func
|
50
|
+
# puts
|
51
|
+
# end
|
52
|
+
|
53
|
+
nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" }
|
54
|
+
torch_functions = other_functions.select { |f| f.variants.include?("function") }
|
55
|
+
tensor_functions = other_functions.select { |f| f.variants.include?("method") }
|
56
|
+
|
57
|
+
{torch: torch_functions, tensor: tensor_functions, nn: nn_functions}
|
58
|
+
end
|
59
|
+
|
60
|
+
private
|
61
|
+
|
62
|
+
def generate_cpp_file(type, def_method, functions)
|
63
|
+
hpp_template = <<-TEMPLATE
|
64
|
+
// generated by rake generate:functions
|
65
|
+
// do not edit by hand
|
66
|
+
|
67
|
+
#pragma once
|
68
|
+
|
69
|
+
void add_%{type}_functions(Module m);
|
70
|
+
TEMPLATE
|
71
|
+
|
72
|
+
cpp_template = <<-TEMPLATE
|
73
|
+
// generated by rake generate:functions
|
74
|
+
// do not edit by hand
|
75
|
+
|
76
|
+
#include <torch/torch.h>
|
77
|
+
#include <rice/Module.hpp>
|
78
|
+
#include "templates.hpp"
|
79
|
+
|
80
|
+
void add_%{type}_functions(Module m) {
|
81
|
+
m
|
82
|
+
%{functions};
|
83
|
+
}
|
84
|
+
TEMPLATE
|
85
|
+
|
86
|
+
cpp_defs = []
|
87
|
+
functions.sort_by(&:cpp_name).each do |func|
|
88
|
+
fargs = func.args.select { |a| a[:type] != "Generator?" }
|
89
|
+
|
90
|
+
cpp_args = []
|
91
|
+
fargs.each do |a|
|
92
|
+
t =
|
93
|
+
case a[:type]
|
94
|
+
when "Tensor"
|
95
|
+
"const Tensor &"
|
96
|
+
when "Tensor?"
|
97
|
+
# TODO better signature
|
98
|
+
"OptionalTensor"
|
99
|
+
when "Tensor[]"
|
100
|
+
"TensorList"
|
101
|
+
when "int"
|
102
|
+
"int64_t"
|
103
|
+
when "float"
|
104
|
+
"double"
|
105
|
+
when /\Aint\[/
|
106
|
+
"IntArrayRef"
|
107
|
+
when /Tensor\(\S!?\)/
|
108
|
+
"Tensor &"
|
109
|
+
else
|
110
|
+
a[:type]
|
111
|
+
end
|
112
|
+
|
113
|
+
t = "MyReduction" if a[:name] == "reduction" && t == "int64_t"
|
114
|
+
cpp_args << [t, a[:name]].join(" ").sub("& ", "&")
|
115
|
+
end
|
116
|
+
|
117
|
+
dispatch = func.out? ? "#{func.base_name}_out" : func.base_name
|
118
|
+
args = fargs.map { |a| a[:name] }
|
119
|
+
args.unshift(*args.pop(func.out_size)) if func.out?
|
120
|
+
args.delete("self") if def_method == :define_method
|
121
|
+
|
122
|
+
prefix = def_method == :define_method ? "self." : "torch::"
|
123
|
+
|
124
|
+
cpp_defs << ".#{def_method}(
|
125
|
+
\"#{func.cpp_name}\",
|
126
|
+
*[](#{cpp_args.join(", ")}) {
|
127
|
+
return #{prefix}#{dispatch}(#{args.join(", ")});
|
128
|
+
})"
|
129
|
+
end
|
130
|
+
|
131
|
+
hpp_contents = hpp_template % {type: type}
|
132
|
+
cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n ")}
|
133
|
+
|
134
|
+
path = File.expand_path("../../../ext/torch", __dir__)
|
135
|
+
File.write("#{path}/#{type}_functions.hpp", hpp_contents)
|
136
|
+
File.write("#{path}/#{type}_functions.cpp", cpp_contents)
|
137
|
+
end
|
138
|
+
|
139
|
+
def functions
|
140
|
+
@native_functions ||= YAML.load_file(path).map { |f| Function.new(f) }
|
141
|
+
end
|
142
|
+
|
143
|
+
def path
|
144
|
+
File.expand_path("native_functions.yaml", __dir__)
|
145
|
+
end
|
146
|
+
end
|
147
|
+
end
|
148
|
+
end
|
149
|
+
end
|