tf-models-nightly 2.18.0.dev20240816__py2.py3-none-any.whl → 2.18.0.dev20240818__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- official/recommendation/ranking/configs/config.py +2 -0
- official/recommendation/ranking/data/data_pipeline_multi_hot.py +2 -4
- official/recommendation/ranking/task.py +10 -1
- {tf_models_nightly-2.18.0.dev20240816.dist-info → tf_models_nightly-2.18.0.dev20240818.dist-info}/METADATA +1 -1
- {tf_models_nightly-2.18.0.dev20240816.dist-info → tf_models_nightly-2.18.0.dev20240818.dist-info}/RECORD +9 -9
- {tf_models_nightly-2.18.0.dev20240816.dist-info → tf_models_nightly-2.18.0.dev20240818.dist-info}/AUTHORS +0 -0
- {tf_models_nightly-2.18.0.dev20240816.dist-info → tf_models_nightly-2.18.0.dev20240818.dist-info}/LICENSE +0 -0
- {tf_models_nightly-2.18.0.dev20240816.dist-info → tf_models_nightly-2.18.0.dev20240818.dist-info}/WHEEL +0 -0
- {tf_models_nightly-2.18.0.dev20240816.dist-info → tf_models_nightly-2.18.0.dev20240818.dist-info}/top_level.txt +0 -0
@@ -138,6 +138,8 @@ class ModelConfig(hyperparams.Config):
|
|
138
138
|
max_ids_per_chip_per_sample: int | None = None
|
139
139
|
max_ids_per_table: Union[int, List[int]] | None = None
|
140
140
|
max_unique_ids_per_table: Union[int, List[int]] | None = None
|
141
|
+
allow_id_dropping: bool = False
|
142
|
+
initialize_tables_on_host: bool = False
|
141
143
|
|
142
144
|
|
143
145
|
@dataclasses.dataclass
|
@@ -45,15 +45,13 @@ class CriteoTsvReaderMultiHot:
|
|
45
45
|
num_dense_features: int,
|
46
46
|
vocab_sizes: List[int],
|
47
47
|
multi_hot_sizes: List[int],
|
48
|
-
use_synthetic_data: bool = False
|
49
|
-
use_cached_data: bool = False):
|
48
|
+
use_synthetic_data: bool = False):
|
50
49
|
self._file_pattern = file_pattern
|
51
50
|
self._params = params
|
52
51
|
self._num_dense_features = num_dense_features
|
53
52
|
self._vocab_sizes = vocab_sizes
|
54
53
|
self._use_synthetic_data = use_synthetic_data
|
55
54
|
self._multi_hot_sizes = multi_hot_sizes
|
56
|
-
self._use_cached_data = use_cached_data
|
57
55
|
|
58
56
|
def __call__(self, ctx: tf.distribute.InputContext) -> tf.data.Dataset:
|
59
57
|
params = self._params
|
@@ -146,7 +144,7 @@ class CriteoTsvReaderMultiHot:
|
|
146
144
|
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
147
145
|
|
148
146
|
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
|
149
|
-
if self.
|
147
|
+
if self._params.use_cached_data:
|
150
148
|
dataset = dataset.take(1).cache().repeat()
|
151
149
|
|
152
150
|
return dataset
|
@@ -39,6 +39,8 @@ def _get_tpu_embedding_feature_config(
|
|
39
39
|
max_ids_per_chip_per_sample: Optional[int] = None,
|
40
40
|
max_ids_per_table: Optional[Union[int, List[int]]] = None,
|
41
41
|
max_unique_ids_per_table: Optional[Union[int, List[int]]] = None,
|
42
|
+
allow_id_dropping: bool = False,
|
43
|
+
initialize_tables_on_host: bool = False,
|
42
44
|
) -> Tuple[
|
43
45
|
Dict[str, tf.tpu.experimental.embedding.FeatureConfig],
|
44
46
|
Optional[tf.tpu.experimental.embedding.SparseCoreEmbeddingConfig],
|
@@ -57,6 +59,10 @@ def _get_tpu_embedding_feature_config(
|
|
57
59
|
sample.
|
58
60
|
max_ids_per_table: Maximum number of embedding ids per table.
|
59
61
|
max_unique_ids_per_table: Maximum number of unique embedding ids per table.
|
62
|
+
allow_id_dropping: bool to allow id dropping.
|
63
|
+
initialize_tables_on_host: bool : if the embedding table size is more than
|
64
|
+
what HBM can handle, this flag will help initialize the full embedding
|
65
|
+
tables on host and then copy shards to HBM.
|
60
66
|
|
61
67
|
Returns:
|
62
68
|
A dictionary of feature_name, FeatureConfig pairs.
|
@@ -140,7 +146,8 @@ def _get_tpu_embedding_feature_config(
|
|
140
146
|
max_ids_per_chip_per_sample=max_ids_per_chip_per_sample,
|
141
147
|
max_ids_per_table=max_ids_per_table_dict,
|
142
148
|
max_unique_ids_per_table=max_unique_ids_per_table_dict,
|
143
|
-
allow_id_dropping=
|
149
|
+
allow_id_dropping=allow_id_dropping,
|
150
|
+
initialize_tables_on_host=initialize_tables_on_host,
|
144
151
|
)
|
145
152
|
|
146
153
|
return feature_config, sparsecore_config
|
@@ -248,6 +255,8 @@ class RankingTask(base_task.Task):
|
|
248
255
|
max_ids_per_chip_per_sample=self.task_config.model.max_ids_per_chip_per_sample,
|
249
256
|
max_ids_per_table=self.task_config.model.max_ids_per_table,
|
250
257
|
max_unique_ids_per_table=self.task_config.model.max_unique_ids_per_table,
|
258
|
+
allow_id_dropping=self.task_config.model.allow_id_dropping,
|
259
|
+
initialize_tables_on_host=self.task_config.model.initialize_tables_on_host,
|
251
260
|
)
|
252
261
|
)
|
253
262
|
|
@@ -877,16 +877,16 @@ official/recommendation/popen_helper.py,sha256=TMWMwsW1DF15YCJ0RG9pE3wsL6njLu8Ed
|
|
877
877
|
official/recommendation/stat_utils.py,sha256=BjWRO2jzmAJyqeRUw0hMHhQqlyDNrdNvQJKAmbDJ4Rc,3076
|
878
878
|
official/recommendation/ranking/__init__.py,sha256=7oiypy0N82PDw9aSdcJBLVoGTd_oRSUOdvuJhMv4leQ,609
|
879
879
|
official/recommendation/ranking/common.py,sha256=fUpBln1auJWwvX5BJIYJYxRQDzXPqCzbS6NYAcn0KaQ,3998
|
880
|
-
official/recommendation/ranking/task.py,sha256=
|
880
|
+
official/recommendation/ranking/task.py,sha256=Xw4VjbrdONt2pNZf_1e35A2swb5_fcYrqZqmEZ1mmAg,14480
|
881
881
|
official/recommendation/ranking/task_test.py,sha256=vPN_5oq1tWF3r7GuRTuAppKSF_mP3qoGFzGRoX09ylw,2891
|
882
882
|
official/recommendation/ranking/train.py,sha256=_7zC2SVsOOOBN--If0XlWw4gtdaEeTT04PbCQRSzieo,6604
|
883
883
|
official/recommendation/ranking/train_test.py,sha256=n96JN-gI5wj4cbZNJCwQ9XLViKt0ZAg9f9dtaTpAdV8,8756
|
884
884
|
official/recommendation/ranking/configs/__init__.py,sha256=7oiypy0N82PDw9aSdcJBLVoGTd_oRSUOdvuJhMv4leQ,609
|
885
|
-
official/recommendation/ranking/configs/config.py,sha256=
|
885
|
+
official/recommendation/ranking/configs/config.py,sha256=M0DgZvJ8CNtl6wUMutepQbn9joBtQtNQ-iEggbUxgec,14952
|
886
886
|
official/recommendation/ranking/configs/config_test.py,sha256=4_W0YUwJ2q0D8HdPdIvpJcAsaySjb_pDMWj26G7T1Ws,1474
|
887
887
|
official/recommendation/ranking/data/__init__.py,sha256=7oiypy0N82PDw9aSdcJBLVoGTd_oRSUOdvuJhMv4leQ,609
|
888
888
|
official/recommendation/ranking/data/data_pipeline.py,sha256=jtNyeAEZQFcFJaatTJhxYPcokN56OtETbF2fFqzWk8k,7578
|
889
|
-
official/recommendation/ranking/data/data_pipeline_multi_hot.py,sha256=
|
889
|
+
official/recommendation/ranking/data/data_pipeline_multi_hot.py,sha256=ghhxQh-L5OQU6cn0mfvo2S73Ppi50KfbU90Q2bn9dwo,12197
|
890
890
|
official/recommendation/ranking/data/data_pipeline_multi_hot_test.py,sha256=-NKJ4-0wwSjveCY7sMP9lChDkN6lWH5VahYMHUa9z2A,2949
|
891
891
|
official/recommendation/ranking/data/data_pipeline_test.py,sha256=34hKGxUBRjxyh9d2c7_eBhISpXI5F5PfgPfqBQoHM3Q,2573
|
892
892
|
official/recommendation/uplift/__init__.py,sha256=_jZilTPWKu-MfMaz1IgBjEW6wqkK3FNZ1QAP4a8my3I,990
|
@@ -1212,9 +1212,9 @@ tensorflow_models/tensorflow_models_test.py,sha256=nc6A9K53OGqF25xN5St8EiWvdVbda
|
|
1212
1212
|
tensorflow_models/nlp/__init__.py,sha256=4tA5Pf4qaFwT-fIFOpX7x7FHJpnyJT-5UgOeFYTyMlc,807
|
1213
1213
|
tensorflow_models/uplift/__init__.py,sha256=mqfa55gweOdpKoaQyid4A_4u7xw__FcQeSIF0k_pYmI,999
|
1214
1214
|
tensorflow_models/vision/__init__.py,sha256=zBorY_v5xva1uI-qxhZO3Qh-Dii-Suq6wEYh6hKHDfc,833
|
1215
|
-
tf_models_nightly-2.18.0.
|
1216
|
-
tf_models_nightly-2.18.0.
|
1217
|
-
tf_models_nightly-2.18.0.
|
1218
|
-
tf_models_nightly-2.18.0.
|
1219
|
-
tf_models_nightly-2.18.0.
|
1220
|
-
tf_models_nightly-2.18.0.
|
1215
|
+
tf_models_nightly-2.18.0.dev20240818.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
|
1216
|
+
tf_models_nightly-2.18.0.dev20240818.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
|
1217
|
+
tf_models_nightly-2.18.0.dev20240818.dist-info/METADATA,sha256=HqIcfgOhVj0ZIvYNVNRLiZ8crrTfPwNKsa7O5n0dYDw,1432
|
1218
|
+
tf_models_nightly-2.18.0.dev20240818.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
|
1219
|
+
tf_models_nightly-2.18.0.dev20240818.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
|
1220
|
+
tf_models_nightly-2.18.0.dev20240818.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|