keras-rs-nightly 0.0.1.dev2025022303__py3-none-any.whl → 0.0.1.dev2025022503__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-rs-nightly might be problematic. Click here for more details.
- keras_rs/src/layers/modeling/dot_interaction.py +66 -23
- keras_rs/src/utils/keras_utils.py +17 -0
- keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.0.1.dev2025022303.dist-info → keras_rs_nightly-0.0.1.dev2025022503.dist-info}/METADATA +1 -1
- {keras_rs_nightly-0.0.1.dev2025022303.dist-info → keras_rs_nightly-0.0.1.dev2025022503.dist-info}/RECORD +7 -7
- {keras_rs_nightly-0.0.1.dev2025022303.dist-info → keras_rs_nightly-0.0.1.dev2025022503.dist-info}/WHEEL +0 -0
- {keras_rs_nightly-0.0.1.dev2025022303.dist-info → keras_rs_nightly-0.0.1.dev2025022503.dist-info}/top_level.txt +0 -0
|
@@ -5,6 +5,7 @@ from keras import ops
|
|
|
5
5
|
|
|
6
6
|
from keras_rs.src import types
|
|
7
7
|
from keras_rs.src.api_export import keras_rs_export
|
|
8
|
+
from keras_rs.src.utils.keras_utils import check_shapes_compatible
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
@keras_rs_export("keras_rs.layers.DotInteraction")
|
|
@@ -44,6 +45,44 @@ class DotInteraction(keras.layers.Layer):
|
|
|
44
45
|
self.self_interaction = self_interaction
|
|
45
46
|
self.skip_gather = skip_gather
|
|
46
47
|
|
|
48
|
+
def _generate_tril_mask(
|
|
49
|
+
self, pairwise_interaction_matrix: types.Tensor
|
|
50
|
+
) -> types.Tensor:
|
|
51
|
+
"""Generates lower triangular mask."""
|
|
52
|
+
|
|
53
|
+
# If `self.self_interaction` is `True`, keep the main diagonal.
|
|
54
|
+
k = -1
|
|
55
|
+
if self.self_interaction:
|
|
56
|
+
k = 0
|
|
57
|
+
|
|
58
|
+
# Typecast k from Python int to tensor, because `ops.tril` uses
|
|
59
|
+
# `tf.cond` (which requires tensors).
|
|
60
|
+
# TODO (abheesht): Remove typecast once fix is merged in core Keras.
|
|
61
|
+
if keras.config.backend() == "tensorflow":
|
|
62
|
+
k = ops.array(k)
|
|
63
|
+
tril_mask = ops.tril(
|
|
64
|
+
ops.ones_like(pairwise_interaction_matrix, dtype=bool),
|
|
65
|
+
k=k,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
return tril_mask
|
|
69
|
+
|
|
70
|
+
def _get_lower_triangular_indices(self, num_features: int) -> list[int]:
|
|
71
|
+
"""Python function which generates indices to get the lower triangular
|
|
72
|
+
matrix as if it were flattened.
|
|
73
|
+
"""
|
|
74
|
+
flattened_indices = []
|
|
75
|
+
for i in range(num_features):
|
|
76
|
+
k = i
|
|
77
|
+
# if `self.self_interaction` is `True`, keep the main diagonal.
|
|
78
|
+
if self.self_interaction:
|
|
79
|
+
k += 1
|
|
80
|
+
for j in range(k):
|
|
81
|
+
flattened_index = i * num_features + j
|
|
82
|
+
flattened_indices.append(flattened_index)
|
|
83
|
+
|
|
84
|
+
return flattened_indices
|
|
85
|
+
|
|
47
86
|
def call(self, inputs: list[types.Tensor]) -> types.Tensor:
|
|
48
87
|
"""Forward pass of the dot interaction layer.
|
|
49
88
|
|
|
@@ -64,23 +103,25 @@ class DotInteraction(keras.layers.Layer):
|
|
|
64
103
|
# Check if all feature tensors have the same shape and are of rank 2.
|
|
65
104
|
shape = ops.shape(inputs[0])
|
|
66
105
|
for idx, tensor in enumerate(inputs):
|
|
67
|
-
|
|
106
|
+
other_shape = ops.shape(tensor)
|
|
107
|
+
|
|
108
|
+
if len(shape) != 2:
|
|
109
|
+
raise ValueError(
|
|
110
|
+
"All feature tensors inside `inputs` should have rank 2. "
|
|
111
|
+
f"Received rank {len(shape)} at index {idx}."
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
if not check_shapes_compatible(shape, other_shape):
|
|
68
115
|
raise ValueError(
|
|
69
116
|
"All feature tensors in `inputs` should have the same "
|
|
70
117
|
f"shape. Found at least one conflict: shape = {shape} at "
|
|
71
|
-
f"index 0 and shape = {ops.shape(tensor)} at index {idx}"
|
|
118
|
+
f"index 0 and shape = {ops.shape(tensor)} at index {idx}."
|
|
72
119
|
)
|
|
73
120
|
|
|
74
|
-
if len(shape) != 2:
|
|
75
|
-
raise ValueError(
|
|
76
|
-
"All feature tensors inside `inputs` should have rank 2. "
|
|
77
|
-
f"Received rank {len(shape)}."
|
|
78
|
-
)
|
|
79
|
-
|
|
80
121
|
# `(batch_size, num_features, feature_dim)`
|
|
81
122
|
features = ops.stack(inputs, axis=1)
|
|
82
123
|
|
|
83
|
-
batch_size,
|
|
124
|
+
batch_size, num_features, _ = ops.shape(features)
|
|
84
125
|
|
|
85
126
|
# Compute the dot product to get feature interactions. The shape here is
|
|
86
127
|
# `(batch_size, num_features, num_features)`.
|
|
@@ -88,28 +129,30 @@ class DotInteraction(keras.layers.Layer):
|
|
|
88
129
|
features, ops.transpose(features, axes=(0, 2, 1))
|
|
89
130
|
)
|
|
90
131
|
|
|
91
|
-
# if `self.self_interaction` is `True`, keep the main diagonal.
|
|
92
|
-
k = -1
|
|
93
|
-
if self.self_interaction:
|
|
94
|
-
k = 0
|
|
95
|
-
|
|
96
|
-
tril_mask = ops.tril(
|
|
97
|
-
ops.ones_like(pairwise_interaction_matrix, dtype=bool),
|
|
98
|
-
k=k,
|
|
99
|
-
)
|
|
100
|
-
|
|
101
132
|
# Set the upper triangle entries to 0, if `self.skip_gather` is True.
|
|
102
133
|
# Else, "pick" only the lower triangle entries.
|
|
103
134
|
if self.skip_gather:
|
|
135
|
+
tril_mask = self._generate_tril_mask(pairwise_interaction_matrix)
|
|
136
|
+
|
|
104
137
|
activations = ops.multiply(
|
|
105
138
|
pairwise_interaction_matrix,
|
|
106
139
|
ops.cast(tril_mask, dtype=pairwise_interaction_matrix.dtype),
|
|
107
140
|
)
|
|
141
|
+
# Rank-2 tensor.
|
|
142
|
+
activations = ops.reshape(
|
|
143
|
+
activations, (batch_size, num_features * num_features)
|
|
144
|
+
)
|
|
108
145
|
else:
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
146
|
+
flattened_indices = self._get_lower_triangular_indices(num_features)
|
|
147
|
+
pairwise_interaction_matrix_flattened = ops.reshape(
|
|
148
|
+
pairwise_interaction_matrix,
|
|
149
|
+
(batch_size, num_features * num_features),
|
|
150
|
+
)
|
|
151
|
+
activations = ops.take(
|
|
152
|
+
pairwise_interaction_matrix_flattened,
|
|
153
|
+
flattened_indices,
|
|
154
|
+
axis=-1,
|
|
155
|
+
)
|
|
113
156
|
|
|
114
157
|
return activations
|
|
115
158
|
|
|
@@ -2,6 +2,8 @@ from typing import Union
|
|
|
2
2
|
|
|
3
3
|
import keras
|
|
4
4
|
|
|
5
|
+
from keras_rs.src import types
|
|
6
|
+
|
|
5
7
|
|
|
6
8
|
def clone_initializer(
|
|
7
9
|
initializer: Union[str, keras.initializers.Initializer],
|
|
@@ -25,3 +27,18 @@ def clone_initializer(
|
|
|
25
27
|
return initializer_class.from_config(config)
|
|
26
28
|
# If we get a string or dict, just return as we cannot and should not clone.
|
|
27
29
|
return initializer
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def check_shapes_compatible(
|
|
33
|
+
shape1: types.TensorShape, shape2: types.TensorShape
|
|
34
|
+
) -> bool:
|
|
35
|
+
# Check rank first.
|
|
36
|
+
if len(shape1) != len(shape2):
|
|
37
|
+
return False
|
|
38
|
+
|
|
39
|
+
for d1, d2 in zip(shape1, shape2):
|
|
40
|
+
if isinstance(d1, int) and isinstance(d2, int):
|
|
41
|
+
if d1 != d2:
|
|
42
|
+
return False
|
|
43
|
+
|
|
44
|
+
return True
|
keras_rs/src/version.py
CHANGED
|
@@ -4,16 +4,16 @@ keras_rs/api/layers/__init__.py,sha256=moxzBqnCCQsZiIgEUXn9mdtSALFTU5e0c4Ct8c0bW
|
|
|
4
4
|
keras_rs/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
5
5
|
keras_rs/src/api_export.py,sha256=RsmG-DvO-cdFeAF9W6LRzms0kvtm-Yp9BAA_d-952zI,510
|
|
6
6
|
keras_rs/src/types.py,sha256=UyOdgjqrqg_b58opnY8n6gTiDHKVR8z_bmEruehERBk,514
|
|
7
|
-
keras_rs/src/version.py,sha256=
|
|
7
|
+
keras_rs/src/version.py,sha256=MDAVWC9bHK8-Jwx-8Ya6HC2AHweZmT9LsoohVj41IQM,222
|
|
8
8
|
keras_rs/src/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
9
9
|
keras_rs/src/layers/modeling/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
|
-
keras_rs/src/layers/modeling/dot_interaction.py,sha256=
|
|
10
|
+
keras_rs/src/layers/modeling/dot_interaction.py,sha256=Ove9USc5wFFICHfgkTfJuZjSWt7G53K3eE0HtwAO-Dw,6850
|
|
11
11
|
keras_rs/src/layers/modeling/feature_cross.py,sha256=nd2SBIe1gXC-vL_BeF3vrq8XqTAb6iLxhQGmtyUl3EM,8105
|
|
12
12
|
keras_rs/src/layers/retrieval/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
13
13
|
keras_rs/src/layers/retrieval/brute_force_retrieval.py,sha256=mohILOt6PC6jHBztaowDbj3QBnSetuvkq55FmE39PlY,7321
|
|
14
14
|
keras_rs/src/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
15
|
-
keras_rs/src/utils/keras_utils.py,sha256=
|
|
16
|
-
keras_rs_nightly-0.0.1.
|
|
17
|
-
keras_rs_nightly-0.0.1.
|
|
18
|
-
keras_rs_nightly-0.0.1.
|
|
19
|
-
keras_rs_nightly-0.0.1.
|
|
15
|
+
keras_rs/src/utils/keras_utils.py,sha256=IjWSRieBkv7UX12qgUoI1tcOeISstCLRSTqSHpT06yE,1275
|
|
16
|
+
keras_rs_nightly-0.0.1.dev2025022503.dist-info/METADATA,sha256=1XeOVbD3f-SNXAqgMsOBkZRJb_2uXbeOZY2dC2vg0ew,3522
|
|
17
|
+
keras_rs_nightly-0.0.1.dev2025022503.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
18
|
+
keras_rs_nightly-0.0.1.dev2025022503.dist-info/top_level.txt,sha256=pWs8X78Z0cn6lfcIb9VYOW5UeJ-TpoaO9dByzo7_FFo,9
|
|
19
|
+
keras_rs_nightly-0.0.1.dev2025022503.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|