torch-rb 0.3.2 → 0.3.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +28 -0
- data/README.md +7 -2
- data/ext/torch/ext.cpp +60 -20
- data/ext/torch/extconf.rb +3 -0
- data/ext/torch/templates.cpp +36 -0
- data/ext/torch/templates.hpp +81 -87
- data/lib/torch.rb +71 -19
- data/lib/torch/native/dispatcher.rb +30 -8
- data/lib/torch/native/function.rb +93 -4
- data/lib/torch/native/generator.rb +45 -41
- data/lib/torch/native/parser.rb +57 -76
- data/lib/torch/nn/functional.rb +112 -2
- data/lib/torch/nn/leaky_relu.rb +3 -3
- data/lib/torch/nn/module.rb +9 -1
- data/lib/torch/nn/upsample.rb +31 -0
- data/lib/torch/tensor.rb +45 -51
- data/lib/torch/utils/data/data_loader.rb +2 -0
- data/lib/torch/utils/data/tensor_dataset.rb +2 -0
- data/lib/torch/version.rb +1 -1
- metadata +3 -2
    
        data/lib/torch.rb
    CHANGED
    
    | @@ -174,6 +174,9 @@ require "torch/nn/smooth_l1_loss" | |
| 174 174 | 
             
            require "torch/nn/soft_margin_loss"
         | 
| 175 175 | 
             
            require "torch/nn/triplet_margin_loss"
         | 
| 176 176 |  | 
| 177 | 
            +
            # nn vision
         | 
| 178 | 
            +
            require "torch/nn/upsample"
         | 
| 179 | 
            +
             | 
| 177 180 | 
             
            # nn other
         | 
| 178 181 | 
             
            require "torch/nn/functional"
         | 
| 179 182 | 
             
            require "torch/nn/init"
         | 
| @@ -196,6 +199,32 @@ module Torch | |
| 196 199 | 
             
                end
         | 
| 197 200 | 
             
              end
         | 
| 198 201 |  | 
| 202 | 
            +
              # legacy
         | 
| 203 | 
            +
              # but may make it easier to port tutorials
         | 
| 204 | 
            +
              module Autograd
         | 
| 205 | 
            +
                class Variable
         | 
| 206 | 
            +
                  def self.new(x)
         | 
| 207 | 
            +
                    raise ArgumentError, "Variable data has to be a tensor, but got #{x.class.name}" unless x.is_a?(Tensor)
         | 
| 208 | 
            +
                    warn "[torch] The Variable API is deprecated. Use tensors with requires_grad: true instead."
         | 
| 209 | 
            +
                    x
         | 
| 210 | 
            +
                  end
         | 
| 211 | 
            +
                end
         | 
| 212 | 
            +
              end
         | 
| 213 | 
            +
             | 
| 214 | 
            +
              # TODO move to C++
         | 
| 215 | 
            +
              class ByteStorage
         | 
| 216 | 
            +
                # private
         | 
| 217 | 
            +
                attr_reader :bytes
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                def initialize(bytes)
         | 
| 220 | 
            +
                  @bytes = bytes
         | 
| 221 | 
            +
                end
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                def self.from_buffer(bytes)
         | 
| 224 | 
            +
                  new(bytes)
         | 
| 225 | 
            +
                end
         | 
| 226 | 
            +
              end
         | 
| 227 | 
            +
             | 
| 199 228 | 
             
              # keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
         | 
| 200 229 | 
             
              # values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
         | 
| 201 230 | 
             
              DTYPE_TO_ENUM = {
         | 
| @@ -224,40 +253,43 @@ module Torch | |
| 224 253 | 
             
              }
         | 
| 225 254 | 
             
              ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
         | 
| 226 255 |  | 
| 256 | 
            +
              TENSOR_TYPE_CLASSES = []
         | 
| 257 | 
            +
             | 
| 227 258 | 
             
              def self._make_tensor_class(dtype, cuda = false)
         | 
| 228 259 | 
             
                cls = Class.new
         | 
