tf-models-nightly 2.18.0.dev20240815__py2.py3-none-any.whl → 2.18.0.dev20240817__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -138,6 +138,8 @@ class ModelConfig(hyperparams.Config):
138
138
  max_ids_per_chip_per_sample: int | None = None
139
139
  max_ids_per_table: Union[int, List[int]] | None = None
140
140
  max_unique_ids_per_table: Union[int, List[int]] | None = None
141
+ allow_id_dropping: bool = False
142
+ initialize_tables_on_host: bool = False
141
143
 
142
144
 
143
145
  @dataclasses.dataclass
@@ -45,15 +45,13 @@ class CriteoTsvReaderMultiHot:
45
45
  num_dense_features: int,
46
46
  vocab_sizes: List[int],
47
47
  multi_hot_sizes: List[int],
48
- use_synthetic_data: bool = False,
49
- use_cached_data: bool = False):
48
+ use_synthetic_data: bool = False):
50
49
  self._file_pattern = file_pattern
51
50
  self._params = params
52
51
  self._num_dense_features = num_dense_features
53
52
  self._vocab_sizes = vocab_sizes
54
53
  self._use_synthetic_data = use_synthetic_data
55
54
  self._multi_hot_sizes = multi_hot_sizes
56
- self._use_cached_data = use_cached_data
57
55
 
58
56
  def __call__(self, ctx: tf.distribute.InputContext) -> tf.data.Dataset:
59
57
  params = self._params
@@ -146,7 +144,7 @@ class CriteoTsvReaderMultiHot:
146
144
  num_parallel_calls=tf.data.experimental.AUTOTUNE)
147
145
 
148
146
  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
149
- if self._use_cached_data:
147
+ if self._params.use_cached_data:
150
148
  dataset = dataset.take(1).cache().repeat()
151
149
 
152
150
  return dataset
@@ -39,6 +39,8 @@ def _get_tpu_embedding_feature_config(
39
39
  max_ids_per_chip_per_sample: Optional[int] = None,
40
40
  max_ids_per_table: Optional[Union[int, List[int]]] = None,
41
41
  max_unique_ids_per_table: Optional[Union[int, List[int]]] = None,
42
+ allow_id_dropping: bool = False,
43
+ initialize_tables_on_host: bool = False,
42
44
  ) -> Tuple[
43
45
  Dict[str, tf.tpu.experimental.embedding.FeatureConfig],
44
46
  Optional[tf.tpu.experimental.embedding.SparseCoreEmbeddingConfig],
@@ -57,6 +59,10 @@ def _get_tpu_embedding_feature_config(
57
59
  sample.
58
60
  max_ids_per_table: Maximum number of embedding ids per table.
59
61
  max_unique_ids_per_table: Maximum number of unique embedding ids per table.
62
+ allow_id_dropping: bool to allow id dropping.
63
+ initialize_tables_on_host: bool : if the embedding table size is more than
64
+ what HBM can handle, this flag will help initialize the full embedding
65
+ tables on host and then copy shards to HBM.
60
66
 
61
67
  Returns:
62
68
  A dictionary of feature_name, FeatureConfig pairs.
@@ -140,7 +146,8 @@ def _get_tpu_embedding_feature_config(
140
146
  max_ids_per_chip_per_sample=max_ids_per_chip_per_sample,
141
147
  max_ids_per_table=max_ids_per_table_dict,
142
148
  max_unique_ids_per_table=max_unique_ids_per_table_dict,
143
- allow_id_dropping=False,
149
+ allow_id_dropping=allow_id_dropping,
150
+ initialize_tables_on_host=initialize_tables_on_host,
144
151
  )
145
152
 
146
153
  return feature_config, sparsecore_config
@@ -248,6 +255,8 @@ class RankingTask(base_task.Task):
248
255
  max_ids_per_chip_per_sample=self.task_config.model.max_ids_per_chip_per_sample,
249
256
  max_ids_per_table=self.task_config.model.max_ids_per_table,
250
257
  max_unique_ids_per_table=self.task_config.model.max_unique_ids_per_table,
258
+ allow_id_dropping=self.task_config.model.allow_id_dropping,
259
+ initialize_tables_on_host=self.task_config.model.initialize_tables_on_host,
251
260
  )
252
261
  )
253
262
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tf-models-nightly
3
- Version: 2.18.0.dev20240815
3
+ Version: 2.18.0.dev20240817
4
4
  Summary: TensorFlow Official Models
5
5
  Home-page: https://github.com/tensorflow/models
6
6
  Author: Google Inc.
@@ -877,16 +877,16 @@ official/recommendation/popen_helper.py,sha256=TMWMwsW1DF15YCJ0RG9pE3wsL6njLu8Ed
877
877
  official/recommendation/stat_utils.py,sha256=BjWRO2jzmAJyqeRUw0hMHhQqlyDNrdNvQJKAmbDJ4Rc,3076
878
878
  official/recommendation/ranking/__init__.py,sha256=7oiypy0N82PDw9aSdcJBLVoGTd_oRSUOdvuJhMv4leQ,609
