torch-rb 0.2.5 → 0.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +32 -2
- data/README.md +4 -1
- data/ext/torch/ext.cpp +23 -6
- data/ext/torch/extconf.rb +3 -4
- data/lib/torch.rb +14 -5
- data/lib/torch/hub.rb +11 -10
- data/lib/torch/inspector.rb +236 -61
- data/lib/torch/native/function.rb +1 -0
- data/lib/torch/native/generator.rb +5 -2
- data/lib/torch/native/native_functions.yaml +654 -660
- data/lib/torch/native/parser.rb +1 -1
- data/lib/torch/nn/conv2d.rb +0 -1
- data/lib/torch/nn/functional.rb +5 -1
- data/lib/torch/nn/module.rb +5 -2
- data/lib/torch/optim/optimizer.rb +6 -4
- data/lib/torch/optim/rprop.rb +0 -3
- data/lib/torch/tensor.rb +46 -15
- data/lib/torch/utils/data.rb +23 -0
- data/lib/torch/utils/data/data_loader.rb +22 -6
- data/lib/torch/utils/data/subset.rb +25 -0
- data/lib/torch/version.rb +1 -1
- metadata +4 -2
    
        data/lib/torch/native/parser.rb
    CHANGED
    
    
    
        data/lib/torch/nn/conv2d.rb
    CHANGED
    
    | @@ -18,7 +18,6 @@ module Torch | |
| 18 18 | 
             
                    F.conv2d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
         | 
| 19 19 | 
             
                  end
         | 
| 20 20 |  | 
| 21 | 
            -
                  # TODO add more parameters
         | 
| 22 21 | 
             
                  def extra_inspect
         | 
| 23 22 | 
             
                    s = String.new("%{in_channels}, %{out_channels}, kernel_size: %{kernel_size}, stride: %{stride}")
         | 
| 24 23 | 
             
                    s += ", padding: %{padding}" if @padding != [0] * @padding.size
         | 
    
        data/lib/torch/nn/functional.rb
    CHANGED
    
    | @@ -373,7 +373,8 @@ module Torch | |
| 373 373 | 
             
                        end
         | 
| 374 374 |  | 
| 375 375 | 
             
                      # weight and input swapped
         | 
| 376 | 
            -
                      Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
         | 
| 376 | 
            +
                      ret, _, _, _ = Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
         | 
| 377 | 
            +
                      ret
         | 
| 377 378 | 
             
                    end
         | 
| 378 379 |  | 
| 379 380 | 
             
                    # distance functions
         | 
| @@ -426,6 +427,9 @@ module Torch | |
| 426 427 | 
             
                    end
         | 
| 427 428 |  | 
| 428 429 | 
             
                    def mse_loss(input, target, reduction: "mean")
         | 
| 430 | 
            +
                      if target.size != input.size
         | 
| 431 | 
            +
                        warn "Using a target size (#{target.size}) that is different to the input size (#{input.size}). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size."
         | 
| 432 | 
            +
                      end
         | 
| 429 433 | 
             
                      NN.mse_loss(input, target, reduction)
         | 
| 430 434 | 
             
                    end
         | 
| 431 435 |  | 
    
        data/lib/torch/nn/module.rb
    CHANGED
    
    | @@ -145,7 +145,7 @@ module Torch | |
| 145 145 | 
             
                    params = {}
         | 
| 146 146 | 
             
                    if recurse
         | 
| 147 147 | 
             
                      named_children.each do |name, mod|
         | 
| 148 | 
            -
                        params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
         | 
| 148 | 
            +
                        params.merge!(mod.named_parameters(prefix: "#{prefix}#{name}.", recurse: recurse))
         | 
| 149 149 | 
             
                      end
         | 
| 150 150 | 
             
                    end
         | 
| 151 151 | 
             
                    instance_variables.each do |name|
         | 
| @@ -286,8 +286,11 @@ module Torch | |
| 286 286 | 
             
                    str % vars
         | 
| 287 287 | 
             
                  end
         | 
| 288 288 |  | 
| 289 | 
            +
                  # used for format
         | 
| 290 | 
            +
                  # remove tensors for performance
         | 
| 291 | 
            +
                  # so we can skip call to inspect
         | 
| 289 292 | 
             
                  def dict
         | 
