torch-rb 0.8.3 → 0.9.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +2 -1
- data/codegen/generate_functions.rb +11 -4
- data/codegen/native_functions.yaml +1103 -373
- data/ext/torch/ruby_arg_parser.h +17 -3
- data/ext/torch/templates.h +0 -37
- data/ext/torch/tensor.cpp +8 -8
- data/lib/torch/tensor.rb +12 -0
- data/lib/torch/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: '09221364dad232f1b76129fe9dc9407675cc2afbd03bd1339c736d4eec752df7'
|
4
|
+
data.tar.gz: a37a0584aed809009ebd74e7c0da9430481ccabf1ac915d2468ee7511c249588
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: c220d35971b9ce3e5a7a80f6a5d1ae4324f3524d0e0171c680deced66fe2f29342a46eecb1a4447d84a401a677c7bb1ef910a0c7ee6c925ea4b578b7e5712772
|
7
|
+
data.tar.gz: 807fe2907de1caac92da6dddb0154b7971dda3aa0ee2c53f6b3046732f4bf3c02310e59a6441efa0fdf1ad3d0ddb5dcd7a3c1a946adbfbea73b6b30f10a71487
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -28,6 +28,9 @@ def skip_functions(functions)
|
|
28
28
|
f.base_name.include?("_forward") ||
|
29
29
|
f.base_name == "to" ||
|
30
30
|
f.base_name == "record_stream" ||
|
31
|
+
f.base_name == "is_pinned" ||
|
32
|
+
f.base_name == "pin_memory" ||
|
33
|
+
f.base_name == "fused_moving_avg_obs_fake_quant" ||
|
31
34
|
# in ext.cpp
|
32
35
|
f.base_name == "index" ||
|
33
36
|
f.base_name == "index_put_" ||
|
@@ -387,6 +390,8 @@ def generate_function_params(function, params, remove_self)
|
|
387
390
|
end
|
388
391
|
when "generator", "tensorlist", "intlist"
|
389
392
|
func
|
393
|
+
when "string"
|
394
|
+
"stringViewOptional"
|
390
395
|
else
|
391
396
|
"#{func}Optional"
|
392
397
|
end
|
@@ -424,9 +429,7 @@ def generate_dispatch_params(function, params)
|
|
424
429
|
if function.out?
|
425
430
|
"const Tensor &"
|
426
431
|
else
|
427
|
-
|
428
|
-
# "const c10::optional<at::Tensor> &"
|
429
|
-
"const OptionalTensor &"
|
432
|
+
"const c10::optional<at::Tensor> &"
|
430
433
|
end
|
431
434
|
elsif param[:modifier]
|
432
435
|
if param[:modifier].include?("!") && function.retvals.size > 1
|
@@ -450,7 +453,11 @@ def generate_dispatch_params(function, params)
|
|
450
453
|
when "float[]"
|
451
454
|
"ArrayRef<double>"
|
452
455
|
when "str"
|
453
|
-
|
456
|
+
if param[:optional]
|
457
|
+
"c10::string_view"
|
458
|
+
else
|
459
|
+
"std::string"
|
460
|
+
end
|
454
461
|
when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage"
|
455
462
|
param[:type]
|
456
463
|
else
|