torch-rb 0.1.0 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +40 -0
- data/LICENSE.txt +46 -22
- data/README.md +85 -19
- data/ext/torch/ext.cpp +274 -256
- data/ext/torch/extconf.rb +9 -0
- data/ext/torch/nn_functions.cpp +595 -0
- data/ext/torch/nn_functions.hpp +6 -0
- data/ext/torch/templates.hpp +250 -0
- data/ext/torch/tensor_functions.cpp +1860 -0
- data/ext/torch/tensor_functions.hpp +6 -0
- data/ext/torch/torch_functions.cpp +2875 -0
- data/ext/torch/torch_functions.hpp +6 -0
- data/lib/torch.rb +199 -84
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +52 -25
- data/lib/torch/native/dispatcher.rb +48 -0
- data/lib/torch/native/function.rb +78 -0
- data/lib/torch/native/generator.rb +149 -0
- data/lib/torch/native/native_functions.yaml +6837 -0
- data/lib/torch/native/parser.rb +97 -0
- data/lib/torch/nn/alpha_dropout.rb +9 -0
- data/lib/torch/nn/avg_pool2d.rb +14 -0
- data/lib/torch/nn/avg_poolnd.rb +9 -0
- data/lib/torch/nn/bce_loss.rb +13 -0
- data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
- data/lib/torch/nn/bilinear.rb +38 -0
- data/lib/torch/nn/conv2d.rb +14 -29
- data/lib/torch/nn/convnd.rb +41 -0
- data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
- data/lib/torch/nn/cosine_similarity.rb +15 -0
- data/lib/torch/nn/cross_entropy_loss.rb +14 -0
- data/lib/torch/nn/ctc_loss.rb +15 -0
- data/lib/torch/nn/dropout.rb +9 -0
- data/lib/torch/nn/dropout2d.rb +9 -0
- data/lib/torch/nn/dropout3d.rb +9 -0
- data/lib/torch/nn/dropoutnd.rb +15 -0
- data/lib/torch/nn/embedding.rb +52 -0
- data/lib/torch/nn/embedding_bag.rb +34 -0
- data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
- data/lib/torch/nn/functional.rb +194 -11
- data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
- data/lib/torch/nn/identity.rb +14 -0
- data/lib/torch/nn/init.rb +58 -1
- data/lib/torch/nn/kl_div_loss.rb +13 -0
- data/lib/torch/nn/l1_loss.rb +13 -0
- data/lib/torch/nn/leaky_relu.rb +20 -0
- data/lib/torch/nn/linear.rb +12 -11
- data/lib/torch/nn/log_softmax.rb +14 -0
- data/lib/torch/nn/loss.rb +10 -0
- data/lib/torch/nn/margin_ranking_loss.rb +14 -0
- data/lib/torch/nn/max_pool2d.rb +9 -0
- data/lib/torch/nn/max_poolnd.rb +19 -0
- data/lib/torch/nn/module.rb +184 -19
- data/lib/torch/nn/mse_loss.rb +2 -2
- data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_margin_loss.rb +17 -0
- data/lib/torch/nn/nll_loss.rb +14 -0
- data/lib/torch/nn/pairwise_distance.rb +16 -0
- data/lib/torch/nn/parameter.rb +4 -0
- data/lib/torch/nn/poisson_nll_loss.rb +16 -0
- data/lib/torch/nn/prelu.rb +19 -0
- data/lib/torch/nn/relu.rb +8 -3
- data/lib/torch/nn/rnn.rb +22 -0
- data/lib/torch/nn/rnn_base.rb +154 -0
- data/lib/torch/nn/sequential.rb +1 -10
- data/lib/torch/nn/sigmoid.rb +9 -0
- data/lib/torch/nn/smooth_l1_loss.rb +13 -0
- data/lib/torch/nn/soft_margin_loss.rb +13 -0
- data/lib/torch/nn/softmax.rb +18 -0
- data/lib/torch/nn/softmax2d.rb +10 -0
- data/lib/torch/nn/softmin.rb +14 -0
- data/lib/torch/nn/softplus.rb +19 -0
- data/lib/torch/nn/triplet_margin_loss.rb +18 -0
- data/lib/torch/nn/weighted_loss.rb +10 -0
- data/lib/torch/optim/adadelta.rb +57 -0
- data/lib/torch/optim/adagrad.rb +71 -0
- data/lib/torch/optim/adam.rb +81 -0
- data/lib/torch/optim/adamax.rb +68 -0
- data/lib/torch/optim/adamw.rb +82 -0
- data/lib/torch/optim/asgd.rb +65 -0
- data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
- data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
- data/lib/torch/optim/optimizer.rb +62 -0
- data/lib/torch/optim/rmsprop.rb +76 -0
- data/lib/torch/optim/rprop.rb +68 -0
- data/lib/torch/optim/sgd.rb +60 -0
- data/lib/torch/random.rb +10 -0
- data/lib/torch/tensor.rb +92 -21
- data/lib/torch/utils/data/data_loader.rb +15 -0
- data/lib/torch/utils/data/tensor_dataset.rb +8 -1
- data/lib/torch/version.rb +1 -1
- metadata +74 -3
| @@ -0,0 +1,14 @@ | |
| 1 | 
            +
            module Torch
         | 
