torch-rb 0.8.0 → 0.9.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,123 @@
1
+ module Torch
2
+ module NN
3
+ class MultiheadAttention < Module
4
+ def initialize(
5
+ embed_dim, num_heads,
6
+ dropout: 0.0, bias: true, add_bias_kv: false, add_zero_attn: false,
7
+ kdim: nil, vdim: nil, batch_first: false, device: nil, dtype: nil
8
+ )
9
+
10
+ super()
11
+
12
+ @embed_dim = embed_dim
13
+ @kdim = kdim || @embed_dim
14
+ @vdim = vdim || @embed_dim
15
+
16
+ @qkv_same_embed_dim = @kdim == @embed_dim && @vdim == @embed_dim
17
+
18
+ @num_heads = num_heads
19
+ @dropout = dropout
20
+ @batch_first = batch_first
21
+
22
+ @head_dim = @embed_dim.div @num_heads
23
+
24
+ raise ArgumentError, "embed_dim must be divisible by num_heads" unless @head_dim * @num_heads == @embed_dim
25
+
26
+ if @qkv_same_embed_dim
27
+ @in_proj_weight = Parameter.new(Torch.empty([3 * @embed_dim, @embed_dim]))
28
+ %w(q k v).each { |x| register_parameter("#{x}_proj_weight", nil) }
29
+ else
30
+ @q_proj_weight = Parameter.new(Torch.empty([@embed_dim, @embed_dim]))
31
+ @k_proj_weight = Parameter.new(Torch.empty([@embed_dim, @kdim]))
32
+ @v_proj_weight = Parameter.new(Torch.empty([@embed_dim, @vdim]))
33
+
34
+ register_parameter('in_proj_weight', nil)
35
+ end
36
+
37
+ if bias
38
+ @in_proj_bias = Parameter.new(Torch.empty(3 * @embed_dim))
39
+ else
40
+ register_parameter('in_proj_bias', nil)
41
+ end
42
+
43
+ @out_proj = Linear.new(@embed_dim, @embed_dim, bias: bias)
44
+
45
+ if add_bias_kv
46
+ @bias_k = Parameter.new(Torch.empty([1, 1, @embed_dim]))
47
+ @bias_v = Parameter.new(Torch.empty([1, 1, @embed_dim]))
48
+ else
49
+ @bias_k = @bias_v = nil
50
+ end
51
+
52
+ @add_zero_attn = add_zero_attn
53
+
54
+ reset_parameters
55
+ end
56
+
57
+ def batch_first?
58
+ !!@batch_first
59
+ end
60
+
61
+ def reset_parameters
62
+ if @qkv_same_embed_dim
63
+ Init.xavier_uniform!(@in_proj_weight)
64
+ else
65
+ Init.xavier_uniform!(@q_proj_weight)
66
+ Init.xavier_uniform!(@k_proj_weight)
67
+ Init.xavier_uniform!(@v_proj_weight)
68
+ end
69
+
70
+ if @in_proj_bias
71
+ Init.constant!(@in_proj_bias, 0.0)
72
+ Init.constant!(@out_proj.bias, 0.0)
73
+ end
74
+
75
+ Init.xavier_uniform!(@bias_k) if @bias_k
76
+ Init.xavier_uniform!(@bias_v) if @bias_v
77
+ end
78
+
79
+ def forward(
80
+ query, key, value,
81
+ key_padding_mask: nil, need_weights: true, attn_mask: nil
82
+ )
83
+
84
+ if batch_first?
85
+ query, key, value = [query, key, value].map { |t| t.transpose(1, 0) }
86
+ end
87
+
88
+ attn_output, attn_output_weights =
89
+ if @qkv_same_embed_dim
90
+ F.multi_head_attention_forward(
91
+ query, key, value,
92
+ @embed_dim, @num_heads,
93
+ @in_proj_weight, @in_proj_bias,
94
+ @bias_k, @bias_v, @add_zero_attn,
95
+ @dropout, @out_proj.weight, @out_proj.bias,
96
+ training: @training,
97
+ key_padding_mask: key_padding_mask,
98
+ need_weights: need_weights,
99
+ attn_mask: attn_mask
100
+ )
101
+ else
102
+ F.multi_head_attention_forward(
103
+ query, key, value,
104
+ @embed_dim, @num_heads,
105
+ @in_proj_weight, @in_proj_bias,
106
+ @bias_k, @bias_v, @add_zero_attn,
107
+ @dropout, @out_proj.weight, @out_proj.bias,
108
+ training: @training,
109
+ key_padding_mask: key_padding_mask,
110
+ need_weights: need_weights,
111
+ attn_mask: attn_mask,
112
+ use_separate_proj_weight: true,
113
+ q_proj_weight: @q_proj_weight, k_proj_weight: @k_proj_weight, v_proj_weight: @v_proj_weight
114
+ )
115
+ end
116
+
117
+ attn_output = attn_output.transpose(1, 0) if batch_first?
118
+
119
+ [attn_output, attn_output_weights]
120
+ end
121
+ end
122
+ end
123
+ end
@@ -9,6 +9,12 @@ module Torch
9
9
  def inspect
