rapidfireai 0.10.2rc5__py3-none-any.whl → 0.11.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rapidfireai might be problematic. Click here for more details.

Files changed (36) hide show
  1. rapidfireai/automl/grid_search.py +4 -5
  2. rapidfireai/automl/model_config.py +41 -37
  3. rapidfireai/automl/random_search.py +21 -33
  4. rapidfireai/backend/controller.py +80 -161
  5. rapidfireai/backend/worker.py +26 -8
  6. rapidfireai/cli.py +171 -132
  7. rapidfireai/db/rf_db.py +1 -1
  8. rapidfireai/db/tables.sql +1 -1
  9. rapidfireai/dispatcher/dispatcher.py +3 -1
  10. rapidfireai/dispatcher/gunicorn.conf.py +1 -1
  11. rapidfireai/experiment.py +86 -7
  12. rapidfireai/frontend/build/asset-manifest.json +3 -3
  13. rapidfireai/frontend/build/index.html +1 -1
  14. rapidfireai/frontend/build/static/js/{main.1bf27639.js → main.58393d31.js} +3 -3
  15. rapidfireai/frontend/build/static/js/{main.1bf27639.js.map → main.58393d31.js.map} +1 -1
  16. rapidfireai/frontend/proxy_middleware.py +1 -1
  17. rapidfireai/ml/callbacks.py +85 -59
  18. rapidfireai/ml/trainer.py +42 -86
  19. rapidfireai/start.sh +117 -34
  20. rapidfireai/utils/constants.py +22 -1
  21. rapidfireai/utils/experiment_utils.py +87 -43
  22. rapidfireai/utils/interactive_controller.py +473 -0
  23. rapidfireai/utils/logging.py +1 -2
  24. rapidfireai/utils/metric_logger.py +346 -0
  25. rapidfireai/utils/mlflow_manager.py +0 -1
  26. rapidfireai/utils/ping.py +4 -2
  27. rapidfireai/utils/worker_manager.py +16 -6
  28. rapidfireai/version.py +2 -2
  29. {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/METADATA +7 -4
  30. {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/RECORD +36 -33
  31. tutorial_notebooks/rf-colab-tensorboard-tutorial.ipynb +314 -0
  32. /rapidfireai/frontend/build/static/js/{main.1bf27639.js.LICENSE.txt → main.58393d31.js.LICENSE.txt} +0 -0
  33. {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/WHEEL +0 -0
  34. {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/entry_points.txt +0 -0
  35. {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/licenses/LICENSE +0 -0
  36. {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/top_level.txt +0 -0
@@ -9,7 +9,6 @@ from pathlib import Path
9
9
  from pprint import pformat
10
10
  from typing import Any
11
11
 
12
- import mlflow
13
12
  import torch
14
13
  from torch.utils.data import Dataset
15
14
 
@@ -20,6 +19,7 @@ from rapidfireai.db.rf_db import RfDb
20
19
  from rapidfireai.utils.automl_utils import get_flattened_config_leaf, get_runs
21
20
  from rapidfireai.utils.constants import (
22
21
  MLFLOW_URL,
22
+ TENSORBOARD_LOG_DIR,
23
23
  ControllerTask,
24
24
  ExperimentTask,
25
25
  RunEndedBy,
@@ -27,11 +27,12 @@ from rapidfireai.utils.constants import (
27
27
  RunStatus,
28
28
  TaskStatus,
29
29
  WorkerTask,
30
+ get_tracking_backend,
30
31
  )
31
32
  from rapidfireai.utils.datapaths import DataPath
32
33
  from rapidfireai.utils.exceptions import ControllerException, NoGPUsFoundException
33
34
  from rapidfireai.utils.logging import RFLogger
34
- from rapidfireai.utils.mlflow_manager import MLflowManager
35
+ from rapidfireai.utils.metric_logger import create_metric_logger
35
36
  from rapidfireai.utils.serialize import encode_payload
36
37
  from rapidfireai.utils.shm_manager import SharedMemoryManager
37
38
  from rapidfireai.utils.worker_manager import WorkerManager
@@ -69,19 +70,26 @@ class Controller:
69
70
  self.logger.debug(f"Found {self.num_workers} workers/GPUs.")
70
71
 
71
72
  # initialize shared manager and registry, create shared memory manager instance
72
- self.shm_manager: SharedMemoryManager = SharedMemoryManager(
73
- name="controller-shm"
74
- )
73
+ self.shm_manager: SharedMemoryManager = SharedMemoryManager(name="controller-shm")
75
74
  registry, process_lock = self.shm_manager.get_shm_objects()
76
75
 
77
76
  # create worker manager
78
- self.worker_manager: WorkerManager = WorkerManager(
79
- self.num_workers, registry, process_lock
77
+ self.worker_manager: WorkerManager = WorkerManager(self.num_workers, registry, process_lock)
78
+
79
+ # create metric logger
80
+ # Initialize DataPath temporarily to get experiment path for tensorboard logs
81
+ experiment_path = self.db.get_experiments_path(self.experiment_name)
82
+ DataPath.initialize(self.experiment_name, experiment_path)
83
+ tensorboard_log_dir = TENSORBOARD_LOG_DIR or str(DataPath.experiments_path / "tensorboard_logs")
84
+
85
+ self.metric_logger = create_metric_logger(
86
+ backend=get_tracking_backend(),
87
+ mlflow_tracking_uri=MLFLOW_URL,
88
+ tensorboard_log_dir=tensorboard_log_dir,
80
89
  )
81
-
82
- # create mlflow manager
83
- self.mlflow_manager: MLflowManager = MLflowManager(MLFLOW_URL)
84
- self.mlflow_manager.get_experiment(self.experiment_name)
90
+ # Get experiment if using MLflow
91
+ if hasattr(self.metric_logger, "get_experiment"):
92
+ self.metric_logger.get_experiment(self.experiment_name)
85
93
 
86
94
  self.logger.debug("Controller initialized")
87
95
 
@@ -104,14 +112,12 @@ class Controller:
104
112
  for config_leaf in config_leafs:
105
113
  flattened_config = get_flattened_config_leaf(config_leaf)
106
114
  # print("flattened_config: ",flattened_config)
107
- total_steps = self._get_total_step(
108
- config_leaf, len_train_dataset, num_chunks
109
- )
115
+ total_steps = self._get_total_step(config_leaf, len_train_dataset, num_chunks)
110
116
 
111
117
  # get clone modify info
112
118
  warm_started_from = clone_modify_info.get("warm_started_from") if clone_modify_info else None
113
119
  cloned_from = clone_modify_info.get("cloned_from") if clone_modify_info else None
114
- chunk_offset = clone_modify_info.get("chunk_offset",0) if clone_modify_info else 0
120
+ chunk_offset = clone_modify_info.get("chunk_offset", 0) if clone_modify_info else 0
115
121
 
116
122
  run_id = self.db.create_run(
117
123
  config_leaf=config_leaf,
@@ -131,57 +137,44 @@ class Controller:
131
137
  try:
132
138
  base_run_path = DataPath.base_run_path(run_id)
133
139
  work_dir_path = DataPath.work_dir_path(base_run_path)
134
- initial_checkpoint_path = DataPath.initial_checkpoint_path(
135
- base_run_path
136
- )
140
+ initial_checkpoint_path = DataPath.initial_checkpoint_path(base_run_path)
137
141
  final_checkpoint_path = DataPath.final_checkpoint_path(base_run_path)
138
- intermediate_checkpoint_path = DataPath.intermediate_checkpoint_path(
139
- base_run_path
140
- )
142
+ intermediate_checkpoint_path = DataPath.intermediate_checkpoint_path(base_run_path)
141
143
 
142
144
  Path.mkdir(work_dir_path, parents=True, exist_ok=True)
143
145
  Path.mkdir(initial_checkpoint_path, parents=True, exist_ok=True)
144
146
  Path.mkdir(final_checkpoint_path, parents=True, exist_ok=True)
145
147
  Path.mkdir(intermediate_checkpoint_path, parents=True, exist_ok=True)
146
148
  except (PermissionError, OSError) as e:
147
- raise ControllerException(
148
- f"Failed to create required Run DataPath directories: {e}"
149
- ) from e
149
+ raise ControllerException(f"Failed to create required Run DataPath directories: {e}") from e
150
150
 
151
- # create new MlFlow run
151
+ # create new tracking run
152
152
  try:
153
- # create new MlFlow run and get the mlflow_run_id
154
- mlflow_run_id = self.mlflow_manager.create_run(str(run_id))
153
+ # create new tracking run and get the mlflow_run_id
154
+ mlflow_run_id = self.metric_logger.create_run(str(run_id))
155
155
 
156
- # populate MLFlow with model config info
156
+ # populate tracking backend with model config info
157
157
  for key, value in flattened_config.items():
158
- self.mlflow_manager.log_param(mlflow_run_id, key, value)
158
+ self.metric_logger.log_param(mlflow_run_id, key, value)
159
159
  if warm_started_from:
160
- self.mlflow_manager.log_param(
161
- mlflow_run_id, "warm-start", str(warm_started_from)
162
- )
160
+ self.metric_logger.log_param(mlflow_run_id, "warm-start", str(warm_started_from))
163
161
  if cloned_from:
164
- self.mlflow_manager.log_param(
165
- mlflow_run_id, "parent-run", str(cloned_from)
166
- )
167
- self.logger.debug(
168
- f"Populated MLFlow with model config info for run {run_id}."
169
- )
162
+ self.metric_logger.log_param(mlflow_run_id, "parent-run", str(cloned_from))
163
+ self.logger.debug(f"Populated MLFlow with model config info for run {run_id}.")
170
164
  self.db.set_run_details(
171
165
  run_id=run_id,
172
166
  mlflow_run_id=mlflow_run_id,
173
167
  flattened_config=flattened_config,
174
168
  )
175
- except mlflow.exceptions.MlflowException as e:
176
- msg = f"Error creating new MLFlow run for run {run_id} - {e}."
169
+ except Exception as e:
170
+ # Catch any metric logger exceptions (MLflow, TensorBoard, etc.)
171
+ msg = f"Error creating new tracking run for run {run_id} - {e}."
177
172
  print(msg)
178
- self.mlflow_manager.end_run(mlflow_run_id)
173
+ self.metric_logger.end_run(mlflow_run_id)
179
174
  self.logger.error(msg, exc_info=True)
180
175
 
181
176
  total_runs = len(runs)
182
- self.logger.info(
183
- f"Created {total_runs} runs - \n{pformat(runs, indent=4, width=120)}"
184
- )
177
+ self.logger.info(f"Created {total_runs} runs - \n{pformat(runs, indent=4, width=120)}")
185
178
  self.logger.debug(f"Got {total_runs} runs for {source.value}.")
186
179
 
187
180
  # set experiment task to run_fit
@@ -195,24 +188,17 @@ class Controller:
195
188
 
196
189
  # check if there are any other runs with the same base model
197
190
  base_model_name = self.db.get_run(run_id)["config_leaf"]["model_name"]
198
- relevant_runs = self.db.get_runs_by_status(
199
- [RunStatus.ONGOING, RunStatus.NEW, RunStatus.STOPPED]
200
- )
191
+ relevant_runs = self.db.get_runs_by_status([RunStatus.ONGOING, RunStatus.NEW, RunStatus.STOPPED])
201
192
 
202
193
  # get shared object types to delete - if no other runs are using it
203
194
  delete_shared_objects = True
204
195
  for r_run_id, r_run_details in relevant_runs.items():
205
- if (
206
- r_run_details["config_leaf"]["model_name"] == base_model_name
207
- and r_run_id != run_id
208
- ):
196
+ if r_run_details["config_leaf"]["model_name"] == base_model_name and r_run_id != run_id:
209
197
  delete_shared_objects = False
210
198
  break
211
199
 
212
200
  # delete model object from shared memory
213
- self.shm_manager.delete_model_object(
214
- run_id, base_model_name if delete_shared_objects else None
215
- )
201
+ self.shm_manager.delete_model_object(run_id, base_model_name if delete_shared_objects else None)
216
202
 
217
203
  def _process_interactive_control(
218
204
  self,
@@ -237,9 +223,7 @@ class Controller:
237
223
  status=RunStatus.STOPPED,
238
224
  ended_by=RunEndedBy.INTERACTIVE_CONTROL,
239
225
  )
240
- self.db.set_ic_ops_task_status(
241
- run_state["task_id"], TaskStatus.COMPLETED
242
- )
226
+ self.db.set_ic_ops_task_status(run_state["task_id"], TaskStatus.COMPLETED)
243
227
  self.ic_logger.info(f"Stopping run {run_id} by Interactive Control")
244
228
  elif run_state["status"] == RunStatus.DELETED:
245
229
  # process deleted tasks
@@ -249,16 +233,14 @@ class Controller:
249
233
 
250
234
  # delete run from MLFlow
251
235
  mlflow_run_id = self.db.get_run(run_id)["mlflow_run_id"]
252
- self.mlflow_manager.delete_run(mlflow_run_id)
236
+ self.metric_logger.delete_run(mlflow_run_id)
253
237
  # mark run as deleted
254
238
  self.db.set_run_details(
255
239
  run_id=run_id,
256
240
  status=RunStatus.DELETED,
257
241
  ended_by=RunEndedBy.INTERACTIVE_CONTROL,
258
242
  )
259
- self.db.set_ic_ops_task_status(
260
- run_state["task_id"], TaskStatus.COMPLETED
261
- )
243
+ self.db.set_ic_ops_task_status(run_state["task_id"], TaskStatus.COMPLETED)
262
244
  self.ic_logger.info(f"Deleting run {run_id} by Interactive Control")
263
245
  elif run_state["status"] == RunStatus.ONGOING:
264
246
  # process ongoing tasks
@@ -267,15 +249,11 @@ class Controller:
267
249
  status=RunStatus.ONGOING,
268
250
  ended_by="",
269
251
  )
270
- self.db.set_ic_ops_task_status(
271
- run_state["task_id"], TaskStatus.COMPLETED
272
- )
252
+ self.db.set_ic_ops_task_status(run_state["task_id"], TaskStatus.COMPLETED)
273
253
  self.ic_logger.info(f"Resuming run {run_id} by Interactive Control")
274
254
  elif run_state["status"] == RunStatus.COMPLETED:
275
255
  # process completed tasks
276
- self.logger.warning(
277
- f"Run {run_id} is already completed. Skipping Interactive Control task."
278
- )
256
+ self.logger.warning(f"Run {run_id} is already completed. Skipping Interactive Control task.")
279
257
  self.db.set_ic_ops_task_status(run_state["task_id"], TaskStatus.SKIPPED)
280
258
  else:
281
259
  raise ValueError(f"Unsupported run status {run_state['status']}")
@@ -291,9 +269,7 @@ class Controller:
291
269
  # add additional_kwargs to config_leaf if it exists in the parent run
292
270
  parent_run_details = self.db.get_run(parent_run_id)
293
271
  if "additional_kwargs" in parent_run_details["config_leaf"]:
294
- config_leaf["additional_kwargs"] = parent_run_details["config_leaf"][
295
- "additional_kwargs"
296
- ]
272
+ config_leaf["additional_kwargs"] = parent_run_details["config_leaf"]["additional_kwargs"]
297
273
 
298
274
  # create model for the new run
299
275
  try:
@@ -311,7 +287,9 @@ class Controller:
311
287
  )
312
288
  elif ic_op == ControllerTask.IC_CLONE_MODIFY_WARM:
313
289
  # calculate clone chunk offset
314
- effective_batch_size = parent_run_details["config_leaf"]["training_args"].get("per_device_train_batch_size", 1) * parent_run_details["config_leaf"]["training_args"].get("gradient_accumulation_steps", 1)
290
+ effective_batch_size = parent_run_details["config_leaf"]["training_args"].get(
291
+ "per_device_train_batch_size", 1
292
+ ) * parent_run_details["config_leaf"]["training_args"].get("gradient_accumulation_steps", 1)
315
293
  chunker = DatasetChunks(
316
294
  len_train_dataset,
317
295
  num_chunks,
@@ -342,12 +320,8 @@ class Controller:
342
320
  )
343
321
  except Exception as e:
344
322
  self.db.set_ic_ops_task_status(task["task_id"], TaskStatus.FAILED)
345
- self.ic_logger.error(
346
- f"Error creating model for run {parent_run_id}: {e}"
347
- )
348
- raise ControllerException(
349
- f"Error creating model for run {parent_run_id}: {e}"
350
- ) from e
323
+ self.ic_logger.error(f"Error creating model for run {parent_run_id}: {e}")
324
+ raise ControllerException(f"Error creating model for run {parent_run_id}: {e}") from e
351
325
 
352
326
  def _process_interm_ic_ops_states(
353
327
  self,
@@ -376,11 +350,7 @@ class Controller:
376
350
  if is_clone_modify_task:
377
351
  # clone_modify tasks
378
352
  # get latest run state
379
- run_status = (
380
- run_states[run_id]["status"]
381
- if run_id in run_states
382
- else self.db.get_run(run_id)["status"]
383
- )
353
+ run_status = run_states[run_id]["status"] if run_id in run_states else self.db.get_run(run_id)["status"]
384
354
 
385
355
  # track clone_modify tasks only for non-deleted runs
386
356
  if run_status != RunStatus.DELETED:
@@ -388,9 +358,7 @@ class Controller:
388
358
  self.ic_logger.info(f"Added {task['ic_op']} task for run {run_id}.")
389
359
  else:
390
360
  self.db.set_ic_ops_task_status(task["task_id"], TaskStatus.SKIPPED)
391
- self.ic_logger.warning(
392
- f"Skipping {task['ic_op']} task for deleted run {run_id}."
393
- )
361
+ self.ic_logger.warning(f"Skipping {task['ic_op']} task for deleted run {run_id}.")
394
362
  else:
395
363
  # Non clone_modify tasks
396
364
  if run_id not in run_states:
@@ -407,32 +375,21 @@ class Controller:
407
375
  ControllerTask.IC_STOP,
408
376
  ]:
409
377
  # ignore RESUME/STOP tasks for completed runs
410
- self.ic_logger.warning(
411
- f"Ignoring RESUME/STOP task for run {run_id} as it is already completed"
412
- )
378
+ self.ic_logger.warning(f"Ignoring RESUME/STOP task for run {run_id} as it is already completed")
413
379
  self.db.set_ic_ops_task_status(task["task_id"], TaskStatus.SKIPPED)
414
- elif (
415
- current_status == RunStatus.FAILED
416
- and task["ic_op"] != ControllerTask.IC_DELETE
417
- ):
380
+ elif current_status == RunStatus.FAILED and task["ic_op"] != ControllerTask.IC_DELETE:
418
381
  # ignore all tasks except DELETE for failed runs
419
- self.ic_logger.warning(
420
- f"Ignoring task {task['ic_op'].value} for failed run {run_id}"
421
- )
382
+ self.ic_logger.warning(f"Ignoring task {task['ic_op'].value} for failed run {run_id}")
422
383
  self.db.set_ic_ops_task_status(task["task_id"], TaskStatus.SKIPPED)
423
384
  elif current_status == RunStatus.DELETED:
424
385
  # ignore all tasks for deleted runs
425
- self.ic_logger.warning(
426
- f"Ignoring task {task['ic_op'].value} for deleted run {run_id}"
427
- )
386
+ self.ic_logger.warning(f"Ignoring task {task['ic_op'].value} for deleted run {run_id}")
428
387
  self.db.set_ic_ops_task_status(task["task_id"], TaskStatus.SKIPPED)
429
388
  else:
430
389
  # valid ic_op for this run
431
390
  # mark prev task as completed
432
391
  if run_states[run_id]["task_id"] is not None:
433
- self.db.set_ic_ops_task_status(
434
- run_states[run_id]["task_id"], TaskStatus.COMPLETED
435
- )
392
+ self.db.set_ic_ops_task_status(run_states[run_id]["task_id"], TaskStatus.COMPLETED)
436
393
 
437
394
  # add new task to run states
438
395
  if task["ic_op"] == ControllerTask.IC_STOP:
@@ -445,26 +402,20 @@ class Controller:
445
402
  updated_status = RunStatus.ONGOING
446
403
  info_msg = f"Received RESUME task for run {run_id}"
447
404
  else:
448
- self.db.set_ic_ops_task_status(
449
- task["task_id"], TaskStatus.FAILED
450
- )
405
+ self.db.set_ic_ops_task_status(task["task_id"], TaskStatus.FAILED)
451
406
  raise ValueError(f"Unsupported task {task['ic_op']}")
452
407
  run_states[run_id].update(
453
408
  {
454
409
  "task_id": task["task_id"],
455
410
  "task": task["ic_op"],
456
- "status": (
457
- updated_status if updated_status else current_status
458
- ),
411
+ "status": (updated_status if updated_status else current_status),
459
412
  }
460
413
  )
461
414
  self.ic_logger.info(info_msg)
462
415
 
463
416
  return run_states, clone_modify_tasks
464
417
 
465
- def _get_total_step(
466
- self, config_leaf: dict[str, Any], len_train_dataset: int, num_chunks: int
467
- ) -> int:
418
+ def _get_total_step(self, config_leaf: dict[str, Any], len_train_dataset: int, num_chunks: int) -> int:
468
419
  """Get the total number of steps for a run."""
469
420
  num_train_epochs = config_leaf["training_args"].get("num_train_epochs", 1)
470
421
 
@@ -474,25 +425,20 @@ class Controller:
474
425
  # ceil to nearest chunk multiple
475
426
  total_steps = config_leaf["training_args"]["max_steps"]
476
427
  elif num_train_epochs:
477
- per_device_train_batch_size = config_leaf["training_args"].get(
478
- "per_device_train_batch_size", 1
479
- )
480
- gradient_accumulation_steps = config_leaf["training_args"].get(
481
- "gradient_accumulation_steps", 1
482
- )
428
+ per_device_train_batch_size = config_leaf["training_args"].get("per_device_train_batch_size", 1)
429
+ gradient_accumulation_steps = config_leaf["training_args"].get("gradient_accumulation_steps", 1)
483
430
  len_dataloader = math.ceil(len_train_dataset / per_device_train_batch_size)
484
431
  num_update_steps_per_epoch = max(
485
- len_dataloader // gradient_accumulation_steps
486
- + int(len_dataloader % gradient_accumulation_steps > 0),
432
+ len_dataloader // gradient_accumulation_steps + int(len_dataloader % gradient_accumulation_steps > 0),
487
433
  1,
488
434
  )
489
435
  total_steps = math.ceil(num_train_epochs * num_update_steps_per_epoch)
490
436
 
491
437
  if config_leaf.get("trainer_type", "SFT") == "GRPO":
492
438
  num_generations = config_leaf["training_args"].get("num_generations", 8)
493
- total_steps = (
494
- num_generations * len_train_dataset * num_train_epochs
495
- ) // (gradient_accumulation_steps * per_device_train_batch_size)
439
+ total_steps = (num_generations * len_train_dataset * num_train_epochs) // (
440
+ gradient_accumulation_steps * per_device_train_batch_size
441
+ )
496
442
  return total_steps
497
443
 
498
444
  def run_fit(
@@ -508,9 +454,7 @@ class Controller:
508
454
 
509
455
  # set experiment task to create models
510
456
  self.db.set_experiment_current_task(ExperimentTask.CREATE_MODELS)
511
- self.logger.debug(
512
- f"Set experiment task to {ExperimentTask.CREATE_MODELS.value}."
513
- )
457
+ self.logger.debug(f"Set experiment task to {ExperimentTask.CREATE_MODELS.value}.")
514
458
 
515
459
  # save train and eval dataset objects to a file for workers to load
516
460
  try:
@@ -595,10 +539,7 @@ class Controller:
595
539
  )
596
540
 
597
541
  # skip if task is the same as previous iteration (no change in status) or run is not active
598
- if (
599
- current_task_tuple == prev_task_tuple
600
- or worker_task["run_id"] not in scheduler.run_ids
601
- ):
542
+ if current_task_tuple == prev_task_tuple or worker_task["run_id"] not in scheduler.run_ids:
602
543
  continue
603
544
 
604
545
  if worker_task["status"] == TaskStatus.COMPLETED:
@@ -611,9 +552,7 @@ class Controller:
611
552
  run_id = worker_task["run_id"]
612
553
  chunk_id = worker_task["chunk_id"]
613
554
  run_details = all_run_details[run_id]
614
- self.logger.debug(
615
- f"Completed task: run {run_id}, chunk {chunk_id} on worker {worker_id}"
616
- )
555
+ self.logger.debug(f"Completed task: run {run_id}, chunk {chunk_id} on worker {worker_id}")
617
556
  self.logger.info(
618
557
  f"Run {run_id} completed steps - {run_details['completed_steps']}/{run_details['total_steps']}"
619
558
  )
@@ -635,11 +574,7 @@ class Controller:
635
574
 
636
575
  # Update progress
637
576
  progress_percentage = (
638
- (
639
- run_details["completed_steps"]
640
- / run_details["total_steps"]
641
- * 100
642
- )
577
+ (run_details["completed_steps"] / run_details["total_steps"] * 100)
643
578
  if run_details["total_steps"] > 0
644
579
  else 0
645
580
  )
@@ -660,16 +595,11 @@ class Controller:
660
595
  )
661
596
  # Check if run has completed only current epoch (hasn't reached total_steps yet)
662
597
  elif (
663
- new_chunks_visited == num_chunks
664
- and run_details["completed_steps"] < run_details["total_steps"]
598
+ new_chunks_visited == num_chunks and run_details["completed_steps"] < run_details["total_steps"]
665
599
  ):
666
600
  scheduler.reset_run(run_id)
667
- self.db.set_run_details(
668
- run_id=run_id, num_chunks_visited_curr_epoch=0
669
- )
670
- self.logger.info(
671
- f"Run {run_id} has completed epoch ({new_chunks_visited}/{num_chunks} chunks)"
672
- )
601
+ self.db.set_run_details(run_id=run_id, num_chunks_visited_curr_epoch=0)
602
+ self.logger.info(f"Run {run_id} has completed epoch ({new_chunks_visited}/{num_chunks} chunks)")
673
603
 
674
604
  # Check for failed runs and update scheduler, local state, shm
675
605
  for worker_task in failed_tasks:
@@ -685,12 +615,8 @@ class Controller:
685
615
 
686
616
  # Process interactive control tasks (this fetches latest run states internally)
687
617
  try:
688
- currently_scheduled_runs = list(
689
- scheduler.worker_running_current_run.values()
690
- )
691
- run_states, clone_modify_tasks = self._process_interm_ic_ops_states(
692
- currently_scheduled_runs
693
- )
618
+ currently_scheduled_runs = list(scheduler.worker_running_current_run.values())
619
+ run_states, clone_modify_tasks = self._process_interm_ic_ops_states(currently_scheduled_runs)
694
620
  self._process_interactive_control(
695
621
  run_states,
696
622
  clone_modify_tasks,
@@ -699,9 +625,7 @@ class Controller:
699
625
  num_chunks,
700
626
  )
701
627
  except Exception as e:
702
- raise ControllerException(
703
- f"Error processing interactive control tasks: {e}"
704
- ) from e
628
+ raise ControllerException(f"Error processing interactive control tasks: {e}") from e
705
629
 
706
630
  # fetch latest run states again (post IC ops states)
707
631
  all_run_details = self.db.get_all_runs()
@@ -715,8 +639,7 @@ class Controller:
715
639
  self.logger.debug(f"Added run {run_id} to scheduler with {chunks_visited} chunks visited")
716
640
  # remove inactive runs from scheduler
717
641
  elif (
718
- run_details["status"] in (RunStatus.STOPPED, RunStatus.DELETED)
719
- and run_id in scheduler.run_ids
642
+ run_details["status"] in (RunStatus.STOPPED, RunStatus.DELETED) and run_id in scheduler.run_ids
720
643
  ):
721
644
  scheduler.remove_run(run_id)
722
645
  self.logger.debug(f"Removed run {run_id} from scheduler")
@@ -729,9 +652,7 @@ class Controller:
729
652
 
730
653
  # Check termination condition
731
654
  if run_id is None and worker_id is None and chunk_id is None:
732
- self.logger.info(
733
- "Scheduler indicates all runs have completed all chunks"
734
- )
655
+ self.logger.info("Scheduler indicates all runs have completed all chunks")
735
656
  all_done = True
736
657
  break
737
658
 
@@ -753,9 +674,7 @@ class Controller:
753
674
  chunk_id,
754
675
  config_options={"create_model_fn": create_model_fn},
755
676
  )
756
- self.logger.debug(
757
- f"Scheduled run {run_id} on worker {worker_id} for chunk {chunk_id}"
758
- )
677
+ self.logger.debug(f"Scheduled run {run_id} on worker {worker_id} for chunk {chunk_id}")
759
678
 
760
679
  # Small delay
761
680
  time.sleep(1)
@@ -24,11 +24,20 @@ from rapidfireai.ml.checkpoint_utils import (
24
24
  save_model_to_shared_memory,
25
25
  )
26
26
  from rapidfireai.ml.trainer import create_trainer_instance
27
- from rapidfireai.utils.constants import MLFLOW_URL, USE_SHARED_MEMORY, RunStatus, SHMObjectType, TaskStatus, WorkerTask
27
+ from rapidfireai.utils.constants import (
28
+ MLFLOW_URL,
29
+ TENSORBOARD_LOG_DIR,
30
+ USE_SHARED_MEMORY,
31
+ RunStatus,
32
+ SHMObjectType,
33
+ TaskStatus,
34
+ WorkerTask,
35
+ get_tracking_backend,
36
+ )
28
37
  from rapidfireai.utils.datapaths import DataPath
29
38
  from rapidfireai.utils.exceptions import WorkerException
30
39
  from rapidfireai.utils.logging import RFLogger, TrainingLogger
31
- from rapidfireai.utils.mlflow_manager import MLflowManager
40
+ from rapidfireai.utils.metric_logger import create_metric_logger
32
41
  from rapidfireai.utils.serialize import decode_db_payload
33
42
  from rapidfireai.utils.shm_manager import SharedMemoryManager
34
43
  from rapidfireai.utils.trainer_config import TrainerConfig
@@ -71,13 +80,20 @@ class Worker:
71
80
  # get experiment name
72
81
  self.experiment_name: str = self.db.get_running_experiment()["experiment_name"]
73
82
 
74
- # create mlflow manager
75
- self.mlflow_manager: MLflowManager = MLflowManager(MLFLOW_URL)
76
- self.mlflow_manager.get_experiment(self.experiment_name)
77
-
78
83
  # initialize data paths
79
84
  DataPath.initialize(self.experiment_name, self.db.get_experiments_path(self.experiment_name))
80
85
 
86
+ # create metric logger
87
+ tensorboard_log_dir = TENSORBOARD_LOG_DIR or str(DataPath.experiments_path / "tensorboard_logs")
88
+ self.metric_logger = create_metric_logger(
89
+ backend=get_tracking_backend(),
90
+ mlflow_tracking_uri=MLFLOW_URL,
91
+ tensorboard_log_dir=tensorboard_log_dir,
92
+ )
93
+ # Get experiment if using MLflow
94
+ if hasattr(self.metric_logger, "get_experiment"):
95
+ self.metric_logger.get_experiment(self.experiment_name)
96
+
81
97
  # load datasets
82
98
  self.train_dataset, self.eval_dataset, self.num_chunks = self.load_datasets()
83
99
  self.len_train_dataset = len(self.train_dataset)
@@ -112,7 +128,9 @@ class Worker:
112
128
  # torch.manual_seed(run_details["seed"])
113
129
  # np.random.seed(run_details["seed"])
114
130
  # random.seed(run_details["seed"])
115
- effective_batch_size = config_leaf["training_args"].get("per_device_train_batch_size", 1) * config_leaf["training_args"].get("gradient_accumulation_steps", 1)
131
+ effective_batch_size = config_leaf["training_args"].get("per_device_train_batch_size", 1) * config_leaf[
132
+ "training_args"
133
+ ].get("gradient_accumulation_steps", 1)
116
134
 
117
135
  # fetch train dataset chunk
118
136
  train_dataset_chunker = DatasetChunks(
@@ -149,7 +167,7 @@ class Worker:
149
167
  stderr_buffer = StringIO()
150
168
  with redirect_stdout(stdout_buffer), redirect_stderr(stderr_buffer):
151
169
  trainer_instance, base_model_name = create_trainer_instance(
152
- trainer_config, self.shm_manager, USE_SHARED_MEMORY, self.mlflow_manager, chunk_id
170
+ trainer_config, self.shm_manager, USE_SHARED_MEMORY, self.metric_logger, chunk_id
153
171
  )
154
172
 
155
173
  # if first time, save checkpoint to disk