torch-rb 0.11.2 → 0.12.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 618b3c09305777402177e8bc3d536053434b1c94e5294035a8b4a78e645b3873
4
- data.tar.gz: 910ac373619cb43887826d6277bc075b88311abd5dcc2be92bc6e9e15a35971d
3
+ metadata.gz: 24617a6191c4d1e42ff0d5885c0762db75d034ee52817d70004ddc52025890dd
4
+ data.tar.gz: 7036f3c7fac8a1aac22914ec2811b8771464baeb847052e417b069f3897b3977
5
5
  SHA512:
6
- metadata.gz: af100f347769a268f7b45779fd9531f631e341bb0af4c999b1fba075b5d1aedae2d736eff048d9926c005b56e1ae35582a946b71db0cba5e62d83e938c315814
7
- data.tar.gz: 67803295ac4642c4cd32e66e9f3fcdeaaa33a37ae675693236c0561c4dfb4d6d405db6277ab156b0f31415058333ecd2aa06b3788f895a7836fc277b4d3fc084
6
+ metadata.gz: 044166f526bea5fea0c4314abc55dfaf2410cdd5d7bce764ccb78329754fe284675dc0ecb55cf5a4546587221cf2361c8e215f1e9e067668ea5833c445bea961
7
+ data.tar.gz: e086d0dcc731fb3e810f8fb91d04d95bf3c5efd9cd35f053238be60479d8a47f2126ddad7ab2705d18a26175c6c4126f12be3f19962c2a69f24ad3ee551252b1
data/CHANGELOG.md CHANGED
@@ -1,3 +1,7 @@
1
+ ## 0.12.0 (2022-11-05)
2
+
3
+ - Updated LibTorch to 1.13.0
4
+
1
5
  ## 0.11.2 (2022-09-25)
2
6
 
3
7
  - Improved LibTorch detection for Homebrew on Mac ARM and Linux
data/README.md CHANGED
@@ -13,10 +13,10 @@ Check out:
13
13
 
14
14
  ## Installation
15
15
 
