torch-rb 0.1.0 → 0.1.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (94) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +40 -0
  3. data/LICENSE.txt +46 -22
  4. data/README.md +85 -19
  5. data/ext/torch/ext.cpp +274 -256
  6. data/ext/torch/extconf.rb +9 -0
  7. data/ext/torch/nn_functions.cpp +595 -0
  8. data/ext/torch/nn_functions.hpp +6 -0
  9. data/ext/torch/templates.hpp +250 -0
  10. data/ext/torch/tensor_functions.cpp +1860 -0
  11. data/ext/torch/tensor_functions.hpp +6 -0
  12. data/ext/torch/torch_functions.cpp +2875 -0
  13. data/ext/torch/torch_functions.hpp +6 -0
  14. data/lib/torch.rb +199 -84
  15. data/lib/torch/ext.bundle +0 -0
  16. data/lib/torch/inspector.rb +52 -25
  17. data/lib/torch/native/dispatcher.rb +48 -0
  18. data/lib/torch/native/function.rb +78 -0
  19. data/lib/torch/native/generator.rb +149 -0
  20. data/lib/torch/native/native_functions.yaml +6837 -0
  21. data/lib/torch/native/parser.rb +97 -0
  22. data/lib/torch/nn/alpha_dropout.rb +9 -0
  23. data/lib/torch/nn/avg_pool2d.rb +14 -0
  24. data/lib/torch/nn/avg_poolnd.rb +9 -0
  25. data/lib/torch/nn/bce_loss.rb +13 -0
  26. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  27. data/lib/torch/nn/bilinear.rb +38 -0
  28. data/lib/torch/nn/conv2d.rb +14 -29
  29. data/lib/torch/nn/convnd.rb +41 -0
  30. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  31. data/lib/torch/nn/cosine_similarity.rb +15 -0
  32. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  33. data/lib/torch/nn/ctc_loss.rb +15 -0
  34. data/lib/torch/nn/dropout.rb +9 -0
  35. data/lib/torch/nn/dropout2d.rb +9 -0
  36. data/lib/torch/nn/dropout3d.rb +9 -0
  37. data/lib/torch/nn/dropoutnd.rb +15 -0
  38. data/lib/torch/nn/embedding.rb +52 -0
  39. data/lib/torch/nn/embedding_bag.rb +34 -0
  40. data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
  41. data/lib/torch/nn/functional.rb +194 -11
  42. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  43. data/lib/torch/nn/identity.rb +14 -0
  44. data/lib/torch/nn/init.rb +58 -1
  45. data/lib/torch/nn/kl_div_loss.rb +13 -0
  46. data/lib/torch/nn/l1_loss.rb +13 -0
  47. data/lib/torch/nn/leaky_relu.rb +20 -0
  48. data/lib/torch/nn/linear.rb +12 -11
  49. data/lib/torch/nn/log_softmax.rb +14 -0
  50. data/lib/torch/nn/loss.rb +10 -0
  51. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  52. data/lib/torch/nn/max_pool2d.rb +9 -0
  53. data/lib/torch/nn/max_poolnd.rb +19 -0
  54. data/lib/torch/nn/module.rb +184 -19
  55. data/lib/torch/nn/mse_loss.rb +2 -2
  56. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  57. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  58. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  59. data/lib/torch/nn/nll_loss.rb +14 -0
  60. data/lib/torch/nn/pairwise_distance.rb +16 -0
  61. data/lib/torch/nn/parameter.rb +4 -0
  62. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  63. data/lib/torch/nn/prelu.rb +19 -0
  64. data/lib/torch/nn/relu.rb +8 -3
  65. data/lib/torch/nn/rnn.rb +22 -0
  66. data/lib/torch/nn/rnn_base.rb +154 -0
  67. data/lib/torch/nn/sequential.rb +1 -10
  68. data/lib/torch/nn/sigmoid.rb +9 -0
  69. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  70. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  71. data/lib/torch/nn/softmax.rb +18 -0
  72. data/lib/torch/nn/softmax2d.rb +10 -0
  73. data/lib/torch/nn/softmin.rb +14 -0
  74. data/lib/torch/nn/softplus.rb +19 -0
  75. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  76. data/lib/torch/nn/weighted_loss.rb +10 -0
  77. data/lib/torch/optim/adadelta.rb +57 -0
  78. data/lib/torch/optim/adagrad.rb +71 -0
  79. data/lib/torch/optim/adam.rb +81 -0
  80. data/lib/torch/optim/adamax.rb +68 -0
  81. data/lib/torch/optim/adamw.rb +82 -0
  82. data/lib/torch/optim/asgd.rb +65 -0
  83. data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
  84. data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
  85. data/lib/torch/optim/optimizer.rb +62 -0
  86. data/lib/torch/optim/rmsprop.rb +76 -0
  87. data/lib/torch/optim/rprop.rb +68 -0
  88. data/lib/torch/optim/sgd.rb +60 -0
  89. data/lib/torch/random.rb +10 -0
  90. data/lib/torch/tensor.rb +92 -21
  91. data/lib/torch/utils/data/data_loader.rb +15 -0
  92. data/lib/torch/utils/data/tensor_dataset.rb +8 -1
  93. data/lib/torch/version.rb +1 -1
  94. metadata +74 -3
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class LogSoftmax < Module
4
+ def initialize(dim: nil)
5
+ super()
6
+ @dim = dim
7
+ end
8
+
9
+ def forward(input)
10
+ F.log_softmax(input, @dim)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class Loss < Module
4
+ def initialize(reduction)
5
+ super()
6
+ @reduction = reduction
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class MarginRankingLoss < Loss
4
+ def initialize(margin: 1.0, reduction: "mean")
5
+ super(reduction)
6
+ @margin = margin
7
+ end
8
+
9
+ def forward(input1, input2, target)
10
+ F.margin_ranking_loss(input1, input2, target, margin: @margin, reduction: @reduction)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class MaxPool2d < MaxPoolNd
4
+ def forward(input)
5
+ F.max_pool2d(input, @kernel_size) # TODO other parameters
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,19 @@
1
+ module Torch
2
+ module NN
3
+ class MaxPoolNd < Module
4
+ def initialize(kernel_size) #, stride: nil, padding: 0, dilation: 1, return_indices: false, ceil_mode: false)
5
+ super()
6
+ @kernel_size = kernel_size
7
+ # @stride = stride || kernel_size
8
+ # @padding = padding
9
+ # @dilation = dilation
10
+ # @return_indices = return_indices
11
+ # @ceil_mode = ceil_mode
12
+ end
13
+
14
+ def extra_inspect
15
+ format("kernel_size: %s", @kernel_size)
16
+ end
17
+ end
18
+ end
19
+ end
@@ -1,55 +1,220 @@
1
1
  module Torch
