tf-models-nightly 2.17.0.dev20240711__py2.py3-none-any.whl → 2.17.0.dev20240713__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -74,6 +74,7 @@ class DataConfig(hyperparams.Config):
74
74
  cycle_length: int = 10
75
75
  sharding: bool = True
76
76
  num_shards_per_host: int = 8
77
+ use_cached_data: bool = False
77
78
 
78
79
 
79
80
  @dataclasses.dataclass
@@ -38,12 +38,14 @@ class CriteoTsvReader:
38
38
  params: config.DataConfig,
39
39
  num_dense_features: int,
40
40
  vocab_sizes: List[int],
41
- use_synthetic_data: bool = False):
41
+ use_synthetic_data: bool = False,
42
+ use_cached_data: bool = False):
42
43
  self._file_pattern = file_pattern
43
44
  self._params = params
44
45
  self._num_dense_features = num_dense_features
45
46
  self._vocab_sizes = vocab_sizes
46
47
  self._use_synthetic_data = use_synthetic_data
48
+ self._use_cached_data = use_cached_data
47
49
 
48
50
  def __call__(self, ctx: tf.distribute.InputContext) -> tf.data.Dataset:
49
51
  params = self._params
@@ -117,6 +119,8 @@ class CriteoTsvReader:
117
119
  num_parallel_calls=tf.data.experimental.AUTOTUNE)
118
120
 
119
121
  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
122
+ if self._use_cached_data:
123
+ dataset = dataset.take(1).cache().repeat()
120
124
 
121
125
  return dataset
122
126
 
@@ -173,6 +177,9 @@ class CriteoTsvReader:
173
177
  if params.is_training:
174
178
  dataset = dataset.repeat()
175
179
 
180
+ if self._use_cached_data:
181
+ dataset = dataset.take(1).cache().repeat()
182
+
176
183
  return dataset.batch(batch_size, drop_remainder=True)
177
184
 
178
185
 
@@ -45,13 +45,15 @@ class CriteoTsvReaderMultiHot:
45
45
  num_dense_features: int,
46
46
  vocab_sizes: List[int],
47
47
  multi_hot_sizes: List[int],
48
- use_synthetic_data: bool = False):
48
+ use_synthetic_data: bool = False,
49
+ use_cached_data: bool = False):
49
50
  self._file_pattern = file_pattern
50
51
  self._params = params
51
52
  self._num_dense_features = num_dense_features
52
53
  self._vocab_sizes = vocab_sizes
53
54
  self._use_synthetic_data = use_synthetic_data
54
55
  self._multi_hot_sizes = multi_hot_sizes
56
+ self._use_cached_data = use_cached_data
55
57
 
56
58
  def __call__(self, ctx: tf.distribute.InputContext) -> tf.data.Dataset:
57
59
  params = self._params
@@ -144,6 +146,8 @@ class CriteoTsvReaderMultiHot:
144
146
  num_parallel_calls=tf.data.experimental.AUTOTUNE)
145
147
 
146
148
  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
149
+ if self._use_cached_data:
150
+ dataset = dataset.take(1).cache().repeat()
147
151
 
148
152
  return dataset
149
153
 
@@ -215,12 +219,14 @@ class CriteoTFRecordReader(object):
215
219
  params: config.DataConfig,
216
220
  num_dense_features: int,
217
221
  vocab_sizes: List[int],
218
- multi_hot_sizes: List[int],):
222
+ multi_hot_sizes: List[int],
223
+ use_cached_data: bool = False):
219
224
  self._file_pattern = file_pattern
220
225
  self._params = params
221
226
  self._num_dense_features = num_dense_features
222
227
  self._vocab_sizes = vocab_sizes
223
228
  self._multi_hot_sizes = multi_hot_sizes
229
+ self._use_cached_data = use_cached_data
224
230
 
225
231
  self.label_features = 'label'
226
232
  self.dense_features = ['dense-feature-%d' % x for x in range(1, 14)]
@@ -307,6 +313,8 @@ class CriteoTFRecordReader(object):
307
313
  num_parallel_calls=tf.data.experimental.AUTOTUNE,
308
314
  )
309
315
  dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
316
+ if self._use_cached_data:
317
+ dataset = dataset.take(1).cache().repeat()
310
318
 
311
319
  return dataset
312
320
 
@@ -23,9 +23,13 @@ from official.recommendation.ranking.data import data_pipeline_multi_hot
23
23
 
