symbolicai 0.20.2__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- symai/__init__.py +96 -64
- symai/backend/base.py +93 -80
- symai/backend/engines/drawing/engine_bfl.py +12 -11
- symai/backend/engines/drawing/engine_gpt_image.py +108 -87
- symai/backend/engines/embedding/engine_llama_cpp.py +25 -28
- symai/backend/engines/embedding/engine_openai.py +3 -5
- symai/backend/engines/execute/engine_python.py +6 -5
- symai/backend/engines/files/engine_io.py +74 -67
- symai/backend/engines/imagecaptioning/engine_blip2.py +3 -3
- symai/backend/engines/imagecaptioning/engine_llavacpp_client.py +54 -38
- symai/backend/engines/index/engine_pinecone.py +23 -24
- symai/backend/engines/index/engine_vectordb.py +16 -14
- symai/backend/engines/lean/engine_lean4.py +38 -34
- symai/backend/engines/neurosymbolic/__init__.py +41 -13
- symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_chat.py +262 -182
- symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_reasoning.py +263 -191
- symai/backend/engines/neurosymbolic/engine_deepseekX_reasoning.py +53 -49
- symai/backend/engines/neurosymbolic/engine_google_geminiX_reasoning.py +212 -211
- symai/backend/engines/neurosymbolic/engine_groq.py +87 -63
- symai/backend/engines/neurosymbolic/engine_huggingface.py +21 -24
- symai/backend/engines/neurosymbolic/engine_llama_cpp.py +117 -48
- symai/backend/engines/neurosymbolic/engine_openai_gptX_chat.py +256 -229
- symai/backend/engines/neurosymbolic/engine_openai_gptX_reasoning.py +270 -150
- symai/backend/engines/ocr/engine_apilayer.py +6 -8
- symai/backend/engines/output/engine_stdout.py +1 -4
- symai/backend/engines/search/engine_openai.py +7 -7
- symai/backend/engines/search/engine_perplexity.py +5 -5
- symai/backend/engines/search/engine_serpapi.py +12 -14
- symai/backend/engines/speech_to_text/engine_local_whisper.py +20 -27
- symai/backend/engines/symbolic/engine_wolframalpha.py +3 -3
- symai/backend/engines/text_to_speech/engine_openai.py +5 -7
- symai/backend/engines/text_vision/engine_clip.py +7 -11
- symai/backend/engines/userinput/engine_console.py +3 -3
- symai/backend/engines/webscraping/engine_requests.py +81 -48
- symai/backend/mixin/__init__.py +13 -0
- symai/backend/mixin/anthropic.py +4 -2
- symai/backend/mixin/deepseek.py +2 -0
- symai/backend/mixin/google.py +2 -0
- symai/backend/mixin/openai.py +11 -3
- symai/backend/settings.py +83 -16
- symai/chat.py +101 -78
- symai/collect/__init__.py +7 -1
- symai/collect/dynamic.py +77 -69
- symai/collect/pipeline.py +35 -27
- symai/collect/stats.py +75 -63
- symai/components.py +198 -169
- symai/constraints.py +15 -12
- symai/core.py +698 -359
- symai/core_ext.py +32 -34
- symai/endpoints/api.py +80 -73
- symai/extended/.DS_Store +0 -0
- symai/extended/__init__.py +46 -12
- symai/extended/api_builder.py +11 -8
- symai/extended/arxiv_pdf_parser.py +13 -12
- symai/extended/bibtex_parser.py +2 -3
- symai/extended/conversation.py +101 -90
- symai/extended/document.py +17 -10
- symai/extended/file_merger.py +18 -13
- symai/extended/graph.py +18 -13
- symai/extended/html_style_template.py +2 -4
- symai/extended/interfaces/blip_2.py +1 -2
- symai/extended/interfaces/clip.py +1 -2
- symai/extended/interfaces/console.py +7 -1
- symai/extended/interfaces/dall_e.py +1 -1
- symai/extended/interfaces/flux.py +1 -1
- symai/extended/interfaces/gpt_image.py +1 -1
- symai/extended/interfaces/input.py +1 -1
- symai/extended/interfaces/llava.py +0 -1
- symai/extended/interfaces/naive_vectordb.py +7 -8
- symai/extended/interfaces/naive_webscraping.py +1 -1
- symai/extended/interfaces/ocr.py +1 -1
- symai/extended/interfaces/pinecone.py +6 -5
- symai/extended/interfaces/serpapi.py +1 -1
- symai/extended/interfaces/terminal.py +2 -3
- symai/extended/interfaces/tts.py +1 -1
- symai/extended/interfaces/whisper.py +1 -1
- symai/extended/interfaces/wolframalpha.py +1 -1
- symai/extended/metrics/__init__.py +11 -1
- symai/extended/metrics/similarity.py +11 -13
- symai/extended/os_command.py +17 -16
- symai/extended/packages/__init__.py +29 -3
- symai/extended/packages/symdev.py +19 -16
- symai/extended/packages/sympkg.py +12 -9
- symai/extended/packages/symrun.py +21 -19
- symai/extended/repo_cloner.py +11 -10
- symai/extended/seo_query_optimizer.py +1 -2
- symai/extended/solver.py +20 -23
- symai/extended/summarizer.py +4 -3
- symai/extended/taypan_interpreter.py +10 -12
- symai/extended/vectordb.py +99 -82
- symai/formatter/__init__.py +9 -1
- symai/formatter/formatter.py +12 -16
- symai/formatter/regex.py +62 -63
- symai/functional.py +176 -122
- symai/imports.py +136 -127
- symai/interfaces.py +56 -27
- symai/memory.py +14 -13
- symai/misc/console.py +49 -39
- symai/misc/loader.py +5 -3
- symai/models/__init__.py +17 -1
- symai/models/base.py +269 -181
- symai/models/errors.py +0 -1
- symai/ops/__init__.py +32 -22
- symai/ops/measures.py +11 -15
- symai/ops/primitives.py +348 -228
- symai/post_processors.py +32 -28
- symai/pre_processors.py +39 -41
- symai/processor.py +6 -4
- symai/prompts.py +59 -45
- symai/server/huggingface_server.py +23 -20
- symai/server/llama_cpp_server.py +7 -5
- symai/shell.py +3 -4
- symai/shellsv.py +499 -375
- symai/strategy.py +517 -287
- symai/symbol.py +111 -116
- symai/utils.py +42 -36
- {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/METADATA +4 -2
- symbolicai-1.0.0.dist-info/RECORD +163 -0
- symbolicai-0.20.2.dist-info/RECORD +0 -162
- {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/WHEEL +0 -0
- {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/entry_points.txt +0 -0
- {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/top_level.txt +0 -0
symai/strategy.py
CHANGED
|
@@ -2,7 +2,8 @@ import inspect
|
|
|
2
2
|
import logging
|
|
3
3
|
import time
|
|
4
4
|
from collections import defaultdict
|
|
5
|
-
from
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from typing import Any, ClassVar
|
|
6
7
|
|
|
7
8
|
import numpy as np
|
|
8
9
|
from beartype import beartype
|
|
@@ -14,9 +15,9 @@ from rich.panel import Panel
|
|
|
14
15
|
from rich.table import Table
|
|
15
16
|
|
|
16
17
|
from .components import Function
|
|
17
|
-
from .models import
|
|
18
|
-
build_dynamic_llm_datamodel)
|
|
18
|
+
from .models import LLMDataModel, TypeValidationError, build_dynamic_llm_datamodel
|
|
19
19
|
from .symbol import Expression
|
|
20
|
+
from .utils import UserMessage
|
|
20
21
|
|
|
21
22
|
|
|
22
23
|
class ValidationFunction(Function):
|
|
@@ -29,18 +30,18 @@ class ValidationFunction(Function):
|
|
|
29
30
|
• Error simplification
|
|
30
31
|
"""
|
|
31
32
|
# Have some default retry params that don't add overhead
|
|
32
|
-
_default_retry_params =
|
|
33
|
-
tries
|
|
34
|
-
delay
|
|
35
|
-
backoff
|
|
36
|
-
jitter
|
|
37
|
-
max_delay
|
|
38
|
-
graceful
|
|
39
|
-
|
|
33
|
+
_default_retry_params: ClassVar[dict[str, int | float | bool]] = {
|
|
34
|
+
"tries": 8,
|
|
35
|
+
"delay": 0.015,
|
|
36
|
+
"backoff": 1.25,
|
|
37
|
+
"jitter": 0.0,
|
|
38
|
+
"max_delay": 0.25,
|
|
39
|
+
"graceful": False,
|
|
40
|
+
}
|
|
40
41
|
|
|
41
42
|
def __init__(
|
|
42
43
|
self,
|
|
43
|
-
retry_params: dict[str, int | float | bool] = None,
|
|
44
|
+
retry_params: dict[str, int | float | bool] | None = None,
|
|
44
45
|
verbose: bool = False,
|
|
45
46
|
*args,
|
|
46
47
|
**kwargs,
|
|
@@ -92,10 +93,9 @@ class ValidationFunction(Function):
|
|
|
92
93
|
seed = 42
|
|
93
94
|
|
|
94
95
|
rnd = np.random.RandomState(seed=seed)
|
|
95
|
-
|
|
96
|
+
return rnd.randint(
|
|
96
97
|
0, np.iinfo(np.int16).max, size=num_seeds, dtype=np.int16
|
|
97
98
|
).tolist()
|
|
98
|
-
return seeds
|
|
99
99
|
|
|
100
100
|
def simplify_validation_errors(self, error: ValidationError) -> str:
|
|
101
101
|
"""
|
|
@@ -130,12 +130,14 @@ class ValidationFunction(Function):
|
|
|
130
130
|
_delay = min(base + jit, self.retry_params['max_delay'])
|
|
131
131
|
time.sleep(_delay)
|
|
132
132
|
|
|
133
|
-
def remedy_prompt(self, *
|
|
133
|
+
def remedy_prompt(self, *_args, **_kwargs):
|
|
134
134
|
"""
|
|
135
135
|
Abstract or base remedy prompt method.
|
|
136
136
|
Child classes typically override this to include additional context needed for correction.
|
|
137
137
|
"""
|
|
138
|
-
|
|
138
|
+
msg = "Each child class needs its own remedy_prompt implementation."
|
|
139
|
+
UserMessage(msg)
|
|
140
|
+
raise NotImplementedError(msg)
|
|
139
141
|
|
|
140
142
|
def display_panel(self, content, title, border_style="cyan", style="#f0eee6", padding=(1,2)):
|
|
141
143
|
"""
|
|
@@ -171,7 +173,7 @@ class TypeValidationFunction(ValidationFunction):
|
|
|
171
173
|
*args,
|
|
172
174
|
**kwargs,
|
|
173
175
|
):
|
|
174
|
-
super().__init__(retry_params=retry_params, verbose=verbose,
|
|
176
|
+
super().__init__(*args, retry_params=retry_params, verbose=verbose, **kwargs)
|
|
175
177
|
self.input_data_model = None
|
|
176
178
|
self.output_data_model = None
|
|
177
179
|
self.accumulate_errors = accumulate_errors
|
|
@@ -181,11 +183,15 @@ class TypeValidationFunction(ValidationFunction):
|
|
|
181
183
|
assert attach_to in ["input", "output"], f"Invalid attach_to value: {attach_to}; must be either 'input' or 'output'"
|
|
182
184
|
if attach_to == "input":
|
|
183
185
|
if self.input_data_model is not None and not override:
|
|
184
|
-
|
|
186
|
+
msg = "There is already a data model attached to the input. If you want to override it, set `override=True`."
|
|
187
|
+
UserMessage(msg)
|
|
188
|
+
raise ValueError(msg)
|
|
185
189
|
self.input_data_model = data_model
|
|
186
190
|
elif attach_to == "output":
|
|
187
191
|
if self.output_data_model is not None and not override:
|
|
188
|
-
|
|
192
|
+
msg = "There is already a data model attached to the output. If you want to override it, set `override=True`."
|
|
193
|
+
UserMessage(msg)
|
|
194
|
+
raise ValueError(msg)
|
|
189
195
|
self.output_data_model = data_model
|
|
190
196
|
|
|
191
197
|
def remedy_prompt(self, prompt: str, output: str, errors: str) -> str:
|
|
@@ -268,118 +274,174 @@ Important guidelines:
|
|
|
268
274
|
</guidelines>
|
|
269
275
|
"""
|
|
270
276
|
|
|
271
|
-
def
|
|
277
|
+
def _ensure_output_model(self):
|
|
272
278
|
if self.output_data_model is None:
|
|
273
|
-
|
|
279
|
+
msg = (
|
|
280
|
+
"While the input data model is optional, the output data model must be provided. "
|
|
281
|
+
"Please register it before calling the `forward` method."
|
|
282
|
+
)
|
|
283
|
+
UserMessage(msg)
|
|
284
|
+
raise ValueError(msg)
|
|
285
|
+
|
|
286
|
+
def _display_verbose_panels(self, prompt: str):
|
|
287
|
+
if not self.verbose:
|
|
288
|
+
return
|
|
289
|
+
for label, body in [
|
|
290
|
+
("Prompt", prompt),
|
|
291
|
+
("Input data model", self.input_data_model.simplify_json_schema() if self.input_data_model else 'N/A'),
|
|
292
|
+
("Output data model", self.output_data_model.simplify_json_schema()),
|
|
293
|
+
]:
|
|
294
|
+
self.display_panel(body, title=label)
|
|
295
|
+
|
|
296
|
+
def _check_semantic_conditions(
|
|
297
|
+
self,
|
|
298
|
+
result,
|
|
299
|
+
f_semantic_conditions: list[Callable] | None,
|
|
300
|
+
) -> str | None:
|
|
301
|
+
if f_semantic_conditions is None:
|
|
302
|
+
return None
|
|
303
|
+
try:
|
|
304
|
+
assert all(
|
|
305
|
+
f(result if not getattr(self.output_data_model, '_is_dynamic_model', False) else result.value)
|
|
306
|
+
for f in f_semantic_conditions
|
|
307
|
+
)
|
|
308
|
+
except Exception as err:
|
|
309
|
+
return f"Semantic validation failed with:\n{err!s}"
|
|
310
|
+
return None
|
|
311
|
+
|
|
312
|
+
def _format_validation_error(self, error: Exception) -> str:
|
|
313
|
+
if isinstance(error, ValidationError):
|
|
314
|
+
return self.simplify_validation_errors(error)
|
|
315
|
+
return str(error)
|
|
316
|
+
|
|
317
|
+
def _handle_failed_validation_attempt(
|
|
318
|
+
self,
|
|
319
|
+
attempt_index: int,
|
|
320
|
+
prompt: str,
|
|
321
|
+
json_str: str,
|
|
322
|
+
errors: list[str],
|
|
323
|
+
error: Exception,
|
|
324
|
+
remedy_seeds: list[Any],
|
|
325
|
+
kwargs: dict,
|
|
326
|
+
) -> str:
|
|
327
|
+
logger.info(f"Validation attempt {attempt_index + 1} failed, pausing before retry…")
|
|
328
|
+
self._pause(attempt_index)
|
|
329
|
+
error_str = self._format_validation_error(error)
|
|
330
|
+
errors.append(error_str)
|
|
331
|
+
|
|
332
|
+
logger.error("Validation errors identified!")
|
|
333
|
+
if self.verbose:
|
|
334
|
+
errors_report = "\n".join(errors) if self.accumulate_errors else error_str
|
|
335
|
+
title = f"Validation Errors ({'accumulated errors' if self.accumulate_errors else 'last error'})"
|
|
336
|
+
self.display_panel(errors_report, title=title, border_style="red")
|
|
337
|
+
|
|
338
|
+
logger.info("Updating remedy function context…")
|
|
339
|
+
context = self.remedy_prompt(
|
|
340
|
+
prompt=prompt,
|
|
341
|
+
output=json_str,
|
|
342
|
+
errors="\n".join(errors) if self.accumulate_errors else error_str,
|
|
343
|
+
)
|
|
344
|
+
self.remedy_function.clear()
|
|
345
|
+
self.remedy_function.adapt(context)
|
|
346
|
+
if self.verbose:
|
|
347
|
+
self.display_panel(self.remedy_function.dynamic_context, title="New Context")
|
|
348
|
+
|
|
349
|
+
json_str = self.remedy_function(seed=remedy_seeds[attempt_index], **kwargs).value
|
|
350
|
+
logger.info("Applied remedy function with updated context!")
|
|
351
|
+
return json_str
|
|
352
|
+
|
|
353
|
+
def _run_validation_attempts(
|
|
354
|
+
self,
|
|
355
|
+
prompt: str,
|
|
356
|
+
f_semantic_conditions: list[Callable] | None,
|
|
357
|
+
validation_context: dict,
|
|
358
|
+
remedy_seeds: list[Any],
|
|
359
|
+
json_str: str,
|
|
360
|
+
kwargs: dict,
|
|
361
|
+
) -> tuple[Any | None, str, list[str]]:
|
|
362
|
+
errors: list[str] = []
|
|
363
|
+
result = None
|
|
364
|
+
total_attempts = self.retry_params["tries"] + 1
|
|
365
|
+
for attempt in range(total_attempts):
|
|
366
|
+
if attempt != self.retry_params["tries"]:
|
|
367
|
+
logger.info(f"Attempt {attempt + 1}/{self.retry_params['tries']}: Attempting validation…")
|
|
368
|
+
try:
|
|
369
|
+
result = self.output_data_model.model_validate_json(
|
|
370
|
+
json_str,
|
|
371
|
+
strict=False,
|
|
372
|
+
context=validation_context,
|
|
373
|
+
)
|
|
374
|
+
semantic_error = self._check_semantic_conditions(result, f_semantic_conditions)
|
|
375
|
+
if semantic_error is not None:
|
|
376
|
+
if attempt == total_attempts - 1:
|
|
377
|
+
result = None
|
|
378
|
+
errors.append(semantic_error)
|
|
379
|
+
break
|
|
380
|
+
raise AssertionError(semantic_error)
|
|
381
|
+
break
|
|
382
|
+
except Exception as error:
|
|
383
|
+
json_str = self._handle_failed_validation_attempt(
|
|
384
|
+
attempt,
|
|
385
|
+
prompt,
|
|
386
|
+
json_str,
|
|
387
|
+
errors,
|
|
388
|
+
error,
|
|
389
|
+
remedy_seeds,
|
|
390
|
+
kwargs,
|
|
391
|
+
)
|
|
392
|
+
return result, json_str, errors
|
|
393
|
+
|
|
394
|
+
def _handle_validation_failure(self, prompt: str, json_str: str, errors: list[str]):
|
|
395
|
+
logger.error("All validation attempts failed!")
|
|
396
|
+
if self.retry_params['graceful']:
|
|
397
|
+
return
|
|
398
|
+
raise TypeValidationError(
|
|
399
|
+
prompt=prompt,
|
|
400
|
+
result=json_str,
|
|
401
|
+
violations=errors,
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
def forward(self, prompt: str, f_semantic_conditions: list[Callable] | None = None, *args, **kwargs):
|
|
405
|
+
self._ensure_output_model()
|
|
274
406
|
validation_context = kwargs.pop('validation_context', {})
|
|
275
|
-
# Force JSON mode
|
|
276
407
|
kwargs["response_format"] = {"type": "json_object"}
|
|
277
408
|
logger.info("Initializing validation…")
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
("Prompt", prompt),
|
|
281
|
-
("Input data model", self.input_data_model.simplify_json_schema() if self.input_data_model else 'N/A'),
|
|
282
|
-
("Output data model", self.output_data_model.simplify_json_schema()),
|
|
283
|
-
]:
|
|
284
|
-
self.display_panel(body, title=label)
|
|
285
|
-
|
|
286
|
-
# Zero shot the task
|
|
409
|
+
self._display_verbose_panels(prompt)
|
|
410
|
+
|
|
287
411
|
context = self.zero_shot_prompt(prompt=prompt)
|
|
288
412
|
json_str = super().forward(context, *args, **kwargs).value
|
|
289
413
|
|
|
290
414
|
remedy_seeds = self.prepare_seeds(self.retry_params["tries"] + 1, **kwargs)
|
|
291
415
|
logger.info(f"Prepared {len(remedy_seeds)} remedy seeds for validation attempts…")
|
|
292
416
|
|
|
293
|
-
result =
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
# - String keys "123" -> integer keys 123 for dict[int, ...] types
|
|
302
|
-
# - String numbers "42" -> int 42
|
|
303
|
-
# - Float 1.0 -> int 1
|
|
304
|
-
# These coercions are safe and helpful when dealing with JSON from LLMs,
|
|
305
|
-
# while still catching actual type errors (e.g., "not_a_number" -> int fails).
|
|
306
|
-
result = self.output_data_model.model_validate_json(json_str, strict=False, context=validation_context)
|
|
307
|
-
if f_semantic_conditions is not None:
|
|
308
|
-
try:
|
|
309
|
-
assert all(
|
|
310
|
-
f(result if not getattr(self.output_data_model, '_is_dynamic_model', False) else result.value)
|
|
311
|
-
for f in f_semantic_conditions
|
|
312
|
-
)
|
|
313
|
-
except Exception as e:
|
|
314
|
-
# If we are in the last attempt and semantic validation fails, result will be None and we propagate the error
|
|
315
|
-
if i == self.retry_params["tries"]:
|
|
316
|
-
result = None
|
|
317
|
-
errors.append(f"Semantic validation failed with:\n{str(e)}")
|
|
318
|
-
break # We break to avoid going into the remedy loop
|
|
319
|
-
raise e
|
|
320
|
-
break
|
|
321
|
-
except Exception as e:
|
|
322
|
-
logger.info(f"Validation attempt {i+1} failed, pausing before retry…")
|
|
323
|
-
|
|
324
|
-
self._pause(i)
|
|
325
|
-
|
|
326
|
-
if isinstance(e, ValidationError):
|
|
327
|
-
error_str = self.simplify_validation_errors(e)
|
|
328
|
-
else:
|
|
329
|
-
error_str = str(e)
|
|
330
|
-
|
|
331
|
-
errors.append(error_str)
|
|
332
|
-
|
|
333
|
-
logger.error(f"Validation errors identified!")
|
|
334
|
-
if self.verbose:
|
|
335
|
-
self.display_panel(
|
|
336
|
-
"\n".join(errors) if self.accumulate_errors else error_str,
|
|
337
|
-
title=f"Validation Errors ({'accumulated errors' if self.accumulate_errors else 'last error'})",
|
|
338
|
-
border_style="red"
|
|
339
|
-
)
|
|
340
|
-
|
|
341
|
-
# Update remedy function context
|
|
342
|
-
logger.info("Updating remedy function context…")
|
|
343
|
-
context = self.remedy_prompt(prompt=prompt, output=json_str, errors="\n".join(errors) if self.accumulate_errors else error_str)
|
|
344
|
-
self.remedy_function.clear()
|
|
345
|
-
self.remedy_function.adapt(context)
|
|
346
|
-
if self.verbose:
|
|
347
|
-
self.display_panel(
|
|
348
|
-
self.remedy_function.dynamic_context,
|
|
349
|
-
title="New Context"
|
|
350
|
-
)
|
|
351
|
-
|
|
352
|
-
# Apply the remedy function
|
|
353
|
-
json_str = self.remedy_function(seed=remedy_seeds[i], **kwargs).value
|
|
354
|
-
logger.info("Applied remedy function with updated context!")
|
|
417
|
+
result, json_str, errors = self._run_validation_attempts(
|
|
418
|
+
prompt,
|
|
419
|
+
f_semantic_conditions,
|
|
420
|
+
validation_context,
|
|
421
|
+
remedy_seeds,
|
|
422
|
+
json_str,
|
|
423
|
+
kwargs,
|
|
424
|
+
)
|
|
355
425
|
|
|
356
426
|
if result is None:
|
|
357
|
-
|
|
358
|
-
if self.retry_params['graceful']:
|
|
359
|
-
return
|
|
360
|
-
raise TypeValidationError(
|
|
361
|
-
prompt=prompt,
|
|
362
|
-
result=json_str,
|
|
363
|
-
violations=errors,
|
|
364
|
-
)
|
|
427
|
+
return self._handle_validation_failure(prompt, json_str, errors)
|
|
365
428
|
|
|
366
429
|
logger.success("Validation completed successfully!")
|
|
367
|
-
# Clear artifacts from the remedy function
|
|
368
430
|
self.remedy_function.clear()
|
|
369
|
-
|
|
370
431
|
return result
|
|
371
432
|
|
|
372
433
|
|
|
373
434
|
@beartype
|
|
374
435
|
class contract:
|
|
375
|
-
_default_remedy_retry_params =
|
|
376
|
-
tries
|
|
377
|
-
delay
|
|
378
|
-
backoff
|
|
379
|
-
jitter
|
|
380
|
-
max_delay
|
|
381
|
-
graceful
|
|
382
|
-
|
|
436
|
+
_default_remedy_retry_params: ClassVar[dict[str, int | float | bool]] = {
|
|
437
|
+
"tries": 8,
|
|
438
|
+
"delay": 0.015,
|
|
439
|
+
"backoff": 1.25,
|
|
440
|
+
"jitter": 0.0,
|
|
441
|
+
"max_delay": 0.25,
|
|
442
|
+
"graceful": False,
|
|
443
|
+
}
|
|
444
|
+
_internal_forward_kwargs: ClassVar[set[str]] = {"validation_context"}
|
|
383
445
|
|
|
384
446
|
def __init__(
|
|
385
447
|
self,
|
|
@@ -404,22 +466,30 @@ class contract:
|
|
|
404
466
|
else:
|
|
405
467
|
logger.enable(__name__)
|
|
406
468
|
|
|
407
|
-
def _is_valid_input(self,
|
|
408
|
-
if
|
|
469
|
+
def _is_valid_input(self, input_value):
|
|
470
|
+
if input_value is None:
|
|
409
471
|
logger.error("No `input` argument provided!")
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
472
|
+
msg = "Please provide an `input` argument."
|
|
473
|
+
UserMessage(msg)
|
|
474
|
+
raise ValueError(msg)
|
|
475
|
+
if not isinstance(input_value, LLMDataModel):
|
|
476
|
+
logger.error(f"Invalid input type: {type(input_value)}")
|
|
477
|
+
msg = f"Expected input to be of type `LLMDataModel`, got {type(input_value)}"
|
|
478
|
+
UserMessage(msg)
|
|
479
|
+
raise TypeError(msg)
|
|
414
480
|
return True
|
|
415
481
|
|
|
416
482
|
def _is_valid_output(self, output_type):
|
|
417
483
|
if output_type == inspect._empty:
|
|
418
484
|
logger.error("Missing return type annotation!")
|
|
419
|
-
|
|
485
|
+
msg = "The contract requires a return type annotation."
|
|
486
|
+
UserMessage(msg)
|
|
487
|
+
raise ValueError(msg)
|
|
420
488
|
if not issubclass(output_type, LLMDataModel):
|
|
421
489
|
logger.error(f"Invalid return type: {output_type}")
|
|
422
|
-
|
|
490
|
+
msg = "The return type annotation must be a subclass of `LLMDataModel`."
|
|
491
|
+
UserMessage(msg)
|
|
492
|
+
raise TypeError(msg)
|
|
423
493
|
return True
|
|
424
494
|
|
|
425
495
|
def _try_dynamic_type_annotation(self, original_forward, *, context: str = "input"):
|
|
@@ -428,23 +498,42 @@ class contract:
|
|
|
428
498
|
)
|
|
429
499
|
sig = inspect.signature(original_forward)
|
|
430
500
|
try:
|
|
501
|
+
resolved_param = None
|
|
431
502
|
# Fallback: look at the relevant part of the function signature
|
|
432
503
|
# depending on whether we deal with an *input* or *output*
|
|
433
504
|
if context == "input":
|
|
434
505
|
param = sig.parameters.get("input")
|
|
506
|
+
if param is None:
|
|
507
|
+
for candidate in sig.parameters.values():
|
|
508
|
+
if candidate.name == "self":
|
|
509
|
+
continue
|
|
510
|
+
if candidate.kind in (
|
|
511
|
+
inspect.Parameter.POSITIONAL_ONLY,
|
|
512
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
513
|
+
inspect.Parameter.KEYWORD_ONLY,
|
|
514
|
+
):
|
|
515
|
+
param = candidate
|
|
516
|
+
break
|
|
517
|
+
resolved_param = param
|
|
435
518
|
if param is None or param.annotation == inspect._empty:
|
|
436
|
-
|
|
519
|
+
msg = "Failed to infer type from input parameter annotation"
|
|
520
|
+
UserMessage(msg)
|
|
521
|
+
raise TypeError(msg)
|
|
437
522
|
dynamic_model = build_dynamic_llm_datamodel(param.annotation)
|
|
438
523
|
else: # context == "output"
|
|
439
524
|
if sig.return_annotation == inspect._empty:
|
|
440
|
-
|
|
525
|
+
msg = "Failed to infer type from return annotation"
|
|
526
|
+
UserMessage(msg)
|
|
527
|
+
raise TypeError(msg)
|
|
441
528
|
dynamic_model = build_dynamic_llm_datamodel(sig.return_annotation)
|
|
442
|
-
except Exception:
|
|
443
|
-
logger.exception(f"Failed to build dynamic LLMDataModel from {
|
|
444
|
-
|
|
529
|
+
except Exception as err:
|
|
530
|
+
logger.exception(f"Failed to build dynamic LLMDataModel from {resolved_param}!")
|
|
531
|
+
msg = (
|
|
445
532
|
"The type annotation must be a subclass of `LLMDataModel` or a "
|
|
446
533
|
"valid Python typing object supported by Pydantic."
|
|
447
534
|
)
|
|
535
|
+
UserMessage(msg)
|
|
536
|
+
raise TypeError(msg) from err
|
|
448
537
|
|
|
449
538
|
dynamic_model._is_dynamic_model = True
|
|
450
539
|
return dynamic_model
|
|
@@ -453,49 +542,54 @@ class contract:
|
|
|
453
542
|
try:
|
|
454
543
|
data_model = self.f_type_validation_remedy(prompt, f_semantic_conditions=f_semantic_conditions, **remedy_kwargs)
|
|
455
544
|
except Exception as e:
|
|
456
|
-
logger.error(
|
|
545
|
+
logger.error("Type validation failed with exception!")
|
|
457
546
|
raise e
|
|
458
547
|
return data_model
|
|
459
548
|
|
|
460
|
-
def _validate_input(self, wrapped_self,
|
|
549
|
+
def _validate_input(self, wrapped_self, input_value, it, **remedy_kwargs):
|
|
461
550
|
logger.info("Starting input validation...")
|
|
462
551
|
if self.pre_remedy:
|
|
463
552
|
logger.info("Validating pre-conditions with remedy...")
|
|
464
553
|
if not hasattr(wrapped_self, 'pre'):
|
|
465
554
|
logger.error("Pre-condition function not defined!")
|
|
466
|
-
|
|
555
|
+
msg = "Pre-condition function not defined. Please define a `pre` method if you want to enforce pre-conditions through a remedy."
|
|
556
|
+
UserMessage(msg)
|
|
557
|
+
raise Exception(msg)
|
|
467
558
|
|
|
468
559
|
op_start = time.perf_counter()
|
|
469
560
|
try:
|
|
470
|
-
assert wrapped_self.pre(
|
|
561
|
+
assert wrapped_self.pre(input_value)
|
|
471
562
|
logger.success("Pre-condition validation successful!")
|
|
472
|
-
return
|
|
473
|
-
except Exception
|
|
563
|
+
return input_value
|
|
564
|
+
except Exception:
|
|
474
565
|
logger.exception("Pre-condition validation failed!")
|
|
475
|
-
self.f_type_validation_remedy.register_expected_data_model(
|
|
476
|
-
|
|
566
|
+
self.f_type_validation_remedy.register_expected_data_model(input_value, attach_to="output", override=True)
|
|
567
|
+
input_value = self._try_remedy_with_exception(
|
|
568
|
+
prompt=wrapped_self.prompt,
|
|
569
|
+
f_semantic_conditions=[wrapped_self.pre],
|
|
570
|
+
**remedy_kwargs,
|
|
571
|
+
)
|
|
477
572
|
finally:
|
|
478
573
|
wrapped_self._contract_timing[it]["input_validation"] = time.perf_counter() - op_start
|
|
479
|
-
return
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
return input
|
|
574
|
+
return input_value
|
|
575
|
+
if hasattr(wrapped_self, 'pre'):
|
|
576
|
+
logger.info("Validating pre-conditions without remedy...")
|
|
577
|
+
op_start = time.perf_counter()
|
|
578
|
+
try:
|
|
579
|
+
assert wrapped_self.pre(input_value)
|
|
580
|
+
except Exception as e:
|
|
581
|
+
logger.exception("Pre-condition validation failed")
|
|
582
|
+
raise e
|
|
583
|
+
finally:
|
|
584
|
+
wrapped_self._contract_timing[it]["input_validation"] = time.perf_counter() - op_start
|
|
585
|
+
logger.success("Pre-condition validation successful!")
|
|
586
|
+
return input_value
|
|
587
|
+
logger.info("Skip; no pre-condition validation was required!")
|
|
588
|
+
return input_value
|
|
495
589
|
|
|
496
|
-
def _validate_output(self, wrapped_self,
|
|
590
|
+
def _validate_output(self, wrapped_self, input_value, output, it, **remedy_kwargs):
|
|
497
591
|
logger.info("Starting output validation...")
|
|
498
|
-
self.f_type_validation_remedy.register_expected_data_model(
|
|
592
|
+
self.f_type_validation_remedy.register_expected_data_model(input_value, attach_to="input", override=True)
|
|
499
593
|
self.f_type_validation_remedy.register_expected_data_model(output, attach_to="output", override=True)
|
|
500
594
|
|
|
501
595
|
op_start = time.perf_counter()
|
|
@@ -505,7 +599,7 @@ class contract:
|
|
|
505
599
|
if output is None: # output is None when graceful mode is enabled
|
|
506
600
|
return output
|
|
507
601
|
except Exception as e:
|
|
508
|
-
logger.exception(
|
|
602
|
+
logger.exception("Type creation failed!")
|
|
509
603
|
raise e
|
|
510
604
|
finally:
|
|
511
605
|
wrapped_self._contract_timing[it]["output_validation"] = time.perf_counter() - op_start
|
|
@@ -515,88 +609,111 @@ class contract:
|
|
|
515
609
|
logger.info("Validating post-conditions with remedy...")
|
|
516
610
|
if not hasattr(wrapped_self, "post"):
|
|
517
611
|
logger.error("Post-condition function not defined!")
|
|
518
|
-
|
|
612
|
+
msg = "Post-condition function not defined. Please define a `post` method if you want to enforce post-conditions through a remedy."
|
|
613
|
+
UserMessage(msg)
|
|
614
|
+
raise Exception(msg)
|
|
519
615
|
|
|
520
616
|
op_start = time.perf_counter()
|
|
521
617
|
try:
|
|
522
618
|
assert wrapped_self.post(output)
|
|
523
619
|
logger.success("Post-condition validation successful!")
|
|
524
620
|
return output
|
|
525
|
-
except Exception
|
|
621
|
+
except Exception:
|
|
526
622
|
logger.exception("Post-condition validation failed!")
|
|
527
623
|
output = self._try_remedy_with_exception(prompt=wrapped_self.prompt, f_semantic_conditions=[wrapped_self.post], **remedy_kwargs)
|
|
528
624
|
finally:
|
|
529
625
|
wrapped_self._contract_timing[it]["output_validation"] += (time.perf_counter() - op_start)
|
|
530
626
|
logger.success("Post-condition validation successful!")
|
|
531
627
|
return output
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
return output
|
|
628
|
+
if hasattr(wrapped_self, "post"):
|
|
629
|
+
logger.info("Validating post-conditions without remedy...")
|
|
630
|
+
op_start = time.perf_counter()
|
|
631
|
+
try:
|
|
632
|
+
assert wrapped_self.post(output)
|
|
633
|
+
except Exception as e:
|
|
634
|
+
logger.exception("Post-condition validation failed!")
|
|
635
|
+
raise e
|
|
636
|
+
finally:
|
|
637
|
+
wrapped_self._contract_timing[it]["output_validation"] = time.perf_counter() - op_start
|
|
638
|
+
logger.success("Post-condition validation successful!")
|
|
639
|
+
return output
|
|
545
640
|
logger.info("Skip; no post-condition validation was required!")
|
|
546
641
|
return output
|
|
547
642
|
|
|
548
643
|
def _validate_act_method(self, act_method):
|
|
549
644
|
act_sig = inspect.signature(act_method)
|
|
550
645
|
params = list(act_sig.parameters.values())
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
646
|
+
|
|
647
|
+
first_param = None
|
|
648
|
+
for param in params:
|
|
649
|
+
if param.name == "self":
|
|
650
|
+
continue
|
|
651
|
+
if param.kind in (
|
|
652
|
+
inspect.Parameter.POSITIONAL_ONLY,
|
|
653
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
654
|
+
):
|
|
655
|
+
first_param = param
|
|
656
|
+
break
|
|
657
|
+
|
|
658
|
+
if first_param is None:
|
|
659
|
+
msg = "'act' method must accept at least one positional parameter after `self`."
|
|
660
|
+
UserMessage(msg)
|
|
661
|
+
raise TypeError(msg)
|
|
662
|
+
if first_param.annotation == inspect._empty:
|
|
663
|
+
msg = f"'act' method parameter '{first_param.name}' must have a type annotation."
|
|
664
|
+
UserMessage(msg)
|
|
665
|
+
raise TypeError(msg)
|
|
555
666
|
if act_sig.return_annotation == inspect._empty:
|
|
556
|
-
|
|
667
|
+
msg = "'act' method must have a return type annotation'"
|
|
668
|
+
UserMessage(msg)
|
|
669
|
+
raise TypeError(msg)
|
|
557
670
|
return True
|
|
558
671
|
|
|
559
|
-
def _act(self, wrapped_self,
|
|
672
|
+
def _act(self, wrapped_self, input_value, it, **act_kwargs):
|
|
560
673
|
act_method = getattr(wrapped_self, 'act', None)
|
|
561
674
|
if not callable(act_method):
|
|
562
675
|
# Propagate the input if no act method is defined
|
|
563
|
-
return
|
|
676
|
+
return input_value
|
|
564
677
|
|
|
565
678
|
assert self._validate_act_method(act_method)
|
|
566
679
|
|
|
567
|
-
is_dynamic_model = getattr(
|
|
568
|
-
|
|
680
|
+
is_dynamic_model = getattr(input_value, '_is_dynamic_model', False)
|
|
681
|
+
input_value = input_value if not is_dynamic_model else input_value.value
|
|
569
682
|
|
|
570
683
|
logger.info(f"Executing 'act' method on {wrapped_self.__class__.__name__}…")
|
|
571
684
|
|
|
572
685
|
op_start = time.perf_counter()
|
|
573
686
|
try:
|
|
574
|
-
act_output = act_method(
|
|
687
|
+
act_output = act_method(input_value, **act_kwargs)
|
|
575
688
|
except Exception as e:
|
|
576
|
-
logger.exception(
|
|
689
|
+
logger.exception("'act' method execution failed")
|
|
577
690
|
raise e
|
|
578
691
|
finally:
|
|
579
692
|
wrapped_self._contract_timing[it]["act_execution"] = time.perf_counter() - op_start
|
|
580
693
|
|
|
581
694
|
act_sig = inspect.signature(act_method)
|
|
582
|
-
if
|
|
583
|
-
|
|
584
|
-
|
|
695
|
+
if (
|
|
696
|
+
act_sig.return_annotation != inspect.Signature.empty
|
|
697
|
+
and inspect.isclass(act_sig.return_annotation)
|
|
698
|
+
and not isinstance(act_output, act_sig.return_annotation)
|
|
699
|
+
):
|
|
700
|
+
msg = f"'act' method returned {type(act_output).__name__}, expected {act_sig.return_annotation.__name__}."
|
|
701
|
+
UserMessage(msg)
|
|
702
|
+
raise TypeError(msg)
|
|
585
703
|
|
|
586
704
|
logger.success("'act' method executed successfully!")
|
|
587
705
|
return act_output
|
|
588
706
|
|
|
589
|
-
def
|
|
590
|
-
original_init = cls.__init__
|
|
591
|
-
original_forward = cls.forward
|
|
592
|
-
|
|
707
|
+
def _build_wrapped_init(self, original_init):
|
|
593
708
|
def __init__(wrapped_self, *args, **kwargs):
|
|
594
709
|
logger.info("Initializing contract...")
|
|
595
710
|
original_init(wrapped_self, *args, **kwargs)
|
|
596
711
|
|
|
597
712
|
if not hasattr(wrapped_self, "prompt"):
|
|
598
713
|
logger.error("Prompt attribute not defined!")
|
|
599
|
-
|
|
714
|
+
msg = "Please define a static `prompt` attribute that describes what the contract must do."
|
|
715
|
+
UserMessage(msg)
|
|
716
|
+
raise Exception(msg)
|
|
600
717
|
|
|
601
718
|
wrapped_self.contract_successful = False
|
|
602
719
|
wrapped_self.contract_result = None
|
|
@@ -604,105 +721,207 @@ class contract:
|
|
|
604
721
|
wrapped_self._contract_timing = defaultdict(dict)
|
|
605
722
|
logger.info("Contract initialization complete!")
|
|
606
723
|
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
724
|
+
return __init__
|
|
725
|
+
|
|
726
|
+
def _start_contract_execution(self, wrapped_self):
|
|
727
|
+
logger.info("Starting contract execution...")
|
|
728
|
+
it = len(wrapped_self._contract_timing) # the len is the __call__ op_start
|
|
729
|
+
return it, time.perf_counter()
|
|
730
|
+
|
|
731
|
+
def _find_input_param_name(self, sig: inspect.Signature) -> str | None:
|
|
732
|
+
for param in sig.parameters.values():
|
|
733
|
+
if param.name == "self":
|
|
734
|
+
continue
|
|
735
|
+
if param.kind in (
|
|
736
|
+
inspect.Parameter.POSITIONAL_ONLY,
|
|
737
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
738
|
+
inspect.Parameter.KEYWORD_ONLY,
|
|
739
|
+
):
|
|
740
|
+
return param.name
|
|
741
|
+
return None
|
|
611
742
|
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
input_type = self._try_dynamic_type_annotation(original_forward, context="input")
|
|
618
|
-
input = input_type(value=input)
|
|
619
|
-
|
|
620
|
-
maybe_payload = getattr(wrapped_self, "payload", None)
|
|
621
|
-
maybe_template = getattr(wrapped_self, "template")
|
|
622
|
-
if inspect.ismethod(maybe_template):
|
|
623
|
-
# `template` is a primitive in symbolicai case in which we actually don't have a template
|
|
624
|
-
maybe_template = None
|
|
625
|
-
|
|
626
|
-
# Create validation kwargs that include all original kwargs plus payload and template
|
|
627
|
-
validation_kwargs = {
|
|
628
|
-
**kwargs,
|
|
629
|
-
"payload": maybe_payload,
|
|
630
|
-
"template_suffix": maybe_template
|
|
631
|
-
}
|
|
743
|
+
def _prepare_forward_args(self, args, kwargs):
|
|
744
|
+
args_list = list(args)
|
|
745
|
+
kwargs_without_input = dict(kwargs)
|
|
746
|
+
original_kwargs = dict(kwargs)
|
|
747
|
+
return args_list, kwargs_without_input, original_kwargs
|
|
632
748
|
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
749
|
+
def _extract_input_value(
|
|
750
|
+
self,
|
|
751
|
+
args_list,
|
|
752
|
+
kwargs_without_input,
|
|
753
|
+
original_kwargs,
|
|
754
|
+
input_param_name: str | None,
|
|
755
|
+
):
|
|
756
|
+
if args_list:
|
|
757
|
+
return args_list[0], ("args", 0)
|
|
758
|
+
if input_param_name and input_param_name in kwargs_without_input:
|
|
759
|
+
return kwargs_without_input.pop(input_param_name), ("kwargs", input_param_name)
|
|
760
|
+
if "input" in kwargs_without_input:
|
|
761
|
+
return kwargs_without_input.pop("input"), ("kwargs", "input")
|
|
762
|
+
return original_kwargs.get("input"), ("fallback_kw", "input")
|
|
763
|
+
|
|
764
|
+
def _coerce_input_value(self, original_forward, input_value):
|
|
765
|
+
try:
|
|
766
|
+
assert self._is_valid_input(input_value)
|
|
767
|
+
return input_value
|
|
768
|
+
except TypeError:
|
|
769
|
+
input_type_model = self._try_dynamic_type_annotation(original_forward, context="input")
|
|
770
|
+
return input_type_model(value=input_value)
|
|
771
|
+
|
|
772
|
+
def _collect_validation_kwargs(self, wrapped_self, kwargs_without_input):
|
|
773
|
+
maybe_payload = getattr(wrapped_self, "payload", None)
|
|
774
|
+
maybe_template = getattr(wrapped_self, "template", None)
|
|
775
|
+
if inspect.ismethod(maybe_template):
|
|
776
|
+
maybe_template = None
|
|
777
|
+
return {
|
|
778
|
+
**kwargs_without_input,
|
|
779
|
+
"payload": maybe_payload,
|
|
780
|
+
"template_suffix": maybe_template,
|
|
781
|
+
}
|
|
782
|
+
|
|
783
|
+
def _resolve_output_type(self, sig: inspect.Signature, original_forward):
|
|
784
|
+
output_type = sig.return_annotation
|
|
785
|
+
try:
|
|
786
|
+
assert self._is_valid_output(output_type)
|
|
787
|
+
except TypeError:
|
|
788
|
+
output_type = self._try_dynamic_type_annotation(original_forward, context="output")
|
|
789
|
+
return output_type
|
|
639
790
|
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
791
|
+
def _run_contract_pipeline(
|
|
792
|
+
self,
|
|
793
|
+
wrapped_self,
|
|
794
|
+
current_input_value,
|
|
795
|
+
output_type,
|
|
796
|
+
it,
|
|
797
|
+
validation_kwargs,
|
|
798
|
+
):
|
|
799
|
+
output = None
|
|
800
|
+
try:
|
|
801
|
+
maybe_new_input = self._validate_input(wrapped_self, current_input_value, it, **validation_kwargs)
|
|
802
|
+
if maybe_new_input is not None:
|
|
803
|
+
current_input_value = maybe_new_input
|
|
804
|
+
|
|
805
|
+
current_input_value = self._act(wrapped_self, current_input_value, it, **validation_kwargs)
|
|
806
|
+
|
|
807
|
+
output = self._validate_output(
|
|
808
|
+
wrapped_self,
|
|
809
|
+
current_input_value,
|
|
810
|
+
output_type,
|
|
811
|
+
it,
|
|
812
|
+
**validation_kwargs,
|
|
813
|
+
)
|
|
814
|
+
wrapped_self.contract_successful = output is not None
|
|
815
|
+
wrapped_self.contract_result = output
|
|
816
|
+
wrapped_self.contract_exception = None
|
|
817
|
+
except Exception as exc:
|
|
818
|
+
logger.exception("Contract execution failed in main path!")
|
|
819
|
+
wrapped_self.contract_successful = False
|
|
820
|
+
wrapped_self.contract_exception = exc
|
|
821
|
+
return output, current_input_value
|
|
647
822
|
|
|
648
|
-
|
|
649
|
-
|
|
823
|
+
def _execute_forward_call(
|
|
824
|
+
self,
|
|
825
|
+
wrapped_self,
|
|
826
|
+
original_forward,
|
|
827
|
+
args_list,
|
|
828
|
+
original_kwargs,
|
|
829
|
+
input_param_name,
|
|
830
|
+
input_source,
|
|
831
|
+
forward_input_value,
|
|
832
|
+
it,
|
|
833
|
+
contract_start,
|
|
834
|
+
):
|
|
835
|
+
forward_kwargs = original_kwargs.copy()
|
|
836
|
+
for internal_kw in self._internal_forward_kwargs:
|
|
837
|
+
forward_kwargs.pop(internal_kw, None)
|
|
838
|
+
|
|
839
|
+
logger.info("Executing original forward method...")
|
|
840
|
+
|
|
841
|
+
if input_param_name:
|
|
842
|
+
if input_param_name in forward_kwargs or input_source == ("kwargs", input_param_name):
|
|
843
|
+
forward_kwargs[input_param_name] = forward_input_value
|
|
844
|
+
elif input_source == ("args", 0) and args_list:
|
|
845
|
+
args_list[0] = forward_input_value
|
|
846
|
+
else:
|
|
847
|
+
forward_kwargs[input_param_name] = forward_input_value
|
|
848
|
+
else:
|
|
849
|
+
forward_kwargs['input'] = forward_input_value
|
|
650
850
|
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
wrapped_self.contract_successful = True if output is not None else False
|
|
654
|
-
wrapped_self.contract_result = output
|
|
655
|
-
wrapped_self.contract_exception = None
|
|
851
|
+
if input_param_name and input_param_name != "input" and "input" in forward_kwargs:
|
|
852
|
+
forward_kwargs.pop("input")
|
|
656
853
|
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
finally:
|
|
665
|
-
# Execute the original forward method with appropriate input
|
|
666
|
-
logger.info("Executing original forward method...")
|
|
667
|
-
|
|
668
|
-
# If contract was successful, use the processed input (after pre-validation and act, both optional)
|
|
669
|
-
# `current_input` at this stage is the result of the try block's processing up to the point of exception,
|
|
670
|
-
# or the full processing if successful.
|
|
671
|
-
# If contract failed, use original_input (fallback).
|
|
672
|
-
forward_input = current_input if wrapped_self.contract_successful else input
|
|
673
|
-
|
|
674
|
-
# Prepare kwargs for original_forward
|
|
675
|
-
forward_kwargs = kwargs.copy()
|
|
676
|
-
forward_kwargs['input'] = forward_input
|
|
677
|
-
|
|
678
|
-
try:
|
|
679
|
-
op_start = time.perf_counter()
|
|
680
|
-
output = original_forward(wrapped_self, **forward_kwargs)
|
|
681
|
-
finally:
|
|
682
|
-
wrapped_self._contract_timing[it]["forward_execution"] = time.perf_counter() - op_start
|
|
683
|
-
wrapped_self._contract_timing[it]["contract_execution"] = time.perf_counter() - contract_start
|
|
684
|
-
|
|
685
|
-
if not isinstance(output, output_type):
|
|
686
|
-
logger.error(f"Output type mismatch: {type(output)}")
|
|
687
|
-
if self.remedy_retry_params["graceful"]:
|
|
688
|
-
# In graceful mode, skip type mismatch error and return raw output
|
|
689
|
-
if hasattr(output_type, '_is_dynamic_model') and output_type._is_dynamic_model and hasattr(output, 'value'):
|
|
690
|
-
return output.value
|
|
691
|
-
return output
|
|
692
|
-
raise TypeError(
|
|
693
|
-
f"Expected output to be an instance of {output_type}, "
|
|
694
|
-
f"but got {type(output)}! Forward method must return an instance of {output_type}!"
|
|
695
|
-
)
|
|
696
|
-
if not wrapped_self.contract_successful:
|
|
697
|
-
logger.warning("Contract validation failed!")
|
|
698
|
-
else:
|
|
699
|
-
logger.success("Contract validation successful!")
|
|
854
|
+
try:
|
|
855
|
+
op_start = time.perf_counter()
|
|
856
|
+
output = original_forward(wrapped_self, *args_list, **forward_kwargs)
|
|
857
|
+
finally:
|
|
858
|
+
wrapped_self._contract_timing[it]["forward_execution"] = time.perf_counter() - op_start
|
|
859
|
+
wrapped_self._contract_timing[it]["contract_execution"] = time.perf_counter() - contract_start
|
|
860
|
+
return output
|
|
700
861
|
|
|
701
|
-
|
|
702
|
-
|
|
862
|
+
def _finalize_contract_output(self, output, output_type, wrapped_self):
|
|
863
|
+
if not isinstance(output, output_type):
|
|
864
|
+
logger.error(f"Output type mismatch: {type(output)}")
|
|
865
|
+
if self.remedy_retry_params["graceful"]:
|
|
866
|
+
if getattr(output_type, '_is_dynamic_model', False) and hasattr(output, 'value'):
|
|
867
|
+
return output.value
|
|
868
|
+
return output
|
|
869
|
+
msg = (
|
|
870
|
+
f"Expected output to be an instance of {output_type}, "
|
|
871
|
+
f"but got {type(output)}! Forward method must return an instance of {output_type}!"
|
|
872
|
+
)
|
|
873
|
+
UserMessage(msg)
|
|
874
|
+
raise TypeError(msg)
|
|
875
|
+
if not wrapped_self.contract_successful:
|
|
876
|
+
logger.warning("Contract validation failed!")
|
|
877
|
+
else:
|
|
878
|
+
logger.success("Contract validation successful!")
|
|
703
879
|
|
|
704
|
-
|
|
880
|
+
if getattr(output_type, '_is_dynamic_model', False):
|
|
881
|
+
return output.value
|
|
882
|
+
return output
|
|
705
883
|
|
|
884
|
+
def _contract_forward_impl(self, wrapped_self, original_forward, *args, **kwargs):
|
|
885
|
+
it, contract_start = self._start_contract_execution(wrapped_self)
|
|
886
|
+
sig = inspect.signature(original_forward)
|
|
887
|
+
input_param_name = self._find_input_param_name(sig)
|
|
888
|
+
args_list, kwargs_without_input, original_kwargs = self._prepare_forward_args(args, kwargs)
|
|
889
|
+
input_value, input_source = self._extract_input_value(args_list, kwargs_without_input, original_kwargs, input_param_name)
|
|
890
|
+
current_input_value = self._coerce_input_value(original_forward, input_value)
|
|
891
|
+
input_value = current_input_value
|
|
892
|
+
validation_kwargs = self._collect_validation_kwargs(wrapped_self, kwargs_without_input)
|
|
893
|
+
output_type = self._resolve_output_type(sig, original_forward)
|
|
894
|
+
|
|
895
|
+
output, current_input_value = self._run_contract_pipeline(
|
|
896
|
+
wrapped_self,
|
|
897
|
+
current_input_value,
|
|
898
|
+
output_type,
|
|
899
|
+
it,
|
|
900
|
+
validation_kwargs,
|
|
901
|
+
)
|
|
902
|
+
|
|
903
|
+
forward_input_value = current_input_value if wrapped_self.contract_successful else input_value
|
|
904
|
+
output = self._execute_forward_call(
|
|
905
|
+
wrapped_self,
|
|
906
|
+
original_forward,
|
|
907
|
+
args_list,
|
|
908
|
+
original_kwargs,
|
|
909
|
+
input_param_name,
|
|
910
|
+
input_source,
|
|
911
|
+
forward_input_value,
|
|
912
|
+
it,
|
|
913
|
+
contract_start,
|
|
914
|
+
)
|
|
915
|
+
|
|
916
|
+
return self._finalize_contract_output(output, output_type, wrapped_self)
|
|
917
|
+
|
|
918
|
+
def _build_wrapped_forward(self, original_forward):
|
|
919
|
+
def wrapped_forward(wrapped_self, *args, **kwargs):
|
|
920
|
+
return self._contract_forward_impl(wrapped_self, original_forward, *args, **kwargs)
|
|
921
|
+
|
|
922
|
+
return wrapped_forward
|
|
923
|
+
|
|
924
|
+
def _build_contract_perf_stats(self):
|
|
706
925
|
def contract_perf_stats(wrapped_self):
|
|
707
926
|
"""Analyzes and prints timing statistics across all forward calls."""
|
|
708
927
|
console = Console()
|
|
@@ -828,17 +1047,29 @@ class contract:
|
|
|
828
1047
|
|
|
829
1048
|
return stats
|
|
830
1049
|
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
1050
|
+
return contract_perf_stats
|
|
1051
|
+
|
|
1052
|
+
def __call__(self, cls):
|
|
1053
|
+
original_init = cls.__init__
|
|
1054
|
+
original_forward = cls.forward
|
|
834
1055
|
|
|
1056
|
+
cls.__init__ = self._build_wrapped_init(original_init)
|
|
1057
|
+
cls.forward = self._build_wrapped_forward(original_forward)
|
|
1058
|
+
cls.contract_perf_stats = self._build_contract_perf_stats()
|
|
835
1059
|
return cls
|
|
836
1060
|
|
|
837
1061
|
|
|
838
1062
|
class BaseStrategy(TypeValidationFunction):
|
|
839
|
-
def __init__(self, data_model: BaseModel, *
|
|
1063
|
+
def __init__(self, data_model: BaseModel, *_args, **kwargs):
|
|
840
1064
|
super().__init__(
|
|
841
|
-
retry_params=
|
|
1065
|
+
retry_params={
|
|
1066
|
+
"tries": 8,
|
|
1067
|
+
"delay": 0.015,
|
|
1068
|
+
"backoff": 1.25,
|
|
1069
|
+
"jitter": 0.0,
|
|
1070
|
+
"max_delay": 0.25,
|
|
1071
|
+
"graceful": False,
|
|
1072
|
+
},
|
|
842
1073
|
**kwargs,
|
|
843
1074
|
)
|
|
844
1075
|
super().register_expected_data_model(data_model, attach_to="output")
|
|
@@ -851,14 +1082,13 @@ class BaseStrategy(TypeValidationFunction):
|
|
|
851
1082
|
pass
|
|
852
1083
|
|
|
853
1084
|
def forward(self, *args, **kwargs):
|
|
854
|
-
|
|
1085
|
+
return super().forward(
|
|
855
1086
|
*args,
|
|
856
1087
|
payload=self.payload,
|
|
857
1088
|
template_suffix=self.template,
|
|
858
1089
|
response_format={"type": "json_object"},
|
|
859
1090
|
**kwargs,
|
|
860
1091
|
)
|
|
861
|
-
return result
|
|
862
1092
|
|
|
863
1093
|
@property
|
|
864
1094
|
def payload(self):
|
|
@@ -866,7 +1096,7 @@ class BaseStrategy(TypeValidationFunction):
|
|
|
866
1096
|
|
|
867
1097
|
@property
|
|
868
1098
|
def static_context(self):
|
|
869
|
-
raise NotImplementedError
|
|
1099
|
+
raise NotImplementedError
|
|
870
1100
|
|
|
871
1101
|
@property
|
|
872
1102
|
def template(self):
|
|
@@ -878,7 +1108,7 @@ class Strategy(Expression):
|
|
|
878
1108
|
super().__init__(*args, **kwargs)
|
|
879
1109
|
self.logger = logging.getLogger(__name__)
|
|
880
1110
|
|
|
881
|
-
def __new__(
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
return Strategy.load_module_class(
|
|
1111
|
+
def __new__(cls, module: str, *_args, **_kwargs):
|
|
1112
|
+
cls._module = module
|
|
1113
|
+
cls.module_path = 'symai.extended.strategies'
|
|
1114
|
+
return Strategy.load_module_class(cls.module)
|