keras-rs-nightly 0.3.1.dev202510060327__tar.gz → 0.3.1.dev202510080323__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-rs-nightly might be problematic. Click here for more details.
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/PKG-INFO +1 -1
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/layers/embedding/jax/distributed_embedding.py +34 -33
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/layers/embedding/jax/embedding_utils.py +3 -110
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs_nightly.egg-info/PKG-INFO +1 -1
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/pyproject.toml +1 -1
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/README.md +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/api/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/api/layers/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/api/losses/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/api/metrics/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/api_export.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/layers/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/layers/embedding/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/layers/embedding/base_distributed_embedding.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/layers/embedding/distributed_embedding.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/layers/embedding/distributed_embedding_config.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/layers/embedding/embed_reduce.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/layers/embedding/jax/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/layers/embedding/jax/checkpoint_utils.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/layers/embedding/jax/config_conversion.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/layers/embedding/jax/embedding_lookup.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/layers/embedding/tensorflow/config_conversion.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/layers/feature_interaction/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/layers/feature_interaction/dot_interaction.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/layers/feature_interaction/feature_cross.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/layers/retrieval/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/layers/retrieval/brute_force_retrieval.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/layers/retrieval/hard_negative_mining.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/layers/retrieval/remove_accidental_hits.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/layers/retrieval/retrieval.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/layers/retrieval/sampling_probability_correction.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/losses/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/losses/pairwise_hinge_loss.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/losses/pairwise_logistic_loss.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/losses/pairwise_loss.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/losses/pairwise_loss_utils.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/losses/pairwise_mean_squared_error.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/losses/pairwise_soft_zero_one_loss.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/metrics/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/metrics/dcg.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/metrics/mean_average_precision.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/metrics/mean_reciprocal_rank.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/metrics/ndcg.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/metrics/precision_at_k.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/metrics/ranking_metric.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/metrics/ranking_metrics_utils.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/metrics/recall_at_k.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/metrics/utils.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/types.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/utils/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/utils/doc_string_utils.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs/src/utils/keras_utils.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs_nightly.egg-info/SOURCES.txt +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs_nightly.egg-info/dependency_links.txt +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs_nightly.egg-info/requires.txt +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/keras_rs_nightly.egg-info/top_level.txt +0 -0
- {keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/setup.cfg +0 -0
|
@@ -441,7 +441,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
441
441
|
)
|
|
442
442
|
|
|
443
443
|
# Collect all stacked tables.
|
|
444
|
-
table_specs =
|
|
444
|
+
table_specs = embedding.get_table_specs(feature_specs)
|
|
445
445
|
table_stacks = embedding_utils.get_table_stacks(table_specs)
|
|
446
446
|
|
|
447
447
|
# Create variables for all stacked tables and slot variables.
|
|
@@ -515,9 +515,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
515
515
|
del inputs, weights, training
|
|
516
516
|
|
|
517
517
|
# Each stacked-table gets a ShardedCooMatrix.
|
|
518
|
-
table_specs =
|
|
519
|
-
self._config.feature_specs
|
|
520
|
-
)
|
|
518
|
+
table_specs = embedding.get_table_specs(self._config.feature_specs)
|
|
521
519
|
table_stacks = embedding_utils.get_table_stacks(table_specs)
|
|
522
520
|
stacked_table_specs = {
|
|
523
521
|
stack_name: stack[0].stacked_table_spec
|
|
@@ -600,40 +598,43 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
600
598
|
if training:
|
|
601
599
|
# Synchronize input statistics across all devices and update the
|
|
602
600
|
# underlying stacked tables specs in the feature specs.
|
|
603
|
-
prev_stats = embedding_utils.get_stacked_table_stats(
|
|
604
|
-
self._config.feature_specs
|
|
605
|
-
)
|
|
606
601
|
|
|
607
|
-
#
|
|
608
|
-
|
|
602
|
+
# Aggregate stats across all processes/devices via pmax.
|
|
603
|
+
num_local_cpu_devices = jax.local_device_count("cpu")
|
|
609
604
|
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
605
|
+
def pmax_aggregate(x: Any) -> Any:
|
|
606
|
+
if not hasattr(x, "ndim"):
|
|
607
|
+
x = np.array(x)
|
|
608
|
+
tiled_x = np.tile(x, (num_local_cpu_devices, *([1] * x.ndim)))
|
|
609
|
+
return jax.pmap(
|
|
610
|
+
lambda y: jax.lax.pmax(y, "all_cpus"), # type: ignore[no-untyped-call]
|
|
611
|
+
axis_name="all_cpus",
|
|
612
|
+
backend="cpu",
|
|
613
|
+
)(tiled_x)[0]
|
|
614
614
|
|
|
615
|
-
|
|
616
|
-
# replicate the stats to placate JAX collectives.
|
|
617
|
-
num_local_cpu_devices = jax.local_device_count("cpu")
|
|
618
|
-
tiled_stats = np.tile(
|
|
619
|
-
np.array(flat_stats, dtype=np.int32), (num_local_cpu_devices, 1)
|
|
620
|
-
)
|
|
615
|
+
full_stats = jax.tree.map(pmax_aggregate, stats)
|
|
621
616
|
|
|
622
|
-
#
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
617
|
+
# Check if stats changed enough to warrant action.
|
|
618
|
+
stacked_table_specs = embedding.get_stacked_table_specs(
|
|
619
|
+
self._config.feature_specs
|
|
620
|
+
)
|
|
621
|
+
changed = any(
|
|
622
|
+
np.max(full_stats.max_ids_per_partition[stack_name])
|
|
623
|
+
> spec.max_ids_per_partition
|
|
624
|
+
or np.max(full_stats.max_unique_ids_per_partition[stack_name])
|
|
625
|
+
> spec.max_unique_ids_per_partition
|
|
626
|
+
or (
|
|
627
|
+
np.max(full_stats.required_buffer_size_per_sc[stack_name])
|
|
628
|
+
* num_sc_per_device
|
|
629
|
+
)
|
|
630
|
+
> (spec.suggested_coo_buffer_size_per_device or 0)
|
|
631
|
+
for stack_name, spec in stacked_table_specs.items()
|
|
629
632
|
)
|
|
630
|
-
flat_stats = max_across_cpus(tiled_stats)[0].tolist()
|
|
631
|
-
stats = jax.tree.unflatten(stats_treedef, flat_stats)
|
|
632
633
|
|
|
633
634
|
# Update configuration and repeat preprocessing if stats changed.
|
|
634
|
-
if
|
|
635
|
-
|
|
636
|
-
self._config.feature_specs,
|
|
635
|
+
if changed:
|
|
636
|
+
embedding.update_preprocessing_parameters(
|
|
637
|
+
self._config.feature_specs, full_stats, num_sc_per_device
|
|
637
638
|
)
|
|
638
639
|
|
|
639
640
|
# Re-execute preprocessing with consistent input statistics.
|
|
@@ -718,7 +719,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
718
719
|
|
|
719
720
|
config = self._config
|
|
720
721
|
num_table_shards = config.mesh.devices.size * config.num_sc_per_device
|
|
721
|
-
table_specs =
|
|
722
|
+
table_specs = embedding.get_table_specs(config.feature_specs)
|
|
722
723
|
sharded_tables = embedding_utils.stack_and_shard_tables(
|
|
723
724
|
table_specs,
|
|
724
725
|
tables,
|
|
@@ -750,7 +751,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
750
751
|
|
|
751
752
|
config = self._config
|
|
752
753
|
num_table_shards = config.mesh.devices.size * config.num_sc_per_device
|
|
753
|
-
table_specs =
|
|
754
|
+
table_specs = embedding.get_table_specs(config.feature_specs)
|
|
754
755
|
|
|
755
756
|
# Extract only the table variables, not the gradient slot variables.
|
|
756
757
|
table_variables = {
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
"""Utility functions for manipulating JAX embedding tables and inputs."""
|
|
2
2
|
|
|
3
3
|
import collections
|
|
4
|
-
import dataclasses
|
|
5
4
|
import typing
|
|
6
5
|
from typing import Any, Mapping, NamedTuple, Sequence, TypeAlias, TypeVar
|
|
7
6
|
|
|
@@ -35,12 +34,6 @@ class ShardedCooMatrix(NamedTuple):
|
|
|
35
34
|
values: ArrayLike
|
|
36
35
|
|
|
37
36
|
|
|
38
|
-
class InputStatsPerTable(NamedTuple):
|
|
39
|
-
max_ids_per_partition: int
|
|
40
|
-
max_unique_ids_per_partition: int
|
|
41
|
-
required_buffer_size_per_device: int
|
|
42
|
-
|
|
43
|
-
|
|
44
37
|
def _round_up_to_multiple(value: int, multiple: int) -> int:
|
|
45
38
|
return ((value + multiple - 1) // multiple) * multiple
|
|
46
39
|
|
|
@@ -303,15 +296,6 @@ def unshard_and_unstack_tables(
|
|
|
303
296
|
return output
|
|
304
297
|
|
|
305
298
|
|
|
306
|
-
def get_table_specs(feature_specs: Nested[FeatureSpec]) -> dict[str, TableSpec]:
|
|
307
|
-
table_spec_map: dict[str, TableSpec] = {}
|
|
308
|
-
flat_feature_specs, _ = jax.tree.flatten(feature_specs)
|
|
309
|
-
for feature_spec in flat_feature_specs:
|
|
310
|
-
table_spec = feature_spec.table_spec
|
|
311
|
-
table_spec_map[table_spec.name] = table_spec
|
|
312
|
-
return table_spec_map
|
|
313
|
-
|
|
314
|
-
|
|
315
299
|
def get_table_stacks(
|
|
316
300
|
table_specs: Nested[TableSpec],
|
|
317
301
|
) -> dict[str, list[TableSpec]]:
|
|
@@ -341,84 +325,6 @@ def get_table_stacks(
|
|
|
341
325
|
return stacked_table_specs
|
|
342
326
|
|
|
343
327
|
|
|
344
|
-
def get_stacked_table_stats(
|
|
345
|
-
feature_specs: Nested[FeatureSpec],
|
|
346
|
-
) -> dict[str, InputStatsPerTable]:
|
|
347
|
-
"""Extracts the stacked-table input statistics from the feature specs.
|
|
348
|
-
|
|
349
|
-
Args:
|
|
350
|
-
feature_specs: Feature specs from which to extracts the statistics.
|
|
351
|
-
|
|
352
|
-
Returns:
|
|
353
|
-
A mapping of stacked table names to input statistics per table.
|
|
354
|
-
"""
|
|
355
|
-
stacked_table_specs: dict[str, StackedTableSpec] = {}
|
|
356
|
-
for feature_spec in jax.tree.flatten(feature_specs)[0]:
|
|
357
|
-
feature_spec = typing.cast(FeatureSpec, feature_spec)
|
|
358
|
-
stacked_table_spec = typing.cast(
|
|
359
|
-
StackedTableSpec, feature_spec.table_spec.stacked_table_spec
|
|
360
|
-
)
|
|
361
|
-
stacked_table_specs[stacked_table_spec.stack_name] = stacked_table_spec
|
|
362
|
-
|
|
363
|
-
stats: dict[str, InputStatsPerTable] = {}
|
|
364
|
-
for stacked_table_spec in stacked_table_specs.values():
|
|
365
|
-
buffer_size = stacked_table_spec.suggested_coo_buffer_size_per_device
|
|
366
|
-
buffer_size = buffer_size or 0
|
|
367
|
-
stats[stacked_table_spec.stack_name] = InputStatsPerTable(
|
|
368
|
-
max_ids_per_partition=stacked_table_spec.max_ids_per_partition,
|
|
369
|
-
max_unique_ids_per_partition=stacked_table_spec.max_unique_ids_per_partition,
|
|
370
|
-
required_buffer_size_per_device=buffer_size,
|
|
371
|
-
)
|
|
372
|
-
|
|
373
|
-
return stats
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
def update_stacked_table_stats(
|
|
377
|
-
feature_specs: Nested[FeatureSpec],
|
|
378
|
-
stats: Mapping[str, InputStatsPerTable],
|
|
379
|
-
) -> None:
|
|
380
|
-
"""Updates stacked-table input properties in the supplied feature specs.
|
|
381
|
-
|
|
382
|
-
Args:
|
|
383
|
-
feature_specs: Feature specs to update in-place.
|
|
384
|
-
stats: Per-stacked-table input statistics.
|
|
385
|
-
"""
|
|
386
|
-
# Collect table specs and stacked table specs.
|
|
387
|
-
table_specs: dict[str, TableSpec] = {}
|
|
388
|
-
for feature_spec in jax.tree.flatten(feature_specs)[0]:
|
|
389
|
-
feature_spec = typing.cast(FeatureSpec, feature_spec)
|
|
390
|
-
table_specs[feature_spec.table_spec.name] = feature_spec.table_spec
|
|
391
|
-
|
|
392
|
-
stacked_table_specs: dict[str, StackedTableSpec] = {}
|
|
393
|
-
for table_spec in table_specs.values():
|
|
394
|
-
stacked_table_spec = typing.cast(
|
|
395
|
-
StackedTableSpec, table_spec.stacked_table_spec
|
|
396
|
-
)
|
|
397
|
-
stacked_table_specs[stacked_table_spec.stack_name] = stacked_table_spec
|
|
398
|
-
|
|
399
|
-
# Replace fields in the stacked_table_specs.
|
|
400
|
-
stack_names = stacked_table_specs.keys()
|
|
401
|
-
for stack_name in stack_names:
|
|
402
|
-
stack_stats = stats[stack_name]
|
|
403
|
-
stacked_table_spec = stacked_table_specs[stack_name]
|
|
404
|
-
buffer_size = stack_stats.required_buffer_size_per_device or None
|
|
405
|
-
stacked_table_specs[stack_name] = dataclasses.replace(
|
|
406
|
-
stacked_table_spec,
|
|
407
|
-
max_ids_per_partition=stack_stats.max_ids_per_partition,
|
|
408
|
-
max_unique_ids_per_partition=stack_stats.max_unique_ids_per_partition,
|
|
409
|
-
suggested_coo_buffer_size_per_device=buffer_size,
|
|
410
|
-
)
|
|
411
|
-
|
|
412
|
-
# Insert new stacked tables into tables.
|
|
413
|
-
for table_spec in table_specs.values():
|
|
414
|
-
stacked_table_spec = typing.cast(
|
|
415
|
-
StackedTableSpec, table_spec.stacked_table_spec
|
|
416
|
-
)
|
|
417
|
-
table_spec.stacked_table_spec = stacked_table_specs[
|
|
418
|
-
stacked_table_spec.stack_name
|
|
419
|
-
]
|
|
420
|
-
|
|
421
|
-
|
|
422
328
|
def convert_to_numpy(
|
|
423
329
|
ragged_or_dense: np.ndarray[Any, Any] | Sequence[Sequence[Any]] | Any,
|
|
424
330
|
dtype: Any,
|
|
@@ -483,7 +389,7 @@ def ones_like(
|
|
|
483
389
|
|
|
484
390
|
Args:
|
|
485
391
|
ragged_or_dense: The ragged or dense input whose shape and data-type
|
|
486
|
-
|
|
392
|
+
define these same attributes of the returned array.
|
|
487
393
|
dtype: The data-type of the returned array.
|
|
488
394
|
|
|
489
395
|
Returns:
|
|
@@ -567,7 +473,7 @@ def stack_and_shard_samples(
|
|
|
567
473
|
global_device_count: int,
|
|
568
474
|
num_sc_per_device: int,
|
|
569
475
|
static_buffer_size: int | Mapping[str, int] | None = None,
|
|
570
|
-
) -> tuple[dict[str, ShardedCooMatrix],
|
|
476
|
+
) -> tuple[dict[str, ShardedCooMatrix], embedding.SparseDenseMatmulInputStats]:
|
|
571
477
|
"""Prepares input samples for use in embedding lookups.
|
|
572
478
|
|
|
573
479
|
Args:
|
|
@@ -612,7 +518,6 @@ def stack_and_shard_samples(
|
|
|
612
518
|
)
|
|
613
519
|
|
|
614
520
|
out: dict[str, ShardedCooMatrix] = {}
|
|
615
|
-
out_stats: dict[str, InputStatsPerTable] = {}
|
|
616
521
|
tables_names = preprocessed_inputs.lhs_row_pointers.keys()
|
|
617
522
|
for table_name in tables_names:
|
|
618
523
|
shard_ends = preprocessed_inputs.lhs_row_pointers[table_name]
|
|
@@ -626,17 +531,5 @@ def stack_and_shard_samples(
|
|
|
626
531
|
row_ids=preprocessed_inputs.lhs_sample_ids[table_name],
|
|
627
532
|
values=preprocessed_inputs.lhs_gains[table_name],
|
|
628
533
|
)
|
|
629
|
-
out_stats[table_name] = InputStatsPerTable(
|
|
630
|
-
max_ids_per_partition=np.max(
|
|
631
|
-
stats.max_ids_per_partition[table_name]
|
|
632
|
-
),
|
|
633
|
-
max_unique_ids_per_partition=np.max(
|
|
634
|
-
stats.max_unique_ids_per_partition[table_name]
|
|
635
|
-
),
|
|
636
|
-
required_buffer_size_per_device=np.max(
|
|
637
|
-
stats.required_buffer_size_per_sc[table_name]
|
|
638
|
-
)
|
|
639
|
-
* num_sc_per_device,
|
|
640
|
-
)
|
|
641
534
|
|
|
642
|
-
return out,
|
|
535
|
+
return out, stats
|
{keras_rs_nightly-0.3.1.dev202510060327 → keras_rs_nightly-0.3.1.dev202510080323}/pyproject.toml
RENAMED
|
@@ -64,7 +64,7 @@ known-first-party = ["keras_rs"]
|
|
|
64
64
|
[tool.mypy]
|
|
65
65
|
python_version = "3.10"
|
|
66
66
|
strict = "True"
|
|
67
|
-
exclude = ["_test\\.py$", "^examples/"]
|
|
67
|
+
exclude = ["_test\\.py$", "^examples/", "venv/"]
|
|
68
68
|
untyped_calls_exclude = ["ml_dtypes"]
|
|
69
69
|
disable_error_code = ["import-untyped", "unused-ignore"]
|
|
70
70
|
disallow_subclassing_any = "False"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|