torch-rb 0.3.7 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +1 -1
- data/codegen/function.rb +134 -0
- data/codegen/generate_functions.rb +546 -0
- data/{lib/torch/native → codegen}/native_functions.yaml +0 -0
- data/ext/torch/ext.cpp +54 -75
- data/ext/torch/extconf.rb +2 -2
- data/ext/torch/nn_functions.h +6 -0
- data/ext/torch/ruby_arg_parser.cpp +593 -0
- data/ext/torch/ruby_arg_parser.h +373 -0
- data/ext/torch/{templates.hpp → templates.h} +30 -51
- data/ext/torch/tensor_functions.h +6 -0
- data/ext/torch/torch_functions.h +6 -0
- data/ext/torch/utils.h +42 -0
- data/ext/torch/{templates.cpp → wrap_outputs.h} +16 -15
- data/lib/torch.rb +0 -62
- data/lib/torch/nn/functional.rb +30 -16
- data/lib/torch/nn/init.rb +5 -19
- data/lib/torch/optim/adadelta.rb +1 -1
- data/lib/torch/optim/adam.rb +2 -2
- data/lib/torch/optim/adamax.rb +1 -1
- data/lib/torch/optim/adamw.rb +1 -1
- data/lib/torch/optim/asgd.rb +1 -1
- data/lib/torch/optim/sgd.rb +3 -3
- data/lib/torch/tensor.rb +25 -105
- data/lib/torch/version.rb +1 -1
- metadata +27 -9
- data/lib/torch/native/dispatcher.rb +0 -70
- data/lib/torch/native/function.rb +0 -200
- data/lib/torch/native/generator.rb +0 -178
- data/lib/torch/native/parser.rb +0 -117
    
        checksums.yaml
    CHANGED
    
    | @@ -1,7 +1,7 @@ | |
| 1 1 | 
             
            ---
         | 
| 2 2 | 
             
            SHA256:
         | 
| 3 | 
            -
              metadata.gz:  | 
| 4 | 
            -
              data.tar.gz:  | 
| 3 | 
            +
              metadata.gz: 6294a5809ed3ea9b3d3f9954b351b8c8e5d81c37e9318ae5044d243320166b20
         | 
| 4 | 
            +
              data.tar.gz: 57d8036a0449497cbf040c8650c4dcdbe3e0f09191951d77e6d907c6b3c1013d
         | 
| 5 5 | 
             
            SHA512:
         | 
| 6 | 
            -
              metadata.gz:  | 
| 7 | 
            -
              data.tar.gz:  | 
| 6 | 
            +
              metadata.gz: c7aaf862b43f72f5ca8e7f6f6aca079195c11420db081f9c4f18f161610598d975fb81a036689dcd3158d200c33e39e67841bded2e728f2a2fb1b5b8d9e9dfad
         | 
| 7 | 
            +
              data.tar.gz: afaf4c344951629248f79bb03246f6879555a58438d07107dea008aa2433d08115d56b1f9464cf20065c276417d3d90a59aecbfefa304208ee81ae313867bad6
         | 
    
        data/CHANGELOG.md
    CHANGED
    
    
    
        data/README.md
    CHANGED
    
    
    
        data/codegen/function.rb
    ADDED
    
    | @@ -0,0 +1,134 @@ | |
| 1 | 
            +
            class Function
         | 
| 2 | 
            +
              attr_reader :definition, :params, :retvals
         | 
| 3 | 
            +
             | 
| 4 | 
            +
              def initialize(definition)
         | 
| 5 | 
            +
                @definition = definition
         | 
| 6 | 
            +
                @params, @retvals = parse_func
         | 
| 7 | 
            +
              end
         | 
| 8 | 
            +
             | 
| 9 | 
            +
              def name
         | 
| 10 | 
            +
                func.split("(", 2).first
         | 
| 11 | 
            +
              end
         | 
| 12 | 
            +
             | 
| 13 | 
            +
              def base_name
         | 
| 14 | 
            +
                name.split(".").first
         | 
| 15 | 
            +
              end
         | 
| 16 | 
            +
             | 
| 17 | 
            +
              def func
         | 
| 18 | 
            +
                definition["func"]
         | 
| 19 | 
            +
              end
         | 
| 20 | 
            +
             | 
| 21 | 
            +
              def python_module
         | 
| 22 | 
            +
                definition["python_module"]
         | 
| 23 | 
            +
              end
         | 
| 24 | 
            +
             | 
| 25 | 
            +
              def variants
         | 
| 26 | 
            +
                (definition["variants"] || "function").split(", ")
         | 
| 27 | 
            +
              end
         | 
| 28 | 
            +
             | 
| 29 | 
            +
              def out_index
         | 
| 30 | 
            +
                params.index { |v| v[:modifier].to_s.include?("!") } if base_name[-1] != "_" && retvals.any?
         | 
| 31 | 
            +
              end
         | 
| 32 | 
            +
             | 
| 33 | 
            +
              def out?
         | 
| 34 | 
            +
                !out_index.nil?
         | 
| 35 | 
            +
              end
         | 
| 36 | 
            +
             | 
| 37 | 
            +
              private
         | 
| 38 | 
            +
             | 
| 39 | 
            +
              def parse_func
         | 
| 40 | 
            +
                input, output = func.split(/\s*->\s*/)
         | 
| 41 | 
            +
                [generate_params(input), generate_retvals(output)]
         | 
| 42 | 
            +
              end
         | 
| 43 | 
            +
             | 