2
2
  module NN
3
3
  class Module
4
- def inspect
5
- str = String.new
6
- str << "#{self.class.name}(\n"
7
- modules.each do |name, mod|
8
- str << " (#{name}): #{mod.inspect}\n"
4
+ def initialize
5
+ @training = true
6
+ @parameters = {}
7
+ @buffers = {}
8
+ @modules = {}
9
+ end
10
+
11
+ def forward
12
+ raise NotImplementedError
13
+ end
14
+
15
+ def register_buffer(name, tensor)
16
+ # TODO add checks
17
+ @buffers[name] = tensor
18
+ end
19
+
20
+ def register_parameter(name, param)
21
+ # TODO add checks
22
+ @parameters[name] = param
23
+ end
24
+
25
+ def add_module(name, mod)
26
+ # TODO add checks
27
+ @modules[name] = mod
28
+ end
29
+
30
+ def _apply(fn)
31
+ children.each do |mod|
32
+ mod._apply(fn)
33
+ end
34
+ # TODO apply to more objects
35
+ self
36
+ end
37
+
38
+ def apply(fn)
39
+ children.each do |mod|
40
+ mod.apply(fn)
41
+ end
42
+ fn.call(self)
43
+ self
44
+ end
45
+
46
+ def cuda(device: nil)
47
+ _apply ->(t) { t.cuda(device) }
48
+ end
49
+
50
+ def cpu
51
+ _apply ->(t) { t.cpu }
52
+ end
53
+
54
+ def type(dst_type)
55
+ _apply ->(t) { t.type(dst_type) }
56
+ end
57
+
58
+ def float
59
+ _apply ->(t) { t.floating_point? ? t.float : t }
60
+ end
61
+
62
+ def double
63
+ _apply ->(t) { t.floating_point? ? t.double : t }
64
+ end
65
+
66
+ def half
67
+ _apply ->(t) { t.floating_point? ? t.half : t }
68
+ end
69
+
70
+ # modifies in-place
71
+ def to(device)
72
+ convert = lambda do |t|
73
+ t.to(device)
9
74
  end