| 229 260 | 
             
                device = cuda ? "cuda" : "cpu"
         | 
| 230 261 | 
             
                cls.define_singleton_method("new") do |*args|
         | 
| 231 262 | 
             
                  if args.size == 1 && args.first.is_a?(Tensor)
         | 
| 232 263 | 
             
                    args.first.send(dtype).to(device)
         | 
| 264 | 
            +
                  elsif args.size == 1 && args.first.is_a?(ByteStorage) && dtype == :uint8
         | 
| 265 | 
            +
                    bytes = args.first.bytes
         | 
| 266 | 
            +
                    Torch._from_blob(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
         | 
| 233 267 | 
             
                  elsif args.size == 1 && args.first.is_a?(Array)
         | 
| 234 268 | 
             
                    Torch.tensor(args.first, dtype: dtype, device: device)
         | 
| 235 269 | 
             
                  else
         | 
| 236 270 | 
             
                    Torch.empty(*args, dtype: dtype, device: device)
         | 
| 237 271 | 
             
                  end
         | 
| 238 272 | 
             
                end
         | 
| 273 | 
            +
                TENSOR_TYPE_CLASSES << cls
         | 
| 239 274 | 
             
                cls
         | 
| 240 275 | 
             
              end
         | 
| 241 276 |  | 
| 242 | 
            -
               | 
| 243 | 
            -
             | 
| 244 | 
            -
             | 
| 245 | 
            -
             | 
| 246 | 
            -
             | 
| 247 | 
            -
             | 
| 248 | 
            -
             | 
| 249 | 
            -
             | 
| 250 | 
            -
             | 
| 251 | 
            -
             | 
| 252 | 
            -
               | 
| 253 | 
            -
             | 
| 254 | 
            -
               | 
| 255 | 
            -
             | 
| 256 | 
            -
             | 
| 257 | 
            -
               | 
| 258 | 
            -
              CUDA::IntTensor = _make_tensor_class(:int32, true)
         | 
| 259 | 
            -
              CUDA::LongTensor = _make_tensor_class(:int64, true)
         | 
| 260 | 
            -
              CUDA::BoolTensor = _make_tensor_class(:bool, true)
         | 
| 277 | 
            +
              DTYPE_TO_CLASS = {
         | 
| 278 | 
            +
                float32: "FloatTensor",
         | 
| 279 | 
            +
                float64: "DoubleTensor",
         | 
| 280 | 
            +
                float16: "HalfTensor",
         | 
| 281 | 
            +
                uint8: "ByteTensor",
         | 
| 282 | 
            +
                int8: "CharTensor",
         | 
| 283 | 
            +
                int16: "ShortTensor",
         | 
| 284 | 
            +
                int32: "IntTensor",
         | 
| 285 | 
            +
                int64: "LongTensor",
         | 
| 286 | 
            +
                bool: "BoolTensor"
         | 
| 287 | 
            +
              }
         | 
| 288 | 
            +
             | 
| 289 | 
            +
              DTYPE_TO_CLASS.each do |dtype, class_name|
         | 
| 290 | 
            +
                const_set(class_name, _make_tensor_class(dtype))
         | 
| 291 | 
            +
                CUDA.const_set(class_name, _make_tensor_class(dtype, true))
         | 
| 292 | 
            +
              end
         | 
| 261 293 |  | 
| 262 294 | 
             
              class << self
         | 
| 263 295 | 
             
                # Torch.float, Torch.long, etc
         | 
| @@ -388,6 +420,10 @@ module Torch | |
| 388 420 | 
             
                end
         | 
| 389 421 |  | 
| 390 422 | 
             
                def randperm(n, **options)
         | 
| 423 | 
            +
                  # dtype hack in Python
         | 
| 424 | 
            +
                  # https://github.com/pytorch/pytorch/blob/v1.6.0/tools/autograd/gen_python_functions.py#L1307-L1311
         | 
| 425 | 
            +
                  options[:dtype] ||= :int64
         | 
| 426 | 
            +
             | 
| 391 427 | 
             
                  _randperm(n, tensor_options(**options))
         | 
| 392 428 | 
             
                end
         | 
| 393 429 |  | 
| @@ -460,6 +496,22 @@ module Torch | |
| 460 496 | 
             
                  zeros(input.size, **like_options(input, options))
         | 
| 461 497 | 
             
                end
         | 
| 462 498 |  | 
| 499 | 
            +
                def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true)
         | 