| 44 | 
            +
              def generate_params(input)
         | 
| 45 | 
            +
                input = input.split("(", 2).last.chomp(")").split(/\s*,\s+/)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                keyword_only = false
         | 
| 48 | 
            +
                params = []
         | 
| 49 | 
            +
                input.each do |i|
         | 
| 50 | 
            +
                  if i == "*"
         | 
| 51 | 
            +
                    keyword_only = true
         | 
| 52 | 
            +
                    next
         | 
| 53 | 
            +
                  end
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                  type, name = i.split(/\s+/)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                  if name.include?("=")
         | 
| 58 | 
            +
                    name, default = name.split("=", 2)
         | 
| 59 | 
            +
                  end
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                  optional = false
         | 
| 62 | 
            +
                  if type.include?("?")
         | 
| 63 | 
            +
                    optional = true unless ["dtype", "device", "layout", "pin_memory"].include?(name)
         | 
| 64 | 
            +
                    type = type.delete("?")
         | 
| 65 | 
            +
                  end
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                  type, modifier = extract_modifier(type)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                  if type.include?("[")
         | 
| 70 | 
            +
                    list_size = /\[(.*)\]/.match(type)[1]
         | 
| 71 | 
            +
                    list_size = nil if list_size.empty?
         | 
| 72 | 
            +
                  end
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                  if name == "dtype" && (base_name.start_with?("randperm") || base_name == "tril_indices" || base_name == "triu_indices")
         | 
| 75 | 
            +
                    # dtype hack
         | 
| 76 | 
            +
                    # https://github.com/pytorch/pytorch/blob/v1.6.0/tools/autograd/gen_python_functions.py#L1307-L1311
         | 
| 77 | 
            +
                    default = "torch.int64"
         | 
| 78 | 
            +
                  end
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                  params << {
         | 
| 81 | 
            +
                    name: name,
         | 
| 82 | 
            +
                    type: type,
         | 
| 83 | 
            +
                    default: default,
         | 
| 84 | 
            +
                    keyword_only: keyword_only,
         | 
| 85 | 
            +
                    optional: optional,
         | 
| 86 | 
            +
                    modifier: modifier,
         | 
| 87 | 
            +
                    list_size: list_size
         | 
| 88 | 
            +
                  }
         | 
| 89 | 
            +
                end
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                if (params.map { |v| v[:name] } & ["dtype", "device", "layout", "pin_memory"]).size == 4
         | 
| 92 | 
            +
                  params << {
         | 
| 93 | 
            +
                    name: "requires_grad",
         | 
| 94 | 
            +
                    type: "bool",
         | 
| 95 | 
            +
                    default: "False",
         | 
| 96 | 
            +
                    keyword_only: true,
         | 
| 97 | 
            +
                    optional: false,
         | 
| 98 | 
            +
                    modifier: nil,
         | 
| 99 | 
            +
                    list_size: nil
         | 
| 100 | 
            +
                  }
         | 
| 101 | 
            +
                end
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                params
         | 
| 104 | 
            +
              end
         | 
| 105 | 
            +
             | 
| 106 | 
            +
              def generate_retvals(output)
         | 
| 107 | 
            +
                output =
         | 
| 108 | 
            +
                  if output == "()"
         | 
| 109 | 
            +
                    []
         | 
| 110 | 
            +
                  elsif output[0] == "("
         | 
| 111 | 
            +
                    output[1..-2].split(/\s*,\s*/)
         | 
| 112 | 
            +
                  else
         | 
| 113 | 
            +
                    [output]
         | 
| 114 | 
            +
                  end
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                retvals = []
         | 
| 117 | 
            +
                output.each do |o|
         | 
| 118 | 
            +
                  type, name = o.split(/\s+/)
         | 
| 119 | 
            +
                  type, modifier = extract_modifier(type)
         | 
| 120 | 
            +
                  retvals << {name: name, type: type, modifier: modifier}
         | 
| 121 | 
            +
                end
         | 
| 122 | 
            +
                retvals
         | 
| 123 | 
            +
              end
         | 
| 124 | 
            +
             | 
| 125 | 
            +
              # Tensor(a), Tensor(a!), Tensor(a)[]
         | 
| 126 | 
            +
              def extract_modifier(type)
         | 
| 127 | 
            +
                if type.include?("(")
         | 
| 128 | 
            +
                  parts = type.split(/[\(\)]/, 3)
         | 
| 129 | 
            +
                  modifier = parts.delete_at(1)
         | 
| 130 | 
            +
                  type = parts.join("")
         | 
| 131 | 
            +
                end
         | 
| 132 | 
            +
                [type, modifier]
         | 
| 133 | 
            +
              end
         | 
| 134 | 
            +
            end
         | 
| @@ -0,0 +1,546 @@ | |
| 1 | 
            +
            require "yaml"
         | 
| 2 | 
            +
            # use require_relative for
         | 
| 3 | 
            +
            # rake generate:function (without bundle)
         | 
| 4 | 
            +
            require_relative "function"
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            def generate_functions
         | 
| 7 | 
            +
              functions = load_functions
         | 
| 8 | 
            +
              functions = skip_functions(functions)
         | 
| 9 | 
            +
              functions = group_functions(functions)
         | 
| 10 | 
            +
             | 
| 11 | 
            +
              generate_files("torch", :define_singleton_method, functions[:torch])
         | 
| 12 | 
            +
              generate_files("tensor", :define_method, functions[:tensor])
         | 
| 13 | 
            +
              generate_files("nn", :define_singleton_method, functions[:nn])
         | 
