torch-rb 0.10.1 → 0.11.1

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: f4665eec43d85fbf02ce75f4b268dbf001bfad7e3ae1ecace0e9911b651e2cc2
4
- data.tar.gz: d11ee1386ce7feeea68333de6c361d8737a7164cfd1626abce3a511deecb2963
3
+ metadata.gz: c4c3936a55fd47a8898d3e678aabc81c158f2f39a89c08af0b36695700bf2043
4
+ data.tar.gz: a68179c7d7bab7547ac3be1d2369abbd5c3c632954996ca2f7fdd089f815edf7
5
5
  SHA512:
6
- metadata.gz: cf346bc03f36d4fc920151b0554c93c33a59d0fecd35c6f110dc862ad8b35c6b8641d306124505ddd012ce9d903363672e99772cc2bd7b981d962c9d00f08d3e
7
- data.tar.gz: 265157846417fdc3c024e0f50d0b0a663d345ca423ad9233a7d0e722bf5134a8e15e1afce30248992155787484e80311e484c64c5787945b5f5a88625115479a
6
+ metadata.gz: 97551aa27f154cade530e58b4ae92286cd18e432acc99646495a197c89fb719cc991b0399aacaf383a2de04e340c04a9e177c44a0169fce91e19435463df4753
7
+ data.tar.gz: 8a5a5decf900a4d93aa56eb5e5b3848d5386be8bbc3657c10c201ef321cf0082c6521700378a5e96f795c09882b5b72498d8fed3931153b6e1f7557630f55547
data/CHANGELOG.md CHANGED
@@ -1,3 +1,17 @@
1
+ ## 0.11.1 (2022-07-06)
2
+
3
+ - Fixed error with `stft` method
4
+
5
+ ## 0.11.0 (2022-07-06)
6
+
7
+ - Updated LibTorch to 1.12.0
8
+ - Dropped support for Ruby < 2.7
9
+
10
+ ## 0.10.2 (2022-06-14)
11
+
12
+ - Improved numeric operations between scalars and tensors
13
+ - Fixed `dtype` of `cumsum` method
14
+
1
15
  ## 0.10.1 (2022-04-12)
2
16
 
3
17
  - Fixed `dtype`, `device`, and `layout` for `new_*` and `like_*` methods
data/README.md CHANGED
@@ -409,7 +409,8 @@ Here’s the list of compatible versions.
409
409
 
410
410
  Torch.rb | LibTorch
411
411
  --- | ---
412
- 0.10.0+ | 1.11.0+
412
+ 0.11.0+ | 1.12.0+
413
+ 0.10.0-0.10.2 | 1.11.0
413
414
  0.9.0-0.9.2 | 1.10.0-1.10.2
414
415
  0.8.0-0.8.3 | 1.9.0-1.9.1
415
416
  0.6.0-0.7.0 | 1.8.0-1.8.1
@@ -421,13 +422,25 @@ Torch.rb | LibTorch
421
422
 
422
423
  ### Homebrew
423
424
 
424
- For Mac, you can use Homebrew.
425
+ You can also use Homebrew.
425
426
 
426
427
  ```sh
427
428
  brew install libtorch
428
429
  ```
429
430
 
430
- Then install the gem (no need for `bundle config`).
431
+ For Mac ARM, run:
432
+
433
+ ```sh
434
+ bundle config build.torch-rb --with-torch-dir=/opt/homebrew
435
+ ```
436
+
437
+ And for Linux, run:
438
+
439
+ ```sh
440
+ bundle config build.torch-rb --with-torch-dir=/home/linuxbrew/.linuxbrew
441
+ ```
442
+
443
+ Then install the gem.
431
444
 
432
445
  ## Performance
433
446
 
data/codegen/function.rb CHANGED
@@ -60,7 +60,7 @@ class Function
60
60
 
61
61
  optional = false
62
62
  if type.include?("?")
63
- optional = true unless ["dtype", "device", "layout", "pin_memory"].include?(name)
63
+ optional = true
64
64
  type = type.delete("?")
65
65
  end
66
66
 
@@ -39,7 +39,14 @@ def skip_functions(functions)
39
39
  f.base_name == "index_put" ||
40
40
  # not supported yet
41
41
  f.func.include?("Dimname") ||
42
- f.func.include?("ConstQuantizerPtr")
42
+ f.func.include?("ConstQuantizerPtr") ||
43
+ f.func.include?("SymInt") ||
44
+ # TODO fix LibTorch 1.12 changes
45
+ f.base_name == "histogramdd" ||
46
+ f.base_name == "nested_tensor" ||
47
+ f.base_name == "split_copy" ||
48
+ f.base_name == "split_with_sizes_copy" ||
49
+ f.base_name == "unbind_copy"
43
50
  end
