symbolicai 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. symai/__init__.py +198 -134
  2. symai/backend/base.py +51 -51
  3. symai/backend/engines/drawing/engine_bfl.py +33 -33
  4. symai/backend/engines/drawing/engine_gpt_image.py +4 -10
  5. symai/backend/engines/embedding/engine_llama_cpp.py +50 -35
  6. symai/backend/engines/embedding/engine_openai.py +22 -16
  7. symai/backend/engines/execute/engine_python.py +16 -16
  8. symai/backend/engines/files/engine_io.py +51 -49
  9. symai/backend/engines/imagecaptioning/engine_blip2.py +27 -23
  10. symai/backend/engines/imagecaptioning/engine_llavacpp_client.py +53 -46
  11. symai/backend/engines/index/engine_pinecone.py +116 -88
  12. symai/backend/engines/index/engine_qdrant.py +1011 -0
  13. symai/backend/engines/index/engine_vectordb.py +78 -52
  14. symai/backend/engines/lean/engine_lean4.py +65 -25
  15. symai/backend/engines/neurosymbolic/__init__.py +28 -28
  16. symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_chat.py +137 -135
  17. symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_reasoning.py +145 -152
  18. symai/backend/engines/neurosymbolic/engine_cerebras.py +328 -0
  19. symai/backend/engines/neurosymbolic/engine_deepseekX_reasoning.py +75 -49
  20. symai/backend/engines/neurosymbolic/engine_google_geminiX_reasoning.py +199 -155
  21. symai/backend/engines/neurosymbolic/engine_groq.py +106 -72
  22. symai/backend/engines/neurosymbolic/engine_huggingface.py +100 -67
  23. symai/backend/engines/neurosymbolic/engine_llama_cpp.py +121 -93
  24. symai/backend/engines/neurosymbolic/engine_openai_gptX_chat.py +213 -132
  25. symai/backend/engines/neurosymbolic/engine_openai_gptX_reasoning.py +180 -137
  26. symai/backend/engines/ocr/engine_apilayer.py +18 -20
  27. symai/backend/engines/output/engine_stdout.py +9 -9
  28. symai/backend/engines/{webscraping → scrape}/engine_requests.py +25 -11
  29. symai/backend/engines/search/engine_openai.py +95 -83
  30. symai/backend/engines/search/engine_parallel.py +665 -0
  31. symai/backend/engines/search/engine_perplexity.py +40 -41
  32. symai/backend/engines/search/engine_serpapi.py +33 -28
  33. symai/backend/engines/speech_to_text/engine_local_whisper.py +37 -27
  34. symai/backend/engines/symbolic/engine_wolframalpha.py +14 -8
  35. symai/backend/engines/text_to_speech/engine_openai.py +15 -19
  36. symai/backend/engines/text_vision/engine_clip.py +34 -28
  37. symai/backend/engines/userinput/engine_console.py +3 -4
  38. symai/backend/mixin/anthropic.py +48 -40
  39. symai/backend/mixin/deepseek.py +4 -5
  40. symai/backend/mixin/google.py +5 -4
  41. symai/backend/mixin/groq.py +2 -4
  42. symai/backend/mixin/openai.py +132 -110
  43. symai/backend/settings.py +14 -14
  44. symai/chat.py +164 -94
  45. symai/collect/dynamic.py +13 -11
  46. symai/collect/pipeline.py +39 -31
  47. symai/collect/stats.py +109 -69
  48. symai/components.py +556 -238
  49. symai/constraints.py +14 -5
  50. symai/core.py +1495 -1210
  51. symai/core_ext.py +55 -50
  52. symai/endpoints/api.py +113 -58
  53. symai/extended/api_builder.py +22 -17
  54. symai/extended/arxiv_pdf_parser.py +13 -5
  55. symai/extended/bibtex_parser.py +8 -4
  56. symai/extended/conversation.py +88 -69
  57. symai/extended/document.py +40 -27
  58. symai/extended/file_merger.py +45 -7
  59. symai/extended/graph.py +38 -24
  60. symai/extended/html_style_template.py +17 -11
  61. symai/extended/interfaces/blip_2.py +1 -1
  62. symai/extended/interfaces/clip.py +4 -2
  63. symai/extended/interfaces/console.py +5 -3
  64. symai/extended/interfaces/dall_e.py +3 -1
  65. symai/extended/interfaces/file.py +2 -0
  66. symai/extended/interfaces/flux.py +3 -1
  67. symai/extended/interfaces/gpt_image.py +15 -6
  68. symai/extended/interfaces/input.py +2 -1
  69. symai/extended/interfaces/llava.py +1 -1
  70. symai/extended/interfaces/{naive_webscraping.py → naive_scrape.py} +3 -2
  71. symai/extended/interfaces/naive_vectordb.py +2 -2
  72. symai/extended/interfaces/ocr.py +4 -2
  73. symai/extended/interfaces/openai_search.py +2 -0
  74. symai/extended/interfaces/parallel.py +30 -0
  75. symai/extended/interfaces/perplexity.py +2 -0
  76. symai/extended/interfaces/pinecone.py +6 -4
  77. symai/extended/interfaces/python.py +2 -0
  78. symai/extended/interfaces/serpapi.py +2 -0
  79. symai/extended/interfaces/terminal.py +0 -1
  80. symai/extended/interfaces/tts.py +2 -1
  81. symai/extended/interfaces/whisper.py +2 -1
  82. symai/extended/interfaces/wolframalpha.py +1 -0
  83. symai/extended/metrics/__init__.py +1 -1
  84. symai/extended/metrics/similarity.py +5 -2
  85. symai/extended/os_command.py +31 -22
  86. symai/extended/packages/symdev.py +39 -34
  87. symai/extended/packages/sympkg.py +30 -27
  88. symai/extended/packages/symrun.py +46 -35
  89. symai/extended/repo_cloner.py +10 -9
  90. symai/extended/seo_query_optimizer.py +15 -12
  91. symai/extended/solver.py +104 -76
  92. symai/extended/summarizer.py +8 -7
  93. symai/extended/taypan_interpreter.py +10 -9
  94. symai/extended/vectordb.py +28 -15
  95. symai/formatter/formatter.py +39 -31
  96. symai/formatter/regex.py +46 -44
  97. symai/functional.py +184 -86
  98. symai/imports.py +85 -51
  99. symai/interfaces.py +1 -1
  100. symai/memory.py +33 -24
  101. symai/menu/screen.py +28 -19
  102. symai/misc/console.py +27 -27
  103. symai/misc/loader.py +4 -3
  104. symai/models/base.py +147 -76
  105. symai/models/errors.py +1 -1
  106. symai/ops/__init__.py +1 -1
  107. symai/ops/measures.py +17 -14
  108. symai/ops/primitives.py +933 -635
  109. symai/post_processors.py +28 -24
  110. symai/pre_processors.py +58 -52
  111. symai/processor.py +15 -9
  112. symai/prompts.py +714 -649
  113. symai/server/huggingface_server.py +115 -32
  114. symai/server/llama_cpp_server.py +14 -6
  115. symai/server/qdrant_server.py +206 -0
  116. symai/shell.py +98 -39
  117. symai/shellsv.py +307 -223
  118. symai/strategy.py +135 -81
  119. symai/symbol.py +276 -225
  120. symai/utils.py +62 -46
  121. {symbolicai-1.0.0.dist-info → symbolicai-1.1.0.dist-info}/METADATA +19 -9
  122. symbolicai-1.1.0.dist-info/RECORD +168 -0
  123. symbolicai-1.0.0.dist-info/RECORD +0 -163
  124. {symbolicai-1.0.0.dist-info → symbolicai-1.1.0.dist-info}/WHEEL +0 -0
  125. {symbolicai-1.0.0.dist-info → symbolicai-1.1.0.dist-info}/entry_points.txt +0 -0
  126. {symbolicai-1.0.0.dist-info → symbolicai-1.1.0.dist-info}/licenses/LICENSE +0 -0
  127. {symbolicai-1.0.0.dist-info → symbolicai-1.1.0.dist-info}/top_level.txt +0 -0
symai/components.py CHANGED
@@ -11,7 +11,11 @@ from string import ascii_lowercase, ascii_uppercase
11
11
  from threading import Lock
12
12
  from typing import TYPE_CHECKING, Union
13
13
 
14
+ if TYPE_CHECKING:
15
+ from typing import Any
16
+
14
17
  import numpy as np
18
+ from beartype import beartype
15
19
  from box import Box
16
20
  from loguru import logger
17
21
  from pyvis.network import Network
@@ -43,33 +47,42 @@ _DEFAULT_PARAGRAPH_FORMATTER = ParagraphFormatter()
43
47
 
44
48
 
45
49
  class GraphViz(Expression):
46
- def __init__(self,
47
- notebook = True,
48
- cdn_resources = "remote",
49
- bgcolor = "#222222",
50
- font_color = "white",
51
- height = "750px",
52
- width = "100%",
53
- select_menu = True,
54
- filter_menu = True,
55
- **kwargs):
50
+ def __init__(
51
+ self,
52
+ notebook=True,
53
+ cdn_resources="remote",
54
+ bgcolor="#222222",
55
+ font_color="white",
56
+ height="750px",
57
+ width="100%",
58
+ select_menu=True,
59
+ filter_menu=True,
60
+ **kwargs,
61
+ ):
56
62
  super().__init__(**kwargs)
