torch-rb 0.10.2 → 0.11.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +2 -1
- data/codegen/generate_functions.rb +24 -7
- data/codegen/native_functions.yaml +1362 -199
- data/ext/torch/extconf.rb +1 -13
- data/ext/torch/ruby_arg_parser.h +11 -3
- data/ext/torch/utils.h +1 -1
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +1 -0
- metadata +3 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: d2f88c938144476a772fd7c606751e0a6e67a338cb1c37ede9c2011db7fc4579
|
4
|
+
data.tar.gz: f4408457e3bf8c7bf9b42863459248aa592c70e8e3924cdbe72863a979e65106
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 90ebc506942809b02331f7accce6e93e714a57b5d2f58b06ad0d3c60204b947eba61fe0bf43431a19e27890f5c764cb348ccb1785a2ffccfba7969f4726b6f1f
|
7
|
+
data.tar.gz: 410af7b4934f79aaae94f6c1bde1d17ee8b043f02354b4ea9f139271575f43abb90afeba162badfa01a3ec406612ae6cdcb4a2adfb1067b1634982a590d10cc9
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -39,7 +39,14 @@ def skip_functions(functions)
|
|
39
39
|
f.base_name == "index_put" ||
|
40
40
|
# not supported yet
|
41
41
|
f.func.include?("Dimname") ||
|
42
|
-
f.func.include?("ConstQuantizerPtr")
|
42
|
+
f.func.include?("ConstQuantizerPtr") ||
|
43
|
+
f.func.include?("SymInt") ||
|
44
|
+
# TODO fix LibTorch 1.12 changes
|
45
|
+
f.base_name == "histogramdd" ||
|
46
|
+
f.base_name == "nested_tensor" ||
|
47
|
+
f.base_name == "split_copy" ||
|
48
|
+
f.base_name == "split_with_sizes_copy" ||
|
49
|
+
f.base_name == "unbind_copy"
|
43
50
|
end
|
44
51
|
end
|
45
52
|
|
@@ -250,7 +257,7 @@ def generate_dispatch(function, def_method)
|
|
250
257
|
|
251
258
|
cpp_params = generate_dispatch_params(function, params)
|
252
259
|
if opt_index
|
253
|
-
cpp_params.insert(remove_self ? opt_index + 1 : opt_index, "
|
260
|
+
cpp_params.insert(remove_self ? opt_index + 1 : opt_index, "TensorOptions options")
|
254
261
|
end
|
255
262
|
|
256
263
|
retval = generate_dispatch_retval(function)
|
@@ -410,7 +417,7 @@ def generate_function_params(function, params, remove_self)
|
|
410
417
|
else
|
411
418
|
"optionalTensor"
|
412
419
|
end
|
413
|
-
when "generator", "tensorlist"
|
420
|
+
when "generator", "tensorlist"
|
414
421
|
func
|
415
422
|
when "string"
|
416
423
|
"stringViewOptional"
|
@@ -471,7 +478,11 @@ def generate_dispatch_params(function, params)
|
|
471
478
|
when "float"
|
472
479
|
"double"
|
473
480
|
when /\Aint\[/
|
474
|
-
|
481
|
+
if param[:optional]
|
482
|
+
"at::OptionalIntArrayRef"
|
483
|
+
else
|
484
|
+
"IntArrayRef"
|
485
|
+
end
|
475
486
|
when "float[]"
|
476
487
|
"ArrayRef<double>"
|
477
488
|
when "str"
|
@@ -480,13 +491,19 @@ def generate_dispatch_params(function, params)
|
|
480
491
|
else
|
481
492
|
"std::string"
|
482
493
|
end
|
483
|
-
when "Scalar"
|
494
|
+
when "Scalar"
|
495
|
+
if param[:optional]
|
496
|
+
"const c10::optional<Scalar> &"
|
497
|
+
else
|
498
|
+
"const Scalar &"
|
499
|
+
end
|
500
|
+
when "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage"
|
484
501
|
param[:type]
|
485
502
|
else
|
486
503
|
raise "Unknown type: #{param[:type]} (#{function.name})"
|
487
504
|
end
|
488
505
|
|
489
|
-
if param[:optional] && param[:type]
|
506
|
+
if param[:optional] && !["Tensor", "Scalar"].include?(param[:type]) && !param[:type].start_with?("int[")
|
490
507
|
type = "c10::optional<#{type}>"
|
491
508
|
end
|
492
509
|
|
@@ -529,7 +546,7 @@ def generate_dispatch_retval(function)
|
|
529
546
|
when ["float", "float"]
|
530
547
|
"std::tuple<double,double>"
|
531
548
|
else
|
532
|
-
raise "Unknown retvals: #{types}"
|
549
|
+
raise "Unknown retvals: #{types} (#{function.name})"
|
533
550
|
end
|
534
551
|
end
|
535
552
|
|