symbolicai 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. symai/__init__.py +198 -134
  2. symai/backend/base.py +51 -51
  3. symai/backend/engines/drawing/engine_bfl.py +33 -33
  4. symai/backend/engines/drawing/engine_gpt_image.py +4 -10
  5. symai/backend/engines/embedding/engine_llama_cpp.py +50 -35
  6. symai/backend/engines/embedding/engine_openai.py +22 -16
  7. symai/backend/engines/execute/engine_python.py +16 -16
  8. symai/backend/engines/files/engine_io.py +51 -49
  9. symai/backend/engines/imagecaptioning/engine_blip2.py +27 -23
  10. symai/backend/engines/imagecaptioning/engine_llavacpp_client.py +53 -46
  11. symai/backend/engines/index/engine_pinecone.py +116 -88
  12. symai/backend/engines/index/engine_qdrant.py +1011 -0
  13. symai/backend/engines/index/engine_vectordb.py +78 -52
  14. symai/backend/engines/lean/engine_lean4.py +65 -25
  15. symai/backend/engines/neurosymbolic/__init__.py +28 -28
  16. symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_chat.py +137 -135
  17. symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_reasoning.py +145 -152
  18. symai/backend/engines/neurosymbolic/engine_cerebras.py +328 -0
  19. symai/backend/engines/neurosymbolic/engine_deepseekX_reasoning.py +75 -49
  20. symai/backend/engines/neurosymbolic/engine_google_geminiX_reasoning.py +199 -155
  21. symai/backend/engines/neurosymbolic/engine_groq.py +106 -72
  22. symai/backend/engines/neurosymbolic/engine_huggingface.py +100 -67
  23. symai/backend/engines/neurosymbolic/engine_llama_cpp.py +121 -93
  24. symai/backend/engines/neurosymbolic/engine_openai_gptX_chat.py +213 -132
  25. symai/backend/engines/neurosymbolic/engine_openai_gptX_reasoning.py +180 -137
  26. symai/backend/engines/ocr/engine_apilayer.py +18 -20
  27. symai/backend/engines/output/engine_stdout.py +9 -9
  28. symai/backend/engines/{webscraping → scrape}/engine_requests.py +25 -11
  29. symai/backend/engines/search/engine_openai.py +95 -83
  30. symai/backend/engines/search/engine_parallel.py +665 -0
  31. symai/backend/engines/search/engine_perplexity.py +40 -41
  32. symai/backend/engines/search/engine_serpapi.py +33 -28
  33. symai/backend/engines/speech_to_text/engine_local_whisper.py +37 -27
  34. symai/backend/engines/symbolic/engine_wolframalpha.py +14 -8
  35. symai/backend/engines/text_to_speech/engine_openai.py +15 -19
  36. symai/backend/engines/text_vision/engine_clip.py +34 -28
  37. symai/backend/engines/userinput/engine_console.py +3 -4
  38. symai/backend/mixin/anthropic.py +48 -40
  39. symai/backend/mixin/deepseek.py +4 -5
  40. symai/backend/mixin/google.py +5 -4
  41. symai/backend/mixin/groq.py +2 -4
  42. symai/backend/mixin/openai.py +132 -110
  43. symai/backend/settings.py +14 -14
  44. symai/chat.py +164 -94
  45. symai/collect/dynamic.py +13 -11
  46. symai/collect/pipeline.py +39 -31
  47. symai/collect/stats.py +109 -69
  48. symai/components.py +556 -238
  49. symai/constraints.py +14 -5
  50. symai/core.py +1495 -1210
  51. symai/core_ext.py +55 -50
  52. symai/endpoints/api.py +113 -58
  53. symai/extended/api_builder.py +22 -17
  54. symai/extended/arxiv_pdf_parser.py +13 -5
  55. symai/extended/bibtex_parser.py +8 -4
  56. symai/extended/conversation.py +88 -69
  57. symai/extended/document.py +40 -27
  58. symai/extended/file_merger.py +45 -7
  59. symai/extended/graph.py +38 -24
  60. symai/extended/html_style_template.py +17 -11
  61. symai/extended/interfaces/blip_2.py +1 -1
  62. symai/extended/interfaces/clip.py +4 -2
  63. symai/extended/interfaces/console.py +5 -3
  64. symai/extended/interfaces/dall_e.py +3 -1
  65. symai/extended/interfaces/file.py +2 -0
  66. symai/extended/interfaces/flux.py +3 -1
  67. symai/extended/interfaces/gpt_image.py +15 -6
  68. symai/extended/interfaces/input.py +2 -1
  69. symai/extended/interfaces/llava.py +1 -1
  70. symai/extended/interfaces/{naive_webscraping.py → naive_scrape.py} +3 -2
  71. symai/extended/interfaces/naive_vectordb.py +2 -2
  72. symai/extended/interfaces/ocr.py +4 -2
  73. symai/extended/interfaces/openai_search.py +2 -0
  74. symai/extended/interfaces/parallel.py +30 -0
  75. symai/extended/interfaces/perplexity.py +2 -0
  76. symai/extended/interfaces/pinecone.py +6 -4
  77. symai/extended/interfaces/python.py +2 -0
  78. symai/extended/interfaces/serpapi.py +2 -0
  79. symai/extended/interfaces/terminal.py +0 -1
  80. symai/extended/interfaces/tts.py +2 -1
  81. symai/extended/interfaces/whisper.py +2 -1
  82. symai/extended/interfaces/wolframalpha.py +1 -0
  83. symai/extended/metrics/__init__.py +1 -1
  84. symai/extended/metrics/similarity.py +5 -2
  85. symai/extended/os_command.py +31 -22
  86. symai/extended/packages/symdev.py +39 -34
  87. symai/extended/packages/sympkg.py +30 -27
  88. symai/extended/packages/symrun.py +46 -35
  89. symai/extended/repo_cloner.py +10 -9
  90. symai/extended/seo_query_optimizer.py +15 -12
  91. symai/extended/solver.py +104 -76
  92. symai/extended/summarizer.py +8 -7
  93. symai/extended/taypan_interpreter.py +10 -9
  94. symai/extended/vectordb.py +28 -15
  95. symai/formatter/formatter.py +39 -31
  96. symai/formatter/regex.py +46 -44
  97. symai/functional.py +184 -86
  98. symai/imports.py +85 -51
  99. symai/interfaces.py +1 -1
  100. symai/memory.py +33 -24
  101. symai/menu/screen.py +28 -19
  102. symai/misc/console.py +27 -27
  103. symai/misc/loader.py +4 -3
  104. symai/models/base.py +147 -76
  105. symai/models/errors.py +1 -1
  106. symai/ops/__init__.py +1 -1
  107. symai/ops/measures.py +17 -14
  108. symai/ops/primitives.py +933 -635
  109. symai/post_processors.py +28 -24
  110. symai/pre_processors.py +58 -52
  111. symai/processor.py +15 -9
  112. symai/prompts.py +714 -649
  113. symai/server/huggingface_server.py +115 -32
  114. symai/server/llama_cpp_server.py +14 -6
  115. symai/server/qdrant_server.py +206 -0
  116. symai/shell.py +98 -39
  117. symai/shellsv.py +307 -223
  118. symai/strategy.py +135 -81
  119. symai/symbol.py +276 -225
  120. symai/utils.py +62 -46
  121. {symbolicai-1.0.0.dist-info → symbolicai-1.1.0.dist-info}/METADATA +19 -9
  122. symbolicai-1.1.0.dist-info/RECORD +168 -0
  123. symbolicai-1.0.0.dist-info/RECORD +0 -163
  124. {symbolicai-1.0.0.dist-info → symbolicai-1.1.0.dist-info}/WHEEL +0 -0
  125. {symbolicai-1.0.0.dist-info → symbolicai-1.1.0.dist-info}/entry_points.txt +0 -0
  126. {symbolicai-1.0.0.dist-info → symbolicai-1.1.0.dist-info}/licenses/LICENSE +0 -0
  127. {symbolicai-1.0.0.dist-info → symbolicai-1.1.0.dist-info}/top_level.txt +0 -0
