torchrec 0.0.1 → 0.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: b20b0c478258d789f1b6a6b2f80e3350600b6ce23f6ee1146cfb6819551a1150
4
- data.tar.gz: 443383a994cd25d6225f84401e0a47886b848e590551ae0eec2eb31dba9468bd
3
+ metadata.gz: 573d37473e5a3e8e24b60bfe886ef1d14be87833ecc7c207a349c236480f1d6e
4
+ data.tar.gz: 9fcc58e9201bb545582feaa50a56727dd074688b4c830e495abd452aa31ad572
5
5
  SHA512:
6
- metadata.gz: a8e6ce693978cd21505f4b7378efe92a87ee672ece2d012cf3ee69dc9a3a3332cc57a84d1e0edcc7c317385aaae212a8351e54d5ea33ad338f795ec7138c5cdc
7
- data.tar.gz: eb105de85d924fd2d3c7ef444a61caecdfb8008a5f25dba7297bb233b04f2662a47e144305301bd60f6cc1e9734bb1551946911047e8f654581b5fb192330cf2
6
+ metadata.gz: bb80aa64a91fb1af2dad904fc4f39b366f79765ba541334c4bf4109011249b9a6075a6d050b1aa644f42236cf037110181ed09cb74e71bd47882c10402e15101
7
+ data.tar.gz: 7393ad51f8882915e433f3aa8314f3d743a70b192b174854f161e2479c811ad7bb57f7f1953c0e817211056c201de5cf88d4d15e78c361e52768598155e9bad9
data/CHANGELOG.md CHANGED
@@ -1,3 +1,8 @@
1
+ ## 0.0.2 (2022-03-14)
2
+
3
+ - Added `JaggedTensor`
4
+ - Added `CrossNet`, `DeepFM`, and `FactorizationMachine` modules
5
+
1
6
  ## 0.0.1 (2022-02-28)
2
7
 
3
8
  - First release
data/README.md CHANGED
@@ -35,6 +35,9 @@ TorchRec::Models::DLRM::DenseArch.new(in_features, layer_sizes, device: nil)
35
35
 
36
36
  ```ruby
37
37
  TorchRec::Modules::Activation::SwishLayerNorm.new(input_dims, device: nil)
38
+ TorchRec::Modules::CrossNet::CrossNet.new(in_features, num_layers)
39
+ TorchRec::Modules::DeepFM::DeepFM.new(dense_module)
40
+ TorchRec::Modules::DeepFM::FactorizationMachine.new
38
41
  TorchRec::Modules::MLP::MLP.new(in_size, layer_sizes, bias: true, activation: :relu, device: nil)
39
42
  TorchRec::Modules::MLP::Perceptron.new(in_size, out_size, bias: true, activation: Torch.method(:relu), device: nil)
40
43
  ```
@@ -0,0 +1,36 @@
1
+ module TorchRec
2
+ module Modules
3
+ module CrossNet
4
+ class CrossNet < Torch::NN::Module
5
+ def initialize(in_features, num_layers)
6
+ super()
7
+ @num_layers = num_layers
8
+ @kernels = Torch::NN::ParameterList.new(
9
+ @num_layers.times.map do |i|
10
+ Torch::NN::Parameter.new(
11
+ Torch::NN::Init.xavier_normal!(Torch.empty(in_features, in_features))
12
+ )
13
+ end
14
+ )
15
+ @bias = Torch::NN::ParameterList.new(
16
+ @num_layers.times.map do |i|
17
+ Torch::NN::Parameter.new(Torch::NN::Init.zeros!(Torch.empty(in_features, 1)))
18
+ end
19
+ )
20
+ end
21
+
22
+ def forward(input)
23
+ x_0 = input.unsqueeze(2) # (B, N, 1)
24
+ x_l = x_0
25
+
26
+ @num_layers.times do |layer|
27
+ xl_w = Torch.matmul(@kernels[layer], x_l) # (B, N, 1)
28
+ x_l = x_0 * (xl_w + @bias[layer]) + x_l # (B, N, 1)
29
+ end
30
+
31
+ Torch.squeeze(x_l, dim: 2)
32
+ end
33
+ end
34
+ end
35
+ end
36
+ end
@@ -0,0 +1,23 @@
1
+ module TorchRec
2
+ module Modules
3
+ module DeepFM
4
+ class DeepFM < Torch::NN::Module
5
+ def initialize(dense_module)
6
+ super()
7
+ @dense_module = dense_module
8
+ end
9
+
10
+ def forward(embeddings)
11
+ deepfm_input = flatten_input(embeddings)
12
+ @dense_module.call(deepfm_input)
13
+ end
14
+
15
+ private
16
+
17
+ def flatten_input(inputs)
18
+ Torch.cat(inputs.map { |input| input.flatten(1) }, dim: 1)
19
+ end
20
+ end
21
+ end
22
+ end
23
+ end
@@ -0,0 +1,27 @@
1
+ module TorchRec
2
+ module Modules
3
+ module DeepFM
4
+ class FactorizationMachine < Torch::NN::Module
5
+ def initialize
6
+ super()
7
+ end
8
+
9
+ def forward(embeddings)
10
+ fm_input = flatten_input(embeddings)
11
+ sum_of_input = Torch.sum(fm_input, dim: 1, keepdim: true)
12
+ sum_of_square = Torch.sum(fm_input * fm_input, dim: 1, keepdim: true)
13
+ square_of_sum = sum_of_input * sum_of_input
14
+ cross_term = square_of_sum - sum_of_square
15
+ cross_term = Torch.sum(cross_term, dim: 1, keepdim: true) * 0.5 # [B, 1]
16
+ cross_term
17
+ end
18
+
19
+ private
20
+
21
+ def flatten_input(inputs)
22
+ Torch.cat(inputs.map { |input| input.flatten(1) }, dim: 1)
23
+ end
24
+ end
25
+ end
26
+ end
27
+ end
@@ -0,0 +1,33 @@
1
+ module TorchRec
2
+ module Sparse
3
+ class JaggedTensor
4
+ def initialize(values, weights: nil, lengths: nil, offsets: nil)
5
+ @values = values
6
+ @weights = weights
7
+ assert_offsets_or_lengths_is_provided(offsets, lengths)
8
+ if !offsets.nil?
9
+ assert_tensor_has_no_elements_or_has_integers(offsets, "offsets")
10
+ end
11
+ if !lengths.nil?
12
+ assert_tensor_has_no_elements_or_has_integers(lengths, "lengths")
13
+ end
14
+ @lengths = lengths
15
+ @offsets = offsets
16
+ end
17
+
18
+ private
19
+
20
+ def assert_offsets_or_lengths_is_provided(offsets, lengths)
21
+ if offsets.nil? && lengths.nil?
22
+ raise ArgumentError, "Must provide lengths or offsets"
23
+ end
24
+ end
25
+
26
+ def assert_tensor_has_no_elements_or_has_integers(tensor, tensor_name)
27
+ if tensor.numel != 0 && ![:int64, :int32, :int16, :int8, :uint8].include?(tensor.dtype)
28
+ raise ArgumentError, "#{tensor_name} must be of integer type, but got #{tensor.dtype}"
29
+ end
30
+ end
31
+ end
32
+ end
33
+ end
@@ -1,3 +1,3 @@
1
1
  module TorchRec
