torch-rb 0.1.0 → 0.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +40 -0
  3. data/LICENSE.txt +46 -22
  4. data/README.md +85 -19
  5. data/ext/torch/ext.cpp +274 -256
  6. data/ext/torch/extconf.rb +9 -0
  7. data/ext/torch/nn_functions.cpp +595 -0
  8. data/ext/torch/nn_functions.hpp +6 -0
  9. data/ext/torch/templates.hpp +250 -0
  10. data/ext/torch/tensor_functions.cpp +1860 -0
  11. data/ext/torch/tensor_functions.hpp +6 -0
  12. data/ext/torch/torch_functions.cpp +2875 -0
  13. data/ext/torch/torch_functions.hpp +6 -0
  14. data/lib/torch.rb +199 -84
  15. data/lib/torch/ext.bundle +0 -0
  16. data/lib/torch/inspector.rb +52 -25
  17. data/lib/torch/native/dispatcher.rb +48 -0
  18. data/lib/torch/native/function.rb +78 -0
  19. data/lib/torch/native/generator.rb +149 -0
  20. data/lib/torch/native/native_functions.yaml +6837 -0
  21. data/lib/torch/native/parser.rb +97 -0
  22. data/lib/torch/nn/alpha_dropout.rb +9 -0
  23. data/lib/torch/nn/avg_pool2d.rb +14 -0
  24. data/lib/torch/nn/avg_poolnd.rb +9 -0
  25. data/lib/torch/nn/bce_loss.rb +13 -0
  26. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  27. data/lib/torch/nn/bilinear.rb +38 -0
  28. data/lib/torch/nn/conv2d.rb +14 -29
  29. data/lib/torch/nn/convnd.rb +41 -0
  30. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  31. data/lib/torch/nn/cosine_similarity.rb +15 -0
  32. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  33. data/lib/torch/nn/ctc_loss.rb +15 -0
  34. data/lib/torch/nn/dropout.rb +9 -0
  35. data/lib/torch/nn/dropout2d.rb +9 -0
  36. data/lib/torch/nn/dropout3d.rb +9 -0
  37. data/lib/torch/nn/dropoutnd.rb +15 -0
  38. data/lib/torch/nn/embedding.rb +52 -0
  39. data/lib/torch/nn/embedding_bag.rb +34 -0
  40. data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
  41. data/lib/torch/nn/functional.rb +194 -11
  42. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  43. data/lib/torch/nn/identity.rb +14 -0
  44. data/lib/torch/nn/init.rb +58 -1
  45. data/lib/torch/nn/kl_div_loss.rb +13 -0
  46. data/lib/torch/nn/l1_loss.rb +13 -0
  47. data/lib/torch/nn/leaky_relu.rb +20 -0
  48. data/lib/torch/nn/linear.rb +12 -11
  49. data/lib/torch/nn/log_softmax.rb +14 -0
  50. data/lib/torch/nn/loss.rb +10 -0
  51. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  52. data/lib/torch/nn/max_pool2d.rb +9 -0
  53. data/lib/torch/nn/max_poolnd.rb +19 -0
  54. data/lib/torch/nn/module.rb +184 -19
  55. data/lib/torch/nn/mse_loss.rb +2 -2
  56. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  57. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  58. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  59. data/lib/torch/nn/nll_loss.rb +14 -0
  60. data/lib/torch/nn/pairwise_distance.rb +16 -0
  61. data/lib/torch/nn/parameter.rb +4 -0
  62. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  63. data/lib/torch/nn/prelu.rb +19 -0
  64. data/lib/torch/nn/relu.rb +8 -3
  65. data/lib/torch/nn/rnn.rb +22 -0
  66. data/lib/torch/nn/rnn_base.rb +154 -0
  67. data/lib/torch/nn/sequential.rb +1 -10
  68. data/lib/torch/nn/sigmoid.rb +9 -0
  69. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  70. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  71. data/lib/torch/nn/softmax.rb +18 -0
  72. data/lib/torch/nn/softmax2d.rb +10 -0
  73. data/lib/torch/nn/softmin.rb +14 -0
  74. data/lib/torch/nn/softplus.rb +19 -0
  75. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  76. data/lib/torch/nn/weighted_loss.rb +10 -0
  77. data/lib/torch/optim/adadelta.rb +57 -0
  78. data/lib/torch/optim/adagrad.rb +71 -0
  79. data/lib/torch/optim/adam.rb +81 -0
  80. data/lib/torch/optim/adamax.rb +68 -0
  81. data/lib/torch/optim/adamw.rb +82 -0
  82. data/lib/torch/optim/asgd.rb +65 -0
  83. data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
  84. data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
  85. data/lib/torch/optim/optimizer.rb +62 -0
  86. data/lib/torch/optim/rmsprop.rb +76 -0
  87. data/lib/torch/optim/rprop.rb +68 -0
  88. data/lib/torch/optim/sgd.rb +60 -0
  89. data/lib/torch/random.rb +10 -0
  90. data/lib/torch/tensor.rb +92 -21
  91. data/lib/torch/utils/data/data_loader.rb +15 -0
  92. data/lib/torch/utils/data/tensor_dataset.rb +8 -1
  93. data/lib/torch/version.rb +1 -1
  94. metadata +74 -3
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class LogSoftmax < Module
4
+ def initialize(dim: nil)
5
+ super()
6
+ @dim = dim
7
+ end
8
+
9
+ def forward(input)
10
+ F.log_softmax(input, @dim)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class Loss < Module
4
+ def initialize(reduction)
5
+ super()
6
+ @reduction = reduction
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class MarginRankingLoss < Loss
4
+ def initialize(margin: 1.0, reduction: "mean")
5
+ super(reduction)
6
+ @margin = margin
7
+ end
8
+
9
+ def forward(input1, input2, target)
10
+ F.margin_ranking_loss(input1, input2, target, margin: @margin, reduction: @reduction)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class MaxPool2d < MaxPoolNd
4
+ def forward(input)
5
+ F.max_pool2d(input, @kernel_size) # TODO other parameters
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,19 @@
1
+ module Torch
2
+ module NN
3
+ class MaxPoolNd < Module
4
+ def initialize(kernel_size) #, stride: nil, padding: 0, dilation: 1, return_indices: false, ceil_mode: false)
5
+ super()
6
+ @kernel_size = kernel_size
7
+ # @stride = stride || kernel_size
8
+ # @padding = padding
9
+ # @dilation = dilation
10
+ # @return_indices = return_indices
11
+ # @ceil_mode = ceil_mode
12
+ end
13
+
14
+ def extra_inspect
15
+ format("kernel_size: %s", @kernel_size)
16
+ end
17
+ end
18
+ end
19
+ end
@@ -1,55 +1,220 @@
1
1
  module Torch
