speedy-utils 1.1.5__py3-none-any.whl → 1.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,387 @@
1
+ # from ._utils import *
2
+ from typing import (
3
+ Any,
4
+ List,
5
+ Literal,
6
+ Optional,
7
+ Type,
8
+ cast,
9
+ )
10
+
11
+ from loguru import logger
12
+ from openai import AuthenticationError, BadRequestError, RateLimitError
13
+ from pydantic import BaseModel
14
+ from speedy_utils import jloads
15
+
16
+ # from llm_utils.lm.async_lm.async_llm_task import OutputModelType
17
+ from llm_utils.lm.async_lm.async_lm_base import AsyncLMBase
18
+
19
+ from ._utils import (
20
+ LegacyMsgs,
21
+ Messages,
22
+ OutputModelType,
23
+ ParsedOutput,
24
+ RawMsgs,
25
+ )
26
+
27
+
28
+ def jloads_safe(content: str) -> Any:
29
+ # if contain ```json, remove it
30
+ if "```json" in content:
31
+ content = content.split("```json")[1].strip().split("```")[0].strip()
32
+ try:
33
+ return jloads(content)
34
+ except Exception as e:
35
+ logger.error(
36
+ f"Failed to parse JSON content: {content[:100]}... with error: {e}"
37
+ )
38
+ raise ValueError(f"Invalid JSON content: {content}") from e
39
+
40
+
41
+ class AsyncLM(AsyncLMBase):
42
+ """Unified **async** language‑model wrapper with optional JSON parsing."""
43
+
44
+ def __init__(
45
+ self,
46
+ model: str,
47
+ *,
48
+ response_model: Optional[type[BaseModel]] = None,
49
+ temperature: float = 0.0,
50
+ max_tokens: int = 2_000,
51
+ host: str = "localhost",
52
+ port: Optional[int | str] = None,
53
+ base_url: Optional[str] = None,
54
+ api_key: Optional[str] = None,
55
+ cache: bool = True,
56
+ think: Literal[True, False, None] = None,
57
+ add_json_schema_to_instruction: Optional[bool] = None,
58
+ use_beta: bool = False,
59
+ ports: Optional[List[int]] = None,
60
+ top_p: float = 1.0,
61
+ presence_penalty: float = 0.0,
62
+ top_k: int = 1,
63
+ repetition_penalty: float = 1.0,
64
+ frequency_penalty: Optional[float] = None,
65
+ ) -> None:
66
+ super().__init__(
67
+ host=host,
68
+ port=port,
69
+ ports=ports,
70
+ base_url=base_url,
71
+ cache=cache,
72
+ api_key=api_key,
73
+ )
74
+
75
+ # Model behavior options
76
+ self.response_model = response_model
77
+ self.think = think
78
+ self._use_beta = use_beta
79
+ self.add_json_schema_to_instruction = add_json_schema_to_instruction
80
+ if not use_beta:
81
+ self.add_json_schema_to_instruction = True
82
+
83
+ # Store all model-related parameters in model_kwargs
84
+ self.model_kwargs = dict(
85
+ model=model,
86
+ temperature=temperature,
87
+ max_tokens=max_tokens,
88
+ top_p=top_p,
89
+ presence_penalty=presence_penalty,
90
+ )
91
+ self.extra_body = dict(
92
+ top_k=top_k,
93
+ repetition_penalty=repetition_penalty,
94
+ frequency_penalty=frequency_penalty,
95
+ )
96
+
97
+ async def _unified_client_call(
98
+ self,
99
+ messages: list[dict],
100
+ extra_body: Optional[dict] = None,
101
+ cache_suffix: str = "",
102
+ ) -> dict:
103
+ """Unified method for all client interactions with caching and error handling."""
104
+ converted_messages = self._convert_messages(messages)
105
+ cache_key = None
106
+ completion = None
107
+
108
+ # Handle caching
109
+ if self._cache:
110
+ cache_data = {
111
+ "messages": converted_messages,
112
+ "model_kwargs": self.model_kwargs,
113
+ "extra_body": extra_body or {},
114
+ "cache_suffix": cache_suffix,
115
+ }
116
+ cache_key = self._cache_key(cache_data, {}, str)
117
+ completion = self._load_cache(cache_key)
118
+
119
+ # Check for cached error responses
120
+ if (
121
+ completion
122
+ and isinstance(completion, dict)
123
+ and "error" in completion
124
+ and completion["error"]
125
+ ):
126
+ error_type = completion.get("error_type", "Unknown")
127
+ error_message = completion.get("error_message", "Cached error")
128
+ logger.warning(f"Found cached error ({error_type}): {error_message}")
129
+ raise ValueError(f"Cached {error_type}: {error_message}")
130
+
131
+ try:
132
+ # Get completion from API if not cached
133
+ if not completion:
134
+ call_kwargs = {
135
+ "messages": converted_messages,
136
+ **self.model_kwargs,
137
+ }
138
+ if extra_body:
139
+ call_kwargs["extra_body"] = extra_body
140
+
141
+ completion = await self.client.chat.completions.create(**call_kwargs)
142
+
143
+ if hasattr(completion, "model_dump"):
144
+ completion = completion.model_dump()
145
+ if cache_key:
146
+ self._dump_cache(cache_key, completion)
147
+
148
+ except (AuthenticationError, RateLimitError, BadRequestError) as exc:
149
+ error_msg = f"OpenAI API error ({type(exc).__name__}): {exc}"
150
+ logger.error(error_msg)
151
+ if isinstance(exc, BadRequestError) and cache_key:
152
+ error_response = {
153
+ "error": True,
154
+ "error_type": "BadRequestError",
155
+ "error_message": str(exc),
156
+ "choices": [],
157
+ }
158
+ self._dump_cache(cache_key, error_response)
159
+ logger.debug(f"Cached BadRequestError for key: {cache_key}")
160
+ raise
161
+
162
+ return completion
163
+
164
+ async def _call_and_parse(
165
+ self,
166
+ messages: list[dict],
167
+ response_model: Type[OutputModelType],
168
+ json_schema: dict,
169
+ ) -> tuple[dict, list[dict], OutputModelType]:
170
+ """Unified call and parse with cache and error handling."""
171
+ if self._use_beta:
172
+ return await self._call_and_parse_with_beta(
173
+ messages, response_model, json_schema
174
+ )
175
+
176
+ choice = None
177
+ try:
178
+ # Use unified client call
179
+ completion = await self._unified_client_call(
180
+ messages,
181
+ extra_body={**self.extra_body},
182
+ cache_suffix=f"_parse_{response_model.__name__}",
183
+ )
184
+
185
+ # Parse the response
186
+ choice = completion["choices"][0]["message"]
187
+ if "content" not in choice:
188
+ raise ValueError("Response choice must contain 'content' field.")
189
+
190
+ content = choice["content"]
191
+ if not content:
192
+ raise ValueError("Response content is empty")
193
+
194
+ parsed = response_model.model_validate(jloads_safe(content))
195
+
196
+ except Exception as e:
197
+ # Try fallback to beta mode if regular parsing fails
198
+ if not isinstance(
199
+ e, (AuthenticationError, RateLimitError, BadRequestError)
200
+ ):
201
+ content = choice.get("content", "N/A") if choice else "N/A"
202
+ logger.info(
203
+ f"Regular parsing failed due to wrong format or content, now falling back to beta mode: {content=}, {e=}"
204
+ )
205
+ try:
206
+ return await self._call_and_parse_with_beta(
207
+ messages, response_model, json_schema
208
+ )
209
+ except Exception as beta_e:
210
+ logger.warning(f"Beta mode fallback also failed: {beta_e}")
211
+ choice_info = choice if choice is not None else "N/A"
212
+ raise ValueError(
213
+ f"Failed to parse model response with both regular and beta modes. "
214
+ f"Regular error: {e}. Beta error: {beta_e}. "
215
+ f"Model response message: {choice_info}"
216
+ ) from e
217
+ raise
218
+
219
+ assistant_msg = self._extract_assistant_message(choice)
220
+ full_messages = messages + [assistant_msg]
221
+
222
+ return completion, full_messages, cast(OutputModelType, parsed)
223
+
224
+ async def _call_and_parse_with_beta(
225
+ self,
226
+ messages: list[dict],
227
+ response_model: Type[OutputModelType],
228
+ json_schema: dict,
229
+ ) -> tuple[dict, list[dict], OutputModelType]:
230
+ """Call and parse for beta mode with guided JSON."""
231
+ choice = None
232
+ try:
233
+ # Use unified client call with guided JSON
234
+ completion = await self._unified_client_call(
235
+ messages,
236
+ extra_body={"guided_json": json_schema, **self.extra_body},
237
+ cache_suffix=f"_beta_parse_{response_model.__name__}",
238
+ )
239
+
240
+ # Parse the response
241
+ choice = completion["choices"][0]["message"]
242
+ parsed = self._parse_complete_output(completion, response_model)
243
+
244
+ except Exception as e:
245
+ choice_info = choice if choice is not None else "N/A"
246
+ raise ValueError(
247
+ f"Failed to parse model response: {e}\nModel response message: {choice_info}"
248
+ ) from e
249
+
250
+ assistant_msg = self._extract_assistant_message(choice)
251
+ full_messages = messages + [assistant_msg]
252
+
253
+ return completion, full_messages, cast(OutputModelType, parsed)
254
+
255
+ def _extract_assistant_message(self, choice): # -> dict[str, str] | dict[str, Any]:
256
+ # TODO this current assume choice is a dict with "reasoning_content" and "content"
257
+ has_reasoning = False
258
+ if "reasoning_content" in choice and isinstance(
259
+ choice["reasoning_content"], str
260
+ ):
261
+ reasoning_content = choice["reasoning_content"].strip()
262
+ has_reasoning = True
263
+
264
+ content = choice["content"]
265
+ _content = content.lstrip("\n")
266
+ if has_reasoning:
267
+ assistant_msg = {
268
+ "role": "assistant",
269
+ "content": f"<think>\n{reasoning_content}\n</think>\n\n{_content}",
270
+ }
271
+ else:
272
+ assistant_msg = {"role": "assistant", "content": _content}
273
+
274
+ return assistant_msg
275
+
276
+ async def __call__(
277
+ self,
278
+ prompt: Optional[str] = None,
279
+ messages: Optional[RawMsgs] = None,
280
+ ): # -> tuple[Any | dict[Any, Any], list[ChatCompletionMessagePar...:# -> tuple[Any | dict[Any, Any], list[ChatCompletionMessagePar...:
281
+ """Unified async call for language model, returns (assistant_message.model_dump(), messages)."""
282
+ if (prompt is None) == (messages is None):
283
+ raise ValueError("Provide *either* `prompt` or `messages` (but not both).")
284
+
285
+ if prompt is not None:
286
+ messages = [{"role": "user", "content": prompt}]
287
+
288
+ assert messages is not None
289
+
290
+ openai_msgs: Messages = (
291
+ self._convert_messages(cast(LegacyMsgs, messages))
292
+ if isinstance(messages[0], dict)
293
+ else cast(Messages, messages)
294
+ )
295
+
296
+ assert self.model_kwargs["model"] is not None, (
297
+ "Model must be set before making a call."
298
+ )
299
+
300
+ # Use unified client call
301
+ raw_response = await self._unified_client_call(
302
+ list(openai_msgs), cache_suffix="_call"
303
+ )
304
+
305
+ if hasattr(raw_response, "model_dump"):
306
+ raw_response = raw_response.model_dump() # type: ignore
307
+
308
+ # Extract the assistant's message
309
+ assistant_msg = raw_response["choices"][0]["message"]
310
+ # Build the full messages list (input + assistant reply)
311
+ full_messages = list(messages) + [
312
+ {"role": assistant_msg["role"], "content": assistant_msg["content"]}
313
+ ]
314
+ # Return the OpenAI message as model_dump (if available) and the messages list
315
+ if hasattr(assistant_msg, "model_dump"):
316
+ msg_dump = assistant_msg.model_dump()
317
+ else:
318
+ msg_dump = dict(assistant_msg)
319
+ return msg_dump, full_messages
320
+
321
+ async def parse(
322
+ self,
323
+ instruction,
324
+ prompt,
325
+ ) -> ParsedOutput[BaseModel]:
326
+ """Parse response using guided JSON generation. Returns (parsed.model_dump(), messages)."""
327
+ if not self._use_beta:
328
+ assert self.add_json_schema_to_instruction, (
329
+ "add_json_schema_to_instruction must be True when use_beta is False. otherwise model will not be able to parse the response."
330
+ )
331
+
332
+ assert self.response_model is not None, "response_model must be set at init."
333
+ json_schema = self.response_model.model_json_schema()
334
+
335
+ # Build system message content in a single, clear block
336
+ assert instruction is not None, "Instruction must be provided."
337
+ assert prompt is not None, "Prompt must be provided."
338
+ system_content = instruction
339
+
340
+ # Add schema if needed
341
+ system_content = self.build_system_prompt(
342
+ self.response_model,
343
+ self.add_json_schema_to_instruction,
344
+ json_schema,
345
+ system_content,
346
+ think=self.think,
347
+ )
348
+
349
+ messages = [
350
+ {"role": "system", "content": system_content},
351
+ {"role": "user", "content": prompt},
352
+ ] # type: ignore
353
+
354
+ completion, full_messages, parsed = await self._call_and_parse(
355
+ messages,
356
+ self.response_model,
357
+ json_schema,
358
+ )
359
+
360
+ return ParsedOutput(
361
+ messages=full_messages,
362
+ parsed=cast(BaseModel, parsed),
363
+ completion=completion,
364
+ model_kwargs=self.model_kwargs,
365
+ )
366
+
367
+ def _parse_complete_output(
368
+ self, completion: Any, response_model: Type[BaseModel]
369
+ ) -> BaseModel:
370
+ """Parse completion output to response model."""
371
+ if hasattr(completion, "model_dump"):
372
+ completion = completion.model_dump()
373
+
374
+ if "choices" not in completion or not completion["choices"]:
375
+ raise ValueError("No choices in OpenAI response")
376
+
377
+ content = completion["choices"][0]["message"]["content"]
378
+ if not content:
379
+ raise ValueError("Response content is empty")
380
+
381
+ try:
382
+ data = jloads(content)
383
+ return response_model.model_validate(data)
384
+ except Exception as exc:
385
+ raise ValueError(
386
+ f"Failed to validate against response model {response_model.__name__}: {exc}\nRaw content: {content}"
387
+ ) from exc