| 500 | 
            +
                  if center
         | 
| 501 | 
            +
                    signal_dim = input.dim
         | 
| 502 | 
            +
                    extended_shape = [1] * (3 - signal_dim) + input.size
         | 
| 503 | 
            +
                    pad = n_fft.div(2).to_i
         | 
| 504 | 
            +
                    input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
         | 
| 505 | 
            +
                    input = input.view(input.shape[-signal_dim..-1])
         | 
| 506 | 
            +
                  end
         | 
| 507 | 
            +
                  _stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
         | 
| 508 | 
            +
                end
         | 
| 509 | 
            +
             | 
| 510 | 
            +
                def clamp(tensor, min, max)
         | 
| 511 | 
            +
                  tensor = _clamp_min(tensor, min)
         | 
| 512 | 
            +
                  _clamp_max(tensor, max)
         | 
| 513 | 
            +
                end
         | 
| 514 | 
            +
             | 
| 463 515 | 
             
                private
         | 
| 464 516 |  | 
| 465 517 | 
             
                def to_ivalue(obj)
         | 
| @@ -22,21 +22,43 @@ module Torch | |
| 22 22 | 
             
                    end
         | 
| 23 23 |  | 
| 24 24 | 
             
                    def bind_functions(context, def_method, functions)
         | 
| 25 | 
            +
                      instance_method = def_method == :define_method
         | 
| 25 26 | 
             
                      functions.group_by(&:ruby_name).sort_by { |g, _| g }.each do |name, funcs|
         | 
| 26 | 
            -
                        if  | 
| 27 | 
            +
                        if instance_method
         | 
| 27 28 | 
             
                          funcs.map! { |f| Function.new(f.function) }
         | 
