torch-rb 0.8.2 → 0.9.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +22 -0
- data/README.md +5 -4
- data/codegen/generate_functions.rb +11 -4
- data/codegen/native_functions.yaml +1103 -373
- data/ext/torch/backends.cpp +2 -2
- data/ext/torch/nn.cpp +4 -1
- data/ext/torch/ruby_arg_parser.cpp +2 -2
- data/ext/torch/ruby_arg_parser.h +19 -5
- data/ext/torch/templates.h +30 -33
- data/ext/torch/tensor.cpp +33 -33
- data/ext/torch/torch.cpp +38 -28
- data/ext/torch/utils.h +0 -6
- data/lib/torch/inspector.rb +1 -1
- data/lib/torch/nn/functional.rb +1 -1
- data/lib/torch/nn/functional_attention.rb +1 -1
- data/lib/torch/nn/module.rb +28 -0
- data/lib/torch/nn/parameter.rb +9 -0
- data/lib/torch/nn/transformer_decoder_layer.rb +1 -1
- data/lib/torch/nn/utils.rb +1 -5
- data/lib/torch/tensor.rb +22 -8
- data/lib/torch/utils/data/data_loader.rb +1 -1
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +14 -48
- metadata +3 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: bd10ea76fc9cde319cad48a1b3bd298384c4f47812d9be5ee9016c44151458a0
|
4
|
+
data.tar.gz: e8391c2bc4ccae67ca0c2e402e431e99a683023285952edbd3f011a18793dfb6
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: c290b76e1d18e88c46ec2bd544471cc5922de4741a0095786ba00415b2fa6ae52023b0c296365a7af6a81143515c36ccd3dbe69d11aea5138fbcd81b6e687bf6
|
7
|
+
data.tar.gz: bc35fd6f9c9ed38b12caf98158f61a6d9827e967c36f8fbcf46f64331cc96c03de3105f4bc0b5b70c911d25ce223ea0085c18537cff677e84bf3d564ce0058ab
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,25 @@
|
|
1
|
+
## 0.9.2 (2022-02-03)
|
2
|
+
|
3
|
+
- Added support for setting `nil` gradient
|
4
|
+
- Added checks when setting gradient
|
5
|
+
- Fixed precision with `Torch.tensor` method
|
6
|
+
- Fixed memory issue when creating tensor for `ByteStorage`
|
7
|
+
|
8
|
+
## 0.9.1 (2022-02-02)
|
9
|
+
|
10
|
+
- Moved `like` methods to C++
|
11
|
+
- Fixed memory issue
|
12
|
+
|
13
|
+
## 0.9.0 (2021-10-23)
|
14
|
+
|
15
|
+
- Updated LibTorch to 1.10.0
|
16
|
+
- Added `real` and `imag` methods to tensors
|
17
|
+
|
18
|
+
## 0.8.3 (2021-10-17)
|
19
|
+
|
20
|
+
- Fixed `dup` method for tensors and parameters
|
21
|
+
- Fixed issues with transformers
|
22
|
+
|
1
23
|
## 0.8.2 (2021-10-03)
|
2
24
|
|
3
25
|
- Added transformers
|
data/README.md
CHANGED
@@ -21,10 +21,10 @@ brew install libtorch
|
|
21
21
|
Add this line to your application’s Gemfile:
|
22
22
|
|
23
23
|
```ruby
|
24
|
-
gem
|
24
|
+
gem "torch-rb"
|
25
25
|
```
|
26
26
|
|
27
|
-
It can take
|
27
|
+
It can take 5-10 minutes to compile the extension.
|
28
28
|
|
29
29
|
## Getting Started
|
30
30
|
|
@@ -79,7 +79,7 @@ b = Torch.zeros(2, 3)
|
|
79
79
|
|
80
80
|
Each tensor has four properties
|
81
81
|
|
82
|
-
- `dtype` - the data type - `:uint8`, `:int8`, `:int16`, `:int32`, `:int64`, `:float32`,
|
82
|
+
- `dtype` - the data type - `:uint8`, `:int8`, `:int16`, `:int32`, `:int64`, `:float32`, `:float64`, or `:bool`
|
83
83
|
- `layout` - `:strided` (dense) or `:sparse`
|
84
84
|
- `device` - the compute device, like CPU or GPU
|
85
85
|
- `requires_grad` - whether or not to record gradients
|
@@ -408,7 +408,8 @@ Here’s the list of compatible versions.
|
|
408
408
|
|
409
409
|
Torch.rb | LibTorch
|
410
410
|
--- | ---
|
411
|
-
0.
|
411
|
+
0.9.0+ | 1.10.0+
|
412
|
+
0.8.0-0.8.3 | 1.9.0-1.9.1
|
412
413
|
0.6.0-0.7.0 | 1.8.0-1.8.1
|
413
414
|
0.5.0-0.5.3 | 1.7.0-1.7.1
|
414
415
|
0.3.0-0.4.2 | 1.6.0
|
@@ -28,6 +28,9 @@ def skip_functions(functions)
|
|
28
28
|
f.base_name.include?("_forward") ||
|
29
29
|
f.base_name == "to" ||
|
30
30
|
f.base_name == "record_stream" ||
|
31
|
+
f.base_name == "is_pinned" ||
|
32
|
+
f.base_name == "pin_memory" ||
|
33
|
+
f.base_name == "fused_moving_avg_obs_fake_quant" ||
|
31
34
|
# in ext.cpp
|
32
35
|
f.base_name == "index" ||
|
33
36
|
f.base_name == "index_put_" ||
|
@@ -387,6 +390,8 @@ def generate_function_params(function, params, remove_self)
|
|
387
390
|
end
|
388
391
|
when "generator", "tensorlist", "intlist"
|
389
392
|
func
|
393
|
+
when "string"
|
394
|
+
"stringViewOptional"
|
390
395
|
else
|
391
396
|
"#{func}Optional"
|
392
397
|
end
|
@@ -424,9 +429,7 @@ def generate_dispatch_params(function, params)
|
|
424
429
|
if function.out?
|
425
430
|
"const Tensor &"
|
426
431
|
else
|
427
|
-
|
428
|
-
# "const c10::optional<at::Tensor> &"
|
429
|
-
"const OptionalTensor &"
|
432
|
+
"const c10::optional<at::Tensor> &"
|
430
433
|
end
|
431
434
|
elsif param[:modifier]
|
432
435
|
if param[:modifier].include?("!") && function.retvals.size > 1
|
@@ -450,7 +453,11 @@ def generate_dispatch_params(function, params)
|
|
450
453
|
when "float[]"
|
451
454
|
"ArrayRef<double>"
|
452
455
|
when "str"
|
453
|
-
|
456
|
+
if param[:optional]
|
457
|
+
"c10::string_view"
|
458
|
+
else
|
459
|
+
"std::string"
|
460
|
+
end
|
454
461
|
when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage"
|
455
462
|
param[:type]
|
456
463
|
else
|