rapidfireai 0.10.2rc5__py3-none-any.whl → 0.11.1rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rapidfireai might be problematic. Click here for more details.
- rapidfireai/automl/grid_search.py +4 -5
- rapidfireai/automl/model_config.py +41 -37
- rapidfireai/automl/random_search.py +21 -33
- rapidfireai/backend/controller.py +80 -161
- rapidfireai/backend/worker.py +26 -8
- rapidfireai/cli.py +171 -132
- rapidfireai/db/rf_db.py +1 -1
- rapidfireai/db/tables.sql +1 -1
- rapidfireai/dispatcher/dispatcher.py +3 -1
- rapidfireai/dispatcher/gunicorn.conf.py +1 -1
- rapidfireai/experiment.py +86 -7
- rapidfireai/frontend/build/asset-manifest.json +3 -3
- rapidfireai/frontend/build/index.html +1 -1
- rapidfireai/frontend/build/static/js/{main.1bf27639.js → main.58393d31.js} +3 -3
- rapidfireai/frontend/build/static/js/{main.1bf27639.js.map → main.58393d31.js.map} +1 -1
- rapidfireai/frontend/proxy_middleware.py +1 -1
- rapidfireai/ml/callbacks.py +85 -59
- rapidfireai/ml/trainer.py +42 -86
- rapidfireai/start.sh +117 -34
- rapidfireai/utils/constants.py +22 -1
- rapidfireai/utils/experiment_utils.py +87 -43
- rapidfireai/utils/interactive_controller.py +473 -0
- rapidfireai/utils/logging.py +1 -2
- rapidfireai/utils/metric_logger.py +346 -0
- rapidfireai/utils/mlflow_manager.py +0 -1
- rapidfireai/utils/ping.py +4 -2
- rapidfireai/utils/worker_manager.py +16 -6
- rapidfireai/version.py +2 -2
- {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/METADATA +7 -4
- {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/RECORD +36 -33
- tutorial_notebooks/rf-colab-tensorboard-tutorial.ipynb +314 -0
- /rapidfireai/frontend/build/static/js/{main.1bf27639.js.LICENSE.txt → main.58393d31.js.LICENSE.txt} +0 -0
- {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/WHEEL +0 -0
- {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/entry_points.txt +0 -0
- {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/licenses/LICENSE +0 -0
- {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/top_level.txt +0 -0
|
@@ -9,7 +9,6 @@ from pathlib import Path
|
|
|
9
9
|
from pprint import pformat
|
|
10
10
|
from typing import Any
|
|
11
11
|
|
|
12
|
-
import mlflow
|
|
13
12
|
import torch
|
|
14
13
|
from torch.utils.data import Dataset
|
|
15
14
|
|
|
@@ -20,6 +19,7 @@ from rapidfireai.db.rf_db import RfDb
|
|
|
20
19
|
from rapidfireai.utils.automl_utils import get_flattened_config_leaf, get_runs
|
|
21
20
|
from rapidfireai.utils.constants import (
|
|
22
21
|
MLFLOW_URL,
|
|
22
|
+
TENSORBOARD_LOG_DIR,
|
|
23
23
|
ControllerTask,
|
|
24
24
|
ExperimentTask,
|
|
25
25
|
RunEndedBy,
|
|
@@ -27,11 +27,12 @@ from rapidfireai.utils.constants import (
|
|
|
27
27
|
RunStatus,
|
|
28
28
|
TaskStatus,
|
|
29
29
|
WorkerTask,
|
|
30
|
+
get_tracking_backend,
|
|
30
31
|
)
|
|
31
32
|
from rapidfireai.utils.datapaths import DataPath
|
|
32
33
|
from rapidfireai.utils.exceptions import ControllerException, NoGPUsFoundException
|
|
33
34
|
from rapidfireai.utils.logging import RFLogger
|
|
34
|
-
from rapidfireai.utils.
|
|
35
|
+
from rapidfireai.utils.metric_logger import create_metric_logger
|
|
35
36
|
from rapidfireai.utils.serialize import encode_payload
|
|
36
37
|
from rapidfireai.utils.shm_manager import SharedMemoryManager
|
|
37
38
|
from rapidfireai.utils.worker_manager import WorkerManager
|
|
@@ -69,19 +70,26 @@ class Controller:
|
|
|
69
70
|
self.logger.debug(f"Found {self.num_workers} workers/GPUs.")
|
|
70
71
|
|
|
71
72
|
# initialize shared manager and registry, create shared memory manager instance
|
|
72
|
-
self.shm_manager: SharedMemoryManager = SharedMemoryManager(
|
|
73
|
-
name="controller-shm"
|
|
74
|
-
)
|
|
73
|
+
self.shm_manager: SharedMemoryManager = SharedMemoryManager(name="controller-shm")
|
|
75
74
|
registry, process_lock = self.shm_manager.get_shm_objects()
|
|
76
75
|
|
|
77
76
|
# create worker manager
|
|
78
|
-
self.worker_manager: WorkerManager = WorkerManager(
|
|
79
|
-
|
|
77
|
+
self.worker_manager: WorkerManager = WorkerManager(self.num_workers, registry, process_lock)
|
|
78
|
+
|
|
79
|
+
# create metric logger
|
|
80
|
+
# Initialize DataPath temporarily to get experiment path for tensorboard logs
|
|
81
|
+
experiment_path = self.db.get_experiments_path(self.experiment_name)
|
|
82
|
+
DataPath.initialize(self.experiment_name, experiment_path)
|
|
83
|
+
tensorboard_log_dir = TENSORBOARD_LOG_DIR or str(DataPath.experiments_path / "tensorboard_logs")
|
|
84
|
+
|
|
85
|
+
self.metric_logger = create_metric_logger(
|
|
86
|
+
backend=get_tracking_backend(),
|
|
87
|
+
mlflow_tracking_uri=MLFLOW_URL,
|
|
88
|
+
tensorboard_log_dir=tensorboard_log_dir,
|
|
80
89
|
)
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
self.mlflow_manager.get_experiment(self.experiment_name)
|
|
90
|
+
# Get experiment if using MLflow
|
|
91
|
+
if hasattr(self.metric_logger, "get_experiment"):
|
|
92
|
+
self.metric_logger.get_experiment(self.experiment_name)
|
|
85
93
|
|
|
86
94
|
self.logger.debug("Controller initialized")
|
|
87
95
|
|
|
@@ -104,14 +112,12 @@ class Controller:
|
|
|
104
112
|
for config_leaf in config_leafs:
|
|
105
113
|
flattened_config = get_flattened_config_leaf(config_leaf)
|
|
106
114
|
# print("flattened_config: ",flattened_config)
|
|
107
|
-
total_steps = self._get_total_step(
|
|
108
|
-
config_leaf, len_train_dataset, num_chunks
|
|
109
|
-
)
|
|
115
|
+
total_steps = self._get_total_step(config_leaf, len_train_dataset, num_chunks)
|
|
110
116
|
|
|
111
117
|
# get clone modify info
|
|
112
118
|
warm_started_from = clone_modify_info.get("warm_started_from") if clone_modify_info else None
|
|
113
119
|
cloned_from = clone_modify_info.get("cloned_from") if clone_modify_info else None
|
|
114
|
-
chunk_offset = clone_modify_info.get("chunk_offset",0) if clone_modify_info else 0
|
|
120
|
+
chunk_offset = clone_modify_info.get("chunk_offset", 0) if clone_modify_info else 0
|
|
115
121
|
|
|
116
122
|
run_id = self.db.create_run(
|
|
117
123
|
config_leaf=config_leaf,
|
|
@@ -131,57 +137,44 @@ class Controller:
|
|
|
131
137
|
try:
|
|
132
138
|
base_run_path = DataPath.base_run_path(run_id)
|
|
133
139
|
work_dir_path = DataPath.work_dir_path(base_run_path)
|
|
134
|
-
initial_checkpoint_path = DataPath.initial_checkpoint_path(
|
|
135
|
-
base_run_path
|
|
136
|
-
)
|
|
140
|
+
initial_checkpoint_path = DataPath.initial_checkpoint_path(base_run_path)
|
|
137
141
|
final_checkpoint_path = DataPath.final_checkpoint_path(base_run_path)
|
|
138
|
-
intermediate_checkpoint_path = DataPath.intermediate_checkpoint_path(
|
|
139
|
-
base_run_path
|
|
140
|
-
)
|
|
142
|
+
intermediate_checkpoint_path = DataPath.intermediate_checkpoint_path(base_run_path)
|
|
141
143
|
|
|
142
144
|
Path.mkdir(work_dir_path, parents=True, exist_ok=True)
|
|
143
145
|
Path.mkdir(initial_checkpoint_path, parents=True, exist_ok=True)
|
|
144
146
|
Path.mkdir(final_checkpoint_path, parents=True, exist_ok=True)
|
|
145
147
|
Path.mkdir(intermediate_checkpoint_path, parents=True, exist_ok=True)
|
|
146
148
|
except (PermissionError, OSError) as e:
|
|
147
|
-
raise ControllerException(
|
|
148
|
-
f"Failed to create required Run DataPath directories: {e}"
|
|
149
|
-
) from e
|
|
149
|
+
raise ControllerException(f"Failed to create required Run DataPath directories: {e}") from e
|
|
150
150
|
|
|
151
|
-
# create new
|
|
151
|
+
# create new tracking run
|
|
152
152
|
try:
|
|
153
|
-
# create new
|
|
154
|
-
mlflow_run_id = self.
|
|
153
|
+
# create new tracking run and get the mlflow_run_id
|
|
154
|
+
mlflow_run_id = self.metric_logger.create_run(str(run_id))
|
|
155
155
|
|
|
156
|
-
# populate
|
|
156
|
+
# populate tracking backend with model config info
|
|
157
157
|
for key, value in flattened_config.items():
|
|
158
|
-
self.
|
|
158
|
+
self.metric_logger.log_param(mlflow_run_id, key, value)
|
|
159
159
|
if warm_started_from:
|
|
160
|
-
self.
|
|
161
|
-
mlflow_run_id, "warm-start", str(warm_started_from)
|
|
162
|
-
)
|
|
160
|
+
self.metric_logger.log_param(mlflow_run_id, "warm-start", str(warm_started_from))
|
|
163
161
|
if cloned_from:
|
|
164
|
-
self.
|
|
165
|
-
|
|
166
|
-
)
|
|
167
|
-
self.logger.debug(
|
|
168
|
-
f"Populated MLFlow with model config info for run {run_id}."
|
|
169
|
-
)
|
|
162
|
+
self.metric_logger.log_param(mlflow_run_id, "parent-run", str(cloned_from))
|
|
163
|
+
self.logger.debug(f"Populated MLFlow with model config info for run {run_id}.")
|
|
170
164
|
self.db.set_run_details(
|
|
171
165
|
run_id=run_id,
|
|
172
166
|
mlflow_run_id=mlflow_run_id,
|
|
173
167
|
flattened_config=flattened_config,
|
|
174
168
|
)
|
|
175
|
-
except
|
|
176
|
-
|
|
169
|
+
except Exception as e:
|
|
170
|
+
# Catch any metric logger exceptions (MLflow, TensorBoard, etc.)
|
|
171
|
+
msg = f"Error creating new tracking run for run {run_id} - {e}."
|
|
177
172
|
print(msg)
|
|
178
|
-
self.
|
|
173
|
+
self.metric_logger.end_run(mlflow_run_id)
|
|
179
174
|
self.logger.error(msg, exc_info=True)
|
|
180
175
|
|
|
181
176
|
total_runs = len(runs)
|
|
182
|
-
self.logger.info(
|
|
183
|
-
f"Created {total_runs} runs - \n{pformat(runs, indent=4, width=120)}"
|
|
184
|
-
)
|
|
177
|
+
self.logger.info(f"Created {total_runs} runs - \n{pformat(runs, indent=4, width=120)}")
|
|
185
178
|
self.logger.debug(f"Got {total_runs} runs for {source.value}.")
|
|
186
179
|
|
|
187
180
|
# set experiment task to run_fit
|
|
@@ -195,24 +188,17 @@ class Controller:
|
|
|
195
188
|
|
|
196
189
|
# check if there are any other runs with the same base model
|
|
197
190
|
base_model_name = self.db.get_run(run_id)["config_leaf"]["model_name"]
|
|
198
|
-
relevant_runs = self.db.get_runs_by_status(
|
|
199
|
-
[RunStatus.ONGOING, RunStatus.NEW, RunStatus.STOPPED]
|
|
200
|
-
)
|
|
191
|
+
relevant_runs = self.db.get_runs_by_status([RunStatus.ONGOING, RunStatus.NEW, RunStatus.STOPPED])
|
|
201
192
|
|
|
202
193
|
# get shared object types to delete - if no other runs are using it
|
|
203
194
|
delete_shared_objects = True
|
|
204
195
|
for r_run_id, r_run_details in relevant_runs.items():
|
|
205
|
-
if
|
|
206
|
-
r_run_details["config_leaf"]["model_name"] == base_model_name
|
|
207
|
-
and r_run_id != run_id
|
|
208
|
-
):
|
|
196
|
+
if r_run_details["config_leaf"]["model_name"] == base_model_name and r_run_id != run_id:
|
|
209
197
|
delete_shared_objects = False
|
|
210
198
|
break
|
|
211
199
|
|
|
212
200
|
# delete model object from shared memory
|
|
213
|
-
self.shm_manager.delete_model_object(
|
|
214
|
-
run_id, base_model_name if delete_shared_objects else None
|
|
215
|
-
)
|
|
201
|
+
self.shm_manager.delete_model_object(run_id, base_model_name if delete_shared_objects else None)
|
|
216
202
|
|
|
217
203
|
def _process_interactive_control(
|
|
218
204
|
self,
|
|
@@ -237,9 +223,7 @@ class Controller:
|
|
|
237
223
|
status=RunStatus.STOPPED,
|
|
238
224
|
ended_by=RunEndedBy.INTERACTIVE_CONTROL,
|
|
239
225
|
)
|
|
240
|
-
self.db.set_ic_ops_task_status(
|
|
241
|
-
run_state["task_id"], TaskStatus.COMPLETED
|
|
242
|
-
)
|
|
226
|
+
self.db.set_ic_ops_task_status(run_state["task_id"], TaskStatus.COMPLETED)
|
|
243
227
|
self.ic_logger.info(f"Stopping run {run_id} by Interactive Control")
|
|
244
228
|
elif run_state["status"] == RunStatus.DELETED:
|
|
245
229
|
# process deleted tasks
|
|
@@ -249,16 +233,14 @@ class Controller:
|
|
|
249
233
|
|
|
250
234
|
# delete run from MLFlow
|
|
251
235
|
mlflow_run_id = self.db.get_run(run_id)["mlflow_run_id"]
|
|
252
|
-
self.
|
|
236
|
+
self.metric_logger.delete_run(mlflow_run_id)
|
|
253
237
|
# mark run as deleted
|
|
254
238
|
self.db.set_run_details(
|
|
255
239
|
run_id=run_id,
|
|
256
240
|
status=RunStatus.DELETED,
|
|
257
241
|
ended_by=RunEndedBy.INTERACTIVE_CONTROL,
|
|
258
242
|
)
|
|
259
|
-
self.db.set_ic_ops_task_status(
|
|
260
|
-
run_state["task_id"], TaskStatus.COMPLETED
|
|
261
|
-
)
|
|
243
|
+
self.db.set_ic_ops_task_status(run_state["task_id"], TaskStatus.COMPLETED)
|
|
262
244
|
self.ic_logger.info(f"Deleting run {run_id} by Interactive Control")
|
|
263
245
|
elif run_state["status"] == RunStatus.ONGOING:
|
|
264
246
|
# process ongoing tasks
|
|
@@ -267,15 +249,11 @@ class Controller:
|
|
|
267
249
|
status=RunStatus.ONGOING,
|
|
268
250
|
ended_by="",
|
|
269
251
|
)
|
|
270
|
-
self.db.set_ic_ops_task_status(
|
|
271
|
-
run_state["task_id"], TaskStatus.COMPLETED
|
|
272
|
-
)
|
|
252
|
+
self.db.set_ic_ops_task_status(run_state["task_id"], TaskStatus.COMPLETED)
|
|
273
253
|
self.ic_logger.info(f"Resuming run {run_id} by Interactive Control")
|
|
274
254
|
elif run_state["status"] == RunStatus.COMPLETED:
|
|
275
255
|
# process completed tasks
|
|
276
|
-
self.logger.warning(
|
|
277
|
-
f"Run {run_id} is already completed. Skipping Interactive Control task."
|
|
278
|
-
)
|
|
256
|
+
self.logger.warning(f"Run {run_id} is already completed. Skipping Interactive Control task.")
|
|
279
257
|
self.db.set_ic_ops_task_status(run_state["task_id"], TaskStatus.SKIPPED)
|
|
280
258
|
else:
|
|
281
259
|
raise ValueError(f"Unsupported run status {run_state['status']}")
|
|
@@ -291,9 +269,7 @@ class Controller:
|
|
|
291
269
|
# add additional_kwargs to config_leaf if it exists in the parent run
|
|
292
270
|
parent_run_details = self.db.get_run(parent_run_id)
|
|
293
271
|
if "additional_kwargs" in parent_run_details["config_leaf"]:
|
|
294
|
-
config_leaf["additional_kwargs"] = parent_run_details["config_leaf"][
|
|
295
|
-
"additional_kwargs"
|
|
296
|
-
]
|
|
272
|
+
config_leaf["additional_kwargs"] = parent_run_details["config_leaf"]["additional_kwargs"]
|
|
297
273
|
|
|
298
274
|
# create model for the new run
|
|
299
275
|
try:
|
|
@@ -311,7 +287,9 @@ class Controller:
|
|
|
311
287
|
)
|
|
312
288
|
elif ic_op == ControllerTask.IC_CLONE_MODIFY_WARM:
|
|
313
289
|
# calculate clone chunk offset
|
|
314
|
-
effective_batch_size = parent_run_details["config_leaf"]["training_args"].get(
|
|
290
|
+
effective_batch_size = parent_run_details["config_leaf"]["training_args"].get(
|
|
291
|
+
"per_device_train_batch_size", 1
|
|
292
|
+
) * parent_run_details["config_leaf"]["training_args"].get("gradient_accumulation_steps", 1)
|
|
315
293
|
chunker = DatasetChunks(
|
|
316
294
|
len_train_dataset,
|
|
317
295
|
num_chunks,
|
|
@@ -342,12 +320,8 @@ class Controller:
|
|
|
342
320
|
)
|
|
343
321
|
except Exception as e:
|
|
344
322
|
self.db.set_ic_ops_task_status(task["task_id"], TaskStatus.FAILED)
|
|
345
|
-
self.ic_logger.error(
|
|
346
|
-
|
|
347
|
-
)
|
|
348
|
-
raise ControllerException(
|
|
349
|
-
f"Error creating model for run {parent_run_id}: {e}"
|
|
350
|
-
) from e
|
|
323
|
+
self.ic_logger.error(f"Error creating model for run {parent_run_id}: {e}")
|
|
324
|
+
raise ControllerException(f"Error creating model for run {parent_run_id}: {e}") from e
|
|
351
325
|
|
|
352
326
|
def _process_interm_ic_ops_states(
|
|
353
327
|
self,
|
|
@@ -376,11 +350,7 @@ class Controller:
|
|
|
376
350
|
if is_clone_modify_task:
|
|
377
351
|
# clone_modify tasks
|
|
378
352
|
# get latest run state
|
|
379
|
-
run_status = (
|
|
380
|
-
run_states[run_id]["status"]
|
|
381
|
-
if run_id in run_states
|
|
382
|
-
else self.db.get_run(run_id)["status"]
|
|
383
|
-
)
|
|
353
|
+
run_status = run_states[run_id]["status"] if run_id in run_states else self.db.get_run(run_id)["status"]
|
|
384
354
|
|
|
385
355
|
# track clone_modify tasks only for non-deleted runs
|
|
386
356
|
if run_status != RunStatus.DELETED:
|
|
@@ -388,9 +358,7 @@ class Controller:
|
|
|
388
358
|
self.ic_logger.info(f"Added {task['ic_op']} task for run {run_id}.")
|
|
389
359
|
else:
|
|
390
360
|
self.db.set_ic_ops_task_status(task["task_id"], TaskStatus.SKIPPED)
|
|
391
|
-
self.ic_logger.warning(
|
|
392
|
-
f"Skipping {task['ic_op']} task for deleted run {run_id}."
|
|
393
|
-
)
|
|
361
|
+
self.ic_logger.warning(f"Skipping {task['ic_op']} task for deleted run {run_id}.")
|
|
394
362
|
else:
|
|
395
363
|
# Non clone_modify tasks
|
|
396
364
|
if run_id not in run_states:
|
|
@@ -407,32 +375,21 @@ class Controller:
|
|
|
407
375
|
ControllerTask.IC_STOP,
|
|
408
376
|
]:
|
|
409
377
|
# ignore RESUME/STOP tasks for completed runs
|
|
410
|
-
self.ic_logger.warning(
|
|
411
|
-
f"Ignoring RESUME/STOP task for run {run_id} as it is already completed"
|
|
412
|
-
)
|
|
378
|
+
self.ic_logger.warning(f"Ignoring RESUME/STOP task for run {run_id} as it is already completed")
|
|
413
379
|
self.db.set_ic_ops_task_status(task["task_id"], TaskStatus.SKIPPED)
|
|
414
|
-
elif
|
|
415
|
-
current_status == RunStatus.FAILED
|
|
416
|
-
and task["ic_op"] != ControllerTask.IC_DELETE
|
|
417
|
-
):
|
|
380
|
+
elif current_status == RunStatus.FAILED and task["ic_op"] != ControllerTask.IC_DELETE:
|
|
418
381
|
# ignore all tasks except DELETE for failed runs
|
|
419
|
-
self.ic_logger.warning(
|
|
420
|
-
f"Ignoring task {task['ic_op'].value} for failed run {run_id}"
|
|
421
|
-
)
|
|
382
|
+
self.ic_logger.warning(f"Ignoring task {task['ic_op'].value} for failed run {run_id}")
|
|
422
383
|
self.db.set_ic_ops_task_status(task["task_id"], TaskStatus.SKIPPED)
|
|
423
384
|
elif current_status == RunStatus.DELETED:
|
|
424
385
|
# ignore all tasks for deleted runs
|
|
425
|
-
self.ic_logger.warning(
|
|
426
|
-
f"Ignoring task {task['ic_op'].value} for deleted run {run_id}"
|
|
427
|
-
)
|
|
386
|
+
self.ic_logger.warning(f"Ignoring task {task['ic_op'].value} for deleted run {run_id}")
|
|
428
387
|
self.db.set_ic_ops_task_status(task["task_id"], TaskStatus.SKIPPED)
|
|
429
388
|
else:
|
|
430
389
|
# valid ic_op for this run
|
|
431
390
|
# mark prev task as completed
|
|
432
391
|
if run_states[run_id]["task_id"] is not None:
|
|
433
|
-
self.db.set_ic_ops_task_status(
|
|
434
|
-
run_states[run_id]["task_id"], TaskStatus.COMPLETED
|
|
435
|
-
)
|
|
392
|
+
self.db.set_ic_ops_task_status(run_states[run_id]["task_id"], TaskStatus.COMPLETED)
|
|
436
393
|
|
|
437
394
|
# add new task to run states
|
|
438
395
|
if task["ic_op"] == ControllerTask.IC_STOP:
|
|
@@ -445,26 +402,20 @@ class Controller:
|
|
|
445
402
|
updated_status = RunStatus.ONGOING
|
|
446
403
|
info_msg = f"Received RESUME task for run {run_id}"
|
|
447
404
|
else:
|
|
448
|
-
self.db.set_ic_ops_task_status(
|
|
449
|
-
task["task_id"], TaskStatus.FAILED
|
|
450
|
-
)
|
|
405
|
+
self.db.set_ic_ops_task_status(task["task_id"], TaskStatus.FAILED)
|
|
451
406
|
raise ValueError(f"Unsupported task {task['ic_op']}")
|
|
452
407
|
run_states[run_id].update(
|
|
453
408
|
{
|
|
454
409
|
"task_id": task["task_id"],
|
|
455
410
|
"task": task["ic_op"],
|
|
456
|
-
"status": (
|
|
457
|
-
updated_status if updated_status else current_status
|
|
458
|
-
),
|
|
411
|
+
"status": (updated_status if updated_status else current_status),
|
|
459
412
|
}
|
|
460
413
|
)
|
|
461
414
|
self.ic_logger.info(info_msg)
|
|
462
415
|
|
|
463
416
|
return run_states, clone_modify_tasks
|
|
464
417
|
|
|
465
|
-
def _get_total_step(
|
|
466
|
-
self, config_leaf: dict[str, Any], len_train_dataset: int, num_chunks: int
|
|
467
|
-
) -> int:
|
|
418
|
+
def _get_total_step(self, config_leaf: dict[str, Any], len_train_dataset: int, num_chunks: int) -> int:
|
|
468
419
|
"""Get the total number of steps for a run."""
|
|
469
420
|
num_train_epochs = config_leaf["training_args"].get("num_train_epochs", 1)
|
|
470
421
|
|
|
@@ -474,25 +425,20 @@ class Controller:
|
|
|
474
425
|
# ceil to nearest chunk multiple
|
|
475
426
|
total_steps = config_leaf["training_args"]["max_steps"]
|
|
476
427
|
elif num_train_epochs:
|
|
477
|
-
per_device_train_batch_size = config_leaf["training_args"].get(
|
|
478
|
-
|
|
479
|
-
)
|
|
480
|
-
gradient_accumulation_steps = config_leaf["training_args"].get(
|
|
481
|
-
"gradient_accumulation_steps", 1
|
|
482
|
-
)
|
|
428
|
+
per_device_train_batch_size = config_leaf["training_args"].get("per_device_train_batch_size", 1)
|
|
429
|
+
gradient_accumulation_steps = config_leaf["training_args"].get("gradient_accumulation_steps", 1)
|
|
483
430
|
len_dataloader = math.ceil(len_train_dataset / per_device_train_batch_size)
|
|
484
431
|
num_update_steps_per_epoch = max(
|
|
485
|
-
len_dataloader // gradient_accumulation_steps
|
|
486
|
-
+ int(len_dataloader % gradient_accumulation_steps > 0),
|
|
432
|
+
len_dataloader // gradient_accumulation_steps + int(len_dataloader % gradient_accumulation_steps > 0),
|
|
487
433
|
1,
|
|
488
434
|
)
|
|
489
435
|
total_steps = math.ceil(num_train_epochs * num_update_steps_per_epoch)
|
|
490
436
|
|
|
491
437
|
if config_leaf.get("trainer_type", "SFT") == "GRPO":
|
|
492
438
|
num_generations = config_leaf["training_args"].get("num_generations", 8)
|
|
493
|
-
total_steps = (
|
|
494
|
-
|
|
495
|
-
)
|
|
439
|
+
total_steps = (num_generations * len_train_dataset * num_train_epochs) // (
|
|
440
|
+
gradient_accumulation_steps * per_device_train_batch_size
|
|
441
|
+
)
|
|
496
442
|
return total_steps
|
|
497
443
|
|
|
498
444
|
def run_fit(
|
|
@@ -508,9 +454,7 @@ class Controller:
|
|
|
508
454
|
|
|
509
455
|
# set experiment task to create models
|
|
510
456
|
self.db.set_experiment_current_task(ExperimentTask.CREATE_MODELS)
|
|
511
|
-
self.logger.debug(
|
|
512
|
-
f"Set experiment task to {ExperimentTask.CREATE_MODELS.value}."
|
|
513
|
-
)
|
|
457
|
+
self.logger.debug(f"Set experiment task to {ExperimentTask.CREATE_MODELS.value}.")
|
|
514
458
|
|
|
515
459
|
# save train and eval dataset objects to a file for workers to load
|
|
516
460
|
try:
|
|
@@ -595,10 +539,7 @@ class Controller:
|
|
|
595
539
|
)
|
|
596
540
|
|
|
597
541
|
# skip if task is the same as previous iteration (no change in status) or run is not active
|
|
598
|
-
if
|
|
599
|
-
current_task_tuple == prev_task_tuple
|
|
600
|
-
or worker_task["run_id"] not in scheduler.run_ids
|
|
601
|
-
):
|
|
542
|
+
if current_task_tuple == prev_task_tuple or worker_task["run_id"] not in scheduler.run_ids:
|
|
602
543
|
continue
|
|
603
544
|
|
|
604
545
|
if worker_task["status"] == TaskStatus.COMPLETED:
|
|
@@ -611,9 +552,7 @@ class Controller:
|
|
|
611
552
|
run_id = worker_task["run_id"]
|
|
612
553
|
chunk_id = worker_task["chunk_id"]
|
|
613
554
|
run_details = all_run_details[run_id]
|
|
614
|
-
self.logger.debug(
|
|
615
|
-
f"Completed task: run {run_id}, chunk {chunk_id} on worker {worker_id}"
|
|
616
|
-
)
|
|
555
|
+
self.logger.debug(f"Completed task: run {run_id}, chunk {chunk_id} on worker {worker_id}")
|
|
617
556
|
self.logger.info(
|
|
618
557
|
f"Run {run_id} completed steps - {run_details['completed_steps']}/{run_details['total_steps']}"
|
|
619
558
|
)
|
|
@@ -635,11 +574,7 @@ class Controller:
|
|
|
635
574
|
|
|
636
575
|
# Update progress
|
|
637
576
|
progress_percentage = (
|
|
638
|
-
(
|
|
639
|
-
run_details["completed_steps"]
|
|
640
|
-
/ run_details["total_steps"]
|
|
641
|
-
* 100
|
|
642
|
-
)
|
|
577
|
+
(run_details["completed_steps"] / run_details["total_steps"] * 100)
|
|
643
578
|
if run_details["total_steps"] > 0
|
|
644
579
|
else 0
|
|
645
580
|
)
|
|
@@ -660,16 +595,11 @@ class Controller:
|
|
|
660
595
|
)
|
|
661
596
|
# Check if run has completed only current epoch (hasn't reached total_steps yet)
|
|
662
597
|
elif (
|
|
663
|
-
new_chunks_visited == num_chunks
|
|
664
|
-
and run_details["completed_steps"] < run_details["total_steps"]
|
|
598
|
+
new_chunks_visited == num_chunks and run_details["completed_steps"] < run_details["total_steps"]
|
|
665
599
|
):
|
|
666
600
|
scheduler.reset_run(run_id)
|
|
667
|
-
self.db.set_run_details(
|
|
668
|
-
|
|
669
|
-
)
|
|
670
|
-
self.logger.info(
|
|
671
|
-
f"Run {run_id} has completed epoch ({new_chunks_visited}/{num_chunks} chunks)"
|
|
672
|
-
)
|
|
601
|
+
self.db.set_run_details(run_id=run_id, num_chunks_visited_curr_epoch=0)
|
|
602
|
+
self.logger.info(f"Run {run_id} has completed epoch ({new_chunks_visited}/{num_chunks} chunks)")
|
|
673
603
|
|
|
674
604
|
# Check for failed runs and update scheduler, local state, shm
|
|
675
605
|
for worker_task in failed_tasks:
|
|
@@ -685,12 +615,8 @@ class Controller:
|
|
|
685
615
|
|
|
686
616
|
# Process interactive control tasks (this fetches latest run states internally)
|
|
687
617
|
try:
|
|
688
|
-
currently_scheduled_runs = list(
|
|
689
|
-
|
|
690
|
-
)
|
|
691
|
-
run_states, clone_modify_tasks = self._process_interm_ic_ops_states(
|
|
692
|
-
currently_scheduled_runs
|
|
693
|
-
)
|
|
618
|
+
currently_scheduled_runs = list(scheduler.worker_running_current_run.values())
|
|
619
|
+
run_states, clone_modify_tasks = self._process_interm_ic_ops_states(currently_scheduled_runs)
|
|
694
620
|
self._process_interactive_control(
|
|
695
621
|
run_states,
|
|
696
622
|
clone_modify_tasks,
|
|
@@ -699,9 +625,7 @@ class Controller:
|
|
|
699
625
|
num_chunks,
|
|
700
626
|
)
|
|
701
627
|
except Exception as e:
|
|
702
|
-
raise ControllerException(
|
|
703
|
-
f"Error processing interactive control tasks: {e}"
|
|
704
|
-
) from e
|
|
628
|
+
raise ControllerException(f"Error processing interactive control tasks: {e}") from e
|
|
705
629
|
|
|
706
630
|
# fetch latest run states again (post IC ops states)
|
|
707
631
|
all_run_details = self.db.get_all_runs()
|
|
@@ -715,8 +639,7 @@ class Controller:
|
|
|
715
639
|
self.logger.debug(f"Added run {run_id} to scheduler with {chunks_visited} chunks visited")
|
|
716
640
|
# remove inactive runs from scheduler
|
|
717
641
|
elif (
|
|
718
|
-
run_details["status"] in (RunStatus.STOPPED, RunStatus.DELETED)
|
|
719
|
-
and run_id in scheduler.run_ids
|
|
642
|
+
run_details["status"] in (RunStatus.STOPPED, RunStatus.DELETED) and run_id in scheduler.run_ids
|
|
720
643
|
):
|
|
721
644
|
scheduler.remove_run(run_id)
|
|
722
645
|
self.logger.debug(f"Removed run {run_id} from scheduler")
|
|
@@ -729,9 +652,7 @@ class Controller:
|
|
|
729
652
|
|
|
730
653
|
# Check termination condition
|
|
731
654
|
if run_id is None and worker_id is None and chunk_id is None:
|
|
732
|
-
self.logger.info(
|
|
733
|
-
"Scheduler indicates all runs have completed all chunks"
|
|
734
|
-
)
|
|
655
|
+
self.logger.info("Scheduler indicates all runs have completed all chunks")
|
|
735
656
|
all_done = True
|
|
736
657
|
break
|
|
737
658
|
|
|
@@ -753,9 +674,7 @@ class Controller:
|
|
|
753
674
|
chunk_id,
|
|
754
675
|
config_options={"create_model_fn": create_model_fn},
|
|
755
676
|
)
|
|
756
|
-
self.logger.debug(
|
|
757
|
-
f"Scheduled run {run_id} on worker {worker_id} for chunk {chunk_id}"
|
|
758
|
-
)
|
|
677
|
+
self.logger.debug(f"Scheduled run {run_id} on worker {worker_id} for chunk {chunk_id}")
|
|
759
678
|
|
|
760
679
|
# Small delay
|
|
761
680
|
time.sleep(1)
|
rapidfireai/backend/worker.py
CHANGED
|
@@ -24,11 +24,20 @@ from rapidfireai.ml.checkpoint_utils import (
|
|
|
24
24
|
save_model_to_shared_memory,
|
|
25
25
|
)
|
|
26
26
|
from rapidfireai.ml.trainer import create_trainer_instance
|
|
27
|
-
from rapidfireai.utils.constants import
|
|
27
|
+
from rapidfireai.utils.constants import (
|
|
28
|
+
MLFLOW_URL,
|
|
29
|
+
TENSORBOARD_LOG_DIR,
|
|
30
|
+
USE_SHARED_MEMORY,
|
|
31
|
+
RunStatus,
|
|
32
|
+
SHMObjectType,
|
|
33
|
+
TaskStatus,
|
|
34
|
+
WorkerTask,
|
|
35
|
+
get_tracking_backend,
|
|
36
|
+
)
|
|
28
37
|
from rapidfireai.utils.datapaths import DataPath
|
|
29
38
|
from rapidfireai.utils.exceptions import WorkerException
|
|
30
39
|
from rapidfireai.utils.logging import RFLogger, TrainingLogger
|
|
31
|
-
from rapidfireai.utils.
|
|
40
|
+
from rapidfireai.utils.metric_logger import create_metric_logger
|
|
32
41
|
from rapidfireai.utils.serialize import decode_db_payload
|
|
33
42
|
from rapidfireai.utils.shm_manager import SharedMemoryManager
|
|
34
43
|
from rapidfireai.utils.trainer_config import TrainerConfig
|
|
@@ -71,13 +80,20 @@ class Worker:
|
|
|
71
80
|
# get experiment name
|
|
72
81
|
self.experiment_name: str = self.db.get_running_experiment()["experiment_name"]
|
|
73
82
|
|
|
74
|
-
# create mlflow manager
|
|
75
|
-
self.mlflow_manager: MLflowManager = MLflowManager(MLFLOW_URL)
|
|
76
|
-
self.mlflow_manager.get_experiment(self.experiment_name)
|
|
77
|
-
|
|
78
83
|
# initialize data paths
|
|
79
84
|
DataPath.initialize(self.experiment_name, self.db.get_experiments_path(self.experiment_name))
|
|
80
85
|
|
|
86
|
+
# create metric logger
|
|
87
|
+
tensorboard_log_dir = TENSORBOARD_LOG_DIR or str(DataPath.experiments_path / "tensorboard_logs")
|
|
88
|
+
self.metric_logger = create_metric_logger(
|
|
89
|
+
backend=get_tracking_backend(),
|
|
90
|
+
mlflow_tracking_uri=MLFLOW_URL,
|
|
91
|
+
tensorboard_log_dir=tensorboard_log_dir,
|
|
92
|
+
)
|
|
93
|
+
# Get experiment if using MLflow
|
|
94
|
+
if hasattr(self.metric_logger, "get_experiment"):
|
|
95
|
+
self.metric_logger.get_experiment(self.experiment_name)
|
|
96
|
+
|
|
81
97
|
# load datasets
|
|
82
98
|
self.train_dataset, self.eval_dataset, self.num_chunks = self.load_datasets()
|
|
83
99
|
self.len_train_dataset = len(self.train_dataset)
|
|
@@ -112,7 +128,9 @@ class Worker:
|
|
|
112
128
|
# torch.manual_seed(run_details["seed"])
|
|
113
129
|
# np.random.seed(run_details["seed"])
|
|
114
130
|
# random.seed(run_details["seed"])
|
|
115
|
-
effective_batch_size = config_leaf["training_args"].get("per_device_train_batch_size", 1) * config_leaf[
|
|
131
|
+
effective_batch_size = config_leaf["training_args"].get("per_device_train_batch_size", 1) * config_leaf[
|
|
132
|
+
"training_args"
|
|
133
|
+
].get("gradient_accumulation_steps", 1)
|
|
116
134
|
|
|
117
135
|
# fetch train dataset chunk
|
|
118
136
|
train_dataset_chunker = DatasetChunks(
|
|
@@ -149,7 +167,7 @@ class Worker:
|
|
|
149
167
|
stderr_buffer = StringIO()
|
|
150
168
|
with redirect_stdout(stdout_buffer), redirect_stderr(stderr_buffer):
|
|
151
169
|
trainer_instance, base_model_name = create_trainer_instance(
|
|
152
|
-
trainer_config, self.shm_manager, USE_SHARED_MEMORY, self.
|
|
170
|
+
trainer_config, self.shm_manager, USE_SHARED_MEMORY, self.metric_logger, chunk_id
|
|
153
171
|
)
|
|
154
172
|
|
|
155
173
|
# if first time, save checkpoint to disk
|