symbolicai 0.20.2__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. symai/__init__.py +96 -64
  2. symai/backend/base.py +93 -80
  3. symai/backend/engines/drawing/engine_bfl.py +12 -11
  4. symai/backend/engines/drawing/engine_gpt_image.py +108 -87
  5. symai/backend/engines/embedding/engine_llama_cpp.py +25 -28
  6. symai/backend/engines/embedding/engine_openai.py +3 -5
  7. symai/backend/engines/execute/engine_python.py +6 -5
  8. symai/backend/engines/files/engine_io.py +74 -67
  9. symai/backend/engines/imagecaptioning/engine_blip2.py +3 -3
  10. symai/backend/engines/imagecaptioning/engine_llavacpp_client.py +54 -38
  11. symai/backend/engines/index/engine_pinecone.py +23 -24
  12. symai/backend/engines/index/engine_vectordb.py +16 -14
  13. symai/backend/engines/lean/engine_lean4.py +38 -34
  14. symai/backend/engines/neurosymbolic/__init__.py +41 -13
  15. symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_chat.py +262 -182
  16. symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_reasoning.py +263 -191
  17. symai/backend/engines/neurosymbolic/engine_deepseekX_reasoning.py +53 -49
  18. symai/backend/engines/neurosymbolic/engine_google_geminiX_reasoning.py +212 -211
  19. symai/backend/engines/neurosymbolic/engine_groq.py +87 -63
  20. symai/backend/engines/neurosymbolic/engine_huggingface.py +21 -24
  21. symai/backend/engines/neurosymbolic/engine_llama_cpp.py +117 -48
  22. symai/backend/engines/neurosymbolic/engine_openai_gptX_chat.py +256 -229
  23. symai/backend/engines/neurosymbolic/engine_openai_gptX_reasoning.py +270 -150
  24. symai/backend/engines/ocr/engine_apilayer.py +6 -8
  25. symai/backend/engines/output/engine_stdout.py +1 -4
  26. symai/backend/engines/search/engine_openai.py +7 -7
  27. symai/backend/engines/search/engine_perplexity.py +5 -5
  28. symai/backend/engines/search/engine_serpapi.py +12 -14
  29. symai/backend/engines/speech_to_text/engine_local_whisper.py +20 -27
  30. symai/backend/engines/symbolic/engine_wolframalpha.py +3 -3
  31. symai/backend/engines/text_to_speech/engine_openai.py +5 -7
  32. symai/backend/engines/text_vision/engine_clip.py +7 -11
  33. symai/backend/engines/userinput/engine_console.py +3 -3
  34. symai/backend/engines/webscraping/engine_requests.py +81 -48
  35. symai/backend/mixin/__init__.py +13 -0
  36. symai/backend/mixin/anthropic.py +4 -2
  37. symai/backend/mixin/deepseek.py +2 -0
  38. symai/backend/mixin/google.py +2 -0
  39. symai/backend/mixin/openai.py +11 -3
  40. symai/backend/settings.py +83 -16
  41. symai/chat.py +101 -78
  42. symai/collect/__init__.py +7 -1
  43. symai/collect/dynamic.py +77 -69
  44. symai/collect/pipeline.py +35 -27
  45. symai/collect/stats.py +75 -63
  46. symai/components.py +198 -169
  47. symai/constraints.py +15 -12
  48. symai/core.py +698 -359
  49. symai/core_ext.py +32 -34
  50. symai/endpoints/api.py +80 -73
  51. symai/extended/.DS_Store +0 -0
  52. symai/extended/__init__.py +46 -12
  53. symai/extended/api_builder.py +11 -8
  54. symai/extended/arxiv_pdf_parser.py +13 -12
  55. symai/extended/bibtex_parser.py +2 -3
  56. symai/extended/conversation.py +101 -90
  57. symai/extended/document.py +17 -10
  58. symai/extended/file_merger.py +18 -13
  59. symai/extended/graph.py +18 -13
  60. symai/extended/html_style_template.py +2 -4
  61. symai/extended/interfaces/blip_2.py +1 -2
  62. symai/extended/interfaces/clip.py +1 -2
  63. symai/extended/interfaces/console.py +7 -1
  64. symai/extended/interfaces/dall_e.py +1 -1
  65. symai/extended/interfaces/flux.py +1 -1
  66. symai/extended/interfaces/gpt_image.py +1 -1
  67. symai/extended/interfaces/input.py +1 -1
  68. symai/extended/interfaces/llava.py +0 -1
  69. symai/extended/interfaces/naive_vectordb.py +7 -8
  70. symai/extended/interfaces/naive_webscraping.py +1 -1
  71. symai/extended/interfaces/ocr.py +1 -1
  72. symai/extended/interfaces/pinecone.py +6 -5
  73. symai/extended/interfaces/serpapi.py +1 -1
  74. symai/extended/interfaces/terminal.py +2 -3
  75. symai/extended/interfaces/tts.py +1 -1
  76. symai/extended/interfaces/whisper.py +1 -1
  77. symai/extended/interfaces/wolframalpha.py +1 -1
  78. symai/extended/metrics/__init__.py +11 -1
  79. symai/extended/metrics/similarity.py +11 -13
  80. symai/extended/os_command.py +17 -16
  81. symai/extended/packages/__init__.py +29 -3
  82. symai/extended/packages/symdev.py +19 -16
  83. symai/extended/packages/sympkg.py +12 -9
  84. symai/extended/packages/symrun.py +21 -19
  85. symai/extended/repo_cloner.py +11 -10
  86. symai/extended/seo_query_optimizer.py +1 -2
  87. symai/extended/solver.py +20 -23
  88. symai/extended/summarizer.py +4 -3
  89. symai/extended/taypan_interpreter.py +10 -12
  90. symai/extended/vectordb.py +99 -82
  91. symai/formatter/__init__.py +9 -1
  92. symai/formatter/formatter.py +12 -16
  93. symai/formatter/regex.py +62 -63
  94. symai/functional.py +176 -122
  95. symai/imports.py +136 -127
  96. symai/interfaces.py +56 -27
  97. symai/memory.py +14 -13
  98. symai/misc/console.py +49 -39
  99. symai/misc/loader.py +5 -3
  100. symai/models/__init__.py +17 -1
  101. symai/models/base.py +269 -181
  102. symai/models/errors.py +0 -1
  103. symai/ops/__init__.py +32 -22
  104. symai/ops/measures.py +11 -15
  105. symai/ops/primitives.py +348 -228
  106. symai/post_processors.py +32 -28
  107. symai/pre_processors.py +39 -41
  108. symai/processor.py +6 -4
  109. symai/prompts.py +59 -45
  110. symai/server/huggingface_server.py +23 -20
  111. symai/server/llama_cpp_server.py +7 -5
  112. symai/shell.py +3 -4
  113. symai/shellsv.py +499 -375
  114. symai/strategy.py +517 -287
  115. symai/symbol.py +111 -116
  116. symai/utils.py +42 -36
  117. {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/METADATA +4 -2
  118. symbolicai-1.0.0.dist-info/RECORD +163 -0
  119. symbolicai-0.20.2.dist-info/RECORD +0 -162
  120. {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/WHEEL +0 -0
  121. {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/entry_points.txt +0 -0
  122. {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/licenses/LICENSE +0 -0
  123. {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/top_level.txt +0 -0
symai/strategy.py CHANGED
@@ -2,7 +2,8 @@ import inspect
2
2
  import logging
3
3
  import time
4
4
  from collections import defaultdict
5
- from typing import Callable
5
+ from collections.abc import Callable
6
+ from typing import Any, ClassVar
6
7
 
7
8
  import numpy as np
8
9
  from beartype import beartype
@@ -14,9 +15,9 @@ from rich.panel import Panel
14
15
  from rich.table import Table
15
16
 
16
17
  from .components import Function
17
- from .models import (LLMDataModel, TypeValidationError,
18
- build_dynamic_llm_datamodel)
18
+ from .models import LLMDataModel, TypeValidationError, build_dynamic_llm_datamodel
19
19
  from .symbol import Expression
20
+ from .utils import UserMessage
20
21
 
21
22
 
22
23
  class ValidationFunction(Function):
@@ -29,18 +30,18 @@ class ValidationFunction(Function):
29
30
  • Error simplification
30
31
  """
31
32
  # Have some default retry params that don't add overhead
32
- _default_retry_params = dict(
33
- tries=8,
34
- delay=0.015,
35
- backoff=1.25,
36
- jitter=0.0,
37
- max_delay=0.25,
38
- graceful=False
39
- )
33
+ _default_retry_params: ClassVar[dict[str, int | float | bool]] = {
34
+ "tries": 8,
35
+ "delay": 0.015,
36
+ "backoff": 1.25,
37
+ "jitter": 0.0,
38
+ "max_delay": 0.25,
39
+ "graceful": False,
40
+ }
40
41
 
41
42
  def __init__(
42
43
  self,
43
- retry_params: dict[str, int | float | bool] = None,
44
+ retry_params: dict[str, int | float | bool] | None = None,
44
45
  verbose: bool = False,
45
46
  *args,
46
47
  **kwargs,
@@ -92,10 +93,9 @@ class ValidationFunction(Function):
92
93
  seed = 42
93
94
 
94
95
  rnd = np.random.RandomState(seed=seed)
95
- seeds = rnd.randint(
96
+ return rnd.randint(
96
97
  0, np.iinfo(np.int16).max, size=num_seeds, dtype=np.int16
97
98
  ).tolist()
98
- return seeds
99
99
 
100
100
  def simplify_validation_errors(self, error: ValidationError) -> str:
101
101
  """
@@ -130,12 +130,14 @@ class ValidationFunction(Function):
130
130
  _delay = min(base + jit, self.retry_params['max_delay'])
131
131
  time.sleep(_delay)
132
132
 
133
- def remedy_prompt(self, *args, **kwargs):
133
+ def remedy_prompt(self, *_args, **_kwargs):
134
134
  """
135
135
  Abstract or base remedy prompt method.
136
136
  Child classes typically override this to include additional context needed for correction.
137
137
  """
138
- raise NotImplementedError("Each child class needs its own remedy_prompt implementation.")
138
+ msg = "Each child class needs its own remedy_prompt implementation."
139
+ UserMessage(msg)
140
+ raise NotImplementedError(msg)
139
141
 
140
142
  def display_panel(self, content, title, border_style="cyan", style="#f0eee6", padding=(1,2)):
141
143
  """
@@ -171,7 +173,7 @@ class TypeValidationFunction(ValidationFunction):
171
173
  *args,
172
174
  **kwargs,
173
175
  ):
174
- super().__init__(retry_params=retry_params, verbose=verbose, *args, **kwargs)
176
+ super().__init__(*args, retry_params=retry_params, verbose=verbose, **kwargs)
175
177
  self.input_data_model = None
176
178
  self.output_data_model = None
177
179
  self.accumulate_errors = accumulate_errors
@@ -181,11 +183,15 @@ class TypeValidationFunction(ValidationFunction):
181
183
  assert attach_to in ["input", "output"], f"Invalid attach_to value: {attach_to}; must be either 'input' or 'output'"
182
184
  if attach_to == "input":
183
185
  if self.input_data_model is not None and not override:
184
- raise ValueError("There is already a data model attached to the input. If you want to override it, set `override=True`.")
186
+ msg = "There is already a data model attached to the input. If you want to override it, set `override=True`."
187
+ UserMessage(msg)
188
+ raise ValueError(msg)
185
189
  self.input_data_model = data_model
186
190
  elif attach_to == "output":
187
191
  if self.output_data_model is not None and not override:
188
- raise ValueError("There is already a data model attached to the output. If you want to override it, set `override=True`.")
192
+ msg = "There is already a data model attached to the output. If you want to override it, set `override=True`."
193
+ UserMessage(msg)
194
+ raise ValueError(msg)
189
195
  self.output_data_model = data_model
190
196
 
191
197
  def remedy_prompt(self, prompt: str, output: str, errors: str) -> str:
@@ -268,118 +274,174 @@ Important guidelines:
268
274
  </guidelines>
269
275
  """
270
276
 
271
- def forward(self, prompt: str, f_semantic_conditions: list[Callable] | None = None, *args, **kwargs):
277
+ def _ensure_output_model(self):
272
278
  if self.output_data_model is None:
273
- raise ValueError("While the input data model is optional, the output data model must be provided. Please register it before calling the `forward` method.")
279
+ msg = (
280
+ "While the input data model is optional, the output data model must be provided. "
281
+ "Please register it before calling the `forward` method."
282
+ )
283
+ UserMessage(msg)
284
+ raise ValueError(msg)
285
+
286
+ def _display_verbose_panels(self, prompt: str):
287
+ if not self.verbose:
288
+ return
289
+ for label, body in [
290
+ ("Prompt", prompt),
291
+ ("Input data model", self.input_data_model.simplify_json_schema() if self.input_data_model else 'N/A'),
292
+ ("Output data model", self.output_data_model.simplify_json_schema()),
293
+ ]:
294
+ self.display_panel(body, title=label)
295
+
296
+ def _check_semantic_conditions(
297
+ self,
298
+ result,
299
+ f_semantic_conditions: list[Callable] | None,
300
+ ) -> str | None:
301
+ if f_semantic_conditions is None:
302
+ return None
303
+ try:
304
+ assert all(
305
+ f(result if not getattr(self.output_data_model, '_is_dynamic_model', False) else result.value)
306
+ for f in f_semantic_conditions
307
+ )
308
+ except Exception as err:
309
+ return f"Semantic validation failed with:\n{err!s}"
310
+ return None
311
+
312
+ def _format_validation_error(self, error: Exception) -> str:
313
+ if isinstance(error, ValidationError):
314
+ return self.simplify_validation_errors(error)
315
+ return str(error)
316
+
317
+ def _handle_failed_validation_attempt(
318
+ self,
319
+ attempt_index: int,
320
+ prompt: str,
321
+ json_str: str,
322
+ errors: list[str],
323
+ error: Exception,
324
+ remedy_seeds: list[Any],
325
+ kwargs: dict,
326
+ ) -> str:
327
+ logger.info(f"Validation attempt {attempt_index + 1} failed, pausing before retry…")
328
+ self._pause(attempt_index)
329
+ error_str = self._format_validation_error(error)
330
+ errors.append(error_str)
331
+
332
+ logger.error("Validation errors identified!")
333
+ if self.verbose:
334
+ errors_report = "\n".join(errors) if self.accumulate_errors else error_str
335
+ title = f"Validation Errors ({'accumulated errors' if self.accumulate_errors else 'last error'})"
336
+ self.display_panel(errors_report, title=title, border_style="red")
337
+
338
+ logger.info("Updating remedy function context…")
339
+ context = self.remedy_prompt(
340
+ prompt=prompt,
341
+ output=json_str,
342
+ errors="\n".join(errors) if self.accumulate_errors else error_str,
343
+ )
344
+ self.remedy_function.clear()
345
+ self.remedy_function.adapt(context)
346
+ if self.verbose:
347
+ self.display_panel(self.remedy_function.dynamic_context, title="New Context")
348
+
349
+ json_str = self.remedy_function(seed=remedy_seeds[attempt_index], **kwargs).value
350
+ logger.info("Applied remedy function with updated context!")
351
+ return json_str
352
+
353
+ def _run_validation_attempts(
354
+ self,
355
+ prompt: str,
356
+ f_semantic_conditions: list[Callable] | None,
357
+ validation_context: dict,
358
+ remedy_seeds: list[Any],
359
+ json_str: str,
360
+ kwargs: dict,
361
+ ) -> tuple[Any | None, str, list[str]]:
362
+ errors: list[str] = []
363
+ result = None
364
+ total_attempts = self.retry_params["tries"] + 1
365
+ for attempt in range(total_attempts):
366
+ if attempt != self.retry_params["tries"]:
367
+ logger.info(f"Attempt {attempt + 1}/{self.retry_params['tries']}: Attempting validation…")
368
+ try:
369
+ result = self.output_data_model.model_validate_json(
370
+ json_str,
371
+ strict=False,
372
+ context=validation_context,
373
+ )
374
+ semantic_error = self._check_semantic_conditions(result, f_semantic_conditions)
375
+ if semantic_error is not None:
376
+ if attempt == total_attempts - 1:
377
+ result = None
378
+ errors.append(semantic_error)
379
+ break
380
+ raise AssertionError(semantic_error)
381
+ break
382
+ except Exception as error:
383
+ json_str = self._handle_failed_validation_attempt(
384
+ attempt,
385
+ prompt,
386
+ json_str,
387
+ errors,
388
+ error,
389
+ remedy_seeds,
390
+ kwargs,
391
+ )
392
+ return result, json_str, errors
393
+
394
+ def _handle_validation_failure(self, prompt: str, json_str: str, errors: list[str]):
395
+ logger.error("All validation attempts failed!")
396
+ if self.retry_params['graceful']:
397
+ return
398
+ raise TypeValidationError(
399
+ prompt=prompt,
400
+ result=json_str,
401
+ violations=errors,
402
+ )
403
+
404
+ def forward(self, prompt: str, f_semantic_conditions: list[Callable] | None = None, *args, **kwargs):
405
+ self._ensure_output_model()
274
406
  validation_context = kwargs.pop('validation_context', {})
275
- # Force JSON mode
276
407
  kwargs["response_format"] = {"type": "json_object"}
277
408
  logger.info("Initializing validation…")
278
- if self.verbose:
279
- for label, body in [
280
- ("Prompt", prompt),
281
- ("Input data model", self.input_data_model.simplify_json_schema() if self.input_data_model else 'N/A'),
282
- ("Output data model", self.output_data_model.simplify_json_schema()),
283
- ]:
284
- self.display_panel(body, title=label)
285
-
286
- # Zero shot the task
409
+ self._display_verbose_panels(prompt)
410
+
287
411
  context = self.zero_shot_prompt(prompt=prompt)
288
412
  json_str = super().forward(context, *args, **kwargs).value
289
413
 
290
414
  remedy_seeds = self.prepare_seeds(self.retry_params["tries"] + 1, **kwargs)
291
415
  logger.info(f"Prepared {len(remedy_seeds)} remedy seeds for validation attempts…")
292
416
 
293
- result = None
294
- errors = []
295
- for i in range(self.retry_params["tries"] + 1):
296
- if i != self.retry_params["tries"]:
297
- logger.info(f"Attempt {i+1}/{self.retry_params['tries']}: Attempting validation…")
298
- try:
299
- #@NOTE: We use strict=False (default) to allow Pydantic's type coercion.
300
- # This handles common LLM output issues like:
301
- # - String keys "123" -> integer keys 123 for dict[int, ...] types
302
- # - String numbers "42" -> int 42
303
- # - Float 1.0 -> int 1
304
- # These coercions are safe and helpful when dealing with JSON from LLMs,
305
- # while still catching actual type errors (e.g., "not_a_number" -> int fails).
306
- result = self.output_data_model.model_validate_json(json_str, strict=False, context=validation_context)
307
- if f_semantic_conditions is not None:
308
- try:
309
- assert all(
310
- f(result if not getattr(self.output_data_model, '_is_dynamic_model', False) else result.value)
311
- for f in f_semantic_conditions
312
- )
313
- except Exception as e:
314
- # If we are in the last attempt and semantic validation fails, result will be None and we propagate the error
315
- if i == self.retry_params["tries"]:
316
- result = None
317
- errors.append(f"Semantic validation failed with:\n{str(e)}")
318
- break # We break to avoid going into the remedy loop
319
- raise e
320
- break
321
- except Exception as e:
322
- logger.info(f"Validation attempt {i+1} failed, pausing before retry…")
323
-
324
- self._pause(i)
325
-
326
- if isinstance(e, ValidationError):
327
- error_str = self.simplify_validation_errors(e)
328
- else:
329
- error_str = str(e)
330
-
331
- errors.append(error_str)
332
-
333
- logger.error(f"Validation errors identified!")
334
- if self.verbose:
335
- self.display_panel(
336
- "\n".join(errors) if self.accumulate_errors else error_str,
337
- title=f"Validation Errors ({'accumulated errors' if self.accumulate_errors else 'last error'})",
338
- border_style="red"
339
- )
340
-
341
- # Update remedy function context
342
- logger.info("Updating remedy function context…")
343
- context = self.remedy_prompt(prompt=prompt, output=json_str, errors="\n".join(errors) if self.accumulate_errors else error_str)
344
- self.remedy_function.clear()
345
- self.remedy_function.adapt(context)
346
- if self.verbose:
347
- self.display_panel(
348
- self.remedy_function.dynamic_context,
349
- title="New Context"
350
- )
351
-
352
- # Apply the remedy function
353
- json_str = self.remedy_function(seed=remedy_seeds[i], **kwargs).value
354
- logger.info("Applied remedy function with updated context!")
417
+ result, json_str, errors = self._run_validation_attempts(
418
+ prompt,
419
+ f_semantic_conditions,
420
+ validation_context,
421
+ remedy_seeds,
422
+ json_str,
423
+ kwargs,
424
+ )
355
425
 
356
426
  if result is None:
357
- logger.error(f"All validation attempts failed!")
358
- if self.retry_params['graceful']:
359
- return
360
- raise TypeValidationError(
361
- prompt=prompt,
362
- result=json_str,
363
- violations=errors,
364
- )
427
+ return self._handle_validation_failure(prompt, json_str, errors)
365
428
 
366
429
  logger.success("Validation completed successfully!")
367
- # Clear artifacts from the remedy function
368
430
  self.remedy_function.clear()
369
-
370
431
  return result
371
432
 
372
433
 
373
434
  @beartype
374
435
  class contract:
375
- _default_remedy_retry_params = dict(
376
- tries=8,
377
- delay=0.015,
378
- backoff=1.25,
379
- jitter=0.0,
380
- max_delay=0.25,
381
- graceful=False
382
- )
436
+ _default_remedy_retry_params: ClassVar[dict[str, int | float | bool]] = {
437
+ "tries": 8,
438
+ "delay": 0.015,
439
+ "backoff": 1.25,
440
+ "jitter": 0.0,
441
+ "max_delay": 0.25,
442
+ "graceful": False,
443
+ }
444
+ _internal_forward_kwargs: ClassVar[set[str]] = {"validation_context"}
383
445
 
384
446
  def __init__(
385
447
  self,
@@ -404,22 +466,30 @@ class contract:
404
466
  else:
405
467
  logger.enable(__name__)
406
468
 
407
- def _is_valid_input(self, input):
408
- if input is None:
469
+ def _is_valid_input(self, input_value):
470
+ if input_value is None:
409
471
  logger.error("No `input` argument provided!")
410
- raise ValueError("Please provide an `input` argument.")
411
- if not isinstance(input, LLMDataModel):
412
- logger.error(f"Invalid input type: {type(input)}")
413
- raise TypeError(f"Expected input to be of type `LLMDataModel`, got {type(input)}")
472
+ msg = "Please provide an `input` argument."
473
+ UserMessage(msg)
474
+ raise ValueError(msg)
475
+ if not isinstance(input_value, LLMDataModel):
476
+ logger.error(f"Invalid input type: {type(input_value)}")
477
+ msg = f"Expected input to be of type `LLMDataModel`, got {type(input_value)}"
478
+ UserMessage(msg)
479
+ raise TypeError(msg)
414
480
  return True
415
481
 
416
482
  def _is_valid_output(self, output_type):
417
483
  if output_type == inspect._empty:
418
484
  logger.error("Missing return type annotation!")
419
- raise ValueError("The contract requires a return type annotation.")
485
+ msg = "The contract requires a return type annotation."
486
+ UserMessage(msg)
487
+ raise ValueError(msg)
420
488
  if not issubclass(output_type, LLMDataModel):
421
489
  logger.error(f"Invalid return type: {output_type}")
422
- raise TypeError("The return type annotation must be a subclass of `LLMDataModel`.")
490
+ msg = "The return type annotation must be a subclass of `LLMDataModel`."
491
+ UserMessage(msg)
492
+ raise TypeError(msg)
423
493
  return True
424
494
 
425
495
  def _try_dynamic_type_annotation(self, original_forward, *, context: str = "input"):
@@ -428,23 +498,42 @@ class contract:
428
498
  )
429
499
  sig = inspect.signature(original_forward)
430
500
  try:
501
+ resolved_param = None
431
502
  # Fallback: look at the relevant part of the function signature
432
503
  # depending on whether we deal with an *input* or *output*
433
504
  if context == "input":
434
505
  param = sig.parameters.get("input")
506
+ if param is None:
507
+ for candidate in sig.parameters.values():
508
+ if candidate.name == "self":
509
+ continue
510
+ if candidate.kind in (
511
+ inspect.Parameter.POSITIONAL_ONLY,
512
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
513
+ inspect.Parameter.KEYWORD_ONLY,
514
+ ):
515
+ param = candidate
516
+ break
517
+ resolved_param = param
435
518
  if param is None or param.annotation == inspect._empty:
436
- raise TypeError("Failed to infer type from input parameter annotation")
519
+ msg = "Failed to infer type from input parameter annotation"
520
+ UserMessage(msg)
521
+ raise TypeError(msg)
437
522
  dynamic_model = build_dynamic_llm_datamodel(param.annotation)
438
523
  else: # context == "output"
439
524
  if sig.return_annotation == inspect._empty:
440
- raise TypeError("Failed to infer type from return annotation")
525
+ msg = "Failed to infer type from return annotation"
526
+ UserMessage(msg)
527
+ raise TypeError(msg)
441
528
  dynamic_model = build_dynamic_llm_datamodel(sig.return_annotation)
442
- except Exception:
443
- logger.exception(f"Failed to build dynamic LLMDataModel from {sig.parameters.get('input')}!")
444
- raise TypeError(
529
+ except Exception as err:
530
+ logger.exception(f"Failed to build dynamic LLMDataModel from {resolved_param}!")
531
+ msg = (
445
532
  "The type annotation must be a subclass of `LLMDataModel` or a "
446
533
  "valid Python typing object supported by Pydantic."
447
534
  )
535
+ UserMessage(msg)
536
+ raise TypeError(msg) from err
448
537
 
449
538
  dynamic_model._is_dynamic_model = True
450
539
  return dynamic_model
@@ -453,49 +542,54 @@ class contract:
453
542
  try:
454
543
  data_model = self.f_type_validation_remedy(prompt, f_semantic_conditions=f_semantic_conditions, **remedy_kwargs)
455
544
  except Exception as e:
456
- logger.error(f"Type validation failed with exception!")
545
+ logger.error("Type validation failed with exception!")
457
546
  raise e
458
547
  return data_model
459
548
 
460
- def _validate_input(self, wrapped_self, input, it, **remedy_kwargs):
549
+ def _validate_input(self, wrapped_self, input_value, it, **remedy_kwargs):
461
550
  logger.info("Starting input validation...")
462
551
  if self.pre_remedy:
463
552
  logger.info("Validating pre-conditions with remedy...")
464
553
  if not hasattr(wrapped_self, 'pre'):
465
554
  logger.error("Pre-condition function not defined!")
466
- raise Exception("Pre-condition function not defined. Please define a `pre` method if you want to enforce pre-conditions through a remedy.")
555
+ msg = "Pre-condition function not defined. Please define a `pre` method if you want to enforce pre-conditions through a remedy."
556
+ UserMessage(msg)
557
+ raise Exception(msg)
467
558
 
468
559
  op_start = time.perf_counter()
469
560
  try:
470
- assert wrapped_self.pre(input)
561
+ assert wrapped_self.pre(input_value)
471
562
  logger.success("Pre-condition validation successful!")
472
- return input
473
- except Exception as e:
563
+ return input_value
564
+ except Exception:
474
565
  logger.exception("Pre-condition validation failed!")
475
- self.f_type_validation_remedy.register_expected_data_model(input, attach_to="output", override=True)
476
- input = self._try_remedy_with_exception(prompt=wrapped_self.prompt, f_semantic_conditions=[wrapped_self.pre], **remedy_kwargs)
566
+ self.f_type_validation_remedy.register_expected_data_model(input_value, attach_to="output", override=True)
567
+ input_value = self._try_remedy_with_exception(
568
+ prompt=wrapped_self.prompt,
569
+ f_semantic_conditions=[wrapped_self.pre],
570
+ **remedy_kwargs,
571
+ )
477
572
  finally:
478
573
  wrapped_self._contract_timing[it]["input_validation"] = time.perf_counter() - op_start
479
- return input
480
- else:
481
- if hasattr(wrapped_self, 'pre'):
482
- logger.info("Validating pre-conditions without remedy...")
483
- op_start = time.perf_counter()
484
- try:
485
- assert wrapped_self.pre(input)
486
- except Exception as e:
487
- logger.exception(f"Pre-condition validation failed")
488
- raise e
489
- finally:
490
- wrapped_self._contract_timing[it]["input_validation"] = time.perf_counter() - op_start
491
- logger.success("Pre-condition validation successful!")
492
- return input
493
- logger.info("Skip; no pre-condition validation was required!")
494
- return input
574
+ return input_value
575
+ if hasattr(wrapped_self, 'pre'):
576
+ logger.info("Validating pre-conditions without remedy...")
577
+ op_start = time.perf_counter()
578
+ try:
579
+ assert wrapped_self.pre(input_value)
580
+ except Exception as e:
581
+ logger.exception("Pre-condition validation failed")
582
+ raise e
583
+ finally:
584
+ wrapped_self._contract_timing[it]["input_validation"] = time.perf_counter() - op_start
585
+ logger.success("Pre-condition validation successful!")
586
+ return input_value
587
+ logger.info("Skip; no pre-condition validation was required!")
588
+ return input_value
495
589
 
496
- def _validate_output(self, wrapped_self, input, output, it, **remedy_kwargs):
590
+ def _validate_output(self, wrapped_self, input_value, output, it, **remedy_kwargs):
497
591
  logger.info("Starting output validation...")
498
- self.f_type_validation_remedy.register_expected_data_model(input, attach_to="input", override=True)
592
+ self.f_type_validation_remedy.register_expected_data_model(input_value, attach_to="input", override=True)
499
593
  self.f_type_validation_remedy.register_expected_data_model(output, attach_to="output", override=True)
500
594
 
501
595
  op_start = time.perf_counter()
@@ -505,7 +599,7 @@ class contract:
505
599
  if output is None: # output is None when graceful mode is enabled
506
600
  return output
507
601
  except Exception as e:
508
- logger.exception(f"Type creation failed!")
602
+ logger.exception("Type creation failed!")
509
603
  raise e
510
604
  finally:
511
605
  wrapped_self._contract_timing[it]["output_validation"] = time.perf_counter() - op_start
@@ -515,88 +609,111 @@ class contract:
515
609
  logger.info("Validating post-conditions with remedy...")
516
610
  if not hasattr(wrapped_self, "post"):
517
611
  logger.error("Post-condition function not defined!")
518
- raise Exception("Post-condition function not defined. Please define a `post` method if you want to enforce post-conditions through a remedy.")
612
+ msg = "Post-condition function not defined. Please define a `post` method if you want to enforce post-conditions through a remedy."
613
+ UserMessage(msg)
614
+ raise Exception(msg)
519
615
 
520
616
  op_start = time.perf_counter()
521
617
  try:
522
618
  assert wrapped_self.post(output)
523
619
  logger.success("Post-condition validation successful!")
524
620
  return output
525
- except Exception as e:
621
+ except Exception:
526
622
  logger.exception("Post-condition validation failed!")
527
623
  output = self._try_remedy_with_exception(prompt=wrapped_self.prompt, f_semantic_conditions=[wrapped_self.post], **remedy_kwargs)
528
624
  finally:
529
625
  wrapped_self._contract_timing[it]["output_validation"] += (time.perf_counter() - op_start)
530
626
  logger.success("Post-condition validation successful!")
531
627
  return output
532
- else:
533
- if hasattr(wrapped_self, "post"):
534
- logger.info("Validating post-conditions without remedy...")
535
- op_start = time.perf_counter()
536
- try:
537
- assert wrapped_self.post(output)
538
- except Exception as e:
539
- logger.exception("Post-condition validation failed!")
540
- raise e
541
- finally:
542
- wrapped_self._contract_timing[it]["output_validation"] = time.perf_counter() - op_start
543
- logger.success("Post-condition validation successful!")
544
- return output
628
+ if hasattr(wrapped_self, "post"):
629
+ logger.info("Validating post-conditions without remedy...")
630
+ op_start = time.perf_counter()
631
+ try:
632
+ assert wrapped_self.post(output)
633
+ except Exception as e:
634
+ logger.exception("Post-condition validation failed!")
635
+ raise e
636
+ finally:
637
+ wrapped_self._contract_timing[it]["output_validation"] = time.perf_counter() - op_start
638
+ logger.success("Post-condition validation successful!")
639
+ return output
545
640
  logger.info("Skip; no post-condition validation was required!")
546
641
  return output
547
642
 
548
643
  def _validate_act_method(self, act_method):
549
644
  act_sig = inspect.signature(act_method)
550
645
  params = list(act_sig.parameters.values())
551
- if not params or params[0].name != 'input':
552
- raise TypeError("'act' method first parameter must be named 'input'")
553
- if params[0].annotation == inspect._empty:
554
- raise TypeError("'act' method parameter 'input' must have a type annotation'")
646
+
647
+ first_param = None
648
+ for param in params:
649
+ if param.name == "self":
650
+ continue
651
+ if param.kind in (
652
+ inspect.Parameter.POSITIONAL_ONLY,
653
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
654
+ ):
655
+ first_param = param
656
+ break
657
+
658
+ if first_param is None:
659
+ msg = "'act' method must accept at least one positional parameter after `self`."
660
+ UserMessage(msg)
661
+ raise TypeError(msg)
662
+ if first_param.annotation == inspect._empty:
663
+ msg = f"'act' method parameter '{first_param.name}' must have a type annotation."
664
+ UserMessage(msg)
665
+ raise TypeError(msg)
555
666
  if act_sig.return_annotation == inspect._empty:
556
- raise TypeError("'act' method must have a return type annotation'")
667
+ msg = "'act' method must have a return type annotation'"
668
+ UserMessage(msg)
669
+ raise TypeError(msg)
557
670
  return True
558
671
 
559
- def _act(self, wrapped_self, input, it, **act_kwargs):
672
+ def _act(self, wrapped_self, input_value, it, **act_kwargs):
560
673
  act_method = getattr(wrapped_self, 'act', None)
561
674
  if not callable(act_method):
562
675
  # Propagate the input if no act method is defined
563
- return input
676
+ return input_value
564
677
 
565
678
  assert self._validate_act_method(act_method)
566
679
 
567
- is_dynamic_model = getattr(input, '_is_dynamic_model', False)
568
- input = input if not is_dynamic_model else input.value
680
+ is_dynamic_model = getattr(input_value, '_is_dynamic_model', False)
681
+ input_value = input_value if not is_dynamic_model else input_value.value
569
682
 
570
683
  logger.info(f"Executing 'act' method on {wrapped_self.__class__.__name__}…")
571
684
 
572
685
  op_start = time.perf_counter()
573
686
  try:
574
- act_output = act_method(input, **act_kwargs)
687
+ act_output = act_method(input_value, **act_kwargs)
575
688
  except Exception as e:
576
- logger.exception(f"'act' method execution failed")
689
+ logger.exception("'act' method execution failed")
577
690
  raise e
578
691
  finally:
579
692
  wrapped_self._contract_timing[it]["act_execution"] = time.perf_counter() - op_start
580
693
 
581
694
  act_sig = inspect.signature(act_method)
582
- if act_sig.return_annotation != inspect.Signature.empty and inspect.isclass(act_sig.return_annotation):
583
- if not isinstance(act_output, act_sig.return_annotation):
584
- raise TypeError(f"'act' method returned {type(act_output).__name__}, expected {act_sig.return_annotation.__name__}.")
695
+ if (
696
+ act_sig.return_annotation != inspect.Signature.empty
697
+ and inspect.isclass(act_sig.return_annotation)
698
+ and not isinstance(act_output, act_sig.return_annotation)
699
+ ):
700
+ msg = f"'act' method returned {type(act_output).__name__}, expected {act_sig.return_annotation.__name__}."
701
+ UserMessage(msg)
702
+ raise TypeError(msg)
585
703
 
586
704
  logger.success("'act' method executed successfully!")
587
705
  return act_output
588
706
 
589
- def __call__(self, cls):
590
- original_init = cls.__init__
591
- original_forward = cls.forward
592
-
707
+ def _build_wrapped_init(self, original_init):
593
708
  def __init__(wrapped_self, *args, **kwargs):
594
709
  logger.info("Initializing contract...")
595
710
  original_init(wrapped_self, *args, **kwargs)
596
711
 
597
712
  if not hasattr(wrapped_self, "prompt"):
598
713
  logger.error("Prompt attribute not defined!")
599
- raise Exception("Please define a static `prompt` attribute that describes what the contract must do.")
714
+ msg = "Please define a static `prompt` attribute that describes what the contract must do."
715
+ UserMessage(msg)
716
+ raise Exception(msg)
600
717
 
601
718
  wrapped_self.contract_successful = False
602
719
  wrapped_self.contract_result = None
@@ -604,105 +721,207 @@ class contract:
604
721
  wrapped_self._contract_timing = defaultdict(dict)
605
722
  logger.info("Contract initialization complete!")
606
723
 
607
- def wrapped_forward(wrapped_self, **kwargs):
608
- logger.info("Starting contract execution...")
609
- it = len(wrapped_self._contract_timing) # the len is the __call__ op_start
610
- contract_start = time.perf_counter()
724
+ return __init__
725
+
726
+ def _start_contract_execution(self, wrapped_self):
727
+ logger.info("Starting contract execution...")
728
+ it = len(wrapped_self._contract_timing) # the len is the __call__ op_start
729
+ return it, time.perf_counter()
730
+
731
+ def _find_input_param_name(self, sig: inspect.Signature) -> str | None:
732
+ for param in sig.parameters.values():
733
+ if param.name == "self":
734
+ continue
735
+ if param.kind in (
736
+ inspect.Parameter.POSITIONAL_ONLY,
737
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
738
+ inspect.Parameter.KEYWORD_ONLY,
739
+ ):
740
+ return param.name
741
+ return None
611
742
 
612
- input = kwargs.pop("input", None)
613
- input_type = None
614
- try:
615
- assert self._is_valid_input(input)
616
- except TypeError:
617
- input_type = self._try_dynamic_type_annotation(original_forward, context="input")
618
- input = input_type(value=input)
619
-
620
- maybe_payload = getattr(wrapped_self, "payload", None)
621
- maybe_template = getattr(wrapped_self, "template")
622
- if inspect.ismethod(maybe_template):
623
- # `template` is a primitive in symbolicai case in which we actually don't have a template
624
- maybe_template = None
625
-
626
- # Create validation kwargs that include all original kwargs plus payload and template
627
- validation_kwargs = {
628
- **kwargs,
629
- "payload": maybe_payload,
630
- "template_suffix": maybe_template
631
- }
743
+ def _prepare_forward_args(self, args, kwargs):
744
+ args_list = list(args)
745
+ kwargs_without_input = dict(kwargs)
746
+ original_kwargs = dict(kwargs)
747
+ return args_list, kwargs_without_input, original_kwargs
632
748
 
633
- sig = inspect.signature(original_forward)
634
- output_type = sig.return_annotation
635
- try:
636
- assert self._is_valid_output(output_type)
637
- except TypeError:
638
- output_type = self._try_dynamic_type_annotation(original_forward, context="output")
749
+ def _extract_input_value(
750
+ self,
751
+ args_list,
752
+ kwargs_without_input,
753
+ original_kwargs,
754
+ input_param_name: str | None,
755
+ ):
756
+ if args_list:
757
+ return args_list[0], ("args", 0)
758
+ if input_param_name and input_param_name in kwargs_without_input:
759
+ return kwargs_without_input.pop(input_param_name), ("kwargs", input_param_name)
760
+ if "input" in kwargs_without_input:
761
+ return kwargs_without_input.pop("input"), ("kwargs", "input")
762
+ return original_kwargs.get("input"), ("fallback_kw", "input")
763
+
764
+ def _coerce_input_value(self, original_forward, input_value):
765
+ try:
766
+ assert self._is_valid_input(input_value)
767
+ return input_value
768
+ except TypeError:
769
+ input_type_model = self._try_dynamic_type_annotation(original_forward, context="input")
770
+ return input_type_model(value=input_value)
771
+
772
+ def _collect_validation_kwargs(self, wrapped_self, kwargs_without_input):
773
+ maybe_payload = getattr(wrapped_self, "payload", None)
774
+ maybe_template = getattr(wrapped_self, "template", None)
775
+ if inspect.ismethod(maybe_template):
776
+ maybe_template = None
777
+ return {
778
+ **kwargs_without_input,
779
+ "payload": maybe_payload,
780
+ "template_suffix": maybe_template,
781
+ }
782
+
783
+ def _resolve_output_type(self, sig: inspect.Signature, original_forward):
784
+ output_type = sig.return_annotation
785
+ try:
786
+ assert self._is_valid_output(output_type)
787
+ except TypeError:
788
+ output_type = self._try_dynamic_type_annotation(original_forward, context="output")
789
+ return output_type
639
790
 
640
- output = None
641
- current_input = input
642
- try:
643
- # 1. Start with original input and apply pre-validation
644
- maybe_new_input = self._validate_input(wrapped_self, current_input, it, **validation_kwargs)
645
- if maybe_new_input is not None:
646
- current_input = maybe_new_input
791
+ def _run_contract_pipeline(
792
+ self,
793
+ wrapped_self,
794
+ current_input_value,
795
+ output_type,
796
+ it,
797
+ validation_kwargs,
798
+ ):
799
+ output = None
800
+ try:
801
+ maybe_new_input = self._validate_input(wrapped_self, current_input_value, it, **validation_kwargs)
802
+ if maybe_new_input is not None:
803
+ current_input_value = maybe_new_input
804
+
805
+ current_input_value = self._act(wrapped_self, current_input_value, it, **validation_kwargs)
806
+
807
+ output = self._validate_output(
808
+ wrapped_self,
809
+ current_input_value,
810
+ output_type,
811
+ it,
812
+ **validation_kwargs,
813
+ )
814
+ wrapped_self.contract_successful = output is not None
815
+ wrapped_self.contract_result = output
816
+ wrapped_self.contract_exception = None
817
+ except Exception as exc:
818
+ logger.exception("Contract execution failed in main path!")
819
+ wrapped_self.contract_successful = False
820
+ wrapped_self.contract_exception = exc
821
+ return output, current_input_value
647
822
 
648
- # 2. Check if 'act' method exists and execute it
649
- current_input = self._act(wrapped_self, current_input, it, **validation_kwargs)
823
+ def _execute_forward_call(
824
+ self,
825
+ wrapped_self,
826
+ original_forward,
827
+ args_list,
828
+ original_kwargs,
829
+ input_param_name,
830
+ input_source,
831
+ forward_input_value,
832
+ it,
833
+ contract_start,
834
+ ):
835
+ forward_kwargs = original_kwargs.copy()
836
+ for internal_kw in self._internal_forward_kwargs:
837
+ forward_kwargs.pop(internal_kw, None)
838
+
839
+ logger.info("Executing original forward method...")
840
+
841
+ if input_param_name:
842
+ if input_param_name in forward_kwargs or input_source == ("kwargs", input_param_name):
843
+ forward_kwargs[input_param_name] = forward_input_value
844
+ elif input_source == ("args", 0) and args_list:
845
+ args_list[0] = forward_input_value
846
+ else:
847
+ forward_kwargs[input_param_name] = forward_input_value
848
+ else:
849
+ forward_kwargs['input'] = forward_input_value
650
850
 
651
- # 3. Validate output type and prepare for original_forward
652
- output = self._validate_output(wrapped_self, current_input, output_type, it, **validation_kwargs)
653
- wrapped_self.contract_successful = True if output is not None else False
654
- wrapped_self.contract_result = output
655
- wrapped_self.contract_exception = None
851
+ if input_param_name and input_param_name != "input" and "input" in forward_kwargs:
852
+ forward_kwargs.pop("input")
656
853
 
657
- except Exception as e:
658
- logger.exception(f"Contract execution failed in main path!")
659
- wrapped_self.contract_successful = False
660
- wrapped_self.contract_exception = e
661
- # contract_result remains None or its value before the exception.
662
- # final_output remains None or its value before the exception.
663
- # The finally block's execution of original_forward will determine the actual returned value.
664
- finally:
665
- # Execute the original forward method with appropriate input
666
- logger.info("Executing original forward method...")
667
-
668
- # If contract was successful, use the processed input (after pre-validation and act, both optional)
669
- # `current_input` at this stage is the result of the try block's processing up to the point of exception,
670
- # or the full processing if successful.
671
- # If contract failed, use original_input (fallback).
672
- forward_input = current_input if wrapped_self.contract_successful else input
673
-
674
- # Prepare kwargs for original_forward
675
- forward_kwargs = kwargs.copy()
676
- forward_kwargs['input'] = forward_input
677
-
678
- try:
679
- op_start = time.perf_counter()
680
- output = original_forward(wrapped_self, **forward_kwargs)
681
- finally:
682
- wrapped_self._contract_timing[it]["forward_execution"] = time.perf_counter() - op_start
683
- wrapped_self._contract_timing[it]["contract_execution"] = time.perf_counter() - contract_start
684
-
685
- if not isinstance(output, output_type):
686
- logger.error(f"Output type mismatch: {type(output)}")
687
- if self.remedy_retry_params["graceful"]:
688
- # In graceful mode, skip type mismatch error and return raw output
689
- if hasattr(output_type, '_is_dynamic_model') and output_type._is_dynamic_model and hasattr(output, 'value'):
690
- return output.value
691
- return output
692
- raise TypeError(
693
- f"Expected output to be an instance of {output_type}, "
694
- f"but got {type(output)}! Forward method must return an instance of {output_type}!"
695
- )
696
- if not wrapped_self.contract_successful:
697
- logger.warning("Contract validation failed!")
698
- else:
699
- logger.success("Contract validation successful!")
854
+ try:
855
+ op_start = time.perf_counter()
856
+ output = original_forward(wrapped_self, *args_list, **forward_kwargs)
857
+ finally:
858
+ wrapped_self._contract_timing[it]["forward_execution"] = time.perf_counter() - op_start
859
+ wrapped_self._contract_timing[it]["contract_execution"] = time.perf_counter() - contract_start
860
+ return output
700
861
 
701
- if hasattr(output_type, '_is_dynamic_model') and output_type._is_dynamic_model:
702
- return output.value
862
+ def _finalize_contract_output(self, output, output_type, wrapped_self):
863
+ if not isinstance(output, output_type):
864
+ logger.error(f"Output type mismatch: {type(output)}")
865
+ if self.remedy_retry_params["graceful"]:
866
+ if getattr(output_type, '_is_dynamic_model', False) and hasattr(output, 'value'):
867
+ return output.value
868
+ return output
869
+ msg = (
870
+ f"Expected output to be an instance of {output_type}, "
871
+ f"but got {type(output)}! Forward method must return an instance of {output_type}!"
872
+ )
873
+ UserMessage(msg)
874
+ raise TypeError(msg)
875
+ if not wrapped_self.contract_successful:
876
+ logger.warning("Contract validation failed!")
877
+ else:
878
+ logger.success("Contract validation successful!")
703
879
 
704
- return output
880
+ if getattr(output_type, '_is_dynamic_model', False):
881
+ return output.value
882
+ return output
705
883
 
884
+ def _contract_forward_impl(self, wrapped_self, original_forward, *args, **kwargs):
885
+ it, contract_start = self._start_contract_execution(wrapped_self)
886
+ sig = inspect.signature(original_forward)
887
+ input_param_name = self._find_input_param_name(sig)
888
+ args_list, kwargs_without_input, original_kwargs = self._prepare_forward_args(args, kwargs)
889
+ input_value, input_source = self._extract_input_value(args_list, kwargs_without_input, original_kwargs, input_param_name)
890
+ current_input_value = self._coerce_input_value(original_forward, input_value)
891
+ input_value = current_input_value
892
+ validation_kwargs = self._collect_validation_kwargs(wrapped_self, kwargs_without_input)
893
+ output_type = self._resolve_output_type(sig, original_forward)
894
+
895
+ output, current_input_value = self._run_contract_pipeline(
896
+ wrapped_self,
897
+ current_input_value,
898
+ output_type,
899
+ it,
900
+ validation_kwargs,
901
+ )
902
+
903
+ forward_input_value = current_input_value if wrapped_self.contract_successful else input_value
904
+ output = self._execute_forward_call(
905
+ wrapped_self,
906
+ original_forward,
907
+ args_list,
908
+ original_kwargs,
909
+ input_param_name,
910
+ input_source,
911
+ forward_input_value,
912
+ it,
913
+ contract_start,
914
+ )
915
+
916
+ return self._finalize_contract_output(output, output_type, wrapped_self)
917
+
918
+ def _build_wrapped_forward(self, original_forward):
919
+ def wrapped_forward(wrapped_self, *args, **kwargs):
920
+ return self._contract_forward_impl(wrapped_self, original_forward, *args, **kwargs)
921
+
922
+ return wrapped_forward
923
+
924
+ def _build_contract_perf_stats(self):
706
925
  def contract_perf_stats(wrapped_self):
707
926
  """Analyzes and prints timing statistics across all forward calls."""
708
927
  console = Console()
@@ -828,17 +1047,29 @@ class contract:
828
1047
 
829
1048
  return stats
830
1049
 
831
- cls.__init__ = __init__
832
- cls.forward = wrapped_forward
833
- cls.contract_perf_stats = contract_perf_stats
1050
+ return contract_perf_stats
1051
+
1052
+ def __call__(self, cls):
1053
+ original_init = cls.__init__
1054
+ original_forward = cls.forward
834
1055
 
1056
+ cls.__init__ = self._build_wrapped_init(original_init)
1057
+ cls.forward = self._build_wrapped_forward(original_forward)
1058
+ cls.contract_perf_stats = self._build_contract_perf_stats()
835
1059
  return cls
836
1060
 
837
1061
 
838
1062
  class BaseStrategy(TypeValidationFunction):
839
- def __init__(self, data_model: BaseModel, *args, **kwargs):
1063
+ def __init__(self, data_model: BaseModel, *_args, **kwargs):
840
1064
  super().__init__(
841
- retry_params=dict(tries=8, delay=0.015, backoff=1.25, jitter=0.0, max_delay=0.25, graceful=False),
1065
+ retry_params={
1066
+ "tries": 8,
1067
+ "delay": 0.015,
1068
+ "backoff": 1.25,
1069
+ "jitter": 0.0,
1070
+ "max_delay": 0.25,
1071
+ "graceful": False,
1072
+ },
842
1073
  **kwargs,
843
1074
  )
844
1075
  super().register_expected_data_model(data_model, attach_to="output")
@@ -851,14 +1082,13 @@ class BaseStrategy(TypeValidationFunction):
851
1082
  pass
852
1083
 
853
1084
  def forward(self, *args, **kwargs):
854
- result = super().forward(
1085
+ return super().forward(
855
1086
  *args,
856
1087
  payload=self.payload,
857
1088
  template_suffix=self.template,
858
1089
  response_format={"type": "json_object"},
859
1090
  **kwargs,
860
1091
  )
861
- return result
862
1092
 
863
1093
  @property
864
1094
  def payload(self):
@@ -866,7 +1096,7 @@ class BaseStrategy(TypeValidationFunction):
866
1096
 
867
1097
  @property
868
1098
  def static_context(self):
869
- raise NotImplementedError()
1099
+ raise NotImplementedError
870
1100
 
871
1101
  @property
872
1102
  def template(self):
@@ -878,7 +1108,7 @@ class Strategy(Expression):
878
1108
  super().__init__(*args, **kwargs)
879
1109
  self.logger = logging.getLogger(__name__)
880
1110
 
881
- def __new__(self, module: str, *args, **kwargs):
882
- self._module = module
883
- self.module_path = 'symai.extended.strategies'
884
- return Strategy.load_module_class(self.module)
1111
+ def __new__(cls, module: str, *_args, **_kwargs):
1112
+ cls._module = module
1113
+ cls.module_path = 'symai.extended.strategies'
1114
+ return Strategy.load_module_class(cls.module)