secryst 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: 919071d9eb29b220b762212fafe8f7f7ace0f0605cde7c0402dbbe8eac2584c6
4
+ data.tar.gz: de1b20e79c10d14261c71e0043e82a196b38b531da2fcdec0db44bcb992c7239
5
+ SHA512:
6
+ metadata.gz: 94bfb8aff36218341cd5baa3bcb0eb48e5be147c54f4a8b072e25a82ae027c07bacc7769fafb7f657cff9f638f9a0a4a21385108f0a45384bdf1a7796c74ea67
7
+ data.tar.gz: a1fee058281400822908c30a2d5fcf27ceb320dfed8d78878986460323d4028e78fa89db1579146870864fd2104238ee038e9f7468710863ed16ab6c8652ce03
@@ -0,0 +1,103 @@
1
+ = Secryst
2
+
3
+ image:https://github.com/secryst/secryst/workflows/test/badge.svg["Build status", link="https://github.com/secryst/secryst/actions?workflow=test"]
4
+
5
+ == Purpose
6
+
7
+ A seq2seq transformer suited for transliteration. Written in Ruby.
8
+
9
+ Secryst was originally built for the
10
+ https://www.interscript.com[Interscript project]
11
+ (https://github.com/secryst/secryst[at GitHub]).
12
+
13
+ The goal is to allow:
14
+
15
+ * Developers to train models and provide the trained model to users. In order to to train models, raw computing and their bindings can be used, e.g. OpenCL.
16
+
17
+ * Users of the library in Ruby who only want to "use" the trained models should not require special bindings to run.
18
+
19
+
20
+ == Status
21
+
22
+ Currently Secryst works with the Khmer Romanization system as cited below.
23
+
24
+
25
+ == Prerequisites
26
+
27
+ * Ruby 2.7 (*MUST* - 2.6 does not work with the latest torch-rb)
28
+
29
+ * `libtorch` (1.6.0)
30
+ * `fftw`
31
+ * `gsl`
32
+ * `lapack`
33
+ * `openblas`
34
+
35
+
36
+ On Ubuntu:
37
+
38
+ [source,sh]
39
+ ----
40
+ $ sudo apt-get -y install libfftw3-dev libgsl-dev libopenblas-dev \
41
+ liblapack-dev liblapacke-dev unzip automake make gcc g++ \
42
+ libtorch libtorch-dev
43
+ $ wget https://download.pytorch.org/libtorch/cu102/libtorch-cxx11-abi-shared-with-deps-1.6.0.zip
44
+ $ unzip libtorch-cxx11-abi-shared-with-deps-1.6.0.zip
45
+
46
+ $ gem install bundler -v "~> 2"
47
+ $ bundle config build.torch-rb \
48
+ --with-torch-dir=$(pwd)/libtorch
49
+
50
+ $ bundle install
51
+ ----
52
+
53
+
54
+ On macOS:
55
+
56
+ [source,sh]
57
+ ----
58
+ $ brew install libtorch gsl lapack openblas fftw automake gcc
59
+
60
+ $ gem install bundler -v "~> 2"
61
+ $ bundle config build.numo-linalg \
62
+ --with-openblas-dir=/usr/local/opt/openblas \
63
+ --with-lapack-lib=/usr/local/opt/lapack
64
+
65
+ $ bundle install
66
+ ----
67
+
68
+
69
+ NOTE: (for macOS)
70
+ If you mistakenly installed `numo-linalg` without the above configuration
71
+ options, please uninstall it with these steps and configure the bundle as
72
+ described above:
73
+
74
+ [source,sh]
75
+ ----
76
+ $ bundle exec gem uninstall numo-linalg
77
+ ----
78
+
79
+
80
+
81
+ == References
82
+
83
+ Secryst is built on the transformer model with architecture
84
+ based on:
85
+
86
+ * Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
87
+ Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin.
88
+ Attention is all you need. 2017. In:
89
+ _Advances in Neural Information Processing Systems_, pages 6000-6010.
90
+
91
+ The sample transliteration system implemented is the Khmer system:
92
+
93
+ * https://viblo.asia/p/nlp-khmer-word-segmentation-YWOZrgNNlQ0
94
+ * https://viblo.asia/p/nlp-khmer-romanization-using-seq2seq-m68Z07OQKkG
95
+
96
+
97
+ == Origin of name
98
+
99
+ Scrying is the practice of peering into a crystal sphere for fortune telling.
100
+ The purpose of `seq2seq` is nearly like scrying: looking into a crystal sphere
101
+ for some machine-learning magic to happen.
102
+
103
+ "`Secryst`" comes from the combination of "`seq2seq`" + "`crystal`" + "`scrying`".
@@ -0,0 +1,8 @@
1
+ require "secryst"
2
+
3
+ # only if training?
4
+ require "numo/linalg/use/openblas"
5
+
6
+ # transformer model
7
+ require "secryst/clip_grad_norm"
8
+ require "secryst/trainer"
@@ -0,0 +1,11 @@
1
+ require 'json'
2
+
3
+ # torch
4
+ require "torch-rb"
5
+
6
+ # transformer model
7
+ require "secryst/multihead_attention"
8
+ require "secryst/vocab"
9
+ require "secryst/transformer"
10
+
11
+ require "secryst/translator"
@@ -0,0 +1,25 @@
1
+ # ported from https://pytorch.org/docs/master/_modules/torch/nn/utils/clip_grad.html#clip_grad_norm_
2
+
3
+ module Secryst
4
+ class ClipGradNorm < Torch::NN::F
5
+ def self.clip_grad_norm(parameters, max_norm:, norm_type:2)
6
+ parameters = parameters.select {|p| p.grad }
7
+ max_norm = max_norm.to_f
8
+ if parameters.length == 0
9
+ return Torch.tensor(0.0)
10
+ end
11
+ device = parameters[0].grad.device
12
+ if norm_type == Float::INFINITY
13
+ # ... TODO
14
+ else
15
+ total_norm = Numo::Linalg.norm(Numo::NArray.concatenate(parameters.map {|p| Numo::Linalg.norm(p.grad.detach.numo, norm_type)}), norm_type)
16
+ end
17
+ clip_coef = max_norm / (total_norm + 1e-6)
18
+ if clip_coef < 1
19
+ parameters.each {|p| p.grad = p.grad.detach * clip_coef}
20
+ end
21
+
22
+ return total_norm
23
+ end
24
+ end
25
+ end
@@ -0,0 +1,288 @@
1
+ module Secryst
2
+ class MultiHeadAttentionForward < Torch::NN::F
3
+ # Args:
4
+ # query, key, value: map a query and a set of key-value pairs to an output.
5
+ # See "Attention Is All You Need" for more details.
6
+ # embed_dim_to_check: total dimension of the model.
7
+ # num_heads: parallel attention heads.
8
+ # in_proj_weight, in_proj_bias: input projection weight and bias.
9
+ # bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
10
+ # add_zero_attn: add a new batch of zeros to the key and
11
+ # value sequences at dim=1.
12
+ # dropout_p: probability of an element to be zeroed.
13
+ # out_proj_weight, out_proj_bias: the output projection weight and bias.
14
+ # training: apply dropout if is ``true``.
15
+ # key_padding_mask: if provided, specified padding elements in the key will
16
+ # be ignored by the attention. This is a binary mask. When the value is true,
17
+ # the corresponding value on the attention layer will be filled with -inf.
18
+ # need_weights: output attn_output_weights.
19
+ # attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
20
+ # the batches while a 3D mask allows to specify a different mask for the entries of each batch.
21
+ # use_separate_proj_weight: the function accept the proj. weights for query, key,
22
+ # and value in different forms. If false, in_proj_weight will be used, which is
23
+ # a combination of q_proj_weight, k_proj_weight, v_proj_weight.
24
+ # q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
25
+ # static_k, static_v: static key and value used for attention operators.
26
+ # Shape:
27
+ # Inputs:
28
+ # - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
29
+ # the embedding dimension.
30
+ # - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
31
+ # the embedding dimension.
32
+ # - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
33
+ # the embedding dimension.
34
+ # - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
35
+ # If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
36
+ # will be unchanged. If a BoolTensor is provided, the positions with the
37
+ # value of ``true`` will be ignored while the position with the value of ``false`` will be unchanged.
38
+ # - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
39
+ # 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
40
+ # S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
41
+ # positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
42
+ # while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``true``
43
+ # are not allowed to attend while ``false`` values will be unchanged. If a FloatTensor
44
+ # is provided, it will be added to the attention weight.
45
+ # - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
46
+ # N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
47
+ # - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
48
+ # N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
49
+ # Outputs:
50
+ # - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
51
+ # E is the embedding dimension.
52
+ # - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
53
+ # L is the target sequence length, S is the source sequence length.
54
+ def self.multi_head_attention_forward(query,
55
+ key,
56
+ value,
57
+ embed_dim_to_check,
58
+ num_heads,
59
+ in_proj_weight,
60
+ in_proj_bias,
61
+ bias_k,
62
+ bias_v,
63
+ add_zero_attn,
64
+ dropout_p,
65
+ out_proj_weight,
66
+ out_proj_bias,
67
+ training: true,
68
+ key_padding_mask: nil,
69
+ need_weights: true,
70
+ attn_mask: nil,
71
+ use_separate_proj_weight: false,
72
+ q_proj_weight: nil,
73
+ k_proj_weight: nil,
74
+ v_proj_weight: nil,
75
+ static_k: nil,
76
+ static_v: nil)
77
+ tgt_len, bsz, embed_dim = query.size()
78
+ raise ArgumentError if embed_dim != embed_dim_to_check
79
+ # allow MHA to have different sizes for the feature dimension
80
+ raise ArgumentError if key.size(0) != value.size(0) or key.size(1) != value.size(1)
81
+
82
+ head_dim = embed_dim / num_heads
83
+ raise ArgumentError, "embed_dim must be divisible by num_heads" if head_dim * num_heads != embed_dim
84
+ scaling = head_dim.to_f ** -0.5
85
+
86
+ if !use_separate_proj_weight
87
+ if Torch.equal(query, key) && Torch.equal(key, value)
88
+ # self-attention
89
+ q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, -1)
90
+
91
+ elsif Torch.equal(key, value)
92
+ # encoder-decoder attention
93
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
94
+ _b = in_proj_bias
95
+ _start = 0
96
+ _end = embed_dim
97
+ _w = in_proj_weight.slice(0, _start, _end) # NOTE: inc-trspl
98
+ if _b
99
+ _b = _b.slice(0, _start, _end)
100
+ end
101
+ q = linear(query, _w, _b)
102
+
103
+ if !key
104
+ raise ArgumentError if value
105
+ k = nil
106
+ v = nil
107
+ else
108
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
109
+ _b = in_proj_bias
110
+ _start = embed_dim
111
+ _end = nil
112
+ _w = in_proj_weight.slice(0, _start)
113
+ if _b
114
+ _b = _b.slice(0, _start)
115
+ end
116
+ k, v = linear(key, _w, _b).chunk(2, -1)
117
+ end
118
+
119
+ else
120
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
121
+ _b = in_proj_bias
122
+ _start = 0
123
+ _end = embed_dim
124
+ _w = in_proj_weight.slice(0, _start, _end)
125
+ if _b
126
+ _b = _b.slice(0, _start, _end)
127
+ end
128
+ q = linear(query, _w, _b)
129
+
130
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
131
+ _b = in_proj_bias
132
+ _start = embed_dim
133
+ _end = embed_dim * 2
134
+ _w = in_proj_weight.slice(0, _start, _end)
135
+ if _b
136
+ _b = _b.slice(0, _start, _end)
137
+ end
138
+ k = linear(key, _w, _b)
139
+
140
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
141
+ _b = in_proj_bias
142
+ _start = embed_dim * 2
143
+ _end = nil
144
+ _w = in_proj_weight.slice(0, _start)
145
+ if _b
146
+ _b = _b.slice(0, _start)
147
+ end
148
+ v = linear(value, _w, _b)
149
+ end
150
+ else
151
+ q_proj_weight_non_opt = q_proj_weight
152
+ len1, len2 = q_proj_weight_non_opt.size()
153
+ raise ArgumentError if len1 != embed_dim || len2 != query.size(-1)
154
+
155
+ k_proj_weight_non_opt = k_proj_weight
156
+ len1, len2 = k_proj_weight_non_opt.size()
157
+ raise ArgumentError if len1 != embed_dim || len2 != key.size(-1)
158
+
159
+ v_proj_weight_non_opt = v_proj_weight
160
+ len1, len2 = v_proj_weight_non_opt.size()
161
+ raise ArgumentError if len1 != embed_dim || len2 != value.size(-1)
162
+
163
+ if in_proj_bias
164
+ q = linear(query, q_proj_weight_non_opt, in_proj_bias.slice(0,0,embed_dim))
165
+ k = linear(key, k_proj_weight_non_opt, in_proj_bias.slice(0, embed_dim, embed_dim * 2))
166
+ v = linear(value, v_proj_weight_non_opt, in_proj_bias.slice(0, embed_dim * 2))
167
+ else
168
+ q = linear(query, q_proj_weight_non_opt, in_proj_bias)
169
+ k = linear(key, k_proj_weight_non_opt, in_proj_bias)
170
+ v = linear(value, v_proj_weight_non_opt, in_proj_bias)
171
+ end
172
+ end
173
+ q = q * scaling
174
+
175
+ if attn_mask
176
+ raise ArgumentError, 'Only float, byte, and bool types are supported for attn_mask, not %s' % attn_mask.dtype unless attn_mask.dtype == Torch.float32 || attn_mask.dtype == Torch.float64 || attn_mask.dtype == Torch.float16 || attn_mask.dtype == Torch.uint8 || attn_mask.dtype == Torch.bool
177
+ if attn_mask.dtype == Torch.uint8
178
+ puts "Byte tensor for attn_mask in NN::MultiheadAttention is deprecated. Use bool tensor instead."
179
+ attn_mask = attn_mask.to(Torch.bool)
180
+ end
181
+
182
+ if attn_mask.dim() == 2
183
+ attn_mask = attn_mask.unsqueeze(0)
184
+ raise ArgumentError, 'The size of the 2D attn_mask is not correct.' if attn_mask.size() != [1, query.size(0), key.size(0)]
185
+ elsif attn_mask.dim() == 3
186
+ raise ArgumentError, 'The size of the 3D attn_mask is not correct.' if attn_mask.size() != [bsz * num_heads, query.size(0), key.size(0)]
187
+ else
188
+ raise ArgumentError, "attn_mask's dimension %s is not supported" % attn_mask.dim()
189
+ end
190
+ # attn_mask's dim is 3 now.
191
+ end
192
+
193
+ # convert ByteTensor key_padding_mask to bool
194
+ if key_padding_mask && key_padding_mask.dtype == Torch.uint8
195
+ puts("Byte tensor for key_padding_mask in NN::MultiheadAttention is deprecated. Use bool tensor instead.")
196
+ key_padding_mask = key_padding_mask.to(Torch.bool)
197
+ end
198
+
199
+ if bias_k && bias_v
200
+ if !static_k && !static_v
201
+ k = Torch.cat([k, bias_k.repeat(1, bsz, 1)])
202
+ v = Torch.cat([v, bias_v.repeat(1, bsz, 1)])
203
+ attn_mask = pad(attn_mask, [0, 1]) if attn_mask
204
+ key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask
205
+ else
206
+ raise ArgumentError, "bias cannot be added to static key." unless !static_k
207
+ raise ArgumentError, "bias cannot be added to static value." unless !static_v
208
+ end
209
+ else
210
+ raise ArgumentError unless !bias_k
211
+ raise ArgumentError unless !bias_v
212
+ end
213
+
214
+ q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
215
+ k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) if k
216
+ v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) if v
217
+
218
+ if static_k
219
+ raise ArgumentError unless static_k.size(0) == bsz * num_heads
220
+ raise ArgumentError unless static_k.size(2) == head_dim
221
+ k = static_k
222
+ end
223
+
224
+ if static_v
225
+ raise ArgumentError unless static_v.size(0) == bsz * num_heads
226
+ raise ArgumentError unless static_v.size(2) == head_dim
227
+ v = static_v
228
+ end
229
+
230
+ src_len = k.size(1)
231
+
232
+ if key_padding_mask
233
+ raise ArgumentError unless key_padding_mask.size(0) == bsz
234
+ raise ArgumentError unless key_padding_mask.size(1) == src_len
235
+ end
236
+
237
+ if add_zero_attn
238
+ src_len += 1
239
+ k_sizes = k.size()
240
+ k_sizes[1] = 1
241
+ k = Torch.cat([k, Torch.zeros(k_sizes, dtype: k.dtype, device: k.device)], 1)
242
+ v_sizes = v.size()
243
+ v_sizes[1] = 1
244
+ v = Torch.cat([v, Torch.zeros(v_sizes, dtype: v.dtype, device: v.device)], 1)
245
+ attn_mask = pad(attn_mask, [0, 1]) if attn_mask
246
+ key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask
247
+ end
248
+
249
+ attn_output_weights = Torch.bmm(q, k.transpose(1, 2))
250
+ raise ArgumentError unless attn_output_weights.size() == [bsz * num_heads, tgt_len, src_len]
251
+
252
+ if attn_mask
253
+ if attn_mask.dtype == Torch.bool
254
+ attn_output_weights.masked_fill!(attn_mask, -1.0/0.0)
255
+ else
256
+ attn_output_weights += attn_mask
257
+ end
258
+ end
259
+
260
+
261
+ if key_padding_mask
262
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
263
+ attn_output_weights = attn_output_weights.masked_fill(
264
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
265
+ -1.0/0.0
266
+ )
267
+ attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
268
+ end
269
+
270
+ attn_output_weights = softmax(
271
+ attn_output_weights, dim: -1)
272
+ attn_output_weights = dropout(attn_output_weights, p: dropout_p, training: training)
273
+
274
+ attn_output = Torch.bmm(attn_output_weights, v)
275
+ raise ArgumentError unless attn_output.size() == [bsz * num_heads, tgt_len, head_dim]
276
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
277
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
278
+
279
+ if need_weights
280
+ # average attention weights over heads
281
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
282
+ return attn_output, attn_output_weights.sum(1) / num_heads
283
+ else
284
+ return attn_output, nil
285
+ end
286
+ end
287
+ end
288
+ end