57
- self.net = Network(notebook=notebook,
58
- cdn_resources=cdn_resources,
59
- bgcolor=bgcolor,
60
- font_color=font_color,
61
- height=height,
62
- width=width,
63
- select_menu=select_menu,
64
- filter_menu=filter_menu)
63
+ self.net = Network(
64
+ notebook=notebook,
65
+ cdn_resources=cdn_resources,
66
+ bgcolor=bgcolor,
67
+ font_color=font_color,
68
+ height=height,
69
+ width=width,
70
+ select_menu=select_menu,
71
+ filter_menu=filter_menu,
72
+ )
65
73
 
66
74
  def forward(self, sym: Symbol, file_path: str, **_kwargs):
67
75
  nodes = [str(n) if n.value else n.__repr__(simplified=True) for n in sym.nodes]
68
- edges = [(str(e[0]) if e[0].value else e[0].__repr__(simplified=True),
69
- str(e[1]) if e[1].value else e[1].__repr__(simplified=True)) for e in sym.edges]
76
+ edges = [
77
+ (
78
+ str(e[0]) if e[0].value else e[0].__repr__(simplified=True),
79
+ str(e[1]) if e[1].value else e[1].__repr__(simplified=True),
80
+ )
81
+ for e in sym.edges
82
+ ]
70
83
  self.net.add_nodes(nodes)
71
84
  self.net.add_edges(edges)
72
- file_path = file_path if file_path.endswith('.html') else file_path + '.html'
85
+ file_path = file_path if file_path.endswith(".html") else file_path + ".html"
73
86
  return self.net.show(file_path)
74
87
 
75
88
 
@@ -109,12 +122,14 @@ class Try(Expression):
109
122
  class Lambda(Expression):
110
123
  def __init__(self, callable: Callable, **kwargs):
111
124
  super().__init__(**kwargs)
125
+
112
126
  def _callable(*args, **kwargs):
113
127
  kw = {
114
- 'args': args,
115
- 'kwargs': kwargs,
128
+ "args": args,
129
+ "kwargs": kwargs,
116
130
  }
117
131
  return callable(kw)
132
+
118
133
  self.callable: Callable = _callable
119
134
 
120
135
  def forward(self, *args, **kwargs) -> Symbol:
@@ -140,8 +155,8 @@ class Output(Expression):
140
155
  self.verbose: bool = verbose
141
156
 
142
157
  def forward(self, *args, **kwargs) -> Expression:
143
- kwargs['verbose'] = self.verbose
144
- kwargs['handler'] = self.handler
158
+ kwargs["verbose"] = self.verbose
159
+ kwargs["handler"] = self.handler
145
160
  return self.output(*args, expr=self.expr, **kwargs)
146
161
 
147
162
 
@@ -166,32 +181,34 @@ class Sequence(TrackerTraceable):
166
181
  class Parallel(Expression):
167
182
  def __init__(self, *expr: list[Expression | Callable], sequential: bool = False, **kwargs):
168
183
  super().__init__(**kwargs)
169
- self.sequential: bool = sequential
184
+ self.sequential: bool = sequential
170
185
  self.expr: list[Expression] = expr
171
- self.results: list[Symbol] = []
186
+ self.results: list[Symbol] = []
172
187
 
173
188
  def forward(self, *args, **kwargs) -> Symbol:
174
189
  # run in sequence
175
190
  if self.sequential:
176
191
  return [e(*args, **kwargs) for e in self.expr]
192
+
177
193
  # run in parallel
178
194
  @core_ext.parallel(self.expr)
179
195
  def _func(e, *args, **kwargs):
180
196
  return e(*args, **kwargs)
197
+
181
198
  self.results = _func(*args, **kwargs)
182
199
  # final result of the parallel execution
183
200
  return self._to_symbol(self.results)
184
201
 
185
202
 
186
- #@TODO: BinPacker(format="...") -> ensure that data packages form a "bin" that's consistent (e.g. never break a sentence in the middle)
203
+ # @TODO: BinPacker(format="...") -> ensure that data packages form a "bin" that's consistent (e.g. never break a sentence in the middle)
187
204
  class Stream(Expression):
188
205
  def __init__(self, expr: Expression | None = None, retrieval: str | None = None, **kwargs):
189
206
  super().__init__(**kwargs)
190
- self.char_token_ratio: float = 0.6
207
+ self.char_token_ratio: float = 0.6
191
208
  self.expr: Expression | None = expr
192
- self.retrieval: str | None = retrieval
193
- self._trace: bool = False
194
- self._previous_frame = None
209
+ self.retrieval: str | None = retrieval
210
+ self._trace: bool = False
211
+ self._previous_frame = None
195
212
 
196
213
  def forward(self, sym: Symbol, **kwargs) -> Iterator:
197
214
  sym = self._to_symbol(sym)
@@ -213,17 +230,15 @@ class Stream(Expression):
213
230
  raise_with=ValueError,
214
231
  )
215
232
 
216
- res = sym.stream(expr=self.expr,
217
- char_token_ratio=self.char_token_ratio,
218
- **kwargs)
233
+ res = sym.stream(expr=self.expr, char_token_ratio=self.char_token_ratio, **kwargs)
219
234
  if self.retrieval is not None:
220
235
  res = list(res)
221
- if self.retrieval == 'all':
236
+ if self.retrieval == "all":
222
237
  return res
223
- if self.retrieval == 'longest':
238
+ if self.retrieval == "longest":
224
239
  res = sorted(res, key=lambda x: len(x), reverse=True)
225
240
  return res[0]
226
- if self.retrieval == 'contains':
241
+ if self.retrieval == "contains":
227
242
  return [r for r in res if self.expr in r]
228
243
  UserMessage(f"Invalid retrieval method: {self.retrieval}", raise_with=ValueError)
229
244
 
@@ -241,7 +256,7 @@ class Stream(Expression):
241
256
  class Trace(Expression):
242
257
  def __init__(self, expr: Expression | None = None, engines=None, **kwargs):
243
258
  if engines is None:
244
- engines = ['all']
259
+ engines = ["all"]
245
260
  super().__init__(**kwargs)
246
261
  self.expr: Expression = expr
247
262
  self.engines: list[str] = engines
@@ -278,7 +293,7 @@ class Analyze(Expression):
278
293
  class Log(Expression):
279
294
  def __init__(self, expr: Expression | None = None, engines=None, **kwargs):
280
295
  if engines is None:
281
- engines = ['all']
296
+ engines = ["all"]
282
297
  super().__init__(**kwargs)
283
298
  self.expr: Expression = expr
284
299
  self.engines: list[str] = engines
@@ -303,7 +318,12 @@ class Log(Expression):
303
318
 
304
319
 
305
320
  class Template(Expression):
306
- def __init__(self, template: str = "<html><body>{{placeholder}}</body></html>", placeholder: str = '{{placeholder}}', **kwargs):
321
+ def __init__(
322
+ self,
323
+ template: str = "<html><body>{{placeholder}}</body></html>",
324
+ placeholder: str = "{{placeholder}}",
325
+ **kwargs,
326
+ ):
307
327
  super().__init__(**kwargs)
308
328
  self.placeholder = placeholder
309
329
  self.template_ = template
@@ -333,21 +353,25 @@ class RuntimeExpression(Expression):
333
353
  code = self._to_symbol(code)
334
354
  # declare the runtime expression from the code
335
355
  expr = self.runner(code)
356
+
336
357
  def _func(sym):
337
358
  # execute nested expression
338
- return expr['locals']['_output_'](sym)
359
+ return expr["locals"]["_output_"](sym)
360
+
339
361
  return _func
340
362
 
341
363
 
342
364
  class Metric(Expression):
343
365
  def __init__(self, normalize: bool = False, eps: float = 1e-8, **kwargs):
344
366
  super().__init__(**kwargs)
345
- self.normalize = normalize
346
- self.eps = eps
367
+ self.normalize = normalize
368
+ self.eps = eps
347
369
 
348
370
  def forward(self, sym: Symbol, **_kwargs) -> Symbol:
349
371
  sym = self._to_symbol(sym)
350
- assert sym.value_type is np.ndarray or sym.value_type is list, 'Metric can only be applied to numpy arrays or lists.'
372
+ assert sym.value_type is np.ndarray or sym.value_type is list, (
373
+ "Metric can only be applied to numpy arrays or lists."
374
+ )
351
375
  if sym.value_type is list:
352
376
  sym._value = np.array(sym.value)
353
377
  # compute normalization between 0 and 1
@@ -357,7 +381,7 @@ class Metric(Expression):
357
381
  elif len(sym.value.shape) == 2:
358
382
  pass
359
383
  else:
360
- UserMessage(f'Invalid shape: {sym.value.shape}', raise_with=ValueError)
384
+ UserMessage(f"Invalid shape: {sym.value.shape}", raise_with=ValueError)
361
385
  # normalize between 0 and 1 and sum to 1
362
386
  sym._value = np.exp(sym.value) / (np.exp(sym.value).sum() + self.eps)