| 28 | 
            -
                          funcs.each { |f| f.args.reject! { |a| a[:name] ==  | 
| 29 | 
            +
                          funcs.each { |f| f.args.reject! { |a| a[:name] == :self } }
         | 
| 29 30 | 
             
                        end
         | 
| 30 31 |  | 
| 31 | 
            -
                        defined =  | 
| 32 | 
            +
                        defined = instance_method ? context.method_defined?(name) : context.respond_to?(name)
         | 
| 32 33 | 
             
                        next if defined && name != "clone"
         | 
| 33 34 |  | 
| 34 | 
            -
                        parser  | 
| 35 | 
            +
                        # skip parser when possible for performance
         | 
| 36 | 
            +
                        if funcs.size == 1 && funcs.first.args.size == 0
         | 
| 37 | 
            +
                          # functions with no arguments
         | 
| 38 | 
            +
                          if instance_method
         | 
| 39 | 
            +
                            context.send(:alias_method, name, funcs.first.cpp_name)
         | 
| 40 | 
            +
                          else
         | 
| 41 | 
            +
                            context.singleton_class.send(:alias_method, name, funcs.first.cpp_name)
         | 
| 42 | 
            +
                          end
         | 
| 43 | 
            +
                        elsif funcs.size == 2 && funcs.map { |f| f.arg_types.values }.sort == [["Scalar"], ["Tensor"]]
         | 
| 44 | 
            +
                          # functions that take a tensor or scalar
         | 
| 45 | 
            +
                          scalar_name, tensor_name = funcs.sort_by { |f| f.arg_types.values }.map(&:cpp_name)
         | 
| 46 | 
            +
                          context.send(def_method, name) do |other|
         | 
| 47 | 
            +
                            case other
         | 
| 48 | 
            +
                            when Tensor
         | 
| 49 | 
            +
                              send(tensor_name, other)
         | 
| 50 | 
            +
                            else
         | 
| 51 | 
            +
                              send(scalar_name, other)
         | 
| 52 | 
            +
                            end
         | 
| 53 | 
            +
                          end
         | 
| 54 | 
            +
                        else
         | 
| 55 | 
            +
                          parser = Parser.new(funcs)
         | 
| 35 56 |  | 
| 36 | 
            -
             | 
| 37 | 
            -
             | 
| 38 | 
            -
             | 
| 39 | 
            -
             | 
| 57 | 
            +
                          context.send(def_method, name) do |*args, **options|
         | 
| 58 | 
            +
                            result = parser.parse(args, options)
         | 
| 59 | 
            +
                            raise ArgumentError, result[:error] if result[:error]
         | 
| 60 | 
            +
                            send(result[:name], *result[:args])
         | 
| 61 | 
            +
                          end
         | 
| 40 62 | 
             
                        end
         | 
| 41 63 | 
             
                      end
         | 
| 42 64 | 
             
                    end
         | 
| @@ -1,10 +1,15 @@ | |
| 1 1 | 
             
            module Torch
         | 
| 2 2 | 
             
              module Native
         | 
| 3 3 | 
             
                class Function
         | 
| 4 | 
            -
                  attr_reader :function
         | 
| 4 | 
            +
                  attr_reader :function, :tensor_options
         | 
| 5 5 |  | 
| 6 6 | 
             
                  def initialize(function)
         | 
| 7 7 | 
             
                    @function = function
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                    # note: don't modify function in-place
         | 
| 10 | 
            +
                    @tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
         | 
| 11 | 
            +
                    @tensor_options = @function["func"].include?(@tensor_options_str)
         | 
| 12 | 
            +
                    @out = out_size > 0 && base_name[-1] != "_"
         | 
| 8 13 | 
             
                  end
         | 
| 9 14 |  | 
| 10 15 | 
             
                  def func
         | 
| @@ -27,7 +32,7 @@ module Torch | |
| 27 32 | 
             
                    @args ||= begin
         | 
| 28 33 | 
             
                      args = []
         | 
| 29 34 | 
             
                      pos = true
         | 
| 30 | 
            -
                      args_str = func.split("(", 2).last.split(") ->").first
         | 
| 35 | 
            +
                      args_str = func.sub(@tensor_options_str, ")").split("(", 2).last.split(") ->").first
         | 
| 31 36 | 
             
                      args_str.split(", ").each do |a|
         | 
| 32 37 | 
             
                        if a == "*"
         | 
| 33 38 | 
             
                          pos = false
         | 
| @@ -68,12 +73,88 @@ module Torch | |
| 68 73 | 
             
                        next if t == "Generator?"
         | 
| 69 74 | 
             
                        next if t == "MemoryFormat"
         | 
| 70 75 | 
             
                        next if t == "MemoryFormat?"
         | 
| 71 | 
            -
                        args << {name: k, type: t, default: d, pos: pos, has_default: has_default}
         | 
| 76 | 
            +
                        args << {name: k.to_sym, type: t, default: d, pos: pos, has_default: has_default}
         | 
| 72 77 | 
             
                      end
         | 
| 73 78 | 
             
                      args
         | 
| 74 79 | 
             
                    end
         | 
| 75 80 | 
             
                  end
         | 
| 76 81 |  | 
| 82 | 
            +
                  def arg_checkers
         | 
| 83 | 
            +
                    @arg_checkers ||= begin
         | 
| 84 | 
            +
                      checkers = {}
         | 
| 85 | 
            +
                      arg_types.each do |k, t|
         | 
| 86 | 
            +
                        checker =
         | 
| 87 | 
            +
                          case t
         | 
| 88 | 
            +
                          when "Tensor"
         | 
| 89 | 
            +
                            ->(v) { v.is_a?(Tensor) }
         | 
| 90 | 
            +
                          when "Tensor?"
         | 
| 91 | 
            +
                            ->(v) { v.nil? || v.is_a?(Tensor) }
         | 
| 92 | 
            +
                          when "Tensor[]", "Tensor?[]"
         | 
| 93 | 
            +
                            ->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) } }
         | 
