torch-rb 0.10.1 → 0.11.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: f4665eec43d85fbf02ce75f4b268dbf001bfad7e3ae1ecace0e9911b651e2cc2
4
- data.tar.gz: d11ee1386ce7feeea68333de6c361d8737a7164cfd1626abce3a511deecb2963
3
+ metadata.gz: c4c3936a55fd47a8898d3e678aabc81c158f2f39a89c08af0b36695700bf2043
4
+ data.tar.gz: a68179c7d7bab7547ac3be1d2369abbd5c3c632954996ca2f7fdd089f815edf7
5
5
  SHA512:
6
- metadata.gz: cf346bc03f36d4fc920151b0554c93c33a59d0fecd35c6f110dc862ad8b35c6b8641d306124505ddd012ce9d903363672e99772cc2bd7b981d962c9d00f08d3e
7
- data.tar.gz: 265157846417fdc3c024e0f50d0b0a663d345ca423ad9233a7d0e722bf5134a8e15e1afce30248992155787484e80311e484c64c5787945b5f5a88625115479a
6
+ metadata.gz: 97551aa27f154cade530e58b4ae92286cd18e432acc99646495a197c89fb719cc991b0399aacaf383a2de04e340c04a9e177c44a0169fce91e19435463df4753
7
+ data.tar.gz: 8a5a5decf900a4d93aa56eb5e5b3848d5386be8bbc3657c10c201ef321cf0082c6521700378a5e96f795c09882b5b72498d8fed3931153b6e1f7557630f55547
data/CHANGELOG.md CHANGED
@@ -1,3 +1,17 @@
1
+ ## 0.11.1 (2022-07-06)
2
+
3
+ - Fixed error with `stft` method
4
+
5
+ ## 0.11.0 (2022-07-06)
6
+
7
+ - Updated LibTorch to 1.12.0
8
+ - Dropped support for Ruby < 2.7
9
+
10
+ ## 0.10.2 (2022-06-14)
11
+
12
+ - Improved numeric operations between scalars and tensors
13
+ - Fixed `dtype` of `cumsum` method
14
+
1
15
  ## 0.10.1 (2022-04-12)
2
16
 
3
17
  - Fixed `dtype`, `device`, and `layout` for `new_*` and `like_*` methods
data/README.md CHANGED
@@ -409,7 +409,8 @@ Here’s the list of compatible versions.
409
409
 
410
410
  Torch.rb | LibTorch
411
411
  --- | ---
412
- 0.10.0+ | 1.11.0+
412
+ 0.11.0+ | 1.12.0+
413
+ 0.10.0-0.10.2 | 1.11.0
413
414
  0.9.0-0.9.2 | 1.10.0-1.10.2
414
415
  0.8.0-0.8.3 | 1.9.0-1.9.1
415
416
  0.6.0-0.7.0 | 1.8.0-1.8.1
@@ -421,13 +422,25 @@ Torch.rb | LibTorch
421
422
 
422
423
  ### Homebrew
423
424
 
424
- For Mac, you can use Homebrew.
425
+ You can also use Homebrew.
425
426
 
426
427
  ```sh
427
428
  brew install libtorch
428
429
  ```
429
430
 
430
- Then install the gem (no need for `bundle config`).
431
+ For Mac ARM, run:
432
+
433
+ ```sh
434
+ bundle config build.torch-rb --with-torch-dir=/opt/homebrew
435
+ ```
436
+
437
+ And for Linux, run:
438
+
439
+ ```sh
440
+ bundle config build.torch-rb --with-torch-dir=/home/linuxbrew/.linuxbrew
441
+ ```
442
+
443
+ Then install the gem.
431
444
 
432
445
  ## Performance
433
446
 
data/codegen/function.rb CHANGED
@@ -60,7 +60,7 @@ class Function
60
60
 
61
61
  optional = false
62
62
  if type.include?("?")
63
- optional = true unless ["dtype", "device", "layout", "pin_memory"].include?(name)
63
+ optional = true
64
64
  type = type.delete("?")
65
65
  end
66
66
 
@@ -39,7 +39,14 @@ def skip_functions(functions)
39
39
  f.base_name == "index_put" ||
40
40
  # not supported yet
41
41
  f.func.include?("Dimname") ||
42
- f.func.include?("ConstQuantizerPtr")
42
+ f.func.include?("ConstQuantizerPtr") ||
43
+ f.func.include?("SymInt") ||
44
+ # TODO fix LibTorch 1.12 changes
45
+ f.base_name == "histogramdd" ||
46
+ f.base_name == "nested_tensor" ||
47
+ f.base_name == "split_copy" ||
48
+ f.base_name == "split_with_sizes_copy" ||
49
+ f.base_name == "unbind_copy"
43
50
  end
