torch-rb 0.8.3 → 0.9.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +2 -1
- data/codegen/generate_functions.rb +11 -4
- data/codegen/native_functions.yaml +1103 -373
- data/ext/torch/ruby_arg_parser.h +17 -3
- data/ext/torch/templates.h +0 -37
- data/ext/torch/tensor.cpp +8 -8
- data/lib/torch/tensor.rb +12 -0
- data/lib/torch/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: '09221364dad232f1b76129fe9dc9407675cc2afbd03bd1339c736d4eec752df7'
|
4
|
+
data.tar.gz: a37a0584aed809009ebd74e7c0da9430481ccabf1ac915d2468ee7511c249588
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: c220d35971b9ce3e5a7a80f6a5d1ae4324f3524d0e0171c680deced66fe2f29342a46eecb1a4447d84a401a677c7bb1ef910a0c7ee6c925ea4b578b7e5712772
|
7
|
+
data.tar.gz: 807fe2907de1caac92da6dddb0154b7971dda3aa0ee2c53f6b3046732f4bf3c02310e59a6441efa0fdf1ad3d0ddb5dcd7a3c1a946adbfbea73b6b30f10a71487
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -28,6 +28,9 @@ def skip_functions(functions)
|
|
28
28
|
f.base_name.include?("_forward") ||
|
29
29
|
f.base_name == "to" ||
|
30
30
|
f.base_name == "record_stream" ||
|
31
|
+
f.base_name == "is_pinned" ||
|
32
|
+
f.base_name == "pin_memory" ||
|
33
|
+
f.base_name == "fused_moving_avg_obs_fake_quant" ||
|
31
34
|
# in ext.cpp
|
32
35
|
f.base_name == "index" ||
|
33
36
|
f.base_name == "index_put_" ||
|
@@ -387,6 +390,8 @@ def generate_function_params(function, params, remove_self)
|
|
387
390
|
end
|
388
391
|
when "generator", "tensorlist", "intlist"
|
389
392
|
func
|
393
|
+
when "string"
|
394
|
+
"stringViewOptional"
|
390
395
|
else
|
391
396
|
"#{func}Optional"
|
392
397
|
end
|
@@ -424,9 +429,7 @@ def generate_dispatch_params(function, params)
|
|
424
429
|
if function.out?
|
425
430
|
"const Tensor &"
|
426
431
|
else
|
427
|
-
|
428
|
-
# "const c10::optional<at::Tensor> &"
|
429
|
-
"const OptionalTensor &"
|
432
|
+
"const c10::optional<at::Tensor> &"
|
430
433
|
end
|
431
434
|
elsif param[:modifier]
|
432
435
|
if param[:modifier].include?("!") && function.retvals.size > 1
|
@@ -450,7 +453,11 @@ def generate_dispatch_params(function, params)
|
|
450
453
|
when "float[]"
|
451
454
|
"ArrayRef<double>"
|
452
455
|
when "str"
|
453
|
-
|
456
|
+
if param[:optional]
|
457
|
+
"c10::string_view"
|
458
|
+
else
|
459
|
+
"std::string"
|
460
|
+
end
|
454
461
|
when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage"
|
455
462
|
param[:type]
|
456
463
|
else
|