24
24
  class DataPipelineTest(parameterized.TestCase, tf.test.TestCase):
25
25
 
26
- @parameterized.named_parameters(('Train', True),
27
- ('Eval', False))
28
- def testSyntheticDataPipeline(self, is_training):
26
+ @parameterized.named_parameters(
27
+ ('TrainCached', True, True),
28
+ ('EvalNotCached', False, False),
29
+ ('TrainNotCached', True, False),
30
+ ('EvalCached', False, True),
31
+ )
32
+ def testSyntheticDataPipeline(self, is_training, use_cached_data):
29
33
  task = config.Task(
30
34
  model=config.ModelConfig(
31
35
  embedding_dim=4,
@@ -39,8 +43,10 @@ class DataPipelineTest(parameterized.TestCase, tf.test.TestCase):
39
43
  dcn_low_rank_dim=64,
40
44
  bottom_mlp=[64, 32, 4],
41
45
  top_mlp=[64, 32, 1]),
42
- train_data=config.DataConfig(global_batch_size=16),
43
- validation_data=config.DataConfig(global_batch_size=16),
46
+ train_data=config.DataConfig(global_batch_size=16,
47
+ use_cached_data=use_cached_data),
48
+ validation_data=config.DataConfig(global_batch_size=16,
49
+ use_cached_data=use_cached_data),
44
50
  use_synthetic_data=True)
45
51
 
46
52
  num_dense_features = task.model.num_dense_features
@@ -23,19 +23,29 @@ from official.recommendation.ranking.data import data_pipeline
23
23
 
24
24
  class DataPipelineTest(parameterized.TestCase, tf.test.TestCase):
25
25
 
26
- @parameterized.named_parameters(('Train', True),
27
- ('Eval', False))
28
- def testSyntheticDataPipeline(self, is_training):
26
+ @parameterized.named_parameters(
27
+ ('TrainCached', True, True),
28
+ ('EvalNotCached', False, False),
29
+ ('TrainNotCached', True, False),
30
+ ('EvalCached', False, True),
31
+ )
32
+ def testSyntheticDataPipeline(self, is_training, use_cached_data):
29
33
  task = config.Task(
30
34
  model=config.ModelConfig(
31
35
  embedding_dim=4,
32
36
  num_dense_features=8,
33
37
  vocab_sizes=[40, 12, 11, 13, 2, 5],
34
38
  bottom_mlp=[64, 32, 4],
35
- top_mlp=[64, 32, 1]),
36
- train_data=config.DataConfig(global_batch_size=16),
37
- validation_data=config.DataConfig(global_batch_size=16),
38
- use_synthetic_data=True)
39
+ top_mlp=[64, 32, 1],
40
+ ),
41
+ train_data=config.DataConfig(
42
+ global_batch_size=16, use_cached_data=use_cached_data
43
+ ),
44
+ validation_data=config.DataConfig(
45
+ global_batch_size=16, use_cached_data=use_cached_data
46
+ ),
47
+ use_synthetic_data=True,
48
+ )
39
49
 
40
50
  num_dense_features = task.model.num_dense_features
41
51
  num_sparse_features = len(task.model.vocab_sizes)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tf-models-nightly
3
- Version: 2.17.0.dev20240711
3
+ Version: 2.17.0.dev20240713
4
4
  Summary: TensorFlow Official Models
5
5
  Home-page: https://github.com/tensorflow/models
6
6
  Author: Google Inc.
@@ -882,13 +882,13 @@ official/recommendation/ranking/task_test.py,sha256=vPN_5oq1tWF3r7GuRTuAppKSF_mP
882
882
  official/recommendation/ranking/train.py,sha256=_7zC2SVsOOOBN--If0XlWw4gtdaEeTT04PbCQRSzieo,6604
883
883
  official/recommendation/ranking/train_test.py,sha256=n96JN-gI5wj4cbZNJCwQ9XLViKt0ZAg9f9dtaTpAdV8,8756
884
884
  official/recommendation/ranking/configs/__init__.py,sha256=7oiypy0N82PDw9aSdcJBLVoGTd_oRSUOdvuJhMv4leQ,609
885
- official/recommendation/ranking/configs/config.py,sha256=03KupVFJ9QeslF81Ec7n4VyPn-bVvyZaEd10GQWomVA,14844
885
+ official/recommendation/ranking/configs/config.py,sha256=_VICZEsRaLS4gbKeVv6T4RDnktqU9ughgx0s-GfgMgA,14876
886
886
  official/recommendation/ranking/configs/config_test.py,sha256=4_W0YUwJ2q0D8HdPdIvpJcAsaySjb_pDMWj26G7T1Ws,1474