2
- VERSION = "0.0.1"
2
+ VERSION = "0.0.2"
3
3
  end
data/lib/torchrec.rb CHANGED
@@ -8,10 +8,16 @@ require "torchrec/models/dlrm/dense_arch"
8
8
 
9
9
  # modules
10
10
  require "torchrec/modules/activation/swish_layer_norm"
11
+ require "torchrec/modules/cross_net/cross_net"
12
+ require "torchrec/modules/deepfm/deepfm"
13
+ require "torchrec/modules/deepfm/factorization_machine"
11
14
  require "torchrec/modules/mlp/mlp"
12
15
  require "torchrec/modules/mlp/perceptron"
13
16
  require "torchrec/modules/utils"
14
17
 
18
+ # sparse
19
+ require "torchrec/sparse/jagged_tensor"
20
+
15
21
  # other
16
22
  require "torchrec/version"
17
23
 
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torchrec
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.1
4
+ version: 0.0.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2022-02-28 00:00:00.000000000 Z
11
+ date: 2022-03-15 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: torch-rb
@@ -16,14 +16,14 @@ dependencies:
16
16
  requirements:
17
17
  - - ">="
18
18
  - !ruby/object:Gem::Version
19
- version: '0'
19
+ version: 0.9.2
20
20
  type: :runtime
21
21
  prerelease: false
22
22
  version_requirements: !ruby/object:Gem::Requirement
23
23
  requirements:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
- version: '0'
26
+ version: 0.9.2
27
27
  description:
28
28
  email: andrew@ankane.org
29
29
  executables: []
@@ -38,9 +38,13 @@ files:
38
38
  - lib/torchrec/models/deepfm/over_arch.rb
39
39
  - lib/torchrec/models/dlrm/dense_arch.rb
40
40
  - lib/torchrec/modules/activation/swish_layer_norm.rb
41
+ - lib/torchrec/modules/cross_net/cross_net.rb
42
+ - lib/torchrec/modules/deepfm/deepfm.rb
43
+ - lib/torchrec/modules/deepfm/factorization_machine.rb
41
44
  - lib/torchrec/modules/mlp/mlp.rb
42
45
  - lib/torchrec/modules/mlp/perceptron.rb
43
46
  - lib/torchrec/modules/utils.rb
47
+ - lib/torchrec/sparse/jagged_tensor.rb
44
48
  - lib/torchrec/version.rb
45
49
  homepage: https://github.com/ankane/torchrec-ruby
46
50
  licenses:
@@ -61,7 +65,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
61
65
  - !ruby/object:Gem::Version
62
66
  version: '0'
63
67
  requirements: []
64
- rubygems_version: 3.3.3
68
+ rubygems_version: 3.3.7
65
69
  signing_key:
66
70
  specification_version: 4
67
71
  summary: Deep learning recommendation systems for Ruby