speedy-utils 1.1.23__py3-none-any.whl → 1.1.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +12 -8
- llm_utils/chat_format/__init__.py +2 -0
- llm_utils/chat_format/display.py +115 -44
- llm_utils/lm/__init__.py +14 -6
- llm_utils/lm/llm.py +413 -0
- llm_utils/lm/llm_signature.py +35 -0
- llm_utils/lm/mixins.py +379 -0
- llm_utils/lm/openai_memoize.py +18 -7
- llm_utils/lm/signature.py +26 -37
- llm_utils/lm/utils.py +61 -76
- speedy_utils/__init__.py +31 -2
- speedy_utils/all.py +30 -1
- speedy_utils/common/utils_cache.py +142 -1
- speedy_utils/common/utils_io.py +36 -26
- speedy_utils/common/utils_misc.py +25 -1
- speedy_utils/multi_worker/thread.py +145 -58
- {speedy_utils-1.1.23.dist-info → speedy_utils-1.1.25.dist-info}/METADATA +1 -1
- {speedy_utils-1.1.23.dist-info → speedy_utils-1.1.25.dist-info}/RECORD +20 -19
- llm_utils/lm/llm_as_a_judge.py +0 -390
- llm_utils/lm/llm_task.py +0 -614
- {speedy_utils-1.1.23.dist-info → speedy_utils-1.1.25.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.23.dist-info → speedy_utils-1.1.25.dist-info}/entry_points.txt +0 -0
llm_utils/lm/mixins.py
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
1
|
+
"""Mixin classes for LLM functionality extensions."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import subprocess
|
|
5
|
+
from time import sleep
|
|
6
|
+
from typing import Any, Dict, List, Optional, Type, Union
|
|
7
|
+
|
|
8
|
+
import requests
|
|
9
|
+
from loguru import logger
|
|
10
|
+
from openai import OpenAI
|
|
11
|
+
from pydantic import BaseModel
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TemperatureRangeMixin:
|
|
15
|
+
"""Mixin for sampling with different temperature ranges."""
|
|
16
|
+
|
|
17
|
+
def temperature_range_sampling(
|
|
18
|
+
self,
|
|
19
|
+
input_data: Union[str, BaseModel, List[Dict]],
|
|
20
|
+
temperature_ranges: tuple[float, float],
|
|
21
|
+
n: int = 32,
|
|
22
|
+
response_model: Optional[Type[BaseModel] | Type[str]] = None,
|
|
23
|
+
**runtime_kwargs,
|
|
24
|
+
) -> List[Dict[str, Any]]:
|
|
25
|
+
"""
|
|
26
|
+
Sample LLM responses with a range of temperatures.
|
|
27
|
+
|
|
28
|
+
This method generates multiple responses by systematically varying
|
|
29
|
+
the temperature parameter, which controls randomness in the output.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
input_data: Input data (string, BaseModel, or message list)
|
|
33
|
+
temperature_ranges: Tuple of (min_temp, max_temp) to sample
|
|
34
|
+
n: Number of temperature samples to generate (must be >= 2)
|
|
35
|
+
response_model: Optional response model override
|
|
36
|
+
**runtime_kwargs: Additional runtime parameters
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
List of response dictionaries from all temperature samples
|
|
40
|
+
"""
|
|
41
|
+
from speedy_utils.multi_worker.thread import multi_thread
|
|
42
|
+
|
|
43
|
+
min_temp, max_temp = temperature_ranges
|
|
44
|
+
if n < 2:
|
|
45
|
+
raise ValueError(f"n must be >= 2, got {n}")
|
|
46
|
+
|
|
47
|
+
step = (max_temp - min_temp) / (n - 1)
|
|
48
|
+
list_kwargs = []
|
|
49
|
+
|
|
50
|
+
for i in range(n):
|
|
51
|
+
kwargs = dict(
|
|
52
|
+
temperature=min_temp + i * step,
|
|
53
|
+
i=i,
|
|
54
|
+
**runtime_kwargs,
|
|
55
|
+
)
|
|
56
|
+
list_kwargs.append(kwargs)
|
|
57
|
+
|
|
58
|
+
def f(kwargs):
|
|
59
|
+
i = kwargs.pop("i")
|
|
60
|
+
sleep(i * 0.05)
|
|
61
|
+
return self.__inner_call__(
|
|
62
|
+
input_data,
|
|
63
|
+
response_model=response_model,
|
|
64
|
+
**kwargs,
|
|
65
|
+
)[0]
|
|
66
|
+
|
|
67
|
+
choices = multi_thread(f, list_kwargs, progress=False)
|
|
68
|
+
return [c for c in choices if c is not None]
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class TwoStepPydanticMixin:
|
|
72
|
+
"""Mixin for two-step Pydantic parsing functionality."""
|
|
73
|
+
|
|
74
|
+
def two_step_pydantic_parse(
|
|
75
|
+
self,
|
|
76
|
+
input_data: Union[str, BaseModel, List[Dict]],
|
|
77
|
+
response_model: Type[BaseModel],
|
|
78
|
+
**runtime_kwargs,
|
|
79
|
+
) -> List[Dict[str, Any]]:
|
|
80
|
+
"""
|
|
81
|
+
Parse responses in two steps: text completion then Pydantic parsing.
|
|
82
|
+
|
|
83
|
+
This is useful for models that may include reasoning or extra text
|
|
84
|
+
before the JSON output.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
input_data: Input data (string, BaseModel, or message list)
|
|
88
|
+
response_model: Pydantic model to parse into
|
|
89
|
+
**runtime_kwargs: Additional runtime parameters
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
List of parsed response dictionaries
|
|
93
|
+
"""
|
|
94
|
+
# Step 1: Get text completions
|
|
95
|
+
results = self.text_completion(input_data, **runtime_kwargs)
|
|
96
|
+
parsed_results = []
|
|
97
|
+
|
|
98
|
+
for result in results:
|
|
99
|
+
response_text = result["parsed"]
|
|
100
|
+
messages = result["messages"]
|
|
101
|
+
|
|
102
|
+
# Handle reasoning models that use <think> tags
|
|
103
|
+
if "</think>" in response_text:
|
|
104
|
+
response_text = response_text.split("</think>")[1]
|
|
105
|
+
|
|
106
|
+
try:
|
|
107
|
+
# Try direct parsing
|
|
108
|
+
parsed = response_model.model_validate_json(response_text)
|
|
109
|
+
except Exception:
|
|
110
|
+
# Fallback: use LLM to extract JSON
|
|
111
|
+
logger.warning("Failed to parse JSON directly, using LLM to extract")
|
|
112
|
+
_parsed_messages = [
|
|
113
|
+
{
|
|
114
|
+
"role": "system",
|
|
115
|
+
"content": ("You are a helpful assistant that extracts JSON from text."),
|
|
116
|
+
},
|
|
117
|
+
{
|
|
118
|
+
"role": "user",
|
|
119
|
+
"content": (f"Extract JSON from the following text:\n{response_text}"),
|
|
120
|
+
},
|
|
121
|
+
]
|
|
122
|
+
parsed_result = self.pydantic_parse(
|
|
123
|
+
_parsed_messages,
|
|
124
|
+
response_model=response_model,
|
|
125
|
+
**runtime_kwargs,
|
|
126
|
+
)[0]
|
|
127
|
+
parsed = parsed_result["parsed"]
|
|
128
|
+
|
|
129
|
+
parsed_results.append({"parsed": parsed, "messages": messages})
|
|
130
|
+
|
|
131
|
+
return parsed_results
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class VLLMMixin:
|
|
135
|
+
"""Mixin for VLLM server management and LoRA operations."""
|
|
136
|
+
|
|
137
|
+
def _setup_vllm_server(self) -> None:
|
|
138
|
+
"""
|
|
139
|
+
Setup VLLM server if vllm_cmd is provided.
|
|
140
|
+
|
|
141
|
+
This method handles:
|
|
142
|
+
- Server reuse logic
|
|
143
|
+
- Starting new servers
|
|
144
|
+
- Port management
|
|
145
|
+
|
|
146
|
+
Should be called from __init__.
|
|
147
|
+
"""
|
|
148
|
+
from .utils import (
|
|
149
|
+
_extract_port_from_vllm_cmd,
|
|
150
|
+
_is_server_running,
|
|
151
|
+
_kill_vllm_on_port,
|
|
152
|
+
_start_vllm_server,
|
|
153
|
+
get_base_client,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
if not hasattr(self, "vllm_cmd") or not self.vllm_cmd:
|
|
157
|
+
return
|
|
158
|
+
|
|
159
|
+
port = _extract_port_from_vllm_cmd(self.vllm_cmd)
|
|
160
|
+
reuse_existing = False
|
|
161
|
+
|
|
162
|
+
if self.vllm_reuse:
|
|
163
|
+
try:
|
|
164
|
+
reuse_client = get_base_client(port, cache=False)
|
|
165
|
+
models_response = reuse_client.models.list()
|
|
166
|
+
if getattr(models_response, "data", None):
|
|
167
|
+
reuse_existing = True
|
|
168
|
+
logger.info(
|
|
169
|
+
f"VLLM server already running on port {port}, reusing existing server (vllm_reuse=True)"
|
|
170
|
+
)
|
|
171
|
+
else:
|
|
172
|
+
logger.info(f"No models returned from VLLM server on port {port}; starting a new server")
|
|
173
|
+
except Exception as exc:
|
|
174
|
+
logger.info(
|
|
175
|
+
f"Unable to reach VLLM server on port {port} (list_models failed): {exc}. Starting a new server."
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
if not self.vllm_reuse:
|
|
179
|
+
if _is_server_running(port):
|
|
180
|
+
logger.info(f"VLLM server already running on port {port}, killing it first (vllm_reuse=False)")
|
|
181
|
+
_kill_vllm_on_port(port)
|
|
182
|
+
logger.info(f"Starting new VLLM server on port {port}")
|
|
183
|
+
self.vllm_process = _start_vllm_server(self.vllm_cmd, self.vllm_timeout)
|
|
184
|
+
elif not reuse_existing:
|
|
185
|
+
logger.info(f"Starting VLLM server on port {port}")
|
|
186
|
+
self.vllm_process = _start_vllm_server(self.vllm_cmd, self.vllm_timeout)
|
|
187
|
+
|
|
188
|
+
def _load_lora_adapter(self) -> None:
|
|
189
|
+
"""
|
|
190
|
+
Load LoRA adapter from the specified lora_path.
|
|
191
|
+
|
|
192
|
+
This method:
|
|
193
|
+
1. Validates that lora_path is a valid LoRA directory
|
|
194
|
+
2. Checks if LoRA is already loaded (unless force_lora_unload)
|
|
195
|
+
3. Loads the LoRA adapter and updates the model name
|
|
196
|
+
"""
|
|
197
|
+
from .utils import (
|
|
198
|
+
_is_lora_path,
|
|
199
|
+
_get_port_from_client,
|
|
200
|
+
_load_lora_adapter,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
if not self.lora_path:
|
|
204
|
+
return
|
|
205
|
+
|
|
206
|
+
if not _is_lora_path(self.lora_path):
|
|
207
|
+
raise ValueError(f"Invalid LoRA path '{self.lora_path}': Directory must contain 'adapter_config.json'")
|
|
208
|
+
|
|
209
|
+
logger.info(f"Loading LoRA adapter from: {self.lora_path}")
|
|
210
|
+
|
|
211
|
+
# Get the expected LoRA name (basename of the path)
|
|
212
|
+
lora_name = os.path.basename(self.lora_path.rstrip("/\\"))
|
|
213
|
+
if not lora_name: # Handle edge case of empty basename
|
|
214
|
+
lora_name = os.path.basename(os.path.dirname(self.lora_path))
|
|
215
|
+
|
|
216
|
+
# Get list of available models to check if LoRA is already loaded
|
|
217
|
+
try:
|
|
218
|
+
available_models = [m.id for m in self.client.models.list().data]
|
|
219
|
+
except Exception as e:
|
|
220
|
+
logger.warning(f"Failed to list models, proceeding with LoRA load: {str(e)[:100]}")
|
|
221
|
+
available_models = []
|
|
222
|
+
|
|
223
|
+
# Check if LoRA is already loaded
|
|
224
|
+
if lora_name in available_models and not self.force_lora_unload:
|
|
225
|
+
logger.info(f"LoRA adapter '{lora_name}' is already loaded, using existing model")
|
|
226
|
+
self.model_kwargs["model"] = lora_name
|
|
227
|
+
return
|
|
228
|
+
|
|
229
|
+
# Force unload if requested
|
|
230
|
+
if self.force_lora_unload and lora_name in available_models:
|
|
231
|
+
logger.info(f"Force unloading LoRA adapter '{lora_name}' before reloading")
|
|
232
|
+
port = _get_port_from_client(self.client)
|
|
233
|
+
if port is not None:
|
|
234
|
+
try:
|
|
235
|
+
VLLMMixin.unload_lora(port, lora_name)
|
|
236
|
+
logger.info(f"Successfully unloaded LoRA adapter: {lora_name}")
|
|
237
|
+
except Exception as e:
|
|
238
|
+
logger.warning(f"Failed to unload LoRA adapter: {str(e)[:100]}")
|
|
239
|
+
|
|
240
|
+
# Get port from client for API calls
|
|
241
|
+
port = _get_port_from_client(self.client)
|
|
242
|
+
if port is None:
|
|
243
|
+
raise ValueError(
|
|
244
|
+
f"Cannot load LoRA adapter '{self.lora_path}': "
|
|
245
|
+
f"Unable to determine port from client base_url. "
|
|
246
|
+
f"LoRA loading requires a client initialized with port."
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
try:
|
|
250
|
+
# Load the LoRA adapter
|
|
251
|
+
loaded_lora_name = _load_lora_adapter(self.lora_path, port)
|
|
252
|
+
logger.info(f"Successfully loaded LoRA adapter: {loaded_lora_name}")
|
|
253
|
+
|
|
254
|
+
# Update model name to the loaded LoRA name
|
|
255
|
+
self.model_kwargs["model"] = loaded_lora_name
|
|
256
|
+
|
|
257
|
+
except requests.RequestException as e:
|
|
258
|
+
# Check if error is due to LoRA already being loaded
|
|
259
|
+
error_msg = str(e)
|
|
260
|
+
if "400" in error_msg or "Bad Request" in error_msg:
|
|
261
|
+
logger.info(f"LoRA adapter may already be loaded, attempting to use '{lora_name}'")
|
|
262
|
+
# Refresh the model list to check if it's now available
|
|
263
|
+
try:
|
|
264
|
+
updated_models = [m.id for m in self.client.models.list().data]
|
|
265
|
+
if lora_name in updated_models:
|
|
266
|
+
logger.info(f"Found LoRA adapter '{lora_name}' in updated model list")
|
|
267
|
+
self.model_kwargs["model"] = lora_name
|
|
268
|
+
return
|
|
269
|
+
except Exception:
|
|
270
|
+
pass # Fall through to original error
|
|
271
|
+
|
|
272
|
+
raise ValueError(f"Failed to load LoRA adapter from '{self.lora_path}': {error_msg[:100]}")
|
|
273
|
+
|
|
274
|
+
def unload_lora_adapter(self, lora_path: str) -> None:
|
|
275
|
+
"""
|
|
276
|
+
Unload a LoRA adapter.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
lora_path: Path to the LoRA adapter directory to unload
|
|
280
|
+
|
|
281
|
+
Raises:
|
|
282
|
+
ValueError: If unable to determine port from client
|
|
283
|
+
"""
|
|
284
|
+
from .utils import _get_port_from_client, _unload_lora_adapter
|
|
285
|
+
|
|
286
|
+
port = _get_port_from_client(self.client)
|
|
287
|
+
if port is None:
|
|
288
|
+
raise ValueError(
|
|
289
|
+
"Cannot unload LoRA adapter: "
|
|
290
|
+
"Unable to determine port from client base_url. "
|
|
291
|
+
"LoRA operations require a client initialized with port."
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
_unload_lora_adapter(lora_path, port)
|
|
295
|
+
lora_name = os.path.basename(lora_path.rstrip("/\\"))
|
|
296
|
+
logger.info(f"Unloaded LoRA adapter: {lora_name}")
|
|
297
|
+
|
|
298
|
+
@staticmethod
|
|
299
|
+
def unload_lora(port: int, lora_name: str) -> None:
|
|
300
|
+
"""
|
|
301
|
+
Static method to unload a LoRA adapter by name.
|
|
302
|
+
|
|
303
|
+
Args:
|
|
304
|
+
port: Port number for the API endpoint
|
|
305
|
+
lora_name: Name of the LoRA adapter to unload
|
|
306
|
+
|
|
307
|
+
Raises:
|
|
308
|
+
requests.RequestException: If the API call fails
|
|
309
|
+
"""
|
|
310
|
+
try:
|
|
311
|
+
response = requests.post(
|
|
312
|
+
f"http://localhost:{port}/v1/unload_lora_adapter",
|
|
313
|
+
headers={
|
|
314
|
+
"accept": "application/json",
|
|
315
|
+
"Content-Type": "application/json",
|
|
316
|
+
},
|
|
317
|
+
json={"lora_name": lora_name, "lora_int_id": 0},
|
|
318
|
+
)
|
|
319
|
+
response.raise_for_status()
|
|
320
|
+
logger.info(f"Successfully unloaded LoRA adapter: {lora_name}")
|
|
321
|
+
except requests.RequestException as e:
|
|
322
|
+
logger.error(f"Error unloading LoRA adapter '{lora_name}': {str(e)[:100]}")
|
|
323
|
+
raise
|
|
324
|
+
|
|
325
|
+
def cleanup_vllm_server(self) -> None:
|
|
326
|
+
"""Stop the VLLM server process if started by this instance."""
|
|
327
|
+
from .utils import stop_vllm_process
|
|
328
|
+
|
|
329
|
+
if hasattr(self, "vllm_process") and self.vllm_process is not None:
|
|
330
|
+
stop_vllm_process(self.vllm_process)
|
|
331
|
+
self.vllm_process = None
|
|
332
|
+
|
|
333
|
+
@staticmethod
|
|
334
|
+
def kill_all_vllm() -> int:
|
|
335
|
+
"""
|
|
336
|
+
Kill all tracked VLLM server processes.
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
Number of processes killed
|
|
340
|
+
"""
|
|
341
|
+
from .utils import kill_all_vllm_processes
|
|
342
|
+
|
|
343
|
+
return kill_all_vllm_processes()
|
|
344
|
+
|
|
345
|
+
@staticmethod
|
|
346
|
+
def kill_vllm_on_port(port: int) -> bool:
|
|
347
|
+
"""
|
|
348
|
+
Kill VLLM server running on a specific port.
|
|
349
|
+
|
|
350
|
+
Args:
|
|
351
|
+
port: Port number to kill server on
|
|
352
|
+
|
|
353
|
+
Returns:
|
|
354
|
+
True if a server was killed, False if no server was running
|
|
355
|
+
"""
|
|
356
|
+
from .utils import _kill_vllm_on_port
|
|
357
|
+
|
|
358
|
+
return _kill_vllm_on_port(port)
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
class ModelUtilsMixin:
|
|
362
|
+
"""Mixin for model utility methods."""
|
|
363
|
+
|
|
364
|
+
@staticmethod
|
|
365
|
+
def list_models(client: Union[OpenAI, int, str, None] = None) -> List[str]:
|
|
366
|
+
"""
|
|
367
|
+
List available models from the OpenAI client.
|
|
368
|
+
|
|
369
|
+
Args:
|
|
370
|
+
client: OpenAI client, port number, or base_url string
|
|
371
|
+
|
|
372
|
+
Returns:
|
|
373
|
+
List of available model names
|
|
374
|
+
"""
|
|
375
|
+
from .utils import get_base_client
|
|
376
|
+
|
|
377
|
+
client_instance = get_base_client(client, cache=False)
|
|
378
|
+
models = client_instance.models.list().data
|
|
379
|
+
return [m.id for m in models]
|
llm_utils/lm/openai_memoize.py
CHANGED
|
@@ -42,13 +42,16 @@ class MOpenAI(OpenAI):
|
|
|
42
42
|
|
|
43
43
|
def __init__(self, *args, cache=True, **kwargs):
|
|
44
44
|
super().__init__(*args, **kwargs)
|
|
45
|
+
self._orig_post = self.post
|
|
45
46
|
if cache:
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
self.post =
|
|
47
|
+
self.set_cache(cache)
|
|
48
|
+
|
|
49
|
+
def set_cache(self, cache: bool) -> None:
|
|
50
|
+
"""Enable or disable caching of the post method."""
|
|
51
|
+
if cache and self.post == self._orig_post:
|
|
52
|
+
self.post = memoize(self._orig_post) # type: ignore
|
|
53
|
+
elif not cache and self.post != self._orig_post:
|
|
54
|
+
self.post = self._orig_post
|
|
52
55
|
|
|
53
56
|
|
|
54
57
|
class MAsyncOpenAI(AsyncOpenAI):
|
|
@@ -76,5 +79,13 @@ class MAsyncOpenAI(AsyncOpenAI):
|
|
|
76
79
|
|
|
77
80
|
def __init__(self, *args, cache=True, **kwargs):
|
|
78
81
|
super().__init__(*args, **kwargs)
|
|
82
|
+
self._orig_post = self.post
|
|
79
83
|
if cache:
|
|
80
|
-
self.
|
|
84
|
+
self.set_cache(cache)
|
|
85
|
+
|
|
86
|
+
def set_cache(self, cache: bool) -> None:
|
|
87
|
+
"""Enable or disable caching of the post method."""
|
|
88
|
+
if cache and self.post == self._orig_post:
|
|
89
|
+
self.post = memoize(self._orig_post) # type: ignore
|
|
90
|
+
elif not cache and self.post != self._orig_post:
|
|
91
|
+
self.post = self._orig_post
|
llm_utils/lm/signature.py
CHANGED
|
@@ -5,43 +5,38 @@ This module provides a declarative way to define LLM input/output schemas
|
|
|
5
5
|
with field descriptions and type annotations.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
|
-
from typing import Any, Dict, List, Type,
|
|
8
|
+
from typing import Any, Dict, List, Type, get_type_hints, Annotated, get_origin, get_args, cast
|
|
9
9
|
from pydantic import BaseModel, Field
|
|
10
10
|
import inspect
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
class
|
|
14
|
-
"""
|
|
13
|
+
class _FieldProxy:
|
|
14
|
+
"""Proxy that stores field information while appearing type-compatible."""
|
|
15
15
|
|
|
16
|
-
def __init__(self, desc: str = "", **kwargs):
|
|
16
|
+
def __init__(self, field_type: str, desc: str = "", **kwargs):
|
|
17
|
+
self.field_type = field_type # 'input' or 'output'
|
|
17
18
|
self.desc = desc
|
|
18
19
|
self.kwargs = kwargs
|
|
19
|
-
|
|
20
|
-
def __class_getitem__(cls, item):
|
|
21
|
-
"""Support for InputField[type] syntax."""
|
|
22
|
-
return item
|
|
23
20
|
|
|
24
21
|
|
|
25
|
-
|
|
26
|
-
"""
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
"""Support for OutputField[type] syntax."""
|
|
34
|
-
return item
|
|
22
|
+
def InputField(desc: str = "", **kwargs) -> Any:
|
|
23
|
+
"""Create an input field descriptor."""
|
|
24
|
+
return cast(Any, _FieldProxy('input', desc=desc, **kwargs))
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def OutputField(desc: str = "", **kwargs) -> Any:
|
|
28
|
+
"""Create an output field descriptor."""
|
|
29
|
+
return cast(Any, _FieldProxy('output', desc=desc, **kwargs))
|
|
35
30
|
|
|
36
31
|
|
|
37
32
|
# Type aliases for cleaner syntax
|
|
38
33
|
def Input(desc: str = "", **kwargs) -> Any:
|
|
39
|
-
"""Create an input field descriptor."""
|
|
34
|
+
"""Create an input field descriptor that's compatible with type annotations."""
|
|
40
35
|
return InputField(desc=desc, **kwargs)
|
|
41
36
|
|
|
42
37
|
|
|
43
38
|
def Output(desc: str = "", **kwargs) -> Any:
|
|
44
|
-
"""Create an output field descriptor."""
|
|
39
|
+
"""Create an output field descriptor that's compatible with type annotations."""
|
|
45
40
|
return OutputField(desc=desc, **kwargs)
|
|
46
41
|
|
|
47
42
|
|
|
@@ -67,24 +62,24 @@ class SignatureMeta(type):
|
|
|
67
62
|
if args:
|
|
68
63
|
# First arg is the actual type
|
|
69
64
|
field_type = args[0]
|
|
70
|
-
# Look for
|
|
65
|
+
# Look for _FieldProxy in the metadata
|
|
71
66
|
for metadata in args[1:]:
|
|
72
|
-
if isinstance(metadata,
|
|
67
|
+
if isinstance(metadata, _FieldProxy):
|
|
73
68
|
field_desc = metadata
|
|
74
69
|
break
|
|
75
70
|
|
|
76
71
|
# Handle old syntax with direct assignment
|
|
77
|
-
if field_desc is None and isinstance(field_value,
|
|
72
|
+
if field_desc is None and isinstance(field_value, _FieldProxy):
|
|
78
73
|
field_desc = field_value
|
|
79
74
|
|
|
80
75
|
# Store field information
|
|
81
|
-
if
|
|
76
|
+
if field_desc and field_desc.field_type == 'input':
|
|
82
77
|
input_fields[field_name] = {
|
|
83
78
|
'type': field_type,
|
|
84
79
|
'desc': field_desc.desc,
|
|
85
80
|
**field_desc.kwargs
|
|
86
81
|
}
|
|
87
|
-
elif
|
|
82
|
+
elif field_desc and field_desc.field_type == 'output':
|
|
88
83
|
output_fields[field_name] = {
|
|
89
84
|
'type': field_type,
|
|
90
85
|
'desc': field_desc.desc,
|
|
@@ -136,10 +131,10 @@ class Signature(metaclass=SignatureMeta):
|
|
|
136
131
|
return instruction
|
|
137
132
|
|
|
138
133
|
@classmethod
|
|
139
|
-
def get_input_model(cls) ->
|
|
134
|
+
def get_input_model(cls) -> Type[BaseModel]:
|
|
140
135
|
"""Generate Pydantic input model from input fields."""
|
|
141
136
|
if not cls._input_fields:
|
|
142
|
-
|
|
137
|
+
raise ValueError(f"Signature {cls.__name__} must have at least one input field")
|
|
143
138
|
|
|
144
139
|
fields = {}
|
|
145
140
|
annotations = {}
|
|
@@ -170,10 +165,10 @@ class Signature(metaclass=SignatureMeta):
|
|
|
170
165
|
return input_model
|
|
171
166
|
|
|
172
167
|
@classmethod
|
|
173
|
-
def get_output_model(cls) ->
|
|
168
|
+
def get_output_model(cls) -> Type[BaseModel]:
|
|
174
169
|
"""Generate Pydantic output model from output fields."""
|
|
175
170
|
if not cls._output_fields:
|
|
176
|
-
|
|
171
|
+
raise ValueError(f"Signature {cls.__name__} must have at least one output field")
|
|
177
172
|
|
|
178
173
|
fields = {}
|
|
179
174
|
annotations = {}
|
|
@@ -259,17 +254,11 @@ if __name__ == "__main__":
|
|
|
259
254
|
|
|
260
255
|
print("\nInput Model:")
|
|
261
256
|
input_model = judge_class.get_input_model()
|
|
262
|
-
|
|
263
|
-
print(input_model.model_json_schema()) # type: ignore
|
|
264
|
-
else:
|
|
265
|
-
print("String input model")
|
|
257
|
+
print(input_model.model_json_schema())
|
|
266
258
|
|
|
267
259
|
print("\nOutput Model:")
|
|
268
260
|
output_model = judge_class.get_output_model()
|
|
269
|
-
|
|
270
|
-
print(output_model.model_json_schema()) # type: ignore
|
|
271
|
-
else:
|
|
272
|
-
print("String output model")
|
|
261
|
+
print(output_model.model_json_schema())
|
|
273
262
|
|
|
274
263
|
# Test instance usage
|
|
275
264
|
judge = judge_class()
|