| 14 | 
            +
            end
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            def load_functions
         | 
| 17 | 
            +
              path = File.expand_path("native_functions.yaml", __dir__)
         | 
| 18 | 
            +
              YAML.load_file(path).map { |f| Function.new(f) }.sort_by(&:name)
         | 
| 19 | 
            +
            end
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            def skip_functions(functions)
         | 
| 22 | 
            +
              functions.reject do |f|
         | 
| 23 | 
            +
                f.base_name.start_with?("_") ||
         | 
| 24 | 
            +
                f.base_name.include?("_backward") ||
         | 
| 25 | 
            +
                f.base_name.include?("_forward") ||
         | 
| 26 | 
            +
                f.base_name == "to" ||
         | 
| 27 | 
            +
                # in ext.cpp
         | 
| 28 | 
            +
                f.base_name == "index" ||
         | 
| 29 | 
            +
                f.base_name == "index_put_" ||
         | 
| 30 | 
            +
                # need to add to ext.cpp
         | 
| 31 | 
            +
                f.base_name == "index_put" ||
         | 
| 32 | 
            +
                # not supported yet
         | 
| 33 | 
            +
                f.func.include?("Dimname") ||
         | 
| 34 | 
            +
                f.func.include?("ConstQuantizerPtr")
         | 
| 35 | 
            +
              end
         | 
| 36 | 
            +
            end
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            def group_functions(functions)
         | 
| 39 | 
            +
              nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" }
         | 
| 40 | 
            +
              torch_functions = other_functions.select { |f| f.variants.include?("function") }
         | 
| 41 | 
            +
              tensor_functions = other_functions.select { |f| f.variants.include?("method") }
         | 
| 42 | 
            +
             | 
| 43 | 
            +
              {torch: torch_functions, tensor: tensor_functions, nn: nn_functions}
         | 
| 44 | 
            +
            end
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            def generate_files(type, def_method, functions)
         | 
| 47 | 
            +
              method_defs = []
         | 
| 48 | 
            +
              attach_defs = []
         | 
| 49 | 
            +
              functions.group_by(&:base_name).each do |name, grouped_functions|
         | 
| 50 | 
            +
                method_defs << generate_method_def(name, grouped_functions, type, def_method)
         | 
| 51 | 
            +
                attach_defs << generate_attach_def(name, type, def_method)
         | 
| 52 | 
            +
              end
         | 
| 53 | 
            +
              write_header(type)
         | 
| 54 | 
            +
              write_body(type, method_defs, attach_defs)
         | 
| 55 | 
            +
            end
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            def write_header(type)
         | 
| 58 | 
            +
              template = <<~EOS
         | 
| 59 | 
            +
                // generated by rake generate:functions
         | 
| 60 | 
            +
                // do not edit by hand
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                #pragma once
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                void add_%{type}_functions(Module m);
         | 
| 65 | 
            +
              EOS
         | 
| 66 | 
            +
             | 
| 67 | 
            +
              contents = template % {type: type}
         | 
| 68 | 
            +
              write_file("#{type}_functions.h", contents)
         | 
| 69 | 
            +
            end
         | 
| 70 | 
            +
             | 
| 71 | 
            +
            def write_body(type, method_defs, attach_defs)
         | 
