torch-rb 0.8.2 → 0.9.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: '05811faa93ab089485bfa213362bfed0462227e6964e726bf1c3f9fc0cdba0c3'
4
- data.tar.gz: 8d25304063db51850e535e71c2b4fe643e038a18d046b8f27ee50269e5e9695d
3
+ metadata.gz: bd10ea76fc9cde319cad48a1b3bd298384c4f47812d9be5ee9016c44151458a0
4
+ data.tar.gz: e8391c2bc4ccae67ca0c2e402e431e99a683023285952edbd3f011a18793dfb6
5
5
  SHA512:
6
- metadata.gz: 8cea906b03b37ec848be7b1c7cfa6bfb0fde4ef7ed384818bcc85826d04611621835bbeecad0fd31c94b497bb527a21c107901a9d991692b6fffa6bb24c23c38
7
- data.tar.gz: 2ded65d614d274afe61e061898268172e6dec85dc28e3094481c2650713a129c0574b7ba17dbafccd8b8dc7ee064602fd261ce5034cec4a184fa32f2965eb476
6
+ metadata.gz: c290b76e1d18e88c46ec2bd544471cc5922de4741a0095786ba00415b2fa6ae52023b0c296365a7af6a81143515c36ccd3dbe69d11aea5138fbcd81b6e687bf6
7
+ data.tar.gz: bc35fd6f9c9ed38b12caf98158f61a6d9827e967c36f8fbcf46f64331cc96c03de3105f4bc0b5b70c911d25ce223ea0085c18537cff677e84bf3d564ce0058ab
data/CHANGELOG.md CHANGED
@@ -1,3 +1,25 @@
1
+ ## 0.9.2 (2022-02-03)
2
+
3
+ - Added support for setting `nil` gradient
4
+ - Added checks when setting gradient
5
+ - Fixed precision with `Torch.tensor` method
6
+ - Fixed memory issue when creating tensor for `ByteStorage`
7
+
8
+ ## 0.9.1 (2022-02-02)
9
+
10
+ - Moved `like` methods to C++
11
+ - Fixed memory issue
12
+
13
+ ## 0.9.0 (2021-10-23)
14
+
15
+ - Updated LibTorch to 1.10.0
16
+ - Added `real` and `imag` methods to tensors
17
+
18
+ ## 0.8.3 (2021-10-17)
19
+
20
+ - Fixed `dup` method for tensors and parameters
21
+ - Fixed issues with transformers
22
+
1
23
  ## 0.8.2 (2021-10-03)
2
24
 
3
25
  - Added transformers
data/README.md CHANGED
@@ -21,10 +21,10 @@ brew install libtorch
21
21
  Add this line to your application’s Gemfile:
22
22
 
23
23
  ```ruby
24
- gem 'torch-rb'
24
+ gem "torch-rb"
25
25
  ```
26
26
 
27
- It can take a few minutes to compile the extension.
27
+ It can take 5-10 minutes to compile the extension.
28
28
 
29
29
  ## Getting Started
30
30
 
@@ -79,7 +79,7 @@ b = Torch.zeros(2, 3)
79
79
 
80
80
  Each tensor has four properties
81
81
 
82
- - `dtype` - the data type - `:uint8`, `:int8`, `:int16`, `:int32`, `:int64`, `:float32`, `float64`, or `:bool`
82
+ - `dtype` - the data type - `:uint8`, `:int8`, `:int16`, `:int32`, `:int64`, `:float32`, `:float64`, or `:bool`
83
83
  - `layout` - `:strided` (dense) or `:sparse`
84
84
  - `device` - the compute device, like CPU or GPU
85
85
  - `requires_grad` - whether or not to record gradients
@@ -408,7 +408,8 @@ Here’s the list of compatible versions.
408
408
 
409
409
  Torch.rb | LibTorch
410
410
  --- | ---
411
- 0.8.0+ | 1.9.0+
411
+ 0.9.0+ | 1.10.0+
412
+ 0.8.0-0.8.3 | 1.9.0-1.9.1
412
413
  0.6.0-0.7.0 | 1.8.0-1.8.1
413
414
  0.5.0-0.5.3 | 1.7.0-1.7.1
414
415
  0.3.0-0.4.2 | 1.6.0
@@ -28,6 +28,9 @@ def skip_functions(functions)
28
28
  f.base_name.include?("_forward") ||
29
29
  f.base_name == "to" ||
30
30
  f.base_name == "record_stream" ||
31
+ f.base_name == "is_pinned" ||
32
+ f.base_name == "pin_memory" ||
33
+ f.base_name == "fused_moving_avg_obs_fake_quant" ||
31
34
  # in ext.cpp
32
35
  f.base_name == "index" ||
33
36
  f.base_name == "index_put_" ||
@@ -387,6 +390,8 @@ def generate_function_params(function, params, remove_self)
387
390
  end
388
391
  when "generator", "tensorlist", "intlist"
389
392
  func
393
+ when "string"
394
+ "stringViewOptional"
390
395
  else
391
396
  "#{func}Optional"
392
397
  end
@@ -424,9 +429,7 @@ def generate_dispatch_params(function, params)
424
429
  if function.out?
425
430
  "const Tensor &"
426
431
  else
427
- # TODO
428
- # "const c10::optional<at::Tensor> &"
429
- "const OptionalTensor &"
432
+ "const c10::optional<at::Tensor> &"
430
433
  end
431
434
  elsif param[:modifier]
432
435
  if param[:modifier].include?("!") && function.retvals.size > 1
@@ -450,7 +453,11 @@ def generate_dispatch_params(function, params)
450
453
  when "float[]"
451
454
  "ArrayRef<double>"
452
455
  when "str"
453
- "std::string"
456
+ if param[:optional]
457
+ "c10::string_view"
458
+ else
459
+ "std::string"
460
+ end
454
461
  when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage"
455
462
  param[:type]
456
463
  else