rapidfireai 0.10.2rc5__py3-none-any.whl → 0.11.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rapidfireai might be problematic. Click here for more details.

Files changed (36) hide show
  1. rapidfireai/automl/grid_search.py +4 -5
  2. rapidfireai/automl/model_config.py +41 -37
  3. rapidfireai/automl/random_search.py +21 -33
  4. rapidfireai/backend/controller.py +80 -161
  5. rapidfireai/backend/worker.py +26 -8
  6. rapidfireai/cli.py +171 -132
  7. rapidfireai/db/rf_db.py +1 -1
  8. rapidfireai/db/tables.sql +1 -1
  9. rapidfireai/dispatcher/dispatcher.py +3 -1
  10. rapidfireai/dispatcher/gunicorn.conf.py +1 -1
  11. rapidfireai/experiment.py +86 -7
  12. rapidfireai/frontend/build/asset-manifest.json +3 -3
  13. rapidfireai/frontend/build/index.html +1 -1
  14. rapidfireai/frontend/build/static/js/{main.1bf27639.js → main.58393d31.js} +3 -3
  15. rapidfireai/frontend/build/static/js/{main.1bf27639.js.map → main.58393d31.js.map} +1 -1
  16. rapidfireai/frontend/proxy_middleware.py +1 -1
  17. rapidfireai/ml/callbacks.py +85 -59
  18. rapidfireai/ml/trainer.py +42 -86
  19. rapidfireai/start.sh +117 -34
  20. rapidfireai/utils/constants.py +22 -1
  21. rapidfireai/utils/experiment_utils.py +87 -43
  22. rapidfireai/utils/interactive_controller.py +473 -0
  23. rapidfireai/utils/logging.py +1 -2
  24. rapidfireai/utils/metric_logger.py +346 -0
  25. rapidfireai/utils/mlflow_manager.py +0 -1
  26. rapidfireai/utils/ping.py +4 -2
  27. rapidfireai/utils/worker_manager.py +16 -6
  28. rapidfireai/version.py +2 -2
  29. {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/METADATA +7 -4
  30. {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/RECORD +36 -33
  31. tutorial_notebooks/rf-colab-tensorboard-tutorial.ipynb +314 -0
  32. /rapidfireai/frontend/build/static/js/{main.1bf27639.js.LICENSE.txt → main.58393d31.js.LICENSE.txt} +0 -0
  33. {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/WHEEL +0 -0
  34. {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/entry_points.txt +0 -0
  35. {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/licenses/LICENSE +0 -0
  36. {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/top_level.txt +0 -0
@@ -25,7 +25,7 @@ class UserProxyManager:
25
25
  self.default_proxy = {
26
26
  'main_proxy_target': 'http://127.0.0.1:5002/',
27
27
  'static_proxy_target': 'http://127.0.0.1:5002/',
28
- 'dispatcher_proxy_target': 'http://127.0.0.1:8080/',
28
+ 'dispatcher_proxy_target': 'http://127.0.0.1:8081/',
29
29
  }
30
30
 
31
31
  def get_user_proxy(self, user_id: str) -> Dict[str, str]:
@@ -1,14 +1,9 @@
1
- from typing import Callable, Dict, List, Optional
1
+ from collections.abc import Callable
2
2
 
3
3
  import torch
4
4
  from datasets import Dataset
5
5
  from tqdm import tqdm
6
- from transformers import (
7
- TrainerCallback,
8
- TrainerControl,
9
- TrainerState,
10
- TrainingArguments,
11
- )
6
+ from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
12
7
  from transformers.trainer_utils import IntervalStrategy, SaveStrategy
13
8
 
14
9
 
@@ -17,10 +12,10 @@ class GenerationMetricsCallback(TrainerCallback):
17
12
  self,
18
13
  tokenizer,
19
14
  eval_dataset: Dataset,
20
- generation_config: Optional[Dict] = None,
15
+ generation_config: dict | None = None,
21
16
  compute_metrics: Callable = None,
22
17
  batch_size: int = 8,
23
- mlflow_manager=None,
18
+ metric_logger=None,
24
19
  mlflow_run_id: str = None,
25
20
  completed_steps: int = 0,
26
21
  ):
@@ -36,7 +31,7 @@ class GenerationMetricsCallback(TrainerCallback):
36
31
  "pad_token_id": tokenizer.pad_token_id,
37
32
  "eos_token_id": tokenizer.eos_token_id,
38
33
  }
39
- self.mlflow_manager = mlflow_manager
34
+ self.metric_logger = metric_logger
40
35
  self.mlflow_run_id = mlflow_run_id
41
36
  self.completed_steps = completed_steps
42
37
 
@@ -63,8 +58,8 @@ class GenerationMetricsCallback(TrainerCallback):
63
58
  state.log_history.append(metrics)
64
59
 
65
60
  for key, value in metrics.items():
66
- if self.mlflow_manager:
67
- self.mlflow_manager.log_metric(
61
+ if self.metric_logger:
62
+ self.metric_logger.log_metric(
68
63
  self.mlflow_run_id,
69
64
  key,
70
65
  value,
@@ -72,43 +67,69 @@ class GenerationMetricsCallback(TrainerCallback):
72
67
  )
73
68
 
74
69
  def _prepare_data(self, eval_dataset: Dataset) -> tuple:
75
- """Prepare batch data for generation"""
70
+ """Prepare batch data for generation with defensive validation"""
76
71
  input_texts = []
77
72
  references = []
78
73
 
79
74
  for item in eval_dataset:
80
- if isinstance(item, dict):
81
- if "input" in item and "output" in item:
82
- input_text = item["input"]
83
- reference = item["output"]
84
- elif "prompt" in item and "completion" in item:
85
- input_text = item["prompt"]
86
- reference = item["completion"][-1]["content"]
87
- input_text = self.tokenizer.apply_chat_template(
88
- input_text, tokenize=False
89
- )
90
- else:
91
- continue
92
-
93
- input_texts.append(input_text)
94
- references.append(reference)
75
+ if not isinstance(item, dict):
76
+ continue
77
+
78
+ input_text = None
79
+ reference = None
80
+
81
+ # Support multiple field name patterns
82
+ if "input" in item and "output" in item:
83
+ input_text = item["input"]
84
+ reference = item["output"]
85
+ elif "prompt" in item and "completion" in item:
86
+ input_text = item["prompt"]
87
+ reference = item["completion"][-1]["content"]
88
+ input_text = self.tokenizer.apply_chat_template(input_text, tokenize=False)
89
+ elif "text" in item:
90
+ # SFT format - use text as input, response as reference
91
+ input_text = item["text"]
92
+ reference = item.get("response", item.get("instruction", item["text"]))
93
+ elif "instruction" in item and "response" in item:
94
+ # Direct instruction/response format
95
+ input_text = item["instruction"]
96
+ reference = item["response"]
97
+
98
+ # Validate non-empty strings
99
+ if input_text and isinstance(input_text, str) and input_text.strip():
100
+ if reference and isinstance(reference, str) and reference.strip():
101
+ input_texts.append(input_text.strip())
102
+ references.append(reference.strip())
103
+
104
+ # Return safe empty values to prevent downstream errors
105
+ if not input_texts:
106
+ return [], []
95
107
 
96
108
  return input_texts, references
97
109
 
98
- def _generate_batch(self, model, input_texts: List[str]) -> List[str]:
99
- """Generate text for a batch of inputs"""
100
- # Tokenize batch
101
- inputs = self.tokenizer(
102
- input_texts,
103
- return_tensors="pt",
104
- padding=True,
105
- truncation=True,
106
- max_length=512, # Adjust based on your model's context length
107
- ).to(model.device)
110
+ def _generate_batch(self, model, input_texts: list[str]) -> torch.Tensor:
111
+ """Generate text for a batch of inputs with defensive validation"""
112
+ # Defensive validation for empty inputs
113
+ if not input_texts:
114
+ return torch.empty((0, 0), dtype=torch.long).to(model.device)
108
115
 
109
- return inputs["input_ids"]
110
-
111
- def _compute_generation_metrics(self, model, step: int) -> Dict[str, float]:
116
+ try:
117
+ # Tokenize batch
118
+ inputs = self.tokenizer(
119
+ input_texts,
120
+ return_tensors="pt",
121
+ padding=True,
122
+ truncation=True,
123
+ max_length=512, # Adjust based on your model's context length
124
+ ).to(model.device)
125
+
126
+ return inputs["input_ids"]
127
+ except Exception as e:
128
+ # Log error and return empty tensor to prevent crash
129
+ print(f"Warning: Tokenization error in generation callback: {e}")
130
+ return torch.empty((0, 0), dtype=torch.long).to(model.device)
131
+
132
+ def _compute_generation_metrics(self, model, step: int) -> dict[str, float]:
112
133
  """Generate text and compute BLEU/ROUGE metrics with batch processing"""
113
134
  model.eval()
114
135
 
@@ -121,16 +142,24 @@ class GenerationMetricsCallback(TrainerCallback):
121
142
 
122
143
  # Process in batches
123
144
  input_texts, batch_references = self._prepare_data(self.eval_dataset)
145
+
146
+ # Early return if no valid data
147
+ if not input_texts:
148
+ print("Warning: No valid eval data for generation metrics")
149
+ return {}
150
+
124
151
  input_ids = self._generate_batch(model, input_texts)
152
+
153
+ # Check for empty generation batch
154
+ if input_ids.numel() == 0:
155
+ print("Warning: Empty input_ids from tokenization")
156
+ return {}
157
+
125
158
  with torch.no_grad():
126
- for i in tqdm(
127
- range(0, len(indices), self.batch_size), desc="Generating for metrics"
128
- ):
159
+ for i in tqdm(range(0, len(indices), self.batch_size), desc="Generating for metrics"):
129
160
  input_ids_batch = input_ids[i : i + self.batch_size]
130
161
  with torch.inference_mode(), torch.cuda.amp.autocast():
131
- outputs_batch = model.generate(
132
- input_ids_batch, **self.generation_config
133
- )
162
+ outputs_batch = model.generate(input_ids_batch, **self.generation_config)
134
163
  generated_texts = self.tokenizer.batch_decode(
135
164
  outputs_batch[:, input_ids_batch.shape[1] :],
136
165
  skip_special_tokens=True,
@@ -155,18 +184,18 @@ class GenerationMetricsCallback(TrainerCallback):
155
184
 
156
185
 
157
186
  class MLflowLoggingCallback(TrainerCallback):
158
- """Callback for logging metrics to MLflow during training"""
187
+ """Callback for logging metrics to tracking backend during training"""
159
188
 
160
189
  def __init__(
161
190
  self,
162
- mlflow_manager,
191
+ metric_logger,
163
192
  mlflow_run_id: str,
164
193
  excluded_keys: list = None,
165
194
  completed_steps: int = 0,
166
195
  chunk_id: int = 0,
167
196
  num_epochs_completed: int = 0,
168
197
  ):
169
- self.mlflow_manager = mlflow_manager
198
+ self.metric_logger = metric_logger
170
199
  self.mlflow_run_id = mlflow_run_id
171
200
  self.completed_steps = completed_steps
172
201
  self.excluded_keys = excluded_keys or [
@@ -189,22 +218,22 @@ class MLflowLoggingCallback(TrainerCallback):
189
218
  for key, value in logs.items():
190
219
  if isinstance(value, (int, float)) and key not in self.excluded_keys:
191
220
  try:
192
- self.mlflow_manager.log_metric(
221
+ self.metric_logger.log_metric(
193
222
  self.mlflow_run_id,
194
223
  key,
195
224
  value,
196
225
  step=self.completed_steps + state.global_step,
197
226
  )
198
227
  except Exception as e:
199
- print(f"Warning: Failed to log metric {key} to MLflow: {e}")
228
+ print(f"Warning: Failed to log metric {key} to tracking backend: {e}")
200
229
  if "eval_loss" not in logs and "train_runtime" not in logs:
201
- self.mlflow_manager.log_metric(
230
+ self.metric_logger.log_metric(
202
231
  self.mlflow_run_id,
203
232
  "chunk number",
204
233
  self.chunk_id,
205
234
  step=self.completed_steps + state.global_step,
206
235
  )
207
- self.mlflow_manager.log_metric(
236
+ self.metric_logger.log_metric(
208
237
  self.mlflow_run_id,
209
238
  "num_epochs_completed",
210
239
  self.num_epochs_completed,
@@ -217,7 +246,7 @@ class LogLevelCallback(TrainerCallback):
217
246
  A [`TrainerCallback`] that handles the default flow of the training loop for logs, evaluation and checkpoints.
218
247
  """
219
248
 
220
- def __init__(self, global_step_args: Dict):
249
+ def __init__(self, global_step_args: dict):
221
250
  self.eval_first_step = global_step_args.get("eval_first_step", 0)
222
251
  self.actual_steps = global_step_args.get("actual_steps", 0)
223
252
  self.log_first_step = global_step_args.get("log_first_step", 0)
@@ -275,10 +304,7 @@ class LogLevelCallback(TrainerCallback):
275
304
  control.should_log = True
276
305
 
277
306
  # Evaluate
278
- if (
279
- args.eval_strategy == IntervalStrategy.EPOCH
280
- and args.eval_delay <= state.epoch
281
- ):
307
+ if args.eval_strategy == IntervalStrategy.EPOCH and args.eval_delay <= state.epoch:
282
308
  control.should_evaluate = True
283
309
 
284
310
  # Save
rapidfireai/ml/trainer.py CHANGED
@@ -1,4 +1,3 @@
1
- import logging
2
1
  import math
3
2
  import os
4
3
 
@@ -7,11 +6,7 @@ from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dic
7
6
  from transformers.utils.logging import set_verbosity_error
8
7
  from trl import DPOConfig, DPOTrainer, GRPOConfig, GRPOTrainer, SFTConfig, SFTTrainer
9
8
 
10
- from rapidfireai.ml.callbacks import (
11
- GenerationMetricsCallback,
12
- MLflowLoggingCallback,
13
- LogLevelCallback,
14
- )
9
+ from rapidfireai.ml.callbacks import GenerationMetricsCallback, LogLevelCallback, MLflowLoggingCallback
15
10
  from rapidfireai.ml.checkpoint_utils import (
16
11
  ensure_gradient_compatibility,
17
12
  load_checkpoint_from_disk,
@@ -34,7 +29,7 @@ def create_trainer_instance(
34
29
  trainer_config: TrainerConfig,
35
30
  shm_manager: SharedMemoryManager,
36
31
  use_shared_memory: bool = False,
37
- mlflow_manager=None,
32
+ metric_logger=None,
38
33
  chunk_id: int = 0,
39
34
  ) -> tuple[SFTTrainer | DPOTrainer | GRPOTrainer | None, str]:
40
35
  """
@@ -51,21 +46,15 @@ def create_trainer_instance(
51
46
  compute_metrics = additional_trainer_kwargs.get("compute_metrics", None)
52
47
 
53
48
  # Configure training arguments
54
- training_args, global_step_args = _configure_training_args(
55
- training_args, trainer_config
56
- )
49
+ training_args, global_step_args = _configure_training_args(training_args, trainer_config)
57
50
  trainer_config_obj = _create_trainer_config_object(trainer_type, training_args)
58
51
  # check if peft params is empty dict
59
52
  is_peft = bool(config_leaf.get("peft_params"))
60
53
  # Load model and tokenizer
61
54
  if use_shared_memory:
62
- model_instance, tokenizer = load_checkpoint_from_shared_memory(
63
- trainer_config, shm_manager, is_peft=is_peft
64
- )
55
+ model_instance, tokenizer = load_checkpoint_from_shared_memory(trainer_config, shm_manager, is_peft=is_peft)
65
56
  else:
66
- model_instance, tokenizer = load_checkpoint_from_disk(
67
- trainer_config, is_peft=is_peft
68
- )
57
+ model_instance, tokenizer = load_checkpoint_from_disk(trainer_config, is_peft=is_peft)
69
58
  # add model name to model config
70
59
  config_leaf["model_name"] = model_instance.config._name_or_path
71
60
 
@@ -84,30 +73,26 @@ def create_trainer_instance(
84
73
 
85
74
  model_instance = model_instance.to(device)
86
75
 
87
- trainer_kwargs, formatting_func, additional_trainer_kwargs = (
88
- _prepare_trainer_kwargs(
89
- model_instance,
90
- trainer_config_obj,
91
- tokenizer,
92
- trainer_config,
93
- additional_trainer_kwargs,
94
- ref_model_instance,
95
- config_leaf,
96
- )
76
+ trainer_kwargs, formatting_func, additional_trainer_kwargs = _prepare_trainer_kwargs(
77
+ model_instance,
78
+ trainer_config_obj,
79
+ tokenizer,
80
+ trainer_config,
81
+ additional_trainer_kwargs,
82
+ ref_model_instance,
83
+ config_leaf,
97
84
  )
98
85
 
99
- callbacks, additional_trainer_kwargs = (
100
- _setup_callbacks( # FIXME: avoid returning additional_trainer_kwargs
101
- mlflow_manager,
102
- trainer_config,
103
- chunk_id,
104
- compute_metrics,
105
- additional_trainer_kwargs,
106
- tokenizer,
107
- training_args,
108
- formatting_func,
109
- global_step_args,
110
- )
86
+ callbacks, additional_trainer_kwargs = _setup_callbacks( # FIXME: avoid returning additional_trainer_kwargs
87
+ metric_logger,
88
+ trainer_config,
89
+ chunk_id,
90
+ compute_metrics,
91
+ additional_trainer_kwargs,
92
+ tokenizer,
93
+ training_args,
94
+ formatting_func,
95
+ global_step_args,
111
96
  )
112
97
 
113
98
  if callbacks:
@@ -116,29 +101,22 @@ def create_trainer_instance(
116
101
  trainer_kwargs.update(additional_trainer_kwargs)
117
102
  trainer_kwargs = {k: v for k, v in trainer_kwargs.items() if v is not None}
118
103
 
119
- trainer = _create_trainer_by_type(
120
- trainer_type, trainer_kwargs, trainer_config, use_shared_memory, shm_manager
121
- )
104
+ trainer = _create_trainer_by_type(trainer_type, trainer_kwargs, trainer_config, use_shared_memory, shm_manager)
122
105
  return trainer, config_leaf["model_name"]
123
106
 
124
107
 
125
- def _configure_training_args(
126
- training_args: dict, trainer_config: TrainerConfig
127
- ) -> dict:
108
+ def _configure_training_args(training_args: dict, trainer_config: TrainerConfig) -> dict:
128
109
  """Configure training arguments with default values."""
129
110
  completed_steps = trainer_config.completed_steps
130
111
  per_device_train_batch_size = training_args.get("per_device_train_batch_size", 1)
131
112
  gradient_accumulation_steps = training_args.get("gradient_accumulation_steps", 1)
132
- len_dataloader = math.ceil(
133
- trainer_config.train_dataset.num_rows / per_device_train_batch_size
134
- )
113
+ len_dataloader = math.ceil(trainer_config.train_dataset.num_rows / per_device_train_batch_size)
135
114
  steps_per_epoch = max(
136
- len_dataloader // gradient_accumulation_steps
137
- + int(len_dataloader % gradient_accumulation_steps > 0),
115
+ len_dataloader // gradient_accumulation_steps + int(len_dataloader % gradient_accumulation_steps > 0),
138
116
  1,
139
117
  )
140
118
 
141
- if trainer_config.config_leaf.get("trainer_type","SFT") == "GRPO":
119
+ if trainer_config.config_leaf.get("trainer_type", "SFT") == "GRPO":
142
120
  num_generations = training_args.get("num_generations", 8)
143
121
  steps_per_epoch = (num_generations * trainer_config.train_dataset.num_rows) // (
144
122
  gradient_accumulation_steps * per_device_train_batch_size
@@ -215,10 +193,7 @@ def _setup_reference_model(
215
193
  if model_adapter_name is not None and ref_adapter_name is not None:
216
194
  if use_shared_memory:
217
195
  peft_config = LoraConfig(**config_leaf["peft_params"])
218
- if (
219
- trainer_config.completed_steps == 0
220
- and trainer_config.warm_started_from is None
221
- ):
196
+ if trainer_config.completed_steps == 0 and trainer_config.warm_started_from is None:
222
197
  reference_state_dict = get_peft_model_state_dict(model_instance)
223
198
  reference_state_dict = move_tensors_to_cpu(reference_state_dict)
224
199
  shm_manager.save_model_object(
@@ -230,14 +205,10 @@ def _setup_reference_model(
230
205
  reference_state_dict = shm_manager.load_model_object(
231
206
  trainer_config.run_id, SHMObjectType.REF_STATE_DICT
232
207
  )
233
- reference_state_dict = move_tensors_to_device(
234
- reference_state_dict, device
235
- )
208
+ reference_state_dict = move_tensors_to_device(reference_state_dict, device)
236
209
  model_instance.add_adapter(ref_adapter_name, peft_config)
237
210
  model_instance.set_adapter(ref_adapter_name)
238
- set_peft_model_state_dict(
239
- model_instance, reference_state_dict, adapter_name=ref_adapter_name
240
- )
211
+ set_peft_model_state_dict(model_instance, reference_state_dict, adapter_name=ref_adapter_name)
241
212
  model_instance.set_adapter(model_adapter_name)
242
213
  else:
243
214
  base_run_path = DataPath.base_run_path(trainer_config.run_id)
@@ -289,9 +260,7 @@ def _prepare_trainer_kwargs(
289
260
 
290
261
  if additional_trainer_kwargs.get("formatting_func") is not None:
291
262
  formatting_func = additional_trainer_kwargs.get("formatting_func")
292
- train_dataset = train_dataset.map(
293
- formatting_func
294
- ) # FIXME: add try exception with batched/unbatched
263
+ train_dataset = train_dataset.map(formatting_func) # FIXME: add try exception with batched/unbatched
295
264
  if eval_dataset is not None:
296
265
  eval_dataset = eval_dataset.map(formatting_func)
297
266
  additional_trainer_kwargs_copy = additional_trainer_kwargs.copy()
@@ -314,7 +283,7 @@ def _prepare_trainer_kwargs(
314
283
 
315
284
 
316
285
  def _setup_callbacks(
317
- mlflow_manager,
286
+ metric_logger,
318
287
  trainer_config,
319
288
  chunk_id,
320
289
  compute_metrics,
@@ -327,9 +296,9 @@ def _setup_callbacks(
327
296
  """Setup callbacks for the trainer."""
328
297
  callbacks = []
329
298
 
330
- if mlflow_manager is not None and trainer_config.mlflow_run_id is not None:
299
+ if metric_logger is not None and trainer_config.mlflow_run_id is not None:
331
300
  mlflow_callback = MLflowLoggingCallback(
332
- mlflow_manager=mlflow_manager,
301
+ metric_logger=metric_logger,
333
302
  mlflow_run_id=trainer_config.mlflow_run_id,
334
303
  completed_steps=trainer_config.completed_steps,
335
304
  chunk_id=chunk_id,
@@ -337,10 +306,7 @@ def _setup_callbacks(
337
306
  )
338
307
  callbacks.append(mlflow_callback)
339
308
 
340
- if (
341
- compute_metrics is not None
342
- and additional_trainer_kwargs.get("generation_config") is not None
343
- ):
309
+ if compute_metrics is not None and additional_trainer_kwargs.get("generation_config") is not None:
344
310
  compute_metrics_function = compute_metrics
345
311
  if formatting_func is not None:
346
312
  formatted_eval_dataset = trainer_config.eval_dataset.map(formatting_func)
@@ -353,7 +319,7 @@ def _setup_callbacks(
353
319
  generation_config=additional_trainer_kwargs.get("generation_config"),
354
320
  compute_metrics=compute_metrics_function,
355
321
  batch_size=training_args.get("per_device_eval_batch_size"),
356
- mlflow_manager=mlflow_manager,
322
+ metric_logger=metric_logger,
357
323
  mlflow_run_id=trainer_config.mlflow_run_id,
358
324
  completed_steps=trainer_config.completed_steps,
359
325
  )
@@ -365,15 +331,11 @@ def _setup_callbacks(
365
331
  return callbacks, additional_trainer_kwargs
366
332
 
367
333
 
368
- def _create_trainer_by_type(
369
- trainer_type, trainer_kwargs, trainer_config, use_shared_memory, shm_manager
370
- ):
334
+ def _create_trainer_by_type(trainer_type, trainer_kwargs, trainer_config, use_shared_memory, shm_manager):
371
335
  """Create trainer instance based on type with proper state restoration."""
372
336
  if trainer_type == "SFT":
373
337
  dummy_trainer = SFTTrainer(**trainer_kwargs)
374
- dummy_trainer.create_optimizer_and_scheduler(
375
- num_training_steps=trainer_config.total_steps
376
- )
338
+ dummy_trainer.create_optimizer_and_scheduler(num_training_steps=trainer_config.total_steps)
377
339
  trainer = SFTTrainer(
378
340
  **trainer_kwargs,
379
341
  optimizers=(dummy_trainer.optimizer, dummy_trainer.lr_scheduler),
@@ -382,9 +344,7 @@ def _create_trainer_by_type(
382
344
 
383
345
  elif trainer_type == "DPO":
384
346
  dummy_trainer = DPOTrainer(**trainer_kwargs)
385
- dummy_trainer.create_optimizer_and_scheduler(
386
- num_training_steps=trainer_config.total_steps
387
- )
347
+ dummy_trainer.create_optimizer_and_scheduler(num_training_steps=trainer_config.total_steps)
388
348
  trainer = DPOTrainer(
389
349
  **trainer_kwargs,
390
350
  optimizers=(dummy_trainer.optimizer, dummy_trainer.lr_scheduler),
@@ -393,9 +353,7 @@ def _create_trainer_by_type(
393
353
 
394
354
  elif trainer_type == "GRPO":
395
355
  dummy_trainer = GRPOTrainer(**trainer_kwargs)
396
- dummy_trainer.create_optimizer_and_scheduler(
397
- num_training_steps=trainer_config.total_steps
398
- )
356
+ dummy_trainer.create_optimizer_and_scheduler(num_training_steps=trainer_config.total_steps)
399
357
  trainer = GRPOTrainer(
400
358
  **trainer_kwargs,
401
359
  optimizers=(dummy_trainer.optimizer, dummy_trainer.lr_scheduler),
@@ -406,9 +364,7 @@ def _create_trainer_by_type(
406
364
 
407
365
  if trainer_config.completed_steps > 0:
408
366
  if use_shared_memory:
409
- trainer = restore_trainer_from_shared_memory(
410
- trainer, trainer_config, shm_manager
411
- )
367
+ trainer = restore_trainer_from_shared_memory(trainer, trainer_config, shm_manager)
412
368
  else:
413
369
  trainer = restore_trainer_from_disk(trainer, trainer_config)
414
370