torchrec 0.0.1 → 0.0.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +3 -0
- data/lib/torchrec/modules/cross_net/cross_net.rb +36 -0
- data/lib/torchrec/modules/deepfm/deepfm.rb +23 -0
- data/lib/torchrec/modules/deepfm/factorization_machine.rb +27 -0
- data/lib/torchrec/sparse/jagged_tensor.rb +33 -0
- data/lib/torchrec/version.rb +1 -1
- data/lib/torchrec.rb +6 -0
- metadata +9 -5
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 573d37473e5a3e8e24b60bfe886ef1d14be87833ecc7c207a349c236480f1d6e
|
4
|
+
data.tar.gz: 9fcc58e9201bb545582feaa50a56727dd074688b4c830e495abd452aa31ad572
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: bb80aa64a91fb1af2dad904fc4f39b366f79765ba541334c4bf4109011249b9a6075a6d050b1aa644f42236cf037110181ed09cb74e71bd47882c10402e15101
|
7
|
+
data.tar.gz: 7393ad51f8882915e433f3aa8314f3d743a70b192b174854f161e2479c811ad7bb57f7f1953c0e817211056c201de5cf88d4d15e78c361e52768598155e9bad9
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -35,6 +35,9 @@ TorchRec::Models::DLRM::DenseArch.new(in_features, layer_sizes, device: nil)
|
|
35
35
|
|
36
36
|
```ruby
|
37
37
|
TorchRec::Modules::Activation::SwishLayerNorm.new(input_dims, device: nil)
|
38
|
+
TorchRec::Modules::CrossNet::CrossNet.new(in_features, num_layers)
|
39
|
+
TorchRec::Modules::DeepFM::DeepFM.new(dense_module)
|
40
|
+
TorchRec::Modules::DeepFM::FactorizationMachine.new
|
38
41
|
TorchRec::Modules::MLP::MLP.new(in_size, layer_sizes, bias: true, activation: :relu, device: nil)
|
39
42
|
TorchRec::Modules::MLP::Perceptron.new(in_size, out_size, bias: true, activation: Torch.method(:relu), device: nil)
|
40
43
|
```
|
@@ -0,0 +1,36 @@
|
|
1
|
+
module TorchRec
|
2
|
+
module Modules
|
3
|
+
module CrossNet
|
4
|
+
class CrossNet < Torch::NN::Module
|
5
|
+
def initialize(in_features, num_layers)
|
6
|
+
super()
|
7
|
+
@num_layers = num_layers
|
8
|
+
@kernels = Torch::NN::ParameterList.new(
|
9
|
+
@num_layers.times.map do |i|
|
10
|
+
Torch::NN::Parameter.new(
|
11
|
+
Torch::NN::Init.xavier_normal!(Torch.empty(in_features, in_features))
|
12
|
+
)
|
13
|
+
end
|
14
|
+
)
|
15
|
+
@bias = Torch::NN::ParameterList.new(
|
16
|
+
@num_layers.times.map do |i|
|
17
|
+
Torch::NN::Parameter.new(Torch::NN::Init.zeros!(Torch.empty(in_features, 1)))
|
18
|
+
end
|
19
|
+
)
|
20
|
+
end
|
21
|
+
|
22
|
+
def forward(input)
|
23
|
+
x_0 = input.unsqueeze(2) # (B, N, 1)
|
24
|
+
x_l = x_0
|
25
|
+
|
26
|
+
@num_layers.times do |layer|
|
27
|
+
xl_w = Torch.matmul(@kernels[layer], x_l) # (B, N, 1)
|
28
|
+
x_l = x_0 * (xl_w + @bias[layer]) + x_l # (B, N, 1)
|
29
|
+
end
|
30
|
+
|
31
|
+
Torch.squeeze(x_l, dim: 2)
|
32
|
+
end
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
@@ -0,0 +1,23 @@
|
|
1
|
+
module TorchRec
|
2
|
+
module Modules
|
3
|
+
module DeepFM
|
4
|
+
class DeepFM < Torch::NN::Module
|
5
|
+
def initialize(dense_module)
|
6
|
+
super()
|
7
|
+
@dense_module = dense_module
|
8
|
+
end
|
9
|
+
|
10
|
+
def forward(embeddings)
|
11
|
+
deepfm_input = flatten_input(embeddings)
|
12
|
+
@dense_module.call(deepfm_input)
|
13
|
+
end
|
14
|
+
|
15
|
+
private
|
16
|
+
|
17
|
+
def flatten_input(inputs)
|
18
|
+
Torch.cat(inputs.map { |input| input.flatten(1) }, dim: 1)
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
@@ -0,0 +1,27 @@
|
|
1
|
+
module TorchRec
|
2
|
+
module Modules
|
3
|
+
module DeepFM
|
4
|
+
class FactorizationMachine < Torch::NN::Module
|
5
|
+
def initialize
|
6
|
+
super()
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(embeddings)
|
10
|
+
fm_input = flatten_input(embeddings)
|
11
|
+
sum_of_input = Torch.sum(fm_input, dim: 1, keepdim: true)
|
12
|
+
sum_of_square = Torch.sum(fm_input * fm_input, dim: 1, keepdim: true)
|
13
|
+
square_of_sum = sum_of_input * sum_of_input
|
14
|
+
cross_term = square_of_sum - sum_of_square
|
15
|
+
cross_term = Torch.sum(cross_term, dim: 1, keepdim: true) * 0.5 # [B, 1]
|
16
|
+
cross_term
|
17
|
+
end
|
18
|
+
|
19
|
+
private
|
20
|
+
|
21
|
+
def flatten_input(inputs)
|
22
|
+
Torch.cat(inputs.map { |input| input.flatten(1) }, dim: 1)
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
@@ -0,0 +1,33 @@
|
|
1
|
+
module TorchRec
|
2
|
+
module Sparse
|
3
|
+
class JaggedTensor
|
4
|
+
def initialize(values, weights: nil, lengths: nil, offsets: nil)
|
5
|
+
@values = values
|
6
|
+
@weights = weights
|
7
|
+
assert_offsets_or_lengths_is_provided(offsets, lengths)
|
8
|
+
if !offsets.nil?
|
9
|
+
assert_tensor_has_no_elements_or_has_integers(offsets, "offsets")
|
10
|
+
end
|
11
|
+
if !lengths.nil?
|
12
|
+
assert_tensor_has_no_elements_or_has_integers(lengths, "lengths")
|
13
|
+
end
|
14
|
+
@lengths = lengths
|
15
|
+
@offsets = offsets
|
16
|
+
end
|
17
|
+
|
18
|
+
private
|
19
|
+
|
20
|
+
def assert_offsets_or_lengths_is_provided(offsets, lengths)
|
21
|
+
if offsets.nil? && lengths.nil?
|
22
|
+
raise ArgumentError, "Must provide lengths or offsets"
|
23
|
+
end
|
24
|
+
end
|
25
|
+
|
26
|
+
def assert_tensor_has_no_elements_or_has_integers(tensor, tensor_name)
|
27
|
+
if tensor.numel != 0 && ![:int64, :int32, :int16, :int8, :uint8].include?(tensor.dtype)
|
28
|
+
raise ArgumentError, "#{tensor_name} must be of integer type, but got #{tensor.dtype}"
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
33
|
+
end
|
data/lib/torchrec/version.rb
CHANGED
data/lib/torchrec.rb
CHANGED
@@ -8,10 +8,16 @@ require "torchrec/models/dlrm/dense_arch"
|
|
8
8
|
|
9
9
|
# modules
|
10
10
|
require "torchrec/modules/activation/swish_layer_norm"
|
11
|
+
require "torchrec/modules/cross_net/cross_net"
|
12
|
+
require "torchrec/modules/deepfm/deepfm"
|
13
|
+
require "torchrec/modules/deepfm/factorization_machine"
|
11
14
|
require "torchrec/modules/mlp/mlp"
|
12
15
|
require "torchrec/modules/mlp/perceptron"
|
13
16
|
require "torchrec/modules/utils"
|
14
17
|
|
18
|
+
# sparse
|
19
|
+
require "torchrec/sparse/jagged_tensor"
|
20
|
+
|
15
21
|
# other
|
16
22
|
require "torchrec/version"
|
17
23
|
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torchrec
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.0.
|
4
|
+
version: 0.0.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2022-
|
11
|
+
date: 2022-03-15 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: torch-rb
|
@@ -16,14 +16,14 @@ dependencies:
|
|
16
16
|
requirements:
|
17
17
|
- - ">="
|
18
18
|
- !ruby/object:Gem::Version
|
19
|
-
version:
|
19
|
+
version: 0.9.2
|
20
20
|
type: :runtime
|
21
21
|
prerelease: false
|
22
22
|
version_requirements: !ruby/object:Gem::Requirement
|
23
23
|
requirements:
|
24
24
|
- - ">="
|
25
25
|
- !ruby/object:Gem::Version
|
26
|
-
version:
|
26
|
+
version: 0.9.2
|
27
27
|
description:
|
28
28
|
email: andrew@ankane.org
|
29
29
|
executables: []
|
@@ -38,9 +38,13 @@ files:
|
|
38
38
|
- lib/torchrec/models/deepfm/over_arch.rb
|
39
39
|
- lib/torchrec/models/dlrm/dense_arch.rb
|
40
40
|
- lib/torchrec/modules/activation/swish_layer_norm.rb
|
41
|
+
- lib/torchrec/modules/cross_net/cross_net.rb
|
42
|
+
- lib/torchrec/modules/deepfm/deepfm.rb
|
43
|
+
- lib/torchrec/modules/deepfm/factorization_machine.rb
|
41
44
|
- lib/torchrec/modules/mlp/mlp.rb
|
42
45
|
- lib/torchrec/modules/mlp/perceptron.rb
|
43
46
|
- lib/torchrec/modules/utils.rb
|
47
|
+
- lib/torchrec/sparse/jagged_tensor.rb
|
44
48
|
- lib/torchrec/version.rb
|
45
49
|
homepage: https://github.com/ankane/torchrec-ruby
|
46
50
|
licenses:
|
@@ -61,7 +65,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
61
65
|
- !ruby/object:Gem::Version
|
62
66
|
version: '0'
|
63
67
|
requirements: []
|
64
|
-
rubygems_version: 3.3.
|
68
|
+
rubygems_version: 3.3.7
|
65
69
|
signing_key:
|
66
70
|
specification_version: 4
|
67
71
|
summary: Deep learning recommendation systems for Ruby
|