tf-models-nightly 2.17.0.dev20240712__py2.py3-none-any.whl → 2.17.0.dev20240714__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- official/recommendation/ranking/configs/config.py +1 -0
- official/recommendation/ranking/data/data_pipeline.py +8 -1
- official/recommendation/ranking/data/data_pipeline_multi_hot.py +10 -2
- official/recommendation/ranking/data/data_pipeline_multi_hot_test.py +11 -5
- official/recommendation/ranking/data/data_pipeline_test.py +17 -7
- {tf_models_nightly-2.17.0.dev20240712.dist-info → tf_models_nightly-2.17.0.dev20240714.dist-info}/METADATA +1 -1
- {tf_models_nightly-2.17.0.dev20240712.dist-info → tf_models_nightly-2.17.0.dev20240714.dist-info}/RECORD +11 -11
- {tf_models_nightly-2.17.0.dev20240712.dist-info → tf_models_nightly-2.17.0.dev20240714.dist-info}/AUTHORS +0 -0
- {tf_models_nightly-2.17.0.dev20240712.dist-info → tf_models_nightly-2.17.0.dev20240714.dist-info}/LICENSE +0 -0
- {tf_models_nightly-2.17.0.dev20240712.dist-info → tf_models_nightly-2.17.0.dev20240714.dist-info}/WHEEL +0 -0
- {tf_models_nightly-2.17.0.dev20240712.dist-info → tf_models_nightly-2.17.0.dev20240714.dist-info}/top_level.txt +0 -0
@@ -38,12 +38,14 @@ class CriteoTsvReader:
|
|
38
38
|
params: config.DataConfig,
|
39
39
|
num_dense_features: int,
|
40
40
|
vocab_sizes: List[int],
|
41
|
-
use_synthetic_data: bool = False
|
41
|
+
use_synthetic_data: bool = False,
|
42
|
+
use_cached_data: bool = False):
|
42
43
|
self._file_pattern = file_pattern
|
43
44
|
self._params = params
|
44
45
|
self._num_dense_features = num_dense_features
|
45
46
|
self._vocab_sizes = vocab_sizes
|
46
47
|
self._use_synthetic_data = use_synthetic_data
|
48
|
+
self._use_cached_data = use_cached_data
|
47
49
|
|
48
50
|
def __call__(self, ctx: tf.distribute.InputContext) -> tf.data.Dataset:
|
49
51
|
params = self._params
|
@@ -117,6 +119,8 @@ class CriteoTsvReader:
|
|
117
119
|
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
118
120
|
|
119
121
|
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
|
122
|
+
if self._use_cached_data:
|
123
|
+
dataset = dataset.take(1).cache().repeat()
|
120
124
|
|
121
125
|
return dataset
|
122
126
|
|
@@ -173,6 +177,9 @@ class CriteoTsvReader:
|
|
173
177
|
if params.is_training:
|
174
178
|
dataset = dataset.repeat()
|
175
179
|
|
180
|
+
if self._use_cached_data:
|
181
|
+
dataset = dataset.take(1).cache().repeat()
|
182
|
+
|
176
183
|
return dataset.batch(batch_size, drop_remainder=True)
|
177
184
|
|
178
185
|
|
@@ -45,13 +45,15 @@ class CriteoTsvReaderMultiHot:
|
|
45
45
|
num_dense_features: int,
|
46
46
|
vocab_sizes: List[int],
|
47
47
|
multi_hot_sizes: List[int],
|
48
|
-
use_synthetic_data: bool = False
|
48
|
+
use_synthetic_data: bool = False,
|
49
|
+
use_cached_data: bool = False):
|
49
50
|
self._file_pattern = file_pattern
|
50
51
|
self._params = params
|
51
52
|
self._num_dense_features = num_dense_features
|
52
53
|
self._vocab_sizes = vocab_sizes
|
53
54
|
self._use_synthetic_data = use_synthetic_data
|
54
55
|
self._multi_hot_sizes = multi_hot_sizes
|
56
|
+
self._use_cached_data = use_cached_data
|
55
57
|
|
56
58
|
def __call__(self, ctx: tf.distribute.InputContext) -> tf.data.Dataset:
|
57
59
|
params = self._params
|
@@ -144,6 +146,8 @@ class CriteoTsvReaderMultiHot:
|
|
144
146
|
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
145
147
|
|
146
148
|
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
|
149
|
+
if self._use_cached_data:
|
150
|
+
dataset = dataset.take(1).cache().repeat()
|
147
151
|
|
148
152
|
return dataset
|
149
153
|
|
@@ -215,12 +219,14 @@ class CriteoTFRecordReader(object):
|
|
215
219
|
params: config.DataConfig,
|
216
220
|
num_dense_features: int,
|
217
221
|
vocab_sizes: List[int],
|
218
|
-
multi_hot_sizes: List[int],
|
222
|
+
multi_hot_sizes: List[int],
|
223
|
+
use_cached_data: bool = False):
|
219
224
|
self._file_pattern = file_pattern
|
220
225
|
self._params = params
|
221
226
|
self._num_dense_features = num_dense_features
|
222
227
|
self._vocab_sizes = vocab_sizes
|
223
228
|
self._multi_hot_sizes = multi_hot_sizes
|
229
|
+
self._use_cached_data = use_cached_data
|
224
230
|
|
225
231
|
self.label_features = 'label'
|
226
232
|
self.dense_features = ['dense-feature-%d' % x for x in range(1, 14)]
|
@@ -307,6 +313,8 @@ class CriteoTFRecordReader(object):
|
|
307
313
|
num_parallel_calls=tf.data.experimental.AUTOTUNE,
|
308
314
|
)
|
309
315
|
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
|
316
|
+
if self._use_cached_data:
|
317
|
+
dataset = dataset.take(1).cache().repeat()
|
310
318
|
|
311
319
|
return dataset
|
312
320
|
|
@@ -23,9 +23,13 @@ from official.recommendation.ranking.data import data_pipeline_multi_hot
|
|
23
23
|
|
24
24
|
class DataPipelineTest(parameterized.TestCase, tf.test.TestCase):
|
25
25
|
|
26
|
-
@parameterized.named_parameters(
|
27
|
-
|
28
|
-
|
26
|
+
@parameterized.named_parameters(
|
27
|
+
('TrainCached', True, True),
|
28
|
+
('EvalNotCached', False, False),
|
29
|
+
('TrainNotCached', True, False),
|
30
|
+
('EvalCached', False, True),
|
31
|
+
)
|
32
|
+
def testSyntheticDataPipeline(self, is_training, use_cached_data):
|
29
33
|
task = config.Task(
|
30
34
|
model=config.ModelConfig(
|
31
35
|
embedding_dim=4,
|
@@ -39,8 +43,10 @@ class DataPipelineTest(parameterized.TestCase, tf.test.TestCase):
|
|
39
43
|
dcn_low_rank_dim=64,
|
40
44
|
bottom_mlp=[64, 32, 4],
|
41
45
|
top_mlp=[64, 32, 1]),
|
42
|
-
train_data=config.DataConfig(global_batch_size=16
|
43
|
-
|
46
|
+
train_data=config.DataConfig(global_batch_size=16,
|
47
|
+
use_cached_data=use_cached_data),
|
48
|
+
validation_data=config.DataConfig(global_batch_size=16,
|
49
|
+
use_cached_data=use_cached_data),
|
44
50
|
use_synthetic_data=True)
|
45
51
|
|
46
52
|
num_dense_features = task.model.num_dense_features
|
@@ -23,19 +23,29 @@ from official.recommendation.ranking.data import data_pipeline
|
|
23
23
|
|
24
24
|
class DataPipelineTest(parameterized.TestCase, tf.test.TestCase):
|
25
25
|
|
26
|
-
@parameterized.named_parameters(
|
27
|
-
|
28
|
-
|
26
|
+
@parameterized.named_parameters(
|
27
|
+
('TrainCached', True, True),
|
28
|
+
('EvalNotCached', False, False),
|
29
|
+
('TrainNotCached', True, False),
|
30
|
+
('EvalCached', False, True),
|
31
|
+
)
|
32
|
+
def testSyntheticDataPipeline(self, is_training, use_cached_data):
|
29
33
|
task = config.Task(
|
30
34
|
model=config.ModelConfig(
|
31
35
|
embedding_dim=4,
|
32
36
|
num_dense_features=8,
|
33
37
|
vocab_sizes=[40, 12, 11, 13, 2, 5],
|
34
38
|
bottom_mlp=[64, 32, 4],
|
35
|
-
top_mlp=[64, 32, 1]
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
+
top_mlp=[64, 32, 1],
|
40
|
+
),
|
41
|
+
train_data=config.DataConfig(
|
42
|
+
global_batch_size=16, use_cached_data=use_cached_data
|
43
|
+
),
|
44
|
+
validation_data=config.DataConfig(
|
45
|
+
global_batch_size=16, use_cached_data=use_cached_data
|
46
|
+
),
|
47
|
+
use_synthetic_data=True,
|
48
|
+
)
|
39
49
|
|
40
50
|
num_dense_features = task.model.num_dense_features
|
41
51
|
num_sparse_features = len(task.model.vocab_sizes)
|
@@ -882,13 +882,13 @@ official/recommendation/ranking/task_test.py,sha256=vPN_5oq1tWF3r7GuRTuAppKSF_mP
|
|
882
882
|
official/recommendation/ranking/train.py,sha256=_7zC2SVsOOOBN--If0XlWw4gtdaEeTT04PbCQRSzieo,6604
|
883
883
|
official/recommendation/ranking/train_test.py,sha256=n96JN-gI5wj4cbZNJCwQ9XLViKt0ZAg9f9dtaTpAdV8,8756
|
884
884
|
official/recommendation/ranking/configs/__init__.py,sha256=7oiypy0N82PDw9aSdcJBLVoGTd_oRSUOdvuJhMv4leQ,609
|
885
|
-
official/recommendation/ranking/configs/config.py,sha256=
|
885
|
+
official/recommendation/ranking/configs/config.py,sha256=_VICZEsRaLS4gbKeVv6T4RDnktqU9ughgx0s-GfgMgA,14876
|
886
886
|
official/recommendation/ranking/configs/config_test.py,sha256=4_W0YUwJ2q0D8HdPdIvpJcAsaySjb_pDMWj26G7T1Ws,1474
|
887
887
|
official/recommendation/ranking/data/__init__.py,sha256=7oiypy0N82PDw9aSdcJBLVoGTd_oRSUOdvuJhMv4leQ,609
|
888
|
-
official/recommendation/ranking/data/data_pipeline.py,sha256=
|
889
|
-
official/recommendation/ranking/data/data_pipeline_multi_hot.py,sha256=
|
890
|
-
official/recommendation/ranking/data/data_pipeline_multi_hot_test.py,sha256
|
891
|
-
official/recommendation/ranking/data/data_pipeline_test.py,sha256=
|
888
|
+
official/recommendation/ranking/data/data_pipeline.py,sha256=jtNyeAEZQFcFJaatTJhxYPcokN56OtETbF2fFqzWk8k,7578
|
889
|
+
official/recommendation/ranking/data/data_pipeline_multi_hot.py,sha256=b8JGO0UyFOEjWn43NRXhEAs2wVBzEJuewwnwOGqTluM,12280
|
890
|
+
official/recommendation/ranking/data/data_pipeline_multi_hot_test.py,sha256=-NKJ4-0wwSjveCY7sMP9lChDkN6lWH5VahYMHUa9z2A,2949
|
891
|
+
official/recommendation/ranking/data/data_pipeline_test.py,sha256=34hKGxUBRjxyh9d2c7_eBhISpXI5F5PfgPfqBQoHM3Q,2573
|
892
892
|
official/recommendation/uplift/__init__.py,sha256=_jZilTPWKu-MfMaz1IgBjEW6wqkK3FNZ1QAP4a8my3I,990
|
893
893
|
official/recommendation/uplift/keras_test_case.py,sha256=gF5Z2FzXlKAvhuJDdj7PmFj3jsW_ZmUAv_F9Xokvs2M,6156
|
894
894
|
official/recommendation/uplift/keys.py,sha256=7zkxkPIcXceIN5hWm4ATai4h8ymwCj85dU2r8-3XifM,1032
|
@@ -1212,9 +1212,9 @@ tensorflow_models/tensorflow_models_test.py,sha256=nc6A9K53OGqF25xN5St8EiWvdVbda
|
|
1212
1212
|
tensorflow_models/nlp/__init__.py,sha256=4tA5Pf4qaFwT-fIFOpX7x7FHJpnyJT-5UgOeFYTyMlc,807
|
1213
1213
|
tensorflow_models/uplift/__init__.py,sha256=mqfa55gweOdpKoaQyid4A_4u7xw__FcQeSIF0k_pYmI,999
|
1214
1214
|
tensorflow_models/vision/__init__.py,sha256=zBorY_v5xva1uI-qxhZO3Qh-Dii-Suq6wEYh6hKHDfc,833
|
1215
|
-
tf_models_nightly-2.17.0.
|
1216
|
-
tf_models_nightly-2.17.0.
|
1217
|
-
tf_models_nightly-2.17.0.
|
1218
|
-
tf_models_nightly-2.17.0.
|
1219
|
-
tf_models_nightly-2.17.0.
|
1220
|
-
tf_models_nightly-2.17.0.
|
1215
|
+
tf_models_nightly-2.17.0.dev20240714.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
|
1216
|
+
tf_models_nightly-2.17.0.dev20240714.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
|
1217
|
+
tf_models_nightly-2.17.0.dev20240714.dist-info/METADATA,sha256=YTQD7k71nGpFpJBjeSrX1NchtMazonQ6j-I5fzO3dcI,1432
|
1218
|
+
tf_models_nightly-2.17.0.dev20240714.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
|
1219
|
+
tf_models_nightly-2.17.0.dev20240714.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
|
1220
|
+
tf_models_nightly-2.17.0.dev20240714.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|