secryst 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: 919071d9eb29b220b762212fafe8f7f7ace0f0605cde7c0402dbbe8eac2584c6
4
+ data.tar.gz: de1b20e79c10d14261c71e0043e82a196b38b531da2fcdec0db44bcb992c7239
5
+ SHA512:
6
+ metadata.gz: 94bfb8aff36218341cd5baa3bcb0eb48e5be147c54f4a8b072e25a82ae027c07bacc7769fafb7f657cff9f638f9a0a4a21385108f0a45384bdf1a7796c74ea67
7
+ data.tar.gz: a1fee058281400822908c30a2d5fcf27ceb320dfed8d78878986460323d4028e78fa89db1579146870864fd2104238ee038e9f7468710863ed16ab6c8652ce03
@@ -0,0 +1,103 @@
1
+ = Secryst
2
+
3
+ image:https://github.com/secryst/secryst/workflows/test/badge.svg["Build status", link="https://github.com/secryst/secryst/actions?workflow=test"]
4
+
5
+ == Purpose
6
+
7
+ A seq2seq transformer suited for transliteration. Written in Ruby.
8
+
9
+ Secryst was originally built for the
10
+ https://www.interscript.com[Interscript project]
11
+ (https://github.com/secryst/secryst[at GitHub]).
12
+
13
+ The goal is to allow:
14
+
15
+ * Developers to train models and provide the trained model to users. In order to to train models, raw computing and their bindings can be used, e.g. OpenCL.
16
+
17
+ * Users of the library in Ruby who only want to "use" the trained models should not require special bindings to run.
18
+
19
+
20
+ == Status
21
+
22
+ Currently Secryst works with the Khmer Romanization system as cited below.
23
+
24
+
25
+ == Prerequisites
26
+
27
+ * Ruby 2.7 (*MUST* - 2.6 does not work with the latest torch-rb)
28
+
29
+ * `libtorch` (1.6.0)
30
+ * `fftw`
31
+ * `gsl`
32
+ * `lapack`
33
+ * `openblas`
34
+
35
+
36
+ On Ubuntu:
37
+
38
+ [source,sh]
39
+ ----
40
+ $ sudo apt-get -y install libfftw3-dev libgsl-dev libopenblas-dev \
41
+ liblapack-dev liblapacke-dev unzip automake make gcc g++ \
42
+ libtorch libtorch-dev
43
+ $ wget https://download.pytorch.org/libtorch/cu102/libtorch-cxx11-abi-shared-with-deps-1.6.0.zip
44
+ $ unzip libtorch-cxx11-abi-shared-with-deps-1.6.0.zip
45
+
46
+ $ gem install bundler -v "~> 2"
47
+ $ bundle config build.torch-rb \
48
+ --with-torch-dir=$(pwd)/libtorch
49
+
50
+ $ bundle install
51
+ ----
52
+
53
+
54
+ On macOS:
55
+
56
+ [source,sh]
57
+ ----
58
+ $ brew install libtorch gsl lapack openblas fftw automake gcc
59
+
60
+ $ gem install bundler -v "~> 2"
61
+ $ bundle config build.numo-linalg \
62
+ --with-openblas-dir=/usr/local/opt/openblas \
63
+ --with-lapack-lib=/usr/local/opt/lapack
64
+
65
+ $ bundle install
66
+ ----
67
+
68
+
69
+ NOTE: (for macOS)
70
+ If you mistakenly installed `numo-linalg` without the above configuration
71
+ options, please uninstall it with these steps and configure the bundle as
72
+ described above:
73
+
74
+ [source,sh]
75
+ ----
76
+ $ bundle exec gem uninstall numo-linalg
77
+ ----
78
+
79
+
80
+
81
+ == References
82
+
83
+ Secryst is built on the transformer model with architecture
84
+ based on:
85
+
86
+ * Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
87
+ Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin.
88
+ Attention is all you need. 2017. In:
89
+ _Advances in Neural Information Processing Systems_, pages 6000-6010.
90
+
91
+ The sample transliteration system implemented is the Khmer system:
92
+
93
+ * https://viblo.asia/p/nlp-khmer-word-segmentation-YWOZrgNNlQ0
94
+ * https://viblo.asia/p/nlp-khmer-romanization-using-seq2seq-m68Z07OQKkG
95
+
96
+
97
+ == Origin of name
98
+
99
+ Scrying is the practice of peering into a crystal sphere for fortune telling.
100
+ The purpose of `seq2seq` is nearly like scrying: looking into a crystal sphere
101
+ for some machine-learning magic to happen.
102
+
103
+ "`Secryst`" comes from the combination of "`seq2seq`" + "`crystal`" + "`scrying`".
@@ -0,0 +1,8 @@
1
+ require "secryst"
2
+
3
+ # only if training?
4
+ require "numo/linalg/use/openblas"
5
+
6
+ # transformer model
7
+ require "secryst/clip_grad_norm"
8
+ require "secryst/trainer"
@@ -0,0 +1,11 @@
1
+ require 'json'
2
+
3
+ # torch
4
+ require "torch-rb"
5
+
6
+ # transformer model
7
+ require "secryst/multihead_attention"
8
+ require "secryst/vocab"
9
+ require "secryst/transformer"
10
+
11
+ require "secryst/translator"
@@ -0,0 +1,25 @@
1
+ # ported from https://pytorch.org/docs/master/_modules/torch/nn/utils/clip_grad.html#clip_grad_norm_
2
+
3
+ module Secryst
4
+ class ClipGradNorm < Torch::NN::F
5
+ def self.clip_grad_norm(parameters, max_norm:, norm_type:2)
6
+ parameters = parameters.select {|p| p.grad }
7
+ max_norm = max_norm.to_f
8
+ if parameters.length == 0
9
+ return Torch.tensor(0.0)
10
+ end
11
+ device = parameters[0].grad.device
12
+ if norm_type == Float::INFINITY
13
+ # ... TODO
14
+ else
15
+ total_norm = Numo::Linalg.norm(Numo::NArray.concatenate(parameters.map {|p| Numo::Linalg.norm(p.grad.detach.numo, norm_type)}), norm_type)
16
+ end
17
+ clip_coef = max_norm / (total_norm + 1e-6)
18
+ if clip_coef < 1
19
+ parameters.each {|p| p.grad = p.grad.detach * clip_coef}
20
+ end
21
+
22
+ return total_norm
23
+ end
24
+ end
25
+ end
@@ -0,0 +1,288 @@
1
+ module Secryst
2
+ class MultiHeadAttentionForward < Torch::NN::F
3
+ # Args:
4
+ # query, key, value: map a query and a set of key-value pairs to an output.
5
+ # See "Attention Is All You Need" for more details.
6
+ # embed_dim_to_check: total dimension of the model.
7
+ # num_heads: parallel attention heads.
8
+ # in_proj_weight, in_proj_bias: input projection weight and bias.
9
+ # bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
10
+ # add_zero_attn: add a new batch of zeros to the key and
11
+ # value sequences at dim=1.
12
+ # dropout_p: probability of an element to be zeroed.
13
+ # out_proj_weight, out_proj_bias: the output projection weight and bias.
14
+ # training: apply dropout if is ``true``.
15
+ # key_padding_mask: if provided, specified padding elements in the key will
16
+ # be ignored by the attention. This is a binary mask. When the value is true,
17
+ # the corresponding value on the attention layer will be filled with -inf.
18
+ # need_weights: output attn_output_weights.
19
+ # attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
20
+ # the batches while a 3D mask allows to specify a different mask for the entries of each batch.
21
+ # use_separate_proj_weight: the function accept the proj. weights for query, key,
22
+ # and value in different forms. If false, in_proj_weight will be used, which is
23
+ # a combination of q_proj_weight, k_proj_weight, v_proj_weight.
24
+ # q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
25
+ # static_k, static_v: static key and value used for attention operators.
26
+ # Shape:
27
+ # Inputs:
28
+ # - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
29
+ # the embedding dimension.
30
+ # - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
31
+ # the embedding dimension.
32
+ # - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
33
+ # the embedding dimension.
34
+ # - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
35
+ # If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
36
+ # will be unchanged. If a BoolTensor is provided, the positions with the
37
+ # value of ``true`` will be ignored while the position with the value of ``false`` will be unchanged.
38
+ # - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
39
+ # 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
40
+ # S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
41
+ # positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
42
+ # while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``true``
43
+ # are not allowed to attend while ``false`` values will be unchanged. If a FloatTensor
44
+ # is provided, it will be added to the attention weight.
45
+ # - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
46
+ # N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
47
+ # - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
48
+ # N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
49
+ # Outputs:
50
+ # - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
51
+ # E is the embedding dimension.
52
+ # - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
53
+ # L is the target sequence length, S is the source sequence length.
54
+ def self.multi_head_attention_forward(query,
55
+ key,
56
+ value,
57
+ embed_dim_to_check,
58
+ num_heads,
59
+ in_proj_weight,
60
+ in_proj_bias,
61
+ bias_k,
62
+ bias_v,
63
+ add_zero_attn,
64
+ dropout_p,
65
+ out_proj_weight,
66
+ out_proj_bias,
67
+ training: true,
68
+ key_padding_mask: nil,
69
+ need_weights: true,
70
+ attn_mask: nil,
71
+ use_separate_proj_weight: false,
72
+ q_proj_weight: nil,
73
+ k_proj_weight: nil,
74
+ v_proj_weight: nil,
75
+ static_k: nil,
76
+ static_v: nil)
77
+ tgt_len, bsz, embed_dim = query.size()
78
+ raise ArgumentError if embed_dim != embed_dim_to_check
79
+ # allow MHA to have different sizes for the feature dimension
80
+ raise ArgumentError if key.size(0) != value.size(0) or key.size(1) != value.size(1)
81
+
82
+ head_dim = embed_dim / num_heads
83
+ raise ArgumentError, "embed_dim must be divisible by num_heads" if head_dim * num_heads != embed_dim
84
+ scaling = head_dim.to_f ** -0.5
85
+
86
+ if !use_separate_proj_weight
87
+ if Torch.equal(query, key) && Torch.equal(key, value)
88
+ # self-attention
89
+ q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, -1)
90
+
91
+ elsif Torch.equal(key, value)
92
+ # encoder-decoder attention
93
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
94
+ _b = in_proj_bias
95
+ _start = 0
96
+ _end = embed_dim
97
+ _w = in_proj_weight.slice(0, _start, _end) # NOTE: inc-trspl
98
+ if _b
99
+ _b = _b.slice(0, _start, _end)
100
+ end
101
+ q = linear(query, _w, _b)
102
+
103
+ if !key
104
+ raise ArgumentError if value
105
+ k = nil
106
+ v = nil
107
+ else
108
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
109
+ _b = in_proj_bias
110
+ _start = embed_dim
111
+ _end = nil
112
+ _w = in_proj_weight.slice(0, _start)
113
+ if _b
114
+ _b = _b.slice(0, _start)
115
+ end
116
+ k, v = linear(key, _w, _b).chunk(2, -1)
117
+ end
118
+
119
+ else
120
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
121
+ _b = in_proj_bias
122
+ _start = 0
123
+ _end = embed_dim
124
+ _w = in_proj_weight.slice(0, _start, _end)
125
+ if _b
126
+ _b = _b.slice(0, _start, _end)
127
+ end
128
+ q = linear(query, _w, _b)
129
+
130
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
131
+ _b = in_proj_bias
132
+ _start = embed_dim
133
+ _end = embed_dim * 2
134
+ _w = in_proj_weight.slice(0, _start, _end)
135
+ if _b
136
+ _b = _b.slice(0, _start, _end)
137
+ end
138
+ k = linear(key, _w, _b)
139
+
140
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
141
+ _b = in_proj_bias
142
+ _start = embed_dim * 2
143
+ _end = nil
144
+ _w = in_proj_weight.slice(0, _start)
145
+ if _b
146
+ _b = _b.slice(0, _start)
147
+ end
148
+ v = linear(value, _w, _b)
149
+ end
150
+ else
151
+ q_proj_weight_non_opt = q_proj_weight
152
+ len1, len2 = q_proj_weight_non_opt.size()
153
+ raise ArgumentError if len1 != embed_dim || len2 != query.size(-1)
154
+
155
+ k_proj_weight_non_opt = k_proj_weight
156
+ len1, len2 = k_proj_weight_non_opt.size()
157
+ raise ArgumentError if len1 != embed_dim || len2 != key.size(-1)
158
+
159
+ v_proj_weight_non_opt = v_proj_weight
160
+ len1, len2 = v_proj_weight_non_opt.size()
161
+ raise ArgumentError if len1 != embed_dim || len2 != value.size(-1)
162
+
163
+ if in_proj_bias
164
+ q = linear(query, q_proj_weight_non_opt, in_proj_bias.slice(0,0,embed_dim))
165
+ k = linear(key, k_proj_weight_non_opt, in_proj_bias.slice(0, embed_dim, embed_dim * 2))
166
+ v = linear(value, v_proj_weight_non_opt, in_proj_bias.slice(0, embed_dim * 2))
167
+ else
168
+ q = linear(query, q_proj_weight_non_opt, in_proj_bias)
169
+ k = linear(key, k_proj_weight_non_opt, in_proj_bias)
170
+ v = linear(value, v_proj_weight_non_opt, in_proj_bias)
171
+ end
172
+ end
173
+ q = q * scaling
174
+
175
+ if attn_mask
176
+ raise ArgumentError, 'Only float, byte, and bool types are supported for attn_mask, not %s' % attn_mask.dtype unless attn_mask.dtype == Torch.float32 || attn_mask.dtype == Torch.float64 || attn_mask.dtype == Torch.float16 || attn_mask.dtype == Torch.uint8 || attn_mask.dtype == Torch.bool
177
+ if attn_mask.dtype == Torch.uint8
178
+ puts "Byte tensor for attn_mask in NN::MultiheadAttention is deprecated. Use bool tensor instead."
179
+ attn_mask = attn_mask.to(Torch.bool)
180
+ end
181
+
182
+ if attn_mask.dim() == 2
183
+ attn_mask = attn_mask.unsqueeze(0)
184
+ raise ArgumentError, 'The size of the 2D attn_mask is not correct.' if attn_mask.size() != [1, query.size(0), key.size(0)]
185
+ elsif attn_mask.dim() == 3
186
+ raise ArgumentError, 'The size of the 3D attn_mask is not correct.' if attn_mask.size() != [bsz * num_heads, query.size(0), key.size(0)]
187
+ else
188
+ raise ArgumentError, "attn_mask's dimension %s is not supported" % attn_mask.dim()
189
+ end
190
+ # attn_mask's dim is 3 now.
191
+ end
192
+
193
+ # convert ByteTensor key_padding_mask to bool
194
+ if key_padding_mask && key_padding_mask.dtype == Torch.uint8
195
+ puts("Byte tensor for key_padding_mask in NN::MultiheadAttention is deprecated. Use bool tensor instead.")
196
+ key_padding_mask = key_padding_mask.to(Torch.bool)
197
+ end
198
+
199
+ if bias_k && bias_v
200
+ if !static_k && !static_v
201
+ k = Torch.cat([k, bias_k.repeat(1, bsz, 1)])
202
+ v = Torch.cat([v, bias_v.repeat(1, bsz, 1)])
203
+ attn_mask = pad(attn_mask, [0, 1]) if attn_mask
204
+ key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask
205
+ else
206
+ raise ArgumentError, "bias cannot be added to static key." unless !static_k
207
+ raise ArgumentError, "bias cannot be added to static value." unless !static_v
208
+ end
209
+ else
210
+ raise ArgumentError unless !bias_k
211
+ raise ArgumentError unless !bias_v
212
+ end
213
+
214
+ q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
215
+ k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) if k
216
+ v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) if v
217
+
218
+ if static_k
219
+ raise ArgumentError unless static_k.size(0) == bsz * num_heads
220
+ raise ArgumentError unless static_k.size(2) == head_dim
221
+ k = static_k
222
+ end
223
+
224
+ if static_v
225
+ raise ArgumentError unless static_v.size(0) == bsz * num_heads
226
+ raise ArgumentError unless static_v.size(2) == head_dim
227
+ v = static_v
228
+ end
229
+
230
+ src_len = k.size(1)
231
+
232
+ if key_padding_mask
233
+ raise ArgumentError unless key_padding_mask.size(0) == bsz
234
+ raise ArgumentError unless key_padding_mask.size(1) == src_len
235
+ end
236
+
237
+ if add_zero_attn
238
+ src_len += 1
239
+ k_sizes = k.size()
240
+ k_sizes[1] = 1
241
+ k = Torch.cat([k, Torch.zeros(k_sizes, dtype: k.dtype, device: k.device)], 1)
242
+ v_sizes = v.size()
243
+ v_sizes[1] = 1
244
+ v = Torch.cat([v, Torch.zeros(v_sizes, dtype: v.dtype, device: v.device)], 1)
245
+ attn_mask = pad(attn_mask, [0, 1]) if attn_mask
246
+ key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask
247
+ end
248
+
249
+ attn_output_weights = Torch.bmm(q, k.transpose(1, 2))
250
+ raise ArgumentError unless attn_output_weights.size() == [bsz * num_heads, tgt_len, src_len]
251
+
252
+ if attn_mask
253
+ if attn_mask.dtype == Torch.bool
254
+ attn_output_weights.masked_fill!(attn_mask, -1.0/0.0)
255
+ else
256
+ attn_output_weights += attn_mask
257
+ end
258
+ end
259
+
260
+
261
+ if key_padding_mask
262
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
263
+ attn_output_weights = attn_output_weights.masked_fill(
264
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
265
+ -1.0/0.0
266
+ )
267
+ attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
268
+ end
269
+
270
+ attn_output_weights = softmax(
271
+ attn_output_weights, dim: -1)
272
+ attn_output_weights = dropout(attn_output_weights, p: dropout_p, training: training)
273
+
274
+ attn_output = Torch.bmm(attn_output_weights, v)
275
+ raise ArgumentError unless attn_output.size() == [bsz * num_heads, tgt_len, head_dim]
276
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
277
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
278
+
279
+ if need_weights
280
+ # average attention weights over heads
281
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
282
+ return attn_output, attn_output_weights.sum(1) / num_heads
283
+ else
284
+ return attn_output, nil
285
+ end
286
+ end
287
+ end
288
+ end