| 94 | 
            +
                          when "int"
         | 
| 95 | 
            +
                            if k == :reduction
         | 
| 96 | 
            +
                              ->(v) { v.is_a?(String) }
         | 
| 97 | 
            +
                            else
         | 
| 98 | 
            +
                              ->(v) { v.is_a?(Integer) }
         | 
| 99 | 
            +
                            end
         | 
| 100 | 
            +
                          when "int?"
         | 
| 101 | 
            +
                            ->(v) { v.is_a?(Integer) || v.nil? }
         | 
| 102 | 
            +
                          when "float?"
         | 
| 103 | 
            +
                            ->(v) { v.is_a?(Numeric) || v.nil? }
         | 
| 104 | 
            +
                          when "bool?"
         | 
| 105 | 
            +
                            ->(v) { v == true || v == false || v.nil? }
         | 
| 106 | 
            +
                          when "float"
         | 
| 107 | 
            +
                            ->(v) { v.is_a?(Numeric) }
         | 
| 108 | 
            +
                          when /int\[.*\]/
         | 
| 109 | 
            +
                            ->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) } }
         | 
| 110 | 
            +
                          when "Scalar"
         | 
| 111 | 
            +
                            ->(v) { v.is_a?(Numeric) }
         | 
| 112 | 
            +
                          when "Scalar?"
         | 
| 113 | 
            +
                            ->(v) { v.is_a?(Numeric) || v.nil? }
         | 
| 114 | 
            +
                          when "ScalarType"
         | 
| 115 | 
            +
                            ->(v) { false } # not supported yet
         | 
| 116 | 
            +
                          when "ScalarType?"
         | 
| 117 | 
            +
                            ->(v) { v.nil? }
         | 
| 118 | 
            +
                          when "bool"
         | 
| 119 | 
            +
                            ->(v) { v == true || v == false }
         | 
| 120 | 
            +
                          when "str"
         | 
| 121 | 
            +
                            ->(v) { v.is_a?(String) }
         | 
| 122 | 
            +
                          else
         | 
| 123 | 
            +
                            raise Error, "Unknown argument type: #{t}. Please report a bug with #{@name}."
         | 
| 124 | 
            +
                          end
         | 
| 125 | 
            +
                        checkers[k] = checker
         | 
| 126 | 
            +
                      end
         | 
| 127 | 
            +
                      checkers
         | 
| 128 | 
            +
                    end
         | 
| 129 | 
            +
                  end
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                  def int_array_lengths
         | 
| 132 | 
            +
                    @int_array_lengths ||= begin
         | 
| 133 | 
            +
                      ret = {}
         | 
| 134 | 
            +
                      arg_types.each do |k, t|
         | 
| 135 | 
            +
                        if t.match?(/\Aint\[.+\]\z/)
         | 
| 136 | 
            +
                          size = t[4..-2]
         | 
| 137 | 
            +
                          raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
         | 
| 138 | 
            +
                          ret[k] = size.to_i
         | 
| 139 | 
            +
                        end
         | 
| 140 | 
            +
                      end
         | 
| 141 | 
            +
                      ret
         | 
| 142 | 
            +
                    end
         | 
| 143 | 
            +
                  end
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                  def arg_names
         | 
| 146 | 
            +
                    @arg_names ||= args.map { |a| a[:name] }
         | 
| 147 | 
            +
                  end
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                  def arg_types
         | 
| 150 | 
            +
                    @arg_types ||= args.map { |a| [a[:name], a[:type].split("(").first] }.to_h
         | 
| 151 | 
            +
                  end
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                  def arg_defaults
         | 
| 154 | 
            +
                    # TODO find out why can't use select here
         | 
| 155 | 
            +
                    @arg_defaults ||= args.map { |a| [a[:name], a[:default]] }.to_h
         | 