10
- str << ")"
75
+
76
+ _apply(convert)
11
77
  end
12
78
 
13
79
  def call(*input)
14
80
  forward(*input)
15
81
  end
16
82
 
83
+ def state_dict
84
+ raise NotImplementedYet
85
+ end
86
+
17
87
  def parameters
18
- params = []
88
+ named_parameters.values
89
+ end
90
+
91
+ def named_parameters(prefix: "", recurse: true)
92
+ params = {}
93
+ if recurse
94
+ named_children.each do |name, mod|
95
+ params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
96
+ end
97
+ end
19
98
  instance_variables.each do |name|
20
99
  param = instance_variable_get(name)
21
- params << param if param.is_a?(Parameter)
100
+ params[[prefix, name[1..-1]].join] = param if param.is_a?(Parameter)
101
+ end
102
+ @parameters.each do |name, param|
103
+ params[[prefix, name].join] = param
22
104
  end
23
- params + modules.flat_map { |_, mod| mod.parameters }
105
+ params
106
+ end
107
+
108
+ def buffers
109
+ named_buffers.values
110
+ end
111
+
112
+ def named_buffers
113
+ @buffers || {}
114
+ end
115
+
116
+ def children
117
+ named_children.values
118
+ end
119
+
120
+ def named_children
121
+ modules = {}
122
+ instance_variables.each do |name|
123
+ mod = instance_variable_get(name)
124
+ modules[name[1..-1]] = mod if mod.is_a?(Module)
125
+ end
126
+ @modules.each do |name, mod|
127
+ modules[name] = mod
128
+ end
129
+ modules
130
+ end
131
+
132
+ def modules
133
+ named_modules.values
134
+ end
135
+
136
+ def named_modules
137
+ {"" => self}.merge(named_children)
138
+ end
139
+
140
+ def train(mode = true)
141
+ @training = mode
142
+ children.each do |mod|
143
+ mod.train(mode)
144
+ end
145
+ self
146
+ end
147
+
148
+ def eval
149
+ train(false)
150
+ end
151
+
152
+ def requires_grad!(requires_grad: true)
153
+ parameters.each do |p|
154
+ p.requires_grad!(requires_grad)
155
+ end
156
+ self
24
157
  end
25
158
 
26
159
  def zero_grad
27
160
  parameters.each do |param|
28
161
  if param.grad
29
- raise Error, "Not supported yet"
30
162
  param.grad.detach!
31
163
  param.grad.zero!
32
164
  end
33
165
  end
34
166
  end
35
167
 
168
+ def share_memory
169
+ _apply ->(t) { t.share_memory! }
170
+ end
171
+
172
+ def inspect
173
+ name = self.class.name.split("::").last
174
+ if children.empty?
175
+ "#{name}(#{extra_inspect})"
176
+ else
177
+ str = String.new
178
+ str << "#{name}(\n"
179
+ children.each do |name, mod|
180
+ str << " (#{name}): #{mod.inspect}\n"
181
+ end
182
+ str << ")"
183
+ end
184
+ end
185
+
36
186
  def method_missing(method, *args, &block)
37
- modules[method.to_s] || super
187
+ name = method.to_s
188
+ if named_parameters.key?(name)
189
+ named_parameters[name]
190
+ elsif named_buffers.key?(name)
191
+ named_buffers[name]
192
+ elsif named_modules.key?(name)
193
+ named_modules[name]
194
+ else
195
+ super
196
+ end
38
197
  end
39
198
 
40
199
  def respond_to?(method, include_private = false)
41
- modules.key?(method.to_s) || super
200
+ name = method.to_s
201
+ named_parameters.key?(name) || named_buffers.key?(name) || named_modules.key?(name) || super
42
202
  end
43
203
 
44
204
  private
45
205
 
46
- def modules
47
- modules = {}
48
- instance_variables.each do |name|
49
- mod = instance_variable_get(name)
50
- modules[name[1..-1]] = mod if mod.is_a?(Module)
51
- end
52
- modules
206
+ def extra_inspect
207
+ nil
208
+ end
209
+
210
+ def format(str, *vars, **options)
211
+ vars =
212
+ if vars.any?
213
+ vars.map(&:inspect)
214
+ else
215
+ options.map { |k, v| [k, v.inspect] }.to_h
216
+ end
217
+ str % vars
53
218
  end
