torch-rb 0.8.0 → 0.9.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,17 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "utils.h"
6
+
7
+ void init_backends(Rice::Module& m) {
8
+ auto rb_mBackends = Rice::define_module_under(m, "Backends");
9
+
10
+ Rice::define_module_under(rb_mBackends, "OpenMP")
11
+ .add_handler<torch::Error>(handle_error)
12
+ .define_singleton_function("available?", &torch::hasOpenMP);
13
+
14
+ Rice::define_module_under(rb_mBackends, "MKL")
15
+ .add_handler<torch::Error>(handle_error)
16
+ .define_singleton_function("available?", &torch::hasMKL);
17
+ }
data/ext/torch/ext.cpp CHANGED
@@ -2,10 +2,14 @@
2
2
 
3
3
  #include <rice/rice.hpp>
4
4
 
5
+ void init_fft(Rice::Module& m);
6
+ void init_linalg(Rice::Module& m);
5
7
  void init_nn(Rice::Module& m);
8
+ void init_special(Rice::Module& m);
6
9
  void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions);
7
10
  void init_torch(Rice::Module& m);
8
11
 
12
+ void init_backends(Rice::Module& m);
9
13
  void init_cuda(Rice::Module& m);
10
14
  void init_device(Rice::Module& m);
11
15
  void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
@@ -27,7 +31,11 @@ void Init_ext()
27
31
  init_torch(m);
28
32
  init_tensor(m, rb_cTensor, rb_cTensorOptions);
29
33
  init_nn(m);
34
+ init_fft(m);
35
+ init_linalg(m);
36
+ init_special(m);
30
37
 
38
+ init_backends(m);
31
39
  init_cuda(m);
32
40
  init_device(m);
33
41
  init_ivalue(m, rb_cIValue);