| 156 | 
            +
                  end
         | 
| 157 | 
            +
             | 
| 77 158 | 
             
                  def out_size
         | 
| 78 159 | 
             
                    @out_size ||= func.split("->").last.count("!")
         | 
| 79 160 | 
             
                  end
         | 
| @@ -82,8 +163,16 @@ module Torch | |
| 82 163 | 
             
                    @ret_size ||= func.split("->").last.split(", ").size
         | 
| 83 164 | 
             
                  end
         | 
| 84 165 |  | 
| 166 | 
            +
                  def ret_array?
         | 
| 167 | 
            +
                    @ret_array ||= func.split("->").last.include?('[]')
         | 
| 168 | 
            +
                  end
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                  def ret_void?
         | 
| 171 | 
            +
                    func.split("->").last.strip == "()"
         | 
| 172 | 
            +
                  end
         | 
| 173 | 
            +
             | 
| 85 174 | 
             
                  def out?
         | 
| 86 | 
            -
                     | 
| 175 | 
            +
                    @out
         | 
| 87 176 | 
             
                  end
         | 
| 88 177 |  | 
| 89 178 | 
             
                  def ruby_name
         | 
| @@ -18,12 +18,12 @@ module Torch | |
| 18 18 | 
             
                      functions = functions()
         | 
| 19 19 |  | 
| 20 20 | 
             
                      # skip functions
         | 
| 21 | 
            -
                      skip_args = [" | 
| 21 | 
            +
                      skip_args = ["Layout", "Storage", "ConstQuantizerPtr"]
         | 
| 22 22 |  | 
| 23 23 | 
             
                      # remove functions
         | 
| 24 24 | 
             
                      functions.reject! do |f|
         | 
| 25 25 | 
             
                        f.ruby_name.start_with?("_") ||
         | 
| 26 | 
            -
                        f.ruby_name. | 
| 26 | 
            +
                        f.ruby_name.include?("_backward") ||
         | 
| 27 27 | 
             
                        f.args.any? { |a| a[:type].include?("Dimname") }
         | 
| 28 28 | 
             
                      end
         | 
| 29 29 |  | 
| @@ -31,32 +31,15 @@ module Torch | |
| 31 31 | 
             
                      todo_functions, functions =
         | 
| 32 32 | 
             
                        functions.partition do |f|
         | 
| 33 33 | 
             
                          f.args.any? do |a|
         | 
| 34 | 
            -
                            a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
         | 
| 35 34 | 
             
                            skip_args.any? { |sa| a[:type].include?(sa) } ||
         | 
| 35 | 
            +
                            # call to 'range' is ambiguous
         | 
| 36 | 
            +
                            f.cpp_name == "_range" ||
         | 
| 36 37 | 
             
                            # native_functions.yaml is missing size argument for normal
         | 
| 37 38 | 
             
                            # https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
         | 
| 38 39 | 
             
                            (f.base_name == "normal" && !f.out?)
         | 
| 39 40 | 
             
                          end
         | 
| 40 41 | 
             
                        end
         | 
| 41 42 |  | 
| 42 | 
            -
                      # generate additional functions for optional arguments
         | 
| 43 | 
            -
                      # there may be a better way to do this
         | 
| 44 | 
            -
                      optional_functions, functions = functions.partition { |f| f.args.any? { |a| a[:type] == "int?" } }
         | 
| 45 | 
            -
                      optional_functions.each do |f|
         | 
| 46 | 
            -
                        next if f.ruby_name == "cross"
         | 
| 47 | 
            -
                        next if f.ruby_name.start_with?("avg_pool") && f.out?
         | 
| 48 | 
            -
             | 
| 49 | 
            -
                        opt_args = f.args.select { |a| a[:type] == "int?" }
         | 
| 50 | 
            -
                        if opt_args.size == 1
         | 
| 51 | 
            -
                          sep = f.name.include?(".") ? "_" : "."
         | 
| 52 | 
            -
                          f1 = Function.new(f.function.merge("func" => f.func.sub("(", "#{sep}#{opt_args.first[:name]}(").gsub("int?", "int")))
         | 
| 53 | 
            -
                          # TODO only remove some arguments
         | 
| 54 | 
            -
                          f2 = Function.new(f.function.merge("func" => f.func.sub(/, int\?.+\) ->/, ") ->")))
         | 
