keras-rs-nightly 0.2.2.dev202509020329__py3-none-any.whl → 0.2.2.dev202509170322__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-rs-nightly might be problematic. Click here for more details.
- keras_rs/src/layers/embedding/base_distributed_embedding.py +12 -10
- keras_rs/src/layers/embedding/distributed_embedding_config.py +2 -2
- keras_rs/src/layers/embedding/jax/distributed_embedding.py +41 -174
- keras_rs/src/layers/embedding/jax/embedding_utils.py +68 -22
- keras_rs/src/layers/embedding/tensorflow/config_conversion.py +26 -19
- keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +15 -5
- keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.2.2.dev202509020329.dist-info → keras_rs_nightly-0.2.2.dev202509170322.dist-info}/METADATA +4 -3
- {keras_rs_nightly-0.2.2.dev202509020329.dist-info → keras_rs_nightly-0.2.2.dev202509170322.dist-info}/RECORD +11 -11
- {keras_rs_nightly-0.2.2.dev202509020329.dist-info → keras_rs_nightly-0.2.2.dev202509170322.dist-info}/WHEEL +0 -0
- {keras_rs_nightly-0.2.2.dev202509020329.dist-info → keras_rs_nightly-0.2.2.dev202509170322.dist-info}/top_level.txt +0 -0
|
@@ -822,13 +822,13 @@ class DistributedEmbedding(keras.layers.Layer):
|
|
|
822
822
|
table_stacking: str | Sequence[Sequence[str]],
|
|
823
823
|
) -> None:
|
|
824
824
|
del table_stacking
|
|
825
|
-
|
|
825
|
+
table_config_id_to_embedding_layer: dict[int, EmbedReduce] = {}
|
|
826
826
|
self._default_device_embedding_layers: dict[str, EmbedReduce] = {}
|
|
827
827
|
|
|
828
828
|
for path, feature_config in feature_configs.items():
|
|
829
|
-
if feature_config.table in
|
|
829
|
+
if id(feature_config.table) in table_config_id_to_embedding_layer:
|
|
830
830
|
self._default_device_embedding_layers[path] = (
|
|
831
|
-
|
|
831
|
+
table_config_id_to_embedding_layer[id(feature_config.table)]
|
|
832
832
|
)
|
|
833
833
|
else:
|
|
834
834
|
embedding_layer = EmbedReduce(
|
|
@@ -838,7 +838,9 @@ class DistributedEmbedding(keras.layers.Layer):
|
|
|
838
838
|
embeddings_initializer=feature_config.table.initializer,
|
|
839
839
|
combiner=feature_config.table.combiner,
|
|
840
840
|
)
|
|
841
|
-
|
|
841
|
+
table_config_id_to_embedding_layer[id(feature_config.table)] = (
|
|
842
|
+
embedding_layer
|
|
843
|
+
)
|
|
842
844
|
self._default_device_embedding_layers[path] = embedding_layer
|
|
843
845
|
|
|
844
846
|
def _default_device_build(
|
|
@@ -1013,8 +1015,8 @@ class DistributedEmbedding(keras.layers.Layer):
|
|
|
1013
1015
|
|
|
1014
1016
|
# The serialized `TableConfig` objects.
|
|
1015
1017
|
table_config_dicts: list[dict[str, Any]] = []
|
|
1016
|
-
# Mapping from `TableConfig` to index in `table_config_dicts`.
|
|
1017
|
-
|
|
1018
|
+
# Mapping from `TableConfig` id to index in `table_config_dicts`.
|
|
1019
|
+
table_config_id_to_index: dict[int, int] = {}
|
|
1018
1020
|
|
|
1019
1021
|
def serialize_feature_config(
|
|
1020
1022
|
feature_config: FeatureConfig,
|
|
@@ -1024,17 +1026,17 @@ class DistributedEmbedding(keras.layers.Layer):
|
|
|
1024
1026
|
# key.
|
|
1025
1027
|
feature_config_dict = feature_config.get_config()
|
|
1026
1028
|
|
|
1027
|
-
if feature_config.table not in
|
|
1029
|
+
if id(feature_config.table) not in table_config_id_to_index:
|
|
1028
1030
|
# Save the serialized `TableConfig` the first time we see it and
|
|
1029
1031
|
# remember its index.
|
|
1030
|
-
|
|
1032
|
+
table_config_id_to_index[id(feature_config.table)] = len(
|
|
1031
1033
|
table_config_dicts
|
|
1032
1034
|
)
|
|
1033
1035
|
table_config_dicts.append(feature_config_dict["table"])
|
|
1034
1036
|
|
|
1035
1037
|
# Replace the serialized `TableConfig` with its index.
|
|
1036
|
-
feature_config_dict["table"] =
|
|
1037
|
-
feature_config.table
|
|
1038
|
+
feature_config_dict["table"] = table_config_id_to_index[
|
|
1039
|
+
id(feature_config.table)
|
|
1038
1040
|
]
|
|
1039
1041
|
return feature_config_dict
|
|
1040
1042
|
|
|
@@ -10,7 +10,7 @@ from keras_rs.src.api_export import keras_rs_export
|
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
@keras_rs_export("keras_rs.layers.TableConfig")
|
|
13
|
-
@dataclasses.dataclass(
|
|
13
|
+
@dataclasses.dataclass(order=True)
|
|
14
14
|
class TableConfig:
|
|
15
15
|
"""Configuration for one embedding table.
|
|
16
16
|
|
|
@@ -88,7 +88,7 @@ class TableConfig:
|
|
|
88
88
|
|
|
89
89
|
|
|
90
90
|
@keras_rs_export("keras_rs.layers.FeatureConfig")
|
|
91
|
-
@dataclasses.dataclass(
|
|
91
|
+
@dataclasses.dataclass(order=True)
|
|
92
92
|
class FeatureConfig:
|
|
93
93
|
"""Configuration for one embedding feature.
|
|
94
94
|
|
|
@@ -15,7 +15,6 @@ from jax_tpu_embedding.sparsecore.lib.nn import (
|
|
|
15
15
|
table_stacking as jte_table_stacking,
|
|
16
16
|
)
|
|
17
17
|
from jax_tpu_embedding.sparsecore.utils import utils as jte_utils
|
|
18
|
-
from keras.src import backend
|
|
19
18
|
|
|
20
19
|
from keras_rs.src import types
|
|
21
20
|
from keras_rs.src.layers.embedding import base_distributed_embedding
|
|
@@ -247,23 +246,6 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
247
246
|
)
|
|
248
247
|
return sparsecore_distribution, sparsecore_layout
|
|
249
248
|
|
|
250
|
-
def _create_cpu_distribution(
|
|
251
|
-
self, cpu_axis_name: str = "cpu"
|
|
252
|
-
) -> tuple[
|
|
253
|
-
keras.distribution.ModelParallel, keras.distribution.TensorLayout
|
|
254
|
-
]:
|
|
255
|
-
"""Share a variable across all CPU processes."""
|
|
256
|
-
cpu_devices = jax.devices("cpu")
|
|
257
|
-
device_mesh = keras.distribution.DeviceMesh(
|
|
258
|
-
(len(cpu_devices),), [cpu_axis_name], cpu_devices
|
|
259
|
-
)
|
|
260
|
-
replicated_layout = keras.distribution.TensorLayout([], device_mesh)
|
|
261
|
-
layout_map = keras.distribution.LayoutMap(device_mesh=device_mesh)
|
|
262
|
-
cpu_distribution = keras.distribution.ModelParallel(
|
|
263
|
-
layout_map=layout_map
|
|
264
|
-
)
|
|
265
|
-
return cpu_distribution, replicated_layout
|
|
266
|
-
|
|
267
249
|
def _add_sparsecore_weight(
|
|
268
250
|
self,
|
|
269
251
|
name: str,
|
|
@@ -405,11 +387,6 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
405
387
|
self._sparsecore_layout = sparsecore_layout
|
|
406
388
|
self._sparsecore_distribution = sparsecore_distribution
|
|
407
389
|
|
|
408
|
-
# Distribution for CPU operations.
|
|
409
|
-
cpu_distribution, cpu_layout = self._create_cpu_distribution()
|
|
410
|
-
self._cpu_distribution = cpu_distribution
|
|
411
|
-
self._cpu_layout = cpu_layout
|
|
412
|
-
|
|
413
390
|
mesh = sparsecore_distribution.device_mesh.backend_mesh
|
|
414
391
|
global_device_count = mesh.devices.size
|
|
415
392
|
num_sc_per_device = jte_utils.num_sparsecores_per_device(
|
|
@@ -466,10 +443,6 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
466
443
|
# Collect all stacked tables.
|
|
467
444
|
table_specs = embedding_utils.get_table_specs(feature_specs)
|
|
468
445
|
table_stacks = embedding_utils.get_table_stacks(table_specs)
|
|
469
|
-
stacked_table_specs = {
|
|
470
|
-
stack_name: stack[0].stacked_table_spec
|
|
471
|
-
for stack_name, stack in table_stacks.items()
|
|
472
|
-
}
|
|
473
446
|
|
|
474
447
|
# Create variables for all stacked tables and slot variables.
|
|
475
448
|
with sparsecore_distribution.scope():
|
|
@@ -502,50 +475,6 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
502
475
|
)
|
|
503
476
|
self._iterations.overwrite_with_gradient = True
|
|
504
477
|
|
|
505
|
-
with cpu_distribution.scope():
|
|
506
|
-
# Create variables to track static buffer size and max IDs for each
|
|
507
|
-
# table during preprocessing. These variables are shared across all
|
|
508
|
-
# processes on CPU. We don't add these via `add_weight` because we
|
|
509
|
-
# can't have them passed to the training function.
|
|
510
|
-
replicated_zeros_initializer = ShardedInitializer(
|
|
511
|
-
"zeros", cpu_layout
|
|
512
|
-
)
|
|
513
|
-
|
|
514
|
-
with backend.name_scope(self.name, caller=self):
|
|
515
|
-
self._preprocessing_buffer_size = {
|
|
516
|
-
table_name: backend.Variable(
|
|
517
|
-
initializer=replicated_zeros_initializer,
|
|
518
|
-
shape=(),
|
|
519
|
-
dtype=backend.standardize_dtype("int32"),
|
|
520
|
-
trainable=False,
|
|
521
|
-
name=table_name + ":preprocessing:buffer_size",
|
|
522
|
-
)
|
|
523
|
-
for table_name in stacked_table_specs.keys()
|
|
524
|
-
}
|
|
525
|
-
self._preprocessing_max_unique_ids_per_partition = {
|
|
526
|
-
table_name: backend.Variable(
|
|
527
|
-
shape=(),
|
|
528
|
-
name=table_name
|
|
529
|
-
+ ":preprocessing:max_unique_ids_per_partition",
|
|
530
|
-
initializer=replicated_zeros_initializer,
|
|
531
|
-
dtype=backend.standardize_dtype("int32"),
|
|
532
|
-
trainable=False,
|
|
533
|
-
)
|
|
534
|
-
for table_name in stacked_table_specs.keys()
|
|
535
|
-
}
|
|
536
|
-
|
|
537
|
-
self._preprocessing_max_ids_per_partition = {
|
|
538
|
-
table_name: backend.Variable(
|
|
539
|
-
shape=(),
|
|
540
|
-
name=table_name
|
|
541
|
-
+ ":preprocessing:max_ids_per_partition",
|
|
542
|
-
initializer=replicated_zeros_initializer,
|
|
543
|
-
dtype=backend.standardize_dtype("int32"),
|
|
544
|
-
trainable=False,
|
|
545
|
-
)
|
|
546
|
-
for table_name in stacked_table_specs.keys()
|
|
547
|
-
}
|
|
548
|
-
|
|
549
478
|
self._config = jte_embedding_lookup.EmbeddingLookupConfiguration(
|
|
550
479
|
feature_specs,
|
|
551
480
|
mesh=mesh,
|
|
@@ -660,76 +589,35 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
660
589
|
mesh.devices.item(0)
|
|
661
590
|
)
|
|
662
591
|
|
|
663
|
-
# Get current buffer size/max_ids.
|
|
664
|
-
previous_max_ids_per_partition = keras.tree.map_structure(
|
|
665
|
-
lambda max_ids_per_partition: max_ids_per_partition.value.item(),
|
|
666
|
-
self._preprocessing_max_ids_per_partition,
|
|
667
|
-
)
|
|
668
|
-
previous_max_unique_ids_per_partition = keras.tree.map_structure(
|
|
669
|
-
lambda max_unique_ids_per_partition: (
|
|
670
|
-
max_unique_ids_per_partition.value.item()
|
|
671
|
-
),
|
|
672
|
-
self._preprocessing_max_unique_ids_per_partition,
|
|
673
|
-
)
|
|
674
|
-
previous_buffer_size = keras.tree.map_structure(
|
|
675
|
-
lambda buffer_size: buffer_size.value.item(),
|
|
676
|
-
self._preprocessing_buffer_size,
|
|
677
|
-
)
|
|
678
|
-
|
|
679
592
|
preprocessed, stats = embedding_utils.stack_and_shard_samples(
|
|
680
593
|
self._config.feature_specs,
|
|
681
594
|
samples,
|
|
682
595
|
local_device_count,
|
|
683
596
|
global_device_count,
|
|
684
597
|
num_sc_per_device,
|
|
685
|
-
static_buffer_size=previous_buffer_size,
|
|
686
598
|
)
|
|
687
599
|
|
|
688
|
-
# Extract max unique IDs and buffer sizes.
|
|
689
|
-
# We need to replicate this value across all local CPU devices.
|
|
690
600
|
if training:
|
|
601
|
+
# Synchronize input statistics across all devices and update the
|
|
602
|
+
# underlying stacked tables specs in the feature specs.
|
|
603
|
+
prev_stats = embedding_utils.get_stacked_table_stats(
|
|
604
|
+
self._config.feature_specs
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
# Take the maximum with existing stats.
|
|
608
|
+
stats = keras.tree.map_structure(max, prev_stats, stats)
|
|
609
|
+
|
|
610
|
+
# Flatten the stats so we can more efficiently transfer them
|
|
611
|
+
# between hosts. We use jax.tree because we will later need to
|
|
612
|
+
# unflatten.
|
|
613
|
+
flat_stats, stats_treedef = jax.tree.flatten(stats)
|
|
614
|
+
|
|
615
|
+
# In the case of multiple local CPU devices per host, we need to
|
|
616
|
+
# replicate the stats to placate JAX collectives.
|
|
691
617
|
num_local_cpu_devices = jax.local_device_count("cpu")
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
np.maximum(
|
|
696
|
-
np.max(elems),
|
|
697
|
-
previous_max_ids_per_partition[table_name],
|
|
698
|
-
),
|
|
699
|
-
num_local_cpu_devices,
|
|
700
|
-
)
|
|
701
|
-
for table_name, elems in stats.max_ids_per_partition.items()
|
|
702
|
-
}
|
|
703
|
-
local_max_unique_ids_per_partition = {
|
|
704
|
-
name: np.repeat(
|
|
705
|
-
# Maximum across all partitions and previous max.
|
|
706
|
-
np.maximum(
|
|
707
|
-
np.max(elems),
|
|
708
|
-
previous_max_unique_ids_per_partition[name],
|
|
709
|
-
),
|
|
710
|
-
num_local_cpu_devices,
|
|
711
|
-
)
|
|
712
|
-
for name, elems in stats.max_unique_ids_per_partition.items()
|
|
713
|
-
}
|
|
714
|
-
local_buffer_size = {
|
|
715
|
-
table_name: np.repeat(
|
|
716
|
-
np.maximum(
|
|
717
|
-
np.max(
|
|
718
|
-
# Round values up to the next multiple of 8.
|
|
719
|
-
# Currently using this as a proxy for the actual
|
|
720
|
-
# required buffer size.
|
|
721
|
-
((elems + 7) // 8) * 8
|
|
722
|
-
)
|
|
723
|
-
* global_device_count
|
|
724
|
-
* num_sc_per_device
|
|
725
|
-
* local_device_count
|
|
726
|
-
* num_sc_per_device,
|
|
727
|
-
previous_buffer_size[table_name],
|
|
728
|
-
),
|
|
729
|
-
num_local_cpu_devices,
|
|
730
|
-
)
|
|
731
|
-
for table_name, elems in stats.max_ids_per_partition.items()
|
|
732
|
-
}
|
|
618
|
+
tiled_stats = np.tile(
|
|
619
|
+
np.array(flat_stats, dtype=np.int32), (num_local_cpu_devices, 1)
|
|
620
|
+
)
|
|
733
621
|
|
|
734
622
|
# Aggregate variables across all processes/devices.
|
|
735
623
|
max_across_cpus = jax.pmap(
|
|
@@ -737,48 +625,24 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
737
625
|
x, "all_cpus"
|
|
738
626
|
),
|
|
739
627
|
axis_name="all_cpus",
|
|
740
|
-
|
|
628
|
+
backend="cpu",
|
|
741
629
|
)
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
# Assign new preprocessing parameters.
|
|
751
|
-
with self._cpu_distribution.scope():
|
|
752
|
-
# For each process, all max ids/buffer sizes are replicated
|
|
753
|
-
# across all local devices. Take the value from the first
|
|
754
|
-
# device.
|
|
755
|
-
keras.tree.map_structure(
|
|
756
|
-
lambda var, values: var.assign(values[0]),
|
|
757
|
-
self._preprocessing_max_ids_per_partition,
|
|
758
|
-
new_max_ids_per_partition,
|
|
759
|
-
)
|
|
760
|
-
keras.tree.map_structure(
|
|
761
|
-
lambda var, values: var.assign(values[0]),
|
|
762
|
-
self._preprocessing_max_unique_ids_per_partition,
|
|
763
|
-
new_max_unique_ids_per_partition,
|
|
764
|
-
)
|
|
765
|
-
keras.tree.map_structure(
|
|
766
|
-
lambda var, values: var.assign(values[0]),
|
|
767
|
-
self._preprocessing_buffer_size,
|
|
768
|
-
new_buffer_size,
|
|
769
|
-
)
|
|
770
|
-
# Update parameters in the underlying feature specs.
|
|
771
|
-
int_max_ids_per_partition = keras.tree.map_structure(
|
|
772
|
-
lambda varray: varray.item(), new_max_ids_per_partition
|
|
773
|
-
)
|
|
774
|
-
int_max_unique_ids_per_partition = keras.tree.map_structure(
|
|
775
|
-
lambda varray: varray.item(),
|
|
776
|
-
new_max_unique_ids_per_partition,
|
|
630
|
+
flat_stats = max_across_cpus(tiled_stats)[0].tolist()
|
|
631
|
+
stats = jax.tree.unflatten(stats_treedef, flat_stats)
|
|
632
|
+
|
|
633
|
+
# Update configuration and repeat preprocessing if stats changed.
|
|
634
|
+
if stats != prev_stats:
|
|
635
|
+
embedding_utils.update_stacked_table_stats(
|
|
636
|
+
self._config.feature_specs, stats
|
|
777
637
|
)
|
|
778
|
-
|
|
638
|
+
|
|
639
|
+
# Re-execute preprocessing with consistent input statistics.
|
|
640
|
+
preprocessed, _ = embedding_utils.stack_and_shard_samples(
|
|
779
641
|
self._config.feature_specs,
|
|
780
|
-
|
|
781
|
-
|
|
642
|
+
samples,
|
|
643
|
+
local_device_count,
|
|
644
|
+
global_device_count,
|
|
645
|
+
num_sc_per_device,
|
|
782
646
|
)
|
|
783
647
|
|
|
784
648
|
return {"inputs": preprocessed}
|
|
@@ -826,19 +690,22 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
826
690
|
raise ValueError("Layer must first be built before setting tables.")
|
|
827
691
|
|
|
828
692
|
if "default_device" in self._placement_to_path_to_feature_config:
|
|
829
|
-
|
|
693
|
+
table_name_to_embedding_layer = {}
|
|
830
694
|
for (
|
|
831
695
|
path,
|
|
832
696
|
feature_config,
|
|
833
697
|
) in self._placement_to_path_to_feature_config[
|
|
834
698
|
"default_device"
|
|
835
699
|
].items():
|
|
836
|
-
|
|
700
|
+
table_name_to_embedding_layer[feature_config.table.name] = (
|
|
837
701
|
self._default_device_embedding_layers[path]
|
|
838
702
|
)
|
|
839
703
|
|
|
840
|
-
for
|
|
841
|
-
|
|
704
|
+
for (
|
|
705
|
+
table_name,
|
|
706
|
+
embedding_layer,
|
|
707
|
+
) in table_name_to_embedding_layer.items():
|
|
708
|
+
table_values = tables.get(table_name, None)
|
|
842
709
|
if table_values is not None:
|
|
843
710
|
if embedding_layer.lora_enabled:
|
|
844
711
|
raise ValueError("Cannot set table if LoRA is enabled.")
|
|
@@ -35,6 +35,12 @@ class ShardedCooMatrix(NamedTuple):
|
|
|
35
35
|
values: ArrayLike
|
|
36
36
|
|
|
37
37
|
|
|
38
|
+
class InputStatsPerTable(NamedTuple):
|
|
39
|
+
max_ids_per_partition: int
|
|
40
|
+
max_unique_ids_per_partition: int
|
|
41
|
+
required_buffer_size_per_device: int
|
|
42
|
+
|
|
43
|
+
|
|
38
44
|
def _round_up_to_multiple(value: int, multiple: int) -> int:
|
|
39
45
|
return ((value + multiple - 1) // multiple) * multiple
|
|
40
46
|
|
|
@@ -335,19 +341,47 @@ def get_table_stacks(
|
|
|
335
341
|
return stacked_table_specs
|
|
336
342
|
|
|
337
343
|
|
|
338
|
-
def
|
|
344
|
+
def get_stacked_table_stats(
|
|
339
345
|
feature_specs: Nested[FeatureSpec],
|
|
340
|
-
|
|
341
|
-
|
|
346
|
+
) -> dict[str, InputStatsPerTable]:
|
|
347
|
+
"""Extracts the stacked-table input statistics from the feature specs.
|
|
348
|
+
|
|
349
|
+
Args:
|
|
350
|
+
feature_specs: Feature specs from which to extracts the statistics.
|
|
351
|
+
|
|
352
|
+
Returns:
|
|
353
|
+
A mapping of stacked table names to input statistics per table.
|
|
354
|
+
"""
|
|
355
|
+
stacked_table_specs: dict[str, StackedTableSpec] = {}
|
|
356
|
+
for feature_spec in jax.tree.flatten(feature_specs)[0]:
|
|
357
|
+
feature_spec = typing.cast(FeatureSpec, feature_spec)
|
|
358
|
+
stacked_table_spec = typing.cast(
|
|
359
|
+
StackedTableSpec, feature_spec.table_spec.stacked_table_spec
|
|
360
|
+
)
|
|
361
|
+
stacked_table_specs[stacked_table_spec.stack_name] = stacked_table_spec
|
|
362
|
+
|
|
363
|
+
stats: dict[str, InputStatsPerTable] = {}
|
|
364
|
+
for stacked_table_spec in stacked_table_specs.values():
|
|
365
|
+
buffer_size = stacked_table_spec.suggested_coo_buffer_size_per_device
|
|
366
|
+
buffer_size = buffer_size or 0
|
|
367
|
+
stats[stacked_table_spec.stack_name] = InputStatsPerTable(
|
|
368
|
+
max_ids_per_partition=stacked_table_spec.max_ids_per_partition,
|
|
369
|
+
max_unique_ids_per_partition=stacked_table_spec.max_unique_ids_per_partition,
|
|
370
|
+
required_buffer_size_per_device=buffer_size,
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
return stats
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
def update_stacked_table_stats(
|
|
377
|
+
feature_specs: Nested[FeatureSpec],
|
|
378
|
+
stats: Mapping[str, InputStatsPerTable],
|
|
342
379
|
) -> None:
|
|
343
|
-
"""Updates properties in the supplied feature specs.
|
|
380
|
+
"""Updates stacked-table input properties in the supplied feature specs.
|
|
344
381
|
|
|
345
382
|
Args:
|
|
346
383
|
feature_specs: Feature specs to update in-place.
|
|
347
|
-
|
|
348
|
-
new `max_ids_per_partition` for the stack.
|
|
349
|
-
max_unique_ids_per_partition: Mapping of table stack name to
|
|
350
|
-
new `max_unique_ids_per_partition` for the stack.
|
|
384
|
+
stats: Per-stacked-table input statistics.
|
|
351
385
|
"""
|
|
352
386
|
# Collect table specs and stacked table specs.
|
|
353
387
|
table_specs: dict[str, TableSpec] = {}
|
|
@@ -363,18 +397,17 @@ def update_stacked_table_specs(
|
|
|
363
397
|
stacked_table_specs[stacked_table_spec.stack_name] = stacked_table_spec
|
|
364
398
|
|
|
365
399
|
# Replace fields in the stacked_table_specs.
|
|
366
|
-
|
|
367
|
-
|
|
400
|
+
stack_names = stacked_table_specs.keys()
|
|
401
|
+
for stack_name in stack_names:
|
|
402
|
+
stack_stats = stats[stack_name]
|
|
403
|
+
stacked_table_spec = stacked_table_specs[stack_name]
|
|
404
|
+
buffer_size = stack_stats.required_buffer_size_per_device or None
|
|
405
|
+
stacked_table_specs[stack_name] = dataclasses.replace(
|
|
368
406
|
stacked_table_spec,
|
|
369
|
-
max_ids_per_partition=max_ids_per_partition
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
max_unique_ids_per_partition=max_unique_ids_per_partition[
|
|
373
|
-
stacked_table_spec.stack_name
|
|
374
|
-
],
|
|
407
|
+
max_ids_per_partition=stack_stats.max_ids_per_partition,
|
|
408
|
+
max_unique_ids_per_partition=stack_stats.max_unique_ids_per_partition,
|
|
409
|
+
suggested_coo_buffer_size_per_device=buffer_size,
|
|
375
410
|
)
|
|
376
|
-
for stack_name, stacked_table_spec in stacked_table_specs.items()
|
|
377
|
-
}
|
|
378
411
|
|
|
379
412
|
# Insert new stacked tables into tables.
|
|
380
413
|
for table_spec in table_specs.values():
|
|
@@ -534,7 +567,7 @@ def stack_and_shard_samples(
|
|
|
534
567
|
global_device_count: int,
|
|
535
568
|
num_sc_per_device: int,
|
|
536
569
|
static_buffer_size: int | Mapping[str, int] | None = None,
|
|
537
|
-
) -> tuple[dict[str, ShardedCooMatrix],
|
|
570
|
+
) -> tuple[dict[str, ShardedCooMatrix], dict[str, InputStatsPerTable]]:
|
|
538
571
|
"""Prepares input samples for use in embedding lookups.
|
|
539
572
|
|
|
540
573
|
Args:
|
|
@@ -544,8 +577,8 @@ def stack_and_shard_samples(
|
|
|
544
577
|
global_device_count: Number of global JAX devices.
|
|
545
578
|
num_sc_per_device: Number of sparsecores per device.
|
|
546
579
|
static_buffer_size: The static buffer size to use for the samples.
|
|
547
|
-
|
|
548
|
-
|
|
580
|
+
Defaults to None, in which case an upper-bound for the buffer size
|
|
581
|
+
will be automatically determined.
|
|
549
582
|
|
|
550
583
|
Returns:
|
|
551
584
|
The preprocessed inputs, and statistics useful for updating FeatureSpecs
|
|
@@ -579,6 +612,7 @@ def stack_and_shard_samples(
|
|
|
579
612
|
)
|
|
580
613
|
|
|
581
614
|
out: dict[str, ShardedCooMatrix] = {}
|
|
615
|
+
out_stats: dict[str, InputStatsPerTable] = {}
|
|
582
616
|
tables_names = preprocessed_inputs.lhs_row_pointers.keys()
|
|
583
617
|
for table_name in tables_names:
|
|
584
618
|
shard_ends = preprocessed_inputs.lhs_row_pointers[table_name]
|
|
@@ -592,5 +626,17 @@ def stack_and_shard_samples(
|
|
|
592
626
|
row_ids=preprocessed_inputs.lhs_sample_ids[table_name],
|
|
593
627
|
values=preprocessed_inputs.lhs_gains[table_name],
|
|
594
628
|
)
|
|
629
|
+
out_stats[table_name] = InputStatsPerTable(
|
|
630
|
+
max_ids_per_partition=np.max(
|
|
631
|
+
stats.max_ids_per_partition[table_name]
|
|
632
|
+
),
|
|
633
|
+
max_unique_ids_per_partition=np.max(
|
|
634
|
+
stats.max_unique_ids_per_partition[table_name]
|
|
635
|
+
),
|
|
636
|
+
required_buffer_size_per_device=np.max(
|
|
637
|
+
stats.required_buffer_size_per_sc[table_name]
|
|
638
|
+
)
|
|
639
|
+
* num_sc_per_device,
|
|
640
|
+
)
|
|
595
641
|
|
|
596
|
-
return out,
|
|
642
|
+
return out, out_stats
|
|
@@ -53,7 +53,7 @@ OPTIMIZER_MAPPINGS = {
|
|
|
53
53
|
# KerasRS to TensorFlow
|
|
54
54
|
|
|
55
55
|
|
|
56
|
-
def
|
|
56
|
+
def keras_to_tf_tpu_configuration(
|
|
57
57
|
feature_configs: types.Nested[FeatureConfig],
|
|
58
58
|
table_stacking: str | Sequence[str] | Sequence[Sequence[str]],
|
|
59
59
|
num_replicas_in_sync: int,
|
|
@@ -66,14 +66,15 @@ def translate_keras_rs_configuration(
|
|
|
66
66
|
Args:
|
|
67
67
|
feature_configs: The nested Keras RS feature configs.
|
|
68
68
|
table_stacking: The Keras RS table stacking.
|
|
69
|
+
num_replicas_in_sync: The number of replicas in sync from the strategy.
|
|
69
70
|
|
|
70
71
|
Returns:
|
|
71
72
|
A tuple containing the TensorFlow TPU feature configs and the TensorFlow
|
|
72
73
|
TPU sparse core embedding config.
|
|
73
74
|
"""
|
|
74
|
-
tables: dict[
|
|
75
|
+
tables: dict[int, tf.tpu.experimental.embedding.TableConfig] = {}
|
|
75
76
|
feature_configs = keras.tree.map_structure(
|
|
76
|
-
lambda f:
|
|
77
|
+
lambda f: keras_to_tf_tpu_feature_config(
|
|
77
78
|
f, tables, num_replicas_in_sync
|
|
78
79
|
),
|
|
79
80
|
feature_configs,
|
|
@@ -108,9 +109,9 @@ def translate_keras_rs_configuration(
|
|
|
108
109
|
return feature_configs, sparse_core_embedding_config
|
|
109
110
|
|
|
110
111
|
|
|
111
|
-
def
|
|
112
|
+
def keras_to_tf_tpu_feature_config(
|
|
112
113
|
feature_config: FeatureConfig,
|
|
113
|
-
tables: dict[
|
|
114
|
+
tables: dict[int, tf.tpu.experimental.embedding.TableConfig],
|
|
114
115
|
num_replicas_in_sync: int,
|
|
115
116
|
) -> tf.tpu.experimental.embedding.FeatureConfig:
|
|
116
117
|
"""Translates a Keras RS feature config to a TensorFlow TPU feature config.
|
|
@@ -120,7 +121,8 @@ def translate_keras_rs_feature_config(
|
|
|
120
121
|
|
|
121
122
|
Args:
|
|
122
123
|
feature_config: The Keras RS feature config to translate.
|
|
123
|
-
tables: A mapping of KerasRS table
|
|
124
|
+
tables: A mapping of KerasRS table config ids to TF TPU table configs.
|
|
125
|
+
num_replicas_in_sync: The number of replicas in sync from the strategy.
|
|
124
126
|
|
|
125
127
|
Returns:
|
|
126
128
|
The TensorFlow TPU feature config.
|
|
@@ -131,10 +133,10 @@ def translate_keras_rs_feature_config(
|
|
|
131
133
|
f"but got {num_replicas_in_sync}."
|
|
132
134
|
)
|
|
133
135
|
|
|
134
|
-
table = tables.get(feature_config.table, None)
|
|
136
|
+
table = tables.get(id(feature_config.table), None)
|
|
135
137
|
if table is None:
|
|
136
|
-
table =
|
|
137
|
-
tables[feature_config.table] = table
|
|
138
|
+
table = keras_to_tf_tpu_table_config(feature_config.table)
|
|
139
|
+
tables[id(feature_config.table)] = table
|
|
138
140
|
|
|
139
141
|
if len(feature_config.output_shape) < 2:
|
|
140
142
|
raise ValueError(
|
|
@@ -168,7 +170,7 @@ def translate_keras_rs_feature_config(
|
|
|
168
170
|
)
|
|
169
171
|
|
|
170
172
|
|
|
171
|
-
def
|
|
173
|
+
def keras_to_tf_tpu_table_config(
|
|
172
174
|
table_config: TableConfig,
|
|
173
175
|
) -> tf.tpu.experimental.embedding.TableConfig:
|
|
174
176
|
initializer = table_config.initializer
|
|
@@ -179,13 +181,13 @@ def translate_keras_rs_table_config(
|
|
|
179
181
|
vocabulary_size=table_config.vocabulary_size,
|
|
180
182
|
dim=table_config.embedding_dim,
|
|
181
183
|
initializer=initializer,
|
|
182
|
-
optimizer=
|
|
184
|
+
optimizer=to_tf_tpu_optimizer(table_config.optimizer),
|
|
183
185
|
combiner=table_config.combiner,
|
|
184
186
|
name=table_config.name,
|
|
185
187
|
)
|
|
186
188
|
|
|
187
189
|
|
|
188
|
-
def
|
|
190
|
+
def keras_to_tf_tpu_optimizer(
|
|
189
191
|
optimizer: keras.optimizers.Optimizer,
|
|
190
192
|
) -> TfTpuOptimizer:
|
|
191
193
|
"""Translates a Keras optimizer to a TensorFlow TPU `_Optimizer`.
|
|
@@ -238,7 +240,12 @@ def translate_keras_optimizer(
|
|
|
238
240
|
"Unsupported optimizer option `Optimizer.loss_scale_factor`."
|
|
239
241
|
)
|
|
240
242
|
|
|
241
|
-
optimizer_mapping =
|
|
243
|
+
optimizer_mapping = None
|
|
244
|
+
for optimizer_class, mapping in OPTIMIZER_MAPPINGS.items():
|
|
245
|
+
# Handle subclasses of the main optimizer class.
|
|
246
|
+
if isinstance(optimizer, optimizer_class):
|
|
247
|
+
optimizer_mapping = mapping
|
|
248
|
+
break
|
|
242
249
|
if optimizer_mapping is None:
|
|
243
250
|
raise ValueError(
|
|
244
251
|
f"Unsupported optimizer type {type(optimizer)}. Optimizer must be "
|
|
@@ -258,7 +265,7 @@ def translate_keras_optimizer(
|
|
|
258
265
|
return optimizer_mapping.tpu_optimizer_class(**tpu_optimizer_kwargs)
|
|
259
266
|
|
|
260
267
|
|
|
261
|
-
def
|
|
268
|
+
def to_tf_tpu_optimizer(
|
|
262
269
|
optimizer: str | keras.optimizers.Optimizer | TfTpuOptimizer | None,
|
|
263
270
|
) -> TfTpuOptimizer:
|
|
264
271
|
"""Translates a Keras optimizer into a TensorFlow TPU `_Optimizer`.
|
|
@@ -299,7 +306,7 @@ def translate_optimizer(
|
|
|
299
306
|
"'sgd', 'adagrad', 'adam', or 'ftrl'"
|
|
300
307
|
)
|
|
301
308
|
elif isinstance(optimizer, keras.optimizers.Optimizer):
|
|
302
|
-
return
|
|
309
|
+
return keras_to_tf_tpu_optimizer(optimizer)
|
|
303
310
|
else:
|
|
304
311
|
raise ValueError(
|
|
305
312
|
f"Unknown optimizer type {type(optimizer)}. Please pass an "
|
|
@@ -312,7 +319,7 @@ def translate_optimizer(
|
|
|
312
319
|
# TensorFlow to TensorFlow
|
|
313
320
|
|
|
314
321
|
|
|
315
|
-
def
|
|
322
|
+
def clone_tf_tpu_feature_configs(
|
|
316
323
|
feature_configs: types.Nested[tf.tpu.experimental.embedding.FeatureConfig],
|
|
317
324
|
) -> types.Nested[tf.tpu.experimental.embedding.FeatureConfig]:
|
|
318
325
|
"""Clones and resolves TensorFlow TPU feature configs.
|
|
@@ -327,7 +334,7 @@ def clone_tf_feature_configs(
|
|
|
327
334
|
"""
|
|
328
335
|
table_configs_dict = {}
|
|
329
336
|
|
|
330
|
-
def
|
|
337
|
+
def clone_and_resolve_tf_tpu_feature_config(
|
|
331
338
|
fc: tf.tpu.experimental.embedding.FeatureConfig,
|
|
332
339
|
) -> tf.tpu.experimental.embedding.FeatureConfig:
|
|
333
340
|
if fc.table not in table_configs_dict:
|
|
@@ -336,7 +343,7 @@ def clone_tf_feature_configs(
|
|
|
336
343
|
vocabulary_size=fc.table.vocabulary_size,
|
|
337
344
|
dim=fc.table.dim,
|
|
338
345
|
initializer=fc.table.initializer,
|
|
339
|
-
optimizer=
|
|
346
|
+
optimizer=to_tf_tpu_optimizer(fc.table.optimizer),
|
|
340
347
|
combiner=fc.table.combiner,
|
|
341
348
|
name=fc.table.name,
|
|
342
349
|
quantization_config=fc.table.quantization_config,
|
|
@@ -352,5 +359,5 @@ def clone_tf_feature_configs(
|
|
|
352
359
|
)
|
|
353
360
|
|
|
354
361
|
return keras.tree.map_structure(
|
|
355
|
-
|
|
362
|
+
clone_and_resolve_tf_tpu_feature_config, feature_configs
|
|
356
363
|
)
|
|
@@ -106,7 +106,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
106
106
|
"for the configuration."
|
|
107
107
|
)
|
|
108
108
|
self._tpu_feature_configs, self._sparse_core_embedding_config = (
|
|
109
|
-
config_conversion.
|
|
109
|
+
config_conversion.keras_to_tf_tpu_configuration(
|
|
110
110
|
feature_configs,
|
|
111
111
|
table_stacking,
|
|
112
112
|
strategy.num_replicas_in_sync,
|
|
@@ -135,10 +135,10 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
135
135
|
"supported with this TPU generation."
|
|
136
136
|
)
|
|
137
137
|
self._tpu_feature_configs = (
|
|
138
|
-
config_conversion.
|
|
138
|
+
config_conversion.clone_tf_tpu_feature_configs(feature_configs)
|
|
139
139
|
)
|
|
140
140
|
|
|
141
|
-
self._tpu_optimizer = config_conversion.
|
|
141
|
+
self._tpu_optimizer = config_conversion.to_tf_tpu_optimizer(
|
|
142
142
|
self._optimizer
|
|
143
143
|
)
|
|
144
144
|
|
|
@@ -281,8 +281,18 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
281
281
|
def _sparsecore_get_embedding_tables(self) -> dict[str, types.Tensor]:
|
|
282
282
|
tables: dict[str, types.Tensor] = {}
|
|
283
283
|
strategy = tf.distribute.get_strategy()
|
|
284
|
-
|
|
285
|
-
|
|
284
|
+
if not self._is_tpu_strategy(strategy):
|
|
285
|
+
raise RuntimeError(
|
|
286
|
+
"`DistributedEmbedding.get_embedding_tables` needs to be "
|
|
287
|
+
"called under the TPUStrategy that DistributedEmbedding was "
|
|
288
|
+
f"created with, but is being called under strategy {strategy}. "
|
|
289
|
+
"Please use `with strategy.scope()` when calling "
|
|
290
|
+
"`get_embedding_tables`."
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
tpu_hardware = strategy.extended.tpu_hardware_feature
|
|
294
|
+
num_sc_per_device = tpu_hardware.num_embedding_devices_per_chip
|
|
295
|
+
num_shards = strategy.num_replicas_in_sync * num_sc_per_device
|
|
286
296
|
|
|
287
297
|
def populate_table(
|
|
288
298
|
feature_config: tf.tpu.experimental.embedding.FeatureConfig,
|
keras_rs/src/version.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: keras-rs-nightly
|
|
3
|
-
Version: 0.2.2.
|
|
3
|
+
Version: 0.2.2.dev202509170322
|
|
4
4
|
Summary: Multi-backend recommender systems with Keras 3.
|
|
5
5
|
Author-email: Keras team <keras-users@googlegroups.com>
|
|
6
6
|
License: Apache License 2.0
|
|
@@ -8,8 +8,9 @@ Project-URL: Home, https://keras.io/keras_rs
|
|
|
8
8
|
Project-URL: Repository, https://github.com/keras-team/keras-rs
|
|
9
9
|
Classifier: Development Status :: 3 - Alpha
|
|
10
10
|
Classifier: Programming Language :: Python :: 3
|
|
11
|
-
Classifier: Programming Language :: Python :: 3.10
|
|
12
11
|
Classifier: Programming Language :: Python :: 3.11
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
13
14
|
Classifier: Programming Language :: Python :: 3 :: Only
|
|
14
15
|
Classifier: Operating System :: Unix
|
|
15
16
|
Classifier: Operating System :: Microsoft :: Windows
|
|
@@ -17,7 +18,7 @@ Classifier: Operating System :: MacOS
|
|
|
17
18
|
Classifier: Intended Audience :: Science/Research
|
|
18
19
|
Classifier: Topic :: Scientific/Engineering
|
|
19
20
|
Classifier: Topic :: Software Development
|
|
20
|
-
Requires-Python: >=3.
|
|
21
|
+
Requires-Python: >=3.11
|
|
21
22
|
Description-Content-Type: text/markdown
|
|
22
23
|
Requires-Dist: keras
|
|
23
24
|
Requires-Dist: ml-dtypes
|
|
@@ -5,22 +5,22 @@ keras_rs/metrics/__init__.py,sha256=Qxpf6OFooIL9TIn2l3WgOea3HFRG0hq02glPAxtMZ9c,
|
|
|
5
5
|
keras_rs/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
6
6
|
keras_rs/src/api_export.py,sha256=RsmG-DvO-cdFeAF9W6LRzms0kvtm-Yp9BAA_d-952zI,510
|
|
7
7
|
keras_rs/src/types.py,sha256=1A-oLRdX1-f2DsVZBcNl8qNsaH8pM-gnleLT9FWZWBw,1189
|
|
8
|
-
keras_rs/src/version.py,sha256=
|
|
8
|
+
keras_rs/src/version.py,sha256=PwX4FMnP4-c0qMKRrgFfXMqENOSFmjAEPfuKyVB7jS0,224
|
|
9
9
|
keras_rs/src/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
10
|
keras_rs/src/layers/embedding/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
-
keras_rs/src/layers/embedding/base_distributed_embedding.py,sha256=
|
|
11
|
+
keras_rs/src/layers/embedding/base_distributed_embedding.py,sha256=RkXZ6notj3Cq6ryR9w30Wb8UlaWjLcUK2Os9ZUQvuhY,45568
|
|
12
12
|
keras_rs/src/layers/embedding/distributed_embedding.py,sha256=94jxUHoGK3Gs9yfV0KxFTuqPo7XFnhgCNlO2FEeiSgM,1072
|
|
13
|
-
keras_rs/src/layers/embedding/distributed_embedding_config.py,sha256=
|
|
13
|
+
keras_rs/src/layers/embedding/distributed_embedding_config.py,sha256=L41x6W1xcXI-3m94nOh_OsHn6OYpoynakKJvNboJnvE,5762
|
|
14
14
|
keras_rs/src/layers/embedding/embed_reduce.py,sha256=c-MnEw1-KWs0jTf0JJ_ZBOY-9hRkiFyu989Dof3DnS8,12343
|
|
15
15
|
keras_rs/src/layers/embedding/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
16
16
|
keras_rs/src/layers/embedding/jax/checkpoint_utils.py,sha256=wZ4I5WZVNg5WnrD2j7nhAXgLzDc7xMrUEkSAOx5Sz5c,3495
|
|
17
17
|
keras_rs/src/layers/embedding/jax/config_conversion.py,sha256=Di1UzRwLgGHd7RuWYJMj2mCOr1u9MseFEWaYKnwD9Bs,16742
|
|
18
|
-
keras_rs/src/layers/embedding/jax/distributed_embedding.py,sha256=
|
|
18
|
+
keras_rs/src/layers/embedding/jax/distributed_embedding.py,sha256=qgokWQKaSMj7LeCtPdzGpdRk5ZGlqUo6m840y3FkNYw,29666
|
|
19
19
|
keras_rs/src/layers/embedding/jax/embedding_lookup.py,sha256=8LigXjPr7uQaUOdZM6yoLGoPYdRcbkXkFeL_sJoQ6uQ,8223
|
|
20
|
-
keras_rs/src/layers/embedding/jax/embedding_utils.py,sha256=
|
|
20
|
+
keras_rs/src/layers/embedding/jax/embedding_utils.py,sha256=nTRa0C40dCjrSO7VAdiX3PierXmodyjAJqNnLSXLxMU,22178
|
|
21
21
|
keras_rs/src/layers/embedding/tensorflow/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
22
|
-
keras_rs/src/layers/embedding/tensorflow/config_conversion.py,sha256=
|
|
23
|
-
keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py,sha256
|
|
22
|
+
keras_rs/src/layers/embedding/tensorflow/config_conversion.py,sha256=HpuDthRQQ3X0EO8dW6OAdcgTODkujZlx_swgreVwXyk,13220
|
|
23
|
+
keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py,sha256=TBPYV8gP3ZnAFEwtxmWr_Rp3s-Cj0RrKzF6UOALJ4B0,17817
|
|
24
24
|
keras_rs/src/layers/feature_interaction/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
25
25
|
keras_rs/src/layers/feature_interaction/dot_interaction.py,sha256=Rs8xIHXNWQNiwjp_xzvQRmTSV1AyhJjDgVc3K5pTmrQ,8530
|
|
26
26
|
keras_rs/src/layers/feature_interaction/feature_cross.py,sha256=Wq_eQvO0WTRlep69QbKi8TwY8bnFoF9vreP_j6ZHNFE,8666
|
|
@@ -50,7 +50,7 @@ keras_rs/src/metrics/utils.py,sha256=fGTo8j0ykVE5Y3yQCS2orSFcHY20Uxt0NazyPsybUsw
|
|
|
50
50
|
keras_rs/src/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
51
51
|
keras_rs/src/utils/doc_string_utils.py,sha256=CmqomepmaYcvpACpXEXkrJb8DMnvIgmYK-lJ53lYarY,1675
|
|
52
52
|
keras_rs/src/utils/keras_utils.py,sha256=dc-NFzs3a-qmRw0vBDiMslPLfrm9yymGduLWesXPhuY,2123
|
|
53
|
-
keras_rs_nightly-0.2.2.
|
|
54
|
-
keras_rs_nightly-0.2.2.
|
|
55
|
-
keras_rs_nightly-0.2.2.
|
|
56
|
-
keras_rs_nightly-0.2.2.
|
|
53
|
+
keras_rs_nightly-0.2.2.dev202509170322.dist-info/METADATA,sha256=6XLLYgYUyBV6ww_1AuexR8ORRX4WwcONBaqsObhloqs,5324
|
|
54
|
+
keras_rs_nightly-0.2.2.dev202509170322.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
55
|
+
keras_rs_nightly-0.2.2.dev202509170322.dist-info/top_level.txt,sha256=pWs8X78Z0cn6lfcIb9VYOW5UeJ-TpoaO9dByzo7_FFo,9
|
|
56
|
+
keras_rs_nightly-0.2.2.dev202509170322.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|