symai/strategy.py CHANGED
@@ -29,6 +29,7 @@ class ValidationFunction(Function):
29
29
  • Pause/backoff logic
30
30
  • Error simplification
31
31
  """
32
+
32
33
  # Have some default retry params that don't add overhead
33
34
  _default_retry_params: ClassVar[dict[str, int | float | bool]] = {
34
35
  "tries": 8,
@@ -93,9 +94,7 @@ class ValidationFunction(Function):
93
94
  seed = 42
94
95
 
95
96
  rnd = np.random.RandomState(seed=seed)
96
- return rnd.randint(
97
- 0, np.iinfo(np.int16).max, size=num_seeds, dtype=np.int16
98
- ).tolist()
97
+ return rnd.randint(0, np.iinfo(np.int16).max, size=num_seeds, dtype=np.int16).tolist()
99
98
 
100
99
  def simplify_validation_errors(self, error: ValidationError) -> str:
101
100
  """
@@ -123,11 +122,13 @@ class ValidationFunction(Function):
123
122
  return "\n".join(simplified_errors)
124
123
 
125
124
  def _pause(self, attempt):
126
- base = self.retry_params['delay'] * (self.retry_params['backoff'] ** attempt)
127
- jit = (np.random.uniform(*self.retry_params['jitter'])
128
- if isinstance(self.retry_params['jitter'], tuple)
129
- else self.retry_params['jitter'])
130
- _delay = min(base + jit, self.retry_params['max_delay'])
125
+ base = self.retry_params["delay"] * (self.retry_params["backoff"] ** attempt)
126
+ jit = (
127
+ np.random.uniform(*self.retry_params["jitter"])
128
+ if isinstance(self.retry_params["jitter"], tuple)
129
+ else self.retry_params["jitter"]
130
+ )
131
+ _delay = min(base + jit, self.retry_params["max_delay"])
131
132
  time.sleep(_delay)
