keras-rs-nightly 0.0.1.dev2025021903__py3-none-any.whl → 0.3.1.dev202512130338__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. keras_rs/__init__.py +9 -28
  2. keras_rs/layers/__init__.py +37 -0
  3. keras_rs/losses/__init__.py +19 -0
  4. keras_rs/metrics/__init__.py +16 -0
  5. keras_rs/src/layers/embedding/base_distributed_embedding.py +1151 -0
  6. keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
  7. keras_rs/src/layers/embedding/distributed_embedding_config.py +132 -0
  8. keras_rs/src/layers/embedding/embed_reduce.py +309 -0
  9. keras_rs/src/layers/embedding/jax/__init__.py +0 -0
  10. keras_rs/src/layers/embedding/jax/checkpoint_utils.py +104 -0
  11. keras_rs/src/layers/embedding/jax/config_conversion.py +468 -0
  12. keras_rs/src/layers/embedding/jax/distributed_embedding.py +829 -0
  13. keras_rs/src/layers/embedding/jax/embedding_lookup.py +276 -0
  14. keras_rs/src/layers/embedding/jax/embedding_utils.py +217 -0
  15. keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
  16. keras_rs/src/layers/embedding/tensorflow/config_conversion.py +363 -0
  17. keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +436 -0
  18. keras_rs/src/layers/feature_interaction/__init__.py +0 -0
  19. keras_rs/src/layers/{modeling → feature_interaction}/dot_interaction.py +116 -25
  20. keras_rs/src/layers/{modeling → feature_interaction}/feature_cross.py +40 -22
  21. keras_rs/src/layers/retrieval/brute_force_retrieval.py +16 -65
  22. keras_rs/src/layers/retrieval/hard_negative_mining.py +94 -0
  23. keras_rs/src/layers/retrieval/remove_accidental_hits.py +97 -0
  24. keras_rs/src/layers/retrieval/retrieval.py +127 -0
  25. keras_rs/src/layers/retrieval/sampling_probability_correction.py +63 -0
  26. keras_rs/src/losses/__init__.py +0 -0
  27. keras_rs/src/losses/list_mle_loss.py +212 -0
  28. keras_rs/src/losses/pairwise_hinge_loss.py +90 -0
  29. keras_rs/src/losses/pairwise_logistic_loss.py +99 -0
  30. keras_rs/src/losses/pairwise_loss.py +165 -0
  31. keras_rs/src/losses/pairwise_loss_utils.py +39 -0
  32. keras_rs/src/losses/pairwise_mean_squared_error.py +133 -0
  33. keras_rs/src/losses/pairwise_soft_zero_one_loss.py +98 -0
  34. keras_rs/src/metrics/__init__.py +0 -0
  35. keras_rs/src/metrics/dcg.py +161 -0
  36. keras_rs/src/metrics/mean_average_precision.py +130 -0
  37. keras_rs/src/metrics/mean_reciprocal_rank.py +121 -0
  38. keras_rs/src/metrics/ndcg.py +197 -0
  39. keras_rs/src/metrics/precision_at_k.py +117 -0
  40. keras_rs/src/metrics/ranking_metric.py +260 -0
  41. keras_rs/src/metrics/ranking_metrics_utils.py +257 -0
  42. keras_rs/src/metrics/recall_at_k.py +108 -0
  43. keras_rs/src/metrics/utils.py +70 -0
  44. keras_rs/src/types.py +43 -14
  45. keras_rs/src/utils/doc_string_utils.py +53 -0
  46. keras_rs/src/utils/keras_utils.py +52 -3
  47. keras_rs/src/utils/tpu_test_utils.py +120 -0
  48. keras_rs/src/version.py +1 -1
  49. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/METADATA +88 -8
  50. keras_rs_nightly-0.3.1.dev202512130338.dist-info/RECORD +58 -0
  51. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/WHEEL +1 -1
  52. keras_rs/api/__init__.py +0 -9
  53. keras_rs/api/layers/__init__.py +0 -11
  54. keras_rs_nightly-0.0.1.dev2025021903.dist-info/RECORD +0 -19
  55. /keras_rs/src/layers/{modeling → embedding}/__init__.py +0 -0
  56. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,257 @@