| 2 | 
            +
              module NN
         | 
| 3 | 
            +
                class MarginRankingLoss < Loss
         | 
| 4 | 
            +
                  def initialize(margin: 1.0, reduction: "mean")
         | 
| 5 | 
            +
                    super(reduction)
         | 
| 6 | 
            +
                    @margin = margin
         | 
| 7 | 
            +
                  end
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                  def forward(input1, input2, target)
         | 
| 10 | 
            +
                    F.margin_ranking_loss(input1, input2, target, margin: @margin, reduction: @reduction)
         | 
| 11 | 
            +
                  end
         | 
| 12 | 
            +
                end
         | 
| 13 | 
            +
              end
         | 
| 14 | 
            +
            end
         | 
| @@ -0,0 +1,19 @@ | |
| 1 | 
            +
            module Torch
         | 
| 2 | 
            +
              module NN
         | 
| 3 | 
            +
                class MaxPoolNd < Module
         | 
| 4 | 
            +
                  def initialize(kernel_size) #, stride: nil, padding: 0, dilation: 1, return_indices: false, ceil_mode: false)
         | 
| 5 | 
            +
                    super()
         | 
| 6 | 
            +
                    @kernel_size = kernel_size
         | 
| 7 | 
            +
                    # @stride = stride || kernel_size
         | 
| 8 | 
            +
                    # @padding = padding
         | 
| 9 | 
            +
                    # @dilation = dilation
         | 
| 10 | 
            +
                    # @return_indices = return_indices
         | 
| 11 | 
            +
                    # @ceil_mode = ceil_mode
         | 
| 12 | 
            +
                  end
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                  def extra_inspect
         | 
| 15 | 
            +
                    format("kernel_size: %s", @kernel_size)
         | 
| 16 | 
            +
                  end
         | 
| 17 | 
            +
                end
         | 
| 18 | 
            +
              end
         | 
| 19 | 
            +
            end
         | 
    
        data/lib/torch/nn/module.rb
    CHANGED
    
    | @@ -1,55 +1,220 @@ | |
| 1 1 | 
             
            module Torch
         | 
| 2 2 | 
             
              module NN
         | 
| 3 3 | 
             
                class Module
         | 
| 4 | 
            -
                  def  | 
| 5 | 
            -
                     | 
| 6 | 
            -
                     | 
| 7 | 
            -
                     | 
| 8 | 
            -
             | 
| 4 | 
            +
                  def initialize
         | 
| 5 | 
            +
                    @training = true
         | 
| 6 | 
            +
                    @parameters = {}
         | 
| 7 | 
            +
                    @buffers = {}
         | 
| 8 | 
            +
                    @modules = {}
         | 
| 9 | 
            +
                  end
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                  def forward
         | 
| 12 | 
            +
                    raise NotImplementedError
         | 
| 13 | 
            +
                  end
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                  def register_buffer(name, tensor)
         | 
| 16 | 
            +
                    # TODO add checks
         | 