363
387
  return sym
@@ -413,16 +437,16 @@ _output_ = _func()
413
437
 
414
438
  def forward(self, sym: Symbol, enclosure: bool = False, **kwargs) -> Symbol:
415
439
  if enclosure or self.enclosure:
416
- lines = str(sym).split('\n')
417
- lines = [' ' + line for line in lines]
418
- sym = '\n'.join(lines)
419
- sym = self.template.replace('{sym}', str(sym))
440
+ lines = str(sym).split("\n")
441
+ lines = [" " + line for line in lines]
442
+ sym = "\n".join(lines)
443
+ sym = self.template.replace("{sym}", str(sym))
420
444
  sym = self._to_symbol(sym)
421
445
  return sym.execute(**kwargs)
422
446
 
423
447
 
424
448
  class Convert(Expression):
425
- def __init__(self, format: str = 'Python', **kwargs):
449
+ def __init__(self, format: str = "Python", **kwargs):
426
450
  super().__init__(**kwargs)
427
451
  self.format = format
428
452
 
@@ -456,13 +480,13 @@ class Map(Expression):
456
480
 
457
481
 
458
482
  class Translate(Expression):
459
- def __init__(self, language: str = 'English', **kwargs):
483
+ def __init__(self, language: str = "English", **kwargs):
460
484
  super().__init__(**kwargs)
461
485
  self.language = language
462
486
 
463
487
  def forward(self, sym: Symbol, **kwargs) -> Symbol:
464
488
  sym = self._to_symbol(sym)
465
- if sym.isinstanceof(f'{self.language} text'):
489
+ if sym.isinstanceof(f"{self.language} text"):
466
490
  return sym
467
491
  return sym.translate(language=self.language, **kwargs)
468
492
 
@@ -494,7 +518,7 @@ class FileWriter(Expression):
494
518
 
495
519
  def forward(self, sym: Symbol, **_kwargs) -> Symbol:
496
520
  sym = self._to_symbol(sym)
497
- with self.path.open('w') as f:
521
+ with self.path.open("w") as f:
498
522
  f.write(str(sym))
499
523
 
500
524
 
@@ -502,18 +526,18 @@ class FileReader(Expression):
502
526
  @staticmethod
503
527
  def exists(path: str) -> bool:
504
528
  # remove slicing if any
505
- _tmp = path
506
- _splits = _tmp.split('[')
507
- if '[' in _tmp:
529
+ _tmp = path
530
+ _splits = _tmp.split("[")
531
+ if "[" in _tmp:
508
532
  _tmp = _splits[0]
509
- assert len(_splits) == 1 or len(_splits) == 2, 'Invalid file link format.'
510
- _tmp = Path(_tmp)
533
+ assert len(_splits) == 1 or len(_splits) == 2, "Invalid file link format."
534
+ _tmp = Path(_tmp)
511
535
  # check if file exists and is a file
512
536
  return _tmp.is_file()
513
537
 
514
538
  @staticmethod
515
539
  def get_files(folder_path: str, max_depth: int = 1) -> list[str]:
516
- accepted_formats = ['.pdf', '.md', '.txt']
540
+ accepted_formats = [".pdf", ".md", ".txt"]
517
541
 
518
542
  folder = Path(folder_path)
519
543
  files = []
@@ -527,9 +551,34 @@ class FileReader(Expression):
527
551
 
528
552
  @staticmethod
529
553
  def extract_files(cmds: str) -> list[str] | None:
530
- # Use the updated regular expression to match quoted and non-quoted paths
531
- pattern = r'''(?:"((?:\\.|[^"\\])*)"|'((?:\\.|[^'\\])*)'|`((?:\\.|[^`\\])*)`|((?:\\ |[^ ])+))'''
532
- # Use the regular expression to split and handle quoted and non-quoted paths
554
+ """
555
+ Extract file paths from a command string, handling various quoting styles.
556
+
557
+ This method is used by the Qdrant RAG implementation when processing document paths.
558
+ It uses regex to parse file paths that may be quoted in different ways.
559
+
560
+ Regex patterns used:
561
+ 1. Main pattern: Matches file paths in four formats:
562
+ - Double-quoted: "path/to/file" (handles escaped characters)
563
+ - Single-quoted: 'path/to/file' (handles escaped characters)
564
+ - Backtick-quoted: `path/to/file` (handles escaped characters)
565
+ - Non-quoted: path/to/file (handles escaped spaces)
566
+
567
+ 2. Escape removal pattern: r"\\(.)" -> r"\1"
568
+ - Removes backslash escape sequences from quoted paths
569
+ - Example: "path\\/to\\/file" -> "path/to/file"
570
+ - Used for double quotes, single quotes, and backticks
571
+ """
572
+ # Regex pattern to match file paths in various quoting styles
573
+ # Pattern breakdown:
574
+ # - (?:"((?:\\.|[^"\\])*)") : Matches double-quoted paths, capturing content while handling escapes
575
+ # - '((?:\\.|[^'\\])*)' : Matches single-quoted paths, capturing content while handling escapes
576
+ # - `((?:\\.|[^`\\])*)` : Matches backtick-quoted paths, capturing content while handling escapes
577
+ # - ((?:\\ |[^ ])+) : Matches non-quoted paths, allowing escaped spaces
578
+ pattern = (
579
+ r"""(?:"((?:\\.|[^"\\])*)"|'((?:\\.|[^'\\])*)'|`((?:\\.|[^`\\])*)`|((?:\\ |[^ ])+))"""
580
+ )
581
+ # Use regex to find all file path matches in the command string
533
582
  matches = re.findall(pattern, cmds)
534
583
  # Process the matches to handle quoted paths and normal paths
535
584
  files = []
@@ -537,23 +586,27 @@ class FileReader(Expression):
537
586
  # Each match will have 4 groups due to the pattern; only one will be non-empty
538
587
  quoted_double, quoted_single, quoted_backtick, non_quoted = match
539
588
  if quoted_double:
540
- # Remove backslashes used for escaping inside double quotes
541
- path = re.sub(r'\\(.)', r'\1', quoted_double)
589
+ # Regex substitution: Remove backslashes used for escaping inside double quotes
590
+ # Pattern r"\\(.)" matches a backslash followed by any character and replaces with just the character
591
+ # Example: "path\\/to\\/file" -> "path/to/file"
592
+ path = re.sub(r"\\(.)", r"\1", quoted_double)
542
593
  file = FileReader.expand_user_path(path)
543
594
  files.append(file)
544
595
  elif quoted_single:
545
- # Remove backslashes used for escaping inside single quotes
546
- path = re.sub(r'\\(.)', r'\1', quoted_single)
596
+ # Regex substitution: Remove backslashes used for escaping inside single quotes
597
+ # Same pattern as above, applied to single-quoted paths
598
+ path = re.sub(r"\\(.)", r"\1", quoted_single)
547
599
  file = FileReader.expand_user_path(path)
548
600
  files.append(file)
549
601
  elif quoted_backtick:
550
- # Remove backslashes used for escaping inside backticks
551
- path = re.sub(r'\\(.)', r'\1', quoted_backtick)
602
+ # Regex substitution: Remove backslashes used for escaping inside backticks
603
+ # Same pattern as above, applied to backtick-quoted paths
604
+ path = re.sub(r"\\(.)", r"\1", quoted_backtick)
552
605
  file = FileReader.expand_user_path(path)
553
606
  files.append(file)
554
607
  elif non_quoted:
555
- # Replace escaped spaces with actual spaces
556
- path = non_quoted.replace('\\ ', ' ')
608
+ # Replace escaped spaces with actual spaces (no regex needed here, simple string replace)
609
+ path = non_quoted.replace("\\ ", " ")
557
610
  file = FileReader.expand_user_path(path)
558
611
  files.append(file)
559
612
  # Filter out any files that do not exist
@@ -571,25 +624,28 @@ class FileReader(Expression):
571
624
  if FileReader.exists(file):
572
625
  not_skipped.append(file)
573
626
  else:
574
- UserMessage(f'Skipping file: {file}')
627
+ UserMessage(f"Skipping file: {file}")
575
628
  return not_skipped
576
629
 
577
630
  def forward(self, files: str | list[str], **kwargs) -> Expression:
578
631
  if isinstance(files, str):
579
632
  # Convert to list for uniform processing; more easily downstream
580
633
  files = [files]
581
- if kwargs.get('run_integrity_check'):
634
+ if kwargs.get("run_integrity_check"):
582
635
  files = self.integrity_check(files)
583
636
  return self.sym_return_type([self.open(f, **kwargs).value for f in files])
584
637
 
638
+
585
639
  class FileQuery(Expression):
586
640
  def __init__(self, path: str, filter: str, **kwargs):
587
641
  super().__init__(**kwargs)
588
642
  self.path = path
589
643
  file_open = FileReader()
590
- self.query_stream = Stream(Sequence(
591
- IncludeFilter(filter),
592
- ))
644
+ self.query_stream = Stream(
645
+ Sequence(
646
+ IncludeFilter(filter),
647
+ )
648
+ )
593
649
  self.file = file_open(path)
594
650
 
595
651
  def forward(self, sym: Symbol, **kwargs) -> Symbol:
@@ -599,42 +655,45 @@ class FileQuery(Expression):
599
655
 
600
656
 
601
657
  class Function(TrackerTraceable):
602
- def __init__(self, prompt: str = '',
603
- examples: str | None = [],
604
- pre_processors: list[PreProcessor] | None = None,
605
- post_processors: list[PostProcessor] | None = None,
606
- default: object | None = None,
607
- constraints: list[Callable] | None = None,
608
- return_type: type | None = str,
609
- sym_return_type: type | None = Symbol,
610
- origin_type: type | None = Expression,
611
- *args, **kwargs):
658
+ def __init__(
659
+ self,
660
+ prompt: str = "",
661
+ examples: str | None = [],
662
+ pre_processors: list[PreProcessor] | None = None,
663
+ post_processors: list[PostProcessor] | None = None,
664
+ default: object | None = None,
665
+ constraints: list[Callable] | None = None,
666
+ return_type: type | None = str,
667
+ sym_return_type: type | None = Symbol,
668
+ origin_type: type | None = Expression,
669
+ *args,
670
+ **kwargs,
671
+ ):
612
672
  if constraints is None:
613
673
  constraints = []
614
674
  super().__init__(**kwargs)
615
- chars = ascii_lowercase + ascii_uppercase
616
- self.name = 'func_' + ''.join(sample(chars, 15))
617
- self.args = args
675
+ chars = ascii_lowercase + ascii_uppercase
676
+ self.name = "func_" + "".join(sample(chars, 15))
677
+ self.args = args
618
678
  self.kwargs = kwargs
619
- self._promptTemplate = prompt
620
- self._promptFormatArgs = []
679
+ self._promptTemplate = prompt
680
+ self._promptFormatArgs = []
621
681
  self._promptFormatKwargs = {}
622
- self.examples = Prompt(examples)
623
- self.pre_processors = pre_processors
682
+ self.examples = Prompt(examples)
683
+ self.pre_processors = pre_processors
624
684
  self.post_processors = post_processors
625
- self.constraints = constraints
626
- self.default = default
627
- self.return_type = return_type
685
+ self.constraints = constraints
686
+ self.default = default
687
+ self.return_type = return_type
628
688
  self.sym_return_type = sym_return_type
629
- self.origin_type = origin_type
689
+ self.origin_type = origin_type
630
690
 
631
691
  @property
632
692
  def prompt(self):
633
693
  # return a copy of the prompt template
634
694
  if len(self._promptFormatArgs) == 0 and len(self._promptFormatKwargs) == 0:
635
695
  return self._promptTemplate
636
- return f"{self._promptTemplate}".format(*self._promptFormatArgs,
637
- **self._promptFormatKwargs)
696
+ return f"{self._promptTemplate}".format(*self._promptFormatArgs, **self._promptFormatKwargs)
638
697
 
639
698
  def format(self, *args, **kwargs):
640
699
  self._promptFormatArgs = args
@@ -642,9 +701,10 @@ class Function(TrackerTraceable):
642
701
 
643
702
  def forward(self, *args, **kwargs) -> Expression:
644
703
  # special case for few shot function prompt definition override
645
- if 'fn' in kwargs:
646
- self.prompt = kwargs['fn']
647
- del kwargs['fn']
704
+ if "fn" in kwargs:
705
+ self.prompt = kwargs["fn"]
706
+ del kwargs["fn"]
707
+
648
708
  @core.few_shot(
649
709
  *self.args,
650
710
  prompt=self.prompt,
@@ -653,19 +713,24 @@ class Function(TrackerTraceable):
653
713
  post_processors=self.post_processors,
654
714
  constraints=self.constraints,
655
715
  default=self.default,
656
- **self.kwargs
716
+ **self.kwargs,
657
717
  )
658
718
  def _func(_, *args, **kwargs) -> self.return_type:
659
719
  pass
660
- _type = type(self.name, (self.origin_type, ), {
661
- # constructor
662
- "forward": _func,
663
- "sym_return_type": self.sym_return_type,
664
- "static_context": self.static_context,
665
- "dynamic_context": self.dynamic_context,
666
- "__class__": self.__class__,
667
- "__module__": self.__module__,
668
- })
720
+
721
+ _type = type(
722
+ self.name,
723
+ (self.origin_type,),
724
+ {
725
+ # constructor
726
+ "forward": _func,
727
+ "sym_return_type": self.sym_return_type,
728
+ "static_context": self.static_context,
729
+ "dynamic_context": self.dynamic_context,
730
+ "__class__": self.__class__,
731
+ "__module__": self.__module__,
732
+ },
733
+ )
669
734
  obj = _type()
670
735
 
671
736
  return self._to_symbol(obj(*args, **kwargs))
@@ -676,7 +741,7 @@ class PrepareData(Function):
676
741
  def __call__(self, argument):
677
742
  assert argument.prop.context is not None
678
743
  instruct = argument.prop.prompt