887
887
  official/recommendation/ranking/data/__init__.py,sha256=7oiypy0N82PDw9aSdcJBLVoGTd_oRSUOdvuJhMv4leQ,609
888
- official/recommendation/ranking/data/data_pipeline.py,sha256=5GSF1naOl6GdqOVin7_Yx_hprmzaq0WAImHqDxysAoI,7329
889
- official/recommendation/ranking/data/data_pipeline_multi_hot.py,sha256=BOTItacIrH7f6ALYVx3vHkl_2ad4jKNpyQ4wFma7Hmc,11943
890
- official/recommendation/ranking/data/data_pipeline_multi_hot_test.py,sha256=arLjX742h_J6p2qqVH0vX86kl8WABWGIrU3hne5g-Dw,2702
891
- official/recommendation/ranking/data/data_pipeline_test.py,sha256=VRYo7WqURRkM3lbmfctvSZxyH1EzUfqwxR3sy2sZxdc,2345
888
+ official/recommendation/ranking/data/data_pipeline.py,sha256=jtNyeAEZQFcFJaatTJhxYPcokN56OtETbF2fFqzWk8k,7578
889
+ official/recommendation/ranking/data/data_pipeline_multi_hot.py,sha256=b8JGO0UyFOEjWn43NRXhEAs2wVBzEJuewwnwOGqTluM,12280
890
+ official/recommendation/ranking/data/data_pipeline_multi_hot_test.py,sha256=-NKJ4-0wwSjveCY7sMP9lChDkN6lWH5VahYMHUa9z2A,2949
891
+ official/recommendation/ranking/data/data_pipeline_test.py,sha256=34hKGxUBRjxyh9d2c7_eBhISpXI5F5PfgPfqBQoHM3Q,2573
892
892
  official/recommendation/uplift/__init__.py,sha256=_jZilTPWKu-MfMaz1IgBjEW6wqkK3FNZ1QAP4a8my3I,990
893
893
  official/recommendation/uplift/keras_test_case.py,sha256=gF5Z2FzXlKAvhuJDdj7PmFj3jsW_ZmUAv_F9Xokvs2M,6156
894
894
  official/recommendation/uplift/keys.py,sha256=7zkxkPIcXceIN5hWm4ATai4h8ymwCj85dU2r8-3XifM,1032
@@ -1212,9 +1212,9 @@ tensorflow_models/tensorflow_models_test.py,sha256=nc6A9K53OGqF25xN5St8EiWvdVbda
1212
1212
  tensorflow_models/nlp/__init__.py,sha256=4tA5Pf4qaFwT-fIFOpX7x7FHJpnyJT-5UgOeFYTyMlc,807
1213
1213
  tensorflow_models/uplift/__init__.py,sha256=mqfa55gweOdpKoaQyid4A_4u7xw__FcQeSIF0k_pYmI,999
1214
1214
  tensorflow_models/vision/__init__.py,sha256=zBorY_v5xva1uI-qxhZO3Qh-Dii-Suq6wEYh6hKHDfc,833
1215
- tf_models_nightly-2.17.0.dev20240711.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1216
- tf_models_nightly-2.17.0.dev20240711.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1217
- tf_models_nightly-2.17.0.dev20240711.dist-info/METADATA,sha256=YbWFC7bTtfE7KBzkLclvMxWAou4SWeJXau_SEYFS2u0,1432
1218
- tf_models_nightly-2.17.0.dev20240711.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1219
- tf_models_nightly-2.17.0.dev20240711.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1220
- tf_models_nightly-2.17.0.dev20240711.dist-info/RECORD,,
1215
+ tf_models_nightly-2.17.0.dev20240713.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1216
+ tf_models_nightly-2.17.0.dev20240713.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1217
+ tf_models_nightly-2.17.0.dev20240713.dist-info/METADATA,sha256=HKKg98Cbp65vo1xox9QmOHryuDX73MHbqfcVUbKW7Qs,1432
1218
+ tf_models_nightly-2.17.0.dev20240713.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1219
+ tf_models_nightly-2.17.0.dev20240713.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1220
+ tf_models_nightly-2.17.0.dev20240713.dist-info/RECORD,,