speedy-utils 1.1.17__py3-none-any.whl → 1.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_utils/lm/lm.py ADDED
@@ -0,0 +1,207 @@
1
+ # # from ._utils import *
2
+ # from typing import (
3
+ # Any,
4
+ # List,
5
+ # Literal,
6
+ # Optional,
7
+ # Type,
8
+ # Union,
9
+ # cast,
10
+ # )
11
+
12
+ # from loguru import logger
13
+ # from openai import AuthenticationError, BadRequestError, OpenAI, RateLimitError
14
+ # from pydantic import BaseModel
15
+ # from speedy_utils import jloads
16
+
17
+ # # from llm_utils.lm.async_lm.async_llm_task import OutputModelType
18
+ # from llm_utils.lm.lm_base import LMBase
19
+
20
+ # from .async_lm._utils import (
21
+ # LegacyMsgs,
22
+ # Messages,
23
+ # OutputModelType,
24
+ # ParsedOutput,
25
+ # RawMsgs,
26
+ # )
27
+
28
+
29
+ # class LM(LMBase):
30
+ # """Unified **sync** language‑model wrapper with optional JSON parsing."""
31
+
32
+ # def __init__(
33
+ # self,
34
+ # *,
35
+ # model: Optional[str] = None,
36
+ # response_model: Optional[type[BaseModel]] = None,
37
+ # temperature: float = 0.0,
38
+ # max_tokens: int = 2_000,
39
+ # base_url: Optional[str] = None,
40
+ # api_key: Optional[str] = None,
41
+ # cache: bool = True,
42
+ # ports: Optional[List[int]] = None,
43
+ # top_p: float = 1.0,
44
+ # presence_penalty: float = 0.0,
45
+ # top_k: int = 1,
46
+ # repetition_penalty: float = 1.0,
47
+ # frequency_penalty: Optional[float] = None,
48
+ # ) -> None:
49
+
50
+ # if model is None:
51
+ # if base_url is None:
52
+ # raise ValueError("Either model or base_url must be provided")
53
+ # models = OpenAI(base_url=base_url, api_key=api_key or 'abc').models.list().data
54
+ # assert len(models) == 1, f"Found {len(models)} models, please specify one."
55
+ # model = models[0].id
56
+ # print(f"Using model: {model}")
57
+
58
+ # super().__init__(
59
+ # ports=ports,
60
+ # base_url=base_url,
61
+ # cache=cache,
62
+ # api_key=api_key,
63
+ # )
64
+
65
+ # # Model behavior options
66
+ # self.response_model = response_model
67
+
68
+ # # Store all model-related parameters in model_kwargs
69
+ # self.model_kwargs = dict(
70
+ # model=model,
71
+ # temperature=temperature,
72
+ # max_tokens=max_tokens,
73
+ # top_p=top_p,
74
+ # presence_penalty=presence_penalty,
75
+ # )
76
+ # self.extra_body = dict(
77
+ # top_k=top_k,
78
+ # repetition_penalty=repetition_penalty,
79
+ # frequency_penalty=frequency_penalty,
80
+ # )
81
+
82
+ # def _unified_client_call(
83
+ # self,
84
+ # messages: RawMsgs,
85
+ # extra_body: Optional[dict] = None,
86
+ # max_tokens: Optional[int] = None,
87
+ # ) -> dict:
88
+ # """Unified method for all client interactions (caching handled by MOpenAI)."""
89
+ # converted_messages: Messages = (
90
+ # self._convert_messages(cast(LegacyMsgs, messages))
91
+ # if messages and isinstance(messages[0], dict)
92
+ # else cast(Messages, messages)
93
+ # )
94
+ # if max_tokens is not None:
95
+ # self.model_kwargs["max_tokens"] = max_tokens
96
+
97
+ # try:
98
+ # # Get completion from API (caching handled by MOpenAI)
99
+ # call_kwargs = {
100
+ # "messages": converted_messages,
101
+ # **self.model_kwargs,
102
+ # }
103
+ # if extra_body:
104
+ # call_kwargs["extra_body"] = extra_body
105
+
106
+ # completion = self.client.chat.completions.create(**call_kwargs)
107
+
108
+ # if hasattr(completion, "model_dump"):
109
+ # completion = completion.model_dump()
110
+
111
+ # except (AuthenticationError, RateLimitError, BadRequestError) as exc:
112
+ # error_msg = f"OpenAI API error ({type(exc).__name__}): {exc}"
113
+ # logger.error(error_msg)
114
+ # raise
115
+
116
+ # return completion
117
+
118
+ # def __call__(
119
+ # self,
120
+ # prompt: Optional[str] = None,
121
+ # messages: Optional[RawMsgs] = None,
122
+ # max_tokens: Optional[int] = None,
123
+ # ): # -> tuple[Any | dict[Any, Any], list[ChatCompletionMessagePar...:# -> tuple[Any | dict[Any, Any], list[ChatCompletionMessagePar...:
124
+ # """Unified sync call for language model, returns (assistant_message.model_dump(), messages)."""
125
+ # if (prompt is None) == (messages is None):
126
+ # raise ValueError("Provide *either* `prompt` or `messages` (but not both).")
127
+
128
+ # if prompt is not None:
129
+ # messages = [{"role": "user", "content": prompt}]
130
+
131
+ # assert messages is not None
132
+
133
+ # openai_msgs: Messages = (
134
+ # self._convert_messages(cast(LegacyMsgs, messages))
135
+ # if isinstance(messages[0], dict)
136
+ # else cast(Messages, messages)
137
+ # )
138
+
139
+ # assert self.model_kwargs["model"] is not None, (
140
+ # "Model must be set before making a call."
141
+ # )
142
+
143
+ # # Use unified client call
144
+ # raw_response = self._unified_client_call(
145
+ # list(openai_msgs), max_tokens=max_tokens
146
+ # )
147
+
148
+ # if hasattr(raw_response, "model_dump"):
149
+ # raw_response = raw_response.model_dump() # type: ignore
150
+
151
+ # # Extract the assistant's message
152
+ # assistant_msg = raw_response["choices"][0]["message"]
153
+ # # Build the full messages list (input + assistant reply)
154
+ # full_messages = list(messages) + [
155
+ # {"role": assistant_msg["role"], "content": assistant_msg["content"]}
156
+ # ]
157
+ # # Return the OpenAI message as model_dump (if available) and the messages list
158
+ # if hasattr(assistant_msg, "model_dump"):
159
+ # msg_dump = assistant_msg.model_dump()
160
+ # else:
161
+ # msg_dump = dict(assistant_msg)
162
+ # return msg_dump, full_messages
163
+
164
+ # def parse(
165
+ # self,
166
+ # messages: Messages,
167
+ # response_model: Optional[type[BaseModel]] = None,
168
+ # ) -> ParsedOutput[BaseModel]:
169
+ # """Parse response using OpenAI's native parse API."""
170
+ # # Use provided response_model or fall back to instance default
171
+ # model_to_use = response_model or self.response_model
172
+ # assert model_to_use is not None, "response_model must be provided or set at init."
173
+
174
+ # # Use OpenAI's native parse API directly
175
+ # response = self.client.chat.completions.parse(
176
+ # model=self.model_kwargs["model"],
177
+ # messages=messages,
178
+ # response_format=model_to_use,
179
+ # **{k: v for k, v in self.model_kwargs.items() if k != "model"}
180
+ # )
181
+
182
+ # parsed = response.choices[0].message.parsed
183
+ # completion = response.model_dump() if hasattr(response, "model_dump") else {}
184
+ # full_messages = list(messages) + [
185
+ # {"role": "assistant", "content": parsed}
186
+ # ]
187
+
188
+ # return ParsedOutput(
189
+ # messages=full_messages,
190
+ # parsed=cast(BaseModel, parsed),
191
+ # completion=completion,
192
+ # model_kwargs=self.model_kwargs,
193
+ # )
194
+
195
+
196
+
197
+ # def __enter__(self):
198
+ # return self
199
+
200
+ # def __exit__(self, exc_type, exc_val, exc_tb):
201
+ # if hasattr(self, "_last_client"):
202
+ # last_client = self._last_client # type: ignore
203
+ # if hasattr(last_client, "close"):
204
+ # last_client.close()
205
+ # else:
206
+ # logger.warning("No last client to close")
207
+ LM = None
@@ -0,0 +1,285 @@
1
+ # from ._utils import *
2
+ import json
3
+ import os
4
+ from typing import (
5
+ Any,
6
+ List,
7
+ Optional,
8
+ Type,
9
+ Union,
10
+ cast,
11
+ overload,
12
+ )
13
+
14
+ from httpx import URL
15
+ from loguru import logger
16
+ from openai import OpenAI
17
+ from openai.pagination import SyncPage
18
+ from openai.types.chat import (
19
+ ChatCompletionAssistantMessageParam,
20
+ ChatCompletionSystemMessageParam,
21
+ ChatCompletionToolMessageParam,
22
+ ChatCompletionUserMessageParam,
23
+ )
24
+ from openai.types.model import Model
25
+ from pydantic import BaseModel
26
+
27
+ from llm_utils.lm.openai_memoize import MOpenAI
28
+
29
+ from .async_lm._utils import (
30
+ LegacyMsgs,
31
+ Messages,
32
+ RawMsgs,
33
+ TModel,
34
+ )
35
+
36
+
37
+ class LMBase:
38
+ """Unified **sync** language‑model wrapper with optional JSON parsing."""
39
+
40
+ def __init__(
41
+ self,
42
+ *,
43
+ base_url: Optional[str] = None,
44
+ api_key: Optional[str] = None,
45
+ cache: bool = True,
46
+ ports: Optional[List[int]] = None,
47
+ ) -> None:
48
+ self.base_url = base_url
49
+ self.api_key = api_key or os.getenv("OPENAI_API_KEY", "abc")
50
+ self._cache = cache
51
+ self.ports = ports
52
+
53
+ @property
54
+ def client(self) -> MOpenAI:
55
+ # if have multiple ports
56
+ if self.ports and self.base_url:
57
+ import random
58
+ import re
59
+
60
+ port = random.choice(self.ports)
61
+ # Replace port in base_url if it exists
62
+ base_url_pattern = r'(https?://[^:/]+):?\d*(/.*)?'
63
+ match = re.match(base_url_pattern, self.base_url)
64
+ if match:
65
+ host_part = match.group(1)
66
+ path_part = match.group(2) or '/v1'
67
+ api_base = f"{host_part}:{port}{path_part}"
68
+ else:
69
+ api_base = self.base_url
70
+ logger.debug(f"Using port: {port}")
71
+ else:
72
+ api_base = self.base_url
73
+
74
+ if api_base is None:
75
+ raise ValueError("base_url must be provided")
76
+
77
+ client = MOpenAI(
78
+ api_key=self.api_key,
79
+ base_url=api_base,
80
+ cache=self._cache,
81
+ )
82
+ self._last_client = client
83
+ return client
84
+
85
+ # ------------------------------------------------------------------ #
86
+ # Public API – typed overloads
87
+ # ------------------------------------------------------------------ #
88
+ @overload
89
+ def __call__( # type: ignore
90
+ self,
91
+ *,
92
+ prompt: Optional[str] = ...,
93
+ messages: Optional[RawMsgs] = ...,
94
+ response_format: type[str] = str,
95
+ return_openai_response: bool = ...,
96
+ **kwargs: Any,
97
+ ) -> str: ...
98
+
99
+ @overload
100
+ def __call__(
101
+ self,
102
+ *,
103
+ prompt: Optional[str] = ...,
104
+ messages: Optional[RawMsgs] = ...,
105
+ response_format: Type[TModel],
106
+ return_openai_response: bool = ...,
107
+ **kwargs: Any,
108
+ ) -> TModel: ...
109
+
110
+ # ------------------------------------------------------------------ #
111
+ # Utilities below are unchanged (sync I/O is acceptable)
112
+ # ------------------------------------------------------------------ #
113
+ @staticmethod
114
+ def _convert_messages(msgs: LegacyMsgs) -> Messages:
115
+ converted: Messages = []
116
+ for msg in msgs:
117
+ role = msg["role"]
118
+ content = msg["content"]
119
+ if role == "user":
120
+ converted.append(
121
+ ChatCompletionUserMessageParam(role="user", content=content)
122
+ )
123
+ elif role == "assistant":
124
+ converted.append(
125
+ ChatCompletionAssistantMessageParam(
126
+ role="assistant", content=content
127
+ )
128
+ )
129
+ elif role == "system":
130
+ converted.append(
131
+ ChatCompletionSystemMessageParam(role="system", content=content)
132
+ )
133
+ elif role == "tool":
134
+ converted.append(
135
+ ChatCompletionToolMessageParam(
136
+ role="tool",
137
+ content=content,
138
+ tool_call_id=msg.get("tool_call_id") or "",
139
+ )
140
+ )
141
+ else:
142
+ converted.append({"role": role, "content": content}) # type: ignore[arg-type]
143
+ return converted
144
+
145
+ @staticmethod
146
+ def _parse_output(
147
+ raw_response: Any, response_format: Union[type[str], Type[BaseModel]]
148
+ ) -> Union[str, BaseModel]:
149
+ if hasattr(raw_response, "model_dump"):
150
+ raw_response = raw_response.model_dump()
151
+
152
+ if response_format is str:
153
+ if isinstance(raw_response, dict) and "choices" in raw_response:
154
+ message = raw_response["choices"][0]["message"]
155
+ return message.get("content", "") or ""
156
+ return cast(str, raw_response)
157
+
158
+ model_cls = cast(Type[BaseModel], response_format)
159
+
160
+ if isinstance(raw_response, dict) and "choices" in raw_response:
161
+ message = raw_response["choices"][0]["message"]
162
+ if "parsed" in message:
163
+ return model_cls.model_validate(message["parsed"])
164
+ content = message.get("content")
165
+ if content is None:
166
+ raise ValueError("Model returned empty content")
167
+ try:
168
+ data = json.loads(content)
169
+ return model_cls.model_validate(data)
170
+ except Exception as exc:
171
+ raise ValueError(
172
+ f"Failed to parse model output as JSON:\n{content}"
173
+ ) from exc
174
+
175
+ if isinstance(raw_response, model_cls):
176
+ return raw_response
177
+ if isinstance(raw_response, dict):
178
+ return model_cls.model_validate(raw_response)
179
+
180
+ try:
181
+ data = json.loads(raw_response)
182
+ return model_cls.model_validate(data)
183
+ except Exception as exc:
184
+ raise ValueError(
185
+ f"Model did not return valid JSON:\n---\n{raw_response}"
186
+ ) from exc
187
+
188
+ # ------------------------------------------------------------------ #
189
+ # Misc helpers
190
+ # ------------------------------------------------------------------ #
191
+
192
+ @staticmethod
193
+ def list_models(base_url: Optional[str] = None) -> List[str]:
194
+ try:
195
+ if base_url is None:
196
+ raise ValueError("base_url must be provided")
197
+ client = LMBase(base_url=base_url).client
198
+ base_url_obj: URL = client.base_url
199
+ logger.debug(f"Base URL: {base_url_obj}")
200
+ models: SyncPage[Model] = client.models.list() # type: ignore[assignment]
201
+ return [model.id for model in models.data]
202
+ except Exception as exc:
203
+ logger.error(f"Failed to list models: {exc}")
204
+ return []
205
+
206
+ def build_system_prompt(
207
+ self,
208
+ response_model,
209
+ add_json_schema_to_instruction,
210
+ json_schema,
211
+ system_content,
212
+ think,
213
+ ):
214
+ if add_json_schema_to_instruction and response_model:
215
+ schema_block = f"\n\n<output_json_schema>\n{json.dumps(json_schema, indent=2)}\n</output_json_schema>"
216
+ # if schema_block not in system_content:
217
+ if "<output_json_schema>" in system_content:
218
+ # remove exsting schema block
219
+ import re # replace
220
+
221
+ system_content = re.sub(
222
+ r"<output_json_schema>.*?</output_json_schema>",
223
+ "",
224
+ system_content,
225
+ flags=re.DOTALL,
226
+ )
227
+ system_content = system_content.strip()
228
+ system_content += schema_block
229
+
230
+ if think is True:
231
+ if "/think" in system_content:
232
+ pass
233
+ elif "/no_think" in system_content:
234
+ system_content = system_content.replace("/no_think", "/think")
235
+ else:
236
+ system_content += "\n\n/think"
237
+ elif think is False:
238
+ if "/no_think" in system_content:
239
+ pass
240
+ elif "/think" in system_content:
241
+ system_content = system_content.replace("/think", "/no_think")
242
+ else:
243
+ system_content += "\n\n/no_think"
244
+ return system_content
245
+
246
+ def inspect_history(self):
247
+ """Inspect the history of the LLM calls."""
248
+ pass
249
+
250
+
251
+ def get_model_name(client: OpenAI|str|int) -> str:
252
+ """
253
+ Get the first available model name from the client.
254
+
255
+ Args:
256
+ client: OpenAI client, base_url string, or port number
257
+
258
+ Returns:
259
+ Name of the first available model
260
+
261
+ Raises:
262
+ ValueError: If no models are available or client is invalid
263
+ """
264
+ try:
265
+ if isinstance(client, OpenAI):
266
+ openai_client = client
267
+ elif isinstance(client, str):
268
+ # String base_url
269
+ openai_client = OpenAI(base_url=client, api_key='abc')
270
+ elif isinstance(client, int):
271
+ # Port number
272
+ base_url = f"http://localhost:{client}/v1"
273
+ openai_client = OpenAI(base_url=base_url, api_key='abc')
274
+ else:
275
+ raise ValueError(f"Unsupported client type: {type(client)}")
276
+
277
+ models = openai_client.models.list()
278
+ if not models.data:
279
+ raise ValueError("No models available")
280
+
281
+ return models.data[0].id
282
+
283
+ except Exception as exc:
284
+ logger.error(f"Failed to get model name: {exc}")
285
+ raise ValueError(f"Could not retrieve model name: {exc}") from exc