torchrec 0.0.1 → 0.0.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +3 -0
- data/lib/torchrec/modules/cross_net/cross_net.rb +36 -0
- data/lib/torchrec/modules/deepfm/deepfm.rb +23 -0
- data/lib/torchrec/modules/deepfm/factorization_machine.rb +27 -0
- data/lib/torchrec/sparse/jagged_tensor.rb +33 -0
- data/lib/torchrec/version.rb +1 -1
- data/lib/torchrec.rb +6 -0
- metadata +9 -5
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 573d37473e5a3e8e24b60bfe886ef1d14be87833ecc7c207a349c236480f1d6e
|
4
|
+
data.tar.gz: 9fcc58e9201bb545582feaa50a56727dd074688b4c830e495abd452aa31ad572
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: bb80aa64a91fb1af2dad904fc4f39b366f79765ba541334c4bf4109011249b9a6075a6d050b1aa644f42236cf037110181ed09cb74e71bd47882c10402e15101
|
7
|
+
data.tar.gz: 7393ad51f8882915e433f3aa8314f3d743a70b192b174854f161e2479c811ad7bb57f7f1953c0e817211056c201de5cf88d4d15e78c361e52768598155e9bad9
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -35,6 +35,9 @@ TorchRec::Models::DLRM::DenseArch.new(in_features, layer_sizes, device: nil)
|
|
35
35
|
|
36
36
|
```ruby
|
37
37
|
TorchRec::Modules::Activation::SwishLayerNorm.new(input_dims, device: nil)
|
38
|
+
TorchRec::Modules::CrossNet::CrossNet.new(in_features, num_layers)
|
39
|
+
TorchRec::Modules::DeepFM::DeepFM.new(dense_module)
|
40
|
+
TorchRec::Modules::DeepFM::FactorizationMachine.new
|
38
41
|
TorchRec::Modules::MLP::MLP.new(in_size, layer_sizes, bias: true, activation: :relu, device: nil)
|
39
42
|
TorchRec::Modules::MLP::Perceptron.new(in_size, out_size, bias: true, activation: Torch.method(:relu), device: nil)
|
40
43
|
```
|
@@ -0,0 +1,36 @@
|
|
1
|
+
module TorchRec
|
2
|
+
module Modules
|
3
|
+
module CrossNet
|
4
|
+
class CrossNet < Torch::NN::Module
|
5
|
+
def initialize(in_features, num_layers)
|
6
|
+
super()
|
7
|
+
@num_layers = num_layers
|
8
|
+
@kernels = Torch::NN::ParameterList.new(
|
9
|
+
@num_layers.times.map do |i|
|
10
|
+
Torch::NN::Parameter.new(
|
11
|
+
Torch::NN::Init.xavier_normal!(Torch.empty(in_features, in_features))
|
12
|
+
)
|
13
|
+
end
|
14
|
+
)
|
15
|
+
@bias = Torch::NN::ParameterList.new(
|
16
|
+
@num_layers.times.map do |i|
|
17
|
+
Torch::NN::Parameter.new(Torch::NN::Init.zeros!(Torch.empty(in_features, 1)))
|
18
|
+
end
|
19
|
+
)
|
20
|
+
end
|
21
|
+
|
22
|
+
def forward(input)
|
23
|
+
x_0 = input.unsqueeze(2) # (B, N, 1)
|
24
|
+
x_l = x_0
|
25
|
+
|
26
|
+
@num_layers.times do |layer|
|
27
|
+
xl_w = Torch.matmul(@kernels[layer], x_l) # (B, N, 1)
|
28
|
+
x_l = x_0 * (xl_w + @bias[layer]) + x_l # (B, N, 1)
|
29
|
+
end
|
30
|
+
|
31
|
+
Torch.squeeze(x_l, dim: 2)
|
32
|
+
end
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
@@ -0,0 +1,23 @@
|
|
1
|
+
module TorchRec
|
2
|
+
module Modules
|
3
|
+
module DeepFM
|
4
|
+
class DeepFM < Torch::NN::Module
|
5
|
+
def initialize(dense_module)
|
6
|
+
super()
|
7
|
+
@dense_module = dense_module
|
8
|
+
end
|
9
|
+
|
10
|
+
def forward(embeddings)
|
11
|
+
deepfm_input = flatten_input(embeddings)
|
12
|
+
@dense_module.call(deepfm_input)
|
13
|
+
end
|
14
|
+
|
15
|
+
private
|
16
|
+
|
17
|
+
def flatten_input(inputs)
|
18
|
+
Torch.cat(inputs.map { |input| input.flatten(1) }, dim: 1)
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
@@ -0,0 +1,27 @@
|
|
1
|
+
module TorchRec
|
2
|
+
module Modules
|
3
|
+
module DeepFM
|
4
|
+
class FactorizationMachine < Torch::NN::Module
|
5
|
+
def initialize
|
6
|
+
super()
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(embeddings)
|
10
|
+
fm_input = flatten_input(embeddings)
|
11
|
+
sum_of_input = Torch.sum(fm_input, dim: 1, keepdim: true)
|
12
|
+
sum_of_square = Torch.sum(fm_input * fm_input, dim: 1, keepdim: true)
|
13
|
+
square_of_sum = sum_of_input * sum_of_input
|
14
|
+
cross_term = square_of_sum - sum_of_square
|
15
|
+
cross_term = Torch.sum(cross_term, dim: 1, keepdim: true) * 0.5 # [B, 1]
|
16
|
+
cross_term
|
17
|
+
end
|
18
|
+
|
19
|
+
private
|
20
|
+
|
21
|
+
def flatten_input(inputs)
|
22
|
+
Torch.cat(inputs.map { |input| input.flatten(1) }, dim: 1)
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
@@ -0,0 +1,33 @@
|
|
1
|
+
module TorchRec
|
2
|
+
module Sparse
|
3
|
+
class JaggedTensor
|
4
|
+
def initialize(values, weights: nil, lengths: nil, offsets: nil)
|
5
|
+
@values = values
|
6
|
+
@weights = weights
|
7
|
+
assert_offsets_or_lengths_is_provided(offsets, lengths)
|
8
|
+
if !offsets.nil?
|
9
|
+
assert_tensor_has_no_elements_or_has_integers(offsets, "offsets")
|
10
|
+
end
|
11
|
+
if !lengths.nil?
|
12
|
+
assert_tensor_has_no_elements_or_has_integers(lengths, "lengths")
|
13
|
+
end
|
14
|
+
@lengths = lengths
|
15
|
+
@offsets = offsets
|
16
|
+
end
|
17
|
+
|
18
|
+
private
|
19
|
+
|
20
|
+
def assert_offsets_or_lengths_is_provided(offsets, lengths)
|
21
|
+
if offsets.nil? && lengths.nil?
|
22
|
+
raise ArgumentError, "Must provide lengths or offsets"
|
23
|
+
end
|
24
|
+
end
|
25
|
+
|
26
|
+
def assert_tensor_has_no_elements_or_has_integers(tensor, tensor_name)
|
27
|
+
if tensor.numel != 0 && ![:int64, :int32, :int16, :int8, :uint8].include?(tensor.dtype)
|
28
|
+
raise ArgumentError, "#{tensor_name} must be of integer type, but got #{tensor.dtype}"
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
33
|
+
end
|
data/lib/torchrec/version.rb
CHANGED
data/lib/torchrec.rb
CHANGED
@@ -8,10 +8,16 @@ require "torchrec/models/dlrm/dense_arch"
|
|
8
8
|
|
9
9
|
# modules
|
10
10
|
require "torchrec/modules/activation/swish_layer_norm"
|
11
|
+
require "torchrec/modules/cross_net/cross_net"
|
12
|
+
require "torchrec/modules/deepfm/deepfm"
|
13
|
+
require "torchrec/modules/deepfm/factorization_machine"
|
11
14
|
require "torchrec/modules/mlp/mlp"
|
12
15
|
require "torchrec/modules/mlp/perceptron"
|
13
16
|
require "torchrec/modules/utils"
|
14
17
|
|
18
|
+
# sparse
|
19
|
+
require "torchrec/sparse/jagged_tensor"
|
20
|
+
|
15
21
|
# other
|
16
22
|
require "torchrec/version"
|
17
23
|
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torchrec
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.0.
|
4
|
+
version: 0.0.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2022-
|
11
|
+
date: 2022-03-15 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: torch-rb
|
@@ -16,14 +16,14 @@ dependencies:
|
|
16
16
|
requirements:
|
17
17
|
- - ">="
|
18
18
|
- !ruby/object:Gem::Version
|
19
|
-
version:
|
19
|
+
version: 0.9.2
|
20
20
|
type: :runtime
|
21
21
|
prerelease: false
|
22
22
|
version_requirements: !ruby/object:Gem::Requirement
|
23
23
|
requirements:
|
24
24
|
- - ">="
|
25
25
|
- !ruby/object:Gem::Version
|
26
|
-
version:
|
26
|
+
version: 0.9.2
|
27
27
|
description:
|
28
28
|
email: andrew@ankane.org
|
29
29
|
executables: []
|
@@ -38,9 +38,13 @@ files:
|
|
38
38
|
- lib/torchrec/models/deepfm/over_arch.rb
|
39
39
|
- lib/torchrec/models/dlrm/dense_arch.rb
|
40
40
|
- lib/torchrec/modules/activation/swish_layer_norm.rb
|
41
|
+
- lib/torchrec/modules/cross_net/cross_net.rb
|
42
|
+
- lib/torchrec/modules/deepfm/deepfm.rb
|
43
|
+
- lib/torchrec/modules/deepfm/factorization_machine.rb
|
41
44
|
- lib/torchrec/modules/mlp/mlp.rb
|
42
45
|
- lib/torchrec/modules/mlp/perceptron.rb
|
43
46
|
- lib/torchrec/modules/utils.rb
|
47
|
+
- lib/torchrec/sparse/jagged_tensor.rb
|
44
48
|
- lib/torchrec/version.rb
|
45
49
|
homepage: https://github.com/ankane/torchrec-ruby
|
46
50
|
licenses:
|
@@ -61,7 +65,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
61
65
|
- !ruby/object:Gem::Version
|
62
66
|
version: '0'
|
63
67
|
requirements: []
|
64
|
-
rubygems_version: 3.3.
|
68
|
+
rubygems_version: 3.3.7
|
65
69
|
signing_key:
|
66
70
|
specification_version: 4
|
67
71
|
summary: Deep learning recommendation systems for Ruby
|