symbolicai 0.21.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- symai/__init__.py +269 -173
- symai/backend/base.py +123 -110
- symai/backend/engines/drawing/engine_bfl.py +45 -44
- symai/backend/engines/drawing/engine_gpt_image.py +112 -97
- symai/backend/engines/embedding/engine_llama_cpp.py +63 -52
- symai/backend/engines/embedding/engine_openai.py +25 -21
- symai/backend/engines/execute/engine_python.py +19 -18
- symai/backend/engines/files/engine_io.py +104 -95
- symai/backend/engines/imagecaptioning/engine_blip2.py +28 -24
- symai/backend/engines/imagecaptioning/engine_llavacpp_client.py +102 -79
- symai/backend/engines/index/engine_pinecone.py +124 -97
- symai/backend/engines/index/engine_qdrant.py +1011 -0
- symai/backend/engines/index/engine_vectordb.py +84 -56
- symai/backend/engines/lean/engine_lean4.py +96 -52
- symai/backend/engines/neurosymbolic/__init__.py +41 -13
- symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_chat.py +330 -248
- symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_reasoning.py +329 -264
- symai/backend/engines/neurosymbolic/engine_cerebras.py +328 -0
- symai/backend/engines/neurosymbolic/engine_deepseekX_reasoning.py +118 -88
- symai/backend/engines/neurosymbolic/engine_google_geminiX_reasoning.py +344 -299
- symai/backend/engines/neurosymbolic/engine_groq.py +173 -115
- symai/backend/engines/neurosymbolic/engine_huggingface.py +114 -84
- symai/backend/engines/neurosymbolic/engine_llama_cpp.py +144 -118
- symai/backend/engines/neurosymbolic/engine_openai_gptX_chat.py +415 -307
- symai/backend/engines/neurosymbolic/engine_openai_gptX_reasoning.py +394 -231
- symai/backend/engines/ocr/engine_apilayer.py +23 -27
- symai/backend/engines/output/engine_stdout.py +10 -13
- symai/backend/engines/{webscraping → scrape}/engine_requests.py +101 -54
- symai/backend/engines/search/engine_openai.py +100 -88
- symai/backend/engines/search/engine_parallel.py +665 -0
- symai/backend/engines/search/engine_perplexity.py +44 -45
- symai/backend/engines/search/engine_serpapi.py +37 -34
- symai/backend/engines/speech_to_text/engine_local_whisper.py +54 -51
- symai/backend/engines/symbolic/engine_wolframalpha.py +15 -9
- symai/backend/engines/text_to_speech/engine_openai.py +20 -26
- symai/backend/engines/text_vision/engine_clip.py +39 -37
- symai/backend/engines/userinput/engine_console.py +5 -6
- symai/backend/mixin/__init__.py +13 -0
- symai/backend/mixin/anthropic.py +48 -38
- symai/backend/mixin/deepseek.py +6 -5
- symai/backend/mixin/google.py +7 -4
- symai/backend/mixin/groq.py +2 -4
- symai/backend/mixin/openai.py +140 -110
- symai/backend/settings.py +87 -20
- symai/chat.py +216 -123
- symai/collect/__init__.py +7 -1
- symai/collect/dynamic.py +80 -70
- symai/collect/pipeline.py +67 -51
- symai/collect/stats.py +161 -109
- symai/components.py +707 -360
- symai/constraints.py +24 -12
- symai/core.py +1857 -1233
- symai/core_ext.py +83 -80
- symai/endpoints/api.py +166 -104
- symai/extended/.DS_Store +0 -0
- symai/extended/__init__.py +46 -12
- symai/extended/api_builder.py +29 -21
- symai/extended/arxiv_pdf_parser.py +23 -14
- symai/extended/bibtex_parser.py +9 -6
- symai/extended/conversation.py +156 -126
- symai/extended/document.py +50 -30
- symai/extended/file_merger.py +57 -14
- symai/extended/graph.py +51 -32
- symai/extended/html_style_template.py +18 -14
- symai/extended/interfaces/blip_2.py +2 -3
- symai/extended/interfaces/clip.py +4 -3
- symai/extended/interfaces/console.py +9 -1
- symai/extended/interfaces/dall_e.py +4 -2
- symai/extended/interfaces/file.py +2 -0
- symai/extended/interfaces/flux.py +4 -2
- symai/extended/interfaces/gpt_image.py +16 -7
- symai/extended/interfaces/input.py +2 -1
- symai/extended/interfaces/llava.py +1 -2
- symai/extended/interfaces/{naive_webscraping.py → naive_scrape.py} +4 -3
- symai/extended/interfaces/naive_vectordb.py +9 -10
- symai/extended/interfaces/ocr.py +5 -3
- symai/extended/interfaces/openai_search.py +2 -0
- symai/extended/interfaces/parallel.py +30 -0
- symai/extended/interfaces/perplexity.py +2 -0
- symai/extended/interfaces/pinecone.py +12 -9
- symai/extended/interfaces/python.py +2 -0
- symai/extended/interfaces/serpapi.py +3 -1
- symai/extended/interfaces/terminal.py +2 -4
- symai/extended/interfaces/tts.py +3 -2
- symai/extended/interfaces/whisper.py +3 -2
- symai/extended/interfaces/wolframalpha.py +2 -1
- symai/extended/metrics/__init__.py +11 -1
- symai/extended/metrics/similarity.py +14 -13
- symai/extended/os_command.py +39 -29
- symai/extended/packages/__init__.py +29 -3
- symai/extended/packages/symdev.py +51 -43
- symai/extended/packages/sympkg.py +41 -35
- symai/extended/packages/symrun.py +63 -50
- symai/extended/repo_cloner.py +14 -12
- symai/extended/seo_query_optimizer.py +15 -13
- symai/extended/solver.py +116 -91
- symai/extended/summarizer.py +12 -10
- symai/extended/taypan_interpreter.py +17 -18
- symai/extended/vectordb.py +122 -92
- symai/formatter/__init__.py +9 -1
- symai/formatter/formatter.py +51 -47
- symai/formatter/regex.py +70 -69
- symai/functional.py +325 -176
- symai/imports.py +190 -147
- symai/interfaces.py +57 -28
- symai/memory.py +45 -35
- symai/menu/screen.py +28 -19
- symai/misc/console.py +66 -56
- symai/misc/loader.py +8 -5
- symai/models/__init__.py +17 -1
- symai/models/base.py +395 -236
- symai/models/errors.py +1 -2
- symai/ops/__init__.py +32 -22
- symai/ops/measures.py +24 -25
- symai/ops/primitives.py +1149 -731
- symai/post_processors.py +58 -50
- symai/pre_processors.py +86 -82
- symai/processor.py +21 -13
- symai/prompts.py +764 -685
- symai/server/huggingface_server.py +135 -49
- symai/server/llama_cpp_server.py +21 -11
- symai/server/qdrant_server.py +206 -0
- symai/shell.py +100 -42
- symai/shellsv.py +700 -492
- symai/strategy.py +630 -346
- symai/symbol.py +368 -322
- symai/utils.py +100 -78
- {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/METADATA +22 -10
- symbolicai-1.1.0.dist-info/RECORD +168 -0
- symbolicai-0.21.0.dist-info/RECORD +0 -162
- {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/WHEEL +0 -0
- {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/entry_points.txt +0 -0
- {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/licenses/LICENSE +0 -0
- {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/top_level.txt +0 -0
symai/strategy.py
CHANGED
|
@@ -2,7 +2,8 @@ import inspect
|
|
|
2
2
|
import logging
|
|
3
3
|
import time
|
|
4
4
|
from collections import defaultdict
|
|
5
|
-
from
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from typing import Any, ClassVar
|
|
6
7
|
|
|
7
8
|
import numpy as np
|
|
8
9
|
from beartype import beartype
|
|
@@ -14,9 +15,9 @@ from rich.panel import Panel
|
|
|
14
15
|
from rich.table import Table
|
|
15
16
|
|
|
16
17
|
from .components import Function
|
|
17
|
-
from .models import
|
|
18
|
-
build_dynamic_llm_datamodel)
|
|
18
|
+
from .models import LLMDataModel, TypeValidationError, build_dynamic_llm_datamodel
|
|
19
19
|
from .symbol import Expression
|
|
20
|
+
from .utils import UserMessage
|
|
20
21
|
|
|
21
22
|
|
|
22
23
|
class ValidationFunction(Function):
|
|
@@ -28,19 +29,20 @@ class ValidationFunction(Function):
|
|
|
28
29
|
• Pause/backoff logic
|
|
29
30
|
• Error simplification
|
|
30
31
|
"""
|
|
32
|
+
|
|
31
33
|
# Have some default retry params that don't add overhead
|
|
32
|
-
_default_retry_params =
|
|
33
|
-
tries
|
|
34
|
-
delay
|
|
35
|
-
backoff
|
|
36
|
-
jitter
|
|
37
|
-
max_delay
|
|
38
|
-
graceful
|
|
39
|
-
|
|
34
|
+
_default_retry_params: ClassVar[dict[str, int | float | bool]] = {
|
|
35
|
+
"tries": 8,
|
|
36
|
+
"delay": 0.015,
|
|
37
|
+
"backoff": 1.25,
|
|
38
|
+
"jitter": 0.0,
|
|
39
|
+
"max_delay": 0.25,
|
|
40
|
+
"graceful": False,
|
|
41
|
+
}
|
|
40
42
|
|
|
41
43
|
def __init__(
|
|
42
44
|
self,
|
|
43
|
-
retry_params: dict[str, int | float | bool] = None,
|
|
45
|
+
retry_params: dict[str, int | float | bool] | None = None,
|
|
44
46
|
verbose: bool = False,
|
|
45
47
|
*args,
|
|
46
48
|
**kwargs,
|
|
@@ -92,10 +94,7 @@ class ValidationFunction(Function):
|
|
|
92
94
|
seed = 42
|
|
93
95
|
|
|
94
96
|
rnd = np.random.RandomState(seed=seed)
|
|
95
|
-
|
|
96
|
-
0, np.iinfo(np.int16).max, size=num_seeds, dtype=np.int16
|
|
97
|
-
).tolist()
|
|
98
|
-
return seeds
|
|
97
|
+
return rnd.randint(0, np.iinfo(np.int16).max, size=num_seeds, dtype=np.int16).tolist()
|
|
99
98
|
|
|
100
99
|
def simplify_validation_errors(self, error: ValidationError) -> str:
|
|
101
100
|
"""
|
|
@@ -123,21 +122,25 @@ class ValidationFunction(Function):
|
|
|
123
122
|
return "\n".join(simplified_errors)
|
|
124
123
|
|
|
125
124
|
def _pause(self, attempt):
|
|
126
|
-
base = self.retry_params[
|
|
127
|
-
jit = (
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
125
|
+
base = self.retry_params["delay"] * (self.retry_params["backoff"] ** attempt)
|
|
126
|
+
jit = (
|
|
127
|
+
np.random.uniform(*self.retry_params["jitter"])
|
|
128
|
+
if isinstance(self.retry_params["jitter"], tuple)
|
|
129
|
+
else self.retry_params["jitter"]
|
|
130
|
+
)
|
|
131
|
+
_delay = min(base + jit, self.retry_params["max_delay"])
|
|
131
132
|
time.sleep(_delay)
|
|
132
133
|
|
|
133
|
-
def remedy_prompt(self, *
|
|
134
|
+
def remedy_prompt(self, *_args, **_kwargs):
|
|
134
135
|
"""
|
|
135
136
|
Abstract or base remedy prompt method.
|
|
136
137
|
Child classes typically override this to include additional context needed for correction.
|
|
137
138
|
"""
|
|
138
|
-
|
|
139
|
+
msg = "Each child class needs its own remedy_prompt implementation."
|
|
140
|
+
UserMessage(msg)
|
|
141
|
+
raise NotImplementedError(msg)
|
|
139
142
|
|
|
140
|
-
def display_panel(self, content, title, border_style="cyan", style="#f0eee6", padding=(1,2)):
|
|
143
|
+
def display_panel(self, content, title, border_style="cyan", style="#f0eee6", padding=(1, 2)):
|
|
141
144
|
"""
|
|
142
145
|
Display content in a rich panel with consistent formatting.
|
|
143
146
|
|
|
@@ -149,11 +152,13 @@ class ValidationFunction(Function):
|
|
|
149
152
|
padding: Padding for the panel (default: (1,2))
|
|
150
153
|
"""
|
|
151
154
|
body = escape(content)
|
|
152
|
-
panel = Panel.fit(
|
|
155
|
+
panel = Panel.fit(
|
|
156
|
+
body, title=title, padding=padding, border_style=border_style, style=style
|
|
157
|
+
)
|
|
153
158
|
self.console.print(panel)
|
|
154
159
|
|
|
155
160
|
def forward(self, *args, **kwargs):
|
|
156
|
-
return super().forward(*args, **kwargs)
|
|
161
|
+
return super().forward(*args, **kwargs) # Just propagate to Function
|
|
157
162
|
|
|
158
163
|
|
|
159
164
|
class TypeValidationFunction(ValidationFunction):
|
|
@@ -163,6 +168,7 @@ class TypeValidationFunction(ValidationFunction):
|
|
|
163
168
|
if a user provides a callable designed to semantically validate the
|
|
164
169
|
structure of the type-validated data.
|
|
165
170
|
"""
|
|
171
|
+
|
|
166
172
|
def __init__(
|
|
167
173
|
self,
|
|
168
174
|
retry_params: dict[str, int | float | bool] = ValidationFunction._default_retry_params,
|
|
@@ -171,21 +177,29 @@ class TypeValidationFunction(ValidationFunction):
|
|
|
171
177
|
*args,
|
|
172
178
|
**kwargs,
|
|
173
179
|
):
|
|
174
|
-
super().__init__(retry_params=retry_params, verbose=verbose,
|
|
180
|
+
super().__init__(*args, retry_params=retry_params, verbose=verbose, **kwargs)
|
|
175
181
|
self.input_data_model = None
|
|
176
182
|
self.output_data_model = None
|
|
177
183
|
self.accumulate_errors = accumulate_errors
|
|
178
184
|
self.verbose = verbose
|
|
179
185
|
|
|
180
|
-
def register_expected_data_model(
|
|
181
|
-
|
|
186
|
+
def register_expected_data_model(
|
|
187
|
+
self, data_model: LLMDataModel, attach_to: str, override: bool = False
|
|
188
|
+
):
|
|
189
|
+
assert attach_to in ["input", "output"], (
|
|
190
|
+
f"Invalid attach_to value: {attach_to}; must be either 'input' or 'output'"
|
|
191
|
+
)
|
|
182
192
|
if attach_to == "input":
|
|
183
193
|
if self.input_data_model is not None and not override:
|
|
184
|
-
|
|
194
|
+
msg = "There is already a data model attached to the input. If you want to override it, set `override=True`."
|
|
195
|
+
UserMessage(msg)
|
|
196
|
+
raise ValueError(msg)
|
|
185
197
|
self.input_data_model = data_model
|
|
186
198
|
elif attach_to == "output":
|
|
187
199
|
if self.output_data_model is not None and not override:
|
|
188
|
-
|
|
200
|
+
msg = "There is already a data model attached to the output. If you want to override it, set `override=True`."
|
|
201
|
+
UserMessage(msg)
|
|
202
|
+
raise ValueError(msg)
|
|
189
203
|
self.output_data_model = data_model
|
|
190
204
|
|
|
191
205
|
def remedy_prompt(self, prompt: str, output: str, errors: str) -> str:
|
|
@@ -200,12 +214,12 @@ Your prompt was:
|
|
|
200
214
|
|
|
201
215
|
The input data model is:
|
|
202
216
|
<input_data_model>
|
|
203
|
-
{self.input_data_model.simplify_json_schema() if self.input_data_model is not None else
|
|
217
|
+
{self.input_data_model.simplify_json_schema() if self.input_data_model is not None else "N/A"}
|
|
204
218
|
</input_data_model>
|
|
205
219
|
|
|
206
220
|
The given input was:
|
|
207
221
|
<input>
|
|
208
|
-
{str(self.input_data_model) if self.input_data_model is not None else
|
|
222
|
+
{str(self.input_data_model) if self.input_data_model is not None else "N/A"}
|
|
209
223
|
</input>
|
|
210
224
|
|
|
211
225
|
The output data model is:
|
|
@@ -247,12 +261,12 @@ You are given the following prompt:
|
|
|
247
261
|
|
|
248
262
|
The input data model is:
|
|
249
263
|
<input_data_model>
|
|
250
|
-
{self.input_data_model.simplify_json_schema() if self.input_data_model is not None else
|
|
264
|
+
{self.input_data_model.simplify_json_schema() if self.input_data_model is not None else "N/A"}
|
|
251
265
|
</input_data_model>
|
|
252
266
|
|
|
253
267
|
The given input is:
|
|
254
268
|
<input>
|
|
255
|
-
{str(self.input_data_model) if self.input_data_model is not None else
|
|
269
|
+
{str(self.input_data_model) if self.input_data_model is not None else "N/A"}
|
|
256
270
|
</input>
|
|
257
271
|
|
|
258
272
|
The output data model is:
|
|
@@ -268,118 +282,185 @@ Important guidelines:
|
|
|
268
282
|
</guidelines>
|
|
269
283
|
"""
|
|
270
284
|
|
|
271
|
-
def
|
|
285
|
+
def _ensure_output_model(self):
|
|
272
286
|
if self.output_data_model is None:
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
287
|
+
msg = (
|
|
288
|
+
"While the input data model is optional, the output data model must be provided. "
|
|
289
|
+
"Please register it before calling the `forward` method."
|
|
290
|
+
)
|
|
291
|
+
UserMessage(msg)
|
|
292
|
+
raise ValueError(msg)
|
|
293
|
+
|
|
294
|
+
def _display_verbose_panels(self, prompt: str):
|
|
295
|
+
if not self.verbose:
|
|
296
|
+
return
|
|
297
|
+
for label, body in [
|
|
298
|
+
("Prompt", prompt),
|
|
299
|
+
(
|
|
300
|
+
"Input data model",
|
|
301
|
+
self.input_data_model.simplify_json_schema() if self.input_data_model else "N/A",
|
|
302
|
+
),
|
|
303
|
+
("Output data model", self.output_data_model.simplify_json_schema()),
|
|
304
|
+
]:
|
|
305
|
+
self.display_panel(body, title=label)
|
|
306
|
+
|
|
307
|
+
def _check_semantic_conditions(
|
|
308
|
+
self,
|
|
309
|
+
result,
|
|
310
|
+
f_semantic_conditions: list[Callable] | None,
|
|
311
|
+
) -> str | None:
|
|
312
|
+
if f_semantic_conditions is None:
|
|
313
|
+
return None
|
|
314
|
+
try:
|
|
315
|
+
assert all(
|
|
316
|
+
f(
|
|
317
|
+
result
|
|
318
|
+
if not getattr(self.output_data_model, "_is_dynamic_model", False)
|
|
319
|
+
else result.value
|
|
320
|
+
)
|
|
321
|
+
for f in f_semantic_conditions
|
|
322
|
+
)
|
|
323
|
+
except Exception as err:
|
|
324
|
+
return f"Semantic validation failed with:\n{err!s}"
|
|
325
|
+
return None
|
|
326
|
+
|
|
327
|
+
def _format_validation_error(self, error: Exception) -> str:
|
|
328
|
+
if isinstance(error, ValidationError):
|
|
329
|
+
return self.simplify_validation_errors(error)
|
|
330
|
+
return str(error)
|
|
331
|
+
|
|
332
|
+
def _handle_failed_validation_attempt(
|
|
333
|
+
self,
|
|
334
|
+
attempt_index: int,
|
|
335
|
+
prompt: str,
|
|
336
|
+
json_str: str,
|
|
337
|
+
errors: list[str],
|
|
338
|
+
error: Exception,
|
|
339
|
+
remedy_seeds: list[Any],
|
|
340
|
+
kwargs: dict,
|
|
341
|
+
) -> str:
|
|
342
|
+
logger.info(f"Validation attempt {attempt_index + 1} failed, pausing before retry…")
|
|
343
|
+
self._pause(attempt_index)
|
|
344
|
+
error_str = self._format_validation_error(error)
|
|
345
|
+
errors.append(error_str)
|
|
346
|
+
|
|
347
|
+
logger.error("Validation errors identified!")
|
|
278
348
|
if self.verbose:
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
349
|
+
errors_report = "\n".join(errors) if self.accumulate_errors else error_str
|
|
350
|
+
title = f"Validation Errors ({'accumulated errors' if self.accumulate_errors else 'last error'})"
|
|
351
|
+
self.display_panel(errors_report, title=title, border_style="red")
|
|
352
|
+
|
|
353
|
+
logger.info("Updating remedy function context…")
|
|
354
|
+
context = self.remedy_prompt(
|
|
355
|
+
prompt=prompt,
|
|
356
|
+
output=json_str,
|
|
357
|
+
errors="\n".join(errors) if self.accumulate_errors else error_str,
|
|
358
|
+
)
|
|
359
|
+
self.remedy_function.clear()
|
|
360
|
+
self.remedy_function.adapt(context)
|
|
361
|
+
if self.verbose:
|
|
362
|
+
self.display_panel(self.remedy_function.dynamic_context, title="New Context")
|
|
289
363
|
|
|
290
|
-
|
|
291
|
-
logger.info(
|
|
364
|
+
json_str = self.remedy_function(seed=remedy_seeds[attempt_index], **kwargs).value
|
|
365
|
+
logger.info("Applied remedy function with updated context!")
|
|
366
|
+
return json_str
|
|
292
367
|
|
|
368
|
+
def _run_validation_attempts(
|
|
369
|
+
self,
|
|
370
|
+
prompt: str,
|
|
371
|
+
f_semantic_conditions: list[Callable] | None,
|
|
372
|
+
validation_context: dict,
|
|
373
|
+
remedy_seeds: list[Any],
|
|
374
|
+
json_str: str,
|
|
375
|
+
kwargs: dict,
|
|
376
|
+
) -> tuple[Any | None, str, list[str]]:
|
|
377
|
+
errors: list[str] = []
|
|
293
378
|
result = None
|
|
294
|
-
|
|
295
|
-
for
|
|
296
|
-
if
|
|
297
|
-
logger.info(
|
|
379
|
+
total_attempts = self.retry_params["tries"] + 1
|
|
380
|
+
for attempt in range(total_attempts):
|
|
381
|
+
if attempt != self.retry_params["tries"]:
|
|
382
|
+
logger.info(
|
|
383
|
+
f"Attempt {attempt + 1}/{self.retry_params['tries']}: Attempting validation…"
|
|
384
|
+
)
|
|
298
385
|
try:
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
for f in f_semantic_conditions
|
|
312
|
-
)
|
|
313
|
-
except Exception as e:
|
|
314
|
-
# If we are in the last attempt and semantic validation fails, result will be None and we propagate the error
|
|
315
|
-
if i == self.retry_params["tries"]:
|
|
316
|
-
result = None
|
|
317
|
-
errors.append(f"Semantic validation failed with:\n{str(e)}")
|
|
318
|
-
break # We break to avoid going into the remedy loop
|
|
319
|
-
raise e
|
|
386
|
+
result = self.output_data_model.model_validate_json(
|
|
387
|
+
json_str,
|
|
388
|
+
strict=False,
|
|
389
|
+
context=validation_context,
|
|
390
|
+
)
|
|
391
|
+
semantic_error = self._check_semantic_conditions(result, f_semantic_conditions)
|
|
392
|
+
if semantic_error is not None:
|
|
393
|
+
if attempt == total_attempts - 1:
|
|
394
|
+
result = None
|
|
395
|
+
errors.append(semantic_error)
|
|
396
|
+
break
|
|
397
|
+
raise AssertionError(semantic_error)
|
|
320
398
|
break
|
|
321
|
-
except Exception as
|
|
322
|
-
|
|
399
|
+
except Exception as error:
|
|
400
|
+
json_str = self._handle_failed_validation_attempt(
|
|
401
|
+
attempt,
|
|
402
|
+
prompt,
|
|
403
|
+
json_str,
|
|
404
|
+
errors,
|
|
405
|
+
error,
|
|
406
|
+
remedy_seeds,
|
|
407
|
+
kwargs,
|
|
408
|
+
)
|
|
409
|
+
return result, json_str, errors
|
|
410
|
+
|
|
411
|
+
def _handle_validation_failure(self, prompt: str, json_str: str, errors: list[str]):
|
|
412
|
+
logger.error("All validation attempts failed!")
|
|
413
|
+
if self.retry_params["graceful"]:
|
|
414
|
+
return
|
|
415
|
+
raise TypeValidationError(
|
|
416
|
+
prompt=prompt,
|
|
417
|
+
result=json_str,
|
|
418
|
+
violations=errors,
|
|
419
|
+
)
|
|
323
420
|
|
|
324
|
-
|
|
421
|
+
def forward(
|
|
422
|
+
self, prompt: str, f_semantic_conditions: list[Callable] | None = None, *args, **kwargs
|
|
423
|
+
):
|
|
424
|
+
self._ensure_output_model()
|
|
425
|
+
validation_context = kwargs.pop("validation_context", {})
|
|
426
|
+
kwargs["response_format"] = {"type": "json_object"}
|
|
427
|
+
logger.info("Initializing validation…")
|
|
428
|
+
self._display_verbose_panels(prompt)
|
|
325
429
|
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
# Update remedy function context
|
|
342
|
-
logger.info("Updating remedy function context…")
|
|
343
|
-
context = self.remedy_prompt(prompt=prompt, output=json_str, errors="\n".join(errors) if self.accumulate_errors else error_str)
|
|
344
|
-
self.remedy_function.clear()
|
|
345
|
-
self.remedy_function.adapt(context)
|
|
346
|
-
if self.verbose:
|
|
347
|
-
self.display_panel(
|
|
348
|
-
self.remedy_function.dynamic_context,
|
|
349
|
-
title="New Context"
|
|
350
|
-
)
|
|
351
|
-
|
|
352
|
-
# Apply the remedy function
|
|
353
|
-
json_str = self.remedy_function(seed=remedy_seeds[i], **kwargs).value
|
|
354
|
-
logger.info("Applied remedy function with updated context!")
|
|
430
|
+
context = self.zero_shot_prompt(prompt=prompt)
|
|
431
|
+
json_str = super().forward(context, *args, **kwargs).value
|
|
432
|
+
|
|
433
|
+
remedy_seeds = self.prepare_seeds(self.retry_params["tries"] + 1, **kwargs)
|
|
434
|
+
logger.info(f"Prepared {len(remedy_seeds)} remedy seeds for validation attempts…")
|
|
435
|
+
|
|
436
|
+
result, json_str, errors = self._run_validation_attempts(
|
|
437
|
+
prompt,
|
|
438
|
+
f_semantic_conditions,
|
|
439
|
+
validation_context,
|
|
440
|
+
remedy_seeds,
|
|
441
|
+
json_str,
|
|
442
|
+
kwargs,
|
|
443
|
+
)
|
|
355
444
|
|
|
356
445
|
if result is None:
|
|
357
|
-
|
|
358
|
-
if self.retry_params['graceful']:
|
|
359
|
-
return
|
|
360
|
-
raise TypeValidationError(
|
|
361
|
-
prompt=prompt,
|
|
362
|
-
result=json_str,
|
|
363
|
-
violations=errors,
|
|
364
|
-
)
|
|
446
|
+
return self._handle_validation_failure(prompt, json_str, errors)
|
|
365
447
|
|
|
366
448
|
logger.success("Validation completed successfully!")
|
|
367
|
-
# Clear artifacts from the remedy function
|
|
368
449
|
self.remedy_function.clear()
|
|
369
|
-
|
|
370
450
|
return result
|
|
371
451
|
|
|
372
452
|
|
|
373
453
|
@beartype
|
|
374
454
|
class contract:
|
|
375
|
-
_default_remedy_retry_params =
|
|
376
|
-
tries
|
|
377
|
-
delay
|
|
378
|
-
backoff
|
|
379
|
-
jitter
|
|
380
|
-
max_delay
|
|
381
|
-
graceful
|
|
382
|
-
|
|
455
|
+
_default_remedy_retry_params: ClassVar[dict[str, int | float | bool]] = {
|
|
456
|
+
"tries": 8,
|
|
457
|
+
"delay": 0.015,
|
|
458
|
+
"backoff": 1.25,
|
|
459
|
+
"jitter": 0.0,
|
|
460
|
+
"max_delay": 0.25,
|
|
461
|
+
"graceful": False,
|
|
462
|
+
}
|
|
463
|
+
_internal_forward_kwargs: ClassVar[set[str]] = {"validation_context"}
|
|
383
464
|
|
|
384
465
|
def __init__(
|
|
385
466
|
self,
|
|
@@ -389,37 +470,47 @@ class contract:
|
|
|
389
470
|
verbose: bool = False,
|
|
390
471
|
remedy_retry_params: dict[str, int | float | bool] = _default_remedy_retry_params,
|
|
391
472
|
):
|
|
392
|
-
|
|
473
|
+
"""
|
|
393
474
|
A contract class decorator inspired by DbC principles. It ensures that the function's input and output
|
|
394
475
|
adhere to specified data models both syntactically and semantically. This implementation includes retry
|
|
395
476
|
logic to handle transient errors and gracefully handle failures.
|
|
396
|
-
|
|
477
|
+
"""
|
|
397
478
|
self.pre_remedy = pre_remedy
|
|
398
479
|
self.post_remedy = post_remedy
|
|
399
480
|
self.remedy_retry_params = remedy_retry_params
|
|
400
|
-
self.f_type_validation_remedy = TypeValidationFunction(
|
|
481
|
+
self.f_type_validation_remedy = TypeValidationFunction(
|
|
482
|
+
accumulate_errors=accumulate_errors, verbose=verbose, retry_params=remedy_retry_params
|
|
483
|
+
)
|
|
401
484
|
|
|
402
485
|
if not verbose:
|
|
403
486
|
logger.disable(__name__)
|
|
404
487
|
else:
|
|
405
488
|
logger.enable(__name__)
|
|
406
489
|
|
|
407
|
-
def _is_valid_input(self,
|
|
408
|
-
if
|
|
490
|
+
def _is_valid_input(self, input_value):
|
|
491
|
+
if input_value is None:
|
|
409
492
|
logger.error("No `input` argument provided!")
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
493
|
+
msg = "Please provide an `input` argument."
|
|
494
|
+
UserMessage(msg)
|
|
495
|
+
raise ValueError(msg)
|
|
496
|
+
if not isinstance(input_value, LLMDataModel):
|
|
497
|
+
logger.error(f"Invalid input type: {type(input_value)}")
|
|
498
|
+
msg = f"Expected input to be of type `LLMDataModel`, got {type(input_value)}"
|
|
499
|
+
UserMessage(msg)
|
|
500
|
+
raise TypeError(msg)
|
|
414
501
|
return True
|
|
415
502
|
|
|
416
503
|
def _is_valid_output(self, output_type):
|
|
417
504
|
if output_type == inspect._empty:
|
|
418
505
|
logger.error("Missing return type annotation!")
|
|
419
|
-
|
|
506
|
+
msg = "The contract requires a return type annotation."
|
|
507
|
+
UserMessage(msg)
|
|
508
|
+
raise ValueError(msg)
|
|
420
509
|
if not issubclass(output_type, LLMDataModel):
|
|
421
510
|
logger.error(f"Invalid return type: {output_type}")
|
|
422
|
-
|
|
511
|
+
msg = "The return type annotation must be a subclass of `LLMDataModel`."
|
|
512
|
+
UserMessage(msg)
|
|
513
|
+
raise TypeError(msg)
|
|
423
514
|
return True
|
|
424
515
|
|
|
425
516
|
def _try_dynamic_type_annotation(self, original_forward, *, context: str = "input"):
|
|
@@ -428,84 +519,122 @@ class contract:
|
|
|
428
519
|
)
|
|
429
520
|
sig = inspect.signature(original_forward)
|
|
430
521
|
try:
|
|
522
|
+
resolved_param = None
|
|
431
523
|
# Fallback: look at the relevant part of the function signature
|
|
432
524
|
# depending on whether we deal with an *input* or *output*
|
|
433
525
|
if context == "input":
|
|
434
526
|
param = sig.parameters.get("input")
|
|
527
|
+
if param is None:
|
|
528
|
+
for candidate in sig.parameters.values():
|
|
529
|
+
if candidate.name == "self":
|
|
530
|
+
continue
|
|
531
|
+
if candidate.kind in (
|
|
532
|
+
inspect.Parameter.POSITIONAL_ONLY,
|
|
533
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
534
|
+
inspect.Parameter.KEYWORD_ONLY,
|
|
535
|
+
):
|
|
536
|
+
param = candidate
|
|
537
|
+
break
|
|
538
|
+
resolved_param = param
|
|
435
539
|
if param is None or param.annotation == inspect._empty:
|
|
436
|
-
|
|
540
|
+
msg = "Failed to infer type from input parameter annotation"
|
|
541
|
+
UserMessage(msg)
|
|
542
|
+
raise TypeError(msg)
|
|
437
543
|
dynamic_model = build_dynamic_llm_datamodel(param.annotation)
|
|
438
544
|
else: # context == "output"
|
|
439
545
|
if sig.return_annotation == inspect._empty:
|
|
440
|
-
|
|
546
|
+
msg = "Failed to infer type from return annotation"
|
|
547
|
+
UserMessage(msg)
|
|
548
|
+
raise TypeError(msg)
|
|
441
549
|
dynamic_model = build_dynamic_llm_datamodel(sig.return_annotation)
|
|
442
|
-
except Exception:
|
|
443
|
-
logger.exception(f"Failed to build dynamic LLMDataModel from {
|
|
444
|
-
|
|
550
|
+
except Exception as err:
|
|
551
|
+
logger.exception(f"Failed to build dynamic LLMDataModel from {resolved_param}!")
|
|
552
|
+
msg = (
|
|
445
553
|
"The type annotation must be a subclass of `LLMDataModel` or a "
|
|
446
554
|
"valid Python typing object supported by Pydantic."
|
|
447
555
|
)
|
|
556
|
+
UserMessage(msg)
|
|
557
|
+
raise TypeError(msg) from err
|
|
448
558
|
|
|
449
559
|
dynamic_model._is_dynamic_model = True
|
|
450
560
|
return dynamic_model
|
|
451
561
|
|
|
452
562
|
def _try_remedy_with_exception(self, prompt, f_semantic_conditions, **remedy_kwargs):
|
|
453
563
|
try:
|
|
454
|
-
data_model = self.f_type_validation_remedy(
|
|
564
|
+
data_model = self.f_type_validation_remedy(
|
|
565
|
+
prompt, f_semantic_conditions=f_semantic_conditions, **remedy_kwargs
|
|
566
|
+
)
|
|
455
567
|
except Exception as e:
|
|
456
|
-
logger.error(
|
|
568
|
+
logger.error("Type validation failed with exception!")
|
|
457
569
|
raise e
|
|
458
570
|
return data_model
|
|
459
571
|
|
|
460
|
-
def _validate_input(self, wrapped_self,
|
|
572
|
+
def _validate_input(self, wrapped_self, input_value, it, **remedy_kwargs):
|
|
461
573
|
logger.info("Starting input validation...")
|
|
462
574
|
if self.pre_remedy:
|
|
463
575
|
logger.info("Validating pre-conditions with remedy...")
|
|
464
|
-
if not hasattr(wrapped_self,
|
|
576
|
+
if not hasattr(wrapped_self, "pre"):
|
|
465
577
|
logger.error("Pre-condition function not defined!")
|
|
466
|
-
|
|
578
|
+
msg = "Pre-condition function not defined. Please define a `pre` method if you want to enforce pre-conditions through a remedy."
|
|
579
|
+
UserMessage(msg)
|
|
580
|
+
raise Exception(msg)
|
|
467
581
|
|
|
468
582
|
op_start = time.perf_counter()
|
|
469
583
|
try:
|
|
470
|
-
assert wrapped_self.pre(
|
|
584
|
+
assert wrapped_self.pre(input_value)
|
|
471
585
|
logger.success("Pre-condition validation successful!")
|
|
472
|
-
return
|
|
473
|
-
except Exception
|
|
586
|
+
return input_value
|
|
587
|
+
except Exception:
|
|
474
588
|
logger.exception("Pre-condition validation failed!")
|
|
475
|
-
self.f_type_validation_remedy.register_expected_data_model(
|
|
476
|
-
|
|
589
|
+
self.f_type_validation_remedy.register_expected_data_model(
|
|
590
|
+
input_value, attach_to="output", override=True
|
|
591
|
+
)
|
|
592
|
+
input_value = self._try_remedy_with_exception(
|
|
593
|
+
prompt=wrapped_self.prompt,
|
|
594
|
+
f_semantic_conditions=[wrapped_self.pre],
|
|
595
|
+
**remedy_kwargs,
|
|
596
|
+
)
|
|
477
597
|
finally:
|
|
478
|
-
wrapped_self._contract_timing[it]["input_validation"] =
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
598
|
+
wrapped_self._contract_timing[it]["input_validation"] = (
|
|
599
|
+
time.perf_counter() - op_start
|
|
600
|
+
)
|
|
601
|
+
return input_value
|
|
602
|
+
if hasattr(wrapped_self, "pre"):
|
|
603
|
+
logger.info("Validating pre-conditions without remedy...")
|
|
604
|
+
op_start = time.perf_counter()
|
|
605
|
+
try:
|
|
606
|
+
assert wrapped_self.pre(input_value)
|
|
607
|
+
except Exception as e:
|
|
608
|
+
logger.exception("Pre-condition validation failed")
|
|
609
|
+
raise e
|
|
610
|
+
finally:
|
|
611
|
+
wrapped_self._contract_timing[it]["input_validation"] = (
|
|
612
|
+
time.perf_counter() - op_start
|
|
613
|
+
)
|
|
614
|
+
logger.success("Pre-condition validation successful!")
|
|
615
|
+
return input_value
|
|
616
|
+
logger.info("Skip; no pre-condition validation was required!")
|
|
617
|
+
return input_value
|
|
495
618
|
|
|
496
|
-
def _validate_output(self, wrapped_self,
|
|
619
|
+
def _validate_output(self, wrapped_self, input_value, output, it, **remedy_kwargs):
|
|
497
620
|
logger.info("Starting output validation...")
|
|
498
|
-
self.f_type_validation_remedy.register_expected_data_model(
|
|
499
|
-
|
|
621
|
+
self.f_type_validation_remedy.register_expected_data_model(
|
|
622
|
+
input_value, attach_to="input", override=True
|
|
623
|
+
)
|
|
624
|
+
self.f_type_validation_remedy.register_expected_data_model(
|
|
625
|
+
output, attach_to="output", override=True
|
|
626
|
+
)
|
|
500
627
|
|
|
501
628
|
op_start = time.perf_counter()
|
|
502
629
|
try:
|
|
503
630
|
logger.info("Getting a valid output type...")
|
|
504
|
-
output = self._try_remedy_with_exception(
|
|
505
|
-
|
|
631
|
+
output = self._try_remedy_with_exception(
|
|
632
|
+
prompt=wrapped_self.prompt, f_semantic_conditions=None, **remedy_kwargs
|
|
633
|
+
)
|
|
634
|
+
if output is None: # output is None when graceful mode is enabled
|
|
506
635
|
return output
|
|
507
636
|
except Exception as e:
|
|
508
|
-
logger.exception(
|
|
637
|
+
logger.exception("Type creation failed!")
|
|
509
638
|
raise e
|
|
510
639
|
finally:
|
|
511
640
|
wrapped_self._contract_timing[it]["output_validation"] = time.perf_counter() - op_start
|
|
@@ -515,88 +644,119 @@ class contract:
|
|
|
515
644
|
logger.info("Validating post-conditions with remedy...")
|
|
516
645
|
if not hasattr(wrapped_self, "post"):
|
|
517
646
|
logger.error("Post-condition function not defined!")
|
|
518
|
-
|
|
647
|
+
msg = "Post-condition function not defined. Please define a `post` method if you want to enforce post-conditions through a remedy."
|
|
648
|
+
UserMessage(msg)
|
|
649
|
+
raise Exception(msg)
|
|
519
650
|
|
|
520
651
|
op_start = time.perf_counter()
|
|
521
652
|
try:
|
|
522
653
|
assert wrapped_self.post(output)
|
|
523
654
|
logger.success("Post-condition validation successful!")
|
|
524
655
|
return output
|
|
656
|
+
except Exception:
|
|
657
|
+
logger.exception("Post-condition validation failed!")
|
|
658
|
+
output = self._try_remedy_with_exception(
|
|
659
|
+
prompt=wrapped_self.prompt,
|
|
660
|
+
f_semantic_conditions=[wrapped_self.post],
|
|
661
|
+
**remedy_kwargs,
|
|
662
|
+
)
|
|
663
|
+
finally:
|
|
664
|
+
wrapped_self._contract_timing[it]["output_validation"] += (
|
|
665
|
+
time.perf_counter() - op_start
|
|
666
|
+
)
|
|
667
|
+
logger.success("Post-condition validation successful!")
|
|
668
|
+
return output
|
|
669
|
+
if hasattr(wrapped_self, "post"):
|
|
670
|
+
logger.info("Validating post-conditions without remedy...")
|
|
671
|
+
op_start = time.perf_counter()
|
|
672
|
+
try:
|
|
673
|
+
assert wrapped_self.post(output)
|
|
525
674
|
except Exception as e:
|
|
526
675
|
logger.exception("Post-condition validation failed!")
|
|
527
|
-
|
|
676
|
+
raise e
|
|
528
677
|
finally:
|
|
529
|
-
wrapped_self._contract_timing[it]["output_validation"]
|
|
678
|
+
wrapped_self._contract_timing[it]["output_validation"] = (
|
|
679
|
+
time.perf_counter() - op_start
|
|
680
|
+
)
|
|
530
681
|
logger.success("Post-condition validation successful!")
|
|
531
682
|
return output
|
|
532
|
-
else:
|
|
533
|
-
if hasattr(wrapped_self, "post"):
|
|
534
|
-
logger.info("Validating post-conditions without remedy...")
|
|
535
|
-
op_start = time.perf_counter()
|
|
536
|
-
try:
|
|
537
|
-
assert wrapped_self.post(output)
|
|
538
|
-
except Exception as e:
|
|
539
|
-
logger.exception("Post-condition validation failed!")
|
|
540
|
-
raise e
|
|
541
|
-
finally:
|
|
542
|
-
wrapped_self._contract_timing[it]["output_validation"] = time.perf_counter() - op_start
|
|
543
|
-
logger.success("Post-condition validation successful!")
|
|
544
|
-
return output
|
|
545
683
|
logger.info("Skip; no post-condition validation was required!")
|
|
546
684
|
return output
|
|
547
685
|
|
|
548
686
|
def _validate_act_method(self, act_method):
|
|
549
687
|
act_sig = inspect.signature(act_method)
|
|
550
688
|
params = list(act_sig.parameters.values())
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
689
|
+
|
|
690
|
+
first_param = None
|
|
691
|
+
for param in params:
|
|
692
|
+
if param.name == "self":
|
|
693
|
+
continue
|
|
694
|
+
if param.kind in (
|
|
695
|
+
inspect.Parameter.POSITIONAL_ONLY,
|
|
696
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
697
|
+
):
|
|
698
|
+
first_param = param
|
|
699
|
+
break
|
|
700
|
+
|
|
701
|
+
if first_param is None:
|
|
702
|
+
msg = "'act' method must accept at least one positional parameter after `self`."
|
|
703
|
+
UserMessage(msg)
|
|
704
|
+
raise TypeError(msg)
|
|
705
|
+
if first_param.annotation == inspect._empty:
|
|
706
|
+
msg = f"'act' method parameter '{first_param.name}' must have a type annotation."
|
|
707
|
+
UserMessage(msg)
|
|
708
|
+
raise TypeError(msg)
|
|
555
709
|
if act_sig.return_annotation == inspect._empty:
|
|
556
|
-
|
|
710
|
+
msg = "'act' method must have a return type annotation'"
|
|
711
|
+
UserMessage(msg)
|
|
712
|
+
raise TypeError(msg)
|
|
557
713
|
return True
|
|
558
714
|
|
|
559
|
-
def _act(self, wrapped_self,
|
|
560
|
-
act_method = getattr(wrapped_self,
|
|
715
|
+
def _act(self, wrapped_self, input_value, it, **act_kwargs):
|
|
716
|
+
act_method = getattr(wrapped_self, "act", None)
|
|
561
717
|
if not callable(act_method):
|
|
562
718
|
# Propagate the input if no act method is defined
|
|
563
|
-
return
|
|
719
|
+
return input_value
|
|
564
720
|
|
|
565
721
|
assert self._validate_act_method(act_method)
|
|
566
722
|
|
|
567
|
-
is_dynamic_model = getattr(
|
|
568
|
-
|
|
723
|
+
is_dynamic_model = getattr(input_value, "_is_dynamic_model", False)
|
|
724
|
+
input_value = input_value if not is_dynamic_model else input_value.value
|
|
569
725
|
|
|
570
726
|
logger.info(f"Executing 'act' method on {wrapped_self.__class__.__name__}…")
|
|
571
727
|
|
|
572
728
|
op_start = time.perf_counter()
|
|
573
729
|
try:
|
|
574
|
-
act_output = act_method(
|
|
730
|
+
act_output = act_method(input_value, **act_kwargs)
|
|
575
731
|
except Exception as e:
|
|
576
|
-
logger.exception(
|
|
732
|
+
logger.exception("'act' method execution failed")
|
|
577
733
|
raise e
|
|
578
734
|
finally:
|
|
579
735
|
wrapped_self._contract_timing[it]["act_execution"] = time.perf_counter() - op_start
|
|
580
736
|
|
|
581
737
|
act_sig = inspect.signature(act_method)
|
|
582
|
-
if
|
|
583
|
-
|
|
584
|
-
|
|
738
|
+
if (
|
|
739
|
+
act_sig.return_annotation != inspect.Signature.empty
|
|
740
|
+
and inspect.isclass(act_sig.return_annotation)
|
|
741
|
+
and not isinstance(act_output, act_sig.return_annotation)
|
|
742
|
+
):
|
|
743
|
+
msg = f"'act' method returned {type(act_output).__name__}, expected {act_sig.return_annotation.__name__}."
|
|
744
|
+
UserMessage(msg)
|
|
745
|
+
raise TypeError(msg)
|
|
585
746
|
|
|
586
747
|
logger.success("'act' method executed successfully!")
|
|
587
748
|
return act_output
|
|
588
749
|
|
|
589
|
-
def
|
|
590
|
-
original_init = cls.__init__
|
|
591
|
-
original_forward = cls.forward
|
|
592
|
-
|
|
750
|
+
def _build_wrapped_init(self, original_init):
|
|
593
751
|
def __init__(wrapped_self, *args, **kwargs):
|
|
594
752
|
logger.info("Initializing contract...")
|
|
595
753
|
original_init(wrapped_self, *args, **kwargs)
|
|
596
754
|
|
|
597
755
|
if not hasattr(wrapped_self, "prompt"):
|
|
598
756
|
logger.error("Prompt attribute not defined!")
|
|
599
|
-
|
|
757
|
+
msg = "Please define a static `prompt` attribute that describes what the contract must do."
|
|
758
|
+
UserMessage(msg)
|
|
759
|
+
raise Exception(msg)
|
|
600
760
|
|
|
601
761
|
wrapped_self.contract_successful = False
|
|
602
762
|
wrapped_self.contract_result = None
|
|
@@ -604,105 +764,217 @@ class contract:
|
|
|
604
764
|
wrapped_self._contract_timing = defaultdict(dict)
|
|
605
765
|
logger.info("Contract initialization complete!")
|
|
606
766
|
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
767
|
+
return __init__
|
|
768
|
+
|
|
769
|
+
def _start_contract_execution(self, wrapped_self):
|
|
770
|
+
logger.info("Starting contract execution...")
|
|
771
|
+
it = len(wrapped_self._contract_timing) # the len is the __call__ op_start
|
|
772
|
+
return it, time.perf_counter()
|
|
773
|
+
|
|
774
|
+
def _find_input_param_name(self, sig: inspect.Signature) -> str | None:
|
|
775
|
+
for param in sig.parameters.values():
|
|
776
|
+
if param.name == "self":
|
|
777
|
+
continue
|
|
778
|
+
if param.kind in (
|
|
779
|
+
inspect.Parameter.POSITIONAL_ONLY,
|
|
780
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
781
|
+
inspect.Parameter.KEYWORD_ONLY,
|
|
782
|
+
):
|
|
783
|
+
return param.name
|
|
784
|
+
return None
|
|
611
785
|
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
input_type = self._try_dynamic_type_annotation(original_forward, context="input")
|
|
618
|
-
input = input_type(value=input)
|
|
619
|
-
|
|
620
|
-
maybe_payload = getattr(wrapped_self, "payload", None)
|
|
621
|
-
maybe_template = getattr(wrapped_self, "template")
|
|
622
|
-
if inspect.ismethod(maybe_template):
|
|
623
|
-
# `template` is a primitive in symbolicai case in which we actually don't have a template
|
|
624
|
-
maybe_template = None
|
|
625
|
-
|
|
626
|
-
# Create validation kwargs that include all original kwargs plus payload and template
|
|
627
|
-
validation_kwargs = {
|
|
628
|
-
**kwargs,
|
|
629
|
-
"payload": maybe_payload,
|
|
630
|
-
"template_suffix": maybe_template
|
|
631
|
-
}
|
|
786
|
+
def _prepare_forward_args(self, args, kwargs):
|
|
787
|
+
args_list = list(args)
|
|
788
|
+
kwargs_without_input = dict(kwargs)
|
|
789
|
+
original_kwargs = dict(kwargs)
|
|
790
|
+
return args_list, kwargs_without_input, original_kwargs
|
|
632
791
|
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
792
|
+
def _extract_input_value(
|
|
793
|
+
self,
|
|
794
|
+
args_list,
|
|
795
|
+
kwargs_without_input,
|
|
796
|
+
original_kwargs,
|
|
797
|
+
input_param_name: str | None,
|
|
798
|
+
):
|
|
799
|
+
if args_list:
|
|
800
|
+
return args_list[0], ("args", 0)
|
|
801
|
+
if input_param_name and input_param_name in kwargs_without_input:
|
|
802
|
+
return kwargs_without_input.pop(input_param_name), ("kwargs", input_param_name)
|
|
803
|
+
if "input" in kwargs_without_input:
|
|
804
|
+
return kwargs_without_input.pop("input"), ("kwargs", "input")
|
|
805
|
+
return original_kwargs.get("input"), ("fallback_kw", "input")
|
|
806
|
+
|
|
807
|
+
def _coerce_input_value(self, original_forward, input_value):
|
|
808
|
+
try:
|
|
809
|
+
assert self._is_valid_input(input_value)
|
|
810
|
+
return input_value
|
|
811
|
+
except TypeError:
|
|
812
|
+
input_type_model = self._try_dynamic_type_annotation(original_forward, context="input")
|
|
813
|
+
return input_type_model(value=input_value)
|
|
814
|
+
|
|
815
|
+
def _collect_validation_kwargs(self, wrapped_self, kwargs_without_input):
|
|
816
|
+
maybe_payload = getattr(wrapped_self, "payload", None)
|
|
817
|
+
maybe_template = getattr(wrapped_self, "template", None)
|
|
818
|
+
if inspect.ismethod(maybe_template):
|
|
819
|
+
maybe_template = None
|
|
820
|
+
return {
|
|
821
|
+
**kwargs_without_input,
|
|
822
|
+
"payload": maybe_payload,
|
|
823
|
+
"template_suffix": maybe_template,
|
|
824
|
+
}
|
|
825
|
+
|
|
826
|
+
def _resolve_output_type(self, sig: inspect.Signature, original_forward):
|
|
827
|
+
output_type = sig.return_annotation
|
|
828
|
+
try:
|
|
829
|
+
assert self._is_valid_output(output_type)
|
|
830
|
+
except TypeError:
|
|
831
|
+
output_type = self._try_dynamic_type_annotation(original_forward, context="output")
|
|
832
|
+
return output_type
|
|
639
833
|
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
834
|
+
def _run_contract_pipeline(
|
|
835
|
+
self,
|
|
836
|
+
wrapped_self,
|
|
837
|
+
current_input_value,
|
|
838
|
+
output_type,
|
|
839
|
+
it,
|
|
840
|
+
validation_kwargs,
|
|
841
|
+
):
|
|
842
|
+
output = None
|
|
843
|
+
try:
|
|
844
|
+
maybe_new_input = self._validate_input(
|
|
845
|
+
wrapped_self, current_input_value, it, **validation_kwargs
|
|
846
|
+
)
|
|
847
|
+
if maybe_new_input is not None:
|
|
848
|
+
current_input_value = maybe_new_input
|
|
647
849
|
|
|
648
|
-
|
|
649
|
-
|
|
850
|
+
current_input_value = self._act(
|
|
851
|
+
wrapped_self, current_input_value, it, **validation_kwargs
|
|
852
|
+
)
|
|
650
853
|
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
854
|
+
output = self._validate_output(
|
|
855
|
+
wrapped_self,
|
|
856
|
+
current_input_value,
|
|
857
|
+
output_type,
|
|
858
|
+
it,
|
|
859
|
+
**validation_kwargs,
|
|
860
|
+
)
|
|
861
|
+
wrapped_self.contract_successful = output is not None
|
|
862
|
+
wrapped_self.contract_result = output
|
|
863
|
+
wrapped_self.contract_exception = None
|
|
864
|
+
except Exception as exc:
|
|
865
|
+
logger.exception("Contract execution failed in main path!")
|
|
866
|
+
wrapped_self.contract_successful = False
|
|
867
|
+
wrapped_self.contract_exception = exc
|
|
868
|
+
return output, current_input_value
|
|
656
869
|
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
if not isinstance(output, output_type):
|
|
686
|
-
logger.error(f"Output type mismatch: {type(output)}")
|
|
687
|
-
if self.remedy_retry_params["graceful"]:
|
|
688
|
-
# In graceful mode, skip type mismatch error and return raw output
|
|
689
|
-
if hasattr(output_type, '_is_dynamic_model') and output_type._is_dynamic_model and hasattr(output, 'value'):
|
|
690
|
-
return output.value
|
|
691
|
-
return output
|
|
692
|
-
raise TypeError(
|
|
693
|
-
f"Expected output to be an instance of {output_type}, "
|
|
694
|
-
f"but got {type(output)}! Forward method must return an instance of {output_type}!"
|
|
695
|
-
)
|
|
696
|
-
if not wrapped_self.contract_successful:
|
|
697
|
-
logger.warning("Contract validation failed!")
|
|
698
|
-
else:
|
|
699
|
-
logger.success("Contract validation successful!")
|
|
870
|
+
def _execute_forward_call(
|
|
871
|
+
self,
|
|
872
|
+
wrapped_self,
|
|
873
|
+
original_forward,
|
|
874
|
+
args_list,
|
|
875
|
+
original_kwargs,
|
|
876
|
+
input_param_name,
|
|
877
|
+
input_source,
|
|
878
|
+
forward_input_value,
|
|
879
|
+
it,
|
|
880
|
+
contract_start,
|
|
881
|
+
):
|
|
882
|
+
forward_kwargs = original_kwargs.copy()
|
|
883
|
+
for internal_kw in self._internal_forward_kwargs:
|
|
884
|
+
forward_kwargs.pop(internal_kw, None)
|
|
885
|
+
|
|
886
|
+
logger.info("Executing original forward method...")
|
|
887
|
+
|
|
888
|
+
if input_param_name:
|
|
889
|
+
if input_param_name in forward_kwargs or input_source == ("kwargs", input_param_name):
|
|
890
|
+
forward_kwargs[input_param_name] = forward_input_value
|
|
891
|
+
elif input_source == ("args", 0) and args_list:
|
|
892
|
+
args_list[0] = forward_input_value
|
|
893
|
+
else:
|
|
894
|
+
forward_kwargs[input_param_name] = forward_input_value
|
|
895
|
+
else:
|
|
896
|
+
forward_kwargs["input"] = forward_input_value
|
|
700
897
|
|
|
701
|
-
|
|
702
|
-
|
|
898
|
+
if input_param_name and input_param_name != "input" and "input" in forward_kwargs:
|
|
899
|
+
forward_kwargs.pop("input")
|
|
703
900
|
|
|
704
|
-
|
|
901
|
+
try:
|
|
902
|
+
op_start = time.perf_counter()
|
|
903
|
+
output = original_forward(wrapped_self, *args_list, **forward_kwargs)
|
|
904
|
+
finally:
|
|
905
|
+
wrapped_self._contract_timing[it]["forward_execution"] = time.perf_counter() - op_start
|
|
906
|
+
wrapped_self._contract_timing[it]["contract_execution"] = (
|
|
907
|
+
time.perf_counter() - contract_start
|
|
908
|
+
)
|
|
909
|
+
return output
|
|
910
|
+
|
|
911
|
+
def _finalize_contract_output(self, output, output_type, wrapped_self):
|
|
912
|
+
if not isinstance(output, output_type):
|
|
913
|
+
logger.error(f"Output type mismatch: {type(output)}")
|
|
914
|
+
if self.remedy_retry_params["graceful"]:
|
|
915
|
+
if getattr(output_type, "_is_dynamic_model", False) and hasattr(output, "value"):
|
|
916
|
+
return output.value
|
|
917
|
+
return output
|
|
918
|
+
msg = (
|
|
919
|
+
f"Expected output to be an instance of {output_type}, "
|
|
920
|
+
f"but got {type(output)}! Forward method must return an instance of {output_type}!"
|
|
921
|
+
)
|
|
922
|
+
UserMessage(msg)
|
|
923
|
+
raise TypeError(msg)
|
|
924
|
+
if not wrapped_self.contract_successful:
|
|
925
|
+
logger.warning("Contract validation failed!")
|
|
926
|
+
else:
|
|
927
|
+
logger.success("Contract validation successful!")
|
|
928
|
+
|
|
929
|
+
if getattr(output_type, "_is_dynamic_model", False):
|
|
930
|
+
return output.value
|
|
931
|
+
return output
|
|
705
932
|
|
|
933
|
+
def _contract_forward_impl(self, wrapped_self, original_forward, *args, **kwargs):
|
|
934
|
+
it, contract_start = self._start_contract_execution(wrapped_self)
|
|
935
|
+
sig = inspect.signature(original_forward)
|
|
936
|
+
input_param_name = self._find_input_param_name(sig)
|
|
937
|
+
args_list, kwargs_without_input, original_kwargs = self._prepare_forward_args(args, kwargs)
|
|
938
|
+
input_value, input_source = self._extract_input_value(
|
|
939
|
+
args_list, kwargs_without_input, original_kwargs, input_param_name
|
|
940
|
+
)
|
|
941
|
+
current_input_value = self._coerce_input_value(original_forward, input_value)
|
|
942
|
+
input_value = current_input_value
|
|
943
|
+
validation_kwargs = self._collect_validation_kwargs(wrapped_self, kwargs_without_input)
|
|
944
|
+
output_type = self._resolve_output_type(sig, original_forward)
|
|
945
|
+
|
|
946
|
+
output, current_input_value = self._run_contract_pipeline(
|
|
947
|
+
wrapped_self,
|
|
948
|
+
current_input_value,
|
|
949
|
+
output_type,
|
|
950
|
+
it,
|
|
951
|
+
validation_kwargs,
|
|
952
|
+
)
|
|
953
|
+
|
|
954
|
+
forward_input_value = (
|
|
955
|
+
current_input_value if wrapped_self.contract_successful else input_value
|
|
956
|
+
)
|
|
957
|
+
output = self._execute_forward_call(
|
|
958
|
+
wrapped_self,
|
|
959
|
+
original_forward,
|
|
960
|
+
args_list,
|
|
961
|
+
original_kwargs,
|
|
962
|
+
input_param_name,
|
|
963
|
+
input_source,
|
|
964
|
+
forward_input_value,
|
|
965
|
+
it,
|
|
966
|
+
contract_start,
|
|
967
|
+
)
|
|
968
|
+
|
|
969
|
+
return self._finalize_contract_output(output, output_type, wrapped_self)
|
|
970
|
+
|
|
971
|
+
def _build_wrapped_forward(self, original_forward):
|
|
972
|
+
def wrapped_forward(wrapped_self, *args, **kwargs):
|
|
973
|
+
return self._contract_forward_impl(wrapped_self, original_forward, *args, **kwargs)
|
|
974
|
+
|
|
975
|
+
return wrapped_forward
|
|
976
|
+
|
|
977
|
+
def _build_contract_perf_stats(self):
|
|
706
978
|
def contract_perf_stats(wrapped_self):
|
|
707
979
|
"""Analyzes and prints timing statistics across all forward calls."""
|
|
708
980
|
console = Console()
|
|
@@ -717,7 +989,7 @@ class contract:
|
|
|
717
989
|
"act_execution",
|
|
718
990
|
"output_validation",
|
|
719
991
|
"forward_execution",
|
|
720
|
-
"contract_execution"
|
|
992
|
+
"contract_execution",
|
|
721
993
|
]
|
|
722
994
|
|
|
723
995
|
stats = {}
|
|
@@ -739,40 +1011,41 @@ class contract:
|
|
|
739
1011
|
max_time = max(non_zero_times) if non_zero_times else 0
|
|
740
1012
|
|
|
741
1013
|
stats[op] = {
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
1014
|
+
"count": actual_count,
|
|
1015
|
+
"total": total_time,
|
|
1016
|
+
"mean": mean_time,
|
|
1017
|
+
"std": std_time,
|
|
1018
|
+
"min": min_time,
|
|
1019
|
+
"max": max_time,
|
|
748
1020
|
}
|
|
749
1021
|
|
|
750
|
-
total_execution_time = stats[
|
|
1022
|
+
total_execution_time = stats["contract_execution"]["total"]
|
|
751
1023
|
for op in ordered_operations[:-1]:
|
|
752
1024
|
if total_execution_time > 0:
|
|
753
|
-
stats[op][
|
|
1025
|
+
stats[op]["percentage"] = (stats[op]["total"] / total_execution_time) * 100
|
|
754
1026
|
else:
|
|
755
|
-
stats[op][
|
|
1027
|
+
stats[op]["percentage"] = 0
|
|
756
1028
|
|
|
757
|
-
sum_tracked_times = sum(stats[op][
|
|
1029
|
+
sum_tracked_times = sum(stats[op]["total"] for op in ordered_operations[:-1])
|
|
758
1030
|
overhead_time = total_execution_time - sum_tracked_times
|
|
759
|
-
overhead_percentage = (
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
1031
|
+
overhead_percentage = (
|
|
1032
|
+
(overhead_time / total_execution_time) * 100 if total_execution_time > 0 else 0
|
|
1033
|
+
)
|
|
1034
|
+
|
|
1035
|
+
stats["overhead"] = {
|
|
1036
|
+
"count": num_calls,
|
|
1037
|
+
"total": overhead_time,
|
|
1038
|
+
"mean": overhead_time / num_calls if num_calls > 0 else 0,
|
|
1039
|
+
"std": 0,
|
|
1040
|
+
"min": 0,
|
|
1041
|
+
"max": 0,
|
|
1042
|
+
"percentage": overhead_percentage,
|
|
769
1043
|
}
|
|
770
1044
|
|
|
771
|
-
stats[
|
|
1045
|
+
stats["contract_execution"]["percentage"] = 100.0
|
|
772
1046
|
|
|
773
1047
|
table = Table(
|
|
774
|
-
title=f"Contract Execution Summary ({num_calls} Forward Calls)",
|
|
775
|
-
show_header=True
|
|
1048
|
+
title=f"Contract Execution Summary ({num_calls} Forward Calls)", show_header=True
|
|
776
1049
|
)
|
|
777
1050
|
table.add_column("Operation", style="cyan")
|
|
778
1051
|
table.add_column("Count", justify="right", style="blue")
|
|
@@ -787,29 +1060,29 @@ class contract:
|
|
|
787
1060
|
s = stats[op]
|
|
788
1061
|
table.add_row(
|
|
789
1062
|
op.replace("_", " ").title(),
|
|
790
|
-
str(s[
|
|
1063
|
+
str(s["count"]),
|
|
791
1064
|
f"{s['total']:.3f}",
|
|
792
1065
|
f"{s['mean']:.3f}",
|
|
793
1066
|
f"{s['std']:.3f}",
|
|
794
1067
|
f"{s['min']:.3f}",
|
|
795
1068
|
f"{s['max']:.3f}",
|
|
796
|
-
f"{s['percentage']:.1f}%"
|
|
1069
|
+
f"{s['percentage']:.1f}%",
|
|
797
1070
|
)
|
|
798
1071
|
|
|
799
|
-
s = stats[
|
|
1072
|
+
s = stats["overhead"]
|
|
800
1073
|
table.add_row(
|
|
801
1074
|
"Overhead",
|
|
802
|
-
str(s[
|
|
1075
|
+
str(s["count"]),
|
|
803
1076
|
f"{s['total']:.3f}",
|
|
804
1077
|
f"{s['mean']:.3f}",
|
|
805
1078
|
f"{s['std']:.3f}",
|
|
806
1079
|
f"{s['min']:.3f}",
|
|
807
1080
|
f"{s['max']:.3f}",
|
|
808
1081
|
f"{s['percentage']:.1f}%",
|
|
809
|
-
style="bold blue"
|
|
1082
|
+
style="bold blue",
|
|
810
1083
|
)
|
|
811
1084
|
|
|
812
|
-
s = stats[
|
|
1085
|
+
s = stats["contract_execution"]
|
|
813
1086
|
table.add_row(
|
|
814
1087
|
"Total Execution",
|
|
815
1088
|
"N/A",
|
|
@@ -819,7 +1092,7 @@ class contract:
|
|
|
819
1092
|
f"{s['min']:.3f}",
|
|
820
1093
|
f"{s['max']:.3f}",
|
|
821
1094
|
"100.0%",
|
|
822
|
-
style="bold magenta"
|
|
1095
|
+
style="bold magenta",
|
|
823
1096
|
)
|
|
824
1097
|
|
|
825
1098
|
console.print("\n")
|
|
@@ -828,17 +1101,29 @@ class contract:
|
|
|
828
1101
|
|
|
829
1102
|
return stats
|
|
830
1103
|
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
1104
|
+
return contract_perf_stats
|
|
1105
|
+
|
|
1106
|
+
def __call__(self, cls):
|
|
1107
|
+
original_init = cls.__init__
|
|
1108
|
+
original_forward = cls.forward
|
|
834
1109
|
|
|
1110
|
+
cls.__init__ = self._build_wrapped_init(original_init)
|
|
1111
|
+
cls.forward = self._build_wrapped_forward(original_forward)
|
|
1112
|
+
cls.contract_perf_stats = self._build_contract_perf_stats()
|
|
835
1113
|
return cls
|
|
836
1114
|
|
|
837
1115
|
|
|
838
1116
|
class BaseStrategy(TypeValidationFunction):
|
|
839
|
-
def __init__(self, data_model: BaseModel, *
|
|
1117
|
+
def __init__(self, data_model: BaseModel, *_args, **kwargs):
|
|
840
1118
|
super().__init__(
|
|
841
|
-
retry_params=
|
|
1119
|
+
retry_params={
|
|
1120
|
+
"tries": 8,
|
|
1121
|
+
"delay": 0.015,
|
|
1122
|
+
"backoff": 1.25,
|
|
1123
|
+
"jitter": 0.0,
|
|
1124
|
+
"max_delay": 0.25,
|
|
1125
|
+
"graceful": False,
|
|
1126
|
+
},
|
|
842
1127
|
**kwargs,
|
|
843
1128
|
)
|
|
844
1129
|
super().register_expected_data_model(data_model, attach_to="output")
|
|
@@ -851,14 +1136,13 @@ class BaseStrategy(TypeValidationFunction):
|
|
|
851
1136
|
pass
|
|
852
1137
|
|
|
853
1138
|
def forward(self, *args, **kwargs):
|
|
854
|
-
|
|
1139
|
+
return super().forward(
|
|
855
1140
|
*args,
|
|
856
1141
|
payload=self.payload,
|
|
857
1142
|
template_suffix=self.template,
|
|
858
1143
|
response_format={"type": "json_object"},
|
|
859
1144
|
**kwargs,
|
|
860
1145
|
)
|
|
861
|
-
return result
|
|
862
1146
|
|
|
863
1147
|
@property
|
|
864
1148
|
def payload(self):
|
|
@@ -866,7 +1150,7 @@ class BaseStrategy(TypeValidationFunction):
|
|
|
866
1150
|
|
|
867
1151
|
@property
|
|
868
1152
|
def static_context(self):
|
|
869
|
-
raise NotImplementedError
|
|
1153
|
+
raise NotImplementedError
|
|
870
1154
|
|
|
871
1155
|
@property
|
|
872
1156
|
def template(self):
|
|
@@ -878,7 +1162,7 @@ class Strategy(Expression):
|
|
|
878
1162
|
super().__init__(*args, **kwargs)
|
|
879
1163
|
self.logger = logging.getLogger(__name__)
|
|
880
1164
|
|
|
881
|
-
def __new__(
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
return Strategy.load_module_class(
|
|
1165
|
+
def __new__(cls, module: str, *_args, **_kwargs):
|
|
1166
|
+
cls._module = module
|
|
1167
|
+
cls.module_path = "symai.extended.strategies"
|
|
1168
|
+
return Strategy.load_module_class(cls.module)
|