10
10
  "Parameter containing:\n#{super}"
11
11
  end
12
+
13
+ def dup
14
+ Torch.no_grad do
15
+ Parameter.new(clone, requires_grad: requires_grad)
16
+ end
17
+ end
12
18
  end
13
19
  end
14
20
  end
@@ -0,0 +1,92 @@
1
+ require_relative 'transformer_encoder_layer'
2
+ require_relative 'transformer_encoder'
3
+ require_relative 'transformer_decoder_layer'
4
+ require_relative 'transformer_decoder'
5
+
6
+ module Torch
7
+ module NN
8
+ class Transformer < Module
9
+ def initialize(
10
+ d_model: 512, nhead: 8,
11
+ num_encoder_layers: 6, num_decoder_layers: 6,
12
+ dim_feedforward: 2048, dropout: 0.1, activation: :relu,
13
+ custom_encoder: nil, custom_decoder: nil,
14
+ layer_norm_eps: 1e-5, batch_first: false
15
+ )
16
+
17
+ super()
18
+
19
+ @encoder =
20
+ if custom_encoder
21
+ custom_encoder
22
+ else
23
+ encoder_layer = TransformerEncoderLayer.new(
24
+ d_model, nhead,
25
+ dim_feedforward: dim_feedforward, dropout: dropout, activation: activation,
26
+ layer_norm_eps: layer_norm_eps, batch_first: batch_first
27
+ )
28
+ encoder_norm = LayerNorm.new(d_model, eps: layer_norm_eps)
29
+ TransformerEncoder.new(encoder_layer, num_encoder_layers, norm: encoder_norm)
30
+ end
31
+
32
+ @decoder =
33
+ if custom_decoder
34
+ custom_decoder
35
+ else
36
+ decoder_layer = TransformerDecoderLayer.new(
37
+ d_model, nhead,
38
+ dim_feedforward: dim_feedforward, dropout: dropout, activation: activation,
39
+ layer_norm_eps: layer_norm_eps, batch_first: batch_first
40
+ )
41
+ decoder_norm = LayerNorm.new(d_model, eps: layer_norm_eps)
42
+ TransformerDecoder.new(decoder_layer, num_decoder_layers, norm: decoder_norm)
43
+ end
44
+
45
+ reset_parameters
46
+
47
+ @d_model = d_model
48
+ @nhead = nhead
49
+ @batch_first = batch_first
50
+ end
51
+
52
+ attr_reader :d_model, :nhead, :encoder, :decoder
53
+
54
+ def batch_first?
55
+ !!@batch_first
56
+ end
57
+
58
+ def reset_parameters
59
+ parameters.each { |p| Init.xavier_uniform!(p) if p.dim > 1 }
60
+ end
61
+
62
+ def forward(
63
+ src, tgt,
64
+ src_mask: nil, tgt_mask: nil, memory_mask: nil,
65
+ src_key_padding_mask: nil, tgt_key_padding_mask: nil, memory_key_padding_mask: nil
66
+ )
67
+
68
+ if (!batch_first? && src.size(1) != tgt.size(1)) ||
69
+ (batch_first? && src.size(0) != tgt.size(0))
70
+
71
+ raise ArgumentError, "The batch number of src and tgt must be equal"
72
+ end
73
+
74
+ if src.size(2) != d_model || tgt.size(2) != d_model
75
+ raise ArgumentError, "The feature number of src and tgt must be equal to d_model"
76
+ end
77
+
78
+ memory = @encoder.(src, mask: src_mask, src_key_padding_mask: src_key_padding_mask)
79
+ @decoder.(
80
+ tgt, memory,
81
+ tgt_mask: tgt_mask, memory_mask: memory_mask,
82
+ tgt_key_padding_mask: tgt_key_padding_mask, memory_key_padding_mask: memory_key_padding_mask
83
+ )
84
+ end
85
+
86
+ def generate_square_subsequent_mask(sz)
87
+ mask = Torch.triu(Torch.ones([sz, sz])).eq(1).transpose(0, 1)
88
+ mask.float.masked_fill!(mask.eq(0), -Float::INFINITY).masked_fill!(mask.eq(1), 0.0)
89
+ end
90
+ end
91
+ end
92
+ end
@@ -0,0 +1,25 @@
1
+ module Torch
2
+ module NN
3
+ class TransformerDecoder < Module
4
+ def initialize(decoder_layer, num_layers, norm: nil)
5
+ super()
6
+
7
+ @layers = _clones(decoder_layer, num_layers)
8
+ @num_layers = num_layers
9
+ @norm = norm
10
+ end
11
+
12
+ def forward(tgt, memory, tgt_mask: nil, memory_mask: nil, tgt_key_padding_mask: nil, memory_key_padding_mask: nil)
13
+ output = tgt
14
+
15
+ @layers.each do |mod|
16
+ output = mod.call(output, memory, tgt_mask: tgt_mask, memory_mask: memory_mask, tgt_key_padding_mask: tgt_key_padding_mask, memory_key_padding_mask: memory_key_padding_mask)
17
+ end
18
+
19
+ output = @norm.call(output) if @norm
20
+
21
+ output
22
+ end
23
+ end
24
+ end
25
+ end
@@ -0,0 +1,43 @@
1
+ module Torch
2
+ module NN
3
+ class TransformerDecoderLayer < Module
4
+ def initialize(
5
+ d_model, n_head,
6
+ dim_feedforward: 2048, dropout: 0.1, activation: :relu,
7
+ layer_norm_eps: 1e-5, batch_first: false
8
+ )
9
+
10
+ super()
11
+
12
+ @self_attn = MultiheadAttention.new(d_model, n_head, dropout: dropout, batch_first: batch_first)
13
+ @multihead_attn = MultiheadAttention.new(d_model, n_head, dropout: dropout, batch_first: batch_first)
14
+
15
+ @linear1 = Linear.new(d_model, dim_feedforward)
16
+ @dropout = Dropout.new(p: dropout)
17
+ @linear2 = Linear.new(dim_feedforward, d_model)
18
+
19
+ @norm1 = LayerNorm.new(d_model, eps: layer_norm_eps)
20
+ @norm2 = LayerNorm.new(d_model, eps: layer_norm_eps)
21
+ @norm3 = LayerNorm.new(d_model, eps: layer_norm_eps)
22
+
23
+ @dropout1 = Dropout.new(p: dropout)
24
+ @dropout2 = Dropout.new(p: dropout)
25
+ @dropout3 = Dropout.new(p: dropout)
26
+
27
+ @activation = _activation_fn(activation)
28
+ end
29
+
30
+ def forward(tgt, memory, tgt_mask: nil, memory_mask: nil, tgt_key_padding_mask: nil, memory_key_padding_mask: nil)
31
+ tgt2 = @self_attn.(tgt, tgt, tgt, attn_mask: tgt_mask, key_padding_mask: tgt_key_padding_mask).first
32
+ tgt += @dropout1.(tgt2)
33
+ tgt = @norm1.(tgt)
34
+ tgt2 = @multihead_attn.(tgt, memory, memory, attn_mask: memory_mask, key_padding_mask: memory_key_padding_mask).first
35
+ tgt += @dropout2.(tgt2)
36
+ tgt = @norm2.(tgt)
37
+ tgt2 = @linear2.(@dropout.(@activation.(@linear1.(tgt))))
38
+ tgt += @dropout3.(tgt2)
39
+ @norm3.(tgt)
40
+ end
41
+ end
42
+ end
43
+ end
@@ -0,0 +1,25 @@
1
+ module Torch
2
+ module NN
3
+ class TransformerEncoder < Module
4
+ def initialize(encoder_layer, num_layers, norm: nil)
5
+ super()
6
+
7
+ @layers = _clones(encoder_layer, num_layers)
8
+ @num_layers = num_layers
9
+ @norm = norm
10
+ end
11
+
12
+ def forward(src, mask: nil, src_key_padding_mask: nil)
13
+ output = src
14
+
15
+ @layers.each do |mod|
16
+ output = mod.call(output, src_mask: mask, src_key_padding_mask: src_key_padding_mask)
17
+ end
18
+
19
+ output = @norm.call(output) if @norm
20
+
21
+ output
22
+ end
23
+ end
24
+ end
25
+ end
@@ -0,0 +1,36 @@
1
+ module Torch
2
+ module NN
3
+ class TransformerEncoderLayer < Module
4
+ def initialize(
5
+ d_model, n_head,
6
+ dim_feedforward: 2048, dropout: 0.1, activation: :relu,
7
+ layer_norm_eps: 1e-5, batch_first: false
8
+ )
9
+
10
+ super()
11
+
12
+ @self_attn = MultiheadAttention.new(d_model, n_head, dropout: dropout, batch_first: batch_first)
13
+ @linear1 = Linear.new(d_model, dim_feedforward)
14
+ @dropout = Dropout.new(p: dropout)
15
+ @linear2 = Linear.new(dim_feedforward, d_model)
16
+
17
+ @norm1 = LayerNorm.new(d_model, eps: layer_norm_eps)
18
+ @norm2 = LayerNorm.new(d_model, eps: layer_norm_eps)
19
+
20
+ @dropout1 = Dropout.new(p: dropout)
21
+ @dropout2 = Dropout.new(p: dropout)
22
+
23
+ @activation = _activation_fn(activation)
24
+ end
25
+
26
+ def forward(src, src_mask: nil, src_key_padding_mask: nil)
27
+ src2 = @self_attn.(src, src, src, attn_mask: src_mask, key_padding_mask: src_key_padding_mask).first
28
+ src += @dropout1.(src2)
29
+ src = @norm1.(src)
30
+ src2 = @linear2.(@dropout.(@activation.(@linear1.(src))))
31
+ src += @dropout2.(src2)
32
+ @norm2.(src)
33
+ end
34
+ end
35
+ end
36
+ end
@@ -20,6 +20,18 @@ module Torch
20
20
  def _ntuple(n, value)
