symbolicai 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- symai/__init__.py +198 -134
- symai/backend/base.py +51 -51
- symai/backend/engines/drawing/engine_bfl.py +33 -33
- symai/backend/engines/drawing/engine_gpt_image.py +4 -10
- symai/backend/engines/embedding/engine_llama_cpp.py +50 -35
- symai/backend/engines/embedding/engine_openai.py +22 -16
- symai/backend/engines/execute/engine_python.py +16 -16
- symai/backend/engines/files/engine_io.py +51 -49
- symai/backend/engines/imagecaptioning/engine_blip2.py +27 -23
- symai/backend/engines/imagecaptioning/engine_llavacpp_client.py +53 -46
- symai/backend/engines/index/engine_pinecone.py +116 -88
- symai/backend/engines/index/engine_qdrant.py +1011 -0
- symai/backend/engines/index/engine_vectordb.py +78 -52
- symai/backend/engines/lean/engine_lean4.py +65 -25
- symai/backend/engines/neurosymbolic/__init__.py +28 -28
- symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_chat.py +137 -135
- symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_reasoning.py +145 -152
- symai/backend/engines/neurosymbolic/engine_cerebras.py +328 -0
- symai/backend/engines/neurosymbolic/engine_deepseekX_reasoning.py +75 -49
- symai/backend/engines/neurosymbolic/engine_google_geminiX_reasoning.py +199 -155
- symai/backend/engines/neurosymbolic/engine_groq.py +106 -72
- symai/backend/engines/neurosymbolic/engine_huggingface.py +100 -67
- symai/backend/engines/neurosymbolic/engine_llama_cpp.py +121 -93
- symai/backend/engines/neurosymbolic/engine_openai_gptX_chat.py +213 -132
- symai/backend/engines/neurosymbolic/engine_openai_gptX_reasoning.py +180 -137
- symai/backend/engines/ocr/engine_apilayer.py +18 -20
- symai/backend/engines/output/engine_stdout.py +9 -9
- symai/backend/engines/{webscraping → scrape}/engine_requests.py +25 -11
- symai/backend/engines/search/engine_openai.py +95 -83
- symai/backend/engines/search/engine_parallel.py +665 -0
- symai/backend/engines/search/engine_perplexity.py +40 -41
- symai/backend/engines/search/engine_serpapi.py +33 -28
- symai/backend/engines/speech_to_text/engine_local_whisper.py +37 -27
- symai/backend/engines/symbolic/engine_wolframalpha.py +14 -8
- symai/backend/engines/text_to_speech/engine_openai.py +15 -19
- symai/backend/engines/text_vision/engine_clip.py +34 -28
- symai/backend/engines/userinput/engine_console.py +3 -4
- symai/backend/mixin/anthropic.py +48 -40
- symai/backend/mixin/deepseek.py +4 -5
- symai/backend/mixin/google.py +5 -4
- symai/backend/mixin/groq.py +2 -4
- symai/backend/mixin/openai.py +132 -110
- symai/backend/settings.py +14 -14
- symai/chat.py +164 -94
- symai/collect/dynamic.py +13 -11
- symai/collect/pipeline.py +39 -31
- symai/collect/stats.py +109 -69
- symai/components.py +556 -238
- symai/constraints.py +14 -5
- symai/core.py +1495 -1210
- symai/core_ext.py +55 -50
- symai/endpoints/api.py +113 -58
- symai/extended/api_builder.py +22 -17
- symai/extended/arxiv_pdf_parser.py +13 -5
- symai/extended/bibtex_parser.py +8 -4
- symai/extended/conversation.py +88 -69
- symai/extended/document.py +40 -27
- symai/extended/file_merger.py +45 -7
- symai/extended/graph.py +38 -24
- symai/extended/html_style_template.py +17 -11
- symai/extended/interfaces/blip_2.py +1 -1
- symai/extended/interfaces/clip.py +4 -2
- symai/extended/interfaces/console.py +5 -3
- symai/extended/interfaces/dall_e.py +3 -1
- symai/extended/interfaces/file.py +2 -0
- symai/extended/interfaces/flux.py +3 -1
- symai/extended/interfaces/gpt_image.py +15 -6
- symai/extended/interfaces/input.py +2 -1
- symai/extended/interfaces/llava.py +1 -1
- symai/extended/interfaces/{naive_webscraping.py → naive_scrape.py} +3 -2
- symai/extended/interfaces/naive_vectordb.py +2 -2
- symai/extended/interfaces/ocr.py +4 -2
- symai/extended/interfaces/openai_search.py +2 -0
- symai/extended/interfaces/parallel.py +30 -0
- symai/extended/interfaces/perplexity.py +2 -0
- symai/extended/interfaces/pinecone.py +6 -4
- symai/extended/interfaces/python.py +2 -0
- symai/extended/interfaces/serpapi.py +2 -0
- symai/extended/interfaces/terminal.py +0 -1
- symai/extended/interfaces/tts.py +2 -1
- symai/extended/interfaces/whisper.py +2 -1
- symai/extended/interfaces/wolframalpha.py +1 -0
- symai/extended/metrics/__init__.py +1 -1
- symai/extended/metrics/similarity.py +5 -2
- symai/extended/os_command.py +31 -22
- symai/extended/packages/symdev.py +39 -34
- symai/extended/packages/sympkg.py +30 -27
- symai/extended/packages/symrun.py +46 -35
- symai/extended/repo_cloner.py +10 -9
- symai/extended/seo_query_optimizer.py +15 -12
- symai/extended/solver.py +104 -76
- symai/extended/summarizer.py +8 -7
- symai/extended/taypan_interpreter.py +10 -9
- symai/extended/vectordb.py +28 -15
- symai/formatter/formatter.py +39 -31
- symai/formatter/regex.py +46 -44
- symai/functional.py +184 -86
- symai/imports.py +85 -51
- symai/interfaces.py +1 -1
- symai/memory.py +33 -24
- symai/menu/screen.py +28 -19
- symai/misc/console.py +27 -27
- symai/misc/loader.py +4 -3
- symai/models/base.py +147 -76
- symai/models/errors.py +1 -1
- symai/ops/__init__.py +1 -1
- symai/ops/measures.py +17 -14
- symai/ops/primitives.py +933 -635
- symai/post_processors.py +28 -24
- symai/pre_processors.py +58 -52
- symai/processor.py +15 -9
- symai/prompts.py +714 -649
- symai/server/huggingface_server.py +115 -32
- symai/server/llama_cpp_server.py +14 -6
- symai/server/qdrant_server.py +206 -0
- symai/shell.py +98 -39
- symai/shellsv.py +307 -223
- symai/strategy.py +135 -81
- symai/symbol.py +276 -225
- symai/utils.py +62 -46
- {symbolicai-1.0.0.dist-info → symbolicai-1.1.0.dist-info}/METADATA +19 -9
- symbolicai-1.1.0.dist-info/RECORD +168 -0
- symbolicai-1.0.0.dist-info/RECORD +0 -163
- {symbolicai-1.0.0.dist-info → symbolicai-1.1.0.dist-info}/WHEEL +0 -0
- {symbolicai-1.0.0.dist-info → symbolicai-1.1.0.dist-info}/entry_points.txt +0 -0
- {symbolicai-1.0.0.dist-info → symbolicai-1.1.0.dist-info}/licenses/LICENSE +0 -0
- {symbolicai-1.0.0.dist-info → symbolicai-1.1.0.dist-info}/top_level.txt +0 -0
symai/strategy.py
CHANGED
|
@@ -29,6 +29,7 @@ class ValidationFunction(Function):
|
|
|
29
29
|
• Pause/backoff logic
|
|
30
30
|
• Error simplification
|
|
31
31
|
"""
|
|
32
|
+
|
|
32
33
|
# Have some default retry params that don't add overhead
|
|
33
34
|
_default_retry_params: ClassVar[dict[str, int | float | bool]] = {
|
|
34
35
|
"tries": 8,
|
|
@@ -93,9 +94,7 @@ class ValidationFunction(Function):
|
|
|
93
94
|
seed = 42
|
|
94
95
|
|
|
95
96
|
rnd = np.random.RandomState(seed=seed)
|
|
96
|
-
return rnd.randint(
|
|
97
|
-
0, np.iinfo(np.int16).max, size=num_seeds, dtype=np.int16
|
|
98
|
-
).tolist()
|
|
97
|
+
return rnd.randint(0, np.iinfo(np.int16).max, size=num_seeds, dtype=np.int16).tolist()
|
|
99
98
|
|
|
100
99
|
def simplify_validation_errors(self, error: ValidationError) -> str:
|
|
101
100
|
"""
|
|
@@ -123,11 +122,13 @@ class ValidationFunction(Function):
|
|
|
123
122
|
return "\n".join(simplified_errors)
|
|
124
123
|
|
|
125
124
|
def _pause(self, attempt):
|
|
126
|
-
base = self.retry_params[
|
|
127
|
-
jit = (
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
125
|
+
base = self.retry_params["delay"] * (self.retry_params["backoff"] ** attempt)
|
|
126
|
+
jit = (
|
|
127
|
+
np.random.uniform(*self.retry_params["jitter"])
|
|
128
|
+
if isinstance(self.retry_params["jitter"], tuple)
|
|
129
|
+
else self.retry_params["jitter"]
|
|
130
|
+
)
|
|
131
|
+
_delay = min(base + jit, self.retry_params["max_delay"])
|
|
131
132
|
time.sleep(_delay)
|
|
132
133
|
|
|
133
134
|
def remedy_prompt(self, *_args, **_kwargs):
|
|
@@ -139,7 +140,7 @@ class ValidationFunction(Function):
|
|
|
139
140
|
UserMessage(msg)
|
|
140
141
|
raise NotImplementedError(msg)
|
|
141
142
|
|
|
142
|
-
def display_panel(self, content, title, border_style="cyan", style="#f0eee6", padding=(1,2)):
|
|
143
|
+
def display_panel(self, content, title, border_style="cyan", style="#f0eee6", padding=(1, 2)):
|
|
143
144
|
"""
|
|
144
145
|
Display content in a rich panel with consistent formatting.
|
|
145
146
|
|
|
@@ -151,11 +152,13 @@ class ValidationFunction(Function):
|
|
|
151
152
|
padding: Padding for the panel (default: (1,2))
|
|
152
153
|
"""
|
|
153
154
|
body = escape(content)
|
|
154
|
-
panel = Panel.fit(
|
|
155
|
+
panel = Panel.fit(
|
|
156
|
+
body, title=title, padding=padding, border_style=border_style, style=style
|
|
157
|
+
)
|
|
155
158
|
self.console.print(panel)
|
|
156
159
|
|
|
157
160
|
def forward(self, *args, **kwargs):
|
|
158
|
-
return super().forward(*args, **kwargs)
|
|
161
|
+
return super().forward(*args, **kwargs) # Just propagate to Function
|
|
159
162
|
|
|
160
163
|
|
|
161
164
|
class TypeValidationFunction(ValidationFunction):
|
|
@@ -165,6 +168,7 @@ class TypeValidationFunction(ValidationFunction):
|
|
|
165
168
|
if a user provides a callable designed to semantically validate the
|
|
166
169
|
structure of the type-validated data.
|
|
167
170
|
"""
|
|
171
|
+
|
|
168
172
|
def __init__(
|
|
169
173
|
self,
|
|
170
174
|
retry_params: dict[str, int | float | bool] = ValidationFunction._default_retry_params,
|
|
@@ -179,8 +183,12 @@ class TypeValidationFunction(ValidationFunction):
|
|
|
179
183
|
self.accumulate_errors = accumulate_errors
|
|
180
184
|
self.verbose = verbose
|
|
181
185
|
|
|
182
|
-
def register_expected_data_model(
|
|
183
|
-
|
|
186
|
+
def register_expected_data_model(
|
|
187
|
+
self, data_model: LLMDataModel, attach_to: str, override: bool = False
|
|
188
|
+
):
|
|
189
|
+
assert attach_to in ["input", "output"], (
|
|
190
|
+
f"Invalid attach_to value: {attach_to}; must be either 'input' or 'output'"
|
|
191
|
+
)
|
|
184
192
|
if attach_to == "input":
|
|
185
193
|
if self.input_data_model is not None and not override:
|
|
186
194
|
msg = "There is already a data model attached to the input. If you want to override it, set `override=True`."
|
|
@@ -206,12 +214,12 @@ Your prompt was:
|
|
|
206
214
|
|
|
207
215
|
The input data model is:
|
|
208
216
|
<input_data_model>
|
|
209
|
-
{self.input_data_model.simplify_json_schema() if self.input_data_model is not None else
|
|
217
|
+
{self.input_data_model.simplify_json_schema() if self.input_data_model is not None else "N/A"}
|
|
210
218
|
</input_data_model>
|
|
211
219
|
|
|
212
220
|
The given input was:
|
|
213
221
|
<input>
|
|
214
|
-
{str(self.input_data_model) if self.input_data_model is not None else
|
|
222
|
+
{str(self.input_data_model) if self.input_data_model is not None else "N/A"}
|
|
215
223
|
</input>
|
|
216
224
|
|
|
217
225
|
The output data model is:
|
|
@@ -253,12 +261,12 @@ You are given the following prompt:
|
|
|
253
261
|
|
|
254
262
|
The input data model is:
|
|
255
263
|
<input_data_model>
|
|
256
|
-
{self.input_data_model.simplify_json_schema() if self.input_data_model is not None else
|
|
264
|
+
{self.input_data_model.simplify_json_schema() if self.input_data_model is not None else "N/A"}
|
|
257
265
|
</input_data_model>
|
|
258
266
|
|
|
259
267
|
The given input is:
|
|
260
268
|
<input>
|
|
261
|
-
{str(self.input_data_model) if self.input_data_model is not None else
|
|
269
|
+
{str(self.input_data_model) if self.input_data_model is not None else "N/A"}
|
|
262
270
|
</input>
|
|
263
271
|
|
|
264
272
|
The output data model is:
|
|
@@ -288,7 +296,10 @@ Important guidelines:
|
|
|
288
296
|
return
|
|
289
297
|
for label, body in [
|
|
290
298
|
("Prompt", prompt),
|
|
291
|
-
(
|
|
299
|
+
(
|
|
300
|
+
"Input data model",
|
|
301
|
+
self.input_data_model.simplify_json_schema() if self.input_data_model else "N/A",
|
|
302
|
+
),
|
|
292
303
|
("Output data model", self.output_data_model.simplify_json_schema()),
|
|
293
304
|
]:
|
|
294
305
|
self.display_panel(body, title=label)
|
|
@@ -302,7 +313,11 @@ Important guidelines:
|
|
|
302
313
|
return None
|
|
303
314
|
try:
|
|
304
315
|
assert all(
|
|
305
|
-
f(
|
|
316
|
+
f(
|
|
317
|
+
result
|
|
318
|
+
if not getattr(self.output_data_model, "_is_dynamic_model", False)
|
|
319
|
+
else result.value
|
|
320
|
+
)
|
|
306
321
|
for f in f_semantic_conditions
|
|
307
322
|
)
|
|
308
323
|
except Exception as err:
|
|
@@ -364,7 +379,9 @@ Important guidelines:
|
|
|
364
379
|
total_attempts = self.retry_params["tries"] + 1
|
|
365
380
|
for attempt in range(total_attempts):
|
|
366
381
|
if attempt != self.retry_params["tries"]:
|
|
367
|
-
logger.info(
|
|
382
|
+
logger.info(
|
|
383
|
+
f"Attempt {attempt + 1}/{self.retry_params['tries']}: Attempting validation…"
|
|
384
|
+
)
|
|
368
385
|
try:
|
|
369
386
|
result = self.output_data_model.model_validate_json(
|
|
370
387
|
json_str,
|
|
@@ -393,7 +410,7 @@ Important guidelines:
|
|
|
393
410
|
|
|
394
411
|
def _handle_validation_failure(self, prompt: str, json_str: str, errors: list[str]):
|
|
395
412
|
logger.error("All validation attempts failed!")
|
|
396
|
-
if self.retry_params[
|
|
413
|
+
if self.retry_params["graceful"]:
|
|
397
414
|
return
|
|
398
415
|
raise TypeValidationError(
|
|
399
416
|
prompt=prompt,
|
|
@@ -401,9 +418,11 @@ Important guidelines:
|
|
|
401
418
|
violations=errors,
|
|
402
419
|
)
|
|
403
420
|
|
|
404
|
-
def forward(
|
|
421
|
+
def forward(
|
|
422
|
+
self, prompt: str, f_semantic_conditions: list[Callable] | None = None, *args, **kwargs
|
|
423
|
+
):
|
|
405
424
|
self._ensure_output_model()
|
|
406
|
-
validation_context = kwargs.pop(
|
|
425
|
+
validation_context = kwargs.pop("validation_context", {})
|
|
407
426
|
kwargs["response_format"] = {"type": "json_object"}
|
|
408
427
|
logger.info("Initializing validation…")
|
|
409
428
|
self._display_verbose_panels(prompt)
|
|
@@ -451,15 +470,17 @@ class contract:
|
|
|
451
470
|
verbose: bool = False,
|
|
452
471
|
remedy_retry_params: dict[str, int | float | bool] = _default_remedy_retry_params,
|
|
453
472
|
):
|
|
454
|
-
|
|
473
|
+
"""
|
|
455
474
|
A contract class decorator inspired by DbC principles. It ensures that the function's input and output
|
|
456
475
|
adhere to specified data models both syntactically and semantically. This implementation includes retry
|
|
457
476
|
logic to handle transient errors and gracefully handle failures.
|
|
458
|
-
|
|
477
|
+
"""
|
|
459
478
|
self.pre_remedy = pre_remedy
|
|
460
479
|
self.post_remedy = post_remedy
|
|
461
480
|
self.remedy_retry_params = remedy_retry_params
|
|
462
|
-
self.f_type_validation_remedy = TypeValidationFunction(
|
|
481
|
+
self.f_type_validation_remedy = TypeValidationFunction(
|
|
482
|
+
accumulate_errors=accumulate_errors, verbose=verbose, retry_params=remedy_retry_params
|
|
483
|
+
)
|
|
463
484
|
|
|
464
485
|
if not verbose:
|
|
465
486
|
logger.disable(__name__)
|
|
@@ -540,7 +561,9 @@ class contract:
|
|
|
540
561
|
|
|
541
562
|
def _try_remedy_with_exception(self, prompt, f_semantic_conditions, **remedy_kwargs):
|
|
542
563
|
try:
|
|
543
|
-
data_model = self.f_type_validation_remedy(
|
|
564
|
+
data_model = self.f_type_validation_remedy(
|
|
565
|
+
prompt, f_semantic_conditions=f_semantic_conditions, **remedy_kwargs
|
|
566
|
+
)
|
|
544
567
|
except Exception as e:
|
|
545
568
|
logger.error("Type validation failed with exception!")
|
|
546
569
|
raise e
|
|
@@ -550,7 +573,7 @@ class contract:
|
|
|
550
573
|
logger.info("Starting input validation...")
|
|
551
574
|
if self.pre_remedy:
|
|
552
575
|
logger.info("Validating pre-conditions with remedy...")
|
|
553
|
-
if not hasattr(wrapped_self,
|
|
576
|
+
if not hasattr(wrapped_self, "pre"):
|
|
554
577
|
logger.error("Pre-condition function not defined!")
|
|
555
578
|
msg = "Pre-condition function not defined. Please define a `pre` method if you want to enforce pre-conditions through a remedy."
|
|
556
579
|
UserMessage(msg)
|
|
@@ -563,16 +586,20 @@ class contract:
|
|
|
563
586
|
return input_value
|
|
564
587
|
except Exception:
|
|
565
588
|
logger.exception("Pre-condition validation failed!")
|
|
566
|
-
self.f_type_validation_remedy.register_expected_data_model(
|
|
589
|
+
self.f_type_validation_remedy.register_expected_data_model(
|
|
590
|
+
input_value, attach_to="output", override=True
|
|
591
|
+
)
|
|
567
592
|
input_value = self._try_remedy_with_exception(
|
|
568
593
|
prompt=wrapped_self.prompt,
|
|
569
594
|
f_semantic_conditions=[wrapped_self.pre],
|
|
570
595
|
**remedy_kwargs,
|
|
571
596
|
)
|
|
572
597
|
finally:
|
|
573
|
-
wrapped_self._contract_timing[it]["input_validation"] =
|
|
598
|
+
wrapped_self._contract_timing[it]["input_validation"] = (
|
|
599
|
+
time.perf_counter() - op_start
|
|
600
|
+
)
|
|
574
601
|
return input_value
|
|
575
|
-
if hasattr(wrapped_self,
|
|
602
|
+
if hasattr(wrapped_self, "pre"):
|
|
576
603
|
logger.info("Validating pre-conditions without remedy...")
|
|
577
604
|
op_start = time.perf_counter()
|
|
578
605
|
try:
|
|
@@ -581,7 +608,9 @@ class contract:
|
|
|
581
608
|
logger.exception("Pre-condition validation failed")
|
|
582
609
|
raise e
|
|
583
610
|
finally:
|
|
584
|
-
wrapped_self._contract_timing[it]["input_validation"] =
|
|
611
|
+
wrapped_self._contract_timing[it]["input_validation"] = (
|
|
612
|
+
time.perf_counter() - op_start
|
|
613
|
+
)
|
|
585
614
|
logger.success("Pre-condition validation successful!")
|
|
586
615
|
return input_value
|
|
587
616
|
logger.info("Skip; no pre-condition validation was required!")
|
|
@@ -589,14 +618,20 @@ class contract:
|
|
|
589
618
|
|
|
590
619
|
def _validate_output(self, wrapped_self, input_value, output, it, **remedy_kwargs):
|
|
591
620
|
logger.info("Starting output validation...")
|
|
592
|
-
self.f_type_validation_remedy.register_expected_data_model(
|
|
593
|
-
|
|
621
|
+
self.f_type_validation_remedy.register_expected_data_model(
|
|
622
|
+
input_value, attach_to="input", override=True
|
|
623
|
+
)
|
|
624
|
+
self.f_type_validation_remedy.register_expected_data_model(
|
|
625
|
+
output, attach_to="output", override=True
|
|
626
|
+
)
|
|
594
627
|
|
|
595
628
|
op_start = time.perf_counter()
|
|
596
629
|
try:
|
|
597
630
|
logger.info("Getting a valid output type...")
|
|
598
|
-
output = self._try_remedy_with_exception(
|
|
599
|
-
|
|
631
|
+
output = self._try_remedy_with_exception(
|
|
632
|
+
prompt=wrapped_self.prompt, f_semantic_conditions=None, **remedy_kwargs
|
|
633
|
+
)
|
|
634
|
+
if output is None: # output is None when graceful mode is enabled
|
|
600
635
|
return output
|
|
601
636
|
except Exception as e:
|
|
602
637
|
logger.exception("Type creation failed!")
|
|
@@ -620,9 +655,15 @@ class contract:
|
|
|
620
655
|
return output
|
|
621
656
|
except Exception:
|
|
622
657
|
logger.exception("Post-condition validation failed!")
|
|
623
|
-
output = self._try_remedy_with_exception(
|
|
658
|
+
output = self._try_remedy_with_exception(
|
|
659
|
+
prompt=wrapped_self.prompt,
|
|
660
|
+
f_semantic_conditions=[wrapped_self.post],
|
|
661
|
+
**remedy_kwargs,
|
|
662
|
+
)
|
|
624
663
|
finally:
|
|
625
|
-
wrapped_self._contract_timing[it]["output_validation"] += (
|
|
664
|
+
wrapped_self._contract_timing[it]["output_validation"] += (
|
|
665
|
+
time.perf_counter() - op_start
|
|
666
|
+
)
|
|
626
667
|
logger.success("Post-condition validation successful!")
|
|
627
668
|
return output
|
|
628
669
|
if hasattr(wrapped_self, "post"):
|
|
@@ -634,7 +675,9 @@ class contract:
|
|
|
634
675
|
logger.exception("Post-condition validation failed!")
|
|
635
676
|
raise e
|
|
636
677
|
finally:
|
|
637
|
-
wrapped_self._contract_timing[it]["output_validation"] =
|
|
678
|
+
wrapped_self._contract_timing[it]["output_validation"] = (
|
|
679
|
+
time.perf_counter() - op_start
|
|
680
|
+
)
|
|
638
681
|
logger.success("Post-condition validation successful!")
|
|
639
682
|
return output
|
|
640
683
|
logger.info("Skip; no post-condition validation was required!")
|
|
@@ -670,14 +713,14 @@ class contract:
|
|
|
670
713
|
return True
|
|
671
714
|
|
|
672
715
|
def _act(self, wrapped_self, input_value, it, **act_kwargs):
|
|
673
|
-
act_method = getattr(wrapped_self,
|
|
716
|
+
act_method = getattr(wrapped_self, "act", None)
|
|
674
717
|
if not callable(act_method):
|
|
675
718
|
# Propagate the input if no act method is defined
|
|
676
719
|
return input_value
|
|
677
720
|
|
|
678
721
|
assert self._validate_act_method(act_method)
|
|
679
722
|
|
|
680
|
-
is_dynamic_model = getattr(input_value,
|
|
723
|
+
is_dynamic_model = getattr(input_value, "_is_dynamic_model", False)
|
|
681
724
|
input_value = input_value if not is_dynamic_model else input_value.value
|
|
682
725
|
|
|
683
726
|
logger.info(f"Executing 'act' method on {wrapped_self.__class__.__name__}…")
|
|
@@ -798,11 +841,15 @@ class contract:
|
|
|
798
841
|
):
|
|
799
842
|
output = None
|
|
800
843
|
try:
|
|
801
|
-
maybe_new_input = self._validate_input(
|
|
844
|
+
maybe_new_input = self._validate_input(
|
|
845
|
+
wrapped_self, current_input_value, it, **validation_kwargs
|
|
846
|
+
)
|
|
802
847
|
if maybe_new_input is not None:
|
|
803
848
|
current_input_value = maybe_new_input
|
|
804
849
|
|
|
805
|
-
current_input_value = self._act(
|
|
850
|
+
current_input_value = self._act(
|
|
851
|
+
wrapped_self, current_input_value, it, **validation_kwargs
|
|
852
|
+
)
|
|
806
853
|
|
|
807
854
|
output = self._validate_output(
|
|
808
855
|
wrapped_self,
|
|
@@ -846,7 +893,7 @@ class contract:
|
|
|
846
893
|
else:
|
|
847
894
|
forward_kwargs[input_param_name] = forward_input_value
|
|
848
895
|
else:
|
|
849
|
-
forward_kwargs[
|
|
896
|
+
forward_kwargs["input"] = forward_input_value
|
|
850
897
|
|
|
851
898
|
if input_param_name and input_param_name != "input" and "input" in forward_kwargs:
|
|
852
899
|
forward_kwargs.pop("input")
|
|
@@ -856,14 +903,16 @@ class contract:
|
|
|
856
903
|
output = original_forward(wrapped_self, *args_list, **forward_kwargs)
|
|
857
904
|
finally:
|
|
858
905
|
wrapped_self._contract_timing[it]["forward_execution"] = time.perf_counter() - op_start
|
|
859
|
-
wrapped_self._contract_timing[it]["contract_execution"] =
|
|
906
|
+
wrapped_self._contract_timing[it]["contract_execution"] = (
|
|
907
|
+
time.perf_counter() - contract_start
|
|
908
|
+
)
|
|
860
909
|
return output
|
|
861
910
|
|
|
862
911
|
def _finalize_contract_output(self, output, output_type, wrapped_self):
|
|
863
912
|
if not isinstance(output, output_type):
|
|
864
913
|
logger.error(f"Output type mismatch: {type(output)}")
|
|
865
914
|
if self.remedy_retry_params["graceful"]:
|
|
866
|
-
if getattr(output_type,
|
|
915
|
+
if getattr(output_type, "_is_dynamic_model", False) and hasattr(output, "value"):
|
|
867
916
|
return output.value
|
|
868
917
|
return output
|
|
869
918
|
msg = (
|
|
@@ -877,7 +926,7 @@ class contract:
|
|
|
877
926
|
else:
|
|
878
927
|
logger.success("Contract validation successful!")
|
|
879
928
|
|
|
880
|
-
if getattr(output_type,
|
|
929
|
+
if getattr(output_type, "_is_dynamic_model", False):
|
|
881
930
|
return output.value
|
|
882
931
|
return output
|
|
883
932
|
|
|
@@ -886,7 +935,9 @@ class contract:
|
|
|
886
935
|
sig = inspect.signature(original_forward)
|
|
887
936
|
input_param_name = self._find_input_param_name(sig)
|
|
888
937
|
args_list, kwargs_without_input, original_kwargs = self._prepare_forward_args(args, kwargs)
|
|
889
|
-
input_value, input_source = self._extract_input_value(
|
|
938
|
+
input_value, input_source = self._extract_input_value(
|
|
939
|
+
args_list, kwargs_without_input, original_kwargs, input_param_name
|
|
940
|
+
)
|
|
890
941
|
current_input_value = self._coerce_input_value(original_forward, input_value)
|
|
891
942
|
input_value = current_input_value
|
|
892
943
|
validation_kwargs = self._collect_validation_kwargs(wrapped_self, kwargs_without_input)
|
|
@@ -900,7 +951,9 @@ class contract:
|
|
|
900
951
|
validation_kwargs,
|
|
901
952
|
)
|
|
902
953
|
|
|
903
|
-
forward_input_value =
|
|
954
|
+
forward_input_value = (
|
|
955
|
+
current_input_value if wrapped_self.contract_successful else input_value
|
|
956
|
+
)
|
|
904
957
|
output = self._execute_forward_call(
|
|
905
958
|
wrapped_self,
|
|
906
959
|
original_forward,
|
|
@@ -936,7 +989,7 @@ class contract:
|
|
|
936
989
|
"act_execution",
|
|
937
990
|
"output_validation",
|
|
938
991
|
"forward_execution",
|
|
939
|
-
"contract_execution"
|
|
992
|
+
"contract_execution",
|
|
940
993
|
]
|
|
941
994
|
|
|
942
995
|
stats = {}
|
|
@@ -958,40 +1011,41 @@ class contract:
|
|
|
958
1011
|
max_time = max(non_zero_times) if non_zero_times else 0
|
|
959
1012
|
|
|
960
1013
|
stats[op] = {
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
|
|
966
|
-
|
|
1014
|
+
"count": actual_count,
|
|
1015
|
+
"total": total_time,
|
|
1016
|
+
"mean": mean_time,
|
|
1017
|
+
"std": std_time,
|
|
1018
|
+
"min": min_time,
|
|
1019
|
+
"max": max_time,
|
|
967
1020
|
}
|
|
968
1021
|
|
|
969
|
-
total_execution_time = stats[
|
|
1022
|
+
total_execution_time = stats["contract_execution"]["total"]
|
|
970
1023
|
for op in ordered_operations[:-1]:
|
|
971
1024
|
if total_execution_time > 0:
|
|
972
|
-
stats[op][
|
|
1025
|
+
stats[op]["percentage"] = (stats[op]["total"] / total_execution_time) * 100
|
|
973
1026
|
else:
|
|
974
|
-
stats[op][
|
|
1027
|
+
stats[op]["percentage"] = 0
|
|
975
1028
|
|
|
976
|
-
sum_tracked_times = sum(stats[op][
|
|
1029
|
+
sum_tracked_times = sum(stats[op]["total"] for op in ordered_operations[:-1])
|
|
977
1030
|
overhead_time = total_execution_time - sum_tracked_times
|
|
978
|
-
overhead_percentage = (
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
1031
|
+
overhead_percentage = (
|
|
1032
|
+
(overhead_time / total_execution_time) * 100 if total_execution_time > 0 else 0
|
|
1033
|
+
)
|
|
1034
|
+
|
|
1035
|
+
stats["overhead"] = {
|
|
1036
|
+
"count": num_calls,
|
|
1037
|
+
"total": overhead_time,
|
|
1038
|
+
"mean": overhead_time / num_calls if num_calls > 0 else 0,
|
|
1039
|
+
"std": 0,
|
|
1040
|
+
"min": 0,
|
|
1041
|
+
"max": 0,
|
|
1042
|
+
"percentage": overhead_percentage,
|
|
988
1043
|
}
|
|
989
1044
|
|
|
990
|
-
stats[
|
|
1045
|
+
stats["contract_execution"]["percentage"] = 100.0
|
|
991
1046
|
|
|
992
1047
|
table = Table(
|
|
993
|
-
title=f"Contract Execution Summary ({num_calls} Forward Calls)",
|
|
994
|
-
show_header=True
|
|
1048
|
+
title=f"Contract Execution Summary ({num_calls} Forward Calls)", show_header=True
|
|
995
1049
|
)
|
|
996
1050
|
table.add_column("Operation", style="cyan")
|
|
997
1051
|
table.add_column("Count", justify="right", style="blue")
|
|
@@ -1006,29 +1060,29 @@ class contract:
|
|
|
1006
1060
|
s = stats[op]
|
|
1007
1061
|
table.add_row(
|
|
1008
1062
|
op.replace("_", " ").title(),
|
|
1009
|
-
str(s[
|
|
1063
|
+
str(s["count"]),
|
|
1010
1064
|
f"{s['total']:.3f}",
|
|
1011
1065
|
f"{s['mean']:.3f}",
|
|
1012
1066
|
f"{s['std']:.3f}",
|
|
1013
1067
|
f"{s['min']:.3f}",
|
|
1014
1068
|
f"{s['max']:.3f}",
|
|
1015
|
-
f"{s['percentage']:.1f}%"
|
|
1069
|
+
f"{s['percentage']:.1f}%",
|
|
1016
1070
|
)
|
|
1017
1071
|
|
|
1018
|
-
s = stats[
|
|
1072
|
+
s = stats["overhead"]
|
|
1019
1073
|
table.add_row(
|
|
1020
1074
|
"Overhead",
|
|
1021
|
-
str(s[
|
|
1075
|
+
str(s["count"]),
|
|
1022
1076
|
f"{s['total']:.3f}",
|
|
1023
1077
|
f"{s['mean']:.3f}",
|
|
1024
1078
|
f"{s['std']:.3f}",
|
|
1025
1079
|
f"{s['min']:.3f}",
|
|
1026
1080
|
f"{s['max']:.3f}",
|
|
1027
1081
|
f"{s['percentage']:.1f}%",
|
|
1028
|
-
style="bold blue"
|
|
1082
|
+
style="bold blue",
|
|
1029
1083
|
)
|
|
1030
1084
|
|
|
1031
|
-
s = stats[
|
|
1085
|
+
s = stats["contract_execution"]
|
|
1032
1086
|
table.add_row(
|
|
1033
1087
|
"Total Execution",
|
|
1034
1088
|
"N/A",
|
|
@@ -1038,7 +1092,7 @@ class contract:
|
|
|
1038
1092
|
f"{s['min']:.3f}",
|
|
1039
1093
|
f"{s['max']:.3f}",
|
|
1040
1094
|
"100.0%",
|
|
1041
|
-
style="bold magenta"
|
|
1095
|
+
style="bold magenta",
|
|
1042
1096
|
)
|
|
1043
1097
|
|
|
1044
1098
|
console.print("\n")
|
|
@@ -1110,5 +1164,5 @@ class Strategy(Expression):
|
|
|
1110
1164
|
|
|
1111
1165
|
def __new__(cls, module: str, *_args, **_kwargs):
|
|
1112
1166
|
cls._module = module
|
|
1113
|
-
cls.module_path =
|
|
1167
|
+
cls.module_path = "symai.extended.strategies"
|
|
1114
1168
|
return Strategy.load_module_class(cls.module)
|