16
- First, [install LibTorch](#libtorch-installation). For Homebrew, use:
16
+ First, [install LibTorch](#libtorch-installation). With Homebrew, it’s part of the PyTorch package:
17
17
 
18
18
  ```sh
19
- brew install libtorch
19
+ brew install pytorch
20
20
  ```
21
21
 
22
22
  Add this line to your application’s Gemfile:
@@ -409,7 +409,8 @@ Here’s the list of compatible versions.
409
409
 
410
410
  Torch.rb | LibTorch
411
411
  --- | ---
412
- 0.11.0+ | 1.12.0+
412
+ 0.12.0+ | 1.13.0+
413
+ 0.11.0-0.11.2 | 1.12.0-1.12.1
413
414
  0.10.0-0.10.2 | 1.11.0
414
415
  0.9.0-0.9.2 | 1.10.0-1.10.2
415
416
  0.8.0-0.8.3 | 1.9.0-1.9.1
@@ -425,7 +426,7 @@ Torch.rb | LibTorch
425
426
  You can also use Homebrew.
426
427
 
427
428
  ```sh
428
- brew install libtorch
429
+ brew install pytorch
429
430
  ```
430
431
 
431
432
  ## Performance
data/codegen/function.rb CHANGED
@@ -34,6 +34,10 @@ class Function
34
34
  !out_index.nil?
35
35
  end
36
36
 
37
+ def dispatch_name
38
+ definition.dig("dispatch", "CompositeImplicitAutograd")
39
+ end
40
+
37
41
  private
38
42
 
39
43
  def parse_func
@@ -77,6 +81,10 @@ class Function
77
81
  default = "torch.int64"
78
82
  end
79
83
 
84
+ if name == "dtype" && base_name == "randint"
85
+ default = "None"
86
+ end
87
+
80
88
  default = nil if definition["cpp_no_default_args"].to_a.include?(name)
81
89
 
82
90
  params << {
@@ -15,6 +15,7 @@ def generate_functions
15
15
  generate_files("linalg", :define_singleton_method, functions[:linalg])
16
16
  generate_files("special", :define_singleton_method, functions[:special])
17
17
  generate_files("sparse", :define_singleton_method, functions[:sparse])
18
+ # TODO generate nested
18
19
  end
19
20
 
20
21
  def load_functions
@@ -40,13 +41,14 @@ def skip_functions(functions)
40
41
  # not supported yet
41
42
  f.func.include?("Dimname") ||
42
43
  f.func.include?("ConstQuantizerPtr") ||
43
- f.func.include?("SymInt") ||
44
44
  # TODO fix LibTorch 1.12 changes
45
45
  f.base_name == "histogramdd" ||
46
46
  f.base_name == "nested_tensor" ||
47
47
  f.base_name == "split_copy" ||
48
48
  f.base_name == "split_with_sizes_copy" ||
49
- f.base_name == "unbind_copy"
49
+ f.base_name == "unbind_copy" ||
50
+ # TODO fix LibTorch 1.13 changes
51
+ f.base_name == "native_channel_shuffle"
50
52
  end
51
53
  end
52
54
 
@@ -56,6 +58,7 @@ def group_functions(functions)
56
58
  fft_functions, other_functions = other_functions.partition { |f| f.python_module == "fft" }
57
59
  special_functions, other_functions = other_functions.partition { |f| f.python_module == "special" }
58
60
  sparse_functions, other_functions = other_functions.partition { |f| f.python_module == "sparse" }
61
+ nested_functions, other_functions = other_functions.partition { |f| f.python_module == "nested" }
59
62
  unexpected_functions, other_functions = other_functions.partition { |f| f.python_module }
60
63
  torch_functions = other_functions.select { |f| f.variants.include?("function") }
61
64
  tensor_functions = other_functions.select { |f| f.variants.include?("method") }
@@ -72,7 +75,8 @@ def group_functions(functions)
72
75
  linalg: linalg_functions,
73
76
  fft: fft_functions,
74
77
  special: special_functions,
75
- sparse: sparse_functions
78
+ sparse: sparse_functions,
79
+ nested: nested_functions
76
80
  }
77
81
  end
78
82
 
@@ -387,6 +391,8 @@ def generate_function_params(function, params, remove_self)
387
391
  "scalarlist"
388
392
  when /\Aint\[/
389
393
  "intlist"
394
+ when /\ASymInt\[/
395
+ "symintlist"
390
396
  when "float[]"
391
397
  "doublelist"
392
398
  when "Scalar"
@@ -395,6 +401,8 @@ def generate_function_params(function, params, remove_self)
395
401
  "toBool"
396
402
  when "int"
397
403
  "toInt64"
404
+ when "SymInt"
405
+ "toSymInt"
398
406
  when "float"
399
407
  "toDouble"
400
408
  when "ScalarType"
@@ -437,7 +445,12 @@ def generate_dispatch_code(function, def_method, params, opt_index, remove_self)
437
445
  # torch::empty sets requires_grad by at::empty doesn't
438
446
  # https://github.com/pytorch/pytorch/issues/36455
439
447
  prefix = remove_self ? "self." : (opt_index ? "torch::" : "at::")
440
- dispatch = function.out? ? "#{function.base_name}_out" : function.base_name
448
+ dispatch = function.dispatch_name
449
+ unless dispatch
450
+ dispatch = function.base_name
451
+ dispatch += "_symint" if function.func.include?("SymInt")
452
+ dispatch += "_out" if function.out?
453
+ end
441
454
 
442
455
  params = params.map { |v| v[:name] }
443
456
  params.reject! { |v| v == "self" } if remove_self
@@ -478,6 +491,8 @@ def generate_dispatch_params(function, params)
478
491
  "ScalarList"
479
492
  when "int"
480
493
  "int64_t"
494
+ when "SymInt"
495
+ "c10::SymInt"
481
496
  when "float"
482
497
  "double"
483
498
  when /\Aint\[/
@@ -486,6 +501,12 @@ def generate_dispatch_params(function, params)
486
501
  else
487
502
  "IntArrayRef"
488
503
  end
504
+ when /\ASymInt\[/
505
+ if param[:optional]
506
+ "at::OptionalSymIntArrayRef"
507
+ else
508
+ "c10::SymIntArrayRef"
509
+ end
489
510
  when "float[]"
490
511
  "ArrayRef<double>"
491
512
  when "str"
@@ -506,7 +527,7 @@ def generate_dispatch_params(function, params)
506
527
  raise "Unknown type: #{param[:type]} (#{function.name})"
507
528
  end
508
529
 
509
- if param[:optional] && !["Tensor", "Scalar"].include?(param[:type]) && !param[:type].start_with?("int[")
530
+ if param[:optional] && !["Tensor", "Scalar"].include?(param[:type]) && !param[:type].start_with?("int[") && !param[:type].start_with?("SymInt[")
510
531
  type = "c10::optional<#{type}>"
511
532
  end
512
533
 
@@ -614,8 +635,12 @@ def signature_type(param)
614
635
  "DirnameList"
615
636
  when /\Aint\[\d*\]\z/
616
637
  "IntArrayRef"
638
+ when /\ASymInt\[\d*\]\z/
639
+ "SymIntArrayRef"
617
640
  when "int"
618
641
  "int64_t"
642
+ when "SymInt"
643
+ "c10::SymInt"
619
644
  when "float"
620
645
  "double"
621
646
  when "str"