132
133
 
133
134
  def remedy_prompt(self, *_args, **_kwargs):
@@ -139,7 +140,7 @@ class ValidationFunction(Function):
139
140
  UserMessage(msg)
140
141
  raise NotImplementedError(msg)
141
142
 
142
- def display_panel(self, content, title, border_style="cyan", style="#f0eee6", padding=(1,2)):
143
+ def display_panel(self, content, title, border_style="cyan", style="#f0eee6", padding=(1, 2)):
143
144
  """
144
145
  Display content in a rich panel with consistent formatting.
145
146
 
@@ -151,11 +152,13 @@ class ValidationFunction(Function):
151
152
  padding: Padding for the panel (default: (1,2))
152
153
  """
153
154
  body = escape(content)
154
- panel = Panel.fit(body, title=title, padding=padding, border_style=border_style, style=style)
155
+ panel = Panel.fit(
156
+ body, title=title, padding=padding, border_style=border_style, style=style
157
+ )
155
158
  self.console.print(panel)
156
159
 
157
160
  def forward(self, *args, **kwargs):
158
- return super().forward(*args, **kwargs) # Just propagate to Function
161
+ return super().forward(*args, **kwargs) # Just propagate to Function
159
162
 
160
163
 
161
164
  class TypeValidationFunction(ValidationFunction):
@@ -165,6 +168,7 @@ class TypeValidationFunction(ValidationFunction):
165
168
  if a user provides a callable designed to semantically validate the
166
169
  structure of the type-validated data.
