torch-rb 0.8.2 → 0.9.2

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: '05811faa93ab089485bfa213362bfed0462227e6964e726bf1c3f9fc0cdba0c3'
4
- data.tar.gz: 8d25304063db51850e535e71c2b4fe643e038a18d046b8f27ee50269e5e9695d
3
+ metadata.gz: bd10ea76fc9cde319cad48a1b3bd298384c4f47812d9be5ee9016c44151458a0
4
+ data.tar.gz: e8391c2bc4ccae67ca0c2e402e431e99a683023285952edbd3f011a18793dfb6
5
5
  SHA512:
6
- metadata.gz: 8cea906b03b37ec848be7b1c7cfa6bfb0fde4ef7ed384818bcc85826d04611621835bbeecad0fd31c94b497bb527a21c107901a9d991692b6fffa6bb24c23c38
7
- data.tar.gz: 2ded65d614d274afe61e061898268172e6dec85dc28e3094481c2650713a129c0574b7ba17dbafccd8b8dc7ee064602fd261ce5034cec4a184fa32f2965eb476
6
+ metadata.gz: c290b76e1d18e88c46ec2bd544471cc5922de4741a0095786ba00415b2fa6ae52023b0c296365a7af6a81143515c36ccd3dbe69d11aea5138fbcd81b6e687bf6
7
+ data.tar.gz: bc35fd6f9c9ed38b12caf98158f61a6d9827e967c36f8fbcf46f64331cc96c03de3105f4bc0b5b70c911d25ce223ea0085c18537cff677e84bf3d564ce0058ab
data/CHANGELOG.md CHANGED
@@ -1,3 +1,25 @@
1
+ ## 0.9.2 (2022-02-03)
2
+
3
+ - Added support for setting `nil` gradient
4
+ - Added checks when setting gradient
5
+ - Fixed precision with `Torch.tensor` method
6
+ - Fixed memory issue when creating tensor for `ByteStorage`
7
+
8
+ ## 0.9.1 (2022-02-02)
9
+
10
+ - Moved `like` methods to C++
11
+ - Fixed memory issue
12
+
13
+ ## 0.9.0 (2021-10-23)
14
+
15
+ - Updated LibTorch to 1.10.0
16
+ - Added `real` and `imag` methods to tensors
17
+
18
+ ## 0.8.3 (2021-10-17)
19
+
20
+ - Fixed `dup` method for tensors and parameters
21
+ - Fixed issues with transformers
22
+
1
23
  ## 0.8.2 (2021-10-03)
2
24
 
3
25
  - Added transformers
data/README.md CHANGED
@@ -21,10 +21,10 @@ brew install libtorch
21
21
  Add this line to your application’s Gemfile:
22
22
 
23
23
  ```ruby
24
- gem 'torch-rb'
24
+ gem "torch-rb"
25
25
  ```
26
26
 
27
- It can take a few minutes to compile the extension.
27
+ It can take 5-10 minutes to compile the extension.
28
28
 
29
29
  ## Getting Started
30
30
 
@@ -79,7 +79,7 @@ b = Torch.zeros(2, 3)
79
79
 
80
80
  Each tensor has four properties
81
81
 
82
- - `dtype` - the data type - `:uint8`, `:int8`, `:int16`, `:int32`, `:int64`, `:float32`, `float64`, or `:bool`
82
+ - `dtype` - the data type - `:uint8`, `:int8`, `:int16`, `:int32`, `:int64`, `:float32`, `:float64`, or `:bool`
83
83
  - `layout` - `:strided` (dense) or `:sparse`
84
84
  - `device` - the compute device, like CPU or GPU
85
85
  - `requires_grad` - whether or not to record gradients
@@ -408,7 +408,8 @@ Here’s the list of compatible versions.
408
408
 
409
409
  Torch.rb | LibTorch
410
410
  --- | ---
411
- 0.8.0+ | 1.9.0+
411
+ 0.9.0+ | 1.10.0+
412
+ 0.8.0-0.8.3 | 1.9.0-1.9.1
412
413
  0.6.0-0.7.0 | 1.8.0-1.8.1
413
414
  0.5.0-0.5.3 | 1.7.0-1.7.1
414
415
  0.3.0-0.4.2 | 1.6.0
@@ -28,6 +28,9 @@ def skip_functions(functions)
28
28
  f.base_name.include?("_forward") ||
29
29
  f.base_name == "to" ||
30
30
  f.base_name == "record_stream" ||
31
+ f.base_name == "is_pinned" ||
32
+ f.base_name == "pin_memory" ||
33
+ f.base_name == "fused_moving_avg_obs_fake_quant" ||
31
34
  # in ext.cpp
32
35
  f.base_name == "index" ||
33
36
  f.base_name == "index_put_" ||
@@ -387,6 +390,8 @@ def generate_function_params(function, params, remove_self)
387
390
  end
388
391
  when "generator", "tensorlist", "intlist"
389
392
  func
393
+ when "string"
394
+ "stringViewOptional"
390
395
  else
391
396
  "#{func}Optional"
392
397
  end
@@ -424,9 +429,7 @@ def generate_dispatch_params(function, params)
424
429
  if function.out?
425
430
  "const Tensor &"
426
431
  else
427
- # TODO
428
- # "const c10::optional<at::Tensor> &"
429
- "const OptionalTensor &"
432
+ "const c10::optional<at::Tensor> &"
430
433
  end
431
434
  elsif param[:modifier]
432
435
  if param[:modifier].include?("!") && function.retvals.size > 1
@@ -450,7 +453,11 @@ def generate_dispatch_params(function, params)
450
453
  when "float[]"
451
454
  "ArrayRef<double>"
452
455
  when "str"
453
- "std::string"
456
+ if param[:optional]
457
+ "c10::string_view"
458
+ else
459
+ "std::string"
460
+ end
454
461
  when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage"
455
462
  param[:type]
456
463
  else