torchrec 0.0.1 → 0.0.2

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: b20b0c478258d789f1b6a6b2f80e3350600b6ce23f6ee1146cfb6819551a1150
4
- data.tar.gz: 443383a994cd25d6225f84401e0a47886b848e590551ae0eec2eb31dba9468bd
3
+ metadata.gz: 573d37473e5a3e8e24b60bfe886ef1d14be87833ecc7c207a349c236480f1d6e
4
+ data.tar.gz: 9fcc58e9201bb545582feaa50a56727dd074688b4c830e495abd452aa31ad572
5
5
  SHA512:
6
- metadata.gz: a8e6ce693978cd21505f4b7378efe92a87ee672ece2d012cf3ee69dc9a3a3332cc57a84d1e0edcc7c317385aaae212a8351e54d5ea33ad338f795ec7138c5cdc
7
- data.tar.gz: eb105de85d924fd2d3c7ef444a61caecdfb8008a5f25dba7297bb233b04f2662a47e144305301bd60f6cc1e9734bb1551946911047e8f654581b5fb192330cf2
6
+ metadata.gz: bb80aa64a91fb1af2dad904fc4f39b366f79765ba541334c4bf4109011249b9a6075a6d050b1aa644f42236cf037110181ed09cb74e71bd47882c10402e15101
7
+ data.tar.gz: 7393ad51f8882915e433f3aa8314f3d743a70b192b174854f161e2479c811ad7bb57f7f1953c0e817211056c201de5cf88d4d15e78c361e52768598155e9bad9
data/CHANGELOG.md CHANGED
@@ -1,3 +1,8 @@
1
+ ## 0.0.2 (2022-03-14)
2
+
3
+ - Added `JaggedTensor`
4
+ - Added `CrossNet`, `DeepFM`, and `FactorizationMachine` modules
5
+
1
6
  ## 0.0.1 (2022-02-28)
2
7
 
3
8
  - First release
data/README.md CHANGED
@@ -35,6 +35,9 @@ TorchRec::Models::DLRM::DenseArch.new(in_features, layer_sizes, device: nil)
35
35
 
36
36
  ```ruby
37
37
  TorchRec::Modules::Activation::SwishLayerNorm.new(input_dims, device: nil)
38
+ TorchRec::Modules::CrossNet::CrossNet.new(in_features, num_layers)
39
+ TorchRec::Modules::DeepFM::DeepFM.new(dense_module)
40
+ TorchRec::Modules::DeepFM::FactorizationMachine.new
38
41
  TorchRec::Modules::MLP::MLP.new(in_size, layer_sizes, bias: true, activation: :relu, device: nil)
39
42
  TorchRec::Modules::MLP::Perceptron.new(in_size, out_size, bias: true, activation: Torch.method(:relu), device: nil)
40
43
  ```
@@ -0,0 +1,36 @@
1
+ module TorchRec
2
+ module Modules
3
+ module CrossNet
4
+ class CrossNet < Torch::NN::Module
5
+ def initialize(in_features, num_layers)
6
+ super()
7
+ @num_layers = num_layers
8
+ @kernels = Torch::NN::ParameterList.new(
9
+ @num_layers.times.map do |i|
10
+ Torch::NN::Parameter.new(
11
+ Torch::NN::Init.xavier_normal!(Torch.empty(in_features, in_features))
12
+ )
13
+ end
14
+ )
15
+ @bias = Torch::NN::ParameterList.new(
16
+ @num_layers.times.map do |i|
17
+ Torch::NN::Parameter.new(Torch::NN::Init.zeros!(Torch.empty(in_features, 1)))
18
+ end
19
+ )
20
+ end
21
+
22
+ def forward(input)
23
+ x_0 = input.unsqueeze(2) # (B, N, 1)
24
+ x_l = x_0
25
+
26
+ @num_layers.times do |layer|
27
+ xl_w = Torch.matmul(@kernels[layer], x_l) # (B, N, 1)
28
+ x_l = x_0 * (xl_w + @bias[layer]) + x_l # (B, N, 1)
29
+ end
30
+
31
+ Torch.squeeze(x_l, dim: 2)
32
+ end
33
+ end
34
+ end
35
+ end
36
+ end
@@ -0,0 +1,23 @@
1
+ module TorchRec
2
+ module Modules
3
+ module DeepFM
4
+ class DeepFM < Torch::NN::Module
5
+ def initialize(dense_module)
6
+ super()
7
+ @dense_module = dense_module
8
+ end
9
+
10
+ def forward(embeddings)
11
+ deepfm_input = flatten_input(embeddings)
12
+ @dense_module.call(deepfm_input)
13
+ end
14
+
15
+ private
16
+
17
+ def flatten_input(inputs)
18
+ Torch.cat(inputs.map { |input| input.flatten(1) }, dim: 1)
19
+ end
20
+ end
21
+ end
22
+ end
23
+ end
@@ -0,0 +1,27 @@
1
+ module TorchRec
2
+ module Modules
3
+ module DeepFM
4
+ class FactorizationMachine < Torch::NN::Module
5
+ def initialize
6
+ super()
7
+ end
8
+
9
+ def forward(embeddings)
10
+ fm_input = flatten_input(embeddings)
11
+ sum_of_input = Torch.sum(fm_input, dim: 1, keepdim: true)
12
+ sum_of_square = Torch.sum(fm_input * fm_input, dim: 1, keepdim: true)
13
+ square_of_sum = sum_of_input * sum_of_input
14
+ cross_term = square_of_sum - sum_of_square
15
+ cross_term = Torch.sum(cross_term, dim: 1, keepdim: true) * 0.5 # [B, 1]
16
+ cross_term
17
+ end
18
+
19
+ private
20
+
21
+ def flatten_input(inputs)
22
+ Torch.cat(inputs.map { |input| input.flatten(1) }, dim: 1)
23
+ end
24
+ end
25
+ end
26
+ end
27
+ end
@@ -0,0 +1,33 @@
1
+ module TorchRec
2
+ module Sparse
3
+ class JaggedTensor
4
+ def initialize(values, weights: nil, lengths: nil, offsets: nil)
5
+ @values = values
6
+ @weights = weights
7
+ assert_offsets_or_lengths_is_provided(offsets, lengths)
8
+ if !offsets.nil?
9
+ assert_tensor_has_no_elements_or_has_integers(offsets, "offsets")
10
+ end
11
+ if !lengths.nil?
12
+ assert_tensor_has_no_elements_or_has_integers(lengths, "lengths")
13
+ end
14
+ @lengths = lengths
15
+ @offsets = offsets
16
+ end
17
+
18
+ private
19
+
20
+ def assert_offsets_or_lengths_is_provided(offsets, lengths)
21
+ if offsets.nil? && lengths.nil?
22
+ raise ArgumentError, "Must provide lengths or offsets"
23
+ end
24
+ end
25
+
26
+ def assert_tensor_has_no_elements_or_has_integers(tensor, tensor_name)
27
+ if tensor.numel != 0 && ![:int64, :int32, :int16, :int8, :uint8].include?(tensor.dtype)
28
+ raise ArgumentError, "#{tensor_name} must be of integer type, but got #{tensor.dtype}"
29
+ end
30
+ end
31
+ end
32
+ end
33
+ end
@@ -1,3 +1,3 @@
1
1
  module TorchRec
