rapidfireai 0.10.3rc1__py3-none-any.whl → 0.11.1rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rapidfireai might be problematic. Click here for more details.
- rapidfireai/automl/grid_search.py +4 -5
- rapidfireai/automl/model_config.py +41 -37
- rapidfireai/automl/random_search.py +21 -33
- rapidfireai/backend/controller.py +54 -148
- rapidfireai/backend/worker.py +14 -3
- rapidfireai/cli.py +148 -136
- rapidfireai/experiment.py +22 -11
- rapidfireai/frontend/build/asset-manifest.json +3 -3
- rapidfireai/frontend/build/index.html +1 -1
- rapidfireai/frontend/build/static/js/{main.e7d3b759.js → main.58393d31.js} +3 -3
- rapidfireai/frontend/build/static/js/{main.e7d3b759.js.map → main.58393d31.js.map} +1 -1
- rapidfireai/ml/callbacks.py +10 -24
- rapidfireai/ml/trainer.py +37 -81
- rapidfireai/utils/constants.py +3 -1
- rapidfireai/utils/interactive_controller.py +40 -61
- rapidfireai/utils/logging.py +1 -2
- rapidfireai/utils/mlflow_manager.py +1 -0
- rapidfireai/utils/ping.py +4 -2
- rapidfireai/version.py +2 -2
- {rapidfireai-0.10.3rc1.dist-info → rapidfireai-0.11.1rc1.dist-info}/METADATA +1 -1
- {rapidfireai-0.10.3rc1.dist-info → rapidfireai-0.11.1rc1.dist-info}/RECORD +26 -26
- /rapidfireai/frontend/build/static/js/{main.e7d3b759.js.LICENSE.txt → main.58393d31.js.LICENSE.txt} +0 -0
- {rapidfireai-0.10.3rc1.dist-info → rapidfireai-0.11.1rc1.dist-info}/WHEEL +0 -0
- {rapidfireai-0.10.3rc1.dist-info → rapidfireai-0.11.1rc1.dist-info}/entry_points.txt +0 -0
- {rapidfireai-0.10.3rc1.dist-info → rapidfireai-0.11.1rc1.dist-info}/licenses/LICENSE +0 -0
- {rapidfireai-0.10.3rc1.dist-info → rapidfireai-0.11.1rc1.dist-info}/top_level.txt +0 -0
|
@@ -70,15 +70,11 @@ class Controller:
|
|
|
70
70
|
self.logger.debug(f"Found {self.num_workers} workers/GPUs.")
|
|
71
71
|
|
|
72
72
|
# initialize shared manager and registry, create shared memory manager instance
|
|
73
|
-
self.shm_manager: SharedMemoryManager = SharedMemoryManager(
|
|
74
|
-
name="controller-shm"
|
|
75
|
-
)
|
|
73
|
+
self.shm_manager: SharedMemoryManager = SharedMemoryManager(name="controller-shm")
|
|
76
74
|
registry, process_lock = self.shm_manager.get_shm_objects()
|
|
77
75
|
|
|
78
76
|
# create worker manager
|
|
79
|
-
self.worker_manager: WorkerManager = WorkerManager(
|
|
80
|
-
self.num_workers, registry, process_lock
|
|
81
|
-
)
|
|
77
|
+
self.worker_manager: WorkerManager = WorkerManager(self.num_workers, registry, process_lock)
|
|
82
78
|
|
|
83
79
|
# create metric logger
|
|
84
80
|
# Initialize DataPath temporarily to get experiment path for tensorboard logs
|
|
@@ -92,7 +88,7 @@ class Controller:
|
|
|
92
88
|
tensorboard_log_dir=tensorboard_log_dir,
|
|
93
89
|
)
|
|
94
90
|
# Get experiment if using MLflow
|
|
95
|
-
if hasattr(self.metric_logger,
|
|
91
|
+
if hasattr(self.metric_logger, "get_experiment"):
|
|
96
92
|
self.metric_logger.get_experiment(self.experiment_name)
|
|
97
93
|
|
|
98
94
|
self.logger.debug("Controller initialized")
|
|
@@ -116,14 +112,12 @@ class Controller:
|
|
|
116
112
|
for config_leaf in config_leafs:
|
|
117
113
|
flattened_config = get_flattened_config_leaf(config_leaf)
|
|
118
114
|
# print("flattened_config: ",flattened_config)
|
|
119
|
-
total_steps = self._get_total_step(
|
|
120
|
-
config_leaf, len_train_dataset, num_chunks
|
|
121
|
-
)
|
|
115
|
+
total_steps = self._get_total_step(config_leaf, len_train_dataset, num_chunks)
|
|
122
116
|
|
|
123
117
|
# get clone modify info
|
|
124
118
|
warm_started_from = clone_modify_info.get("warm_started_from") if clone_modify_info else None
|
|
125
119
|
cloned_from = clone_modify_info.get("cloned_from") if clone_modify_info else None
|
|
126
|
-
chunk_offset = clone_modify_info.get("chunk_offset",0) if clone_modify_info else 0
|
|
120
|
+
chunk_offset = clone_modify_info.get("chunk_offset", 0) if clone_modify_info else 0
|
|
127
121
|
|
|
128
122
|
run_id = self.db.create_run(
|
|
129
123
|
config_leaf=config_leaf,
|
|
@@ -143,22 +137,16 @@ class Controller:
|
|
|
143
137
|
try:
|
|
144
138
|
base_run_path = DataPath.base_run_path(run_id)
|
|
145
139
|
work_dir_path = DataPath.work_dir_path(base_run_path)
|
|
146
|
-
initial_checkpoint_path = DataPath.initial_checkpoint_path(
|
|
147
|
-
base_run_path
|
|
148
|
-
)
|
|
140
|
+
initial_checkpoint_path = DataPath.initial_checkpoint_path(base_run_path)
|
|
149
141
|
final_checkpoint_path = DataPath.final_checkpoint_path(base_run_path)
|
|
150
|
-
intermediate_checkpoint_path = DataPath.intermediate_checkpoint_path(
|
|
151
|
-
base_run_path
|
|
152
|
-
)
|
|
142
|
+
intermediate_checkpoint_path = DataPath.intermediate_checkpoint_path(base_run_path)
|
|
153
143
|
|
|
154
144
|
Path.mkdir(work_dir_path, parents=True, exist_ok=True)
|
|
155
145
|
Path.mkdir(initial_checkpoint_path, parents=True, exist_ok=True)
|
|
156
146
|
Path.mkdir(final_checkpoint_path, parents=True, exist_ok=True)
|
|
157
147
|
Path.mkdir(intermediate_checkpoint_path, parents=True, exist_ok=True)
|
|
158
148
|
except (PermissionError, OSError) as e:
|
|
159
|
-
raise ControllerException(
|
|
160
|
-
f"Failed to create required Run DataPath directories: {e}"
|
|
161
|
-
) from e
|
|
149
|
+
raise ControllerException(f"Failed to create required Run DataPath directories: {e}") from e
|
|
162
150
|
|
|
163
151
|
# create new tracking run
|
|
164
152
|
try:
|
|
@@ -169,16 +157,10 @@ class Controller:
|
|
|
169
157
|
for key, value in flattened_config.items():
|
|
170
158
|
self.metric_logger.log_param(mlflow_run_id, key, value)
|
|
171
159
|
if warm_started_from:
|
|
172
|
-
self.metric_logger.log_param(
|
|
173
|
-
mlflow_run_id, "warm-start", str(warm_started_from)
|
|
174
|
-
)
|
|
160
|
+
self.metric_logger.log_param(mlflow_run_id, "warm-start", str(warm_started_from))
|
|
175
161
|
if cloned_from:
|
|
176
|
-
self.metric_logger.log_param(
|
|
177
|
-
|
|
178
|
-
)
|
|
179
|
-
self.logger.debug(
|
|
180
|
-
f"Populated MLFlow with model config info for run {run_id}."
|
|
181
|
-
)
|
|
162
|
+
self.metric_logger.log_param(mlflow_run_id, "parent-run", str(cloned_from))
|
|
163
|
+
self.logger.debug(f"Populated MLFlow with model config info for run {run_id}.")
|
|
182
164
|
self.db.set_run_details(
|
|
183
165
|
run_id=run_id,
|
|
184
166
|
mlflow_run_id=mlflow_run_id,
|
|
@@ -192,9 +174,7 @@ class Controller:
|
|
|
192
174
|
self.logger.error(msg, exc_info=True)
|
|
193
175
|
|
|
194
176
|
total_runs = len(runs)
|
|
195
|
-
self.logger.info(
|
|
196
|
-
f"Created {total_runs} runs - \n{pformat(runs, indent=4, width=120)}"
|
|
197
|
-
)
|
|
177
|
+
self.logger.info(f"Created {total_runs} runs - \n{pformat(runs, indent=4, width=120)}")
|
|
198
178
|
self.logger.debug(f"Got {total_runs} runs for {source.value}.")
|
|
199
179
|
|
|
200
180
|
# set experiment task to run_fit
|
|
@@ -208,24 +188,17 @@ class Controller:
|
|
|
208
188
|
|
|
209
189
|
# check if there are any other runs with the same base model
|
|
210
190
|
base_model_name = self.db.get_run(run_id)["config_leaf"]["model_name"]
|
|
211
|
-
relevant_runs = self.db.get_runs_by_status(
|
|
212
|
-
[RunStatus.ONGOING, RunStatus.NEW, RunStatus.STOPPED]
|
|
213
|
-
)
|
|
191
|
+
relevant_runs = self.db.get_runs_by_status([RunStatus.ONGOING, RunStatus.NEW, RunStatus.STOPPED])
|
|
214
192
|
|
|
215
193
|
# get shared object types to delete - if no other runs are using it
|
|
216
194
|
delete_shared_objects = True
|
|
217
195
|
for r_run_id, r_run_details in relevant_runs.items():
|
|
218
|
-
if
|
|
219
|
-
r_run_details["config_leaf"]["model_name"] == base_model_name
|
|
220
|
-
and r_run_id != run_id
|
|
221
|
-
):
|
|
196
|
+
if r_run_details["config_leaf"]["model_name"] == base_model_name and r_run_id != run_id:
|
|
222
197
|
delete_shared_objects = False
|
|
223
198
|
break
|
|
224
199
|
|
|
225
200
|
# delete model object from shared memory
|
|
226
|
-
self.shm_manager.delete_model_object(
|
|
227
|
-
run_id, base_model_name if delete_shared_objects else None
|
|
228
|
-
)
|
|
201
|
+
self.shm_manager.delete_model_object(run_id, base_model_name if delete_shared_objects else None)
|
|
229
202
|
|
|
230
203
|
def _process_interactive_control(
|
|
231
204
|
self,
|
|
@@ -250,9 +223,7 @@ class Controller:
|
|
|
250
223
|
status=RunStatus.STOPPED,
|
|
251
224
|
ended_by=RunEndedBy.INTERACTIVE_CONTROL,
|
|
252
225
|
)
|
|
253
|
-
self.db.set_ic_ops_task_status(
|
|
254
|
-
run_state["task_id"], TaskStatus.COMPLETED
|
|
255
|
-
)
|
|
226
|
+
self.db.set_ic_ops_task_status(run_state["task_id"], TaskStatus.COMPLETED)
|
|
256
227
|
self.ic_logger.info(f"Stopping run {run_id} by Interactive Control")
|
|
257
228
|
elif run_state["status"] == RunStatus.DELETED:
|
|
258
229
|
# process deleted tasks
|
|
@@ -269,9 +240,7 @@ class Controller:
|
|
|
269
240
|
status=RunStatus.DELETED,
|
|
270
241
|
ended_by=RunEndedBy.INTERACTIVE_CONTROL,
|
|
271
242
|
)
|
|
272
|
-
self.db.set_ic_ops_task_status(
|
|
273
|
-
run_state["task_id"], TaskStatus.COMPLETED
|
|
274
|
-
)
|
|
243
|
+
self.db.set_ic_ops_task_status(run_state["task_id"], TaskStatus.COMPLETED)
|
|
275
244
|
self.ic_logger.info(f"Deleting run {run_id} by Interactive Control")
|
|
276
245
|
elif run_state["status"] == RunStatus.ONGOING:
|
|
277
246
|
# process ongoing tasks
|
|
@@ -280,15 +249,11 @@ class Controller:
|
|
|
280
249
|
status=RunStatus.ONGOING,
|
|
281
250
|
ended_by="",
|
|
282
251
|
)
|
|
283
|
-
self.db.set_ic_ops_task_status(
|
|
284
|
-
run_state["task_id"], TaskStatus.COMPLETED
|
|
285
|
-
)
|
|
252
|
+
self.db.set_ic_ops_task_status(run_state["task_id"], TaskStatus.COMPLETED)
|
|
286
253
|
self.ic_logger.info(f"Resuming run {run_id} by Interactive Control")
|
|
287
254
|
elif run_state["status"] == RunStatus.COMPLETED:
|
|
288
255
|
# process completed tasks
|
|
289
|
-
self.logger.warning(
|
|
290
|
-
f"Run {run_id} is already completed. Skipping Interactive Control task."
|
|
291
|
-
)
|
|
256
|
+
self.logger.warning(f"Run {run_id} is already completed. Skipping Interactive Control task.")
|
|
292
257
|
self.db.set_ic_ops_task_status(run_state["task_id"], TaskStatus.SKIPPED)
|
|
293
258
|
else:
|
|
294
259
|
raise ValueError(f"Unsupported run status {run_state['status']}")
|
|
@@ -304,9 +269,7 @@ class Controller:
|
|
|
304
269
|
# add additional_kwargs to config_leaf if it exists in the parent run
|
|
305
270
|
parent_run_details = self.db.get_run(parent_run_id)
|
|
306
271
|
if "additional_kwargs" in parent_run_details["config_leaf"]:
|
|
307
|
-
config_leaf["additional_kwargs"] = parent_run_details["config_leaf"][
|
|
308
|
-
"additional_kwargs"
|
|
309
|
-
]
|
|
272
|
+
config_leaf["additional_kwargs"] = parent_run_details["config_leaf"]["additional_kwargs"]
|
|
310
273
|
|
|
311
274
|
# create model for the new run
|
|
312
275
|
try:
|
|
@@ -324,7 +287,9 @@ class Controller:
|
|
|
324
287
|
)
|
|
325
288
|
elif ic_op == ControllerTask.IC_CLONE_MODIFY_WARM:
|
|
326
289
|
# calculate clone chunk offset
|
|
327
|
-
effective_batch_size = parent_run_details["config_leaf"]["training_args"].get(
|
|
290
|
+
effective_batch_size = parent_run_details["config_leaf"]["training_args"].get(
|
|
291
|
+
"per_device_train_batch_size", 1
|
|
292
|
+
) * parent_run_details["config_leaf"]["training_args"].get("gradient_accumulation_steps", 1)
|
|
328
293
|
chunker = DatasetChunks(
|
|
329
294
|
len_train_dataset,
|
|
330
295
|
num_chunks,
|
|
@@ -355,12 +320,8 @@ class Controller:
|
|
|
355
320
|
)
|
|
356
321
|
except Exception as e:
|
|
357
322
|
self.db.set_ic_ops_task_status(task["task_id"], TaskStatus.FAILED)
|
|
358
|
-
self.ic_logger.error(
|
|
359
|
-
|
|
360
|
-
)
|
|
361
|
-
raise ControllerException(
|
|
362
|
-
f"Error creating model for run {parent_run_id}: {e}"
|
|
363
|
-
) from e
|
|
323
|
+
self.ic_logger.error(f"Error creating model for run {parent_run_id}: {e}")
|
|
324
|
+
raise ControllerException(f"Error creating model for run {parent_run_id}: {e}") from e
|
|
364
325
|
|
|
365
326
|
def _process_interm_ic_ops_states(
|
|
366
327
|
self,
|
|
@@ -389,11 +350,7 @@ class Controller:
|
|
|
389
350
|
if is_clone_modify_task:
|
|
390
351
|
# clone_modify tasks
|
|
391
352
|
# get latest run state
|
|
392
|
-
run_status = (
|
|
393
|
-
run_states[run_id]["status"]
|
|
394
|
-
if run_id in run_states
|
|
395
|
-
else self.db.get_run(run_id)["status"]
|
|
396
|
-
)
|
|
353
|
+
run_status = run_states[run_id]["status"] if run_id in run_states else self.db.get_run(run_id)["status"]
|
|
397
354
|
|
|
398
355
|
# track clone_modify tasks only for non-deleted runs
|
|
399
356
|
if run_status != RunStatus.DELETED:
|
|
@@ -401,9 +358,7 @@ class Controller:
|
|
|
401
358
|
self.ic_logger.info(f"Added {task['ic_op']} task for run {run_id}.")
|
|
402
359
|
else:
|
|
403
360
|
self.db.set_ic_ops_task_status(task["task_id"], TaskStatus.SKIPPED)
|
|
404
|
-
self.ic_logger.warning(
|
|
405
|
-
f"Skipping {task['ic_op']} task for deleted run {run_id}."
|
|
406
|
-
)
|
|
361
|
+
self.ic_logger.warning(f"Skipping {task['ic_op']} task for deleted run {run_id}.")
|
|
407
362
|
else:
|
|
408
363
|
# Non clone_modify tasks
|
|
409
364
|
if run_id not in run_states:
|
|
@@ -420,32 +375,21 @@ class Controller:
|
|
|
420
375
|
ControllerTask.IC_STOP,
|
|
421
376
|
]:
|
|
422
377
|
# ignore RESUME/STOP tasks for completed runs
|
|
423
|
-
self.ic_logger.warning(
|
|
424
|
-
f"Ignoring RESUME/STOP task for run {run_id} as it is already completed"
|
|
425
|
-
)
|
|
378
|
+
self.ic_logger.warning(f"Ignoring RESUME/STOP task for run {run_id} as it is already completed")
|
|
426
379
|
self.db.set_ic_ops_task_status(task["task_id"], TaskStatus.SKIPPED)
|
|
427
|
-
elif
|
|
428
|
-
current_status == RunStatus.FAILED
|
|
429
|
-
and task["ic_op"] != ControllerTask.IC_DELETE
|
|
430
|
-
):
|
|
380
|
+
elif current_status == RunStatus.FAILED and task["ic_op"] != ControllerTask.IC_DELETE:
|
|
431
381
|
# ignore all tasks except DELETE for failed runs
|
|
432
|
-
self.ic_logger.warning(
|
|
433
|
-
f"Ignoring task {task['ic_op'].value} for failed run {run_id}"
|
|
434
|
-
)
|
|
382
|
+
self.ic_logger.warning(f"Ignoring task {task['ic_op'].value} for failed run {run_id}")
|
|
435
383
|
self.db.set_ic_ops_task_status(task["task_id"], TaskStatus.SKIPPED)
|
|
436
384
|
elif current_status == RunStatus.DELETED:
|
|
437
385
|
# ignore all tasks for deleted runs
|
|
438
|
-
self.ic_logger.warning(
|
|
439
|
-
f"Ignoring task {task['ic_op'].value} for deleted run {run_id}"
|
|
440
|
-
)
|
|
386
|
+
self.ic_logger.warning(f"Ignoring task {task['ic_op'].value} for deleted run {run_id}")
|
|
441
387
|
self.db.set_ic_ops_task_status(task["task_id"], TaskStatus.SKIPPED)
|
|
442
388
|
else:
|
|
443
389
|
# valid ic_op for this run
|
|
444
390
|
# mark prev task as completed
|
|
445
391
|
if run_states[run_id]["task_id"] is not None:
|
|
446
|
-
self.db.set_ic_ops_task_status(
|
|
447
|
-
run_states[run_id]["task_id"], TaskStatus.COMPLETED
|
|
448
|
-
)
|
|
392
|
+
self.db.set_ic_ops_task_status(run_states[run_id]["task_id"], TaskStatus.COMPLETED)
|
|
449
393
|
|
|
450
394
|
# add new task to run states
|
|
451
395
|
if task["ic_op"] == ControllerTask.IC_STOP:
|
|
@@ -458,26 +402,20 @@ class Controller:
|
|
|
458
402
|
updated_status = RunStatus.ONGOING
|
|
459
403
|
info_msg = f"Received RESUME task for run {run_id}"
|
|
460
404
|
else:
|
|
461
|
-
self.db.set_ic_ops_task_status(
|
|
462
|
-
task["task_id"], TaskStatus.FAILED
|
|
463
|
-
)
|
|
405
|
+
self.db.set_ic_ops_task_status(task["task_id"], TaskStatus.FAILED)
|
|
464
406
|
raise ValueError(f"Unsupported task {task['ic_op']}")
|
|
465
407
|
run_states[run_id].update(
|
|
466
408
|
{
|
|
467
409
|
"task_id": task["task_id"],
|
|
468
410
|
"task": task["ic_op"],
|
|
469
|
-
"status": (
|
|
470
|
-
updated_status if updated_status else current_status
|
|
471
|
-
),
|
|
411
|
+
"status": (updated_status if updated_status else current_status),
|
|
472
412
|
}
|
|
473
413
|
)
|
|
474
414
|
self.ic_logger.info(info_msg)
|
|
475
415
|
|
|
476
416
|
return run_states, clone_modify_tasks
|
|
477
417
|
|
|
478
|
-
def _get_total_step(
|
|
479
|
-
self, config_leaf: dict[str, Any], len_train_dataset: int, num_chunks: int
|
|
480
|
-
) -> int:
|
|
418
|
+
def _get_total_step(self, config_leaf: dict[str, Any], len_train_dataset: int, num_chunks: int) -> int:
|
|
481
419
|
"""Get the total number of steps for a run."""
|
|
482
420
|
num_train_epochs = config_leaf["training_args"].get("num_train_epochs", 1)
|
|
483
421
|
|
|
@@ -487,25 +425,20 @@ class Controller:
|
|
|
487
425
|
# ceil to nearest chunk multiple
|
|
488
426
|
total_steps = config_leaf["training_args"]["max_steps"]
|
|
489
427
|
elif num_train_epochs:
|
|
490
|
-
per_device_train_batch_size = config_leaf["training_args"].get(
|
|
491
|
-
|
|
492
|
-
)
|
|
493
|
-
gradient_accumulation_steps = config_leaf["training_args"].get(
|
|
494
|
-
"gradient_accumulation_steps", 1
|
|
495
|
-
)
|
|
428
|
+
per_device_train_batch_size = config_leaf["training_args"].get("per_device_train_batch_size", 1)
|
|
429
|
+
gradient_accumulation_steps = config_leaf["training_args"].get("gradient_accumulation_steps", 1)
|
|
496
430
|
len_dataloader = math.ceil(len_train_dataset / per_device_train_batch_size)
|
|
497
431
|
num_update_steps_per_epoch = max(
|
|
498
|
-
len_dataloader // gradient_accumulation_steps
|
|
499
|
-
+ int(len_dataloader % gradient_accumulation_steps > 0),
|
|
432
|
+
len_dataloader // gradient_accumulation_steps + int(len_dataloader % gradient_accumulation_steps > 0),
|
|
500
433
|
1,
|
|
501
434
|
)
|
|
502
435
|
total_steps = math.ceil(num_train_epochs * num_update_steps_per_epoch)
|
|
503
436
|
|
|
504
437
|
if config_leaf.get("trainer_type", "SFT") == "GRPO":
|
|
505
438
|
num_generations = config_leaf["training_args"].get("num_generations", 8)
|
|
506
|
-
total_steps = (
|
|
507
|
-
|
|
508
|
-
)
|
|
439
|
+
total_steps = (num_generations * len_train_dataset * num_train_epochs) // (
|
|
440
|
+
gradient_accumulation_steps * per_device_train_batch_size
|
|
441
|
+
)
|
|
509
442
|
return total_steps
|
|
510
443
|
|
|
511
444
|
def run_fit(
|
|
@@ -521,9 +454,7 @@ class Controller:
|
|
|
521
454
|
|
|
522
455
|
# set experiment task to create models
|
|
523
456
|
self.db.set_experiment_current_task(ExperimentTask.CREATE_MODELS)
|
|
524
|
-
self.logger.debug(
|
|
525
|
-
f"Set experiment task to {ExperimentTask.CREATE_MODELS.value}."
|
|
526
|
-
)
|
|
457
|
+
self.logger.debug(f"Set experiment task to {ExperimentTask.CREATE_MODELS.value}.")
|
|
527
458
|
|
|
528
459
|
# save train and eval dataset objects to a file for workers to load
|
|
529
460
|
try:
|
|
@@ -608,10 +539,7 @@ class Controller:
|
|
|
608
539
|
)
|
|
609
540
|
|
|
610
541
|
# skip if task is the same as previous iteration (no change in status) or run is not active
|
|
611
|
-
if
|
|
612
|
-
current_task_tuple == prev_task_tuple
|
|
613
|
-
or worker_task["run_id"] not in scheduler.run_ids
|
|
614
|
-
):
|
|
542
|
+
if current_task_tuple == prev_task_tuple or worker_task["run_id"] not in scheduler.run_ids:
|
|
615
543
|
continue
|
|
616
544
|
|
|
617
545
|
if worker_task["status"] == TaskStatus.COMPLETED:
|
|
@@ -624,9 +552,7 @@ class Controller:
|
|
|
624
552
|
run_id = worker_task["run_id"]
|
|
625
553
|
chunk_id = worker_task["chunk_id"]
|
|
626
554
|
run_details = all_run_details[run_id]
|
|
627
|
-
self.logger.debug(
|
|
628
|
-
f"Completed task: run {run_id}, chunk {chunk_id} on worker {worker_id}"
|
|
629
|
-
)
|
|
555
|
+
self.logger.debug(f"Completed task: run {run_id}, chunk {chunk_id} on worker {worker_id}")
|
|
630
556
|
self.logger.info(
|
|
631
557
|
f"Run {run_id} completed steps - {run_details['completed_steps']}/{run_details['total_steps']}"
|
|
632
558
|
)
|
|
@@ -648,11 +574,7 @@ class Controller:
|
|
|
648
574
|
|
|
649
575
|
# Update progress
|
|
650
576
|
progress_percentage = (
|
|
651
|
-
(
|
|
652
|
-
run_details["completed_steps"]
|
|
653
|
-
/ run_details["total_steps"]
|
|
654
|
-
* 100
|
|
655
|
-
)
|
|
577
|
+
(run_details["completed_steps"] / run_details["total_steps"] * 100)
|
|
656
578
|
if run_details["total_steps"] > 0
|
|
657
579
|
else 0
|
|
658
580
|
)
|
|
@@ -673,16 +595,11 @@ class Controller:
|
|
|
673
595
|
)
|
|
674
596
|
# Check if run has completed only current epoch (hasn't reached total_steps yet)
|
|
675
597
|
elif (
|
|
676
|
-
new_chunks_visited == num_chunks
|
|
677
|
-
and run_details["completed_steps"] < run_details["total_steps"]
|
|
598
|
+
new_chunks_visited == num_chunks and run_details["completed_steps"] < run_details["total_steps"]
|
|
678
599
|
):
|
|
679
600
|
scheduler.reset_run(run_id)
|
|
680
|
-
self.db.set_run_details(
|
|
681
|
-
|
|
682
|
-
)
|
|
683
|
-
self.logger.info(
|
|
684
|
-
f"Run {run_id} has completed epoch ({new_chunks_visited}/{num_chunks} chunks)"
|
|
685
|
-
)
|
|
601
|
+
self.db.set_run_details(run_id=run_id, num_chunks_visited_curr_epoch=0)
|
|
602
|
+
self.logger.info(f"Run {run_id} has completed epoch ({new_chunks_visited}/{num_chunks} chunks)")
|
|
686
603
|
|
|
687
604
|
# Check for failed runs and update scheduler, local state, shm
|
|
688
605
|
for worker_task in failed_tasks:
|
|
@@ -698,12 +615,8 @@ class Controller:
|
|
|
698
615
|
|
|
699
616
|
# Process interactive control tasks (this fetches latest run states internally)
|
|
700
617
|
try:
|
|
701
|
-
currently_scheduled_runs = list(
|
|
702
|
-
|
|
703
|
-
)
|
|
704
|
-
run_states, clone_modify_tasks = self._process_interm_ic_ops_states(
|
|
705
|
-
currently_scheduled_runs
|
|
706
|
-
)
|
|
618
|
+
currently_scheduled_runs = list(scheduler.worker_running_current_run.values())
|
|
619
|
+
run_states, clone_modify_tasks = self._process_interm_ic_ops_states(currently_scheduled_runs)
|
|
707
620
|
self._process_interactive_control(
|
|
708
621
|
run_states,
|
|
709
622
|
clone_modify_tasks,
|
|
@@ -712,9 +625,7 @@ class Controller:
|
|
|
712
625
|
num_chunks,
|
|
713
626
|
)
|
|
714
627
|
except Exception as e:
|
|
715
|
-
raise ControllerException(
|
|
716
|
-
f"Error processing interactive control tasks: {e}"
|
|
717
|
-
) from e
|
|
628
|
+
raise ControllerException(f"Error processing interactive control tasks: {e}") from e
|
|
718
629
|
|
|
719
630
|
# fetch latest run states again (post IC ops states)
|
|
720
631
|
all_run_details = self.db.get_all_runs()
|
|
@@ -728,8 +639,7 @@ class Controller:
|
|
|
728
639
|
self.logger.debug(f"Added run {run_id} to scheduler with {chunks_visited} chunks visited")
|
|
729
640
|
# remove inactive runs from scheduler
|
|
730
641
|
elif (
|
|
731
|
-
run_details["status"] in (RunStatus.STOPPED, RunStatus.DELETED)
|
|
732
|
-
and run_id in scheduler.run_ids
|
|
642
|
+
run_details["status"] in (RunStatus.STOPPED, RunStatus.DELETED) and run_id in scheduler.run_ids
|
|
733
643
|
):
|
|
734
644
|
scheduler.remove_run(run_id)
|
|
735
645
|
self.logger.debug(f"Removed run {run_id} from scheduler")
|
|
@@ -742,9 +652,7 @@ class Controller:
|
|
|
742
652
|
|
|
743
653
|
# Check termination condition
|
|
744
654
|
if run_id is None and worker_id is None and chunk_id is None:
|
|
745
|
-
self.logger.info(
|
|
746
|
-
"Scheduler indicates all runs have completed all chunks"
|
|
747
|
-
)
|
|
655
|
+
self.logger.info("Scheduler indicates all runs have completed all chunks")
|
|
748
656
|
all_done = True
|
|
749
657
|
break
|
|
750
658
|
|
|
@@ -766,9 +674,7 @@ class Controller:
|
|
|
766
674
|
chunk_id,
|
|
767
675
|
config_options={"create_model_fn": create_model_fn},
|
|
768
676
|
)
|
|
769
|
-
self.logger.debug(
|
|
770
|
-
f"Scheduled run {run_id} on worker {worker_id} for chunk {chunk_id}"
|
|
771
|
-
)
|
|
677
|
+
self.logger.debug(f"Scheduled run {run_id} on worker {worker_id} for chunk {chunk_id}")
|
|
772
678
|
|
|
773
679
|
# Small delay
|
|
774
680
|
time.sleep(1)
|
rapidfireai/backend/worker.py
CHANGED
|
@@ -24,7 +24,16 @@ from rapidfireai.ml.checkpoint_utils import (
|
|
|
24
24
|
save_model_to_shared_memory,
|
|
25
25
|
)
|
|
26
26
|
from rapidfireai.ml.trainer import create_trainer_instance
|
|
27
|
-
from rapidfireai.utils.constants import
|
|
27
|
+
from rapidfireai.utils.constants import (
|
|
28
|
+
MLFLOW_URL,
|
|
29
|
+
TENSORBOARD_LOG_DIR,
|
|
30
|
+
USE_SHARED_MEMORY,
|
|
31
|
+
RunStatus,
|
|
32
|
+
SHMObjectType,
|
|
33
|
+
TaskStatus,
|
|
34
|
+
WorkerTask,
|
|
35
|
+
get_tracking_backend,
|
|
36
|
+
)
|
|
28
37
|
from rapidfireai.utils.datapaths import DataPath
|
|
29
38
|
from rapidfireai.utils.exceptions import WorkerException
|
|
30
39
|
from rapidfireai.utils.logging import RFLogger, TrainingLogger
|
|
@@ -82,7 +91,7 @@ class Worker:
|
|
|
82
91
|
tensorboard_log_dir=tensorboard_log_dir,
|
|
83
92
|
)
|
|
84
93
|
# Get experiment if using MLflow
|
|
85
|
-
if hasattr(self.metric_logger,
|
|
94
|
+
if hasattr(self.metric_logger, "get_experiment"):
|
|
86
95
|
self.metric_logger.get_experiment(self.experiment_name)
|
|
87
96
|
|
|
88
97
|
# load datasets
|
|
@@ -119,7 +128,9 @@ class Worker:
|
|
|
119
128
|
# torch.manual_seed(run_details["seed"])
|
|
120
129
|
# np.random.seed(run_details["seed"])
|
|
121
130
|
# random.seed(run_details["seed"])
|
|
122
|
-
effective_batch_size = config_leaf["training_args"].get("per_device_train_batch_size", 1) * config_leaf[
|
|
131
|
+
effective_batch_size = config_leaf["training_args"].get("per_device_train_batch_size", 1) * config_leaf[
|
|
132
|
+
"training_args"
|
|
133
|
+
].get("gradient_accumulation_steps", 1)
|
|
123
134
|
|
|
124
135
|
# fetch train dataset chunk
|
|
125
136
|
train_dataset_chunker = DatasetChunks(
|