keras-rs-nightly 0.0.1.dev2025022303__py3-none-any.whl → 0.0.1.dev2025022503__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-rs-nightly might be problematic. Click here for more details.

@@ -5,6 +5,7 @@ from keras import ops
5
5
 
6
6
  from keras_rs.src import types
7
7
  from keras_rs.src.api_export import keras_rs_export
8
+ from keras_rs.src.utils.keras_utils import check_shapes_compatible
8
9
 
9
10
 
10
11
  @keras_rs_export("keras_rs.layers.DotInteraction")
@@ -44,6 +45,44 @@ class DotInteraction(keras.layers.Layer):
44
45
  self.self_interaction = self_interaction
45
46
  self.skip_gather = skip_gather
46
47
 
48
+ def _generate_tril_mask(
49
+ self, pairwise_interaction_matrix: types.Tensor
50
+ ) -> types.Tensor:
51
+ """Generates lower triangular mask."""
52
+
53
+ # If `self.self_interaction` is `True`, keep the main diagonal.
54
+ k = -1
55
+ if self.self_interaction:
56
+ k = 0
57
+
58
+ # Typecast k from Python int to tensor, because `ops.tril` uses
59
+ # `tf.cond` (which requires tensors).
60
+ # TODO (abheesht): Remove typecast once fix is merged in core Keras.
61
+ if keras.config.backend() == "tensorflow":
62
+ k = ops.array(k)
63
+ tril_mask = ops.tril(
64
+ ops.ones_like(pairwise_interaction_matrix, dtype=bool),
65
+ k=k,
66
+ )
67
+
68
+ return tril_mask
69
+
70
+ def _get_lower_triangular_indices(self, num_features: int) -> list[int]:
71
+ """Python function which generates indices to get the lower triangular
72
+ matrix as if it were flattened.
73
+ """
74
+ flattened_indices = []
75
+ for i in range(num_features):
76
+ k = i
77
+ # if `self.self_interaction` is `True`, keep the main diagonal.
78
+ if self.self_interaction:
79
+ k += 1
80
+ for j in range(k):
81
+ flattened_index = i * num_features + j
82
+ flattened_indices.append(flattened_index)
83
+
84
+ return flattened_indices
85
+
47
86
  def call(self, inputs: list[types.Tensor]) -> types.Tensor:
