torch-rb 0.11.2 → 0.12.1

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 618b3c09305777402177e8bc3d536053434b1c94e5294035a8b4a78e645b3873
4
- data.tar.gz: 910ac373619cb43887826d6277bc075b88311abd5dcc2be92bc6e9e15a35971d
3
+ metadata.gz: ce853372191c85509a65417abeaa05c484976f24681246babf9bd00f8db16df1
4
+ data.tar.gz: 3327004180566a194c7de8288c260b8fe9487d80c25493d82e041cf4fc0062e2
5
5
  SHA512:
6
- metadata.gz: af100f347769a268f7b45779fd9531f631e341bb0af4c999b1fba075b5d1aedae2d736eff048d9926c005b56e1ae35582a946b71db0cba5e62d83e938c315814
7
- data.tar.gz: 67803295ac4642c4cd32e66e9f3fcdeaaa33a37ae675693236c0561c4dfb4d6d405db6277ab156b0f31415058333ecd2aa06b3788f895a7836fc277b4d3fc084
6
+ metadata.gz: d8b6b0f7bd8b79963931b6b28b6a6cee59be18b8f185e2acadc22487a27b5793edabf0fb8b80857c5dfc0eb036a27b67a55750b8dd3c8eb1d199f06323e2b919
7
+ data.tar.gz: ffcafd2e9e99d6654f9689dd021c874cf53847833e26207853b0051056ae7cfe0b5ff202036e9a38264198facfa2c66f83b6ee51815131b63aa350415139b614
data/CHANGELOG.md CHANGED
@@ -1,3 +1,11 @@
1
+ ## 0.12.1 (2023-01-29)
2
+
3
+ - Added `Generator` class
4
+
5
+ ## 0.12.0 (2022-11-05)
6
+
7
+ - Updated LibTorch to 1.13.0
8
+
1
9
  ## 0.11.2 (2022-09-25)
2
10
 
3
11
  - Improved LibTorch detection for Homebrew on Mac ARM and Linux
data/README.md CHANGED
@@ -13,10 +13,10 @@ Check out:
13
13
 
14
14
  ## Installation
15
15
 
