symbolicai 0.20.2__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. symai/__init__.py +96 -64
  2. symai/backend/base.py +93 -80
  3. symai/backend/engines/drawing/engine_bfl.py +12 -11
  4. symai/backend/engines/drawing/engine_gpt_image.py +108 -87
  5. symai/backend/engines/embedding/engine_llama_cpp.py +25 -28
  6. symai/backend/engines/embedding/engine_openai.py +3 -5
  7. symai/backend/engines/execute/engine_python.py +6 -5
  8. symai/backend/engines/files/engine_io.py +74 -67
  9. symai/backend/engines/imagecaptioning/engine_blip2.py +3 -3
  10. symai/backend/engines/imagecaptioning/engine_llavacpp_client.py +54 -38
  11. symai/backend/engines/index/engine_pinecone.py +23 -24
  12. symai/backend/engines/index/engine_vectordb.py +16 -14
  13. symai/backend/engines/lean/engine_lean4.py +38 -34
  14. symai/backend/engines/neurosymbolic/__init__.py +41 -13
  15. symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_chat.py +262 -182
  16. symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_reasoning.py +263 -191
  17. symai/backend/engines/neurosymbolic/engine_deepseekX_reasoning.py +53 -49
  18. symai/backend/engines/neurosymbolic/engine_google_geminiX_reasoning.py +212 -211
  19. symai/backend/engines/neurosymbolic/engine_groq.py +87 -63
  20. symai/backend/engines/neurosymbolic/engine_huggingface.py +21 -24
  21. symai/backend/engines/neurosymbolic/engine_llama_cpp.py +117 -48
  22. symai/backend/engines/neurosymbolic/engine_openai_gptX_chat.py +256 -229
  23. symai/backend/engines/neurosymbolic/engine_openai_gptX_reasoning.py +270 -150
  24. symai/backend/engines/ocr/engine_apilayer.py +6 -8
  25. symai/backend/engines/output/engine_stdout.py +1 -4
  26. symai/backend/engines/search/engine_openai.py +7 -7
  27. symai/backend/engines/search/engine_perplexity.py +5 -5
  28. symai/backend/engines/search/engine_serpapi.py +12 -14
  29. symai/backend/engines/speech_to_text/engine_local_whisper.py +20 -27
  30. symai/backend/engines/symbolic/engine_wolframalpha.py +3 -3
  31. symai/backend/engines/text_to_speech/engine_openai.py +5 -7
  32. symai/backend/engines/text_vision/engine_clip.py +7 -11
  33. symai/backend/engines/userinput/engine_console.py +3 -3
  34. symai/backend/engines/webscraping/engine_requests.py +81 -48
  35. symai/backend/mixin/__init__.py +13 -0
  36. symai/backend/mixin/anthropic.py +4 -2
  37. symai/backend/mixin/deepseek.py +2 -0
  38. symai/backend/mixin/google.py +2 -0
  39. symai/backend/mixin/openai.py +11 -3
  40. symai/backend/settings.py +83 -16
  41. symai/chat.py +101 -78
  42. symai/collect/__init__.py +7 -1
  43. symai/collect/dynamic.py +77 -69
  44. symai/collect/pipeline.py +35 -27
  45. symai/collect/stats.py +75 -63
  46. symai/components.py +198 -169
  47. symai/constraints.py +15 -12
  48. symai/core.py +698 -359
  49. symai/core_ext.py +32 -34
  50. symai/endpoints/api.py +80 -73
  51. symai/extended/.DS_Store +0 -0
  52. symai/extended/__init__.py +46 -12
  53. symai/extended/api_builder.py +11 -8
  54. symai/extended/arxiv_pdf_parser.py +13 -12
  55. symai/extended/bibtex_parser.py +2 -3
  56. symai/extended/conversation.py +101 -90
  57. symai/extended/document.py +17 -10
  58. symai/extended/file_merger.py +18 -13
  59. symai/extended/graph.py +18 -13
  60. symai/extended/html_style_template.py +2 -4
  61. symai/extended/interfaces/blip_2.py +1 -2
  62. symai/extended/interfaces/clip.py +1 -2
  63. symai/extended/interfaces/console.py +7 -1
  64. symai/extended/interfaces/dall_e.py +1 -1
  65. symai/extended/interfaces/flux.py +1 -1
  66. symai/extended/interfaces/gpt_image.py +1 -1
  67. symai/extended/interfaces/input.py +1 -1
  68. symai/extended/interfaces/llava.py +0 -1
  69. symai/extended/interfaces/naive_vectordb.py +7 -8
  70. symai/extended/interfaces/naive_webscraping.py +1 -1
  71. symai/extended/interfaces/ocr.py +1 -1
  72. symai/extended/interfaces/pinecone.py +6 -5
  73. symai/extended/interfaces/serpapi.py +1 -1
  74. symai/extended/interfaces/terminal.py +2 -3
  75. symai/extended/interfaces/tts.py +1 -1
  76. symai/extended/interfaces/whisper.py +1 -1
  77. symai/extended/interfaces/wolframalpha.py +1 -1
  78. symai/extended/metrics/__init__.py +11 -1
  79. symai/extended/metrics/similarity.py +11 -13
  80. symai/extended/os_command.py +17 -16
  81. symai/extended/packages/__init__.py +29 -3
  82. symai/extended/packages/symdev.py +19 -16
  83. symai/extended/packages/sympkg.py +12 -9
  84. symai/extended/packages/symrun.py +21 -19
  85. symai/extended/repo_cloner.py +11 -10
  86. symai/extended/seo_query_optimizer.py +1 -2
  87. symai/extended/solver.py +20 -23
  88. symai/extended/summarizer.py +4 -3
  89. symai/extended/taypan_interpreter.py +10 -12
  90. symai/extended/vectordb.py +99 -82
  91. symai/formatter/__init__.py +9 -1
  92. symai/formatter/formatter.py +12 -16
  93. symai/formatter/regex.py +62 -63
  94. symai/functional.py +176 -122
  95. symai/imports.py +136 -127
  96. symai/interfaces.py +56 -27
  97. symai/memory.py +14 -13
  98. symai/misc/console.py +49 -39
  99. symai/misc/loader.py +5 -3
  100. symai/models/__init__.py +17 -1
  101. symai/models/base.py +269 -181
  102. symai/models/errors.py +0 -1
  103. symai/ops/__init__.py +32 -22
  104. symai/ops/measures.py +11 -15
  105. symai/ops/primitives.py +348 -228
  106. symai/post_processors.py +32 -28
  107. symai/pre_processors.py +39 -41
  108. symai/processor.py +6 -4
  109. symai/prompts.py +59 -45
  110. symai/server/huggingface_server.py +23 -20
  111. symai/server/llama_cpp_server.py +7 -5
  112. symai/shell.py +3 -4
  113. symai/shellsv.py +499 -375
  114. symai/strategy.py +517 -287
  115. symai/symbol.py +111 -116
  116. symai/utils.py +42 -36
  117. {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/METADATA +4 -2
  118. symbolicai-1.0.0.dist-info/RECORD +163 -0
  119. symbolicai-0.20.2.dist-info/RECORD +0 -162
  120. {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/WHEEL +0 -0
  121. {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/entry_points.txt +0 -0
  122. {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/licenses/LICENSE +0 -0
  123. {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/top_level.txt +0 -0
symai/components.py CHANGED
@@ -1,23 +1,19 @@
1
1
  import copy
2
2
  import inspect
3
3
  import json
4
- import os
5
4
  import re
6
5
  import sys
7
- from abc import abstractmethod
8
6
  from collections import defaultdict
7
+ from collections.abc import Callable, Iterator
9
8
  from pathlib import Path
10
9
  from random import sample
11
10
  from string import ascii_lowercase, ascii_uppercase
12
11
  from threading import Lock
13
- from .context import CURRENT_ENGINE_VAR
14
- from typing import Callable, Dict, Iterator, List, Optional, Type, Union
12
+ from typing import TYPE_CHECKING, Union
15
13
 
16
14
  import numpy as np
17
- from attr import dataclass
18
15
  from box import Box
19
16
  from loguru import logger
20
- from pydantic import BaseModel, ValidationError
21
17
  from pyvis.network import Network
22
18
  from tqdm import tqdm
23
19
 
@@ -25,16 +21,25 @@ from . import core, core_ext
25
21
  from .backend.base import Engine
26
22
  from .backend.settings import HOME_PATH
27
23
  from .constraints import DictFormatConstraint
24
+ from .context import CURRENT_ENGINE_VAR
28
25
  from .formatter import ParagraphFormatter
29
- from .post_processors import (CodeExtractPostProcessor,
30
- JsonTruncateMarkdownPostProcessor,
31
- JsonTruncatePostProcessor, PostProcessor,
32
- StripPostProcessor)
26
+ from .post_processors import (
27
+ CodeExtractPostProcessor,
28
+ JsonTruncateMarkdownPostProcessor,
29
+ JsonTruncatePostProcessor,
30
+ PostProcessor,
31
+ StripPostProcessor,
32
+ )
33
33
  from .pre_processors import JsonPreProcessor, PreProcessor
34
34
  from .processor import ProcessorPipeline
35
35
  from .prompts import JsonPromptTemplate, Prompt
36
36
  from .symbol import Expression, Metadata, Symbol
37
- from .utils import CustomUserWarning
37
+ from .utils import UserMessage
38
+
39
+ if TYPE_CHECKING:
40
+ from .backend.engines.index.engine_vectordb import VectorDBResult
41
+
42
+ _DEFAULT_PARAGRAPH_FORMATTER = ParagraphFormatter()
38
43
 
39
44
 
40
45
  class GraphViz(Expression):
@@ -58,7 +63,7 @@ class GraphViz(Expression):
58
63
  select_menu=select_menu,
59
64
  filter_menu=filter_menu)
60
65
 
61
- def forward(self, sym: Symbol, file_path: str, **kwargs):
66
+ def forward(self, sym: Symbol, file_path: str, **_kwargs):
62
67
  nodes = [str(n) if n.value else n.__repr__(simplified=True) for n in sym.nodes]
63
68
  edges = [(str(e[0]) if e[0].value else e[0].__repr__(simplified=True),
64
69
  str(e[1]) if e[1].value else e[1].__repr__(simplified=True)) for e in sym.edges]
@@ -73,21 +78,21 @@ class TrackerTraceable(Expression):
73
78
 
74
79
 
75
80
  class Any(Expression):
76
- def __init__(self, *expr: List[Expression], **kwargs):
81
+ def __init__(self, *expr: list[Expression], **kwargs):
77
82
  super().__init__(**kwargs)
78
- self.expr: List[Expression] = expr
83
+ self.expr: list[Expression] = expr
79
84
 
80
85
  def forward(self, *args, **kwargs) -> Symbol:
81
- return self.sym_return_type(any([e() for e in self.expr(*args, **kwargs)]))
86
+ return self.sym_return_type(any(e() for e in self.expr(*args, **kwargs)))
82
87
 
83
88
 
84
89
  class All(Expression):
85
- def __init__(self, *expr: List[Expression], **kwargs):
90
+ def __init__(self, *expr: list[Expression], **kwargs):
86
91
  super().__init__(**kwargs)
87
- self.expr: List[Expression] = expr
92
+ self.expr: list[Expression] = expr
88
93
 
89
94
  def forward(self, *args, **kwargs) -> Symbol:
90
- return self.sym_return_type(all([e() for e in self.expr(*args, **kwargs)]))
95
+ return self.sym_return_type(all(e() for e in self.expr(*args, **kwargs)))
91
96
 
92
97
 
93
98
  class Try(Expression):
@@ -117,14 +122,14 @@ class Lambda(Expression):
117
122
 
118
123
 
119
124
  class Choice(Expression):
120
- def __init__(self, cases: List[str], default: Optional[str] = None, **kwargs):
125
+ def __init__(self, cases: list[str], default: str | None = None, **kwargs):
121
126
  super().__init__(**kwargs)
122
- self.cases: List[str] = cases
123
- self.default: Optional[str] = default
127
+ self.cases: list[str] = cases
128
+ self.default: str | None = default
124
129
 
125
130
  def forward(self, sym: Symbol, *args, **kwargs) -> Symbol:
126
131
  sym = self._to_symbol(sym)
127
- return sym.choice(cases=self.cases, default=self.default, *args, **kwargs)
132
+ return sym.choice(*args, cases=self.cases, default=self.default, **kwargs)
128
133
 
129
134
 
130
135
  class Output(Expression):
@@ -137,13 +142,13 @@ class Output(Expression):
137
142
  def forward(self, *args, **kwargs) -> Expression:
138
143
  kwargs['verbose'] = self.verbose
139
144
  kwargs['handler'] = self.handler
140
- return self.output(expr=self.expr, *args, **kwargs)
145
+ return self.output(*args, expr=self.expr, **kwargs)
141
146
 
142
147
 
143
148
  class Sequence(TrackerTraceable):
144
- def __init__(self, *expressions: List[Expression], **kwargs):
149
+ def __init__(self, *expressions: list[Expression], **kwargs):
145
150
  super().__init__(**kwargs)
146
- self.expressions: List[Expression] = expressions
151
+ self.expressions: list[Expression] = expressions
147
152
 
148
153
  def forward(self, *args, **kwargs) -> Symbol:
149
154
  sym = self.expressions[0](*args, **kwargs)
@@ -159,11 +164,11 @@ class Sequence(TrackerTraceable):
159
164
 
160
165
 
161
166
  class Parallel(Expression):
162
- def __init__(self, *expr: List[Expression | Callable], sequential: bool = False, **kwargs):
167
+ def __init__(self, *expr: list[Expression | Callable], sequential: bool = False, **kwargs):
163
168
  super().__init__(**kwargs)
164
169
  self.sequential: bool = sequential
165
- self.expr: List[Expression] = expr
166
- self.results: List[Symbol] = []
170
+ self.expr: list[Expression] = expr
171
+ self.results: list[Symbol] = []
167
172
 
168
173
  def forward(self, *args, **kwargs) -> Symbol:
169
174
  # run in sequence
@@ -180,11 +185,11 @@ class Parallel(Expression):
180
185
 
181
186
  #@TODO: BinPacker(format="...") -> ensure that data packages form a "bin" that's consistent (e.g. never break a sentence in the middle)
182
187
  class Stream(Expression):
183
- def __init__(self, expr: Optional[Expression] = None, retrieval: Optional[str] = None, **kwargs):
188
+ def __init__(self, expr: Expression | None = None, retrieval: str | None = None, **kwargs):
184
189
  super().__init__(**kwargs)
185
190
  self.char_token_ratio: float = 0.6
186
- self.expr: Optional[Expression] = expr
187
- self.retrieval: Optional[str] = retrieval
191
+ self.expr: Expression | None = expr
192
+ self.retrieval: str | None = retrieval
188
193
  self._trace: bool = False
189
194
  self._previous_frame = None
190
195
 
@@ -194,19 +199,23 @@ class Stream(Expression):
194
199
  if self._trace:
195
200
  local_vars = self._previous_frame.f_locals
196
201
  vals = []
197
- for key, var in local_vars.items():
202
+ for _key, var in local_vars.items():
198
203
  if isinstance(var, TrackerTraceable):
199
204
  vals.append(var)
200
205
 
201
206
  if len(vals) == 1:
202
207
  self.expr = vals[0]
203
208
  else:
204
- raise ValueError(f"This component does either not inherit from TrackerTraceable or has an invalid number of component declarations: {len(vals)}! Only one component that inherits from TrackerTraceable is allowed in the with stream clause.")
209
+ UserMessage(
210
+ "This component does either not inherit from TrackerTraceable or has an invalid number of component "
211
+ f"declarations: {len(vals)}! Only one component that inherits from TrackerTraceable is allowed in the "
212
+ "with stream clause.",
213
+ raise_with=ValueError,
214
+ )
205
215
 
206
216
  res = sym.stream(expr=self.expr,
207
217
  char_token_ratio=self.char_token_ratio,
208
218
  **kwargs)
209
-
210
219
  if self.retrieval is not None:
211
220
  res = list(res)
212
221
  if self.retrieval == 'all':
@@ -215,9 +224,8 @@ class Stream(Expression):
215
224
  res = sorted(res, key=lambda x: len(x), reverse=True)
216
225
  return res[0]
217
226
  if self.retrieval == 'contains':
218
- res = [r for r in res if self.expr in r]
219
- return res
220
- raise ValueError(f"Invalid retrieval method: {self.retrieval}")
227
+ return [r for r in res if self.expr in r]
228
+ UserMessage(f"Invalid retrieval method: {self.retrieval}", raise_with=ValueError)
221
229
 
222
230
  return res
223
231
 
@@ -231,10 +239,12 @@ class Stream(Expression):
231
239
 
232
240
 
233
241
  class Trace(Expression):
234
- def __init__(self, expr: Optional[Expression] = None, engines=['all'], **kwargs):
242
+ def __init__(self, expr: Expression | None = None, engines=None, **kwargs):
243
+ if engines is None:
244
+ engines = ['all']
235
245
  super().__init__(**kwargs)
236
246
  self.expr: Expression = expr
237
- self.engines: List[str] = engines
247
+ self.engines: list[str] = engines
238
248
 
239
249
  def forward(self, *args, **kwargs) -> Expression:
240
250
  Expression.command(verbose=True, engines=self.engines)
@@ -252,23 +262,26 @@ class Trace(Expression):
252
262
  Expression.command(verbose=False, engines=self.engines)
253
263
  if self.expr is not None:
254
264
  return self.expr.__exit__(type, value, traceback)
265
+ return None
255
266
 
256
267
 
257
268
  class Analyze(Expression):
258
- def __init__(self, exception: Exception, query: Optional[str] = None, **kwargs):
269
+ def __init__(self, exception: Exception, query: str | None = None, **kwargs):
259
270
  super().__init__(**kwargs)
260
271
  self.exception: Expression = exception
261
- self.query: Optional[str] = query
272
+ self.query: str | None = query
262
273
 
263
274
  def forward(self, sym: Symbol, *args, **kwargs) -> Symbol:
264
- return sym.analyze(exception=self.exception, query=self.query, *args, **kwargs)
275
+ return sym.analyze(*args, exception=self.exception, query=self.query, **kwargs)
265
276
 
266
277
 
267
278
  class Log(Expression):
268
- def __init__(self, expr: Optional[Expression] = None, engines=['all'], **kwargs):
279
+ def __init__(self, expr: Expression | None = None, engines=None, **kwargs):
280
+ if engines is None:
281
+ engines = ['all']
269
282
  super().__init__(**kwargs)
270
283
  self.expr: Expression = expr
271
- self.engines: List[str] = engines
284
+ self.engines: list[str] = engines
272
285
 
273
286
  def forward(self, *args, **kwargs) -> Expression:
274
287
  Expression.command(logging=True, engines=self.engines)
@@ -286,6 +299,7 @@ class Log(Expression):
286
299
  Expression.command(logging=False, engines=self.engines)
287
300
  if self.expr is not None:
288
301
  return self.expr.__exit__(type, value, traceback)
302
+ return None
289
303
 
290
304
 
291
305
  class Template(Expression):
@@ -331,10 +345,10 @@ class Metric(Expression):
331
345
  self.normalize = normalize
332
346
  self.eps = eps
333
347
 
334
- def forward(self, sym: Symbol, **kwargs) -> Symbol:
348
+ def forward(self, sym: Symbol, **_kwargs) -> Symbol:
335
349
  sym = self._to_symbol(sym)
336
- assert sym.value_type == np.ndarray or sym.value_type == list, 'Metric can only be applied to numpy arrays or lists.'
337
- if sym.value_type == list:
350
+ assert sym.value_type is np.ndarray or sym.value_type is list, 'Metric can only be applied to numpy arrays or lists.'
351
+ if sym.value_type is list:
338
352
  sym._value = np.array(sym.value)
339
353
  # compute normalization between 0 and 1
340
354
  if self.normalize:
@@ -343,17 +357,19 @@ class Metric(Expression):
343
357
  elif len(sym.value.shape) == 2:
344
358
  pass
345
359
  else:
346
- raise ValueError(f'Invalid shape: {sym.value.shape}')
360
+ UserMessage(f'Invalid shape: {sym.value.shape}', raise_with=ValueError)
347
361
  # normalize between 0 and 1 and sum to 1
348
362
  sym._value = np.exp(sym.value) / (np.exp(sym.value).sum() + self.eps)
349
363
  return sym
350
364
 
351
365
 
352
366
  class Style(Expression):
353
- def __init__(self, description: str, libraries: List[str] = [], **kwargs):
367
+ def __init__(self, description: str, libraries: list[str] | None = None, **kwargs):
368
+ if libraries is None:
369
+ libraries = []
354
370
  super().__init__(**kwargs)
355
371
  self.description: str = description
356
- self.libraries: List[str] = libraries
372
+ self.libraries: list[str] = libraries
357
373
 
358
374
  def forward(self, sym: Symbol, **kwargs) -> Symbol:
359
375
  sym = self._to_symbol(sym)
@@ -365,7 +381,7 @@ class Query(TrackerTraceable):
365
381
  super().__init__(**kwargs)
366
382
  self.prompt: str = prompt
367
383
 
368
- def forward(self, sym: Symbol, context: Symbol = None, *args, **kwargs) -> Symbol:
384
+ def forward(self, sym: Symbol, context: Symbol = None, *_args, **kwargs) -> Symbol:
369
385
  sym = self._to_symbol(sym)
370
386
  return sym.query(prompt=self.prompt, context=context, **kwargs)
371
387
 
@@ -474,11 +490,11 @@ class ExcludeFilter(Expression):
474
490
  class FileWriter(Expression):
475
491
  def __init__(self, path: str, **kwargs):
476
492
  super().__init__(**kwargs)
477
- self.path = path
493
+ self.path = Path(path)
478
494
 
479
- def forward(self, sym: Symbol, **kwargs) -> Symbol:
495
+ def forward(self, sym: Symbol, **_kwargs) -> Symbol:
480
496
  sym = self._to_symbol(sym)
481
- with open(self.path, 'w') as f:
497
+ with self.path.open('w') as f:
482
498
  f.write(str(sym))
483
499
 
484
500
 
@@ -493,12 +509,10 @@ class FileReader(Expression):
493
509
  assert len(_splits) == 1 or len(_splits) == 2, 'Invalid file link format.'
494
510
  _tmp = Path(_tmp)
495
511
  # check if file exists and is a file
496
- if os.path.exists(_tmp) and os.path.isfile(_tmp):
497
- return True
498
- return False
512
+ return _tmp.is_file()
499
513
 
500
514
  @staticmethod
501
- def get_files(folder_path: str, max_depth: int = 1) -> List[str]:
515
+ def get_files(folder_path: str, max_depth: int = 1) -> list[str]:
502
516
  accepted_formats = ['.pdf', '.md', '.txt']
503
517
 
504
518
  folder = Path(folder_path)
@@ -512,7 +526,7 @@ class FileReader(Expression):
512
526
  return files
513
527
 
514
528
  @staticmethod
515
- def extract_files(cmds: str) -> Optional[List[str]]:
529
+ def extract_files(cmds: str) -> list[str] | None:
516
530
  # Use the updated regular expression to match quoted and non-quoted paths
517
531
  pattern = r'''(?:"((?:\\.|[^"\\])*)"|'((?:\\.|[^'\\])*)'|`((?:\\.|[^`\\])*)`|((?:\\ |[^ ])+))'''
518
532
  # Use the regular expression to split and handle quoted and non-quoted paths
@@ -551,16 +565,16 @@ class FileReader(Expression):
551
565
  return Path(path).expanduser().resolve().as_posix()
552
566
 
553
567
  @staticmethod
554
- def integrity_check(files: List[str]) -> List[str]:
568
+ def integrity_check(files: list[str]) -> list[str]:
555
569
  not_skipped = []
556
570
  for file in tqdm(files):
557
571
  if FileReader.exists(file):
558
572
  not_skipped.append(file)
559
573
  else:
560
- CustomUserWarning(f'Skipping file: {file}')
574
+ UserMessage(f'Skipping file: {file}')
561
575
  return not_skipped
562
576
 
563
- def forward(self, files: Union[str, List[str]], **kwargs) -> Expression:
577
+ def forward(self, files: str | list[str], **kwargs) -> Expression:
564
578
  if isinstance(files, str):
565
579
  # Convert to list for uniform processing; more easily downstream
566
580
  files = [files]
@@ -586,15 +600,17 @@ class FileQuery(Expression):
586
600
 
587
601
  class Function(TrackerTraceable):
588
602
  def __init__(self, prompt: str = '',
589
- examples: Optional[str] = [],
590
- pre_processors: Optional[List[PreProcessor]] = None,
591
- post_processors: Optional[List[PostProcessor]] = None,
592
- default: Optional[object] = None,
593
- constraints: List[Callable] = [],
594
- return_type: Optional[Type] = str,
595
- sym_return_type: Optional[Type] = Symbol,
596
- origin_type: Optional[Type] = Expression,
603
+ examples: str | None = [],
604
+ pre_processors: list[PreProcessor] | None = None,
605
+ post_processors: list[PostProcessor] | None = None,
606
+ default: object | None = None,
607
+ constraints: list[Callable] | None = None,
608
+ return_type: type | None = str,
609
+ sym_return_type: type | None = Symbol,
610
+ origin_type: type | None = Expression,
597
611
  *args, **kwargs):
612
+ if constraints is None:
613
+ constraints = []
598
614
  super().__init__(**kwargs)
599
615
  chars = ascii_lowercase + ascii_uppercase
600
616
  self.name = 'func_' + ''.join(sample(chars, 15))
@@ -629,13 +645,16 @@ class Function(TrackerTraceable):
629
645
  if 'fn' in kwargs:
630
646
  self.prompt = kwargs['fn']
631
647
  del kwargs['fn']
632
- @core.few_shot(prompt=self.prompt,
633
- examples=self.examples,
634
- pre_processors=self.pre_processors,
635
- post_processors=self.post_processors,
636
- constraints=self.constraints,
637
- default=self.default,
638
- *self.args, **self.kwargs)
648
+ @core.few_shot(
649
+ *self.args,
650
+ prompt=self.prompt,
651
+ examples=self.examples,
652
+ pre_processors=self.pre_processors,
653
+ post_processors=self.post_processors,
654
+ constraints=self.constraints,
655
+ default=self.default,
656
+ **self.kwargs
657
+ )
639
658
  def _func(_, *args, **kwargs) -> self.return_type:
640
659
  pass
641
660
  _type = type(self.name, (self.origin_type, ), {
@@ -658,11 +677,11 @@ class PrepareData(Function):
658
677
  assert argument.prop.context is not None
659
678
  instruct = argument.prop.prompt
660
679
  context = argument.prop.context
661
- return """{
662
- 'context': '%s',
663
- 'instruction': '%s',
680
+ return f"""{{
681
+ 'context': '{context}',
682
+ 'instruction': '{instruct}',
664
683
  'result': 'TODO: Replace this with the expected result.'
665
- }""" % (context, instruct)
684
+ }}"""
666
685
 
667
686
  def __init__(self, *args, **kwargs):
668
687
  super().__init__(*args, **kwargs)
@@ -707,7 +726,7 @@ class ExpressionBuilder(Function):
707
726
  super().__init__('Generate the code following the instructions:', **kwargs)
708
727
  self.processors = ProcessorPipeline([StripPostProcessor(), CodeExtractPostProcessor()])
709
728
 
710
- def forward(self, instruct, *args, **kwargs):
729
+ def forward(self, instruct, *_args, **_kwargs):
711
730
  result = super().forward(instruct)
712
731
  return self.processors(str(result), None)
713
732
 
@@ -768,21 +787,21 @@ class JsonParser(Expression):
768
787
 
769
788
 
770
789
  class SimilarityClassification(Expression):
771
- def __init__(self, classes: List[str], metric: str = 'cosine', in_memory: bool = False, **kwargs):
790
+ def __init__(self, classes: list[str], metric: str = 'cosine', in_memory: bool = False, **kwargs):
772
791
  super().__init__(**kwargs)
773
792
  self.classes = classes
774
793
  self.metric = metric
775
794
  self.in_memory = in_memory
776
795
 
777
796
  if self.in_memory:
778
- CustomUserWarning(f'Caching mode is enabled! It is your responsability to empty the .cache folder if you did changes to the classes. The cache is located at {HOME_PATH}/cache')
797
+ UserMessage(f'Caching mode is enabled! It is your responsability to empty the .cache folder if you did changes to the classes. The cache is located at {HOME_PATH}/cache')
779
798
 
780
799
  def forward(self, x: Symbol) -> Symbol:
781
800
  x = self._to_symbol(x)
782
801
  usr_embed = x.embed()
783
802
  embeddings = self._dynamic_cache()
784
803
  similarities = [usr_embed.similarity(emb, metric=self.metric) for emb in embeddings]
785
- similarities = sorted(zip(self.classes, similarities), key=lambda x: x[1], reverse=True)
804
+ similarities = sorted(zip(self.classes, similarities, strict=False), key=lambda x: x[1], reverse=True)
786
805
 
787
806
  return Symbol(similarities[0][0])
788
807
 
@@ -790,9 +809,7 @@ class SimilarityClassification(Expression):
790
809
  @core_ext.cache(in_memory=self.in_memory)
791
810
  def embed_classes(self):
792
811
  opts = map(Symbol, self.classes)
793
- embeddings = [opt.embed() for opt in opts]
794
-
795
- return embeddings
812
+ return [opt.embed() for opt in opts]
796
813
 
797
814
  return embed_classes(self)
798
815
 
@@ -820,19 +837,14 @@ class Indexer(Expression):
820
837
  @staticmethod
821
838
  def replace_special_chars(index: str):
822
839
  # replace special characters that are not for path
823
- index = str(index)
824
- index = index.replace('-', '')
825
- index = index.replace('_', '')
826
- index = index.replace(' ', '')
827
- index = index.lower()
828
- return index
840
+ return str(index).replace('-', '').replace('_', '').replace(' ', '').lower()
829
841
 
830
842
  def __init__(
831
843
  self,
832
844
  index_name: str = DEFAULT,
833
845
  top_k: int = 8,
834
846
  batch_size: int = 20,
835
- formatter: Callable = ParagraphFormatter(),
847
+ formatter: Callable = _DEFAULT_PARAGRAPH_FORMATTER,
836
848
  auto_add=False,
837
849
  raw_result: bool = False,
838
850
  new_dim: int = 1536,
@@ -861,15 +873,15 @@ class Indexer(Expression):
861
873
  def register(self):
862
874
  # check if index already exists in indices.txt and append if not
863
875
  change = False
864
- with open(self.path, 'r') as f:
876
+ with self.path.open() as f:
865
877
  indices = f.read().split('\n')
866
878
  # filter out empty strings
867
879
  indices = [i for i in indices if i]
868
- if self.index_name not in indices:
880
+ if self.index_name not in indices:
869
881
  indices.append(self.index_name)
870
882
  change = True
871
883
  if change:
872
- with open(self.path, 'w') as f:
884
+ with self.path.open('w') as f:
873
885
  f.write('\n'.join(indices))
874
886
 
875
887
  def exists(self) -> bool:
@@ -877,15 +889,16 @@ class Indexer(Expression):
877
889
  path = HOME_PATH / 'indices.txt'
878
890
  if not path.exists():
879
891
  return False
880
- with open(path, 'r') as f:
892
+ with path.open() as f:
881
893
  indices = f.read().split('\n')
882
894
  if self.index_name in indices:
883
895
  return True
896
+ return False
884
897
 
885
898
  def forward(
886
899
  self,
887
- data: Optional[Symbol] = None,
888
- raw_result: bool = False,
900
+ data: Symbol | None = None,
901
+ _raw_result: bool = False,
889
902
  ) -> Symbol:
890
903
  that = self
891
904
  if data is not None:
@@ -898,15 +911,14 @@ class Indexer(Expression):
898
911
  # we save the index
899
912
  that.config(None, save=True, index_name=that.index_name, index_dims=that.new_dim)
900
913
 
901
- def _func(query, *args, **kwargs) -> Union[Symbol, 'VectorDBResult']:
914
+ def _func(query, *_args, **kwargs) -> Union[Symbol, 'VectorDBResult']:
902
915
  raw_result = kwargs.get('raw_result') or that.raw_result
903
916
  query_emb = Symbol(query).embed(new_dim=that.new_dim).value
904
917
  res = that.get(query_emb, index_name=that.index_name, index_top_k=that.top_k, ori_query=query, index_dims=that.new_dim, **kwargs)
905
918
  that.retrieval = res
906
919
  if raw_result:
907
920
  return res
908
- rsp = Symbol(res).query(prompt='From the retrieved data, select the most relevant information.', context=query)
909
- return rsp
921
+ return Symbol(res).query(prompt='From the retrieved data, select the most relevant information.', context=query)
910
922
  return _func
911
923
 
912
924
 
@@ -917,8 +929,8 @@ class PrimitiveDisabler(Expression):
917
929
  self._original_primitives = defaultdict(list)
918
930
 
919
931
  def __enter__(self):
920
- # Avoid circular imports; import locally
921
- from .symbol import Symbol
932
+ # Import Symbol lazily so components does not clash with symbol during load.
933
+ from .symbol import Symbol # noqa
922
934
 
923
935
  frame = inspect.currentframe()
924
936
  f_locals = frame.f_back.f_locals
@@ -934,7 +946,7 @@ class PrimitiveDisabler(Expression):
934
946
  for func in self._primitives:
935
947
  if hasattr(sym, func):
936
948
  self._original_primitives[sym_name].append((func, getattr(sym, func)))
937
- setattr(sym, func, lambda *args, **kwargs: None)
949
+ setattr(sym, func, lambda *_args, **_kwargs: None)
938
950
 
939
951
  def _enable_primitives(self):
940
952
  for sym_name, sym in self._symbols.items():
@@ -968,7 +980,7 @@ class FunctionWithUsage(Function):
968
980
 
969
981
  def print_verbose(self, msg):
970
982
  if self.verbose:
971
- print(msg)
983
+ UserMessage(msg)
972
984
 
973
985
  def _format_usage(self, prompt_tokens, completion_tokens, total_tokens):
974
986
  return Box(
@@ -1020,12 +1032,11 @@ class FunctionWithUsage(Function):
1020
1032
  self.completion_tokens += completion_tokens
1021
1033
  self.total_tokens += total_tokens
1022
1034
  else:
1023
- if self.missing_usage_exception and not "preview" in kwargs:
1024
- raise Exception("Missing usage in metadata of neursymbolic engine")
1025
- else:
1026
- prompt_tokens = 0
1027
- completion_tokens = 0
1028
- total_tokens = 0
1035
+ if self.missing_usage_exception and "preview" not in kwargs:
1036
+ UserMessage("Missing usage in metadata of neursymbolic engine", raise_with=Exception)
1037
+ prompt_tokens = 0
1038
+ completion_tokens = 0
1039
+ total_tokens = 0
1029
1040
 
1030
1041
  return res, self._format_usage(prompt_tokens, completion_tokens, total_tokens)
1031
1042
 
@@ -1041,7 +1052,7 @@ class SelfPrompt(Expression):
1041
1052
  def __init__(self, *args, **kwargs):
1042
1053
  super().__init__(*args, **kwargs)
1043
1054
 
1044
- def forward(self, existing_prompt: Dict[str, str], **kwargs) -> Dict[str, str]:
1055
+ def forward(self, existing_prompt: dict[str, str], **kwargs) -> dict[str, str]:
1045
1056
  """
1046
1057
  Generate new system and user prompts based on the existing prompt.
1047
1058
 
@@ -1094,14 +1105,13 @@ class MetadataTracker(Expression):
1094
1105
  value = value or self.metadata
1095
1106
  if isinstance(value, dict):
1096
1107
  return '{\n\t' + ', \n\t'.join(f'"{k}": {self.__str__(v)}' for k,v in value.items()) + '\n}'
1097
- elif isinstance(value, list):
1108
+ if isinstance(value, list):
1098
1109
  return '[' + ', '.join(self.__str__(item) for item in value) + ']'
1099
- elif isinstance(value, str):
1110
+ if isinstance(value, str):
1100
1111
  return f'"{value}"'
1101
- else:
1102
- return f"\n\t {value}"
1112
+ return f"\n\t {value}"
1103
1113
 
1104
- def __new__(cls, *args, **kwargs):
1114
+ def __new__(cls, *_args, **_kwargs):
1105
1115
  cls._lock = getattr(cls, '_lock', Lock())
1106
1116
  with cls._lock:
1107
1117
  instance = super().__new__(cls)
@@ -1122,25 +1132,26 @@ class MetadataTracker(Expression):
1122
1132
 
1123
1133
  def _trace_calls(self, frame, event, arg):
1124
1134
  if not self._trace:
1125
- return
1135
+ return None
1126
1136
 
1127
- if event == 'return' and frame.f_code.co_name == 'forward':
1128
- # Check if this is an engine forward call
1129
- if ('self' in frame.f_locals
1130
- and
1131
- isinstance(frame.f_locals['self'], Engine)):
1132
- _, metadata = arg # arg contains return value on 'return' event
1133
- engine_name = frame.f_locals['self'].__class__.__name__
1134
- model_name = frame.f_locals['self'].model
1135
- self._metadata[(self._metadata_id, engine_name, model_name)] = metadata
1136
- self._metadata_id += 1
1137
+ if (
1138
+ event == 'return'
1139
+ and frame.f_code.co_name == 'forward'
1140
+ and 'self' in frame.f_locals
1141
+ and isinstance(frame.f_locals['self'], Engine)
1142
+ ):
1143
+ _, metadata = arg # arg contains return value on 'return' event
1144
+ engine_name = frame.f_locals['self'].__class__.__name__
1145
+ model_name = frame.f_locals['self'].model
1146
+ self._metadata[(self._metadata_id, engine_name, model_name)] = metadata
1147
+ self._metadata_id += 1
1137
1148
 
1138
1149
  return self._trace_calls
1139
1150
 
1140
1151
  def _accumulate_completion_token_details(self):
1141
1152
  """Parses the return object and accumulates completion token details per token type"""
1142
1153
  if not self._metadata:
1143
- CustomUserWarning("No metadata available to generate usage details.")
1154
+ UserMessage("No metadata available to generate usage details.")
1144
1155
  return {}
1145
1156
 
1146
1157
  token_details = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
@@ -1182,15 +1193,59 @@ class MetadataTracker(Expression):
1182
1193
  logger.warning(f"Tracking {engine_name} is not supported.")
1183
1194
  continue
1184
1195
  except Exception as e:
1185
- CustomUserWarning(f"Failed to parse metadata for {engine_name}: {e}", raise_with=AttributeError)
1196
+ UserMessage(f"Failed to parse metadata for {engine_name}: {e}", raise_with=AttributeError)
1186
1197
 
1187
1198
  # Convert to normal dict
1188
1199
  return {**token_details}
1189
1200
 
1201
+ def _can_accumulate_engine(self, engine_name: str) -> bool:
1202
+ supported_engines = ("GPTXChatEngine", "GPTXReasoningEngine", "GPTXSearchEngine")
1203
+ return engine_name in supported_engines
1204
+
1205
+ def _accumulate_time_field(self, accumulated: dict, metadata: dict) -> None:
1206
+ if 'time' in metadata and 'time' in accumulated:
1207
+ accumulated['time'] += metadata['time']
1208
+
1209
+ def _accumulate_usage_fields(self, accumulated: dict, metadata: dict) -> None:
1210
+ if 'raw_output' not in metadata or 'raw_output' not in accumulated:
1211
+ return
1212
+
1213
+ metadata_raw_output = metadata['raw_output']
1214
+ accumulated_raw_output = accumulated['raw_output']
1215
+ if not hasattr(metadata_raw_output, 'usage') or not hasattr(accumulated_raw_output, 'usage'):
1216
+ return
1217
+
1218
+ current_usage = metadata_raw_output.usage
1219
+ accumulated_usage = accumulated_raw_output.usage
1220
+
1221
+ for attr in ['completion_tokens', 'prompt_tokens', 'total_tokens']:
1222
+ if hasattr(current_usage, attr) and hasattr(accumulated_usage, attr):
1223
+ setattr(
1224
+ accumulated_usage,
1225
+ attr,
1226
+ getattr(accumulated_usage, attr) + getattr(current_usage, attr),
1227
+ )
1228
+
1229
+ for detail_attr in ['completion_tokens_details', 'prompt_tokens_details']:
1230
+ if not hasattr(current_usage, detail_attr) or not hasattr(accumulated_usage, detail_attr):
1231
+ continue
1232
+
1233
+ current_details = getattr(current_usage, detail_attr)
1234
+ accumulated_details = getattr(accumulated_usage, detail_attr)
1235
+
1236
+ for attr in dir(current_details):
1237
+ if attr.startswith('_') or not hasattr(accumulated_details, attr):
1238
+ continue
1239
+
1240
+ current_val = getattr(current_details, attr)
1241
+ accumulated_val = getattr(accumulated_details, attr)
1242
+ if isinstance(current_val, (int, float)) and isinstance(accumulated_val, (int, float)):
1243
+ setattr(accumulated_details, attr, accumulated_val + current_val)
1244
+
1190
1245
  def _accumulate_metadata(self):
1191
1246
  """Accumulates metadata across all tracked engine calls."""
1192
1247
  if not self._metadata:
1193
- CustomUserWarning("No metadata available to generate usage details.")
1248
+ UserMessage("No metadata available to generate usage details.")
1194
1249
  return {}
1195
1250
 
1196
1251
  # Use first entry as base
@@ -1199,39 +1254,12 @@ class MetadataTracker(Expression):
1199
1254
 
1200
1255
  # Skipz first entry
1201
1256
  for (_, engine_name), metadata in list(self._metadata.items())[1:]:
1202
- if engine_name not in ("GPTXChatEngine", "GPTXReasoningEngine", "GPTXSearchEngine"):
1257
+ if not self._can_accumulate_engine(engine_name):
1203
1258
  logger.warning(f"Metadata accumulation for {engine_name} is not supported. Try `.usage` instead for now.")
1204
1259
  continue
1205
1260
 
1206
- # Accumulate time if it exists
1207
- if 'time' in metadata and 'time' in accumulated:
1208
- accumulated['time'] += metadata['time']
1209
-
1210
- # Handle usage stats accumulation
1211
- if 'raw_output' in metadata and 'raw_output' in accumulated:
1212
- if hasattr(metadata['raw_output'], 'usage') and hasattr(accumulated['raw_output'], 'usage'):
1213
- current_usage = metadata['raw_output'].usage
1214
- accumulated_usage = accumulated['raw_output'].usage
1215
-
1216
- # Accumulate token counts
1217
- for attr in ['completion_tokens', 'prompt_tokens', 'total_tokens']:
1218
- if hasattr(current_usage, attr) and hasattr(accumulated_usage, attr):
1219
- setattr(accumulated_usage, attr,
1220
- getattr(accumulated_usage, attr) + getattr(current_usage, attr))
1221
-
1222
- # Handle nested token details if they exist
1223
- for detail_attr in ['completion_tokens_details', 'prompt_tokens_details']:
1224
- if hasattr(current_usage, detail_attr) and hasattr(accumulated_usage, detail_attr):
1225
- current_details = getattr(current_usage, detail_attr)
1226
- accumulated_details = getattr(accumulated_usage, detail_attr)
1227
-
1228
- # Accumulate all numeric attributes in the details
1229
- for attr in dir(current_details):
1230
- if not attr.startswith('_') and hasattr(accumulated_details, attr):
1231
- current_val = getattr(current_details, attr)
1232
- accumulated_val = getattr(accumulated_details, attr)
1233
- if isinstance(current_val, (int, float)) and isinstance(accumulated_val, (int, float)):
1234
- setattr(accumulated_details, attr, accumulated_val + current_val)
1261
+ self._accumulate_time_field(accumulated, metadata)
1262
+ self._accumulate_usage_fields(accumulated, metadata)
1235
1263
 
1236
1264
  return accumulated
1237
1265
 
@@ -1250,7 +1278,7 @@ class MetadataTracker(Expression):
1250
1278
 
1251
1279
  class DynamicEngine(Expression):
1252
1280
  """Context manager for dynamically switching neurosymbolic engine models."""
1253
- def __init__(self, model: str, api_key: str, debug: bool = False, **kwargs):
1281
+ def __init__(self, model: str, api_key: str, _debug: bool = False, **_kwargs):
1254
1282
  super().__init__()
1255
1283
  self.model = model
1256
1284
  self.api_key = api_key
@@ -1259,7 +1287,7 @@ class DynamicEngine(Expression):
1259
1287
  self.engine_instance = None
1260
1288
  self._ctx_token = None
1261
1289
 
1262
- def __new__(cls, *args, **kwargs):
1290
+ def __new__(cls, *_args, **_kwargs):
1263
1291
  cls._lock = getattr(cls, '_lock', Lock())
1264
1292
  with cls._lock:
1265
1293
  instance = super().__new__(cls)
@@ -1293,11 +1321,12 @@ class DynamicEngine(Expression):
1293
1321
 
1294
1322
  def _create_engine_instance(self):
1295
1323
  """Create an engine instance based on the model name."""
1296
- from .backend.engines.neurosymbolic import ENGINE_MAPPING
1324
+ # Deferred to avoid components <-> neurosymbolic engine circular imports.
1325
+ from .backend.engines.neurosymbolic import ENGINE_MAPPING # noqa
1297
1326
  try:
1298
1327
  engine_class = ENGINE_MAPPING.get(self.model)
1299
1328
  if engine_class is None:
1300
- raise ValueError(f"Unsupported model '{self.model}'")
1329
+ UserMessage(f"Unsupported model '{self.model}'", raise_with=ValueError)
1301
1330
  return engine_class(api_key=self.api_key, model=self.model)
1302
1331
  except Exception as e:
1303
- raise ValueError(f"Failed to create engine for model '{self.model}': {str(e)}")
1332
+ UserMessage(f"Failed to create engine for model '{self.model}': {e!s}", raise_with=ValueError)