| 55 | 
            -
                          functions << f1
         | 
| 56 | 
            -
                          functions << f2
         | 
| 57 | 
            -
                        end
         | 
| 58 | 
            -
                      end
         | 
| 59 | 
            -
             | 
| 60 43 | 
             
                      # todo_functions.each do |f|
         | 
| 61 44 | 
             
                      #   puts f.func
         | 
| 62 45 | 
             
                      #   puts
         | 
| @@ -89,15 +72,18 @@ void add_%{type}_functions(Module m); | |
| 89 72 | 
             
            #include <rice/Module.hpp>
         | 
| 90 73 | 
             
            #include "templates.hpp"
         | 
| 91 74 |  | 
| 75 | 
            +
            %{functions}
         | 
| 76 | 
            +
             | 
| 92 77 | 
             
            void add_%{type}_functions(Module m) {
         | 
| 93 | 
            -
               | 
| 94 | 
            -
              %{functions};
         | 
| 78 | 
            +
              %{add_functions}
         | 
| 95 79 | 
             
            }
         | 
| 96 80 | 
             
                    TEMPLATE
         | 
| 97 81 |  | 
| 98 82 | 
             
                      cpp_defs = []
         | 
| 83 | 
            +
                      add_defs = []
         | 
| 99 84 | 
             
                      functions.sort_by(&:cpp_name).each do |func|
         | 
| 100 | 
            -
                        fargs = func.args #.select { |a| a[:type] != "Generator?" }
         | 
| 85 | 
            +
                        fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
         | 
| 86 | 
            +
                        fargs << {name: :options, type: "TensorOptions"} if func.tensor_options
         | 
| 101 87 |  | 
| 102 88 | 
             
                        cpp_args = []
         | 
| 103 89 | 
             
                        fargs.each do |a|
         | 
