rapidfireai 0.10.3rc1__py3-none-any.whl → 0.11.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rapidfireai might be problematic. Click here for more details.

Files changed (26) hide show
  1. rapidfireai/automl/grid_search.py +4 -5
  2. rapidfireai/automl/model_config.py +41 -37
  3. rapidfireai/automl/random_search.py +21 -33
  4. rapidfireai/backend/controller.py +54 -148
  5. rapidfireai/backend/worker.py +14 -3
  6. rapidfireai/cli.py +148 -136
  7. rapidfireai/experiment.py +22 -11
  8. rapidfireai/frontend/build/asset-manifest.json +3 -3
  9. rapidfireai/frontend/build/index.html +1 -1
  10. rapidfireai/frontend/build/static/js/{main.e7d3b759.js → main.58393d31.js} +3 -3
  11. rapidfireai/frontend/build/static/js/{main.e7d3b759.js.map → main.58393d31.js.map} +1 -1
  12. rapidfireai/ml/callbacks.py +10 -24
  13. rapidfireai/ml/trainer.py +37 -81
  14. rapidfireai/utils/constants.py +3 -1
  15. rapidfireai/utils/interactive_controller.py +40 -61
  16. rapidfireai/utils/logging.py +1 -2
  17. rapidfireai/utils/mlflow_manager.py +1 -0
  18. rapidfireai/utils/ping.py +4 -2
  19. rapidfireai/version.py +2 -2
  20. {rapidfireai-0.10.3rc1.dist-info → rapidfireai-0.11.1rc1.dist-info}/METADATA +1 -1
  21. {rapidfireai-0.10.3rc1.dist-info → rapidfireai-0.11.1rc1.dist-info}/RECORD +26 -26
  22. /rapidfireai/frontend/build/static/js/{main.e7d3b759.js.LICENSE.txt → main.58393d31.js.LICENSE.txt} +0 -0
  23. {rapidfireai-0.10.3rc1.dist-info → rapidfireai-0.11.1rc1.dist-info}/WHEEL +0 -0
  24. {rapidfireai-0.10.3rc1.dist-info → rapidfireai-0.11.1rc1.dist-info}/entry_points.txt +0 -0
  25. {rapidfireai-0.10.3rc1.dist-info → rapidfireai-0.11.1rc1.dist-info}/licenses/LICENSE +0 -0
  26. {rapidfireai-0.10.3rc1.dist-info → rapidfireai-0.11.1rc1.dist-info}/top_level.txt +0 -0
@@ -70,15 +70,11 @@ class Controller:
70
70
  self.logger.debug(f"Found {self.num_workers} workers/GPUs.")
71
71
 
72
72
  # initialize shared manager and registry, create shared memory manager instance
73
- self.shm_manager: SharedMemoryManager = SharedMemoryManager(
74
- name="controller-shm"
75
- )
73
+ self.shm_manager: SharedMemoryManager = SharedMemoryManager(name="controller-shm")
76
74
  registry, process_lock = self.shm_manager.get_shm_objects()
77
75
 
78
76
  # create worker manager
79
- self.worker_manager: WorkerManager = WorkerManager(
80
- self.num_workers, registry, process_lock
81
- )
77
+ self.worker_manager: WorkerManager = WorkerManager(self.num_workers, registry, process_lock)
82
78
 
83
79
  # create metric logger
84
80
  # Initialize DataPath temporarily to get experiment path for tensorboard logs
@@ -92,7 +88,7 @@ class Controller:
92
88
  tensorboard_log_dir=tensorboard_log_dir,
93
89
  )
94
90
  # Get experiment if using MLflow
95
- if hasattr(self.metric_logger, 'get_experiment'):
91
+ if hasattr(self.metric_logger, "get_experiment"):
96
92
  self.metric_logger.get_experiment(self.experiment_name)
97
93
 
98
94
  self.logger.debug("Controller initialized")
@@ -116,14 +112,12 @@ class Controller:
116
112
  for config_leaf in config_leafs:
117
113
  flattened_config = get_flattened_config_leaf(config_leaf)