44
51
  end
45
52
 
@@ -121,8 +128,11 @@ def write_body(type, method_defs, attach_defs)
121
128
  end
122
129
 
123
130
  def write_file(name, contents)
124
- path = File.expand_path("../ext/torch", __dir__)
125
- File.write(File.join(path, name), contents)
131
+ path = File.join(File.expand_path("../ext/torch", __dir__), name)
132
+ # only write if changed to improve compile times in development
133
+ if !File.exist?(path) || File.read(path) != contents
134
+ File.write(path, contents)
135
+ end
126
136
  end
127
137
 
128
138
  def generate_attach_def(name, type, def_method)
@@ -135,14 +145,14 @@ def generate_attach_def(name, type, def_method)
135
145
  name
136
146
  end
137
147
 
138
- ruby_name = "_#{ruby_name}" if ["size", "stride", "random!", "stft"].include?(ruby_name)
148
+ ruby_name = "_#{ruby_name}" if ["size", "stride", "random!"].include?(ruby_name)
139
149
  ruby_name = ruby_name.sub(/\Afft_/, "") if type == "fft"
140
150
  ruby_name = ruby_name.sub(/\Alinalg_/, "") if type == "linalg"
141
151
  ruby_name = ruby_name.sub(/\Aspecial_/, "") if type == "special"
142
152
  ruby_name = ruby_name.sub(/\Asparse_/, "") if type == "sparse"
143
153
  ruby_name = name if name.start_with?("__")
144
154
 
145
- # cast for Ruby < 2.7 https://github.com/thisMagpie/fftw/issues/22#issuecomment-49508900
155
+ # cast for Ruby < 3.0 https://github.com/thisMagpie/fftw/issues/22#issuecomment-49508900
146
156
  cast = RUBY_VERSION.to_f > 2.7 ? "" : "(VALUE (*)(...)) "
147
157
 
148
158
  "rb_#{def_method}(m, \"#{ruby_name}\", #{cast}#{full_name(name, type)}, -1);"
@@ -250,7 +260,7 @@ def generate_dispatch(function, def_method)
250
260
 
251
261
  cpp_params = generate_dispatch_params(function, params)
252
262
  if opt_index
253
- cpp_params.insert(remove_self ? opt_index + 1 : opt_index, "const TensorOptions & options")
263
+ cpp_params.insert(remove_self ? opt_index + 1 : opt_index, "TensorOptions options")
254
264
  end
255
265
 
256
266
  retval = generate_dispatch_retval(function)
@@ -410,7 +420,7 @@ def generate_function_params(function, params, remove_self)
410
420
  else
411
421
  "optionalTensor"
412
422
  end
413
- when "generator", "tensorlist", "intlist"
423
+ when "generator", "tensorlist"
414
424
  func
415
425
  when "string"
416
426
  "stringViewOptional"
@@ -471,7 +481,11 @@ def generate_dispatch_params(function, params)
471
481
  when "float"
472
482
  "double"
473
483
  when /\Aint\[/
474
- "IntArrayRef"
484
+ if param[:optional]
485
+ "at::OptionalIntArrayRef"
486
+ else
487
+ "IntArrayRef"
488
+ end
475
489
  when "float[]"
476
490
  "ArrayRef<double>"
477
491
  when "str"
@@ -480,13 +494,19 @@ def generate_dispatch_params(function, params)
480
494
  else
481
495
  "std::string"
482
496
  end
483
- when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage"
497
+ when "Scalar"
498
+ if param[:optional]
499
+ "const c10::optional<Scalar> &"
500
+ else
501
+ "const Scalar &"
502
+ end
503
+ when "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage"
484
504
  param[:type]
485
505
  else
486
506
  raise "Unknown type: #{param[:type]} (#{function.name})"
487
507
  end
488
508
 
489
- if param[:optional] && param[:type] != "Tensor"
509
+ if param[:optional] && !["Tensor", "Scalar"].include?(param[:type]) && !param[:type].start_with?("int[")
490
510
  type = "c10::optional<#{type}>"
491
511
  end
492
512
 
@@ -529,7 +549,7 @@ def generate_dispatch_retval(function)
529
549
  when ["float", "float"]
530
550
  "std::tuple<double,double>"
531
551
  else
532
- raise "Unknown retvals: #{types}"
552
+ raise "Unknown retvals: #{types} (#{function.name})"
533
553
  end
534
554
  end
535
555