torch-rb 0.10.1 → 0.11.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +14 -0
- data/README.md +16 -3
- data/codegen/function.rb +1 -1
- data/codegen/generate_functions.rb +31 -11
- data/codegen/native_functions.yaml +1362 -199
- data/ext/torch/extconf.rb +1 -13
- data/ext/torch/ruby_arg_parser.cpp +64 -2
- data/ext/torch/ruby_arg_parser.h +18 -3
- data/ext/torch/utils.h +1 -1
- data/lib/torch/tensor.rb +8 -5
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +1 -12
- metadata +3 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: c4c3936a55fd47a8898d3e678aabc81c158f2f39a89c08af0b36695700bf2043
|
4
|
+
data.tar.gz: a68179c7d7bab7547ac3be1d2369abbd5c3c632954996ca2f7fdd089f815edf7
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 97551aa27f154cade530e58b4ae92286cd18e432acc99646495a197c89fb719cc991b0399aacaf383a2de04e340c04a9e177c44a0169fce91e19435463df4753
|
7
|
+
data.tar.gz: 8a5a5decf900a4d93aa56eb5e5b3848d5386be8bbc3657c10c201ef321cf0082c6521700378a5e96f795c09882b5b72498d8fed3931153b6e1f7557630f55547
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,17 @@
|
|
1
|
+
## 0.11.1 (2022-07-06)
|
2
|
+
|
3
|
+
- Fixed error with `stft` method
|
4
|
+
|
5
|
+
## 0.11.0 (2022-07-06)
|
6
|
+
|
7
|
+
- Updated LibTorch to 1.12.0
|
8
|
+
- Dropped support for Ruby < 2.7
|
9
|
+
|
10
|
+
## 0.10.2 (2022-06-14)
|
11
|
+
|
12
|
+
- Improved numeric operations between scalars and tensors
|
13
|
+
- Fixed `dtype` of `cumsum` method
|
14
|
+
|
1
15
|
## 0.10.1 (2022-04-12)
|
2
16
|
|
3
17
|
- Fixed `dtype`, `device`, and `layout` for `new_*` and `like_*` methods
|
data/README.md
CHANGED
@@ -409,7 +409,8 @@ Here’s the list of compatible versions.
|
|
409
409
|
|
410
410
|
Torch.rb | LibTorch
|
411
411
|
--- | ---
|
412
|
-
0.
|
412
|
+
0.11.0+ | 1.12.0+
|
413
|
+
0.10.0-0.10.2 | 1.11.0
|
413
414
|
0.9.0-0.9.2 | 1.10.0-1.10.2
|
414
415
|
0.8.0-0.8.3 | 1.9.0-1.9.1
|
415
416
|
0.6.0-0.7.0 | 1.8.0-1.8.1
|
@@ -421,13 +422,25 @@ Torch.rb | LibTorch
|
|
421
422
|
|
422
423
|
### Homebrew
|
423
424
|
|
424
|
-
|
425
|
+
You can also use Homebrew.
|
425
426
|
|
426
427
|
```sh
|
427
428
|
brew install libtorch
|
428
429
|
```
|
429
430
|
|
430
|
-
|
431
|
+
For Mac ARM, run:
|
432
|
+
|
433
|
+
```sh
|
434
|
+
bundle config build.torch-rb --with-torch-dir=/opt/homebrew
|
435
|
+
```
|
436
|
+
|
437
|
+
And for Linux, run:
|
438
|
+
|
439
|
+
```sh
|
440
|
+
bundle config build.torch-rb --with-torch-dir=/home/linuxbrew/.linuxbrew
|
441
|
+
```
|
442
|
+
|
443
|
+
Then install the gem.
|
431
444
|
|
432
445
|
## Performance
|
433
446
|
|
data/codegen/function.rb
CHANGED
@@ -39,7 +39,14 @@ def skip_functions(functions)
|
|
39
39
|
f.base_name == "index_put" ||
|
40
40
|
# not supported yet
|
41
41
|
f.func.include?("Dimname") ||
|
42
|
-
f.func.include?("ConstQuantizerPtr")
|
42
|
+
f.func.include?("ConstQuantizerPtr") ||
|
43
|
+
f.func.include?("SymInt") ||
|
44
|
+
# TODO fix LibTorch 1.12 changes
|
45
|
+
f.base_name == "histogramdd" ||
|
46
|
+
f.base_name == "nested_tensor" ||
|
47
|
+
f.base_name == "split_copy" ||
|
48
|
+
f.base_name == "split_with_sizes_copy" ||
|
49
|
+
f.base_name == "unbind_copy"
|
43
50
|
end
|
44
51
|
end
|
45
52
|
|
@@ -121,8 +128,11 @@ def write_body(type, method_defs, attach_defs)
|
|
121
128
|
end
|
122
129
|
|
123
130
|
def write_file(name, contents)
|
124
|
-
path = File.expand_path("../ext/torch", __dir__)
|
125
|
-
|
131
|
+
path = File.join(File.expand_path("../ext/torch", __dir__), name)
|
132
|
+
# only write if changed to improve compile times in development
|
133
|
+
if !File.exist?(path) || File.read(path) != contents
|
134
|
+
File.write(path, contents)
|
135
|
+
end
|
126
136
|
end
|
127
137
|
|
128
138
|
def generate_attach_def(name, type, def_method)
|
@@ -135,14 +145,14 @@ def generate_attach_def(name, type, def_method)
|
|
135
145
|
name
|
136
146
|
end
|
137
147
|
|
138
|
-
ruby_name = "_#{ruby_name}" if ["size", "stride", "random!"
|
148
|
+
ruby_name = "_#{ruby_name}" if ["size", "stride", "random!"].include?(ruby_name)
|
139
149
|
ruby_name = ruby_name.sub(/\Afft_/, "") if type == "fft"
|
140
150
|
ruby_name = ruby_name.sub(/\Alinalg_/, "") if type == "linalg"
|
141
151
|
ruby_name = ruby_name.sub(/\Aspecial_/, "") if type == "special"
|
142
152
|
ruby_name = ruby_name.sub(/\Asparse_/, "") if type == "sparse"
|
143
153
|
ruby_name = name if name.start_with?("__")
|
144
154
|
|
145
|
-
# cast for Ruby <
|
155
|
+
# cast for Ruby < 3.0 https://github.com/thisMagpie/fftw/issues/22#issuecomment-49508900
|
146
156
|
cast = RUBY_VERSION.to_f > 2.7 ? "" : "(VALUE (*)(...)) "
|
147
157
|
|
148
158
|
"rb_#{def_method}(m, \"#{ruby_name}\", #{cast}#{full_name(name, type)}, -1);"
|
@@ -250,7 +260,7 @@ def generate_dispatch(function, def_method)
|
|
250
260
|
|
251
261
|
cpp_params = generate_dispatch_params(function, params)
|
252
262
|
if opt_index
|
253
|
-
cpp_params.insert(remove_self ? opt_index + 1 : opt_index, "
|
263
|
+
cpp_params.insert(remove_self ? opt_index + 1 : opt_index, "TensorOptions options")
|
254
264
|
end
|
255
265
|
|
256
266
|
retval = generate_dispatch_retval(function)
|
@@ -410,7 +420,7 @@ def generate_function_params(function, params, remove_self)
|
|
410
420
|
else
|
411
421
|
"optionalTensor"
|
412
422
|
end
|
413
|
-
when "generator", "tensorlist"
|
423
|
+
when "generator", "tensorlist"
|
414
424
|
func
|
415
425
|
when "string"
|
416
426
|
"stringViewOptional"
|
@@ -471,7 +481,11 @@ def generate_dispatch_params(function, params)
|
|
471
481
|
when "float"
|
472
482
|
"double"
|
473
483
|
when /\Aint\[/
|
474
|
-
|
484
|
+
if param[:optional]
|
485
|
+
"at::OptionalIntArrayRef"
|
486
|
+
else
|
487
|
+
"IntArrayRef"
|
488
|
+
end
|
475
489
|
when "float[]"
|
476
490
|
"ArrayRef<double>"
|
477
491
|
when "str"
|
@@ -480,13 +494,19 @@ def generate_dispatch_params(function, params)
|
|
480
494
|
else
|
481
495
|
"std::string"
|
482
496
|
end
|
483
|
-
when "Scalar"
|
497
|
+
when "Scalar"
|
498
|
+
if param[:optional]
|
499
|
+
"const c10::optional<Scalar> &"
|
500
|
+
else
|
501
|
+
"const Scalar &"
|
502
|
+
end
|
503
|
+
when "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage"
|
484
504
|
param[:type]
|
485
505
|
else
|
486
506
|
raise "Unknown type: #{param[:type]} (#{function.name})"
|
487
507
|
end
|
488
508
|
|
489
|
-
if param[:optional] && param[:type]
|
509
|
+
if param[:optional] && !["Tensor", "Scalar"].include?(param[:type]) && !param[:type].start_with?("int[")
|
490
510
|
type = "c10::optional<#{type}>"
|
491
511
|
end
|
492
512
|
|
@@ -529,7 +549,7 @@ def generate_dispatch_retval(function)
|
|
529
549
|
when ["float", "float"]
|
530
550
|
"std::tuple<double,double>"
|
531
551
|
else
|
532
|
-
raise "Unknown retvals: #{types}"
|
552
|
+
raise "Unknown retvals: #{types} (#{function.name})"
|
533
553
|
end
|
534
554
|
end
|
535
555
|
|