44
51
  end
45
52
 
@@ -121,8 +128,11 @@ def write_body(type, method_defs, attach_defs)
121
128
  end
122
129
 
123
130
  def write_file(name, contents)
124
- path = File.expand_path("../ext/torch", __dir__)
125
- File.write(File.join(path, name), contents)
131
+ path = File.join(File.expand_path("../ext/torch", __dir__), name)
132
+ # only write if changed to improve compile times in development
133
+ if !File.exist?(path) || File.read(path) != contents
134
+ File.write(path, contents)
135
+ end
126
136
  end
127
137
 
128
138
  def generate_attach_def(name, type, def_method)
@@ -135,14 +145,14 @@ def generate_attach_def(name, type, def_method)
135
145
  name
136
146
  end
137
147
 
138
- ruby_name = "_#{ruby_name}" if ["size", "stride", "random!", "stft"].include?(ruby_name)
148
+ ruby_name = "_#{ruby_name}" if ["size", "stride", "random!"].include?(ruby_name)
139
149
  ruby_name = ruby_name.sub(/\Afft_/, "") if type == "fft"
140
150
  ruby_name = ruby_name.sub(/\Alinalg_/, "") if type == "linalg"
141
151
  ruby_name = ruby_name.sub(/\Aspecial_/, "") if type == "special"
142
152
  ruby_name = ruby_name.sub(/\Asparse_/, "") if type == "sparse"
143
153
  ruby_name = name if name.start_with?("__")
144
154
 
145
- # cast for Ruby < 2.7 https://github.com/thisMagpie/fftw/issues/22#issuecomment-49508900
155
+ # cast for Ruby < 3.0 https://github.com/thisMagpie/fftw/issues/22#issuecomment-49508900
146
156
  cast = RUBY_VERSION.to_f > 2.7 ? "" : "(VALUE (*)(...)) "
147
157
 
148
158
  "rb_#{def_method}(m, \"#{ruby_name}\", #{cast}#{full_name(name, type)}, -1);"
@@ -250,7 +260,7 @@ def generate_dispatch(function, def_method)
250
260
 
251
261
  cpp_params = generate_dispatch_params(function, params)
252
262
  if opt_index
253
- cpp_params.insert(remove_self ? opt_index + 1 : opt_index, "const TensorOptions & options")
263
+ cpp_params.insert(remove_self ? opt_index + 1 : opt_index, "TensorOptions options")
254
264
  end
255
265
 
256
266
  retval = generate_dispatch_retval(function)
@@ -410,7 +420,7 @@ def generate_function_params(function, params, remove_self)
410
420
  else
411
421
  "optionalTensor"
412
422
  end
413
- when "generator", "tensorlist", "intlist"
423
+ when "generator", "tensorlist"
414
424
  func
415
425
  when "string"
416
426
  "stringViewOptional"
@@ -471,7 +481,11 @@ def generate_dispatch_params(function, params)
471
481
  when "float"
472
482
  "double"
473
483
  when /\Aint\[/
474
- "IntArrayRef"
484
+ if param[:optional]
485
+ "at::OptionalIntArrayRef"
486
+ else
487
+ "IntArrayRef"
488
+ end
475
489
  when "float[]"
476
490
  "ArrayRef<double>"
477
491
  when "str"
@@ -480,13 +494,19 @@ def generate_dispatch_params(function, params)
480
494
  else
481
495
  "std::string"
482
496
  end
483
- when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage"
497
+ when "Scalar"
498
+ if param[:optional]
499
+ "const c10::optional<Scalar> &"
500
+ else
501
+ "const Scalar &"
502
+ end
503
+ when "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage"
484
504
  param[:type]
485
505
  else
486
506
  raise "Unknown type: #{param[:type]} (#{function.name})"
487
507
  end
488
508
 
489
- if param[:optional] && param[:type] != "Tensor"
509
+ if param[:optional] && !["Tensor", "Scalar"].include?(param[:type]) && !param[:type].start_with?("int[")
490
510
  type = "c10::optional<#{type}>"
491
511
  end
492
512
 
@@ -529,7 +549,7 @@ def generate_dispatch_retval(function)
529
549
  when ["float", "float"]
530
550
  "std::tuple<double,double>"
531
551
  else
532
- raise "Unknown retvals: #{types}"
552
+ raise "Unknown retvals: #{types} (#{function.name})"
533
553
  end
534
554
  end
535
555