speedy-utils 1.1.22__py3-none-any.whl → 1.1.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +19 -7
- llm_utils/chat_format/__init__.py +2 -0
- llm_utils/chat_format/display.py +115 -44
- llm_utils/lm/__init__.py +20 -2
- llm_utils/lm/llm.py +413 -0
- llm_utils/lm/llm_signature.py +35 -0
- llm_utils/lm/mixins.py +379 -0
- llm_utils/lm/openai_memoize.py +18 -7
- llm_utils/lm/signature.py +271 -0
- llm_utils/lm/utils.py +61 -76
- speedy_utils/__init__.py +28 -1
- speedy_utils/all.py +30 -1
- speedy_utils/common/utils_io.py +36 -26
- speedy_utils/common/utils_misc.py +25 -1
- speedy_utils/multi_worker/thread.py +145 -58
- {speedy_utils-1.1.22.dist-info → speedy_utils-1.1.24.dist-info}/METADATA +1 -1
- {speedy_utils-1.1.22.dist-info → speedy_utils-1.1.24.dist-info}/RECORD +19 -17
- llm_utils/lm/llm_task.py +0 -614
- llm_utils/lm/lm.py +0 -207
- {speedy_utils-1.1.22.dist-info → speedy_utils-1.1.24.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.22.dist-info → speedy_utils-1.1.24.dist-info}/entry_points.txt +0 -0
llm_utils/lm/mixins.py
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
1
|
+
"""Mixin classes for LLM functionality extensions."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import subprocess
|
|
5
|
+
from time import sleep
|
|
6
|
+
from typing import Any, Dict, List, Optional, Type, Union
|
|
7
|
+
|
|
8
|
+
import requests
|
|
9
|
+
from loguru import logger
|
|
10
|
+
from openai import OpenAI
|
|
11
|
+
from pydantic import BaseModel
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TemperatureRangeMixin:
|
|
15
|
+
"""Mixin for sampling with different temperature ranges."""
|
|
16
|
+
|
|
17
|
+
def temperature_range_sampling(
|
|
18
|
+
self,
|
|
19
|
+
input_data: Union[str, BaseModel, List[Dict]],
|
|
20
|
+
temperature_ranges: tuple[float, float],
|
|
21
|
+
n: int = 32,
|
|
22
|
+
response_model: Optional[Type[BaseModel] | Type[str]] = None,
|
|
23
|
+
**runtime_kwargs,
|
|
24
|
+
) -> List[Dict[str, Any]]:
|
|
25
|
+
"""
|
|
26
|
+
Sample LLM responses with a range of temperatures.
|
|
27
|
+
|
|
28
|
+
This method generates multiple responses by systematically varying
|
|
29
|
+
the temperature parameter, which controls randomness in the output.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
input_data: Input data (string, BaseModel, or message list)
|
|
33
|
+
temperature_ranges: Tuple of (min_temp, max_temp) to sample
|
|
34
|
+
n: Number of temperature samples to generate (must be >= 2)
|
|
35
|
+
response_model: Optional response model override
|
|
36
|
+
**runtime_kwargs: Additional runtime parameters
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
List of response dictionaries from all temperature samples
|
|
40
|
+
"""
|
|
41
|
+
from speedy_utils.multi_worker.thread import multi_thread
|
|
42
|
+
|
|
43
|
+
min_temp, max_temp = temperature_ranges
|
|
44
|
+
if n < 2:
|
|
45
|
+
raise ValueError(f"n must be >= 2, got {n}")
|
|
46
|
+
|
|
47
|
+
step = (max_temp - min_temp) / (n - 1)
|
|
48
|
+
list_kwargs = []
|
|
49
|
+
|
|
50
|
+
for i in range(n):
|
|
51
|
+
kwargs = dict(
|
|
52
|
+
temperature=min_temp + i * step,
|
|
53
|
+
i=i,
|
|
54
|
+
**runtime_kwargs,
|
|
55
|
+
)
|
|
56
|
+
list_kwargs.append(kwargs)
|
|
57
|
+
|
|
58
|
+
def f(kwargs):
|
|
59
|
+
i = kwargs.pop("i")
|
|
60
|
+
sleep(i * 0.05)
|
|
61
|
+
return self.__inner_call__(
|
|
62
|
+
input_data,
|
|
63
|
+
response_model=response_model,
|
|
64
|
+
**kwargs,
|
|
65
|
+
)[0]
|
|
66
|
+
|
|
67
|
+
choices = multi_thread(f, list_kwargs, progress=False)
|
|
68
|
+
return [c for c in choices if c is not None]
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class TwoStepPydanticMixin:
|
|
72
|
+
"""Mixin for two-step Pydantic parsing functionality."""
|
|
73
|
+
|
|
74
|
+
def two_step_pydantic_parse(
|
|
75
|
+
self,
|
|
76
|
+
input_data: Union[str, BaseModel, List[Dict]],
|
|
77
|
+
response_model: Type[BaseModel],
|
|
78
|
+
**runtime_kwargs,
|
|
79
|
+
) -> List[Dict[str, Any]]:
|
|
80
|
+
"""
|
|
81
|
+
Parse responses in two steps: text completion then Pydantic parsing.
|
|
82
|
+
|
|
83
|
+
This is useful for models that may include reasoning or extra text
|
|
84
|
+
before the JSON output.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
input_data: Input data (string, BaseModel, or message list)
|
|
88
|
+
response_model: Pydantic model to parse into
|
|
89
|
+
**runtime_kwargs: Additional runtime parameters
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
List of parsed response dictionaries
|
|
93
|
+
"""
|
|
94
|
+
# Step 1: Get text completions
|
|
95
|
+
results = self.text_completion(input_data, **runtime_kwargs)
|
|
96
|
+
parsed_results = []
|
|
97
|
+
|
|
98
|
+
for result in results:
|
|
99
|
+
response_text = result["parsed"]
|
|
100
|
+
messages = result["messages"]
|
|
101
|
+
|
|
102
|
+
# Handle reasoning models that use <think> tags
|
|
103
|
+
if "</think>" in response_text:
|
|
104
|
+
response_text = response_text.split("</think>")[1]
|
|
105
|
+
|
|
106
|
+
try:
|
|
107
|
+
# Try direct parsing
|
|
108
|
+
parsed = response_model.model_validate_json(response_text)
|
|
109
|
+
except Exception:
|
|
110
|
+
# Fallback: use LLM to extract JSON
|
|
111
|
+
logger.warning("Failed to parse JSON directly, using LLM to extract")
|
|
112
|
+
_parsed_messages = [
|
|
113
|
+
{
|
|
114
|
+
"role": "system",
|
|
115
|
+
"content": ("You are a helpful assistant that extracts JSON from text."),
|
|
116
|
+
},
|
|
117
|
+
{
|
|
118
|
+
"role": "user",
|
|
119
|
+
"content": (f"Extract JSON from the following text:\n{response_text}"),
|
|
120
|
+
},
|
|
121
|
+
]
|
|
122
|
+
parsed_result = self.pydantic_parse(
|
|
123
|
+
_parsed_messages,
|
|
124
|
+
response_model=response_model,
|
|
125
|
+
**runtime_kwargs,
|
|
126
|
+
)[0]
|
|
127
|
+
parsed = parsed_result["parsed"]
|
|
128
|
+
|
|
129
|
+
parsed_results.append({"parsed": parsed, "messages": messages})
|
|
130
|
+
|
|
131
|
+
return parsed_results
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class VLLMMixin:
|
|
135
|
+
"""Mixin for VLLM server management and LoRA operations."""
|
|
136
|
+
|
|
137
|
+
def _setup_vllm_server(self) -> None:
|
|
138
|
+
"""
|
|
139
|
+
Setup VLLM server if vllm_cmd is provided.
|
|
140
|
+
|
|
141
|
+
This method handles:
|
|
142
|
+
- Server reuse logic
|
|
143
|
+
- Starting new servers
|
|
144
|
+
- Port management
|
|
145
|
+
|
|
146
|
+
Should be called from __init__.
|
|
147
|
+
"""
|
|
148
|
+
from .utils import (
|
|
149
|
+
_extract_port_from_vllm_cmd,
|
|
150
|
+
_is_server_running,
|
|
151
|
+
_kill_vllm_on_port,
|
|
152
|
+
_start_vllm_server,
|
|
153
|
+
get_base_client,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
if not hasattr(self, "vllm_cmd") or not self.vllm_cmd:
|
|
157
|
+
return
|
|
158
|
+
|
|
159
|
+
port = _extract_port_from_vllm_cmd(self.vllm_cmd)
|
|
160
|
+
reuse_existing = False
|
|
161
|
+
|
|
162
|
+
if self.vllm_reuse:
|
|
163
|
+
try:
|
|
164
|
+
reuse_client = get_base_client(port, cache=False)
|
|
165
|
+
models_response = reuse_client.models.list()
|
|
166
|
+
if getattr(models_response, "data", None):
|
|
167
|
+
reuse_existing = True
|
|
168
|
+
logger.info(
|
|
169
|
+
f"VLLM server already running on port {port}, reusing existing server (vllm_reuse=True)"
|
|
170
|
+
)
|
|
171
|
+
else:
|
|
172
|
+
logger.info(f"No models returned from VLLM server on port {port}; starting a new server")
|
|
173
|
+
except Exception as exc:
|
|
174
|
+
logger.info(
|
|
175
|
+
f"Unable to reach VLLM server on port {port} (list_models failed): {exc}. Starting a new server."
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
if not self.vllm_reuse:
|
|
179
|
+
if _is_server_running(port):
|
|
180
|
+
logger.info(f"VLLM server already running on port {port}, killing it first (vllm_reuse=False)")
|
|
181
|
+
_kill_vllm_on_port(port)
|
|
182
|
+
logger.info(f"Starting new VLLM server on port {port}")
|
|
183
|
+
self.vllm_process = _start_vllm_server(self.vllm_cmd, self.vllm_timeout)
|
|
184
|
+
elif not reuse_existing:
|
|
185
|
+
logger.info(f"Starting VLLM server on port {port}")
|
|
186
|
+
self.vllm_process = _start_vllm_server(self.vllm_cmd, self.vllm_timeout)
|
|
187
|
+
|
|
188
|
+
def _load_lora_adapter(self) -> None:
|
|
189
|
+
"""
|
|
190
|
+
Load LoRA adapter from the specified lora_path.
|
|
191
|
+
|
|
192
|
+
This method:
|
|
193
|
+
1. Validates that lora_path is a valid LoRA directory
|
|
194
|
+
2. Checks if LoRA is already loaded (unless force_lora_unload)
|
|
195
|
+
3. Loads the LoRA adapter and updates the model name
|
|
196
|
+
"""
|
|
197
|
+
from .utils import (
|
|
198
|
+
_is_lora_path,
|
|
199
|
+
_get_port_from_client,
|
|
200
|
+
_load_lora_adapter,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
if not self.lora_path:
|
|
204
|
+
return
|
|
205
|
+
|
|
206
|
+
if not _is_lora_path(self.lora_path):
|
|
207
|
+
raise ValueError(f"Invalid LoRA path '{self.lora_path}': Directory must contain 'adapter_config.json'")
|
|
208
|
+
|
|
209
|
+
logger.info(f"Loading LoRA adapter from: {self.lora_path}")
|
|
210
|
+
|
|
211
|
+
# Get the expected LoRA name (basename of the path)
|
|
212
|
+
lora_name = os.path.basename(self.lora_path.rstrip("/\\"))
|
|
213
|
+
if not lora_name: # Handle edge case of empty basename
|
|
214
|
+
lora_name = os.path.basename(os.path.dirname(self.lora_path))
|
|
215
|
+
|
|
216
|
+
# Get list of available models to check if LoRA is already loaded
|
|
217
|
+
try:
|
|
218
|
+
available_models = [m.id for m in self.client.models.list().data]
|
|
219
|
+
except Exception as e:
|
|
220
|
+
logger.warning(f"Failed to list models, proceeding with LoRA load: {str(e)[:100]}")
|
|
221
|
+
available_models = []
|
|
222
|
+
|
|
223
|
+
# Check if LoRA is already loaded
|
|
224
|
+
if lora_name in available_models and not self.force_lora_unload:
|
|
225
|
+
logger.info(f"LoRA adapter '{lora_name}' is already loaded, using existing model")
|
|
226
|
+
self.model_kwargs["model"] = lora_name
|
|
227
|
+
return
|
|
228
|
+
|
|
229
|
+
# Force unload if requested
|
|
230
|
+
if self.force_lora_unload and lora_name in available_models:
|
|
231
|
+
logger.info(f"Force unloading LoRA adapter '{lora_name}' before reloading")
|
|
232
|
+
port = _get_port_from_client(self.client)
|
|
233
|
+
if port is not None:
|
|
234
|
+
try:
|
|
235
|
+
VLLMMixin.unload_lora(port, lora_name)
|
|
236
|
+
logger.info(f"Successfully unloaded LoRA adapter: {lora_name}")
|
|
237
|
+
except Exception as e:
|
|
238
|
+
logger.warning(f"Failed to unload LoRA adapter: {str(e)[:100]}")
|
|
239
|
+
|
|
240
|
+
# Get port from client for API calls
|
|
241
|
+
port = _get_port_from_client(self.client)
|
|
242
|
+
if port is None:
|
|
243
|
+
raise ValueError(
|
|
244
|
+
f"Cannot load LoRA adapter '{self.lora_path}': "
|
|
245
|
+
f"Unable to determine port from client base_url. "
|
|
246
|
+
f"LoRA loading requires a client initialized with port."
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
try:
|
|
250
|
+
# Load the LoRA adapter
|
|
251
|
+
loaded_lora_name = _load_lora_adapter(self.lora_path, port)
|
|
252
|
+
logger.info(f"Successfully loaded LoRA adapter: {loaded_lora_name}")
|
|
253
|
+
|
|
254
|
+
# Update model name to the loaded LoRA name
|
|
255
|
+
self.model_kwargs["model"] = loaded_lora_name
|
|
256
|
+
|
|
257
|
+
except requests.RequestException as e:
|
|
258
|
+
# Check if error is due to LoRA already being loaded
|
|
259
|
+
error_msg = str(e)
|
|
260
|
+
if "400" in error_msg or "Bad Request" in error_msg:
|
|
261
|
+
logger.info(f"LoRA adapter may already be loaded, attempting to use '{lora_name}'")
|
|
262
|
+
# Refresh the model list to check if it's now available
|
|
263
|
+
try:
|
|
264
|
+
updated_models = [m.id for m in self.client.models.list().data]
|
|
265
|
+
if lora_name in updated_models:
|
|
266
|
+
logger.info(f"Found LoRA adapter '{lora_name}' in updated model list")
|
|
267
|
+
self.model_kwargs["model"] = lora_name
|
|
268
|
+
return
|
|
269
|
+
except Exception:
|
|
270
|
+
pass # Fall through to original error
|
|
271
|
+
|
|
272
|
+
raise ValueError(f"Failed to load LoRA adapter from '{self.lora_path}': {error_msg[:100]}")
|
|
273
|
+
|
|
274
|
+
def unload_lora_adapter(self, lora_path: str) -> None:
|
|
275
|
+
"""
|
|
276
|
+
Unload a LoRA adapter.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
lora_path: Path to the LoRA adapter directory to unload
|
|
280
|
+
|
|
281
|
+
Raises:
|
|
282
|
+
ValueError: If unable to determine port from client
|
|
283
|
+
"""
|
|
284
|
+
from .utils import _get_port_from_client, _unload_lora_adapter
|
|
285
|
+
|
|
286
|
+
port = _get_port_from_client(self.client)
|
|
287
|
+
if port is None:
|
|
288
|
+
raise ValueError(
|
|
289
|
+
"Cannot unload LoRA adapter: "
|
|
290
|
+
"Unable to determine port from client base_url. "
|
|
291
|
+
"LoRA operations require a client initialized with port."
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
_unload_lora_adapter(lora_path, port)
|
|
295
|
+
lora_name = os.path.basename(lora_path.rstrip("/\\"))
|
|
296
|
+
logger.info(f"Unloaded LoRA adapter: {lora_name}")
|
|
297
|
+
|
|
298
|
+
@staticmethod
|
|
299
|
+
def unload_lora(port: int, lora_name: str) -> None:
|
|
300
|
+
"""
|
|
301
|
+
Static method to unload a LoRA adapter by name.
|
|
302
|
+
|
|
303
|
+
Args:
|
|
304
|
+
port: Port number for the API endpoint
|
|
305
|
+
lora_name: Name of the LoRA adapter to unload
|
|
306
|
+
|
|
307
|
+
Raises:
|
|
308
|
+
requests.RequestException: If the API call fails
|
|
309
|
+
"""
|
|
310
|
+
try:
|
|
311
|
+
response = requests.post(
|
|
312
|
+
f"http://localhost:{port}/v1/unload_lora_adapter",
|
|
313
|
+
headers={
|
|
314
|
+
"accept": "application/json",
|
|
315
|
+
"Content-Type": "application/json",
|
|
316
|
+
},
|
|
317
|
+
json={"lora_name": lora_name, "lora_int_id": 0},
|
|
318
|
+
)
|
|
319
|
+
response.raise_for_status()
|
|
320
|
+
logger.info(f"Successfully unloaded LoRA adapter: {lora_name}")
|
|
321
|
+
except requests.RequestException as e:
|
|
322
|
+
logger.error(f"Error unloading LoRA adapter '{lora_name}': {str(e)[:100]}")
|
|
323
|
+
raise
|
|
324
|
+
|
|
325
|
+
def cleanup_vllm_server(self) -> None:
|
|
326
|
+
"""Stop the VLLM server process if started by this instance."""
|
|
327
|
+
from .utils import stop_vllm_process
|
|
328
|
+
|
|
329
|
+
if hasattr(self, "vllm_process") and self.vllm_process is not None:
|
|
330
|
+
stop_vllm_process(self.vllm_process)
|
|
331
|
+
self.vllm_process = None
|
|
332
|
+
|
|
333
|
+
@staticmethod
|
|
334
|
+
def kill_all_vllm() -> int:
|
|
335
|
+
"""
|
|
336
|
+
Kill all tracked VLLM server processes.
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
Number of processes killed
|
|
340
|
+
"""
|
|
341
|
+
from .utils import kill_all_vllm_processes
|
|
342
|
+
|
|
343
|
+
return kill_all_vllm_processes()
|
|
344
|
+
|
|
345
|
+
@staticmethod
|
|
346
|
+
def kill_vllm_on_port(port: int) -> bool:
|
|
347
|
+
"""
|
|
348
|
+
Kill VLLM server running on a specific port.
|
|
349
|
+
|
|
350
|
+
Args:
|
|
351
|
+
port: Port number to kill server on
|
|
352
|
+
|
|
353
|
+
Returns:
|
|
354
|
+
True if a server was killed, False if no server was running
|
|
355
|
+
"""
|
|
356
|
+
from .utils import _kill_vllm_on_port
|
|
357
|
+
|
|
358
|
+
return _kill_vllm_on_port(port)
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
class ModelUtilsMixin:
|
|
362
|
+
"""Mixin for model utility methods."""
|
|
363
|
+
|
|
364
|
+
@staticmethod
|
|
365
|
+
def list_models(client: Union[OpenAI, int, str, None] = None) -> List[str]:
|
|
366
|
+
"""
|
|
367
|
+
List available models from the OpenAI client.
|
|
368
|
+
|
|
369
|
+
Args:
|
|
370
|
+
client: OpenAI client, port number, or base_url string
|
|
371
|
+
|
|
372
|
+
Returns:
|
|
373
|
+
List of available model names
|
|
374
|
+
"""
|
|
375
|
+
from .utils import get_base_client
|
|
376
|
+
|
|
377
|
+
client_instance = get_base_client(client, cache=False)
|
|
378
|
+
models = client_instance.models.list().data
|
|
379
|
+
return [m.id for m in models]
|
llm_utils/lm/openai_memoize.py
CHANGED
|
@@ -42,13 +42,16 @@ class MOpenAI(OpenAI):
|
|
|
42
42
|
|
|
43
43
|
def __init__(self, *args, cache=True, **kwargs):
|
|
44
44
|
super().__init__(*args, **kwargs)
|
|
45
|
+
self._orig_post = self.post
|
|
45
46
|
if cache:
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
self.post =
|
|
47
|
+
self.set_cache(cache)
|
|
48
|
+
|
|
49
|
+
def set_cache(self, cache: bool) -> None:
|
|
50
|
+
"""Enable or disable caching of the post method."""
|
|
51
|
+
if cache and self.post == self._orig_post:
|
|
52
|
+
self.post = memoize(self._orig_post) # type: ignore
|
|
53
|
+
elif not cache and self.post != self._orig_post:
|
|
54
|
+
self.post = self._orig_post
|
|
52
55
|
|
|
53
56
|
|
|
54
57
|
class MAsyncOpenAI(AsyncOpenAI):
|
|
@@ -76,5 +79,13 @@ class MAsyncOpenAI(AsyncOpenAI):
|
|
|
76
79
|
|
|
77
80
|
def __init__(self, *args, cache=True, **kwargs):
|
|
78
81
|
super().__init__(*args, **kwargs)
|
|
82
|
+
self._orig_post = self.post
|
|
79
83
|
if cache:
|
|
80
|
-
self.
|
|
84
|
+
self.set_cache(cache)
|
|
85
|
+
|
|
86
|
+
def set_cache(self, cache: bool) -> None:
|
|
87
|
+
"""Enable or disable caching of the post method."""
|
|
88
|
+
if cache and self.post == self._orig_post:
|
|
89
|
+
self.post = memoize(self._orig_post) # type: ignore
|
|
90
|
+
elif not cache and self.post != self._orig_post:
|
|
91
|
+
self.post = self._orig_post
|
|
@@ -0,0 +1,271 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DSPy-like signature system for structured LLM interactions.
|
|
3
|
+
|
|
4
|
+
This module provides a declarative way to define LLM input/output schemas
|
|
5
|
+
with field descriptions and type annotations.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Any, Dict, List, Type, get_type_hints, Annotated, get_origin, get_args, cast
|
|
9
|
+
from pydantic import BaseModel, Field
|
|
10
|
+
import inspect
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class _FieldProxy:
|
|
14
|
+
"""Proxy that stores field information while appearing type-compatible."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, field_type: str, desc: str = "", **kwargs):
|
|
17
|
+
self.field_type = field_type # 'input' or 'output'
|
|
18
|
+
self.desc = desc
|
|
19
|
+
self.kwargs = kwargs
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def InputField(desc: str = "", **kwargs) -> Any:
|
|
23
|
+
"""Create an input field descriptor."""
|
|
24
|
+
return cast(Any, _FieldProxy('input', desc=desc, **kwargs))
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def OutputField(desc: str = "", **kwargs) -> Any:
|
|
28
|
+
"""Create an output field descriptor."""
|
|
29
|
+
return cast(Any, _FieldProxy('output', desc=desc, **kwargs))
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# Type aliases for cleaner syntax
|
|
33
|
+
def Input(desc: str = "", **kwargs) -> Any:
|
|
34
|
+
"""Create an input field descriptor that's compatible with type annotations."""
|
|
35
|
+
return InputField(desc=desc, **kwargs)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def Output(desc: str = "", **kwargs) -> Any:
|
|
39
|
+
"""Create an output field descriptor that's compatible with type annotations."""
|
|
40
|
+
return OutputField(desc=desc, **kwargs)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class SignatureMeta(type):
|
|
44
|
+
"""Metaclass for Signature that processes field annotations."""
|
|
45
|
+
|
|
46
|
+
def __new__(cls, name, bases, namespace, **kwargs):
|
|
47
|
+
# Get type hints for this class
|
|
48
|
+
annotations = namespace.get('__annotations__', {})
|
|
49
|
+
|
|
50
|
+
# Store field information
|
|
51
|
+
input_fields = {}
|
|
52
|
+
output_fields = {}
|
|
53
|
+
|
|
54
|
+
for field_name, field_type in annotations.items():
|
|
55
|
+
field_value = namespace.get(field_name)
|
|
56
|
+
field_desc = None
|
|
57
|
+
|
|
58
|
+
# Handle Annotated[Type, Field(...)] syntax using get_origin/get_args
|
|
59
|
+
if get_origin(field_type) is Annotated:
|
|
60
|
+
# Extract args from Annotated type
|
|
61
|
+
args = get_args(field_type)
|
|
62
|
+
if args:
|
|
63
|
+
# First arg is the actual type
|
|
64
|
+
field_type = args[0]
|
|
65
|
+
# Look for _FieldProxy in the metadata
|
|
66
|
+
for metadata in args[1:]:
|
|
67
|
+
if isinstance(metadata, _FieldProxy):
|
|
68
|
+
field_desc = metadata
|
|
69
|
+
break
|
|
70
|
+
|
|
71
|
+
# Handle old syntax with direct assignment
|
|
72
|
+
if field_desc is None and isinstance(field_value, _FieldProxy):
|
|
73
|
+
field_desc = field_value
|
|
74
|
+
|
|
75
|
+
# Store field information
|
|
76
|
+
if field_desc and field_desc.field_type == 'input':
|
|
77
|
+
input_fields[field_name] = {
|
|
78
|
+
'type': field_type,
|
|
79
|
+
'desc': field_desc.desc,
|
|
80
|
+
**field_desc.kwargs
|
|
81
|
+
}
|
|
82
|
+
elif field_desc and field_desc.field_type == 'output':
|
|
83
|
+
output_fields[field_name] = {
|
|
84
|
+
'type': field_type,
|
|
85
|
+
'desc': field_desc.desc,
|
|
86
|
+
**field_desc.kwargs
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
# Store in class attributes
|
|
90
|
+
namespace['_input_fields'] = input_fields
|
|
91
|
+
namespace['_output_fields'] = output_fields
|
|
92
|
+
|
|
93
|
+
return super().__new__(cls, name, bases, namespace)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class Signature(metaclass=SignatureMeta):
|
|
97
|
+
"""Base class for defining LLM signatures with input and output fields."""
|
|
98
|
+
|
|
99
|
+
_input_fields: Dict[str, Dict[str, Any]] = {}
|
|
100
|
+
_output_fields: Dict[str, Dict[str, Any]] = {}
|
|
101
|
+
|
|
102
|
+
def __init__(self, **kwargs):
|
|
103
|
+
"""Initialize signature with field values."""
|
|
104
|
+
for field_name, value in kwargs.items():
|
|
105
|
+
setattr(self, field_name, value)
|
|
106
|
+
|
|
107
|
+
@classmethod
|
|
108
|
+
def get_instruction(cls) -> str:
|
|
109
|
+
"""Generate instruction text from docstring and field descriptions."""
|
|
110
|
+
instruction = cls.__doc__ or "Complete the following task."
|
|
111
|
+
instruction = instruction.strip()
|
|
112
|
+
|
|
113
|
+
# Add input field descriptions
|
|
114
|
+
if cls._input_fields:
|
|
115
|
+
instruction += "\n\n**Input Fields:**\n"
|
|
116
|
+
for field_name, field_info in cls._input_fields.items():
|
|
117
|
+
desc = field_info.get('desc', '')
|
|
118
|
+
field_type = field_info['type']
|
|
119
|
+
type_str = getattr(field_type, '__name__', str(field_type))
|
|
120
|
+
instruction += f"- {field_name} ({type_str}): {desc}\n"
|
|
121
|
+
|
|
122
|
+
# Add output field descriptions
|
|
123
|
+
if cls._output_fields:
|
|
124
|
+
instruction += "\n**Output Fields:**\n"
|
|
125
|
+
for field_name, field_info in cls._output_fields.items():
|
|
126
|
+
desc = field_info.get('desc', '')
|
|
127
|
+
field_type = field_info['type']
|
|
128
|
+
type_str = getattr(field_type, '__name__', str(field_type))
|
|
129
|
+
instruction += f"- {field_name} ({type_str}): {desc}\n"
|
|
130
|
+
|
|
131
|
+
return instruction
|
|
132
|
+
|
|
133
|
+
@classmethod
|
|
134
|
+
def get_input_model(cls) -> Type[BaseModel]:
|
|
135
|
+
"""Generate Pydantic input model from input fields."""
|
|
136
|
+
if not cls._input_fields:
|
|
137
|
+
raise ValueError(f"Signature {cls.__name__} must have at least one input field")
|
|
138
|
+
|
|
139
|
+
fields = {}
|
|
140
|
+
annotations = {}
|
|
141
|
+
|
|
142
|
+
for field_name, field_info in cls._input_fields.items():
|
|
143
|
+
field_type = field_info['type']
|
|
144
|
+
desc = field_info.get('desc', '')
|
|
145
|
+
|
|
146
|
+
# Create Pydantic field
|
|
147
|
+
field_kwargs = {k: v for k, v in field_info.items()
|
|
148
|
+
if k not in ['type', 'desc']}
|
|
149
|
+
if desc:
|
|
150
|
+
field_kwargs['description'] = desc
|
|
151
|
+
|
|
152
|
+
fields[field_name] = Field(**field_kwargs) if field_kwargs else Field()
|
|
153
|
+
annotations[field_name] = field_type
|
|
154
|
+
|
|
155
|
+
# Create dynamic Pydantic model
|
|
156
|
+
input_model = type(
|
|
157
|
+
f"{cls.__name__}Input",
|
|
158
|
+
(BaseModel,),
|
|
159
|
+
{
|
|
160
|
+
'__annotations__': annotations,
|
|
161
|
+
**fields
|
|
162
|
+
}
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
return input_model
|
|
166
|
+
|
|
167
|
+
@classmethod
|
|
168
|
+
def get_output_model(cls) -> Type[BaseModel]:
|
|
169
|
+
"""Generate Pydantic output model from output fields."""
|
|
170
|
+
if not cls._output_fields:
|
|
171
|
+
raise ValueError(f"Signature {cls.__name__} must have at least one output field")
|
|
172
|
+
|
|
173
|
+
fields = {}
|
|
174
|
+
annotations = {}
|
|
175
|
+
|
|
176
|
+
for field_name, field_info in cls._output_fields.items():
|
|
177
|
+
field_type = field_info['type']
|
|
178
|
+
desc = field_info.get('desc', '')
|
|
179
|
+
|
|
180
|
+
# Create Pydantic field
|
|
181
|
+
field_kwargs = {k: v for k, v in field_info.items()
|
|
182
|
+
if k not in ['type', 'desc']}
|
|
183
|
+
if desc:
|
|
184
|
+
field_kwargs['description'] = desc
|
|
185
|
+
|
|
186
|
+
fields[field_name] = Field(**field_kwargs) if field_kwargs else Field()
|
|
187
|
+
annotations[field_name] = field_type
|
|
188
|
+
|
|
189
|
+
# Create dynamic Pydantic model
|
|
190
|
+
output_model = type(
|
|
191
|
+
f"{cls.__name__}Output",
|
|
192
|
+
(BaseModel,),
|
|
193
|
+
{
|
|
194
|
+
'__annotations__': annotations,
|
|
195
|
+
**fields
|
|
196
|
+
}
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
return output_model
|
|
200
|
+
|
|
201
|
+
def format_input(self, **kwargs) -> str:
|
|
202
|
+
"""Format input fields as a string."""
|
|
203
|
+
input_data = {}
|
|
204
|
+
|
|
205
|
+
# Collect input field values
|
|
206
|
+
for field_name in self._input_fields:
|
|
207
|
+
if field_name in kwargs:
|
|
208
|
+
input_data[field_name] = kwargs[field_name]
|
|
209
|
+
elif hasattr(self, field_name):
|
|
210
|
+
input_data[field_name] = getattr(self, field_name)
|
|
211
|
+
|
|
212
|
+
# Format as key-value pairs
|
|
213
|
+
formatted_lines = []
|
|
214
|
+
for field_name, value in input_data.items():
|
|
215
|
+
field_info = self._input_fields[field_name]
|
|
216
|
+
desc = field_info.get('desc', '')
|
|
217
|
+
if desc:
|
|
218
|
+
formatted_lines.append(f"{field_name} ({desc}): {value}")
|
|
219
|
+
else:
|
|
220
|
+
formatted_lines.append(f"{field_name}: {value}")
|
|
221
|
+
|
|
222
|
+
return '\n'.join(formatted_lines)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
# Export functions for easier importing
|
|
226
|
+
__all__ = ['Signature', 'InputField', 'OutputField', 'Input', 'Output']
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
# Example usage for testing
|
|
230
|
+
if __name__ == "__main__":
|
|
231
|
+
# Define a signature like DSPy - using Annotated approach
|
|
232
|
+
class FactJudge(Signature):
|
|
233
|
+
"""Judge if the answer is factually correct based on the context."""
|
|
234
|
+
|
|
235
|
+
context: Annotated[str, Input("Context for the prediction")]
|
|
236
|
+
question: Annotated[str, Input("Question to be answered")]
|
|
237
|
+
answer: Annotated[str, Input("Answer for the question")]
|
|
238
|
+
factually_correct: Annotated[bool, Output("Is the answer factually correct based on the context?")]
|
|
239
|
+
|
|
240
|
+
# Alternative syntax still works but will show type warnings
|
|
241
|
+
class FactJudgeOldSyntax(Signature):
|
|
242
|
+
"""Judge if the answer is factually correct based on the context."""
|
|
243
|
+
|
|
244
|
+
context: str = InputField(desc="Context for the prediction") # type: ignore
|
|
245
|
+
question: str = InputField(desc="Question to be answered") # type: ignore
|
|
246
|
+
answer: str = InputField(desc="Answer for the question") # type: ignore
|
|
247
|
+
factually_correct: bool = OutputField(desc="Is the answer factually correct based on the context?") # type: ignore
|
|
248
|
+
|
|
249
|
+
# Test both signatures
|
|
250
|
+
for judge_class in [FactJudge, FactJudgeOldSyntax]:
|
|
251
|
+
print(f"\n=== Testing {judge_class.__name__} ===")
|
|
252
|
+
print("Instruction:")
|
|
253
|
+
print(judge_class.get_instruction())
|
|
254
|
+
|
|
255
|
+
print("\nInput Model:")
|
|
256
|
+
input_model = judge_class.get_input_model()
|
|
257
|
+
print(input_model.model_json_schema())
|
|
258
|
+
|
|
259
|
+
print("\nOutput Model:")
|
|
260
|
+
output_model = judge_class.get_output_model()
|
|
261
|
+
print(output_model.model_json_schema())
|
|
262
|
+
|
|
263
|
+
# Test instance usage
|
|
264
|
+
judge = judge_class()
|
|
265
|
+
input_text = judge.format_input(
|
|
266
|
+
context="The sky is blue during daytime.",
|
|
267
|
+
question="What color is the sky?",
|
|
268
|
+
answer="Blue"
|
|
269
|
+
)
|
|
270
|
+
print("\nFormatted Input:")
|
|
271
|
+
print(input_text)
|