1
+ from typing import Callable
2
+
3
+ import keras
4
+ from keras import ops
5
+
6
+ from keras_rs.src import types
7
+
8
+
9
+ def get_shuffled_indices(
10
+ shape: types.Shape,
11
+ mask: types.Tensor | None = None,
12
+ shuffle_ties: bool = True,
13
+ seed: int | keras.random.SeedGenerator | None = None,
14
+ ) -> types.Tensor:
15
+ """Utility function for getting shuffled indices, with masked indices
16
+ pushed to the end.
17
+
18
+ Args:
19
+ shape: tuple. The shape of the tensor for which to generate
20
+ shuffled indices.
21
+ mask: An optional boolean tensor with the same shape as `shape`.
22
+ If provided, elements where `mask` is `False` will be placed
23
+ at the end of the sorted indices. Defaults to `None` (no masking).
24
+ shuffle_ties: Boolean indicating how to handle ties if multiple elements
25
+ have the same sorting value (randomly when `shuffle_ties` is True
26
+ otherwise, order is preserved).
27
+ seed: Optional integer seed for the random number generator used when
28
+ `shuffle_ties` is True. Ensures reproducibility. Defaults to None.
29
+
30
+ Returns:
31
+ A tensor of shape `shape` containing shuffled indices.
32
+ """
33
+ # If `shuffle_ties` is True, generate random values. Otherwise, generate
34
+ # zeros so that we get `[0, 1, 2, ...]` as indices on doing `argsort`.
35
+ if shuffle_ties:
36
+ shuffle_values = keras.random.uniform(shape, seed=seed, dtype="float32")
37
+ else:
38
+ shuffle_values = ops.zeros(shape, dtype="float32")
39
+
40
+ # When `mask = False`, increase value by 1 so that those indices are placed
41
+ # at the end. Note that `shuffle_values` lies in the range `[0, 1)`, so
42
+ # adding by 1 works out.
43
+ if mask is not None:
44
+ shuffle_values = ops.where(
45
+ mask,
46
+ shuffle_values,
47
+ ops.add(shuffle_values, ops.cast(1, dtype="float32")),
48
+ )
49
+
50
+ shuffled_indices = ops.argsort(shuffle_values)
51
+ return shuffled_indices
52
+
53
+
54
+ def sort_by_scores(
55
+ tensors_to_sort: list[types.Tensor],
56
+ scores: types.Tensor,
57
+ mask: types.Tensor | None = None,
58
+ k: int | None = None,
59
+ shuffle_ties: bool = True,
60
+ seed: int | keras.random.SeedGenerator | None = None,
61
+ ) -> types.Tensor:
62
+ """
63
+ Utility function for sorting tensors by scores.
64
+
65
+ Args:
66
+ tensors_to_sort: list of tensors. All tensors are of shape
67
+ `(batch_size, list_size)`. These tensors are sorted based on
68
+ `scores`.
69
+ scores: tensor. Of shape `(batch_size, list_size)`. The scores to sort
70
+ by.
71
+ k: int. The number of top-ranked items to consider (the 'k' in 'top-k').
72
+ If `None`, `list_size` is used.
73
+ shuffle_ties: bool. Whether to randomly shuffle scores before sorting.
74
+ This is done to break ties.
75
+ seed: int. Seed for shuffling.
76
+
77
+ Returns:
78
+ List of sorted tensors (`tensors_to_sort`), sorted using `scores`.
79
+ """
80
+ max_possible_k = ops.shape(scores)[1]
81
+ if k is None:
82
+ k = max_possible_k
83
+ elif isinstance(max_possible_k, int):
84
+ k = min(k, max_possible_k)
85
+ else:
86
+ k = ops.minimum(k, max_possible_k)
87
+
88
+ # --- Work around for PyTorch instability ---
89
+ # Torch's `topk` is not stable with `sorted=True`, unlike JAX and TF.
90
+ # See:
91
+ # - https://github.com/pytorch/pytorch/issues/27542
92
+ # - https://github.com/pytorch/pytorch/issues/88227
93
+ #
94
+ # This small "stable offset" ensures deterministic tie-breaking for
95
+ # equal scores. We can remove this workaround once PyTorch adds a
96
+ # `stable=True` flag for topk.
97
+
98
+ if keras.backend.backend() == "torch" and not shuffle_ties:
99
+ list_size = ops.shape(scores)[1]
100
+ indices = ops.arange(list_size)
101
+ indices = ops.expand_dims(indices, axis=0)
102
+ indices = ops.broadcast_to(indices, ops.shape(scores))
103
+ stable_offset = ops.cast(indices, scores.dtype) * 1e-6
104
+ scores = ops.subtract(scores, stable_offset)
105
+ # --- End FIX ---
106
+
107
+ # Shuffle ties randomly, and push masked values to the beginning.
108
+ shuffled_indices = None
109
+ if shuffle_ties or mask is not None:
110
+ shuffled_indices = get_shuffled_indices(
111
+ ops.shape(scores),
112
+ mask=mask,
113
+ shuffle_ties=True,
114
+ seed=seed,
115
+ )
116
+ scores = ops.take_along_axis(scores, shuffled_indices, axis=1)
117
+
118
+ # Get top-k indices.
119
+ _, indices = ops.top_k(scores, k=k, sorted=True)
120
+
121
+ # If we shuffled our `scores` tensor, we need to get the correct indices
122
+ # by indexing into `shuffled_indices`.
123
+ if shuffled_indices is not None:
124
+ indices = ops.take_along_axis(shuffled_indices, indices, axis=1)
125
+
126
+ return [
127
+ ops.take_along_axis(tensor_to_sort, indices, axis=1)
128
+ for tensor_to_sort in tensors_to_sort
129
+ ]
130
+
131
+
132
+ def get_list_weights(
133
+ weights: types.Tensor, relevance: types.Tensor
134
+ ) -> types.Tensor:
135
+ """Computes per-list weights from provided sample weights.
136
+
137
+ Per-list weights are calculated as follows:
138
+ ```
139
+ per_list_weights = sum(weights * relevance) / sum(relevance).
140
+ ```
141
+
142
+ For lists where the sum of relevance is 0, a default weight is assigned:
143
+ ```
144
+ sum(per_list_weights) / num(sum(relevance) != 0 AND sum(weights) != 0)
145
+ ```
146
+
147
+ If all lists have a sum of relevance equal to 0, the default weight is 1.0.
148
+
149
+ As a result of the above computation, this function takes care of the
150
+ following cases:
151
+ - **Uniform Weights:** When all input weights are 1.0, all per-list weights
152
+ will be 1.0, even for lists with no relevant examples. This aligns with
153
+ standard ranking metrics.
154
+ - **Non-zero Weights per List:** If every list has at least one non-zero
155
+ weight, the default weight mechanism is not utilized, which is suitable
156
+ for unbiased metrics.
157
+ - **Mixed Scenarios:** For cases with a mix of lists having zero and
158
+ non-zero relevance and weights, the weights for lists with non-zero
159
+ relevance and weights are proportional to:
160
+
161
+ ```
162
+ per_list_weights / sum(per_list_weights) *
163
+ num(sum(relevance) != 0) / num(lists)
164
+ ```
165
+
166
+ The rest have weights `1.0 / num(lists)`.
167
+
168
+ Args:
169
+ weights: tensor. Weights tensor of shape `(batch_size, list_size)`.
170
+ relevance: tensor. The relevance `Tensor` of shape
171
+ `(batch_size, list_size)`.
172
+
173
+ Returns:
174
+ A tensor of shape `(batch_size, 1)`, containing the per-list weights.
175
+ """
176
+ # Calculate if the sum of weights per list is greater than 0.0.
177
+ nonzero_weights = ops.greater(ops.sum(weights, axis=1, keepdims=True), 0.0)
178
+ # Calculate the sum of relevance per list
179
+ per_list_relevance = ops.sum(relevance, axis=1, keepdims=True)
180
+ # Calculate if the sum of relevance per list is greater than 0.0
181
+ nonzero_relevance_condition = ops.greater(per_list_relevance, 0.0)
182
+ # Identify lists where both weights and relevance sums are non-zero.
183
+ nonzero_relevance = ops.cast(
184
+ ops.logical_and(nonzero_weights, nonzero_relevance_condition),
185
+ dtype=weights.dtype,
186
+ )
187
+ # Count the number of lists with non-zero relevance and non-zero weights.
188
+ nonzero_relevance_count = ops.sum(nonzero_relevance, axis=0, keepdims=True)
189
+
190
+ # Calculate the per-list weights using the core formula.
191
+ # Numerator: `sum(weights * relevance)` per list
192
+ numerator = ops.sum(ops.multiply(weights, relevance), axis=1, keepdims=True)
193
+ # Denominator: per_list_relevance = sum(relevance) per list
194
+ per_list_weights = ops.divide_no_nan(numerator, per_list_relevance)
195
+
196
+ # Calculate the sum of the computed per-list weights.
197
+ sum_weights = ops.sum(per_list_weights, axis=0, keepdims=True)
198
+
199
+ # Calculate the average weight to use as default for lists with zero
200
+ # relevance but non-zero weights. If no lists have non-zero relevance,
201
+ # default to 1.0.
202
+ avg_weight = ops.where(
203
+ ops.greater(nonzero_relevance_count, 0.0),
204
+ ops.divide(sum_weights, nonzero_relevance_count),
205
+ ops.cast(1, dtype=sum_weights.dtype),
206
+ )
207
+
208
+ # Final assignment of weights based on conditions:
209
+ # 1. If sum(weights) == 0 for a list, the final weight is 0.
210
+ # 2. If sum(weights) > 0 AND sum(relevance) > 0, use the calculated
211
+ # `per_list_weights`.
212
+ # 3. If `sum(weights) > 0` AND `sum(relevance) == 0`, use the calculated
213
+ # `avg_weight`.
214
+ final_weights = ops.where(
215
+ nonzero_weights,
216
+ ops.where(
217
+ nonzero_relevance_condition,
218
+ per_list_weights,
219
+ avg_weight,
220
+ ),
221
+ ops.cast(0, dtype=per_list_weights.dtype),
222
+ )
223
+
224
+ return final_weights
225
+
226
+
227
+ @keras.saving.register_keras_serializable() # type: ignore[untyped-decorator]
228
+ def default_gain_fn(label: types.Tensor) -> types.Tensor:
229
+ return ops.subtract(ops.power(2.0, label), 1.0)
230
+
231
+
232
+ @keras.saving.register_keras_serializable() # type: ignore[untyped-decorator]
233
+ def default_rank_discount_fn(rank: types.Tensor) -> types.Tensor:
234
+ return ops.divide(
235
+ ops.cast(1, dtype=rank.dtype),
236
+ ops.log2(ops.add(ops.cast(1, dtype=rank.dtype), rank)),
237
+ )
238
+
239
+
240
+ def compute_dcg(
241
+ y_true: types.Tensor,
242
+ sample_weight: types.Tensor,
243
+ gain_fn: Callable[[types.Tensor], types.Tensor] = default_gain_fn,
244
+ rank_discount_fn: Callable[
245
+ [types.Tensor], types.Tensor
246
+ ] = default_rank_discount_fn,
247
+ ) -> types.Tensor:
248
+ list_size = ops.shape(y_true)[1]
249
+ positions = ops.arange(1, list_size + 1, dtype=y_true.dtype)
250
+ gain = gain_fn(y_true)
251
+ discount = rank_discount_fn(positions)
252
+
253
+ return ops.sum(
254
+ ops.multiply(sample_weight, ops.multiply(gain, discount)),
255
+ axis=1,
256
+ keepdims=True,
257
+ )
@@ -0,0 +1,108 @@
1
+ from keras import ops
2
+
3
+ from keras_rs.src import types
4
+ from keras_rs.src.api_export import keras_rs_export
5
+ from keras_rs.src.metrics.ranking_metric import RankingMetric
6
+ from keras_rs.src.metrics.ranking_metric import (
7
+ ranking_metric_subclass_doc_string,
8
+ )
9
+ from keras_rs.src.metrics.ranking_metric import (
10
+ ranking_metric_subclass_doc_string_post_desc,
11
+ )
12
+ from keras_rs.src.metrics.ranking_metrics_utils import get_list_weights
13
+ from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
14
+ from keras_rs.src.utils.doc_string_utils import format_docstring
15
+
16
+
17
+ @keras_rs_export("keras_rs.metrics.RecallAtK")
18
+ class RecallAtK(RankingMetric):
19
+ def compute_metric(
20
+ self,
21
+ y_true: types.Tensor,
22
+ y_pred: types.Tensor,
23
+ mask: types.Tensor,
24
+ sample_weight: types.Tensor,
25
+ ) -> types.Tensor:
26
+ # Assume: `y_true = [0, 0, 1]`, `y_pred = [0.1, 0.9, 0.2]`.
27
+ # `sorted_y_true = [0, 1, 0]` (sorted in descending order).
28
+ (sorted_y_true,) = sort_by_scores(
29
+ tensors_to_sort=[y_true],
30
+ scores=y_pred,
31
+ mask=mask,
32
+ k=self.k,
33
+ shuffle_ties=self.shuffle_ties,
34
+ seed=self.seed_generator,
35
+ )
36
+
37
+ relevance = ops.cast(
38
+ ops.greater_equal(
39
+ sorted_y_true, ops.cast(1, dtype=sorted_y_true.dtype)
40
+ ),
41
+ dtype=y_pred.dtype,
42
+ )
43
+ overall_relevance = ops.cast(
44
+ ops.greater_equal(y_true, ops.cast(1, dtype=y_true.dtype)),
45
+ dtype=y_pred.dtype,
46
+ )
47
+ per_list_recall = ops.divide_no_nan(
48
+ ops.sum(relevance, axis=1, keepdims=True),
49
+ ops.sum(overall_relevance, axis=1, keepdims=True),
50
+ )
51
+
52
+ # Get weights.
53
+ per_list_weights = get_list_weights(
54
+ weights=sample_weight, relevance=overall_relevance
55
+ )
56
+
57
+ return per_list_recall, per_list_weights
58
+
59
+
60
+ concept_sentence = (
61
+ "It measures the proportion of relevant items found in the top-k "
62
+ "recommendations out of the total number of relevant items for a user"
63
+ )
64
+ relevance_type = "binary indicators (0 or 1) of relevance"
65
+ score_range_interpretation = (
66
+ "Scores range from 0 to 1, with 1 indicating that all relevant items "
67
+ "for the user were found within the top-k recommendations"
68
+ )
69
+ formula = """```
70
+ R@k(y, s) = sum_i I[rank(s_i) < k] y_i / sum_j y_j
71
+ ```
72
+
73
+ where `y_i` is the relevance label (0/1) of the item ranked at position
74
+ `i`, `I[condition]` is 1 if the condition is met, otherwise 0."""
75
+ extra_args = ""
76
+ example = """
77
+ >>> batch_size = 2
78
+ >>> list_size = 5
79
+ >>> labels = np.random.randint(0, 2, size=(batch_size, list_size))
80
+ >>> scores = np.random.random(size=(batch_size, list_size))
81
+ >>> metric = keras_rs.metrics.RecallAtK()(
82
+ ... y_true=labels, y_pred=scores
83
+ ... )
84
+
85
+ Mask certain elements (can be used for uneven inputs):
86
+
87
+ >>> batch_size = 2
88
+ >>> list_size = 5
89
+ >>> labels = np.random.randint(0, 2, size=(batch_size, list_size))
90
+ >>> scores = np.random.random(size=(batch_size, list_size))
91
+ >>> mask = np.random.randint(0, 2, size=(batch_size, list_size), dtype=bool)
92
+ >>> metric = keras_rs.metrics.RecallAtK()(
93
+ ... y_true={"labels": labels, "mask": mask}, y_pred=scores
94
+ ... )
95
+ """
96
+
97
+ RecallAtK.__doc__ = format_docstring(
98
+ ranking_metric_subclass_doc_string,
99
+ width=80,
100
+ metric_name="Recall@k",
101
+ metric_abbreviation="R@k",
102
+ concept_sentence=concept_sentence,
103
+ relevance_type=relevance_type,
104
+ score_range_interpretation=score_range_interpretation,
105
+ formula=formula,
106
+ ) + ranking_metric_subclass_doc_string_post_desc.format(
107
+ extra_args=extra_args, example=example
108
+ )
@@ -0,0 +1,70 @@
1
+ from keras import ops
2
+
3
+ from keras_rs.src import types
4
+ from keras_rs.src.utils.keras_utils import check_rank
5
+ from keras_rs.src.utils.keras_utils import check_shapes_compatible
6
+
7
+
8
+ def standardize_call_inputs_ranks(
9
+ y_true: types.Tensor,
10
+ y_pred: types.Tensor,
11
+ mask: types.Tensor | None = None,
12
+ check_y_true_rank: bool = True,
13
+ ) -> tuple[types.Tensor, types.Tensor, types.Tensor | None, bool]:
14
+ """
15
+ Utility function for processing inputs for losses and metrics.
16
+
17
+ This utility function does three things:
18
+
19
+ - Checks that `y_true`, `y_pred` are of rank 1 or 2;
20
+ - Checks that `y_true`, `y_pred`, `mask` have the same shape;
21
+ - Adds batch dimension if rank = 1.
22
+
23
+ Args:
24
+ y_true: tensor. Ground truth values.
25
+ y_pred: tensor. The predicted values.
26
+ mask: tensor. Boolean mask for `y_true`.
27
+ check_y_true_rank: bool. Whether to check the rank of `y_true`.
28
+
29
+ Returns:
30
+ Tuple of processed `y_true`, `y_pred`, `mask`, and `batched`. `batched`
31
+ is a bool indicating if the inputs are batched.
32
+ """
33
+
34
+ y_true_shape = ops.shape(y_true)
35
+ y_true_rank = len(y_true_shape)
36
+ y_pred_shape = ops.shape(y_pred)
37
+ y_pred_rank = len(y_pred_shape)
38
+ if mask is not None:
39
+ mask_shape = ops.shape(mask)
40
+ mask_rank = len(mask_shape)
41
+
42
+ if check_y_true_rank:
43
+ check_rank(y_true_rank, allowed_ranks=(1, 2), tensor_name="y_true")
44
+ check_rank(y_pred_rank, allowed_ranks=(1, 2), tensor_name="y_pred")
45
+ if mask is not None:
46
+ check_rank(mask_rank, allowed_ranks=(1, 2), tensor_name="mask")
47
+ if not check_shapes_compatible(y_true_shape, y_pred_shape):
48
+ raise ValueError(
49
+ "`y_true` and `y_pred` should have the same shape. Received: "
50
+ f"`y_true.shape` = {y_true_shape}, `y_pred.shape` = {y_pred_shape}."
51
+ )
52
+ if mask is not None and not check_shapes_compatible(
53
+ y_true_shape, mask_shape
54
+ ):
55
+ raise ValueError(
56
+ "`y_true['labels']` and `y_true['mask']` should have the same "
57
+ f"shape. Received: `y_true['labels'].shape` = {y_true_shape}, "
58
+ f"`y_true['mask'].shape` = {mask_shape}."
59
+ )
60
+
61
+ batched = True
62
+ if y_true_rank == 1:
63
+ batched = False
64
+
65
+ y_true = ops.expand_dims(y_true, axis=0)
66
+ y_pred = ops.expand_dims(y_pred, axis=0)
67
+ if mask is not None:
68
+ mask = ops.expand_dims(mask, axis=0)
69
+
70
+ return y_true, y_pred, mask, batched
keras_rs/src/types.py CHANGED
@@ -1,6 +1,8 @@
1
1
  """Type definitions."""