2
2
  module NN
3
3
  class Module
4
- def inspect
5
- str = String.new
6
- str << "#{self.class.name}(\n"
7
- modules.each do |name, mod|
8
- str << " (#{name}): #{mod.inspect}\n"
4
+ def initialize
5
+ @training = true
6
+ @parameters = {}
7
+ @buffers = {}
8
+ @modules = {}
9
+ end
10
+
11
+ def forward
12
+ raise NotImplementedError
13
+ end
14
+
15
+ def register_buffer(name, tensor)
16
+ # TODO add checks
17
+ @buffers[name] = tensor
18
+ end
19
+
20
+ def register_parameter(name, param)
21
+ # TODO add checks
22
+ @parameters[name] = param
23
+ end
24
+
25
+ def add_module(name, mod)
26
+ # TODO add checks
27
+ @modules[name] = mod
28
+ end
29
+
30
+ def _apply(fn)
31
+ children.each do |mod|
32
+ mod._apply(fn)
33
+ end
34
+ # TODO apply to more objects
35
+ self
36
+ end
37
+
38
+ def apply(fn)
39
+ children.each do |mod|
40
+ mod.apply(fn)
41
+ end
42
+ fn.call(self)
43
+ self
44
+ end
45
+
46
+ def cuda(device: nil)
47
+ _apply ->(t) { t.cuda(device) }
48
+ end
49
+
50
+ def cpu
51
+ _apply ->(t) { t.cpu }
52
+ end
53
+
54
+ def type(dst_type)
55
+ _apply ->(t) { t.type(dst_type) }
56
+ end
57
+
58
+ def float
59
+ _apply ->(t) { t.floating_point? ? t.float : t }
60
+ end
61
+
62
+ def double
63
+ _apply ->(t) { t.floating_point? ? t.double : t }
64
+ end
65
+
66
+ def half
67
+ _apply ->(t) { t.floating_point? ? t.half : t }
68
+ end
69
+
70
+ # modifies in-place
71
+ def to(device)
72
+ convert = lambda do |t|
73
+ t.to(device)
9
74
  end