118
114
  # print("flattened_config: ",flattened_config)
119
- total_steps = self._get_total_step(
120
- config_leaf, len_train_dataset, num_chunks
121
- )
115
+ total_steps = self._get_total_step(config_leaf, len_train_dataset, num_chunks)
122
116
 
123
117
  # get clone modify info
124
118
  warm_started_from = clone_modify_info.get("warm_started_from") if clone_modify_info else None
125
119
  cloned_from = clone_modify_info.get("cloned_from") if clone_modify_info else None
126
- chunk_offset = clone_modify_info.get("chunk_offset",0) if clone_modify_info else 0
120
+ chunk_offset = clone_modify_info.get("chunk_offset", 0) if clone_modify_info else 0
127
121
 
128
122
  run_id = self.db.create_run(
129
123
  config_leaf=config_leaf,
@@ -143,22 +137,16 @@ class Controller:
143
137
  try:
144
138
  base_run_path = DataPath.base_run_path(run_id)
145
139
  work_dir_path = DataPath.work_dir_path(base_run_path)
146
- initial_checkpoint_path = DataPath.initial_checkpoint_path(
147
- base_run_path
148
- )
140
+ initial_checkpoint_path = DataPath.initial_checkpoint_path(base_run_path)
149
141
  final_checkpoint_path = DataPath.final_checkpoint_path(base_run_path)
150
- intermediate_checkpoint_path = DataPath.intermediate_checkpoint_path(
151
- base_run_path
152
- )
142
+ intermediate_checkpoint_path = DataPath.intermediate_checkpoint_path(base_run_path)
153
143
 
154
144
  Path.mkdir(work_dir_path, parents=True, exist_ok=True)
155
145
  Path.mkdir(initial_checkpoint_path, parents=True, exist_ok=True)
156
146
  Path.mkdir(final_checkpoint_path, parents=True, exist_ok=True)
157
147
  Path.mkdir(intermediate_checkpoint_path, parents=True, exist_ok=True)
158
148
  except (PermissionError, OSError) as e:
159
- raise ControllerException(
160
- f"Failed to create required Run DataPath directories: {e}"
161
- ) from e
149
+ raise ControllerException(f"Failed to create required Run DataPath directories: {e}") from e
162
150
 
163
151
  # create new tracking run
164
152
  try:
@@ -169,16 +157,10 @@ class Controller:
169
157
  for key, value in flattened_config.items():
170
158
  self.metric_logger.log_param(mlflow_run_id, key, value)
171
159
  if warm_started_from:
172
- self.metric_logger.log_param(
173
- mlflow_run_id, "warm-start", str(warm_started_from)
174
- )
160
+ self.metric_logger.log_param(mlflow_run_id, "warm-start", str(warm_started_from))
175
161
  if cloned_from:
176
- self.metric_logger.log_param(
177
- mlflow_run_id, "parent-run", str(cloned_from)
178
- )
179
- self.logger.debug(
180
- f"Populated MLFlow with model config info for run {run_id}."
181
- )
162
+ self.metric_logger.log_param(mlflow_run_id, "parent-run", str(cloned_from))
163
+ self.logger.debug(f"Populated MLFlow with model config info for run {run_id}.")
182
164
  self.db.set_run_details(
183
165
  run_id=run_id,
184
166
  mlflow_run_id=mlflow_run_id,
@@ -192,9 +174,7 @@ class Controller:
192
174
  self.logger.error(msg, exc_info=True)
193
175
 
194
176
  total_runs = len(runs)
195
- self.logger.info(
196
- f"Created {total_runs} runs - \n{pformat(runs, indent=4, width=120)}"
197
- )
177
+ self.logger.info(f"Created {total_runs} runs - \n{pformat(runs, indent=4, width=120)}")
198
178
  self.logger.debug(f"Got {total_runs} runs for {source.value}.")
199
179
 
200
180
  # set experiment task to run_fit
@@ -208,24 +188,17 @@ class Controller:
208
188
 
209
189
  # check if there are any other runs with the same base model
210
190
  base_model_name = self.db.get_run(run_id)["config_leaf"]["model_name"]
211
- relevant_runs = self.db.get_runs_by_status(
212
- [RunStatus.ONGOING, RunStatus.NEW, RunStatus.STOPPED]
213
- )
191
+ relevant_runs = self.db.get_runs_by_status([RunStatus.ONGOING, RunStatus.NEW, RunStatus.STOPPED])
214
192
 
215
193
  # get shared object types to delete - if no other runs are using it
216
194
  delete_shared_objects = True
217
195
  for r_run_id, r_run_details in relevant_runs.items():
218
- if (
219
- r_run_details["config_leaf"]["model_name"] == base_model_name
220
- and r_run_id != run_id
221
- ):
196
+ if r_run_details["config_leaf"]["model_name"] == base_model_name and r_run_id != run_id:
222
197
  delete_shared_objects = False
223
198
  break
224
199
 
225
200
  # delete model object from shared memory
226
- self.shm_manager.delete_model_object(
227
- run_id, base_model_name if delete_shared_objects else None
228
- )
201
+ self.shm_manager.delete_model_object(run_id, base_model_name if delete_shared_objects else None)
229
202
 
230
203
  def _process_interactive_control(
231
204
  self,
@@ -250,9 +223,7 @@ class Controller:
250
223
  status=RunStatus.STOPPED,
251
224
  ended_by=RunEndedBy.INTERACTIVE_CONTROL,
252
225
  )
253
- self.db.set_ic_ops_task_status(
254
- run_state["task_id"], TaskStatus.COMPLETED
255
- )
226
+ self.db.set_ic_ops_task_status(run_state["task_id"], TaskStatus.COMPLETED)
256
227
  self.ic_logger.info(f"Stopping run {run_id} by Interactive Control")
257
228
  elif run_state["status"] == RunStatus.DELETED:
258
229
  # process deleted tasks
@@ -269,9 +240,7 @@ class Controller:
269
240
  status=RunStatus.DELETED,
270
241
  ended_by=RunEndedBy.INTERACTIVE_CONTROL,
271
242
  )
272
- self.db.set_ic_ops_task_status(
273
- run_state["task_id"], TaskStatus.COMPLETED
274
- )
243
+ self.db.set_ic_ops_task_status(run_state["task_id"], TaskStatus.COMPLETED)
275
244
  self.ic_logger.info(f"Deleting run {run_id} by Interactive Control")
276
245
  elif run_state["status"] == RunStatus.ONGOING:
277
246
  # process ongoing tasks
@@ -280,15 +249,11 @@ class Controller:
280
249
  status=RunStatus.ONGOING,
281
250
  ended_by="",
282
251
  )
