torch-rb 0.10.2 → 0.11.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 935c88d7ce2a1f9d0f12014dd4f79beee6d1c153bc2439f614f00f50ec60c286
4
- data.tar.gz: d77145dd7944d0ec024e5bef00d6ba1d47e52efb74c1d731101b8b8e67c9a3c9
3
+ metadata.gz: d2f88c938144476a772fd7c606751e0a6e67a338cb1c37ede9c2011db7fc4579
4
+ data.tar.gz: f4408457e3bf8c7bf9b42863459248aa592c70e8e3924cdbe72863a979e65106
5
5
  SHA512:
6
- metadata.gz: d9b4df4b90ff11b67d4c203a2e5f59037a6079dafb16ec02bae89a34f364ca86a5530dcb74adce48474c857ce2f2ce49323d96830fb30c843440b842b6e7fb50
7
- data.tar.gz: a68bf5d89984861c847390a4aaea337bc86a43c01ee93de47f575b52a87b7281e3318fcf10d70f0db5013f9685a126a46d80e42f88bbf51080bd15fac1933f71
6
+ metadata.gz: 90ebc506942809b02331f7accce6e93e714a57b5d2f58b06ad0d3c60204b947eba61fe0bf43431a19e27890f5c764cb348ccb1785a2ffccfba7969f4726b6f1f
7
+ data.tar.gz: 410af7b4934f79aaae94f6c1bde1d17ee8b043f02354b4ea9f139271575f43abb90afeba162badfa01a3ec406612ae6cdcb4a2adfb1067b1634982a590d10cc9
data/CHANGELOG.md CHANGED
@@ -1,3 +1,8 @@
1
+ ## 0.11.0 (2022-07-06)
2
+
3
+ - Updated LibTorch to 1.12.0
4
+ - Dropped support for Ruby < 2.7
5
+
1
6
  ## 0.10.2 (2022-06-14)
2
7
 
3
8
  - Improved numeric operations between scalars and tensors
data/README.md CHANGED
@@ -409,7 +409,8 @@ Here’s the list of compatible versions.
409
409
 
410
410
  Torch.rb | LibTorch
411
411
  --- | ---
412
- 0.10.0+ | 1.11.0+
412
+ 0.11.0+ | 1.12.0+
413
+ 0.10.0-0.10.2 | 1.11.0
413
414
  0.9.0-0.9.2 | 1.10.0-1.10.2
414
415
  0.8.0-0.8.3 | 1.9.0-1.9.1
415
416
  0.6.0-0.7.0 | 1.8.0-1.8.1
@@ -39,7 +39,14 @@ def skip_functions(functions)
39
39
  f.base_name == "index_put" ||
40
40
  # not supported yet
41
41
  f.func.include?("Dimname") ||
42
- f.func.include?("ConstQuantizerPtr")
42
+ f.func.include?("ConstQuantizerPtr") ||
43
+ f.func.include?("SymInt") ||
44
+ # TODO fix LibTorch 1.12 changes
45
+ f.base_name == "histogramdd" ||
46
+ f.base_name == "nested_tensor" ||
47
+ f.base_name == "split_copy" ||
48
+ f.base_name == "split_with_sizes_copy" ||
49
+ f.base_name == "unbind_copy"
43
50
  end
44
51
  end
45
52
 
@@ -250,7 +257,7 @@ def generate_dispatch(function, def_method)
250
257
 
251
258
  cpp_params = generate_dispatch_params(function, params)
252
259
  if opt_index
253
- cpp_params.insert(remove_self ? opt_index + 1 : opt_index, "const TensorOptions & options")
260
+ cpp_params.insert(remove_self ? opt_index + 1 : opt_index, "TensorOptions options")
254
261
  end
255
262
 
256
263
  retval = generate_dispatch_retval(function)
@@ -410,7 +417,7 @@ def generate_function_params(function, params, remove_self)
410
417
  else
411
418
  "optionalTensor"
412
419
  end
413
- when "generator", "tensorlist", "intlist"
420
+ when "generator", "tensorlist"
414
421
  func
415
422
  when "string"
416
423
  "stringViewOptional"
@@ -471,7 +478,11 @@ def generate_dispatch_params(function, params)
471
478
  when "float"
472
479
  "double"
473
480
  when /\Aint\[/
474
- "IntArrayRef"
481
+ if param[:optional]
482
+ "at::OptionalIntArrayRef"
483
+ else
484
+ "IntArrayRef"
485
+ end
475
486
  when "float[]"
476
487
  "ArrayRef<double>"
477
488
  when "str"
@@ -480,13 +491,19 @@ def generate_dispatch_params(function, params)
480
491
  else
481
492
  "std::string"
482
493
  end
483
- when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage"
494
+ when "Scalar"
495
+ if param[:optional]
496
+ "const c10::optional<Scalar> &"
497
+ else
498
+ "const Scalar &"
499
+ end
500
+ when "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage"
484
501
  param[:type]
485
502
  else
486
503
  raise "Unknown type: #{param[:type]} (#{function.name})"
487
504
  end
488
505
 
489
- if param[:optional] && param[:type] != "Tensor"
506
+ if param[:optional] && !["Tensor", "Scalar"].include?(param[:type]) && !param[:type].start_with?("int[")
490
507
  type = "c10::optional<#{type}>"
491
508
  end
492
509
 
@@ -529,7 +546,7 @@ def generate_dispatch_retval(function)
529
546
  when ["float", "float"]
530
547
  "std::tuple<double,double>"
531
548
  else
532
- raise "Unknown retvals: #{types}"
549
+ raise "Unknown retvals: #{types} (#{function.name})"
533
550
  end
534
551
  end
535
552