speedy-utils 1.1.22__py3-none-any.whl → 1.1.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +11 -3
- llm_utils/lm/__init__.py +10 -0
- llm_utils/lm/llm_as_a_judge.py +390 -0
- llm_utils/lm/signature.py +282 -0
- {speedy_utils-1.1.22.dist-info → speedy_utils-1.1.23.dist-info}/METADATA +1 -1
- {speedy_utils-1.1.22.dist-info → speedy_utils-1.1.23.dist-info}/RECORD +8 -7
- llm_utils/lm/lm.py +0 -207
- {speedy_utils-1.1.22.dist-info → speedy_utils-1.1.23.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.22.dist-info → speedy_utils-1.1.23.dist-info}/entry_points.txt +0 -0
llm_utils/__init__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from llm_utils.lm.openai_memoize import MOpenAI
|
|
2
|
-
from llm_utils.lm import LLMTask, AsyncLM, AsyncLLMTask
|
|
2
|
+
from llm_utils.lm import LLMTask, AsyncLM, AsyncLLMTask, LLMJudgeBase, ChainOfThought, TranslationEvaluatorJudge, Signature, InputField, OutputField, Input, Output
|
|
3
3
|
from llm_utils.vector_cache import VectorCache
|
|
4
4
|
from llm_utils.lm.lm_base import get_model_name
|
|
5
5
|
from llm_utils.lm.base_prompt_builder import BasePromptBuilder
|
|
@@ -15,7 +15,7 @@ def kill_vllm_on_port(port: int) -> bool:
|
|
|
15
15
|
"""Kill VLLM server on specific port. Returns True if server was killed."""
|
|
16
16
|
return LLMTask.kill_vllm_on_port(port)
|
|
17
17
|
|
|
18
|
-
from .chat_format import (
|
|
18
|
+
from llm_utils.chat_format import (
|
|
19
19
|
build_chatml_input,
|
|
20
20
|
display_chat_messages_as_html,
|
|
21
21
|
display_conversations,
|
|
@@ -46,5 +46,13 @@ __all__ = [
|
|
|
46
46
|
"BasePromptBuilder",
|
|
47
47
|
"LLM",
|
|
48
48
|
"kill_all_vllm",
|
|
49
|
-
"kill_vllm_on_port"
|
|
49
|
+
"kill_vllm_on_port",
|
|
50
|
+
"LLMJudgeBase",
|
|
51
|
+
"ChainOfThought",
|
|
52
|
+
"TranslationEvaluatorJudge",
|
|
53
|
+
"Signature",
|
|
54
|
+
"InputField",
|
|
55
|
+
"OutputField",
|
|
56
|
+
"Input",
|
|
57
|
+
"Output",
|
|
50
58
|
]
|
llm_utils/lm/__init__.py
CHANGED
|
@@ -3,6 +3,8 @@ from .async_lm.async_llm_task import AsyncLLMTask
|
|
|
3
3
|
from .lm_base import LMBase, get_model_name
|
|
4
4
|
from .llm_task import LLMTask
|
|
5
5
|
from .base_prompt_builder import BasePromptBuilder
|
|
6
|
+
from .llm_as_a_judge import LLMJudgeBase, ChainOfThought, TranslationEvaluatorJudge
|
|
7
|
+
from .signature import Signature, InputField, OutputField, Input, Output
|
|
6
8
|
|
|
7
9
|
__all__ = [
|
|
8
10
|
"LMBase",
|
|
@@ -10,4 +12,12 @@ __all__ = [
|
|
|
10
12
|
"AsyncLM",
|
|
11
13
|
"AsyncLLMTask",
|
|
12
14
|
"BasePromptBuilder",
|
|
15
|
+
"LLMJudgeBase",
|
|
16
|
+
"ChainOfThought",
|
|
17
|
+
"TranslationEvaluatorJudge",
|
|
18
|
+
"Signature",
|
|
19
|
+
"InputField",
|
|
20
|
+
"OutputField",
|
|
21
|
+
"Input",
|
|
22
|
+
"Output",
|
|
13
23
|
]
|
|
@@ -0,0 +1,390 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LLM-as-a-Judge implementation with template support and SFT export utilities.
|
|
3
|
+
|
|
4
|
+
This module provides a base class for creating LLM judges with structured
|
|
5
|
+
prompts, variable substitution, and export capabilities for fine-tuning.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
from typing import Any, Dict, List, Optional, Type, Union
|
|
10
|
+
from pydantic import BaseModel
|
|
11
|
+
from ..chat_format import get_conversation_one_turn
|
|
12
|
+
from .llm_task import LLMTask
|
|
13
|
+
from .signature import Signature
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LLMJudgeBase(LLMTask):
|
|
17
|
+
"""Base class for LLM judges with template support and SFT export."""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
system_prompt_template: str,
|
|
22
|
+
signature: Optional[Type[Signature]] = None,
|
|
23
|
+
**kwargs
|
|
24
|
+
):
|
|
25
|
+
"""
|
|
26
|
+
Initialize LLMJudgeBase.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
system_prompt_template: System prompt template with {variable} placeholders
|
|
30
|
+
signature: Optional Signature class for structured I/O
|
|
31
|
+
**kwargs: Additional arguments passed to LLMTask
|
|
32
|
+
"""
|
|
33
|
+
self.system_prompt_template = system_prompt_template
|
|
34
|
+
self.signature = signature
|
|
35
|
+
self.sft_data: List[Dict[str, Any]] = [] # Store SFT training examples
|
|
36
|
+
|
|
37
|
+
# Set instruction from signature if available
|
|
38
|
+
if signature is not None:
|
|
39
|
+
instruction = signature.get_instruction()
|
|
40
|
+
kwargs.setdefault('instruction', instruction)
|
|
41
|
+
kwargs.setdefault('output_model', signature.get_output_model())
|
|
42
|
+
else:
|
|
43
|
+
kwargs.setdefault('instruction', system_prompt_template)
|
|
44
|
+
|
|
45
|
+
super().__init__(**kwargs)
|
|
46
|
+
|
|
47
|
+
def format_system_prompt(self, variables: Dict[str, Any]) -> str:
|
|
48
|
+
"""Format system prompt template with provided variables."""
|
|
49
|
+
try:
|
|
50
|
+
return self.system_prompt_template.format(**variables)
|
|
51
|
+
except KeyError as e:
|
|
52
|
+
missing_var = str(e).strip("'")
|
|
53
|
+
raise ValueError(f"Missing required variable '{missing_var}' for system prompt template")
|
|
54
|
+
|
|
55
|
+
def judge(
|
|
56
|
+
self,
|
|
57
|
+
input_data: Union[str, Dict[str, Any], BaseModel],
|
|
58
|
+
variables: Optional[Dict[str, Any]] = None,
|
|
59
|
+
**runtime_kwargs
|
|
60
|
+
) -> List[Dict[str, Any]]:
|
|
61
|
+
"""
|
|
62
|
+
Execute judgment with variable substitution in system prompt.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
input_data: Input data for the judge
|
|
66
|
+
variables: Variables to substitute in system prompt template
|
|
67
|
+
**runtime_kwargs: Additional runtime arguments
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
List of judgment results
|
|
71
|
+
"""
|
|
72
|
+
variables = variables or {}
|
|
73
|
+
|
|
74
|
+
# Format system prompt with variables
|
|
75
|
+
formatted_prompt = self.format_system_prompt(variables)
|
|
76
|
+
|
|
77
|
+
# Temporarily override instruction
|
|
78
|
+
original_instruction = self.instruction
|
|
79
|
+
self.instruction = formatted_prompt
|
|
80
|
+
|
|
81
|
+
try:
|
|
82
|
+
# Handle different input types
|
|
83
|
+
if isinstance(input_data, dict):
|
|
84
|
+
processed_input = json.dumps(input_data)
|
|
85
|
+
else:
|
|
86
|
+
processed_input = input_data
|
|
87
|
+
results = self(processed_input, **runtime_kwargs)
|
|
88
|
+
|
|
89
|
+
# Store for SFT if needed
|
|
90
|
+
self._store_sft_example(input_data, results, variables, formatted_prompt)
|
|
91
|
+
|
|
92
|
+
return results
|
|
93
|
+
finally:
|
|
94
|
+
# Restore original instruction
|
|
95
|
+
self.instruction = original_instruction
|
|
96
|
+
|
|
97
|
+
def _store_sft_example(
|
|
98
|
+
self,
|
|
99
|
+
input_data: Union[str, Dict[str, Any], BaseModel],
|
|
100
|
+
results: List[Dict[str, Any]],
|
|
101
|
+
variables: Dict[str, Any],
|
|
102
|
+
formatted_prompt: str
|
|
103
|
+
) -> None:
|
|
104
|
+
"""Store example for SFT export."""
|
|
105
|
+
for result in results:
|
|
106
|
+
# Create input text
|
|
107
|
+
if isinstance(input_data, str):
|
|
108
|
+
input_text = input_data
|
|
109
|
+
elif isinstance(input_data, BaseModel):
|
|
110
|
+
input_text = input_data.model_dump_json()
|
|
111
|
+
elif isinstance(input_data, dict):
|
|
112
|
+
input_text = json.dumps(input_data)
|
|
113
|
+
else:
|
|
114
|
+
input_text = str(input_data)
|
|
115
|
+
|
|
116
|
+
# Extract output
|
|
117
|
+
output_text = result['parsed']
|
|
118
|
+
if isinstance(output_text, BaseModel):
|
|
119
|
+
output_text = output_text.model_dump_json()
|
|
120
|
+
elif not isinstance(output_text, str):
|
|
121
|
+
output_text = str(output_text)
|
|
122
|
+
|
|
123
|
+
# Create conversation format
|
|
124
|
+
messages = get_conversation_one_turn(
|
|
125
|
+
formatted_prompt,
|
|
126
|
+
input_text,
|
|
127
|
+
output_text
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
sft_example = {
|
|
131
|
+
'messages': messages,
|
|
132
|
+
'variables': variables,
|
|
133
|
+
'input_data': input_data,
|
|
134
|
+
'output': result['parsed']
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
self.sft_data.append(sft_example)
|
|
138
|
+
|
|
139
|
+
def export_sft_data(self, format: str = 'messages') -> List[Dict[str, Any]]:
|
|
140
|
+
"""
|
|
141
|
+
Export stored examples in SFT format.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
format: Export format ('messages', 'full', or 'sharegpt')
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
List of SFT training examples
|
|
148
|
+
"""
|
|
149
|
+
if format == 'messages':
|
|
150
|
+
return [{'messages': example['messages']} for example in self.sft_data]
|
|
151
|
+
elif format == 'full':
|
|
152
|
+
return self.sft_data
|
|
153
|
+
elif format == 'sharegpt':
|
|
154
|
+
# Convert to ShareGPT format
|
|
155
|
+
sharegpt_data = []
|
|
156
|
+
for example in self.sft_data:
|
|
157
|
+
conversations = []
|
|
158
|
+
for msg in example['messages']:
|
|
159
|
+
conversations.append({
|
|
160
|
+
'from': 'human' if msg['role'] == 'user' else 'gpt' if msg['role'] == 'assistant' else 'system',
|
|
161
|
+
'value': msg['content']
|
|
162
|
+
})
|
|
163
|
+
sharegpt_data.append({'conversations': conversations})
|
|
164
|
+
return sharegpt_data
|
|
165
|
+
else:
|
|
166
|
+
raise ValueError(f"Unsupported format: {format}. Choose from 'messages', 'full', or 'sharegpt'")
|
|
167
|
+
|
|
168
|
+
def save_sft_data(self, filepath: str, format: str = 'messages') -> None:
|
|
169
|
+
"""Save SFT data to file."""
|
|
170
|
+
sft_data = self.export_sft_data(format)
|
|
171
|
+
with open(filepath, 'w', encoding='utf-8') as f:
|
|
172
|
+
json.dump(sft_data, f, indent=2, ensure_ascii=False)
|
|
173
|
+
|
|
174
|
+
def clear_sft_data(self) -> None:
|
|
175
|
+
"""Clear stored SFT examples."""
|
|
176
|
+
self.sft_data.clear()
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class ChainOfThought:
|
|
180
|
+
"""DSPy-like ChainOfThought wrapper for signatures."""
|
|
181
|
+
|
|
182
|
+
def __init__(self, signature: Type[Signature], **llm_kwargs):
|
|
183
|
+
"""
|
|
184
|
+
Initialize ChainOfThought with a signature.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
signature: Signature class defining input/output structure
|
|
188
|
+
**llm_kwargs: Arguments passed to LLMJudgeBase
|
|
189
|
+
"""
|
|
190
|
+
self.signature = signature
|
|
191
|
+
|
|
192
|
+
# Create system prompt from signature
|
|
193
|
+
system_prompt = signature.get_instruction()
|
|
194
|
+
|
|
195
|
+
# Add reasoning instruction
|
|
196
|
+
system_prompt += "\n\nThink step by step before providing your final answer."
|
|
197
|
+
|
|
198
|
+
self.llm = LLMJudgeBase(
|
|
199
|
+
system_prompt_template=system_prompt,
|
|
200
|
+
signature=signature,
|
|
201
|
+
**llm_kwargs
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
def __call__(self, **kwargs) -> Any:
|
|
205
|
+
"""Execute chain of thought reasoning."""
|
|
206
|
+
# Format input using signature
|
|
207
|
+
signature_instance = self.signature(**kwargs)
|
|
208
|
+
input_text = signature_instance.format_input(**kwargs)
|
|
209
|
+
|
|
210
|
+
results = self.llm.judge(input_text)
|
|
211
|
+
|
|
212
|
+
# Return the parsed output
|
|
213
|
+
if results:
|
|
214
|
+
return results[0]['parsed']
|
|
215
|
+
return None
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
# Example usage classes based on the raw code
|
|
219
|
+
class TranslationOutput(BaseModel):
|
|
220
|
+
"""Output schema for translation evaluation."""
|
|
221
|
+
structure_score: int # 0 = wrong, 1 = partially correct, 2 = correct
|
|
222
|
+
translation_score: int # 0 = not faithful, 1 = somewhat faithful, 2 = fully faithful
|
|
223
|
+
term_score: int # 0 = glossary not followed, 1 = partially followed, 2 = fully followed or no glossary provided
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
class TranslationEvaluatorJudge(LLMJudgeBase):
|
|
227
|
+
"""Translation evaluator judge based on the raw code example."""
|
|
228
|
+
|
|
229
|
+
def __init__(self, **kwargs):
|
|
230
|
+
system_prompt = """You are a careful **translation evaluator**.
|
|
231
|
+
|
|
232
|
+
You are given five inputs:
|
|
233
|
+
|
|
234
|
+
* **Source Prompt** (the original text & any constraints)
|
|
235
|
+
* **AI Translation** (the machine translation to evaluate)
|
|
236
|
+
* **Human Reference** (a reference rendering; use only for guidance, not as ground truth)
|
|
237
|
+
* **System Message** (an automated hint about a possible structural error)
|
|
238
|
+
* **Glossaries** (optional terminology constraints; may be empty)
|
|
239
|
+
|
|
240
|
+
## Your tasks
|
|
241
|
+
|
|
242
|
+
1. **Check structure correctness**:
|
|
243
|
+
- Use the System Message as a hint.
|
|
244
|
+
- Assign a `structure_score`:
|
|
245
|
+
* `0` = structure is clearly wrong or the error flagged is correct.
|
|
246
|
+
* `1` = partially correct but flawed.
|
|
247
|
+
* `2` = structure is correct; the system error is invalid.
|
|
248
|
+
|
|
249
|
+
2. **Check translation quality**:
|
|
250
|
+
- Compare AI Translation with Source Prompt and Human Reference.
|
|
251
|
+
- Assign a `translation_score`:
|
|
252
|
+
* `0` = unfaithful (major omissions/additions/distortions/repetitions).
|
|
253
|
+
* `1` = somewhat faithful (mostly correct but noticeable issues).
|
|
254
|
+
* `2` = faithful (preserves meaning, scope, nuance; only minor style differences).
|
|
255
|
+
|
|
256
|
+
3. **Check glossary/terminology adherence**:
|
|
257
|
+
- If no glossary is provided → `term_score = 2`.
|
|
258
|
+
- If glossary exists but only partially followed → `term_score = 1`.
|
|
259
|
+
- If glossary exists but not followed at all → `term_score = 0`.
|
|
260
|
+
|
|
261
|
+
## Output format (JSON only; no commentary)
|
|
262
|
+
|
|
263
|
+
Return exactly one JSON object with the three scores.
|
|
264
|
+
Do not output any explanations.
|
|
265
|
+
|
|
266
|
+
---
|
|
267
|
+
|
|
268
|
+
### Inputs
|
|
269
|
+
|
|
270
|
+
Source Prompt: {SOURCE_PROMPT}
|
|
271
|
+
|
|
272
|
+
AI Translation: {AI_TRANSLATION}
|
|
273
|
+
|
|
274
|
+
Human Reference: {HUMAN_REFERENCE}
|
|
275
|
+
|
|
276
|
+
System Message: {SYSTEM_MESSAGE}
|
|
277
|
+
|
|
278
|
+
Glossaries: {GLOSSARIES}
|
|
279
|
+
"""
|
|
280
|
+
|
|
281
|
+
super().__init__(
|
|
282
|
+
system_prompt_template=system_prompt,
|
|
283
|
+
output_model=TranslationOutput,
|
|
284
|
+
**kwargs
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
def evaluate_translation(
|
|
288
|
+
self,
|
|
289
|
+
source_prompt: str,
|
|
290
|
+
ai_translation: str,
|
|
291
|
+
human_reference: str,
|
|
292
|
+
system_message: str,
|
|
293
|
+
glossaries: str
|
|
294
|
+
) -> TranslationOutput:
|
|
295
|
+
"""
|
|
296
|
+
Evaluate a translation with all required parameters.
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
TranslationOutput with the three scores
|
|
300
|
+
"""
|
|
301
|
+
variables = {
|
|
302
|
+
'SOURCE_PROMPT': source_prompt,
|
|
303
|
+
'AI_TRANSLATION': ai_translation,
|
|
304
|
+
'HUMAN_REFERENCE': human_reference,
|
|
305
|
+
'SYSTEM_MESSAGE': system_message,
|
|
306
|
+
'GLOSSARIES': glossaries
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
input_data = {
|
|
310
|
+
'source': source_prompt,
|
|
311
|
+
'target': human_reference,
|
|
312
|
+
'glossaries': glossaries,
|
|
313
|
+
'translation': ai_translation
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
results = self.judge(json.dumps(input_data), variables=variables)
|
|
317
|
+
return results[0]['parsed']
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
# Example usage and testing
|
|
321
|
+
if __name__ == "__main__":
|
|
322
|
+
# Test the Signature system
|
|
323
|
+
from .signature import Signature, InputField, OutputField
|
|
324
|
+
|
|
325
|
+
# Example 1: Using Signature with ChainOfThought (like DSPy)
|
|
326
|
+
class FactJudge(Signature):
|
|
327
|
+
"""Judge if the answer is factually correct based on the context."""
|
|
328
|
+
|
|
329
|
+
context: str = InputField(desc="Context for the prediction") # type: ignore
|
|
330
|
+
question: str = InputField(desc="Question to be answered") # type: ignore
|
|
331
|
+
answer: str = InputField(desc="Answer for the question") # type: ignore
|
|
332
|
+
factually_correct: bool = OutputField(desc="Is the answer factually correct based on the context?") # type: ignore
|
|
333
|
+
|
|
334
|
+
print("=== Testing Signature System ===")
|
|
335
|
+
print("Instruction:")
|
|
336
|
+
print(FactJudge.get_instruction())
|
|
337
|
+
|
|
338
|
+
# Example 2: Using LLMJudgeBase directly
|
|
339
|
+
judge_prompt = """You are a factual accuracy judge.
|
|
340
|
+
|
|
341
|
+
Given:
|
|
342
|
+
- Context: {context}
|
|
343
|
+
- Question: {question}
|
|
344
|
+
- Answer: {answer}
|
|
345
|
+
|
|
346
|
+
Determine if the answer is factually correct based on the context.
|
|
347
|
+
Respond with true if correct, false if incorrect."""
|
|
348
|
+
|
|
349
|
+
print("\n=== Testing LLMJudgeBase ===")
|
|
350
|
+
print("System prompt template:")
|
|
351
|
+
print(judge_prompt)
|
|
352
|
+
|
|
353
|
+
# Example 3: Translation evaluator from raw code
|
|
354
|
+
print("\n=== Translation Evaluator Example ===")
|
|
355
|
+
evaluator = TranslationEvaluatorJudge()
|
|
356
|
+
print("Translation evaluator initialized with structured output schema.")
|
|
357
|
+
print("Output schema:", TranslationOutput.model_json_schema())
|
|
358
|
+
|
|
359
|
+
# Test SFT export functionality
|
|
360
|
+
print("\n=== SFT Export Test ===")
|
|
361
|
+
# Create a mock judge with some example data
|
|
362
|
+
mock_judge = LLMJudgeBase("Rate the quality: {text}")
|
|
363
|
+
mock_judge.sft_data = [
|
|
364
|
+
{
|
|
365
|
+
'messages': [
|
|
366
|
+
{'role': 'system', 'content': 'Rate the quality: This is good text'},
|
|
367
|
+
{'role': 'user', 'content': 'Please rate this text'},
|
|
368
|
+
{'role': 'assistant', 'content': '{"quality": "good"}'}
|
|
369
|
+
],
|
|
370
|
+
'variables': {'text': 'This is good text'},
|
|
371
|
+
'input_data': 'Please rate this text',
|
|
372
|
+
'output': '{"quality": "good"}'
|
|
373
|
+
}
|
|
374
|
+
]
|
|
375
|
+
|
|
376
|
+
sft_formats = ['messages', 'sharegpt']
|
|
377
|
+
for format_name in sft_formats:
|
|
378
|
+
exported = mock_judge.export_sft_data(format_name)
|
|
379
|
+
print(f"SFT export ({format_name} format): {len(exported)} examples")
|
|
380
|
+
if exported:
|
|
381
|
+
print(f"Sample structure: {list(exported[0].keys())}")
|
|
382
|
+
|
|
383
|
+
print("\n=== All Tests Completed ===")
|
|
384
|
+
print("The LLMJudgeBase system is ready for use!")
|
|
385
|
+
print("\nKey features:")
|
|
386
|
+
print("- System prompt templating with variables")
|
|
387
|
+
print("- DSPy-like Signature system")
|
|
388
|
+
print("- Automatic SFT data collection")
|
|
389
|
+
print("- Multiple export formats (messages, sharegpt, full)")
|
|
390
|
+
print("- Chain of Thought reasoning support")
|
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DSPy-like signature system for structured LLM interactions.
|
|
3
|
+
|
|
4
|
+
This module provides a declarative way to define LLM input/output schemas
|
|
5
|
+
with field descriptions and type annotations.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Any, Dict, List, Type, Union, get_type_hints, Annotated, get_origin, get_args
|
|
9
|
+
from pydantic import BaseModel, Field
|
|
10
|
+
import inspect
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class InputField:
|
|
14
|
+
"""Represents an input field in a signature."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, desc: str = "", **kwargs):
|
|
17
|
+
self.desc = desc
|
|
18
|
+
self.kwargs = kwargs
|
|
19
|
+
|
|
20
|
+
def __class_getitem__(cls, item):
|
|
21
|
+
"""Support for InputField[type] syntax."""
|
|
22
|
+
return item
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class OutputField:
|
|
26
|
+
"""Represents an output field in a signature."""
|
|
27
|
+
|
|
28
|
+
def __init__(self, desc: str = "", **kwargs):
|
|
29
|
+
self.desc = desc
|
|
30
|
+
self.kwargs = kwargs
|
|
31
|
+
|
|
32
|
+
def __class_getitem__(cls, item):
|
|
33
|
+
"""Support for OutputField[type] syntax."""
|
|
34
|
+
return item
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# Type aliases for cleaner syntax
|
|
38
|
+
def Input(desc: str = "", **kwargs) -> Any:
|
|
39
|
+
"""Create an input field descriptor."""
|
|
40
|
+
return InputField(desc=desc, **kwargs)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def Output(desc: str = "", **kwargs) -> Any:
|
|
44
|
+
"""Create an output field descriptor."""
|
|
45
|
+
return OutputField(desc=desc, **kwargs)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class SignatureMeta(type):
|
|
49
|
+
"""Metaclass for Signature that processes field annotations."""
|
|
50
|
+
|
|
51
|
+
def __new__(cls, name, bases, namespace, **kwargs):
|
|
52
|
+
# Get type hints for this class
|
|
53
|
+
annotations = namespace.get('__annotations__', {})
|
|
54
|
+
|
|
55
|
+
# Store field information
|
|
56
|
+
input_fields = {}
|
|
57
|
+
output_fields = {}
|
|
58
|
+
|
|
59
|
+
for field_name, field_type in annotations.items():
|
|
60
|
+
field_value = namespace.get(field_name)
|
|
61
|
+
field_desc = None
|
|
62
|
+
|
|
63
|
+
# Handle Annotated[Type, Field(...)] syntax using get_origin/get_args
|
|
64
|
+
if get_origin(field_type) is Annotated:
|
|
65
|
+
# Extract args from Annotated type
|
|
66
|
+
args = get_args(field_type)
|
|
67
|
+
if args:
|
|
68
|
+
# First arg is the actual type
|
|
69
|
+
field_type = args[0]
|
|
70
|
+
# Look for InputField or OutputField in the metadata
|
|
71
|
+
for metadata in args[1:]:
|
|
72
|
+
if isinstance(metadata, (InputField, OutputField)):
|
|
73
|
+
field_desc = metadata
|
|
74
|
+
break
|
|
75
|
+
|
|
76
|
+
# Handle old syntax with direct assignment
|
|
77
|
+
if field_desc is None and isinstance(field_value, (InputField, OutputField)):
|
|
78
|
+
field_desc = field_value
|
|
79
|
+
|
|
80
|
+
# Store field information
|
|
81
|
+
if isinstance(field_desc, InputField):
|
|
82
|
+
input_fields[field_name] = {
|
|
83
|
+
'type': field_type,
|
|
84
|
+
'desc': field_desc.desc,
|
|
85
|
+
**field_desc.kwargs
|
|
86
|
+
}
|
|
87
|
+
elif isinstance(field_desc, OutputField):
|
|
88
|
+
output_fields[field_name] = {
|
|
89
|
+
'type': field_type,
|
|
90
|
+
'desc': field_desc.desc,
|
|
91
|
+
**field_desc.kwargs
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
# Store in class attributes
|
|
95
|
+
namespace['_input_fields'] = input_fields
|
|
96
|
+
namespace['_output_fields'] = output_fields
|
|
97
|
+
|
|
98
|
+
return super().__new__(cls, name, bases, namespace)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class Signature(metaclass=SignatureMeta):
|
|
102
|
+
"""Base class for defining LLM signatures with input and output fields."""
|
|
103
|
+
|
|
104
|
+
_input_fields: Dict[str, Dict[str, Any]] = {}
|
|
105
|
+
_output_fields: Dict[str, Dict[str, Any]] = {}
|
|
106
|
+
|
|
107
|
+
def __init__(self, **kwargs):
|
|
108
|
+
"""Initialize signature with field values."""
|
|
109
|
+
for field_name, value in kwargs.items():
|
|
110
|
+
setattr(self, field_name, value)
|
|
111
|
+
|
|
112
|
+
@classmethod
|
|
113
|
+
def get_instruction(cls) -> str:
|
|
114
|
+
"""Generate instruction text from docstring and field descriptions."""
|
|
115
|
+
instruction = cls.__doc__ or "Complete the following task."
|
|
116
|
+
instruction = instruction.strip()
|
|
117
|
+
|
|
118
|
+
# Add input field descriptions
|
|
119
|
+
if cls._input_fields:
|
|
120
|
+
instruction += "\n\n**Input Fields:**\n"
|
|
121
|
+
for field_name, field_info in cls._input_fields.items():
|
|
122
|
+
desc = field_info.get('desc', '')
|
|
123
|
+
field_type = field_info['type']
|
|
124
|
+
type_str = getattr(field_type, '__name__', str(field_type))
|
|
125
|
+
instruction += f"- {field_name} ({type_str}): {desc}\n"
|
|
126
|
+
|
|
127
|
+
# Add output field descriptions
|
|
128
|
+
if cls._output_fields:
|
|
129
|
+
instruction += "\n**Output Fields:**\n"
|
|
130
|
+
for field_name, field_info in cls._output_fields.items():
|
|
131
|
+
desc = field_info.get('desc', '')
|
|
132
|
+
field_type = field_info['type']
|
|
133
|
+
type_str = getattr(field_type, '__name__', str(field_type))
|
|
134
|
+
instruction += f"- {field_name} ({type_str}): {desc}\n"
|
|
135
|
+
|
|
136
|
+
return instruction
|
|
137
|
+
|
|
138
|
+
@classmethod
|
|
139
|
+
def get_input_model(cls) -> Union[Type[BaseModel], type[str]]:
|
|
140
|
+
"""Generate Pydantic input model from input fields."""
|
|
141
|
+
if not cls._input_fields:
|
|
142
|
+
return str
|
|
143
|
+
|
|
144
|
+
fields = {}
|
|
145
|
+
annotations = {}
|
|
146
|
+
|
|
147
|
+
for field_name, field_info in cls._input_fields.items():
|
|
148
|
+
field_type = field_info['type']
|
|
149
|
+
desc = field_info.get('desc', '')
|
|
150
|
+
|
|
151
|
+
# Create Pydantic field
|
|
152
|
+
field_kwargs = {k: v for k, v in field_info.items()
|
|
153
|
+
if k not in ['type', 'desc']}
|
|
154
|
+
if desc:
|
|
155
|
+
field_kwargs['description'] = desc
|
|
156
|
+
|
|
157
|
+
fields[field_name] = Field(**field_kwargs) if field_kwargs else Field()
|
|
158
|
+
annotations[field_name] = field_type
|
|
159
|
+
|
|
160
|
+
# Create dynamic Pydantic model
|
|
161
|
+
input_model = type(
|
|
162
|
+
f"{cls.__name__}Input",
|
|
163
|
+
(BaseModel,),
|
|
164
|
+
{
|
|
165
|
+
'__annotations__': annotations,
|
|
166
|
+
**fields
|
|
167
|
+
}
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
return input_model
|
|
171
|
+
|
|
172
|
+
@classmethod
|
|
173
|
+
def get_output_model(cls) -> Union[Type[BaseModel], type[str]]:
|
|
174
|
+
"""Generate Pydantic output model from output fields."""
|
|
175
|
+
if not cls._output_fields:
|
|
176
|
+
return str
|
|
177
|
+
|
|
178
|
+
fields = {}
|
|
179
|
+
annotations = {}
|
|
180
|
+
|
|
181
|
+
for field_name, field_info in cls._output_fields.items():
|
|
182
|
+
field_type = field_info['type']
|
|
183
|
+
desc = field_info.get('desc', '')
|
|
184
|
+
|
|
185
|
+
# Create Pydantic field
|
|
186
|
+
field_kwargs = {k: v for k, v in field_info.items()
|
|
187
|
+
if k not in ['type', 'desc']}
|
|
188
|
+
if desc:
|
|
189
|
+
field_kwargs['description'] = desc
|
|
190
|
+
|
|
191
|
+
fields[field_name] = Field(**field_kwargs) if field_kwargs else Field()
|
|
192
|
+
annotations[field_name] = field_type
|
|
193
|
+
|
|
194
|
+
# Create dynamic Pydantic model
|
|
195
|
+
output_model = type(
|
|
196
|
+
f"{cls.__name__}Output",
|
|
197
|
+
(BaseModel,),
|
|
198
|
+
{
|
|
199
|
+
'__annotations__': annotations,
|
|
200
|
+
**fields
|
|
201
|
+
}
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
return output_model
|
|
205
|
+
|
|
206
|
+
def format_input(self, **kwargs) -> str:
|
|
207
|
+
"""Format input fields as a string."""
|
|
208
|
+
input_data = {}
|
|
209
|
+
|
|
210
|
+
# Collect input field values
|
|
211
|
+
for field_name in self._input_fields:
|
|
212
|
+
if field_name in kwargs:
|
|
213
|
+
input_data[field_name] = kwargs[field_name]
|
|
214
|
+
elif hasattr(self, field_name):
|
|
215
|
+
input_data[field_name] = getattr(self, field_name)
|
|
216
|
+
|
|
217
|
+
# Format as key-value pairs
|
|
218
|
+
formatted_lines = []
|
|
219
|
+
for field_name, value in input_data.items():
|
|
220
|
+
field_info = self._input_fields[field_name]
|
|
221
|
+
desc = field_info.get('desc', '')
|
|
222
|
+
if desc:
|
|
223
|
+
formatted_lines.append(f"{field_name} ({desc}): {value}")
|
|
224
|
+
else:
|
|
225
|
+
formatted_lines.append(f"{field_name}: {value}")
|
|
226
|
+
|
|
227
|
+
return '\n'.join(formatted_lines)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
# Export functions for easier importing
|
|
231
|
+
__all__ = ['Signature', 'InputField', 'OutputField', 'Input', 'Output']
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
# Example usage for testing
|
|
235
|
+
if __name__ == "__main__":
|
|
236
|
+
# Define a signature like DSPy - using Annotated approach
|
|
237
|
+
class FactJudge(Signature):
|
|
238
|
+
"""Judge if the answer is factually correct based on the context."""
|
|
239
|
+
|
|
240
|
+
context: Annotated[str, Input("Context for the prediction")]
|
|
241
|
+
question: Annotated[str, Input("Question to be answered")]
|
|
242
|
+
answer: Annotated[str, Input("Answer for the question")]
|
|
243
|
+
factually_correct: Annotated[bool, Output("Is the answer factually correct based on the context?")]
|
|
244
|
+
|
|
245
|
+
# Alternative syntax still works but will show type warnings
|
|
246
|
+
class FactJudgeOldSyntax(Signature):
|
|
247
|
+
"""Judge if the answer is factually correct based on the context."""
|
|
248
|
+
|
|
249
|
+
context: str = InputField(desc="Context for the prediction") # type: ignore
|
|
250
|
+
question: str = InputField(desc="Question to be answered") # type: ignore
|
|
251
|
+
answer: str = InputField(desc="Answer for the question") # type: ignore
|
|
252
|
+
factually_correct: bool = OutputField(desc="Is the answer factually correct based on the context?") # type: ignore
|
|
253
|
+
|
|
254
|
+
# Test both signatures
|
|
255
|
+
for judge_class in [FactJudge, FactJudgeOldSyntax]:
|
|
256
|
+
print(f"\n=== Testing {judge_class.__name__} ===")
|
|
257
|
+
print("Instruction:")
|
|
258
|
+
print(judge_class.get_instruction())
|
|
259
|
+
|
|
260
|
+
print("\nInput Model:")
|
|
261
|
+
input_model = judge_class.get_input_model()
|
|
262
|
+
if input_model is not str and hasattr(input_model, 'model_json_schema'):
|
|
263
|
+
print(input_model.model_json_schema()) # type: ignore
|
|
264
|
+
else:
|
|
265
|
+
print("String input model")
|
|
266
|
+
|
|
267
|
+
print("\nOutput Model:")
|
|
268
|
+
output_model = judge_class.get_output_model()
|
|
269
|
+
if output_model is not str and hasattr(output_model, 'model_json_schema'):
|
|
270
|
+
print(output_model.model_json_schema()) # type: ignore
|
|
271
|
+
else:
|
|
272
|
+
print("String output model")
|
|
273
|
+
|
|
274
|
+
# Test instance usage
|
|
275
|
+
judge = judge_class()
|
|
276
|
+
input_text = judge.format_input(
|
|
277
|
+
context="The sky is blue during daytime.",
|
|
278
|
+
question="What color is the sky?",
|
|
279
|
+
answer="Blue"
|
|
280
|
+
)
|
|
281
|
+
print("\nFormatted Input:")
|
|
282
|
+
print(input_text)
|
|
@@ -1,15 +1,16 @@
|
|
|
1
|
-
llm_utils/__init__.py,sha256=
|
|
1
|
+
llm_utils/__init__.py,sha256=gUmdXk6DF7dTKgYr23LdnkXLNT6x8bUZoacKQJ9pi8I,1625
|
|
2
2
|
llm_utils/group_messages.py,sha256=Oe2tlhg-zRodG1-hodYebddrR77j9UdE05LzJw0EvYI,3622
|
|
3
3
|
llm_utils/chat_format/__init__.py,sha256=8dBIUqFJvkgQYedxBtcyxt-4tt8JxAKVap2JlTXmgaM,737
|
|
4
4
|
llm_utils/chat_format/display.py,sha256=3jKDm4OTrvytK1qBhSOjRLltUIObHsYFdBLgm8SVDE8,14159
|
|
5
5
|
llm_utils/chat_format/transform.py,sha256=eU0c3PdAHCNLuGP1UqPwln0B34Lv3bt_uV9v9BrlCN4,5402
|
|
6
6
|
llm_utils/chat_format/utils.py,sha256=xTxN4HrLHcRO2PfCTR43nH1M5zCa7v0kTTdzAcGkZg0,1229
|
|
7
|
-
llm_utils/lm/__init__.py,sha256=
|
|
7
|
+
llm_utils/lm/__init__.py,sha256=znjUTzke2tmCRWkR46sbOQPcRNe5oEbLo5zqg6Vxud0,632
|
|
8
8
|
llm_utils/lm/base_prompt_builder.py,sha256=OLqyxbA8QeYIVFzB9EqxUiE_P2p4_MD_Lq4WSwxFtKU,12136
|
|
9
|
+
llm_utils/lm/llm_as_a_judge.py,sha256=LwqzlIMSBbpv6A2Qq8-fhVO2CGO7_BtU6j0PXoIWFOA,14022
|
|
9
10
|
llm_utils/lm/llm_task.py,sha256=gawOtoP-LOH-8iaUI_2-TXvhFAVuj2fWGxeQZ1xytAo,25288
|
|
10
|
-
llm_utils/lm/lm.py,sha256=8TaLuU7naPQbOFmiS2NQyWVLG0jUUzRRBQsR0In7GVo,7249
|
|
11
11
|
llm_utils/lm/lm_base.py,sha256=pqbHZOdR7yUMpvwt8uBG1dZnt76SY_Wk8BkXQQ-mpWs,9557
|
|
12
12
|
llm_utils/lm/openai_memoize.py,sha256=KToCcB_rhyrULxolnwMfQgl5GNrAeykePxuLS4hBjtc,3442
|
|
13
|
+
llm_utils/lm/signature.py,sha256=s3Zjxjs6AU97jbz1LZ2BTGw-F9aFCF3G346gSMtyEpE,10370
|
|
13
14
|
llm_utils/lm/utils.py,sha256=25oOznZhbBWfen1-X1PXQfO09kQZgP5V9CDuqLrf_ZU,12440
|
|
14
15
|
llm_utils/lm/async_lm/__init__.py,sha256=PUBbCuf5u6-0GBUu-2PI6YAguzsyXj-LPkU6vccqT6E,121
|
|
15
16
|
llm_utils/lm/async_lm/_utils.py,sha256=P1-pUDf_0pDmo8WTIi43t5ARlyGA1RIJfpAhz-gfA5g,6105
|
|
@@ -44,7 +45,7 @@ speedy_utils/multi_worker/thread.py,sha256=UniMl8mw-Xw1y3aU9bKGMtBSlQj05QhFouWb5
|
|
|
44
45
|
speedy_utils/scripts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
45
46
|
speedy_utils/scripts/mpython.py,sha256=IvywP7Y0_V6tWfMP-4MjPvN5_KfxWF21xaLJsCIayCk,3821
|
|
46
47
|
speedy_utils/scripts/openapi_client_codegen.py,sha256=f2125S_q0PILgH5dyzoKRz7pIvNEjCkzpi4Q4pPFRZE,9683
|
|
47
|
-
speedy_utils-1.1.
|
|
48
|
-
speedy_utils-1.1.
|
|
49
|
-
speedy_utils-1.1.
|
|
50
|
-
speedy_utils-1.1.
|
|
48
|
+
speedy_utils-1.1.23.dist-info/METADATA,sha256=xmQbqlIBS8fHw9ZZFG2kLftUiZX37cCt0hMmFc7putk,8028
|
|
49
|
+
speedy_utils-1.1.23.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
50
|
+
speedy_utils-1.1.23.dist-info/entry_points.txt,sha256=1rrFMfqvaMUE9hvwGiD6vnVh98kmgy0TARBj-v0Lfhs,244
|
|
51
|
+
speedy_utils-1.1.23.dist-info/RECORD,,
|
llm_utils/lm/lm.py
DELETED
|
@@ -1,207 +0,0 @@
|
|
|
1
|
-
# # from ._utils import *
|
|
2
|
-
# from typing import (
|
|
3
|
-
# Any,
|
|
4
|
-
# List,
|
|
5
|
-
# Literal,
|
|
6
|
-
# Optional,
|
|
7
|
-
# Type,
|
|
8
|
-
# Union,
|
|
9
|
-
# cast,
|
|
10
|
-
# )
|
|
11
|
-
|
|
12
|
-
# from loguru import logger
|
|
13
|
-
# from openai import AuthenticationError, BadRequestError, OpenAI, RateLimitError
|
|
14
|
-
# from pydantic import BaseModel
|
|
15
|
-
# from speedy_utils import jloads
|
|
16
|
-
|
|
17
|
-
# # from llm_utils.lm.async_lm.async_llm_task import OutputModelType
|
|
18
|
-
# from llm_utils.lm.lm_base import LMBase
|
|
19
|
-
|
|
20
|
-
# from .async_lm._utils import (
|
|
21
|
-
# LegacyMsgs,
|
|
22
|
-
# Messages,
|
|
23
|
-
# OutputModelType,
|
|
24
|
-
# ParsedOutput,
|
|
25
|
-
# RawMsgs,
|
|
26
|
-
# )
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
# class LM(LMBase):
|
|
30
|
-
# """Unified **sync** language‑model wrapper with optional JSON parsing."""
|
|
31
|
-
|
|
32
|
-
# def __init__(
|
|
33
|
-
# self,
|
|
34
|
-
# *,
|
|
35
|
-
# model: Optional[str] = None,
|
|
36
|
-
# response_model: Optional[type[BaseModel]] = None,
|
|
37
|
-
# temperature: float = 0.0,
|
|
38
|
-
# max_tokens: int = 2_000,
|
|
39
|
-
# base_url: Optional[str] = None,
|
|
40
|
-
# api_key: Optional[str] = None,
|
|
41
|
-
# cache: bool = True,
|
|
42
|
-
# ports: Optional[List[int]] = None,
|
|
43
|
-
# top_p: float = 1.0,
|
|
44
|
-
# presence_penalty: float = 0.0,
|
|
45
|
-
# top_k: int = 1,
|
|
46
|
-
# repetition_penalty: float = 1.0,
|
|
47
|
-
# frequency_penalty: Optional[float] = None,
|
|
48
|
-
# ) -> None:
|
|
49
|
-
|
|
50
|
-
# if model is None:
|
|
51
|
-
# if base_url is None:
|
|
52
|
-
# raise ValueError("Either model or base_url must be provided")
|
|
53
|
-
# models = OpenAI(base_url=base_url, api_key=api_key or 'abc').models.list().data
|
|
54
|
-
# assert len(models) == 1, f"Found {len(models)} models, please specify one."
|
|
55
|
-
# model = models[0].id
|
|
56
|
-
# print(f"Using model: {model}")
|
|
57
|
-
|
|
58
|
-
# super().__init__(
|
|
59
|
-
# ports=ports,
|
|
60
|
-
# base_url=base_url,
|
|
61
|
-
# cache=cache,
|
|
62
|
-
# api_key=api_key,
|
|
63
|
-
# )
|
|
64
|
-
|
|
65
|
-
# # Model behavior options
|
|
66
|
-
# self.response_model = response_model
|
|
67
|
-
|
|
68
|
-
# # Store all model-related parameters in model_kwargs
|
|
69
|
-
# self.model_kwargs = dict(
|
|
70
|
-
# model=model,
|
|
71
|
-
# temperature=temperature,
|
|
72
|
-
# max_tokens=max_tokens,
|
|
73
|
-
# top_p=top_p,
|
|
74
|
-
# presence_penalty=presence_penalty,
|
|
75
|
-
# )
|
|
76
|
-
# self.extra_body = dict(
|
|
77
|
-
# top_k=top_k,
|
|
78
|
-
# repetition_penalty=repetition_penalty,
|
|
79
|
-
# frequency_penalty=frequency_penalty,
|
|
80
|
-
# )
|
|
81
|
-
|
|
82
|
-
# def _unified_client_call(
|
|
83
|
-
# self,
|
|
84
|
-
# messages: RawMsgs,
|
|
85
|
-
# extra_body: Optional[dict] = None,
|
|
86
|
-
# max_tokens: Optional[int] = None,
|
|
87
|
-
# ) -> dict:
|
|
88
|
-
# """Unified method for all client interactions (caching handled by MOpenAI)."""
|
|
89
|
-
# converted_messages: Messages = (
|
|
90
|
-
# self._convert_messages(cast(LegacyMsgs, messages))
|
|
91
|
-
# if messages and isinstance(messages[0], dict)
|
|
92
|
-
# else cast(Messages, messages)
|
|
93
|
-
# )
|
|
94
|
-
# if max_tokens is not None:
|
|
95
|
-
# self.model_kwargs["max_tokens"] = max_tokens
|
|
96
|
-
|
|
97
|
-
# try:
|
|
98
|
-
# # Get completion from API (caching handled by MOpenAI)
|
|
99
|
-
# call_kwargs = {
|
|
100
|
-
# "messages": converted_messages,
|
|
101
|
-
# **self.model_kwargs,
|
|
102
|
-
# }
|
|
103
|
-
# if extra_body:
|
|
104
|
-
# call_kwargs["extra_body"] = extra_body
|
|
105
|
-
|
|
106
|
-
# completion = self.client.chat.completions.create(**call_kwargs)
|
|
107
|
-
|
|
108
|
-
# if hasattr(completion, "model_dump"):
|
|
109
|
-
# completion = completion.model_dump()
|
|
110
|
-
|
|
111
|
-
# except (AuthenticationError, RateLimitError, BadRequestError) as exc:
|
|
112
|
-
# error_msg = f"OpenAI API error ({type(exc).__name__}): {exc}"
|
|
113
|
-
# logger.error(error_msg)
|
|
114
|
-
# raise
|
|
115
|
-
|
|
116
|
-
# return completion
|
|
117
|
-
|
|
118
|
-
# def __call__(
|
|
119
|
-
# self,
|
|
120
|
-
# prompt: Optional[str] = None,
|
|
121
|
-
# messages: Optional[RawMsgs] = None,
|
|
122
|
-
# max_tokens: Optional[int] = None,
|
|
123
|
-
# ): # -> tuple[Any | dict[Any, Any], list[ChatCompletionMessagePar...:# -> tuple[Any | dict[Any, Any], list[ChatCompletionMessagePar...:
|
|
124
|
-
# """Unified sync call for language model, returns (assistant_message.model_dump(), messages)."""
|
|
125
|
-
# if (prompt is None) == (messages is None):
|
|
126
|
-
# raise ValueError("Provide *either* `prompt` or `messages` (but not both).")
|
|
127
|
-
|
|
128
|
-
# if prompt is not None:
|
|
129
|
-
# messages = [{"role": "user", "content": prompt}]
|
|
130
|
-
|
|
131
|
-
# assert messages is not None
|
|
132
|
-
|
|
133
|
-
# openai_msgs: Messages = (
|
|
134
|
-
# self._convert_messages(cast(LegacyMsgs, messages))
|
|
135
|
-
# if isinstance(messages[0], dict)
|
|
136
|
-
# else cast(Messages, messages)
|
|
137
|
-
# )
|
|
138
|
-
|
|
139
|
-
# assert self.model_kwargs["model"] is not None, (
|
|
140
|
-
# "Model must be set before making a call."
|
|
141
|
-
# )
|
|
142
|
-
|
|
143
|
-
# # Use unified client call
|
|
144
|
-
# raw_response = self._unified_client_call(
|
|
145
|
-
# list(openai_msgs), max_tokens=max_tokens
|
|
146
|
-
# )
|
|
147
|
-
|
|
148
|
-
# if hasattr(raw_response, "model_dump"):
|
|
149
|
-
# raw_response = raw_response.model_dump() # type: ignore
|
|
150
|
-
|
|
151
|
-
# # Extract the assistant's message
|
|
152
|
-
# assistant_msg = raw_response["choices"][0]["message"]
|
|
153
|
-
# # Build the full messages list (input + assistant reply)
|
|
154
|
-
# full_messages = list(messages) + [
|
|
155
|
-
# {"role": assistant_msg["role"], "content": assistant_msg["content"]}
|
|
156
|
-
# ]
|
|
157
|
-
# # Return the OpenAI message as model_dump (if available) and the messages list
|
|
158
|
-
# if hasattr(assistant_msg, "model_dump"):
|
|
159
|
-
# msg_dump = assistant_msg.model_dump()
|
|
160
|
-
# else:
|
|
161
|
-
# msg_dump = dict(assistant_msg)
|
|
162
|
-
# return msg_dump, full_messages
|
|
163
|
-
|
|
164
|
-
# def parse(
|
|
165
|
-
# self,
|
|
166
|
-
# messages: Messages,
|
|
167
|
-
# response_model: Optional[type[BaseModel]] = None,
|
|
168
|
-
# ) -> ParsedOutput[BaseModel]:
|
|
169
|
-
# """Parse response using OpenAI's native parse API."""
|
|
170
|
-
# # Use provided response_model or fall back to instance default
|
|
171
|
-
# model_to_use = response_model or self.response_model
|
|
172
|
-
# assert model_to_use is not None, "response_model must be provided or set at init."
|
|
173
|
-
|
|
174
|
-
# # Use OpenAI's native parse API directly
|
|
175
|
-
# response = self.client.chat.completions.parse(
|
|
176
|
-
# model=self.model_kwargs["model"],
|
|
177
|
-
# messages=messages,
|
|
178
|
-
# response_format=model_to_use,
|
|
179
|
-
# **{k: v for k, v in self.model_kwargs.items() if k != "model"}
|
|
180
|
-
# )
|
|
181
|
-
|
|
182
|
-
# parsed = response.choices[0].message.parsed
|
|
183
|
-
# completion = response.model_dump() if hasattr(response, "model_dump") else {}
|
|
184
|
-
# full_messages = list(messages) + [
|
|
185
|
-
# {"role": "assistant", "content": parsed}
|
|
186
|
-
# ]
|
|
187
|
-
|
|
188
|
-
# return ParsedOutput(
|
|
189
|
-
# messages=full_messages,
|
|
190
|
-
# parsed=cast(BaseModel, parsed),
|
|
191
|
-
# completion=completion,
|
|
192
|
-
# model_kwargs=self.model_kwargs,
|
|
193
|
-
# )
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
# def __enter__(self):
|
|
198
|
-
# return self
|
|
199
|
-
|
|
200
|
-
# def __exit__(self, exc_type, exc_val, exc_tb):
|
|
201
|
-
# if hasattr(self, "_last_client"):
|
|
202
|
-
# last_client = self._last_client # type: ignore
|
|
203
|
-
# if hasattr(last_client, "close"):
|
|
204
|
-
# last_client.close()
|
|
205
|
-
# else:
|
|
206
|
-
# logger.warning("No last client to close")
|
|
207
|
-
LM = None
|
|
File without changes
|
|
File without changes
|