283
- self.db.set_ic_ops_task_status(
284
- run_state["task_id"], TaskStatus.COMPLETED
285
- )
252
+ self.db.set_ic_ops_task_status(run_state["task_id"], TaskStatus.COMPLETED)
286
253
  self.ic_logger.info(f"Resuming run {run_id} by Interactive Control")
287
254
  elif run_state["status"] == RunStatus.COMPLETED:
288
255
  # process completed tasks
289
- self.logger.warning(
290
- f"Run {run_id} is already completed. Skipping Interactive Control task."
291
- )
256
+ self.logger.warning(f"Run {run_id} is already completed. Skipping Interactive Control task.")
292
257
  self.db.set_ic_ops_task_status(run_state["task_id"], TaskStatus.SKIPPED)
293
258
  else:
294
259
  raise ValueError(f"Unsupported run status {run_state['status']}")
@@ -304,9 +269,7 @@ class Controller:
304
269
  # add additional_kwargs to config_leaf if it exists in the parent run
305
270
  parent_run_details = self.db.get_run(parent_run_id)
306
271
  if "additional_kwargs" in parent_run_details["config_leaf"]:
307
- config_leaf["additional_kwargs"] = parent_run_details["config_leaf"][
308
- "additional_kwargs"
309
- ]
272
+ config_leaf["additional_kwargs"] = parent_run_details["config_leaf"]["additional_kwargs"]
310
273
 
311
274
  # create model for the new run
312
275
  try:
