torch-rb 0.8.1 → 0.8.2

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: d52cf2bf4770e9166623614f6071e5180e9492b0063757be8bdec73a2c930b38
4
- data.tar.gz: 1b650a3277d1aebe28cdd5d75ce54420feba7a1c9a2046335c9271c72eb3f74f
3
+ metadata.gz: '05811faa93ab089485bfa213362bfed0462227e6964e726bf1c3f9fc0cdba0c3'
4
+ data.tar.gz: 8d25304063db51850e535e71c2b4fe643e038a18d046b8f27ee50269e5e9695d
5
5
  SHA512:
6
- metadata.gz: 3afad67c5ca6cedc4925dab1aadc37541daac6752a8b508a8d4bd6cd3b25b71600ed32625a65651bf150d6ddd519ec86ca7b9ccd2f12022366ed692603f65c1a
7
- data.tar.gz: 970fa451044ce68d60e13f2da297ea8e20e53f01a1ffee31f152f190baf5d8f709f9be3defcd6b24739499cb400c5448abd43939a451be826e87b2e62eba3cac
6
+ metadata.gz: 8cea906b03b37ec848be7b1c7cfa6bfb0fde4ef7ed384818bcc85826d04611621835bbeecad0fd31c94b497bb527a21c107901a9d991692b6fffa6bb24c23c38
7
+ data.tar.gz: 2ded65d614d274afe61e061898268172e6dec85dc28e3094481c2650713a129c0574b7ba17dbafccd8b8dc7ee064602fd261ce5034cec4a184fa32f2965eb476
data/CHANGELOG.md CHANGED
@@ -1,3 +1,8 @@
1
+ ## 0.8.2 (2021-10-03)
2
+
3
+ - Added transformers
4
+ - Added left shift and right shift
5
+
1
6
  ## 0.8.1 (2021-06-15)
2
7
 
3
8
  - Added `Backends` module
data/README.md CHANGED
@@ -28,15 +28,19 @@ It can take a few minutes to compile the extension.
28
28
 
29
29
  ## Getting Started
30
30
 
