torch-rb 0.10.2 → 0.11.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +2 -1
- data/codegen/generate_functions.rb +24 -7
- data/codegen/native_functions.yaml +1362 -199
- data/ext/torch/extconf.rb +1 -13
- data/ext/torch/ruby_arg_parser.h +11 -3
- data/ext/torch/utils.h +1 -1
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +1 -0
- metadata +3 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: d2f88c938144476a772fd7c606751e0a6e67a338cb1c37ede9c2011db7fc4579
|
4
|
+
data.tar.gz: f4408457e3bf8c7bf9b42863459248aa592c70e8e3924cdbe72863a979e65106
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 90ebc506942809b02331f7accce6e93e714a57b5d2f58b06ad0d3c60204b947eba61fe0bf43431a19e27890f5c764cb348ccb1785a2ffccfba7969f4726b6f1f
|
7
|
+
data.tar.gz: 410af7b4934f79aaae94f6c1bde1d17ee8b043f02354b4ea9f139271575f43abb90afeba162badfa01a3ec406612ae6cdcb4a2adfb1067b1634982a590d10cc9
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -39,7 +39,14 @@ def skip_functions(functions)
|
|
39
39
|
f.base_name == "index_put" ||
|
40
40
|
# not supported yet
|
41
41
|
f.func.include?("Dimname") ||
|
42
|
-
f.func.include?("ConstQuantizerPtr")
|
42
|
+
f.func.include?("ConstQuantizerPtr") ||
|
43
|
+
f.func.include?("SymInt") ||
|
44
|
+
# TODO fix LibTorch 1.12 changes
|
45
|
+
f.base_name == "histogramdd" ||
|
46
|
+
f.base_name == "nested_tensor" ||
|
47
|
+
f.base_name == "split_copy" ||
|
48
|
+
f.base_name == "split_with_sizes_copy" ||
|
49
|
+
f.base_name == "unbind_copy"
|
43
50
|
end
|
44
51
|
end
|
45
52
|
|
@@ -250,7 +257,7 @@ def generate_dispatch(function, def_method)
|
|
250
257
|
|
251
258
|
cpp_params = generate_dispatch_params(function, params)
|
252
259
|
if opt_index
|
253
|
-
cpp_params.insert(remove_self ? opt_index + 1 : opt_index, "
|
260
|
+
cpp_params.insert(remove_self ? opt_index + 1 : opt_index, "TensorOptions options")
|
254
261
|
end
|
255
262
|
|
256
263
|
retval = generate_dispatch_retval(function)
|
@@ -410,7 +417,7 @@ def generate_function_params(function, params, remove_self)
|
|
410
417
|
else
|
411
418
|
"optionalTensor"
|
412
419
|
end
|
413
|
-
when "generator", "tensorlist"
|
420
|
+
when "generator", "tensorlist"
|
414
421
|
func
|
415
422
|
when "string"
|
416
423
|
"stringViewOptional"
|
@@ -471,7 +478,11 @@ def generate_dispatch_params(function, params)
|
|
471
478
|
when "float"
|
472
479
|
"double"
|
473
480
|
when /\Aint\[/
|
474
|
-
|
481
|
+
if param[:optional]
|
482
|
+
"at::OptionalIntArrayRef"
|
483
|
+
else
|
484
|
+
"IntArrayRef"
|
485
|
+
end
|
475
486
|
when "float[]"
|
476
487
|
"ArrayRef<double>"
|
477
488
|
when "str"
|
@@ -480,13 +491,19 @@ def generate_dispatch_params(function, params)
|
|
480
491
|
else
|
481
492
|
"std::string"
|
482
493
|
end
|
483
|
-
when "Scalar"
|
494
|
+
when "Scalar"
|
495
|
+
if param[:optional]
|
496
|
+
"const c10::optional<Scalar> &"
|
497
|
+
else
|
498
|
+
"const Scalar &"
|
499
|
+
end
|
500
|
+
when "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage"
|
484
501
|
param[:type]
|
485
502
|
else
|
486
503
|
raise "Unknown type: #{param[:type]} (#{function.name})"
|
487
504
|
end
|
488
505
|
|
489
|
-
if param[:optional] && param[:type]
|
506
|
+
if param[:optional] && !["Tensor", "Scalar"].include?(param[:type]) && !param[:type].start_with?("int[")
|
490
507
|
type = "c10::optional<#{type}>"
|
491
508
|
end
|
492
509
|
|
@@ -529,7 +546,7 @@ def generate_dispatch_retval(function)
|
|
529
546
|
when ["float", "float"]
|
530
547
|
"std::tuple<double,double>"
|
531
548
|
else
|
532
|
-
raise "Unknown retvals: #{types}"
|
549
|
+
raise "Unknown retvals: #{types} (#{function.name})"
|
533
550
|
end
|
534
551
|
end
|
535
552
|
|