21
21
  value.is_a?(Array) ? value : [value] * n
22
22
  end
23
+
24
+ def _clones(mod, n)
25
+ ModuleList.new(n.times.map { mod.deep_dup })
26
+ end
27
+
28
+ def _activation_fn(activation)
29
+ case activation.to_sym
30
+ when :relu then F.method(:relu)
31
+ when :gelu then F.method(:gelu)
32
+ else raise ArgumentError, "Activation should be relu/gelu, not `#{activation}`"
33
+ end
34
+ end
23
35
  end
24
36
  end
25
37
  end
data/lib/torch/tensor.rb CHANGED
@@ -19,6 +19,8 @@ module Torch
19
19
  alias_method :&, :logical_and
20
20
  alias_method :|, :logical_or
21
21
  alias_method :^, :logical_xor
22
+ alias_method :<<, :__lshift__
23
+ alias_method :>>, :__rshift__
22
24
 
23
25
  def self.new(*args)
24
26
  FloatTensor.new(*args)
@@ -183,5 +185,23 @@ module Torch
183
185
  def stft(*args)
184
186
  Torch.stft(*args)
185
187
  end
188
+
189
+ def dup
190
+ Torch.no_grad do
191
+ clone
192
+ end
193
+ end
194
+
195
+ # not a method in native_functions.yaml
196
+ # attribute in Python rather than method
197
+ def imag
198
+ Torch.imag(self)
199
+ end
200
+
201
+ # not a method in native_functions.yaml
202
+ # attribute in Python rather than method
203
+ def real
204
+ Torch.real(self)
205
+ end
186
206
  end