679
- context = argument.prop.context
744
+ context = argument.prop.context
680
745
  return f"""{{
681
746
  'context': '{context}',
682
747
  'instruction': '{instruct}',
@@ -685,10 +750,10 @@ class PrepareData(Function):
685
750
 
686
751
  def __init__(self, *args, **kwargs):
687
752
  super().__init__(*args, **kwargs)
688
- self.pre_processors = [self.PrepareDataPreProcessor()]
689
- self.constraints = [DictFormatConstraint({ 'result': '<the data>' })]
753
+ self.pre_processors = [self.PrepareDataPreProcessor()]
754
+ self.constraints = [DictFormatConstraint({"result": "<the data>"})]
690
755
  self.post_processors = [JsonTruncateMarkdownPostProcessor()]
691
- self.return_type = dict # constraint to cast the result to a dict
756
+ self.return_type = dict # constraint to cast the result to a dict
692
757
 
693
758
  @property
694
759
  def static_context(self):
@@ -723,7 +788,7 @@ Your goal is to prepare the data for the next task instruction. The data should
723
788
 
724
789
  class ExpressionBuilder(Function):
725
790
  def __init__(self, **kwargs):
726
- super().__init__('Generate the code following the instructions:', **kwargs)
791
+ super().__init__("Generate the code following the instructions:", **kwargs)
727
792
  self.processors = ProcessorPipeline([StripPostProcessor(), CodeExtractPostProcessor()])
728
793
 
729
794
  def forward(self, instruct, *_args, **_kwargs):
@@ -774,10 +839,12 @@ Always produce the entire code to be executed in the same Python process. All ta
774
839
  class JsonParser(Expression):
775
840
  def __init__(self, query: str, json_: dict, **kwargs):
776
841
  super().__init__(**kwargs)
777
- func = Function(prompt=JsonPromptTemplate(query, json_),
778
- constraints=[DictFormatConstraint(json_)],
779
- pre_processors=[JsonPreProcessor()],
780
- post_processors=[JsonTruncatePostProcessor()])
842
+ func = Function(
843
+ prompt=JsonPromptTemplate(query, json_),
844
+ constraints=[DictFormatConstraint(json_)],
845
+ pre_processors=[JsonPreProcessor()],
846
+ post_processors=[JsonTruncatePostProcessor()],
847
+ )
781
848
  self.fn = Try(func, retries=1)
782
849
 
783
850
  def forward(self, sym: Symbol, **kwargs) -> Symbol:
@@ -787,21 +854,27 @@ class JsonParser(Expression):
787
854
 
788
855
 
789
856
  class SimilarityClassification(Expression):
790
- def __init__(self, classes: list[str], metric: str = 'cosine', in_memory: bool = False, **kwargs):
857
+ def __init__(
858
+ self, classes: list[str], metric: str = "cosine", in_memory: bool = False, **kwargs
859
+ ):
791
860
  super().__init__(**kwargs)
792
- self.classes = classes
793
- self.metric = metric
861
+ self.classes = classes
862
+ self.metric = metric
794
863
  self.in_memory = in_memory
795
864
 
796
865
  if self.in_memory:
797
- UserMessage(f'Caching mode is enabled! It is your responsability to empty the .cache folder if you did changes to the classes. The cache is located at {HOME_PATH}/cache')
866
+ UserMessage(
867
+ f"Caching mode is enabled! It is your responsability to empty the .cache folder if you did changes to the classes. The cache is located at {HOME_PATH}/cache"
868
+ )
798
869
 
799
870
  def forward(self, x: Symbol) -> Symbol:
800
- x = self._to_symbol(x)
801
- usr_embed = x.embed()
802
- embeddings = self._dynamic_cache()
871
+ x = self._to_symbol(x)
872
+ usr_embed = x.embed()
873
+ embeddings = self._dynamic_cache()
803
874
  similarities = [usr_embed.similarity(emb, metric=self.metric) for emb in embeddings]
804
- similarities = sorted(zip(self.classes, similarities, strict=False), key=lambda x: x[1], reverse=True)
875
+ similarities = sorted(
876
+ zip(self.classes, similarities, strict=False), key=lambda x: x[1], reverse=True
877
+ )
805
878
 
806
879
  return Symbol(similarities[0][0])
807
880
 
@@ -820,11 +893,7 @@ class InContextClassification(Expression):
820
893
  self.blueprint = blueprint
821
894
 
822
895
  def forward(self, x: Symbol, **kwargs) -> Symbol:
823
- @core.few_shot(
824
- prompt=x,
825
- examples=self.blueprint,
826
- **kwargs
827
- )
896
+ @core.few_shot(prompt=x, examples=self.blueprint, **kwargs)
828
897
  def _func(_):
829
898
  pass
830
899
 
@@ -832,38 +901,38 @@ class InContextClassification(Expression):
832
901
 
833
902
 
834
903
  class Indexer(Expression):
835
- DEFAULT = 'dataindex'
904
+ DEFAULT = "dataindex"
836
905
 
837
906
  @staticmethod
838
907
  def replace_special_chars(index: str):
839
908
  # replace special characters that are not for path
840
- return str(index).replace('-', '').replace('_', '').replace(' ', '').lower()
909
+ return str(index).replace("-", "").replace("_", "").replace(" ", "").lower()
841
910
 
842
911
  def __init__(
843
- self,
844
- index_name: str = DEFAULT,
845
- top_k: int = 8,
846
- batch_size: int = 20,
847
- formatter: Callable = _DEFAULT_PARAGRAPH_FORMATTER,
848
- auto_add=False,
849
- raw_result: bool = False,
850
- new_dim: int = 1536,
851
- **kwargs
852
- ):
912
+ self,
913
+ index_name: str = DEFAULT,
914
+ top_k: int = 8,
915
+ batch_size: int = 20,
916
+ formatter: Callable = _DEFAULT_PARAGRAPH_FORMATTER,
917
+ auto_add=False,
918
+ raw_result: bool = False,
919
+ new_dim: int = 1536,
920
+ **kwargs,
921
+ ):
853
922
  super().__init__(**kwargs)
854
923
  index_name = Indexer.replace_special_chars(index_name)
855
924
  self.index_name = index_name
856
- self.elements = []
925
+ self.elements = []
857
926
  self.batch_size = batch_size
858
- self.top_k = top_k
859
- self.retrieval = None
860
- self.formatter = formatter
927
+ self.top_k = top_k
928
+ self.retrieval = None
929
+ self.formatter = formatter
861
930
  self.raw_result = raw_result
862
- self.new_dim = new_dim
931
+ self.new_dim = new_dim
863
932
  self.sym_return_type = Expression
864
933
 
865
934
  # append index name to indices.txt in home directory .symai folder (default)
866
- self.path = HOME_PATH / 'indices.txt'
935
+ self.path = HOME_PATH / "indices.txt"
867
936
  if not self.path.exists():
868
937
  self.path.parent.mkdir(parents=True, exist_ok=True)
869
938
  self.path.touch()
@@ -874,51 +943,62 @@ class Indexer(Expression):
874
943
  # check if index already exists in indices.txt and append if not
875
944
  change = False
876
945
  with self.path.open() as f:
877
- indices = f.read().split('\n')
946
+ indices = f.read().split("\n")
878
947
  # filter out empty strings
879
948
  indices = [i for i in indices if i]
880
949
  if self.index_name not in indices:
881
- indices.append(self.index_name)
882
- change = True
950
+ indices.append(self.index_name)
951
+ change = True
883
952
  if change:
884
- with self.path.open('w') as f:
885
- f.write('\n'.join(indices))
953
+ with self.path.open("w") as f:
954
+ f.write("\n".join(indices))
886
955
 
887
956
  def exists(self) -> bool:
888
957
  # check if index exists in home directory .symai folder (default) indices.txt
889
- path = HOME_PATH / 'indices.txt'
958
+ path = HOME_PATH / "indices.txt"
890
959
  if not path.exists():
891
960
  return False
892
961
  with path.open() as f:
893
- indices = f.read().split('\n')
962
+ indices = f.read().split("\n")
894
963
  if self.index_name in indices:
895
964
  return True
896
965
  return False
897
966
 
898
967
  def forward(
899
- self,
900
- data: Symbol | None = None,
901
- _raw_result: bool = False,
902
- ) -> Symbol:
968
+ self,
969
+ data: Symbol | None = None,
970
+ _raw_result: bool = False,
971
+ ) -> Symbol:
903
972
  that = self
904
973
  if data is not None:
905
974
  data = self._to_symbol(data)
906
975
  self.elements = self.formatter(data).value
907
976
  # run over the elments in batches
908
977
  for i in tqdm(range(0, len(self.elements), self.batch_size)):
909
- val = Symbol(self.elements[i:i+self.batch_size]).zip(new_dim=self.new_dim)
978
+ val = Symbol(self.elements[i : i + self.batch_size]).zip(new_dim=self.new_dim)
910
979
  that.add(val, index_name=that.index_name, index_dims=that.new_dim)
911
980
  # we save the index
912
981
  that.config(None, save=True, index_name=that.index_name, index_dims=that.new_dim)
913
982
 
914
- def _func(query, *_args, **kwargs) -> Union[Symbol, 'VectorDBResult']:
915
- raw_result = kwargs.get('raw_result') or that.raw_result
983
+ def _func(query, *_args, **kwargs) -> Union[Symbol, "VectorDBResult"]:
984
+ raw_result = kwargs.get("raw_result") or that.raw_result
916
985
  query_emb = Symbol(query).embed(new_dim=that.new_dim).value
917
- res = that.get(query_emb, index_name=that.index_name, index_top_k=that.top_k, ori_query=query, index_dims=that.new_dim, **kwargs)
986
+ res = that.get(
987
+ query_emb,
988
+ index_name=that.index_name,
989
+ index_top_k=that.top_k,
990
+ ori_query=query,
991
+ index_dims=that.new_dim,
992
+ **kwargs,
993
+ )
918
994
  that.retrieval = res
919
995
  if raw_result:
920
996
  return res
921
- return Symbol(res).query(prompt='From the retrieved data, select the most relevant information.', context=query)
997
+ return Symbol(res).query(
998
+ prompt="From the retrieved data, select the most relevant information.",
999
+ context=query,
1000
+ )
1001
+
922
1002
  return _func
923
1003
 
924
1004
 
@@ -930,7 +1010,7 @@ class PrimitiveDisabler(Expression):
930
1010
 
931
1011
  def __enter__(self):
932
1012
  # Import Symbol lazily so components does not clash with symbol during load.
933
- from .symbol import Symbol # noqa
1013
+ from .symbol import Symbol # noqa
934
1014
 
935
1015
  frame = inspect.currentframe()
936
1016
  f_locals = frame.f_back.f_locals
@@ -957,7 +1037,7 @@ class PrimitiveDisabler(Expression):
957
1037
  for sym in self._symbols.values():
958
1038
  for primitive in sym._primitives:
959
1039
  for method, _ in inspect.getmembers(primitive, predicate=inspect.isfunction):
960
- if method in self._primitives or method.startswith('_'):
1040
+ if method in self._primitives or method.startswith("_"):
961
1041
  continue
962
1042
  self._primitives.add(method)
963
1043
 
@@ -1002,9 +1082,7 @@ class FunctionWithUsage(Function):
1002
1082
  self.total_tokens += usage.total_tokens
1003
1083
 
1004
1084
  def get_usage(self):
1005
- return self._format_usage(
1006
- self.prompt_tokens, self.completion_tokens, self.total_tokens
1007
- )
1085
+ return self._format_usage(self.prompt_tokens, self.completion_tokens, self.total_tokens)
1008
1086
 
1009
1087
  def forward(self, *args, **kwargs):
1010
1088
  if "return_metadata" not in kwargs:
@@ -1015,9 +1093,7 @@ class FunctionWithUsage(Function):
1015
1093
  raw_output = metadata.get("raw_output")
1016
1094
  if hasattr(raw_output, "usage"):
1017
1095
  usage = raw_output.usage
1018
- prompt_tokens = (
1019
- usage.prompt_tokens if hasattr(usage, "prompt_tokens") else 0
1020
- )
1096
+ prompt_tokens = usage.prompt_tokens if hasattr(usage, "prompt_tokens") else 0
1021
1097
  completion_tokens = (
1022
1098
  usage.completion_tokens if hasattr(usage, "completion_tokens") else 0
1023
1099
  )
@@ -1033,7 +1109,9 @@ class FunctionWithUsage(Function):
1033
1109
  self.total_tokens += total_tokens
1034
1110
  else:
1035
1111
  if self.missing_usage_exception and "preview" not in kwargs:
1036
- UserMessage("Missing usage in metadata of neursymbolic engine", raise_with=Exception)
1112
+ UserMessage(
1113
+ "Missing usage in metadata of neursymbolic engine", raise_with=Exception
1114
+ )
1037
1115
  prompt_tokens = 0
1038
1116
  completion_tokens = 0
1039
1117
  total_tokens = 0
@@ -1042,12 +1120,12 @@ class FunctionWithUsage(Function):
1042
1120
 
1043
1121
 
1044
1122
  class SelfPrompt(Expression):
1045
- _default_retry_tries = 20
1046
- _default_retry_delay = 0.5
1123
+ _default_retry_tries = 20
1124
+ _default_retry_delay = 0.5
1047
1125
  _default_retry_max_delay = -1
1048
- _default_retry_backoff = 1
1049
- _default_retry_jitter = 0
1050
- _default_retry_graceful = True
1126
+ _default_retry_backoff = 1
1127
+ _default_retry_jitter = 0
1128
+ _default_retry_graceful = True
1051
1129
 
1052
1130
  def __init__(self, *args, **kwargs):
1053
1131
  super().__init__(*args, **kwargs)
@@ -1061,14 +1139,21 @@ class SelfPrompt(Expression):
1061
1139
  :return: A dictionary containing the new prompts in the same format:
1062
1140
  {'user': '...', 'system': '...'}
1063
1141
  """
