tuft 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,404 @@
1
+ import asyncio
2
+ import shutil
3
+ from typing import Dict
4
+
5
+ import ray
6
+ import torch
7
+ from opentelemetry.trace import StatusCode
8
+ from peft import LoraConfig, get_peft_model
9
+ from ray.actor import ActorProxy
10
+ from tinker import types
11
+ from tinker.types import LoraConfig as TinkerLoraConfig
12
+ from torch.nn.utils.rnn import pad_sequence
13
+ from transformers import AutoModelForCausalLM
14
+
15
+ from tuft.checkpoints import CheckpointRecord
16
+ from tuft.config import ModelConfig
17
+ from tuft.loss_fn import get_loss_fn
18
+ from tuft.telemetry.tracing import extract_context, get_tracer
19
+
20
+
21
+ _get_tracer = lambda: get_tracer("tuft.hf_training_model") # noqa: E731
22
+
23
+
24
+ MODULE_MAP = {
25
+ "llama": {
26
+ "attn": ["q_proj", "k_proj", "v_proj", "o_proj"],
27
+ "mlp": ["gate_proj", "up_proj", "down_proj"],
28
+ "unembed": ["lm_head"],
29
+ },
30
+ "qwen": {
31
+ "attn": ["q_proj", "k_proj", "v_proj", "o_proj"],
32
+ "mlp": ["gate_proj", "up_proj", "down_proj"],
33
+ "unembed": [], # set unembed will cause warning in Qwen models
34
+ },
35
+ }
36
+
37
+
38
+ def get_target_modules(model_path: str, lora_config: TinkerLoraConfig) -> list[str]:
39
+ if "qwen" in model_path.lower():
40
+ mode_series = "qwen"
41
+ elif "llama" in model_path.lower():
42
+ mode_series = "llama"
43
+ else:
44
+ raise ValueError(f"Unsupported model series: {model_path}")
45
+ target_modules = []
46
+ if lora_config.train_attn:
47
+ target_modules.extend(MODULE_MAP[mode_series]["attn"])
48
+ if lora_config.train_mlp:
49
+ target_modules.extend(MODULE_MAP[mode_series]["mlp"])
50
+ if lora_config.train_unembed:
51
+ target_modules.extend(MODULE_MAP[mode_series]["unembed"])
52
+ return target_modules
53
+
54
+
55
+ class HFTrainingModel:
56
+ def __init__(self, config: ModelConfig) -> None:
57
+ self.config = config
58
+ self.model = self._init_peft_model(config)
59
+ self.adapter_optimizer: Dict[str, torch.optim.AdamW] = {}
60
+ self._lock = asyncio.Lock()
61
+
62
+ async def async_init(self) -> None:
63
+ """Do nothing for now. Just used to make sure the actor is ready."""
64
+ pass
65
+
66
+ # --------------------------------
67
+ # LoRA adapter management methods
68
+ # --------------------------------
69
+ async def create_adapter(
70
+ self,
71
+ lora_id: str,
72
+ lora_config: TinkerLoraConfig,
73
+ trace_context: dict[str, str] | None = None,
74
+ ):
75
+ ctx = extract_context(trace_context or {})
76
+ with _get_tracer().start_as_current_span("hf_model.create_adapter", context=ctx) as span:
77
+ span.set_attribute("tuft.lora_id", lora_id)
78
+ try:
79
+ if lora_id in self.adapter_optimizer:
80
+ raise ValueError(f"Adapter {lora_id} already exists.")
81
+ peft_config = LoraConfig(
82
+ r=lora_config.rank,
83
+ target_modules=get_target_modules(str(self.config.model_path), lora_config),
84
+ # TODO: here we set lora_alpha equal to rank for common practice,
85
+ # but we may expose it in the future if needed.
86
+ lora_alpha=lora_config.rank,
87
+ )
88
+
89
+ self.model.add_adapter(adapter_name=lora_id, peft_config=peft_config)
90
+ async with self._lock:
91
+ self.model.set_adapter(lora_id)
92
+ params = [p for p in self.model.parameters() if p.requires_grad]
93
+ self.adapter_optimizer[lora_id] = torch.optim.AdamW(params)
94
+ except Exception as e:
95
+ span.record_exception(e)
96
+ span.set_status(StatusCode.ERROR)
97
+ raise
98
+
99
+ async def save_state(
100
+ self,
101
+ lora_id: str,
102
+ checkpoint_record: CheckpointRecord,
103
+ optimizer: bool,
104
+ trace_context: dict[str, str] | None = None,
105
+ ):
106
+ """
107
+ Save LoRA adapter and optimizer state.
108
+ Args:
109
+ lora_id: The LoRA adapter ID to save.
110
+ checkpoint_record: The CheckpointRecord containing paths to save to.
111
+ optimizer: Whether to save the optimizer state.
112
+ trace_context: Optional trace context for distributed tracing.
113
+ """
114
+ ctx = extract_context(trace_context or {})
115
+ with _get_tracer().start_as_current_span("hf_model.save_state", context=ctx) as span:
116
+ span.set_attribute("tuft.lora_id", lora_id)
117
+ span.set_attribute("tuft.optimizer", optimizer)
118
+ try:
119
+ if lora_id not in self.adapter_optimizer:
120
+ raise ValueError(f"Adapter {lora_id} not found.")
121
+
122
+ # 1. Save adapter (LoRA weights)
123
+ adapter_dir = checkpoint_record.adapter_path
124
+ adapter_dir.mkdir(parents=True, exist_ok=True)
125
+ # peft automatically creates a subdirectory with adapter name inside the given path
126
+ self.model.save_pretrained(str(adapter_dir), selected_adapters=[lora_id])
127
+ # move the files out of the subdirectory
128
+ lora_subdir = adapter_dir / lora_id
129
+ if lora_subdir.exists() and lora_subdir.is_dir():
130
+ for item in lora_subdir.iterdir():
131
+ dest = adapter_dir / item.name
132
+ if dest.exists():
133
+ if dest.is_file():
134
+ dest.unlink()
135
+ elif dest.is_dir():
136
+ shutil.rmtree(dest)
137
+ shutil.move(str(item), str(dest))
138
+ lora_subdir.rmdir()
139
+
140
+ # 2. Save optimizer state
141
+ if optimizer:
142
+ opt_dir = checkpoint_record.optimizer_path
143
+ opt_dir.mkdir(parents=True, exist_ok=True)
144
+ opt_state = self.adapter_optimizer[lora_id].state_dict()
145
+ opt_path = opt_dir / (f"{lora_id}.pt")
146
+ torch.save(opt_state, opt_path)
147
+ except Exception as e:
148
+ span.record_exception(e)
149
+ span.set_status(StatusCode.ERROR)
150
+ raise
151
+
152
+ async def load_state(
153
+ self,
154
+ lora_id: str,
155
+ checkpoint_record: CheckpointRecord,
156
+ optimizer: bool,
157
+ trace_context: dict[str, str] | None = None,
158
+ ):
159
+ """
160
+ Load LoRA adapter and optimizer state (standard format).
161
+ Args:
162
+ lora_id: The LoRA adapter ID to load.
163
+ checkpoint_record: The CheckpointRecord containing paths to load from.
164
+ optimizer: Whether to load the optimizer state.
165
+ trace_context: Optional trace context for distributed tracing.
166
+ """
167
+ ctx = extract_context(trace_context or {})
168
+ with _get_tracer().start_as_current_span("hf_model.load_state", context=ctx) as span:
169
+ span.set_attribute("tuft.lora_id", lora_id)
170
+ span.set_attribute("tuft.optimizer", optimizer)
171
+ # 1. Load adapter
172
+ # find lora adapter name from the directory
173
+ self.model.load_adapter(
174
+ model_id=str(checkpoint_record.adapter_path), adapter_name=lora_id
175
+ )
176
+
177
+ # 2. Load optimizer state if needed
178
+ async with self._lock:
179
+ self.model.set_adapter(lora_id)
180
+ params = [p for p in self.model.parameters() if p.requires_grad]
181
+ optimizer_obj = torch.optim.AdamW(params)
182
+ if optimizer:
183
+ opt_dir = checkpoint_record.optimizer_path
184
+ opt_path = opt_dir / f"{lora_id}.pt"
185
+ state_dict = None
186
+ if opt_path.exists():
187
+ state_dict = torch.load(opt_path)
188
+ if state_dict is not None:
189
+ optimizer_obj.load_state_dict(state_dict)
190
+ self.adapter_optimizer[lora_id] = optimizer_obj
191
+
192
+ async def remove_adapter(self, lora_id: str):
193
+ async with self._lock:
194
+ if lora_id in self.adapter_optimizer:
195
+ self.model.delete_adapter(lora_id)
196
+ self.adapter_optimizer.pop(lora_id)
197
+
198
+ # --------------------------------
199
+ # Training methods
200
+ # --------------------------------
201
+ async def forward(
202
+ self,
203
+ data: list[types.Datum],
204
+ lora_id: str,
205
+ loss_fn: types.LossFnType,
206
+ loss_fn_config: dict[str, float] | None,
207
+ backward: bool = False,
208
+ trace_context: dict[str, str] | None = None,
209
+ ) -> types.ForwardBackwardOutput:
210
+ """Forward pass (and backward if specified).
211
+
212
+ Args:
213
+ data: List of Datum objects containing input data.
214
+ lora_id: The LoRA adapter ID to use.
215
+ loss_fn: The loss function to apply.
216
+ loss_fn_config: Optional configuration for the loss function.
217
+ backward: Whether to perform backward pass.
218
+ trace_context: Optional trace context for distributed tracing.
219
+
220
+ Returns:
221
+ ForwardBackwardOutput: The output of the forward (and backward) pass.
222
+ """
223
+ ctx = extract_context(trace_context or {})
224
+ span_name = "hf_model.forward_backward" if backward else "hf_model.forward"
225
+ with _get_tracer().start_as_current_span(span_name, context=ctx) as span:
226
+ span.set_attribute("tuft.lora_id", lora_id)
227
+ span.set_attribute("tuft.backward", backward)
228
+ span.set_attribute("tuft.data_count", len(data))
229
+ # Prepare input tensors
230
+ input_ids = [
231
+ torch.tensor(datum.model_input.to_ints(), dtype=torch.long) for datum in data
232
+ ]
233
+ input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=0)
234
+ attention_mask = (input_ids_padded != 0).long()
235
+ position_ids = (
236
+ torch.arange(input_ids_padded.size(1), dtype=torch.long)
237
+ .unsqueeze(0)
238
+ .expand(input_ids_padded.size(0), -1)
239
+ )
240
+ # Move tensors to model device
241
+ device = next(self.model.parameters()).device
242
+ input_ids_padded = input_ids_padded.to(device)
243
+ attention_mask = attention_mask.to(device)
244
+ position_ids = position_ids.to(device)
245
+
246
+ # Activate the correct adapter
247
+ async with self._lock:
248
+ self._activate_adapter(lora_id)
249
+
250
+ # Forward pass
251
+ outputs = self.model(
252
+ input_ids=input_ids_padded,
253
+ attention_mask=attention_mask,
254
+ position_ids=position_ids,
255
+ return_dict=True,
256
+ )
257
+
258
+ # Compute loss
259
+ if loss_fn_config is None:
260
+ loss_fn_config = {}
261
+ loss_fn_callable = get_loss_fn(loss_fn)
262
+ logits = outputs.logits
263
+ if "temperature" in loss_fn_config:
264
+ temperature = loss_fn_config["temperature"]
265
+ logits.div_(temperature)
266
+
267
+ loss_fn_inputs = self._prepare_loss_fn_inputs(data)
268
+
269
+ ## compute target_logprobs from logits and target_tokens
270
+ target_tokens = loss_fn_inputs["target_tokens"]
271
+ target_logprobs = self._compute_logprobs_from_target_tokens(logits, target_tokens)
272
+ loss_fn_inputs["target_logprobs"] = target_logprobs
273
+
274
+ loss, metric = loss_fn_callable(loss_fn_inputs, loss_fn_config)
275
+
276
+ # Backward pass if needed
277
+ if backward:
278
+ loss.backward()
279
+
280
+ unpaded_logprobs = self._unpad_tensor(
281
+ target_logprobs, [len(datum.model_input.to_ints()) for datum in data]
282
+ )
283
+ return types.ForwardBackwardOutput(
284
+ loss_fn_output_type=loss_fn,
285
+ loss_fn_outputs=[
286
+ {"logprobs": types.TensorData.from_torch(logprobs.detach())}
287
+ for logprobs in unpaded_logprobs
288
+ ],
289
+ metrics=metric,
290
+ )
291
+
292
+ async def optim_step(
293
+ self,
294
+ adam_params: types.AdamParams,
295
+ lora_id: str,
296
+ trace_context: dict[str, str] | None = None,
297
+ ) -> types.OptimStepResponse:
298
+ """Perform an optimization step using Adam optimizer.
299
+
300
+ Args:
301
+ adam_params: Parameters for the Adam optimizer.
302
+ lora_id: The LoRA adapter ID to use.
303
+ trace_context: Optional trace context for distributed tracing.
304
+
305
+ Returns:
306
+ OptimStepResponse: The response containing optimization metrics.
307
+ """
308
+ ctx = extract_context(trace_context or {})
309
+ with _get_tracer().start_as_current_span("hf_model.optim_step", context=ctx) as span:
310
+ span.set_attribute("tuft.lora_id", lora_id)
311
+ optimizer = self.adapter_optimizer[lora_id]
312
+ for param_group in optimizer.param_groups:
313
+ param_group["lr"] = adam_params.learning_rate
314
+ param_group["betas"] = (adam_params.beta1, adam_params.beta2)
315
+ param_group["eps"] = adam_params.eps
316
+ param_group["weight_decay"] = adam_params.weight_decay
317
+ optimizer.step()
318
+ optimizer.zero_grad()
319
+ return types.OptimStepResponse()
320
+
321
+ # --------------------------------
322
+ # Helper methods
323
+ # --------------------------------
324
+ def _prepare_loss_fn_inputs(self, data: list[types.Datum]) -> Dict[str, torch.Tensor]:
325
+ """Prepare input tensors from Datum list."""
326
+ device = next(self.model.parameters()).device
327
+
328
+ loss_fn_input_dict = {}
329
+ # prepare loss_fn_inputs tensors
330
+ loss_fn_input_keys = data[0].loss_fn_inputs.keys()
331
+ for key in loss_fn_input_keys:
332
+ tensors = [datum.loss_fn_inputs[key].to_torch() for datum in data]
333
+ # If tensor is 1D, pad to max length; if already same shape, stack directly
334
+ if all(t.dim() == 1 for t in tensors):
335
+ padded = pad_sequence(tensors, batch_first=True, padding_value=0)
336
+ loss_fn_input_dict[key] = padded.to(device)
337
+ else:
338
+ # Try to stack, if shape mismatch, pad last dim
339
+ try:
340
+ stacked = torch.stack(tensors)
341
+ loss_fn_input_dict[key] = stacked.to(device)
342
+ except Exception:
343
+ # Pad last dim to max length
344
+ max_shape = list(tensors[0].shape)
345
+ for t in tensors:
346
+ for i, s in enumerate(t.shape):
347
+ if s > max_shape[i]:
348
+ max_shape[i] = s
349
+ padded_tensors = []
350
+ for t in tensors:
351
+ pad_width = [(0, m - s) for s, m in zip(t.shape, max_shape, strict=False)]
352
+ pad_args = []
353
+ for p in reversed(pad_width):
354
+ pad_args.extend(p)
355
+ padded = torch.nn.functional.pad(t, pad_args, value=0)
356
+ padded_tensors.append(padded)
357
+ stacked = torch.stack(padded_tensors)
358
+ loss_fn_input_dict[key] = stacked.to(device)
359
+
360
+ return loss_fn_input_dict
361
+
362
+ def _compute_logprobs_from_target_tokens(
363
+ self, logits: torch.Tensor, target_tokens: torch.Tensor
364
+ ) -> torch.Tensor:
365
+ logits_labels = torch.gather(logits, dim=-1, index=target_tokens.unsqueeze(-1)).squeeze(-1)
366
+ # loop to reduce peak mem consumption
367
+ logsumexp_values = torch.stack([torch.logsumexp(logit, dim=-1) for logit in logits])
368
+ logprobs_labels = logits_labels - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
369
+ return logprobs_labels
370
+
371
+ def _unpad_tensor(
372
+ self, padded_tensor: torch.Tensor, original_lengths: list[int]
373
+ ) -> list[torch.Tensor]:
374
+ """Unpad a padded tensor back to list of tensors with original lengths."""
375
+ tensors = []
376
+ for i, length in enumerate(original_lengths):
377
+ tensors.append(padded_tensor[i, :length])
378
+ return tensors
379
+
380
+ def _init_peft_model(self, config: ModelConfig):
381
+ model = AutoModelForCausalLM.from_pretrained(
382
+ str(config.model_path),
383
+ dtype="auto",
384
+ device_map="auto",
385
+ )
386
+ peft_config = LoraConfig()
387
+ peft_model = get_peft_model(model, peft_config=peft_config, adapter_name="default")
388
+ return peft_model
389
+
390
+ def _activate_adapter(self, lora_id: str):
391
+ if lora_id not in self.adapter_optimizer:
392
+ raise ValueError(f"Adapter {lora_id} not found.")
393
+ self.model.set_adapter(lora_id)
394
+
395
+ @classmethod
396
+ def get_actor(cls, config: ModelConfig) -> "ActorProxy":
397
+ return (
398
+ ray.remote(cls)
399
+ .options(
400
+ name="training_model_" + config.model_name,
401
+ num_gpus=1 if not config.colocate else 1 - config.sampling_memory_fraction,
402
+ )
403
+ .remote(config)
404
+ )
@@ -0,0 +1,253 @@
1
+ """Sampling backend implementated using vLLM"""
2
+
3
+ import asyncio
4
+ from logging import getLogger
5
+ from pathlib import Path
6
+ from typing import Optional
7
+
8
+ from opentelemetry.trace import StatusCode
9
+ from tinker import types
10
+
11
+ from ..config import ModelConfig
12
+ from ..telemetry.tracing import get_tracer
13
+ from .base_backend import BaseSamplingBackend
14
+
15
+
16
+ _get_tracer = lambda: get_tracer("tuft.sampling_backend") # noqa: E731
17
+
18
+
19
+ logger = getLogger(__name__)
20
+
21
+
22
+ class VLLMSamplingBackend(BaseSamplingBackend):
23
+ """A sampling backend using vLLM.
24
+
25
+ User side `sample`, `sample_async`, `compute_logprobs` and
26
+ `compute_logprobs_async` are all supported by the sample method.
27
+ """
28
+
29
+ def __init__(self, config: ModelConfig) -> None:
30
+ from vllm.lora.request import LoRARequest
31
+
32
+ super().__init__(config)
33
+ self.engine = self._create_engine(config)
34
+ self.lora_adapters: dict[str, LoRARequest] = {}
35
+ self._counter = 1
36
+ self._lock = asyncio.Lock()
37
+
38
+ def _create_engine(self, config: ModelConfig):
39
+ if config.colocate:
40
+ return self._create_colocated_engine(config)
41
+ else:
42
+ return self._create_standalone_engine(config)
43
+
44
+ def _create_colocated_engine(self, config: ModelConfig):
45
+ import ray
46
+ from trinity.common.config import InferenceModelConfig
47
+ from trinity.common.models.vllm_model import vLLMRolloutModel
48
+
49
+ return (
50
+ ray.remote(vLLMRolloutModel)
51
+ .options(
52
+ name="sampling_model_" + self.base_model,
53
+ num_gpus=config.sampling_memory_fraction,
54
+ )
55
+ .remote(
56
+ config=InferenceModelConfig(
57
+ model_path=str(config.model_path),
58
+ tensor_parallel_size=1,
59
+ max_model_len=config.max_model_len,
60
+ temperature=config.temperature,
61
+ top_p=config.top_p,
62
+ top_k=config.top_k,
63
+ logprobs=config.logprobs,
64
+ min_response_tokens=config.min_response_tokens,
65
+ repetition_penalty=1.0,
66
+ enable_lora=True,
67
+ enable_runtime_lora_updating=True,
68
+ lora_kwargs={
69
+ "max_lora_rank": config.max_lora_rank,
70
+ "max_loras": config.max_loras,
71
+ },
72
+ # sampling use less memory than training
73
+ gpu_memory_utilization=config.sampling_memory_fraction,
74
+ )
75
+ )
76
+ )
77
+
78
+ def _create_standalone_engine(self, config: ModelConfig):
79
+ import ray
80
+ from ray.util.placement_group import placement_group
81
+ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
82
+ from trinity.common.config import InferenceModelConfig
83
+ from trinity.common.models.vllm_model import vLLMRolloutModel
84
+
85
+ # create a placement group for this model
86
+ pg = placement_group(
87
+ [{"CPU": 1, "GPU": 1} for _ in range(config.tensor_parallel_size)],
88
+ strategy="PACK",
89
+ )
90
+ ray.get(pg.ready(), timeout=10)
91
+ return (
92
+ ray.remote(vLLMRolloutModel)
93
+ .options(
94
+ name="sampling_model_" + self.base_model,
95
+ num_gpus=0 if config.tensor_parallel_size > 1 else 1,
96
+ scheduling_strategy=PlacementGroupSchedulingStrategy(
97
+ placement_group=pg,
98
+ placement_group_capture_child_tasks=True,
99
+ ),
100
+ )
101
+ .remote(
102
+ config=InferenceModelConfig(
103
+ model_path=str(config.model_path),
104
+ tensor_parallel_size=config.tensor_parallel_size,
105
+ max_model_len=config.max_model_len,
106
+ temperature=config.temperature,
107
+ top_p=config.top_p,
108
+ top_k=config.top_k,
109
+ logprobs=config.logprobs,
110
+ min_response_tokens=config.min_response_tokens,
111
+ repetition_penalty=1.0,
112
+ enable_lora=True,
113
+ enable_runtime_lora_updating=True,
114
+ lora_kwargs={
115
+ "max_lora_rank": config.max_lora_rank,
116
+ "max_loras": config.max_loras,
117
+ },
118
+ )
119
+ )
120
+ )
121
+
122
+ async def async_init(self) -> None:
123
+ """Initialize the backend for sampling."""
124
+ # Ray @ray.remote decorator adds .remote() method dynamically
125
+ await self.engine.prepare.remote() # type: ignore[attr-defined]
126
+ logger.info(f"SamplingBackend for model {self.base_model} initialized.")
127
+
128
+ async def sample(
129
+ self,
130
+ prompt: types.ModelInput,
131
+ num_samples: int,
132
+ sampling_params: types.SamplingParams,
133
+ include_prompt_logprobs: bool = False,
134
+ topk_prompt_logprobs: int = 0,
135
+ lora_id: Optional[str] = None,
136
+ ) -> types.SampleResponse:
137
+ """Sampling using vLLM engine."""
138
+ with _get_tracer().start_as_current_span("sampling_backend.sample") as span:
139
+ span.set_attribute("tuft.num_samples", num_samples)
140
+ span.set_attribute("tuft.has_lora", lora_id is not None)
141
+ try:
142
+ async with self._lock:
143
+ if lora_id is not None and lora_id not in self.lora_adapters:
144
+ raise ValueError(f"LoRA adapter {lora_id} not found in backend.")
145
+ lora_request = self.lora_adapters[lora_id] if lora_id is not None else None
146
+ # Ray @ray.remote decorator adds .remote() method dynamically
147
+ return await self.engine.sample.remote( # type: ignore[attr-defined]
148
+ prompt=prompt,
149
+ num_samples=num_samples,
150
+ sampling_params=sampling_params,
151
+ include_prompt_logprobs=include_prompt_logprobs,
152
+ topk_prompt_logprobs=topk_prompt_logprobs,
153
+ lora_request=lora_request,
154
+ )
155
+ except Exception as e:
156
+ span.record_exception(e)
157
+ span.set_status(StatusCode.ERROR)
158
+ raise
159
+
160
+ async def add_adapter(self, lora_id: str, adapter_path: Path) -> None:
161
+ from vllm.lora.request import LoRARequest
162
+
163
+ with _get_tracer().start_as_current_span("sampling_backend.add_adapter") as span:
164
+ span.set_attribute("tuft.lora_id", lora_id)
165
+ try:
166
+ async with self._lock:
167
+ self._counter += 1
168
+ self.lora_adapters[lora_id] = LoRARequest(
169
+ lora_int_id=self._counter + 1,
170
+ lora_name=lora_id,
171
+ lora_path=str(adapter_path),
172
+ )
173
+ if not adapter_path.exists():
174
+ raise ValueError(f"LoRA adapter path {adapter_path} does not exist.")
175
+ await self.engine.add_lora_adapter.remote(self.lora_adapters[lora_id])
176
+ except Exception as e:
177
+ span.record_exception(e)
178
+ span.set_status(StatusCode.ERROR)
179
+ raise
180
+
181
+ async def remove_adapter(self, lora_id: str) -> None:
182
+ with _get_tracer().start_as_current_span("sampling_backend.remove_adapter") as span:
183
+ span.set_attribute("tuft.lora_id", lora_id)
184
+ async with self._lock:
185
+ if lora_id in self.lora_adapters:
186
+ await self.engine.remove_lora_adapter.remote(lora_id)
187
+ del self.lora_adapters[lora_id]
188
+
189
+
190
+ class DummySamplingBackend(BaseSamplingBackend):
191
+ """A dummy sampling backend that returns fixed responses for unittest."""
192
+
193
+ def __init__(self, config: ModelConfig) -> None:
194
+ super().__init__(config)
195
+ self.lora_adapters: dict[str, Path] = {}
196
+ self._counter = 1
197
+ self._lock = asyncio.Lock()
198
+
199
+ async def async_init(self) -> None:
200
+ """No initialization needed for dummy backend."""
201
+ pass
202
+
203
+ async def sample(
204
+ self,
205
+ prompt: types.ModelInput,
206
+ num_samples: int,
207
+ sampling_params: types.SamplingParams,
208
+ include_prompt_logprobs: bool = False,
209
+ topk_prompt_logprobs: int = 0,
210
+ lora_id: Optional[str] = None,
211
+ ) -> types.SampleResponse:
212
+ """Return a fixed dummy response."""
213
+ prompt_tokens = prompt.to_ints()
214
+ max_tokens = sampling_params.max_tokens or 16
215
+ sequences: list[types.SampledSequence] = []
216
+ for _ in range(num_samples):
217
+ generated = self._generate_tokens(prompt_tokens, max_tokens)
218
+ seq = types.SampledSequence(
219
+ stop_reason="length",
220
+ tokens=generated,
221
+ logprobs=[-0.3 for _ in generated],
222
+ )
223
+ sequences.append(seq)
224
+ prompt_logprobs = None
225
+ topk_prompt = None
226
+ if include_prompt_logprobs:
227
+ prompt_logprobs = [-0.1 if tok is not None else None for tok in prompt_tokens]
228
+ if topk_prompt_logprobs > 0:
229
+ topk_prompt = [
230
+ [
231
+ (token, round(-0.05 - idx * 0.01, 4))
232
+ for idx, token in enumerate(prompt_tokens[:topk_prompt_logprobs])
233
+ ]
234
+ if token is not None
235
+ else None
236
+ for token in prompt_tokens
237
+ ]
238
+ return types.SampleResponse(
239
+ sequences=sequences,
240
+ prompt_logprobs=prompt_logprobs,
241
+ topk_prompt_logprobs=topk_prompt,
242
+ )
243
+
244
+ def _generate_tokens(self, prompt_tokens: list[int], max_tokens: int) -> list[int]:
245
+ start = prompt_tokens[-1] if prompt_tokens else (abs(self.config.seed) % 32000) + 1
246
+ return [(start + i) % 32000 for i in range(1, max_tokens + 1)]
247
+
248
+ async def add_adapter(self, lora_id: str, adapter_path: Path) -> None:
249
+ self.lora_adapters[lora_id] = adapter_path
250
+
251
+ async def remove_adapter(self, lora_id: str) -> None:
252
+ if lora_id in self.lora_adapters:
253
+ del self.lora_adapters[lora_id]