@@ -324,7 +287,9 @@ class Controller:
324
287
  )
325
288
  elif ic_op == ControllerTask.IC_CLONE_MODIFY_WARM:
326
289
  # calculate clone chunk offset
327
- effective_batch_size = parent_run_details["config_leaf"]["training_args"].get("per_device_train_batch_size", 1) * parent_run_details["config_leaf"]["training_args"].get("gradient_accumulation_steps", 1)
290
+ effective_batch_size = parent_run_details["config_leaf"]["training_args"].get(
291
+ "per_device_train_batch_size", 1
292
+ ) * parent_run_details["config_leaf"]["training_args"].get("gradient_accumulation_steps", 1)
328
293
  chunker = DatasetChunks(
329
294
  len_train_dataset,
330
295
  num_chunks,
@@ -355,12 +320,8 @@ class Controller:
355
320
  )
356
321
  except Exception as e:
357
322
  self.db.set_ic_ops_task_status(task["task_id"], TaskStatus.FAILED)
358
- self.ic_logger.error(
359
- f"Error creating model for run {parent_run_id}: {e}"
360
- )
361
- raise ControllerException(
362
- f"Error creating model for run {parent_run_id}: {e}"
363
- ) from e
323
+ self.ic_logger.error(f"Error creating model for run {parent_run_id}: {e}")
324
+ raise ControllerException(f"Error creating model for run {parent_run_id}: {e}") from e
364
325
 
365
326
  def _process_interm_ic_ops_states(
366
327
  self,
@@ -389,11 +350,7 @@ class Controller:
389
350
  if is_clone_modify_task:
390
351
  # clone_modify tasks
391
352
  # get latest run state
392
- run_status = (
393
- run_states[run_id]["status"]
394
- if run_id in run_states
395
- else self.db.get_run(run_id)["status"]
396
- )
353
+ run_status = run_states[run_id]["status"] if run_id in run_states else self.db.get_run(run_id)["status"]
397
354
 
398
355
  # track clone_modify tasks only for non-deleted runs
399
356
  if run_status != RunStatus.DELETED:
@@ -401,9 +358,7 @@ class Controller:
401
358
  self.ic_logger.info(f"Added {task['ic_op']} task for run {run_id}.")
402
359
  else:
403
360
  self.db.set_ic_ops_task_status(task["task_id"], TaskStatus.SKIPPED)
404
- self.ic_logger.warning(
405
- f"Skipping {task['ic_op']} task for deleted run {run_id}."
406
- )
361
+ self.ic_logger.warning(f"Skipping {task['ic_op']} task for deleted run {run_id}.")
407
362
  else:
408
363
  # Non clone_modify tasks
409
364
  if run_id not in run_states:
@@ -420,32 +375,21 @@ class Controller:
420
375
  ControllerTask.IC_STOP,
421
376
  ]:
422
377
  # ignore RESUME/STOP tasks for completed runs
423
- self.ic_logger.warning(
424
- f"Ignoring RESUME/STOP task for run {run_id} as it is already completed"
425
- )
378
+ self.ic_logger.warning(f"Ignoring RESUME/STOP task for run {run_id} as it is already completed")
426
379
  self.db.set_ic_ops_task_status(task["task_id"], TaskStatus.SKIPPED)
427
- elif (
428
- current_status == RunStatus.FAILED
429
- and task["ic_op"] != ControllerTask.IC_DELETE
430
- ):
380
+ elif current_status == RunStatus.FAILED and task["ic_op"] != ControllerTask.IC_DELETE:
431
381
  # ignore all tasks except DELETE for failed runs
432
- self.ic_logger.warning(
433
- f"Ignoring task {task['ic_op'].value} for failed run {run_id}"
434
- )
382
+ self.ic_logger.warning(f"Ignoring task {task['ic_op'].value} for failed run {run_id}")
435
383
  self.db.set_ic_ops_task_status(task["task_id"], TaskStatus.SKIPPED)
436
384
  elif current_status == RunStatus.DELETED:
437
385
  # ignore all tasks for deleted runs