2
- VERSION = "0.0.1"
2
+ VERSION = "0.0.2"
3
3
  end
data/lib/torchrec.rb CHANGED
@@ -8,10 +8,16 @@ require "torchrec/models/dlrm/dense_arch"
8
8
 
9
9
  # modules
10
10
  require "torchrec/modules/activation/swish_layer_norm"
11
+ require "torchrec/modules/cross_net/cross_net"
12
+ require "torchrec/modules/deepfm/deepfm"
13
+ require "torchrec/modules/deepfm/factorization_machine"
11
14
  require "torchrec/modules/mlp/mlp"
12
15
  require "torchrec/modules/mlp/perceptron"
13
16
  require "torchrec/modules/utils"
14
17
 
18
+ # sparse
19
+ require "torchrec/sparse/jagged_tensor"
20
+
15
21
  # other
16
22
  require "torchrec/version"
17
23
 
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torchrec
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.1
4
+ version: 0.0.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2022-02-28 00:00:00.000000000 Z
11
+ date: 2022-03-15 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: torch-rb
@@ -16,14 +16,14 @@ dependencies:
16
16
  requirements:
17
17
  - - ">="
18
18
  - !ruby/object:Gem::Version
19
- version: '0'
19
+ version: 0.9.2
20
20
  type: :runtime
21
21
  prerelease: false
22
22
  version_requirements: !ruby/object:Gem::Requirement
23
23
  requirements:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
- version: '0'
26
+ version: 0.9.2
27
27
  description:
28
28
  email: andrew@ankane.org
29
29
  executables: []
@@ -38,9 +38,13 @@ files:
38
38
  - lib/torchrec/models/deepfm/over_arch.rb
39
39
  - lib/torchrec/models/dlrm/dense_arch.rb
40
40
  - lib/torchrec/modules/activation/swish_layer_norm.rb
41
+ - lib/torchrec/modules/cross_net/cross_net.rb
42
+ - lib/torchrec/modules/deepfm/deepfm.rb
43
+ - lib/torchrec/modules/deepfm/factorization_machine.rb
41
44
  - lib/torchrec/modules/mlp/mlp.rb
42
45
  - lib/torchrec/modules/mlp/perceptron.rb
43
46
  - lib/torchrec/modules/utils.rb
47
+ - lib/torchrec/sparse/jagged_tensor.rb
44
48
  - lib/torchrec/version.rb
45
49
  homepage: https://github.com/ankane/torchrec-ruby
46
50
  licenses:
@@ -61,7 +65,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
61
65
  - !ruby/object:Gem::Version
62
66
  version: '0'
63
67
  requirements: []
64
- rubygems_version: 3.3.3
68
+ rubygems_version: 3.3.7
65
69
  signing_key:
66
70
  specification_version: 4
67
71
  summary: Deep learning recommendation systems for Ruby