| 17 | 
            +
                    @buffers[name] = tensor
         | 
| 18 | 
            +
                  end
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                  def register_parameter(name, param)
         | 
| 21 | 
            +
                    # TODO add checks
         | 
| 22 | 
            +
                    @parameters[name] = param
         | 
| 23 | 
            +
                  end
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                  def add_module(name, mod)
         | 
| 26 | 
            +
                    # TODO add checks
         | 
| 27 | 
            +
                    @modules[name] = mod
         | 
| 28 | 
            +
                  end
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                  def _apply(fn)
         | 
| 31 | 
            +
                    children.each do |mod|
         | 
| 32 | 
            +
                      mod._apply(fn)
         | 
| 33 | 
            +
                    end
         | 
| 34 | 
            +
                    # TODO apply to more objects
         | 
| 35 | 
            +
                    self
         | 
| 36 | 
            +
                  end
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                  def apply(fn)
         | 
| 39 | 
            +
                    children.each do |mod|
         | 
| 40 | 
            +
                      mod.apply(fn)
         | 
| 41 | 
            +
                    end
         | 
| 42 | 
            +
                    fn.call(self)
         | 
| 43 | 
            +
                    self
         | 
| 44 | 
            +
                  end
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                  def cuda(device: nil)
         | 
| 47 | 
            +
                    _apply ->(t) { t.cuda(device) }
         | 
| 48 | 
            +
                  end
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                  def cpu
         | 
| 51 | 
            +
                    _apply ->(t) { t.cpu }
         | 
| 52 | 
            +
                  end
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                  def type(dst_type)
         | 
| 55 | 
            +
                    _apply ->(t) { t.type(dst_type) }
         | 
| 56 | 
            +
                  end
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                  def float
         | 
| 59 | 
            +
                    _apply ->(t) { t.floating_point? ? t.float : t }
         | 
| 60 | 
            +
                  end
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                  def double
         | 
| 63 | 
            +
                    _apply ->(t) { t.floating_point? ? t.double : t }
         | 
| 64 | 
            +
                  end
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                  def half
         | 
| 67 | 
            +
                    _apply ->(t) { t.floating_point? ? t.half : t }
         | 
| 68 | 
            +
                  end
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                  # modifies in-place
         | 
| 71 | 
            +
                  def to(device)
         | 
| 72 | 
            +
                    convert = lambda do |t|
         | 
| 73 | 
            +
                      t.to(device)
         | 
| 9 74 | 
             
                    end
         | 
| 10 | 
            -
             | 
| 75 | 
            +
             | 
| 76 | 
            +
                    _apply(convert)
         | 
| 11 77 | 
             
                  end
         | 
| 12 78 |  | 
| 13 79 | 
             
                  def call(*input)
         | 
| 14 80 | 
             
                    forward(*input)
         | 
| 15 81 | 
             
                  end
         | 
| 16 82 |  | 
| 83 | 
            +
                  def state_dict
         | 
| 84 | 
            +
                    raise NotImplementedYet
         | 
| 85 | 
            +
                  end
         | 
| 86 | 
            +
             | 
| 17 87 | 
             
                  def parameters
         | 
| 18 | 
            -
                     | 
| 88 | 
            +
                    named_parameters.values
         | 
| 89 | 
            +
                  end
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                  def named_parameters(prefix: "", recurse: true)
         | 
| 92 | 
            +
                    params = {}
         | 
| 93 | 
            +
                    if recurse
         | 
| 94 | 
            +
                      named_children.each do |name, mod|
         | 
| 95 | 
            +
                        params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
         | 
| 96 | 
            +
                      end
         | 
| 97 | 
            +
                    end
         | 
| 19 98 | 
             
                    instance_variables.each do |name|
         | 
| 20 99 | 
             
                      param = instance_variable_get(name)
         | 
| 21 | 
            -
                      params  | 
| 100 | 
            +
                      params[[prefix, name[1..-1]].join] = param if param.is_a?(Parameter)
         | 
| 101 | 
            +
                    end
         | 
| 102 | 
            +
                    @parameters.each do |name, param|
         | 