879
879
  official/recommendation/ranking/common.py,sha256=fUpBln1auJWwvX5BJIYJYxRQDzXPqCzbS6NYAcn0KaQ,3998
880
- official/recommendation/ranking/task.py,sha256=NLoU4ZrCzrxfqg3h-xk_S8kVN1ZedZ-PKjdULmwusQU,13908
880
+ official/recommendation/ranking/task.py,sha256=Xw4VjbrdONt2pNZf_1e35A2swb5_fcYrqZqmEZ1mmAg,14480
881
881
  official/recommendation/ranking/task_test.py,sha256=vPN_5oq1tWF3r7GuRTuAppKSF_mP3qoGFzGRoX09ylw,2891
882
882
  official/recommendation/ranking/train.py,sha256=_7zC2SVsOOOBN--If0XlWw4gtdaEeTT04PbCQRSzieo,6604
883
883
  official/recommendation/ranking/train_test.py,sha256=n96JN-gI5wj4cbZNJCwQ9XLViKt0ZAg9f9dtaTpAdV8,8756
884
884
  official/recommendation/ranking/configs/__init__.py,sha256=7oiypy0N82PDw9aSdcJBLVoGTd_oRSUOdvuJhMv4leQ,609
885
- official/recommendation/ranking/configs/config.py,sha256=_VICZEsRaLS4gbKeVv6T4RDnktqU9ughgx0s-GfgMgA,14876
885
+ official/recommendation/ranking/configs/config.py,sha256=M0DgZvJ8CNtl6wUMutepQbn9joBtQtNQ-iEggbUxgec,14952
886
886
  official/recommendation/ranking/configs/config_test.py,sha256=4_W0YUwJ2q0D8HdPdIvpJcAsaySjb_pDMWj26G7T1Ws,1474
887
887
  official/recommendation/ranking/data/__init__.py,sha256=7oiypy0N82PDw9aSdcJBLVoGTd_oRSUOdvuJhMv4leQ,609
888
888
  official/recommendation/ranking/data/data_pipeline.py,sha256=jtNyeAEZQFcFJaatTJhxYPcokN56OtETbF2fFqzWk8k,7578
889
- official/recommendation/ranking/data/data_pipeline_multi_hot.py,sha256=b8JGO0UyFOEjWn43NRXhEAs2wVBzEJuewwnwOGqTluM,12280
889
+ official/recommendation/ranking/data/data_pipeline_multi_hot.py,sha256=ghhxQh-L5OQU6cn0mfvo2S73Ppi50KfbU90Q2bn9dwo,12197
890
890
  official/recommendation/ranking/data/data_pipeline_multi_hot_test.py,sha256=-NKJ4-0wwSjveCY7sMP9lChDkN6lWH5VahYMHUa9z2A,2949
891
891
  official/recommendation/ranking/data/data_pipeline_test.py,sha256=34hKGxUBRjxyh9d2c7_eBhISpXI5F5PfgPfqBQoHM3Q,2573
892
892
  official/recommendation/uplift/__init__.py,sha256=_jZilTPWKu-MfMaz1IgBjEW6wqkK3FNZ1QAP4a8my3I,990
@@ -1212,9 +1212,9 @@ tensorflow_models/tensorflow_models_test.py,sha256=nc6A9K53OGqF25xN5St8EiWvdVbda
1212
1212
  tensorflow_models/nlp/__init__.py,sha256=4tA5Pf4qaFwT-fIFOpX7x7FHJpnyJT-5UgOeFYTyMlc,807
1213
1213
  tensorflow_models/uplift/__init__.py,sha256=mqfa55gweOdpKoaQyid4A_4u7xw__FcQeSIF0k_pYmI,999
1214
1214
  tensorflow_models/vision/__init__.py,sha256=zBorY_v5xva1uI-qxhZO3Qh-Dii-Suq6wEYh6hKHDfc,833
1215
- tf_models_nightly-2.18.0.dev20240815.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1216
- tf_models_nightly-2.18.0.dev20240815.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1217
- tf_models_nightly-2.18.0.dev20240815.dist-info/METADATA,sha256=hckZ4iIvoDOYPyOgzMUC1oi0oRJgvU560js0o6GDilk,1432
1218
- tf_models_nightly-2.18.0.dev20240815.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1219
- tf_models_nightly-2.18.0.dev20240815.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1220
- tf_models_nightly-2.18.0.dev20240815.dist-info/RECORD,,
1215
+ tf_models_nightly-2.18.0.dev20240817.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1216
+ tf_models_nightly-2.18.0.dev20240817.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1217
+ tf_models_nightly-2.18.0.dev20240817.dist-info/METADATA,sha256=TmF7lI61RDYhY0CsJ0KWZ-dTsBFtw7lJkFqqCO6S8PU,1432
1218
+ tf_models_nightly-2.18.0.dev20240817.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1219
+ tf_models_nightly-2.18.0.dev20240817.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1220
+ tf_models_nightly-2.18.0.dev20240817.dist-info/RECORD,,