187
207
  end
@@ -25,6 +25,8 @@ module Torch
25
25
  end
26
26
 
27
27
  def each
28
+ return to_enum(:each) unless block_given?
29
+
28
30
  # try to keep the random number generator in sync with Python
29
31
  # this makes it easy to compare results
30
32
  base_seed = Torch.empty([], dtype: :int64).random!.item
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.8.0"
2
+ VERSION = "0.9.0"
3
3
  end
data/lib/torch.rb CHANGED
@@ -39,6 +39,7 @@ require "torch/nn/utils"
39
39
 
40
40
  # nn containers
41
41
  require "torch/nn/module"
42
+ require "torch/nn/module_list"
42
43
  require "torch/nn/sequential"
43
44
 
44
45
  # nn convolution layers
@@ -143,6 +144,10 @@ require "torch/nn/softmin"
143
144
  require "torch/nn/embedding"
144
145
  require "torch/nn/embedding_bag"
145
146
 
147
+ # attention is all you need
148
+ require "torch/nn/multihead_attention"
149
+ require "torch/nn/transformer"
150
+
146
151
  # nn distance functions
147
152
  require "torch/nn/cosine_similarity"
148
153
  require "torch/nn/pairwise_distance"
@@ -174,6 +179,7 @@ require "torch/nn/upsample"
174
179
 
