torch-rb 0.8.1 → 0.9.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,6 +9,12 @@ module Torch
9
9
  def inspect
10
10
  "Parameter containing:\n#{super}"
11
11
  end
12
+
13
+ def dup
14
+ Torch.no_grad do
15
+ Parameter.new(clone, requires_grad: requires_grad)
16
+ end
17
+ end
12
18
  end
13
19
  end
14
20
  end
@@ -0,0 +1,92 @@
1
+ require_relative 'transformer_encoder_layer'
2
+ require_relative 'transformer_encoder'
3
+ require_relative 'transformer_decoder_layer'
4
+ require_relative 'transformer_decoder'
5
+
6
+ module Torch
7
+ module NN
8
+ class Transformer < Module
9
+ def initialize(
10
+ d_model: 512, nhead: 8,
11
+ num_encoder_layers: 6, num_decoder_layers: 6,
12
+ dim_feedforward: 2048, dropout: 0.1, activation: :relu,
13
+ custom_encoder: nil, custom_decoder: nil,
14
+ layer_norm_eps: 1e-5, batch_first: false
15
+ )
16
+
17
+ super()
18
+
19
+ @encoder =
20
+ if custom_encoder
21
+ custom_encoder
22
+ else
23
+ encoder_layer = TransformerEncoderLayer.new(
24
+ d_model, nhead,
25
+ dim_feedforward: dim_feedforward, dropout: dropout, activation: activation,
26
+ layer_norm_eps: layer_norm_eps, batch_first: batch_first
27
+ )
28
+ encoder_norm = LayerNorm.new(d_model, eps: layer_norm_eps)
29
+ TransformerEncoder.new(encoder_layer, num_encoder_layers, norm: encoder_norm)
30
+ end
31
+
32
+ @decoder =
33
+ if custom_decoder
34
+ custom_decoder
35
+ else
36
+ decoder_layer = TransformerDecoderLayer.new(
37
+ d_model, nhead,
38
+ dim_feedforward: dim_feedforward, dropout: dropout, activation: activation,
39
+ layer_norm_eps: layer_norm_eps, batch_first: batch_first
40
+ )
41
+ decoder_norm = LayerNorm.new(d_model, eps: layer_norm_eps)
42
+ TransformerDecoder.new(decoder_layer, num_decoder_layers, norm: decoder_norm)
43
+ end
44
+
45
+ reset_parameters
46
+
47
+ @d_model = d_model
48
+ @nhead = nhead
49
+ @batch_first = batch_first
50
+ end
51
+
52
+ attr_reader :d_model, :nhead, :encoder, :decoder
53
+
54
+ def batch_first?
55
+ !!@batch_first
56
+ end
57
+
58
+ def reset_parameters
59
+ parameters.each { |p| Init.xavier_uniform!(p) if p.dim > 1 }
60
+ end
61
+
62
+ def forward(
63
+ src, tgt,
64
+ src_mask: nil, tgt_mask: nil, memory_mask: nil,
65
+ src_key_padding_mask: nil, tgt_key_padding_mask: nil, memory_key_padding_mask: nil
66
+ )
67
+
68
+ if (!batch_first? && src.size(1) != tgt.size(1)) ||
69
+ (batch_first? && src.size(0) != tgt.size(0))
70
+
71
+ raise ArgumentError, "The batch number of src and tgt must be equal"
72
+ end
73
+
74
+ if src.size(2) != d_model || tgt.size(2) != d_model
75
+ raise ArgumentError, "The feature number of src and tgt must be equal to d_model"
76
+ end
77
+
78
+ memory = @encoder.(src, mask: src_mask, src_key_padding_mask: src_key_padding_mask)
79
+ @decoder.(
80
+ tgt, memory,
81
+ tgt_mask: tgt_mask, memory_mask: memory_mask,
82
+ tgt_key_padding_mask: tgt_key_padding_mask, memory_key_padding_mask: memory_key_padding_mask
83
+ )
84
+ end
85
+
86
+ def generate_square_subsequent_mask(sz)
87
+ mask = Torch.triu(Torch.ones([sz, sz])).eq(1).transpose(0, 1)
88
+ mask.float.masked_fill!(mask.eq(0), -Float::INFINITY).masked_fill!(mask.eq(1), 0.0)
89
+ end
90
+ end
91
+ end
92
+ end
@@ -0,0 +1,25 @@
1
+ module Torch
2
+ module NN
3
+ class TransformerDecoder < Module
4
+ def initialize(decoder_layer, num_layers, norm: nil)
5
+ super()
6
+
7
+ @layers = _clones(decoder_layer, num_layers)
8
+ @num_layers = num_layers
9
+ @norm = norm
10
+ end
11
+
12
+ def forward(tgt, memory, tgt_mask: nil, memory_mask: nil, tgt_key_padding_mask: nil, memory_key_padding_mask: nil)
13
+ output = tgt
14
+
15
+ @layers.each do |mod|
16
+ output = mod.call(output, memory, tgt_mask: tgt_mask, memory_mask: memory_mask, tgt_key_padding_mask: tgt_key_padding_mask, memory_key_padding_mask: memory_key_padding_mask)
17
+ end
18
+
19
+ output = @norm.call(output) if @norm
20
+
21
+ output
22
+ end
23
+ end
24
+ end
25
+ end
@@ -0,0 +1,43 @@
1
+ module Torch
2
+ module NN
3
+ class TransformerDecoderLayer < Module
4
+ def initialize(
5
+ d_model, n_head,
6
+ dim_feedforward: 2048, dropout: 0.1, activation: :relu,
7
+ layer_norm_eps: 1e-5, batch_first: false
8
+ )
9
+
10
+ super()
11
+
12
+ @self_attn = MultiheadAttention.new(d_model, n_head, dropout: dropout, batch_first: batch_first)
13
+ @multihead_attn = MultiheadAttention.new(d_model, n_head, dropout: dropout, batch_first: batch_first)
14
+
15
+ @linear1 = Linear.new(d_model, dim_feedforward)
16
+ @dropout = Dropout.new(p: dropout)
17
+ @linear2 = Linear.new(dim_feedforward, d_model)
18
+
19
+ @norm1 = LayerNorm.new(d_model, eps: layer_norm_eps)
20
+ @norm2 = LayerNorm.new(d_model, eps: layer_norm_eps)
21
+ @norm3 = LayerNorm.new(d_model, eps: layer_norm_eps)
22
+
23
+ @dropout1 = Dropout.new(p: dropout)
24
+ @dropout2 = Dropout.new(p: dropout)
25
+ @dropout3 = Dropout.new(p: dropout)
26
+
27
+ @activation = _activation_fn(activation)
28
+ end
29
+
30
+ def forward(tgt, memory, tgt_mask: nil, memory_mask: nil, tgt_key_padding_mask: nil, memory_key_padding_mask: nil)
31
+ tgt2 = @self_attn.(tgt, tgt, tgt, attn_mask: tgt_mask, key_padding_mask: tgt_key_padding_mask).first
32
+ tgt += @dropout1.(tgt2)
33
+ tgt = @norm1.(tgt)
34
+ tgt2 = @multihead_attn.(tgt, memory, memory, attn_mask: memory_mask, key_padding_mask: memory_key_padding_mask).first
35
+ tgt += @dropout2.(tgt2)
36
+ tgt = @norm2.(tgt)
37
+ tgt2 = @linear2.(@dropout.(@activation.(@linear1.(tgt))))
38
+ tgt += @dropout3.(tgt2)
39
+ @norm3.(tgt)
40
+ end
41
+ end
42
+ end
43
+ end
@@ -0,0 +1,25 @@
1
+ module Torch
2
+ module NN
3
+ class TransformerEncoder < Module
4
+ def initialize(encoder_layer, num_layers, norm: nil)
5
+ super()
6
+
7
+ @layers = _clones(encoder_layer, num_layers)
8
+ @num_layers = num_layers
9
+ @norm = norm
10
+ end
11
+
12
+ def forward(src, mask: nil, src_key_padding_mask: nil)
13
+ output = src
14
+
15
+ @layers.each do |mod|
16
+ output = mod.call(output, src_mask: mask, src_key_padding_mask: src_key_padding_mask)
17
+ end
18
+
19
+ output = @norm.call(output) if @norm
20
+
21
+ output
22
+ end
23
+ end
24
+ end
25
+ end
@@ -0,0 +1,36 @@
1
+ module Torch
2
+ module NN
3
+ class TransformerEncoderLayer < Module
4
+ def initialize(
5
+ d_model, n_head,
6
+ dim_feedforward: 2048, dropout: 0.1, activation: :relu,
7
+ layer_norm_eps: 1e-5, batch_first: false
8
+ )
9
+
10
+ super()
11
+
12
+ @self_attn = MultiheadAttention.new(d_model, n_head, dropout: dropout, batch_first: batch_first)
13
+ @linear1 = Linear.new(d_model, dim_feedforward)
14
+ @dropout = Dropout.new(p: dropout)
15
+ @linear2 = Linear.new(dim_feedforward, d_model)
16
+
17
+ @norm1 = LayerNorm.new(d_model, eps: layer_norm_eps)
18
+ @norm2 = LayerNorm.new(d_model, eps: layer_norm_eps)
19
+
20
+ @dropout1 = Dropout.new(p: dropout)
21
+ @dropout2 = Dropout.new(p: dropout)
22
+
23
+ @activation = _activation_fn(activation)
24
+ end
25
+
26
+ def forward(src, src_mask: nil, src_key_padding_mask: nil)
27
+ src2 = @self_attn.(src, src, src, attn_mask: src_mask, key_padding_mask: src_key_padding_mask).first
28
+ src += @dropout1.(src2)
29
+ src = @norm1.(src)
30
+ src2 = @linear2.(@dropout.(@activation.(@linear1.(src))))
31
+ src += @dropout2.(src2)
32
+ @norm2.(src)
33
+ end
34
+ end
35
+ end
36
+ end
@@ -20,6 +20,18 @@ module Torch
20
20
  def _ntuple(n, value)
