torchrec 0.0.1 → 0.0.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: b20b0c478258d789f1b6a6b2f80e3350600b6ce23f6ee1146cfb6819551a1150
4
- data.tar.gz: 443383a994cd25d6225f84401e0a47886b848e590551ae0eec2eb31dba9468bd
3
+ metadata.gz: 8fa2077cf1a744fe2d11d9fdb030e82a428542892670d6f2ef69b4801c3d0f43
4
+ data.tar.gz: 6c95aa0542c037f0cc466e2e82f3b366ee7a099da580deeb7718fc4889702b1f
5
5
  SHA512:
6
- metadata.gz: a8e6ce693978cd21505f4b7378efe92a87ee672ece2d012cf3ee69dc9a3a3332cc57a84d1e0edcc7c317385aaae212a8351e54d5ea33ad338f795ec7138c5cdc
7
- data.tar.gz: eb105de85d924fd2d3c7ef444a61caecdfb8008a5f25dba7297bb233b04f2662a47e144305301bd60f6cc1e9734bb1551946911047e8f654581b5fb192330cf2
6
+ metadata.gz: 796466080c868a2aadd167c7bd2db68c150bb25b56ff7f56ce05b7d50077f24f1361e488edab331e3e468fec929e0a78ef49df1a14e0632baedd465deb266538
7
+ data.tar.gz: 38b05acd34abe5bf3364697690dd8b9dbcf16943ed54287b3ef6ed75256f8902898acff2dd9caf64d3e03d94cf3f8071efc25bdb362bf2eb12b90c68d8a6d310
data/CHANGELOG.md CHANGED
@@ -1,3 +1,12 @@
1
+ ## 0.0.3 (2023-07-24)
2
+
3
+ - Dropped support for Ruby < 3
4
+
5
+ ## 0.0.2 (2022-03-14)
6
+
7
+ - Added `JaggedTensor`
8
+ - Added `CrossNet`, `DeepFM`, and `FactorizationMachine` modules
9
+
1
10
  ## 0.0.1 (2022-02-28)
2
11
 
3
12
  - First release
data/README.md CHANGED
@@ -35,6 +35,9 @@ TorchRec::Models::DLRM::DenseArch.new(in_features, layer_sizes, device: nil)
35
35
 
36
36
  ```ruby
37
37
  TorchRec::Modules::Activation::SwishLayerNorm.new(input_dims, device: nil)
38
+ TorchRec::Modules::CrossNet::CrossNet.new(in_features, num_layers)
39
+ TorchRec::Modules::DeepFM::DeepFM.new(dense_module)
40
+ TorchRec::Modules::DeepFM::FactorizationMachine.new
38
41
  TorchRec::Modules::MLP::MLP.new(in_size, layer_sizes, bias: true, activation: :relu, device: nil)
39
42
  TorchRec::Modules::MLP::Perceptron.new(in_size, out_size, bias: true, activation: Torch.method(:relu), device: nil)
40
43
  ```
@@ -0,0 +1,36 @@
1
+ module TorchRec
2
+ module Modules
3
+ module CrossNet
4
+ class CrossNet < Torch::NN::Module
5
+ def initialize(in_features, num_layers)
6
+ super()
7
+ @num_layers = num_layers
8
+ @kernels = Torch::NN::ParameterList.new(
9
+ @num_layers.times.map do |i|
10
+ Torch::NN::Parameter.new(
11
+ Torch::NN::Init.xavier_normal!(Torch.empty(in_features, in_features))
12
+ )
13
+ end
14
+ )
15
+ @bias = Torch::NN::ParameterList.new(
16
+ @num_layers.times.map do |i|
17
+ Torch::NN::Parameter.new(Torch::NN::Init.zeros!(Torch.empty(in_features, 1)))
18
+ end
19
+ )
20
+ end
21
+
22
+ def forward(input)
23
+ x_0 = input.unsqueeze(2) # (B, N, 1)
24
+ x_l = x_0
25
+
26
+ @num_layers.times do |layer|
27
+ xl_w = Torch.matmul(@kernels[layer], x_l) # (B, N, 1)
28
+ x_l = x_0 * (xl_w + @bias[layer]) + x_l # (B, N, 1)
29
+ end
30
+
31
+ Torch.squeeze(x_l, dim: 2)
32
+ end
33
+ end
34
+ end
35
+ end
36
+ end
@@ -0,0 +1,23 @@
1
+ module TorchRec
2
+ module Modules
3
+ module DeepFM
4
+ class DeepFM < Torch::NN::Module
5
+ def initialize(dense_module)
6
+ super()
7
+ @dense_module = dense_module
8
+ end
9
+
10
+ def forward(embeddings)
11
+ deepfm_input = flatten_input(embeddings)
12
+ @dense_module.call(deepfm_input)
13
+ end
14
+
15
+ private
16
+
17
+ def flatten_input(inputs)
18
+ Torch.cat(inputs.map { |input| input.flatten(1) }, dim: 1)
19
+ end
20
+ end
21
+ end
22
+ end
23
+ end
@@ -0,0 +1,27 @@
1
+ module TorchRec
2
+ module Modules
3
+ module DeepFM
4
+ class FactorizationMachine < Torch::NN::Module
5
+ def initialize
6
+ super()
7
+ end
8
+
9
+ def forward(embeddings)
10
+ fm_input = flatten_input(embeddings)
11
+ sum_of_input = Torch.sum(fm_input, dim: 1, keepdim: true)
12
+ sum_of_square = Torch.sum(fm_input * fm_input, dim: 1, keepdim: true)
13
+ square_of_sum = sum_of_input * sum_of_input
14
+ cross_term = square_of_sum - sum_of_square
15
+ cross_term = Torch.sum(cross_term, dim: 1, keepdim: true) * 0.5 # [B, 1]
16
+ cross_term
17
+ end
18
+
19
+ private
20
+
21
+ def flatten_input(inputs)
22
+ Torch.cat(inputs.map { |input| input.flatten(1) }, dim: 1)
23
+ end
24
+ end
25
+ end
26
+ end
27
+ end
@@ -0,0 +1,33 @@
1
+ module TorchRec
2
+ module Sparse
3
+ class JaggedTensor
4
+ def initialize(values, weights: nil, lengths: nil, offsets: nil)
5
+ @values = values
6
+ @weights = weights
7
+ assert_offsets_or_lengths_is_provided(offsets, lengths)
8
+ if !offsets.nil?
9
+ assert_tensor_has_no_elements_or_has_integers(offsets, "offsets")
10
+ end
11
+ if !lengths.nil?
12
+ assert_tensor_has_no_elements_or_has_integers(lengths, "lengths")
13
+ end
14
+ @lengths = lengths
15
+ @offsets = offsets
16
+ end
17
+
18
+ private
19
+
20
+ def assert_offsets_or_lengths_is_provided(offsets, lengths)
21
+ if offsets.nil? && lengths.nil?
22
+ raise ArgumentError, "Must provide lengths or offsets"
23
+ end
24
+ end
25
+
26
+ def assert_tensor_has_no_elements_or_has_integers(tensor, tensor_name)
27
+ if tensor.numel != 0 && ![:int64, :int32, :int16, :int8, :uint8].include?(tensor.dtype)
28
+ raise ArgumentError, "#{tensor_name} must be of integer type, but got #{tensor.dtype}"
29
+ end
30
+ end
31
+ end
32
+ end
33
+ end
@@ -1,3 +1,3 @@
1
1
  module TorchRec