| 72 | 
            +
              cuda_lazy_init = %{\n#include "torch/csrc/utils/cuda_lazy_init.h"\n} unless type == "nn"
         | 
| 73 | 
            +
             | 
| 74 | 
            +
              template = <<~EOS
         | 
| 75 | 
            +
                // generated by rake generate:functions
         | 
| 76 | 
            +
                // do not edit by hand
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                #include <torch/torch.h>
         | 
| 79 | 
            +
                #include <rice/Module.hpp>
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                #include "ruby_arg_parser.h"
         | 
| 82 | 
            +
                #include "templates.h"
         | 
| 83 | 
            +
                #include "wrap_outputs.h"
         | 
| 84 | 
            +
                %{cuda_lazy_init}
         | 
| 85 | 
            +
                %{method_defs}
         | 
| 86 | 
            +
                void add_%{type}_functions(Module m) {
         | 
| 87 | 
            +
                  %{attach_defs}
         | 
| 88 | 
            +
                }
         | 
| 89 | 
            +
              EOS
         | 
| 90 | 
            +
             | 
| 91 | 
            +
              contents = template % {
         | 
| 92 | 
            +
                type: type,
         | 
| 93 | 
            +
                method_defs: method_defs.join("\n"),
         | 
| 94 | 
            +
                attach_defs: attach_defs.join("\n  "),
         | 
| 95 | 
            +
                cuda_lazy_init: cuda_lazy_init
         | 
| 96 | 
            +
              }
         | 
| 97 | 
            +
              write_file("#{type}_functions.cpp", contents)
         | 
| 98 | 
            +
            end
         | 
| 99 | 
            +
             | 
| 100 | 
            +
            def write_file(name, contents)
         | 
| 101 | 
            +
              path = File.expand_path("../ext/torch", __dir__)
         | 
| 102 | 
            +
              File.write(File.join(path, name), contents)
         | 
| 103 | 
            +
            end
         | 
| 104 | 
            +
             | 
| 105 | 
            +
            def generate_attach_def(name, type, def_method)
         | 
| 106 | 
            +
              ruby_name =
         | 
| 107 | 
            +
                if name.end_with?("_")
         | 
| 108 | 
            +
                  "#{name[0..-2]}!"
         | 
| 109 | 
            +
                elsif name.start_with?("is_")
         | 
| 110 | 
            +
                  "#{name[3..-1]}?"
         | 
| 111 | 
            +
                else
         | 
| 112 | 
            +
                  name
         | 
| 113 | 
            +
                end
         | 
| 114 | 
            +
             | 
| 115 | 
            +
              ruby_name = "_#{ruby_name}" if ["size", "stride", "random!"].include?(ruby_name)
         | 
| 116 | 
            +
             | 
| 117 | 
            +
              "rb_#{def_method}(m, \"#{ruby_name}\", #{type}_#{name}, -1);"
         | 
| 118 | 
            +
            end
         | 
| 119 | 
            +
             | 
| 120 | 
            +
            def generate_method_def(name, functions, type, def_method)
         | 
| 121 | 
            +
              assign_self = type == "tensor" ? "\n  Tensor& self = from_ruby<Tensor&>(self_);" : ""
         | 
| 122 | 
            +
             | 
| 123 | 
            +
              functions = group_overloads(functions, type)
         | 
| 124 | 
            +
              signatures = functions.map { |f| f["signature"] }
         | 
| 125 | 
            +
              max_args = signatures.map { |s| s.count(",") - s.count("*") }.max + 1
         | 
| 126 | 
            +
             | 
| 127 | 
            +
              template = <<~EOS
         | 
| 128 | 
            +
                // #{name}
         | 
| 129 | 
            +
                static VALUE #{type}_#{name}(int argc, VALUE* argv, VALUE self_)
         | 
| 130 | 
            +
                {
         | 
| 131 | 
            +
                  HANDLE_TH_ERRORS#{assign_self}
         | 
| 132 | 
            +
                  static RubyArgParser parser({
         | 
| 133 | 
            +
                    #{signatures.map(&:inspect).join(",\n    ")}
         | 
| 134 | 
            +
                  });
         | 
| 135 | 
            +
                  std::vector<VALUE> parsed_args(#{max_args});
         | 
| 136 | 
            +
                  auto _r = parser.parse(self_, argc, argv, parsed_args);
         | 
| 137 | 
            +
                  #{add_dispatches(functions, def_method)}
         | 
| 138 | 
            +
                  END_HANDLE_TH_ERRORS
         | 
| 139 | 
            +
                }
         | 
| 140 | 
            +
              EOS
         | 
| 141 | 
            +
            end
         | 
| 142 | 
            +
             | 
| 143 | 
            +
            def indent(code)
         | 
| 144 | 
            +
              code.split("\n").join("\n  ")
         | 
| 145 | 
            +
            end
         | 
| 146 | 
            +
             | 
| 147 | 
            +
            def add_dispatches(functions, def_method)
         | 
| 148 | 
            +
              if functions.size == 1
         | 
| 149 | 
            +
                add_dispatch(functions.first, def_method)
         | 
| 150 | 
            +
              else
         | 
| 151 | 
            +
                body = []
         | 
| 152 | 
            +
                functions.each_with_index do |f, i|
         | 
| 153 | 
            +
                  body << "case #{i}: {
         | 
| 154 | 
            +
                  #{add_dispatch(f, def_method).split("\n").join("\n    ")}
         | 
| 155 | 
            +
                }"
         | 
| 156 | 
            +
                end
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                "switch (_r.idx) {
         | 
| 159 | 
            +
                #{body.join("\n    ")}
         | 
| 160 | 
            +
              }
         | 
| 161 | 
            +
              RETURN_NIL"
         | 
| 162 | 
            +
              end
         | 
| 163 | 
            +
            end
         | 
| 164 | 
            +
             | 
| 165 | 
            +
            def add_dispatch(function, def_method)
         | 
| 166 | 
            +
              if function["out"] && function["out"] != function["base"]
         | 
| 167 | 
            +
                base_code = generate_dispatch(function["base"], def_method)
         | 
| 168 | 
            +
                out_code = generate_dispatch(function["out"], def_method)
         | 
| 169 | 
            +
                out_index = function["out"].out_index
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                return "if (_r.isNone(#{out_index})) {
         | 
| 172 | 
            +
                #{indent(base_code)}
         | 
| 173 | 
            +
              } else {
         | 
| 174 | 
            +
                #{indent(out_code)}
         | 
| 175 | 
            +
              }"
         | 
| 176 | 
            +
              else
         | 
| 177 | 
            +
                generate_dispatch(function["base"], def_method)
         | 
| 178 | 
            +
              end
         | 
| 179 | 
            +
            end
         | 
| 180 | 
            +
             | 
| 181 | 
            +
            def group_overloads(functions, type)
         | 
| 182 | 
            +
              grouped = Hash.new { |hash, key| hash[key] = {} }
         | 
| 183 | 
            +
             | 
| 184 | 
            +
              functions.each do |function|
         | 
| 185 | 
            +
                signature = generate_signature(function, type, skip_out: true)
         | 
| 186 | 
            +
                v = grouped[signature]
         | 
| 187 | 
            +
                if function.out?
         | 
| 188 | 
            +
                  v["out"] = function
         | 
| 189 | 
            +
                  v["signature"] = generate_signature(function, type)
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                  # for now
         | 
| 192 | 
            +
                  v["base"] ||= function
         | 
| 193 | 
            +
                else
         | 
| 194 | 
            +
                  v["base"] = function
         | 
| 195 | 
            +
                  v["signature"] ||= signature
         | 
| 196 | 
            +
                end
         | 
| 197 | 
            +
              end
         | 