1064
- tries = kwargs.get('tries', self._default_retry_tries)
1065
- delay = kwargs.get('delay', self._default_retry_delay)
1066
- max_delay = kwargs.get('max_delay', self._default_retry_max_delay)
1067
- backoff = kwargs.get('backoff', self._default_retry_backoff)
1068
- jitter = kwargs.get('jitter', self._default_retry_jitter)
1069
- graceful = kwargs.get('graceful', self._default_retry_graceful)
1070
-
1071
- @core_ext.retry(tries=tries, delay=delay, max_delay=max_delay, backoff=backoff, jitter=jitter, graceful=graceful)
1142
+ tries = kwargs.get("tries", self._default_retry_tries)
1143
+ delay = kwargs.get("delay", self._default_retry_delay)
1144
+ max_delay = kwargs.get("max_delay", self._default_retry_max_delay)
1145
+ backoff = kwargs.get("backoff", self._default_retry_backoff)
1146
+ jitter = kwargs.get("jitter", self._default_retry_jitter)
1147
+ graceful = kwargs.get("graceful", self._default_retry_graceful)
1148
+
1149
+ @core_ext.retry(
1150
+ tries=tries,
1151
+ delay=delay,
1152
+ max_delay=max_delay,
1153
+ backoff=backoff,
1154
+ jitter=jitter,
1155
+ graceful=graceful,
1156
+ )
1072
1157
  @core.zero_shot(
1073
1158
  prompt=(
1074
1159
  "Based on the following prompt, generate a new system (or developer) prompt and a new user prompt. "
@@ -1077,18 +1162,19 @@ class SelfPrompt(Expression):
1077
1162
  "The new user prompt should contain the user's requirements. "
1078
1163
  "Check if the input contains a 'system' or 'developer' key and use the same key in your output. "
1079
1164
  "Only output the new prompts in JSON format as shown:\n\n"
1080
- "{\"system\": \"<new system prompt>\", \"user\": \"<new user prompt>\"}\n\n"
1165
+ '{"system": "<new system prompt>", "user": "<new user prompt>"}\n\n'
1081
1166
  "OR\n\n"
1082
- "{\"developer\": \"<new developer prompt>\", \"user\": \"<new user prompt>\"}\n\n"
1167
+ '{"developer": "<new developer prompt>", "user": "<new user prompt>"}\n\n'
1083
1168
  "Maintain the same key structure as in the input prompt. Do not include any additional text."
1084
1169
  ),
1085
1170
  response_format={"type": "json_object"},
1086
1171
  post_processors=[
1087
1172
  lambda res, _: json.loads(res),
1088
1173
  ],
1089
- **kwargs
1174
+ **kwargs,
1090
1175
  )
1091
- def _func(self, sym: Symbol): pass
1176
+ def _func(self, sym: Symbol):
1177
+ pass
1092
1178
 
1093
1179
  return _func(self, self._to_symbol(existing_prompt))
1094
1180
 
@@ -1104,15 +1190,19 @@ class MetadataTracker(Expression):
1104
1190
  def __str__(self, value=None):
1105
1191
  value = value or self.metadata
1106
1192
  if isinstance(value, dict):
1107
- return '{\n\t' + ', \n\t'.join(f'"{k}": {self.__str__(v)}' for k,v in value.items()) + '\n}'
1193
+ return (
1194
+ "{\n\t"
1195
+ + ", \n\t".join(f'"{k}": {self.__str__(v)}' for k, v in value.items())
1196
+ + "\n}"
1197
+ )
1108
1198
  if isinstance(value, list):
1109
- return '[' + ', '.join(self.__str__(item) for item in value) + ']'
1199
+ return "[" + ", ".join(self.__str__(item) for item in value) + "]"
1110
1200
  if isinstance(value, str):
1111
1201
  return f'"{value}"'
1112
1202
  return f"\n\t {value}"
1113
1203
 
1114
1204
  def __new__(cls, *_args, **_kwargs):
1115
- cls._lock = getattr(cls, '_lock', Lock())
1205
+ cls._lock = getattr(cls, "_lock", Lock())
1116
1206
  with cls._lock:
1117
1207
  instance = super().__new__(cls)
1118
1208
  instance._metadata = {}
@@ -1135,14 +1225,14 @@ class MetadataTracker(Expression):
1135
1225
  return None
1136
1226
 
1137
1227
  if (
1138
- event == 'return'
1139
- and frame.f_code.co_name == 'forward'
1140
- and 'self' in frame.f_locals
1141
- and isinstance(frame.f_locals['self'], Engine)
1228
+ event == "return"
1229
+ and frame.f_code.co_name == "forward"
1230
+ and "self" in frame.f_locals
1231
+ and isinstance(frame.f_locals["self"], Engine)
1142
1232
  ):
1143
1233
  _, metadata = arg # arg contains return value on 'return' event
1144
- engine_name = frame.f_locals['self'].__class__.__name__
1145
- model_name = frame.f_locals['self'].model
1234
+ engine_name = frame.f_locals["self"].__class__.__name__
1235
+ model_name = frame.f_locals["self"].model
1146
1236
  self._metadata[(self._metadata_id, engine_name, model_name)] = metadata
1147
1237
  self._metadata_id += 1
1148
1238
 
@@ -1162,38 +1252,91 @@ class MetadataTracker(Expression):
1162
1252
  try:
1163
1253
  if engine_name == "GroqEngine":
1164
1254
  usage = metadata["raw_output"].usage
1165
- token_details[(engine_name, model_name)]["usage"]["completion_tokens"] += usage.completion_tokens
1166
- token_details[(engine_name, model_name)]["usage"]["prompt_tokens"] += usage.prompt_tokens
1167
- token_details[(engine_name, model_name)]["usage"]["total_tokens"] += usage.total_tokens
1255
+ token_details[(engine_name, model_name)]["usage"]["completion_tokens"] += (
1256
+ usage.completion_tokens
1257
+ )
1258
+ token_details[(engine_name, model_name)]["usage"]["prompt_tokens"] += (
1259
+ usage.prompt_tokens
1260
+ )
1261
+ token_details[(engine_name, model_name)]["usage"]["total_tokens"] += (
1262
+ usage.total_tokens
1263
+ )
1168
1264
  token_details[(engine_name, model_name)]["usage"]["total_calls"] += 1
1169
1265
  #!: Backward compatibility for components like `RuntimeInfo`
1170
- token_details[(engine_name, model_name)]["prompt_breakdown"]["cached_tokens"] += 0 # Assignment not allowed with defualtdict
1171
- token_details[(engine_name, model_name)]["completion_breakdown"]["reasoning_tokens"] += 0
1266
+ token_details[(engine_name, model_name)]["prompt_breakdown"][
1267
+ "cached_tokens"
1268
+ ] += 0 # Assignment not allowed with defualtdict
1269
+ token_details[(engine_name, model_name)]["completion_breakdown"][
1270
+ "reasoning_tokens"
1271
+ ] += 0
1272
+ elif engine_name == "ParallelEngine":
1273
+ token_details[(engine_name, None)]["usage"]["total_calls"] += 1
1274
+ # There are no model-specific tokens for this engine
1275
+ token_details[(engine_name, None)]["usage"]["completion_tokens"] += 0
1276
+ token_details[(engine_name, None)]["usage"]["prompt_tokens"] += 0
1277
+ token_details[(engine_name, None)]["usage"]["total_tokens"] += 0
1278
+ #!: Backward compatibility for components like `RuntimeInfo`
1279
+ token_details[(engine_name, None)]["prompt_breakdown"]["cached_tokens"] += (
1280
+ 0 # Assignment not allowed with defualtdict
1281
+ )
1282
+ token_details[(engine_name, None)]["completion_breakdown"][
1283
+ "reasoning_tokens"
1284
+ ] += 0
1172
1285
  elif engine_name in ("GPTXChatEngine", "GPTXReasoningEngine"):
1173
1286
  usage = metadata["raw_output"].usage
1174
- token_details[(engine_name, model_name)]["usage"]["completion_tokens"] += usage.completion_tokens
1175
- token_details[(engine_name, model_name)]["usage"]["prompt_tokens"] += usage.prompt_tokens
1176
- token_details[(engine_name, model_name)]["usage"]["total_tokens"] += usage.total_tokens
1287
+ token_details[(engine_name, model_name)]["usage"]["completion_tokens"] += (
1288
+ usage.completion_tokens
1289
+ )
1290
+ token_details[(engine_name, model_name)]["usage"]["prompt_tokens"] += (
1291
+ usage.prompt_tokens
1292
+ )
1293
+ token_details[(engine_name, model_name)]["usage"]["total_tokens"] += (
1294
+ usage.total_tokens
1295
+ )
1177
1296
  token_details[(engine_name, model_name)]["usage"]["total_calls"] += 1
