torch-rb 0.8.0 → 0.9.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,17 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "utils.h"
6
+
7
+ void init_backends(Rice::Module& m) {
8
+ auto rb_mBackends = Rice::define_module_under(m, "Backends");
9
+
10
+ Rice::define_module_under(rb_mBackends, "OpenMP")
11
+ .add_handler<torch::Error>(handle_error)
12
+ .define_singleton_function("available?", &torch::hasOpenMP);
13
+
14
+ Rice::define_module_under(rb_mBackends, "MKL")
15
+ .add_handler<torch::Error>(handle_error)
16
+ .define_singleton_function("available?", &torch::hasMKL);
17
+ }
data/ext/torch/ext.cpp CHANGED
@@ -2,10 +2,14 @@
2
2
 
3
3
  #include <rice/rice.hpp>
4
4
 
5
+ void init_fft(Rice::Module& m);
6
+ void init_linalg(Rice::Module& m);
5
7
  void init_nn(Rice::Module& m);
8
+ void init_special(Rice::Module& m);
6
9
  void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions);
7
10
  void init_torch(Rice::Module& m);
8
11
 
12
+ void init_backends(Rice::Module& m);
9
13
  void init_cuda(Rice::Module& m);
10
14
  void init_device(Rice::Module& m);
11
15
  void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
@@ -27,7 +31,11 @@ void Init_ext()
27
31
  init_torch(m);
28
32
  init_tensor(m, rb_cTensor, rb_cTensorOptions);
29
33
  init_nn(m);
34
+ init_fft(m);
35
+ init_linalg(m);
36
+ init_special(m);
30
37
 
38
+ init_backends(m);
31
39
  init_cuda(m);
32
40
  init_device(m);
33
41
  init_ivalue(m, rb_cIValue);