438
- self.ic_logger.warning(
439
- f"Ignoring task {task['ic_op'].value} for deleted run {run_id}"
440
- )
386
+ self.ic_logger.warning(f"Ignoring task {task['ic_op'].value} for deleted run {run_id}")
441
387
  self.db.set_ic_ops_task_status(task["task_id"], TaskStatus.SKIPPED)
442
388
  else:
443
389
  # valid ic_op for this run
444
390
  # mark prev task as completed
445
391
  if run_states[run_id]["task_id"] is not None:
446
- self.db.set_ic_ops_task_status(
447
- run_states[run_id]["task_id"], TaskStatus.COMPLETED
448
- )
392
+ self.db.set_ic_ops_task_status(run_states[run_id]["task_id"], TaskStatus.COMPLETED)
449
393
 
450
394
  # add new task to run states
451
395
  if task["ic_op"] == ControllerTask.IC_STOP:
@@ -458,26 +402,20 @@ class Controller:
458
402
  updated_status = RunStatus.ONGOING
459
403
  info_msg = f"Received RESUME task for run {run_id}"
460
404
  else:
461
- self.db.set_ic_ops_task_status(
462
- task["task_id"], TaskStatus.FAILED
463
- )
405
+ self.db.set_ic_ops_task_status(task["task_id"], TaskStatus.FAILED)
464
406
  raise ValueError(f"Unsupported task {task['ic_op']}")
465
407
  run_states[run_id].update(
466
408
  {
467
409
  "task_id": task["task_id"],
468
410
  "task": task["ic_op"],
469
- "status": (
470
- updated_status if updated_status else current_status
471
- ),
411
+ "status": (updated_status if updated_status else current_status),
472
412
  }
473
413
  )
474
414
  self.ic_logger.info(info_msg)
475
415
 
476
416
  return run_states, clone_modify_tasks
477
417
 
478
- def _get_total_step(
479
- self, config_leaf: dict[str, Any], len_train_dataset: int, num_chunks: int
480
- ) -> int:
418
+ def _get_total_step(self, config_leaf: dict[str, Any], len_train_dataset: int, num_chunks: int) -> int:
481
419
  """Get the total number of steps for a run."""
482
420
  num_train_epochs = config_leaf["training_args"].get("num_train_epochs", 1)
483
421
 
@@ -487,25 +425,20 @@ class Controller:
487
425
  # ceil to nearest chunk multiple
488
426
  total_steps = config_leaf["training_args"]["max_steps"]
489
427
  elif num_train_epochs:
490
- per_device_train_batch_size = config_leaf["training_args"].get(
491
- "per_device_train_batch_size", 1
492
- )
493
- gradient_accumulation_steps = config_leaf["training_args"].get(
494
- "gradient_accumulation_steps", 1
495
- )
428
+ per_device_train_batch_size = config_leaf["training_args"].get("per_device_train_batch_size", 1)
429
+ gradient_accumulation_steps = config_leaf["training_args"].get("gradient_accumulation_steps", 1)
496
430
  len_dataloader = math.ceil(len_train_dataset / per_device_train_batch_size)
497
431
  num_update_steps_per_epoch = max(
498
- len_dataloader // gradient_accumulation_steps
499
- + int(len_dataloader % gradient_accumulation_steps > 0),
432
+ len_dataloader // gradient_accumulation_steps + int(len_dataloader % gradient_accumulation_steps > 0),
500
433
  1,
501
434
  )
502
435
  total_steps = math.ceil(num_train_epochs * num_update_steps_per_epoch)
503
436
 
504
437
  if config_leaf.get("trainer_type", "SFT") == "GRPO":
505
438
  num_generations = config_leaf["training_args"].get("num_generations", 8)
506
- total_steps = (
507
- num_generations * len_train_dataset * num_train_epochs
508
- ) // (gradient_accumulation_steps * per_device_train_batch_size)
439
+ total_steps = (num_generations * len_train_dataset * num_train_epochs) // (
440
+ gradient_accumulation_steps * per_device_train_batch_size
441
+ )
509
442
  return total_steps
510
443
 
