tuft 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tuft/__init__.py +5 -2
- tuft/__main__.py +7 -0
- tuft/auth.py +35 -0
- tuft/backend.py +254 -0
- tuft/backends/__init__.py +10 -0
- tuft/backends/base_backend.py +112 -0
- tuft/backends/hf_training_model.py +404 -0
- tuft/backends/sampling_backend.py +253 -0
- tuft/backends/training_backend.py +327 -0
- tuft/checkpoints.py +193 -0
- tuft/cli.py +124 -0
- tuft/config.py +123 -0
- tuft/exceptions.py +138 -0
- tuft/futures.py +431 -0
- tuft/loss_fn/__init__.py +48 -0
- tuft/loss_fn/cispo.py +40 -0
- tuft/loss_fn/cross_entropy.py +26 -0
- tuft/loss_fn/dro.py +37 -0
- tuft/loss_fn/importance_sampling.py +33 -0
- tuft/loss_fn/ppo.py +43 -0
- tuft/persistence/__init__.py +32 -0
- tuft/persistence/file_redis.py +268 -0
- tuft/persistence/redis_store.py +488 -0
- tuft/sampling_controller.py +368 -0
- tuft/server.py +720 -0
- tuft/state.py +352 -0
- tuft/telemetry/__init__.py +17 -0
- tuft/telemetry/metrics.py +335 -0
- tuft/telemetry/provider.py +198 -0
- tuft/telemetry/tracing.py +43 -0
- tuft/training_controller.py +728 -0
- tuft-0.1.2.dist-info/METADATA +633 -0
- tuft-0.1.2.dist-info/RECORD +36 -0
- {tuft-0.1.0.dist-info → tuft-0.1.2.dist-info}/WHEEL +1 -2
- tuft-0.1.2.dist-info/entry_points.txt +2 -0
- {tuft-0.1.0.dist-info → tuft-0.1.2.dist-info}/licenses/LICENSE +2 -2
- tuft-0.1.0.dist-info/METADATA +0 -77
- tuft-0.1.0.dist-info/RECORD +0 -6
- tuft-0.1.0.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,404 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import shutil
|
|
3
|
+
from typing import Dict
|
|
4
|
+
|
|
5
|
+
import ray
|
|
6
|
+
import torch
|
|
7
|
+
from opentelemetry.trace import StatusCode
|
|
8
|
+
from peft import LoraConfig, get_peft_model
|
|
9
|
+
from ray.actor import ActorProxy
|
|
10
|
+
from tinker import types
|
|
11
|
+
from tinker.types import LoraConfig as TinkerLoraConfig
|
|
12
|
+
from torch.nn.utils.rnn import pad_sequence
|
|
13
|
+
from transformers import AutoModelForCausalLM
|
|
14
|
+
|
|
15
|
+
from tuft.checkpoints import CheckpointRecord
|
|
16
|
+
from tuft.config import ModelConfig
|
|
17
|
+
from tuft.loss_fn import get_loss_fn
|
|
18
|
+
from tuft.telemetry.tracing import extract_context, get_tracer
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
_get_tracer = lambda: get_tracer("tuft.hf_training_model") # noqa: E731
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
MODULE_MAP = {
|
|
25
|
+
"llama": {
|
|
26
|
+
"attn": ["q_proj", "k_proj", "v_proj", "o_proj"],
|
|
27
|
+
"mlp": ["gate_proj", "up_proj", "down_proj"],
|
|
28
|
+
"unembed": ["lm_head"],
|
|
29
|
+
},
|
|
30
|
+
"qwen": {
|
|
31
|
+
"attn": ["q_proj", "k_proj", "v_proj", "o_proj"],
|
|
32
|
+
"mlp": ["gate_proj", "up_proj", "down_proj"],
|
|
33
|
+
"unembed": [], # set unembed will cause warning in Qwen models
|
|
34
|
+
},
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_target_modules(model_path: str, lora_config: TinkerLoraConfig) -> list[str]:
|
|
39
|
+
if "qwen" in model_path.lower():
|
|
40
|
+
mode_series = "qwen"
|
|
41
|
+
elif "llama" in model_path.lower():
|
|
42
|
+
mode_series = "llama"
|
|
43
|
+
else:
|
|
44
|
+
raise ValueError(f"Unsupported model series: {model_path}")
|
|
45
|
+
target_modules = []
|
|
46
|
+
if lora_config.train_attn:
|
|
47
|
+
target_modules.extend(MODULE_MAP[mode_series]["attn"])
|
|
48
|
+
if lora_config.train_mlp:
|
|
49
|
+
target_modules.extend(MODULE_MAP[mode_series]["mlp"])
|
|
50
|
+
if lora_config.train_unembed:
|
|
51
|
+
target_modules.extend(MODULE_MAP[mode_series]["unembed"])
|
|
52
|
+
return target_modules
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class HFTrainingModel:
|
|
56
|
+
def __init__(self, config: ModelConfig) -> None:
|
|
57
|
+
self.config = config
|
|
58
|
+
self.model = self._init_peft_model(config)
|
|
59
|
+
self.adapter_optimizer: Dict[str, torch.optim.AdamW] = {}
|
|
60
|
+
self._lock = asyncio.Lock()
|
|
61
|
+
|
|
62
|
+
async def async_init(self) -> None:
|
|
63
|
+
"""Do nothing for now. Just used to make sure the actor is ready."""
|
|
64
|
+
pass
|
|
65
|
+
|
|
66
|
+
# --------------------------------
|
|
67
|
+
# LoRA adapter management methods
|
|
68
|
+
# --------------------------------
|
|
69
|
+
async def create_adapter(
|
|
70
|
+
self,
|
|
71
|
+
lora_id: str,
|
|
72
|
+
lora_config: TinkerLoraConfig,
|
|
73
|
+
trace_context: dict[str, str] | None = None,
|
|
74
|
+
):
|
|
75
|
+
ctx = extract_context(trace_context or {})
|
|
76
|
+
with _get_tracer().start_as_current_span("hf_model.create_adapter", context=ctx) as span:
|
|
77
|
+
span.set_attribute("tuft.lora_id", lora_id)
|
|
78
|
+
try:
|
|
79
|
+
if lora_id in self.adapter_optimizer:
|
|
80
|
+
raise ValueError(f"Adapter {lora_id} already exists.")
|
|
81
|
+
peft_config = LoraConfig(
|
|
82
|
+
r=lora_config.rank,
|
|
83
|
+
target_modules=get_target_modules(str(self.config.model_path), lora_config),
|
|
84
|
+
# TODO: here we set lora_alpha equal to rank for common practice,
|
|
85
|
+
# but we may expose it in the future if needed.
|
|
86
|
+
lora_alpha=lora_config.rank,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
self.model.add_adapter(adapter_name=lora_id, peft_config=peft_config)
|
|
90
|
+
async with self._lock:
|
|
91
|
+
self.model.set_adapter(lora_id)
|
|
92
|
+
params = [p for p in self.model.parameters() if p.requires_grad]
|
|
93
|
+
self.adapter_optimizer[lora_id] = torch.optim.AdamW(params)
|
|
94
|
+
except Exception as e:
|
|
95
|
+
span.record_exception(e)
|
|
96
|
+
span.set_status(StatusCode.ERROR)
|
|
97
|
+
raise
|
|
98
|
+
|
|
99
|
+
async def save_state(
|
|
100
|
+
self,
|
|
101
|
+
lora_id: str,
|
|
102
|
+
checkpoint_record: CheckpointRecord,
|
|
103
|
+
optimizer: bool,
|
|
104
|
+
trace_context: dict[str, str] | None = None,
|
|
105
|
+
):
|
|
106
|
+
"""
|
|
107
|
+
Save LoRA adapter and optimizer state.
|
|
108
|
+
Args:
|
|
109
|
+
lora_id: The LoRA adapter ID to save.
|
|
110
|
+
checkpoint_record: The CheckpointRecord containing paths to save to.
|
|
111
|
+
optimizer: Whether to save the optimizer state.
|
|
112
|
+
trace_context: Optional trace context for distributed tracing.
|
|
113
|
+
"""
|
|
114
|
+
ctx = extract_context(trace_context or {})
|
|
115
|
+
with _get_tracer().start_as_current_span("hf_model.save_state", context=ctx) as span:
|
|
116
|
+
span.set_attribute("tuft.lora_id", lora_id)
|
|
117
|
+
span.set_attribute("tuft.optimizer", optimizer)
|
|
118
|
+
try:
|
|
119
|
+
if lora_id not in self.adapter_optimizer:
|
|
120
|
+
raise ValueError(f"Adapter {lora_id} not found.")
|
|
121
|
+
|
|
122
|
+
# 1. Save adapter (LoRA weights)
|
|
123
|
+
adapter_dir = checkpoint_record.adapter_path
|
|
124
|
+
adapter_dir.mkdir(parents=True, exist_ok=True)
|
|
125
|
+
# peft automatically creates a subdirectory with adapter name inside the given path
|
|
126
|
+
self.model.save_pretrained(str(adapter_dir), selected_adapters=[lora_id])
|
|
127
|
+
# move the files out of the subdirectory
|
|
128
|
+
lora_subdir = adapter_dir / lora_id
|
|
129
|
+
if lora_subdir.exists() and lora_subdir.is_dir():
|
|
130
|
+
for item in lora_subdir.iterdir():
|
|
131
|
+
dest = adapter_dir / item.name
|
|
132
|
+
if dest.exists():
|
|
133
|
+
if dest.is_file():
|
|
134
|
+
dest.unlink()
|
|
135
|
+
elif dest.is_dir():
|
|
136
|
+
shutil.rmtree(dest)
|
|
137
|
+
shutil.move(str(item), str(dest))
|
|
138
|
+
lora_subdir.rmdir()
|
|
139
|
+
|
|
140
|
+
# 2. Save optimizer state
|
|
141
|
+
if optimizer:
|
|
142
|
+
opt_dir = checkpoint_record.optimizer_path
|
|
143
|
+
opt_dir.mkdir(parents=True, exist_ok=True)
|
|
144
|
+
opt_state = self.adapter_optimizer[lora_id].state_dict()
|
|
145
|
+
opt_path = opt_dir / (f"{lora_id}.pt")
|
|
146
|
+
torch.save(opt_state, opt_path)
|
|
147
|
+
except Exception as e:
|
|
148
|
+
span.record_exception(e)
|
|
149
|
+
span.set_status(StatusCode.ERROR)
|
|
150
|
+
raise
|
|
151
|
+
|
|
152
|
+
async def load_state(
|
|
153
|
+
self,
|
|
154
|
+
lora_id: str,
|
|
155
|
+
checkpoint_record: CheckpointRecord,
|
|
156
|
+
optimizer: bool,
|
|
157
|
+
trace_context: dict[str, str] | None = None,
|
|
158
|
+
):
|
|
159
|
+
"""
|
|
160
|
+
Load LoRA adapter and optimizer state (standard format).
|
|
161
|
+
Args:
|
|
162
|
+
lora_id: The LoRA adapter ID to load.
|
|
163
|
+
checkpoint_record: The CheckpointRecord containing paths to load from.
|
|
164
|
+
optimizer: Whether to load the optimizer state.
|
|
165
|
+
trace_context: Optional trace context for distributed tracing.
|
|
166
|
+
"""
|
|
167
|
+
ctx = extract_context(trace_context or {})
|
|
168
|
+
with _get_tracer().start_as_current_span("hf_model.load_state", context=ctx) as span:
|
|
169
|
+
span.set_attribute("tuft.lora_id", lora_id)
|
|
170
|
+
span.set_attribute("tuft.optimizer", optimizer)
|
|
171
|
+
# 1. Load adapter
|
|
172
|
+
# find lora adapter name from the directory
|
|
173
|
+
self.model.load_adapter(
|
|
174
|
+
model_id=str(checkpoint_record.adapter_path), adapter_name=lora_id
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# 2. Load optimizer state if needed
|
|
178
|
+
async with self._lock:
|
|
179
|
+
self.model.set_adapter(lora_id)
|
|
180
|
+
params = [p for p in self.model.parameters() if p.requires_grad]
|
|
181
|
+
optimizer_obj = torch.optim.AdamW(params)
|
|
182
|
+
if optimizer:
|
|
183
|
+
opt_dir = checkpoint_record.optimizer_path
|
|
184
|
+
opt_path = opt_dir / f"{lora_id}.pt"
|
|
185
|
+
state_dict = None
|
|
186
|
+
if opt_path.exists():
|
|
187
|
+
state_dict = torch.load(opt_path)
|
|
188
|
+
if state_dict is not None:
|
|
189
|
+
optimizer_obj.load_state_dict(state_dict)
|
|
190
|
+
self.adapter_optimizer[lora_id] = optimizer_obj
|
|
191
|
+
|
|
192
|
+
async def remove_adapter(self, lora_id: str):
|
|
193
|
+
async with self._lock:
|
|
194
|
+
if lora_id in self.adapter_optimizer:
|
|
195
|
+
self.model.delete_adapter(lora_id)
|
|
196
|
+
self.adapter_optimizer.pop(lora_id)
|
|
197
|
+
|
|
198
|
+
# --------------------------------
|
|
199
|
+
# Training methods
|
|
200
|
+
# --------------------------------
|
|
201
|
+
async def forward(
|
|
202
|
+
self,
|
|
203
|
+
data: list[types.Datum],
|
|
204
|
+
lora_id: str,
|
|
205
|
+
loss_fn: types.LossFnType,
|
|
206
|
+
loss_fn_config: dict[str, float] | None,
|
|
207
|
+
backward: bool = False,
|
|
208
|
+
trace_context: dict[str, str] | None = None,
|
|
209
|
+
) -> types.ForwardBackwardOutput:
|
|
210
|
+
"""Forward pass (and backward if specified).
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
data: List of Datum objects containing input data.
|
|
214
|
+
lora_id: The LoRA adapter ID to use.
|
|
215
|
+
loss_fn: The loss function to apply.
|
|
216
|
+
loss_fn_config: Optional configuration for the loss function.
|
|
217
|
+
backward: Whether to perform backward pass.
|
|
218
|
+
trace_context: Optional trace context for distributed tracing.
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
ForwardBackwardOutput: The output of the forward (and backward) pass.
|
|
222
|
+
"""
|
|
223
|
+
ctx = extract_context(trace_context or {})
|
|
224
|
+
span_name = "hf_model.forward_backward" if backward else "hf_model.forward"
|
|
225
|
+
with _get_tracer().start_as_current_span(span_name, context=ctx) as span:
|
|
226
|
+
span.set_attribute("tuft.lora_id", lora_id)
|
|
227
|
+
span.set_attribute("tuft.backward", backward)
|
|
228
|
+
span.set_attribute("tuft.data_count", len(data))
|
|
229
|
+
# Prepare input tensors
|
|
230
|
+
input_ids = [
|
|
231
|
+
torch.tensor(datum.model_input.to_ints(), dtype=torch.long) for datum in data
|
|
232
|
+
]
|
|
233
|
+
input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=0)
|
|
234
|
+
attention_mask = (input_ids_padded != 0).long()
|
|
235
|
+
position_ids = (
|
|
236
|
+
torch.arange(input_ids_padded.size(1), dtype=torch.long)
|
|
237
|
+
.unsqueeze(0)
|
|
238
|
+
.expand(input_ids_padded.size(0), -1)
|
|
239
|
+
)
|
|
240
|
+
# Move tensors to model device
|
|
241
|
+
device = next(self.model.parameters()).device
|
|
242
|
+
input_ids_padded = input_ids_padded.to(device)
|
|
243
|
+
attention_mask = attention_mask.to(device)
|
|
244
|
+
position_ids = position_ids.to(device)
|
|
245
|
+
|
|
246
|
+
# Activate the correct adapter
|
|
247
|
+
async with self._lock:
|
|
248
|
+
self._activate_adapter(lora_id)
|
|
249
|
+
|
|
250
|
+
# Forward pass
|
|
251
|
+
outputs = self.model(
|
|
252
|
+
input_ids=input_ids_padded,
|
|
253
|
+
attention_mask=attention_mask,
|
|
254
|
+
position_ids=position_ids,
|
|
255
|
+
return_dict=True,
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
# Compute loss
|
|
259
|
+
if loss_fn_config is None:
|
|
260
|
+
loss_fn_config = {}
|
|
261
|
+
loss_fn_callable = get_loss_fn(loss_fn)
|
|
262
|
+
logits = outputs.logits
|
|
263
|
+
if "temperature" in loss_fn_config:
|
|
264
|
+
temperature = loss_fn_config["temperature"]
|
|
265
|
+
logits.div_(temperature)
|
|
266
|
+
|
|
267
|
+
loss_fn_inputs = self._prepare_loss_fn_inputs(data)
|
|
268
|
+
|
|
269
|
+
## compute target_logprobs from logits and target_tokens
|
|
270
|
+
target_tokens = loss_fn_inputs["target_tokens"]
|
|
271
|
+
target_logprobs = self._compute_logprobs_from_target_tokens(logits, target_tokens)
|
|
272
|
+
loss_fn_inputs["target_logprobs"] = target_logprobs
|
|
273
|
+
|
|
274
|
+
loss, metric = loss_fn_callable(loss_fn_inputs, loss_fn_config)
|
|
275
|
+
|
|
276
|
+
# Backward pass if needed
|
|
277
|
+
if backward:
|
|
278
|
+
loss.backward()
|
|
279
|
+
|
|
280
|
+
unpaded_logprobs = self._unpad_tensor(
|
|
281
|
+
target_logprobs, [len(datum.model_input.to_ints()) for datum in data]
|
|
282
|
+
)
|
|
283
|
+
return types.ForwardBackwardOutput(
|
|
284
|
+
loss_fn_output_type=loss_fn,
|
|
285
|
+
loss_fn_outputs=[
|
|
286
|
+
{"logprobs": types.TensorData.from_torch(logprobs.detach())}
|
|
287
|
+
for logprobs in unpaded_logprobs
|
|
288
|
+
],
|
|
289
|
+
metrics=metric,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
async def optim_step(
|
|
293
|
+
self,
|
|
294
|
+
adam_params: types.AdamParams,
|
|
295
|
+
lora_id: str,
|
|
296
|
+
trace_context: dict[str, str] | None = None,
|
|
297
|
+
) -> types.OptimStepResponse:
|
|
298
|
+
"""Perform an optimization step using Adam optimizer.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
adam_params: Parameters for the Adam optimizer.
|
|
302
|
+
lora_id: The LoRA adapter ID to use.
|
|
303
|
+
trace_context: Optional trace context for distributed tracing.
|
|
304
|
+
|
|
305
|
+
Returns:
|
|
306
|
+
OptimStepResponse: The response containing optimization metrics.
|
|
307
|
+
"""
|
|
308
|
+
ctx = extract_context(trace_context or {})
|
|
309
|
+
with _get_tracer().start_as_current_span("hf_model.optim_step", context=ctx) as span:
|
|
310
|
+
span.set_attribute("tuft.lora_id", lora_id)
|
|
311
|
+
optimizer = self.adapter_optimizer[lora_id]
|
|
312
|
+
for param_group in optimizer.param_groups:
|
|
313
|
+
param_group["lr"] = adam_params.learning_rate
|
|
314
|
+
param_group["betas"] = (adam_params.beta1, adam_params.beta2)
|
|
315
|
+
param_group["eps"] = adam_params.eps
|
|
316
|
+
param_group["weight_decay"] = adam_params.weight_decay
|
|
317
|
+
optimizer.step()
|
|
318
|
+
optimizer.zero_grad()
|
|
319
|
+
return types.OptimStepResponse()
|
|
320
|
+
|
|
321
|
+
# --------------------------------
|
|
322
|
+
# Helper methods
|
|
323
|
+
# --------------------------------
|
|
324
|
+
def _prepare_loss_fn_inputs(self, data: list[types.Datum]) -> Dict[str, torch.Tensor]:
|
|
325
|
+
"""Prepare input tensors from Datum list."""
|
|
326
|
+
device = next(self.model.parameters()).device
|
|
327
|
+
|
|
328
|
+
loss_fn_input_dict = {}
|
|
329
|
+
# prepare loss_fn_inputs tensors
|
|
330
|
+
loss_fn_input_keys = data[0].loss_fn_inputs.keys()
|
|
331
|
+
for key in loss_fn_input_keys:
|
|
332
|
+
tensors = [datum.loss_fn_inputs[key].to_torch() for datum in data]
|
|
333
|
+
# If tensor is 1D, pad to max length; if already same shape, stack directly
|
|
334
|
+
if all(t.dim() == 1 for t in tensors):
|
|
335
|
+
padded = pad_sequence(tensors, batch_first=True, padding_value=0)
|
|
336
|
+
loss_fn_input_dict[key] = padded.to(device)
|
|
337
|
+
else:
|
|
338
|
+
# Try to stack, if shape mismatch, pad last dim
|
|
339
|
+
try:
|
|
340
|
+
stacked = torch.stack(tensors)
|
|
341
|
+
loss_fn_input_dict[key] = stacked.to(device)
|
|
342
|
+
except Exception:
|
|
343
|
+
# Pad last dim to max length
|
|
344
|
+
max_shape = list(tensors[0].shape)
|
|
345
|
+
for t in tensors:
|
|
346
|
+
for i, s in enumerate(t.shape):
|
|
347
|
+
if s > max_shape[i]:
|
|
348
|
+
max_shape[i] = s
|
|
349
|
+
padded_tensors = []
|
|
350
|
+
for t in tensors:
|
|
351
|
+
pad_width = [(0, m - s) for s, m in zip(t.shape, max_shape, strict=False)]
|
|
352
|
+
pad_args = []
|
|
353
|
+
for p in reversed(pad_width):
|
|
354
|
+
pad_args.extend(p)
|
|
355
|
+
padded = torch.nn.functional.pad(t, pad_args, value=0)
|
|
356
|
+
padded_tensors.append(padded)
|
|
357
|
+
stacked = torch.stack(padded_tensors)
|
|
358
|
+
loss_fn_input_dict[key] = stacked.to(device)
|
|
359
|
+
|
|
360
|
+
return loss_fn_input_dict
|
|
361
|
+
|
|
362
|
+
def _compute_logprobs_from_target_tokens(
|
|
363
|
+
self, logits: torch.Tensor, target_tokens: torch.Tensor
|
|
364
|
+
) -> torch.Tensor:
|
|
365
|
+
logits_labels = torch.gather(logits, dim=-1, index=target_tokens.unsqueeze(-1)).squeeze(-1)
|
|
366
|
+
# loop to reduce peak mem consumption
|
|
367
|
+
logsumexp_values = torch.stack([torch.logsumexp(logit, dim=-1) for logit in logits])
|
|
368
|
+
logprobs_labels = logits_labels - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
|
369
|
+
return logprobs_labels
|
|
370
|
+
|
|
371
|
+
def _unpad_tensor(
|
|
372
|
+
self, padded_tensor: torch.Tensor, original_lengths: list[int]
|
|
373
|
+
) -> list[torch.Tensor]:
|
|
374
|
+
"""Unpad a padded tensor back to list of tensors with original lengths."""
|
|
375
|
+
tensors = []
|
|
376
|
+
for i, length in enumerate(original_lengths):
|
|
377
|
+
tensors.append(padded_tensor[i, :length])
|
|
378
|
+
return tensors
|
|
379
|
+
|
|
380
|
+
def _init_peft_model(self, config: ModelConfig):
|
|
381
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
382
|
+
str(config.model_path),
|
|
383
|
+
dtype="auto",
|
|
384
|
+
device_map="auto",
|
|
385
|
+
)
|
|
386
|
+
peft_config = LoraConfig()
|
|
387
|
+
peft_model = get_peft_model(model, peft_config=peft_config, adapter_name="default")
|
|
388
|
+
return peft_model
|
|
389
|
+
|
|
390
|
+
def _activate_adapter(self, lora_id: str):
|
|
391
|
+
if lora_id not in self.adapter_optimizer:
|
|
392
|
+
raise ValueError(f"Adapter {lora_id} not found.")
|
|
393
|
+
self.model.set_adapter(lora_id)
|
|
394
|
+
|
|
395
|
+
@classmethod
|
|
396
|
+
def get_actor(cls, config: ModelConfig) -> "ActorProxy":
|
|
397
|
+
return (
|
|
398
|
+
ray.remote(cls)
|
|
399
|
+
.options(
|
|
400
|
+
name="training_model_" + config.model_name,
|
|
401
|
+
num_gpus=1 if not config.colocate else 1 - config.sampling_memory_fraction,
|
|
402
|
+
)
|
|
403
|
+
.remote(config)
|
|
404
|
+
)
|
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
"""Sampling backend implementated using vLLM"""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from logging import getLogger
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
from opentelemetry.trace import StatusCode
|
|
9
|
+
from tinker import types
|
|
10
|
+
|
|
11
|
+
from ..config import ModelConfig
|
|
12
|
+
from ..telemetry.tracing import get_tracer
|
|
13
|
+
from .base_backend import BaseSamplingBackend
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
_get_tracer = lambda: get_tracer("tuft.sampling_backend") # noqa: E731
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
logger = getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class VLLMSamplingBackend(BaseSamplingBackend):
|
|
23
|
+
"""A sampling backend using vLLM.
|
|
24
|
+
|
|
25
|
+
User side `sample`, `sample_async`, `compute_logprobs` and
|
|
26
|
+
`compute_logprobs_async` are all supported by the sample method.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, config: ModelConfig) -> None:
|
|
30
|
+
from vllm.lora.request import LoRARequest
|
|
31
|
+
|
|
32
|
+
super().__init__(config)
|
|
33
|
+
self.engine = self._create_engine(config)
|
|
34
|
+
self.lora_adapters: dict[str, LoRARequest] = {}
|
|
35
|
+
self._counter = 1
|
|
36
|
+
self._lock = asyncio.Lock()
|
|
37
|
+
|
|
38
|
+
def _create_engine(self, config: ModelConfig):
|
|
39
|
+
if config.colocate:
|
|
40
|
+
return self._create_colocated_engine(config)
|
|
41
|
+
else:
|
|
42
|
+
return self._create_standalone_engine(config)
|
|
43
|
+
|
|
44
|
+
def _create_colocated_engine(self, config: ModelConfig):
|
|
45
|
+
import ray
|
|
46
|
+
from trinity.common.config import InferenceModelConfig
|
|
47
|
+
from trinity.common.models.vllm_model import vLLMRolloutModel
|
|
48
|
+
|
|
49
|
+
return (
|
|
50
|
+
ray.remote(vLLMRolloutModel)
|
|
51
|
+
.options(
|
|
52
|
+
name="sampling_model_" + self.base_model,
|
|
53
|
+
num_gpus=config.sampling_memory_fraction,
|
|
54
|
+
)
|
|
55
|
+
.remote(
|
|
56
|
+
config=InferenceModelConfig(
|
|
57
|
+
model_path=str(config.model_path),
|
|
58
|
+
tensor_parallel_size=1,
|
|
59
|
+
max_model_len=config.max_model_len,
|
|
60
|
+
temperature=config.temperature,
|
|
61
|
+
top_p=config.top_p,
|
|
62
|
+
top_k=config.top_k,
|
|
63
|
+
logprobs=config.logprobs,
|
|
64
|
+
min_response_tokens=config.min_response_tokens,
|
|
65
|
+
repetition_penalty=1.0,
|
|
66
|
+
enable_lora=True,
|
|
67
|
+
enable_runtime_lora_updating=True,
|
|
68
|
+
lora_kwargs={
|
|
69
|
+
"max_lora_rank": config.max_lora_rank,
|
|
70
|
+
"max_loras": config.max_loras,
|
|
71
|
+
},
|
|
72
|
+
# sampling use less memory than training
|
|
73
|
+
gpu_memory_utilization=config.sampling_memory_fraction,
|
|
74
|
+
)
|
|
75
|
+
)
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def _create_standalone_engine(self, config: ModelConfig):
|
|
79
|
+
import ray
|
|
80
|
+
from ray.util.placement_group import placement_group
|
|
81
|
+
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
|
82
|
+
from trinity.common.config import InferenceModelConfig
|
|
83
|
+
from trinity.common.models.vllm_model import vLLMRolloutModel
|
|
84
|
+
|
|
85
|
+
# create a placement group for this model
|
|
86
|
+
pg = placement_group(
|
|
87
|
+
[{"CPU": 1, "GPU": 1} for _ in range(config.tensor_parallel_size)],
|
|
88
|
+
strategy="PACK",
|
|
89
|
+
)
|
|
90
|
+
ray.get(pg.ready(), timeout=10)
|
|
91
|
+
return (
|
|
92
|
+
ray.remote(vLLMRolloutModel)
|
|
93
|
+
.options(
|
|
94
|
+
name="sampling_model_" + self.base_model,
|
|
95
|
+
num_gpus=0 if config.tensor_parallel_size > 1 else 1,
|
|
96
|
+
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
|
97
|
+
placement_group=pg,
|
|
98
|
+
placement_group_capture_child_tasks=True,
|
|
99
|
+
),
|
|
100
|
+
)
|
|
101
|
+
.remote(
|
|
102
|
+
config=InferenceModelConfig(
|
|
103
|
+
model_path=str(config.model_path),
|
|
104
|
+
tensor_parallel_size=config.tensor_parallel_size,
|
|
105
|
+
max_model_len=config.max_model_len,
|
|
106
|
+
temperature=config.temperature,
|
|
107
|
+
top_p=config.top_p,
|
|
108
|
+
top_k=config.top_k,
|
|
109
|
+
logprobs=config.logprobs,
|
|
110
|
+
min_response_tokens=config.min_response_tokens,
|
|
111
|
+
repetition_penalty=1.0,
|
|
112
|
+
enable_lora=True,
|
|
113
|
+
enable_runtime_lora_updating=True,
|
|
114
|
+
lora_kwargs={
|
|
115
|
+
"max_lora_rank": config.max_lora_rank,
|
|
116
|
+
"max_loras": config.max_loras,
|
|
117
|
+
},
|
|
118
|
+
)
|
|
119
|
+
)
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
async def async_init(self) -> None:
|
|
123
|
+
"""Initialize the backend for sampling."""
|
|
124
|
+
# Ray @ray.remote decorator adds .remote() method dynamically
|
|
125
|
+
await self.engine.prepare.remote() # type: ignore[attr-defined]
|
|
126
|
+
logger.info(f"SamplingBackend for model {self.base_model} initialized.")
|
|
127
|
+
|
|
128
|
+
async def sample(
|
|
129
|
+
self,
|
|
130
|
+
prompt: types.ModelInput,
|
|
131
|
+
num_samples: int,
|
|
132
|
+
sampling_params: types.SamplingParams,
|
|
133
|
+
include_prompt_logprobs: bool = False,
|
|
134
|
+
topk_prompt_logprobs: int = 0,
|
|
135
|
+
lora_id: Optional[str] = None,
|
|
136
|
+
) -> types.SampleResponse:
|
|
137
|
+
"""Sampling using vLLM engine."""
|
|
138
|
+
with _get_tracer().start_as_current_span("sampling_backend.sample") as span:
|
|
139
|
+
span.set_attribute("tuft.num_samples", num_samples)
|
|
140
|
+
span.set_attribute("tuft.has_lora", lora_id is not None)
|
|
141
|
+
try:
|
|
142
|
+
async with self._lock:
|
|
143
|
+
if lora_id is not None and lora_id not in self.lora_adapters:
|
|
144
|
+
raise ValueError(f"LoRA adapter {lora_id} not found in backend.")
|
|
145
|
+
lora_request = self.lora_adapters[lora_id] if lora_id is not None else None
|
|
146
|
+
# Ray @ray.remote decorator adds .remote() method dynamically
|
|
147
|
+
return await self.engine.sample.remote( # type: ignore[attr-defined]
|
|
148
|
+
prompt=prompt,
|
|
149
|
+
num_samples=num_samples,
|
|
150
|
+
sampling_params=sampling_params,
|
|
151
|
+
include_prompt_logprobs=include_prompt_logprobs,
|
|
152
|
+
topk_prompt_logprobs=topk_prompt_logprobs,
|
|
153
|
+
lora_request=lora_request,
|
|
154
|
+
)
|
|
155
|
+
except Exception as e:
|
|
156
|
+
span.record_exception(e)
|
|
157
|
+
span.set_status(StatusCode.ERROR)
|
|
158
|
+
raise
|
|
159
|
+
|
|
160
|
+
async def add_adapter(self, lora_id: str, adapter_path: Path) -> None:
|
|
161
|
+
from vllm.lora.request import LoRARequest
|
|
162
|
+
|
|
163
|
+
with _get_tracer().start_as_current_span("sampling_backend.add_adapter") as span:
|
|
164
|
+
span.set_attribute("tuft.lora_id", lora_id)
|
|
165
|
+
try:
|
|
166
|
+
async with self._lock:
|
|
167
|
+
self._counter += 1
|
|
168
|
+
self.lora_adapters[lora_id] = LoRARequest(
|
|
169
|
+
lora_int_id=self._counter + 1,
|
|
170
|
+
lora_name=lora_id,
|
|
171
|
+
lora_path=str(adapter_path),
|
|
172
|
+
)
|
|
173
|
+
if not adapter_path.exists():
|
|
174
|
+
raise ValueError(f"LoRA adapter path {adapter_path} does not exist.")
|
|
175
|
+
await self.engine.add_lora_adapter.remote(self.lora_adapters[lora_id])
|
|
176
|
+
except Exception as e:
|
|
177
|
+
span.record_exception(e)
|
|
178
|
+
span.set_status(StatusCode.ERROR)
|
|
179
|
+
raise
|
|
180
|
+
|
|
181
|
+
async def remove_adapter(self, lora_id: str) -> None:
|
|
182
|
+
with _get_tracer().start_as_current_span("sampling_backend.remove_adapter") as span:
|
|
183
|
+
span.set_attribute("tuft.lora_id", lora_id)
|
|
184
|
+
async with self._lock:
|
|
185
|
+
if lora_id in self.lora_adapters:
|
|
186
|
+
await self.engine.remove_lora_adapter.remote(lora_id)
|
|
187
|
+
del self.lora_adapters[lora_id]
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class DummySamplingBackend(BaseSamplingBackend):
|
|
191
|
+
"""A dummy sampling backend that returns fixed responses for unittest."""
|
|
192
|
+
|
|
193
|
+
def __init__(self, config: ModelConfig) -> None:
|
|
194
|
+
super().__init__(config)
|
|
195
|
+
self.lora_adapters: dict[str, Path] = {}
|
|
196
|
+
self._counter = 1
|
|
197
|
+
self._lock = asyncio.Lock()
|
|
198
|
+
|
|
199
|
+
async def async_init(self) -> None:
|
|
200
|
+
"""No initialization needed for dummy backend."""
|
|
201
|
+
pass
|
|
202
|
+
|
|
203
|
+
async def sample(
|
|
204
|
+
self,
|
|
205
|
+
prompt: types.ModelInput,
|
|
206
|
+
num_samples: int,
|
|
207
|
+
sampling_params: types.SamplingParams,
|
|
208
|
+
include_prompt_logprobs: bool = False,
|
|
209
|
+
topk_prompt_logprobs: int = 0,
|
|
210
|
+
lora_id: Optional[str] = None,
|
|
211
|
+
) -> types.SampleResponse:
|
|
212
|
+
"""Return a fixed dummy response."""
|
|
213
|
+
prompt_tokens = prompt.to_ints()
|
|
214
|
+
max_tokens = sampling_params.max_tokens or 16
|
|
215
|
+
sequences: list[types.SampledSequence] = []
|
|
216
|
+
for _ in range(num_samples):
|
|
217
|
+
generated = self._generate_tokens(prompt_tokens, max_tokens)
|
|
218
|
+
seq = types.SampledSequence(
|
|
219
|
+
stop_reason="length",
|
|
220
|
+
tokens=generated,
|
|
221
|
+
logprobs=[-0.3 for _ in generated],
|
|
222
|
+
)
|
|
223
|
+
sequences.append(seq)
|
|
224
|
+
prompt_logprobs = None
|
|
225
|
+
topk_prompt = None
|
|
226
|
+
if include_prompt_logprobs:
|
|
227
|
+
prompt_logprobs = [-0.1 if tok is not None else None for tok in prompt_tokens]
|
|
228
|
+
if topk_prompt_logprobs > 0:
|
|
229
|
+
topk_prompt = [
|
|
230
|
+
[
|
|
231
|
+
(token, round(-0.05 - idx * 0.01, 4))
|
|
232
|
+
for idx, token in enumerate(prompt_tokens[:topk_prompt_logprobs])
|
|
233
|
+
]
|
|
234
|
+
if token is not None
|
|
235
|
+
else None
|
|
236
|
+
for token in prompt_tokens
|
|
237
|
+
]
|
|
238
|
+
return types.SampleResponse(
|
|
239
|
+
sequences=sequences,
|
|
240
|
+
prompt_logprobs=prompt_logprobs,
|
|
241
|
+
topk_prompt_logprobs=topk_prompt,
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
def _generate_tokens(self, prompt_tokens: list[int], max_tokens: int) -> list[int]:
|
|
245
|
+
start = prompt_tokens[-1] if prompt_tokens else (abs(self.config.seed) % 32000) + 1
|
|
246
|
+
return [(start + i) % 32000 for i in range(1, max_tokens + 1)]
|
|
247
|
+
|
|
248
|
+
async def add_adapter(self, lora_id: str, adapter_path: Path) -> None:
|
|
249
|
+
self.lora_adapters[lora_id] = adapter_path
|
|
250
|
+
|
|
251
|
+
async def remove_adapter(self, lora_id: str) -> None:
|
|
252
|
+
if lora_id in self.lora_adapters:
|
|
253
|
+
del self.lora_adapters[lora_id]
|