torch-rb 0.6.0 → 0.8.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +21 -0
- data/README.md +23 -41
- data/codegen/function.rb +2 -0
- data/codegen/generate_functions.rb +43 -6
- data/codegen/native_functions.yaml +2007 -1327
- data/ext/torch/backends.cpp +17 -0
- data/ext/torch/cuda.cpp +5 -5
- data/ext/torch/device.cpp +13 -6
- data/ext/torch/ext.cpp +22 -5
- data/ext/torch/extconf.rb +1 -3
- data/ext/torch/fft.cpp +13 -0
- data/ext/torch/fft_functions.h +6 -0
- data/ext/torch/ivalue.cpp +31 -33
- data/ext/torch/linalg.cpp +13 -0
- data/ext/torch/linalg_functions.h +6 -0
- data/ext/torch/nn.cpp +34 -34
- data/ext/torch/random.cpp +5 -5
- data/ext/torch/ruby_arg_parser.cpp +2 -2
- data/ext/torch/ruby_arg_parser.h +23 -12
- data/ext/torch/special.cpp +13 -0
- data/ext/torch/special_functions.h +6 -0
- data/ext/torch/templates.h +111 -133
- data/ext/torch/tensor.cpp +80 -67
- data/ext/torch/torch.cpp +30 -21
- data/ext/torch/utils.h +3 -4
- data/ext/torch/wrap_outputs.h +72 -65
- data/lib/torch/inspector.rb +5 -2
- data/lib/torch/nn/convnd.rb +2 -0
- data/lib/torch/nn/functional_attention.rb +241 -0
- data/lib/torch/nn/module.rb +2 -0
- data/lib/torch/nn/module_list.rb +49 -0
- data/lib/torch/nn/multihead_attention.rb +123 -0
- data/lib/torch/nn/transformer.rb +92 -0
- data/lib/torch/nn/transformer_decoder.rb +25 -0
- data/lib/torch/nn/transformer_decoder_layer.rb +43 -0
- data/lib/torch/nn/transformer_encoder.rb +25 -0
- data/lib/torch/nn/transformer_encoder_layer.rb +36 -0
- data/lib/torch/nn/utils.rb +16 -0
- data/lib/torch/tensor.rb +2 -0
- data/lib/torch/utils/data/data_loader.rb +2 -0
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +11 -0
- metadata +20 -5
data/ext/torch/torch.cpp
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
#include <torch/torch.h>
|
2
2
|
|
3
|
-
#include <rice/
|
3
|
+
#include <rice/rice.hpp>
|
4
4
|
|
5
5
|
#include "torch_functions.h"
|
6
6
|
#include "templates.h"
|
@@ -9,69 +9,78 @@
|
|
9
9
|
void init_torch(Rice::Module& m) {
|
10
10
|
m.add_handler<torch::Error>(handle_error);
|
11
11
|
add_torch_functions(m);
|
12
|
-
m.
|
12
|
+
m.define_singleton_function(
|
13
13
|
"grad_enabled?",
|
14
|
-
|
14
|
+
[]() {
|
15
15
|
return torch::GradMode::is_enabled();
|
16
16
|
})
|
17
|
-
.
|
17
|
+
.define_singleton_function(
|
18
18
|
"_set_grad_enabled",
|
19
|
-
|
19
|
+
[](bool enabled) {
|
20
20
|
torch::GradMode::set_enabled(enabled);
|
21
21
|
})
|
22
|
-
.
|
22
|
+
.define_singleton_function(
|
23
23
|
"manual_seed",
|
24
|
-
|
24
|
+
[](uint64_t seed) {
|
25
25
|
return torch::manual_seed(seed);
|
26
26
|
})
|
27
27
|
// config
|
28
|
-
.
|
28
|
+
.define_singleton_function(
|
29
29
|
"show_config",
|
30
|
-
|
30
|
+
[] {
|
31
31
|
return torch::show_config();
|
32
32
|
})
|
33
|
-
.
|
33
|
+
.define_singleton_function(
|
34
34
|
"parallel_info",
|
35
|
-
|
35
|
+
[] {
|
36
36
|
return torch::get_parallel_info();
|
37
37
|
})
|
38
38
|
// begin operations
|
39
|
-
.
|
39
|
+
.define_singleton_function(
|
40
40
|
"_save",
|
41
|
-
|
41
|
+
[](const torch::IValue &value) {
|
42
42
|
auto v = torch::pickle_save(value);
|
43
43
|
std::string str(v.begin(), v.end());
|
44
44
|
return str;
|
45
45
|
})
|
46
|
-
.
|
46
|
+
.define_singleton_function(
|
47
47
|
"_load",
|
48
|
-
|
48
|
+
[](const std::string &s) {
|
49
49
|
std::vector<char> v;
|
50
50
|
std::copy(s.begin(), s.end(), std::back_inserter(v));
|
51
51
|
// https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
|
52
52
|
return torch::pickle_load(v);
|
53
53
|
})
|
54
|
-
.
|
54
|
+
.define_singleton_function(
|
55
55
|
"_from_blob",
|
56
|
-
|
56
|
+
[](Rice::String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
57
57
|
void *data = const_cast<char *>(s.c_str());
|
58
58
|
return torch::from_blob(data, size, options);
|
59
59
|
})
|
60
|
-
.
|
60
|
+
.define_singleton_function(
|
61
61
|
"_tensor",
|
62
|
-
|
62
|
+
[](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
63
63
|
auto dtype = options.dtype();
|
64
64
|
torch::Tensor t;
|
65
65
|
if (dtype == torch::kBool) {
|
66
66
|
std::vector<uint8_t> vec;
|
67
67
|
for (long i = 0; i < a.size(); i++) {
|
68
|
-
vec.push_back(
|
68
|
+
vec.push_back(Rice::detail::From_Ruby<bool>().convert(a[i].value()));
|
69
|
+
}
|
70
|
+
t = torch::tensor(vec, options);
|
71
|
+
} else if (dtype == torch::kComplexFloat || dtype == torch::kComplexDouble) {
|
72
|
+
// TODO use template
|
73
|
+
std::vector<c10::complex<double>> vec;
|
74
|
+
Object obj;
|
75
|
+
for (long i = 0; i < a.size(); i++) {
|
76
|
+
obj = a[i];
|
77
|
+
vec.push_back(c10::complex<double>(Rice::detail::From_Ruby<double>().convert(obj.call("real").value()), Rice::detail::From_Ruby<double>().convert(obj.call("imag").value())));
|
69
78
|
}
|
70
79
|
t = torch::tensor(vec, options);
|
71
80
|
} else {
|
72
81
|
std::vector<float> vec;
|
73
82
|
for (long i = 0; i < a.size(); i++) {
|
74
|
-
vec.push_back(
|
83
|
+
vec.push_back(Rice::detail::From_Ruby<float>().convert(a[i].value()));
|
75
84
|
}
|
76
85
|
// hack for requires_grad error
|
77
86
|
if (options.requires_grad()) {
|
data/ext/torch/utils.h
CHANGED
@@ -1,11 +1,10 @@
|
|
1
1
|
#pragma once
|
2
2
|
|
3
|
-
#include <rice/
|
4
|
-
#include <rice/
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
#include <rice/stl.hpp>
|
5
5
|
|
6
6
|
// TODO find better place
|
7
|
-
inline void handle_error(torch::Error const & ex)
|
8
|
-
{
|
7
|
+
inline void handle_error(torch::Error const & ex) {
|
9
8
|
throw Rice::Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
10
9
|
}
|
11
10
|
|
data/ext/torch/wrap_outputs.h
CHANGED
@@ -1,99 +1,106 @@
|
|
1
1
|
#pragma once
|
2
2
|
|
3
3
|
#include <torch/torch.h>
|
4
|
-
#include <rice/
|
4
|
+
#include <rice/rice.hpp>
|
5
5
|
|
6
|
-
inline
|
7
|
-
return
|
6
|
+
inline VALUE wrap(bool x) {
|
7
|
+
return Rice::detail::To_Ruby<bool>().convert(x);
|
8
8
|
}
|
9
9
|
|
10
|
-
inline
|
11
|
-
return
|
10
|
+
inline VALUE wrap(int64_t x) {
|
11
|
+
return Rice::detail::To_Ruby<int64_t>().convert(x);
|
12
12
|
}
|
13
13
|
|
14
|
-
inline
|
15
|
-
return
|
14
|
+
inline VALUE wrap(double x) {
|
15
|
+
return Rice::detail::To_Ruby<double>().convert(x);
|
16
16
|
}
|
17
17
|
|
18
|
-
inline
|
19
|
-
return
|
18
|
+
inline VALUE wrap(torch::Tensor x) {
|
19
|
+
return Rice::detail::To_Ruby<torch::Tensor>().convert(x);
|
20
20
|
}
|
21
21
|
|
22
|
-
inline
|
23
|
-
return
|
22
|
+
inline VALUE wrap(torch::Scalar x) {
|
23
|
+
return Rice::detail::To_Ruby<torch::Scalar>().convert(x);
|
24
24
|
}
|
25
25
|
|
26
|
-
inline
|
27
|
-
return
|
26
|
+
inline VALUE wrap(torch::ScalarType x) {
|
27
|
+
return Rice::detail::To_Ruby<torch::ScalarType>().convert(x);
|
28
28
|
}
|
29
29
|
|
30
|
-
inline
|
31
|
-
return
|
30
|
+
inline VALUE wrap(torch::QScheme x) {
|
31
|
+
return Rice::detail::To_Ruby<torch::QScheme>().convert(x);
|
32
32
|
}
|
33
33
|
|
34
|
-
inline
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
34
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
|
35
|
+
return rb_ary_new3(
|
36
|
+
2,
|
37
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
38
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x))
|
39
|
+
);
|
39
40
|
}
|
40
41
|
|
41
|
-
inline
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
42
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
43
|
+
return rb_ary_new3(
|
44
|
+
3,
|
45
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
46
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
47
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x))
|
48
|
+
);
|
47
49
|
}
|
48
50
|
|
49
|
-
inline
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
51
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
52
|
+
return rb_ary_new3(
|
53
|
+
4,
|
54
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
55
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
56
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
|
57
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<3>(x))
|
58
|
+
);
|
56
59
|
}
|
57
60
|
|
58
|
-
inline
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
61
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
62
|
+
return rb_ary_new3(
|
63
|
+
5,
|
64
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
65
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
66
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
|
67
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<3>(x)),
|
68
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<4>(x))
|
69
|
+
);
|
66
70
|
}
|
67
71
|
|
68
|
-
inline
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
72
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
|
73
|
+
return rb_ary_new3(
|
74
|
+
4,
|
75
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
76
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
77
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
|
78
|
+
Rice::detail::To_Ruby<int64_t>().convert(std::get<3>(x))
|
79
|
+
);
|
75
80
|
}
|
76
81
|
|
77
|
-
inline
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
82
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
|
83
|
+
return rb_ary_new3(
|
84
|
+
4,
|
85
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
86
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
87
|
+
Rice::detail::To_Ruby<double>().convert(std::get<2>(x)),
|
88
|
+
Rice::detail::To_Ruby<int64_t>().convert(std::get<3>(x))
|
89
|
+
);
|
84
90
|
}
|
85
91
|
|
86
|
-
inline
|
87
|
-
|
88
|
-
for (auto
|
89
|
-
a
|
92
|
+
inline VALUE wrap(torch::TensorList x) {
|
93
|
+
auto a = rb_ary_new2(x.size());
|
94
|
+
for (auto t : x) {
|
95
|
+
rb_ary_push(a, Rice::detail::To_Ruby<torch::Tensor>().convert(t));
|
90
96
|
}
|
91
|
-
return
|
97
|
+
return a;
|
92
98
|
}
|
93
99
|
|
94
|
-
inline
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
100
|
+
inline VALUE wrap(std::tuple<double, double> x) {
|
101
|
+
return rb_ary_new3(
|
102
|
+
2,
|
103
|
+
Rice::detail::To_Ruby<double>().convert(std::get<0>(x)),
|
104
|
+
Rice::detail::To_Ruby<double>().convert(std::get<1>(x))
|
105
|
+
);
|
99
106
|
}
|
data/lib/torch/inspector.rb
CHANGED
@@ -96,8 +96,11 @@ module Torch
|
|
96
96
|
ret = "%.#{PRINT_OPTS[:precision]}f" % value
|
97
97
|
end
|
98
98
|
elsif @complex_dtype
|
99
|
-
|
100
|
-
|
99
|
+
# TODO use float formatter for each part
|
100
|
+
precision = PRINT_OPTS[:precision]
|
101
|
+
imag = value.imag
|
102
|
+
sign = imag >= 0 ? "+" : "-"
|
103
|
+
ret = "%.#{precision}f#{sign}%.#{precision}fi" % [value.real, value.imag.abs]
|
101
104
|
else
|
102
105
|
ret = value.to_s
|
103
106
|
end
|
data/lib/torch/nn/convnd.rb
CHANGED
@@ -1,6 +1,8 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class ConvNd < Module
|
4
|
+
attr_reader :in_channels, :out_channels, :kernel_size, :stride, :padding, :dilation, :transposed, :output_paddding, :groups, :padding_mode
|
5
|
+
|
4
6
|
def initialize(in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode)
|
5
7
|
super()
|
6
8
|
raise ArgumentError, "in_channels must be divisible by groups" if in_channels % groups != 0
|
@@ -0,0 +1,241 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Functional
|
4
|
+
class << self
|
5
|
+
def in_projection_packed(q, k, v, w, b: nil)
|
6
|
+
e = q.size(-1)
|
7
|
+
|
8
|
+
if k.eql? v
|
9
|
+
if q.eql? k
|
10
|
+
# self-attention
|
11
|
+
return linear(q, w, b).chunk(3, dim: -1)
|
12
|
+
else
|
13
|
+
# encoder-decoder attention
|
14
|
+
w_q, w_kv = w.split_with_sizes([e, e * 2])
|
15
|
+
if b.nil?
|
16
|
+
b_q = b_kv = nil
|
17
|
+
else
|
18
|
+
b_q, b_kv = b.split_with_sizes([e, e * 2])
|
19
|
+
end
|
20
|
+
|
21
|
+
return [linear(q, w_q, b_q), *linear(k, w_kv, b_kv).chunk(2, dim: -1)]
|
22
|
+
end
|
23
|
+
else
|
24
|
+
w_q, w_k, w_v = w.chunk(3)
|
25
|
+
if b.nil?
|
26
|
+
b_q = b_k = b_v = nil
|
27
|
+
else
|
28
|
+
b_q, b_k, b_v = b.chunk(3)
|
29
|
+
end
|
30
|
+
|
31
|
+
return [linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
|
32
|
+
end
|
33
|
+
end
|
34
|
+
|
35
|
+
def in_projection(
|
36
|
+
q, k, v,
|
37
|
+
w_q, w_k, w_v,
|
38
|
+
b_q: nil, b_k: nil, b_v: nil
|
39
|
+
)
|
40
|
+
|
41
|
+
e_q, e_k, e_v = q.size(-1), k.size(-1), v.size(-1)
|
42
|
+
|
43
|
+
raise ArgumentError, "Expecting query weights shape of #{[e_q, e_q]}, but got #{w_q.shape}" unless w_q.shape == [e_q, e_q]
|
44
|
+
raise ArgumentError, "Expecting key weights shape of #{[e_k, e_k]}, but got #{w_k.shape}" unless w_k.shape == [e_k, e_k]
|
45
|
+
raise ArgumentError, "Expecting value weights shape of #{[e_v, e_v]}, but got #{w_v.shape}" unless w_v.shape == [e_v, e_v]
|
46
|
+
|
47
|
+
raise ArgumentError, "Expecting query bias shape of #{[e_q]}, but got #{b_q.shape}" if b_q && b_q.shape != [e_q]
|
48
|
+
raise ArgumentError, "Expecting key bias shape of #{[e_k]}, but got #{b_k.shape}" if b_k && b_k.shape != [e_k]
|
49
|
+
raise ArgumentError, "Expecting value bias shape of #{[e_v]}, but got #{b_v.shape}" if b_v && b_v.shape != [e_v]
|
50
|
+
|
51
|
+
[linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
|
52
|
+
end
|
53
|
+
|
54
|
+
def scaled_dot_product_attention(
|
55
|
+
q, k, v,
|
56
|
+
attn_mask: nil, dropout_p: 0.0
|
57
|
+
)
|
58
|
+
|
59
|
+
b, nt, e = q.shape
|
60
|
+
|
61
|
+
q = q / Math.sqrt(e)
|
62
|
+
|
63
|
+
attn = Torch.bmm(q, k.transpose(-2, -1))
|
64
|
+
attn += attn_mask if attn_mask
|
65
|
+
attn = softmax(attn, dim: -1)
|
66
|
+
attn = dropout(attn, p: dropout_p) if dropout_p > 0
|
67
|
+
|
68
|
+
output = Torch.bmm(attn, v)
|
69
|
+
|
70
|
+
[output, attn]
|
71
|
+
end
|
72
|
+
|
73
|
+
def multi_head_attention_forward(
|
74
|
+
query, key, value,
|
75
|
+
embed_dim_to_check, num_heads,
|
76
|
+
in_proj_weight, in_proj_bias,
|
77
|
+
bias_k, bias_v,
|
78
|
+
add_zero_attn,
|
79
|
+
dropout_p,
|
80
|
+
out_proj_weight, out_proj_bias,
|
81
|
+
training: true,
|
82
|
+
key_padding_mask: nil,
|
83
|
+
need_weights: true,
|
84
|
+
attn_mask: nil,
|
85
|
+
use_separate_proj_weight: false,
|
86
|
+
q_proj_weight: nil, k_proj_weight: nil, v_proj_weight: nil,
|
87
|
+
static_k: nil, static_v: nil
|
88
|
+
)
|
89
|
+
|
90
|
+
tgt_len, bsz, embed_dim = query.shape
|
91
|
+
src_len = key.shape.first
|
92
|
+
|
93
|
+
raise ArgumentError, "Was expecting embedding dimension of #{embed_dim_to_check}, but got #{embed_dim}" unless embed_dim == embed_dim_to_check
|
94
|
+
|
95
|
+
head_dim = if embed_dim.is_a?(Torch::Tensor)
|
96
|
+
embed_dim.div(num_heads, rounding_mode: 'trunc')
|
97
|
+
else
|
98
|
+
head_dim = embed_dim.div num_heads
|
99
|
+
end
|
100
|
+
|
101
|
+
if use_separate_proj_weight
|
102
|
+
raise ArgumentError, "Key's sequence and batch dims #{key.shape[0...2]} do not match value's #{value.shape[0...2]}" unless key.shape[0...2] == value.shape[0...2]
|
103
|
+
else
|
104
|
+
raise ArgumentError, "Key shape #{key.shape} does not match value shape #{value.shape}" unless key.shape == value.shape
|
105
|
+
end
|
106
|
+
|
107
|
+
# compute in-projection
|
108
|
+
q, k, v =
|
109
|
+
if use_separate_proj_weight
|
110
|
+
raise ArgumentError, "use_separate_proj_weight is true but q_proj_weight is nil" unless q_proj_weight
|
111
|
+
raise ArgumentError, "use_separate_proj_weight is true but k_proj_weight is nil" unless k_proj_weight
|
112
|
+
raise ArgumentError, "use_separate_proj_weight is true but v_proj_weight is nil" unless v_proj_weight
|
113
|
+
|
114
|
+
if in_proj_bias
|
115
|
+
b_q, b_k, b_v = in_proj_bias.chunk(3)
|
116
|
+
else
|
117
|
+
b_q = b_k = b_v = nil
|
118
|
+
end
|
119
|
+
|
120
|
+
in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q: b_q, b_k: b_k, b_v: b_v)
|
121
|
+
else
|
122
|
+
in_projection_packed(query, key, value, in_proj_weight, b: in_proj_bias)
|
123
|
+
end
|
124
|
+
|
125
|
+
# prep attention mask
|
126
|
+
if attn_mask
|
127
|
+
if attn_mask.dtype == :uint8
|
128
|
+
puts "[WARN] Byte tensor for attn_mask in Multihead Attention is deprecated. Use bool tensor instead."
|
129
|
+
attn_mask = attn_mask.bool
|
130
|
+
else
|
131
|
+
raise ArgumentError, "Only float, byte, and bool types are supported for attn_mask, not #{attn_mask.dtype}" unless attn_mask.floating_point? || attn_mask.dtype == :bool
|
132
|
+
end
|
133
|
+
|
134
|
+
if attn_mask.dim == 2
|
135
|
+
correct_2d_size = [tgt_len, src_len]
|
136
|
+
raise ArgumentError, "The shape of the 2D attn_mask is #{attn_mask.shape}, but should be #{correct_2d_size}." unless attn_mask.shape == correct_2d_size
|
137
|
+
|
138
|
+
attn_mask = attn_mask.unsqueeze(0)
|
139
|
+
elsif attn_mask.dim == 3
|
140
|
+
correct_3d_size = [bsz * num_heads, tgt_len, src_len]
|
141
|
+
raise ArgumentError, "The shape of the 3D attn_mask is #{attn_mask.shape}, but should be #{correct_3d_size}." unless attn_mask.shape == correct_3d_size
|
142
|
+
else
|
143
|
+
raise ArgumentError, "attn_mask's dimension #{attn_mask.dim} is not supported"
|
144
|
+
end
|
145
|
+
end
|
146
|
+
|
147
|
+
# prep key padding mask
|
148
|
+
if key_padding_mask && key_padding_mask.dtype == :uint8
|
149
|
+
puts "[WARN] Byte tensor for key_padding_mask in Multihead Attention is deprecated. Use bool tensor instead."
|
150
|
+
key_padding_mask = key_padding_mask.bool
|
151
|
+
end
|
152
|
+
|
153
|
+
# add bias along batch dimension (currently second)
|
154
|
+
if bias_k && bias_v
|
155
|
+
raise ArgumentError, "bias cannot be added to static key." if static_k
|
156
|
+
raise ArgumentError, "bias cannot be added to static value." if static_v
|
157
|
+
|
158
|
+
k = Torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
159
|
+
v = Torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
160
|
+
|
161
|
+
attn_mask = pad(attn_mask, [0, 1]) if attn_mask
|
162
|
+
key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask
|
163
|
+
else
|
164
|
+
raise ArgumentError unless bias_k.nil?
|
165
|
+
raise ArgumentError unless bias_v.nil?
|
166
|
+
end
|
167
|
+
|
168
|
+
# reshape q, k, v for multihead attention and make em batch first
|
169
|
+
q = q.contiguous.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
170
|
+
|
171
|
+
if static_k.nil?
|
172
|
+
k = k.contiguous.view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
173
|
+
else
|
174
|
+
raise ArgumentError, "Expecting static_k.size(0) of #{bsz * num_heads}, but got #{static_k.size(0)}" unless static_k.size(0) == bsz * num_heads
|
175
|
+
raise ArgumentError, "Expecting static_k.size(2) of #{head_dim}, but got #{static_k.size(2)}" unless static_k.size(2) == head_dim
|
176
|
+
|
177
|
+
k = static_k
|
178
|
+
end
|
179
|
+
|
180
|
+
if static_v.nil?
|
181
|
+
v = v.contiguous.view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
182
|
+
else
|
183
|
+
raise ArgumentError, "Expecting static_v.size(0) of #{bsz * num_heads}, but got #{static_v.size(0)}" unless static_v.size(0) == bsz * num_heads
|
184
|
+
raise ArgumentError, "Expecting static_v.size(2) of #{head_dim}, but got #{static_v.size(2)}" unless static_v.size(2) == head_dim
|
185
|
+
|
186
|
+
v = static_v
|
187
|
+
end
|
188
|
+
|
189
|
+
# add zero attention along batch dimension (now first)
|
190
|
+
if add_zero_attn
|
191
|
+
zero_attn_shape = [bsz * num_heads, 1, head_dim]
|
192
|
+
k = Torch.cat([k, Torch.zeros(zero_attn_shape, dtype: k.dtype, device: k.device)], dim: 1)
|
193
|
+
v = Torch.cat([v, Torch.zeros(zero_attn_shape, dtype: v.dtype, device: v.device)], dim: 1)
|
194
|
+
|
195
|
+
attn_mask = pad(attn_mask, [0, 1]) if attn_mask
|
196
|
+
key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask
|
197
|
+
end
|
198
|
+
|
199
|
+
# update source sequence length after adjustments
|
200
|
+
src_len = k.size(1)
|
201
|
+
|
202
|
+
# merge key padding and attention masks
|
203
|
+
if key_padding_mask
|
204
|
+
raise ArgumentError, "Expecting key_padding_mask shape of #{[bsz, src_len]}, but got #{key_padding_mask.shape}" unless key_padding_mask.shape == [bsz, src_len]
|
205
|
+
|
206
|
+
key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
|
207
|
+
|
208
|
+
attn_mask = if attn_mask.nil?
|
209
|
+
key_padding_mask
|
210
|
+
elsif attn_mask.dtype == :bool
|
211
|
+
attn_mask.logical_or(key_padding_mask)
|
212
|
+
else
|
213
|
+
attn_mask.masked_fill(key_padding_mask, -Float::INFINITY)
|
214
|
+
end
|
215
|
+
end
|
216
|
+
|
217
|
+
# convert mask to float
|
218
|
+
if attn_mask && attn_mask.dtype == :bool
|
219
|
+
new_attn_mask = Torch.zeros_like(attn_mask, dtype: :float32)
|
220
|
+
attn_mask = new_attn_mask.masked_fill(attn_mask, -Float::INFINITY)
|
221
|
+
end
|
222
|
+
|
223
|
+
dropout_p = 0.0 unless training
|
224
|
+
|
225
|
+
# (deep breath) calculate attention and out projection
|
226
|
+
attn_output, attn_output_weights = scaled_dot_product_attention(q, k, v, attn_mask: attn_mask, dropout_p: dropout_p)
|
227
|
+
attn_output = attn_output.transpose(0, 1).contiguous.view(tgt_len, bsz, embed_dim)
|
228
|
+
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
229
|
+
|
230
|
+
if need_weights
|
231
|
+
# average attention weights over heads
|
232
|
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
233
|
+
[attn_output, attn_output_weights.sum(dim: 1) / num_heads]
|
234
|
+
else
|
235
|
+
[attn_output, nil]
|
236
|
+
end
|
237
|
+
end
|
238
|
+
end
|
239
|
+
end
|
240
|
+
end
|
241
|
+
end
|
data/lib/torch/nn/module.rb
CHANGED
@@ -0,0 +1,49 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class ModuleList < Module
|
4
|
+
include Enumerable
|
5
|
+
|
6
|
+
def initialize(mods = nil)
|
7
|
+
super()
|
8
|
+
|
9
|
+
self.concat(mods) if mods
|
10
|
+
end
|
11
|
+
|
12
|
+
def length
|
13
|
+
@modules.length
|
14
|
+
end
|
15
|
+
alias_method :count, :length
|
16
|
+
alias_method :size, :length
|
17
|
+
|
18
|
+
def concat(mods)
|
19
|
+
raise ArgumentError, "Modules should respond to #each" unless mods.respond_to?(:each)
|
20
|
+
|
21
|
+
mods.each { |m| append m }
|
22
|
+
|
23
|
+
self
|
24
|
+
end
|
25
|
+
|
26
|
+
def each(&block)
|
27
|
+
if block_given?
|
28
|
+
@modules.values.each(&block)
|
29
|
+
else
|
30
|
+
to_enum(:each)
|
31
|
+
end
|
32
|
+
end
|
33
|
+
|
34
|
+
def append(mod)
|
35
|
+
raise ArgumentError, "Provided element is not a module" unless mod.is_a?(Module)
|
36
|
+
add_module(length.to_s, mod)
|
37
|
+
self
|
38
|
+
end
|
39
|
+
|
40
|
+
def [](idx)
|
41
|
+
if idx.is_a?(Range)
|
42
|
+
self.class.new(@modules.values[idx])
|
43
|
+
else
|
44
|
+
@modules[idx.to_s]
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|