511
444
  def run_fit(
@@ -521,9 +454,7 @@ class Controller:
521
454
 
522
455
  # set experiment task to create models
523
456
  self.db.set_experiment_current_task(ExperimentTask.CREATE_MODELS)
524
- self.logger.debug(
525
- f"Set experiment task to {ExperimentTask.CREATE_MODELS.value}."
526
- )
457
+ self.logger.debug(f"Set experiment task to {ExperimentTask.CREATE_MODELS.value}.")
527
458
 
528
459
  # save train and eval dataset objects to a file for workers to load
529
460
  try:
@@ -608,10 +539,7 @@ class Controller:
608
539
  )
609
540
 
610
541
  # skip if task is the same as previous iteration (no change in status) or run is not active
611
- if (
612
- current_task_tuple == prev_task_tuple
613
- or worker_task["run_id"] not in scheduler.run_ids
614
- ):
542
+ if current_task_tuple == prev_task_tuple or worker_task["run_id"] not in scheduler.run_ids:
615
543
  continue
616
544
 
617
545
  if worker_task["status"] == TaskStatus.COMPLETED:
@@ -624,9 +552,7 @@ class Controller:
624
552
  run_id = worker_task["run_id"]
625
553
  chunk_id = worker_task["chunk_id"]
626
554
  run_details = all_run_details[run_id]
627
- self.logger.debug(
628
- f"Completed task: run {run_id}, chunk {chunk_id} on worker {worker_id}"
629
- )
555
+ self.logger.debug(f"Completed task: run {run_id}, chunk {chunk_id} on worker {worker_id}")
630
556
  self.logger.info(
631
557
  f"Run {run_id} completed steps - {run_details['completed_steps']}/{run_details['total_steps']}"
632
558
  )
@@ -648,11 +574,7 @@ class Controller:
648
574
 
649
575
  # Update progress
650
576
  progress_percentage = (
651
- (
652
- run_details["completed_steps"]
653
- / run_details["total_steps"]
654
- * 100
655
- )
577
+ (run_details["completed_steps"] / run_details["total_steps"] * 100)
656
578
  if run_details["total_steps"] > 0
657
579
  else 0
658
580
  )
@@ -673,16 +595,11 @@ class Controller:
673
595
  )
674
596
  # Check if run has completed only current epoch (hasn't reached total_steps yet)
675
597
  elif (
676
- new_chunks_visited == num_chunks
677
- and run_details["completed_steps"] < run_details["total_steps"]
598
+ new_chunks_visited == num_chunks and run_details["completed_steps"] < run_details["total_steps"]
678
599
  ):
679
600
  scheduler.reset_run(run_id)
680
- self.db.set_run_details(
681
- run_id=run_id, num_chunks_visited_curr_epoch=0
682
- )
683
- self.logger.info(
684
- f"Run {run_id} has completed epoch ({new_chunks_visited}/{num_chunks} chunks)"
685
- )
601
+ self.db.set_run_details(run_id=run_id, num_chunks_visited_curr_epoch=0)
602
+ self.logger.info(f"Run {run_id} has completed epoch ({new_chunks_visited}/{num_chunks} chunks)")
686
603
 
687
604
  # Check for failed runs and update scheduler, local state, shm
688
605
  for worker_task in failed_tasks:
@@ -698,12 +615,8 @@ class Controller:
698
615
 
699
616
  # Process interactive control tasks (this fetches latest run states internally)
700
617
  try:
701
- currently_scheduled_runs = list(
702
- scheduler.worker_running_current_run.values()
703
- )
704
- run_states, clone_modify_tasks = self._process_interm_ic_ops_states(
705
- currently_scheduled_runs
706
- )
618
+ currently_scheduled_runs = list(scheduler.worker_running_current_run.values())
619
+ run_states, clone_modify_tasks = self._process_interm_ic_ops_states(currently_scheduled_runs)
707
620
  self._process_interactive_control(
708
621
  run_states,
709
622
  clone_modify_tasks,
@@ -712,9 +625,7 @@ class Controller:
712
625
  num_chunks,
713
626
  )
714
627
  except Exception as e:
