torch-rb 0.10.1 → 0.11.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +14 -0
- data/README.md +16 -3
- data/codegen/function.rb +1 -1
- data/codegen/generate_functions.rb +31 -11
- data/codegen/native_functions.yaml +1362 -199
- data/ext/torch/extconf.rb +1 -13
- data/ext/torch/ruby_arg_parser.cpp +64 -2
- data/ext/torch/ruby_arg_parser.h +18 -3
- data/ext/torch/utils.h +1 -1
- data/lib/torch/tensor.rb +8 -5
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +1 -12
- metadata +3 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: c4c3936a55fd47a8898d3e678aabc81c158f2f39a89c08af0b36695700bf2043
|
4
|
+
data.tar.gz: a68179c7d7bab7547ac3be1d2369abbd5c3c632954996ca2f7fdd089f815edf7
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 97551aa27f154cade530e58b4ae92286cd18e432acc99646495a197c89fb719cc991b0399aacaf383a2de04e340c04a9e177c44a0169fce91e19435463df4753
|
7
|
+
data.tar.gz: 8a5a5decf900a4d93aa56eb5e5b3848d5386be8bbc3657c10c201ef321cf0082c6521700378a5e96f795c09882b5b72498d8fed3931153b6e1f7557630f55547
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,17 @@
|
|
1
|
+
## 0.11.1 (2022-07-06)
|
2
|
+
|
3
|
+
- Fixed error with `stft` method
|
4
|
+
|
5
|
+
## 0.11.0 (2022-07-06)
|
6
|
+
|
7
|
+
- Updated LibTorch to 1.12.0
|
8
|
+
- Dropped support for Ruby < 2.7
|
9
|
+
|
10
|
+
## 0.10.2 (2022-06-14)
|
11
|
+
|
12
|
+
- Improved numeric operations between scalars and tensors
|
13
|
+
- Fixed `dtype` of `cumsum` method
|
14
|
+
|
1
15
|
## 0.10.1 (2022-04-12)
|
2
16
|
|
3
17
|
- Fixed `dtype`, `device`, and `layout` for `new_*` and `like_*` methods
|
data/README.md
CHANGED
@@ -409,7 +409,8 @@ Here’s the list of compatible versions.
|
|
409
409
|
|
410
410
|
Torch.rb | LibTorch
|
411
411
|
--- | ---
|
412
|
-
0.
|
412
|
+
0.11.0+ | 1.12.0+
|
413
|
+
0.10.0-0.10.2 | 1.11.0
|
413
414
|
0.9.0-0.9.2 | 1.10.0-1.10.2
|
414
415
|
0.8.0-0.8.3 | 1.9.0-1.9.1
|
415
416
|
0.6.0-0.7.0 | 1.8.0-1.8.1
|
@@ -421,13 +422,25 @@ Torch.rb | LibTorch
|
|
421
422
|
|
422
423
|
### Homebrew
|
423
424
|
|
424
|
-
|
425
|
+
You can also use Homebrew.
|
425
426
|
|
426
427
|
```sh
|
427
428
|
brew install libtorch
|
428
429
|
```
|
429
430
|
|
430
|
-
|
431
|
+
For Mac ARM, run:
|
432
|
+
|
433
|
+
```sh
|
434
|
+
bundle config build.torch-rb --with-torch-dir=/opt/homebrew
|
435
|
+
```
|
436
|
+
|
437
|
+
And for Linux, run:
|
438
|
+
|
439
|
+
```sh
|
440
|
+
bundle config build.torch-rb --with-torch-dir=/home/linuxbrew/.linuxbrew
|
441
|
+
```
|
442
|
+
|
443
|
+
Then install the gem.
|
431
444
|
|
432
445
|
## Performance
|
433
446
|
|
data/codegen/function.rb
CHANGED
@@ -39,7 +39,14 @@ def skip_functions(functions)
|
|
39
39
|
f.base_name == "index_put" ||
|
40
40
|
# not supported yet
|
41
41
|
f.func.include?("Dimname") ||
|
42
|
-
f.func.include?("ConstQuantizerPtr")
|
42
|
+
f.func.include?("ConstQuantizerPtr") ||
|
43
|
+
f.func.include?("SymInt") ||
|
44
|
+
# TODO fix LibTorch 1.12 changes
|
45
|
+
f.base_name == "histogramdd" ||
|
46
|
+
f.base_name == "nested_tensor" ||
|
47
|
+
f.base_name == "split_copy" ||
|
48
|
+
f.base_name == "split_with_sizes_copy" ||
|
49
|
+
f.base_name == "unbind_copy"
|
43
50
|
end
|
44
51
|
end
|
45
52
|
|
@@ -121,8 +128,11 @@ def write_body(type, method_defs, attach_defs)
|
|
121
128
|
end
|
122
129
|
|
123
130
|
def write_file(name, contents)
|
124
|
-
path = File.expand_path("../ext/torch", __dir__)
|
125
|
-
|
131
|
+
path = File.join(File.expand_path("../ext/torch", __dir__), name)
|
132
|
+
# only write if changed to improve compile times in development
|
133
|
+
if !File.exist?(path) || File.read(path) != contents
|
134
|
+
File.write(path, contents)
|
135
|
+
end
|
126
136
|
end
|
127
137
|
|
128
138
|
def generate_attach_def(name, type, def_method)
|
@@ -135,14 +145,14 @@ def generate_attach_def(name, type, def_method)
|
|
135
145
|
name
|
136
146
|
end
|
137
147
|
|
138
|
-
ruby_name = "_#{ruby_name}" if ["size", "stride", "random!"
|
148
|
+
ruby_name = "_#{ruby_name}" if ["size", "stride", "random!"].include?(ruby_name)
|
139
149
|
ruby_name = ruby_name.sub(/\Afft_/, "") if type == "fft"
|
140
150
|
ruby_name = ruby_name.sub(/\Alinalg_/, "") if type == "linalg"
|
141
151
|
ruby_name = ruby_name.sub(/\Aspecial_/, "") if type == "special"
|
142
152
|
ruby_name = ruby_name.sub(/\Asparse_/, "") if type == "sparse"
|
143
153
|
ruby_name = name if name.start_with?("__")
|
144
154
|
|
145
|
-
# cast for Ruby <
|
155
|
+
# cast for Ruby < 3.0 https://github.com/thisMagpie/fftw/issues/22#issuecomment-49508900
|
146
156
|
cast = RUBY_VERSION.to_f > 2.7 ? "" : "(VALUE (*)(...)) "
|
147
157
|
|
148
158
|
"rb_#{def_method}(m, \"#{ruby_name}\", #{cast}#{full_name(name, type)}, -1);"
|
@@ -250,7 +260,7 @@ def generate_dispatch(function, def_method)
|
|
250
260
|
|
251
261
|
cpp_params = generate_dispatch_params(function, params)
|
252
262
|
if opt_index
|
253
|
-
cpp_params.insert(remove_self ? opt_index + 1 : opt_index, "
|
263
|
+
cpp_params.insert(remove_self ? opt_index + 1 : opt_index, "TensorOptions options")
|
254
264
|
end
|
255
265
|
|
256
266
|
retval = generate_dispatch_retval(function)
|
@@ -410,7 +420,7 @@ def generate_function_params(function, params, remove_self)
|
|
410
420
|
else
|
411
421
|
"optionalTensor"
|
412
422
|
end
|
413
|
-
when "generator", "tensorlist"
|
423
|
+
when "generator", "tensorlist"
|
414
424
|
func
|
415
425
|
when "string"
|
416
426
|
"stringViewOptional"
|
@@ -471,7 +481,11 @@ def generate_dispatch_params(function, params)
|
|
471
481
|
when "float"
|
472
482
|
"double"
|
473
483
|
when /\Aint\[/
|
474
|
-
|
484
|
+
if param[:optional]
|
485
|
+
"at::OptionalIntArrayRef"
|
486
|
+
else
|
487
|
+
"IntArrayRef"
|
488
|
+
end
|
475
489
|
when "float[]"
|
476
490
|
"ArrayRef<double>"
|
477
491
|
when "str"
|
@@ -480,13 +494,19 @@ def generate_dispatch_params(function, params)
|
|
480
494
|
else
|
481
495
|
"std::string"
|
482
496
|
end
|
483
|
-
when "Scalar"
|
497
|
+
when "Scalar"
|
498
|
+
if param[:optional]
|
499
|
+
"const c10::optional<Scalar> &"
|
500
|
+
else
|
501
|
+
"const Scalar &"
|
502
|
+
end
|
503
|
+
when "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage"
|
484
504
|
param[:type]
|
485
505
|
else
|
486
506
|
raise "Unknown type: #{param[:type]} (#{function.name})"
|
487
507
|
end
|
488
508
|
|
489
|
-
if param[:optional] && param[:type]
|
509
|
+
if param[:optional] && !["Tensor", "Scalar"].include?(param[:type]) && !param[:type].start_with?("int[")
|
490
510
|
type = "c10::optional<#{type}>"
|
491
511
|
end
|
492
512
|
|
@@ -529,7 +549,7 @@ def generate_dispatch_retval(function)
|
|
529
549
|
when ["float", "float"]
|
530
550
|
"std::tuple<double,double>"
|
531
551
|
else
|
532
|
-
raise "Unknown retvals: #{types}"
|
552
|
+
raise "Unknown retvals: #{types} (#{function.name})"
|
533
553
|
end
|
534
554
|
end
|
535
555
|
|