speedy-utils 1.1.22__py3-none-any.whl → 1.1.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_utils/lm/mixins.py ADDED
@@ -0,0 +1,379 @@
1
+ """Mixin classes for LLM functionality extensions."""
2
+
3
+ import os
4
+ import subprocess
5
+ from time import sleep
6
+ from typing import Any, Dict, List, Optional, Type, Union
7
+
8
+ import requests
9
+ from loguru import logger
10
+ from openai import OpenAI
11
+ from pydantic import BaseModel
12
+
13
+
14
+ class TemperatureRangeMixin:
15
+ """Mixin for sampling with different temperature ranges."""
16
+
17
+ def temperature_range_sampling(
18
+ self,
19
+ input_data: Union[str, BaseModel, List[Dict]],
20
+ temperature_ranges: tuple[float, float],
21
+ n: int = 32,
22
+ response_model: Optional[Type[BaseModel] | Type[str]] = None,
23
+ **runtime_kwargs,
24
+ ) -> List[Dict[str, Any]]:
25
+ """
26
+ Sample LLM responses with a range of temperatures.
27
+
28
+ This method generates multiple responses by systematically varying
29
+ the temperature parameter, which controls randomness in the output.
30
+
31
+ Args:
32
+ input_data: Input data (string, BaseModel, or message list)
33
+ temperature_ranges: Tuple of (min_temp, max_temp) to sample
34
+ n: Number of temperature samples to generate (must be >= 2)
35
+ response_model: Optional response model override
36
+ **runtime_kwargs: Additional runtime parameters
37
+
38
+ Returns:
39
+ List of response dictionaries from all temperature samples
40
+ """
41
+ from speedy_utils.multi_worker.thread import multi_thread
42
+
43
+ min_temp, max_temp = temperature_ranges
44
+ if n < 2:
45
+ raise ValueError(f"n must be >= 2, got {n}")
46
+
47
+ step = (max_temp - min_temp) / (n - 1)
48
+ list_kwargs = []
49
+
50
+ for i in range(n):
51
+ kwargs = dict(
52
+ temperature=min_temp + i * step,
53
+ i=i,
54
+ **runtime_kwargs,
55
+ )
56
+ list_kwargs.append(kwargs)
57
+
58
+ def f(kwargs):
59
+ i = kwargs.pop("i")
60
+ sleep(i * 0.05)
61
+ return self.__inner_call__(
62
+ input_data,
63
+ response_model=response_model,
64
+ **kwargs,
65
+ )[0]
66
+
67
+ choices = multi_thread(f, list_kwargs, progress=False)
68
+ return [c for c in choices if c is not None]
69
+
70
+
71
+ class TwoStepPydanticMixin:
72
+ """Mixin for two-step Pydantic parsing functionality."""
73
+
74
+ def two_step_pydantic_parse(
75
+ self,
76
+ input_data: Union[str, BaseModel, List[Dict]],
77
+ response_model: Type[BaseModel],
78
+ **runtime_kwargs,
79
+ ) -> List[Dict[str, Any]]:
80
+ """
81
+ Parse responses in two steps: text completion then Pydantic parsing.
82
+
83
+ This is useful for models that may include reasoning or extra text
84
+ before the JSON output.
85
+
86
+ Args:
87
+ input_data: Input data (string, BaseModel, or message list)
88
+ response_model: Pydantic model to parse into
89
+ **runtime_kwargs: Additional runtime parameters
90
+
91
+ Returns:
92
+ List of parsed response dictionaries
93
+ """
94
+ # Step 1: Get text completions
95
+ results = self.text_completion(input_data, **runtime_kwargs)
96
+ parsed_results = []
97
+
98
+ for result in results:
99
+ response_text = result["parsed"]
100
+ messages = result["messages"]
101
+
102
+ # Handle reasoning models that use <think> tags
103
+ if "</think>" in response_text:
104
+ response_text = response_text.split("</think>")[1]
105
+
106
+ try:
107
+ # Try direct parsing
108
+ parsed = response_model.model_validate_json(response_text)
109
+ except Exception:
110
+ # Fallback: use LLM to extract JSON
111
+ logger.warning("Failed to parse JSON directly, using LLM to extract")
112
+ _parsed_messages = [
113
+ {
114
+ "role": "system",
115
+ "content": ("You are a helpful assistant that extracts JSON from text."),
116
+ },
117
+ {
118
+ "role": "user",
119
+ "content": (f"Extract JSON from the following text:\n{response_text}"),
120
+ },
121
+ ]
122
+ parsed_result = self.pydantic_parse(
123
+ _parsed_messages,
124
+ response_model=response_model,
125
+ **runtime_kwargs,
126
+ )[0]
127
+ parsed = parsed_result["parsed"]
128
+
129
+ parsed_results.append({"parsed": parsed, "messages": messages})
130
+
131
+ return parsed_results
132
+
133
+
134
+ class VLLMMixin:
135
+ """Mixin for VLLM server management and LoRA operations."""
136
+
137
+ def _setup_vllm_server(self) -> None:
138
+ """
139
+ Setup VLLM server if vllm_cmd is provided.
140
+
141
+ This method handles:
142
+ - Server reuse logic
143
+ - Starting new servers
144
+ - Port management
145
+
146
+ Should be called from __init__.
147
+ """
148
+ from .utils import (
149
+ _extract_port_from_vllm_cmd,
150
+ _is_server_running,
151
+ _kill_vllm_on_port,
152
+ _start_vllm_server,
153
+ get_base_client,
154
+ )
155
+
156
+ if not hasattr(self, "vllm_cmd") or not self.vllm_cmd:
157
+ return
158
+
159
+ port = _extract_port_from_vllm_cmd(self.vllm_cmd)
160
+ reuse_existing = False
161
+
162
+ if self.vllm_reuse:
163
+ try:
164
+ reuse_client = get_base_client(port, cache=False)
165
+ models_response = reuse_client.models.list()
166
+ if getattr(models_response, "data", None):
167
+ reuse_existing = True
168
+ logger.info(
169
+ f"VLLM server already running on port {port}, reusing existing server (vllm_reuse=True)"
170
+ )
171
+ else:
172
+ logger.info(f"No models returned from VLLM server on port {port}; starting a new server")
173
+ except Exception as exc:
174
+ logger.info(
175
+ f"Unable to reach VLLM server on port {port} (list_models failed): {exc}. Starting a new server."
176
+ )
177
+
178
+ if not self.vllm_reuse:
179
+ if _is_server_running(port):
180
+ logger.info(f"VLLM server already running on port {port}, killing it first (vllm_reuse=False)")
181
+ _kill_vllm_on_port(port)
182
+ logger.info(f"Starting new VLLM server on port {port}")
183
+ self.vllm_process = _start_vllm_server(self.vllm_cmd, self.vllm_timeout)
184
+ elif not reuse_existing:
185
+ logger.info(f"Starting VLLM server on port {port}")
186
+ self.vllm_process = _start_vllm_server(self.vllm_cmd, self.vllm_timeout)
187
+
188
+ def _load_lora_adapter(self) -> None:
189
+ """
190
+ Load LoRA adapter from the specified lora_path.
191
+
192
+ This method:
193
+ 1. Validates that lora_path is a valid LoRA directory
194
+ 2. Checks if LoRA is already loaded (unless force_lora_unload)
195
+ 3. Loads the LoRA adapter and updates the model name
196
+ """
197
+ from .utils import (
198
+ _is_lora_path,
199
+ _get_port_from_client,
200
+ _load_lora_adapter,
201
+ )
202
+
203
+ if not self.lora_path:
204
+ return
205
+
206
+ if not _is_lora_path(self.lora_path):
207
+ raise ValueError(f"Invalid LoRA path '{self.lora_path}': Directory must contain 'adapter_config.json'")
208
+
209
+ logger.info(f"Loading LoRA adapter from: {self.lora_path}")
210
+
211
+ # Get the expected LoRA name (basename of the path)
212
+ lora_name = os.path.basename(self.lora_path.rstrip("/\\"))
213
+ if not lora_name: # Handle edge case of empty basename
214
+ lora_name = os.path.basename(os.path.dirname(self.lora_path))
215
+
216
+ # Get list of available models to check if LoRA is already loaded
217
+ try:
218
+ available_models = [m.id for m in self.client.models.list().data]
219
+ except Exception as e:
220
+ logger.warning(f"Failed to list models, proceeding with LoRA load: {str(e)[:100]}")
221
+ available_models = []
222
+
223
+ # Check if LoRA is already loaded
224
+ if lora_name in available_models and not self.force_lora_unload:
225
+ logger.info(f"LoRA adapter '{lora_name}' is already loaded, using existing model")
226
+ self.model_kwargs["model"] = lora_name
227
+ return
228
+
229
+ # Force unload if requested
230
+ if self.force_lora_unload and lora_name in available_models:
231
+ logger.info(f"Force unloading LoRA adapter '{lora_name}' before reloading")
232
+ port = _get_port_from_client(self.client)
233
+ if port is not None:
234
+ try:
235
+ VLLMMixin.unload_lora(port, lora_name)
236
+ logger.info(f"Successfully unloaded LoRA adapter: {lora_name}")
237
+ except Exception as e:
238
+ logger.warning(f"Failed to unload LoRA adapter: {str(e)[:100]}")
239
+
240
+ # Get port from client for API calls
241
+ port = _get_port_from_client(self.client)
242
+ if port is None:
243
+ raise ValueError(
244
+ f"Cannot load LoRA adapter '{self.lora_path}': "
245
+ f"Unable to determine port from client base_url. "
246
+ f"LoRA loading requires a client initialized with port."
247
+ )
248
+
249
+ try:
250
+ # Load the LoRA adapter
251
+ loaded_lora_name = _load_lora_adapter(self.lora_path, port)
252
+ logger.info(f"Successfully loaded LoRA adapter: {loaded_lora_name}")
253
+
254
+ # Update model name to the loaded LoRA name
255
+ self.model_kwargs["model"] = loaded_lora_name
256
+
257
+ except requests.RequestException as e:
258
+ # Check if error is due to LoRA already being loaded
259
+ error_msg = str(e)
260
+ if "400" in error_msg or "Bad Request" in error_msg:
261
+ logger.info(f"LoRA adapter may already be loaded, attempting to use '{lora_name}'")
262
+ # Refresh the model list to check if it's now available
263
+ try:
264
+ updated_models = [m.id for m in self.client.models.list().data]
265
+ if lora_name in updated_models:
266
+ logger.info(f"Found LoRA adapter '{lora_name}' in updated model list")
267
+ self.model_kwargs["model"] = lora_name
268
+ return
269
+ except Exception:
270
+ pass # Fall through to original error
271
+
272
+ raise ValueError(f"Failed to load LoRA adapter from '{self.lora_path}': {error_msg[:100]}")
273
+
274
+ def unload_lora_adapter(self, lora_path: str) -> None:
275
+ """
276
+ Unload a LoRA adapter.
277
+
278
+ Args:
279
+ lora_path: Path to the LoRA adapter directory to unload
280
+
281
+ Raises:
282
+ ValueError: If unable to determine port from client
283
+ """
284
+ from .utils import _get_port_from_client, _unload_lora_adapter
285
+
286
+ port = _get_port_from_client(self.client)
287
+ if port is None:
288
+ raise ValueError(
289
+ "Cannot unload LoRA adapter: "
290
+ "Unable to determine port from client base_url. "
291
+ "LoRA operations require a client initialized with port."
292
+ )
293
+
294
+ _unload_lora_adapter(lora_path, port)
295
+ lora_name = os.path.basename(lora_path.rstrip("/\\"))
296
+ logger.info(f"Unloaded LoRA adapter: {lora_name}")
297
+
298
+ @staticmethod
299
+ def unload_lora(port: int, lora_name: str) -> None:
300
+ """
301
+ Static method to unload a LoRA adapter by name.
302
+
303
+ Args:
304
+ port: Port number for the API endpoint
305
+ lora_name: Name of the LoRA adapter to unload
306
+
307
+ Raises:
308
+ requests.RequestException: If the API call fails
309
+ """
310
+ try:
311
+ response = requests.post(
312
+ f"http://localhost:{port}/v1/unload_lora_adapter",
313
+ headers={
314
+ "accept": "application/json",
315
+ "Content-Type": "application/json",
316
+ },
317
+ json={"lora_name": lora_name, "lora_int_id": 0},
318
+ )
319
+ response.raise_for_status()
320
+ logger.info(f"Successfully unloaded LoRA adapter: {lora_name}")
321
+ except requests.RequestException as e:
322
+ logger.error(f"Error unloading LoRA adapter '{lora_name}': {str(e)[:100]}")
323
+ raise
324
+
325
+ def cleanup_vllm_server(self) -> None:
326
+ """Stop the VLLM server process if started by this instance."""
327
+ from .utils import stop_vllm_process
328
+
329
+ if hasattr(self, "vllm_process") and self.vllm_process is not None:
330
+ stop_vllm_process(self.vllm_process)
331
+ self.vllm_process = None
332
+
333
+ @staticmethod
334
+ def kill_all_vllm() -> int:
335
+ """
336
+ Kill all tracked VLLM server processes.
337
+
338
+ Returns:
339
+ Number of processes killed
340
+ """
341
+ from .utils import kill_all_vllm_processes
342
+
343
+ return kill_all_vllm_processes()
344
+
345
+ @staticmethod
346
+ def kill_vllm_on_port(port: int) -> bool:
347
+ """
348
+ Kill VLLM server running on a specific port.
349
+
350
+ Args:
351
+ port: Port number to kill server on
352
+
353
+ Returns:
354
+ True if a server was killed, False if no server was running
355
+ """
356
+ from .utils import _kill_vllm_on_port
357
+
358
+ return _kill_vllm_on_port(port)
359
+
360
+
361
+ class ModelUtilsMixin:
362
+ """Mixin for model utility methods."""
363
+
364
+ @staticmethod
365
+ def list_models(client: Union[OpenAI, int, str, None] = None) -> List[str]:
366
+ """
367
+ List available models from the OpenAI client.
368
+
369
+ Args:
370
+ client: OpenAI client, port number, or base_url string
371
+
372
+ Returns:
373
+ List of available model names
374
+ """
375
+ from .utils import get_base_client
376
+
377
+ client_instance = get_base_client(client, cache=False)
378
+ models = client_instance.models.list().data
379
+ return [m.id for m in models]
@@ -42,13 +42,16 @@ class MOpenAI(OpenAI):
42
42
 
