tf-models-nightly 2.19.0.dev20250224__py2.py3-none-any.whl → 2.19.0.dev20250226__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- official/modeling/multitask/train_lib.py +45 -3
- {tf_models_nightly-2.19.0.dev20250224.dist-info → tf_models_nightly-2.19.0.dev20250226.dist-info}/METADATA +1 -1
- {tf_models_nightly-2.19.0.dev20250224.dist-info → tf_models_nightly-2.19.0.dev20250226.dist-info}/RECORD +7 -7
- {tf_models_nightly-2.19.0.dev20250224.dist-info → tf_models_nightly-2.19.0.dev20250226.dist-info}/AUTHORS +0 -0
- {tf_models_nightly-2.19.0.dev20250224.dist-info → tf_models_nightly-2.19.0.dev20250226.dist-info}/LICENSE +0 -0
- {tf_models_nightly-2.19.0.dev20250224.dist-info → tf_models_nightly-2.19.0.dev20250226.dist-info}/WHEEL +0 -0
- {tf_models_nightly-2.19.0.dev20250224.dist-info → tf_models_nightly-2.19.0.dev20250226.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,7 @@
|
|
15
15
|
"""Multitask training driver library."""
|
16
16
|
# pytype: disable=attribute-error
|
17
17
|
import os
|
18
|
-
from typing import Any, List, Mapping, Optional, Tuple, Union
|
18
|
+
from typing import Any, List, Mapping, Optional, Tuple, Union, Callable
|
19
19
|
from absl import logging
|
20
20
|
import orbit
|
21
21
|
import tensorflow as tf, tf_keras
|
@@ -157,6 +157,25 @@ def run_experiment(
|
|
157
157
|
return model
|
158
158
|
|
159
159
|
|
160
|
+
TrainActionsFactoryType = Callable[
|
161
|
+
[
|
162
|
+
configs.MultiEvalExperimentConfig,
|
163
|
+
orbit.StandardTrainer,
|
164
|
+
str,
|
165
|
+
tf.train.CheckpointManager,
|
166
|
+
],
|
167
|
+
List[orbit.Action],
|
168
|
+
]
|
169
|
+
EvalActionsFactoryType = Callable[
|
170
|
+
[
|
171
|
+
configs.MultiEvalExperimentConfig,
|
172
|
+
orbit.AbstractEvaluator,
|
173
|
+
str,
|
174
|
+
],
|
175
|
+
List[orbit.Action],
|
176
|
+
]
|
177
|
+
|
178
|
+
|
160
179
|
def run_experiment_with_multitask_eval(
|
161
180
|
*,
|
162
181
|
distribution_strategy: tf.distribute.Strategy,
|
@@ -171,6 +190,8 @@ def run_experiment_with_multitask_eval(
|
|
171
190
|
eval_summary_manager: Optional[orbit.utils.SummaryManagerInterface] = None,
|
172
191
|
best_ckpt_exporter_creator: Optional[Any] = train_utils
|
173
192
|
.maybe_create_best_ckpt_exporter,
|
193
|
+
train_actions_factory: Optional[TrainActionsFactoryType] = None,
|
194
|
+
eval_actions_factory: Optional[EvalActionsFactoryType] = None,
|
174
195
|
) -> Tuple[Any, Any]:
|
175
196
|
"""Runs train/eval configured by the experiment params.
|
176
197
|
|
@@ -193,6 +214,8 @@ def run_experiment_with_multitask_eval(
|
|
193
214
|
will be created internally for TensorBoard summaries by default from the
|
194
215
|
`eval_summary_dir`.
|
195
216
|
best_ckpt_exporter_creator: A functor for creating best checkpoint exporter.
|
217
|
+
train_actions_factory: Optional factory function to create train actions.
|
218
|
+
eval_actions_factory: Optional factory function to create eval actions.
|
196
219
|
|
197
220
|
Returns:
|
198
221
|
model: `tf_keras.Model` instance.
|
@@ -214,7 +237,6 @@ def run_experiment_with_multitask_eval(
|
|
214
237
|
|
215
238
|
# Build the model or fetch the pre-cached one (which could be either
|
216
239
|
# multi-task model or single task model).
|
217
|
-
model = None
|
218
240
|
if trainer is None:
|
219
241
|
if isinstance(train_task, multitask.MultiTask):
|
220
242
|
model = train_task.build_multitask_model()
|
@@ -254,6 +276,23 @@ def run_experiment_with_multitask_eval(
|
|
254
276
|
checkpoint_interval=params.trainer.checkpoint_interval,
|
255
277
|
init_fn=trainer.initialize if trainer else None)
|
256
278
|
|
279
|
+
if trainer and train_actions_factory:
|
280
|
+
# pytype: disable=wrong-keyword-args
|
281
|
+
train_actions = train_actions_factory(
|
282
|
+
params=params,
|
283
|
+
trainer=trainer,
|
284
|
+
model_dir=model_dir,
|
285
|
+
checkpoint_manager=checkpoint_manager,
|
286
|
+
)
|
287
|
+
# pytype: enable=wrong-keyword-args
|
288
|
+
else:
|
289
|
+
train_actions = None
|
290
|
+
|
291
|
+
if evaluator and eval_actions_factory:
|
292
|
+
eval_actions = eval_actions_factory(params, evaluator, model_dir)
|
293
|
+
else:
|
294
|
+
eval_actions = None
|
295
|
+
|
257
296
|
controller = orbit.Controller(
|
258
297
|
strategy=distribution_strategy,
|
259
298
|
trainer=trainer,
|
@@ -266,7 +305,10 @@ def run_experiment_with_multitask_eval(
|
|
266
305
|
(save_summary) else None,
|
267
306
|
eval_summary_manager=eval_summary_manager,
|
268
307
|
summary_interval=params.trainer.summary_interval if
|
269
|
-
(save_summary) else None
|
308
|
+
(save_summary) else None,
|
309
|
+
train_actions=train_actions,
|
310
|
+
eval_actions=eval_actions,
|
311
|
+
)
|
270
312
|
|
271
313
|
logging.info('Starts to execute mode: %s', mode)
|
272
314
|
with distribution_strategy.scope():
|
@@ -224,7 +224,7 @@ official/modeling/multitask/multitask.py,sha256=DV-ysfhPiIZgsrzZNylsPBxKNBf_xzPx
|
|
224
224
|
official/modeling/multitask/task_sampler.py,sha256=SGVVdjMb5oG4vnCczpfdgBtbsdsXiyBLl9si_0V6nko,4897
|
225
225
|
official/modeling/multitask/task_sampler_test.py,sha256=wkPTp1LCNx4uJbfHfVbKrQULQxumZS3ctBemEbnMokk,3037
|
226
226
|
official/modeling/multitask/test_utils.py,sha256=fPi_TxtzHy_NGTYmdDAQay9TRjwG70UWHwysDqol4jw,4315
|
227
|
-
official/modeling/multitask/train_lib.py,sha256=
|
227
|
+
official/modeling/multitask/train_lib.py,sha256=ZHRN6afigqcupAz61P6qBEh_RhomAwVgvjgkRHSv0eU,12827
|
228
228
|
official/modeling/multitask/train_lib_test.py,sha256=F_uDEQhSlXq9BH1QBlamUMQ_vRUflli7ef2P9MUKXeg,4766
|
229
229
|
official/modeling/optimization/__init__.py,sha256=BI86b89P0xOksBlzOmbAdWOhlJQe6cxtGIRN_zqCwa8,1201
|
230
230
|
official/modeling/optimization/adafactor_optimizer.py,sha256=zEHHFg9iH1UwcBSzzYE8bCNWln7QIIftXJEXTGTtCtE,792
|
@@ -1248,9 +1248,9 @@ tensorflow_models/tensorflow_models_test.py,sha256=nc6A9K53OGqF25xN5St8EiWvdVbda
|
|
1248
1248
|
tensorflow_models/nlp/__init__.py,sha256=4tA5Pf4qaFwT-fIFOpX7x7FHJpnyJT-5UgOeFYTyMlc,807
|
1249
1249
|
tensorflow_models/uplift/__init__.py,sha256=mqfa55gweOdpKoaQyid4A_4u7xw__FcQeSIF0k_pYmI,999
|
1250
1250
|
tensorflow_models/vision/__init__.py,sha256=zBorY_v5xva1uI-qxhZO3Qh-Dii-Suq6wEYh6hKHDfc,833
|
1251
|
-
tf_models_nightly-2.19.0.
|
1252
|
-
tf_models_nightly-2.19.0.
|
1253
|
-
tf_models_nightly-2.19.0.
|
1254
|
-
tf_models_nightly-2.19.0.
|
1255
|
-
tf_models_nightly-2.19.0.
|
1256
|
-
tf_models_nightly-2.19.0.
|
1251
|
+
tf_models_nightly-2.19.0.dev20250226.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
|
1252
|
+
tf_models_nightly-2.19.0.dev20250226.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
|
1253
|
+
tf_models_nightly-2.19.0.dev20250226.dist-info/METADATA,sha256=pXPpfanhACUR5sKJVyUH2c8SY0vi-8tcXz-JPP71KpQ,1432
|
1254
|
+
tf_models_nightly-2.19.0.dev20250226.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
|
1255
|
+
tf_models_nightly-2.19.0.dev20250226.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
|
1256
|
+
tf_models_nightly-2.19.0.dev20250226.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|