21
21
  value.is_a?(Array) ? value : [value] * n
22
22
  end
23
+
24
+ def _clones(mod, n)
25
+ ModuleList.new(n.times.map { mod.deep_dup })
26
+ end
27
+
28
+ def _activation_fn(activation)
29
+ case activation.to_sym
30
+ when :relu then F.method(:relu)
31
+ when :gelu then F.method(:gelu)
32
+ else raise ArgumentError, "Activation should be relu/gelu, not `#{activation}`"
33
+ end
34
+ end
23
35
  end
24
36
  end
25
37
  end
data/lib/torch/tensor.rb CHANGED
@@ -19,6 +19,8 @@ module Torch
19
19
  alias_method :&, :logical_and
20
20
  alias_method :|, :logical_or
21
21
  alias_method :^, :logical_xor
22
+ alias_method :<<, :__lshift__
23
+ alias_method :>>, :__rshift__
22
24
 
23
25
  def self.new(*args)
24
26
  FloatTensor.new(*args)
@@ -104,6 +106,7 @@ module Torch
104
106
  size(0)
105
107
  end
106
108
 
109
+ remove_method :item
107
110
  def item
108
111
  if numel != 1
109
112
  raise Error, "only one element tensors can be converted to Ruby scalars"
@@ -131,18 +134,10 @@ module Torch
131
134
  cls.from_string(_data_str).reshape(*shape)
