torch-rb 0.11.2 → 0.12.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/README.md +5 -4
- data/codegen/function.rb +8 -0
- data/codegen/generate_functions.rb +30 -5
- data/codegen/native_functions.yaml +2067 -652
- data/ext/torch/ruby_arg_parser.cpp +88 -3
- data/ext/torch/ruby_arg_parser.h +39 -3
- data/ext/torch/utils.h +5 -1
- data/lib/torch/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 24617a6191c4d1e42ff0d5885c0762db75d034ee52817d70004ddc52025890dd
|
4
|
+
data.tar.gz: 7036f3c7fac8a1aac22914ec2811b8771464baeb847052e417b069f3897b3977
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 044166f526bea5fea0c4314abc55dfaf2410cdd5d7bce764ccb78329754fe284675dc0ecb55cf5a4546587221cf2361c8e215f1e9e067668ea5833c445bea961
|
7
|
+
data.tar.gz: e086d0dcc731fb3e810f8fb91d04d95bf3c5efd9cd35f053238be60479d8a47f2126ddad7ab2705d18a26175c6c4126f12be3f19962c2a69f24ad3ee551252b1
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -13,10 +13,10 @@ Check out:
|
|
13
13
|
|
14
14
|
## Installation
|
15
15
|
|
16
|
-
First, [install LibTorch](#libtorch-installation).
|
16
|
+
First, [install LibTorch](#libtorch-installation). With Homebrew, it’s part of the PyTorch package:
|
17
17
|
|
18
18
|
```sh
|
19
|
-
brew install
|
19
|
+
brew install pytorch
|
20
20
|
```
|
21
21
|
|
22
22
|
Add this line to your application’s Gemfile:
|
@@ -409,7 +409,8 @@ Here’s the list of compatible versions.
|
|
409
409
|
|
410
410
|
Torch.rb | LibTorch
|
411
411
|
--- | ---
|
412
|
-
0.
|
412
|
+
0.12.0+ | 1.13.0+
|
413
|
+
0.11.0-0.11.2 | 1.12.0-1.12.1
|
413
414
|
0.10.0-0.10.2 | 1.11.0
|
414
415
|
0.9.0-0.9.2 | 1.10.0-1.10.2
|
415
416
|
0.8.0-0.8.3 | 1.9.0-1.9.1
|
@@ -425,7 +426,7 @@ Torch.rb | LibTorch
|
|
425
426
|
You can also use Homebrew.
|
426
427
|
|
427
428
|
```sh
|
428
|
-
brew install
|
429
|
+
brew install pytorch
|
429
430
|
```
|
430
431
|
|
431
432
|
## Performance
|
data/codegen/function.rb
CHANGED
@@ -34,6 +34,10 @@ class Function
|
|
34
34
|
!out_index.nil?
|
35
35
|
end
|
36
36
|
|
37
|
+
def dispatch_name
|
38
|
+
definition.dig("dispatch", "CompositeImplicitAutograd")
|
39
|
+
end
|
40
|
+
|
37
41
|
private
|
38
42
|
|
39
43
|
def parse_func
|
@@ -77,6 +81,10 @@ class Function
|
|
77
81
|
default = "torch.int64"
|
78
82
|
end
|
79
83
|
|
84
|
+
if name == "dtype" && base_name == "randint"
|
85
|
+
default = "None"
|
86
|
+
end
|
87
|
+
|
80
88
|
default = nil if definition["cpp_no_default_args"].to_a.include?(name)
|
81
89
|
|
82
90
|
params << {
|
@@ -15,6 +15,7 @@ def generate_functions
|
|
15
15
|
generate_files("linalg", :define_singleton_method, functions[:linalg])
|
16
16
|
generate_files("special", :define_singleton_method, functions[:special])
|
17
17
|
generate_files("sparse", :define_singleton_method, functions[:sparse])
|
18
|
+
# TODO generate nested
|
18
19
|
end
|
19
20
|
|
20
21
|
def load_functions
|
@@ -40,13 +41,14 @@ def skip_functions(functions)
|
|
40
41
|
# not supported yet
|
41
42
|
f.func.include?("Dimname") ||
|
42
43
|
f.func.include?("ConstQuantizerPtr") ||
|
43
|
-
f.func.include?("SymInt") ||
|
44
44
|
# TODO fix LibTorch 1.12 changes
|
45
45
|
f.base_name == "histogramdd" ||
|
46
46
|
f.base_name == "nested_tensor" ||
|
47
47
|
f.base_name == "split_copy" ||
|
48
48
|
f.base_name == "split_with_sizes_copy" ||
|
49
|
-
f.base_name == "unbind_copy"
|
49
|
+
f.base_name == "unbind_copy" ||
|
50
|
+
# TODO fix LibTorch 1.13 changes
|
51
|
+
f.base_name == "native_channel_shuffle"
|
50
52
|
end
|
51
53
|
end
|
52
54
|
|
@@ -56,6 +58,7 @@ def group_functions(functions)
|
|
56
58
|
fft_functions, other_functions = other_functions.partition { |f| f.python_module == "fft" }
|
57
59
|
special_functions, other_functions = other_functions.partition { |f| f.python_module == "special" }
|
58
60
|
sparse_functions, other_functions = other_functions.partition { |f| f.python_module == "sparse" }
|
61
|
+
nested_functions, other_functions = other_functions.partition { |f| f.python_module == "nested" }
|
59
62
|
unexpected_functions, other_functions = other_functions.partition { |f| f.python_module }
|
60
63
|
torch_functions = other_functions.select { |f| f.variants.include?("function") }
|
61
64
|
tensor_functions = other_functions.select { |f| f.variants.include?("method") }
|
@@ -72,7 +75,8 @@ def group_functions(functions)
|
|
72
75
|
linalg: linalg_functions,
|
73
76
|
fft: fft_functions,
|
74
77
|
special: special_functions,
|
75
|
-
sparse: sparse_functions
|
78
|
+
sparse: sparse_functions,
|
79
|
+
nested: nested_functions
|
76
80
|
}
|
77
81
|
end
|
78
82
|
|
@@ -387,6 +391,8 @@ def generate_function_params(function, params, remove_self)
|
|
387
391
|
"scalarlist"
|
388
392
|
when /\Aint\[/
|
389
393
|
"intlist"
|
394
|
+
when /\ASymInt\[/
|
395
|
+
"symintlist"
|
390
396
|
when "float[]"
|
391
397
|
"doublelist"
|
392
398
|
when "Scalar"
|
@@ -395,6 +401,8 @@ def generate_function_params(function, params, remove_self)
|
|
395
401
|
"toBool"
|
396
402
|
when "int"
|
397
403
|
"toInt64"
|
404
|
+
when "SymInt"
|
405
|
+
"toSymInt"
|
398
406
|
when "float"
|
399
407
|
"toDouble"
|
400
408
|
when "ScalarType"
|
@@ -437,7 +445,12 @@ def generate_dispatch_code(function, def_method, params, opt_index, remove_self)
|
|
437
445
|
# torch::empty sets requires_grad by at::empty doesn't
|
438
446
|
# https://github.com/pytorch/pytorch/issues/36455
|
439
447
|
prefix = remove_self ? "self." : (opt_index ? "torch::" : "at::")
|
440
|
-
dispatch = function.
|
448
|
+
dispatch = function.dispatch_name
|
449
|
+
unless dispatch
|
450
|
+
dispatch = function.base_name
|
451
|
+
dispatch += "_symint" if function.func.include?("SymInt")
|
452
|
+
dispatch += "_out" if function.out?
|
453
|
+
end
|
441
454
|
|
442
455
|
params = params.map { |v| v[:name] }
|
443
456
|
params.reject! { |v| v == "self" } if remove_self
|
@@ -478,6 +491,8 @@ def generate_dispatch_params(function, params)
|
|
478
491
|
"ScalarList"
|
479
492
|
when "int"
|
480
493
|
"int64_t"
|
494
|
+
when "SymInt"
|
495
|
+
"c10::SymInt"
|
481
496
|
when "float"
|
482
497
|
"double"
|
483
498
|
when /\Aint\[/
|
@@ -486,6 +501,12 @@ def generate_dispatch_params(function, params)
|
|
486
501
|
else
|
487
502
|
"IntArrayRef"
|
488
503
|
end
|
504
|
+
when /\ASymInt\[/
|
505
|
+
if param[:optional]
|
506
|
+
"at::OptionalSymIntArrayRef"
|
507
|
+
else
|
508
|
+
"c10::SymIntArrayRef"
|
509
|
+
end
|
489
510
|
when "float[]"
|
490
511
|
"ArrayRef<double>"
|
491
512
|
when "str"
|
@@ -506,7 +527,7 @@ def generate_dispatch_params(function, params)
|
|
506
527
|
raise "Unknown type: #{param[:type]} (#{function.name})"
|
507
528
|
end
|
508
529
|
|
509
|
-
if param[:optional] && !["Tensor", "Scalar"].include?(param[:type]) && !param[:type].start_with?("int[")
|
530
|
+
if param[:optional] && !["Tensor", "Scalar"].include?(param[:type]) && !param[:type].start_with?("int[") && !param[:type].start_with?("SymInt[")
|
510
531
|
type = "c10::optional<#{type}>"
|
511
532
|
end
|
512
533
|
|
@@ -614,8 +635,12 @@ def signature_type(param)
|
|
614
635
|
"DirnameList"
|
615
636
|
when /\Aint\[\d*\]\z/
|
616
637
|
"IntArrayRef"
|
638
|
+
when /\ASymInt\[\d*\]\z/
|
639
|
+
"SymIntArrayRef"
|
617
640
|
when "int"
|
618
641
|
"int64_t"
|
642
|
+
when "SymInt"
|
643
|
+
"c10::SymInt"
|
619
644
|
when "float"
|
620
645
|
"double"
|
621
646
|
when "str"
|