43
43
  def __init__(self, *args, cache=True, **kwargs):
44
44
  super().__init__(*args, **kwargs)
45
+ self._orig_post = self.post
45
46
  if cache:
46
- # Create a memoized wrapper for the instance's post method.
47
- # The memoize decorator now preserves exact type information,
48
- # so no casting is needed.
49
- orig_post = self.post
50
- memoized = memoize(orig_post)
51
- self.post = memoized
47
+ self.set_cache(cache)
48
+
49
+ def set_cache(self, cache: bool) -> None:
50
+ """Enable or disable caching of the post method."""
51
+ if cache and self.post == self._orig_post:
52
+ self.post = memoize(self._orig_post) # type: ignore
53
+ elif not cache and self.post != self._orig_post:
54
+ self.post = self._orig_post
52
55
 
53
56
 
54
57
  class MAsyncOpenAI(AsyncOpenAI):
@@ -76,5 +79,13 @@ class MAsyncOpenAI(AsyncOpenAI):
76
79
 
77
80
  def __init__(self, *args, cache=True, **kwargs):
78
81
  super().__init__(*args, **kwargs)
82
+ self._orig_post = self.post
79
83
  if cache:
80
- self.post = memoize(self.post) # type: ignore
84
+ self.set_cache(cache)
85
+
86
+ def set_cache(self, cache: bool) -> None:
87
+ """Enable or disable caching of the post method."""
88
+ if cache and self.post == self._orig_post:
89
+ self.post = memoize(self._orig_post) # type: ignore
90
+ elif not cache and self.post != self._orig_post:
91
+ self.post = self._orig_post
@@ -0,0 +1,271 @@
1
+ """
2
+ DSPy-like signature system for structured LLM interactions.
3
+
4
+ This module provides a declarative way to define LLM input/output schemas
5
+ with field descriptions and type annotations.
6
+ """
7
+
8
+ from typing import Any, Dict, List, Type, get_type_hints, Annotated, get_origin, get_args, cast
9
+ from pydantic import BaseModel, Field
10
+ import inspect
11
+
12
+
13
+ class _FieldProxy:
14
+ """Proxy that stores field information while appearing type-compatible."""
15
+
16
+ def __init__(self, field_type: str, desc: str = "", **kwargs):
17
+ self.field_type = field_type # 'input' or 'output'
18
+ self.desc = desc
19
+ self.kwargs = kwargs
20
+
21
+
22
+ def InputField(desc: str = "", **kwargs) -> Any:
23
+ """Create an input field descriptor."""
24
+ return cast(Any, _FieldProxy('input', desc=desc, **kwargs))
25
+
26
+
27
+ def OutputField(desc: str = "", **kwargs) -> Any:
28
+ """Create an output field descriptor."""
29
+ return cast(Any, _FieldProxy('output', desc=desc, **kwargs))
30
+
31
+
32
+ # Type aliases for cleaner syntax
33
+ def Input(desc: str = "", **kwargs) -> Any:
34
+ """Create an input field descriptor that's compatible with type annotations."""
35
+ return InputField(desc=desc, **kwargs)
36
+
37
+
38
+ def Output(desc: str = "", **kwargs) -> Any:
39
+ """Create an output field descriptor that's compatible with type annotations."""
40
+ return OutputField(desc=desc, **kwargs)
41
+
42
+
43
+ class SignatureMeta(type):
44
+ """Metaclass for Signature that processes field annotations."""
45
+
46
+ def __new__(cls, name, bases, namespace, **kwargs):
47
+ # Get type hints for this class
48
+ annotations = namespace.get('__annotations__', {})
49
+
50
+ # Store field information
51
+ input_fields = {}
52
+ output_fields = {}
53
+
54
+ for field_name, field_type in annotations.items():
55
+ field_value = namespace.get(field_name)
56
+ field_desc = None
57
+
58
+ # Handle Annotated[Type, Field(...)] syntax using get_origin/get_args
59
+ if get_origin(field_type) is Annotated:
60
+ # Extract args from Annotated type
61
+ args = get_args(field_type)
62
+ if args:
63
+ # First arg is the actual type
64
+ field_type = args[0]
65
+ # Look for _FieldProxy in the metadata
66
+ for metadata in args[1:]:
67
+ if isinstance(metadata, _FieldProxy):
68
+ field_desc = metadata
69
+ break
70
+
71
+ # Handle old syntax with direct assignment
72
+ if field_desc is None and isinstance(field_value, _FieldProxy):
73
+ field_desc = field_value
74
+
75
+ # Store field information
76
+ if field_desc and field_desc.field_type == 'input':
77
+ input_fields[field_name] = {
78
+ 'type': field_type,
79
+ 'desc': field_desc.desc,
80
+ **field_desc.kwargs
81
+ }
82
+ elif field_desc and field_desc.field_type == 'output':
83
+ output_fields[field_name] = {
84
+ 'type': field_type,
85
+ 'desc': field_desc.desc,
86
+ **field_desc.kwargs
87
+ }
88
+
89
+ # Store in class attributes
90
+ namespace['_input_fields'] = input_fields
91
+ namespace['_output_fields'] = output_fields
92
+
93
+ return super().__new__(cls, name, bases, namespace)
94
+
95
+
96
+ class Signature(metaclass=SignatureMeta):
97
+ """Base class for defining LLM signatures with input and output fields."""
98
+
99
+ _input_fields: Dict[str, Dict[str, Any]] = {}
100
+ _output_fields: Dict[str, Dict[str, Any]] = {}
101
+
102
+ def __init__(self, **kwargs):
103
+ """Initialize signature with field values."""
104
+ for field_name, value in kwargs.items():
105
+ setattr(self, field_name, value)
106
+
107
+ @classmethod
108
+ def get_instruction(cls) -> str:
109
+ """Generate instruction text from docstring and field descriptions."""
110
+ instruction = cls.__doc__ or "Complete the following task."
111
+ instruction = instruction.strip()
112
+
113
+ # Add input field descriptions
114
+ if cls._input_fields:
115
+ instruction += "\n\n**Input Fields:**\n"
116
+ for field_name, field_info in cls._input_fields.items():
117
+ desc = field_info.get('desc', '')
118
+ field_type = field_info['type']
119
+ type_str = getattr(field_type, '__name__', str(field_type))
120
+ instruction += f"- {field_name} ({type_str}): {desc}\n"
121
+
122
+ # Add output field descriptions
123
+ if cls._output_fields:
124
+ instruction += "\n**Output Fields:**\n"
125
+ for field_name, field_info in cls._output_fields.items():
126
+ desc = field_info.get('desc', '')
127
+ field_type = field_info['type']
128
+ type_str = getattr(field_type, '__name__', str(field_type))
129
+ instruction += f"- {field_name} ({type_str}): {desc}\n"
130
+
131
+ return instruction
132
+
133
+ @classmethod
134
+ def get_input_model(cls) -> Type[BaseModel]:
135
+ """Generate Pydantic input model from input fields."""
136
+ if not cls._input_fields:
137
+ raise ValueError(f"Signature {cls.__name__} must have at least one input field")
138
+
139
+ fields = {}
140
+ annotations = {}
141
+
142
+ for field_name, field_info in cls._input_fields.items():
143
+ field_type = field_info['type']
144
+ desc = field_info.get('desc', '')
145
+
146
+ # Create Pydantic field
147
+ field_kwargs = {k: v for k, v in field_info.items()
148
+ if k not in ['type', 'desc']}
149
+ if desc:
150
+ field_kwargs['description'] = desc
151
+
152
+ fields[field_name] = Field(**field_kwargs) if field_kwargs else Field()
153
+ annotations[field_name] = field_type
154
+
155
+ # Create dynamic Pydantic model
156
+ input_model = type(
157
+ f"{cls.__name__}Input",
158
+ (BaseModel,),
159
+ {
160
+ '__annotations__': annotations,
161
+ **fields
162
+ }
163
+ )
164
+
165
+ return input_model
166
+
167
+ @classmethod
168
+ def get_output_model(cls) -> Type[BaseModel]:
169
+ """Generate Pydantic output model from output fields."""
170
+ if not cls._output_fields:
171
+ raise ValueError(f"Signature {cls.__name__} must have at least one output field")
172
+
173
+ fields = {}
174
+ annotations = {}
175
+
176
+ for field_name, field_info in cls._output_fields.items():
177
+ field_type = field_info['type']
178
+ desc = field_info.get('desc', '')
179
+
180
+ # Create Pydantic field
181
+ field_kwargs = {k: v for k, v in field_info.items()
182
+ if k not in ['type', 'desc']}
183
+ if desc:
184
+ field_kwargs['description'] = desc
185
+
186
+ fields[field_name] = Field(**field_kwargs) if field_kwargs else Field()
187
+ annotations[field_name] = field_type
188
+
189
+ # Create dynamic Pydantic model
190
+ output_model = type(
191
+ f"{cls.__name__}Output",
192
+ (BaseModel,),
193
+ {
194
+ '__annotations__': annotations,
195
+ **fields
196
+ }
197
+ )
198
+
199
+ return output_model
200
+
201
+ def format_input(self, **kwargs) -> str:
202
+ """Format input fields as a string."""
203
+ input_data = {}
204
+
205
+ # Collect input field values
206
+ for field_name in self._input_fields:
207
+ if field_name in kwargs:
208
+ input_data[field_name] = kwargs[field_name]
209
+ elif hasattr(self, field_name):
210
+ input_data[field_name] = getattr(self, field_name)
211
+
212
+ # Format as key-value pairs
213
+ formatted_lines = []
214
+ for field_name, value in input_data.items():
215
+ field_info = self._input_fields[field_name]
216
+ desc = field_info.get('desc', '')
217
+ if desc:
218
+ formatted_lines.append(f"{field_name} ({desc}): {value}")
219
+ else:
220
+ formatted_lines.append(f"{field_name}: {value}")
221
+
222
+ return '\n'.join(formatted_lines)
223
+
224
+
225
+ # Export functions for easier importing
226
+ __all__ = ['Signature', 'InputField', 'OutputField', 'Input', 'Output']
227
+
228
+
229
+ # Example usage for testing
230
+ if __name__ == "__main__":
231
+ # Define a signature like DSPy - using Annotated approach
232
+ class FactJudge(Signature):
233
+ """Judge if the answer is factually correct based on the context."""
234
+
235
+ context: Annotated[str, Input("Context for the prediction")]
236
+ question: Annotated[str, Input("Question to be answered")]
237
+ answer: Annotated[str, Input("Answer for the question")]
238
+ factually_correct: Annotated[bool, Output("Is the answer factually correct based on the context?")]
239
+
240
+ # Alternative syntax still works but will show type warnings
241
+ class FactJudgeOldSyntax(Signature):
242
+ """Judge if the answer is factually correct based on the context."""
243
+
244
+ context: str = InputField(desc="Context for the prediction") # type: ignore
245
+ question: str = InputField(desc="Question to be answered") # type: ignore
246
+ answer: str = InputField(desc="Answer for the question") # type: ignore
247
+ factually_correct: bool = OutputField(desc="Is the answer factually correct based on the context?") # type: ignore
248
+
249
+ # Test both signatures
250
+ for judge_class in [FactJudge, FactJudgeOldSyntax]:
251
+ print(f"\n=== Testing {judge_class.__name__} ===")
252
+ print("Instruction:")
253
+ print(judge_class.get_instruction())
254
+
255
+ print("\nInput Model:")
256
+ input_model = judge_class.get_input_model()
257
+ print(input_model.model_json_schema())
258
+
259
+ print("\nOutput Model:")
260
+ output_model = judge_class.get_output_model()
261
+ print(output_model.model_json_schema())
262
+
263
+ # Test instance usage
264
+ judge = judge_class()
265
+ input_text = judge.format_input(
266
+ context="The sky is blue during daytime.",
267
+ question="What color is the sky?",
268
+ answer="Blue"
269
+ )
270
+ print("\nFormatted Input:")
271
+ print(input_text)