54
219
  end
55
220
  end
@@ -1,8 +1,8 @@
1
1
  module Torch
2
2
  module NN
3
- class MSELoss < Module
3
+ class MSELoss < Loss
4
4
  def initialize(reduction: "mean")
5
- @reduction = reduction
5
+ super(reduction)
6
6
  end
7
7
 
8
8
  def forward(input, target)
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class MultiLabelMarginLoss < Loss
4
+ def initialize(reduction: "mean")
5
+ super(reduction)
6
+ end
7
+
8
+ def forward(input, target)
9
+ F.multilabel_margin_loss(input, target, reduction: @reduction)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class MultiLabelSoftMarginLoss < WeightedLoss
4
+ def initialize(weight: nil, reduction: "mean")
5
+ super(weight, reduction)
6
+ end
7
+
8
+ def forward(input, target)
9
+ F.multilabel_soft_margin_loss(input, target, weight: @weight, reduction: @reduction)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,17 @@
1
+ module Torch
2
+ module NN
3
+ class MultiMarginLoss < WeightedLoss
4
+ def initialize(p: 1, margin: 1.0, weight: nil, reduction: "mean")
5
+ super(weight, reduction)
6
+ raise ArgumentError, "only p == 1 and p == 2 supported" if p != 1 && p != 2
7
+ raise ArgumentError, "weight must be nil or have one dimension" unless weight.nil? || weight.dim == 1
8
+ @p = p
9
+ @margin = margin
10
+ end
11
+
12
+ def forward(input, target)
13
+ F.multi_margin_loss(input, target, p: @p, margin: @margin, weight: @weight, reduction: @reduction)
14
+ end
15
+ end
16
+ end
17
+ end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class NLLLoss < WeightedLoss
4
+ def initialize(weight: nil, ignore_index: -100, reduction: "mean")
5
+ super(weight, reduction)
6
+ @ignore_index = ignore_index
7
+ end
8
+
9
+ def forward(input, target)
10
+ F.nll_loss(input, target, weight: @weight, ignore_index: @ignore_index, reduction: @reduction)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,16 @@
1
+ module Torch
2
+ module NN
3
+ class PairwiseDistance < Module
4
+ def initialize(p: 2.0, eps: 1e-6, keepdim: false)
5
+ super()
6
+ @norm = p
7
+ @eps = eps
8
+ @keepdim = keepdim
9
+ end
10
+
11
+ def forward(x1, x2)
12
+ F.pairwise_distance(x1, x2, p: @norm, eps: @eps, keepdim: @keepdim)
13
+ end
14
+ end
15
+ end
16
+ end
@@ -5,6 +5,10 @@ module Torch
5
5
  data = Tensor.new unless data
6
6
  Tensor._make_subclass(data, requires_grad)
7
7
  end
8
+
9
+ def inspect
10
+ "Parameter containing:\n#{super}"
11
+ end
8
12
  end
9
13
  end
10
14
  end
@@ -0,0 +1,16 @@
1
+ module Torch
2
+ module NN
3
+ class PoissonNLLLoss < Loss
4
+ def initialize(log_input: true, full: false, eps: 1e-8, reduction: "mean")
5
+ super(reduction)
6
+ @log_input = log_input
7
+ @full = full
8
+ @eps = eps
9
+ end
10
+
11
+ def forward(log_input, target)
12
+ F.poisson_nll_loss(log_input, target, log_input: @log_input, full: @full, eps: @eps, reduction: @reduction)
13
+ end
14
+ end
15
+ end
16
+ end
@@ -0,0 +1,19 @@
1
+ module Torch
2
+ module NN
3
+ class PReLU < Module
4
+ def initialize(num_parameters: 1, init: 0.25)
5
+ @num_parameters = num_parameters
6
+ super()
7
+ @weight = Parameter.new(Tensor.new(num_parameters).fill!(init))
8
+ end
9
+
10
+ def forward(input)
11
+ F.prelu(input, @weight)
12
+ end
13
+
14
+ def extra_inspect
15
+ format("num_parameters: %s", @num_parameters)
16
+ end
17
+ end
18
+ end
19
+ end