speedy-utils 1.1.6__py3-none-any.whl → 1.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,62 +1,51 @@
1
1
  # from ._utils import *
2
- import base64
3
- import hashlib
4
- import json
5
- import os
6
2
  from typing import (
7
3
  Any,
8
- Dict,
9
4
  List,
10
5
  Literal,
11
6
  Optional,
12
- Sequence,
13
7
  Type,
14
- Union,
15
8
  cast,
16
- overload,
17
9
  )
18
10
 
19
- from httpx import URL
20
11
  from loguru import logger
21
- from openai import AsyncOpenAI, AuthenticationError, BadRequestError, RateLimitError
22
- from openai.pagination import AsyncPage as AsyncSyncPage
23
-
24
- # from openai.pagination import AsyncSyncPage
25
- from openai.types.chat import (
26
- ChatCompletionAssistantMessageParam,
27
- ChatCompletionMessageParam,
28
- ChatCompletionSystemMessageParam,
29
- ChatCompletionToolMessageParam,
30
- ChatCompletionUserMessageParam,
31
- )
32
- from openai.types.model import Model
12
+ from openai import AuthenticationError, BadRequestError, RateLimitError
33
13
  from pydantic import BaseModel
34
-
35
14
  from speedy_utils import jloads
36
15
 
16
+ # from llm_utils.lm.async_lm.async_llm_task import OutputModelType
17
+ from llm_utils.lm.async_lm.async_lm_base import AsyncLMBase
18
+
37
19
  from ._utils import (
38
20
  LegacyMsgs,
39
21
  Messages,
22
+ OutputModelType,
40
23
  ParsedOutput,
41
24
  RawMsgs,
42
- TModel,
43
- TParsed,
44
- _blue,
45
- _green,
46
- _red,
47
- _yellow,
48
- get_tokenizer,
49
- inspect_word_probs_async,
50
25
  )
51
26
 
52
27
 
53
- class AsyncLM:
28
+ def jloads_safe(content: str) -> Any:
29
+ # if contain ```json, remove it
30
+ if "```json" in content:
31
+ content = content.split("```json")[1].strip().split("```")[0].strip()
32
+ try:
33
+ return jloads(content)
34
+ except Exception as e:
35
+ logger.error(
36
+ f"Failed to parse JSON content: {content[:100]}... with error: {e}"
37
+ )
38
+ raise ValueError(f"Invalid JSON content: {content}") from e
39
+
40
+
41
+ class AsyncLM(AsyncLMBase):
54
42
  """Unified **async** language‑model wrapper with optional JSON parsing."""
55
43
 
56
44
  def __init__(
57
45
  self,
58
- model: str | None = None,
46
+ model: str,
59
47
  *,
48
+ response_model: Optional[type[BaseModel]] = None,
60
49
  temperature: float = 0.0,
61
50
  max_tokens: int = 2_000,
62
51
  host: str = "localhost",
@@ -64,167 +53,101 @@ class AsyncLM:
64
53
  base_url: Optional[str] = None,
65
54
  api_key: Optional[str] = None,
66
55
  cache: bool = True,
56
+ think: Literal[True, False, None] = None,
57
+ add_json_schema_to_instruction: Optional[bool] = None,
58
+ use_beta: bool = False,
67
59
  ports: Optional[List[int]] = None,
68
- **openai_kwargs: Any,
60
+ top_p: float = 1.0,
61
+ presence_penalty: float = 0.0,
62
+ top_k: int = 1,
63
+ repetition_penalty: float = 1.0,
64
+ frequency_penalty: Optional[float] = None,
69
65
  ) -> None:
70
- self.model = model
71
- self.temperature = temperature
72
- self.max_tokens = max_tokens
73
- self.port = port
74
- self.host = host
75
- self.base_url = base_url or (f"http://{host}:{port}/v1" if port else None)
76
- self.api_key = api_key or os.getenv("OPENAI_API_KEY", "abc")
77
- self.openai_kwargs = openai_kwargs
78
- self.do_cache = cache
79
- self.ports = ports
80
- self._init_port = port # <-- store the port provided at init
81
-
82
- # Async client
83
-
84
- @property
85
- def client(self) -> AsyncOpenAI:
86
- # if have multiple ports
87
- if self.ports:
88
- import random
89
-
90
- port = random.choice(self.ports)
91
- api_base = f"http://{self.host}:{port}/v1"
92
- logger.debug(f"Using port: {port}")
93
- else:
94
- api_base = self.base_url or f"http://{self.host}:{self.port}/v1"
95
- client = AsyncOpenAI(
96
- api_key=self.api_key, base_url=api_base, **self.openai_kwargs
66
+ super().__init__(
67
+ host=host,
68
+ port=port,
69
+ ports=ports,
70
+ base_url=base_url,
71
+ cache=cache,
72
+ api_key=api_key,
97
73
  )
98
- return client
99
74
 
100
- # ------------------------------------------------------------------ #
101
- # Public API – typed overloads
102
- # ------------------------------------------------------------------ #
103
- @overload
104
- async def __call__(
105
- self,
106
- *,
107
- prompt: str | None = ...,
108
- messages: RawMsgs | None = ...,
109
- response_format: type[str] = str,
110
- return_openai_response: bool = ...,
111
- **kwargs: Any,
112
- ) -> str: ...
113
-
114
- @overload
115
- async def __call__(
116
- self,
117
- *,
118
- prompt: str | None = ...,
119
- messages: RawMsgs | None = ...,
120
- response_format: Type[TModel],
121
- return_openai_response: bool = ...,
122
- **kwargs: Any,
123
- ) -> TModel: ...
124
-
125
- async def _set_model(self) -> None:
126
- if not self.model:
127
- models = await self.list_models(port=self.port, host=self.host)
128
- self.model = models[0] if models else None
129
- logger.info(
130
- f"No model specified. Using the first available model. {self.model}"
131
- )
132
-
133
- async def __call__(
134
- self,
135
- prompt: Optional[str] = None,
136
- messages: Optional[RawMsgs] = None,
137
- response_format: Union[type[str], Type[BaseModel]] = str,
138
- cache: Optional[bool] = None,
139
- max_tokens: Optional[int] = None,
140
- return_openai_response: bool = False,
141
- **kwargs: Any,
142
- ):
143
- if (prompt is None) == (messages is None):
144
- raise ValueError("Provide *either* `prompt` or `messages` (but not both).")
145
-
146
- if prompt is not None:
147
- messages = [{"role": "user", "content": prompt}]
148
-
149
- assert messages is not None
150
- # assert self.model is not None, "Model must be set before calling."
151
- await self._set_model()
152
-
153
- openai_msgs: Messages = (
154
- self._convert_messages(cast(LegacyMsgs, messages))
155
- if isinstance(messages[0], dict)
156
- else cast(Messages, messages)
157
- )
158
-
159
- kw = dict(
160
- self.openai_kwargs,
161
- temperature=self.temperature,
162
- max_tokens=max_tokens or self.max_tokens,
75
+ # Model behavior options
76
+ self.response_model = response_model
77
+ self.think = think
78
+ self._use_beta = use_beta
79
+ self.add_json_schema_to_instruction = add_json_schema_to_instruction
80
+ if not use_beta:
81
+ self.add_json_schema_to_instruction = True
82
+
83
+ # Store all model-related parameters in model_kwargs
84
+ self.model_kwargs = dict(
85
+ model=model,
86
+ temperature=temperature,
87
+ max_tokens=max_tokens,
88
+ top_p=top_p,
89
+ presence_penalty=presence_penalty,
163
90
  )
164
- kw.update(kwargs)
165
- use_cache = self.do_cache if cache is None else cache
166
-
167
- raw_response = await self._call_raw(
168
- openai_msgs,
169
- response_format=response_format,
170
- use_cache=use_cache,
171
- **kw,
91
+ self.extra_body = dict(
92
+ top_k=top_k,
93
+ repetition_penalty=repetition_penalty,
94
+ frequency_penalty=frequency_penalty,
172
95
  )
173
96
 
174
- if return_openai_response:
175
- response = raw_response
176
- else:
177
- response = self._parse_output(raw_response, response_format)
178
-
179
- self._last_log = [prompt, messages, raw_response]
180
- return response
181
-
182
- # ------------------------------------------------------------------ #
183
- # Model invocation (async)
184
- # ------------------------------------------------------------------ #
185
- async def _call_raw(
97
+ async def _unified_client_call(
186
98
  self,
187
- messages: Sequence[ChatCompletionMessageParam],
188
- response_format: Union[type[str], Type[BaseModel]],
189
- use_cache: bool,
190
- **kw: Any,
191
- ):
192
- assert self.model is not None, "Model must be set before making a call."
193
- model: str = self.model
194
-
195
- cache_key = (
196
- self._cache_key(messages, kw, response_format) if use_cache else None
197
- )
198
- if cache_key and (hit := self._load_cache(cache_key)) is not None:
199
- # Check if cached value is an error
200
- if isinstance(hit, dict) and hit.get("error"):
201
- error_type = hit.get("error_type", "Unknown")
202
- error_msg = hit.get("error_message", "Cached error")
203
- logger.warning(f"Found cached error ({error_type}): {error_msg}")
204
- # Re-raise as a ValueError with meaningful message
205
- raise ValueError(f"Cached {error_type}: {error_msg}")
206
- return hit
99
+ messages: list[dict],
100
+ extra_body: Optional[dict] = None,
101
+ cache_suffix: str = "",
102
+ ) -> dict:
103
+ """Unified method for all client interactions with caching and error handling."""
104
+ converted_messages = self._convert_messages(messages)
105
+ cache_key = None
106
+ completion = None
107
+
108
+ # Handle caching
109
+ if self._cache:
110
+ cache_data = {
111
+ "messages": converted_messages,
112
+ "model_kwargs": self.model_kwargs,
113
+ "extra_body": extra_body or {},
114
+ "cache_suffix": cache_suffix,
115
+ }
116
+ cache_key = self._cache_key(cache_data, {}, str)
117
+ completion = self._load_cache(cache_key)
118
+
119
+ # Check for cached error responses
120
+ if (
121
+ completion
122
+ and isinstance(completion, dict)
123
+ and "error" in completion
124
+ and completion["error"]
125
+ ):
126
+ error_type = completion.get("error_type", "Unknown")
127
+ error_message = completion.get("error_message", "Cached error")
128
+ logger.warning(f"Found cached error ({error_type}): {error_message}")
129
+ raise ValueError(f"Cached {error_type}: {error_message}")
207
130
 
208
131
  try:
209
- if response_format is not str and issubclass(response_format, BaseModel):
210
- openai_response = await self.client.beta.chat.completions.parse(
211
- model=model,
212
- messages=list(messages),
213
- response_format=response_format, # type: ignore[arg-type]
214
- **kw,
215
- )
216
- else:
217
- openai_response = await self.client.chat.completions.create(
218
- model=model,
219
- messages=list(messages),
220
- **kw,
221
- )
132
+ # Get completion from API if not cached
133
+ if not completion:
134
+ call_kwargs = {
135
+ "messages": converted_messages,
136
+ **self.model_kwargs,
137
+ }
138
+ if extra_body:
139
+ call_kwargs["extra_body"] = extra_body
140
+
141
+ completion = await self.client.chat.completions.create(**call_kwargs)
142
+
143
+ if hasattr(completion, "model_dump"):
144
+ completion = completion.model_dump()
145
+ if cache_key:
146
+ self._dump_cache(cache_key, completion)
222
147
 
223
148
  except (AuthenticationError, RateLimitError, BadRequestError) as exc:
224
149
  error_msg = f"OpenAI API error ({type(exc).__name__}): {exc}"
225
150
  logger.error(error_msg)
226
-
227
- # Cache the error if it's a BadRequestError to avoid repeated calls
228
151
  if isinstance(exc, BadRequestError) and cache_key:
229
152
  error_response = {
230
153
  "error": True,
@@ -234,153 +157,180 @@ class AsyncLM:
234
157
  }
235
158
  self._dump_cache(cache_key, error_response)
236
159
  logger.debug(f"Cached BadRequestError for key: {cache_key}")
237
-
238
160
  raise
239
161
 
240
- if cache_key:
241
- self._dump_cache(cache_key, openai_response)
242
-
243
- return openai_response
244
-
245
- # ------------------------------------------------------------------ #
246
- # Utilities below are unchanged (sync I/O is acceptable)
247
- # ------------------------------------------------------------------ #
248
- @staticmethod
249
- def _convert_messages(msgs: LegacyMsgs) -> Messages:
250
- converted: Messages = []
251
- for msg in msgs:
252
- role = msg["role"]
253
- content = msg["content"]
254
- if role == "user":
255
- converted.append(
256
- ChatCompletionUserMessageParam(role="user", content=content)
257
- )
258
- elif role == "assistant":
259
- converted.append(
260
- ChatCompletionAssistantMessageParam(
261
- role="assistant", content=content
262
- )
263
- )
264
- elif role == "system":
265
- converted.append(
266
- ChatCompletionSystemMessageParam(role="system", content=content)
162
+ return completion
163
+
164
+ async def _call_and_parse(
165
+ self,
166
+ messages: list[dict],
167
+ response_model: Type[OutputModelType],
168
+ json_schema: dict,
169
+ ) -> tuple[dict, list[dict], OutputModelType]:
170
+ """Unified call and parse with cache and error handling."""
171
+ if self._use_beta:
172
+ return await self._call_and_parse_with_beta(
173
+ messages, response_model, json_schema
174
+ )
175
+
176
+ choice = None
177
+ try:
178
+ # Use unified client call
179
+ completion = await self._unified_client_call(
180
+ messages,
181
+ extra_body={**self.extra_body},
182
+ cache_suffix=f"_parse_{response_model.__name__}",
183
+ )
184
+
185
+ # Parse the response
186
+ choice = completion["choices"][0]["message"]
187
+ if "content" not in choice:
188
+ raise ValueError("Response choice must contain 'content' field.")
189
+
190
+ content = choice["content"]
191
+ if not content:
192
+ raise ValueError("Response content is empty")
193
+
194
+ parsed = response_model.model_validate(jloads_safe(content))
195
+
196
+ except Exception as e:
197
+ # Try fallback to beta mode if regular parsing fails
198
+ if not isinstance(
199
+ e, (AuthenticationError, RateLimitError, BadRequestError)
200
+ ):
201
+ content = choice.get("content", "N/A") if choice else "N/A"
202
+ logger.info(
203
+ f"Regular parsing failed due to wrong format or content, now falling back to beta mode: {content=}, {e=}"
267
204
  )
268
- elif role == "tool":
269
- converted.append(
270
- ChatCompletionToolMessageParam(
271
- role="tool",
272
- content=content,
273
- tool_call_id=msg.get("tool_call_id") or "",
205
+ try:
206
+ return await self._call_and_parse_with_beta(
207
+ messages, response_model, json_schema
274
208
  )
275
- )
276
- else:
277
- converted.append({"role": role, "content": content}) # type: ignore[arg-type]
278
- return converted
279
-
280
- @staticmethod
281
- def _parse_output(
282
- raw_response: Any, response_format: Union[type[str], Type[BaseModel]]
283
- ) -> str | BaseModel:
284
- if hasattr(raw_response, "model_dump"):
285
- raw_response = raw_response.model_dump()
286
-
287
- if response_format is str:
288
- if isinstance(raw_response, dict) and "choices" in raw_response:
289
- message = raw_response["choices"][0]["message"]
290
- return message.get("content", "") or ""
291
- return cast(str, raw_response)
292
-
293
- model_cls = cast(Type[BaseModel], response_format)
294
-
295
- if isinstance(raw_response, dict) and "choices" in raw_response:
296
- message = raw_response["choices"][0]["message"]
297
- if "parsed" in message:
298
- return model_cls.model_validate(message["parsed"])
299
- content = message.get("content")
300
- if content is None:
301
- raise ValueError("Model returned empty content")
302
- try:
303
- data = json.loads(content)
304
- return model_cls.model_validate(data)
305
- except Exception as exc:
306
- raise ValueError(
307
- f"Failed to parse model output as JSON:\n{content}"
308
- ) from exc
309
-
310
- if isinstance(raw_response, model_cls):
311
- return raw_response
312
- if isinstance(raw_response, dict):
313
- return model_cls.model_validate(raw_response)
209
+ except Exception as beta_e:
210
+ logger.warning(f"Beta mode fallback also failed: {beta_e}")
211
+ choice_info = choice if choice is not None else "N/A"
212
+ raise ValueError(
213
+ f"Failed to parse model response with both regular and beta modes. "
214
+ f"Regular error: {e}. Beta error: {beta_e}. "
215
+ f"Model response message: {choice_info}"
216
+ ) from e
217
+ raise
218
+
219
+ assistant_msg = self._extract_assistant_message(choice)
220
+ full_messages = messages + [assistant_msg]
314
221
 
222
+ return completion, full_messages, cast(OutputModelType, parsed)
223
+
224
+ async def _call_and_parse_with_beta(
225
+ self,
226
+ messages: list[dict],
227
+ response_model: Type[OutputModelType],
228
+ json_schema: dict,
229
+ ) -> tuple[dict, list[dict], OutputModelType]:
230
+ """Call and parse for beta mode with guided JSON."""
231
+ choice = None
315
232
  try:
316
- data = json.loads(raw_response)
317
- return model_cls.model_validate(data)
318
- except Exception as exc:
233
+ # Use unified client call with guided JSON
234
+ completion = await self._unified_client_call(
235
+ messages,
236
+ extra_body={"guided_json": json_schema, **self.extra_body},
237
+ cache_suffix=f"_beta_parse_{response_model.__name__}",
238
+ )
239
+
240
+ # Parse the response
241
+ choice = completion["choices"][0]["message"]
242
+ parsed = self._parse_complete_output(completion, response_model)
243
+
244
+ except Exception as e:
245
+ choice_info = choice if choice is not None else "N/A"
319
246
  raise ValueError(
320
- f"Model did not return valid JSON:\n---\n{raw_response}"
321
- ) from exc
247
+ f"Failed to parse model response: {e}\nModel response message: {choice_info}"
248
+ ) from e
322
249
 
323
- # ------------------------------------------------------------------ #
324
- # Simple disk cache (sync)
325
- # ------------------------------------------------------------------ #
326
- @staticmethod
327
- def _cache_key(
328
- messages: Any, kw: Any, response_format: Union[type[str], Type[BaseModel]]
329
- ) -> str:
330
- tag = response_format.__name__ if response_format is not str else "text"
331
- blob = json.dumps([messages, kw, tag], sort_keys=True).encode()
332
- return base64.urlsafe_b64encode(hashlib.sha256(blob).digest()).decode()[:22]
333
-
334
- @staticmethod
335
- def _cache_path(key: str) -> str:
336
- return os.path.expanduser(f"~/.cache/lm/{key}.json")
337
-
338
- def _dump_cache(self, key: str, val: Any) -> None:
339
- try:
340
- path = self._cache_path(key)
341
- os.makedirs(os.path.dirname(path), exist_ok=True)
342
- with open(path, "w") as fh:
343
- if isinstance(val, BaseModel):
344
- json.dump(val.model_dump(mode="json"), fh)
345
- else:
346
- json.dump(val, fh)
347
- except Exception as exc:
348
- logger.debug(f"cache write skipped: {exc}")
250
+ assistant_msg = self._extract_assistant_message(choice)
251
+ full_messages = messages + [assistant_msg]
252
+
253
+ return completion, full_messages, cast(OutputModelType, parsed)
254
+
255
+ def _extract_assistant_message(self, choice): # -> dict[str, str] | dict[str, Any]:
256
+ # TODO this current assume choice is a dict with "reasoning_content" and "content"
257
+ has_reasoning = False
258
+ if "reasoning_content" in choice and isinstance(
259
+ choice["reasoning_content"], str
260
+ ):
261
+ reasoning_content = choice["reasoning_content"].strip()
262
+ has_reasoning = True
263
+
264
+ content = choice["content"]
265
+ _content = content.lstrip("\n")
266
+ if has_reasoning:
267
+ assistant_msg = {
268
+ "role": "assistant",
269
+ "content": f"<think>\n{reasoning_content}\n</think>\n\n{_content}",
270
+ }
271
+ else:
272
+ assistant_msg = {"role": "assistant", "content": _content}
273
+
274
+ return assistant_msg
275
+
276
+ async def __call__(
277
+ self,
278
+ prompt: Optional[str] = None,
279
+ messages: Optional[RawMsgs] = None,
280
+ ): # -> tuple[Any | dict[Any, Any], list[ChatCompletionMessagePar...:# -> tuple[Any | dict[Any, Any], list[ChatCompletionMessagePar...:
281
+ """Unified async call for language model, returns (assistant_message.model_dump(), messages)."""
282
+ if (prompt is None) == (messages is None):
283
+ raise ValueError("Provide *either* `prompt` or `messages` (but not both).")
284
+
285
+ if prompt is not None:
286
+ messages = [{"role": "user", "content": prompt}]
287
+
288
+ assert messages is not None
289
+
290
+ openai_msgs: Messages = (
291
+ self._convert_messages(cast(LegacyMsgs, messages))
292
+ if isinstance(messages[0], dict)
293
+ else cast(Messages, messages)
294
+ )
295
+
296
+ assert self.model_kwargs["model"] is not None, (
297
+ "Model must be set before making a call."
298
+ )
299
+
300
+ # Use unified client call
301
+ raw_response = await self._unified_client_call(
302
+ list(openai_msgs), cache_suffix="_call"
303
+ )
304
+
305
+ if hasattr(raw_response, "model_dump"):
306
+ raw_response = raw_response.model_dump() # type: ignore
307
+
308
+ # Extract the assistant's message
309
+ assistant_msg = raw_response["choices"][0]["message"]
310
+ # Build the full messages list (input + assistant reply)
311
+ full_messages = list(messages) + [
312
+ {"role": assistant_msg["role"], "content": assistant_msg["content"]}
313
+ ]
314
+ # Return the OpenAI message as model_dump (if available) and the messages list
315
+ if hasattr(assistant_msg, "model_dump"):
316
+ msg_dump = assistant_msg.model_dump()
317
+ else:
318
+ msg_dump = dict(assistant_msg)
319
+ return msg_dump, full_messages
349
320
 
350
- def _load_cache(self, key: str) -> Any | None:
351
- path = self._cache_path(key)
352
- if not os.path.exists(path):
353
- return None
354
- try:
355
- with open(path) as fh:
356
- return json.load(fh)
357
- except Exception:
358
- return None
359
-
360
- # ------------------------------------------------------------------ #
361
- # Missing methods from LM class
362
- # ------------------------------------------------------------------ #
363
321
  async def parse(
364
322
  self,
365
- response_model: Type[TParsed],
366
323
  instruction,
367
324
  prompt,
368
- think: Literal[True, False, None] = None,
369
- add_json_schema_to_instruction: bool = False,
370
- temperature: Optional[float] = None,
371
- max_tokens: Optional[int] = None,
372
- cache: Optional[bool] = None,
373
- use_beta: bool = False,
374
- **kwargs,
375
- ) -> ParsedOutput[TParsed]:
376
- """Parse response using guided JSON generation."""
377
-
378
- if not use_beta:
379
- assert add_json_schema_to_instruction, (
325
+ ) -> ParsedOutput[BaseModel]:
326
+ """Parse response using guided JSON generation. Returns (parsed.model_dump(), messages)."""
327
+ if not self._use_beta:
328
+ assert self.add_json_schema_to_instruction, (
380
329
  "add_json_schema_to_instruction must be True when use_beta is False. otherwise model will not be able to parse the response."
381
330
  )
382
331
 
383
- json_schema = response_model.model_json_schema()
332
+ assert self.response_model is not None, "response_model must be set at init."
333
+ json_schema = self.response_model.model_json_schema()
384
334
 
385
335
  # Build system message content in a single, clear block
386
336
  assert instruction is not None, "Instruction must be provided."
@@ -388,122 +338,32 @@ class AsyncLM:
388
338
  system_content = instruction
389
339
 
390
340
  # Add schema if needed
391
- system_content = self._build_system_prompt(
392
- response_model,
393
- add_json_schema_to_instruction,
341
+ system_content = self.build_system_prompt(
342
+ self.response_model,
343
+ self.add_json_schema_to_instruction,
394
344
  json_schema,
395
345
  system_content,
396
- think=think,
346
+ think=self.think,
397
347
  )
398
348
 
399
- # Rebuild messages with updated system message if needed
400
349
  messages = [
401
350
  {"role": "system", "content": system_content},
402
351
  {"role": "user", "content": prompt},
403
352
  ] # type: ignore
404
353
 
405
- model_kwargs = {}
406
- if temperature is not None:
407
- model_kwargs["temperature"] = temperature
408
- if max_tokens is not None:
409
- model_kwargs["max_tokens"] = max_tokens
410
- model_kwargs.update(kwargs)
411
-
412
- use_cache = self.do_cache if cache is None else cache
413
- cache_key = None
414
- completion = None
415
- choice = None
416
- parsed = None
417
-
418
- if use_cache:
419
- cache_data = {
420
- "messages": messages,
421
- "model_kwargs": model_kwargs,
422
- "guided_json": json_schema,
423
- "response_format": response_model.__name__,
424
- "use_beta": use_beta,
425
- }
426
- cache_key = self._cache_key(cache_data, {}, response_model)
427
- completion = self._load_cache(cache_key) # dict
428
-
429
- if not completion:
430
- completion, choice, parsed = await self._call_and_parse_completion(
431
- messages,
432
- response_model,
433
- json_schema,
434
- use_beta=use_beta,
435
- model_kwargs=model_kwargs,
436
- )
437
-
438
- if cache_key:
439
- self._dump_cache(cache_key, completion)
440
- else:
441
- # Extract choice and parsed from cached completion
442
- choice = completion["choices"][0]["message"]
443
- try:
444
- parsed = self._parse_complete_output(completion, response_model)
445
- except Exception as e:
446
- raise ValueError(
447
- f"Failed to parse cached completion: {e}\nRaw: {choice.get('content')}"
448
- ) from e
449
-
450
- assert isinstance(completion, dict), (
451
- "Completion must be a dictionary with OpenAI response format."
354
+ completion, full_messages, parsed = await self._call_and_parse(
355
+ messages,
356
+ self.response_model,
357
+ json_schema,
452
358
  )
453
- self._last_log = [prompt, messages, completion]
454
-
455
- reasoning_content = choice.get("reasoning_content", "").strip()
456
- _content = choice.get("content", "").lstrip("\n")
457
- content = f"<think>\n{reasoning_content}\n</think>\n\n{_content}"
458
-
459
- full_messages = messages + [{"role": "assistant", "content": content}]
460
359
 
461
360
  return ParsedOutput(
462
361
  messages=full_messages,
362
+ parsed=cast(BaseModel, parsed),
463
363
  completion=completion,
464
- parsed=parsed, # type: ignore
364
+ model_kwargs=self.model_kwargs,
465
365
  )
466
366
 
467
- def _build_system_prompt(
468
- self,
469
- response_model,
470
- add_json_schema_to_instruction,
471
- json_schema,
472
- system_content,
473
- think,
474
- ):
475
- if add_json_schema_to_instruction and response_model:
476
- schema_block = f"\n\n<output_json_schema>\n{json.dumps(json_schema, indent=2)}\n</output_json_schema>"
477
- # if schema_block not in system_content:
478
- if "<output_json_schema>" in system_content:
479
- # remove exsting schema block
480
- import re # replace
481
-
482
- system_content = re.sub(
483
- r"<output_json_schema>.*?</output_json_schema>",
484
- "",
485
- system_content,
486
- flags=re.DOTALL,
487
- )
488
- system_content = system_content.strip()
489
- system_content += schema_block
490
-
491
- if think is True:
492
- if "/think" in system_content:
493
- pass
494
- elif "/no_think" in system_content:
495
- system_content = system_content.replace("/no_think", "/think")
496
- else:
497
- system_content += "\n\n/think"
498
- elif think is False:
499
- if "/no_think" in system_content:
500
- pass
501
- elif "/think" in system_content:
502
- system_content = system_content.replace("/think", "/no_think")
503
- else:
504
- system_content += "\n\n/no_think"
505
- return system_content
506
-
507
367
  def _parse_complete_output(
508
368
  self, completion: Any, response_model: Type[BaseModel]
509
369
  ) -> BaseModel:
@@ -516,264 +376,12 @@ class AsyncLM:
516
376
 
517
377
  content = completion["choices"][0]["message"]["content"]
518
378
  if not content:
519
- # Enhanced error for debugging: show input tokens and their count
520
-
521
- # Try to extract tokens from the completion for debugging
522
- input_tokens = None
523
- try:
524
- input_tokens = completion.get("usage", {}).get("prompt_tokens")
525
- except Exception:
526
- input_tokens = None
527
-
528
- # Try to get the prompt/messages for tokenization
529
- prompt = None
530
- try:
531
- prompt = completion.get("messages") or completion.get("prompt")
532
- except Exception:
533
- prompt = None
534
-
535
- tokens_preview = ""
536
- if prompt is not None:
537
- try:
538
- tokenizer = get_tokenizer(self.model)
539
- if isinstance(prompt, list):
540
- prompt_text = "\n".join(
541
- m.get("content", "") for m in prompt if isinstance(m, dict)
542
- )
543
- else:
544
- prompt_text = str(prompt)
545
- tokens = tokenizer.encode(prompt_text)
546
- n_tokens = len(tokens)
547
- first_100 = tokens[:100]
548
- last_100 = tokens[-100:] if n_tokens > 100 else []
549
- tokens_preview = (
550
- f"\nInput tokens: {n_tokens}"
551
- f"\nFirst 100 tokens: {first_100}"
552
- f"\nLast 100 tokens: {last_100}"
553
- )
554
- except Exception as exc:
555
- tokens_preview = f"\n[Tokenization failed: {exc}]"
556
-
557
- raise ValueError(
558
- f"Empty content in response."
559
- f"\nInput tokens (if available): {input_tokens}"
560
- f"{tokens_preview}"
561
- )
379
+ raise ValueError("Response content is empty")
562
380
 
563
381
  try:
564
- data = json.loads(content)
382
+ data = jloads(content)
565
383
  return response_model.model_validate(data)
566
384
  except Exception as exc:
567
385
  raise ValueError(
568
- f"Failed to parse response as {response_model.__name__}: {content}"
386
+ f"Failed to validate against response model {response_model.__name__}: {exc}\nRaw content: {content}"
569
387
  ) from exc
570
-
571
- async def inspect_word_probs(
572
- self,
573
- messages: Optional[List[Dict[str, Any]]] = None,
574
- tokenizer: Optional[Any] = None,
575
- do_print=True,
576
- add_think: bool = True,
577
- ) -> tuple[List[Dict[str, Any]], Any, str]:
578
- """
579
- Inspect word probabilities in a language model response.
580
-
581
- Args:
582
- tokenizer: Tokenizer instance to encode words.
583
- messages: List of messages to analyze.
584
-
585
- Returns:
586
- A tuple containing:
587
- - List of word probabilities with their log probabilities.
588
- - Token log probability dictionaries.
589
- - Rendered string with colored word probabilities.
590
- """
591
- if messages is None:
592
- messages = await self.last_messages(add_think=add_think)
593
- if messages is None:
594
- raise ValueError("No messages provided and no last messages available.")
595
-
596
- if tokenizer is None:
597
- tokenizer = get_tokenizer(self.model)
598
-
599
- ret = await inspect_word_probs_async(self, tokenizer, messages)
600
- if do_print:
601
- print(ret[-1])
602
- return ret
603
-
604
- async def last_messages(
605
- self, add_think: bool = True
606
- ) -> Optional[List[Dict[str, str]]]:
607
- """Get the last conversation messages including assistant response."""
608
- if not hasattr(self, "last_log"):
609
- return None
610
-
611
- last_conv = self._last_log
612
- messages = last_conv[1] if len(last_conv) > 1 else None
613
- last_msg = last_conv[2]
614
- if not isinstance(last_msg, dict):
615
- last_conv[2] = last_conv[2].model_dump() # type: ignore
616
- msg = last_conv[2]
617
- # Ensure msg is a dict
618
- if hasattr(msg, "model_dump"):
619
- msg = msg.model_dump()
620
- message = msg["choices"][0]["message"]
621
- reasoning = message.get("reasoning_content")
622
- answer = message.get("content")
623
- if reasoning and add_think:
624
- final_answer = f"<think>{reasoning}</think>\n{answer}"
625
- else:
626
- final_answer = f"<think>\n\n</think>\n{answer}"
627
- assistant = {"role": "assistant", "content": final_answer}
628
- messages = messages + [assistant] # type: ignore
629
- return messages if messages else None
630
-
631
- # ------------------------------------------------------------------ #
632
- # Utility helpers
633
- # ------------------------------------------------------------------ #
634
- async def inspect_history(self) -> None:
635
- """Inspect the conversation history with proper formatting."""
636
- if not hasattr(self, "last_log"):
637
- raise ValueError("No history available. Please call the model first.")
638
-
639
- prompt, messages, response = self._last_log
640
- if hasattr(response, "model_dump"):
641
- response = response.model_dump()
642
- if not messages:
643
- messages = [{"role": "user", "content": prompt}]
644
-
645
- print("\n\n")
646
- print(_blue("[Conversation History]") + "\n")
647
-
648
- for msg in messages:
649
- role = msg["role"]
650
- content = msg["content"]
651
- print(_red(f"{role.capitalize()}:"))
652
- if isinstance(content, str):
653
- print(content.strip())
654
- elif isinstance(content, list):
655
- for item in content:
656
- if item.get("type") == "text":
657
- print(item["text"].strip())
658
- elif item.get("type") == "image_url":
659
- image_url = item["image_url"]["url"]
660
- if "base64" in image_url:
661
- len_base64 = len(image_url.split("base64,")[1])
662
- print(_blue(f"<IMAGE BASE64 ENCODED({len_base64})>"))
663
- else:
664
- print(_blue(f"<image_url: {image_url}>"))
665
- print("\n")
666
-
667
- print(_red("Response:"))
668
- if isinstance(response, dict) and response.get("choices"):
669
- message = response["choices"][0].get("message", {})
670
- reasoning = message.get("reasoning_content")
671
- parsed = message.get("parsed")
672
- content = message.get("content")
673
- if reasoning:
674
- print(_yellow("<think>"))
675
- print(reasoning.strip())
676
- print(_yellow("</think>\n"))
677
- if parsed:
678
- print(
679
- json.dumps(
680
- (
681
- parsed.model_dump()
682
- if hasattr(parsed, "model_dump")
683
- else parsed
684
- ),
685
- indent=2,
686
- )
687
- + "\n"
688
- )
689
- elif content:
690
- print(content.strip())
691
- else:
692
- print(_green("[No content]"))
693
- if len(response["choices"]) > 1:
694
- print(
695
- _blue(f"\n(Plus {len(response['choices']) - 1} other completions)")
696
- )
697
- else:
698
- print(_yellow("Warning: Not a standard OpenAI response object"))
699
- if isinstance(response, str):
700
- print(_green(response.strip()))
701
- elif isinstance(response, dict):
702
- print(_green(json.dumps(response, indent=2)))
703
- else:
704
- print(_green(str(response)))
705
-
706
- # ------------------------------------------------------------------ #
707
- # Misc helpers
708
- # ------------------------------------------------------------------ #
709
- def set_model(self, model: str) -> None:
710
- self.model = model
711
-
712
- @staticmethod
713
- async def list_models(port=None, host="localhost") -> List[str]:
714
- try:
715
- client: AsyncOpenAI = AsyncLM(port=port, host=host).client # type: ignore[arg-type]
716
- base_url: URL = client.base_url
717
- logger.debug(f"Base URL: {base_url}")
718
- models: AsyncSyncPage[Model] = await client.models.list() # type: ignore[assignment]
719
- return [model.id for model in models.data]
720
- except Exception as exc:
721
- logger.error(f"Failed to list models: {exc}")
722
- return []
723
-
724
- async def _call_and_parse_completion(
725
- self,
726
- messages: list[dict],
727
- response_model: Type[TParsed],
728
- json_schema: dict,
729
- use_beta: bool,
730
- model_kwargs: dict,
731
- ) -> tuple[dict, dict, TParsed]:
732
- """Call vLLM or OpenAI-compatible endpoint and parse JSON response consistently."""
733
- await self._set_model() # Ensure model is set before making the call
734
- # Convert messages to proper type
735
- converted_messages = self._convert_messages(messages) # type: ignore
736
-
737
- if use_beta:
738
- # Use guided JSON for structure enforcement
739
- try:
740
- completion = await self.client.chat.completions.create(
741
- model=str(self.model), # type: ignore
742
- messages=converted_messages,
743
- extra_body={"guided_json": json_schema}, # type: ignore
744
- **model_kwargs,
745
- ) # type: ignore
746
- except Exception:
747
- # Fallback if extra_body is not supported
748
- completion = await self.client.chat.completions.create(
749
- model=str(self.model), # type: ignore
750
- messages=converted_messages,
751
- response_format={"type": "json_object"},
752
- **model_kwargs,
753
- )
754
- else:
755
- # Use OpenAI-style structured output
756
- completion = await self.client.chat.completions.create(
757
- model=str(self.model), # type: ignore
758
- messages=converted_messages,
759
- response_format={"type": "json_object"},
760
- **model_kwargs,
761
- )
762
-
763
- if hasattr(completion, "model_dump"):
764
- completion = completion.model_dump()
765
-
766
- choice = completion["choices"][0]["message"]
767
-
768
- try:
769
- parsed = (
770
- self._parse_complete_output(completion, response_model)
771
- if use_beta
772
- else response_model.model_validate(jloads(choice.get("content")))
773
- )
774
- except Exception as e:
775
- raise ValueError(
776
- f"Failed to parse model response: {e}\nRaw: {choice.get('content')}"
777
- ) from e
778
-
779
- return completion, choice, parsed # type: ignore