31
- Deep learning is significantly faster with a GPU. If you don’t have an NVIDIA GPU, we recommend using a cloud service. [Paperspace](https://www.paperspace.com/) has a great free plan.
31
+ A good place to start is [Deep Learning with Torch.rb: A 60 Minute Blitz](tutorials/blitz/README.md).
32
32
 
33
- We’ve put together a [Docker image](https://github.com/ankane/ml-stack) to make it easy to get started. On Paperspace, create a notebook with a custom container. Under advanced options, set the container name to:
33
+ ## Tutorials
34
34
 
35
- ```text
36
- ankane/ml-stack:torch-gpu
37
- ```
35
+ - [Transfer learning](tutorials/transfer_learning/README.md)
36
+ - [Sequence models](tutorials/nlp/sequence_models.md)
37
+ - [Word embeddings](tutorials/nlp/word_embeddings.md)
38
38
 
39
- And leave the other fields in that section blank. Once the notebook is running, you can run the [MNIST example](https://github.com/ankane/ml-stack/blob/master/torch-gpu/MNIST.ipynb).
39
+ ## Examples
40
+
41
+ - [Image classification with MNIST](examples/mnist) ([日本語版](https://qiita.com/kojix2/items/c19c36dc1bf73ea93409))
42
+ - [Collaborative filtering with MovieLens](examples/movielens)
43
+ - [Generative adversarial networks](examples/gan)
40
44
 
41
45
  ## API
42
46
 
@@ -48,7 +52,7 @@ This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.htm
48
52
 
49
53
  You can follow PyTorch tutorials and convert the code to Ruby in many cases. Feel free to open an issue if you run into problems.
50
54
 
51
- ## Tutorial
55
+ ## Overview
52
56
 
53
57
  Some examples below are from [Deep Learning with PyTorch: A 60 Minutes Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)
54
58
 
@@ -214,7 +218,7 @@ Define a neural network
214
218
  ```ruby
215
219
  class MyNet < Torch::NN::Module
216
220
  def initialize
217
- super
221
+ super()
218
222
  @conv1 = Torch::NN::Conv2d.new(1, 6, 3)
219
223
  @conv2 = Torch::NN::Conv2d.new(6, 16, 3)
220
224
  @fc1 = Torch::NN::Linear.new(16 * 6 * 6, 120)
@@ -225,20 +229,10 @@ class MyNet < Torch::NN::Module
225
229
  def forward(x)
226
230
  x = Torch::NN::F.max_pool2d(Torch::NN::F.relu(@conv1.call(x)), [2, 2])
227
231
  x = Torch::NN::F.max_pool2d(Torch::NN::F.relu(@conv2.call(x)), 2)
228
- x = x.view(-1, num_flat_features(x))
232
+ x = Torch.flatten(x, 1)
229
233
  x = Torch::NN::F.relu(@fc1.call(x))
230
234
  x = Torch::NN::F.relu(@fc2.call(x))
231
- x = @fc3.call(x)
232
- x
233
- end
234
-
235
- def num_flat_features(x)
236
- size = x.size[1..-1]
237
- num_features = 1
238
- size.each do |s|
239
- num_features *= s
240
- end
241
- num_features
235
+ @fc3.call(x)
242
236
  end
243
237
  end
244
238
  ```
@@ -402,19 +396,9 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
402
396
  Torch.zeros(3) # tensor([0, 0, 0])
403
397
  ```
404
398
 
405
- ## Examples
406
-
407
- Here are a few full examples:
408
-
409
- - [Image classification with MNIST](examples/mnist) ([日本語版](https://qiita.com/kojix2/items/c19c36dc1bf73ea93409))
410
- - [Collaborative filtering with MovieLens](examples/movielens)
411
- - [Sequence models and word embeddings](examples/nlp)
412
- - [Generative adversarial networks](examples/gan)
413
- - [Transfer learning](examples/transfer-learning)
414
-
415
399
  ## LibTorch Installation
416
400
 
417
- [Download LibTorch](https://pytorch.org/). For Linux, use the `cxx11 ABI` version. Then run:
401
+ [Download LibTorch](https://pytorch.org/) (for Linux, use the `cxx11 ABI` version). Then run:
418
402
 
419
403
  ```sh
420
404
  bundle config build.torch-rb --with-torch-dir=/path/to/libtorch
@@ -444,9 +428,7 @@ Then install the gem (no need for `bundle config`).
444
428
 
445
429
  ## Performance
446
430
 
447
- ### Linux
448
-
449
- Deep learning is significantly faster on a GPU. Install [CUDA](https://developer.nvidia.com/cuda-downloads) and [cuDNN](https://developer.nvidia.com/cudnn) and reinstall the gem.
431
+ Deep learning is significantly faster on a GPU. With Linux, install [CUDA](https://developer.nvidia.com/cuda-downloads) and [cuDNN](https://developer.nvidia.com/cudnn) and reinstall the gem.
450
432
 
451
433
  Check if CUDA is available
452
434
 
@@ -460,15 +442,14 @@ Move a neural network to a GPU
460
442
  net.cuda
461
443
  ```
462
444
 
463
- ## rbenv
445
+ If you don’t have a GPU that supports CUDA, we recommend using a cloud service. [Paperspace](https://www.paperspace.com/) has a great free plan. We’ve put together a [Docker image](https://github.com/ankane/ml-stack) to make it easy to get started. On Paperspace, create a notebook with a custom container. Under advanced options, set the container name to:
464
446
 
465
- This library uses [Rice](https://github.com/jasonroelofs/rice) to interface with LibTorch. Rice and earlier versions of rbenv don’t play nicely together. If you encounter an error during installation, upgrade ruby-build and reinstall your Ruby version.
466
-
467
- ```sh
468
- brew upgrade ruby-build
469
- rbenv install [version]
447
+ ```text
448
+ ankane/ml-stack:torch-gpu
470
449
  ```
471
450
 
451
+ And leave the other fields in that section blank. Once the notebook is running, you can run the [MNIST example](https://github.com/ankane/ml-stack/blob/master/torch-gpu/MNIST.ipynb).
452
+
472
453
  ## History
473
454
 
474
455
  View the [changelog](https://github.com/ankane/torch.rb/blob/master/CHANGELOG.md)
@@ -23,7 +23,7 @@ end
23
23
 
24
24
  def skip_functions(functions)
25
25
  functions.reject do |f|
26
- f.base_name.start_with?("_") ||
26
+ (f.base_name.start_with?("_") && f.base_name != "__lshift__" && f.base_name != "__rshift__") ||
27
27
  f.base_name.include?("_backward") ||
28
28
  f.base_name.include?("_forward") ||
29
29
  f.base_name == "to" ||
@@ -133,6 +133,7 @@ def generate_attach_def(name, type, def_method)
133
133
  ruby_name = ruby_name.sub(/\Afft_/, "") if type == "fft"
134
134
  ruby_name = ruby_name.sub(/\Alinalg_/, "") if type == "linalg"
135
135
  ruby_name = ruby_name.sub(/\Aspecial_/, "") if type == "special"
136
+ ruby_name = name if name.start_with?("__")
136
137
 
137
138
  # cast for Ruby < 2.7 https://github.com/thisMagpie/fftw/issues/22#issuecomment-49508900
138
139
  cast = RUBY_VERSION.to_f > 2.7 ? "" : "(VALUE (*)(...)) "
@@ -1,6 +1,8 @@
1
1
  module Torch
2
2
  module NN
3
3
  class ConvNd < Module
4
+ attr_reader :in_channels, :out_channels, :kernel_size, :stride, :padding, :dilation, :transposed, :output_paddding, :groups, :padding_mode
5
+
4
6
  def initialize(in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode)
5
7
  super()
6
8
  raise ArgumentError, "in_channels must be divisible by groups" if in_channels % groups != 0
@@ -0,0 +1,241 @@
1
+ module Torch
2
+ module NN
3
+ class Functional
4
+ class << self
5
+ def in_projection_packed(q, k, v, w, b: nil)
6
+ e = q.size(-1)
7
+
8
+ if k.eql? v
9
+ if q.eql? k
10
+ # self-attention
11
+ return linear(q, w, b).chunk(3, dim: -1)
12
+ else
13
+ # encoder-decoder attention
14
+ w_q, w_kv = w.split_with_sizes([e, e * 2])
15
+ if b.nil?
16
+ b_q = b_kv = nil
17
+ else
18
+ b_q, b_kv = b.split_with_sizes([e, e * 2])
19
+ end
20
+
21
+ return [linear(q, w_q, b_q), *linear(k, w_kv, b_kv).chunk(2, dim: -1)]
22
+ end
23
+ else
24
+ w_q, w_k, w_v = w.chunk(3)
25
+ if b.nil?
26
+ b_q = b_k = b_v = nil
27
+ else
28
+ b_q, b_k, b_v = b.chunk(3)
29
+ end
30
+
31
+ return [linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
32
+ end
33
+ end
34
+
35
+ def in_projection(
36
+ q, k, v,
37
+ w_q, w_k, w_v,
38
+ b_q: nil, b_k: nil, b_v: nil
39
+ )
40
+
41
+ e_q, e_k, e_v = q.size(-1), k.size(-1), v.size(-1)
42
+
43
+ raise ArgumentError, "Expecting query weights shape of #{[e_q, e_q]}, but got #{w_q.shape}" unless w_q.shape == [e_q, e_q]
44
+ raise ArgumentError, "Expecting key weights shape of #{[e_k, e_k]}, but got #{w_k.shape}" unless w_k.shape == [e_k, e_k]
45
+ raise ArgumentError, "Expecting value weights shape of #{[e_v, e_v]}, but got #{w_v.shape}" unless w_v.shape == [e_v, e_v]
46
+
47
+ raise ArgumentError, "Expecting query bias shape of #{[e_q]}, but got #{b_q.shape}" if b_q && b_q.shape != [e_q]
48
+ raise ArgumentError, "Expecting key bias shape of #{[e_k]}, but got #{b_k.shape}" if b_k && b_k.shape != [e_k]
49
+ raise ArgumentError, "Expecting value bias shape of #{[e_v]}, but got #{b_v.shape}" if b_v && b_v.shape != [e_v]
50
+
51
+ [linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
52
+ end
53
+
54
+ def scaled_dot_product_attention(
55
+ q, k, v,
56
+ attn_mask: nil, dropout_p: 0.0
57
+ )
58
+
59
+ b, nt, e = q.shape
60
+
61
+ q = q / Math.sqrt(e)
62
+
63
+ attn = Torch.bmm(q, k.transpose(-2, -1))
64
+ attn += attn_mask if attn_mask
65
+ attn = softmax(attn, dim: -1)
66
+ attn = dropout(attn, p: dropout_p) if dropout_p > 0
67
+
68
+ output = Torch.bmm(attn, v)
69
+
70
+ [output, attn]
71
+ end
72
+
73
+ def multi_head_attention_forward(
74
+ query, key, value,
75
+ embed_dim_to_check, num_heads,
76
+ in_proj_weight, in_proj_bias,
77
+ bias_k, bias_v,
78
+ add_zero_attn,
79
+ dropout_p,
80
+ out_proj_weight, out_proj_bias,
81
+ training: true,
82
+ key_padding_mask: nil,
83
+ need_weights: true,
84
+ attn_mask: nil,
85
+ use_separate_proj_weight: false,
86
+ q_proj_weight: nil, k_proj_weight: nil, v_proj_weight: nil,
87
+ static_k: nil, static_v: nil
88
+ )
89
+
90
+ tgt_len, bsz, embed_dim = query.shape
91
+ src_len = key.shape.first
92
+
93
+ raise ArgumentError, "Was expecting embedding dimension of #{embed_dim_to_check}, but got #{embed_dim}" unless embed_dim == embed_dim_to_check
94
+
95
+ head_dim = if embed_dim.is_a?(Torch::Tensor)
96
+ embed_dim.div(num_heads, rounding_mode: 'trunc')
97
+ else
98
+ head_dim = embed_dim.div num_heads
99
+ end
100
+
101
+ if use_separate_proj_weight
102
+ raise ArgumentError, "Key's sequence and batch dims #{key.shape[0...2]} do not match value's #{value.shape[0...2]}" unless key.shape[0...2] == value.shape[0...2]
103
+ else
104
+ raise ArgumentError, "Key shape #{key.shape} does not match value shape #{value.shape}" unless key.shape == value.shape
105
+ end
106
+
107
+ # compute in-projection
108
+ q, k, v =
109
+ if use_separate_proj_weight
110
+ raise ArgumentError, "use_separate_proj_weight is true but q_proj_weight is nil" unless q_proj_weight
111
+ raise ArgumentError, "use_separate_proj_weight is true but k_proj_weight is nil" unless k_proj_weight
112
+ raise ArgumentError, "use_separate_proj_weight is true but v_proj_weight is nil" unless v_proj_weight
113
+
114
+ if in_proj_bias
115
+ b_q, b_k, b_v = in_proj_bias.chunk(3)
116
+ else
117
+ b_q = b_k = b_v = nil
118
+ end
119
+
120
+ in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q: b_q, b_k: b_k, b_v: b_v)
121
+ else
122
+ in_projection_packed(query, key, value, in_proj_weight, b: in_proj_bias)
123
+ end
124
+
125
+ # prep attention mask
126
+ if attn_mask
127
+ if attn_mask.dtype == :uint8
128
+ puts "[WARN] Byte tensor for attn_mask in Multihead Attention is deprecated. Use bool tensor instead."
129
+ attn_mask = attn_mask.bool
130
+ else
131
+ raise ArgumentError, "Only float, byte, and bool types are supported for attn_mask, not #{attn_mask.dtype}" unless attn_mask.floating_point? || attn_mask.dtype == :bool
132
+ end
133
+
134
+ if attn_mask.dim == 2
135
+ correct_2d_size = [tgt_len, src_len]
136
+ raise ArgumentError, "The shape of the 2D attn_mask is #{attn_mask.shape}, but should be #{correct_2d_size}." unless attn_mask.shape == correct_2d_size
137
+
138
+ attn_mask = attn_mask.unsqueeze(0)
139
+ elsif attn_mask.dim == 3
140
+ correct_3d_size = [bsz * num_heads, tgt_len, src_len]
141
+ raise ArgumentError, "The shape of the 3D attn_mask is #{attn_mask.shape}, but should be #{correct_3d_size}." unless attn_mask.shape == correct_3d_size
142
+ else
143
+ raise ArgumentError, "attn_mask's dimension #{attn_mask.dim} is not supported"
144
+ end
145
+ end
146
+
147
+ # prep key padding mask
148
+ if key_padding_mask && key_padding_mask.dtype == :uint8
149
+ puts "[WARN] Byte tensor for key_padding_mask in Multihead Attention is deprecated. Use bool tensor instead."
150
+ key_padding_mask = key_padding_mask.bool
151
+ end
152
+
153
+ # add bias along batch dimension (currently second)
154
+ if bias_k && bias_v
155
+ raise ArgumentError, "bias cannot be added to static key." if static_k
156
+ raise ArgumentError, "bias cannot be added to static value." if static_v
157
+
158
+ k = Torch.cat([k, bias_k.repeat(1, bsz, 1)])
159
+ v = Torch.cat([v, bias_v.repeat(1, bsz, 1)])
160
+
161
+ attn_mask = pad(attn_mask, [0, 1]) if attn_mask
162
+ key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask
163
+ else
164
+ raise ArgumentError unless bias_k.nil?
165
+ raise ArgumentError unless bias_v.nil?
166
+ end
167
+
168
+ # reshape q, k, v for multihead attention and make em batch first
169
+ q = q.contiguous.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
170
+
171
+ if static_k.nil?
172
+ k = k.contiguous.view(-1, bsz * num_heads, head_dim).transpose(0, 1)
173
+ else
174
+ raise ArgumentError, "Expecting static_k.size(0) of #{bsz * num_heads}, but got #{static_k.size(0)}" unless static_k.size(0) == bsz * num_heads
175
+ raise ArgumentError, "Expecting static_k.size(2) of #{head_dim}, but got #{static_k.size(2)}" unless static_k.size(2) == head_dim
176
+
177
+ k = static_k
178
+ end
179
+
180
+ if static_v.nil?
181
+ v = v.contiguous.view(-1, bsz * num_heads, head_dim).transpose(0, 1)
182
+ else
183
+ raise ArgumentError, "Expecting static_v.size(0) of #{bsz * num_heads}, but got #{static_v.size(0)}" unless static_v.size(0) == bsz * num_heads
184
+ raise ArgumentError, "Expecting static_v.size(2) of #{head_dim}, but got #{static_v.size(2)}" unless static_v.size(2) == head_dim
185
+
186
+ v = static_v
187
+ end
188
+
189
+ # add zero attention along batch dimension (now first)
190
+ if add_zero_attn
191
+ zero_attn_shape = [bsz * num_heads, 1, head_dim]
192
+ k = Torch.cat([k, Torch.zeros(zero_attn_shape, dtype: k.dtype, device: k.device)], dim: 1)
193
+ v = Torch.cat([v, Torch.zeros(zero_attn_shape, dtype: v.dtype, device: v.device)], dim: 1)
194
+
195
+ attn_mask = pad(attn_mask, [0, 1]) if attn_mask
196
+ key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask
197
+ end
198
+
199
+ # update source sequence length after adjustments
200
+ src_len = k.size(1)
201
+
202
+ # merge key padding and attention masks
203
+ if key_padding_mask
204
+ raise ArgumentError, "Expecting key_padding_mask shape of #{[bsz, src_len]}, but got #{key_padding_mask.shape}" unless key_padding_mask.shape == [bsz, src_len]
205
+
206
+ key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
207
+
208
+ attn_mask = if attn_mask.nil?
209
+ key_padding_mask
210
+ elsif attn_mask.dtype == :bool
211
+ attn_mask.logical_or(key_padding_mask)
212
+ else
213
+ attn_mask.masked_fill(key_padding_mask, -Float::INFINITY)
214
+ end
215
+ end
216
+
217
+ # convert mask to float
218
+ if attn_mask && attn_mask.dtype == :bool
219
+ new_attn_mask = Torch.zeros_like(attn_mask, dtype: :float32)
220
+ attn_mask = new_attn_mask.masked_fill(attn_mask, -Float::INFINITY)
221
+ end
222
+
223
+ dropout_p = 0.0 unless training
224
+
225
+ # (deep breath) calculate attention and out projection
226
+ attn_output, attn_output_weights = scaled_dot_product_attention(q, k, v, attn_mask: attn_mask, dropout_p: dropout_p)
227
+ attn_output = attn_output.transpose(0, 1).contiguous.view(tgt_len, bsz, embed_dim)
228
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
229
+
230
+ if need_weights
231
+ # average attention weights over heads
232
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
233
+ [attn_output, attn_output_weights.sum(dim: 1) / num_heads]
234
+ else
235
+ [attn_output, nil]
236
+ end
237
+ end
238
+ end
239
+ end
240
+ end
241
+ end
@@ -3,6 +3,8 @@ module Torch
3
3
  class Module
4
4
  include Utils
5
5
 
6
+ attr_reader :training
7
+
6
8
  def initialize
7
9
  @training = true
8
10
  @parameters = {}
@@ -0,0 +1,49 @@
1
+ module Torch
2
+ module NN
3
+ class ModuleList < Module
4
+ include Enumerable
5
+
6
+ def initialize(mods = nil)
7
+ super()
8
+
9
+ self.concat(mods) if mods
10
+ end
11
+
12
+ def length
13
+ @modules.length
14
+ end
15
+ alias_method :count, :length
16
+ alias_method :size, :length
17
+
18
+ def concat(mods)
19
+ raise ArgumentError, "Modules should respond to #each" unless mods.respond_to?(:each)
20
+
21
+ mods.each { |m| append m }
22
+
23
+ self
24
+ end
25
+
26
+ def each(&block)
27
+ if block_given?
28
+ @modules.values.each(&block)
29
+ else
30
+ to_enum(:each)
31
+ end
32
+ end
33
+
34
+ def append(mod)
35
+ raise ArgumentError, "Provided element is not a module" unless mod.is_a?(Module)
36
+ add_module(length.to_s, mod)
37
+ self
38
+ end
39
+
40
+ def [](idx)
41
+ if idx.is_a?(Range)
42
+ self.class.new(@modules.values[idx])
43
+ else
44
+ @modules[idx.to_s]
45
+ end
46
+ end
47
+ end
48
+ end
49
+ end
@@ -0,0 +1,123 @@
1
+ module Torch
2
+ module NN
3
+ class MultiheadAttention < Module
4
+ def initialize(
5
+ embed_dim, num_heads,
6
+ dropout: 0.0, bias: true, add_bias_kv: false, add_zero_attn: false,
7
+ kdim: nil, vdim: nil, batch_first: false, device: nil, dtype: nil
8
+ )
9
+
10
+ super()
11
+
12
+ @embed_dim = embed_dim
13
+ @kdim = kdim || @embed_dim
14
+ @vdim = vdim || @embed_dim
15
+
16
+ @qkv_same_embed_dim = @kdim == @embed_dim && @vdim == @embed_dim
17
+
18
+ @num_heads = num_heads
19
+ @dropout = dropout
20
+ @batch_first = batch_first
21
+
22
+ @head_dim = @embed_dim.div @num_heads
23
+
24
+ raise ArgumentError, "embed_dim must be divisible by num_heads" unless @head_dim * @num_heads == @embed_dim
25
+
26
+ if @qkv_same_embed_dim
27
+ @in_proj_weight = Parameter.new(Torch.empty([3 * @embed_dim, @embed_dim]))
28
+ %w(q k v).each { |x| register_parameter("#{x}_proj_weight", nil) }
29
+ else
30
+ @q_proj_weight = Parameter.new(Torch.empty([@embed_dim, @embed_dim]))
31
+ @k_proj_weight = Parameter.new(Torch.empty([@embed_dim, @kdim]))
32
+ @v_proj_weight = Parameter.new(Torch.empty([@embed_dim, @vdim]))
33
+
34
+ register_parameter('in_proj_weight', nil)
35
+ end
36
+
37
+ if bias
38
+ @in_proj_bias = Parameter.new(Torch.empty(3 * @embed_dim))
39
+ else
40
+ register_parameter('in_proj_bias', nil)
41
+ end
42
+
43
+ @out_proj = Linear.new(@embed_dim, @embed_dim, bias: bias)
44
+
45
+ if add_bias_kv
46
+ @bias_k = Parameter.new(Torch.empty([1, 1, @embed_dim]))
47
+ @bias_v = Parameter.new(Torch.empty([1, 1, @embed_dim]))
48
+ else
49
+ @bias_k = @bias_v = nil
50
+ end
51
+
52
+ @add_zero_attn = add_zero_attn
53
+
54
+ reset_parameters
55
+ end
56
+
57
+ def batch_first?
58
+ !!@batch_first
59
+ end
60
+
61
+ def reset_parameters
62
+ if @qkv_same_embed_dim
63
+ Init.xavier_uniform!(@in_proj_weight)
64
+ else
65
+ Init.xavier_uniform!(@q_proj_weight)
66
+ Init.xavier_uniform!(@k_proj_weight)
67
+ Init.xavier_uniform!(@v_proj_weight)
68
+ end
69
+
70
+ if @in_proj_bias
71
+ Init.constant!(@in_proj_bias, 0.0)
72
+ Init.constant!(@out_proj.bias, 0.0)
73
+ end
74
+
75
+ Init.xavier_uniform!(@bias_k) if @bias_k
76
+ Init.xavier_uniform!(@bias_v) if @bias_v
77
+ end
78
+
79
+ def forward(
80
+ query, key, value,
81
+ key_padding_mask: nil, need_weights: true, attn_mask: nil
82
+ )
83
+
84
+ if batch_first?
85
+ query, key, value = [query, key, value].map { |t| t.transpose(1, 0) }
86
+ end
87
+
88
+ attn_output, attn_output_weights =
89
+ if @qkv_same_embed_dim
90
+ F.multi_head_attention_forward(
91
+ query, key, value,
92
+ @embed_dim, @num_heads,
93
+ @in_proj_weight, @in_proj_bias,
94
+ @bias_k, @bias_v, @add_zero_attn,
95
+ @dropout, @out_proj.weight, @out_proj.bias,
96
+ training: @training,
97
+ key_padding_mask: key_padding_mask,
98
+ need_weights: need_weights,
99
+ attn_mask: attn_mask
100
+ )
101
+ else
102
+ F.multi_head_attention_forward(
103
+ query, key, value,
104
+ @embed_dim, @num_heads,
105
+ @in_proj_weight, @in_proj_bias,
106
+ @bias_k, @bias_v, @add_zero_attn,
107
+ @dropout, @out_proj.weight, @out_proj.bias,
108
+ training: @training,
109
+ key_padding_mask: key_padding_mask,
110
+ need_weights: need_weights,
111
+ attn_mask: attn_mask,
112
+ use_separate_proj_weight: true,
113
+ q_proj_weight: @q_proj_weight, k_proj_weight: @k_proj_weight, v_proj_weight: @v_proj_weight
114
+ )
115
+ end
116
+
117
+ attn_output = attn_output.transpose(1, 0) if batch_first?
118
+
119
+ [attn_output, attn_output_weights]
120
+ end
121
+ end
122
+ end
123
+ end
@@ -0,0 +1,92 @@
1
+ require_relative 'transformer_encoder_layer'
2
+ require_relative 'transformer_encoder'
3
+ require_relative 'transformer_decoder_layer'
4
+ require_relative 'transformer_decoder'
5
+
6
+ module Torch
7
+ module NN
8
+ class Transformer < Module
9
+ def initialize(
10
+ d_model: 512, nhead: 8,
11
+ num_encoder_layers: 6, num_decoder_layers: 6,
12
+ dim_feedforward: 2048, dropout: 0.1, activation: :relu,
13
+ custom_encoder: nil, custom_decoder: nil,
14
+ layer_norm_eps: 1e-5, batch_first: false
15
+ )
16
+
17
+ super()
18
+
19
+ @encoder =
20
+ if custom_encoder
21
+ custom_encoder
22
+ else
23
+ encoder_layer = TransformerEncoderLayer.new(
24
+ d_model, nhead,
25
+ dim_feedforward: dim_feedforward, dropout: dropout, activation: activation,
26
+ layer_norm_eps: layer_norm_eps, batch_first: batch_first
27
+ )
28
+ encoder_norm = LayerNorm.new(d_model, eps: layer_norm_eps)
29
+ TransformerEncoder.new(encoder_layer, num_encoder_layers, norm: encoder_norm)
30
+ end
31
+
32
+ @decoder =
33
+ if custom_decoder
34
+ custom_decoder
35
+ else
36
+ decoder_layer = TransformerDecoderLayer.new(
37
+ d_model, nhead,
38
+ dim_feedforward: dim_feedforward, dropout: dropout, activation: activation,
39
+ layer_norm_eps: layer_norm_eps, batch_first: batch_first
40
+ )
41
+ decoder_norm = LayerNorm.new(d_model, eps: layer_norm_eps)
42
+ TransformerDecoder.new(decoder_layer, num_decoder_layers, norm: decoder_norm)
43
+ end
44
+
45
+ reset_parameters
46
+
47
+ @d_model = d_model
48
+ @nhead = nhead
49
+ @batch_first = batch_first
50
+ end
51
+
52
+ attr_reader :d_model, :nhead, :encoder, :decoder
53
+
54
+ def batch_first?
55
+ !!@batch_first
56
+ end
57
+
58
+ def reset_parameters
59
+ parameters.each { |p| Init.xavier_uniform!(p) if p.dim > 1 }
60
+ end
61
+
62
+ def forward(
63
+ src, tgt,
64
+ src_mask: nil, tgt_mask: nil, memory_mask: nil,
65
+ src_key_padding_mask: nil, tgt_key_padding_mask: nil, memory_key_padding_mask: nil
66
+ )
67
+
68
+ if (!batch_first? && src.size(1) != tgt.size(1)) ||
69
+ (batch_first? && src.size(0) != tgt.size(0))
70
+
71
+ raise ArgumentError, "The batch number of src and tgt must be equal"
72
+ end
73
+
74
+ if src.size(2) != d_model || tgt.size(2) != d_model
75
+ raise ArgumentError, "The feature number of src and tgt must be equal to d_model"
76
+ end
77
+
78
+ memory = @encoder.(src, mask: src_mask, src_key_padding_mask: src_key_padding_mask)
79
+ @decoder.(
80
+ tgt, memory,
81
+ tgt_mask: tgt_mask, memory_mask: memory_mask,
82
+ tgt_key_padding_mask: tgt_key_padding_mask, memory_key_padding_mask: memory_key_padding_mask
83
+ )
84
+ end
85
+
86
+ def generate_square_subsequent_mask(sz)
87
+ mask = Torch.triu(Torch.ones([sz, sz])).eq(1).transpose(0, 1)
88
+ mask.float.masked_fill!(mask.eq(0), -Float::INFINITY).masked_fill!(mask.eq(1), 0.0)
89
+ end
90
+ end
91
+ end
92
+ end
@@ -0,0 +1,25 @@
1
+ module Torch
2
+ module NN
3
+ class TransformerDecoder < Module
4
+ def initialize(decoder_layer, num_layers, norm: nil)
5
+ super()
6
+
7
+ @layers = _clones(decoder_layer, num_layers)
8
+ @num_layers = num_layers
9
+ @norm = norm
10
+ end
11
+
12
+ def forward(tgt, memory, tgt_mask: nil, memory_mask: nil, tgt_key_padding_mask: nil, memory_key_padding_mask: nil)
13
+ output = tgt
14
+
15
+ @layers.each do |mod|
16
+ output = mod.call(output, memory, tgt_mask: tgt_mask, memory_mask: memory_mask, tgt_key_padding_mask: tgt_key_padding_mask, memory_key_padding_mask: memory_key_padding_mask)
17
+ end
18
+
19
+ output = @norm.call(output) if @norm
20
+
21
+ output
22
+ end
23
+ end
24
+ end
25
+ end
@@ -0,0 +1,43 @@
1
+ module Torch
2
+ module NN
3
+ class TransformerDecoderLayer < Module
4
+ def initialize(
5
+ d_model, n_head,
6
+ dim_feedforward: 2048, dropout: 0.1, activation: :relu,
7
+ layer_norm_eps: 1e-5, batch_first: false
8
+ )
9
+
10
+ super()
11
+
12
+ @self_attn = MultiheadAttention.new(d_model, n_head, dropout: dropout, batch_first: batch_first)
13
+ @multihead_attn = MultiheadAttention.new(d_model, n_head, dropout: dropout, batch_first: batch_first)
14
+
15
+ @linear1 = Linear.new(d_model, dim_feedforward)
16
+ @dropout = Dropout.new(p: dropout)
17
+ @linear2 = Linear.new(dim_feedforward, d_model)
18
+
19
+ @norm1 = LayerNorm.new(d_model, eps: layer_norm_eps)
20
+ @norm2 = LayerNorm.new(d_model, eps: layer_norm_eps)
21
+ @norm3 = LayerNorm.new(d_model, eps: layer_norm_eps)
22
+
23
+ @dropout1 = Dropout.new(p: dropout)
24
+ @dropout2 = Dropout.new(p: dropout)
25
+ @dropout3 = Dropout.new(p: dropout)
26
+
27
+ @activation = _activation_fn(activation)
28
+ end
29
+
30
+ def forward(tgt, memory, tgt_mask: nil, memory_mask: nil, tgt_key_padding_mask: nil, memory_key_padding_mask: nil)
31
+ tgt2 = @self_attn.(tgt, tgt, tgt, attn_mask: tgt_mask, key_padding_mask: tgt_key_padding_mask).first
32
+ tgt += @dropout1.(tgt2)
33
+ tgt = @norm1.(tgt)
34
+ tgt2 = @multihead_attn.(tgt, memory, memory, attn_mask: memory_mask, key_padding_mask: memory_key_padding_mask).first
35
+ tgt += @dropout2.(tgt2)
36
+ tgt = @norm2.(tgt)
37
+ tgt2 = @linear2.(@dropout.(@activation.(@linear1.(tgt))))
38
+ tgt += @dropout3.(tgt)
39
+ @norm3.(tgt)
40
+ end
41
+ end
42
+ end
43
+ end
@@ -0,0 +1,25 @@
1
+ module Torch
2
+ module NN
3
+ class TransformerEncoder < Module
4
+ def initialize(encoder_layer, num_layers, norm: nil)
5
+ super()
6
+
7
+ @layers = _clones(encoder_layer, num_layers)
8
+ @num_layers = num_layers
9
+ @norm = norm
10
+ end
11
+
12
+ def forward(src, mask: nil, src_key_padding_mask: nil)
13
+ output = src
14
+
15
+ @layers.each do |mod|
16
+ output = mod.call(output, src_mask: mask, src_key_padding_mask: src_key_padding_mask)
17
+ end
18
+
19
+ output = @norm.call(output) if @norm
20
+
21
+ output
22
+ end
23
+ end
24
+ end
25
+ end
@@ -0,0 +1,36 @@
1
+ module Torch
2
+ module NN
3
+ class TransformerEncoderLayer < Module
4
+ def initialize(
5
+ d_model, n_head,
6
+ dim_feedforward: 2048, dropout: 0.1, activation: :relu,
7
+ layer_norm_eps: 1e-5, batch_first: false
8
+ )
9
+
10
+ super()
11
+
12
+ @self_attn = MultiheadAttention.new(d_model, n_head, dropout: dropout, batch_first: batch_first)
13
+ @linear1 = Linear.new(d_model, dim_feedforward)
14
+ @dropout = Dropout.new(p: dropout)
15
+ @linear2 = Linear.new(dim_feedforward, d_model)
16
+
17
+ @norm1 = LayerNorm.new(d_model, eps: layer_norm_eps)
18
+ @norm2 = LayerNorm.new(d_model, eps: layer_norm_eps)
19
+
20
+ @dropout1 = Dropout.new(p: dropout)
21
+ @dropout2 = Dropout.new(p: dropout)
22
+
23
+ @activation = _activation_fn(activation)
24
+ end
25
+
26
+ def forward(src, src_mask: nil, src_key_padding_mask: nil)
27
+ src2 = @self_attn.(src, src, src, attn_mask: src_mask, key_padding_mask: src_key_padding_mask).first
28
+ src += @dropout1.(src2)
29
+ src = @norm1.(src)
30
+ src2 = @linear2.(@dropout.(@activation.(@linear1.(src))))
31
+ src += @dropout2.(src2)
32
+ @norm2.(src)
33
+ end
34
+ end
35
+ end
36
+ end
@@ -20,6 +20,22 @@ module Torch
20
20
  def _ntuple(n, value)
21
21
  value.is_a?(Array) ? value : [value] * n
22
22
  end
23
+
24
+ def _clones(mod, n)
25
+ state = mod.state_dict
26
+ layers = n.times.map do |i|
27
+ mod.clone.tap { |l| l.load_state_dict(state) }
28
+ end
29
+ ModuleList.new(layers)
30
+ end
31
+
32
+ def _activation_fn(activation)
33
+ case activation.to_sym
34
+ when :relu then F.method(:relu)
35
+ when :gelu then F.method(:gelu)
36
+ else raise ArgumentError, "Activation should be relu/gelu, not `#{activation}`"
37
+ end
38
+ end
23
39
  end
24
40
  end
25
41
  end
data/lib/torch/tensor.rb CHANGED
@@ -19,6 +19,8 @@ module Torch
19
19
  alias_method :&, :logical_and
20
20
  alias_method :|, :logical_or
21
21
  alias_method :^, :logical_xor
22
+ alias_method :<<, :__lshift__
23
+ alias_method :>>, :__rshift__
22
24
 
23
25
  def self.new(*args)
24
26
  FloatTensor.new(*args)
@@ -25,6 +25,8 @@ module Torch
25
25
  end
26
26
 
27
27
  def each
28
+ return to_enum(:each) unless block_given?
29
+
28
30
  # try to keep the random number generator in sync with Python
29
31
  # this makes it easy to compare results
30
32
  base_seed = Torch.empty([], dtype: :int64).random!.item
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.8.1"
2
+ VERSION = "0.8.2"
3
3
  end
data/lib/torch.rb CHANGED
@@ -39,6 +39,7 @@ require "torch/nn/utils"
39
39
 
40
40
  # nn containers
41
41
  require "torch/nn/module"
42
+ require "torch/nn/module_list"
42
43
  require "torch/nn/sequential"
43
44
 
44
45
  # nn convolution layers
@@ -143,6 +144,10 @@ require "torch/nn/softmin"
143
144
  require "torch/nn/embedding"
144
145
  require "torch/nn/embedding_bag"
145
146
 
147
+ # attention is all you need
148
+ require "torch/nn/multihead_attention"
149
+ require "torch/nn/transformer"
150
+
146
151
  # nn distance functions
147
152
  require "torch/nn/cosine_similarity"
148
153
  require "torch/nn/pairwise_distance"
@@ -174,6 +179,7 @@ require "torch/nn/upsample"
174
179
 
175
180
  # nn other
176
181
  require "torch/nn/functional"
182
+ require "torch/nn/functional_attention"
177
183
  require "torch/nn/init"
178
184
 
179
185
  # utils
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.8.1
4
+ version: 0.8.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2021-06-16 00:00:00.000000000 Z
11
+ date: 2021-10-04 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -106,6 +106,7 @@ files:
106
106
  - lib/torch/nn/feature_alpha_dropout.rb
107
107
  - lib/torch/nn/fold.rb
108
108
  - lib/torch/nn/functional.rb
109
+ - lib/torch/nn/functional_attention.rb
109
110
  - lib/torch/nn/group_norm.rb
110
111
  - lib/torch/nn/gru.rb
111
112
  - lib/torch/nn/hardshrink.rb
@@ -139,10 +140,12 @@ files:
139
140
  - lib/torch/nn/max_unpool3d.rb
140
141
  - lib/torch/nn/max_unpoolnd.rb
141
142
  - lib/torch/nn/module.rb
143
+ - lib/torch/nn/module_list.rb
142
144
  - lib/torch/nn/mse_loss.rb
143
145
  - lib/torch/nn/multi_label_margin_loss.rb
144
146
  - lib/torch/nn/multi_label_soft_margin_loss.rb
145
147
  - lib/torch/nn/multi_margin_loss.rb
148
+ - lib/torch/nn/multihead_attention.rb
146
149
  - lib/torch/nn/nll_loss.rb
147
150
  - lib/torch/nn/pairwise_distance.rb
148
151
  - lib/torch/nn/parameter.rb
@@ -170,6 +173,11 @@ files:
170
173
  - lib/torch/nn/softsign.rb
171
174
  - lib/torch/nn/tanh.rb
172
175
  - lib/torch/nn/tanhshrink.rb
176
+ - lib/torch/nn/transformer.rb
177
+ - lib/torch/nn/transformer_decoder.rb
178
+ - lib/torch/nn/transformer_decoder_layer.rb
179
+ - lib/torch/nn/transformer_encoder.rb
180
+ - lib/torch/nn/transformer_encoder_layer.rb
173
181
  - lib/torch/nn/triplet_margin_loss.rb
174
182
  - lib/torch/nn/unfold.rb
175
183
  - lib/torch/nn/upsample.rb
@@ -219,7 +227,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
219
227
  - !ruby/object:Gem::Version
220
228
  version: '0'
221
229
  requirements: []
222
- rubygems_version: 3.2.3
230
+ rubygems_version: 3.2.22
223
231
  signing_key:
224
232
  specification_version: 4
225
233
  summary: Deep learning for Ruby, powered by LibTorch