mlxsmith 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlxsmith/__init__.py +2 -0
- mlxsmith/accel/__init__.py +10 -0
- mlxsmith/accel/base.py +17 -0
- mlxsmith/accel/none.py +13 -0
- mlxsmith/accel/zmlx_backend.py +42 -0
- mlxsmith/adapters.py +46 -0
- mlxsmith/api/__init__.py +48 -0
- mlxsmith/api/handlers.py +1217 -0
- mlxsmith/api/schemas.py +436 -0
- mlxsmith/auth.py +88 -0
- mlxsmith/bench.py +102 -0
- mlxsmith/cli.py +950 -0
- mlxsmith/config.py +543 -0
- mlxsmith/config_models.py +261 -0
- mlxsmith/data.py +493 -0
- mlxsmith/envs/__init__.py +33 -0
- mlxsmith/envs/system.py +388 -0
- mlxsmith/envs/token_env.py +191 -0
- mlxsmith/eval.py +112 -0
- mlxsmith/infer.py +140 -0
- mlxsmith/llm/__init__.py +16 -0
- mlxsmith/llm/backend.py +126 -0
- mlxsmith/llm/interface.py +212 -0
- mlxsmith/llm/mlx_lm_backend.py +509 -0
- mlxsmith/llm/mock_backend.py +228 -0
- mlxsmith/llm/registry.py +12 -0
- mlxsmith/models.py +257 -0
- mlxsmith/orchestrator/__init__.py +25 -0
- mlxsmith/orchestrator/daemon.py +454 -0
- mlxsmith/orchestrator/inference_worker.py +496 -0
- mlxsmith/orchestrator/queue.py +355 -0
- mlxsmith/orchestrator/trainer_worker.py +437 -0
- mlxsmith/rlm/__init__.py +8 -0
- mlxsmith/rlm/corpus.py +74 -0
- mlxsmith/rlm/gating.py +90 -0
- mlxsmith/rlm/generate.py +249 -0
- mlxsmith/rlm/history.py +12 -0
- mlxsmith/rlm/inference.py +150 -0
- mlxsmith/rlm/loop.py +1297 -0
- mlxsmith/rlm/mutate.py +82 -0
- mlxsmith/rlm/trainer.py +73 -0
- mlxsmith/rlm/weights.py +263 -0
- mlxsmith/runs.py +44 -0
- mlxsmith/sdk/__init__.py +392 -0
- mlxsmith/sdk/future.py +486 -0
- mlxsmith/sdk/losses.py +262 -0
- mlxsmith/sdk/sampling_client.py +729 -0
- mlxsmith/sdk/training_client.py +676 -0
- mlxsmith/server.py +376 -0
- mlxsmith/train/__init__.py +0 -0
- mlxsmith/train/distill.py +279 -0
- mlxsmith/train/lora.py +280 -0
- mlxsmith/train/pref.py +180 -0
- mlxsmith/train/rft.py +458 -0
- mlxsmith/train/sft.py +151 -0
- mlxsmith/util.py +174 -0
- mlxsmith/verifiers/__init__.py +3 -0
- mlxsmith/verifiers/compose.py +109 -0
- mlxsmith/verifiers/docker_verifier.py +111 -0
- mlxsmith/verifiers/jsonschema.py +54 -0
- mlxsmith/verifiers/pytest_verifier.py +82 -0
- mlxsmith/verifiers/regex.py +15 -0
- mlxsmith/verifiers/types.py +10 -0
- mlxsmith-0.1.0.dist-info/METADATA +163 -0
- mlxsmith-0.1.0.dist-info/RECORD +69 -0
- mlxsmith-0.1.0.dist-info/WHEEL +5 -0
- mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
- mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
- mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
mlxsmith/infer.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import List, Optional
|
|
6
|
+
|
|
7
|
+
from .config import ProjectConfig
|
|
8
|
+
from .llm.registry import get_llm_backend
|
|
9
|
+
from .models import resolve_model_spec
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class ChatMessage:
|
|
14
|
+
role: str
|
|
15
|
+
content: str
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _messages_to_prompt(messages: List[ChatMessage], tokenizer, *, use_chat_template: bool) -> str:
|
|
19
|
+
if use_chat_template and hasattr(tokenizer, "apply_chat_template"):
|
|
20
|
+
payload = [{"role": m.role, "content": m.content} for m in messages]
|
|
21
|
+
try:
|
|
22
|
+
return tokenizer.apply_chat_template(payload, tokenize=False, add_generation_prompt=True)
|
|
23
|
+
except Exception:
|
|
24
|
+
pass
|
|
25
|
+
joined = "\n".join([f"{m.role}: {m.content}" for m in messages])
|
|
26
|
+
return f"{joined}\nassistant:"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _load_backend(cfg: ProjectConfig, model_spec: str):
|
|
30
|
+
llm = get_llm_backend(cfg.model.backend)
|
|
31
|
+
base_model, adapter_path, _meta = resolve_model_spec(Path.cwd(), model_spec, cfg)
|
|
32
|
+
llm.load(
|
|
33
|
+
base_model,
|
|
34
|
+
max_seq_len=cfg.model.max_seq_len,
|
|
35
|
+
dtype=cfg.model.dtype,
|
|
36
|
+
trust_remote_code=cfg.model.trust_remote_code,
|
|
37
|
+
)
|
|
38
|
+
if adapter_path:
|
|
39
|
+
llm.apply_adapter(str(adapter_path))
|
|
40
|
+
return llm, base_model
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def run_prompt(
|
|
44
|
+
cfg: ProjectConfig,
|
|
45
|
+
model_spec: str,
|
|
46
|
+
prompt: str,
|
|
47
|
+
*,
|
|
48
|
+
max_new_tokens: Optional[int] = None,
|
|
49
|
+
temperature: Optional[float] = None,
|
|
50
|
+
top_p: Optional[float] = None,
|
|
51
|
+
top_k: Optional[int] = None,
|
|
52
|
+
seed: Optional[int] = None,
|
|
53
|
+
) -> str:
|
|
54
|
+
llm, _base_model = _load_backend(cfg, model_spec)
|
|
55
|
+
gen = llm.generate(
|
|
56
|
+
prompt,
|
|
57
|
+
max_new_tokens=max_new_tokens or cfg.infer.max_new_tokens,
|
|
58
|
+
temperature=temperature if temperature is not None else cfg.infer.temperature,
|
|
59
|
+
top_p=top_p if top_p is not None else cfg.infer.top_p,
|
|
60
|
+
top_k=top_k if top_k is not None else cfg.infer.top_k,
|
|
61
|
+
seed=seed,
|
|
62
|
+
)
|
|
63
|
+
if gen.text.startswith(prompt):
|
|
64
|
+
return gen.text[len(prompt) :]
|
|
65
|
+
return gen.text
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def run_chat(
|
|
69
|
+
cfg: ProjectConfig,
|
|
70
|
+
model_spec: str,
|
|
71
|
+
messages: List[ChatMessage],
|
|
72
|
+
*,
|
|
73
|
+
max_new_tokens: Optional[int] = None,
|
|
74
|
+
temperature: Optional[float] = None,
|
|
75
|
+
top_p: Optional[float] = None,
|
|
76
|
+
top_k: Optional[int] = None,
|
|
77
|
+
seed: Optional[int] = None,
|
|
78
|
+
) -> str:
|
|
79
|
+
llm, _base_model = _load_backend(cfg, model_spec)
|
|
80
|
+
prompt = _messages_to_prompt(messages, llm.tokenizer, use_chat_template=cfg.model.use_chat_template)
|
|
81
|
+
gen = llm.generate(
|
|
82
|
+
prompt,
|
|
83
|
+
max_new_tokens=max_new_tokens or cfg.infer.max_new_tokens,
|
|
84
|
+
temperature=temperature if temperature is not None else cfg.infer.temperature,
|
|
85
|
+
top_p=top_p if top_p is not None else cfg.infer.top_p,
|
|
86
|
+
top_k=top_k if top_k is not None else cfg.infer.top_k,
|
|
87
|
+
seed=seed,
|
|
88
|
+
)
|
|
89
|
+
if gen.text.startswith(prompt):
|
|
90
|
+
return gen.text[len(prompt) :]
|
|
91
|
+
return gen.text
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def chat_repl(
|
|
95
|
+
cfg: ProjectConfig,
|
|
96
|
+
model_spec: str,
|
|
97
|
+
*,
|
|
98
|
+
system: Optional[str] = None,
|
|
99
|
+
max_new_tokens: Optional[int] = None,
|
|
100
|
+
temperature: Optional[float] = None,
|
|
101
|
+
top_p: Optional[float] = None,
|
|
102
|
+
top_k: Optional[int] = None,
|
|
103
|
+
seed: Optional[int] = None,
|
|
104
|
+
max_turns: Optional[int] = None,
|
|
105
|
+
) -> None:
|
|
106
|
+
llm, _base_model = _load_backend(cfg, model_spec)
|
|
107
|
+
messages: List[ChatMessage] = []
|
|
108
|
+
if system:
|
|
109
|
+
messages.append(ChatMessage(role="system", content=system))
|
|
110
|
+
|
|
111
|
+
turns = 0
|
|
112
|
+
while True:
|
|
113
|
+
if max_turns is not None and turns >= max_turns:
|
|
114
|
+
break
|
|
115
|
+
try:
|
|
116
|
+
user = input("user> ").strip()
|
|
117
|
+
except EOFError:
|
|
118
|
+
break
|
|
119
|
+
if not user:
|
|
120
|
+
continue
|
|
121
|
+
if user.lower() in {"/exit", "/quit", "exit", "quit"}:
|
|
122
|
+
break
|
|
123
|
+
messages.append(ChatMessage(role="user", content=user))
|
|
124
|
+
prompt = _messages_to_prompt(messages, llm.tokenizer, use_chat_template=cfg.model.use_chat_template)
|
|
125
|
+
gen = llm.generate(
|
|
126
|
+
prompt,
|
|
127
|
+
max_new_tokens=max_new_tokens or cfg.infer.max_new_tokens,
|
|
128
|
+
temperature=temperature if temperature is not None else cfg.infer.temperature,
|
|
129
|
+
top_p=top_p if top_p is not None else cfg.infer.top_p,
|
|
130
|
+
top_k=top_k if top_k is not None else cfg.infer.top_k,
|
|
131
|
+
seed=seed,
|
|
132
|
+
)
|
|
133
|
+
if gen.text.startswith(prompt):
|
|
134
|
+
reply = gen.text[len(prompt) :]
|
|
135
|
+
else:
|
|
136
|
+
reply = gen.text
|
|
137
|
+
reply = reply.strip()
|
|
138
|
+
print(f"assistant> {reply}\n")
|
|
139
|
+
messages.append(ChatMessage(role="assistant", content=reply))
|
|
140
|
+
turns += 1
|
mlxsmith/llm/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Model backends for mlxsmith."""
|
|
2
|
+
|
|
3
|
+
from .backend import LLMBackend, Generation, BackendNotAvailable, DecodingConfig
|
|
4
|
+
from .mlx_lm_backend import MlxLMBackend
|
|
5
|
+
from .mock_backend import MockBackend
|
|
6
|
+
from .registry import get_llm_backend
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"LLMBackend",
|
|
10
|
+
"Generation",
|
|
11
|
+
"BackendNotAvailable",
|
|
12
|
+
"MlxLMBackend",
|
|
13
|
+
"MockBackend",
|
|
14
|
+
"DecodingConfig",
|
|
15
|
+
"get_llm_backend",
|
|
16
|
+
]
|
mlxsmith/llm/backend.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
"""LLM backend abstraction.
|
|
2
|
+
|
|
3
|
+
We keep this intentionally small so mlxsmith can support multiple model loaders
|
|
4
|
+
without hard-binding to one ecosystem.
|
|
5
|
+
|
|
6
|
+
Primary target: mlx-lm (HF -> MLX format + common chat models).
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from typing import Protocol, Sequence, Any, Optional, List, Dict
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class Generation:
|
|
17
|
+
text: str
|
|
18
|
+
token_ids: list[int]
|
|
19
|
+
prompt_len: int
|
|
20
|
+
# Log-probabilities for generated tokens only (len = completion tokens), if available.
|
|
21
|
+
logprobs: list[float] | None = None
|
|
22
|
+
# Top-k logprobs per token: list of {token_str: logprob} dicts for each generated token.
|
|
23
|
+
top_k_logprobs: List[Dict[str, float]] | None = None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class DecodingConfig:
|
|
28
|
+
max_new_tokens: int = 256
|
|
29
|
+
temperature: float = 0.8
|
|
30
|
+
top_p: float = 1.0
|
|
31
|
+
top_k: Optional[int] = None
|
|
32
|
+
seed: Optional[int] = None
|
|
33
|
+
stop: Optional[Sequence[str]] = None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class LLMBackend(Protocol):
|
|
37
|
+
name: str
|
|
38
|
+
|
|
39
|
+
def load(self, model_id_or_path: str, *, max_seq_len: int | None = None, dtype: str | None = None) -> None:
|
|
40
|
+
"""Load model + tokenizer into memory."""
|
|
41
|
+
|
|
42
|
+
def encode(self, text: str) -> list[int]:
|
|
43
|
+
"""Tokenize text -> token ids."""
|
|
44
|
+
|
|
45
|
+
def decode(self, ids: Sequence[int]) -> str:
|
|
46
|
+
"""Token ids -> text."""
|
|
47
|
+
|
|
48
|
+
def generate(
|
|
49
|
+
self,
|
|
50
|
+
prompt: str,
|
|
51
|
+
*,
|
|
52
|
+
max_new_tokens: int = 256,
|
|
53
|
+
temperature: float = 0.8,
|
|
54
|
+
top_p: float = 1.0,
|
|
55
|
+
top_k: int | None = None,
|
|
56
|
+
seed: int | None = None,
|
|
57
|
+
) -> Generation:
|
|
58
|
+
"""Sample a completion and return ids for prompt+completion."""
|
|
59
|
+
|
|
60
|
+
def generate_with_logprobs(
|
|
61
|
+
self,
|
|
62
|
+
prompt: str,
|
|
63
|
+
*,
|
|
64
|
+
max_new_tokens: int = 256,
|
|
65
|
+
temperature: float = 0.8,
|
|
66
|
+
top_p: float = 1.0,
|
|
67
|
+
top_k: int | None = None,
|
|
68
|
+
seed: int | None = None,
|
|
69
|
+
logprobs: int = 0, # Number of top logprobs to return per token (0 = just the sampled token)
|
|
70
|
+
) -> Generation:
|
|
71
|
+
"""Sample a completion and include per-token logprobs when available.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
prompt: Input prompt text
|
|
75
|
+
max_new_tokens: Maximum number of tokens to generate
|
|
76
|
+
temperature: Sampling temperature
|
|
77
|
+
top_p: Nucleus sampling parameter
|
|
78
|
+
top_k: Top-k sampling parameter
|
|
79
|
+
seed: Random seed for reproducibility
|
|
80
|
+
logprobs: Number of top logprobs to return per token (0 = return only sampled token's logprob)
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
Generation with logprobs and optionally top_k_logprobs
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
def sft_loss(self, token_ids: Sequence[int], *, train_on_prompt: bool, prompt_len: int) -> Any:
|
|
87
|
+
"""Return a scalar loss suitable for backprop (MLX array)."""
|
|
88
|
+
|
|
89
|
+
def rl_loss(self, token_ids: Sequence[int], *, prompt_len: int, advantage: float) -> Any:
|
|
90
|
+
"""Return policy-gradient-style loss for a sampled trajectory."""
|
|
91
|
+
|
|
92
|
+
def sequence_logprob(self, token_ids: Sequence[int], *, prompt_len: int) -> Any:
|
|
93
|
+
"""Return log-probability sum of response tokens (differentiable)."""
|
|
94
|
+
|
|
95
|
+
def token_logprobs(
|
|
96
|
+
self,
|
|
97
|
+
token_ids: Sequence[int],
|
|
98
|
+
*,
|
|
99
|
+
prompt_len: int,
|
|
100
|
+
top_k: int = 0,
|
|
101
|
+
include_prompt: bool = False,
|
|
102
|
+
) -> tuple[list[float], List[Dict[str, float]] | None]:
|
|
103
|
+
"""Return per-token logprobs (and optional top-k logprobs).
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
token_ids: Full token sequence (prompt + completion).
|
|
107
|
+
prompt_len: Prompt length in tokens.
|
|
108
|
+
top_k: Number of top logprobs per token to return (0 = none).
|
|
109
|
+
include_prompt: If True, include prompt tokens; otherwise only response tokens.
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
def value_and_grad(self, loss_fn) -> tuple[Any, Any | None]:
|
|
113
|
+
"""Return (loss, grads) using backend autograd when available."""
|
|
114
|
+
|
|
115
|
+
def optimizer_and_params(self, *, lr: float, weight_decay: float = 0.0) -> tuple[Any, Any]:
|
|
116
|
+
"""Return (optimizer, trainable_params_tree)."""
|
|
117
|
+
|
|
118
|
+
def apply_grads(self, optimizer: Any, grads: Any) -> None:
|
|
119
|
+
"""Update model parameters given gradients."""
|
|
120
|
+
|
|
121
|
+
def save_adapter(self, out_dir: str, *, metadata: dict | None = None) -> None:
|
|
122
|
+
"""Persist adapter weights (LoRA) to out_dir."""
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class BackendNotAvailable(RuntimeError):
|
|
126
|
+
pass
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Iterable, List, Optional, Dict
|
|
5
|
+
|
|
6
|
+
from .registry import get_llm_backend
|
|
7
|
+
from .backend import DecodingConfig, Generation
|
|
8
|
+
from ..config import ProjectConfig
|
|
9
|
+
from ..models import resolve_model_spec
|
|
10
|
+
from ..train.lora import load_adapter_config
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class LoadedModel:
|
|
15
|
+
backend: any
|
|
16
|
+
base_model: str
|
|
17
|
+
adapter_path: Optional[str]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def load_base_model(model_id_or_path: str, cfg: ProjectConfig) -> LoadedModel:
|
|
21
|
+
"""Load a base model (and optionally adapter) with the configured backend."""
|
|
22
|
+
backend = get_llm_backend(cfg.model.backend)
|
|
23
|
+
from pathlib import Path
|
|
24
|
+
|
|
25
|
+
base_model, adapter_path, _meta = resolve_model_spec(Path.cwd(), model_id_or_path, cfg)
|
|
26
|
+
backend.load(
|
|
27
|
+
base_model,
|
|
28
|
+
max_seq_len=cfg.model.max_seq_len,
|
|
29
|
+
dtype=cfg.model.dtype,
|
|
30
|
+
trust_remote_code=cfg.model.trust_remote_code,
|
|
31
|
+
)
|
|
32
|
+
return LoadedModel(backend=backend, base_model=base_model, adapter_path=str(adapter_path) if adapter_path else None)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def load_adapter(adapter_path: str):
|
|
36
|
+
"""Load adapter config without applying it."""
|
|
37
|
+
return load_adapter_config(adapter_path)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def apply_adapter(backend, adapter_path: str) -> None:
|
|
41
|
+
backend.apply_adapter(adapter_path)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def generate(
|
|
45
|
+
backend,
|
|
46
|
+
tokenizer,
|
|
47
|
+
prompts: Iterable[str],
|
|
48
|
+
decoding_config: DecodingConfig,
|
|
49
|
+
logprobs: int = 0,
|
|
50
|
+
) -> List[Generation]:
|
|
51
|
+
"""Generate completions for prompts.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
backend: The LLM backend
|
|
55
|
+
tokenizer: Tokenizer instance
|
|
56
|
+
prompts: Iterable of prompt strings
|
|
57
|
+
decoding_config: Decoding configuration
|
|
58
|
+
logprobs: Number of top logprobs to return per token (0 = none)
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
List of Generation results
|
|
62
|
+
"""
|
|
63
|
+
results = []
|
|
64
|
+
for p in prompts:
|
|
65
|
+
if logprobs > 0:
|
|
66
|
+
results.append(
|
|
67
|
+
backend.generate_with_logprobs(
|
|
68
|
+
p,
|
|
69
|
+
max_new_tokens=decoding_config.max_new_tokens,
|
|
70
|
+
temperature=decoding_config.temperature,
|
|
71
|
+
top_p=decoding_config.top_p,
|
|
72
|
+
top_k_sampling=decoding_config.top_k,
|
|
73
|
+
seed=decoding_config.seed,
|
|
74
|
+
logprobs=logprobs,
|
|
75
|
+
)
|
|
76
|
+
)
|
|
77
|
+
else:
|
|
78
|
+
results.append(
|
|
79
|
+
backend.generate(
|
|
80
|
+
p,
|
|
81
|
+
max_new_tokens=decoding_config.max_new_tokens,
|
|
82
|
+
temperature=decoding_config.temperature,
|
|
83
|
+
top_p=decoding_config.top_p,
|
|
84
|
+
top_k=decoding_config.top_k,
|
|
85
|
+
seed=decoding_config.seed,
|
|
86
|
+
)
|
|
87
|
+
)
|
|
88
|
+
return results
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def chat(
|
|
92
|
+
backend,
|
|
93
|
+
messages: List[Dict[str, str]],
|
|
94
|
+
decoding_config: DecodingConfig,
|
|
95
|
+
logprobs: int = 0,
|
|
96
|
+
use_chat_template: bool = True,
|
|
97
|
+
) -> Generation:
|
|
98
|
+
"""Generate a chat completion.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
backend: The LLM backend
|
|
102
|
+
messages: List of message dicts with 'role' and 'content' keys
|
|
103
|
+
decoding_config: Decoding configuration
|
|
104
|
+
logprobs: Number of top logprobs to return per token (0 = none)
|
|
105
|
+
use_chat_template: Whether to use the model's chat template
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
Generation result
|
|
109
|
+
"""
|
|
110
|
+
# Convert messages to prompt
|
|
111
|
+
prompt = _messages_to_prompt(backend.tokenizer, messages, use_chat_template=use_chat_template)
|
|
112
|
+
|
|
113
|
+
if logprobs > 0:
|
|
114
|
+
return backend.generate_with_logprobs(
|
|
115
|
+
prompt,
|
|
116
|
+
max_new_tokens=decoding_config.max_new_tokens,
|
|
117
|
+
temperature=decoding_config.temperature,
|
|
118
|
+
top_p=decoding_config.top_p,
|
|
119
|
+
top_k_sampling=decoding_config.top_k,
|
|
120
|
+
seed=decoding_config.seed,
|
|
121
|
+
logprobs=logprobs,
|
|
122
|
+
)
|
|
123
|
+
else:
|
|
124
|
+
return backend.generate(
|
|
125
|
+
prompt,
|
|
126
|
+
max_new_tokens=decoding_config.max_new_tokens,
|
|
127
|
+
temperature=decoding_config.temperature,
|
|
128
|
+
top_p=decoding_config.top_p,
|
|
129
|
+
top_k=decoding_config.top_k,
|
|
130
|
+
seed=decoding_config.seed,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _messages_to_prompt(
|
|
135
|
+
tokenizer,
|
|
136
|
+
messages: List[Dict[str, str]],
|
|
137
|
+
*,
|
|
138
|
+
use_chat_template: bool = True
|
|
139
|
+
) -> str:
|
|
140
|
+
"""Convert chat messages to prompt string."""
|
|
141
|
+
if use_chat_template and hasattr(tokenizer, "apply_chat_template"):
|
|
142
|
+
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
143
|
+
# Fallback
|
|
144
|
+
return "\n".join([f"{m['role']}: {m['content']}" for m in messages]) + "\nassistant:"
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@dataclass
|
|
148
|
+
class LogprobResult:
|
|
149
|
+
"""Result from logprobs computation."""
|
|
150
|
+
token_logprobs: List[float]
|
|
151
|
+
top_k_logprobs: Optional[List[Dict[str, float]]] = None
|
|
152
|
+
text: Optional[str] = None
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def compute_logprobs(
|
|
156
|
+
backend,
|
|
157
|
+
prompt: str,
|
|
158
|
+
completion: str,
|
|
159
|
+
top_k: int = 0,
|
|
160
|
+
max_seq_len: Optional[int] = None,
|
|
161
|
+
) -> LogprobResult:
|
|
162
|
+
"""Compute logprobs for a prompt-completion pair.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
backend: The LLM backend
|
|
166
|
+
prompt: Prompt text
|
|
167
|
+
completion: Completion text
|
|
168
|
+
top_k: Number of top logprobs per token to return (0 = none)
|
|
169
|
+
max_seq_len: Maximum sequence length
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
LogprobResult with token logprobs and optionally top-k logprobs
|
|
173
|
+
"""
|
|
174
|
+
prompt_ids = backend.encode(prompt)
|
|
175
|
+
ids = backend.encode(prompt + completion)
|
|
176
|
+
|
|
177
|
+
# Truncate if needed
|
|
178
|
+
if max_seq_len and len(ids) > max_seq_len:
|
|
179
|
+
overflow = len(ids) - max_seq_len
|
|
180
|
+
ids = ids[overflow:]
|
|
181
|
+
prompt_len = max(0, len(prompt_ids) - overflow)
|
|
182
|
+
else:
|
|
183
|
+
prompt_len = len(prompt_ids)
|
|
184
|
+
|
|
185
|
+
# Get generation with logprobs
|
|
186
|
+
full_text = backend.decode(ids)
|
|
187
|
+
|
|
188
|
+
# Use backend's sequence_logprob if available
|
|
189
|
+
seq_logprob = backend.sequence_logprob(ids, prompt_len=prompt_len)
|
|
190
|
+
|
|
191
|
+
# For per-token logprobs, we'd need to do a forward pass
|
|
192
|
+
# This is a simplified version
|
|
193
|
+
token_logprobs = []
|
|
194
|
+
if hasattr(backend, '_response_logprobs'):
|
|
195
|
+
token_logprobs = backend._response_logprobs(ids, prompt_len=prompt_len)
|
|
196
|
+
|
|
197
|
+
# Get top-k logprobs if requested
|
|
198
|
+
top_k_logprobs = None
|
|
199
|
+
if top_k > 0 and hasattr(backend, 'generate_with_logprobs'):
|
|
200
|
+
# Generate to get top-k for each position
|
|
201
|
+
gen = backend.generate_with_logprobs(
|
|
202
|
+
prompt,
|
|
203
|
+
max_new_tokens=len(ids) - prompt_len,
|
|
204
|
+
logprobs=top_k,
|
|
205
|
+
)
|
|
206
|
+
top_k_logprobs = gen.top_k_logprobs
|
|
207
|
+
|
|
208
|
+
return LogprobResult(
|
|
209
|
+
token_logprobs=token_logprobs,
|
|
210
|
+
top_k_logprobs=top_k_logprobs,
|
|
211
|
+
text=completion,
|
|
212
|
+
)
|