715
- raise ControllerException(
716
- f"Error processing interactive control tasks: {e}"
717
- ) from e
628
+ raise ControllerException(f"Error processing interactive control tasks: {e}") from e
718
629
 
719
630
  # fetch latest run states again (post IC ops states)
720
631
  all_run_details = self.db.get_all_runs()
@@ -728,8 +639,7 @@ class Controller:
728
639
  self.logger.debug(f"Added run {run_id} to scheduler with {chunks_visited} chunks visited")
729
640
  # remove inactive runs from scheduler
730
641
  elif (
731
- run_details["status"] in (RunStatus.STOPPED, RunStatus.DELETED)
732
- and run_id in scheduler.run_ids
642
+ run_details["status"] in (RunStatus.STOPPED, RunStatus.DELETED) and run_id in scheduler.run_ids
733
643
  ):
734
644
  scheduler.remove_run(run_id)
735
645
  self.logger.debug(f"Removed run {run_id} from scheduler")
@@ -742,9 +652,7 @@ class Controller:
742
652
 
743
653
  # Check termination condition
744
654
  if run_id is None and worker_id is None and chunk_id is None:
745
- self.logger.info(
746
- "Scheduler indicates all runs have completed all chunks"
747
- )
655
+ self.logger.info("Scheduler indicates all runs have completed all chunks")
748
656
  all_done = True
749
657
  break
750
658
 
@@ -766,9 +674,7 @@ class Controller:
766
674
  chunk_id,
767
675
  config_options={"create_model_fn": create_model_fn},
768
676
  )
769
- self.logger.debug(
770
- f"Scheduled run {run_id} on worker {worker_id} for chunk {chunk_id}"
771
- )
677
+ self.logger.debug(f"Scheduled run {run_id} on worker {worker_id} for chunk {chunk_id}")
772
678
 
773
679
  # Small delay
774
680
  time.sleep(1)
@@ -24,7 +24,16 @@ from rapidfireai.ml.checkpoint_utils import (
24
24
  save_model_to_shared_memory,
25
25
  )
26
26
  from rapidfireai.ml.trainer import create_trainer_instance
27
- from rapidfireai.utils.constants import MLFLOW_URL, USE_SHARED_MEMORY, TENSORBOARD_LOG_DIR, RunStatus, SHMObjectType, TaskStatus, WorkerTask, get_tracking_backend
27
+ from rapidfireai.utils.constants import (
28
+ MLFLOW_URL,
29
+ TENSORBOARD_LOG_DIR,
30
+ USE_SHARED_MEMORY,
31
+ RunStatus,
32
+ SHMObjectType,
33
+ TaskStatus,
34
+ WorkerTask,
35
+ get_tracking_backend,
36
+ )
28
37
  from rapidfireai.utils.datapaths import DataPath
29
38
  from rapidfireai.utils.exceptions import WorkerException
30
39
  from rapidfireai.utils.logging import RFLogger, TrainingLogger
@@ -82,7 +91,7 @@ class Worker:
82
91
  tensorboard_log_dir=tensorboard_log_dir,
83
92
  )
84
93
  # Get experiment if using MLflow
85
- if hasattr(self.metric_logger, 'get_experiment'):
94
+ if hasattr(self.metric_logger, "get_experiment"):
86
95
  self.metric_logger.get_experiment(self.experiment_name)
87
96
 
88
97
  # load datasets
@@ -119,7 +128,9 @@ class Worker:
119
128
  # torch.manual_seed(run_details["seed"])
120
129
  # np.random.seed(run_details["seed"])
121
130
  # random.seed(run_details["seed"])
122
- effective_batch_size = config_leaf["training_args"].get("per_device_train_batch_size", 1) * config_leaf["training_args"].get("gradient_accumulation_steps", 1)
131
+ effective_batch_size = config_leaf["training_args"].get("per_device_train_batch_size", 1) * config_leaf[
132
+ "training_args"
133
+ ].get("gradient_accumulation_steps", 1)
123
134
 
124
135
  # fetch train dataset chunk
125
136
  train_dataset_chunker = DatasetChunks(