1178
- token_details[(engine_name, model_name)]["completion_breakdown"]["accepted_prediction_tokens"] += usage.completion_tokens_details.accepted_prediction_tokens
1179
- token_details[(engine_name, model_name)]["completion_breakdown"]["rejected_prediction_tokens"] += usage.completion_tokens_details.rejected_prediction_tokens
1180
- token_details[(engine_name, model_name)]["completion_breakdown"]["audio_tokens"] += usage.completion_tokens_details.audio_tokens
1181
- token_details[(engine_name, model_name)]["completion_breakdown"]["reasoning_tokens"] += usage.completion_tokens_details.reasoning_tokens
1182
- token_details[(engine_name, model_name)]["prompt_breakdown"]["audio_tokens"] += usage.prompt_tokens_details.audio_tokens
1183
- token_details[(engine_name, model_name)]["prompt_breakdown"]["cached_tokens"] += usage.prompt_tokens_details.cached_tokens
1297
+ token_details[(engine_name, model_name)]["completion_breakdown"][
1298
+ "accepted_prediction_tokens"
1299
+ ] += usage.completion_tokens_details.accepted_prediction_tokens
1300
+ token_details[(engine_name, model_name)]["completion_breakdown"][
1301
+ "rejected_prediction_tokens"
1302
+ ] += usage.completion_tokens_details.rejected_prediction_tokens
1303
+ token_details[(engine_name, model_name)]["completion_breakdown"][
1304
+ "audio_tokens"
1305
+ ] += usage.completion_tokens_details.audio_tokens
1306
+ token_details[(engine_name, model_name)]["completion_breakdown"][
1307
+ "reasoning_tokens"
1308
+ ] += usage.completion_tokens_details.reasoning_tokens
1309
+ token_details[(engine_name, model_name)]["prompt_breakdown"][
1310
+ "audio_tokens"
1311
+ ] += usage.prompt_tokens_details.audio_tokens
1312
+ token_details[(engine_name, model_name)]["prompt_breakdown"][
1313
+ "cached_tokens"
1314
+ ] += usage.prompt_tokens_details.cached_tokens
1184
1315
  elif engine_name == "GPTXSearchEngine":
1185
1316
  usage = metadata["raw_output"].usage
1186
- token_details[(engine_name, model_name)]["usage"]["prompt_tokens"] += usage.input_tokens
1187
- token_details[(engine_name, model_name)]["usage"]["completion_tokens"] += usage.output_tokens
1188
- token_details[(engine_name, model_name)]["usage"]["total_tokens"] += usage.total_tokens
1317
+ token_details[(engine_name, model_name)]["usage"]["prompt_tokens"] += (
1318
+ usage.input_tokens
1319
+ )
1320
+ token_details[(engine_name, model_name)]["usage"]["completion_tokens"] += (
1321
+ usage.output_tokens
1322
+ )
1323
+ token_details[(engine_name, model_name)]["usage"]["total_tokens"] += (
1324
+ usage.total_tokens
1325
+ )
1189
1326
  token_details[(engine_name, model_name)]["usage"]["total_calls"] += 1
1190
- token_details[(engine_name, model_name)]["prompt_breakdown"]["cached_tokens"] += usage.input_tokens_details.cached_tokens
1191
- token_details[(engine_name, model_name)]["completion_breakdown"]["reasoning_tokens"] += usage.output_tokens_details.reasoning_tokens
1327
+ token_details[(engine_name, model_name)]["prompt_breakdown"][
1328
+ "cached_tokens"
1329
+ ] += usage.input_tokens_details.cached_tokens
1330
+ token_details[(engine_name, model_name)]["completion_breakdown"][
1331
+ "reasoning_tokens"
1332
+ ] += usage.output_tokens_details.reasoning_tokens
1192
1333
  else:
1193
1334
  logger.warning(f"Tracking {engine_name} is not supported.")
1194
1335
  continue
1195
1336
  except Exception as e:
1196
- UserMessage(f"Failed to parse metadata for {engine_name}: {e}", raise_with=AttributeError)
1337
+ UserMessage(
1338
+ f"Failed to parse metadata for {engine_name}: {e}", raise_with=AttributeError
1339
+ )
1197
1340
 
1198
1341
  # Convert to normal dict
1199
1342
  return {**token_details}
@@ -1203,22 +1346,24 @@ class MetadataTracker(Expression):
1203
1346
  return engine_name in supported_engines
1204
1347
 
1205
1348
  def _accumulate_time_field(self, accumulated: dict, metadata: dict) -> None:
1206
- if 'time' in metadata and 'time' in accumulated:
1207
- accumulated['time'] += metadata['time']
1349
+ if "time" in metadata and "time" in accumulated:
1350
+ accumulated["time"] += metadata["time"]
1208
1351
 
1209
1352
  def _accumulate_usage_fields(self, accumulated: dict, metadata: dict) -> None:
1210
- if 'raw_output' not in metadata or 'raw_output' not in accumulated:
1353
+ if "raw_output" not in metadata or "raw_output" not in accumulated:
1211
1354
  return
1212
1355
 
1213
- metadata_raw_output = metadata['raw_output']
1214
- accumulated_raw_output = accumulated['raw_output']
1215
- if not hasattr(metadata_raw_output, 'usage') or not hasattr(accumulated_raw_output, 'usage'):
1356
+ metadata_raw_output = metadata["raw_output"]
1357
+ accumulated_raw_output = accumulated["raw_output"]
1358
+ if not hasattr(metadata_raw_output, "usage") or not hasattr(
1359
+ accumulated_raw_output, "usage"
1360
+ ):
1216
1361
  return
1217
1362
 
1218
1363
  current_usage = metadata_raw_output.usage
1219
1364
  accumulated_usage = accumulated_raw_output.usage
1220
1365
 
1221
- for attr in ['completion_tokens', 'prompt_tokens', 'total_tokens']:
1366
+ for attr in ["completion_tokens", "prompt_tokens", "total_tokens"]:
1222
1367
  if hasattr(current_usage, attr) and hasattr(accumulated_usage, attr):
1223
1368
  setattr(
1224
1369
  accumulated_usage,
@@ -1226,20 +1371,24 @@ class MetadataTracker(Expression):
1226
1371
  getattr(accumulated_usage, attr) + getattr(current_usage, attr),
1227
1372
  )
1228
1373
 
1229
- for detail_attr in ['completion_tokens_details', 'prompt_tokens_details']:
1230
- if not hasattr(current_usage, detail_attr) or not hasattr(accumulated_usage, detail_attr):
1374
+ for detail_attr in ["completion_tokens_details", "prompt_tokens_details"]:
1375
+ if not hasattr(current_usage, detail_attr) or not hasattr(
1376
+ accumulated_usage, detail_attr
1377
+ ):
1231
1378
  continue
1232
1379
 
1233
1380
  current_details = getattr(current_usage, detail_attr)
1234
1381
  accumulated_details = getattr(accumulated_usage, detail_attr)
1235
1382
 
1236
1383
  for attr in dir(current_details):
1237
- if attr.startswith('_') or not hasattr(accumulated_details, attr):
1384
+ if attr.startswith("_") or not hasattr(accumulated_details, attr):
1238
1385
  continue
1239
1386
 
1240
1387
  current_val = getattr(current_details, attr)
1241
1388
  accumulated_val = getattr(accumulated_details, attr)
1242
- if isinstance(current_val, (int, float)) and isinstance(accumulated_val, (int, float)):
1389
+ if isinstance(current_val, (int, float)) and isinstance(
1390
+ accumulated_val, (int, float)
1391
+ ):
1243
1392
  setattr(accumulated_details, attr, accumulated_val + current_val)
1244
1393
 
1245
1394
  def _accumulate_metadata(self):
@@ -1255,7 +1404,9 @@ class MetadataTracker(Expression):
1255
1404
  # Skipz first entry
1256
1405
  for (_, engine_name), metadata in list(self._metadata.items())[1:]:
1257
1406
  if not self._can_accumulate_engine(engine_name):
1258
- logger.warning(f"Metadata accumulation for {engine_name} is not supported. Try `.usage` instead for now.")
1407
+ logger.warning(
1408
+ f"Metadata accumulation for {engine_name} is not supported. Try `.usage` instead for now."
1409
+ )
1259
1410
  continue
1260
1411
 
1261
1412
  self._accumulate_time_field(accumulated, metadata)
@@ -1278,6 +1429,7 @@ class MetadataTracker(Expression):
1278
1429
 
1279
1430
  class DynamicEngine(Expression):
1280
1431
  """Context manager for dynamically switching neurosymbolic engine models."""
1432
+
1281
1433
  def __init__(self, model: str, api_key: str, _debug: bool = False, **_kwargs):
1282
1434
  super().__init__()
1283
1435
  self.model = model
@@ -1288,7 +1440,7 @@ class DynamicEngine(Expression):
1288
1440
  self._ctx_token = None
1289
1441
 
1290
1442
  def __new__(cls, *_args, **_kwargs):
1291
- cls._lock = getattr(cls, '_lock', Lock())
1443
+ cls._lock = getattr(cls, "_lock", Lock())
1292
1444
  with cls._lock:
