torch-rb 0.8.2 → 0.9.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +22 -0
- data/README.md +5 -4
- data/codegen/generate_functions.rb +11 -4
- data/codegen/native_functions.yaml +1103 -373
- data/ext/torch/backends.cpp +2 -2
- data/ext/torch/nn.cpp +4 -1
- data/ext/torch/ruby_arg_parser.cpp +2 -2
- data/ext/torch/ruby_arg_parser.h +19 -5
- data/ext/torch/templates.h +30 -33
- data/ext/torch/tensor.cpp +33 -33
- data/ext/torch/torch.cpp +38 -28
- data/ext/torch/utils.h +0 -6
- data/lib/torch/inspector.rb +1 -1
- data/lib/torch/nn/functional.rb +1 -1
- data/lib/torch/nn/functional_attention.rb +1 -1
- data/lib/torch/nn/module.rb +28 -0
- data/lib/torch/nn/parameter.rb +9 -0
- data/lib/torch/nn/transformer_decoder_layer.rb +1 -1
- data/lib/torch/nn/utils.rb +1 -5
- data/lib/torch/tensor.rb +22 -8
- data/lib/torch/utils/data/data_loader.rb +1 -1
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +14 -48
- metadata +3 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: bd10ea76fc9cde319cad48a1b3bd298384c4f47812d9be5ee9016c44151458a0
|
4
|
+
data.tar.gz: e8391c2bc4ccae67ca0c2e402e431e99a683023285952edbd3f011a18793dfb6
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: c290b76e1d18e88c46ec2bd544471cc5922de4741a0095786ba00415b2fa6ae52023b0c296365a7af6a81143515c36ccd3dbe69d11aea5138fbcd81b6e687bf6
|
7
|
+
data.tar.gz: bc35fd6f9c9ed38b12caf98158f61a6d9827e967c36f8fbcf46f64331cc96c03de3105f4bc0b5b70c911d25ce223ea0085c18537cff677e84bf3d564ce0058ab
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,25 @@
|
|
1
|
+
## 0.9.2 (2022-02-03)
|
2
|
+
|
3
|
+
- Added support for setting `nil` gradient
|
4
|
+
- Added checks when setting gradient
|
5
|
+
- Fixed precision with `Torch.tensor` method
|
6
|
+
- Fixed memory issue when creating tensor for `ByteStorage`
|
7
|
+
|
8
|
+
## 0.9.1 (2022-02-02)
|
9
|
+
|
10
|
+
- Moved `like` methods to C++
|
11
|
+
- Fixed memory issue
|
12
|
+
|
13
|
+
## 0.9.0 (2021-10-23)
|
14
|
+
|
15
|
+
- Updated LibTorch to 1.10.0
|
16
|
+
- Added `real` and `imag` methods to tensors
|
17
|
+
|
18
|
+
## 0.8.3 (2021-10-17)
|
19
|
+
|
20
|
+
- Fixed `dup` method for tensors and parameters
|
21
|
+
- Fixed issues with transformers
|
22
|
+
|
1
23
|
## 0.8.2 (2021-10-03)
|
2
24
|
|
3
25
|
- Added transformers
|
data/README.md
CHANGED
@@ -21,10 +21,10 @@ brew install libtorch
|
|
21
21
|
Add this line to your application’s Gemfile:
|
22
22
|
|
23
23
|
```ruby
|
24
|
-
gem
|
24
|
+
gem "torch-rb"
|
25
25
|
```
|
26
26
|
|
27
|
-
It can take
|
27
|
+
It can take 5-10 minutes to compile the extension.
|
28
28
|
|
29
29
|
## Getting Started
|
30
30
|
|
@@ -79,7 +79,7 @@ b = Torch.zeros(2, 3)
|
|
79
79
|
|
80
80
|
Each tensor has four properties
|
81
81
|
|
82
|
-
- `dtype` - the data type - `:uint8`, `:int8`, `:int16`, `:int32`, `:int64`, `:float32`,
|
82
|
+
- `dtype` - the data type - `:uint8`, `:int8`, `:int16`, `:int32`, `:int64`, `:float32`, `:float64`, or `:bool`
|
83
83
|
- `layout` - `:strided` (dense) or `:sparse`
|
84
84
|
- `device` - the compute device, like CPU or GPU
|
85
85
|
- `requires_grad` - whether or not to record gradients
|
@@ -408,7 +408,8 @@ Here’s the list of compatible versions.
|
|
408
408
|
|
409
409
|
Torch.rb | LibTorch
|
410
410
|
--- | ---
|
411
|
-
0.
|
411
|
+
0.9.0+ | 1.10.0+
|
412
|
+
0.8.0-0.8.3 | 1.9.0-1.9.1
|
412
413
|
0.6.0-0.7.0 | 1.8.0-1.8.1
|
413
414
|
0.5.0-0.5.3 | 1.7.0-1.7.1
|
414
415
|
0.3.0-0.4.2 | 1.6.0
|
@@ -28,6 +28,9 @@ def skip_functions(functions)
|
|
28
28
|
f.base_name.include?("_forward") ||
|
29
29
|
f.base_name == "to" ||
|
30
30
|
f.base_name == "record_stream" ||
|
31
|
+
f.base_name == "is_pinned" ||
|
32
|
+
f.base_name == "pin_memory" ||
|
33
|
+
f.base_name == "fused_moving_avg_obs_fake_quant" ||
|
31
34
|
# in ext.cpp
|
32
35
|
f.base_name == "index" ||
|
33
36
|
f.base_name == "index_put_" ||
|
@@ -387,6 +390,8 @@ def generate_function_params(function, params, remove_self)
|
|
387
390
|
end
|
388
391
|
when "generator", "tensorlist", "intlist"
|
389
392
|
func
|
393
|
+
when "string"
|
394
|
+
"stringViewOptional"
|
390
395
|
else
|
391
396
|
"#{func}Optional"
|
392
397
|
end
|
@@ -424,9 +429,7 @@ def generate_dispatch_params(function, params)
|
|
424
429
|
if function.out?
|
425
430
|
"const Tensor &"
|
426
431
|
else
|
427
|
-
|
428
|
-
# "const c10::optional<at::Tensor> &"
|
429
|
-
"const OptionalTensor &"
|
432
|
+
"const c10::optional<at::Tensor> &"
|
430
433
|
end
|
431
434
|
elsif param[:modifier]
|
432
435
|
if param[:modifier].include?("!") && function.retvals.size > 1
|
@@ -450,7 +453,11 @@ def generate_dispatch_params(function, params)
|
|
450
453
|
when "float[]"
|
451
454
|
"ArrayRef<double>"
|
452
455
|
when "str"
|
453
|
-
|
456
|
+
if param[:optional]
|
457
|
+
"c10::string_view"
|
458
|
+
else
|
459
|
+
"std::string"
|
460
|
+
end
|
454
461
|
when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage"
|
455
462
|
param[:type]
|
456
463
|
else
|