132
135
  end
133
136
 
134
- def new_ones(*size, **options)
135
- Torch.ones_like(Torch.empty(*size), **options)
136
- end
137
-
138
137
  def requires_grad=(requires_grad)
139
138
  _requires_grad!(requires_grad)
140
139
  end
141
140
 
142
- def requires_grad!(requires_grad = true)
143
- _requires_grad!(requires_grad)
144
- end
145
-
146
141
  def type(dtype)
147
142
  if dtype.is_a?(Class)
148
143
  raise Error, "Invalid type: #{dtype}" unless TENSOR_TYPE_CLASSES.include?(dtype)
@@ -183,5 +178,23 @@ module Torch
183
178
  def stft(*args)
184
179
  Torch.stft(*args)
185
180
  end
181
+
182
+ def dup
183
+ Torch.no_grad do
184
+ clone
185
+ end
186
+ end
187
+
188
+ # not a method in native_functions.yaml
189
+ # attribute in Python rather than method
190
+ def imag
191
+ Torch.imag(self)
192
+ end
193
+
194
+ # not a method in native_functions.yaml
195
+ # attribute in Python rather than method
196
+ def real
197
+ Torch.real(self)
198
+ end
186
199
  end
187
200
  end
@@ -25,9 +25,11 @@ module Torch
25
25
  end
26
26
 
27
27
  def each
28
+ return to_enum(:each) unless block_given?
29
+
28
30
  # try to keep the random number generator in sync with Python
29
31
  # this makes it easy to compare results
30
- base_seed = Torch.empty([], dtype: :int64).random!.item
32
+ _base_seed = Torch.empty([], dtype: :int64).random!.item
31
33
 
32
34
  indexes =
33
35
  if @shuffle
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.8.1"
2
+ VERSION = "0.9.1"
3
3
  end
data/lib/torch.rb CHANGED
@@ -39,6 +39,7 @@ require "torch/nn/utils"
39
39
 
40
40
  # nn containers
41
41
  require "torch/nn/module"
42
+ require "torch/nn/module_list"
42
43
  require "torch/nn/sequential"
43
44
 
44
45
  # nn convolution layers
@@ -143,6 +144,10 @@ require "torch/nn/softmin"
143
144
  require "torch/nn/embedding"
144
145
  require "torch/nn/embedding_bag"
145
146
 
147
+ # attention is all you need
148
+ require "torch/nn/multihead_attention"
149
+ require "torch/nn/transformer"
150
+
146
151
  # nn distance functions
147
152
  require "torch/nn/cosine_similarity"
148
153
  require "torch/nn/pairwise_distance"