data/ext/torch/fft.cpp ADDED
@@ -0,0 +1,13 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "fft_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ void init_fft(Rice::Module& m) {
10
+ auto rb_mFFT = Rice::define_module_under(m, "FFT");
11
+ rb_mFFT.add_handler<torch::Error>(handle_error);
12
+ add_fft_functions(rb_mFFT);
13
+ }
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_fft_functions(Rice::Module& m);
@@ -0,0 +1,13 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "linalg_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ void init_linalg(Rice::Module& m) {
10
+ auto rb_mLinalg = Rice::define_module_under(m, "Linalg");
11
+ rb_mLinalg.add_handler<torch::Error>(handle_error);
12
+ add_linalg_functions(rb_mLinalg);
13
+ }
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_linalg_functions(Rice::Module& m);
@@ -75,7 +75,7 @@ struct RubyArgs {
75
75
  int idx;
76
76
 
77
77
  inline at::Tensor tensor(int i);
78
- inline OptionalTensor optionalTensor(int i);
78
+ inline c10::optional<at::Tensor> optionalTensor(int i);
79
79
  inline at::Scalar scalar(int i);
80
80
  // inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar);
81
81
  inline std::vector<at::Scalar> scalarlist(int i);
@@ -109,6 +109,9 @@ struct RubyArgs {
109
109
  // inline at::QScheme toQScheme(int i);
110
110
  inline std::string string(int i);
111
111
  inline c10::optional<std::string> stringOptional(int i);
112
+ inline c10::string_view stringView(int i);
113
+ // inline c10::string_view stringViewWithDefault(int i, const c10::string_view default_str);
114
+ inline c10::optional<c10::string_view> stringViewOptional(int i);
112
115
  // inline PyObject* pyobject(int i);
113
116
  inline int64_t toInt64(int i);
114
117
  // inline int64_t toInt64WithDefault(int i, int64_t default_int);
@@ -125,8 +128,8 @@ inline at::Tensor RubyArgs::tensor(int i) {
125
128
  return Rice::detail::From_Ruby<torch::Tensor>().convert(args[i]);
126
129
  }
127
130
 
128
- inline OptionalTensor RubyArgs::optionalTensor(int i) {
129
- if (NIL_P(args[i])) return OptionalTensor(Nil);
131
+ inline c10::optional<at::Tensor> RubyArgs::optionalTensor(int i) {
132
+ if (NIL_P(args[i])) return c10::nullopt;
130
133
  return tensor(i);
131
134
  }
132
135
 
@@ -322,6 +325,17 @@ inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
322
325
  return Rice::detail::From_Ruby<std::string>().convert(args[i]);
323
326
  }
324
327
 
328
+ inline c10::string_view RubyArgs::stringView(int i) {
329
+ auto str = Rice::detail::From_Ruby<std::string>().convert(args[i]);
330
+ return c10::string_view(str.data(), str.size());
331
+ }
332
+
333
+ inline c10::optional<c10::string_view> RubyArgs::stringViewOptional(int i) {
334
+ if (NIL_P(args[i])) return c10::nullopt;
335
+ auto str = Rice::detail::From_Ruby<std::string>().convert(args[i]);
336
+ return c10::string_view(str.data(), str.size());
337
+ }
338
+
325
339
  inline int64_t RubyArgs::toInt64(int i) {
326
340
  if (NIL_P(args[i])) return signature.params[i].default_int;
327
341
  return Rice::detail::From_Ruby<int64_t>().convert(args[i]);
@@ -0,0 +1,13 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "special_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ void init_special(Rice::Module& m) {
10
+ auto rb_mSpecial = Rice::define_module_under(m, "Special");
11
+ rb_mSpecial.add_handler<torch::Error>(handle_error);
12
+ add_special_functions(rb_mSpecial);
13
+ }
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_special_functions(Rice::Module& m);
@@ -41,24 +41,6 @@ using torch::nn::init::NonlinearityType;
41
41
  #define RETURN_NIL \
42
42
  return Qnil;
43
43
 
44
- class OptionalTensor {
45
- torch::Tensor value;
46
- public:
47
- OptionalTensor(Object o) {
48
- if (o.is_nil()) {
49
- value = {};
50
- } else {
51
- value = Rice::detail::From_Ruby<torch::Tensor>().convert(o.value());
52
- }
53
- }
54
- OptionalTensor(torch::Tensor o) {
55
- value = o;
56
- }
57
- operator torch::Tensor() const {
58
- return value;
59
- }
60
- };
61
-
62
44
  namespace Rice::detail
63
45
  {
64
46
  template<>
@@ -131,25 +113,6 @@ namespace Rice::detail
131
113
  }
132
114
  };
133
115
 
134
- template<>
135
- struct Type<OptionalTensor>
136
- {
137
- static bool verify()
138
- {
139
- return true;
140
- }
141
- };
142
-
143
- template<>
144
- class From_Ruby<OptionalTensor>
145
- {
146
- public:
147
- OptionalTensor convert(VALUE x)
148
- {
149
- return OptionalTensor(x);
150
- }
151
- };
152
-
153
116
  template<>
154
117
  struct Type<Scalar>
155
118
  {
data/ext/torch/tensor.cpp CHANGED
@@ -107,7 +107,7 @@ static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
107
107
  ParsedArgs<4> parsed_args;
108
108
  auto _r = parser.parse(self_, argc, argv, parsed_args);
109
109
  // _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()
110
- auto dispatch__backward = [](const Tensor & self, TensorList inputs, const OptionalTensor & gradient, c10::optional<bool> retain_graph, bool create_graph) -> void {
110
+ auto dispatch__backward = [](const Tensor & self, TensorList inputs, const c10::optional<at::Tensor> & gradient, c10::optional<bool> retain_graph, bool create_graph) -> void {
111
111
  // in future, release GVL
112
112
  self._backward(inputs, gradient, retain_graph, create_graph);
113
113
  };
@@ -125,13 +125,13 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
125
125
  rb_define_method(rb_cTensor, "backward", (VALUE (*)(...)) tensor__backward, -1);
126
126
 
127
127
  rb_cTensor
128
- .define_method("cuda?", &torch::Tensor::is_cuda)
129
- .define_method("sparse?", &torch::Tensor::is_sparse)
130
- .define_method("quantized?", &torch::Tensor::is_quantized)
131
- .define_method("dim", &torch::Tensor::dim)
132
- .define_method("numel", &torch::Tensor::numel)
133
- .define_method("element_size", &torch::Tensor::element_size)
134
- .define_method("requires_grad", &torch::Tensor::requires_grad)
128
+ .define_method("cuda?", [](Tensor& self) { return self.is_cuda(); })
129
+ .define_method("sparse?", [](Tensor& self) { return self.is_sparse(); })
130
+ .define_method("quantized?", [](Tensor& self) { return self.is_quantized(); })
131
+ .define_method("dim", [](Tensor& self) { return self.dim(); })
132
+ .define_method("numel", [](Tensor& self) { return self.numel(); })
133
+ .define_method("element_size", [](Tensor& self) { return self.element_size(); })
134
+ .define_method("requires_grad", [](Tensor& self) { return self.requires_grad(); })
135
135
  .define_method(
136
136
  "_size",
137
137
  [](Tensor& self, int64_t dim) {
@@ -1,6 +1,8 @@
1
1
  module Torch
2
2
  module NN
3
3
  class ConvNd < Module
4
+ attr_reader :in_channels, :out_channels, :kernel_size, :stride, :padding, :dilation, :transposed, :output_paddding, :groups, :padding_mode
5
+
4
6
  def initialize(in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode)
5
7
  super()
6
8
  raise ArgumentError, "in_channels must be divisible by groups" if in_channels % groups != 0
@@ -0,0 +1,241 @@
1
+ module Torch
2
+ module NN
3
+ class Functional
4
+ class << self
5
+ def in_projection_packed(q, k, v, w, b: nil)
6
+ e = q.size(-1)
7
+
8
+ if k.eql? v
9
+ if q.eql? k
10
+ # self-attention
11
+ return linear(q, w, b).chunk(3, dim: -1)
12
+ else
13
+ # encoder-decoder attention
14
+ w_q, w_kv = w.split_with_sizes([e, e * 2])
15
+ if b.nil?
16
+ b_q = b_kv = nil
17
+ else
18
+ b_q, b_kv = b.split_with_sizes([e, e * 2])
19
+ end
20
+
21
+ return [linear(q, w_q, b_q), *linear(k, w_kv, b_kv).chunk(2, dim: -1)]
22
+ end
23
+ else
24
+ w_q, w_k, w_v = w.chunk(3)
25
+ if b.nil?
26
+ b_q = b_k = b_v = nil
27
+ else
28
+ b_q, b_k, b_v = b.chunk(3)
29
+ end
30
+
31
+ return [linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
32
+ end
33
+ end
34
+
35
+ def in_projection(
36
+ q, k, v,
37
+ w_q, w_k, w_v,
38
+ b_q: nil, b_k: nil, b_v: nil
39
+ )
40
+
41
+ e_q, e_k, e_v = q.size(-1), k.size(-1), v.size(-1)
42
+
43
+ raise ArgumentError, "Expecting query weights shape of #{[e_q, e_q]}, but got #{w_q.shape}" unless w_q.shape == [e_q, e_q]
44
+ raise ArgumentError, "Expecting key weights shape of #{[e_k, e_k]}, but got #{w_k.shape}" unless w_k.shape == [e_k, e_k]
45
+ raise ArgumentError, "Expecting value weights shape of #{[e_v, e_v]}, but got #{w_v.shape}" unless w_v.shape == [e_v, e_v]
46
+
47
+ raise ArgumentError, "Expecting query bias shape of #{[e_q]}, but got #{b_q.shape}" if b_q && b_q.shape != [e_q]
48
+ raise ArgumentError, "Expecting key bias shape of #{[e_k]}, but got #{b_k.shape}" if b_k && b_k.shape != [e_k]
49
+ raise ArgumentError, "Expecting value bias shape of #{[e_v]}, but got #{b_v.shape}" if b_v && b_v.shape != [e_v]
50
+
51
+ [linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
52
+ end
53
+
54
+ def scaled_dot_product_attention(
55
+ q, k, v,
56
+ attn_mask: nil, dropout_p: 0.0
57
+ )
58
+
59
+ b, nt, e = q.shape
60
+
61
+ q = q / Math.sqrt(e)
62
+
63
+ attn = Torch.bmm(q, k.transpose(-2, -1))
64
+ attn += attn_mask if attn_mask
65
+ attn = softmax(attn, dim: -1)
66
+ attn = dropout(attn, p: dropout_p) if dropout_p > 0
67
+
68
+ output = Torch.bmm(attn, v)
69
+
70
+ [output, attn]
71
+ end
72
+
73
+ def multi_head_attention_forward(
74
+ query, key, value,
75
+ embed_dim_to_check, num_heads,
76
+ in_proj_weight, in_proj_bias,
77
+ bias_k, bias_v,
78
+ add_zero_attn,
79
+ dropout_p,
80
+ out_proj_weight, out_proj_bias,
81
+ training: true,
82
+ key_padding_mask: nil,
83
+ need_weights: true,
84
+ attn_mask: nil,
85
+ use_separate_proj_weight: false,
86
+ q_proj_weight: nil, k_proj_weight: nil, v_proj_weight: nil,
87
+ static_k: nil, static_v: nil
88
+ )
89
+
90
+ tgt_len, bsz, embed_dim = query.shape
91
+ src_len = key.shape.first
92
+
93
+ raise ArgumentError, "Was expecting embedding dimension of #{embed_dim_to_check}, but got #{embed_dim}" unless embed_dim == embed_dim_to_check
94
+
95
+ head_dim = if embed_dim.is_a?(Torch::Tensor)
96
+ embed_dim.div(num_heads, rounding_mode: 'trunc')
97
+ else
98
+ head_dim = embed_dim.div num_heads
99
+ end
100
+
101
+ if use_separate_proj_weight
102
+ raise ArgumentError, "Key's sequence and batch dims #{key.shape[0...2]} do not match value's #{value.shape[0...2]}" unless key.shape[0...2] == value.shape[0...2]
103
+ else
104
+ raise ArgumentError, "Key shape #{key.shape} does not match value shape #{value.shape}" unless key.shape == value.shape
105
+ end
106
+
107
+ # compute in-projection
108
+ q, k, v =
109
+ if use_separate_proj_weight
110
+ raise ArgumentError, "use_separate_proj_weight is true but q_proj_weight is nil" unless q_proj_weight
111
+ raise ArgumentError, "use_separate_proj_weight is true but k_proj_weight is nil" unless k_proj_weight
112
+ raise ArgumentError, "use_separate_proj_weight is true but v_proj_weight is nil" unless v_proj_weight
113
+
114
+ if in_proj_bias
115
+ b_q, b_k, b_v = in_proj_bias.chunk(3)
116
+ else
117
+ b_q = b_k = b_v = nil
118
+ end
119
+
120
+ in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q: b_q, b_k: b_k, b_v: b_v)
121
+ else
122
+ in_projection_packed(query, key, value, in_proj_weight, b: in_proj_bias)
123
+ end
124
+
125
+ # prep attention mask
126
+ if attn_mask
127
+ if attn_mask.dtype == :uint8
128
+ puts "[WARN] Byte tensor for attn_mask in Multihead Attention is deprecated. Use bool tensor instead."
129
+ attn_mask = attn_mask.bool
130
+ else
131
+ raise ArgumentError, "Only float, byte, and bool types are supported for attn_mask, not #{attn_mask.dtype}" unless attn_mask.floating_point? || attn_mask.dtype == :bool
132
+ end
133
+
134
+ if attn_mask.dim == 2
135
+ correct_2d_size = [tgt_len, src_len]
136
+ raise ArgumentError, "The shape of the 2D attn_mask is #{attn_mask.shape}, but should be #{correct_2d_size}." unless attn_mask.shape == correct_2d_size
137
+
138
+ attn_mask = attn_mask.unsqueeze(0)
139
+ elsif attn_mask.dim == 3
140
+ correct_3d_size = [bsz * num_heads, tgt_len, src_len]
141
+ raise ArgumentError, "The shape of the 3D attn_mask is #{attn_mask.shape}, but should be #{correct_3d_size}." unless attn_mask.shape == correct_3d_size
142
+ else
143
+ raise ArgumentError, "attn_mask's dimension #{attn_mask.dim} is not supported"
144
+ end
145
+ end
146
+
147
+ # prep key padding mask
148
+ if key_padding_mask && key_padding_mask.dtype == :uint8
149
+ puts "[WARN] Byte tensor for key_padding_mask in Multihead Attention is deprecated. Use bool tensor instead."
150
+ key_padding_mask = key_padding_mask.bool
151
+ end
152
+
153
+ # add bias along batch dimension (currently second)
154
+ if bias_k && bias_v
155
+ raise ArgumentError, "bias cannot be added to static key." if static_k
156
+ raise ArgumentError, "bias cannot be added to static value." if static_v
157
+
158
+ k = Torch.cat([k, bias_k.repeat(1, bsz, 1)])
159
+ v = Torch.cat([v, bias_v.repeat(1, bsz, 1)])
160
+
161
+ attn_mask = pad(attn_mask, [0, 1]) if attn_mask
162
+ key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask
163
+ else
164
+ raise ArgumentError unless bias_k.nil?
165
+ raise ArgumentError unless bias_v.nil?
166
+ end
167
+
168
+ # reshape q, k, v for multihead attention and make em batch first
169
+ q = q.contiguous.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
170
+
171
+ if static_k.nil?
172
+ k = k.contiguous.view(-1, bsz * num_heads, head_dim).transpose(0, 1)
173
+ else
174
+ raise ArgumentError, "Expecting static_k.size(0) of #{bsz * num_heads}, but got #{static_k.size(0)}" unless static_k.size(0) == bsz * num_heads
175
+ raise ArgumentError, "Expecting static_k.size(2) of #{head_dim}, but got #{static_k.size(2)}" unless static_k.size(2) == head_dim
176
+
177
+ k = static_k
178
+ end
179
+
180
+ if static_v.nil?
181
+ v = v.contiguous.view(-1, bsz * num_heads, head_dim).transpose(0, 1)
182
+ else
183
+ raise ArgumentError, "Expecting static_v.size(0) of #{bsz * num_heads}, but got #{static_v.size(0)}" unless static_v.size(0) == bsz * num_heads
184
+ raise ArgumentError, "Expecting static_v.size(2) of #{head_dim}, but got #{static_v.size(2)}" unless static_v.size(2) == head_dim
185
+
186
+ v = static_v
187
+ end
188
+
189
+ # add zero attention along batch dimension (now first)
190
+ if add_zero_attn
191
+ zero_attn_shape = [bsz * num_heads, 1, head_dim]
192
+ k = Torch.cat([k, Torch.zeros(zero_attn_shape, dtype: k.dtype, device: k.device)], dim: 1)
193
+ v = Torch.cat([v, Torch.zeros(zero_attn_shape, dtype: v.dtype, device: v.device)], dim: 1)
194
+
195
+ attn_mask = pad(attn_mask, [0, 1]) if attn_mask
196
+ key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask
197
+ end
198
+
199
+ # update source sequence length after adjustments
200
+ src_len = k.size(1)
201
+
202
+ # merge key padding and attention masks
203
+ if key_padding_mask
204
+ raise ArgumentError, "Expecting key_padding_mask shape of #{[bsz, src_len]}, but got #{key_padding_mask.shape}" unless key_padding_mask.shape == [bsz, src_len]
205
+
206
+ key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
207
+
208
+ attn_mask = if attn_mask.nil?
209
+ key_padding_mask
210
+ elsif attn_mask.dtype == :bool
211
+ attn_mask.logical_or(key_padding_mask)
212
+ else
213
+ attn_mask.masked_fill(key_padding_mask, -Float::INFINITY)
214
+ end
215
+ end
216
+
217
+ # convert mask to float
218
+ if attn_mask && attn_mask.dtype == :bool
219
+ new_attn_mask = Torch.zeros_like(attn_mask, dtype: :float32)
220
+ attn_mask = new_attn_mask.masked_fill(attn_mask, -Float::INFINITY)
221
+ end
222
+
223
+ dropout_p = 0.0 unless training
224
+
225
+ # (deep breath) calculate attention and out projection
226
+ attn_output, attn_output_weights = scaled_dot_product_attention(q, k, v, attn_mask: attn_mask, dropout_p: dropout_p)
227
+ attn_output = attn_output.transpose(0, 1).contiguous.view(tgt_len, bsz, embed_dim)
228
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
229
+
230
+ if need_weights
231
+ # average attention weights over heads
232
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
233
+ [attn_output, attn_output_weights.sum(dim: 1) / num_heads]
234
+ else
235
+ [attn_output, nil]
236
+ end
237
+ end
238
+ end
239
+ end
240
+ end
241
+ end
@@ -3,6 +3,8 @@ module Torch
3
3
  class Module
4
4
  include Utils
5
5
 
6
+ attr_reader :training
7
+
6
8
  def initialize
7
9
  @training = true
8
10
  @parameters = {}
@@ -278,6 +280,11 @@ module Torch
278
280
  end
279
281
  end
280
282
 
283
+ def deep_dup
284
+ memo = {}
285
+ dup_value(self, memo)
286
+ end
287
+
281
288
  def method_missing(method, *args, &block)
282
289
  name = method.to_s
283
290
  if named_parameters.key?(name)
@@ -386,6 +393,29 @@ module Torch
386
393
  destination[prefix + k] = v
387
394
  end
388
395
  end
396
+
397
+ # keep memo hash like Python deepcopy
398
+ # https://docs.python.org/3/library/copy.html
399
+ def dup_value(v, memo)
400
+ memo[v.object_id] ||= begin
401
+ case v
402
+ when Method, UnboundMethod
403
+ v
404
+ when Hash
405
+ v.to_h { |k, v2| [dup_value(k, memo), dup_value(v2, memo)] }
406
+ when Array
407
+ v.map { |v2| dup_value(v2, memo) }
408
+ when Torch::NN::Module
409
+ copy = v.dup
410
+ v.instance_variables.each do |var|
411
+ copy.instance_variable_set(var, dup_value(v.instance_variable_get(var), memo))
412
+ end
413
+ copy
414
+ else
415
+ v.dup
416
+ end
417
+ end
418
+ end
389
419
  end
390
420
  end
391
421
  end
@@ -0,0 +1,49 @@
1
+ module Torch
2
+ module NN
3
+ class ModuleList < Module
4
+ include Enumerable
5
+
6
+ def initialize(mods = nil)
7
+ super()
8
+
9
+ self.concat(mods) if mods
10
+ end
11
+
12
+ def length
13
+ @modules.length
14
+ end
15
+ alias_method :count, :length
16
+ alias_method :size, :length
17
+
18
+ def concat(mods)
19
+ raise ArgumentError, "Modules should respond to #each" unless mods.respond_to?(:each)
20
+
21
+ mods.each { |m| append m }
22
+
23
+ self
24
+ end
25
+
26
+ def each(&block)
27
+ if block_given?
28
+ @modules.values.each(&block)
29
+ else
30
+ to_enum(:each)
31
+ end
32
+ end
33
+
34
+ def append(mod)
35
+ raise ArgumentError, "Provided element is not a module" unless mod.is_a?(Module)
36
+ add_module(length.to_s, mod)
37
+ self
38
+ end
39
+
40
+ def [](idx)
41
+ if idx.is_a?(Range)
42
+ self.class.new(@modules.values[idx])
43
+ else
44
+ @modules[idx.to_s]
45
+ end
46
+ end
47
+ end
48
+ end
49
+ end