keras-rs-nightly 0.3.1.dev202510050328__tar.gz → 0.3.1.dev202510070324__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-rs-nightly might be problematic. Click here for more details.

Files changed (61) hide show
  1. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/PKG-INFO +1 -1
  2. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/layers/embedding/jax/distributed_embedding.py +34 -33
  3. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/layers/embedding/jax/embedding_utils.py +3 -110
  4. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/version.py +1 -1
  5. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs_nightly.egg-info/PKG-INFO +1 -1
  6. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/pyproject.toml +1 -1
  7. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/README.md +0 -0
  8. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/api/__init__.py +0 -0
  9. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/api/layers/__init__.py +0 -0
  10. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/api/losses/__init__.py +0 -0
  11. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/api/metrics/__init__.py +0 -0
  12. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/__init__.py +0 -0
  13. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/api_export.py +0 -0
  14. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/layers/__init__.py +0 -0
  15. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/layers/embedding/__init__.py +0 -0
  16. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/layers/embedding/base_distributed_embedding.py +0 -0
  17. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/layers/embedding/distributed_embedding.py +0 -0
  18. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/layers/embedding/distributed_embedding_config.py +0 -0
  19. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/layers/embedding/embed_reduce.py +0 -0
  20. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/layers/embedding/jax/__init__.py +0 -0
  21. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/layers/embedding/jax/checkpoint_utils.py +0 -0
  22. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/layers/embedding/jax/config_conversion.py +0 -0
  23. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/layers/embedding/jax/embedding_lookup.py +0 -0
  24. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
  25. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/layers/embedding/tensorflow/config_conversion.py +0 -0
  26. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +0 -0
  27. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/layers/feature_interaction/__init__.py +0 -0
  28. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/layers/feature_interaction/dot_interaction.py +0 -0
  29. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/layers/feature_interaction/feature_cross.py +0 -0
  30. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/layers/retrieval/__init__.py +0 -0
  31. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/layers/retrieval/brute_force_retrieval.py +0 -0
  32. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/layers/retrieval/hard_negative_mining.py +0 -0
  33. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/layers/retrieval/remove_accidental_hits.py +0 -0
  34. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/layers/retrieval/retrieval.py +0 -0
  35. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/layers/retrieval/sampling_probability_correction.py +0 -0
  36. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/losses/__init__.py +0 -0
  37. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/losses/pairwise_hinge_loss.py +0 -0
  38. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/losses/pairwise_logistic_loss.py +0 -0
  39. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/losses/pairwise_loss.py +0 -0
  40. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/losses/pairwise_loss_utils.py +0 -0
  41. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/losses/pairwise_mean_squared_error.py +0 -0
  42. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/losses/pairwise_soft_zero_one_loss.py +0 -0
  43. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/metrics/__init__.py +0 -0
  44. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/metrics/dcg.py +0 -0
  45. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/metrics/mean_average_precision.py +0 -0
  46. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/metrics/mean_reciprocal_rank.py +0 -0
  47. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/metrics/ndcg.py +0 -0
  48. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/metrics/precision_at_k.py +0 -0
  49. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/metrics/ranking_metric.py +0 -0
  50. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/metrics/ranking_metrics_utils.py +0 -0
  51. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/metrics/recall_at_k.py +0 -0
  52. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/metrics/utils.py +0 -0
  53. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/types.py +0 -0
  54. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/utils/__init__.py +0 -0
  55. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/utils/doc_string_utils.py +0 -0
  56. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs/src/utils/keras_utils.py +0 -0
  57. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs_nightly.egg-info/SOURCES.txt +0 -0
  58. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs_nightly.egg-info/dependency_links.txt +0 -0
  59. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs_nightly.egg-info/requires.txt +0 -0
  60. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/keras_rs_nightly.egg-info/top_level.txt +0 -0
  61. {keras_rs_nightly-0.3.1.dev202510050328 → keras_rs_nightly-0.3.1.dev202510070324}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-rs-nightly
3
- Version: 0.3.1.dev202510050328
3
+ Version: 0.3.1.dev202510070324
4
4
  Summary: Multi-backend recommender systems with Keras 3.
5
5
  Author-email: Keras team <keras-users@googlegroups.com>
6
6
  License: Apache License 2.0
@@ -441,7 +441,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
441
441
  )
442
442
 
443
443
  # Collect all stacked tables.
444
- table_specs = embedding_utils.get_table_specs(feature_specs)
444
+ table_specs = embedding.get_table_specs(feature_specs)
445
445
  table_stacks = embedding_utils.get_table_stacks(table_specs)
446
446
 