@@ -174,6 +179,7 @@ require "torch/nn/upsample"
174
179
 
175
180
  # nn other
176
181
  require "torch/nn/functional"
182
+ require "torch/nn/functional_attention"
177
183
  require "torch/nn/init"
178
184
 
179
185
  # utils
@@ -371,8 +377,6 @@ module Torch
371
377
  to_ruby(_load(File.binread(f)))
372
378
  end
373
379
 
374
- # --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
375
-
376
380
  def tensor(data, **options)
377
381
  if options[:dtype].nil? && defined?(Numo::NArray) && data.is_a?(Numo::NArray)
378
382
  numo_to_dtype = _dtype_to_numo.map(&:reverse).to_h
@@ -405,41 +409,6 @@ module Torch
405
409
  _tensor(data, size, tensor_options(**options))
406
410
  end
407
411
 
408
- # --- begin like ---
409
-
410
- def ones_like(input, **options)
411
- ones(input.size, **like_options(input, options))
412
- end
413
-
414
- def empty_like(input, **options)
415
- empty(input.size, **like_options(input, options))
416
- end
417
-
418
- def full_like(input, fill_value, **options)
419
- full(input.size, fill_value, **like_options(input, options))
420
- end
421
-
422
- def rand_like(input, **options)
423
- rand(input.size, **like_options(input, options))
424
- end
425
-
426
- def randint_like(input, low, high = nil, **options)
427
- # ruby doesn't support input, low = 0, high, ...
428
- if high.nil?
429
- high = low
430
- low = 0
431
- end
432
- randint(low, high, input.size, **like_options(input, options))
433
- end
434
-
435
- def randn_like(input, **options)
436
- randn(input.size, **like_options(input, options))
437
- end
438
-
439
- def zeros_like(input, **options)
440
- zeros(input.size, **like_options(input, options))
441
- end
442
-
443
412
  # center option
444
413
  def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true, return_complex: nil)
445
414
  if center
@@ -566,13 +535,5 @@ module Torch
566
535
  end
567
536
  options
568
537
  end
569
-
570
- def like_options(input, options)
571
- options = options.dup
572
- options[:dtype] ||= input.dtype
573
- options[:layout] ||= input.layout
574
- options[:device] ||= input.device
575
- options
576
- end
577
538
  end
578
539
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.8.1
4
+ version: 0.9.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2021-06-16 00:00:00.000000000 Z
11
+ date: 2022-02-03 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -106,6 +106,7 @@ files:
106
106
  - lib/torch/nn/feature_alpha_dropout.rb
107
107
  - lib/torch/nn/fold.rb
108
108
  - lib/torch/nn/functional.rb
109
+ - lib/torch/nn/functional_attention.rb
109
110
  - lib/torch/nn/group_norm.rb
110
111
  - lib/torch/nn/gru.rb
111
112
  - lib/torch/nn/hardshrink.rb
@@ -139,10 +140,12 @@ files:
139
140
  - lib/torch/nn/max_unpool3d.rb
140
141
  - lib/torch/nn/max_unpoolnd.rb
141
142
  - lib/torch/nn/module.rb
143
+ - lib/torch/nn/module_list.rb
142
144
  - lib/torch/nn/mse_loss.rb
143
145
  - lib/torch/nn/multi_label_margin_loss.rb
144
146
  - lib/torch/nn/multi_label_soft_margin_loss.rb
145
147
  - lib/torch/nn/multi_margin_loss.rb
148
+ - lib/torch/nn/multihead_attention.rb
146
149
  - lib/torch/nn/nll_loss.rb
147
150
  - lib/torch/nn/pairwise_distance.rb
148
151
  - lib/torch/nn/parameter.rb
@@ -170,6 +173,11 @@ files:
170
173
  - lib/torch/nn/softsign.rb
171
174
  - lib/torch/nn/tanh.rb
172
175
  - lib/torch/nn/tanhshrink.rb
176
+ - lib/torch/nn/transformer.rb
177
+ - lib/torch/nn/transformer_decoder.rb
178
+ - lib/torch/nn/transformer_decoder_layer.rb
179
+ - lib/torch/nn/transformer_encoder.rb
180
+ - lib/torch/nn/transformer_encoder_layer.rb
173
181
  - lib/torch/nn/triplet_margin_loss.rb
174
182
  - lib/torch/nn/unfold.rb
175
183
  - lib/torch/nn/upsample.rb
@@ -219,7 +227,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
219
227
  - !ruby/object:Gem::Version
220
228
  version: '0'
221
229
  requirements: []
222
- rubygems_version: 3.2.3
230
+ rubygems_version: 3.3.3
223
231
  signing_key:
224
232
  specification_version: 4
225
233
  summary: Deep learning for Ruby, powered by LibTorch