| 198 | 
            +
             | 
| 199 | 
            +
              puts "Missing base: #{functions.first.name}" if grouped.any? { |_, v| !v["base"] }
         | 
| 200 | 
            +
              sort_functions(grouped.values)
         | 
| 201 | 
            +
            end
         | 
| 202 | 
            +
             | 
| 203 | 
            +
            def sort_functions(functions)
         | 
| 204 | 
            +
              # TODO
         | 
| 205 | 
            +
              functions.sort_by { |f| f["out"] ? 1 : 0 }
         | 
| 206 | 
            +
            end
         | 
| 207 | 
            +
             | 
| 208 | 
            +
            def generate_dispatch(function, def_method)
         | 
| 209 | 
            +
              cpp_name = function.base_name
         | 
| 210 | 
            +
              cpp_name += "_out" if function.out?
         | 
| 211 | 
            +
             | 
| 212 | 
            +
              remove_self = def_method == :define_method
         | 
| 213 | 
            +
             | 
| 214 | 
            +
              params = function.params.map(&:dup)
         | 
| 215 | 
            +
              set_param_position(params, remove_self)
         | 
| 216 | 
            +
              params, opt_params = split_opt_params(params)
         | 
| 217 | 
            +
              opt_index = opt_params.map { |v| v[:position] }.min if opt_params.any?
         | 
| 218 | 
            +
             | 
| 219 | 
            +
              cpp_params = generate_dispatch_params(function, params)
         | 
| 220 | 
            +
              if opt_index
         | 
| 221 | 
            +
                cpp_params.insert(remove_self ? opt_index + 1 : opt_index, "const TensorOptions & options")
         | 
| 222 | 
            +
              end
         | 
| 223 | 
            +
             | 
| 224 | 
            +
              retval = generate_dispatch_retval(function)
         | 
| 225 | 
            +
              dispatch_code = generate_dispatch_code(function, def_method, params, opt_index, remove_self)
         | 
| 226 | 
            +
              function_code = generate_function_code(function, cpp_name, params, opt_index, remove_self)
         | 
| 227 | 
            +
             | 
| 228 | 
            +
              out_var = generate_out_var(function.out_index, function.retvals.size) if function.out? && function.retvals.size > 1 && function.retvals.all? { |v| v[:type] == "Tensor" }
         | 
| 229 | 
            +
              tensor_options = generate_tensor_options(function, opt_params) if opt_params.any?
         | 
| 230 | 
            +
             | 
| 231 | 
            +
              "// #{function.func}#{tensor_options}#{out_var}
         | 
| 232 | 
            +
              auto dispatch_#{cpp_name} = [](#{cpp_params.join(", ")}) -> #{retval} {
         | 
| 233 | 
            +
                // in future, release GVL
         | 
| 234 | 
            +
                #{dispatch_code}
         | 
| 235 | 
            +
              };
         | 
| 236 | 
            +
              #{function_code}"
         | 
| 237 | 
            +
            end
         | 
| 238 | 
            +
             | 
| 239 | 
            +
            def generate_out_var(out_index, size)
         | 
| 240 | 
            +
              "\n  auto out = _r.tensorlist_n<#{size}>(#{out_index});"
         | 
| 241 | 
            +
            end
         | 
| 242 | 
            +
             | 
| 243 | 
            +
            def set_param_position(params, remove_self)
         | 
| 244 | 
            +
              i = 0
         | 
| 245 | 
            +
              params.each do |v|
         | 
| 246 | 
            +
                next if remove_self && v[:name] == "self"
         | 
| 247 | 
            +
                v[:position] = i
         | 
| 248 | 
            +
                i += 1
         | 
| 249 | 
            +
              end
         | 
| 250 | 
            +
            end
         | 
| 251 | 
            +
             | 
| 252 | 
            +
            def split_opt_params(params)
         | 
| 253 | 
            +
              option_names = ["dtype", "device", "layout", "requires_grad", "pin_memory"]
         | 
| 254 | 
            +
             | 
| 255 | 
            +
              opt_params, other_params = params.partition { |v, i| option_names.include?(v[:name]) }
         | 
| 256 | 
            +
              if opt_params.size >= 4
         | 
| 257 | 
            +
                [other_params, opt_params]
         | 
| 258 | 
            +
              else
         | 
| 259 | 
            +
                [params, []]
         | 
| 260 | 
            +
              end
         | 
| 261 | 
            +
            end
         | 
| 262 | 
            +
             | 
| 263 | 
            +
            def generate_tensor_options(function, opt_params)
         | 
| 264 | 
            +
              code = "\n  const auto options = TensorOptions()"
         | 
| 265 | 
            +
              order = ["dtype", "device", "layout", "requires_grad", "pin_memory"]
         | 
| 266 | 
            +
              opt_params.sort_by { |v| order.index(v[:name]) }.each do |opt|
         | 
| 267 | 
            +
                i = opt[:position]
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                c =
         | 
| 270 | 
            +
                  case opt[:name]
         | 
| 271 | 
            +
                  when "dtype"
         | 
| 272 | 
            +
                    if function.base_name == "arange"
         | 
| 273 | 
            +
                      "dtype(_r.scalartypeOptional(#{i}))"
         | 
| 274 | 
            +
                    else
         | 
| 275 | 
            +
                      "dtype(_r.scalartype(#{i}))"
         | 
| 276 | 
            +
                    end
         | 
| 277 | 
            +
                  when "device"
         | 
| 278 | 
            +
                    "device(_r.device(#{i}))"
         | 
