torch-rb 0.10.2 → 0.11.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 935c88d7ce2a1f9d0f12014dd4f79beee6d1c153bc2439f614f00f50ec60c286
4
- data.tar.gz: d77145dd7944d0ec024e5bef00d6ba1d47e52efb74c1d731101b8b8e67c9a3c9
3
+ metadata.gz: d2f88c938144476a772fd7c606751e0a6e67a338cb1c37ede9c2011db7fc4579
4
+ data.tar.gz: f4408457e3bf8c7bf9b42863459248aa592c70e8e3924cdbe72863a979e65106
5
5
  SHA512:
6
- metadata.gz: d9b4df4b90ff11b67d4c203a2e5f59037a6079dafb16ec02bae89a34f364ca86a5530dcb74adce48474c857ce2f2ce49323d96830fb30c843440b842b6e7fb50
7
- data.tar.gz: a68bf5d89984861c847390a4aaea337bc86a43c01ee93de47f575b52a87b7281e3318fcf10d70f0db5013f9685a126a46d80e42f88bbf51080bd15fac1933f71
6
+ metadata.gz: 90ebc506942809b02331f7accce6e93e714a57b5d2f58b06ad0d3c60204b947eba61fe0bf43431a19e27890f5c764cb348ccb1785a2ffccfba7969f4726b6f1f
7
+ data.tar.gz: 410af7b4934f79aaae94f6c1bde1d17ee8b043f02354b4ea9f139271575f43abb90afeba162badfa01a3ec406612ae6cdcb4a2adfb1067b1634982a590d10cc9
data/CHANGELOG.md CHANGED
@@ -1,3 +1,8 @@
1
+ ## 0.11.0 (2022-07-06)
2
+
3
+ - Updated LibTorch to 1.12.0
4
+ - Dropped support for Ruby < 2.7
5
+
1
6
  ## 0.10.2 (2022-06-14)
2
7
 
3
8
  - Improved numeric operations between scalars and tensors
data/README.md CHANGED
@@ -409,7 +409,8 @@ Here’s the list of compatible versions.
409
409
 
410
410
  Torch.rb | LibTorch
411
411
  --- | ---
412
- 0.10.0+ | 1.11.0+
412
+ 0.11.0+ | 1.12.0+
413
+ 0.10.0-0.10.2 | 1.11.0
413
414
  0.9.0-0.9.2 | 1.10.0-1.10.2
414
415
  0.8.0-0.8.3 | 1.9.0-1.9.1
415
416
  0.6.0-0.7.0 | 1.8.0-1.8.1
@@ -39,7 +39,14 @@ def skip_functions(functions)
39
39
  f.base_name == "index_put" ||
40
40
  # not supported yet
41
41
  f.func.include?("Dimname") ||
42
- f.func.include?("ConstQuantizerPtr")
42
+ f.func.include?("ConstQuantizerPtr") ||
43
+ f.func.include?("SymInt") ||
44
+ # TODO fix LibTorch 1.12 changes
45
+ f.base_name == "histogramdd" ||
46
+ f.base_name == "nested_tensor" ||
47
+ f.base_name == "split_copy" ||
48
+ f.base_name == "split_with_sizes_copy" ||
49
+ f.base_name == "unbind_copy"
43
50
  end
44
51
  end
45
52
 
@@ -250,7 +257,7 @@ def generate_dispatch(function, def_method)
250
257
 
251
258
  cpp_params = generate_dispatch_params(function, params)
252
259
  if opt_index
253
- cpp_params.insert(remove_self ? opt_index + 1 : opt_index, "const TensorOptions & options")
260
+ cpp_params.insert(remove_self ? opt_index + 1 : opt_index, "TensorOptions options")
254
261
  end
255
262
 
256
263
  retval = generate_dispatch_retval(function)
@@ -410,7 +417,7 @@ def generate_function_params(function, params, remove_self)
410
417
  else
411
418
  "optionalTensor"
412
419
  end
413
- when "generator", "tensorlist", "intlist"
420
+ when "generator", "tensorlist"
414
421
  func
415
422
  when "string"
416
423
  "stringViewOptional"
@@ -471,7 +478,11 @@ def generate_dispatch_params(function, params)
471
478
  when "float"
472
479
  "double"
473
480
  when /\Aint\[/
474
- "IntArrayRef"
481
+ if param[:optional]
482
+ "at::OptionalIntArrayRef"
483
+ else
484
+ "IntArrayRef"
485
+ end
475
486
  when "float[]"
476
487
  "ArrayRef<double>"
477
488
  when "str"
@@ -480,13 +491,19 @@ def generate_dispatch_params(function, params)
480
491
  else
481
492
  "std::string"
482
493
  end
483
- when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage"
494
+ when "Scalar"
495
+ if param[:optional]
496
+ "const c10::optional<Scalar> &"
497
+ else
498
+ "const Scalar &"
499
+ end
500
+ when "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage"
484
501
  param[:type]
485
502
  else
486
503
  raise "Unknown type: #{param[:type]} (#{function.name})"
487
504
  end
488
505
 
489
- if param[:optional] && param[:type] != "Tensor"
506
+ if param[:optional] && !["Tensor", "Scalar"].include?(param[:type]) && !param[:type].start_with?("int[")
490
507
  type = "c10::optional<#{type}>"
491
508
  end
492
509
 
@@ -529,7 +546,7 @@ def generate_dispatch_retval(function)
529
546
  when ["float", "float"]
530
547
  "std::tuple<double,double>"
531
548
  else
532
- raise "Unknown retvals: #{types}"
549
+ raise "Unknown retvals: #{types} (#{function.name})"
533
550
  end
534
551
  end
535
552