2
2
 
3
- from typing import Any, Optional, Sequence
3
+ from typing import Any, Callable, Mapping, Sequence, TypeAlias, TypeVar, Union
4
+
5
+ import keras
4
6
 
5
7
  """
6
8
  A tensor in any of the backends.
@@ -8,19 +10,46 @@ A tensor in any of the backends.
8
10
  We do not define it explicitly to not require all the backends to be installed
9
11
  and imported. The explicit definition would be:
10
12
  ```
11
- Union[
12
- numpy.ndarray,
13
- tensorflow.Tensor,
14
- tensorflow.RaggedTensor,
15
- tensorflow.SparseTensor,
16
- tensorflow.IndexedSlices,
17
- jax.Array,
18
- jax.experimental.sparse.JAXSparse,
19
- torch.Tensor,
20
- keras.KerasTensor,
21
- ]
13
+ numpy.ndarray,
14
+ | tensorflow.Tensor,
15
+ | tensorflow.RaggedTensor,
16
+ | tensorflow.SparseTensor,
17
+ | tensorflow.IndexedSlices,
18
+ | jax.Array,
19
+ | jax.experimental.sparse.JAXSparse,
20
+ | torch.Tensor,
21
+ | keras.KerasTensor,
22
22
  ```
23
23
  """