| 279 | 
            +
                  when "layout"
         | 
| 280 | 
            +
                    "layout(_r.layoutOptional(#{i}))"
         | 
| 281 | 
            +
                  when "requires_grad"
         | 
| 282 | 
            +
                    "requires_grad(_r.toBool(#{i}))"
         | 
| 283 | 
            +
                  when "pin_memory"
         | 
| 284 | 
            +
                    "pinned_memory(_r.toBool(#{i}))"
         | 
| 285 | 
            +
                  end
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                code += "\n      .#{c}"
         | 
| 288 | 
            +
              end
         | 
| 289 | 
            +
             | 
| 290 | 
            +
              "#{code};\n  torch::utils::maybe_initialize_cuda(options);"
         | 
| 291 | 
            +
            end
         | 
| 292 | 
            +
             | 
| 293 | 
            +
            def generate_function_code(function, cpp_name, params, opt_index, remove_self)
         | 
| 294 | 
            +
              params = generate_function_params(function, params, remove_self)
         | 
| 295 | 
            +
              if opt_index
         | 
| 296 | 
            +
                opt_index += 1 if remove_self
         | 
| 297 | 
            +
                params.insert(opt_index, "options")
         | 
| 298 | 
            +
              end
         | 
| 299 | 
            +
             | 
| 300 | 
            +
              code = "dispatch_#{cpp_name}(#{params.join(", ")})"
         | 
| 301 | 
            +
              if function.retvals.empty?
         | 
| 302 | 
            +
                "#{code};\nRETURN_NIL"
         | 
| 303 | 
            +
              else
         | 
| 304 | 
            +
                "return wrap(#{code});"
         | 
| 305 | 
            +
              end
         | 
| 306 | 
            +
            end
         | 
| 307 | 
            +
             | 
| 308 | 
            +
            def generate_function_params(function, params, remove_self)
         | 
| 309 | 
            +
              out_var = function.out? && function.retvals.size > 1 && function.retvals.all? { |v| v[:type] == "Tensor" }
         | 
| 310 | 
            +
             | 
| 311 | 
            +
              i = 0
         | 
| 312 | 
            +
              params.map do |param|
         | 
| 313 | 
            +
                i += 1
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                next "self" if remove_self && param[:name] == "self"
         | 
| 316 | 
            +
                if out_var && i > function.out_index
         | 
| 317 | 
            +
                  next "out[#{i - function.out_index - 1}]"
         | 
| 318 | 
            +
                end
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                func =
         | 
| 321 | 
            +
                  case param[:type]
         | 
| 322 | 
            +
                  when "Tensor"
         | 
| 323 | 
            +
                    "tensor"
         | 
| 324 | 
            +
                  when "Tensor[]"
         | 
| 325 | 
            +
                    "tensorlist"
         | 