167
170
  """
171
+
168
172
  def __init__(
169
173
  self,
170
174
  retry_params: dict[str, int | float | bool] = ValidationFunction._default_retry_params,
@@ -179,8 +183,12 @@ class TypeValidationFunction(ValidationFunction):
179
183
  self.accumulate_errors = accumulate_errors
180
184
  self.verbose = verbose
181
185
 
182
- def register_expected_data_model(self, data_model: LLMDataModel, attach_to: str, override: bool = False):
183
- assert attach_to in ["input", "output"], f"Invalid attach_to value: {attach_to}; must be either 'input' or 'output'"
186
+ def register_expected_data_model(
187
+ self, data_model: LLMDataModel, attach_to: str, override: bool = False
188
+ ):
189
+ assert attach_to in ["input", "output"], (
190
+ f"Invalid attach_to value: {attach_to}; must be either 'input' or 'output'"
191
+ )
184
192
  if attach_to == "input":
185
193
  if self.input_data_model is not None and not override:
186
194
  msg = "There is already a data model attached to the input. If you want to override it, set `override=True`."
@@ -206,12 +214,12 @@ Your prompt was:
206
214
 
207
215
  The input data model is:
208
216
  <input_data_model>
209
- {self.input_data_model.simplify_json_schema() if self.input_data_model is not None else 'N/A'}
217
+ {self.input_data_model.simplify_json_schema() if self.input_data_model is not None else "N/A"}
210
218
  </input_data_model>
211
219
 
212
220
  The given input was:
213
221
  <input>
214
- {str(self.input_data_model) if self.input_data_model is not None else 'N/A'}
222
+ {str(self.input_data_model) if self.input_data_model is not None else "N/A"}
215
223
  </input>
216
224
 
217
225
  The output data model is:
@@ -253,12 +261,12 @@ You are given the following prompt:
253
261
 
254
262
  The input data model is:
255
263
  <input_data_model>
256
- {self.input_data_model.simplify_json_schema() if self.input_data_model is not None else 'N/A'}
264
+ {self.input_data_model.simplify_json_schema() if self.input_data_model is not None else "N/A"}
257
265
  </input_data_model>
258
266
 
259
267
  The given input is:
260
268
  <input>
261
- {str(self.input_data_model) if self.input_data_model is not None else 'N/A'}
269
+ {str(self.input_data_model) if self.input_data_model is not None else "N/A"}
262
270
  </input>
263
271
 
264
272
  The output data model is:
@@ -288,7 +296,10 @@ Important guidelines:
288
296
  return
289
297
  for label, body in [
290
298
  ("Prompt", prompt),
291
- ("Input data model", self.input_data_model.simplify_json_schema() if self.input_data_model else 'N/A'),
299
+ (
300
+ "Input data model",
301
+ self.input_data_model.simplify_json_schema() if self.input_data_model else "N/A",
302
+ ),
292
303
  ("Output data model", self.output_data_model.simplify_json_schema()),
293
304
  ]:
294
305
  self.display_panel(body, title=label)
@@ -302,7 +313,11 @@ Important guidelines:
302
313
  return None
303
314
  try:
304
315
  assert all(
305
- f(result if not getattr(self.output_data_model, '_is_dynamic_model', False) else result.value)
316
+ f(
317
+ result
318
+ if not getattr(self.output_data_model, "_is_dynamic_model", False)
319
+ else result.value
320
+ )
306
321
  for f in f_semantic_conditions
307
322
  )
308
323
  except Exception as err:
@@ -364,7 +379,9 @@ Important guidelines:
364
379
  total_attempts = self.retry_params["tries"] + 1
365
380
  for attempt in range(total_attempts):
366
381
  if attempt != self.retry_params["tries"]:
367
- logger.info(f"Attempt {attempt + 1}/{self.retry_params['tries']}: Attempting validation…")
382
+ logger.info(
383
+ f"Attempt {attempt + 1}/{self.retry_params['tries']}: Attempting validation…"
384
+ )
368
385
  try:
369
386
  result = self.output_data_model.model_validate_json(
370
387
  json_str,
@@ -393,7 +410,7 @@ Important guidelines:
393
410
 
394
411
  def _handle_validation_failure(self, prompt: str, json_str: str, errors: list[str]):
395
412
  logger.error("All validation attempts failed!")
396
- if self.retry_params['graceful']:
413
+ if self.retry_params["graceful"]:
397
414
  return
398
415
  raise TypeValidationError(
399
416
  prompt=prompt,
@@ -401,9 +418,11 @@ Important guidelines:
401
418
  violations=errors,
402
419
  )
403
420
 
404
- def forward(self, prompt: str, f_semantic_conditions: list[Callable] | None = None, *args, **kwargs):
421
+ def forward(
422
+ self, prompt: str, f_semantic_conditions: list[Callable] | None = None, *args, **kwargs
423
+ ):
405
424
  self._ensure_output_model()
406
- validation_context = kwargs.pop('validation_context', {})
425
+ validation_context = kwargs.pop("validation_context", {})
407
426
  kwargs["response_format"] = {"type": "json_object"}
408
427
  logger.info("Initializing validation…")
409
428
  self._display_verbose_panels(prompt)
@@ -451,15 +470,17 @@ class contract:
451
470
  verbose: bool = False,
452
471
  remedy_retry_params: dict[str, int | float | bool] = _default_remedy_retry_params,
453
472
  ):
454
- '''
473
+ """
455
474
  A contract class decorator inspired by DbC principles. It ensures that the function's input and output