1293
1445
  instance = super().__new__(cls)
1294
1446
  instance._metadata = {}
@@ -1322,11 +1474,177 @@ class DynamicEngine(Expression):
1322
1474
  def _create_engine_instance(self):
1323
1475
  """Create an engine instance based on the model name."""
1324
1476
  # Deferred to avoid components <-> neurosymbolic engine circular imports.
1325
- from .backend.engines.neurosymbolic import ENGINE_MAPPING # noqa
1477
+ from .backend.engines.neurosymbolic import ENGINE_MAPPING # noqa
1478
+
1326
1479
  try:
1327
1480
  engine_class = ENGINE_MAPPING.get(self.model)
1328
1481
  if engine_class is None:
1329
1482
  UserMessage(f"Unsupported model '{self.model}'", raise_with=ValueError)
1330
1483
  return engine_class(api_key=self.api_key, model=self.model)
1331
1484
  except Exception as e:
1332
- UserMessage(f"Failed to create engine for model '{self.model}': {e!s}", raise_with=ValueError)
1485
+ UserMessage(
1486
+ f"Failed to create engine for model '{self.model}': {e!s}", raise_with=ValueError
1487
+ )
1488
+
1489
+
1490
+ # Chonkie chunker imports - lazy loaded
1491
+ _CHONKIE_MODULES = None
1492
+ _CHUNKER_MAPPING = None
1493
+ _CHONKIE_AVAILABLE = None
1494
+
1495
+
1496
+ def _lazy_import_chonkie():
1497
+ """Lazily import chonkie modules when needed."""
1498
+ global _CHONKIE_MODULES, _CHUNKER_MAPPING, _CHONKIE_AVAILABLE
1499
+
1500
+ if _CHONKIE_MODULES is not None:
1501
+ return _CHONKIE_MODULES
1502
+
1503
+ try:
1504
+ from chonkie import ( # noqa
1505
+ CodeChunker,
1506
+ LateChunker,
1507
+ NeuralChunker,
1508
+ RecursiveChunker,
1509
+ SemanticChunker,
1510
+ SentenceChunker,
1511
+ SlumberChunker,
1512
+ TableChunker,
1513
+ TokenChunker,
1514
+ )
1515
+ from chonkie.embeddings.base import BaseEmbeddings # noqa
1516
+ from tokenizers import Tokenizer # noqa
1517
+
1518
+ _CHONKIE_MODULES = {
1519
+ "CodeChunker": CodeChunker,
1520
+ "LateChunker": LateChunker,
1521
+ "NeuralChunker": NeuralChunker,
1522
+ "RecursiveChunker": RecursiveChunker,
1523
+ "SemanticChunker": SemanticChunker,
1524
+ "SentenceChunker": SentenceChunker,
1525
+ "SlumberChunker": SlumberChunker,
1526
+ "TableChunker": TableChunker,
1527
+ "TokenChunker": TokenChunker,
1528
+ "BaseEmbeddings": BaseEmbeddings,
1529
+ "Tokenizer": Tokenizer,
1530
+ }
1531
+ _CHUNKER_MAPPING = {
1532
+ "TokenChunker": TokenChunker,
1533
+ "SentenceChunker": SentenceChunker,
1534
+ "RecursiveChunker": RecursiveChunker,
1535
+ "SemanticChunker": SemanticChunker,
1536
+ "CodeChunker": CodeChunker,
1537
+ "LateChunker": LateChunker,
1538
+ "NeuralChunker": NeuralChunker,
1539
+ "SlumberChunker": SlumberChunker,
1540
+ "TableChunker": TableChunker,
1541
+ }
1542
+ _CHONKIE_AVAILABLE = True
1543
+ except ImportError:
1544
+ _CHONKIE_MODULES = {}
1545
+ _CHUNKER_MAPPING = {}
1546
+ _CHONKIE_AVAILABLE = False
1547
+
1548
+ return _CHONKIE_MODULES
1549
+
1550
+
1551
+ def _get_chunker_mapping():
1552
+ """Get the chunker mapping, lazily importing chonkie if needed."""
1553
+ if _CHUNKER_MAPPING is None:
1554
+ _lazy_import_chonkie()
1555
+ return _CHUNKER_MAPPING or {}
1556
+
1557
+
1558
+ def _is_chonkie_available():
1559
+ """Check if chonkie is available, lazily importing if needed."""
1560
+ if _CHONKIE_AVAILABLE is None:
1561
+ _lazy_import_chonkie()
1562
+ return _CHONKIE_AVAILABLE or False
1563
+
1564
+
1565
+ @beartype
1566
+ class ChonkieChunker(Expression):
1567
+ def __init__(
1568
+ self,
1569
+ tokenizer_name: str | None = "gpt2",
1570
+ embedding_model_name: str | None = "minishlab/potion-base-8M",
1571
+ **symai_kwargs,
1572
+ ):
1573
+ super().__init__(**symai_kwargs)
1574
+ self.tokenizer_name = tokenizer_name
1575
+ self.embedding_model_name = embedding_model_name
1576
+
1577
+ def forward(
1578
+ self, data: Symbol, chunker_name: str | None = "RecursiveChunker", **chunker_kwargs
1579
+ ) -> Symbol:
1580
+ if not _is_chonkie_available():
1581
+ UserMessage(
1582
+ "chonkie library is not installed. Please install it with `pip install chonkie tokenizers`.",
1583
+ raise_with=ImportError,
1584
+ )
1585
+ chunker = self._resolve_chunker(chunker_name, **chunker_kwargs)
1586
+ chunks = [ChonkieChunker.clean_text(chunk.text) for chunk in chunker(data.value)]
1587
+ return self._to_symbol(chunks)
1588
+
1589
+ def _resolve_chunker(self, chunker_name: str, **chunker_kwargs):
1590
+ """Resolve and instantiate a chunker by name."""
1591
+ chunker_mapping = _get_chunker_mapping()
1592
+
1593
+ if chunker_name not in chunker_mapping:
1594
+ msg = (
1595
+ f"Chunker {chunker_name} not found. Available chunkers: {list(chunker_mapping.keys())}. "
1596
+ f"See docs (https://docs.chonkie.ai/getting-started/introduction) for more info."
1597
+ )
1598
+ raise ValueError(msg)
1599
+
1600
+ chunker_class = chunker_mapping[chunker_name]
1601
+ chonkie_modules = _lazy_import_chonkie()
1602
+ Tokenizer = chonkie_modules.get("Tokenizer")
1603
+
1604
+ # Tokenizer-based chunkers (use tokenizer_name)
1605
+ if chunker_name in ["TokenChunker", "SentenceChunker", "RecursiveChunker"]:
1606
+ if Tokenizer is None:
1607
+ UserMessage(
1608
+ "Tokenizers library is not installed. Please install it with `pip install tokenizers`.",
1609
+ raise_with=ImportError,
1610
+ )
1611
+ tokenizer = Tokenizer.from_pretrained(self.tokenizer_name)
1612
+ return chunker_class(tokenizer, **chunker_kwargs)
1613
+
1614
+ # Embedding-based chunkers (use embedding_model_name)
1615
+ if chunker_name in ["SemanticChunker", "LateChunker"]:
1616
+ return chunker_class(embedding_model=self.embedding_model_name, **chunker_kwargs)
1617
+
1618
+ # CodeChunker and TableChunker use tokenizer (can use string or Tokenizer object)
1619
+ if chunker_name in ["CodeChunker", "TableChunker"]:
1620
+ # These can accept tokenizer as string (default 'character') or Tokenizer object
1621
+ # If tokenizer not provided in kwargs, use tokenizer_name
1622
+ if "tokenizer" not in chunker_kwargs:
1623
+ chunker_kwargs["tokenizer"] = self.tokenizer_name
1624
+ return chunker_class(**chunker_kwargs)
1625
+
1626
+ # SlumberChunker uses tokenizer (can use string or Tokenizer object)
1627
+ if chunker_name == "SlumberChunker":
1628
+ # SlumberChunker can accept tokenizer as string or Tokenizer object
1629
+ # If tokenizer not provided in kwargs, use tokenizer_name
1630
+ if "tokenizer" not in chunker_kwargs:
1631
+ chunker_kwargs["tokenizer"] = self.tokenizer_name
1632
+ return chunker_class(**chunker_kwargs)
1633
+
1634
+ # NeuralChunker uses model parameter (defaults provided by chonkie)
1635
+ if chunker_name == "NeuralChunker":
1636
+ return chunker_class(**chunker_kwargs)
1637
+
1638
+ msg = (
1639
+ f"Chunker {chunker_name} not properly configured. "
1640
+ f"Available chunkers: {list(chunker_mapping.keys())}."
1641
+ )
1642
+ raise ValueError(msg)
1643
+
1644
+ @staticmethod
1645
+ def clean_text(text: str) -> str:
1646
+ """Cleans text by removing problematic characters."""
1647
+ text = text.replace("\x00", "") # Remove null bytes (\x00)
1648
+ return text.encode("utf-8", errors="ignore").decode(
1649
+ "utf-8"
1650
+ ) # Replace invalid UTF-8 sequences