torch-rb 0.11.2 → 0.12.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/README.md +5 -4
- data/codegen/function.rb +8 -0
- data/codegen/generate_functions.rb +30 -5
- data/codegen/native_functions.yaml +2067 -652
- data/ext/torch/ruby_arg_parser.cpp +88 -3
- data/ext/torch/ruby_arg_parser.h +39 -3
- data/ext/torch/utils.h +5 -1
- data/lib/torch/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 24617a6191c4d1e42ff0d5885c0762db75d034ee52817d70004ddc52025890dd
|
4
|
+
data.tar.gz: 7036f3c7fac8a1aac22914ec2811b8771464baeb847052e417b069f3897b3977
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 044166f526bea5fea0c4314abc55dfaf2410cdd5d7bce764ccb78329754fe284675dc0ecb55cf5a4546587221cf2361c8e215f1e9e067668ea5833c445bea961
|
7
|
+
data.tar.gz: e086d0dcc731fb3e810f8fb91d04d95bf3c5efd9cd35f053238be60479d8a47f2126ddad7ab2705d18a26175c6c4126f12be3f19962c2a69f24ad3ee551252b1
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -13,10 +13,10 @@ Check out:
|
|
13
13
|
|
14
14
|
## Installation
|
15
15
|
|
16
|
-
First, [install LibTorch](#libtorch-installation).
|
16
|
+
First, [install LibTorch](#libtorch-installation). With Homebrew, it’s part of the PyTorch package:
|
17
17
|
|
18
18
|
```sh
|
19
|
-
brew install
|
19
|
+
brew install pytorch
|
20
20
|
```
|
21
21
|
|
22
22
|
Add this line to your application’s Gemfile:
|
@@ -409,7 +409,8 @@ Here’s the list of compatible versions.
|
|
409
409
|
|
410
410
|
Torch.rb | LibTorch
|
411
411
|
--- | ---
|
412
|
-
0.
|
412
|
+
0.12.0+ | 1.13.0+
|
413
|
+
0.11.0-0.11.2 | 1.12.0-1.12.1
|
413
414
|
0.10.0-0.10.2 | 1.11.0
|
414
415
|
0.9.0-0.9.2 | 1.10.0-1.10.2
|
415
416
|
0.8.0-0.8.3 | 1.9.0-1.9.1
|
@@ -425,7 +426,7 @@ Torch.rb | LibTorch
|
|
425
426
|
You can also use Homebrew.
|
426
427
|
|
427
428
|
```sh
|
428
|
-
brew install
|
429
|
+
brew install pytorch
|
429
430
|
```
|
430
431
|
|
431
432
|
## Performance
|
data/codegen/function.rb
CHANGED
@@ -34,6 +34,10 @@ class Function
|
|
34
34
|
!out_index.nil?
|
35
35
|
end
|
36
36
|
|
37
|
+
def dispatch_name
|
38
|
+
definition.dig("dispatch", "CompositeImplicitAutograd")
|
39
|
+
end
|
40
|
+
|
37
41
|
private
|
38
42
|
|
39
43
|
def parse_func
|
@@ -77,6 +81,10 @@ class Function
|
|
77
81
|
default = "torch.int64"
|
78
82
|
end
|
79
83
|
|
84
|
+
if name == "dtype" && base_name == "randint"
|
85
|
+
default = "None"
|
86
|
+
end
|
87
|
+
|
80
88
|
default = nil if definition["cpp_no_default_args"].to_a.include?(name)
|
81
89
|
|
82
90
|
params << {
|
@@ -15,6 +15,7 @@ def generate_functions
|
|
15
15
|
generate_files("linalg", :define_singleton_method, functions[:linalg])
|
16
16
|
generate_files("special", :define_singleton_method, functions[:special])
|
17
17
|
generate_files("sparse", :define_singleton_method, functions[:sparse])
|
18
|
+
# TODO generate nested
|
18
19
|
end
|
19
20
|
|
20
21
|
def load_functions
|
@@ -40,13 +41,14 @@ def skip_functions(functions)
|
|
40
41
|
# not supported yet
|
41
42
|
f.func.include?("Dimname") ||
|
42
43
|
f.func.include?("ConstQuantizerPtr") ||
|
43
|
-
f.func.include?("SymInt") ||
|
44
44
|
# TODO fix LibTorch 1.12 changes
|
45
45
|
f.base_name == "histogramdd" ||
|
46
46
|
f.base_name == "nested_tensor" ||
|
47
47
|
f.base_name == "split_copy" ||
|
48
48
|
f.base_name == "split_with_sizes_copy" ||
|
49
|
-
f.base_name == "unbind_copy"
|
49
|
+
f.base_name == "unbind_copy" ||
|
50
|
+
# TODO fix LibTorch 1.13 changes
|
51
|
+
f.base_name == "native_channel_shuffle"
|
50
52
|
end
|
51
53
|
end
|
52
54
|
|
@@ -56,6 +58,7 @@ def group_functions(functions)
|
|
56
58
|
fft_functions, other_functions = other_functions.partition { |f| f.python_module == "fft" }
|
57
59
|
special_functions, other_functions = other_functions.partition { |f| f.python_module == "special" }
|
58
60
|
sparse_functions, other_functions = other_functions.partition { |f| f.python_module == "sparse" }
|
61
|
+
nested_functions, other_functions = other_functions.partition { |f| f.python_module == "nested" }
|
59
62
|
unexpected_functions, other_functions = other_functions.partition { |f| f.python_module }
|
60
63
|
torch_functions = other_functions.select { |f| f.variants.include?("function") }
|
61
64
|
tensor_functions = other_functions.select { |f| f.variants.include?("method") }
|
@@ -72,7 +75,8 @@ def group_functions(functions)
|
|
72
75
|
linalg: linalg_functions,
|
73
76
|
fft: fft_functions,
|
74
77
|
special: special_functions,
|
75
|
-
sparse: sparse_functions
|
78
|
+
sparse: sparse_functions,
|
79
|
+
nested: nested_functions
|
76
80
|
}
|
77
81
|
end
|
78
82
|
|
@@ -387,6 +391,8 @@ def generate_function_params(function, params, remove_self)
|
|
387
391
|
"scalarlist"
|
388
392
|
when /\Aint\[/
|
389
393
|
"intlist"
|
394
|
+
when /\ASymInt\[/
|
395
|
+
"symintlist"
|
390
396
|
when "float[]"
|
391
397
|
"doublelist"
|
392
398
|
when "Scalar"
|
@@ -395,6 +401,8 @@ def generate_function_params(function, params, remove_self)
|
|
395
401
|
"toBool"
|
396
402
|
when "int"
|
397
403
|
"toInt64"
|
404
|
+
when "SymInt"
|
405
|
+
"toSymInt"
|
398
406
|
when "float"
|
399
407
|
"toDouble"
|
400
408
|
when "ScalarType"
|
@@ -437,7 +445,12 @@ def generate_dispatch_code(function, def_method, params, opt_index, remove_self)
|
|
437
445
|
# torch::empty sets requires_grad by at::empty doesn't
|
438
446
|
# https://github.com/pytorch/pytorch/issues/36455
|
439
447
|
prefix = remove_self ? "self." : (opt_index ? "torch::" : "at::")
|
440
|
-
dispatch = function.
|
448
|
+
dispatch = function.dispatch_name
|
449
|
+
unless dispatch
|
450
|
+
dispatch = function.base_name
|
451
|
+
dispatch += "_symint" if function.func.include?("SymInt")
|
452
|
+
dispatch += "_out" if function.out?
|
453
|
+
end
|
441
454
|
|
442
455
|
params = params.map { |v| v[:name] }
|
443
456
|
params.reject! { |v| v == "self" } if remove_self
|
@@ -478,6 +491,8 @@ def generate_dispatch_params(function, params)
|
|
478
491
|
"ScalarList"
|
479
492
|
when "int"
|
480
493
|
"int64_t"
|
494
|
+
when "SymInt"
|
495
|
+
"c10::SymInt"
|
481
496
|
when "float"
|
482
497
|
"double"
|
483
498
|
when /\Aint\[/
|
@@ -486,6 +501,12 @@ def generate_dispatch_params(function, params)
|
|
486
501
|
else
|
487
502
|
"IntArrayRef"
|
488
503
|
end
|
504
|
+
when /\ASymInt\[/
|
505
|
+
if param[:optional]
|
506
|
+
"at::OptionalSymIntArrayRef"
|
507
|
+
else
|
508
|
+
"c10::SymIntArrayRef"
|
509
|
+
end
|
489
510
|
when "float[]"
|
490
511
|
"ArrayRef<double>"
|
491
512
|
when "str"
|
@@ -506,7 +527,7 @@ def generate_dispatch_params(function, params)
|
|
506
527
|
raise "Unknown type: #{param[:type]} (#{function.name})"
|
507
528
|
end
|
508
529
|
|
509
|
-
if param[:optional] && !["Tensor", "Scalar"].include?(param[:type]) && !param[:type].start_with?("int[")
|
530
|
+
if param[:optional] && !["Tensor", "Scalar"].include?(param[:type]) && !param[:type].start_with?("int[") && !param[:type].start_with?("SymInt[")
|
510
531
|
type = "c10::optional<#{type}>"
|
511
532
|
end
|
512
533
|
|
@@ -614,8 +635,12 @@ def signature_type(param)
|
|
614
635
|
"DirnameList"
|
615
636
|
when /\Aint\[\d*\]\z/
|
616
637
|
"IntArrayRef"
|
638
|
+
when /\ASymInt\[\d*\]\z/
|
639
|
+
"SymIntArrayRef"
|
617
640
|
when "int"
|
618
641
|
"int64_t"
|
642
|
+
when "SymInt"
|
643
|
+
"c10::SymInt"
|
619
644
|
when "float"
|
620
645
|
"double"
|
621
646
|
when "str"
|