torch-rb 0.11.2 → 0.12.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 618b3c09305777402177e8bc3d536053434b1c94e5294035a8b4a78e645b3873
4
- data.tar.gz: 910ac373619cb43887826d6277bc075b88311abd5dcc2be92bc6e9e15a35971d
3
+ metadata.gz: 24617a6191c4d1e42ff0d5885c0762db75d034ee52817d70004ddc52025890dd
4
+ data.tar.gz: 7036f3c7fac8a1aac22914ec2811b8771464baeb847052e417b069f3897b3977
5
5
  SHA512:
6
- metadata.gz: af100f347769a268f7b45779fd9531f631e341bb0af4c999b1fba075b5d1aedae2d736eff048d9926c005b56e1ae35582a946b71db0cba5e62d83e938c315814
7
- data.tar.gz: 67803295ac4642c4cd32e66e9f3fcdeaaa33a37ae675693236c0561c4dfb4d6d405db6277ab156b0f31415058333ecd2aa06b3788f895a7836fc277b4d3fc084
6
+ metadata.gz: 044166f526bea5fea0c4314abc55dfaf2410cdd5d7bce764ccb78329754fe284675dc0ecb55cf5a4546587221cf2361c8e215f1e9e067668ea5833c445bea961
7
+ data.tar.gz: e086d0dcc731fb3e810f8fb91d04d95bf3c5efd9cd35f053238be60479d8a47f2126ddad7ab2705d18a26175c6c4126f12be3f19962c2a69f24ad3ee551252b1
data/CHANGELOG.md CHANGED
@@ -1,3 +1,7 @@
1
+ ## 0.12.0 (2022-11-05)
2
+
3
+ - Updated LibTorch to 1.13.0
4
+
1
5
  ## 0.11.2 (2022-09-25)
2
6
 
3
7
  - Improved LibTorch detection for Homebrew on Mac ARM and Linux
data/README.md CHANGED
@@ -13,10 +13,10 @@ Check out:
13
13
 
14
14
  ## Installation
15
15
 