| @@ -109,52 +95,70 @@ void add_%{type}_functions(Module m) { | |
| 109 95 | 
             
                              # TODO better signature
         | 
| 110 96 | 
             
                              "OptionalTensor"
         | 
| 111 97 | 
             
                            when "ScalarType?"
         | 
| 112 | 
            -
                              " | 
| 113 | 
            -
                            when "Tensor[]"
         | 
| 114 | 
            -
                              "TensorList"
         | 
| 115 | 
            -
                            when "Tensor?[]"
         | 
| 98 | 
            +
                              "torch::optional<ScalarType>"
         | 
| 99 | 
            +
                            when "Tensor[]", "Tensor?[]"
         | 
| 116 100 | 
             
                              # TODO make optional
         | 
| 117 | 
            -
                              " | 
| 101 | 
            +
                              "std::vector<Tensor>"
         | 
| 118 102 | 
             
                            when "int"
         | 
| 119 103 | 
             
                              "int64_t"
         | 
| 104 | 
            +
                            when "int?"
         | 
| 105 | 
            +
                              "torch::optional<int64_t>"
         | 
| 106 | 
            +
                            when "float?"
         | 
| 107 | 
            +
                              "torch::optional<double>"
         | 
| 108 | 
            +
                            when "bool?"
         | 
| 109 | 
            +
                              "torch::optional<bool>"
         | 
| 110 | 
            +
                            when "Scalar?"
         | 
| 111 | 
            +
                              "torch::optional<torch::Scalar>"
         | 
| 120 112 | 
             
                            when "float"
         | 
| 121 113 | 
             
                              "double"
         | 
| 122 114 | 
             
                            when /\Aint\[/
         | 
| 123 | 
            -
                              " | 
| 115 | 
            +
                              "std::vector<int64_t>"
         | 
| 124 116 | 
             
                            when /Tensor\(\S!?\)/
         | 
| 125 117 | 
             
                              "Tensor &"
         | 
| 126 118 | 
             
                            when "str"
         | 
| 127 119 | 
             
                              "std::string"
         | 
| 128 | 
            -
                             | 
| 120 | 
            +
                            when "TensorOptions"
         | 
| 121 | 
            +
                              "const torch::TensorOptions &"
         | 
| 122 | 
            +
                            when "Layout?"
         | 
| 123 | 
            +
                              "torch::optional<Layout>"
         | 
| 124 | 
            +
                            when "Device?"
         | 
| 125 | 
            +
                              "torch::optional<Device>"
         | 
| 126 | 
            +
                            when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage"
         | 
| 129 127 | 
             
                              a[:type]
         | 
| 128 | 
            +
                            else
         | 
| 129 | 
            +
                              raise "Unknown type: #{a[:type]}"
         | 
| 130 130 | 
             
                            end
         | 
| 131 131 |  | 
| 132 | 
            -
                          t = "MyReduction" if a[:name] ==  | 
| 132 | 
            +
                          t = "MyReduction" if a[:name] == :reduction && t == "int64_t"
         | 
| 133 133 | 
             
                          cpp_args << [t, a[:name]].join(" ").sub("& ", "&")
         | 
| 134 134 | 
             
                        end
         | 
| 135 135 |  | 
| 136 136 | 
             
                        dispatch = func.out? ? "#{func.base_name}_out" : func.base_name
         | 
| 137 137 | 
             
                        args = fargs.map { |a| a[:name] }
         | 
| 138 138 | 
             
                        args.unshift(*args.pop(func.out_size)) if func.out?
         | 
| 139 | 
            -
                        args.delete( | 
| 139 | 
            +
                        args.delete(:self) if def_method == :define_method
         | 
| 140 140 |  | 
| 141 141 | 
             
                        prefix = def_method == :define_method ? "self." : "torch::"
         | 
| 142 142 |  | 
| 143 143 | 
             
                        body = "#{prefix}#{dispatch}(#{args.join(", ")})"
         | 
| 144 | 
            -
             | 
| 145 | 
            -
                        if func. | 
| 144 | 
            +
             | 
| 145 | 
            +
                        if func.cpp_name == "_fill_diagonal_"
         | 
| 146 | 
            +
                          body = "to_ruby<torch::Tensor>(#{body})"
         | 
| 147 | 
            +
                        elsif !func.ret_void?
         | 
| 146 148 | 
             
                          body = "wrap(#{body})"
         | 
| 147 149 | 
             
                        end
         | 
| 148 150 |  | 
| 149 | 
            -
                        cpp_defs << " | 
| 150 | 
            -
             | 
| 151 | 
            -
             | 
| 152 | 
            -
             | 
| 153 | 
            -
             | 
| 151 | 
            +
                        cpp_defs << "// #{func.func}
         | 
| 152 | 
            +
            static #{func.ret_void? ? "void" : "Object"} #{type}#{func.cpp_name}(#{cpp_args.join(", ")})
         | 
| 153 | 
            +
            {
         | 
| 154 | 
            +
              return #{body};
         | 
| 155 | 
            +
            }"
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                        add_defs << "m.#{def_method}(\"#{func.cpp_name}\", #{type}#{func.cpp_name});"
         | 
| 154 158 | 
             
                      end
         | 
| 155 159 |  | 
| 156 160 | 
             
                      hpp_contents = hpp_template % {type: type}
         | 
| 157 | 
            -
                      cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n  ")}
         | 
| 161 | 
            +
                      cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n\n"), add_functions: add_defs.join("\n  ")}
         | 
| 158 162 |  | 
| 159 163 | 
             
                      path = File.expand_path("../../../ext/torch", __dir__)
         | 
| 160 164 | 
             
                      File.write("#{path}/#{type}_functions.hpp", hpp_contents)
         |