175
180
  # nn other
176
181
  require "torch/nn/functional"
182
+ require "torch/nn/functional_attention"
177
183
  require "torch/nn/init"
178
184
 
179
185
  # utils
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.8.0
4
+ version: 0.9.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2021-06-15 00:00:00.000000000 Z
11
+ date: 2021-10-23 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -37,16 +37,23 @@ files:
37
37
  - codegen/function.rb
38
38
  - codegen/generate_functions.rb
39
39
  - codegen/native_functions.yaml
40
+ - ext/torch/backends.cpp
40
41
  - ext/torch/cuda.cpp
41
42
  - ext/torch/device.cpp
42
43
  - ext/torch/ext.cpp
43
44
  - ext/torch/extconf.rb
45
+ - ext/torch/fft.cpp
46
+ - ext/torch/fft_functions.h
44
47
  - ext/torch/ivalue.cpp
48
+ - ext/torch/linalg.cpp
49
+ - ext/torch/linalg_functions.h
45
50
  - ext/torch/nn.cpp
46
51
  - ext/torch/nn_functions.h
47
52
  - ext/torch/random.cpp
48
53
  - ext/torch/ruby_arg_parser.cpp
49
54
  - ext/torch/ruby_arg_parser.h
55
+ - ext/torch/special.cpp
56
+ - ext/torch/special_functions.h
50
57
  - ext/torch/templates.h
51
58
  - ext/torch/tensor.cpp
52
59
  - ext/torch/tensor_functions.h
@@ -99,6 +106,7 @@ files:
99
106
  - lib/torch/nn/feature_alpha_dropout.rb
100
107
  - lib/torch/nn/fold.rb
101
108
  - lib/torch/nn/functional.rb
109
+ - lib/torch/nn/functional_attention.rb
102
110
  - lib/torch/nn/group_norm.rb
103
111
  - lib/torch/nn/gru.rb
104
112
  - lib/torch/nn/hardshrink.rb
@@ -132,10 +140,12 @@ files:
132
140
  - lib/torch/nn/max_unpool3d.rb
133
141
  - lib/torch/nn/max_unpoolnd.rb
134
142
  - lib/torch/nn/module.rb
143
+ - lib/torch/nn/module_list.rb
135
144
  - lib/torch/nn/mse_loss.rb
136
145
  - lib/torch/nn/multi_label_margin_loss.rb
137
146
  - lib/torch/nn/multi_label_soft_margin_loss.rb
138
147
  - lib/torch/nn/multi_margin_loss.rb
148
+ - lib/torch/nn/multihead_attention.rb
139
149
  - lib/torch/nn/nll_loss.rb
140
150
  - lib/torch/nn/pairwise_distance.rb
141
151
  - lib/torch/nn/parameter.rb
@@ -163,6 +173,11 @@ files:
163
173
  - lib/torch/nn/softsign.rb
164
174
  - lib/torch/nn/tanh.rb
165
175
  - lib/torch/nn/tanhshrink.rb
176
+ - lib/torch/nn/transformer.rb
177
+ - lib/torch/nn/transformer_decoder.rb
178
+ - lib/torch/nn/transformer_decoder_layer.rb
179
+ - lib/torch/nn/transformer_encoder.rb
180
+ - lib/torch/nn/transformer_encoder_layer.rb
166
181
  - lib/torch/nn/triplet_margin_loss.rb
167
182
  - lib/torch/nn/unfold.rb
168
183
  - lib/torch/nn/upsample.rb
@@ -212,7 +227,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
212
227
  - !ruby/object:Gem::Version
213
228
  version: '0'
214
229
  requirements: []
215
- rubygems_version: 3.2.3
230
+ rubygems_version: 3.2.22
216
231
  signing_key:
217
232
  specification_version: 4
218
233
  summary: Deep learning for Ruby, powered by LibTorch