rapidfireai 0.10.2rc5__py3-none-any.whl → 0.11.1rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rapidfireai might be problematic. Click here for more details.
- rapidfireai/automl/grid_search.py +4 -5
- rapidfireai/automl/model_config.py +41 -37
- rapidfireai/automl/random_search.py +21 -33
- rapidfireai/backend/controller.py +80 -161
- rapidfireai/backend/worker.py +26 -8
- rapidfireai/cli.py +171 -132
- rapidfireai/db/rf_db.py +1 -1
- rapidfireai/db/tables.sql +1 -1
- rapidfireai/dispatcher/dispatcher.py +3 -1
- rapidfireai/dispatcher/gunicorn.conf.py +1 -1
- rapidfireai/experiment.py +86 -7
- rapidfireai/frontend/build/asset-manifest.json +3 -3
- rapidfireai/frontend/build/index.html +1 -1
- rapidfireai/frontend/build/static/js/{main.1bf27639.js → main.58393d31.js} +3 -3
- rapidfireai/frontend/build/static/js/{main.1bf27639.js.map → main.58393d31.js.map} +1 -1
- rapidfireai/frontend/proxy_middleware.py +1 -1
- rapidfireai/ml/callbacks.py +85 -59
- rapidfireai/ml/trainer.py +42 -86
- rapidfireai/start.sh +117 -34
- rapidfireai/utils/constants.py +22 -1
- rapidfireai/utils/experiment_utils.py +87 -43
- rapidfireai/utils/interactive_controller.py +473 -0
- rapidfireai/utils/logging.py +1 -2
- rapidfireai/utils/metric_logger.py +346 -0
- rapidfireai/utils/mlflow_manager.py +0 -1
- rapidfireai/utils/ping.py +4 -2
- rapidfireai/utils/worker_manager.py +16 -6
- rapidfireai/version.py +2 -2
- {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/METADATA +7 -4
- {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/RECORD +36 -33
- tutorial_notebooks/rf-colab-tensorboard-tutorial.ipynb +314 -0
- /rapidfireai/frontend/build/static/js/{main.1bf27639.js.LICENSE.txt → main.58393d31.js.LICENSE.txt} +0 -0
- {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/WHEEL +0 -0
- {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/entry_points.txt +0 -0
- {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/licenses/LICENSE +0 -0
- {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/top_level.txt +0 -0
|
@@ -25,7 +25,7 @@ class UserProxyManager:
|
|
|
25
25
|
self.default_proxy = {
|
|
26
26
|
'main_proxy_target': 'http://127.0.0.1:5002/',
|
|
27
27
|
'static_proxy_target': 'http://127.0.0.1:5002/',
|
|
28
|
-
'dispatcher_proxy_target': 'http://127.0.0.1:
|
|
28
|
+
'dispatcher_proxy_target': 'http://127.0.0.1:8081/',
|
|
29
29
|
}
|
|
30
30
|
|
|
31
31
|
def get_user_proxy(self, user_id: str) -> Dict[str, str]:
|
rapidfireai/ml/callbacks.py
CHANGED
|
@@ -1,14 +1,9 @@
|
|
|
1
|
-
from
|
|
1
|
+
from collections.abc import Callable
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
from datasets import Dataset
|
|
5
5
|
from tqdm import tqdm
|
|
6
|
-
from transformers import
|
|
7
|
-
TrainerCallback,
|
|
8
|
-
TrainerControl,
|
|
9
|
-
TrainerState,
|
|
10
|
-
TrainingArguments,
|
|
11
|
-
)
|
|
6
|
+
from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
|
|
12
7
|
from transformers.trainer_utils import IntervalStrategy, SaveStrategy
|
|
13
8
|
|
|
14
9
|
|
|
@@ -17,10 +12,10 @@ class GenerationMetricsCallback(TrainerCallback):
|
|
|
17
12
|
self,
|
|
18
13
|
tokenizer,
|
|
19
14
|
eval_dataset: Dataset,
|
|
20
|
-
generation_config:
|
|
15
|
+
generation_config: dict | None = None,
|
|
21
16
|
compute_metrics: Callable = None,
|
|
22
17
|
batch_size: int = 8,
|
|
23
|
-
|
|
18
|
+
metric_logger=None,
|
|
24
19
|
mlflow_run_id: str = None,
|
|
25
20
|
completed_steps: int = 0,
|
|
26
21
|
):
|
|
@@ -36,7 +31,7 @@ class GenerationMetricsCallback(TrainerCallback):
|
|
|
36
31
|
"pad_token_id": tokenizer.pad_token_id,
|
|
37
32
|
"eos_token_id": tokenizer.eos_token_id,
|
|
38
33
|
}
|
|
39
|
-
self.
|
|
34
|
+
self.metric_logger = metric_logger
|
|
40
35
|
self.mlflow_run_id = mlflow_run_id
|
|
41
36
|
self.completed_steps = completed_steps
|
|
42
37
|
|
|
@@ -63,8 +58,8 @@ class GenerationMetricsCallback(TrainerCallback):
|
|
|
63
58
|
state.log_history.append(metrics)
|
|
64
59
|
|
|
65
60
|
for key, value in metrics.items():
|
|
66
|
-
if self.
|
|
67
|
-
self.
|
|
61
|
+
if self.metric_logger:
|
|
62
|
+
self.metric_logger.log_metric(
|
|
68
63
|
self.mlflow_run_id,
|
|
69
64
|
key,
|
|
70
65
|
value,
|
|
@@ -72,43 +67,69 @@ class GenerationMetricsCallback(TrainerCallback):
|
|
|
72
67
|
)
|
|
73
68
|
|
|
74
69
|
def _prepare_data(self, eval_dataset: Dataset) -> tuple:
|
|
75
|
-
"""Prepare batch data for generation"""
|
|
70
|
+
"""Prepare batch data for generation with defensive validation"""
|
|
76
71
|
input_texts = []
|
|
77
72
|
references = []
|
|
78
73
|
|
|
79
74
|
for item in eval_dataset:
|
|
80
|
-
if isinstance(item, dict):
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
75
|
+
if not isinstance(item, dict):
|
|
76
|
+
continue
|
|
77
|
+
|
|
78
|
+
input_text = None
|
|
79
|
+
reference = None
|
|
80
|
+
|
|
81
|
+
# Support multiple field name patterns
|
|
82
|
+
if "input" in item and "output" in item:
|
|
83
|
+
input_text = item["input"]
|
|
84
|
+
reference = item["output"]
|
|
85
|
+
elif "prompt" in item and "completion" in item:
|
|
86
|
+
input_text = item["prompt"]
|
|
87
|
+
reference = item["completion"][-1]["content"]
|
|
88
|
+
input_text = self.tokenizer.apply_chat_template(input_text, tokenize=False)
|
|
89
|
+
elif "text" in item:
|
|
90
|
+
# SFT format - use text as input, response as reference
|
|
91
|
+
input_text = item["text"]
|
|
92
|
+
reference = item.get("response", item.get("instruction", item["text"]))
|
|
93
|
+
elif "instruction" in item and "response" in item:
|
|
94
|
+
# Direct instruction/response format
|
|
95
|
+
input_text = item["instruction"]
|
|
96
|
+
reference = item["response"]
|
|
97
|
+
|
|
98
|
+
# Validate non-empty strings
|
|
99
|
+
if input_text and isinstance(input_text, str) and input_text.strip():
|
|
100
|
+
if reference and isinstance(reference, str) and reference.strip():
|
|
101
|
+
input_texts.append(input_text.strip())
|
|
102
|
+
references.append(reference.strip())
|
|
103
|
+
|
|
104
|
+
# Return safe empty values to prevent downstream errors
|
|
105
|
+
if not input_texts:
|
|
106
|
+
return [], []
|
|
95
107
|
|
|
96
108
|
return input_texts, references
|
|
97
109
|
|
|
98
|
-
def _generate_batch(self, model, input_texts:
|
|
99
|
-
"""Generate text for a batch of inputs"""
|
|
100
|
-
#
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
return_tensors="pt",
|
|
104
|
-
padding=True,
|
|
105
|
-
truncation=True,
|
|
106
|
-
max_length=512, # Adjust based on your model's context length
|
|
107
|
-
).to(model.device)
|
|
110
|
+
def _generate_batch(self, model, input_texts: list[str]) -> torch.Tensor:
|
|
111
|
+
"""Generate text for a batch of inputs with defensive validation"""
|
|
112
|
+
# Defensive validation for empty inputs
|
|
113
|
+
if not input_texts:
|
|
114
|
+
return torch.empty((0, 0), dtype=torch.long).to(model.device)
|
|
108
115
|
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
116
|
+
try:
|
|
117
|
+
# Tokenize batch
|
|
118
|
+
inputs = self.tokenizer(
|
|
119
|
+
input_texts,
|
|
120
|
+
return_tensors="pt",
|
|
121
|
+
padding=True,
|
|
122
|
+
truncation=True,
|
|
123
|
+
max_length=512, # Adjust based on your model's context length
|
|
124
|
+
).to(model.device)
|
|
125
|
+
|
|
126
|
+
return inputs["input_ids"]
|
|
127
|
+
except Exception as e:
|
|
128
|
+
# Log error and return empty tensor to prevent crash
|
|
129
|
+
print(f"Warning: Tokenization error in generation callback: {e}")
|
|
130
|
+
return torch.empty((0, 0), dtype=torch.long).to(model.device)
|
|
131
|
+
|
|
132
|
+
def _compute_generation_metrics(self, model, step: int) -> dict[str, float]:
|
|
112
133
|
"""Generate text and compute BLEU/ROUGE metrics with batch processing"""
|
|
113
134
|
model.eval()
|
|
114
135
|
|
|
@@ -121,16 +142,24 @@ class GenerationMetricsCallback(TrainerCallback):
|
|
|
121
142
|
|
|
122
143
|
# Process in batches
|
|
123
144
|
input_texts, batch_references = self._prepare_data(self.eval_dataset)
|
|
145
|
+
|
|
146
|
+
# Early return if no valid data
|
|
147
|
+
if not input_texts:
|
|
148
|
+
print("Warning: No valid eval data for generation metrics")
|
|
149
|
+
return {}
|
|
150
|
+
|
|
124
151
|
input_ids = self._generate_batch(model, input_texts)
|
|
152
|
+
|
|
153
|
+
# Check for empty generation batch
|
|
154
|
+
if input_ids.numel() == 0:
|
|
155
|
+
print("Warning: Empty input_ids from tokenization")
|
|
156
|
+
return {}
|
|
157
|
+
|
|
125
158
|
with torch.no_grad():
|
|
126
|
-
for i in tqdm(
|
|
127
|
-
range(0, len(indices), self.batch_size), desc="Generating for metrics"
|
|
128
|
-
):
|
|
159
|
+
for i in tqdm(range(0, len(indices), self.batch_size), desc="Generating for metrics"):
|
|
129
160
|
input_ids_batch = input_ids[i : i + self.batch_size]
|
|
130
161
|
with torch.inference_mode(), torch.cuda.amp.autocast():
|
|
131
|
-
outputs_batch = model.generate(
|
|
132
|
-
input_ids_batch, **self.generation_config
|
|
133
|
-
)
|
|
162
|
+
outputs_batch = model.generate(input_ids_batch, **self.generation_config)
|
|
134
163
|
generated_texts = self.tokenizer.batch_decode(
|
|
135
164
|
outputs_batch[:, input_ids_batch.shape[1] :],
|
|
136
165
|
skip_special_tokens=True,
|
|
@@ -155,18 +184,18 @@ class GenerationMetricsCallback(TrainerCallback):
|
|
|
155
184
|
|
|
156
185
|
|
|
157
186
|
class MLflowLoggingCallback(TrainerCallback):
|
|
158
|
-
"""Callback for logging metrics to
|
|
187
|
+
"""Callback for logging metrics to tracking backend during training"""
|
|
159
188
|
|
|
160
189
|
def __init__(
|
|
161
190
|
self,
|
|
162
|
-
|
|
191
|
+
metric_logger,
|
|
163
192
|
mlflow_run_id: str,
|
|
164
193
|
excluded_keys: list = None,
|
|
165
194
|
completed_steps: int = 0,
|
|
166
195
|
chunk_id: int = 0,
|
|
167
196
|
num_epochs_completed: int = 0,
|
|
168
197
|
):
|
|
169
|
-
self.
|
|
198
|
+
self.metric_logger = metric_logger
|
|
170
199
|
self.mlflow_run_id = mlflow_run_id
|
|
171
200
|
self.completed_steps = completed_steps
|
|
172
201
|
self.excluded_keys = excluded_keys or [
|
|
@@ -189,22 +218,22 @@ class MLflowLoggingCallback(TrainerCallback):
|
|
|
189
218
|
for key, value in logs.items():
|
|
190
219
|
if isinstance(value, (int, float)) and key not in self.excluded_keys:
|
|
191
220
|
try:
|
|
192
|
-
self.
|
|
221
|
+
self.metric_logger.log_metric(
|
|
193
222
|
self.mlflow_run_id,
|
|
194
223
|
key,
|
|
195
224
|
value,
|
|
196
225
|
step=self.completed_steps + state.global_step,
|
|
197
226
|
)
|
|
198
227
|
except Exception as e:
|
|
199
|
-
print(f"Warning: Failed to log metric {key} to
|
|
228
|
+
print(f"Warning: Failed to log metric {key} to tracking backend: {e}")
|
|
200
229
|
if "eval_loss" not in logs and "train_runtime" not in logs:
|
|
201
|
-
self.
|
|
230
|
+
self.metric_logger.log_metric(
|
|
202
231
|
self.mlflow_run_id,
|
|
203
232
|
"chunk number",
|
|
204
233
|
self.chunk_id,
|
|
205
234
|
step=self.completed_steps + state.global_step,
|
|
206
235
|
)
|
|
207
|
-
self.
|
|
236
|
+
self.metric_logger.log_metric(
|
|
208
237
|
self.mlflow_run_id,
|
|
209
238
|
"num_epochs_completed",
|
|
210
239
|
self.num_epochs_completed,
|
|
@@ -217,7 +246,7 @@ class LogLevelCallback(TrainerCallback):
|
|
|
217
246
|
A [`TrainerCallback`] that handles the default flow of the training loop for logs, evaluation and checkpoints.
|
|
218
247
|
"""
|
|
219
248
|
|
|
220
|
-
def __init__(self, global_step_args:
|
|
249
|
+
def __init__(self, global_step_args: dict):
|
|
221
250
|
self.eval_first_step = global_step_args.get("eval_first_step", 0)
|
|
222
251
|
self.actual_steps = global_step_args.get("actual_steps", 0)
|
|
223
252
|
self.log_first_step = global_step_args.get("log_first_step", 0)
|
|
@@ -275,10 +304,7 @@ class LogLevelCallback(TrainerCallback):
|
|
|
275
304
|
control.should_log = True
|
|
276
305
|
|
|
277
306
|
# Evaluate
|
|
278
|
-
if
|
|
279
|
-
args.eval_strategy == IntervalStrategy.EPOCH
|
|
280
|
-
and args.eval_delay <= state.epoch
|
|
281
|
-
):
|
|
307
|
+
if args.eval_strategy == IntervalStrategy.EPOCH and args.eval_delay <= state.epoch:
|
|
282
308
|
control.should_evaluate = True
|
|
283
309
|
|
|
284
310
|
# Save
|
rapidfireai/ml/trainer.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import logging
|
|
2
1
|
import math
|
|
3
2
|
import os
|
|
4
3
|
|
|
@@ -7,11 +6,7 @@ from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dic
|
|
|
7
6
|
from transformers.utils.logging import set_verbosity_error
|
|
8
7
|
from trl import DPOConfig, DPOTrainer, GRPOConfig, GRPOTrainer, SFTConfig, SFTTrainer
|
|
9
8
|
|
|
10
|
-
from rapidfireai.ml.callbacks import
|
|
11
|
-
GenerationMetricsCallback,
|
|
12
|
-
MLflowLoggingCallback,
|
|
13
|
-
LogLevelCallback,
|
|
14
|
-
)
|
|
9
|
+
from rapidfireai.ml.callbacks import GenerationMetricsCallback, LogLevelCallback, MLflowLoggingCallback
|
|
15
10
|
from rapidfireai.ml.checkpoint_utils import (
|
|
16
11
|
ensure_gradient_compatibility,
|
|
17
12
|
load_checkpoint_from_disk,
|
|
@@ -34,7 +29,7 @@ def create_trainer_instance(
|
|
|
34
29
|
trainer_config: TrainerConfig,
|
|
35
30
|
shm_manager: SharedMemoryManager,
|
|
36
31
|
use_shared_memory: bool = False,
|
|
37
|
-
|
|
32
|
+
metric_logger=None,
|
|
38
33
|
chunk_id: int = 0,
|
|
39
34
|
) -> tuple[SFTTrainer | DPOTrainer | GRPOTrainer | None, str]:
|
|
40
35
|
"""
|
|
@@ -51,21 +46,15 @@ def create_trainer_instance(
|
|
|
51
46
|
compute_metrics = additional_trainer_kwargs.get("compute_metrics", None)
|
|
52
47
|
|
|
53
48
|
# Configure training arguments
|
|
54
|
-
training_args, global_step_args = _configure_training_args(
|
|
55
|
-
training_args, trainer_config
|
|
56
|
-
)
|
|
49
|
+
training_args, global_step_args = _configure_training_args(training_args, trainer_config)
|
|
57
50
|
trainer_config_obj = _create_trainer_config_object(trainer_type, training_args)
|
|
58
51
|
# check if peft params is empty dict
|
|
59
52
|
is_peft = bool(config_leaf.get("peft_params"))
|
|
60
53
|
# Load model and tokenizer
|
|
61
54
|
if use_shared_memory:
|
|
62
|
-
model_instance, tokenizer = load_checkpoint_from_shared_memory(
|
|
63
|
-
trainer_config, shm_manager, is_peft=is_peft
|
|
64
|
-
)
|
|
55
|
+
model_instance, tokenizer = load_checkpoint_from_shared_memory(trainer_config, shm_manager, is_peft=is_peft)
|
|
65
56
|
else:
|
|
66
|
-
model_instance, tokenizer = load_checkpoint_from_disk(
|
|
67
|
-
trainer_config, is_peft=is_peft
|
|
68
|
-
)
|
|
57
|
+
model_instance, tokenizer = load_checkpoint_from_disk(trainer_config, is_peft=is_peft)
|
|
69
58
|
# add model name to model config
|
|
70
59
|
config_leaf["model_name"] = model_instance.config._name_or_path
|
|
71
60
|
|
|
@@ -84,30 +73,26 @@ def create_trainer_instance(
|
|
|
84
73
|
|
|
85
74
|
model_instance = model_instance.to(device)
|
|
86
75
|
|
|
87
|
-
trainer_kwargs, formatting_func, additional_trainer_kwargs = (
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
config_leaf,
|
|
96
|
-
)
|
|
76
|
+
trainer_kwargs, formatting_func, additional_trainer_kwargs = _prepare_trainer_kwargs(
|
|
77
|
+
model_instance,
|
|
78
|
+
trainer_config_obj,
|
|
79
|
+
tokenizer,
|
|
80
|
+
trainer_config,
|
|
81
|
+
additional_trainer_kwargs,
|
|
82
|
+
ref_model_instance,
|
|
83
|
+
config_leaf,
|
|
97
84
|
)
|
|
98
85
|
|
|
99
|
-
callbacks, additional_trainer_kwargs = (
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
global_step_args,
|
|
110
|
-
)
|
|
86
|
+
callbacks, additional_trainer_kwargs = _setup_callbacks( # FIXME: avoid returning additional_trainer_kwargs
|
|
87
|
+
metric_logger,
|
|
88
|
+
trainer_config,
|
|
89
|
+
chunk_id,
|
|
90
|
+
compute_metrics,
|
|
91
|
+
additional_trainer_kwargs,
|
|
92
|
+
tokenizer,
|
|
93
|
+
training_args,
|
|
94
|
+
formatting_func,
|
|
95
|
+
global_step_args,
|
|
111
96
|
)
|
|
112
97
|
|
|
113
98
|
if callbacks:
|
|
@@ -116,29 +101,22 @@ def create_trainer_instance(
|
|
|
116
101
|
trainer_kwargs.update(additional_trainer_kwargs)
|
|
117
102
|
trainer_kwargs = {k: v for k, v in trainer_kwargs.items() if v is not None}
|
|
118
103
|
|
|
119
|
-
trainer = _create_trainer_by_type(
|
|
120
|
-
trainer_type, trainer_kwargs, trainer_config, use_shared_memory, shm_manager
|
|
121
|
-
)
|
|
104
|
+
trainer = _create_trainer_by_type(trainer_type, trainer_kwargs, trainer_config, use_shared_memory, shm_manager)
|
|
122
105
|
return trainer, config_leaf["model_name"]
|
|
123
106
|
|
|
124
107
|
|
|
125
|
-
def _configure_training_args(
|
|
126
|
-
training_args: dict, trainer_config: TrainerConfig
|
|
127
|
-
) -> dict:
|
|
108
|
+
def _configure_training_args(training_args: dict, trainer_config: TrainerConfig) -> dict:
|
|
128
109
|
"""Configure training arguments with default values."""
|
|
129
110
|
completed_steps = trainer_config.completed_steps
|
|
130
111
|
per_device_train_batch_size = training_args.get("per_device_train_batch_size", 1)
|
|
131
112
|
gradient_accumulation_steps = training_args.get("gradient_accumulation_steps", 1)
|
|
132
|
-
len_dataloader = math.ceil(
|
|
133
|
-
trainer_config.train_dataset.num_rows / per_device_train_batch_size
|
|
134
|
-
)
|
|
113
|
+
len_dataloader = math.ceil(trainer_config.train_dataset.num_rows / per_device_train_batch_size)
|
|
135
114
|
steps_per_epoch = max(
|
|
136
|
-
len_dataloader // gradient_accumulation_steps
|
|
137
|
-
+ int(len_dataloader % gradient_accumulation_steps > 0),
|
|
115
|
+
len_dataloader // gradient_accumulation_steps + int(len_dataloader % gradient_accumulation_steps > 0),
|
|
138
116
|
1,
|
|
139
117
|
)
|
|
140
118
|
|
|
141
|
-
if trainer_config.config_leaf.get("trainer_type","SFT") == "GRPO":
|
|
119
|
+
if trainer_config.config_leaf.get("trainer_type", "SFT") == "GRPO":
|
|
142
120
|
num_generations = training_args.get("num_generations", 8)
|
|
143
121
|
steps_per_epoch = (num_generations * trainer_config.train_dataset.num_rows) // (
|
|
144
122
|
gradient_accumulation_steps * per_device_train_batch_size
|
|
@@ -215,10 +193,7 @@ def _setup_reference_model(
|
|
|
215
193
|
if model_adapter_name is not None and ref_adapter_name is not None:
|
|
216
194
|
if use_shared_memory:
|
|
217
195
|
peft_config = LoraConfig(**config_leaf["peft_params"])
|
|
218
|
-
if
|
|
219
|
-
trainer_config.completed_steps == 0
|
|
220
|
-
and trainer_config.warm_started_from is None
|
|
221
|
-
):
|
|
196
|
+
if trainer_config.completed_steps == 0 and trainer_config.warm_started_from is None:
|
|
222
197
|
reference_state_dict = get_peft_model_state_dict(model_instance)
|
|
223
198
|
reference_state_dict = move_tensors_to_cpu(reference_state_dict)
|
|
224
199
|
shm_manager.save_model_object(
|
|
@@ -230,14 +205,10 @@ def _setup_reference_model(
|
|
|
230
205
|
reference_state_dict = shm_manager.load_model_object(
|
|
231
206
|
trainer_config.run_id, SHMObjectType.REF_STATE_DICT
|
|
232
207
|
)
|
|
233
|
-
reference_state_dict = move_tensors_to_device(
|
|
234
|
-
reference_state_dict, device
|
|
235
|
-
)
|
|
208
|
+
reference_state_dict = move_tensors_to_device(reference_state_dict, device)
|
|
236
209
|
model_instance.add_adapter(ref_adapter_name, peft_config)
|
|
237
210
|
model_instance.set_adapter(ref_adapter_name)
|
|
238
|
-
set_peft_model_state_dict(
|
|
239
|
-
model_instance, reference_state_dict, adapter_name=ref_adapter_name
|
|
240
|
-
)
|
|
211
|
+
set_peft_model_state_dict(model_instance, reference_state_dict, adapter_name=ref_adapter_name)
|
|
241
212
|
model_instance.set_adapter(model_adapter_name)
|
|
242
213
|
else:
|
|
243
214
|
base_run_path = DataPath.base_run_path(trainer_config.run_id)
|
|
@@ -289,9 +260,7 @@ def _prepare_trainer_kwargs(
|
|
|
289
260
|
|
|
290
261
|
if additional_trainer_kwargs.get("formatting_func") is not None:
|
|
291
262
|
formatting_func = additional_trainer_kwargs.get("formatting_func")
|
|
292
|
-
train_dataset = train_dataset.map(
|
|
293
|
-
formatting_func
|
|
294
|
-
) # FIXME: add try exception with batched/unbatched
|
|
263
|
+
train_dataset = train_dataset.map(formatting_func) # FIXME: add try exception with batched/unbatched
|
|
295
264
|
if eval_dataset is not None:
|
|
296
265
|
eval_dataset = eval_dataset.map(formatting_func)
|
|
297
266
|
additional_trainer_kwargs_copy = additional_trainer_kwargs.copy()
|
|
@@ -314,7 +283,7 @@ def _prepare_trainer_kwargs(
|
|
|
314
283
|
|
|
315
284
|
|
|
316
285
|
def _setup_callbacks(
|
|
317
|
-
|
|
286
|
+
metric_logger,
|
|
318
287
|
trainer_config,
|
|
319
288
|
chunk_id,
|
|
320
289
|
compute_metrics,
|
|
@@ -327,9 +296,9 @@ def _setup_callbacks(
|
|
|
327
296
|
"""Setup callbacks for the trainer."""
|
|
328
297
|
callbacks = []
|
|
329
298
|
|
|
330
|
-
if
|
|
299
|
+
if metric_logger is not None and trainer_config.mlflow_run_id is not None:
|
|
331
300
|
mlflow_callback = MLflowLoggingCallback(
|
|
332
|
-
|
|
301
|
+
metric_logger=metric_logger,
|
|
333
302
|
mlflow_run_id=trainer_config.mlflow_run_id,
|
|
334
303
|
completed_steps=trainer_config.completed_steps,
|
|
335
304
|
chunk_id=chunk_id,
|
|
@@ -337,10 +306,7 @@ def _setup_callbacks(
|
|
|
337
306
|
)
|
|
338
307
|
callbacks.append(mlflow_callback)
|
|
339
308
|
|
|
340
|
-
if (
|
|
341
|
-
compute_metrics is not None
|
|
342
|
-
and additional_trainer_kwargs.get("generation_config") is not None
|
|
343
|
-
):
|
|
309
|
+
if compute_metrics is not None and additional_trainer_kwargs.get("generation_config") is not None:
|
|
344
310
|
compute_metrics_function = compute_metrics
|
|
345
311
|
if formatting_func is not None:
|
|
346
312
|
formatted_eval_dataset = trainer_config.eval_dataset.map(formatting_func)
|
|
@@ -353,7 +319,7 @@ def _setup_callbacks(
|
|
|
353
319
|
generation_config=additional_trainer_kwargs.get("generation_config"),
|
|
354
320
|
compute_metrics=compute_metrics_function,
|
|
355
321
|
batch_size=training_args.get("per_device_eval_batch_size"),
|
|
356
|
-
|
|
322
|
+
metric_logger=metric_logger,
|
|
357
323
|
mlflow_run_id=trainer_config.mlflow_run_id,
|
|
358
324
|
completed_steps=trainer_config.completed_steps,
|
|
359
325
|
)
|
|
@@ -365,15 +331,11 @@ def _setup_callbacks(
|
|
|
365
331
|
return callbacks, additional_trainer_kwargs
|
|
366
332
|
|
|
367
333
|
|
|
368
|
-
def _create_trainer_by_type(
|
|
369
|
-
trainer_type, trainer_kwargs, trainer_config, use_shared_memory, shm_manager
|
|
370
|
-
):
|
|
334
|
+
def _create_trainer_by_type(trainer_type, trainer_kwargs, trainer_config, use_shared_memory, shm_manager):
|
|
371
335
|
"""Create trainer instance based on type with proper state restoration."""
|
|
372
336
|
if trainer_type == "SFT":
|
|
373
337
|
dummy_trainer = SFTTrainer(**trainer_kwargs)
|
|
374
|
-
dummy_trainer.create_optimizer_and_scheduler(
|
|
375
|
-
num_training_steps=trainer_config.total_steps
|
|
376
|
-
)
|
|
338
|
+
dummy_trainer.create_optimizer_and_scheduler(num_training_steps=trainer_config.total_steps)
|
|
377
339
|
trainer = SFTTrainer(
|
|
378
340
|
**trainer_kwargs,
|
|
379
341
|
optimizers=(dummy_trainer.optimizer, dummy_trainer.lr_scheduler),
|
|
@@ -382,9 +344,7 @@ def _create_trainer_by_type(
|
|
|
382
344
|
|
|
383
345
|
elif trainer_type == "DPO":
|
|
384
346
|
dummy_trainer = DPOTrainer(**trainer_kwargs)
|
|
385
|
-
dummy_trainer.create_optimizer_and_scheduler(
|
|
386
|
-
num_training_steps=trainer_config.total_steps
|
|
387
|
-
)
|
|
347
|
+
dummy_trainer.create_optimizer_and_scheduler(num_training_steps=trainer_config.total_steps)
|
|
388
348
|
trainer = DPOTrainer(
|
|
389
349
|
**trainer_kwargs,
|
|
390
350
|
optimizers=(dummy_trainer.optimizer, dummy_trainer.lr_scheduler),
|
|
@@ -393,9 +353,7 @@ def _create_trainer_by_type(
|
|
|
393
353
|
|
|
394
354
|
elif trainer_type == "GRPO":
|
|
395
355
|
dummy_trainer = GRPOTrainer(**trainer_kwargs)
|
|
396
|
-
dummy_trainer.create_optimizer_and_scheduler(
|
|
397
|
-
num_training_steps=trainer_config.total_steps
|
|
398
|
-
)
|
|
356
|
+
dummy_trainer.create_optimizer_and_scheduler(num_training_steps=trainer_config.total_steps)
|
|
399
357
|
trainer = GRPOTrainer(
|
|
400
358
|
**trainer_kwargs,
|
|
401
359
|
optimizers=(dummy_trainer.optimizer, dummy_trainer.lr_scheduler),
|
|
@@ -406,9 +364,7 @@ def _create_trainer_by_type(
|
|
|
406
364
|
|
|
407
365
|
if trainer_config.completed_steps > 0:
|
|
408
366
|
if use_shared_memory:
|
|
409
|
-
trainer = restore_trainer_from_shared_memory(
|
|
410
|
-
trainer, trainer_config, shm_manager
|
|
411
|
-
)
|
|
367
|
+
trainer = restore_trainer_from_shared_memory(trainer, trainer_config, shm_manager)
|
|
412
368
|
else:
|
|
413
369
|
trainer = restore_trainer_from_disk(trainer, trainer_config)
|
|
414
370
|
|