10
- str << ")"
75
+
76
+ _apply(convert)
11
77
  end
12
78
 
13
79
  def call(*input)
14
80
  forward(*input)
15
81
  end
16
82
 
83
+ def state_dict
84
+ raise NotImplementedYet
85
+ end
86
+
17
87
  def parameters
18
- params = []
88
+ named_parameters.values
89
+ end
90
+
91
+ def named_parameters(prefix: "", recurse: true)
92
+ params = {}
93
+ if recurse
94
+ named_children.each do |name, mod|
95
+ params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
96
+ end
97
+ end
19
98
  instance_variables.each do |name|
20
99
  param = instance_variable_get(name)
21
- params << param if param.is_a?(Parameter)
100
+ params[[prefix, name[1..-1]].join] = param if param.is_a?(Parameter)
101
+ end
102
+ @parameters.each do |name, param|
103
+ params[[prefix, name].join] = param
22
104
  end
23
- params + modules.flat_map { |_, mod| mod.parameters }
105
+ params
106
+ end
107
+
108
+ def buffers
109
+ named_buffers.values
110
+ end
111
+
112
+ def named_buffers
113
+ @buffers || {}
114
+ end
115
+
116
+ def children
117
+ named_children.values
118
+ end
119
+
120
+ def named_children
121
+ modules = {}
122
+ instance_variables.each do |name|
123
+ mod = instance_variable_get(name)
124
+ modules[name[1..-1]] = mod if mod.is_a?(Module)
125
+ end
126
+ @modules.each do |name, mod|
127
+ modules[name] = mod
128
+ end
129
+ modules
130
+ end
131
+
132
+ def modules
133
+ named_modules.values
134
+ end
135
+
136
+ def named_modules
137
+ {"" => self}.merge(named_children)
138
+ end
139
+
140
+ def train(mode = true)
141
+ @training = mode
142
+ children.each do |mod|
143
+ mod.train(mode)
144
+ end
145
+ self
146
+ end
147
+
148
+ def eval
149
+ train(false)
150
+ end
151
+
152
+ def requires_grad!(requires_grad: true)
153
+ parameters.each do |p|
154
+ p.requires_grad!(requires_grad)
155
+ end
156
+ self
24
157
  end
25
158
 
26
159
  def zero_grad
27
160
  parameters.each do |param|
28
161
  if param.grad
29
- raise Error, "Not supported yet"
30
162
  param.grad.detach!
31
163
  param.grad.zero!
32
164
  end
33
165
  end
34
166
  end
35
167
 
168
+ def share_memory
169
+ _apply ->(t) { t.share_memory! }
170
+ end
171
+
172
+ def inspect
173
+ name = self.class.name.split("::").last
174
+ if children.empty?
175
+ "#{name}(#{extra_inspect})"
176
+ else
177
+ str = String.new
178
+ str << "#{name}(\n"
179
+ children.each do |name, mod|
180
+ str << " (#{name}): #{mod.inspect}\n"
181
+ end
182
+ str << ")"
183
+ end
184
+ end
185
+
36
186
  def method_missing(method, *args, &block)
37
- modules[method.to_s] || super
187
+ name = method.to_s
188
+ if named_parameters.key?(name)
189
+ named_parameters[name]
190
+ elsif named_buffers.key?(name)
191
+ named_buffers[name]
192
+ elsif named_modules.key?(name)
193
+ named_modules[name]
194
+ else
195
+ super
196
+ end
38
197
  end
39
198
 
40
199
  def respond_to?(method, include_private = false)
41
- modules.key?(method.to_s) || super
200
+ name = method.to_s
201
+ named_parameters.key?(name) || named_buffers.key?(name) || named_modules.key?(name) || super
42
202
  end
43
203
 
44
204
  private
45
205
 
46
- def modules
47
- modules = {}
48
- instance_variables.each do |name|
49
- mod = instance_variable_get(name)
50
- modules[name[1..-1]] = mod if mod.is_a?(Module)
51
- end
52
- modules
206
+ def extra_inspect
207
+ nil
208
+ end
209
+
210
+ def format(str, *vars, **options)
211
+ vars =
212
+ if vars.any?
213
+ vars.map(&:inspect)
214
+ else
215
+ options.map { |k, v| [k, v.inspect] }.to_h
216
+ end
217
+ str % vars
53
218
  end
