torch-rb 0.1.8 → 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_torch_functions(Module m);
Binary file
data/lib/torch/hub.rb ADDED
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module Hub
3
+ class << self
4
+ def list(github, force_reload: false)
5
+ raise NotImplementedYet
6
+ end
7
+ end
8
+ end
9
+ end
@@ -18,14 +18,12 @@ module Torch
18
18
  functions = functions()
19
19
 
20
20
  # skip functions
21
- skip_binding = ["unique_dim_consecutive", "einsum", "normal"]
22
21
  skip_args = ["bool[3]", "Dimname", "MemoryFormat", "Layout", "Storage", "ConstQuantizerPtr"]
23
22
 
24
23
  # remove functions
25
24
  functions.reject! do |f|
26
25
  f.ruby_name.start_with?("_") ||
27
26
  f.ruby_name.end_with?("_backward") ||
28
- skip_binding.include?(f.ruby_name) ||
29
27
  f.args.any? { |a| a[:type].include?("Dimname") }
30
28
  end
31
29
 
@@ -34,7 +32,10 @@ module Torch
34
32
  functions.partition do |f|
35
33
  f.args.any? do |a|
36
34
  a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?"].include?(a[:type]) ||
37
- skip_args.any? { |sa| a[:type].include?(sa) }
35
+ skip_args.any? { |sa| a[:type].include?(sa) } ||
36
+ # native_functions.yaml is missing size argument for normal
37
+ # https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
38
+ (f.base_name == "normal" && !f.out?)
38
39
  end
39
40
  end
40
41
 
@@ -119,6 +120,8 @@ void add_%{type}_functions(Module m) {
119
120
  "IntArrayRef"
120
121
  when /Tensor\(\S!?\)/
121
122
  "Tensor &"
123
+ when "str"
124
+ "std::string"
122
125
  else
123
126
  a[:type]
124
127
  end