symbolicai 0.21.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. symai/__init__.py +269 -173
  2. symai/backend/base.py +123 -110
  3. symai/backend/engines/drawing/engine_bfl.py +45 -44
  4. symai/backend/engines/drawing/engine_gpt_image.py +112 -97
  5. symai/backend/engines/embedding/engine_llama_cpp.py +63 -52
  6. symai/backend/engines/embedding/engine_openai.py +25 -21
  7. symai/backend/engines/execute/engine_python.py +19 -18
  8. symai/backend/engines/files/engine_io.py +104 -95
  9. symai/backend/engines/imagecaptioning/engine_blip2.py +28 -24
  10. symai/backend/engines/imagecaptioning/engine_llavacpp_client.py +102 -79
  11. symai/backend/engines/index/engine_pinecone.py +124 -97
  12. symai/backend/engines/index/engine_qdrant.py +1011 -0
  13. symai/backend/engines/index/engine_vectordb.py +84 -56
  14. symai/backend/engines/lean/engine_lean4.py +96 -52
  15. symai/backend/engines/neurosymbolic/__init__.py +41 -13
  16. symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_chat.py +330 -248
  17. symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_reasoning.py +329 -264
  18. symai/backend/engines/neurosymbolic/engine_cerebras.py +328 -0
  19. symai/backend/engines/neurosymbolic/engine_deepseekX_reasoning.py +118 -88
  20. symai/backend/engines/neurosymbolic/engine_google_geminiX_reasoning.py +344 -299
  21. symai/backend/engines/neurosymbolic/engine_groq.py +173 -115
  22. symai/backend/engines/neurosymbolic/engine_huggingface.py +114 -84
  23. symai/backend/engines/neurosymbolic/engine_llama_cpp.py +144 -118
  24. symai/backend/engines/neurosymbolic/engine_openai_gptX_chat.py +415 -307
  25. symai/backend/engines/neurosymbolic/engine_openai_gptX_reasoning.py +394 -231
  26. symai/backend/engines/ocr/engine_apilayer.py +23 -27
  27. symai/backend/engines/output/engine_stdout.py +10 -13
  28. symai/backend/engines/{webscraping → scrape}/engine_requests.py +101 -54
  29. symai/backend/engines/search/engine_openai.py +100 -88
  30. symai/backend/engines/search/engine_parallel.py +665 -0
  31. symai/backend/engines/search/engine_perplexity.py +44 -45
  32. symai/backend/engines/search/engine_serpapi.py +37 -34
  33. symai/backend/engines/speech_to_text/engine_local_whisper.py +54 -51
  34. symai/backend/engines/symbolic/engine_wolframalpha.py +15 -9
  35. symai/backend/engines/text_to_speech/engine_openai.py +20 -26
  36. symai/backend/engines/text_vision/engine_clip.py +39 -37
  37. symai/backend/engines/userinput/engine_console.py +5 -6
  38. symai/backend/mixin/__init__.py +13 -0
  39. symai/backend/mixin/anthropic.py +48 -38
  40. symai/backend/mixin/deepseek.py +6 -5
  41. symai/backend/mixin/google.py +7 -4
  42. symai/backend/mixin/groq.py +2 -4
  43. symai/backend/mixin/openai.py +140 -110
  44. symai/backend/settings.py +87 -20
  45. symai/chat.py +216 -123
  46. symai/collect/__init__.py +7 -1
  47. symai/collect/dynamic.py +80 -70
  48. symai/collect/pipeline.py +67 -51
  49. symai/collect/stats.py +161 -109
  50. symai/components.py +707 -360
  51. symai/constraints.py +24 -12
  52. symai/core.py +1857 -1233
  53. symai/core_ext.py +83 -80
  54. symai/endpoints/api.py +166 -104
  55. symai/extended/.DS_Store +0 -0
  56. symai/extended/__init__.py +46 -12
  57. symai/extended/api_builder.py +29 -21
  58. symai/extended/arxiv_pdf_parser.py +23 -14
  59. symai/extended/bibtex_parser.py +9 -6
  60. symai/extended/conversation.py +156 -126
  61. symai/extended/document.py +50 -30
  62. symai/extended/file_merger.py +57 -14
  63. symai/extended/graph.py +51 -32
  64. symai/extended/html_style_template.py +18 -14
  65. symai/extended/interfaces/blip_2.py +2 -3
  66. symai/extended/interfaces/clip.py +4 -3
  67. symai/extended/interfaces/console.py +9 -1
  68. symai/extended/interfaces/dall_e.py +4 -2
  69. symai/extended/interfaces/file.py +2 -0
  70. symai/extended/interfaces/flux.py +4 -2
  71. symai/extended/interfaces/gpt_image.py +16 -7
  72. symai/extended/interfaces/input.py +2 -1
  73. symai/extended/interfaces/llava.py +1 -2
  74. symai/extended/interfaces/{naive_webscraping.py → naive_scrape.py} +4 -3
  75. symai/extended/interfaces/naive_vectordb.py +9 -10
  76. symai/extended/interfaces/ocr.py +5 -3
  77. symai/extended/interfaces/openai_search.py +2 -0
  78. symai/extended/interfaces/parallel.py +30 -0
  79. symai/extended/interfaces/perplexity.py +2 -0
  80. symai/extended/interfaces/pinecone.py +12 -9
  81. symai/extended/interfaces/python.py +2 -0
  82. symai/extended/interfaces/serpapi.py +3 -1
  83. symai/extended/interfaces/terminal.py +2 -4
  84. symai/extended/interfaces/tts.py +3 -2
  85. symai/extended/interfaces/whisper.py +3 -2
  86. symai/extended/interfaces/wolframalpha.py +2 -1
  87. symai/extended/metrics/__init__.py +11 -1
  88. symai/extended/metrics/similarity.py +14 -13
  89. symai/extended/os_command.py +39 -29
  90. symai/extended/packages/__init__.py +29 -3
  91. symai/extended/packages/symdev.py +51 -43
  92. symai/extended/packages/sympkg.py +41 -35
  93. symai/extended/packages/symrun.py +63 -50
  94. symai/extended/repo_cloner.py +14 -12
  95. symai/extended/seo_query_optimizer.py +15 -13
  96. symai/extended/solver.py +116 -91
  97. symai/extended/summarizer.py +12 -10
  98. symai/extended/taypan_interpreter.py +17 -18
  99. symai/extended/vectordb.py +122 -92
  100. symai/formatter/__init__.py +9 -1
  101. symai/formatter/formatter.py +51 -47
  102. symai/formatter/regex.py +70 -69
  103. symai/functional.py +325 -176
  104. symai/imports.py +190 -147
  105. symai/interfaces.py +57 -28
  106. symai/memory.py +45 -35
  107. symai/menu/screen.py +28 -19
  108. symai/misc/console.py +66 -56
  109. symai/misc/loader.py +8 -5
  110. symai/models/__init__.py +17 -1
  111. symai/models/base.py +395 -236
  112. symai/models/errors.py +1 -2
  113. symai/ops/__init__.py +32 -22
  114. symai/ops/measures.py +24 -25
  115. symai/ops/primitives.py +1149 -731
  116. symai/post_processors.py +58 -50
  117. symai/pre_processors.py +86 -82
  118. symai/processor.py +21 -13
  119. symai/prompts.py +764 -685
  120. symai/server/huggingface_server.py +135 -49
  121. symai/server/llama_cpp_server.py +21 -11
  122. symai/server/qdrant_server.py +206 -0
  123. symai/shell.py +100 -42
  124. symai/shellsv.py +700 -492
  125. symai/strategy.py +630 -346
  126. symai/symbol.py +368 -322
  127. symai/utils.py +100 -78
  128. {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/METADATA +22 -10
  129. symbolicai-1.1.0.dist-info/RECORD +168 -0
  130. symbolicai-0.21.0.dist-info/RECORD +0 -162
  131. {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/WHEEL +0 -0
  132. {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/entry_points.txt +0 -0
  133. {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/licenses/LICENSE +0 -0
  134. {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/top_level.txt +0 -0
symai/strategy.py CHANGED
@@ -2,7 +2,8 @@ import inspect
2
2
  import logging
3
3
  import time
4
4
  from collections import defaultdict
5
- from typing import Callable
5
+ from collections.abc import Callable
6
+ from typing import Any, ClassVar
6
7
 
7
8
  import numpy as np
8
9
  from beartype import beartype
@@ -14,9 +15,9 @@ from rich.panel import Panel
14
15
  from rich.table import Table
15
16
 
16
17
  from .components import Function
17
- from .models import (LLMDataModel, TypeValidationError,
18
- build_dynamic_llm_datamodel)
18
+ from .models import LLMDataModel, TypeValidationError, build_dynamic_llm_datamodel
19
19
  from .symbol import Expression
20
+ from .utils import UserMessage
20
21
 
21
22
 
22
23
  class ValidationFunction(Function):
@@ -28,19 +29,20 @@ class ValidationFunction(Function):
28
29
  • Pause/backoff logic
29
30
  • Error simplification
30
31
  """
32
+
31
33
  # Have some default retry params that don't add overhead
32
- _default_retry_params = dict(
33
- tries=8,
34
- delay=0.015,
35
- backoff=1.25,
36
- jitter=0.0,
37
- max_delay=0.25,
38
- graceful=False
39
- )
34
+ _default_retry_params: ClassVar[dict[str, int | float | bool]] = {
35
+ "tries": 8,
36
+ "delay": 0.015,
37
+ "backoff": 1.25,
38
+ "jitter": 0.0,
39
+ "max_delay": 0.25,
40
+ "graceful": False,
41
+ }
40
42
 
41
43
  def __init__(
42
44
  self,
43
- retry_params: dict[str, int | float | bool] = None,
45
+ retry_params: dict[str, int | float | bool] | None = None,
44
46
  verbose: bool = False,
45
47
  *args,
46
48
  **kwargs,
@@ -92,10 +94,7 @@ class ValidationFunction(Function):
92
94
  seed = 42
93
95
 
94
96
  rnd = np.random.RandomState(seed=seed)
95
- seeds = rnd.randint(
96
- 0, np.iinfo(np.int16).max, size=num_seeds, dtype=np.int16
97
- ).tolist()
98
- return seeds
97
+ return rnd.randint(0, np.iinfo(np.int16).max, size=num_seeds, dtype=np.int16).tolist()
99
98
 
100
99
  def simplify_validation_errors(self, error: ValidationError) -> str:
101
100
  """
@@ -123,21 +122,25 @@ class ValidationFunction(Function):
123
122
  return "\n".join(simplified_errors)
124
123
 
125
124
  def _pause(self, attempt):
126
- base = self.retry_params['delay'] * (self.retry_params['backoff'] ** attempt)
127
- jit = (np.random.uniform(*self.retry_params['jitter'])
128
- if isinstance(self.retry_params['jitter'], tuple)
129
- else self.retry_params['jitter'])
130
- _delay = min(base + jit, self.retry_params['max_delay'])
125
+ base = self.retry_params["delay"] * (self.retry_params["backoff"] ** attempt)
126
+ jit = (
127
+ np.random.uniform(*self.retry_params["jitter"])
128
+ if isinstance(self.retry_params["jitter"], tuple)
129
+ else self.retry_params["jitter"]
130
+ )
131
+ _delay = min(base + jit, self.retry_params["max_delay"])
131
132
  time.sleep(_delay)
132
133
 
133
- def remedy_prompt(self, *args, **kwargs):
134
+ def remedy_prompt(self, *_args, **_kwargs):
134
135
  """
135
136
  Abstract or base remedy prompt method.
136
137
  Child classes typically override this to include additional context needed for correction.
137
138
  """
138
- raise NotImplementedError("Each child class needs its own remedy_prompt implementation.")
139
+ msg = "Each child class needs its own remedy_prompt implementation."
140
+ UserMessage(msg)
141
+ raise NotImplementedError(msg)
139
142
 
140
- def display_panel(self, content, title, border_style="cyan", style="#f0eee6", padding=(1,2)):
143
+ def display_panel(self, content, title, border_style="cyan", style="#f0eee6", padding=(1, 2)):
141
144
  """
142
145
  Display content in a rich panel with consistent formatting.
143
146
 
@@ -149,11 +152,13 @@ class ValidationFunction(Function):
149
152
  padding: Padding for the panel (default: (1,2))
150
153
  """
151
154
  body = escape(content)
152
- panel = Panel.fit(body, title=title, padding=padding, border_style=border_style, style=style)
155
+ panel = Panel.fit(
156
+ body, title=title, padding=padding, border_style=border_style, style=style
157
+ )
153
158
  self.console.print(panel)
154
159
 
155
160
  def forward(self, *args, **kwargs):
156
- return super().forward(*args, **kwargs) # Just propagate to Function
161
+ return super().forward(*args, **kwargs) # Just propagate to Function
157
162
 
158
163
 
159
164
  class TypeValidationFunction(ValidationFunction):
@@ -163,6 +168,7 @@ class TypeValidationFunction(ValidationFunction):
163
168
  if a user provides a callable designed to semantically validate the
164
169
  structure of the type-validated data.
165
170
  """
171
+
166
172
  def __init__(
167
173
  self,
168
174
  retry_params: dict[str, int | float | bool] = ValidationFunction._default_retry_params,
@@ -171,21 +177,29 @@ class TypeValidationFunction(ValidationFunction):
171
177
  *args,
172
178
  **kwargs,
173
179
  ):
174
- super().__init__(retry_params=retry_params, verbose=verbose, *args, **kwargs)
180
+ super().__init__(*args, retry_params=retry_params, verbose=verbose, **kwargs)
175
181
  self.input_data_model = None
176
182
  self.output_data_model = None
177
183
  self.accumulate_errors = accumulate_errors
178
184
  self.verbose = verbose
179
185
 
180
- def register_expected_data_model(self, data_model: LLMDataModel, attach_to: str, override: bool = False):
181
- assert attach_to in ["input", "output"], f"Invalid attach_to value: {attach_to}; must be either 'input' or 'output'"
186
+ def register_expected_data_model(
187
+ self, data_model: LLMDataModel, attach_to: str, override: bool = False
188
+ ):
189
+ assert attach_to in ["input", "output"], (
190
+ f"Invalid attach_to value: {attach_to}; must be either 'input' or 'output'"
191
+ )
182
192
  if attach_to == "input":
183
193
  if self.input_data_model is not None and not override:
184
- raise ValueError("There is already a data model attached to the input. If you want to override it, set `override=True`.")
194
+ msg = "There is already a data model attached to the input. If you want to override it, set `override=True`."
195
+ UserMessage(msg)
196
+ raise ValueError(msg)
185
197
  self.input_data_model = data_model
186
198
  elif attach_to == "output":
187
199
  if self.output_data_model is not None and not override:
188
- raise ValueError("There is already a data model attached to the output. If you want to override it, set `override=True`.")
200
+ msg = "There is already a data model attached to the output. If you want to override it, set `override=True`."
201
+ UserMessage(msg)
202
+ raise ValueError(msg)
189
203
  self.output_data_model = data_model
190
204
 
191
205
  def remedy_prompt(self, prompt: str, output: str, errors: str) -> str:
@@ -200,12 +214,12 @@ Your prompt was:
200
214
 
201
215
  The input data model is:
202
216
  <input_data_model>
203
- {self.input_data_model.simplify_json_schema() if self.input_data_model is not None else 'N/A'}
217
+ {self.input_data_model.simplify_json_schema() if self.input_data_model is not None else "N/A"}
204
218
  </input_data_model>
205
219
 
206
220
  The given input was:
207
221
  <input>
208
- {str(self.input_data_model) if self.input_data_model is not None else 'N/A'}
222
+ {str(self.input_data_model) if self.input_data_model is not None else "N/A"}
209
223
  </input>
210
224
 
211
225
  The output data model is:
@@ -247,12 +261,12 @@ You are given the following prompt:
247
261
 
248
262
  The input data model is:
249
263
  <input_data_model>
250
- {self.input_data_model.simplify_json_schema() if self.input_data_model is not None else 'N/A'}
264
+ {self.input_data_model.simplify_json_schema() if self.input_data_model is not None else "N/A"}
251
265
  </input_data_model>
252
266
 
253
267
  The given input is:
254
268
  <input>
255
- {str(self.input_data_model) if self.input_data_model is not None else 'N/A'}
269
+ {str(self.input_data_model) if self.input_data_model is not None else "N/A"}
256
270
  </input>
257
271
 
258
272
  The output data model is:
@@ -268,118 +282,185 @@ Important guidelines:
268
282
  </guidelines>
269
283
  """
270
284
 
271
- def forward(self, prompt: str, f_semantic_conditions: list[Callable] | None = None, *args, **kwargs):
285
+ def _ensure_output_model(self):
272
286
  if self.output_data_model is None:
273
- raise ValueError("While the input data model is optional, the output data model must be provided. Please register it before calling the `forward` method.")
274
- validation_context = kwargs.pop('validation_context', {})
275
- # Force JSON mode
276
- kwargs["response_format"] = {"type": "json_object"}
277
- logger.info("Initializing validation…")
287
+ msg = (
288
+ "While the input data model is optional, the output data model must be provided. "
289
+ "Please register it before calling the `forward` method."
290
+ )
291
+ UserMessage(msg)
292
+ raise ValueError(msg)
293
+
294
+ def _display_verbose_panels(self, prompt: str):
295
+ if not self.verbose:
296
+ return
297
+ for label, body in [
298
+ ("Prompt", prompt),
299
+ (
300
+ "Input data model",
301
+ self.input_data_model.simplify_json_schema() if self.input_data_model else "N/A",
302
+ ),
303
+ ("Output data model", self.output_data_model.simplify_json_schema()),
304
+ ]:
305
+ self.display_panel(body, title=label)
306
+
307
+ def _check_semantic_conditions(
308
+ self,
309
+ result,
310
+ f_semantic_conditions: list[Callable] | None,
311
+ ) -> str | None:
312
+ if f_semantic_conditions is None:
313
+ return None
314
+ try:
315
+ assert all(
316
+ f(
317
+ result
318
+ if not getattr(self.output_data_model, "_is_dynamic_model", False)
319
+ else result.value
320
+ )
321
+ for f in f_semantic_conditions
322
+ )
323
+ except Exception as err:
324
+ return f"Semantic validation failed with:\n{err!s}"
325
+ return None
326
+
327
+ def _format_validation_error(self, error: Exception) -> str:
328
+ if isinstance(error, ValidationError):
329
+ return self.simplify_validation_errors(error)
330
+ return str(error)
331
+
332
+ def _handle_failed_validation_attempt(
333
+ self,
334
+ attempt_index: int,
335
+ prompt: str,
336
+ json_str: str,
337
+ errors: list[str],
338
+ error: Exception,
339
+ remedy_seeds: list[Any],
340
+ kwargs: dict,
341
+ ) -> str:
342
+ logger.info(f"Validation attempt {attempt_index + 1} failed, pausing before retry…")
343
+ self._pause(attempt_index)
344
+ error_str = self._format_validation_error(error)
345
+ errors.append(error_str)
346
+
347
+ logger.error("Validation errors identified!")
278
348
  if self.verbose:
279
- for label, body in [
280
- ("Prompt", prompt),
281
- ("Input data model", self.input_data_model.simplify_json_schema() if self.input_data_model else 'N/A'),
282
- ("Output data model", self.output_data_model.simplify_json_schema()),
283
- ]:
284
- self.display_panel(body, title=label)
285
-
286
- # Zero shot the task
287
- context = self.zero_shot_prompt(prompt=prompt)
288
- json_str = super().forward(context, *args, **kwargs).value
349
+ errors_report = "\n".join(errors) if self.accumulate_errors else error_str
350
+ title = f"Validation Errors ({'accumulated errors' if self.accumulate_errors else 'last error'})"
351
+ self.display_panel(errors_report, title=title, border_style="red")
352
+
353
+ logger.info("Updating remedy function context…")
354
+ context = self.remedy_prompt(
355
+ prompt=prompt,
356
+ output=json_str,
357
+ errors="\n".join(errors) if self.accumulate_errors else error_str,
358
+ )
359
+ self.remedy_function.clear()
360
+ self.remedy_function.adapt(context)
361
+ if self.verbose:
362
+ self.display_panel(self.remedy_function.dynamic_context, title="New Context")
289
363
 
290
- remedy_seeds = self.prepare_seeds(self.retry_params["tries"] + 1, **kwargs)
291
- logger.info(f"Prepared {len(remedy_seeds)} remedy seeds for validation attempts…")
364
+ json_str = self.remedy_function(seed=remedy_seeds[attempt_index], **kwargs).value
365
+ logger.info("Applied remedy function with updated context!")
366
+ return json_str
292
367
 
368
+ def _run_validation_attempts(
369
+ self,
370
+ prompt: str,
371
+ f_semantic_conditions: list[Callable] | None,
372
+ validation_context: dict,
373
+ remedy_seeds: list[Any],
374
+ json_str: str,
375
+ kwargs: dict,
376
+ ) -> tuple[Any | None, str, list[str]]:
377
+ errors: list[str] = []
293
378
  result = None
294
- errors = []
295
- for i in range(self.retry_params["tries"] + 1):
296
- if i != self.retry_params["tries"]:
297
- logger.info(f"Attempt {i+1}/{self.retry_params['tries']}: Attempting validation…")
379
+ total_attempts = self.retry_params["tries"] + 1
380
+ for attempt in range(total_attempts):
381
+ if attempt != self.retry_params["tries"]:
382
+ logger.info(
383
+ f"Attempt {attempt + 1}/{self.retry_params['tries']}: Attempting validation…"
384
+ )
298
385
  try:
299
- #@NOTE: We use strict=False (default) to allow Pydantic's type coercion.
300
- # This handles common LLM output issues like:
301
- # - String keys "123" -> integer keys 123 for dict[int, ...] types
302
- # - String numbers "42" -> int 42
303
- # - Float 1.0 -> int 1
304
- # These coercions are safe and helpful when dealing with JSON from LLMs,
305
- # while still catching actual type errors (e.g., "not_a_number" -> int fails).
306
- result = self.output_data_model.model_validate_json(json_str, strict=False, context=validation_context)
307
- if f_semantic_conditions is not None:
308
- try:
309
- assert all(
310
- f(result if not getattr(self.output_data_model, '_is_dynamic_model', False) else result.value)
311
- for f in f_semantic_conditions
312
- )
313
- except Exception as e:
314
- # If we are in the last attempt and semantic validation fails, result will be None and we propagate the error
315
- if i == self.retry_params["tries"]:
316
- result = None
317
- errors.append(f"Semantic validation failed with:\n{str(e)}")
318
- break # We break to avoid going into the remedy loop
319
- raise e
386
+ result = self.output_data_model.model_validate_json(
387
+ json_str,
388
+ strict=False,
389
+ context=validation_context,
390
+ )
391
+ semantic_error = self._check_semantic_conditions(result, f_semantic_conditions)
392
+ if semantic_error is not None:
393
+ if attempt == total_attempts - 1:
394
+ result = None
395
+ errors.append(semantic_error)
396
+ break
397
+ raise AssertionError(semantic_error)
320
398
  break
321
- except Exception as e:
322
- logger.info(f"Validation attempt {i+1} failed, pausing before retry…")
399
+ except Exception as error:
400
+ json_str = self._handle_failed_validation_attempt(
401
+ attempt,
402
+ prompt,
403
+ json_str,
404
+ errors,
405
+ error,
406
+ remedy_seeds,
407
+ kwargs,
408
+ )
409
+ return result, json_str, errors
410
+
411
+ def _handle_validation_failure(self, prompt: str, json_str: str, errors: list[str]):
412
+ logger.error("All validation attempts failed!")
413
+ if self.retry_params["graceful"]:
414
+ return
415
+ raise TypeValidationError(
416
+ prompt=prompt,
417
+ result=json_str,
418
+ violations=errors,
419
+ )
323
420
 
324
- self._pause(i)
421
+ def forward(
422
+ self, prompt: str, f_semantic_conditions: list[Callable] | None = None, *args, **kwargs
423
+ ):
424
+ self._ensure_output_model()
425
+ validation_context = kwargs.pop("validation_context", {})
426
+ kwargs["response_format"] = {"type": "json_object"}
427
+ logger.info("Initializing validation…")
428
+ self._display_verbose_panels(prompt)
325
429
 
326
- if isinstance(e, ValidationError):
327
- error_str = self.simplify_validation_errors(e)
328
- else:
329
- error_str = str(e)
330
-
331
- errors.append(error_str)
332
-
333
- logger.error(f"Validation errors identified!")
334
- if self.verbose:
335
- self.display_panel(
336
- "\n".join(errors) if self.accumulate_errors else error_str,
337
- title=f"Validation Errors ({'accumulated errors' if self.accumulate_errors else 'last error'})",
338
- border_style="red"
339
- )
340
-
341
- # Update remedy function context
342
- logger.info("Updating remedy function context…")
343
- context = self.remedy_prompt(prompt=prompt, output=json_str, errors="\n".join(errors) if self.accumulate_errors else error_str)
344
- self.remedy_function.clear()
345
- self.remedy_function.adapt(context)
346
- if self.verbose:
347
- self.display_panel(
348
- self.remedy_function.dynamic_context,
349
- title="New Context"
350
- )
351
-
352
- # Apply the remedy function
353
- json_str = self.remedy_function(seed=remedy_seeds[i], **kwargs).value
354
- logger.info("Applied remedy function with updated context!")
430
+ context = self.zero_shot_prompt(prompt=prompt)
431
+ json_str = super().forward(context, *args, **kwargs).value
432
+
433
+ remedy_seeds = self.prepare_seeds(self.retry_params["tries"] + 1, **kwargs)
434
+ logger.info(f"Prepared {len(remedy_seeds)} remedy seeds for validation attempts…")
435
+
436
+ result, json_str, errors = self._run_validation_attempts(
437
+ prompt,
438
+ f_semantic_conditions,
439
+ validation_context,
440
+ remedy_seeds,
441
+ json_str,
442
+ kwargs,
443
+ )
355
444
 
356
445
  if result is None:
357
- logger.error(f"All validation attempts failed!")
358
- if self.retry_params['graceful']:
359
- return
360
- raise TypeValidationError(
361
- prompt=prompt,
362
- result=json_str,
363
- violations=errors,
364
- )
446
+ return self._handle_validation_failure(prompt, json_str, errors)
365
447
 
366
448
  logger.success("Validation completed successfully!")
367
- # Clear artifacts from the remedy function
368
449
  self.remedy_function.clear()
369
-
370
450
  return result
371
451
 
372
452
 
373
453
  @beartype
374
454
  class contract:
375
- _default_remedy_retry_params = dict(
376
- tries=8,
377
- delay=0.015,
378
- backoff=1.25,
379
- jitter=0.0,
380
- max_delay=0.25,
381
- graceful=False
382
- )
455
+ _default_remedy_retry_params: ClassVar[dict[str, int | float | bool]] = {
456
+ "tries": 8,
457
+ "delay": 0.015,
458
+ "backoff": 1.25,
459
+ "jitter": 0.0,
460
+ "max_delay": 0.25,
461
+ "graceful": False,
462
+ }
463
+ _internal_forward_kwargs: ClassVar[set[str]] = {"validation_context"}
383
464
 
384
465
  def __init__(
385
466
  self,
@@ -389,37 +470,47 @@ class contract:
389
470
  verbose: bool = False,
390
471
  remedy_retry_params: dict[str, int | float | bool] = _default_remedy_retry_params,
391
472
  ):
392
- '''
473
+ """
393
474
  A contract class decorator inspired by DbC principles. It ensures that the function's input and output
394
475
  adhere to specified data models both syntactically and semantically. This implementation includes retry
395
476
  logic to handle transient errors and gracefully handle failures.
396
- '''
477
+ """
397
478
  self.pre_remedy = pre_remedy
398
479
  self.post_remedy = post_remedy
399
480
  self.remedy_retry_params = remedy_retry_params
400
- self.f_type_validation_remedy = TypeValidationFunction(accumulate_errors=accumulate_errors, verbose=verbose, retry_params=remedy_retry_params)
481
+ self.f_type_validation_remedy = TypeValidationFunction(
482
+ accumulate_errors=accumulate_errors, verbose=verbose, retry_params=remedy_retry_params
483
+ )
401
484
 
402
485
  if not verbose:
403
486
  logger.disable(__name__)
404
487
  else:
405
488
  logger.enable(__name__)
406
489
 
407
- def _is_valid_input(self, input):
408
- if input is None:
490
+ def _is_valid_input(self, input_value):
491
+ if input_value is None:
409
492
  logger.error("No `input` argument provided!")
410
- raise ValueError("Please provide an `input` argument.")
411
- if not isinstance(input, LLMDataModel):
412
- logger.error(f"Invalid input type: {type(input)}")
413
- raise TypeError(f"Expected input to be of type `LLMDataModel`, got {type(input)}")
493
+ msg = "Please provide an `input` argument."
494
+ UserMessage(msg)
495
+ raise ValueError(msg)
496
+ if not isinstance(input_value, LLMDataModel):
497
+ logger.error(f"Invalid input type: {type(input_value)}")
498
+ msg = f"Expected input to be of type `LLMDataModel`, got {type(input_value)}"
499
+ UserMessage(msg)
500
+ raise TypeError(msg)
414
501
  return True
415
502
 
416
503
  def _is_valid_output(self, output_type):
417
504
  if output_type == inspect._empty:
418
505
  logger.error("Missing return type annotation!")
419
- raise ValueError("The contract requires a return type annotation.")
506
+ msg = "The contract requires a return type annotation."
507
+ UserMessage(msg)
508
+ raise ValueError(msg)
420
509
  if not issubclass(output_type, LLMDataModel):
421
510
  logger.error(f"Invalid return type: {output_type}")
422
- raise TypeError("The return type annotation must be a subclass of `LLMDataModel`.")
511
+ msg = "The return type annotation must be a subclass of `LLMDataModel`."
512
+ UserMessage(msg)
513
+ raise TypeError(msg)
423
514
  return True
424
515
 
425
516
  def _try_dynamic_type_annotation(self, original_forward, *, context: str = "input"):
@@ -428,84 +519,122 @@ class contract:
428
519
  )
429
520
  sig = inspect.signature(original_forward)
430
521
  try:
522
+ resolved_param = None
431
523
  # Fallback: look at the relevant part of the function signature
432
524
  # depending on whether we deal with an *input* or *output*
433
525
  if context == "input":
434
526
  param = sig.parameters.get("input")
527
+ if param is None:
528
+ for candidate in sig.parameters.values():
529
+ if candidate.name == "self":
530
+ continue
531
+ if candidate.kind in (
532
+ inspect.Parameter.POSITIONAL_ONLY,
533
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
534
+ inspect.Parameter.KEYWORD_ONLY,
535
+ ):
536
+ param = candidate
537
+ break
538
+ resolved_param = param
435
539
  if param is None or param.annotation == inspect._empty:
436
- raise TypeError("Failed to infer type from input parameter annotation")
540
+ msg = "Failed to infer type from input parameter annotation"
541
+ UserMessage(msg)
542
+ raise TypeError(msg)
437
543
  dynamic_model = build_dynamic_llm_datamodel(param.annotation)
438
544
  else: # context == "output"
439
545
  if sig.return_annotation == inspect._empty:
440
- raise TypeError("Failed to infer type from return annotation")
546
+ msg = "Failed to infer type from return annotation"
547
+ UserMessage(msg)
548
+ raise TypeError(msg)
441
549
  dynamic_model = build_dynamic_llm_datamodel(sig.return_annotation)
442
- except Exception:
443
- logger.exception(f"Failed to build dynamic LLMDataModel from {sig.parameters.get('input')}!")
444
- raise TypeError(
550
+ except Exception as err:
551
+ logger.exception(f"Failed to build dynamic LLMDataModel from {resolved_param}!")
552
+ msg = (
445
553
  "The type annotation must be a subclass of `LLMDataModel` or a "
446
554
  "valid Python typing object supported by Pydantic."
447
555
  )
556
+ UserMessage(msg)
557
+ raise TypeError(msg) from err
448
558
 
449
559
  dynamic_model._is_dynamic_model = True
450
560
  return dynamic_model
451
561
 
452
562
  def _try_remedy_with_exception(self, prompt, f_semantic_conditions, **remedy_kwargs):
453
563
  try:
454
- data_model = self.f_type_validation_remedy(prompt, f_semantic_conditions=f_semantic_conditions, **remedy_kwargs)
564
+ data_model = self.f_type_validation_remedy(
565
+ prompt, f_semantic_conditions=f_semantic_conditions, **remedy_kwargs
566
+ )
455
567
  except Exception as e:
456
- logger.error(f"Type validation failed with exception!")
568
+ logger.error("Type validation failed with exception!")
457
569
  raise e
458
570
  return data_model
459
571
 
460
- def _validate_input(self, wrapped_self, input, it, **remedy_kwargs):
572
+ def _validate_input(self, wrapped_self, input_value, it, **remedy_kwargs):
461
573
  logger.info("Starting input validation...")
462
574
  if self.pre_remedy:
463
575
  logger.info("Validating pre-conditions with remedy...")
464
- if not hasattr(wrapped_self, 'pre'):
576
+ if not hasattr(wrapped_self, "pre"):
465
577
  logger.error("Pre-condition function not defined!")
466
- raise Exception("Pre-condition function not defined. Please define a `pre` method if you want to enforce pre-conditions through a remedy.")
578
+ msg = "Pre-condition function not defined. Please define a `pre` method if you want to enforce pre-conditions through a remedy."
579
+ UserMessage(msg)
580
+ raise Exception(msg)
467
581
 
468
582
  op_start = time.perf_counter()
469
583
  try:
470
- assert wrapped_self.pre(input)
584
+ assert wrapped_self.pre(input_value)
471
585
  logger.success("Pre-condition validation successful!")
472
- return input
473
- except Exception as e:
586
+ return input_value
587
+ except Exception:
474
588
  logger.exception("Pre-condition validation failed!")
475
- self.f_type_validation_remedy.register_expected_data_model(input, attach_to="output", override=True)
476
- input = self._try_remedy_with_exception(prompt=wrapped_self.prompt, f_semantic_conditions=[wrapped_self.pre], **remedy_kwargs)
589
+ self.f_type_validation_remedy.register_expected_data_model(
590
+ input_value, attach_to="output", override=True
591
+ )
592
+ input_value = self._try_remedy_with_exception(
593
+ prompt=wrapped_self.prompt,
594
+ f_semantic_conditions=[wrapped_self.pre],
595
+ **remedy_kwargs,
596
+ )
477
597
  finally:
478
- wrapped_self._contract_timing[it]["input_validation"] = time.perf_counter() - op_start
479
- return input
480
- else:
481
- if hasattr(wrapped_self, 'pre'):
482
- logger.info("Validating pre-conditions without remedy...")
483
- op_start = time.perf_counter()
484
- try:
485
- assert wrapped_self.pre(input)
486
- except Exception as e:
487
- logger.exception(f"Pre-condition validation failed")
488
- raise e
489
- finally:
490
- wrapped_self._contract_timing[it]["input_validation"] = time.perf_counter() - op_start
491
- logger.success("Pre-condition validation successful!")
492
- return input
493
- logger.info("Skip; no pre-condition validation was required!")
494
- return input
598
+ wrapped_self._contract_timing[it]["input_validation"] = (
599
+ time.perf_counter() - op_start
600
+ )
601
+ return input_value
602
+ if hasattr(wrapped_self, "pre"):
603
+ logger.info("Validating pre-conditions without remedy...")
604
+ op_start = time.perf_counter()
605
+ try:
606
+ assert wrapped_self.pre(input_value)
607
+ except Exception as e:
608
+ logger.exception("Pre-condition validation failed")
609
+ raise e
610
+ finally:
611
+ wrapped_self._contract_timing[it]["input_validation"] = (
612
+ time.perf_counter() - op_start
613
+ )
614
+ logger.success("Pre-condition validation successful!")
615
+ return input_value
616
+ logger.info("Skip; no pre-condition validation was required!")
617
+ return input_value
495
618
 
496
- def _validate_output(self, wrapped_self, input, output, it, **remedy_kwargs):
619
+ def _validate_output(self, wrapped_self, input_value, output, it, **remedy_kwargs):
497
620
  logger.info("Starting output validation...")
498
- self.f_type_validation_remedy.register_expected_data_model(input, attach_to="input", override=True)
499
- self.f_type_validation_remedy.register_expected_data_model(output, attach_to="output", override=True)
621
+ self.f_type_validation_remedy.register_expected_data_model(
622
+ input_value, attach_to="input", override=True
623
+ )
624
+ self.f_type_validation_remedy.register_expected_data_model(
625
+ output, attach_to="output", override=True
626
+ )
500
627
 
501
628
  op_start = time.perf_counter()
502
629
  try:
503
630
  logger.info("Getting a valid output type...")
504
- output = self._try_remedy_with_exception(prompt=wrapped_self.prompt, f_semantic_conditions=None, **remedy_kwargs)
505
- if output is None: # output is None when graceful mode is enabled
631
+ output = self._try_remedy_with_exception(
632
+ prompt=wrapped_self.prompt, f_semantic_conditions=None, **remedy_kwargs
633
+ )
634
+ if output is None: # output is None when graceful mode is enabled
506
635
  return output
507
636
  except Exception as e:
508
- logger.exception(f"Type creation failed!")
637
+ logger.exception("Type creation failed!")
509
638
  raise e
510
639
  finally:
511
640
  wrapped_self._contract_timing[it]["output_validation"] = time.perf_counter() - op_start
@@ -515,88 +644,119 @@ class contract:
515
644
  logger.info("Validating post-conditions with remedy...")
516
645
  if not hasattr(wrapped_self, "post"):
517
646
  logger.error("Post-condition function not defined!")
518
- raise Exception("Post-condition function not defined. Please define a `post` method if you want to enforce post-conditions through a remedy.")
647
+ msg = "Post-condition function not defined. Please define a `post` method if you want to enforce post-conditions through a remedy."
648
+ UserMessage(msg)
649
+ raise Exception(msg)
519
650
 
520
651
  op_start = time.perf_counter()
521
652
  try:
522
653
  assert wrapped_self.post(output)
523
654
  logger.success("Post-condition validation successful!")
524
655
  return output
656
+ except Exception:
657
+ logger.exception("Post-condition validation failed!")
658
+ output = self._try_remedy_with_exception(
659
+ prompt=wrapped_self.prompt,
660
+ f_semantic_conditions=[wrapped_self.post],
661
+ **remedy_kwargs,
662
+ )
663
+ finally:
664
+ wrapped_self._contract_timing[it]["output_validation"] += (
665
+ time.perf_counter() - op_start
666
+ )
667
+ logger.success("Post-condition validation successful!")
668
+ return output
669
+ if hasattr(wrapped_self, "post"):
670
+ logger.info("Validating post-conditions without remedy...")
671
+ op_start = time.perf_counter()
672
+ try:
673
+ assert wrapped_self.post(output)
525
674
  except Exception as e:
526
675
  logger.exception("Post-condition validation failed!")
527
- output = self._try_remedy_with_exception(prompt=wrapped_self.prompt, f_semantic_conditions=[wrapped_self.post], **remedy_kwargs)
676
+ raise e
528
677
  finally:
529
- wrapped_self._contract_timing[it]["output_validation"] += (time.perf_counter() - op_start)
678
+ wrapped_self._contract_timing[it]["output_validation"] = (
679
+ time.perf_counter() - op_start
680
+ )
530
681
  logger.success("Post-condition validation successful!")
531
682
  return output
532
- else:
533
- if hasattr(wrapped_self, "post"):
534
- logger.info("Validating post-conditions without remedy...")
535
- op_start = time.perf_counter()
536
- try:
537
- assert wrapped_self.post(output)
538
- except Exception as e:
539
- logger.exception("Post-condition validation failed!")
540
- raise e
541
- finally:
542
- wrapped_self._contract_timing[it]["output_validation"] = time.perf_counter() - op_start
543
- logger.success("Post-condition validation successful!")
544
- return output
545
683
  logger.info("Skip; no post-condition validation was required!")
546
684
  return output
547
685
 
548
686
  def _validate_act_method(self, act_method):
549
687
  act_sig = inspect.signature(act_method)
550
688
  params = list(act_sig.parameters.values())
551
- if not params or params[0].name != 'input':
552
- raise TypeError("'act' method first parameter must be named 'input'")
553
- if params[0].annotation == inspect._empty:
554
- raise TypeError("'act' method parameter 'input' must have a type annotation'")
689
+
690
+ first_param = None
691
+ for param in params:
692
+ if param.name == "self":
693
+ continue
694
+ if param.kind in (
695
+ inspect.Parameter.POSITIONAL_ONLY,
696
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
697
+ ):
698
+ first_param = param
699
+ break
700
+
701
+ if first_param is None:
702
+ msg = "'act' method must accept at least one positional parameter after `self`."
703
+ UserMessage(msg)
704
+ raise TypeError(msg)
705
+ if first_param.annotation == inspect._empty:
706
+ msg = f"'act' method parameter '{first_param.name}' must have a type annotation."
707
+ UserMessage(msg)
708
+ raise TypeError(msg)
555
709
  if act_sig.return_annotation == inspect._empty:
556
- raise TypeError("'act' method must have a return type annotation'")
710
+ msg = "'act' method must have a return type annotation'"
711
+ UserMessage(msg)
712
+ raise TypeError(msg)
557
713
  return True
558
714
 
559
- def _act(self, wrapped_self, input, it, **act_kwargs):
560
- act_method = getattr(wrapped_self, 'act', None)
715
+ def _act(self, wrapped_self, input_value, it, **act_kwargs):
716
+ act_method = getattr(wrapped_self, "act", None)
561
717
  if not callable(act_method):
562
718
  # Propagate the input if no act method is defined
563
- return input
719
+ return input_value
564
720
 
565
721
  assert self._validate_act_method(act_method)
566
722
 
567
- is_dynamic_model = getattr(input, '_is_dynamic_model', False)
568
- input = input if not is_dynamic_model else input.value
723
+ is_dynamic_model = getattr(input_value, "_is_dynamic_model", False)
724
+ input_value = input_value if not is_dynamic_model else input_value.value
569
725
 
570
726
  logger.info(f"Executing 'act' method on {wrapped_self.__class__.__name__}…")
571
727
 
572
728
  op_start = time.perf_counter()
573
729
  try:
574
- act_output = act_method(input, **act_kwargs)
730
+ act_output = act_method(input_value, **act_kwargs)
575
731
  except Exception as e:
576
- logger.exception(f"'act' method execution failed")
732
+ logger.exception("'act' method execution failed")
577
733
  raise e
578
734
  finally:
579
735
  wrapped_self._contract_timing[it]["act_execution"] = time.perf_counter() - op_start
580
736
 
581
737
  act_sig = inspect.signature(act_method)
582
- if act_sig.return_annotation != inspect.Signature.empty and inspect.isclass(act_sig.return_annotation):
583
- if not isinstance(act_output, act_sig.return_annotation):
584
- raise TypeError(f"'act' method returned {type(act_output).__name__}, expected {act_sig.return_annotation.__name__}.")
738
+ if (
739
+ act_sig.return_annotation != inspect.Signature.empty
740
+ and inspect.isclass(act_sig.return_annotation)
741
+ and not isinstance(act_output, act_sig.return_annotation)
742
+ ):
743
+ msg = f"'act' method returned {type(act_output).__name__}, expected {act_sig.return_annotation.__name__}."
744
+ UserMessage(msg)
745
+ raise TypeError(msg)
585
746
 
586
747
  logger.success("'act' method executed successfully!")
587
748
  return act_output
588
749
 
589
- def __call__(self, cls):
590
- original_init = cls.__init__
591
- original_forward = cls.forward
592
-
750
+ def _build_wrapped_init(self, original_init):
593
751
  def __init__(wrapped_self, *args, **kwargs):
594
752
  logger.info("Initializing contract...")
595
753
  original_init(wrapped_self, *args, **kwargs)
596
754
 
597
755
  if not hasattr(wrapped_self, "prompt"):
598
756
  logger.error("Prompt attribute not defined!")
599
- raise Exception("Please define a static `prompt` attribute that describes what the contract must do.")
757
+ msg = "Please define a static `prompt` attribute that describes what the contract must do."
758
+ UserMessage(msg)
759
+ raise Exception(msg)
600
760
 
601
761
  wrapped_self.contract_successful = False
602
762
  wrapped_self.contract_result = None
@@ -604,105 +764,217 @@ class contract:
604
764
  wrapped_self._contract_timing = defaultdict(dict)
605
765
  logger.info("Contract initialization complete!")
606
766
 
607
- def wrapped_forward(wrapped_self, **kwargs):
608
- logger.info("Starting contract execution...")
609
- it = len(wrapped_self._contract_timing) # the len is the __call__ op_start
610
- contract_start = time.perf_counter()
767
+ return __init__
768
+
769
+ def _start_contract_execution(self, wrapped_self):
770
+ logger.info("Starting contract execution...")
771
+ it = len(wrapped_self._contract_timing) # the len is the __call__ op_start
772
+ return it, time.perf_counter()
773
+
774
+ def _find_input_param_name(self, sig: inspect.Signature) -> str | None:
775
+ for param in sig.parameters.values():
776
+ if param.name == "self":
777
+ continue
778
+ if param.kind in (
779
+ inspect.Parameter.POSITIONAL_ONLY,
780
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
781
+ inspect.Parameter.KEYWORD_ONLY,
782
+ ):
783
+ return param.name
784
+ return None
611
785
 
612
- input = kwargs.pop("input", None)
613
- input_type = None
614
- try:
615
- assert self._is_valid_input(input)
616
- except TypeError:
617
- input_type = self._try_dynamic_type_annotation(original_forward, context="input")
618
- input = input_type(value=input)
619
-
620
- maybe_payload = getattr(wrapped_self, "payload", None)
621
- maybe_template = getattr(wrapped_self, "template")
622
- if inspect.ismethod(maybe_template):
623
- # `template` is a primitive in symbolicai case in which we actually don't have a template
624
- maybe_template = None
625
-
626
- # Create validation kwargs that include all original kwargs plus payload and template
627
- validation_kwargs = {
628
- **kwargs,
629
- "payload": maybe_payload,
630
- "template_suffix": maybe_template
631
- }
786
+ def _prepare_forward_args(self, args, kwargs):
787
+ args_list = list(args)
788
+ kwargs_without_input = dict(kwargs)
789
+ original_kwargs = dict(kwargs)
790
+ return args_list, kwargs_without_input, original_kwargs
632
791
 
633
- sig = inspect.signature(original_forward)
634
- output_type = sig.return_annotation
635
- try:
636
- assert self._is_valid_output(output_type)
637
- except TypeError:
638
- output_type = self._try_dynamic_type_annotation(original_forward, context="output")
792
+ def _extract_input_value(
793
+ self,
794
+ args_list,
795
+ kwargs_without_input,
796
+ original_kwargs,
797
+ input_param_name: str | None,
798
+ ):
799
+ if args_list:
800
+ return args_list[0], ("args", 0)
801
+ if input_param_name and input_param_name in kwargs_without_input:
802
+ return kwargs_without_input.pop(input_param_name), ("kwargs", input_param_name)
803
+ if "input" in kwargs_without_input:
804
+ return kwargs_without_input.pop("input"), ("kwargs", "input")
805
+ return original_kwargs.get("input"), ("fallback_kw", "input")
806
+
807
+ def _coerce_input_value(self, original_forward, input_value):
808
+ try:
809
+ assert self._is_valid_input(input_value)
810
+ return input_value
811
+ except TypeError:
812
+ input_type_model = self._try_dynamic_type_annotation(original_forward, context="input")
813
+ return input_type_model(value=input_value)
814
+
815
+ def _collect_validation_kwargs(self, wrapped_self, kwargs_without_input):
816
+ maybe_payload = getattr(wrapped_self, "payload", None)
817
+ maybe_template = getattr(wrapped_self, "template", None)
818
+ if inspect.ismethod(maybe_template):
819
+ maybe_template = None
820
+ return {
821
+ **kwargs_without_input,
822
+ "payload": maybe_payload,
823
+ "template_suffix": maybe_template,
824
+ }
825
+
826
+ def _resolve_output_type(self, sig: inspect.Signature, original_forward):
827
+ output_type = sig.return_annotation
828
+ try:
829
+ assert self._is_valid_output(output_type)
830
+ except TypeError:
831
+ output_type = self._try_dynamic_type_annotation(original_forward, context="output")
832
+ return output_type
639
833
 
640
- output = None
641
- current_input = input
642
- try:
643
- # 1. Start with original input and apply pre-validation
644
- maybe_new_input = self._validate_input(wrapped_self, current_input, it, **validation_kwargs)
645
- if maybe_new_input is not None:
646
- current_input = maybe_new_input
834
+ def _run_contract_pipeline(
835
+ self,
836
+ wrapped_self,
837
+ current_input_value,
838
+ output_type,
839
+ it,
840
+ validation_kwargs,
841
+ ):
842
+ output = None
843
+ try:
844
+ maybe_new_input = self._validate_input(
845
+ wrapped_self, current_input_value, it, **validation_kwargs
846
+ )
847
+ if maybe_new_input is not None:
848
+ current_input_value = maybe_new_input
647
849
 
648
- # 2. Check if 'act' method exists and execute it
649
- current_input = self._act(wrapped_self, current_input, it, **validation_kwargs)
850
+ current_input_value = self._act(
851
+ wrapped_self, current_input_value, it, **validation_kwargs
852
+ )
650
853
 
651
- # 3. Validate output type and prepare for original_forward
652
- output = self._validate_output(wrapped_self, current_input, output_type, it, **validation_kwargs)
653
- wrapped_self.contract_successful = True if output is not None else False
654
- wrapped_self.contract_result = output
655
- wrapped_self.contract_exception = None
854
+ output = self._validate_output(
855
+ wrapped_self,
856
+ current_input_value,
857
+ output_type,
858
+ it,
859
+ **validation_kwargs,
860
+ )
861
+ wrapped_self.contract_successful = output is not None
862
+ wrapped_self.contract_result = output
863
+ wrapped_self.contract_exception = None
864
+ except Exception as exc:
865
+ logger.exception("Contract execution failed in main path!")
866
+ wrapped_self.contract_successful = False
867
+ wrapped_self.contract_exception = exc
868
+ return output, current_input_value
656
869
 
657
- except Exception as e:
658
- logger.exception(f"Contract execution failed in main path!")
659
- wrapped_self.contract_successful = False
660
- wrapped_self.contract_exception = e
661
- # contract_result remains None or its value before the exception.
662
- # final_output remains None or its value before the exception.
663
- # The finally block's execution of original_forward will determine the actual returned value.
664
- finally:
665
- # Execute the original forward method with appropriate input
666
- logger.info("Executing original forward method...")
667
-
668
- # If contract was successful, use the processed input (after pre-validation and act, both optional)
669
- # `current_input` at this stage is the result of the try block's processing up to the point of exception,
670
- # or the full processing if successful.
671
- # If contract failed, use original_input (fallback).
672
- forward_input = current_input if wrapped_self.contract_successful else input
673
-
674
- # Prepare kwargs for original_forward
675
- forward_kwargs = kwargs.copy()
676
- forward_kwargs['input'] = forward_input
677
-
678
- try:
679
- op_start = time.perf_counter()
680
- output = original_forward(wrapped_self, **forward_kwargs)
681
- finally:
682
- wrapped_self._contract_timing[it]["forward_execution"] = time.perf_counter() - op_start
683
- wrapped_self._contract_timing[it]["contract_execution"] = time.perf_counter() - contract_start
684
-
685
- if not isinstance(output, output_type):
686
- logger.error(f"Output type mismatch: {type(output)}")
687
- if self.remedy_retry_params["graceful"]:
688
- # In graceful mode, skip type mismatch error and return raw output
689
- if hasattr(output_type, '_is_dynamic_model') and output_type._is_dynamic_model and hasattr(output, 'value'):
690
- return output.value
691
- return output
692
- raise TypeError(
693
- f"Expected output to be an instance of {output_type}, "
694
- f"but got {type(output)}! Forward method must return an instance of {output_type}!"
695
- )
696
- if not wrapped_self.contract_successful:
697
- logger.warning("Contract validation failed!")
698
- else:
699
- logger.success("Contract validation successful!")
870
+ def _execute_forward_call(
871
+ self,
872
+ wrapped_self,
873
+ original_forward,
874
+ args_list,
875
+ original_kwargs,
876
+ input_param_name,
877
+ input_source,
878
+ forward_input_value,
879
+ it,
880
+ contract_start,
881
+ ):
882
+ forward_kwargs = original_kwargs.copy()
883
+ for internal_kw in self._internal_forward_kwargs:
884
+ forward_kwargs.pop(internal_kw, None)
885
+
886
+ logger.info("Executing original forward method...")
887
+
888
+ if input_param_name:
889
+ if input_param_name in forward_kwargs or input_source == ("kwargs", input_param_name):
890
+ forward_kwargs[input_param_name] = forward_input_value
891
+ elif input_source == ("args", 0) and args_list:
892
+ args_list[0] = forward_input_value
893
+ else:
894
+ forward_kwargs[input_param_name] = forward_input_value
895
+ else:
896
+ forward_kwargs["input"] = forward_input_value
700
897
 
701
- if hasattr(output_type, '_is_dynamic_model') and output_type._is_dynamic_model:
702
- return output.value
898
+ if input_param_name and input_param_name != "input" and "input" in forward_kwargs:
899
+ forward_kwargs.pop("input")
703
900
 
704
- return output
901
+ try:
902
+ op_start = time.perf_counter()
903
+ output = original_forward(wrapped_self, *args_list, **forward_kwargs)
904
+ finally:
905
+ wrapped_self._contract_timing[it]["forward_execution"] = time.perf_counter() - op_start
906
+ wrapped_self._contract_timing[it]["contract_execution"] = (
907
+ time.perf_counter() - contract_start
908
+ )
909
+ return output
910
+
911
+ def _finalize_contract_output(self, output, output_type, wrapped_self):
912
+ if not isinstance(output, output_type):
913
+ logger.error(f"Output type mismatch: {type(output)}")
914
+ if self.remedy_retry_params["graceful"]:
915
+ if getattr(output_type, "_is_dynamic_model", False) and hasattr(output, "value"):
916
+ return output.value
917
+ return output
918
+ msg = (
919
+ f"Expected output to be an instance of {output_type}, "
920
+ f"but got {type(output)}! Forward method must return an instance of {output_type}!"
921
+ )
922
+ UserMessage(msg)
923
+ raise TypeError(msg)
924
+ if not wrapped_self.contract_successful:
925
+ logger.warning("Contract validation failed!")
926
+ else:
927
+ logger.success("Contract validation successful!")
928
+
929
+ if getattr(output_type, "_is_dynamic_model", False):
930
+ return output.value
931
+ return output
705
932
 
933
+ def _contract_forward_impl(self, wrapped_self, original_forward, *args, **kwargs):
934
+ it, contract_start = self._start_contract_execution(wrapped_self)
935
+ sig = inspect.signature(original_forward)
936
+ input_param_name = self._find_input_param_name(sig)
937
+ args_list, kwargs_without_input, original_kwargs = self._prepare_forward_args(args, kwargs)
938
+ input_value, input_source = self._extract_input_value(
939
+ args_list, kwargs_without_input, original_kwargs, input_param_name
940
+ )
941
+ current_input_value = self._coerce_input_value(original_forward, input_value)
942
+ input_value = current_input_value
943
+ validation_kwargs = self._collect_validation_kwargs(wrapped_self, kwargs_without_input)
944
+ output_type = self._resolve_output_type(sig, original_forward)
945
+
946
+ output, current_input_value = self._run_contract_pipeline(
947
+ wrapped_self,
948
+ current_input_value,
949
+ output_type,
950
+ it,
951
+ validation_kwargs,
952
+ )
953
+
954
+ forward_input_value = (
955
+ current_input_value if wrapped_self.contract_successful else input_value
956
+ )
957
+ output = self._execute_forward_call(
958
+ wrapped_self,
959
+ original_forward,
960
+ args_list,
961
+ original_kwargs,
962
+ input_param_name,
963
+ input_source,
964
+ forward_input_value,
965
+ it,
966
+ contract_start,
967
+ )
968
+
969
+ return self._finalize_contract_output(output, output_type, wrapped_self)
970
+
971
+ def _build_wrapped_forward(self, original_forward):
972
+ def wrapped_forward(wrapped_self, *args, **kwargs):
973
+ return self._contract_forward_impl(wrapped_self, original_forward, *args, **kwargs)
974
+
975
+ return wrapped_forward
976
+
977
+ def _build_contract_perf_stats(self):
706
978
  def contract_perf_stats(wrapped_self):
707
979
  """Analyzes and prints timing statistics across all forward calls."""
708
980
  console = Console()
@@ -717,7 +989,7 @@ class contract:
717
989
  "act_execution",
718
990
  "output_validation",
719
991
  "forward_execution",
720
- "contract_execution"
992
+ "contract_execution",
721
993
  ]
722
994
 
723
995
  stats = {}
@@ -739,40 +1011,41 @@ class contract:
739
1011
  max_time = max(non_zero_times) if non_zero_times else 0
740
1012
 
741
1013
  stats[op] = {
742
- 'count': actual_count,
743
- 'total': total_time,
744
- 'mean': mean_time,
745
- 'std': std_time,
746
- 'min': min_time,
747
- 'max': max_time
1014
+ "count": actual_count,
1015
+ "total": total_time,
1016
+ "mean": mean_time,
1017
+ "std": std_time,
1018
+ "min": min_time,
1019
+ "max": max_time,
748
1020
  }
749
1021
 
750
- total_execution_time = stats['contract_execution']['total']
1022
+ total_execution_time = stats["contract_execution"]["total"]
751
1023
  for op in ordered_operations[:-1]:
752
1024
  if total_execution_time > 0:
753
- stats[op]['percentage'] = (stats[op]['total'] / total_execution_time) * 100
1025
+ stats[op]["percentage"] = (stats[op]["total"] / total_execution_time) * 100
754
1026
  else:
755
- stats[op]['percentage'] = 0
1027
+ stats[op]["percentage"] = 0
756
1028
 
757
- sum_tracked_times = sum(stats[op]['total'] for op in ordered_operations[:-1])
1029
+ sum_tracked_times = sum(stats[op]["total"] for op in ordered_operations[:-1])
758
1030
  overhead_time = total_execution_time - sum_tracked_times
759
- overhead_percentage = (overhead_time / total_execution_time) * 100 if total_execution_time > 0 else 0
760
-
761
- stats['overhead'] = {
762
- 'count': num_calls,
763
- 'total': overhead_time,
764
- 'mean': overhead_time / num_calls if num_calls > 0 else 0,
765
- 'std': 0,
766
- 'min': 0,
767
- 'max': 0,
768
- 'percentage': overhead_percentage
1031
+ overhead_percentage = (
1032
+ (overhead_time / total_execution_time) * 100 if total_execution_time > 0 else 0
1033
+ )
1034
+
1035
+ stats["overhead"] = {
1036
+ "count": num_calls,
1037
+ "total": overhead_time,
1038
+ "mean": overhead_time / num_calls if num_calls > 0 else 0,
1039
+ "std": 0,
1040
+ "min": 0,
1041
+ "max": 0,
1042
+ "percentage": overhead_percentage,
769
1043
  }
770
1044
 
771
- stats['contract_execution']['percentage'] = 100.0
1045
+ stats["contract_execution"]["percentage"] = 100.0
772
1046
 
773
1047
  table = Table(
774
- title=f"Contract Execution Summary ({num_calls} Forward Calls)",
775
- show_header=True
1048
+ title=f"Contract Execution Summary ({num_calls} Forward Calls)", show_header=True
776
1049
  )
777
1050
  table.add_column("Operation", style="cyan")
778
1051
  table.add_column("Count", justify="right", style="blue")
@@ -787,29 +1060,29 @@ class contract:
787
1060
  s = stats[op]
788
1061
  table.add_row(
789
1062
  op.replace("_", " ").title(),
790
- str(s['count']),
1063
+ str(s["count"]),
791
1064
  f"{s['total']:.3f}",
792
1065
  f"{s['mean']:.3f}",
793
1066
  f"{s['std']:.3f}",
794
1067
  f"{s['min']:.3f}",
795
1068
  f"{s['max']:.3f}",
796
- f"{s['percentage']:.1f}%"
1069
+ f"{s['percentage']:.1f}%",
797
1070
  )
798
1071
 
799
- s = stats['overhead']
1072
+ s = stats["overhead"]
800
1073
  table.add_row(
801
1074
  "Overhead",
802
- str(s['count']),
1075
+ str(s["count"]),
803
1076
  f"{s['total']:.3f}",
804
1077
  f"{s['mean']:.3f}",
805
1078
  f"{s['std']:.3f}",
806
1079
  f"{s['min']:.3f}",
807
1080
  f"{s['max']:.3f}",
808
1081
  f"{s['percentage']:.1f}%",
809
- style="bold blue"
1082
+ style="bold blue",
810
1083
  )
811
1084
 
812
- s = stats['contract_execution']
1085
+ s = stats["contract_execution"]
813
1086
  table.add_row(
814
1087
  "Total Execution",
815
1088
  "N/A",
@@ -819,7 +1092,7 @@ class contract:
819
1092
  f"{s['min']:.3f}",
820
1093
  f"{s['max']:.3f}",
821
1094
  "100.0%",
822
- style="bold magenta"
1095
+ style="bold magenta",
823
1096
  )
824
1097
 
825
1098
  console.print("\n")
@@ -828,17 +1101,29 @@ class contract:
828
1101
 
829
1102
  return stats
830
1103
 
831
- cls.__init__ = __init__
832
- cls.forward = wrapped_forward
833
- cls.contract_perf_stats = contract_perf_stats
1104
+ return contract_perf_stats
1105
+
1106
+ def __call__(self, cls):
1107
+ original_init = cls.__init__
1108
+ original_forward = cls.forward
834
1109
 
1110
+ cls.__init__ = self._build_wrapped_init(original_init)
1111
+ cls.forward = self._build_wrapped_forward(original_forward)
1112
+ cls.contract_perf_stats = self._build_contract_perf_stats()
835
1113
  return cls
836
1114
 
837
1115
 
838
1116
  class BaseStrategy(TypeValidationFunction):
839
- def __init__(self, data_model: BaseModel, *args, **kwargs):
1117
+ def __init__(self, data_model: BaseModel, *_args, **kwargs):
840
1118
  super().__init__(
841
- retry_params=dict(tries=8, delay=0.015, backoff=1.25, jitter=0.0, max_delay=0.25, graceful=False),
1119
+ retry_params={
1120
+ "tries": 8,
1121
+ "delay": 0.015,
1122
+ "backoff": 1.25,
1123
+ "jitter": 0.0,
1124
+ "max_delay": 0.25,
1125
+ "graceful": False,
1126
+ },
842
1127
  **kwargs,
843
1128
  )
844
1129
  super().register_expected_data_model(data_model, attach_to="output")
@@ -851,14 +1136,13 @@ class BaseStrategy(TypeValidationFunction):
851
1136
  pass
852
1137
 
853
1138
  def forward(self, *args, **kwargs):
854
- result = super().forward(
1139
+ return super().forward(
855
1140
  *args,
856
1141
  payload=self.payload,
857
1142
  template_suffix=self.template,
858
1143
  response_format={"type": "json_object"},
859
1144
  **kwargs,
860
1145
  )
861
- return result
862
1146
 
863
1147
  @property
864
1148
  def payload(self):
@@ -866,7 +1150,7 @@ class BaseStrategy(TypeValidationFunction):
866
1150
 
867
1151
  @property
868
1152
  def static_context(self):
869
- raise NotImplementedError()
1153
+ raise NotImplementedError
870
1154
 
871
1155
  @property
872
1156
  def template(self):
@@ -878,7 +1162,7 @@ class Strategy(Expression):
878
1162
  super().__init__(*args, **kwargs)
879
1163
  self.logger = logging.getLogger(__name__)
880
1164
 
881
- def __new__(self, module: str, *args, **kwargs):
882
- self._module = module
883
- self.module_path = 'symai.extended.strategies'
884
- return Strategy.load_module_class(self.module)
1165
+ def __new__(cls, module: str, *_args, **_kwargs):
1166
+ cls._module = module
1167
+ cls.module_path = "symai.extended.strategies"
1168
+ return Strategy.load_module_class(cls.module)