| 103 | 
            +
                      params[[prefix, name].join] = param
         | 
| 22 104 | 
             
                    end
         | 
| 23 | 
            -
                    params | 
| 105 | 
            +
                    params
         | 
| 106 | 
            +
                  end
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                  def buffers
         | 
| 109 | 
            +
                    named_buffers.values
         | 
| 110 | 
            +
                  end
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                  def named_buffers
         | 
| 113 | 
            +
                    @buffers || {}
         | 
| 114 | 
            +
                  end
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                  def children
         | 
| 117 | 
            +
                    named_children.values
         | 
| 118 | 
            +
                  end
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                  def named_children
         | 
| 121 | 
            +
                    modules = {}
         | 
| 122 | 
            +
                    instance_variables.each do |name|
         | 
| 123 | 
            +
                      mod = instance_variable_get(name)
         | 
| 124 | 
            +
                      modules[name[1..-1]] = mod if mod.is_a?(Module)
         | 
| 125 | 
            +
                    end
         | 
| 126 | 
            +
                    @modules.each do |name, mod|
         | 
| 127 | 
            +
                      modules[name] = mod
         | 
| 128 | 
            +
                    end
         | 
| 129 | 
            +
                    modules
         | 
| 130 | 
            +
                  end
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                  def modules
         | 
| 133 | 
            +
                    named_modules.values
         | 
| 134 | 
            +
                  end
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                  def named_modules
         | 
| 137 | 
            +
                    {"" => self}.merge(named_children)
         | 
| 138 | 
            +
                  end
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                  def train(mode = true)
         | 
| 141 | 
            +
                    @training = mode
         | 
| 142 | 
            +
                    children.each do |mod|
         | 
| 143 | 
            +
                      mod.train(mode)
         | 
| 144 | 
            +
                    end
         | 
| 145 | 
            +
                    self
         | 
| 146 | 
            +
                  end
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                  def eval
         | 
| 149 | 
            +
                    train(false)
         | 
| 150 | 
            +
                  end
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                  def requires_grad!(requires_grad: true)
         | 
| 153 | 
            +
                    parameters.each do |p|
         | 
| 154 | 
            +
                      p.requires_grad!(requires_grad)
         | 
| 155 | 
            +
                    end
         | 
| 156 | 
            +
                    self
         | 
| 24 157 | 
             
                  end
         | 
| 25 158 |  | 
| 26 159 | 
             
                  def zero_grad
         | 
| 27 160 | 
             
                    parameters.each do |param|
         | 
| 28 161 | 
             
                      if param.grad
         | 
| 29 | 
            -
                        raise Error, "Not supported yet"
         | 
| 30 162 | 
             
                        param.grad.detach!
         | 
| 31 163 | 
             
                        param.grad.zero!
         | 
| 32 164 | 
             
                      end
         | 
| 33 165 | 
             
                    end
         | 
| 34 166 | 
             
                  end
         | 
| 35 167 |  | 
| 168 | 
            +
                  def share_memory
         | 
| 169 | 
            +
                    _apply ->(t) { t.share_memory! }
         | 
| 170 | 
            +
                  end
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                  def inspect
         | 
| 173 | 
            +
                    name = self.class.name.split("::").last
         | 
| 174 | 
            +
                    if children.empty?
         | 
| 175 | 
            +
                      "#{name}(#{extra_inspect})"
         | 
| 176 | 
            +
                    else
         | 
| 177 | 
            +
                      str = String.new
         | 
| 178 | 
            +
                      str << "#{name}(\n"
         | 
| 179 | 
            +
                      children.each do |name, mod|
         | 
| 180 | 
            +
                        str << "  (#{name}): #{mod.inspect}\n"
         | 
| 181 | 
            +
                      end
         | 
| 182 | 
            +
                      str << ")"
         | 
| 183 | 
            +
                    end
         | 
| 184 | 
            +
                  end
         | 
| 185 | 
            +
             | 
| 36 186 | 
             
                  def method_missing(method, *args, &block)
         | 
| 37 | 
            -
                     | 
| 187 | 
            +
                    name = method.to_s
         | 