16
- First, [install LibTorch](#libtorch-installation). For Homebrew, use:
16
+ First, [install LibTorch](#libtorch-installation). With Homebrew, it’s part of the PyTorch package:
17
17
 
18
18
  ```sh
19
- brew install libtorch
19
+ brew install pytorch
20
20
  ```
21
21
 
22
22
  Add this line to your application’s Gemfile:
@@ -409,7 +409,8 @@ Here’s the list of compatible versions.
409
409
 
410
410
  Torch.rb | LibTorch
411
411
  --- | ---
412
- 0.11.0+ | 1.12.0+
412
+ 0.12.0+ | 1.13.0+
413
+ 0.11.0-0.11.2 | 1.12.0-1.12.1
413
414
  0.10.0-0.10.2 | 1.11.0
414
415
  0.9.0-0.9.2 | 1.10.0-1.10.2
415
416
  0.8.0-0.8.3 | 1.9.0-1.9.1
@@ -425,7 +426,7 @@ Torch.rb | LibTorch
425
426
  You can also use Homebrew.
426
427
 
427
428
  ```sh
428
- brew install libtorch
429
+ brew install pytorch
429
430
  ```
430
431
 
431
432
  ## Performance
data/codegen/function.rb CHANGED
@@ -34,6 +34,10 @@ class Function
34
34
  !out_index.nil?
35
35
  end
36
36
 
37
+ def dispatch_name
38
+ definition.dig("dispatch", "CompositeImplicitAutograd")
39
+ end
40
+
37
41
  private
38
42
 
39
43
  def parse_func
@@ -77,6 +81,10 @@ class Function
77
81
  default = "torch.int64"
78
82
  end
79
83
 
84
+ if name == "dtype" && base_name == "randint"
85
+ default = "None"
86
+ end
87
+
80
88
  default = nil if definition["cpp_no_default_args"].to_a.include?(name)
81
89
 
82
90
  params << {
@@ -15,6 +15,7 @@ def generate_functions
15
15
  generate_files("linalg", :define_singleton_method, functions[:linalg])
16
16
  generate_files("special", :define_singleton_method, functions[:special])
17
17
  generate_files("sparse", :define_singleton_method, functions[:sparse])
18
+ # TODO generate nested
18
19
  end
19
20
 
20
21
  def load_functions
@@ -40,13 +41,14 @@ def skip_functions(functions)
40
41
  # not supported yet
41
42
  f.func.include?("Dimname") ||
42
43
  f.func.include?("ConstQuantizerPtr") ||
43
- f.func.include?("SymInt") ||
44
44
  # TODO fix LibTorch 1.12 changes
45
45
  f.base_name == "histogramdd" ||
46
46
  f.base_name == "nested_tensor" ||
47
47
  f.base_name == "split_copy" ||
48
48
  f.base_name == "split_with_sizes_copy" ||
49
- f.base_name == "unbind_copy"
49
+ f.base_name == "unbind_copy" ||
50
+ # TODO fix LibTorch 1.13 changes
51
+ f.base_name == "native_channel_shuffle"
50
52
  end
51
53
  end
52
54
 
@@ -56,6 +58,7 @@ def group_functions(functions)
56
58
  fft_functions, other_functions = other_functions.partition { |f| f.python_module == "fft" }
57
59
  special_functions, other_functions = other_functions.partition { |f| f.python_module == "special" }
58
60
  sparse_functions, other_functions = other_functions.partition { |f| f.python_module == "sparse" }
61
+ nested_functions, other_functions = other_functions.partition { |f| f.python_module == "nested" }
59
62
  unexpected_functions, other_functions = other_functions.partition { |f| f.python_module }
60
63
  torch_functions = other_functions.select { |f| f.variants.include?("function") }
61
64
  tensor_functions = other_functions.select { |f| f.variants.include?("method") }
@@ -72,7 +75,8 @@ def group_functions(functions)
72
75
  linalg: linalg_functions,
73
76
  fft: fft_functions,
74
77
  special: special_functions,
75
- sparse: sparse_functions
78
+ sparse: sparse_functions,
79
+ nested: nested_functions
76
80
  }
77
81
  end
78
82
 
@@ -387,6 +391,8 @@ def generate_function_params(function, params, remove_self)
387
391
  "scalarlist"
388
392
  when /\Aint\[/
389
393
  "intlist"
394
+ when /\ASymInt\[/
395
+ "symintlist"
390
396
  when "float[]"
391
397
  "doublelist"
392
398
  when "Scalar"
@@ -395,6 +401,8 @@ def generate_function_params(function, params, remove_self)
395
401
  "toBool"
396
402
  when "int"
397
403
  "toInt64"
404
+ when "SymInt"
405
+ "toSymInt"
398
406
  when "float"
399
407
  "toDouble"
400
408
  when "ScalarType"
@@ -437,7 +445,12 @@ def generate_dispatch_code(function, def_method, params, opt_index, remove_self)
437
445
  # torch::empty sets requires_grad by at::empty doesn't
438
446
  # https://github.com/pytorch/pytorch/issues/36455
439
447
  prefix = remove_self ? "self." : (opt_index ? "torch::" : "at::")
440
- dispatch = function.out? ? "#{function.base_name}_out" : function.base_name
448
+ dispatch = function.dispatch_name
449
+ unless dispatch
450
+ dispatch = function.base_name
451
+ dispatch += "_symint" if function.func.include?("SymInt")
452
+ dispatch += "_out" if function.out?
453
+ end
441
454
 
442
455
  params = params.map { |v| v[:name] }
443
456
  params.reject! { |v| v == "self" } if remove_self
@@ -478,6 +491,8 @@ def generate_dispatch_params(function, params)
478
491
  "ScalarList"
479
492
  when "int"
480
493
  "int64_t"
494
+ when "SymInt"
495
+ "c10::SymInt"
481
496
  when "float"
482
497
  "double"
483
498
  when /\Aint\[/
@@ -486,6 +501,12 @@ def generate_dispatch_params(function, params)
486
501
  else
487
502
  "IntArrayRef"
488
503
  end
504
+ when /\ASymInt\[/
505
+ if param[:optional]
506
+ "at::OptionalSymIntArrayRef"
507
+ else
508
+ "c10::SymIntArrayRef"
509
+ end
489
510
  when "float[]"
490
511
  "ArrayRef<double>"
491
512
  when "str"
@@ -506,7 +527,7 @@ def generate_dispatch_params(function, params)
506
527
  raise "Unknown type: #{param[:type]} (#{function.name})"
507
528
  end
508
529
 
509
- if param[:optional] && !["Tensor", "Scalar"].include?(param[:type]) && !param[:type].start_with?("int[")
530
+ if param[:optional] && !["Tensor", "Scalar"].include?(param[:type]) && !param[:type].start_with?("int[") && !param[:type].start_with?("SymInt[")
510
531
  type = "c10::optional<#{type}>"
511
532
  end
512
533
 
@@ -614,8 +635,12 @@ def signature_type(param)
614
635
  "DirnameList"
615
636
  when /\Aint\[\d*\]\z/
616
637
  "IntArrayRef"
638
+ when /\ASymInt\[\d*\]\z/
639
+ "SymIntArrayRef"
617
640
  when "int"
618
641
  "int64_t"
642
+ when "SymInt"
643
+ "c10::SymInt"
619
644
  when "float"
620
645
  "double"
621
646
  when "str"