| 290 | 
            -
                    instance_variables.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
         | 
| 293 | 
            +
                    instance_variables.reject { |k| instance_variable_get(k).is_a?(Tensor) }.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
         | 
| 291 294 | 
             
                  end
         | 
| 292 295 | 
             
                end
         | 
| 293 296 | 
             
              end
         | 
| @@ -32,9 +32,11 @@ module Torch | |
| 32 32 | 
             
                  end
         | 
| 33 33 |  | 
| 34 34 | 
             
                  def state_dict
         | 
| 35 | 
            +
                    raise NotImplementedYet
         | 
| 36 | 
            +
             | 
| 35 37 | 
             
                    pack_group = lambda do |group|
         | 
| 36 | 
            -
                      packed = group.select { |k, _| k != :params }.to_h
         | 
| 37 | 
            -
                      packed[ | 
| 38 | 
            +
                      packed = group.select { |k, _| k != :params }.map { |k, v| [k.to_s, v] }.to_h
         | 
| 39 | 
            +
                      packed["params"] = group[:params].map { |p| p.object_id }
         | 
| 38 40 | 
             
                      packed
         | 
| 39 41 | 
             
                    end
         | 
| 40 42 |  | 
| @@ -42,8 +44,8 @@ module Torch | |
| 42 44 | 
             
                    packed_state = @state.map { |k, v| [k.is_a?(Tensor) ? k.object_id : k, v] }.to_h
         | 
| 43 45 |  | 
| 44 46 | 
             
                    {
         | 
| 45 | 
            -
                      state | 
| 46 | 
            -
                      param_groups | 
| 47 | 
            +
                      "state" => packed_state,
         | 
| 48 | 
            +
                      "param_groups" => param_groups
         | 
| 47 49 | 
             
                    }
         | 
| 48 50 | 
             
                  end
         | 
| 49 51 |  | 
    
        data/lib/torch/optim/rprop.rb
    CHANGED
    
    
    
        data/lib/torch/tensor.rb
    CHANGED
    
    | @@ -1,6 +1,7 @@ | |
| 1 1 | 
             
            module Torch
         | 
| 2 2 | 
             
              class Tensor
         | 
| 3 3 | 
             
                include Comparable
         | 
| 4 | 
            +
                include Enumerable
         | 
| 4 5 | 
             
                include Inspector
         | 
| 5 6 |  | 
| 6 7 | 
             
                alias_method :requires_grad?, :requires_grad
         | 
| @@ -25,6 +26,14 @@ module Torch | |
| 25 26 | 
             
                  inspect
         | 
| 26 27 | 
             
                end
         | 
| 27 28 |  | 
| 29 | 
            +
                def each
         | 
| 30 | 
            +
                  return enum_for(:each) unless block_given?
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                  size(0).times do |i|
         | 
| 33 | 
            +
                    yield self[i]
         | 
| 34 | 
            +
                  end
         | 
| 35 | 
            +
                end
         | 
| 36 | 
            +
             | 
| 28 37 | 
             
                # TODO make more performant
         | 
| 29 38 | 
             
                def to_a
         | 
| 30 39 | 
             
                  arr = _flat_data
         | 
| @@ -38,10 +47,15 @@ module Torch | |
| 38 47 | 
             
                  end
         | 
| 39 48 | 
             
                end
         | 
| 40 49 |  | 
| 41 | 
            -
                 | 
| 42 | 
            -
             | 
| 50 | 
            +
                def to(device = nil, dtype: nil, non_blocking: false, copy: false)
         | 
| 51 | 
            +
                  device ||= self.device
         | 
| 43 52 | 
             
                  device = Device.new(device) if device.is_a?(String)
         | 
| 44 | 
            -
             | 
| 53 | 
            +
             | 
| 54 | 
            +
                  dtype ||= self.dtype
         | 
| 55 | 
            +
                  enum = DTYPE_TO_ENUM[dtype]
         | 
| 56 | 
            +
                  raise Error, "Unknown type: #{dtype}" unless enum
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                  _to(device, enum, non_blocking, copy)
         | 
| 45 59 | 
             
                end
         | 
| 46 60 |  | 
| 47 61 | 
             
                def cpu
         | 
| @@ -89,8 +103,9 @@ module Torch | |
| 89 103 | 
             
                  Torch.empty(0, dtype: dtype)
         | 
| 90 104 | 
             
                end
         | 
| 91 105 |  | 
| 92 | 
            -
                def backward(gradient = nil)
         | 
| 93 | 
            -
                   | 
| 106 | 
            +
                def backward(gradient = nil, retain_graph: nil, create_graph: false)
         | 
| 107 | 
            +
                  retain_graph = create_graph if retain_graph.nil?
         | 
| 108 | 
            +
                  _backward(gradient, retain_graph, create_graph)
         | 
| 94 109 | 
             
                end
         | 
| 95 110 |  | 
| 96 111 | 
             
                # TODO read directly from memory
         | 
| @@ -153,12 +168,25 @@ module Torch | |
| 153 168 | 
             
                  neg
         | 
| 154 169 | 
             
                end
         | 
| 155 170 |  | 
| 171 | 
            +
                def &(other)
         | 
| 172 | 
            +
                  logical_and(other)
         | 
| 173 | 
            +
                end
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                def |(other)
         | 
| 176 | 
            +
                  logical_or(other)
         | 
| 177 | 
            +
                end
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                def ^(other)
         | 
| 180 | 
            +
                  logical_xor(other)
         | 
| 181 | 
            +
                end
         | 
| 182 | 
            +
             | 
| 156 183 | 
             
                # TODO better compare?
         | 
| 157 184 | 
             
                def <=>(other)
         | 
| 158 185 | 
             
                  item <=> other
         | 
| 159 186 | 
             
                end
         | 
| 160 187 |  | 
| 161 | 
            -
                # based on python_variable_indexing.cpp
         | 
| 188 | 
            +
                # based on python_variable_indexing.cpp and
         | 
| 189 | 
            +
                # https://pytorch.org/cppdocs/notes/tensor_indexing.html
         | 
| 162 190 | 
             
                def [](*indexes)
         | 
| 163 191 | 
             
                  result = self
         | 
| 164 192 | 
             
                  dim = 0
         | 
| @@ -170,6 +198,8 @@ module Torch | |
| 170 198 | 
             
                      finish += 1 unless index.exclude_end?
         | 
| 171 199 | 
             
                      result = result._slice_tensor(dim, index.begin, finish, 1)
         | 
| 172 200 | 
             
                      dim += 1
         | 
| 201 | 
            +
                    elsif index.is_a?(Tensor)
         | 
| 202 | 
            +
                      result = result.index([index])
         | 
| 173 203 | 
             
                    elsif index.nil?
         | 
| 174 204 | 
             
                      result = result.unsqueeze(dim)
         | 
| 175 205 | 
             
                      dim += 1
         | 
| @@ -183,19 +213,21 @@ module Torch | |
| 183 213 | 
             
                  result
         | 
| 184 214 | 
             
                end
         | 
| 185 215 |  | 
| 186 | 
            -
                #  | 
| 187 | 
            -
                #  | 
| 216 | 
            +
                # based on python_variable_indexing.cpp and
         | 
| 217 | 
            +
                # https://pytorch.org/cppdocs/notes/tensor_indexing.html
         | 
| 188 218 | 
             
                def []=(index, value)
         | 
| 189 219 | 
             
                  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
         | 
| 190 220 |  | 
| 191 | 
            -
                  value = Torch.tensor(value) unless value.is_a?(Tensor)
         | 
| 221 | 
            +
                  value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
         | 
| 192 222 |  | 
| 193 223 | 
             
                  if index.is_a?(Numeric)
         | 
| 194 | 
            -
                     | 
| 224 | 
            +
                    index_put!([Torch.tensor(index)], value)
         | 
| 195 225 | 
             
                  elsif index.is_a?(Range)
         | 
| 196 226 | 
             
                    finish = index.end
         | 
| 197 227 | 
             
                    finish += 1 unless index.exclude_end?
         | 
| 198 | 
            -
                     | 
| 228 | 
            +
                    _slice_tensor(0, index.begin, finish, 1).copy!(value)
         | 
| 229 | 
            +
                  elsif index.is_a?(Tensor)
         | 
| 230 | 
            +
                    index_put!([index], value)
         | 
| 199 231 | 
             
                  else
         | 
| 200 232 | 
             
                    raise Error, "Unsupported index type: #{index.class.name}"
         | 
| 201 233 | 
             
                  end
         | 
| @@ -224,10 +256,9 @@ module Torch | |
| 224 256 | 
             
                  end
         | 
| 225 257 | 
             
                end
         | 
| 226 258 |  | 
| 227 | 
            -
                 | 
| 228 | 
            -
             | 
| 229 | 
            -
             | 
| 230 | 
            -
                  dst.copy!(src)
         | 
| 259 | 
            +
                def clamp!(min, max)
         | 
| 260 | 
            +
                  _clamp_min_(min)
         | 
| 261 | 
            +
                  _clamp_max_(max)
         | 
| 231 262 | 
             
                end
         | 
| 232 263 | 
             
              end
         | 
| 233 264 | 
             
            end
         | 
| @@ -0,0 +1,23 @@ | |
| 1 | 
            +
            module Torch
         | 
| 2 | 
            +
              module Utils
         | 
| 3 | 
            +
                module Data
         | 
| 4 | 
            +
                  class << self
         | 
| 5 | 
            +
                    def random_split(dataset, lengths)
         | 
| 6 | 
            +
                      if lengths.sum != dataset.length
         | 
| 7 | 
            +
                        raise ArgumentError, "Sum of input lengths does not equal the length of the input dataset!"
         | 
| 8 | 
            +
                      end
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                      indices = Torch.randperm(lengths.sum).to_a
         | 
| 11 | 
            +
                      _accumulate(lengths).zip(lengths).map { |offset, length| Subset.new(dataset, indices[(offset - length)...offset]) }
         | 
| 12 | 
            +
                    end
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                    private
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                    def _accumulate(iterable)
         | 
| 17 | 
            +
                      sum = 0
         | 
| 18 | 
            +
                      iterable.map { |x| sum += x }
         | 
| 19 | 
            +
                    end
         | 
| 20 | 
            +
                  end
         | 
| 21 | 
            +
                end
         | 
| 22 | 
            +
              end
         | 
| 23 | 
            +
            end
         | 
| @@ -6,10 +6,22 @@ module Torch | |
| 6 6 |  | 
| 7 7 | 
             
                    attr_reader :dataset
         | 
| 8 8 |  | 
| 9 | 
            -
                    def initialize(dataset, batch_size: 1, shuffle: false)
         | 
| 9 | 
            +
                    def initialize(dataset, batch_size: 1, shuffle: false, collate_fn: nil)
         | 
| 10 10 | 
             
                      @dataset = dataset
         | 
| 11 11 | 
             
                      @batch_size = batch_size
         | 
| 12 12 | 
             
                      @shuffle = shuffle
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                      @batch_sampler = nil
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                      if collate_fn.nil?
         | 
| 17 | 
            +
                        if auto_collation?
         | 
| 18 | 
            +
                          collate_fn = method(:default_collate)
         | 
| 19 | 
            +
                        else
         | 
| 20 | 
            +
                          collate_fn = method(:default_convert)
         | 
| 21 | 
            +
                        end
         | 
| 22 | 
            +
                      end
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                      @collate_fn = collate_fn
         | 
| 13 25 | 
             
                    end
         | 
| 14 26 |  | 
| 15 27 | 
             
                    def each
         | 
| @@ -25,8 +37,8 @@ module Torch | |
| 25 37 | 
             
                        end
         | 
| 26 38 |  | 
| 27 39 | 
             
                      indexes.each_slice(@batch_size) do |idx|
         | 
| 28 | 
            -
                         | 
| 29 | 
            -
                        yield  | 
| 40 | 
            +
                        # TODO improve performance
         | 
| 41 | 
            +
                        yield @collate_fn.call(idx.map { |i| @dataset[i] })
         | 
| 30 42 | 
             
                      end
         | 
| 31 43 | 
             
                    end
         | 
| 32 44 |  | 
| @@ -36,7 +48,7 @@ module Torch | |
| 36 48 |  | 
| 37 49 | 
             
                    private
         | 
| 38 50 |  | 
| 39 | 
            -
                    def  | 
| 51 | 
            +
                    def default_convert(batch)
         | 
| 40 52 | 
             
                      elem = batch[0]
         | 
| 41 53 | 
             
                      case elem
         | 
| 42 54 | 
             
                      when Tensor
         | 
| @@ -44,11 +56,15 @@ module Torch | |
| 44 56 | 
             
                      when Integer
         | 
| 45 57 | 
             
                        Torch.tensor(batch)
         | 
| 46 58 | 
             
                      when Array
         | 
| 47 | 
            -
                        batch.transpose.map { |v|  | 
| 59 | 
            +
                        batch.transpose.map { |v| default_convert(v) }
         | 
| 48 60 | 
             
                      else
         | 
| 49 | 
            -
                        raise  | 
| 61 | 
            +
                        raise NotImplementedYet
         | 
| 50 62 | 
             
                      end
         | 
| 51 63 | 
             
                    end
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    def auto_collation?
         | 
| 66 | 
            +
                      !@batch_sampler.nil?
         | 
| 67 | 
            +
                    end
         | 
| 52 68 | 
             
                  end
         | 
| 53 69 | 
             
                end
         | 
| 54 70 | 
             
              end
         | 
| @@ -0,0 +1,25 @@ | |
| 1 | 
            +
            module Torch
         | 
| 2 | 
            +
              module Utils
         | 
| 3 | 
            +
                module Data
         | 
| 4 | 
            +
                  class Subset < Dataset
         | 
| 5 | 
            +
                    def initialize(dataset, indices)
         | 
| 6 | 
            +
                      @dataset = dataset
         | 
| 7 | 
            +
                      @indices = indices
         | 
| 8 | 
            +
                    end
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                    def [](idx)
         | 
| 11 | 
            +
                      @dataset[@indices[idx]]
         | 
| 12 | 
            +
                    end
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                    def length
         | 
| 15 | 
            +
                      @indices.length
         | 
| 16 | 
            +
                    end
         | 
| 17 | 
            +
                    alias_method :size, :length
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                    def to_a
         | 
| 20 | 
            +
                      @indices.map { |i| @dataset[i] }
         | 
| 21 | 
            +
                    end
         | 
| 22 | 
            +
                  end
         | 
| 23 | 
            +
                end
         | 
| 24 | 
            +
              end
         | 
| 25 | 
            +
            end
         | 
    
        data/lib/torch/version.rb
    CHANGED
    
    
    
        metadata
    CHANGED
    
    | @@ -1,14 +1,14 @@ | |
| 1 1 | 
             
            --- !ruby/object:Gem::Specification
         | 
| 2 2 | 
             
            name: torch-rb
         | 
| 3 3 | 
             
            version: !ruby/object:Gem::Version
         | 
| 4 | 
            -
              version: 0.2 | 
| 4 | 
            +
              version: 0.3.2
         | 
| 5 5 | 
             
            platform: ruby
         | 
| 6 6 | 
             
            authors:
         | 
| 7 7 | 
             
            - Andrew Kane
         | 
| 8 8 | 
             
            autorequire: 
         | 
| 9 9 | 
             
            bindir: bin
         | 
| 10 10 | 
             
            cert_chain: []
         | 
| 11 | 
            -
            date: 2020- | 
| 11 | 
            +
            date: 2020-08-24 00:00:00.000000000 Z
         | 
| 12 12 | 
             
            dependencies:
         | 
| 13 13 | 
             
            - !ruby/object:Gem::Dependency
         | 
| 14 14 | 
             
              name: rice
         | 
| @@ -259,8 +259,10 @@ files: | |
| 259 259 | 
             
            - lib/torch/optim/rprop.rb
         | 
| 260 260 | 
             
            - lib/torch/optim/sgd.rb
         | 
| 261 261 | 
             
            - lib/torch/tensor.rb
         | 
| 262 | 
            +
            - lib/torch/utils/data.rb
         | 
| 262 263 | 
             
            - lib/torch/utils/data/data_loader.rb
         | 
| 263 264 | 
             
            - lib/torch/utils/data/dataset.rb
         | 
| 265 | 
            +
            - lib/torch/utils/data/subset.rb
         | 
| 264 266 | 
             
            - lib/torch/utils/data/tensor_dataset.rb
         | 
| 265 267 | 
             
            - lib/torch/version.rb
         | 
| 266 268 | 
             
            homepage: https://github.com/ankane/torch.rb
         |