speedy-utils 1.0.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_utils/__init__.py CHANGED
@@ -8,7 +8,7 @@ from .chat_format import (
8
8
  build_chatml_input,
9
9
  format_msgs,
10
10
  )
11
- from .lm import OAI_LM, LM
11
+ from .lm import PydanticLM, TextLM
12
12
  from .group_messages import (
13
13
  split_indices_by_length,
14
14
  group_messages_by_len,
@@ -23,8 +23,8 @@ __all__ = [
23
23
  "display_conversations",
24
24
  "build_chatml_input",
25
25
  "format_msgs",
26
- "OAI_LM",
27
- "LM",
28
26
  "split_indices_by_length",
29
27
  "group_messages_by_len",
28
+ "PydanticLM",
29
+ "TextLM",
30
30
  ]
@@ -0,0 +1,12 @@
1
+ from .base_lm import LM
2
+ from .text_lm import TextLM
3
+ from .pydantic_lm import PydanticLM
4
+ from .chat_session import ChatSession, Message
5
+
6
+ __all__ = [
7
+ "LM",
8
+ "TextLM",
9
+ "PydanticLM",
10
+ "ChatSession",
11
+ "Message",
12
+ ]
@@ -0,0 +1,337 @@
1
+ import os
2
+ import random
3
+ import time
4
+ from typing import (
5
+ Any,
6
+ List,
7
+ Literal,
8
+ Optional,
9
+ Type,
10
+ TypeVar,
11
+ Union,
12
+ Dict,
13
+ overload,
14
+ Tuple,
15
+ )
16
+ from pydantic import BaseModel
17
+ from speedy_utils import dump_json_or_pickle, identify_uuid, load_json_or_pickle
18
+ from loguru import logger
19
+ from copy import deepcopy
20
+ import numpy as np
21
+ import tempfile
22
+ import fcntl
23
+
24
+
25
+ class LM:
26
+
27
+ def __init__(
28
+ self,
29
+ model: Optional[str] = None,
30
+ model_type: Literal["chat", "text"] = "chat",
31
+ temperature: float = 0.0,
32
+ max_tokens: int = 2000,
33
+ cache: bool = True,
34
+ callbacks: Optional[Any] = None,
35
+ num_retries: int = 3,
36
+ host: str = "localhost",
37
+ port: Optional[int] = None,
38
+ ports: Optional[List[int]] = None,
39
+ api_key: Optional[str] = None,
40
+ system_prompt: Optional[str] = None,
41
+ **kwargs,
42
+ ):
43
+ from openai import OpenAI
44
+
45
+ self.ports = ports
46
+ self.host = host
47
+ self.api_key = api_key or os.getenv("OPENAI_API_KEY", "abc")
48
+ resolved_base_url_from_kwarg = kwargs.get("base_url")
49
+ if resolved_base_url_from_kwarg is not None and not isinstance(
50
+ resolved_base_url_from_kwarg, str
51
+ ):
52
+ logger.warning(
53
+ f"base_url in kwargs was not a string ({type(resolved_base_url_from_kwarg)}), ignoring."
54
+ )
55
+ resolved_base_url_from_kwarg = None
56
+ resolved_base_url: Optional[str] = resolved_base_url_from_kwarg
57
+ if resolved_base_url is None:
58
+ selected_port = port
59
+ if selected_port is None and ports is not None and len(ports) > 0:
60
+ selected_port = ports[0]
61
+ if selected_port is not None:
62
+ resolved_base_url = f"http://{host}:{selected_port}/v1"
63
+ self.base_url = resolved_base_url
64
+
65
+ if model is None:
66
+ if self.base_url:
67
+ try:
68
+ model_list = self.list_models()
69
+ if model_list:
70
+ model_name_from_list = model_list[0]
71
+ model = model_name_from_list
72
+ logger.info(f"Using default model: {model}")
73
+ else:
74
+ logger.warning(
75
+ f"No models found at {self.base_url}. Please specify a model."
76
+ )
77
+ except Exception as e:
78
+ example_cmd = (
79
+ "LM.start_server('unsloth/gemma-3-1b-it')\n"
80
+ "# Or manually run: svllm serve --model unsloth/gemma-3-1b-it --gpus 0 -hp localhost:9150"
81
+ )
82
+ logger.error(
83
+ f"Failed to list models from {self.base_url}: {e}\n"
84
+ f"Make sure your model server is running and accessible.\n"
85
+ f"Example to start a server:\n{example_cmd}"
86
+ )
87
+ else:
88
+ logger.warning(
89
+ "base_url not configured, cannot fetch default model. Please specify a model."
90
+ )
91
+ assert (
92
+ model is not None
93
+ ), "Model name must be provided or discoverable via list_models"
94
+
95
+ # Remove 'openai/' prefix if present
96
+ if model.startswith("openai/"):
97
+ model = model[7:]
98
+
99
+ self.kwargs = {"temperature": temperature, "max_tokens": max_tokens, **kwargs}
100
+ self.model = model
101
+ self.model_type = model_type
102
+ self.num_retries = num_retries
103
+ self.do_cache = cache
104
+ self.callbacks = callbacks
105
+
106
+ # Initialize OpenAI client
107
+ self.openai_client = OpenAI(api_key=self.api_key, base_url=self.base_url)
108
+ self.system_prompt = system_prompt
109
+
110
+ def dump_cache(
111
+ self, id: str, result: Union[str, BaseModel, List[Union[str, BaseModel]]]
112
+ ):
113
+ try:
114
+ cache_file = f"~/.cache/oai_lm/{self.model}/{id}.pkl"
115
+ cache_file = os.path.expanduser(cache_file)
116
+ dump_json_or_pickle(result, cache_file)
117
+ except Exception as e:
118
+ logger.warning(f"Cache dump failed: {e}")
119
+
120
+ def load_cache(
121
+ self, id: str
122
+ ) -> Optional[Union[str, BaseModel, List[Union[str, BaseModel]]]]:
123
+ try:
124
+ cache_file = f"~/.cache/oai_lm/{self.model}/{id}.pkl"
125
+ cache_file = os.path.expanduser(cache_file)
126
+ if not os.path.exists(cache_file):
127
+ return None
128
+ return load_json_or_pickle(cache_file)
129
+ except Exception as e:
130
+ logger.warning(f"Cache load failed for {id}: {e}")
131
+ return None
132
+
133
+ def list_models(self) -> List[str]:
134
+ from openai import OpenAI
135
+
136
+ if not self.base_url:
137
+ raise ValueError("Cannot list models: base_url is not configured.")
138
+ if not self.api_key:
139
+ logger.warning(
140
+ "API key not available for listing models. Using default 'abc'."
141
+ )
142
+ api_key_str = str(self.api_key) if self.api_key is not None else "abc"
143
+ base_url_str = str(self.base_url) if self.base_url is not None else None
144
+ if isinstance(self.base_url, float):
145
+ raise TypeError(
146
+ f"base_url must be a string or None, got float: {self.base_url}"
147
+ )
148
+ client = OpenAI(base_url=base_url_str, api_key=api_key_str)
149
+ try:
150
+ page = client.models.list()
151
+ return [d.id for d in page.data]
152
+ except Exception as e:
153
+ logger.error(f"Error listing models: {e}")
154
+ return []
155
+
156
+ def get_least_used_port(self) -> int:
157
+ if self.ports is None:
158
+ raise ValueError("Ports must be configured to pick the least used port.")
159
+ if not self.ports:
160
+ raise ValueError("Ports list is empty, cannot pick a port.")
161
+ return self._pick_least_used_port(self.ports)
162
+
163
+ def _pick_least_used_port(self, ports: List[int]) -> int:
164
+ global_lock_file = "/tmp/ports.lock"
165
+ with open(global_lock_file, "w") as lock_file:
166
+ fcntl.flock(lock_file, fcntl.LOCK_EX)
167
+ try:
168
+ port_use: Dict[int, int] = {}
169
+ for port in ports:
170
+ file_counter = f"/tmp/port_use_counter_{port}.npy"
171
+ if os.path.exists(file_counter):
172
+ try:
173
+ counter = np.load(file_counter)
174
+ except Exception as e:
175
+ logger.warning(f"Corrupted usage file {file_counter}: {e}")
176
+ counter = np.array([0])
177
+ else:
178
+ counter = np.array([0])
179
+ port_use[port] = counter[0]
180
+ if not port_use:
181
+ if ports:
182
+ raise ValueError(
183
+ "Port usage data is empty, cannot pick a port."
184
+ )
185
+ else:
186
+ raise ValueError("No ports provided to pick from.")
187
+ lsp = min(port_use, key=lambda k: port_use[k])
188
+ self._update_port_use(lsp, 1)
189
+ finally:
190
+ fcntl.flock(lock_file, fcntl.LOCK_UN)
191
+ return lsp
192
+
193
+ def _update_port_use(self, port: int, increment: int):
194
+ file_counter = f"/tmp/port_use_counter_{port}.npy"
195
+ file_counter_lock = f"/tmp/port_use_counter_{port}.lock"
196
+ with open(file_counter_lock, "w") as lock_file:
197
+ fcntl.flock(lock_file, fcntl.LOCK_EX)
198
+ try:
199
+ if os.path.exists(file_counter):
200
+ try:
201
+ counter = np.load(file_counter)
202
+ except Exception as e:
203
+ logger.warning(f"Corrupted usage file {file_counter}: {e}")
204
+ counter = np.array([0])
205
+ else:
206
+ counter = np.array([0])
207
+ counter[0] += increment
208
+ self._atomic_save(counter, file_counter)
209
+ finally:
210
+ fcntl.flock(lock_file, fcntl.LOCK_UN)
211
+
212
+ def _atomic_save(self, array: np.ndarray, filename: str):
213
+ tmp_dir: str = os.path.dirname(filename) or "."
214
+ with tempfile.NamedTemporaryFile(dir=tmp_dir, delete=False) as tmp:
215
+ np.save(tmp, array)
216
+ temp_name: str = tmp.name
217
+ os.replace(temp_name, filename)
218
+
219
+ def _prepare_call_inputs(
220
+ self,
221
+ messages: List[Any],
222
+ max_tokens: Optional[int],
223
+ port: Optional[int],
224
+ use_loadbalance: Optional[bool],
225
+ cache: Optional[bool],
226
+ **kwargs,
227
+ ) -> Tuple[dict, bool, Optional[int], List[Any]]:
228
+ """Prepare inputs for the LLM call."""
229
+ # Prepare kwargs
230
+ effective_kwargs = {**self.kwargs, **kwargs}
231
+ if max_tokens is not None:
232
+ effective_kwargs["max_tokens"] = max_tokens
233
+
234
+ # Set effective cache
235
+ effective_cache = cache if cache is not None else self.do_cache
236
+
237
+ # Setup port
238
+ current_port = port
239
+ if self.ports and not current_port:
240
+ current_port = (
241
+ self.get_least_used_port()
242
+ if use_loadbalance
243
+ else random.choice(self.ports)
244
+ )
245
+ if current_port:
246
+ base_url = f"http://{self.host}:{current_port}/v1"
247
+ effective_kwargs["base_url"] = base_url
248
+ # Update client with new base_url
249
+ from openai import OpenAI
250
+
251
+ self.openai_client = OpenAI(api_key=self.api_key, base_url=base_url)
252
+
253
+ return effective_kwargs, effective_cache, current_port, messages
254
+
255
+ def _call_llm(
256
+ self,
257
+ dspy_main_input: List[Any],
258
+ current_port: Optional[int],
259
+ use_loadbalance: Optional[bool],
260
+ **kwargs,
261
+ ) -> Any:
262
+ raise NotImplementedError("This method should be implemented in subclasses.")
263
+ # """Call the OpenAI API directly and get raw output (no retries)."""
264
+ # try:
265
+ # # Handle message list
266
+ # response = self.openai_client.chat.completions.create(
267
+ # model=self.model, messages=dspy_main_input, **kwargs
268
+ # )
269
+
270
+ # # Update port usage stats if needed
271
+ # if current_port and use_loadbalance is True:
272
+ # self._update_port_use(current_port, -1)
273
+
274
+ # return response.choices[0].message.content
275
+
276
+ # except Exception as e:
277
+ # logger.error(f"API call failed: {e}")
278
+ # raise
279
+
280
+ def _generate_cache_key_base(
281
+ self,
282
+ messages: List[Any],
283
+ effective_kwargs: dict,
284
+ ) -> List[Any]:
285
+ """Base method to generate cache key components."""
286
+ return [
287
+ messages,
288
+ effective_kwargs.get("temperature"),
289
+ effective_kwargs.get("max_tokens"),
290
+ self.model,
291
+ ]
292
+
293
+ def _store_in_cache_base(
294
+ self, effective_cache: bool, id_for_cache: Optional[str], result: Any
295
+ ):
296
+ """Base method to store result in cache if caching is enabled."""
297
+ if effective_cache and id_for_cache:
298
+ self.dump_cache(id_for_cache, result)
299
+
300
+ def __call__(
301
+ self,
302
+ prompt: Optional[str] = None,
303
+ system_prompt: Optional[str] = None,
304
+ messages: Optional[List[Any]] = None,
305
+ **kwargs,
306
+ ) -> Any:
307
+ """
308
+ If have prompt but not messages, convert prompt to messages.
309
+ If both raise
310
+ If neither, raise
311
+ """
312
+ if prompt is not None and messages is not None:
313
+ raise ValueError("Cannot provide both prompt and messages.")
314
+ if prompt is None and messages is None:
315
+ raise ValueError("Either prompt or messages must be provided.")
316
+
317
+ # Convert prompt to messages if needed
318
+ if prompt is not None:
319
+ effective_system_prompt = system_prompt or self.system_prompt
320
+ if effective_system_prompt is not None:
321
+ messages = [
322
+ {"role": "system", "content": effective_system_prompt},
323
+ {"role": "user", "content": prompt},
324
+ ]
325
+ else:
326
+ messages = [{"role": "user", "content": prompt}]
327
+
328
+ # Call the LLM with the prepared inputs
329
+ assert messages is not None, "messages must not be None"
330
+ return self.forward_messages(messages=messages, **kwargs)
331
+
332
+ def forward_messages(
333
+ self,
334
+ messages: List[Any],
335
+ **kwargs,
336
+ ) -> str:
337
+ raise NotImplementedError
@@ -0,0 +1,115 @@
1
+ from copy import deepcopy
2
+ import time # Add time for possible retries
3
+ from typing import (
4
+ Any,
5
+ List,
6
+ Literal,
7
+ Optional,
8
+ TypedDict,
9
+ Type,
10
+ Union,
11
+ TypeVar,
12
+ overload,
13
+ )
14
+ from pydantic import BaseModel
15
+
16
+ T = TypeVar("T", bound=BaseModel)
17
+
18
+
19
+ class Message(TypedDict):
20
+ role: Literal["user", "assistant", "system"]
21
+ content: str | BaseModel
22
+
23
+
24
+ class ChatSession:
25
+ def __init__(
26
+ self,
27
+ lm: Any,
28
+ system_prompt: Optional[str] = None,
29
+ history: List[Message] = [],
30
+ callback=None,
31
+ response_format: Optional[Type[BaseModel]] = None,
32
+ ):
33
+ self.lm = deepcopy(lm)
34
+ self.history = deepcopy(history)
35
+ self.callback = callback
36
+ self.response_format = response_format
37
+ if system_prompt:
38
+ system_message: Message = {
39
+ "role": "system",
40
+ "content": system_prompt,
41
+ }
42
+ self.history.insert(0, system_message)
43
+
44
+ def __len__(self):
45
+ return len(self.history)
46
+
47
+ @overload
48
+ def __call__(
49
+ self, text, response_format: Type[T], display=False, max_prev_turns=3, **kwargs
50
+ ) -> T: ...
51
+ @overload
52
+ def __call__(
53
+ self,
54
+ text,
55
+ response_format: None = None,
56
+ display=False,
57
+ max_prev_turns=3,
58
+ **kwargs,
59
+ ) -> str: ...
60
+ def __call__(
61
+ self,
62
+ text,
63
+ response_format: Optional[Type[BaseModel]] = None,
64
+ display=False,
65
+ max_prev_turns=3,
66
+ **kwargs,
67
+ ) -> Union[str, BaseModel]:
68
+ current_response_format = response_format or self.response_format
69
+ self.history.append({"role": "user", "content": text})
70
+ output = self.lm(
71
+ messages=self.parse_history(),
72
+ response_format=current_response_format,
73
+ **kwargs,
74
+ )
75
+ if isinstance(output, BaseModel):
76
+ self.history.append({"role": "assistant", "content": output})
77
+ else:
78
+ assert response_format is None
79
+ self.history.append({"role": "assistant", "content": output})
80
+ if display:
81
+ self.inspect_history(max_prev_turns=max_prev_turns)
82
+ if self.callback:
83
+ self.callback(self, output)
84
+ return output
85
+
86
+ def send_message(self, text, **kwargs):
87
+ return self.__call__(text, **kwargs)
88
+
89
+ def parse_history(self, indent=None):
90
+ parsed_history = []
91
+ for m in self.history:
92
+ if isinstance(m["content"], str):
93
+ parsed_history.append(m)
94
+ elif isinstance(m["content"], BaseModel):
95
+ parsed_history.append(
96
+ {
97
+ "role": m["role"],
98
+ "content": m["content"].model_dump_json(indent=indent),
99
+ }
100
+ )
101
+ else:
102
+ raise ValueError(f"Unexpected content type: {type(m['content'])}")
103
+ return parsed_history
104
+
105
+ def inspect_history(self, max_prev_turns=3):
106
+ from llm_utils import display_chat_messages_as_html
107
+
108
+ h = self.parse_history(indent=2)
109
+ try:
110
+ from IPython.display import clear_output
111
+
112
+ clear_output()
113
+ display_chat_messages_as_html(h[-max_prev_turns * 2 :])
114
+ except:
115
+ pass
@@ -0,0 +1,195 @@
1
+ import random
2
+ import time
3
+ from typing import Any, List, Optional, Type, Union, cast, TypeVar, Generic
4
+
5
+ from openai import AuthenticationError, RateLimitError
6
+ from pydantic import BaseModel
7
+
8
+ from speedy_utils.common.logger import logger
9
+ from speedy_utils.common.utils_cache import identify_uuid
10
+
11
+ from .base_lm import LM
12
+
13
+ T = TypeVar("T", bound=BaseModel)
14
+
15
+
16
+ class PydanticLM(LM):
17
+ """
18
+ Language model that returns outputs as Pydantic models.
19
+ """
20
+
21
+ def _generate_cache_key(
22
+ self,
23
+ messages: List[Any],
24
+ response_format: Optional[Type[BaseModel]],
25
+ kwargs: dict,
26
+ ) -> str:
27
+ """
28
+ Generate a cache key based on input parameters.
29
+ """
30
+ cache_key_base = self._generate_cache_key_base(messages, kwargs)
31
+ cache_key_base.insert(
32
+ 1, (response_format.model_json_schema() if response_format else None)
33
+ )
34
+ return identify_uuid(str(cache_key_base))
35
+
36
+ def _parse_cached_result(
37
+ self, cached_result: Any, response_format: Optional[Type[BaseModel]]
38
+ ) -> Optional[BaseModel]:
39
+ """
40
+ Parse cached result into a BaseModel instance.
41
+ """
42
+ if isinstance(cached_result, BaseModel):
43
+ return cached_result
44
+ elif isinstance(cached_result, str):
45
+ if response_format is None:
46
+ raise ValueError(
47
+ "response_format must be provided to parse cached string result."
48
+ )
49
+ import json
50
+
51
+ return response_format.model_validate_json(cached_result)
52
+ elif (
53
+ isinstance(cached_result, list)
54
+ and cached_result
55
+ and isinstance(cached_result[0], (str, BaseModel))
56
+ ):
57
+ first = cached_result[0]
58
+ if isinstance(first, BaseModel):
59
+ return first
60
+ elif isinstance(first, str):
61
+ if response_format is None:
62
+ raise ValueError(
63
+ "response_format must be provided to parse cached string result."
64
+ )
65
+ import json
66
+
67
+ return response_format.model_validate_json(first)
68
+ else:
69
+ logger.warning(
70
+ f"Cached result has unexpected type {type(cached_result)}. Ignoring cache."
71
+ )
72
+ return None
73
+
74
+ def _check_cache(
75
+ self,
76
+ effective_cache: bool,
77
+ messages: List[Any],
78
+ response_format: Optional[Type[BaseModel]],
79
+ effective_kwargs: dict,
80
+ ):
81
+ """Check if result is in cache and return it if available."""
82
+ if not effective_cache:
83
+ return None, None
84
+
85
+ cache_id = self._generate_cache_key(messages, response_format, effective_kwargs)
86
+ cached_result = self.load_cache(cache_id)
87
+ parsed_cache = self._parse_cached_result(cached_result, response_format)
88
+
89
+ return cache_id, parsed_cache
90
+
91
+ def _call_llm(
92
+ self,
93
+ dspy_main_input: List[Any],
94
+ response_format: Type[BaseModel],
95
+ current_port: Optional[int],
96
+ use_loadbalance: Optional[bool],
97
+ **kwargs,
98
+ ):
99
+ """Call the LLM with response format support using OpenAI's parse method."""
100
+ # Use messages directly
101
+ messages = dspy_main_input
102
+
103
+ # Use OpenAI's parse method for structured output
104
+ try:
105
+ response = self.openai_client.beta.chat.completions.parse(
106
+ model=self.model,
107
+ messages=messages,
108
+ response_format=response_format,
109
+ **kwargs,
110
+ )
111
+ except AuthenticationError as e:
112
+ logger.error(f"Authentication error: {e}")
113
+ raise
114
+ except TimeoutError as e:
115
+ logger.error(f"Timeout error: {e}")
116
+ raise
117
+ except RateLimitError as e:
118
+ logger.error(f"Rate limit exceeded: {e}")
119
+ raise
120
+ # Update port usage stats if needed
121
+ if current_port and use_loadbalance is True:
122
+ self._update_port_use(current_port, -1)
123
+
124
+ return response.choices[0].message.parsed
125
+
126
+ def _parse_llm_output(
127
+ self, llm_output: Any, response_format: Optional[Type[BaseModel]]
128
+ ) -> BaseModel:
129
+ """Parse the LLM output into the correct format."""
130
+ if isinstance(llm_output, BaseModel):
131
+ return llm_output
132
+ elif isinstance(llm_output, dict):
133
+ if not response_format:
134
+ raise ValueError("response_format required to parse dict output.")
135
+ return response_format.model_validate(llm_output)
136
+ elif isinstance(llm_output, str):
137
+ if not response_format:
138
+ raise ValueError("response_format required to parse string output.")
139
+ import json
140
+
141
+ return response_format.model_validate_json(llm_output)
142
+ else:
143
+ if not response_format:
144
+ raise ValueError("response_format required to parse output.")
145
+ return response_format.model_validate_json(str(llm_output))
146
+
147
+ def _store_in_cache(
148
+ self, effective_cache: bool, cache_id: Optional[str], result: BaseModel
149
+ ):
150
+ """Store the result in cache if caching is enabled."""
151
+ if result and isinstance(result, BaseModel):
152
+ self._store_in_cache_base(
153
+ effective_cache, cache_id, result.model_dump_json()
154
+ )
155
+
156
+ def forward_messages(
157
+ self,
158
+ response_format: Type[T],
159
+ messages: List[Any],
160
+ cache: Optional[bool] = None,
161
+ port: Optional[int] = None,
162
+ use_loadbalance: Optional[bool] = None,
163
+ max_tokens: Optional[int] = None,
164
+ **kwargs,
165
+ ) -> T:
166
+ # 1. Prepare inputs
167
+ effective_kwargs, effective_cache, current_port, dspy_main_input = (
168
+ self._prepare_call_inputs(
169
+ messages, max_tokens, port, use_loadbalance, cache, **kwargs
170
+ )
171
+ )
172
+
173
+ # 2. Check cache
174
+ cache_id, cached_result = self._check_cache(
175
+ effective_cache, messages, response_format, effective_kwargs
176
+ )
177
+ if cached_result:
178
+ return cast(T, cached_result)
179
+
180
+ # 3. Call LLM using OpenAI's parse method
181
+ llm_output = self._call_llm(
182
+ dspy_main_input,
183
+ response_format,
184
+ current_port,
185
+ use_loadbalance,
186
+ **effective_kwargs,
187
+ )
188
+
189
+ # 4. Parse output
190
+ result = self._parse_llm_output(llm_output, response_format)
191
+
192
+ # 5. Store in cache
193
+ self._store_in_cache(effective_cache, cache_id, result)
194
+
195
+ return cast(T, result)