torch-rb 0.20.0 → 0.21.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/ext/torch/device.cpp CHANGED
@@ -1,6 +1,9 @@
1
+ #include <string>
2
+
1
3
  #include <torch/torch.h>
2
4
 
3
5
  #include <rice/rice.hpp>
6
+ #include <rice/stl.hpp>
4
7
 
5
8
  #include "utils.h"
6
9
 
data/ext/torch/ext.cpp CHANGED
@@ -17,8 +17,7 @@ void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
17
17
  void init_random(Rice::Module& m);
18
18
 
19
19
  extern "C"
20
- void Init_ext()
21
- {
20
+ void Init_ext() {
22
21
  auto m = Rice::define_module("Torch");
23
22
 
24
23
  // need to define certain classes up front to keep Rice happy
data/ext/torch/ivalue.cpp CHANGED
@@ -1,3 +1,5 @@
1
+ #include <utility>
2
+
1
3
  #include <torch/torch.h>
2
4
 
3
5
  #include <rice/rice.hpp>
data/ext/torch/nn.cpp CHANGED
@@ -1,3 +1,5 @@
1
+ #include <utility>
2
+
1
3
  #include <torch/torch.h>
2
4
 
3
5
  #include <rice/rice.hpp>
@@ -93,7 +95,7 @@ void init_nn(Rice::Module& m) {
93
95
  "grad",
94
96
  [](Parameter& self) {
95
97
  auto grad = self.grad();
96
- return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
98
+ return grad.defined() ? Rice::Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Rice::Nil;
97
99
  })
98
100
  // can't use grad=
99
101
  // assignment methods fail with Ruby 3.0
@@ -1,5 +1,10 @@
1
1
  // adapted from PyTorch - python_arg_parser.cpp
2
2
 
3
+ #include <string>
4
+ #include <unordered_map>
5
+ #include <unordered_set>
6
+ #include <vector>
7
+
3
8
  #include "ruby_arg_parser.h"
4
9
 
5
10
  VALUE THPGeneratorClass = Qnil;
@@ -99,7 +104,7 @@ FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only)
99
104
  ruby_name = THPUtils_internSymbol(name);
100
105
  auto np_compat_it = numpy_compatibility_arg_names.find(name);
101
106
  if (np_compat_it != numpy_compatibility_arg_names.end()) {
102
- for (const auto& str: np_compat_it->second) {
107
+ for (const auto& str : np_compat_it->second) {
103
108
  numpy_python_names.push_back(THPUtils_internSymbol(str));
104
109
  }
105
110
  }
@@ -190,8 +195,7 @@ static bool is_int_or_symint_list(VALUE obj, int broadcast_size) {
190
195
  }
191
196
 
192
197
  // argnum is needed for raising the TypeError, it's used in the error message.
193
- auto FunctionParameter::check(VALUE obj, int argnum) -> bool
194
- {
198
+ auto FunctionParameter::check(VALUE obj, int argnum) -> bool {
195
199
  switch (type_) {
196
200
  case ParameterType::TENSOR: {
197
201
  if (THPVariable_Check(obj)) {
@@ -3,6 +3,9 @@
3
3
  #pragma once
4
4
 
5
5
  #include <sstream>
6
+ #include <unordered_map>
7
+ #include <string>
8
+ #include <vector>
6
9
 
7
10
  #include <torch/torch.h>
8
11
  #include <rice/rice.hpp>
@@ -162,7 +165,7 @@ inline std::array<at::Tensor, N> RubyArgs::tensorlist_n(int i) {
162
165
  Check_Type(arg, T_ARRAY);
163
166
  auto size = RARRAY_LEN(arg);
164
167
  if (size != N) {
165
- rb_raise(rb_eArgError, "expected array of %d elements but got %d", N, (int)size);
168
+ rb_raise(rb_eArgError, "expected array of %d elements but got %d", N, static_cast<int>(size));
166
169
  }
167
170
  for (int idx = 0; idx < size; idx++) {
168
171
  VALUE obj = rb_ary_entry(arg, idx);
@@ -463,7 +466,7 @@ struct RubyArgParser {
463
466
  template<int N>
464
467
  inline RubyArgs parse(VALUE self, int argc, VALUE* argv, ParsedArgs<N> &dst) {
465
468
  if (N < max_args) {
466
- rb_raise(rb_eArgError, "RubyArgParser: dst ParsedArgs buffer does not have enough capacity, expected %d (got %d)", (int)max_args, N);
469
+ rb_raise(rb_eArgError, "RubyArgParser: dst ParsedArgs buffer does not have enough capacity, expected %d (got %d)", static_cast<int>(max_args), N);
467
470
  }
468
471
  return raw_parse(self, argc, argv, dst.args);
469
472
  }
@@ -1,13 +1,13 @@
1
1
  #pragma once
2
2
 
3
+ #include <string>
4
+
3
5
  #ifdef isfinite
4
6
  #undef isfinite
5
7
  #endif
6
8
 
7
9
  #include <rice/rice.hpp>
8
10
 
9
- using namespace Rice;
10
-
11
11
  using torch::Device;
12
12
  using torch::Scalar;
13
13
  using torch::ScalarType;
@@ -41,55 +41,43 @@ using torch::nn::init::NonlinearityType;
41
41
  #define RETURN_NIL \
42
42
  return Qnil;
43
43
 
44
- namespace Rice::detail
45
- {
44
+ namespace Rice::detail {
46
45
  template<typename T>
47
- struct Type<c10::complex<T>>
48
- {
46
+ struct Type<c10::complex<T>> {
49
47
  static bool verify() { return true; }
50
48
  };
51
49
 
52
50
  template<typename T>
53
- class To_Ruby<c10::complex<T>>
54
- {
51
+ class To_Ruby<c10::complex<T>> {
55
52
  public:
56
- VALUE convert(c10::complex<T> const& x)
57
- {
53
+ VALUE convert(c10::complex<T> const& x) {
58
54
  return rb_dbl_complex_new(x.real(), x.imag());
59
55
  }
60
56
  };
61
57
 
62
58
  template<typename T>
63
- class From_Ruby<c10::complex<T>>
64
- {
59
+ class From_Ruby<c10::complex<T>> {
65
60
  public:
66
61
  Convertible is_convertible(VALUE value) { return Convertible::Cast; }
67
62
 
68
- c10::complex<T> convert(VALUE x)
69
- {
63
+ c10::complex<T> convert(VALUE x) {
70
64
  VALUE real = rb_funcall(x, rb_intern("real"), 0);
71
65
  VALUE imag = rb_funcall(x, rb_intern("imag"), 0);
72
66
  return c10::complex<T>(From_Ruby<T>().convert(real), From_Ruby<T>().convert(imag));
73
67
  }
74
68
  };
75
- }
76
69
 
77
- namespace Rice::detail
78
- {
79
70
  template<>
80
- struct Type<FanModeType>
81
- {
71
+ struct Type<FanModeType> {
82
72
  static bool verify() { return true; }
83
73
  };
84
74
 
85
75
  template<>
86
- class From_Ruby<FanModeType>
87
- {
76
+ class From_Ruby<FanModeType> {
88
77
  public:
89
78
  Convertible is_convertible(VALUE value) { return Convertible::Cast; }
90
79
 
91
- FanModeType convert(VALUE x)
92
- {
80
+ FanModeType convert(VALUE x) {
93
81
  auto s = String(x).str();
94
82
  if (s == "fan_in") {
95
83
  return torch::kFanIn;
@@ -102,19 +90,16 @@ namespace Rice::detail
102
90
  };
103
91
 
104
92
  template<>
105
- struct Type<NonlinearityType>
106
- {
93
+ struct Type<NonlinearityType> {
107
94
  static bool verify() { return true; }
108
95
  };
109
96
 
110
97
  template<>
111
- class From_Ruby<NonlinearityType>
112
- {
98
+ class From_Ruby<NonlinearityType> {
113
99
  public:
114
100
  Convertible is_convertible(VALUE value) { return Convertible::Cast; }
115
101
 
116
- NonlinearityType convert(VALUE x)
117
- {
102
+ NonlinearityType convert(VALUE x) {
118
103
  auto s = String(x).str();
119
104
  if (s == "linear") {
120
105
  return torch::kLinear;
@@ -145,19 +130,16 @@ namespace Rice::detail
145
130
  };
146
131
 
147
132
  template<>
148
- struct Type<Scalar>
149
- {
133
+ struct Type<Scalar> {
150
134
  static bool verify() { return true; }
151
135
  };
152
136
 
153
137
  template<>
154
- class From_Ruby<Scalar>
155
- {
138
+ class From_Ruby<Scalar> {
156
139
  public:
157
140
  Convertible is_convertible(VALUE value) { return Convertible::Cast; }
158
141
 
159
- Scalar convert(VALUE x)
160
- {
142
+ Scalar convert(VALUE x) {
161
143
  if (FIXNUM_P(x)) {
162
144
  return torch::Scalar(From_Ruby<int64_t>().convert(x));
163
145
  } else {
@@ -165,4 +147,4 @@ namespace Rice::detail
165
147
  }
166
148
  }
167
149
  };
168
- }
150
+ } // namespace Rice::detail
data/ext/torch/tensor.cpp CHANGED
@@ -1,3 +1,6 @@
1
+ #include <string>
2
+ #include <vector>
3
+
1
4
  #include <torch/torch.h>
2
5
 
3
6
  #include <rice/rice.hpp>
@@ -7,7 +10,8 @@
7
10
  #include "templates.h"
8
11
  #include "utils.h"
9
12
 
10
- using namespace Rice;
13
+ using Rice::Array;
14
+ using Rice::Object;
11
15
  using torch::indexing::TensorIndex;
12
16
 
13
17
  template<typename T>
@@ -21,7 +25,7 @@ Array flat_data(Tensor& tensor) {
21
25
  return a;
22
26
  }
23
27
 
24
- Class rb_cTensor;
28
+ Rice::Class rb_cTensor;
25
29
 
26
30
  std::vector<TensorIndex> index_vector(Array a) {
27
31
  Object obj;
@@ -62,10 +66,10 @@ std::vector<TensorIndex> index_vector(Array a) {
62
66
  indices.push_back(Rice::detail::From_Ruby<Tensor>().convert(obj.value()));
63
67
  } else if (obj.is_nil()) {
64
68
  indices.push_back(torch::indexing::None);
65
- } else if (obj == True || obj == False) {
69
+ } else if (obj == Rice::True || obj == Rice::False) {
66
70
  indices.push_back(Rice::detail::From_Ruby<bool>().convert(obj.value()));
67
71
  } else {
68
- throw Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
72
+ throw Rice::Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
69
73
  }
70
74
  }
71
75
  return indices;
@@ -75,8 +79,7 @@ std::vector<TensorIndex> index_vector(Array a) {
75
79
  // https://github.com/pytorch/pytorch/commit/2e5bfa9824f549be69a28e4705a72b4cf8a4c519
76
80
  // TODO add support for inputs argument
77
81
  // _backward
78
- static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
79
- {
82
+ static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_) {
80
83
  HANDLE_TH_ERRORS
81
84
  Tensor& self = Rice::detail::From_Ruby<Tensor&>().convert(self_);
82
85
  static RubyArgParser parser({
@@ -165,7 +168,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
165
168
  "grad",
166
169
  [](Tensor& self) {
167
170
  auto grad = self.grad();
168
- return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
171
+ return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Rice::Nil;
169
172
  })
170
173
  // can't use grad=
171
174
  // assignment methods fail with Ruby 3.0
@@ -197,7 +200,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
197
200
  .define_method(
198
201
  "_dtype",
199
202
  [](Tensor& self) {
200
- return (int) at::typeMetaToScalarType(self.dtype());
203
+ return static_cast<int>(at::typeMetaToScalarType(self.dtype()));
201
204
  })
202
205
  .define_method(
203
206
  "_type",
data/ext/torch/torch.cpp CHANGED
@@ -1,8 +1,11 @@
1
+ #include <fstream>
2
+ #include <string>
3
+ #include <vector>
4
+
1
5
  #include <torch/torch.h>
2
6
 
3
7
  #include <rice/rice.hpp>
4
-
5
- #include <fstream>
8
+ #include <rice/stl.hpp>
6
9
 
7
10
  #include "torch_functions.h"
8
11
  #include "templates.h"
@@ -60,7 +63,7 @@ void init_torch(Rice::Module& m) {
60
63
  "_save",
61
64
  [](const torch::IValue &value) {
62
65
  auto v = torch::pickle_save(value);
63
- return Object(rb_str_new(v.data(), v.size()));
66
+ return Rice::Object(rb_str_new(v.data(), v.size()));
64
67
  })
65
68
  .define_singleton_function(
66
69
  "_load",
data/ext/torch/utils.h CHANGED
@@ -1,12 +1,14 @@
1
1
  #pragma once
2
2
 
3
+ #include <string>
4
+
3
5
  #include <torch/torch.h>
4
6
 
5
7
  #include <rice/rice.hpp>
6
8
  #include <rice/stl.hpp>
7
9
 
8
10
  static_assert(
9
- TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 7,
11
+ TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 8,
10
12
  "Incompatible LibTorch version"
11
13
  );
12
14
 
@@ -1,9 +1,17 @@
1
1
  module Torch
2
2
  module NN
3
3
  class Conv1d < ConvNd
4
- def initialize(in_channels, out_channels, kernel_size, stride: 1,
5
- padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
6
-
4
+ def initialize(
5
+ in_channels,
6
+ out_channels,
7
+ kernel_size,
8
+ stride: 1,
9
+ padding: 0,
10
+ dilation: 1,
11
+ groups: 1,
12
+ bias: true,
13
+ padding_mode: "zeros"
14
+ )
7
15
  kernel_size = _single(kernel_size)
8
16
  stride = _single(stride)
9
17
  padding = _single(padding)
@@ -1,9 +1,17 @@
1
1
  module Torch
2
2
  module NN
3
3
  class Conv2d < ConvNd
4
- def initialize(in_channels, out_channels, kernel_size, stride: 1,
5
- padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
6
-
4
+ def initialize(
5
+ in_channels,
6
+ out_channels,
7
+ kernel_size,
8
+ stride: 1,
9
+ padding: 0,
10
+ dilation: 1,
11
+ groups: 1,
12
+ bias: true,
13
+ padding_mode: "zeros"
14
+ )
7
15
  kernel_size = _pair(kernel_size)
8
16
  stride = _pair(stride)
9
17
  padding = _pair(padding)
@@ -1,9 +1,17 @@
1
1
  module Torch
2
2
  module NN
3
3
  class Conv3d < ConvNd
4
- def initialize(in_channels, out_channels, kernel_size, stride: 1,
5
- padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
6
-
4
+ def initialize(
5
+ in_channels,
6
+ out_channels,
7
+ kernel_size,
8
+ stride: 1,
9
+ padding: 0,
10
+ dilation: 1,
11
+ groups: 1,
12
+ bias: true,
13
+ padding_mode: "zeros"
14
+ )
7
15
  kernel_size = _triple(kernel_size)
8
16
  stride = _triple(stride)
9
17
  padding = _triple(padding)
@@ -1,7 +1,7 @@
1
1
  module Torch
2
2
  module NN
3
3
  class ConvNd < Module
4
- attr_reader :in_channels, :out_channels, :kernel_size, :stride, :padding, :dilation, :transposed, :output_paddding, :groups, :padding_mode
4
+ attr_reader :in_channels, :out_channels, :kernel_size, :stride, :padding, :dilation, :transposed, :output_padding, :groups, :padding_mode
5
5
 
6
6
  def initialize(in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode)
7
7
  super()
@@ -2,9 +2,16 @@
2
2
  module Torch
3
3
  module NN
4
4
  class Embedding < Module
5
- def initialize(num_embeddings, embedding_dim, padding_idx: nil, max_norm: nil,
6
- norm_type: 2.0, scale_grad_by_freq: false, sparse: false, _weight: nil)
7
-
5
+ def initialize(
6
+ num_embeddings,
7
+ embedding_dim,
8
+ padding_idx: nil,
9
+ max_norm: nil,
10
+ norm_type: 2.0,
11
+ scale_grad_by_freq: false,
12
+ sparse: false,
13
+ _weight: nil
14
+ )
8
15
  super()
9
16
  @num_embeddings = num_embeddings
10
17
  @embedding_dim = embedding_dim
@@ -2,9 +2,16 @@
2
2
  module Torch
3
3
  module NN
4
4
  class EmbeddingBag < Module
5
- def initialize(num_embeddings, embedding_dim, max_norm: nil, norm_type: 2.0,
6
- scale_grad_by_freq: false, mode: "mean", sparse: false, _weight: nil)
7
-
5
+ def initialize(
6
+ num_embeddings,
7
+ embedding_dim,
8
+ max_norm: nil,
9
+ norm_type: 2.0,
10
+ scale_grad_by_freq: false,
11
+ mode: "mean",
12
+ sparse: false,
13
+ _weight: nil
14
+ )
8
15
  super()
9
16
  @num_embeddings = num_embeddings
10
17
  @embedding_dim = embedding_dim
@@ -250,9 +250,16 @@ module Torch
250
250
 
251
251
  # normalization layers
252
252
 
253
- def batch_norm(input, running_mean, running_var, weight: nil, bias: nil,
254
- training: false, momentum: 0.1, eps: 1e-5)
255
-
253
+ def batch_norm(
254
+ input,
255
+ running_mean,
256
+ running_var,
257
+ weight: nil,
258
+ bias: nil,
259
+ training: false,
260
+ momentum: 0.1,
261
+ eps: 1e-5
262
+ )
256
263
  if training
257
264
  size = input.size
258
265
  size_prods = size[0]
@@ -274,9 +281,16 @@ module Torch
274
281
  Torch.group_norm(input, num_groups, weight, bias, eps, false)
275
282
  end
276
283
 
277
- def instance_norm(input, running_mean: nil, running_var: nil, weight: nil,
278
- bias: nil, use_input_stats: true, momentum: 0.1, eps: 1e-5)
279
-
284
+ def instance_norm(
285
+ input,
286
+ running_mean: nil,
287
+ running_var: nil,
288
+ weight: nil,
289
+ bias: nil,
290
+ use_input_stats: true,
291
+ momentum: 0.1,
292
+ eps: 1e-5
293
+ )
280
294
  Torch.instance_norm(
281
295
  input, weight, bias, running_mean, running_var,
282
296
  use_input_stats, momentum, eps, false
@@ -33,11 +33,16 @@ module Torch
33
33
  end
34
34
 
35
35
  def in_projection(
36
- q, k, v,
37
- w_q, w_k, w_v,
38
- b_q: nil, b_k: nil, b_v: nil
36
+ q,
37
+ k,
38
+ v,
39
+ w_q,
40
+ w_k,
41
+ w_v,
42
+ b_q: nil,
43
+ b_k: nil,
44
+ b_v: nil
39
45
  )
40
-
41
46
  e_q, e_k, e_v = q.size(-1), k.size(-1), v.size(-1)
42
47
 
43
48
  raise ArgumentError, "Expecting query weights shape of #{[e_q, e_q]}, but got #{w_q.shape}" unless w_q.shape == [e_q, e_q]
@@ -52,10 +57,12 @@ module Torch
52
57
  end
53
58
 
54
59
  def scaled_dot_product_attention(
55
- q, k, v,
56
- attn_mask: nil, dropout_p: 0.0
60
+ q,
61
+ k,
62
+ v,
63
+ attn_mask: nil,
64
+ dropout_p: 0.0
57
65
  )
58
-
59
66
  _b, _nt, e = q.shape
60
67
 
61
68
  q = q / Math.sqrt(e)
@@ -71,22 +78,30 @@ module Torch
71
78
  end
72
79
 
73
80
  def multi_head_attention_forward(
74
- query, key, value,
75
- embed_dim_to_check, num_heads,
76
- in_proj_weight, in_proj_bias,
77
- bias_k, bias_v,
81
+ query,
82
+ key,
83
+ value,
84
+ embed_dim_to_check,
85
+ num_heads,
86
+ in_proj_weight,
87
+ in_proj_bias,
88
+ bias_k,
89
+ bias_v,
78
90
  add_zero_attn,
79
91
  dropout_p,
80
- out_proj_weight, out_proj_bias,
92
+ out_proj_weight,
93
+ out_proj_bias,
81
94
  training: true,
82
95
  key_padding_mask: nil,
83
96
  need_weights: true,
84
97
  attn_mask: nil,
85
98
  use_separate_proj_weight: false,
86
- q_proj_weight: nil, k_proj_weight: nil, v_proj_weight: nil,
87
- static_k: nil, static_v: nil
99
+ q_proj_weight: nil,
100
+ k_proj_weight: nil,
101
+ v_proj_weight: nil,
102
+ static_k: nil,
103
+ static_v: nil
88
104
  )
89
-
90
105
  tgt_len, bsz, embed_dim = query.shape
91
106
  src_len = key.shape.first
92
107
 
@@ -2,11 +2,18 @@ module Torch
2
2
  module NN
3
3
  class MultiheadAttention < Module
4
4
  def initialize(
5
- embed_dim, num_heads,
6
- dropout: 0.0, bias: true, add_bias_kv: false, add_zero_attn: false,
7
- kdim: nil, vdim: nil, batch_first: false, device: nil, dtype: nil
5
+ embed_dim,
6
+ num_heads,
7
+ dropout: 0.0,
8
+ bias: true,
9
+ add_bias_kv: false,
10
+ add_zero_attn: false,
11
+ kdim: nil,
12
+ vdim: nil,
13
+ batch_first: false,
14
+ device: nil,
15
+ dtype: nil
8
16
  )
9
-
10
17
  super()
11
18
 
12
19
  @embed_dim = embed_dim
@@ -77,10 +84,13 @@ module Torch
77
84
  end
78
85
 
79
86
  def forward(
80
- query, key, value,
81
- key_padding_mask: nil, need_weights: true, attn_mask: nil
87
+ query,
88
+ key,
89
+ value,
90
+ key_padding_mask: nil,
91
+ need_weights: true,
92
+ attn_mask: nil
82
93
  )
83
-
84
94
  if batch_first?
85
95
  query, key, value = [query, key, value].map { |t| t.transpose(1, 0) }
86
96
  end
@@ -1,9 +1,16 @@
1
1
  module Torch
2
2
  module NN
3
3
  class RNNBase < Module
4
- def initialize(mode, input_size, hidden_size, num_layers: 1, bias: true,
5
- batch_first: false, dropout: 0.0, bidirectional: false)
6
-
4
+ def initialize(
5
+ mode,
6
+ input_size,
7
+ hidden_size,
8
+ num_layers: 1,
9
+ bias: true,
10
+ batch_first: false,
11
+ dropout: 0.0,
12
+ bidirectional: false
13
+ )
7
14
  super()
8
15
  @mode = mode
9
16
  @input_size = input_size
@@ -7,13 +7,18 @@ module Torch
7
7
  module NN
8
8
  class Transformer < Module
9
9
  def initialize(
10
- d_model: 512, nhead: 8,
11
- num_encoder_layers: 6, num_decoder_layers: 6,
12
- dim_feedforward: 2048, dropout: 0.1, activation: :relu,
13
- custom_encoder: nil, custom_decoder: nil,
14
- layer_norm_eps: 1e-5, batch_first: false
10
+ d_model: 512,
11
+ nhead: 8,
12
+ num_encoder_layers: 6,
13
+ num_decoder_layers: 6,
14
+ dim_feedforward: 2048,
15
+ dropout: 0.1,
16
+ activation: :relu,
17
+ custom_encoder: nil,
18
+ custom_decoder: nil,
19
+ layer_norm_eps: 1e-5,
20
+ batch_first: false
15
21
  )
16
-
17
22
  super()
18
23
 
19
24
  @encoder =
@@ -60,11 +65,15 @@ module Torch
60
65
  end
61
66
 
62
67
  def forward(
63
- src, tgt,
64
- src_mask: nil, tgt_mask: nil, memory_mask: nil,
65
- src_key_padding_mask: nil, tgt_key_padding_mask: nil, memory_key_padding_mask: nil
68
+ src,
69
+ tgt,
70
+ src_mask: nil,
71
+ tgt_mask: nil,
72
+ memory_mask: nil,
73
+ src_key_padding_mask: nil,
74
+ tgt_key_padding_mask: nil,
75
+ memory_key_padding_mask: nil
66
76
  )
67
-
68
77
  if (!batch_first? && src.size(1) != tgt.size(1)) ||
69
78
  (batch_first? && src.size(0) != tgt.size(0))
70
79