447
447
  # Create variables for all stacked tables and slot variables.
@@ -515,9 +515,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
515
515
  del inputs, weights, training
516
516
 
517
517
  # Each stacked-table gets a ShardedCooMatrix.
518
- table_specs = embedding_utils.get_table_specs(
519
- self._config.feature_specs
520
- )
518
+ table_specs = embedding.get_table_specs(self._config.feature_specs)
521
519
  table_stacks = embedding_utils.get_table_stacks(table_specs)
522
520
  stacked_table_specs = {
523
521
  stack_name: stack[0].stacked_table_spec
@@ -600,40 +598,43 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
600
598
  if training:
601
599
  # Synchronize input statistics across all devices and update the
602
600
  # underlying stacked tables specs in the feature specs.
603
- prev_stats = embedding_utils.get_stacked_table_stats(
604
- self._config.feature_specs
605
- )
606
601
 
607
- # Take the maximum with existing stats.
608
- stats = keras.tree.map_structure(max, prev_stats, stats)
602
+ # Aggregate stats across all processes/devices via pmax.
603
+ num_local_cpu_devices = jax.local_device_count("cpu")
609
604
 
610
- # Flatten the stats so we can more efficiently transfer them
611
- # between hosts. We use jax.tree because we will later need to
612
- # unflatten.
613
- flat_stats, stats_treedef = jax.tree.flatten(stats)
605
+ def pmax_aggregate(x: Any) -> Any:
606
+ if not hasattr(x, "ndim"):
607
+ x = np.array(x)
608
+ tiled_x = np.tile(x, (num_local_cpu_devices, *([1] * x.ndim)))
609
+ return jax.pmap(
610
+ lambda y: jax.lax.pmax(y, "all_cpus"), # type: ignore[no-untyped-call]
611
+ axis_name="all_cpus",
612
+ backend="cpu",
613
+ )(tiled_x)[0]
614
614
 
615
- # In the case of multiple local CPU devices per host, we need to
616
- # replicate the stats to placate JAX collectives.
617
- num_local_cpu_devices = jax.local_device_count("cpu")
618
- tiled_stats = np.tile(
619
- np.array(flat_stats, dtype=np.int32), (num_local_cpu_devices, 1)
620
- )
615
+ full_stats = jax.tree.map(pmax_aggregate, stats)
621
616
 
622
- # Aggregate variables across all processes/devices.
623
- max_across_cpus = jax.pmap(
624
- lambda x: jax.lax.pmax( # type: ignore[no-untyped-call]
625
- x, "all_cpus"
626
- ),
627
- axis_name="all_cpus",
628
- backend="cpu",
617
+ # Check if stats changed enough to warrant action.
618
+ stacked_table_specs = embedding.get_stacked_table_specs(
619
+ self._config.feature_specs
620
+ )
621
+ changed = any(
622
+ np.max(full_stats.max_ids_per_partition[stack_name])
623
+ > spec.max_ids_per_partition
624
+ or np.max(full_stats.max_unique_ids_per_partition[stack_name])
625
+ > spec.max_unique_ids_per_partition
626
+ or (
627
+ np.max(full_stats.required_buffer_size_per_sc[stack_name])
628
+ * num_sc_per_device
629
+ )
630
+ > (spec.suggested_coo_buffer_size_per_device or 0)
631
+ for stack_name, spec in stacked_table_specs.items()
629
632
  )
630
- flat_stats = max_across_cpus(tiled_stats)[0].tolist()
631
- stats = jax.tree.unflatten(stats_treedef, flat_stats)
632
633
 
633
634
  # Update configuration and repeat preprocessing if stats changed.
634
- if stats != prev_stats:
635
- embedding_utils.update_stacked_table_stats(
636
- self._config.feature_specs, stats
635
+ if changed:
636
+ embedding.update_preprocessing_parameters(
637
+ self._config.feature_specs, full_stats, num_sc_per_device
637
638
  )
638
639
 
639
640
  # Re-execute preprocessing with consistent input statistics.
@@ -718,7 +719,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
718
719
 
719
720
  config = self._config
720
721
  num_table_shards = config.mesh.devices.size * config.num_sc_per_device
721
- table_specs = embedding_utils.get_table_specs(config.feature_specs)
722
+ table_specs = embedding.get_table_specs(config.feature_specs)
722
723
  sharded_tables = embedding_utils.stack_and_shard_tables(
723
724
  table_specs,
724
725
  tables,
@@ -750,7 +751,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
750
751
 
751
752
  config = self._config
752
753
  num_table_shards = config.mesh.devices.size * config.num_sc_per_device
753
- table_specs = embedding_utils.get_table_specs(config.feature_specs)
754
+ table_specs = embedding.get_table_specs(config.feature_specs)
754
755
 
755
756
  # Extract only the table variables, not the gradient slot variables.
756
757
  table_variables = {
@@ -1,7 +1,6 @@
1
1
  """Utility functions for manipulating JAX embedding tables and inputs."""
2
2
 
3
3
  import collections
4
- import dataclasses
5
4
  import typing
6
5
  from typing import Any, Mapping, NamedTuple, Sequence, TypeAlias, TypeVar
7
6
 
@@ -35,12 +34,6 @@ class ShardedCooMatrix(NamedTuple):
35
34
  values: ArrayLike
36
35
 
37
36
 
38
- class InputStatsPerTable(NamedTuple):
39
- max_ids_per_partition: int
40
- max_unique_ids_per_partition: int
41
- required_buffer_size_per_device: int
42
-
43
-
44
37
  def _round_up_to_multiple(value: int, multiple: int) -> int:
45
38
  return ((value + multiple - 1) // multiple) * multiple
46
39
 
@@ -303,15 +296,6 @@ def unshard_and_unstack_tables(
303
296
  return output
304
297
 
305
298
 
306
- def get_table_specs(feature_specs: Nested[FeatureSpec]) -> dict[str, TableSpec]:
307
- table_spec_map: dict[str, TableSpec] = {}
308
- flat_feature_specs, _ = jax.tree.flatten(feature_specs)
309
- for feature_spec in flat_feature_specs:
310
- table_spec = feature_spec.table_spec
311
- table_spec_map[table_spec.name] = table_spec
312
- return table_spec_map
313
-
314
-
315
299
  def get_table_stacks(
316
300
  table_specs: Nested[TableSpec],
317
301
  ) -> dict[str, list[TableSpec]]:
@@ -341,84 +325,6 @@ def get_table_stacks(
341
325
  return stacked_table_specs
342
326
 
343
327
 
344
- def get_stacked_table_stats(
345
- feature_specs: Nested[FeatureSpec],
346
- ) -> dict[str, InputStatsPerTable]:
347
- """Extracts the stacked-table input statistics from the feature specs.
348
-
349
- Args:
350
- feature_specs: Feature specs from which to extracts the statistics.
351
-
352
- Returns:
353
- A mapping of stacked table names to input statistics per table.
354
- """
355
- stacked_table_specs: dict[str, StackedTableSpec] = {}
356
- for feature_spec in jax.tree.flatten(feature_specs)[0]:
357
- feature_spec = typing.cast(FeatureSpec, feature_spec)
358
- stacked_table_spec = typing.cast(
359
- StackedTableSpec, feature_spec.table_spec.stacked_table_spec
360
- )
361
- stacked_table_specs[stacked_table_spec.stack_name] = stacked_table_spec
362
-
363
- stats: dict[str, InputStatsPerTable] = {}
364
- for stacked_table_spec in stacked_table_specs.values():
365
- buffer_size = stacked_table_spec.suggested_coo_buffer_size_per_device
366
- buffer_size = buffer_size or 0
367
- stats[stacked_table_spec.stack_name] = InputStatsPerTable(
368
- max_ids_per_partition=stacked_table_spec.max_ids_per_partition,
369
- max_unique_ids_per_partition=stacked_table_spec.max_unique_ids_per_partition,
370
- required_buffer_size_per_device=buffer_size,
371
- )
372
-
373
- return stats
374
-
375
-
376
- def update_stacked_table_stats(
377
- feature_specs: Nested[FeatureSpec],
378
- stats: Mapping[str, InputStatsPerTable],
379
- ) -> None:
380
- """Updates stacked-table input properties in the supplied feature specs.
381
-
382
- Args:
383
- feature_specs: Feature specs to update in-place.
384
- stats: Per-stacked-table input statistics.
385
- """
386
- # Collect table specs and stacked table specs.
387
- table_specs: dict[str, TableSpec] = {}
388
- for feature_spec in jax.tree.flatten(feature_specs)[0]:
389
- feature_spec = typing.cast(FeatureSpec, feature_spec)
390
- table_specs[feature_spec.table_spec.name] = feature_spec.table_spec
391
-
392
- stacked_table_specs: dict[str, StackedTableSpec] = {}
393
- for table_spec in table_specs.values():
394
- stacked_table_spec = typing.cast(
395
- StackedTableSpec, table_spec.stacked_table_spec
396
- )
397
- stacked_table_specs[stacked_table_spec.stack_name] = stacked_table_spec
398
-
399
- # Replace fields in the stacked_table_specs.
400
- stack_names = stacked_table_specs.keys()
401
- for stack_name in stack_names:
402
- stack_stats = stats[stack_name]
403
- stacked_table_spec = stacked_table_specs[stack_name]
404
- buffer_size = stack_stats.required_buffer_size_per_device or None
405
- stacked_table_specs[stack_name] = dataclasses.replace(
406
- stacked_table_spec,
407
- max_ids_per_partition=stack_stats.max_ids_per_partition,
408
- max_unique_ids_per_partition=stack_stats.max_unique_ids_per_partition,
409
- suggested_coo_buffer_size_per_device=buffer_size,
410
- )
411
-
412
- # Insert new stacked tables into tables.
413
- for table_spec in table_specs.values():
414
- stacked_table_spec = typing.cast(
415
- StackedTableSpec, table_spec.stacked_table_spec
416
- )
417
- table_spec.stacked_table_spec = stacked_table_specs[
418
- stacked_table_spec.stack_name
419
- ]
420
-
421
-
422
328
  def convert_to_numpy(
423
329
  ragged_or_dense: np.ndarray[Any, Any] | Sequence[Sequence[Any]] | Any,
424
330
  dtype: Any,
@@ -483,7 +389,7 @@ def ones_like(
483
389
 
484
390
  Args:
485
391
  ragged_or_dense: The ragged or dense input whose shape and data-type
486
- define these same attributes of the returned array.
392
+ define these same attributes of the returned array.
487
393
  dtype: The data-type of the returned array.
488
394
 
489
395
  Returns:
@@ -567,7 +473,7 @@ def stack_and_shard_samples(
567
473
  global_device_count: int,
568
474
  num_sc_per_device: int,
569
475
  static_buffer_size: int | Mapping[str, int] | None = None,
570
- ) -> tuple[dict[str, ShardedCooMatrix], dict[str, InputStatsPerTable]]:
476
+ ) -> tuple[dict[str, ShardedCooMatrix], embedding.SparseDenseMatmulInputStats]:
571
477
  """Prepares input samples for use in embedding lookups.
572
478
 
573
479
  Args:
@@ -612,7 +518,6 @@ def stack_and_shard_samples(
612
518
  )
613
519
 
614
520
  out: dict[str, ShardedCooMatrix] = {}
615
- out_stats: dict[str, InputStatsPerTable] = {}
616
521
  tables_names = preprocessed_inputs.lhs_row_pointers.keys()
617
522
  for table_name in tables_names:
618
523
  shard_ends = preprocessed_inputs.lhs_row_pointers[table_name]
@@ -626,17 +531,5 @@ def stack_and_shard_samples(
626
531
  row_ids=preprocessed_inputs.lhs_sample_ids[table_name],
627
532
  values=preprocessed_inputs.lhs_gains[table_name],
628
533
  )
629
- out_stats[table_name] = InputStatsPerTable(
630
- max_ids_per_partition=np.max(
631
- stats.max_ids_per_partition[table_name]
632
- ),
633
- max_unique_ids_per_partition=np.max(
634
- stats.max_unique_ids_per_partition[table_name]
635
- ),
636
- required_buffer_size_per_device=np.max(
637
- stats.required_buffer_size_per_sc[table_name]
638
- )
639
- * num_sc_per_device,
640
- )
641
534
 
642
- return out, out_stats
535
+ return out, stats
@@ -1,7 +1,7 @@
1
1
  from keras_rs.src.api_export import keras_rs_export
2
2
 
3
3
  # Unique source of truth for the version number.
4
- __version__ = "0.3.1.dev202510050328"
4
+ __version__ = "0.3.1.dev202510070324"
5
5
 
6
6
 
7
7
  @keras_rs_export("keras_rs.version")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-rs-nightly
3
- Version: 0.3.1.dev202510050328
3
+ Version: 0.3.1.dev202510070324
4
4
  Summary: Multi-backend recommender systems with Keras 3.
5
5
  Author-email: Keras team <keras-users@googlegroups.com>
6
6
  License: Apache License 2.0
@@ -64,7 +64,7 @@ known-first-party = ["keras_rs"]
64
64
  [tool.mypy]
65
65
  python_version = "3.10"
66
66
  strict = "True"
67
- exclude = ["_test\\.py$", "^examples/"]
67
+ exclude = ["_test\\.py$", "^examples/", "venv/"]
68
68
  untyped_calls_exclude = ["ml_dtypes"]
69
69
  disable_error_code = ["import-untyped", "unused-ignore"]
70
70
  disallow_subclassing_any = "False"