24
- Tensor = Any
24
+ Tensor: TypeAlias = Any
25
+
26
+ Shape: TypeAlias = Sequence[int | None]
27
+
28
+ DType: TypeAlias = str
29
+
30
+ ConstraintLike: TypeAlias = (
31
+ str
32
+ | keras.constraints.Constraint
33
+ | type[keras.constraints.Constraint]
34
+ | Callable[[Tensor], Tensor]
35
+ )
36
+
37
+ InitializerLike: TypeAlias = (
38
+ str
39
+ | keras.initializers.Initializer
40
+ | type[keras.initializers.Initializer]
41
+ | Callable[[Shape, DType], Tensor]
42
+ | Tensor
43
+ )
44
+
45
+ RegularizerLike: TypeAlias = (
46
+ str
47
+ | keras.regularizers.Regularizer
48
+ | type[keras.regularizers.Regularizer]
49
+ | Callable[[Tensor], Tensor]
50
+ )
25
51
 
26
- TensorShape = Sequence[Optional[int]]
52
+ T = TypeVar("T")
53
+ Nested: TypeAlias = (
54
+ T | Sequence[Union[T, "Nested[T]"]] | Mapping[str, Union[T, "Nested[T]"]]
55
+ )
@@ -0,0 +1,53 @@
1
+ import re
2
+ import textwrap
3
+ from typing import Any
4
+
5
+
6
+ def format_docstring(template: str, width: int = 80, **kwargs: Any) -> str:
7
+ """Formats and wraps a docstring using dedent and fill."""
8
+ base_indent_str = " " * 4
9
+
10
+ # Initial format
11
+ formatted = template.format(**kwargs)
12
+
13
+ # Dedent the whole block
14
+ dedented_all = textwrap.dedent(formatted).strip()
15
+
16
+ # Split into logical paragraphs/blocks.
17
+ blocks = re.split(r"(\n\s*\n)", dedented_all)
18
+
19
+ processed_output = []
20
+
21
+ for block in blocks:
22
+ stripped_block = block.strip()
23
+ if not stripped_block:
24
+ processed_output.append(block)
25
+ continue
26
+
27
+ if "```" in stripped_block:
28
+ formula_dedented = textwrap.dedent(stripped_block)
29
+ processed_output.append(
30
+ textwrap.indent(formula_dedented, base_indent_str)
31
+ )
32
+ elif "where:" in stripped_block:
33
+ # Expect this to be already indented.
34
+ splitted_block = stripped_block.split("\n")
35
+ processed_output.append(
36
+ textwrap.indent(
37
+ splitted_block[0] + "\n\n" + "\n".join(splitted_block[1:]),
38
+ base_indent_str,
39
+ )
40
+ )
41
+ else:
42
+ processed_output.append(
43
+ textwrap.fill(
44
+ stripped_block,
45
+ width=width - len(base_indent_str),
46
+ initial_indent=base_indent_str,
47
+ subsequent_indent=base_indent_str,
48
+ )
49
+ )
50
+
51
+ final_string = "".join(processed_output).strip()
52
+ final_string = base_indent_str + final_string
53
+ return final_string
@@ -1,11 +1,35 @@
1
- from typing import Union
1
+ from typing import Any, Callable
2
2
 