456
475
  adhere to specified data models both syntactically and semantically. This implementation includes retry
457
476
  logic to handle transient errors and gracefully handle failures.
458
- '''
477
+ """
459
478
  self.pre_remedy = pre_remedy
460
479
  self.post_remedy = post_remedy
461
480
  self.remedy_retry_params = remedy_retry_params
462
- self.f_type_validation_remedy = TypeValidationFunction(accumulate_errors=accumulate_errors, verbose=verbose, retry_params=remedy_retry_params)
481
+ self.f_type_validation_remedy = TypeValidationFunction(
482
+ accumulate_errors=accumulate_errors, verbose=verbose, retry_params=remedy_retry_params
483
+ )
463
484
 
464
485
  if not verbose:
465
486
  logger.disable(__name__)
@@ -540,7 +561,9 @@ class contract:
540
561
 
541
562
  def _try_remedy_with_exception(self, prompt, f_semantic_conditions, **remedy_kwargs):
542
563
  try:
543
- data_model = self.f_type_validation_remedy(prompt, f_semantic_conditions=f_semantic_conditions, **remedy_kwargs)
564
+ data_model = self.f_type_validation_remedy(
565
+ prompt, f_semantic_conditions=f_semantic_conditions, **remedy_kwargs
566
+ )
544
567
  except Exception as e:
545
568
  logger.error("Type validation failed with exception!")
546
569
  raise e
@@ -550,7 +573,7 @@ class contract:
550
573
  logger.info("Starting input validation...")
551
574
  if self.pre_remedy:
552
575
  logger.info("Validating pre-conditions with remedy...")
553
- if not hasattr(wrapped_self, 'pre'):
576
+ if not hasattr(wrapped_self, "pre"):
554
577
  logger.error("Pre-condition function not defined!")
555
578
  msg = "Pre-condition function not defined. Please define a `pre` method if you want to enforce pre-conditions through a remedy."
556
579
  UserMessage(msg)
@@ -563,16 +586,20 @@ class contract:
563
586
  return input_value
564
587
  except Exception:
565
588
  logger.exception("Pre-condition validation failed!")
566
- self.f_type_validation_remedy.register_expected_data_model(input_value, attach_to="output", override=True)
589
+ self.f_type_validation_remedy.register_expected_data_model(
590
+ input_value, attach_to="output", override=True
591
+ )
567
592
  input_value = self._try_remedy_with_exception(
568
593
  prompt=wrapped_self.prompt,
569
594
  f_semantic_conditions=[wrapped_self.pre],
570
595
  **remedy_kwargs,
571
596
  )
572
597
  finally:
573
- wrapped_self._contract_timing[it]["input_validation"] = time.perf_counter() - op_start
598
+ wrapped_self._contract_timing[it]["input_validation"] = (
599
+ time.perf_counter() - op_start
600
+ )
574
601
  return input_value
575
- if hasattr(wrapped_self, 'pre'):
602
+ if hasattr(wrapped_self, "pre"):
576
603
  logger.info("Validating pre-conditions without remedy...")
577
604
  op_start = time.perf_counter()
578
605
  try:
@@ -581,7 +608,9 @@ class contract:
581
608
  logger.exception("Pre-condition validation failed")
582
609
  raise e
583
610
  finally:
584
- wrapped_self._contract_timing[it]["input_validation"] = time.perf_counter() - op_start
611
+ wrapped_self._contract_timing[it]["input_validation"] = (
612
+ time.perf_counter() - op_start
613
+ )
585
614
  logger.success("Pre-condition validation successful!")
586
615
  return input_value
587
616
  logger.info("Skip; no pre-condition validation was required!")
@@ -589,14 +618,20 @@ class contract:
589
618
 
590
619
  def _validate_output(self, wrapped_self, input_value, output, it, **remedy_kwargs):
591
620
  logger.info("Starting output validation...")
592
- self.f_type_validation_remedy.register_expected_data_model(input_value, attach_to="input", override=True)
593
- self.f_type_validation_remedy.register_expected_data_model(output, attach_to="output", override=True)
621
+ self.f_type_validation_remedy.register_expected_data_model(
622
+ input_value, attach_to="input", override=True
623
+ )
624
+ self.f_type_validation_remedy.register_expected_data_model(
625
+ output, attach_to="output", override=True
626
+ )
594
627
 
595
628
  op_start = time.perf_counter()
596
629
  try:
597
630
  logger.info("Getting a valid output type...")
598
- output = self._try_remedy_with_exception(prompt=wrapped_self.prompt, f_semantic_conditions=None, **remedy_kwargs)
599
- if output is None: # output is None when graceful mode is enabled
631
+ output = self._try_remedy_with_exception(
632
+ prompt=wrapped_self.prompt, f_semantic_conditions=None, **remedy_kwargs
633
+ )
634
+ if output is None: # output is None when graceful mode is enabled
600
635
  return output
601
636
  except Exception as e:
602
637
  logger.exception("Type creation failed!")
@@ -620,9 +655,15 @@ class contract:
620
655
  return output
621
656
  except Exception:
622
657
  logger.exception("Post-condition validation failed!")
623
- output = self._try_remedy_with_exception(prompt=wrapped_self.prompt, f_semantic_conditions=[wrapped_self.post], **remedy_kwargs)
658
+ output = self._try_remedy_with_exception(
659
+ prompt=wrapped_self.prompt,
660
+ f_semantic_conditions=[wrapped_self.post],
661
+ **remedy_kwargs,
662
+ )
624
663
  finally:
625
- wrapped_self._contract_timing[it]["output_validation"] += (time.perf_counter() - op_start)
664
+ wrapped_self._contract_timing[it]["output_validation"] += (
665
+ time.perf_counter() - op_start
666
+ )
626
667
  logger.success("Post-condition validation successful!")
627
668
  return output
628
669
  if hasattr(wrapped_self, "post"):
@@ -634,7 +675,9 @@ class contract:
634
675
  logger.exception("Post-condition validation failed!")
635
676
  raise e
636
677
  finally:
637
- wrapped_self._contract_timing[it]["output_validation"] = time.perf_counter() - op_start
678
+ wrapped_self._contract_timing[it]["output_validation"] = (
679
+ time.perf_counter() - op_start
680
+ )
638
681
  logger.success("Post-condition validation successful!")
639
682
  return output
640
683
  logger.info("Skip; no post-condition validation was required!")
@@ -670,14 +713,14 @@ class contract:
670
713
  return True
671
714
 
672
715
  def _act(self, wrapped_self, input_value, it, **act_kwargs):
673
- act_method = getattr(wrapped_self, 'act', None)
716
+ act_method = getattr(wrapped_self, "act", None)
674
717
  if not callable(act_method):
675
718
  # Propagate the input if no act method is defined
676
719
  return input_value
677
720
 
678
721
  assert self._validate_act_method(act_method)
679
722
 
680
- is_dynamic_model = getattr(input_value, '_is_dynamic_model', False)
723
+ is_dynamic_model = getattr(input_value, "_is_dynamic_model", False)
681
724
  input_value = input_value if not is_dynamic_model else input_value.value
682
725
 
683
726
  logger.info(f"Executing 'act' method on {wrapped_self.__class__.__name__}…")
@@ -798,11 +841,15 @@ class contract:
798
841
  ):
799
842
  output = None
800
843
  try:
801
- maybe_new_input = self._validate_input(wrapped_self, current_input_value, it, **validation_kwargs)
844
+ maybe_new_input = self._validate_input(
845
+ wrapped_self, current_input_value, it, **validation_kwargs
846
+ )
802
847
  if maybe_new_input is not None:
803
848
  current_input_value = maybe_new_input
804
849
 
805
- current_input_value = self._act(wrapped_self, current_input_value, it, **validation_kwargs)
850
+ current_input_value = self._act(
851
+ wrapped_self, current_input_value, it, **validation_kwargs
852
+ )
806
853
 
807
854
  output = self._validate_output(
808
855
  wrapped_self,
@@ -846,7 +893,7 @@ class contract:
846
893
  else:
847
894
  forward_kwargs[input_param_name] = forward_input_value
848
895
  else:
849
- forward_kwargs['input'] = forward_input_value
896
+ forward_kwargs["input"] = forward_input_value
850
897
 
851
898
  if input_param_name and input_param_name != "input" and "input" in forward_kwargs:
852
899
  forward_kwargs.pop("input")
@@ -856,14 +903,16 @@ class contract:
856
903
  output = original_forward(wrapped_self, *args_list, **forward_kwargs)
857
904
  finally:
858
905
  wrapped_self._contract_timing[it]["forward_execution"] = time.perf_counter() - op_start
859
- wrapped_self._contract_timing[it]["contract_execution"] = time.perf_counter() - contract_start
906
+ wrapped_self._contract_timing[it]["contract_execution"] = (
907
+ time.perf_counter() - contract_start
908
+ )
860
909
  return output
861
910
 
862
911
  def _finalize_contract_output(self, output, output_type, wrapped_self):
863
912
  if not isinstance(output, output_type):
864
913
  logger.error(f"Output type mismatch: {type(output)}")
865
914
  if self.remedy_retry_params["graceful"]:
866
- if getattr(output_type, '_is_dynamic_model', False) and hasattr(output, 'value'):
915
+ if getattr(output_type, "_is_dynamic_model", False) and hasattr(output, "value"):
867
916
  return output.value
868
917
  return output
869
918
  msg = (
@@ -877,7 +926,7 @@ class contract:
877
926
  else:
878
927
  logger.success("Contract validation successful!")
879
928
 
880
- if getattr(output_type, '_is_dynamic_model', False):
929
+ if getattr(output_type, "_is_dynamic_model", False):
881
930
  return output.value
882
931
  return output
883
932
 
@@ -886,7 +935,9 @@ class contract:
886
935
  sig = inspect.signature(original_forward)
887
936
  input_param_name = self._find_input_param_name(sig)
888
937
  args_list, kwargs_without_input, original_kwargs = self._prepare_forward_args(args, kwargs)
889
- input_value, input_source = self._extract_input_value(args_list, kwargs_without_input, original_kwargs, input_param_name)
938
+ input_value, input_source = self._extract_input_value(
939
+ args_list, kwargs_without_input, original_kwargs, input_param_name
940
+ )
890
941
  current_input_value = self._coerce_input_value(original_forward, input_value)
891
942
  input_value = current_input_value
892
943
  validation_kwargs = self._collect_validation_kwargs(wrapped_self, kwargs_without_input)
@@ -900,7 +951,9 @@ class contract:
900
951
  validation_kwargs,
901
952
  )
902
953
 
903
- forward_input_value = current_input_value if wrapped_self.contract_successful else input_value
954
+ forward_input_value = (
955
+ current_input_value if wrapped_self.contract_successful else input_value
956
+ )
904
957
  output = self._execute_forward_call(
905
958
  wrapped_self,
906
959
  original_forward,
@@ -936,7 +989,7 @@ class contract:
936
989
  "act_execution",
937
990
  "output_validation",
938
991
  "forward_execution",
939
- "contract_execution"
992
+ "contract_execution",
940
993
  ]
941
994
 
942
995
  stats = {}
@@ -958,40 +1011,41 @@ class contract:
958
1011
  max_time = max(non_zero_times) if non_zero_times else 0
959
1012
 
960
1013
  stats[op] = {
961
- 'count': actual_count,
962
- 'total': total_time,
963
- 'mean': mean_time,
964
- 'std': std_time,
965
- 'min': min_time,
966
- 'max': max_time
1014
+ "count": actual_count,
1015
+ "total": total_time,
1016
+ "mean": mean_time,
1017
+ "std": std_time,
1018
+ "min": min_time,
1019
+ "max": max_time,
967
1020
  }
968
1021
 
969
- total_execution_time = stats['contract_execution']['total']
1022
+ total_execution_time = stats["contract_execution"]["total"]
970
1023
  for op in ordered_operations[:-1]:
971
1024
  if total_execution_time > 0:
972
- stats[op]['percentage'] = (stats[op]['total'] / total_execution_time) * 100
1025
+ stats[op]["percentage"] = (stats[op]["total"] / total_execution_time) * 100
973
1026
  else:
974
- stats[op]['percentage'] = 0
1027
+ stats[op]["percentage"] = 0
975
1028
 
976
- sum_tracked_times = sum(stats[op]['total'] for op in ordered_operations[:-1])
1029
+ sum_tracked_times = sum(stats[op]["total"] for op in ordered_operations[:-1])
977
1030
  overhead_time = total_execution_time - sum_tracked_times
978
- overhead_percentage = (overhead_time / total_execution_time) * 100 if total_execution_time > 0 else 0
979
-
980
- stats['overhead'] = {
981
- 'count': num_calls,
982
- 'total': overhead_time,
983
- 'mean': overhead_time / num_calls if num_calls > 0 else 0,
984
- 'std': 0,
985
- 'min': 0,
986
- 'max': 0,
987
- 'percentage': overhead_percentage
1031
+ overhead_percentage = (
1032
+ (overhead_time / total_execution_time) * 100 if total_execution_time > 0 else 0
1033
+ )
1034
+
1035
+ stats["overhead"] = {
1036
+ "count": num_calls,
1037
+ "total": overhead_time,
1038
+ "mean": overhead_time / num_calls if num_calls > 0 else 0,
1039
+ "std": 0,
1040
+ "min": 0,
1041
+ "max": 0,
1042
+ "percentage": overhead_percentage,
988
1043
  }
989
1044
 
990
- stats['contract_execution']['percentage'] = 100.0
1045
+ stats["contract_execution"]["percentage"] = 100.0
991
1046
 
992
1047
  table = Table(
993
- title=f"Contract Execution Summary ({num_calls} Forward Calls)",
994
- show_header=True
1048
+ title=f"Contract Execution Summary ({num_calls} Forward Calls)", show_header=True
995
1049
  )
996
1050
  table.add_column("Operation", style="cyan")
997
1051
  table.add_column("Count", justify="right", style="blue")
@@ -1006,29 +1060,29 @@ class contract:
1006
1060
  s = stats[op]
1007
1061
  table.add_row(
1008
1062
  op.replace("_", " ").title(),
1009
- str(s['count']),
1063
+ str(s["count"]),
1010
1064
  f"{s['total']:.3f}",
1011
1065
  f"{s['mean']:.3f}",
1012
1066
  f"{s['std']:.3f}",
1013
1067
  f"{s['min']:.3f}",
1014
1068
  f"{s['max']:.3f}",
1015
- f"{s['percentage']:.1f}%"
1069
+ f"{s['percentage']:.1f}%",
1016
1070
  )
1017
1071
 
1018
- s = stats['overhead']
1072
+ s = stats["overhead"]
1019
1073
  table.add_row(
1020
1074
  "Overhead",
1021
- str(s['count']),
1075
+ str(s["count"]),
1022
1076
  f"{s['total']:.3f}",
1023
1077
  f"{s['mean']:.3f}",
1024
1078
  f"{s['std']:.3f}",
1025
1079
  f"{s['min']:.3f}",
1026
1080
  f"{s['max']:.3f}",
1027
1081
  f"{s['percentage']:.1f}%",
1028
- style="bold blue"
1082
+ style="bold blue",
1029
1083
  )
1030
1084
 
1031
- s = stats['contract_execution']
1085
+ s = stats["contract_execution"]
1032
1086
  table.add_row(
1033
1087
  "Total Execution",
1034
1088
  "N/A",
@@ -1038,7 +1092,7 @@ class contract:
1038
1092
  f"{s['min']:.3f}",
1039
1093
  f"{s['max']:.3f}",
1040
1094
  "100.0%",
1041
- style="bold magenta"
1095
+ style="bold magenta",
1042
1096
  )
1043
1097
 
1044
1098
  console.print("\n")
@@ -1110,5 +1164,5 @@ class Strategy(Expression):
1110
1164
 
1111
1165
  def __new__(cls, module: str, *_args, **_kwargs):
1112
1166
  cls._module = module
1113
- cls.module_path = 'symai.extended.strategies'
1167
+ cls.module_path = "symai.extended.strategies"
1114
1168
  return Strategy.load_module_class(cls.module)