data/ext/torch/fft.cpp ADDED
@@ -0,0 +1,13 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "fft_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ void init_fft(Rice::Module& m) {
10
+ auto rb_mFFT = Rice::define_module_under(m, "FFT");
11
+ rb_mFFT.add_handler<torch::Error>(handle_error);
12
+ add_fft_functions(rb_mFFT);
13
+ }
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_fft_functions(Rice::Module& m);
@@ -0,0 +1,13 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "linalg_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ void init_linalg(Rice::Module& m) {
10
+ auto rb_mLinalg = Rice::define_module_under(m, "Linalg");
11
+ rb_mLinalg.add_handler<torch::Error>(handle_error);
12
+ add_linalg_functions(rb_mLinalg);
13
+ }
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_linalg_functions(Rice::Module& m);
@@ -75,7 +75,7 @@ struct RubyArgs {
75
75
  int idx;
76
76
 
77
77
  inline at::Tensor tensor(int i);
78
- inline OptionalTensor optionalTensor(int i);
78
+ inline c10::optional<at::Tensor> optionalTensor(int i);
79
79
  inline at::Scalar scalar(int i);
80
80
  // inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar);
81
81
  inline std::vector<at::Scalar> scalarlist(int i);
@@ -109,6 +109,9 @@ struct RubyArgs {
109
109
  // inline at::QScheme toQScheme(int i);
110
110
  inline std::string string(int i);
111
111
  inline c10::optional<std::string> stringOptional(int i);
112
+ inline c10::string_view stringView(int i);
113
+ // inline c10::string_view stringViewWithDefault(int i, const c10::string_view default_str);
114
+ inline c10::optional<c10::string_view> stringViewOptional(int i);
112
115
  // inline PyObject* pyobject(int i);
113
116
  inline int64_t toInt64(int i);
114
117
  // inline int64_t toInt64WithDefault(int i, int64_t default_int);
@@ -125,8 +128,8 @@ inline at::Tensor RubyArgs::tensor(int i) {
125
128
  return Rice::detail::From_Ruby<torch::Tensor>().convert(args[i]);
126
129
  }
127
130
 
128
- inline OptionalTensor RubyArgs::optionalTensor(int i) {
129
- if (NIL_P(args[i])) return OptionalTensor(Nil);
131
+ inline c10::optional<at::Tensor> RubyArgs::optionalTensor(int i) {
132
+ if (NIL_P(args[i])) return c10::nullopt;
130
133
  return tensor(i);
131
134
  }
132
135
 
@@ -322,6 +325,17 @@ inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
322
325
  return Rice::detail::From_Ruby<std::string>().convert(args[i]);
323
326
  }
324
327
 
328
+ inline c10::string_view RubyArgs::stringView(int i) {
329
+ auto str = Rice::detail::From_Ruby<std::string>().convert(args[i]);
330
+ return c10::string_view(str.data(), str.size());
331
+ }
332
+
333
+ inline c10::optional<c10::string_view> RubyArgs::stringViewOptional(int i) {
334
+ if (NIL_P(args[i])) return c10::nullopt;
335
+ auto str = Rice::detail::From_Ruby<std::string>().convert(args[i]);
336
+ return c10::string_view(str.data(), str.size());
337
+ }
338
+
325
339
  inline int64_t RubyArgs::toInt64(int i) {
326
340
  if (NIL_P(args[i])) return signature.params[i].default_int;
327
341
  return Rice::detail::From_Ruby<int64_t>().convert(args[i]);
@@ -0,0 +1,13 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "special_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ void init_special(Rice::Module& m) {
10
+ auto rb_mSpecial = Rice::define_module_under(m, "Special");
11
+ rb_mSpecial.add_handler<torch::Error>(handle_error);
12
+ add_special_functions(rb_mSpecial);
13
+ }
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_special_functions(Rice::Module& m);
@@ -41,24 +41,6 @@ using torch::nn::init::NonlinearityType;
41
41
  #define RETURN_NIL \
42
42
  return Qnil;
43
43
 
44
- class OptionalTensor {
45
- torch::Tensor value;
46
- public:
47
- OptionalTensor(Object o) {
48
- if (o.is_nil()) {
49
- value = {};
50
- } else {
51
- value = Rice::detail::From_Ruby<torch::Tensor>().convert(o.value());
52
- }
53
- }
54
- OptionalTensor(torch::Tensor o) {
55
- value = o;
56
- }
57
- operator torch::Tensor() const {
58
- return value;
59
- }
60
- };
61
-
62
44
  namespace Rice::detail
63
45
  {
64
46
  template<>
@@ -131,25 +113,6 @@ namespace Rice::detail
131
113
  }
132
114
  };
133
115
 
134
- template<>
135
- struct Type<OptionalTensor>
136
- {
137
- static bool verify()
138
- {
139
- return true;
140
- }
141
- };
142
-
143
- template<>
144
- class From_Ruby<OptionalTensor>
145
- {
146
- public:
147
- OptionalTensor convert(VALUE x)
148
- {
149
- return OptionalTensor(x);
150
- }
151
- };
152
-
153
116
  template<>
154
117
  struct Type<Scalar>
155
118
  {
data/ext/torch/tensor.cpp CHANGED
@@ -107,7 +107,7 @@ static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
107
107
  ParsedArgs<4> parsed_args;
108
108
  auto _r = parser.parse(self_, argc, argv, parsed_args);
109
109
  // _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()
110
- auto dispatch__backward = [](const Tensor & self, TensorList inputs, const OptionalTensor & gradient, c10::optional<bool> retain_graph, bool create_graph) -> void {
110
+ auto dispatch__backward = [](const Tensor & self, TensorList inputs, const c10::optional<at::Tensor> & gradient, c10::optional<bool> retain_graph, bool create_graph) -> void {
111
111
  // in future, release GVL
112
112
  self._backward(inputs, gradient, retain_graph, create_graph);
113
113
  };
@@ -125,13 +125,13 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
125
125
  rb_define_method(rb_cTensor, "backward", (VALUE (*)(...)) tensor__backward, -1);
126
126
 
127
127
  rb_cTensor
128
- .define_method("cuda?", &torch::Tensor::is_cuda)
129
- .define_method("sparse?", &torch::Tensor::is_sparse)
130
- .define_method("quantized?", &torch::Tensor::is_quantized)
131
- .define_method("dim", &torch::Tensor::dim)
132
- .define_method("numel", &torch::Tensor::numel)
133
- .define_method("element_size", &torch::Tensor::element_size)
134
- .define_method("requires_grad", &torch::Tensor::requires_grad)
128
+ .define_method("cuda?", [](Tensor& self) { return self.is_cuda(); })
129
+ .define_method("sparse?", [](Tensor& self) { return self.is_sparse(); })
130
+ .define_method("quantized?", [](Tensor& self) { return self.is_quantized(); })
131
+ .define_method("dim", [](Tensor& self) { return self.dim(); })
132
+ .define_method("numel", [](Tensor& self) { return self.numel(); })
133
+ .define_method("element_size", [](Tensor& self) { return self.element_size(); })
134
+ .define_method("requires_grad", [](Tensor& self) { return self.requires_grad(); })
135
135
  .define_method(
136
136
  "_size",
137
137
  [](Tensor& self, int64_t dim) {
@@ -1,6 +1,8 @@
1
1
  module Torch
2
2
  module NN
3
3
  class ConvNd < Module
4
+ attr_reader :in_channels, :out_channels, :kernel_size, :stride, :padding, :dilation, :transposed, :output_paddding, :groups, :padding_mode
5
+
4
6
  def initialize(in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode)
5
7
  super()
6
8
  raise ArgumentError, "in_channels must be divisible by groups" if in_channels % groups != 0
@@ -0,0 +1,241 @@
1
+ module Torch
2
+ module NN
3
+ class Functional
4
+ class << self
5
+ def in_projection_packed(q, k, v, w, b: nil)
6
+ e = q.size(-1)
7
+
8
+ if k.eql? v
9
+ if q.eql? k
10
+ # self-attention
11
+ return linear(q, w, b).chunk(3, dim: -1)
12
+ else
13
+ # encoder-decoder attention
14
+ w_q, w_kv = w.split_with_sizes([e, e * 2])
15
+ if b.nil?
16
+ b_q = b_kv = nil
17
+ else
18
+ b_q, b_kv = b.split_with_sizes([e, e * 2])
19
+ end
20
+
21
+ return [linear(q, w_q, b_q), *linear(k, w_kv, b_kv).chunk(2, dim: -1)]
22
+ end
23
+ else
24
+ w_q, w_k, w_v = w.chunk(3)
25
+ if b.nil?
26
+ b_q = b_k = b_v = nil
27
+ else
28
+ b_q, b_k, b_v = b.chunk(3)
29
+ end
30
+
31
+ return [linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
32
+ end
33
+ end
34
+
35
+ def in_projection(
36
+ q, k, v,
37
+ w_q, w_k, w_v,
38
+ b_q: nil, b_k: nil, b_v: nil
39
+ )
40
+
41
+ e_q, e_k, e_v = q.size(-1), k.size(-1), v.size(-1)
42
+
43
+ raise ArgumentError, "Expecting query weights shape of #{[e_q, e_q]}, but got #{w_q.shape}" unless w_q.shape == [e_q, e_q]
44
+ raise ArgumentError, "Expecting key weights shape of #{[e_k, e_k]}, but got #{w_k.shape}" unless w_k.shape == [e_k, e_k]
45
+ raise ArgumentError, "Expecting value weights shape of #{[e_v, e_v]}, but got #{w_v.shape}" unless w_v.shape == [e_v, e_v]
46
+
47
+ raise ArgumentError, "Expecting query bias shape of #{[e_q]}, but got #{b_q.shape}" if b_q && b_q.shape != [e_q]
48
+ raise ArgumentError, "Expecting key bias shape of #{[e_k]}, but got #{b_k.shape}" if b_k && b_k.shape != [e_k]
49
+ raise ArgumentError, "Expecting value bias shape of #{[e_v]}, but got #{b_v.shape}" if b_v && b_v.shape != [e_v]
50
+
51
+ [linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
52
+ end
53
+
54
+ def scaled_dot_product_attention(
55
+ q, k, v,
56
+ attn_mask: nil, dropout_p: 0.0
57
+ )
58
+
59
+ b, nt, e = q.shape
60
+
61
+ q = q / Math.sqrt(e)
62
+
63
+ attn = Torch.bmm(q, k.transpose(-2, -1))
64
+ attn += attn_mask if attn_mask
65
+ attn = softmax(attn, dim: -1)
66
+ attn = dropout(attn, p: dropout_p) if dropout_p > 0
67
+
68
+ output = Torch.bmm(attn, v)
69
+
70
+ [output, attn]
71
+ end
72
+
73
+ def multi_head_attention_forward(
74
+ query, key, value,
75
+ embed_dim_to_check, num_heads,
76
+ in_proj_weight, in_proj_bias,
77
+ bias_k, bias_v,
78
+ add_zero_attn,
79
+ dropout_p,
80
+ out_proj_weight, out_proj_bias,
81
+ training: true,
82
+ key_padding_mask: nil,
83
+ need_weights: true,
84
+ attn_mask: nil,
85
+ use_separate_proj_weight: false,
86
+ q_proj_weight: nil, k_proj_weight: nil, v_proj_weight: nil,
87
+ static_k: nil, static_v: nil
88
+ )
89
+
90
+ tgt_len, bsz, embed_dim = query.shape
91
+ src_len = key.shape.first
92
+
93
+ raise ArgumentError, "Was expecting embedding dimension of #{embed_dim_to_check}, but got #{embed_dim}" unless embed_dim == embed_dim_to_check
94
+
95
+ head_dim = if embed_dim.is_a?(Torch::Tensor)
96
+ embed_dim.div(num_heads, rounding_mode: 'trunc')
97
+ else
98
+ head_dim = embed_dim.div num_heads
99
+ end
100
+
101
+ if use_separate_proj_weight
102
+ raise ArgumentError, "Key's sequence and batch dims #{key.shape[0...2]} do not match value's #{value.shape[0...2]}" unless key.shape[0...2] == value.shape[0...2]
103
+ else
104
+ raise ArgumentError, "Key shape #{key.shape} does not match value shape #{value.shape}" unless key.shape == value.shape
105
+ end
106
+
107
+ # compute in-projection
108
+ q, k, v =
109
+ if use_separate_proj_weight
110
+ raise ArgumentError, "use_separate_proj_weight is true but q_proj_weight is nil" unless q_proj_weight
111
+ raise ArgumentError, "use_separate_proj_weight is true but k_proj_weight is nil" unless k_proj_weight
112
+ raise ArgumentError, "use_separate_proj_weight is true but v_proj_weight is nil" unless v_proj_weight
113
+
114
+ if in_proj_bias
115
+ b_q, b_k, b_v = in_proj_bias.chunk(3)
116
+ else
117
+ b_q = b_k = b_v = nil
118
+ end
119
+
120
+ in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q: b_q, b_k: b_k, b_v: b_v)
121
+ else
122
+ in_projection_packed(query, key, value, in_proj_weight, b: in_proj_bias)
123
+ end
124
+
125
+ # prep attention mask
126
+ if attn_mask
127
+ if attn_mask.dtype == :uint8
128
+ puts "[WARN] Byte tensor for attn_mask in Multihead Attention is deprecated. Use bool tensor instead."
129
+ attn_mask = attn_mask.bool
130
+ else
131
+ raise ArgumentError, "Only float, byte, and bool types are supported for attn_mask, not #{attn_mask.dtype}" unless attn_mask.floating_point? || attn_mask.dtype == :bool
132
+ end
133
+
134
+ if attn_mask.dim == 2
135
+ correct_2d_size = [tgt_len, src_len]
136
+ raise ArgumentError, "The shape of the 2D attn_mask is #{attn_mask.shape}, but should be #{correct_2d_size}." unless attn_mask.shape == correct_2d_size
137
+
138
+ attn_mask = attn_mask.unsqueeze(0)
139
+ elsif attn_mask.dim == 3
140
+ correct_3d_size = [bsz * num_heads, tgt_len, src_len]
141
+ raise ArgumentError, "The shape of the 3D attn_mask is #{attn_mask.shape}, but should be #{correct_3d_size}." unless attn_mask.shape == correct_3d_size
142
+ else
143
+ raise ArgumentError, "attn_mask's dimension #{attn_mask.dim} is not supported"
144
+ end
145
+ end
146
+
147
+ # prep key padding mask
148
+ if key_padding_mask && key_padding_mask.dtype == :uint8
149
+ puts "[WARN] Byte tensor for key_padding_mask in Multihead Attention is deprecated. Use bool tensor instead."
150
+ key_padding_mask = key_padding_mask.bool
151
+ end
152
+
153
+ # add bias along batch dimension (currently second)
154
+ if bias_k && bias_v
155
+ raise ArgumentError, "bias cannot be added to static key." if static_k
156
+ raise ArgumentError, "bias cannot be added to static value." if static_v
157
+
158
+ k = Torch.cat([k, bias_k.repeat(1, bsz, 1)])
159
+ v = Torch.cat([v, bias_v.repeat(1, bsz, 1)])
160
+
161
+ attn_mask = pad(attn_mask, [0, 1]) if attn_mask
162
+ key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask
163
+ else
164
+ raise ArgumentError unless bias_k.nil?
165
+ raise ArgumentError unless bias_v.nil?
166
+ end
167
+
168
+ # reshape q, k, v for multihead attention and make em batch first
169
+ q = q.contiguous.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
170
+
171
+ if static_k.nil?
172
+ k = k.contiguous.view(-1, bsz * num_heads, head_dim).transpose(0, 1)
173
+ else
174
+ raise ArgumentError, "Expecting static_k.size(0) of #{bsz * num_heads}, but got #{static_k.size(0)}" unless static_k.size(0) == bsz * num_heads
175
+ raise ArgumentError, "Expecting static_k.size(2) of #{head_dim}, but got #{static_k.size(2)}" unless static_k.size(2) == head_dim
176
+
177
+ k = static_k
178
+ end
179
+
180
+ if static_v.nil?
181
+ v = v.contiguous.view(-1, bsz * num_heads, head_dim).transpose(0, 1)
182
+ else
183
+ raise ArgumentError, "Expecting static_v.size(0) of #{bsz * num_heads}, but got #{static_v.size(0)}" unless static_v.size(0) == bsz * num_heads
184
+ raise ArgumentError, "Expecting static_v.size(2) of #{head_dim}, but got #{static_v.size(2)}" unless static_v.size(2) == head_dim
185
+
186
+ v = static_v
187
+ end
188
+
189
+ # add zero attention along batch dimension (now first)
190
+ if add_zero_attn
191
+ zero_attn_shape = [bsz * num_heads, 1, head_dim]
192
+ k = Torch.cat([k, Torch.zeros(zero_attn_shape, dtype: k.dtype, device: k.device)], dim: 1)
193
+ v = Torch.cat([v, Torch.zeros(zero_attn_shape, dtype: v.dtype, device: v.device)], dim: 1)
194
+
195
+ attn_mask = pad(attn_mask, [0, 1]) if attn_mask
196
+ key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask
197
+ end
198
+
199
+ # update source sequence length after adjustments
200
+ src_len = k.size(1)
201
+
202
+ # merge key padding and attention masks
203
+ if key_padding_mask
204
+ raise ArgumentError, "Expecting key_padding_mask shape of #{[bsz, src_len]}, but got #{key_padding_mask.shape}" unless key_padding_mask.shape == [bsz, src_len]
205
+
206
+ key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
207
+
208
+ attn_mask = if attn_mask.nil?
209
+ key_padding_mask
210
+ elsif attn_mask.dtype == :bool
211
+ attn_mask.logical_or(key_padding_mask)
212
+ else
213
+ attn_mask.masked_fill(key_padding_mask, -Float::INFINITY)
214
+ end
215
+ end
216
+
217
+ # convert mask to float
218
+ if attn_mask && attn_mask.dtype == :bool
219
+ new_attn_mask = Torch.zeros_like(attn_mask, dtype: :float32)
220
+ attn_mask = new_attn_mask.masked_fill(attn_mask, -Float::INFINITY)
221
+ end
222
+
223
+ dropout_p = 0.0 unless training
224
+
225
+ # (deep breath) calculate attention and out projection
226
+ attn_output, attn_output_weights = scaled_dot_product_attention(q, k, v, attn_mask: attn_mask, dropout_p: dropout_p)
227
+ attn_output = attn_output.transpose(0, 1).contiguous.view(tgt_len, bsz, embed_dim)
228
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
229
+
230
+ if need_weights
231
+ # average attention weights over heads
232
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
233
+ [attn_output, attn_output_weights.sum(dim: 1) / num_heads]
234
+ else
235
+ [attn_output, nil]
236
+ end
237
+ end
238
+ end
239
+ end
240
+ end
241
+ end
@@ -3,6 +3,8 @@ module Torch
3
3
  class Module
4
4
  include Utils
5
5
 
6
+ attr_reader :training
7
+
6
8
  def initialize
7
9
  @training = true
8
10
  @parameters = {}
@@ -278,6 +280,11 @@ module Torch
278
280
  end
279
281
  end
280
282
 
283
+ def deep_dup
284
+ memo = {}
285
+ dup_value(self, memo)
286
+ end
287
+
281
288
  def method_missing(method, *args, &block)
282
289
  name = method.to_s
283
290
  if named_parameters.key?(name)
@@ -386,6 +393,29 @@ module Torch
386
393
  destination[prefix + k] = v
387
394
  end
388
395
  end
396
+
397
+ # keep memo hash like Python deepcopy
398
+ # https://docs.python.org/3/library/copy.html
399
+ def dup_value(v, memo)
400
+ memo[v.object_id] ||= begin
401
+ case v
402
+ when Method, UnboundMethod
403
+ v
404
+ when Hash
405
+ v.to_h { |k, v2| [dup_value(k, memo), dup_value(v2, memo)] }
406
+ when Array
407
+ v.map { |v2| dup_value(v2, memo) }
408
+ when Torch::NN::Module
409
+ copy = v.dup
410
+ v.instance_variables.each do |var|
411
+ copy.instance_variable_set(var, dup_value(v.instance_variable_get(var), memo))
412
+ end
413
+ copy
414
+ else
415
+ v.dup
416
+ end
417
+ end
418
+ end
389
419
  end
390
420
  end
391
421
  end
@@ -0,0 +1,49 @@
1
+ module Torch
2
+ module NN
3
+ class ModuleList < Module
4
+ include Enumerable
5
+
6
+ def initialize(mods = nil)
7
+ super()
8
+
9
+ self.concat(mods) if mods
10
+ end
11
+
12
+ def length
13
+ @modules.length
14
+ end
15
+ alias_method :count, :length
16
+ alias_method :size, :length
17
+
18
+ def concat(mods)
19
+ raise ArgumentError, "Modules should respond to #each" unless mods.respond_to?(:each)
20
+
21
+ mods.each { |m| append m }
22
+
23
+ self
24
+ end
25
+
26
+ def each(&block)
27
+ if block_given?
28
+ @modules.values.each(&block)
29
+ else
30
+ to_enum(:each)
31
+ end
32
+ end
33
+
34
+ def append(mod)
35
+ raise ArgumentError, "Provided element is not a module" unless mod.is_a?(Module)
36
+ add_module(length.to_s, mod)
37
+ self
38
+ end
39
+
40
+ def [](idx)
41
+ if idx.is_a?(Range)
42
+ self.class.new(@modules.values[idx])
43
+ else
44
+ @modules[idx.to_s]
45
+ end
46
+ end
47
+ end
48
+ end
49
+ end