torch-rb 0.20.0 → 0.22.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,13 +7,18 @@ module Torch
7
7
  module NN
8
8
  class Transformer < Module
9
9
  def initialize(
10
- d_model: 512, nhead: 8,
11
- num_encoder_layers: 6, num_decoder_layers: 6,
12
- dim_feedforward: 2048, dropout: 0.1, activation: :relu,
13
- custom_encoder: nil, custom_decoder: nil,
14
- layer_norm_eps: 1e-5, batch_first: false
10
+ d_model: 512,
11
+ nhead: 8,
12
+ num_encoder_layers: 6,
13
+ num_decoder_layers: 6,
14
+ dim_feedforward: 2048,
15
+ dropout: 0.1,
16
+ activation: :relu,
17
+ custom_encoder: nil,
18
+ custom_decoder: nil,
19
+ layer_norm_eps: 1e-5,
20
+ batch_first: false
15
21
  )
16
-
17
22
  super()
18
23
 
19
24
  @encoder =
@@ -60,11 +65,15 @@ module Torch
60
65
  end
61
66
 
62
67
  def forward(
63
- src, tgt,
64
- src_mask: nil, tgt_mask: nil, memory_mask: nil,
65
- src_key_padding_mask: nil, tgt_key_padding_mask: nil, memory_key_padding_mask: nil
68
+ src,
69
+ tgt,
70
+ src_mask: nil,
71
+ tgt_mask: nil,
72
+ memory_mask: nil,
73
+ src_key_padding_mask: nil,
74
+ tgt_key_padding_mask: nil,
75
+ memory_key_padding_mask: nil
66
76
  )
67
-
68
77
  if (!batch_first? && src.size(1) != tgt.size(1)) ||
69
78
  (batch_first? && src.size(0) != tgt.size(0))
70
79
 
@@ -2,11 +2,14 @@ module Torch
2
2
  module NN
3
3
  class TransformerDecoderLayer < Module
4
4
  def initialize(
5
- d_model, n_head,
6
- dim_feedforward: 2048, dropout: 0.1, activation: :relu,
7
- layer_norm_eps: 1e-5, batch_first: false
5
+ d_model,
6
+ n_head,
7
+ dim_feedforward: 2048,
8
+ dropout: 0.1,
9
+ activation: :relu,
10
+ layer_norm_eps: 1e-5,
11
+ batch_first: false
8
12
  )
9
-
10
13
  super()
11
14
 
12
15
  @self_attn = MultiheadAttention.new(d_model, n_head, dropout: dropout, batch_first: batch_first)
@@ -2,11 +2,14 @@ module Torch
2
2
  module NN
3
3
  class TransformerEncoderLayer < Module
4
4
  def initialize(
5
- d_model, n_head,
6
- dim_feedforward: 2048, dropout: 0.1, activation: :relu,
7
- layer_norm_eps: 1e-5, batch_first: false
5
+ d_model,
6
+ n_head,
7
+ dim_feedforward: 2048,
8
+ dropout: 0.1,
9
+ activation: :relu,
10
+ layer_norm_eps: 1e-5,
11
+ batch_first: false
8
12
  )
9
-
10
13
  super()
11
14
 
12
15
  @self_attn = MultiheadAttention.new(d_model, n_head, dropout: dropout, batch_first: batch_first)
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.20.0"
2
+ VERSION = "0.22.0"
3
3
  end
data/lib/torch.rb CHANGED
@@ -210,7 +210,6 @@ require_relative "torch/utils/data/tensor_dataset"
210
210
  require_relative "torch/hub"
211
211
 
212
212
  module Torch
213
- class Error < StandardError; end
214
213
  class NotImplementedYet < StandardError
215
214
  def message
216
215
  "This feature has not been implemented yet. Consider submitting a PR."
@@ -439,7 +438,7 @@ module Torch
439
438
  # TODO check each dimensions for consistency in future
440
439
  raise Error, "Inconsistent dimensions" if data.size != size.inject(1, :*)
441
440
 
442
- # TOOD move to C++
441
+ # TODO move to C++
443
442
  data = data.map { |v| v ? 1 : 0 } if options[:dtype] == :bool
444
443
 
445
444
  _tensor(data, size, tensor_options(**options))
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.20.0
4
+ version: 0.22.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
@@ -234,14 +234,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
234
234
  requirements:
235
235
  - - ">="
236
236
  - !ruby/object:Gem::Version
237
- version: '3.1'
237
+ version: '3.2'
238
238
  required_rubygems_version: !ruby/object:Gem::Requirement
239
239
  requirements:
240
240
  - - ">="
241
241
  - !ruby/object:Gem::Version
242
242
  version: '0'
243
243
  requirements: []
244
- rubygems_version: 3.6.7
244
+ rubygems_version: 3.6.9
245
245
  specification_version: 4
246
246
  summary: Deep learning for Ruby, powered by LibTorch
247
247
  test_files: []