48
87
  """Forward pass of the dot interaction layer.
49
88
 
@@ -64,23 +103,25 @@ class DotInteraction(keras.layers.Layer):
64
103
  # Check if all feature tensors have the same shape and are of rank 2.
65
104
  shape = ops.shape(inputs[0])
66
105
  for idx, tensor in enumerate(inputs):
67
- if ops.shape(tensor) != shape:
106
+ other_shape = ops.shape(tensor)
107
+
108
+ if len(shape) != 2:
109
+ raise ValueError(
110
+ "All feature tensors inside `inputs` should have rank 2. "
111
+ f"Received rank {len(shape)} at index {idx}."
112
+ )
113
+
114
+ if not check_shapes_compatible(shape, other_shape):
68
115
  raise ValueError(
69
116
  "All feature tensors in `inputs` should have the same "
70
117
  f"shape. Found at least one conflict: shape = {shape} at "
71
- f"index 0 and shape = {ops.shape(tensor)} at index {idx}"
118
+ f"index 0 and shape = {ops.shape(tensor)} at index {idx}."
72
119
  )
73
120
 
74
- if len(shape) != 2:
75
- raise ValueError(
76
- "All feature tensors inside `inputs` should have rank 2. "
77
- f"Received rank {len(shape)}."
78
- )
79
-
80
121
  # `(batch_size, num_features, feature_dim)`
81
122
  features = ops.stack(inputs, axis=1)
82
123
 
83
- batch_size, _, _ = ops.shape(features)
124
+ batch_size, num_features, _ = ops.shape(features)
84
125
 
85
126
  # Compute the dot product to get feature interactions. The shape here is
86
127
  # `(batch_size, num_features, num_features)`.
@@ -88,28 +129,30 @@ class DotInteraction(keras.layers.Layer):
88
129
  features, ops.transpose(features, axes=(0, 2, 1))
89
130
  )
90
131
 
91
- # if `self.self_interaction` is `True`, keep the main diagonal.
92
- k = -1
93
- if self.self_interaction:
94
- k = 0
95
-
96
- tril_mask = ops.tril(
97
- ops.ones_like(pairwise_interaction_matrix, dtype=bool),
98
- k=k,
99
- )
100
-
101
132
  # Set the upper triangle entries to 0, if `self.skip_gather` is True.
102
133
  # Else, "pick" only the lower triangle entries.
103
134
  if self.skip_gather:
135
+ tril_mask = self._generate_tril_mask(pairwise_interaction_matrix)
136
+
104
137
  activations = ops.multiply(
105
138
  pairwise_interaction_matrix,
106
139
  ops.cast(tril_mask, dtype=pairwise_interaction_matrix.dtype),
107
140
  )
141
+ # Rank-2 tensor.
142
+ activations = ops.reshape(
143
+ activations, (batch_size, num_features * num_features)
144
+ )
108
145
  else:
109
- activations = pairwise_interaction_matrix[tril_mask]
110
-
111
- # Rank-2 tensor.
112
- activations = ops.reshape(activations, (batch_size, -1))
146
+ flattened_indices = self._get_lower_triangular_indices(num_features)
147
+ pairwise_interaction_matrix_flattened = ops.reshape(
148
+ pairwise_interaction_matrix,
149
+ (batch_size, num_features * num_features),
150
+ )
151
+ activations = ops.take(
152
+ pairwise_interaction_matrix_flattened,
153
+ flattened_indices,
154
+ axis=-1,
155
+ )
113
156
 
114
157
  return activations
115
158
 
@@ -2,6 +2,8 @@ from typing import Union
2
2
 
3
3
  import keras
4
4
 
5
+ from keras_rs.src import types
6
+
5
7
 
6
8
  def clone_initializer(
7
9
  initializer: Union[str, keras.initializers.Initializer],
@@ -25,3 +27,18 @@ def clone_initializer(
25
27
  return initializer_class.from_config(config)
26
28
  # If we get a string or dict, just return as we cannot and should not clone.
27
29
  return initializer
30
+
31
+
32
+ def check_shapes_compatible(
33
+ shape1: types.TensorShape, shape2: types.TensorShape
34
+ ) -> bool:
35
+ # Check rank first.
36
+ if len(shape1) != len(shape2):
37
+ return False
38
+
39
+ for d1, d2 in zip(shape1, shape2):
40
+ if isinstance(d1, int) and isinstance(d2, int):
41
+ if d1 != d2:
42
+ return False
43
+
44
+ return True
keras_rs/src/version.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from keras_rs.src.api_export import keras_rs_export
2
2
 
3
3
  # Unique source of truth for the version number.
4
- __version__ = "0.0.1.dev2025022303"
4
+ __version__ = "0.0.1.dev2025022503"
5
5
 
6
6
 
7
7
  @keras_rs_export("keras_rs.version")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: keras-rs-nightly
3
- Version: 0.0.1.dev2025022303
3
+ Version: 0.0.1.dev2025022503
4
4
  Summary: Multi-backend recommender systems with Keras 3.
5
5
  Author-email: Keras RS team <keras-rs@google.com>
6
6
  License: Apache License 2.0
@@ -4,16 +4,16 @@ keras_rs/api/layers/__init__.py,sha256=moxzBqnCCQsZiIgEUXn9mdtSALFTU5e0c4Ct8c0bW
4
4
  keras_rs/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
5
  keras_rs/src/api_export.py,sha256=RsmG-DvO-cdFeAF9W6LRzms0kvtm-Yp9BAA_d-952zI,510
6
6
  keras_rs/src/types.py,sha256=UyOdgjqrqg_b58opnY8n6gTiDHKVR8z_bmEruehERBk,514
7
- keras_rs/src/version.py,sha256=cqw3Oj-fSCIXtwptq6VfGiNxxoW5tmFVqK53TwkAWjM,222
7
+ keras_rs/src/version.py,sha256=MDAVWC9bHK8-Jwx-8Ya6HC2AHweZmT9LsoohVj41IQM,222
8
8
  keras_rs/src/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
9
  keras_rs/src/layers/modeling/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
- keras_rs/src/layers/modeling/dot_interaction.py,sha256=kzxs9IUI-llDfgqy-JUr-zBez9pA7A3zIG-O0lGxwWY,5097
10
+ keras_rs/src/layers/modeling/dot_interaction.py,sha256=Ove9USc5wFFICHfgkTfJuZjSWt7G53K3eE0HtwAO-Dw,6850
11
11
  keras_rs/src/layers/modeling/feature_cross.py,sha256=nd2SBIe1gXC-vL_BeF3vrq8XqTAb6iLxhQGmtyUl3EM,8105
12
12
  keras_rs/src/layers/retrieval/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
13
  keras_rs/src/layers/retrieval/brute_force_retrieval.py,sha256=mohILOt6PC6jHBztaowDbj3QBnSetuvkq55FmE39PlY,7321
14
14
  keras_rs/src/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
- keras_rs/src/utils/keras_utils.py,sha256=TUEpKHWDDBmzXJ0g0cTWYwQNjsHkPQ5HMvzNshLDOSU,897
16
- keras_rs_nightly-0.0.1.dev2025022303.dist-info/METADATA,sha256=cd-0g1CXChmPiUzV3iKlIXZ99RKxRJvZJAEc2qYHfvs,3522
17
- keras_rs_nightly-0.0.1.dev2025022303.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
18
- keras_rs_nightly-0.0.1.dev2025022303.dist-info/top_level.txt,sha256=pWs8X78Z0cn6lfcIb9VYOW5UeJ-TpoaO9dByzo7_FFo,9
19
- keras_rs_nightly-0.0.1.dev2025022303.dist-info/RECORD,,
15
+ keras_rs/src/utils/keras_utils.py,sha256=IjWSRieBkv7UX12qgUoI1tcOeISstCLRSTqSHpT06yE,1275
16
+ keras_rs_nightly-0.0.1.dev2025022503.dist-info/METADATA,sha256=1XeOVbD3f-SNXAqgMsOBkZRJb_2uXbeOZY2dC2vg0ew,3522
17
+ keras_rs_nightly-0.0.1.dev2025022503.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
18
+ keras_rs_nightly-0.0.1.dev2025022503.dist-info/top_level.txt,sha256=pWs8X78Z0cn6lfcIb9VYOW5UeJ-TpoaO9dByzo7_FFo,9
19
+ keras_rs_nightly-0.0.1.dev2025022503.dist-info/RECORD,,