16
- First, [install LibTorch](#libtorch-installation). For Homebrew, use:
16
+ First, [install LibTorch](#libtorch-installation). With Homebrew, it’s part of the PyTorch package:
17
17
 
18
18
  ```sh
19
- brew install libtorch
19
+ brew install pytorch
20
20
  ```
21
21
 
22
22
  Add this line to your application’s Gemfile:
@@ -409,7 +409,8 @@ Here’s the list of compatible versions.
409
409
 
410
410
  Torch.rb | LibTorch
411
411
  --- | ---
412
- 0.11.0+ | 1.12.0+
412
+ 0.12.0+ | 1.13.0+
413
+ 0.11.0-0.11.2 | 1.12.0-1.12.1
413
414
  0.10.0-0.10.2 | 1.11.0
414
415
  0.9.0-0.9.2 | 1.10.0-1.10.2
415
416
  0.8.0-0.8.3 | 1.9.0-1.9.1
@@ -425,7 +426,7 @@ Torch.rb | LibTorch
425
426
  You can also use Homebrew.
426
427
 
427
428
  ```sh
428
- brew install libtorch
429
+ brew install pytorch
429
430
  ```
430
431
 
431
432
  ## Performance
data/codegen/function.rb CHANGED
@@ -34,6 +34,10 @@ class Function
34
34
  !out_index.nil?
35
35
  end
36
36
 
37
+ def dispatch_name
38
+ definition.dig("dispatch", "CompositeImplicitAutograd")
39
+ end
40
+
37
41
  private
38
42
 
39
43
  def parse_func
@@ -77,6 +81,10 @@ class Function
77
81
  default = "torch.int64"
78
82
  end
79
83
 
84
+ if name == "dtype" && base_name == "randint"
85
+ default = "None"
86
+ end
87
+
80
88
  default = nil if definition["cpp_no_default_args"].to_a.include?(name)
81
89
 
82
90
  params << {
@@ -15,6 +15,7 @@ def generate_functions
15
15
  generate_files("linalg", :define_singleton_method, functions[:linalg])
16
16
  generate_files("special", :define_singleton_method, functions[:special])
17
17
  generate_files("sparse", :define_singleton_method, functions[:sparse])
18
+ # TODO generate nested
18
19
  end
19
20
 
20
21
  def load_functions
@@ -40,13 +41,14 @@ def skip_functions(functions)
40
41
  # not supported yet
41
42
  f.func.include?("Dimname") ||
42
43
  f.func.include?("ConstQuantizerPtr") ||
43
- f.func.include?("SymInt") ||
44
44
  # TODO fix LibTorch 1.12 changes
45
45
  f.base_name == "histogramdd" ||
46
46
  f.base_name == "nested_tensor" ||
47
47
  f.base_name == "split_copy" ||
48
48
  f.base_name == "split_with_sizes_copy" ||
49
- f.base_name == "unbind_copy"
49
+ f.base_name == "unbind_copy" ||
50
+ # TODO fix LibTorch 1.13 changes
51
+ f.base_name == "native_channel_shuffle"
50
52
  end
51
53
  end
52
54
 
@@ -56,6 +58,7 @@ def group_functions(functions)
56
58
  fft_functions, other_functions = other_functions.partition { |f| f.python_module == "fft" }
57
59
  special_functions, other_functions = other_functions.partition { |f| f.python_module == "special" }
58
60
  sparse_functions, other_functions = other_functions.partition { |f| f.python_module == "sparse" }
61
+ nested_functions, other_functions = other_functions.partition { |f| f.python_module == "nested" }
59
62
  unexpected_functions, other_functions = other_functions.partition { |f| f.python_module }
60
63
  torch_functions = other_functions.select { |f| f.variants.include?("function") }
61
64
  tensor_functions = other_functions.select { |f| f.variants.include?("method") }
@@ -72,7 +75,8 @@ def group_functions(functions)
72
75
  linalg: linalg_functions,
73
76
  fft: fft_functions,
74
77
  special: special_functions,
75
- sparse: sparse_functions
78
+ sparse: sparse_functions,
79
+ nested: nested_functions
76
80
  }
77
81
  end
78
82
 
@@ -387,6 +391,8 @@ def generate_function_params(function, params, remove_self)
387
391
  "scalarlist"
388
392
  when /\Aint\[/
389
393
  "intlist"
394
+ when /\ASymInt\[/
395
+ "symintlist"
390
396
  when "float[]"
391
397
  "doublelist"
392
398
  when "Scalar"
@@ -395,6 +401,8 @@ def generate_function_params(function, params, remove_self)
395
401
  "toBool"
396
402
  when "int"
397
403
  "toInt64"
404
+ when "SymInt"
405
+ "toSymInt"
398
406
  when "float"
399
407
  "toDouble"
400
408
  when "ScalarType"
@@ -437,7 +445,12 @@ def generate_dispatch_code(function, def_method, params, opt_index, remove_self)
437
445
  # torch::empty sets requires_grad by at::empty doesn't
438
446
  # https://github.com/pytorch/pytorch/issues/36455
439
447
  prefix = remove_self ? "self." : (opt_index ? "torch::" : "at::")
440
- dispatch = function.out? ? "#{function.base_name}_out" : function.base_name
448
+ dispatch = function.dispatch_name
449
+ unless dispatch
450
+ dispatch = function.base_name
451
+ dispatch += "_symint" if function.func.include?("SymInt")
452
+ dispatch += "_out" if function.out?
453
+ end
441
454
 
442
455
  params = params.map { |v| v[:name] }
443
456
  params.reject! { |v| v == "self" } if remove_self
@@ -478,6 +491,8 @@ def generate_dispatch_params(function, params)
478
491
  "ScalarList"
479
492
  when "int"
480
493
  "int64_t"
494
+ when "SymInt"
495
+ "c10::SymInt"
481
496
  when "float"
482
497
  "double"
483
498
  when /\Aint\[/
@@ -486,6 +501,12 @@ def generate_dispatch_params(function, params)
486
501
  else
487
502
  "IntArrayRef"
488
503
  end
504
+ when /\ASymInt\[/
505
+ if param[:optional]
506
+ "at::OptionalSymIntArrayRef"
507
+ else
508
+ "c10::SymIntArrayRef"
509
+ end
489
510
  when "float[]"
490
511
  "ArrayRef<double>"
491
512
  when "str"
@@ -506,7 +527,7 @@ def generate_dispatch_params(function, params)
506
527
  raise "Unknown type: #{param[:type]} (#{function.name})"
507
528
  end
508
529
 
509
- if param[:optional] && !["Tensor", "Scalar"].include?(param[:type]) && !param[:type].start_with?("int[")
530
+ if param[:optional] && !["Tensor", "Scalar"].include?(param[:type]) && !param[:type].start_with?("int[") && !param[:type].start_with?("SymInt[")
510
531
  type = "c10::optional<#{type}>"
511
532
  end
512
533
 
@@ -614,8 +635,12 @@ def signature_type(param)
614
635
  "DirnameList"
615
636
  when /\Aint\[\d*\]\z/
616
637
  "IntArrayRef"
638
+ when /\ASymInt\[\d*\]\z/
639
+ "SymIntArrayRef"
617
640
  when "int"
618
641
  "int64_t"
642
+ when "SymInt"
643
+ "c10::SymInt"
619
644
  when "float"
620
645
  "double"
621
646
  when "str"