| 188 | 
            +
                    if named_parameters.key?(name)
         | 
| 189 | 
            +
                      named_parameters[name]
         | 
| 190 | 
            +
                    elsif named_buffers.key?(name)
         | 
| 191 | 
            +
                      named_buffers[name]
         | 
| 192 | 
            +
                    elsif named_modules.key?(name)
         | 
| 193 | 
            +
                      named_modules[name]
         | 
| 194 | 
            +
                    else
         | 
| 195 | 
            +
                      super
         | 
| 196 | 
            +
                    end
         | 
| 38 197 | 
             
                  end
         | 
| 39 198 |  | 
| 40 199 | 
             
                  def respond_to?(method, include_private = false)
         | 
| 41 | 
            -
                     | 
| 200 | 
            +
                    name = method.to_s
         | 
| 201 | 
            +
                    named_parameters.key?(name) || named_buffers.key?(name) || named_modules.key?(name) || super
         | 
| 42 202 | 
             
                  end
         | 
| 43 203 |  | 
| 44 204 | 
             
                  private
         | 
| 45 205 |  | 
| 46 | 
            -
                  def  | 
| 47 | 
            -
                     | 
| 48 | 
            -
             | 
| 49 | 
            -
             | 
| 50 | 
            -
             | 
| 51 | 
            -
                     | 
| 52 | 
            -
             | 
| 206 | 
            +
                  def extra_inspect
         | 
| 207 | 
            +
                    nil
         | 
| 208 | 
            +
                  end
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                  def format(str, *vars, **options)
         | 
| 211 | 
            +
                    vars =
         | 
| 212 | 
            +
                      if vars.any?
         | 
| 213 | 
            +
                        vars.map(&:inspect)
         | 
| 214 | 
            +
                      else
         | 
| 215 | 
            +
                        options.map { |k, v| [k, v.inspect] }.to_h
         | 
| 216 | 
            +
                      end
         | 
| 217 | 
            +
                    str % vars
         | 
| 53 218 | 
             
                  end
         | 
| 54 219 | 
             
                end
         | 
| 55 220 | 
             
              end
         | 
    
        data/lib/torch/nn/mse_loss.rb
    CHANGED
    
    
| @@ -0,0 +1,13 @@ | |
| 1 | 
            +
            module Torch
         | 
| 2 | 
            +
              module NN
         | 
| 3 | 
            +
                class MultiLabelSoftMarginLoss < WeightedLoss
         | 
| 4 | 
            +
                  def initialize(weight: nil, reduction: "mean")
         | 
| 5 | 
            +
                    super(weight, reduction)
         | 
| 6 | 
            +
                  end
         | 
| 7 | 
            +
             | 
| 8 | 
            +
                  def forward(input, target)
         | 
| 9 | 
            +
                    F.multilabel_soft_margin_loss(input, target, weight: @weight, reduction: @reduction)
         | 
| 10 | 
            +
                  end
         | 
| 11 | 
            +
                end
         | 
| 12 | 
            +
              end
         | 
| 13 | 
            +
            end
         | 
| @@ -0,0 +1,17 @@ | |
| 1 | 
            +
            module Torch
         | 
| 2 | 
            +
              module NN
         | 
| 3 | 
            +
                class MultiMarginLoss < WeightedLoss
         | 
| 4 | 
            +
                  def initialize(p: 1, margin: 1.0, weight: nil, reduction: "mean")
         | 
| 5 | 
            +
                    super(weight, reduction)
         | 
| 6 | 
            +
                    raise ArgumentError, "only p == 1 and p == 2 supported" if p != 1 && p != 2
         | 
| 7 | 
            +
                    raise ArgumentError, "weight must be nil or have one dimension" unless weight.nil? || weight.dim == 1
         | 
| 8 | 
            +
                    @p = p
         | 
| 9 | 
            +
                    @margin = margin
         | 
| 10 | 
            +
                  end
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                  def forward(input, target)
         | 
| 13 | 
            +
                    F.multi_margin_loss(input, target, p: @p, margin: @margin, weight: @weight, reduction: @reduction)
         | 