2
- VERSION = "0.0.1"
2
+ VERSION = "0.0.3"
3
3
  end
data/lib/torchrec.rb CHANGED
@@ -1,19 +1,25 @@
1
1
  # dependencies
2
- require "torch"
2
+ require "torch-rb"
3
3
 
4
4
  # models
5
- require "torchrec/models/deepfm/dense_arch"
6
- require "torchrec/models/deepfm/over_arch"
7
- require "torchrec/models/dlrm/dense_arch"
5
+ require_relative "torchrec/models/deepfm/dense_arch"
6
+ require_relative "torchrec/models/deepfm/over_arch"
7
+ require_relative "torchrec/models/dlrm/dense_arch"
8
8
 
9
9
  # modules
10
- require "torchrec/modules/activation/swish_layer_norm"
11
- require "torchrec/modules/mlp/mlp"
12
- require "torchrec/modules/mlp/perceptron"
13
- require "torchrec/modules/utils"
10
+ require_relative "torchrec/modules/activation/swish_layer_norm"
11
+ require_relative "torchrec/modules/cross_net/cross_net"
12
+ require_relative "torchrec/modules/deepfm/deepfm"
13
+ require_relative "torchrec/modules/deepfm/factorization_machine"
14
+ require_relative "torchrec/modules/mlp/mlp"
15
+ require_relative "torchrec/modules/mlp/perceptron"
16
+ require_relative "torchrec/modules/utils"
17
+
18
+ # sparse
19
+ require_relative "torchrec/sparse/jagged_tensor"
14
20
 
15
21
  # other
16
- require "torchrec/version"
22
+ require_relative "torchrec/version"
17
23
 
18
24
  module TorchRec
19
25
  class Error < StandardError; end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torchrec
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.1
4
+ version: 0.0.3
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2022-02-28 00:00:00.000000000 Z
11
+ date: 2023-07-25 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: torch-rb
@@ -16,14 +16,14 @@ dependencies:
16
16
  requirements:
17
17
  - - ">="
18
18
  - !ruby/object:Gem::Version
19
- version: '0'
19
+ version: '0.10'
20
20
  type: :runtime
21
21
  prerelease: false
22
22
  version_requirements: !ruby/object:Gem::Requirement
23
23
  requirements:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
- version: '0'
26
+ version: '0.10'
27
27
  description:
28
28
  email: andrew@ankane.org
29
29
  executables: []
@@ -38,9 +38,13 @@ files:
38
38
  - lib/torchrec/models/deepfm/over_arch.rb
39
39
  - lib/torchrec/models/dlrm/dense_arch.rb
40
40
  - lib/torchrec/modules/activation/swish_layer_norm.rb
41
+ - lib/torchrec/modules/cross_net/cross_net.rb
42
+ - lib/torchrec/modules/deepfm/deepfm.rb
43
+ - lib/torchrec/modules/deepfm/factorization_machine.rb
41
44
  - lib/torchrec/modules/mlp/mlp.rb
42
45
  - lib/torchrec/modules/mlp/perceptron.rb
43
46
  - lib/torchrec/modules/utils.rb
47
+ - lib/torchrec/sparse/jagged_tensor.rb
44
48
  - lib/torchrec/version.rb
45
49
  homepage: https://github.com/ankane/torchrec-ruby
46
50
  licenses:
@@ -54,14 +58,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
54
58
  requirements:
55
59
  - - ">="
56
60
  - !ruby/object:Gem::Version
57
- version: '2.6'
61
+ version: '3'
58
62
  required_rubygems_version: !ruby/object:Gem::Requirement
59
63
  requirements:
60
64
  - - ">="
61
65
  - !ruby/object:Gem::Version
62
66
  version: '0'
63
67
  requirements: []
64
- rubygems_version: 3.3.3
68
+ rubygems_version: 3.4.10
65
69
  signing_key:
66
70
  specification_version: 4
67
71
  summary: Deep learning recommendation systems for Ruby