torchrec 0.0.1 → 0.0.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/README.md +3 -0
- data/lib/torchrec/modules/cross_net/cross_net.rb +36 -0
- data/lib/torchrec/modules/deepfm/deepfm.rb +23 -0
- data/lib/torchrec/modules/deepfm/factorization_machine.rb +27 -0
- data/lib/torchrec/sparse/jagged_tensor.rb +33 -0
- data/lib/torchrec/version.rb +1 -1
- data/lib/torchrec.rb +15 -9
- metadata +10 -6
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 8fa2077cf1a744fe2d11d9fdb030e82a428542892670d6f2ef69b4801c3d0f43
|
4
|
+
data.tar.gz: 6c95aa0542c037f0cc466e2e82f3b366ee7a099da580deeb7718fc4889702b1f
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 796466080c868a2aadd167c7bd2db68c150bb25b56ff7f56ce05b7d50077f24f1361e488edab331e3e468fec929e0a78ef49df1a14e0632baedd465deb266538
|
7
|
+
data.tar.gz: 38b05acd34abe5bf3364697690dd8b9dbcf16943ed54287b3ef6ed75256f8902898acff2dd9caf64d3e03d94cf3f8071efc25bdb362bf2eb12b90c68d8a6d310
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -35,6 +35,9 @@ TorchRec::Models::DLRM::DenseArch.new(in_features, layer_sizes, device: nil)
|
|
35
35
|
|
36
36
|
```ruby
|
37
37
|
TorchRec::Modules::Activation::SwishLayerNorm.new(input_dims, device: nil)
|
38
|
+
TorchRec::Modules::CrossNet::CrossNet.new(in_features, num_layers)
|
39
|
+
TorchRec::Modules::DeepFM::DeepFM.new(dense_module)
|
40
|
+
TorchRec::Modules::DeepFM::FactorizationMachine.new
|
38
41
|
TorchRec::Modules::MLP::MLP.new(in_size, layer_sizes, bias: true, activation: :relu, device: nil)
|
39
42
|
TorchRec::Modules::MLP::Perceptron.new(in_size, out_size, bias: true, activation: Torch.method(:relu), device: nil)
|
40
43
|
```
|
@@ -0,0 +1,36 @@
|
|
1
|
+
module TorchRec
|
2
|
+
module Modules
|
3
|
+
module CrossNet
|
4
|
+
class CrossNet < Torch::NN::Module
|
5
|
+
def initialize(in_features, num_layers)
|
6
|
+
super()
|
7
|
+
@num_layers = num_layers
|
8
|
+
@kernels = Torch::NN::ParameterList.new(
|
9
|
+
@num_layers.times.map do |i|
|
10
|
+
Torch::NN::Parameter.new(
|
11
|
+
Torch::NN::Init.xavier_normal!(Torch.empty(in_features, in_features))
|
12
|
+
)
|
13
|
+
end
|
14
|
+
)
|
15
|
+
@bias = Torch::NN::ParameterList.new(
|
16
|
+
@num_layers.times.map do |i|
|
17
|
+
Torch::NN::Parameter.new(Torch::NN::Init.zeros!(Torch.empty(in_features, 1)))
|
18
|
+
end
|
19
|
+
)
|
20
|
+
end
|
21
|
+
|
22
|
+
def forward(input)
|
23
|
+
x_0 = input.unsqueeze(2) # (B, N, 1)
|
24
|
+
x_l = x_0
|
25
|
+
|
26
|
+
@num_layers.times do |layer|
|
27
|
+
xl_w = Torch.matmul(@kernels[layer], x_l) # (B, N, 1)
|
28
|
+
x_l = x_0 * (xl_w + @bias[layer]) + x_l # (B, N, 1)
|
29
|
+
end
|
30
|
+
|
31
|
+
Torch.squeeze(x_l, dim: 2)
|
32
|
+
end
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
@@ -0,0 +1,23 @@
|
|
1
|
+
module TorchRec
|
2
|
+
module Modules
|
3
|
+
module DeepFM
|
4
|
+
class DeepFM < Torch::NN::Module
|
5
|
+
def initialize(dense_module)
|
6
|
+
super()
|
7
|
+
@dense_module = dense_module
|
8
|
+
end
|
9
|
+
|
10
|
+
def forward(embeddings)
|
11
|
+
deepfm_input = flatten_input(embeddings)
|
12
|
+
@dense_module.call(deepfm_input)
|
13
|
+
end
|
14
|
+
|
15
|
+
private
|
16
|
+
|
17
|
+
def flatten_input(inputs)
|
18
|
+
Torch.cat(inputs.map { |input| input.flatten(1) }, dim: 1)
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
@@ -0,0 +1,27 @@
|
|
1
|
+
module TorchRec
|
2
|
+
module Modules
|
3
|
+
module DeepFM
|
4
|
+
class FactorizationMachine < Torch::NN::Module
|
5
|
+
def initialize
|
6
|
+
super()
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(embeddings)
|
10
|
+
fm_input = flatten_input(embeddings)
|
11
|
+
sum_of_input = Torch.sum(fm_input, dim: 1, keepdim: true)
|
12
|
+
sum_of_square = Torch.sum(fm_input * fm_input, dim: 1, keepdim: true)
|
13
|
+
square_of_sum = sum_of_input * sum_of_input
|
14
|
+
cross_term = square_of_sum - sum_of_square
|
15
|
+
cross_term = Torch.sum(cross_term, dim: 1, keepdim: true) * 0.5 # [B, 1]
|
16
|
+
cross_term
|
17
|
+
end
|
18
|
+
|
19
|
+
private
|
20
|
+
|
21
|
+
def flatten_input(inputs)
|
22
|
+
Torch.cat(inputs.map { |input| input.flatten(1) }, dim: 1)
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
@@ -0,0 +1,33 @@
|
|
1
|
+
module TorchRec
|
2
|
+
module Sparse
|
3
|
+
class JaggedTensor
|
4
|
+
def initialize(values, weights: nil, lengths: nil, offsets: nil)
|
5
|
+
@values = values
|
6
|
+
@weights = weights
|
7
|
+
assert_offsets_or_lengths_is_provided(offsets, lengths)
|
8
|
+
if !offsets.nil?
|
9
|
+
assert_tensor_has_no_elements_or_has_integers(offsets, "offsets")
|
10
|
+
end
|
11
|
+
if !lengths.nil?
|
12
|
+
assert_tensor_has_no_elements_or_has_integers(lengths, "lengths")
|
13
|
+
end
|
14
|
+
@lengths = lengths
|
15
|
+
@offsets = offsets
|
16
|
+
end
|
17
|
+
|
18
|
+
private
|
19
|
+
|
20
|
+
def assert_offsets_or_lengths_is_provided(offsets, lengths)
|
21
|
+
if offsets.nil? && lengths.nil?
|
22
|
+
raise ArgumentError, "Must provide lengths or offsets"
|
23
|
+
end
|
24
|
+
end
|
25
|
+
|
26
|
+
def assert_tensor_has_no_elements_or_has_integers(tensor, tensor_name)
|
27
|
+
if tensor.numel != 0 && ![:int64, :int32, :int16, :int8, :uint8].include?(tensor.dtype)
|
28
|
+
raise ArgumentError, "#{tensor_name} must be of integer type, but got #{tensor.dtype}"
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
33
|
+
end
|
data/lib/torchrec/version.rb
CHANGED
data/lib/torchrec.rb
CHANGED
@@ -1,19 +1,25 @@
|
|
1
1
|
# dependencies
|
2
|
-
require "torch"
|
2
|
+
require "torch-rb"
|
3
3
|
|
4
4
|
# models
|
5
|
-
|
6
|
-
|
7
|
-
|
5
|
+
require_relative "torchrec/models/deepfm/dense_arch"
|
6
|
+
require_relative "torchrec/models/deepfm/over_arch"
|
7
|
+
require_relative "torchrec/models/dlrm/dense_arch"
|
8
8
|
|
9
9
|
# modules
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
10
|
+
require_relative "torchrec/modules/activation/swish_layer_norm"
|
11
|
+
require_relative "torchrec/modules/cross_net/cross_net"
|
12
|
+
require_relative "torchrec/modules/deepfm/deepfm"
|
13
|
+
require_relative "torchrec/modules/deepfm/factorization_machine"
|
14
|
+
require_relative "torchrec/modules/mlp/mlp"
|
15
|
+
require_relative "torchrec/modules/mlp/perceptron"
|
16
|
+
require_relative "torchrec/modules/utils"
|
17
|
+
|
18
|
+
# sparse
|
19
|
+
require_relative "torchrec/sparse/jagged_tensor"
|
14
20
|
|
15
21
|
# other
|
16
|
-
|
22
|
+
require_relative "torchrec/version"
|
17
23
|
|
18
24
|
module TorchRec
|
19
25
|
class Error < StandardError; end
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torchrec
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.0.
|
4
|
+
version: 0.0.3
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2023-07-25 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: torch-rb
|
@@ -16,14 +16,14 @@ dependencies:
|
|
16
16
|
requirements:
|
17
17
|
- - ">="
|
18
18
|
- !ruby/object:Gem::Version
|
19
|
-
version: '0'
|
19
|
+
version: '0.10'
|
20
20
|
type: :runtime
|
21
21
|
prerelease: false
|
22
22
|
version_requirements: !ruby/object:Gem::Requirement
|
23
23
|
requirements:
|
24
24
|
- - ">="
|
25
25
|
- !ruby/object:Gem::Version
|
26
|
-
version: '0'
|
26
|
+
version: '0.10'
|
27
27
|
description:
|
28
28
|
email: andrew@ankane.org
|
29
29
|
executables: []
|
@@ -38,9 +38,13 @@ files:
|
|
38
38
|
- lib/torchrec/models/deepfm/over_arch.rb
|
39
39
|
- lib/torchrec/models/dlrm/dense_arch.rb
|
40
40
|
- lib/torchrec/modules/activation/swish_layer_norm.rb
|
41
|
+
- lib/torchrec/modules/cross_net/cross_net.rb
|
42
|
+
- lib/torchrec/modules/deepfm/deepfm.rb
|
43
|
+
- lib/torchrec/modules/deepfm/factorization_machine.rb
|
41
44
|
- lib/torchrec/modules/mlp/mlp.rb
|
42
45
|
- lib/torchrec/modules/mlp/perceptron.rb
|
43
46
|
- lib/torchrec/modules/utils.rb
|
47
|
+
- lib/torchrec/sparse/jagged_tensor.rb
|
44
48
|
- lib/torchrec/version.rb
|
45
49
|
homepage: https://github.com/ankane/torchrec-ruby
|
46
50
|
licenses:
|
@@ -54,14 +58,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
|
|
54
58
|
requirements:
|
55
59
|
- - ">="
|
56
60
|
- !ruby/object:Gem::Version
|
57
|
-
version: '
|
61
|
+
version: '3'
|
58
62
|
required_rubygems_version: !ruby/object:Gem::Requirement
|
59
63
|
requirements:
|
60
64
|
- - ">="
|
61
65
|
- !ruby/object:Gem::Version
|
62
66
|
version: '0'
|
63
67
|
requirements: []
|
64
|
-
rubygems_version: 3.
|
68
|
+
rubygems_version: 3.4.10
|
65
69
|
signing_key:
|
66
70
|
specification_version: 4
|
67
71
|
summary: Deep learning recommendation systems for Ruby
|