| 14 | 
            +
                  end
         | 
| 15 | 
            +
                end
         | 
| 16 | 
            +
              end
         | 
| 17 | 
            +
            end
         | 
| @@ -0,0 +1,14 @@ | |
| 1 | 
            +
            module Torch
         | 
| 2 | 
            +
              module NN
         | 
| 3 | 
            +
                class NLLLoss < WeightedLoss
         | 
| 4 | 
            +
                  def initialize(weight: nil, ignore_index: -100, reduction: "mean")
         | 
| 5 | 
            +
                    super(weight, reduction)
         | 
| 6 | 
            +
                    @ignore_index = ignore_index
         | 
| 7 | 
            +
                  end
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                  def forward(input, target)
         | 
| 10 | 
            +
                    F.nll_loss(input, target, weight: @weight, ignore_index: @ignore_index, reduction: @reduction)
         | 
| 11 | 
            +
                  end
         | 
| 12 | 
            +
                end
         | 
| 13 | 
            +
              end
         | 
| 14 | 
            +
            end
         | 
| @@ -0,0 +1,16 @@ | |
| 1 | 
            +
            module Torch
         | 
| 2 | 
            +
              module NN
         | 
| 3 | 
            +
                class PairwiseDistance < Module
         | 
| 4 | 
            +
                  def initialize(p: 2.0, eps: 1e-6, keepdim: false)
         | 
| 5 | 
            +
                    super()
         | 
| 6 | 
            +
                    @norm = p
         | 
| 7 | 
            +
                    @eps = eps
         | 
| 8 | 
            +
                    @keepdim = keepdim
         | 
| 9 | 
            +
                  end
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                  def forward(x1, x2)
         | 
| 12 | 
            +
                    F.pairwise_distance(x1, x2, p: @norm, eps: @eps, keepdim: @keepdim)
         | 
| 13 | 
            +
                  end
         | 
| 14 | 
            +
                end
         | 
| 15 | 
            +
              end
         | 
| 16 | 
            +
            end
         | 
    
        data/lib/torch/nn/parameter.rb
    CHANGED
    
    
| @@ -0,0 +1,16 @@ | |
| 1 | 
            +
            module Torch
         | 
| 2 | 
            +
              module NN
         | 
| 3 | 
            +
                class PoissonNLLLoss < Loss
         | 
| 4 | 
            +
                  def initialize(log_input: true, full: false, eps: 1e-8, reduction: "mean")
         | 
| 5 | 
            +
                    super(reduction)
         | 
| 6 | 
            +
                    @log_input = log_input
         | 
| 7 | 
            +
                    @full = full
         | 
| 8 | 
            +
                    @eps = eps
         | 
| 9 | 
            +
                  end
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                  def forward(log_input, target)
         | 
| 12 | 
            +
                    F.poisson_nll_loss(log_input, target, log_input: @log_input, full: @full, eps: @eps, reduction: @reduction)
         | 
| 13 | 
            +
                  end
         | 
| 14 | 
            +
                end
         | 
| 15 | 
            +
              end
         | 
| 16 | 
            +
            end
         | 
| @@ -0,0 +1,19 @@ | |
| 1 | 
            +
            module Torch
         | 
| 2 | 
            +
              module NN
         | 
| 3 | 
            +
                class PReLU < Module
         | 
| 4 | 
            +
                  def initialize(num_parameters: 1, init: 0.25)
         | 
| 5 | 
            +
                    @num_parameters = num_parameters
         | 
| 6 | 
            +
                    super()
         | 
| 7 | 
            +
                    @weight = Parameter.new(Tensor.new(num_parameters).fill!(init))
         | 
| 8 | 
            +
                  end
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                  def forward(input)
         | 
| 11 | 
            +
                    F.prelu(input, @weight)
         | 
| 12 | 
            +
                  end
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                  def extra_inspect
         | 
| 15 | 
            +
                    format("num_parameters: %s", @num_parameters)
         | 
| 16 | 
            +
                  end
         | 
| 17 | 
            +
                end
         | 
| 18 | 
            +
              end
         | 
| 19 | 
            +
            end
         |