| 326 | 
            +
                  when /\Aint\[/
         | 
| 327 | 
            +
                    "intlist"
         | 
| 328 | 
            +
                  when "Scalar"
         | 
| 329 | 
            +
                    "scalar"
         | 
| 330 | 
            +
                  when "bool"
         | 
| 331 | 
            +
                    "toBool"
         | 
| 332 | 
            +
                  when "int"
         | 
| 333 | 
            +
                    "toInt64"
         | 
| 334 | 
            +
                  when "float"
         | 
| 335 | 
            +
                    "toDouble"
         | 
| 336 | 
            +
                  when "ScalarType"
         | 
| 337 | 
            +
                    "scalartype"
         | 
| 338 | 
            +
                  when "str"
         | 
| 339 | 
            +
                    "string"
         | 
| 340 | 
            +
                  when "Generator"
         | 
| 341 | 
            +
                    "generator"
         | 
| 342 | 
            +
                  when "MemoryFormat"
         | 
| 343 | 
            +
                    "memoryformat"
         | 
| 344 | 
            +
                  when "Storage"
         | 
| 345 | 
            +
                    "storage"
         | 
| 346 | 
            +
                  else
         | 
| 347 | 
            +
                    raise "Unknown type: #{param[:type]} (#{function.name})"
         | 
| 348 | 
            +
                  end
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                if param[:optional]
         | 
| 351 | 
            +
                  func =
         | 
| 352 | 
            +
                    case func
         | 
| 353 | 
            +
                    when "tensor"
         | 
| 354 | 
            +
                      if function.out?
         | 
| 355 | 
            +
                        "tensor"
         | 
| 356 | 
            +
                      else
         | 
| 357 | 
            +
                        "optionalTensor"
         | 
| 358 | 
            +
                      end
         | 
| 359 | 
            +
                    when "generator", "tensorlist", "intlist"
         | 
| 360 | 
            +
                      func
         | 
| 361 | 
            +
                    else
         | 
| 362 | 
            +
                      "#{func}Optional"
         | 
| 363 | 
            +
                    end
         | 
| 364 | 
            +
                  end
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                "_r.#{func}(#{param[:position]})"
         | 
| 367 | 
            +
              end
         | 
| 368 | 
            +
            end
         | 
| 369 | 
            +
             | 
| 370 | 
            +
            def generate_dispatch_code(function, def_method, params, opt_index, remove_self)
         | 
| 371 | 
            +
              # torch::empty sets requires_grad by at::empty doesn't
         | 
| 372 | 
            +
              # https://github.com/pytorch/pytorch/issues/36455
         | 
| 373 | 
            +
              prefix = remove_self ? "self." : (opt_index ? "torch::" : "at::")
         | 
| 374 | 
            +
              dispatch = function.out? ? "#{function.base_name}_out" : function.base_name
         | 
| 375 | 
            +
             | 
| 376 | 
            +
              params = params.map { |v| v[:name] }
         | 
| 377 | 
            +
              params.reject! { |v| v == "self" } if remove_self
         | 
| 378 | 
            +
              params.insert(opt_index, "options") if opt_index
         | 
| 379 | 
            +
             | 
| 380 | 
            +
              if function.out_index
         | 
| 381 | 
            +
                params.unshift(params.slice!(function.out_index, function.retvals.size))
         | 
| 382 | 
            +
              end
         | 
| 383 | 
            +
             | 
| 384 | 
            +
              code = "#{prefix}#{dispatch}(#{params.join(", ")});"
         | 
| 385 | 
            +
              code = "return #{code}" unless function.retvals.empty?
         | 
| 386 | 
            +
              code
         | 
| 387 | 
            +
            end
         | 
| 388 | 
            +
             | 
| 389 | 
            +
            def generate_dispatch_params(function, params)
         | 
| 390 | 
            +
              params.map do |param|
         | 
| 391 | 
            +
                type =
         | 
| 392 | 
            +
                  case param[:type]
         | 
| 393 | 
            +
                  when "Tensor"
         | 
| 394 | 
            +
                    if param[:optional]
         | 
| 395 | 
            +
                      if function.out?
         | 
| 396 | 
            +
                        "const Tensor &"
         | 
| 397 | 
            +
                      else
         | 
| 398 | 
            +
                        # TODO
         | 
| 399 | 
            +
                        # "const c10::optional<at::Tensor> &"
         | 
| 400 | 
            +
                        "const OptionalTensor &"
         | 
| 401 | 
            +
                      end
         | 
| 402 | 
            +
                    elsif param[:modifier]
         | 
| 403 | 
            +
                      if param[:modifier].include?("!") && function.retvals.size > 1
         | 
| 404 | 
            +
                        "Tensor &"
         | 
| 405 | 
            +
                      else
         | 
| 406 | 
            +
                        "Tensor"
         | 
| 407 | 
            +
                      end
         | 
| 408 | 
            +
                    else
         | 
| 409 | 
            +
                      "const Tensor &"
         | 
| 410 | 
            +
                    end
         | 
| 411 | 
            +
                  when "Tensor[]"
         | 
| 412 | 
            +
                    "TensorList"
         | 
| 413 | 
            +
                  when "int"
         | 
| 414 | 
            +
                    "int64_t"
         | 
| 415 | 
            +
                  when "float"
         | 
| 416 | 
            +
                    "double"
         | 
| 417 | 
            +
                  when /\Aint\[/
         | 
| 418 | 
            +
                    "IntArrayRef"
         | 
| 419 | 
            +
                  when "str"
         | 
| 420 | 
            +
                    "std::string"
         | 
| 421 | 
            +
                  when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage"
         | 
| 422 | 
            +
                    param[:type]
         | 
| 423 | 
            +
                  else
         | 
| 424 | 
            +
                    raise "Unknown type: #{param[:type]} (#{function.name})"
         | 
| 425 | 
            +
                  end
         | 
| 426 | 
            +
             | 
| 427 | 
            +
                if param[:optional] && param[:type] != "Tensor"
         | 
| 428 | 
            +
                  type = "c10::optional<#{type}>"
         | 
| 429 | 
            +
                end
         | 
| 430 | 
            +
             | 
| 431 | 
            +
                "#{type} #{param[:name]}"
         | 
| 432 | 
            +
              end
         | 
| 433 | 
            +
            end
         | 
| 434 | 
            +
             | 
| 435 | 
            +
            def generate_dispatch_retval(function)
         | 
| 436 | 
            +
              types = function.retvals.map { |r| r[:type] }
         | 
| 437 | 
            +
             | 
| 438 | 
            +
              case types
         | 
| 439 | 
            +
              when []
         | 
| 440 | 
            +
                "void"
         | 
| 441 | 
            +
              when ["bool"]
         | 
| 442 | 
            +
                "bool"
         | 
| 443 | 
            +
              when ["int"]
         | 
| 444 | 
            +
                "int64_t"
         | 
| 445 | 
            +
              when ["float"]
         | 
| 446 | 
            +
                "double"
         | 
| 447 | 
            +
              when ["Scalar"]
         | 
| 448 | 
            +
                "Scalar"
         | 
| 449 | 
            +
              when ["ScalarType"]
         | 
| 450 | 
            +
                "ScalarType"
         | 
| 451 | 
            +
              when ["QScheme"]
         | 
| 452 | 
            +
                "QScheme"
         | 
| 453 | 
            +
              when ["Tensor"]
         | 
| 454 | 
            +
                "Tensor"
         | 
| 455 | 
            +
              when ["Tensor[]"]
         | 
| 456 | 
            +
                "std::vector<Tensor>"
         | 
| 457 | 
            +
              when ["Tensor", "Tensor"]
         | 
| 458 | 
            +
                "std::tuple<Tensor,Tensor>"
         | 
| 459 | 
            +
              when ["Tensor", "Tensor", "Tensor"]
         | 
| 460 | 
            +
                "std::tuple<Tensor,Tensor,Tensor>"
         | 
| 461 | 
            +
              when ["Tensor", "Tensor", "Tensor", "Tensor"]
         | 
| 462 | 
            +
                "std::tuple<Tensor,Tensor,Tensor,Tensor>"
         | 
| 463 | 
            +
              when ["Tensor", "Tensor", "Tensor", "Tensor", "Tensor"]
         | 
| 464 | 
            +
                "std::tuple<Tensor,Tensor,Tensor,Tensor,Tensor>"
         | 
| 465 | 
            +
              when ["Tensor", "Tensor", "float", "int"]
         | 
| 466 | 
            +
                "std::tuple<Tensor,Tensor,float,int>"
         | 
| 467 | 
            +
              else
         | 
| 468 | 
            +
                raise "Unknown retvals: #{types}"
         | 
| 469 | 
            +
              end
         | 
| 470 | 
            +
            end
         | 
| 471 | 
            +
             | 
| 472 | 
            +
            def generate_signature(function, type, skip_out: false)
         | 
| 473 | 
            +
              params = function.params.dup
         | 
| 474 | 
            +
              if function.out?
         | 
| 475 | 
            +
                if skip_out
         | 
| 476 | 
            +
                  # remove out
         | 
| 477 | 
            +
                  params.slice!(function.out_index, function.retvals.size)
         | 
| 478 | 
            +
                elsif function.retvals.size > 1 && params[function.out_index, function.retvals.size].all? { |r| r[:type] == "Tensor" }
         | 
| 479 | 
            +
                  # combine tensor into tensorlist
         | 
| 480 | 
            +
                  list_size = function.retvals.size
         | 
| 481 | 
            +
                  params.slice!(function.out_index, list_size)
         | 
| 482 | 
            +
                  params.insert(function.out_index, {name: "out", type: "Tensor[#{list_size}]", list_size: list_size, keyword_only: true})
         | 
| 483 | 
            +
                end
         | 
| 484 | 
            +
              end
         | 
| 485 | 
            +
             | 
| 486 | 
            +
              parts = params.select { |v| !v[:keyword_only] && !(type == "tensor" && v[:name] == "self") }
         | 
| 487 | 
            +
              keyword_only_parts = params.select { |v| v[:keyword_only] }
         | 
| 488 | 
            +
              if keyword_only_parts.any?
         | 
| 489 | 
            +
                parts << "*"
         | 
| 490 | 
            +
                parts.concat(keyword_only_parts)
         | 
| 491 | 
            +
              end
         | 
| 492 | 
            +
             | 
| 493 | 
            +
              "#{function.base_name}(#{parts.map { |v| signature_param(v) }.join(", ")})"
         | 
| 494 | 
            +
            end
         | 
| 495 | 
            +
             | 
| 496 | 
            +
            def signature_param(param)
         | 
| 497 | 
            +
              return "*" if param == "*"
         | 
| 498 | 
            +
             | 
| 499 | 
            +
              name = param[:name]
         | 
| 500 | 
            +
              name = "input" if name == "self"
         | 
| 501 | 
            +
             | 
| 502 | 
            +
              sig = "#{signature_type(param)} #{name}"
         | 
| 503 | 
            +
              case param[:default]
         | 
| 504 | 
            +
              when nil
         | 
| 505 | 
            +
                # do nothing
         | 
| 506 | 
            +
              when "[]"
         | 
| 507 | 
            +
                sig += "=None"
         | 
| 508 | 
            +
              when "Mean"
         | 
| 509 | 
            +
                sig += "=at::Reduction::Mean"
         | 
| 510 | 
            +
              else
         | 
| 511 | 
            +
                sig += "=#{param[:default]}"
         | 
| 512 | 
            +
              end
         | 
| 513 | 
            +
             | 
| 514 | 
            +
              # hack
         | 
| 515 | 
            +
              sig += "=None" if param[:name] == "out"
         | 
| 516 | 
            +
             | 
| 517 | 
            +
              sig
         | 
| 518 | 
            +
            end
         | 
| 519 | 
            +
             | 
| 520 | 
            +
            def signature_type(param)
         | 
| 521 | 
            +
              type =
         | 
| 522 | 
            +
                case param[:type]
         | 
| 523 | 
            +
                when "Tensor", /\ATensor\([a-z]!?\)\z/
         | 
| 524 | 
            +
                  "Tensor"
         | 
| 525 | 
            +
                when /\Tensor\[\d*\]\z/
         | 
| 526 | 
            +
                  "TensorList"
         | 
| 527 | 
            +
                when /\ADimname\[\d*\]\z/
         | 
| 528 | 
            +
                  "DirnameList"
         | 
| 529 | 
            +
                when /\Aint\[\d*\]\z/
         | 
| 530 | 
            +
                  "IntArrayRef"
         | 
| 531 | 
            +
                when "int"
         | 
| 532 | 
            +
                  "int64_t"
         | 
| 533 | 
            +
                when "float"
         | 
| 534 | 
            +
                  "double"
         | 
| 535 | 
            +
                when "str"
         | 
| 536 | 
            +
                  "std::string"
         | 
| 537 | 
            +
                when "Scalar", "Dimname", "bool", "ScalarType", "Layout", "Device", "Generator", "MemoryFormat", "Storage"
         | 
| 538 | 
            +
                  param[:type]
         | 
| 539 | 
            +
                else
         | 
| 540 | 
            +
                  raise "Unknown type: #{param[:type]}"
         | 
| 541 | 
            +
                end
         | 
| 542 | 
            +
             | 
| 543 | 
            +
              type += "[#{param[:list_size]}]" if param[:list_size]
         | 
| 544 | 
            +
              type += "?" if param[:optional]
         | 
| 545 | 
            +
              type
         | 
| 546 | 
            +
            end
         |