3
3
  import keras
4
4
 
5
+ from keras_rs.src import types
6
+
7
+
8
+ def no_automatic_dependency_tracking(
9
+ fn: Callable[..., Any],
10
+ ) -> Callable[..., Any]:
11
+ """Decorator to disable automatic dependency tracking in Keras and TF.
12
+
13
+ Args:
14
+ fn: the function to disable automatic dependency tracking for.
15
+
16
+ Returns:
17
+ a wrapped version of `fn`.
18
+ """
19
+ if keras.backend.backend() == "tensorflow":
20
+ import tensorflow as tf
21
+
22
+ fn = tf.__internal__.tracking.no_automatic_dependency_tracking(fn)
23
+
24
+ wrapped_fn: Callable[..., Any] = (
25
+ keras.src.utils.tracking.no_automatic_dependency_tracking(fn)
26
+ )
27
+ return wrapped_fn
28
+
5
29
 
6
30
  def clone_initializer(
7
- initializer: Union[str, keras.initializers.Initializer],
8
- ) -> keras.initializers.Initializer:
31
+ initializer: types.InitializerLike,
32
+ ) -> types.InitializerLike:
9
33
  """Clones an initializer to ensure a new seed.
10
34
 
11
35
  Args:
@@ -25,3 +49,28 @@ def clone_initializer(
25
49
  return initializer_class.from_config(config)
26
50
  # If we get a string or dict, just return as we cannot and should not clone.
27
51
  return initializer
52
+
53
+
54
+ def check_shapes_compatible(shape1: types.Shape, shape2: types.Shape) -> bool:
55
+ # Check rank first.
56
+ if len(shape1) != len(shape2):
57
+ return False
58
+
59
+ for d1, d2 in zip(shape1, shape2):
60
+ if isinstance(d1, int) and isinstance(d2, int):
61
+ if d1 != d2:
62
+ return False
63
+
64
+ return True
65
+
66
+
67
+ def check_rank(
68
+ x_rank: int,
69
+ allowed_ranks: tuple[int, ...],
70
+ tensor_name: str,
71
+ ) -> None:
72
+ if x_rank not in allowed_ranks:
73
+ raise ValueError(
74
+ f"`{tensor_name}` should have a rank from `{allowed_ranks}`."
75
+ f"Received: `{x_rank}`."
76
+ )