54
219
  end
55
220
  end
@@ -1,8 +1,8 @@
1
1
  module Torch
2
2
  module NN
3
- class MSELoss < Module
3
+ class MSELoss < Loss
4
4
  def initialize(reduction: "mean")
5
- @reduction = reduction
5
+ super(reduction)
6
6
  end
7
7
 
8
8
  def forward(input, target)
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class MultiLabelMarginLoss < Loss
4
+ def initialize(reduction: "mean")
5
+ super(reduction)
6
+ end
7
+
8
+ def forward(input, target)
9
+ F.multilabel_margin_loss(input, target, reduction: @reduction)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class MultiLabelSoftMarginLoss < WeightedLoss
4
+ def initialize(weight: nil, reduction: "mean")
5
+ super(weight, reduction)
6
+ end
7
+
8
+ def forward(input, target)
9
+ F.multilabel_soft_margin_loss(input, target, weight: @weight, reduction: @reduction)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,17 @@
1
+ module Torch
2
+ module NN
3
+ class MultiMarginLoss < WeightedLoss
4
+ def initialize(p: 1, margin: 1.0, weight: nil, reduction: "mean")
5
+ super(weight, reduction)
6
+ raise ArgumentError, "only p == 1 and p == 2 supported" if p != 1 && p != 2
7
+ raise ArgumentError, "weight must be nil or have one dimension" unless weight.nil? || weight.dim == 1
8
+ @p = p
9
+ @margin = margin
10
+ end
11
+
12
+ def forward(input, target)
13
+ F.multi_margin_loss(input, target, p: @p, margin: @margin, weight: @weight, reduction: @reduction)
14
+ end
15
+ end
16
+ end
17
+ end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class NLLLoss < WeightedLoss
4
+ def initialize(weight: nil, ignore_index: -100, reduction: "mean")
5
+ super(weight, reduction)
6
+ @ignore_index = ignore_index
7
+ end
8
+
9
+ def forward(input, target)
10
+ F.nll_loss(input, target, weight: @weight, ignore_index: @ignore_index, reduction: @reduction)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,16 @@
1
+ module Torch
2
+ module NN
3
+ class PairwiseDistance < Module
4
+ def initialize(p: 2.0, eps: 1e-6, keepdim: false)
5
+ super()
6
+ @norm = p
7
+ @eps = eps
8
+ @keepdim = keepdim
9
+ end
10
+
11
+ def forward(x1, x2)
12
+ F.pairwise_distance(x1, x2, p: @norm, eps: @eps, keepdim: @keepdim)
13
+ end
14
+ end
15
+ end
16
+ end
@@ -5,6 +5,10 @@ module Torch
5
5
  data = Tensor.new unless data
6
6
  Tensor._make_subclass(data, requires_grad)
7
7
  end
8
+
9
+ def inspect
10
+ "Parameter containing:\n#{super}"
11
+ end
8
12
  end
9
13
  end
10
14
  end
@@ -0,0 +1,16 @@
1
+ module Torch
2
+ module NN
3
+ class PoissonNLLLoss < Loss
4
+ def initialize(log_input: true, full: false, eps: 1e-8, reduction: "mean")
5
+ super(reduction)
6
+ @log_input = log_input
7
+ @full = full
8
+ @eps = eps
9
+ end
10
+
11
+ def forward(log_input, target)
12
+ F.poisson_nll_loss(log_input, target, log_input: @log_input, full: @full, eps: @eps, reduction: @reduction)
13
+ end
14
+ end
15
+ end
16
+ end
@@ -0,0 +1,19 @@
1
+ module Torch
2
+ module NN
3
+ class PReLU < Module
4
+ def initialize(num_parameters: 1, init: 0.25)
5
+ @num_parameters = num_parameters
6
+ super()
7
+ @weight = Parameter.new(Tensor.new(num_parameters).fill!(init))
8
+ end
9
+
10
+ def forward(input)
11
+ F.prelu(input, @weight)
12
+ end
13
+
14
+ def extra_inspect
15
+ format("num_parameters: %s", @num_parameters)
16
+ end
17
+ end
18
+ end
19
+ end