symbolicai 0.21.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. symai/__init__.py +269 -173
  2. symai/backend/base.py +123 -110
  3. symai/backend/engines/drawing/engine_bfl.py +45 -44
  4. symai/backend/engines/drawing/engine_gpt_image.py +112 -97
  5. symai/backend/engines/embedding/engine_llama_cpp.py +63 -52
  6. symai/backend/engines/embedding/engine_openai.py +25 -21
  7. symai/backend/engines/execute/engine_python.py +19 -18
  8. symai/backend/engines/files/engine_io.py +104 -95
  9. symai/backend/engines/imagecaptioning/engine_blip2.py +28 -24
  10. symai/backend/engines/imagecaptioning/engine_llavacpp_client.py +102 -79
  11. symai/backend/engines/index/engine_pinecone.py +124 -97
  12. symai/backend/engines/index/engine_qdrant.py +1011 -0
  13. symai/backend/engines/index/engine_vectordb.py +84 -56
  14. symai/backend/engines/lean/engine_lean4.py +96 -52
  15. symai/backend/engines/neurosymbolic/__init__.py +41 -13
  16. symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_chat.py +330 -248
  17. symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_reasoning.py +329 -264
  18. symai/backend/engines/neurosymbolic/engine_cerebras.py +328 -0
  19. symai/backend/engines/neurosymbolic/engine_deepseekX_reasoning.py +118 -88
  20. symai/backend/engines/neurosymbolic/engine_google_geminiX_reasoning.py +344 -299
  21. symai/backend/engines/neurosymbolic/engine_groq.py +173 -115
  22. symai/backend/engines/neurosymbolic/engine_huggingface.py +114 -84
  23. symai/backend/engines/neurosymbolic/engine_llama_cpp.py +144 -118
  24. symai/backend/engines/neurosymbolic/engine_openai_gptX_chat.py +415 -307
  25. symai/backend/engines/neurosymbolic/engine_openai_gptX_reasoning.py +394 -231
  26. symai/backend/engines/ocr/engine_apilayer.py +23 -27
  27. symai/backend/engines/output/engine_stdout.py +10 -13
  28. symai/backend/engines/{webscraping → scrape}/engine_requests.py +101 -54
  29. symai/backend/engines/search/engine_openai.py +100 -88
  30. symai/backend/engines/search/engine_parallel.py +665 -0
  31. symai/backend/engines/search/engine_perplexity.py +44 -45
  32. symai/backend/engines/search/engine_serpapi.py +37 -34
  33. symai/backend/engines/speech_to_text/engine_local_whisper.py +54 -51
  34. symai/backend/engines/symbolic/engine_wolframalpha.py +15 -9
  35. symai/backend/engines/text_to_speech/engine_openai.py +20 -26
  36. symai/backend/engines/text_vision/engine_clip.py +39 -37
  37. symai/backend/engines/userinput/engine_console.py +5 -6
  38. symai/backend/mixin/__init__.py +13 -0
  39. symai/backend/mixin/anthropic.py +48 -38
  40. symai/backend/mixin/deepseek.py +6 -5
  41. symai/backend/mixin/google.py +7 -4
  42. symai/backend/mixin/groq.py +2 -4
  43. symai/backend/mixin/openai.py +140 -110
  44. symai/backend/settings.py +87 -20
  45. symai/chat.py +216 -123
  46. symai/collect/__init__.py +7 -1
  47. symai/collect/dynamic.py +80 -70
  48. symai/collect/pipeline.py +67 -51
  49. symai/collect/stats.py +161 -109
  50. symai/components.py +707 -360
  51. symai/constraints.py +24 -12
  52. symai/core.py +1857 -1233
  53. symai/core_ext.py +83 -80
  54. symai/endpoints/api.py +166 -104
  55. symai/extended/.DS_Store +0 -0
  56. symai/extended/__init__.py +46 -12
  57. symai/extended/api_builder.py +29 -21
  58. symai/extended/arxiv_pdf_parser.py +23 -14
  59. symai/extended/bibtex_parser.py +9 -6
  60. symai/extended/conversation.py +156 -126
  61. symai/extended/document.py +50 -30
  62. symai/extended/file_merger.py +57 -14
  63. symai/extended/graph.py +51 -32
  64. symai/extended/html_style_template.py +18 -14
  65. symai/extended/interfaces/blip_2.py +2 -3
  66. symai/extended/interfaces/clip.py +4 -3
  67. symai/extended/interfaces/console.py +9 -1
  68. symai/extended/interfaces/dall_e.py +4 -2
  69. symai/extended/interfaces/file.py +2 -0
  70. symai/extended/interfaces/flux.py +4 -2
  71. symai/extended/interfaces/gpt_image.py +16 -7
  72. symai/extended/interfaces/input.py +2 -1
  73. symai/extended/interfaces/llava.py +1 -2
  74. symai/extended/interfaces/{naive_webscraping.py → naive_scrape.py} +4 -3
  75. symai/extended/interfaces/naive_vectordb.py +9 -10
  76. symai/extended/interfaces/ocr.py +5 -3
  77. symai/extended/interfaces/openai_search.py +2 -0
  78. symai/extended/interfaces/parallel.py +30 -0
  79. symai/extended/interfaces/perplexity.py +2 -0
  80. symai/extended/interfaces/pinecone.py +12 -9
  81. symai/extended/interfaces/python.py +2 -0
  82. symai/extended/interfaces/serpapi.py +3 -1
  83. symai/extended/interfaces/terminal.py +2 -4
  84. symai/extended/interfaces/tts.py +3 -2
  85. symai/extended/interfaces/whisper.py +3 -2
  86. symai/extended/interfaces/wolframalpha.py +2 -1
  87. symai/extended/metrics/__init__.py +11 -1
  88. symai/extended/metrics/similarity.py +14 -13
  89. symai/extended/os_command.py +39 -29
  90. symai/extended/packages/__init__.py +29 -3
  91. symai/extended/packages/symdev.py +51 -43
  92. symai/extended/packages/sympkg.py +41 -35
  93. symai/extended/packages/symrun.py +63 -50
  94. symai/extended/repo_cloner.py +14 -12
  95. symai/extended/seo_query_optimizer.py +15 -13
  96. symai/extended/solver.py +116 -91
  97. symai/extended/summarizer.py +12 -10
  98. symai/extended/taypan_interpreter.py +17 -18
  99. symai/extended/vectordb.py +122 -92
  100. symai/formatter/__init__.py +9 -1
  101. symai/formatter/formatter.py +51 -47
  102. symai/formatter/regex.py +70 -69
  103. symai/functional.py +325 -176
  104. symai/imports.py +190 -147
  105. symai/interfaces.py +57 -28
  106. symai/memory.py +45 -35
  107. symai/menu/screen.py +28 -19
  108. symai/misc/console.py +66 -56
  109. symai/misc/loader.py +8 -5
  110. symai/models/__init__.py +17 -1
  111. symai/models/base.py +395 -236
  112. symai/models/errors.py +1 -2
  113. symai/ops/__init__.py +32 -22
  114. symai/ops/measures.py +24 -25
  115. symai/ops/primitives.py +1149 -731
  116. symai/post_processors.py +58 -50
  117. symai/pre_processors.py +86 -82
  118. symai/processor.py +21 -13
  119. symai/prompts.py +764 -685
  120. symai/server/huggingface_server.py +135 -49
  121. symai/server/llama_cpp_server.py +21 -11
  122. symai/server/qdrant_server.py +206 -0
  123. symai/shell.py +100 -42
  124. symai/shellsv.py +700 -492
  125. symai/strategy.py +630 -346
  126. symai/symbol.py +368 -322
  127. symai/utils.py +100 -78
  128. {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/METADATA +22 -10
  129. symbolicai-1.1.0.dist-info/RECORD +168 -0
  130. symbolicai-0.21.0.dist-info/RECORD +0 -162
  131. {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/WHEEL +0 -0
  132. {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/entry_points.txt +0 -0
  133. {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/licenses/LICENSE +0 -0
  134. {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/top_level.txt +0 -0
symai/components.py CHANGED
@@ -1,23 +1,23 @@
1
1
  import copy
2
2
  import inspect
3
3
  import json
4
- import os
5
4
  import re
6
5
  import sys
7
- from abc import abstractmethod
8
6
  from collections import defaultdict
7
+ from collections.abc import Callable, Iterator
9
8
  from pathlib import Path
10
9
  from random import sample
11
10
  from string import ascii_lowercase, ascii_uppercase
12
11
  from threading import Lock
13
- from .context import CURRENT_ENGINE_VAR
14
- from typing import Callable, Dict, Iterator, List, Optional, Type, Union
12
+ from typing import TYPE_CHECKING, Union
13
+
14
+ if TYPE_CHECKING:
15
+ from typing import Any
15
16
 
16
17
  import numpy as np
17
- from attr import dataclass
18
+ from beartype import beartype
18
19
  from box import Box
19
20
  from loguru import logger
20
- from pydantic import BaseModel, ValidationError
21
21
  from pyvis.network import Network
22
22
  from tqdm import tqdm
23
23
 
@@ -25,46 +25,64 @@ from . import core, core_ext
25
25
  from .backend.base import Engine
26
26
  from .backend.settings import HOME_PATH
27
27
  from .constraints import DictFormatConstraint
28
+ from .context import CURRENT_ENGINE_VAR
28
29
  from .formatter import ParagraphFormatter
29
- from .post_processors import (CodeExtractPostProcessor,
30
- JsonTruncateMarkdownPostProcessor,
31
- JsonTruncatePostProcessor, PostProcessor,
32
- StripPostProcessor)
30
+ from .post_processors import (
31
+ CodeExtractPostProcessor,
32
+ JsonTruncateMarkdownPostProcessor,
33
+ JsonTruncatePostProcessor,
34
+ PostProcessor,
35
+ StripPostProcessor,
36
+ )
33
37
  from .pre_processors import JsonPreProcessor, PreProcessor
34
38
  from .processor import ProcessorPipeline
35
39
  from .prompts import JsonPromptTemplate, Prompt
36
40
  from .symbol import Expression, Metadata, Symbol
37
- from .utils import CustomUserWarning
41
+ from .utils import UserMessage
42
+
43
+ if TYPE_CHECKING:
44
+ from .backend.engines.index.engine_vectordb import VectorDBResult
45
+
46
+ _DEFAULT_PARAGRAPH_FORMATTER = ParagraphFormatter()
38
47
 
39
48
 
40
49
  class GraphViz(Expression):
41
- def __init__(self,
42
- notebook = True,
43
- cdn_resources = "remote",
44
- bgcolor = "#222222",
45
- font_color = "white",
46
- height = "750px",
47
- width = "100%",
48
- select_menu = True,
49
- filter_menu = True,
50
- **kwargs):
50
+ def __init__(
51
+ self,
52
+ notebook=True,
53
+ cdn_resources="remote",
54
+ bgcolor="#222222",
55
+ font_color="white",
56
+ height="750px",
57
+ width="100%",
58
+ select_menu=True,
59
+ filter_menu=True,
60
+ **kwargs,
61
+ ):
51
62
  super().__init__(**kwargs)
52
- self.net = Network(notebook=notebook,
53
- cdn_resources=cdn_resources,
54
- bgcolor=bgcolor,
55
- font_color=font_color,
56
- height=height,
57
- width=width,
58
- select_menu=select_menu,
59
- filter_menu=filter_menu)
60
-
61
- def forward(self, sym: Symbol, file_path: str, **kwargs):
63
+ self.net = Network(
64
+ notebook=notebook,
65
+ cdn_resources=cdn_resources,
66
+ bgcolor=bgcolor,
67
+ font_color=font_color,
68
+ height=height,
69
+ width=width,
70
+ select_menu=select_menu,
71
+ filter_menu=filter_menu,
72
+ )
73
+
74
+ def forward(self, sym: Symbol, file_path: str, **_kwargs):
62
75
  nodes = [str(n) if n.value else n.__repr__(simplified=True) for n in sym.nodes]
63
- edges = [(str(e[0]) if e[0].value else e[0].__repr__(simplified=True),
64
- str(e[1]) if e[1].value else e[1].__repr__(simplified=True)) for e in sym.edges]
76
+ edges = [
77
+ (
78
+ str(e[0]) if e[0].value else e[0].__repr__(simplified=True),
79
+ str(e[1]) if e[1].value else e[1].__repr__(simplified=True),
80
+ )
81
+ for e in sym.edges
82
+ ]
65
83
  self.net.add_nodes(nodes)
66
84
  self.net.add_edges(edges)
67
- file_path = file_path if file_path.endswith('.html') else file_path + '.html'
85
+ file_path = file_path if file_path.endswith(".html") else file_path + ".html"
68
86
  return self.net.show(file_path)
69
87
 
70
88
 
@@ -73,21 +91,21 @@ class TrackerTraceable(Expression):
73
91
 
74
92
 
75
93
  class Any(Expression):
76
- def __init__(self, *expr: List[Expression], **kwargs):
94
+ def __init__(self, *expr: list[Expression], **kwargs):
77
95
  super().__init__(**kwargs)
78
- self.expr: List[Expression] = expr
96
+ self.expr: list[Expression] = expr
79
97
 
80
98
  def forward(self, *args, **kwargs) -> Symbol:
81
- return self.sym_return_type(any([e() for e in self.expr(*args, **kwargs)]))
99
+ return self.sym_return_type(any(e() for e in self.expr(*args, **kwargs)))
82
100
 
83
101
 
84
102
  class All(Expression):
85
- def __init__(self, *expr: List[Expression], **kwargs):
103
+ def __init__(self, *expr: list[Expression], **kwargs):
86
104
  super().__init__(**kwargs)
87
- self.expr: List[Expression] = expr
105
+ self.expr: list[Expression] = expr
88
106
 
89
107
  def forward(self, *args, **kwargs) -> Symbol:
90
- return self.sym_return_type(all([e() for e in self.expr(*args, **kwargs)]))
108
+ return self.sym_return_type(all(e() for e in self.expr(*args, **kwargs)))
91
109
 
92
110
 
93
111
  class Try(Expression):
@@ -104,12 +122,14 @@ class Try(Expression):
104
122
  class Lambda(Expression):
105
123
  def __init__(self, callable: Callable, **kwargs):
106
124
  super().__init__(**kwargs)
125
+
107
126
  def _callable(*args, **kwargs):
108
127
  kw = {
109
- 'args': args,
110
- 'kwargs': kwargs,
128
+ "args": args,
129
+ "kwargs": kwargs,
111
130
  }
112
131
  return callable(kw)
132
+
113
133
  self.callable: Callable = _callable
114
134
 
115
135
  def forward(self, *args, **kwargs) -> Symbol:
@@ -117,14 +137,14 @@ class Lambda(Expression):
117
137
 
118
138
 
119
139
  class Choice(Expression):
120
- def __init__(self, cases: List[str], default: Optional[str] = None, **kwargs):
140
+ def __init__(self, cases: list[str], default: str | None = None, **kwargs):
121
141
  super().__init__(**kwargs)
122
- self.cases: List[str] = cases
123
- self.default: Optional[str] = default
142
+ self.cases: list[str] = cases
143
+ self.default: str | None = default
124
144
 
125
145
  def forward(self, sym: Symbol, *args, **kwargs) -> Symbol:
126
146
  sym = self._to_symbol(sym)
127
- return sym.choice(cases=self.cases, default=self.default, *args, **kwargs)
147
+ return sym.choice(*args, cases=self.cases, default=self.default, **kwargs)
128
148
 
129
149
 
130
150
  class Output(Expression):
@@ -135,15 +155,15 @@ class Output(Expression):
135
155
  self.verbose: bool = verbose
136
156
 
137
157
  def forward(self, *args, **kwargs) -> Expression:
138
- kwargs['verbose'] = self.verbose
139
- kwargs['handler'] = self.handler
140
- return self.output(expr=self.expr, *args, **kwargs)
158
+ kwargs["verbose"] = self.verbose
159
+ kwargs["handler"] = self.handler
160
+ return self.output(*args, expr=self.expr, **kwargs)
141
161
 
142
162
 
143
163
  class Sequence(TrackerTraceable):
144
- def __init__(self, *expressions: List[Expression], **kwargs):
164
+ def __init__(self, *expressions: list[Expression], **kwargs):
145
165
  super().__init__(**kwargs)
146
- self.expressions: List[Expression] = expressions
166
+ self.expressions: list[Expression] = expressions
147
167
 
148
168
  def forward(self, *args, **kwargs) -> Symbol:
149
169
  sym = self.expressions[0](*args, **kwargs)
@@ -159,34 +179,36 @@ class Sequence(TrackerTraceable):
159
179
 
160
180
 
161
181
  class Parallel(Expression):
162
- def __init__(self, *expr: List[Expression | Callable], sequential: bool = False, **kwargs):
182
+ def __init__(self, *expr: list[Expression | Callable], sequential: bool = False, **kwargs):
163
183
  super().__init__(**kwargs)
164
- self.sequential: bool = sequential
165
- self.expr: List[Expression] = expr
166
- self.results: List[Symbol] = []
184
+ self.sequential: bool = sequential
185
+ self.expr: list[Expression] = expr
186
+ self.results: list[Symbol] = []
167
187
 
168
188
  def forward(self, *args, **kwargs) -> Symbol:
169
189
  # run in sequence
170
190
  if self.sequential:
171
191
  return [e(*args, **kwargs) for e in self.expr]
192
+
172
193
  # run in parallel
173
194
  @core_ext.parallel(self.expr)
174
195
  def _func(e, *args, **kwargs):
175
196
  return e(*args, **kwargs)
197
+
176
198
  self.results = _func(*args, **kwargs)
177
199
  # final result of the parallel execution
178
200
  return self._to_symbol(self.results)
179
201
 
180
202
 
181
- #@TODO: BinPacker(format="...") -> ensure that data packages form a "bin" that's consistent (e.g. never break a sentence in the middle)
203
+ # @TODO: BinPacker(format="...") -> ensure that data packages form a "bin" that's consistent (e.g. never break a sentence in the middle)
182
204
  class Stream(Expression):
183
- def __init__(self, expr: Optional[Expression] = None, retrieval: Optional[str] = None, **kwargs):
205
+ def __init__(self, expr: Expression | None = None, retrieval: str | None = None, **kwargs):
184
206
  super().__init__(**kwargs)
185
- self.char_token_ratio: float = 0.6
186
- self.expr: Optional[Expression] = expr
187
- self.retrieval: Optional[str] = retrieval
188
- self._trace: bool = False
189
- self._previous_frame = None
207
+ self.char_token_ratio: float = 0.6
208
+ self.expr: Expression | None = expr
209
+ self.retrieval: str | None = retrieval
210
+ self._trace: bool = False
211
+ self._previous_frame = None
190
212
 
191
213
  def forward(self, sym: Symbol, **kwargs) -> Iterator:
192
214
  sym = self._to_symbol(sym)
@@ -194,30 +216,31 @@ class Stream(Expression):
194
216
  if self._trace:
195
217
  local_vars = self._previous_frame.f_locals
196
218
  vals = []
197
- for key, var in local_vars.items():
219
+ for _key, var in local_vars.items():
198
220
  if isinstance(var, TrackerTraceable):
199
221
  vals.append(var)
200
222
 
201
223
  if len(vals) == 1:
202
224
  self.expr = vals[0]
203
225
  else:
204
- raise ValueError(f"This component does either not inherit from TrackerTraceable or has an invalid number of component declarations: {len(vals)}! Only one component that inherits from TrackerTraceable is allowed in the with stream clause.")
205
-
206
- res = sym.stream(expr=self.expr,
207
- char_token_ratio=self.char_token_ratio,
208
- **kwargs)
209
-
226
+ UserMessage(
227
+ "This component does either not inherit from TrackerTraceable or has an invalid number of component "
228
+ f"declarations: {len(vals)}! Only one component that inherits from TrackerTraceable is allowed in the "
229
+ "with stream clause.",
230
+ raise_with=ValueError,
231
+ )
232
+
233
+ res = sym.stream(expr=self.expr, char_token_ratio=self.char_token_ratio, **kwargs)
210
234
  if self.retrieval is not None:
211
235
  res = list(res)
212
- if self.retrieval == 'all':
236
+ if self.retrieval == "all":
213
237
  return res
214
- if self.retrieval == 'longest':
238
+ if self.retrieval == "longest":
215
239
  res = sorted(res, key=lambda x: len(x), reverse=True)
216
240
  return res[0]
217
- if self.retrieval == 'contains':
218
- res = [r for r in res if self.expr in r]
219
- return res
220
- raise ValueError(f"Invalid retrieval method: {self.retrieval}")
241
+ if self.retrieval == "contains":
242
+ return [r for r in res if self.expr in r]
243
+ UserMessage(f"Invalid retrieval method: {self.retrieval}", raise_with=ValueError)
221
244
 
222
245
  return res
223
246
 
@@ -231,10 +254,12 @@ class Stream(Expression):
231
254
 
232
255
 
233
256
  class Trace(Expression):
234
- def __init__(self, expr: Optional[Expression] = None, engines=['all'], **kwargs):
257
+ def __init__(self, expr: Expression | None = None, engines=None, **kwargs):
258
+ if engines is None:
259
+ engines = ["all"]
235
260
  super().__init__(**kwargs)
236
261
  self.expr: Expression = expr
237
- self.engines: List[str] = engines
262
+ self.engines: list[str] = engines
238
263
 
239
264
  def forward(self, *args, **kwargs) -> Expression:
240
265
  Expression.command(verbose=True, engines=self.engines)
@@ -252,23 +277,26 @@ class Trace(Expression):
252
277
  Expression.command(verbose=False, engines=self.engines)
253
278
  if self.expr is not None:
254
279
  return self.expr.__exit__(type, value, traceback)
280
+ return None
255
281
 
256
282
 
257
283
  class Analyze(Expression):
258
- def __init__(self, exception: Exception, query: Optional[str] = None, **kwargs):
284
+ def __init__(self, exception: Exception, query: str | None = None, **kwargs):
259
285
  super().__init__(**kwargs)
260
286
  self.exception: Expression = exception
261
- self.query: Optional[str] = query
287
+ self.query: str | None = query
262
288
 
263
289
  def forward(self, sym: Symbol, *args, **kwargs) -> Symbol:
264
- return sym.analyze(exception=self.exception, query=self.query, *args, **kwargs)
290
+ return sym.analyze(*args, exception=self.exception, query=self.query, **kwargs)
265
291
 
266
292
 
267
293
  class Log(Expression):
268
- def __init__(self, expr: Optional[Expression] = None, engines=['all'], **kwargs):
294
+ def __init__(self, expr: Expression | None = None, engines=None, **kwargs):
295
+ if engines is None:
296
+ engines = ["all"]
269
297
  super().__init__(**kwargs)
270
298
  self.expr: Expression = expr
271
- self.engines: List[str] = engines
299
+ self.engines: list[str] = engines
272
300
 
273
301
  def forward(self, *args, **kwargs) -> Expression:
274
302
  Expression.command(logging=True, engines=self.engines)
@@ -286,10 +314,16 @@ class Log(Expression):
286
314
  Expression.command(logging=False, engines=self.engines)
287
315
  if self.expr is not None:
288
316
  return self.expr.__exit__(type, value, traceback)
317
+ return None
289
318
 
290
319
 
291
320
  class Template(Expression):
292
- def __init__(self, template: str = "<html><body>{{placeholder}}</body></html>", placeholder: str = '{{placeholder}}', **kwargs):
321
+ def __init__(
322
+ self,
323
+ template: str = "<html><body>{{placeholder}}</body></html>",
324
+ placeholder: str = "{{placeholder}}",
325
+ **kwargs,
326
+ ):
293
327
  super().__init__(**kwargs)
294
328
  self.placeholder = placeholder
295
329
  self.template_ = template
@@ -319,22 +353,26 @@ class RuntimeExpression(Expression):
319
353
  code = self._to_symbol(code)
320
354
  # declare the runtime expression from the code
321
355
  expr = self.runner(code)
356
+
322
357
  def _func(sym):
323
358
  # execute nested expression
324
- return expr['locals']['_output_'](sym)
359
+ return expr["locals"]["_output_"](sym)
360
+
325
361
  return _func
326
362
 
327
363
 
328
364
  class Metric(Expression):
329
365
  def __init__(self, normalize: bool = False, eps: float = 1e-8, **kwargs):
330
366
  super().__init__(**kwargs)
331
- self.normalize = normalize
332
- self.eps = eps
367
+ self.normalize = normalize
368
+ self.eps = eps
333
369
 
334
- def forward(self, sym: Symbol, **kwargs) -> Symbol:
370
+ def forward(self, sym: Symbol, **_kwargs) -> Symbol:
335
371
  sym = self._to_symbol(sym)
336
- assert sym.value_type == np.ndarray or sym.value_type == list, 'Metric can only be applied to numpy arrays or lists.'
337
- if sym.value_type == list:
372
+ assert sym.value_type is np.ndarray or sym.value_type is list, (
373
+ "Metric can only be applied to numpy arrays or lists."
374
+ )
375
+ if sym.value_type is list:
338
376
  sym._value = np.array(sym.value)
339
377
  # compute normalization between 0 and 1
340
378
  if self.normalize:
@@ -343,17 +381,19 @@ class Metric(Expression):
343
381
  elif len(sym.value.shape) == 2:
344
382
  pass
345
383
  else:
346
- raise ValueError(f'Invalid shape: {sym.value.shape}')
384
+ UserMessage(f"Invalid shape: {sym.value.shape}", raise_with=ValueError)
347
385
  # normalize between 0 and 1 and sum to 1
348
386
  sym._value = np.exp(sym.value) / (np.exp(sym.value).sum() + self.eps)
349
387
  return sym
350
388
 
351
389
 
352
390
  class Style(Expression):
353
- def __init__(self, description: str, libraries: List[str] = [], **kwargs):
391
+ def __init__(self, description: str, libraries: list[str] | None = None, **kwargs):
392
+ if libraries is None:
393
+ libraries = []
354
394
  super().__init__(**kwargs)
355
395
  self.description: str = description
356
- self.libraries: List[str] = libraries
396
+ self.libraries: list[str] = libraries
357
397
 
358
398
  def forward(self, sym: Symbol, **kwargs) -> Symbol:
359
399
  sym = self._to_symbol(sym)
@@ -365,7 +405,7 @@ class Query(TrackerTraceable):
365
405
  super().__init__(**kwargs)
366
406
  self.prompt: str = prompt
367
407
 
368
- def forward(self, sym: Symbol, context: Symbol = None, *args, **kwargs) -> Symbol:
408
+ def forward(self, sym: Symbol, context: Symbol = None, *_args, **kwargs) -> Symbol:
369
409
  sym = self._to_symbol(sym)
370
410
  return sym.query(prompt=self.prompt, context=context, **kwargs)
371
411
 
@@ -397,16 +437,16 @@ _output_ = _func()
397
437
 
398
438
  def forward(self, sym: Symbol, enclosure: bool = False, **kwargs) -> Symbol:
399
439
  if enclosure or self.enclosure:
400
- lines = str(sym).split('\n')
401
- lines = [' ' + line for line in lines]
402
- sym = '\n'.join(lines)
403
- sym = self.template.replace('{sym}', str(sym))
440
+ lines = str(sym).split("\n")
441
+ lines = [" " + line for line in lines]
442
+ sym = "\n".join(lines)
443
+ sym = self.template.replace("{sym}", str(sym))
404
444
  sym = self._to_symbol(sym)
405
445
  return sym.execute(**kwargs)
406
446
 
407
447
 
408
448
  class Convert(Expression):
409
- def __init__(self, format: str = 'Python', **kwargs):
449
+ def __init__(self, format: str = "Python", **kwargs):
410
450
  super().__init__(**kwargs)
411
451
  self.format = format
412
452
 
@@ -440,13 +480,13 @@ class Map(Expression):
440
480
 
441
481
 
442
482
  class Translate(Expression):
443
- def __init__(self, language: str = 'English', **kwargs):
483
+ def __init__(self, language: str = "English", **kwargs):
444
484
  super().__init__(**kwargs)
445
485
  self.language = language
446
486
 
447
487
  def forward(self, sym: Symbol, **kwargs) -> Symbol:
448
488
  sym = self._to_symbol(sym)
449
- if sym.isinstanceof(f'{self.language} text'):
489
+ if sym.isinstanceof(f"{self.language} text"):
450
490
  return sym
451
491
  return sym.translate(language=self.language, **kwargs)
452
492
 
@@ -474,11 +514,11 @@ class ExcludeFilter(Expression):
474
514
  class FileWriter(Expression):
475
515
  def __init__(self, path: str, **kwargs):
476
516
  super().__init__(**kwargs)
477
- self.path = path
517
+ self.path = Path(path)
478
518
 
479
- def forward(self, sym: Symbol, **kwargs) -> Symbol:
519
+ def forward(self, sym: Symbol, **_kwargs) -> Symbol:
480
520
  sym = self._to_symbol(sym)
481
- with open(self.path, 'w') as f:
521
+ with self.path.open("w") as f:
482
522
  f.write(str(sym))
483
523
 
484
524
 
@@ -486,20 +526,18 @@ class FileReader(Expression):
486
526
  @staticmethod
487
527
  def exists(path: str) -> bool:
488
528
  # remove slicing if any
489
- _tmp = path
490
- _splits = _tmp.split('[')
491
- if '[' in _tmp:
529
+ _tmp = path
530
+ _splits = _tmp.split("[")
531
+ if "[" in _tmp:
492
532
  _tmp = _splits[0]
493
- assert len(_splits) == 1 or len(_splits) == 2, 'Invalid file link format.'
494
- _tmp = Path(_tmp)
533
+ assert len(_splits) == 1 or len(_splits) == 2, "Invalid file link format."
534
+ _tmp = Path(_tmp)
495
535
  # check if file exists and is a file
496
- if os.path.exists(_tmp) and os.path.isfile(_tmp):
497
- return True
498
- return False
536
+ return _tmp.is_file()
499
537
 
500
538
  @staticmethod
501
- def get_files(folder_path: str, max_depth: int = 1) -> List[str]:
502
- accepted_formats = ['.pdf', '.md', '.txt']
539
+ def get_files(folder_path: str, max_depth: int = 1) -> list[str]:
540
+ accepted_formats = [".pdf", ".md", ".txt"]
503
541
 
504
542
  folder = Path(folder_path)
505
543
  files = []
@@ -512,10 +550,35 @@ class FileReader(Expression):
512
550
  return files
513
551
 
514
552
  @staticmethod
515
- def extract_files(cmds: str) -> Optional[List[str]]:
516
- # Use the updated regular expression to match quoted and non-quoted paths
517
- pattern = r'''(?:"((?:\\.|[^"\\])*)"|'((?:\\.|[^'\\])*)'|`((?:\\.|[^`\\])*)`|((?:\\ |[^ ])+))'''
518
- # Use the regular expression to split and handle quoted and non-quoted paths
553
+ def extract_files(cmds: str) -> list[str] | None:
554
+ """
555
+ Extract file paths from a command string, handling various quoting styles.
556
+
557
+ This method is used by the Qdrant RAG implementation when processing document paths.
558
+ It uses regex to parse file paths that may be quoted in different ways.
559
+
560
+ Regex patterns used:
561
+ 1. Main pattern: Matches file paths in four formats:
562
+ - Double-quoted: "path/to/file" (handles escaped characters)
563
+ - Single-quoted: 'path/to/file' (handles escaped characters)
564
+ - Backtick-quoted: `path/to/file` (handles escaped characters)
565
+ - Non-quoted: path/to/file (handles escaped spaces)
566
+
567
+ 2. Escape removal pattern: r"\\(.)" -> r"\1"
568
+ - Removes backslash escape sequences from quoted paths
569
+ - Example: "path\\/to\\/file" -> "path/to/file"
570
+ - Used for double quotes, single quotes, and backticks
571
+ """
572
+ # Regex pattern to match file paths in various quoting styles
573
+ # Pattern breakdown:
574
+ # - (?:"((?:\\.|[^"\\])*)") : Matches double-quoted paths, capturing content while handling escapes
575
+ # - '((?:\\.|[^'\\])*)' : Matches single-quoted paths, capturing content while handling escapes
576
+ # - `((?:\\.|[^`\\])*)` : Matches backtick-quoted paths, capturing content while handling escapes
577
+ # - ((?:\\ |[^ ])+) : Matches non-quoted paths, allowing escaped spaces
578
+ pattern = (
579
+ r"""(?:"((?:\\.|[^"\\])*)"|'((?:\\.|[^'\\])*)'|`((?:\\.|[^`\\])*)`|((?:\\ |[^ ])+))"""
580
+ )
581
+ # Use regex to find all file path matches in the command string
519
582
  matches = re.findall(pattern, cmds)
520
583
  # Process the matches to handle quoted paths and normal paths
521
584
  files = []
@@ -523,23 +586,27 @@ class FileReader(Expression):
523
586
  # Each match will have 4 groups due to the pattern; only one will be non-empty
524
587
  quoted_double, quoted_single, quoted_backtick, non_quoted = match
525
588
  if quoted_double:
526
- # Remove backslashes used for escaping inside double quotes
527
- path = re.sub(r'\\(.)', r'\1', quoted_double)
589
+ # Regex substitution: Remove backslashes used for escaping inside double quotes
590
+ # Pattern r"\\(.)" matches a backslash followed by any character and replaces with just the character
591
+ # Example: "path\\/to\\/file" -> "path/to/file"
592
+ path = re.sub(r"\\(.)", r"\1", quoted_double)
528
593
  file = FileReader.expand_user_path(path)
529
594
  files.append(file)
530
595
  elif quoted_single:
531
- # Remove backslashes used for escaping inside single quotes
532
- path = re.sub(r'\\(.)', r'\1', quoted_single)
596
+ # Regex substitution: Remove backslashes used for escaping inside single quotes
597
+ # Same pattern as above, applied to single-quoted paths
598
+ path = re.sub(r"\\(.)", r"\1", quoted_single)
533
599
  file = FileReader.expand_user_path(path)
534
600
  files.append(file)
535
601
  elif quoted_backtick:
536
- # Remove backslashes used for escaping inside backticks
537
- path = re.sub(r'\\(.)', r'\1', quoted_backtick)
602
+ # Regex substitution: Remove backslashes used for escaping inside backticks
603
+ # Same pattern as above, applied to backtick-quoted paths
604
+ path = re.sub(r"\\(.)", r"\1", quoted_backtick)
538
605
  file = FileReader.expand_user_path(path)
539
606
  files.append(file)
540
607
  elif non_quoted:
541
- # Replace escaped spaces with actual spaces
542
- path = non_quoted.replace('\\ ', ' ')
608
+ # Replace escaped spaces with actual spaces (no regex needed here, simple string replace)
609
+ path = non_quoted.replace("\\ ", " ")
543
610
  file = FileReader.expand_user_path(path)
544
611
  files.append(file)
545
612
  # Filter out any files that do not exist
@@ -551,31 +618,34 @@ class FileReader(Expression):
551
618
  return Path(path).expanduser().resolve().as_posix()
552
619
 
553
620
  @staticmethod
554
- def integrity_check(files: List[str]) -> List[str]:
621
+ def integrity_check(files: list[str]) -> list[str]:
555
622
  not_skipped = []
556
623
  for file in tqdm(files):
557
624
  if FileReader.exists(file):
558
625
  not_skipped.append(file)
559
626
  else:
560
- CustomUserWarning(f'Skipping file: {file}')
627
+ UserMessage(f"Skipping file: {file}")
561
628
  return not_skipped
562
629
 
563
- def forward(self, files: Union[str, List[str]], **kwargs) -> Expression:
630
+ def forward(self, files: str | list[str], **kwargs) -> Expression:
564
631
  if isinstance(files, str):
565
632
  # Convert to list for uniform processing; more easily downstream
566
633
  files = [files]
567
- if kwargs.get('run_integrity_check'):
634
+ if kwargs.get("run_integrity_check"):
568
635
  files = self.integrity_check(files)
569
636
  return self.sym_return_type([self.open(f, **kwargs).value for f in files])
570
637
 
638
+
571
639
  class FileQuery(Expression):
572
640
  def __init__(self, path: str, filter: str, **kwargs):
573
641
  super().__init__(**kwargs)
574
642
  self.path = path
575
643
  file_open = FileReader()
576
- self.query_stream = Stream(Sequence(
577
- IncludeFilter(filter),
578
- ))
644
+ self.query_stream = Stream(
645
+ Sequence(
646
+ IncludeFilter(filter),
647
+ )
648
+ )
579
649
  self.file = file_open(path)
580
650
 
581
651
  def forward(self, sym: Symbol, **kwargs) -> Symbol:
@@ -585,40 +655,45 @@ class FileQuery(Expression):
585
655
 
586
656
 
587
657
  class Function(TrackerTraceable):
588
- def __init__(self, prompt: str = '',
589
- examples: Optional[str] = [],
590
- pre_processors: Optional[List[PreProcessor]] = None,
591
- post_processors: Optional[List[PostProcessor]] = None,
592
- default: Optional[object] = None,
593
- constraints: List[Callable] = [],
594
- return_type: Optional[Type] = str,
595
- sym_return_type: Optional[Type] = Symbol,
596
- origin_type: Optional[Type] = Expression,
597
- *args, **kwargs):
658
+ def __init__(
659
+ self,
660
+ prompt: str = "",
661
+ examples: str | None = [],
662
+ pre_processors: list[PreProcessor] | None = None,
663
+ post_processors: list[PostProcessor] | None = None,
664
+ default: object | None = None,
665
+ constraints: list[Callable] | None = None,
666
+ return_type: type | None = str,
667
+ sym_return_type: type | None = Symbol,
668
+ origin_type: type | None = Expression,
669
+ *args,
670
+ **kwargs,
671
+ ):
672
+ if constraints is None:
673
+ constraints = []
598
674
  super().__init__(**kwargs)
599
- chars = ascii_lowercase + ascii_uppercase
600
- self.name = 'func_' + ''.join(sample(chars, 15))
601
- self.args = args
675
+ chars = ascii_lowercase + ascii_uppercase
676
+ self.name = "func_" + "".join(sample(chars, 15))
677
+ self.args = args
602
678
  self.kwargs = kwargs
603
- self._promptTemplate = prompt
604
- self._promptFormatArgs = []
679
+ self._promptTemplate = prompt
680
+ self._promptFormatArgs = []
605
681
  self._promptFormatKwargs = {}
606
- self.examples = Prompt(examples)
607
- self.pre_processors = pre_processors
682
+ self.examples = Prompt(examples)
683
+ self.pre_processors = pre_processors
608
684
  self.post_processors = post_processors
609
- self.constraints = constraints
610
- self.default = default
611
- self.return_type = return_type
685
+ self.constraints = constraints
686
+ self.default = default
687
+ self.return_type = return_type
612
688
  self.sym_return_type = sym_return_type
613
- self.origin_type = origin_type
689
+ self.origin_type = origin_type
614
690
 
615
691
  @property
616
692
  def prompt(self):
617
693
  # return a copy of the prompt template
618
694
  if len(self._promptFormatArgs) == 0 and len(self._promptFormatKwargs) == 0:
619
695
  return self._promptTemplate
620
- return f"{self._promptTemplate}".format(*self._promptFormatArgs,
621
- **self._promptFormatKwargs)
696
+ return f"{self._promptTemplate}".format(*self._promptFormatArgs, **self._promptFormatKwargs)
622
697
 
623
698
  def format(self, *args, **kwargs):
624
699
  self._promptFormatArgs = args
@@ -626,27 +701,36 @@ class Function(TrackerTraceable):
626
701
 
627
702
  def forward(self, *args, **kwargs) -> Expression:
628
703
  # special case for few shot function prompt definition override
629
- if 'fn' in kwargs:
630
- self.prompt = kwargs['fn']
631
- del kwargs['fn']
632
- @core.few_shot(prompt=self.prompt,
633
- examples=self.examples,
634
- pre_processors=self.pre_processors,
635
- post_processors=self.post_processors,
636
- constraints=self.constraints,
637
- default=self.default,
638
- *self.args, **self.kwargs)
704
+ if "fn" in kwargs:
705
+ self.prompt = kwargs["fn"]
706
+ del kwargs["fn"]
707
+
708
+ @core.few_shot(
709
+ *self.args,
710
+ prompt=self.prompt,
711
+ examples=self.examples,
712
+ pre_processors=self.pre_processors,
713
+ post_processors=self.post_processors,
714
+ constraints=self.constraints,
715
+ default=self.default,
716
+ **self.kwargs,
717
+ )
639
718
  def _func(_, *args, **kwargs) -> self.return_type:
640
719
  pass
641
- _type = type(self.name, (self.origin_type, ), {
642
- # constructor
643
- "forward": _func,
644
- "sym_return_type": self.sym_return_type,
645
- "static_context": self.static_context,
646
- "dynamic_context": self.dynamic_context,
647
- "__class__": self.__class__,
648
- "__module__": self.__module__,
649
- })
720
+
721
+ _type = type(
722
+ self.name,
723
+ (self.origin_type,),
724
+ {
725
+ # constructor
726
+ "forward": _func,
727
+ "sym_return_type": self.sym_return_type,
728
+ "static_context": self.static_context,
729
+ "dynamic_context": self.dynamic_context,
730
+ "__class__": self.__class__,
731
+ "__module__": self.__module__,
732
+ },
733
+ )
650
734
  obj = _type()
651
735
 
652
736
  return self._to_symbol(obj(*args, **kwargs))
@@ -657,19 +741,19 @@ class PrepareData(Function):
657
741
  def __call__(self, argument):
658
742
  assert argument.prop.context is not None
659
743
  instruct = argument.prop.prompt
660
- context = argument.prop.context
661
- return """{
662
- 'context': '%s',
663
- 'instruction': '%s',
744
+ context = argument.prop.context
745
+ return f"""{{
746
+ 'context': '{context}',
747
+ 'instruction': '{instruct}',
664
748
  'result': 'TODO: Replace this with the expected result.'
665
- }""" % (context, instruct)
749
+ }}"""
666
750
 
667
751
  def __init__(self, *args, **kwargs):
668
752
  super().__init__(*args, **kwargs)
669
- self.pre_processors = [self.PrepareDataPreProcessor()]
670
- self.constraints = [DictFormatConstraint({ 'result': '<the data>' })]
753
+ self.pre_processors = [self.PrepareDataPreProcessor()]
754
+ self.constraints = [DictFormatConstraint({"result": "<the data>"})]
671
755
  self.post_processors = [JsonTruncateMarkdownPostProcessor()]
672
- self.return_type = dict # constraint to cast the result to a dict
756
+ self.return_type = dict # constraint to cast the result to a dict
673
757
 
674
758
  @property
675
759
  def static_context(self):
@@ -704,10 +788,10 @@ Your goal is to prepare the data for the next task instruction. The data should
704
788
 
705
789
  class ExpressionBuilder(Function):
706
790
  def __init__(self, **kwargs):
707
- super().__init__('Generate the code following the instructions:', **kwargs)
791
+ super().__init__("Generate the code following the instructions:", **kwargs)
708
792
  self.processors = ProcessorPipeline([StripPostProcessor(), CodeExtractPostProcessor()])
709
793
 
710
- def forward(self, instruct, *args, **kwargs):
794
+ def forward(self, instruct, *_args, **_kwargs):
711
795
  result = super().forward(instruct)
712
796
  return self.processors(str(result), None)
713
797
 
@@ -755,10 +839,12 @@ Always produce the entire code to be executed in the same Python process. All ta
755
839
  class JsonParser(Expression):
756
840
  def __init__(self, query: str, json_: dict, **kwargs):
757
841
  super().__init__(**kwargs)
758
- func = Function(prompt=JsonPromptTemplate(query, json_),
759
- constraints=[DictFormatConstraint(json_)],
760
- pre_processors=[JsonPreProcessor()],
761
- post_processors=[JsonTruncatePostProcessor()])
842
+ func = Function(
843
+ prompt=JsonPromptTemplate(query, json_),
844
+ constraints=[DictFormatConstraint(json_)],
845
+ pre_processors=[JsonPreProcessor()],
846
+ post_processors=[JsonTruncatePostProcessor()],
847
+ )
762
848
  self.fn = Try(func, retries=1)
763
849
 
764
850
  def forward(self, sym: Symbol, **kwargs) -> Symbol:
@@ -768,21 +854,27 @@ class JsonParser(Expression):
768
854
 
769
855
 
770
856
  class SimilarityClassification(Expression):
771
- def __init__(self, classes: List[str], metric: str = 'cosine', in_memory: bool = False, **kwargs):
857
+ def __init__(
858
+ self, classes: list[str], metric: str = "cosine", in_memory: bool = False, **kwargs
859
+ ):
772
860
  super().__init__(**kwargs)
773
- self.classes = classes
774
- self.metric = metric
861
+ self.classes = classes
862
+ self.metric = metric
775
863
  self.in_memory = in_memory
776
864
 
777
865
  if self.in_memory:
778
- CustomUserWarning(f'Caching mode is enabled! It is your responsability to empty the .cache folder if you did changes to the classes. The cache is located at {HOME_PATH}/cache')
866
+ UserMessage(
867
+ f"Caching mode is enabled! It is your responsability to empty the .cache folder if you did changes to the classes. The cache is located at {HOME_PATH}/cache"
868
+ )
779
869
 
780
870
  def forward(self, x: Symbol) -> Symbol:
781
- x = self._to_symbol(x)
782
- usr_embed = x.embed()
783
- embeddings = self._dynamic_cache()
871
+ x = self._to_symbol(x)
872
+ usr_embed = x.embed()
873
+ embeddings = self._dynamic_cache()
784
874
  similarities = [usr_embed.similarity(emb, metric=self.metric) for emb in embeddings]
785
- similarities = sorted(zip(self.classes, similarities), key=lambda x: x[1], reverse=True)
875
+ similarities = sorted(
876
+ zip(self.classes, similarities, strict=False), key=lambda x: x[1], reverse=True
877
+ )
786
878
 
787
879
  return Symbol(similarities[0][0])
788
880
 
@@ -790,9 +882,7 @@ class SimilarityClassification(Expression):
790
882
  @core_ext.cache(in_memory=self.in_memory)
791
883
  def embed_classes(self):
792
884
  opts = map(Symbol, self.classes)
793
- embeddings = [opt.embed() for opt in opts]
794
-
795
- return embeddings
885
+ return [opt.embed() for opt in opts]
796
886
 
797
887
  return embed_classes(self)
798
888
 
@@ -803,11 +893,7 @@ class InContextClassification(Expression):
803
893
  self.blueprint = blueprint
804
894
 
805
895
  def forward(self, x: Symbol, **kwargs) -> Symbol:
806
- @core.few_shot(
807
- prompt=x,
808
- examples=self.blueprint,
809
- **kwargs
810
- )
896
+ @core.few_shot(prompt=x, examples=self.blueprint, **kwargs)
811
897
  def _func(_):
812
898
  pass
813
899
 
@@ -815,43 +901,38 @@ class InContextClassification(Expression):
815
901
 
816
902
 
817
903
  class Indexer(Expression):
818
- DEFAULT = 'dataindex'
904
+ DEFAULT = "dataindex"
819
905
 
820
906
  @staticmethod
821
907
  def replace_special_chars(index: str):
822
908
  # replace special characters that are not for path
823
- index = str(index)
824
- index = index.replace('-', '')
825
- index = index.replace('_', '')
826
- index = index.replace(' ', '')
827
- index = index.lower()
828
- return index
909
+ return str(index).replace("-", "").replace("_", "").replace(" ", "").lower()
829
910
 
830
911
  def __init__(
831
- self,
832
- index_name: str = DEFAULT,
833
- top_k: int = 8,
834
- batch_size: int = 20,
835
- formatter: Callable = ParagraphFormatter(),
836
- auto_add=False,
837
- raw_result: bool = False,
838
- new_dim: int = 1536,
839
- **kwargs
840
- ):
912
+ self,
913
+ index_name: str = DEFAULT,
914
+ top_k: int = 8,
915
+ batch_size: int = 20,
916
+ formatter: Callable = _DEFAULT_PARAGRAPH_FORMATTER,
917
+ auto_add=False,
918
+ raw_result: bool = False,
919
+ new_dim: int = 1536,
920
+ **kwargs,
921
+ ):
841
922
  super().__init__(**kwargs)
842
923
  index_name = Indexer.replace_special_chars(index_name)
843
924
  self.index_name = index_name
844
- self.elements = []
925
+ self.elements = []
845
926
  self.batch_size = batch_size
846
- self.top_k = top_k
847
- self.retrieval = None
848
- self.formatter = formatter
927
+ self.top_k = top_k
928
+ self.retrieval = None
929
+ self.formatter = formatter
849
930
  self.raw_result = raw_result
850
- self.new_dim = new_dim
931
+ self.new_dim = new_dim
851
932
  self.sym_return_type = Expression
852
933
 
853
934
  # append index name to indices.txt in home directory .symai folder (default)
854
- self.path = HOME_PATH / 'indices.txt'
935
+ self.path = HOME_PATH / "indices.txt"
855
936
  if not self.path.exists():
856
937
  self.path.parent.mkdir(parents=True, exist_ok=True)
857
938
  self.path.touch()
@@ -861,52 +942,63 @@ class Indexer(Expression):
861
942
  def register(self):
862
943
  # check if index already exists in indices.txt and append if not
863
944
  change = False
864
- with open(self.path, 'r') as f:
865
- indices = f.read().split('\n')
945
+ with self.path.open() as f:
946
+ indices = f.read().split("\n")
866
947
  # filter out empty strings
867
948
  indices = [i for i in indices if i]
868
- if self.index_name not in indices:
869
- indices.append(self.index_name)
870
- change = True
949
+ if self.index_name not in indices:
950
+ indices.append(self.index_name)
951
+ change = True
871
952
  if change:
872
- with open(self.path, 'w') as f:
873
- f.write('\n'.join(indices))
953
+ with self.path.open("w") as f:
954
+ f.write("\n".join(indices))
874
955
 
875
956
  def exists(self) -> bool:
876
957
  # check if index exists in home directory .symai folder (default) indices.txt
877
- path = HOME_PATH / 'indices.txt'
958
+ path = HOME_PATH / "indices.txt"
878
959
  if not path.exists():
879
960
  return False
880
- with open(path, 'r') as f:
881
- indices = f.read().split('\n')
961
+ with path.open() as f:
962
+ indices = f.read().split("\n")
882
963
  if self.index_name in indices:
883
964
  return True
965
+ return False
884
966
 
885
967
  def forward(
886
- self,
887
- data: Optional[Symbol] = None,
888
- raw_result: bool = False,
889
- ) -> Symbol:
968
+ self,
969
+ data: Symbol | None = None,
970
+ _raw_result: bool = False,
971
+ ) -> Symbol:
890
972
  that = self
891
973
  if data is not None:
892
974
  data = self._to_symbol(data)
893
975
  self.elements = self.formatter(data).value
894
976
  # run over the elments in batches
895
977
  for i in tqdm(range(0, len(self.elements), self.batch_size)):
896
- val = Symbol(self.elements[i:i+self.batch_size]).zip(new_dim=self.new_dim)
978
+ val = Symbol(self.elements[i : i + self.batch_size]).zip(new_dim=self.new_dim)
897
979
  that.add(val, index_name=that.index_name, index_dims=that.new_dim)
898
980
  # we save the index
899
981
  that.config(None, save=True, index_name=that.index_name, index_dims=that.new_dim)
900
982
 
901
- def _func(query, *args, **kwargs) -> Union[Symbol, 'VectorDBResult']:
902
- raw_result = kwargs.get('raw_result') or that.raw_result
983
+ def _func(query, *_args, **kwargs) -> Union[Symbol, "VectorDBResult"]:
984
+ raw_result = kwargs.get("raw_result") or that.raw_result
903
985
  query_emb = Symbol(query).embed(new_dim=that.new_dim).value
904
- res = that.get(query_emb, index_name=that.index_name, index_top_k=that.top_k, ori_query=query, index_dims=that.new_dim, **kwargs)
986
+ res = that.get(
987
+ query_emb,
988
+ index_name=that.index_name,
989
+ index_top_k=that.top_k,
990
+ ori_query=query,
991
+ index_dims=that.new_dim,
992
+ **kwargs,
993
+ )
905
994
  that.retrieval = res
906
995
  if raw_result:
907
996
  return res
908
- rsp = Symbol(res).query(prompt='From the retrieved data, select the most relevant information.', context=query)
909
- return rsp
997
+ return Symbol(res).query(
998
+ prompt="From the retrieved data, select the most relevant information.",
999
+ context=query,
1000
+ )
1001
+
910
1002
  return _func
911
1003
 
912
1004
 
@@ -917,8 +1009,8 @@ class PrimitiveDisabler(Expression):
917
1009
  self._original_primitives = defaultdict(list)
918
1010
 
919
1011
  def __enter__(self):
920
- # Avoid circular imports; import locally
921
- from .symbol import Symbol
1012
+ # Import Symbol lazily so components does not clash with symbol during load.
1013
+ from .symbol import Symbol # noqa
922
1014
 
923
1015
  frame = inspect.currentframe()
924
1016
  f_locals = frame.f_back.f_locals
@@ -934,7 +1026,7 @@ class PrimitiveDisabler(Expression):
934
1026
  for func in self._primitives:
935
1027
  if hasattr(sym, func):
936
1028
  self._original_primitives[sym_name].append((func, getattr(sym, func)))
937
- setattr(sym, func, lambda *args, **kwargs: None)
1029
+ setattr(sym, func, lambda *_args, **_kwargs: None)
938
1030
 
939
1031
  def _enable_primitives(self):
940
1032
  for sym_name, sym in self._symbols.items():
@@ -945,7 +1037,7 @@ class PrimitiveDisabler(Expression):
945
1037
  for sym in self._symbols.values():
946
1038
  for primitive in sym._primitives:
947
1039
  for method, _ in inspect.getmembers(primitive, predicate=inspect.isfunction):
948
- if method in self._primitives or method.startswith('_'):
1040
+ if method in self._primitives or method.startswith("_"):
949
1041
  continue
950
1042
  self._primitives.add(method)
951
1043
 
@@ -968,7 +1060,7 @@ class FunctionWithUsage(Function):
968
1060
 
969
1061
  def print_verbose(self, msg):
970
1062
  if self.verbose:
971
- print(msg)
1063
+ UserMessage(msg)
972
1064
 
973
1065
  def _format_usage(self, prompt_tokens, completion_tokens, total_tokens):
974
1066
  return Box(
@@ -990,9 +1082,7 @@ class FunctionWithUsage(Function):
990
1082
  self.total_tokens += usage.total_tokens
991
1083
 
992
1084
  def get_usage(self):
993
- return self._format_usage(
994
- self.prompt_tokens, self.completion_tokens, self.total_tokens
995
- )
1085
+ return self._format_usage(self.prompt_tokens, self.completion_tokens, self.total_tokens)
996
1086
 
997
1087
  def forward(self, *args, **kwargs):
998
1088
  if "return_metadata" not in kwargs:
@@ -1003,9 +1093,7 @@ class FunctionWithUsage(Function):
1003
1093
  raw_output = metadata.get("raw_output")
1004
1094
  if hasattr(raw_output, "usage"):
1005
1095
  usage = raw_output.usage
1006
- prompt_tokens = (
1007
- usage.prompt_tokens if hasattr(usage, "prompt_tokens") else 0
1008
- )
1096
+ prompt_tokens = usage.prompt_tokens if hasattr(usage, "prompt_tokens") else 0
1009
1097
  completion_tokens = (
1010
1098
  usage.completion_tokens if hasattr(usage, "completion_tokens") else 0
1011
1099
  )
@@ -1020,28 +1108,29 @@ class FunctionWithUsage(Function):
1020
1108
  self.completion_tokens += completion_tokens
1021
1109
  self.total_tokens += total_tokens
1022
1110
  else:
1023
- if self.missing_usage_exception and not "preview" in kwargs:
1024
- raise Exception("Missing usage in metadata of neursymbolic engine")
1025
- else:
1026
- prompt_tokens = 0
1027
- completion_tokens = 0
1028
- total_tokens = 0
1111
+ if self.missing_usage_exception and "preview" not in kwargs:
1112
+ UserMessage(
1113
+ "Missing usage in metadata of neursymbolic engine", raise_with=Exception
1114
+ )
1115
+ prompt_tokens = 0
1116
+ completion_tokens = 0
1117
+ total_tokens = 0
1029
1118
 
1030
1119
  return res, self._format_usage(prompt_tokens, completion_tokens, total_tokens)
1031
1120
 
1032
1121
 
1033
1122
  class SelfPrompt(Expression):
1034
- _default_retry_tries = 20
1035
- _default_retry_delay = 0.5
1123
+ _default_retry_tries = 20
1124
+ _default_retry_delay = 0.5
1036
1125
  _default_retry_max_delay = -1
1037
- _default_retry_backoff = 1
1038
- _default_retry_jitter = 0
1039
- _default_retry_graceful = True
1126
+ _default_retry_backoff = 1
1127
+ _default_retry_jitter = 0
1128
+ _default_retry_graceful = True
1040
1129
 
1041
1130
  def __init__(self, *args, **kwargs):
1042
1131
  super().__init__(*args, **kwargs)
1043
1132
 
1044
- def forward(self, existing_prompt: Dict[str, str], **kwargs) -> Dict[str, str]:
1133
+ def forward(self, existing_prompt: dict[str, str], **kwargs) -> dict[str, str]:
1045
1134
  """
1046
1135
  Generate new system and user prompts based on the existing prompt.
1047
1136
 
@@ -1050,14 +1139,21 @@ class SelfPrompt(Expression):
1050
1139
  :return: A dictionary containing the new prompts in the same format:
1051
1140
  {'user': '...', 'system': '...'}
1052
1141
  """
1053
- tries = kwargs.get('tries', self._default_retry_tries)
1054
- delay = kwargs.get('delay', self._default_retry_delay)
1055
- max_delay = kwargs.get('max_delay', self._default_retry_max_delay)
1056
- backoff = kwargs.get('backoff', self._default_retry_backoff)
1057
- jitter = kwargs.get('jitter', self._default_retry_jitter)
1058
- graceful = kwargs.get('graceful', self._default_retry_graceful)
1059
-
1060
- @core_ext.retry(tries=tries, delay=delay, max_delay=max_delay, backoff=backoff, jitter=jitter, graceful=graceful)
1142
+ tries = kwargs.get("tries", self._default_retry_tries)
1143
+ delay = kwargs.get("delay", self._default_retry_delay)
1144
+ max_delay = kwargs.get("max_delay", self._default_retry_max_delay)
1145
+ backoff = kwargs.get("backoff", self._default_retry_backoff)
1146
+ jitter = kwargs.get("jitter", self._default_retry_jitter)
1147
+ graceful = kwargs.get("graceful", self._default_retry_graceful)
1148
+
1149
+ @core_ext.retry(
1150
+ tries=tries,
1151
+ delay=delay,
1152
+ max_delay=max_delay,
1153
+ backoff=backoff,
1154
+ jitter=jitter,
1155
+ graceful=graceful,
1156
+ )
1061
1157
  @core.zero_shot(
1062
1158
  prompt=(
1063
1159
  "Based on the following prompt, generate a new system (or developer) prompt and a new user prompt. "
@@ -1066,18 +1162,19 @@ class SelfPrompt(Expression):
1066
1162
  "The new user prompt should contain the user's requirements. "
1067
1163
  "Check if the input contains a 'system' or 'developer' key and use the same key in your output. "
1068
1164
  "Only output the new prompts in JSON format as shown:\n\n"
1069
- "{\"system\": \"<new system prompt>\", \"user\": \"<new user prompt>\"}\n\n"
1165
+ '{"system": "<new system prompt>", "user": "<new user prompt>"}\n\n'
1070
1166
  "OR\n\n"
1071
- "{\"developer\": \"<new developer prompt>\", \"user\": \"<new user prompt>\"}\n\n"
1167
+ '{"developer": "<new developer prompt>", "user": "<new user prompt>"}\n\n'
1072
1168
  "Maintain the same key structure as in the input prompt. Do not include any additional text."
1073
1169
  ),
1074
1170
  response_format={"type": "json_object"},
1075
1171
  post_processors=[
1076
1172
  lambda res, _: json.loads(res),
1077
1173
  ],
1078
- **kwargs
1174
+ **kwargs,
1079
1175
  )
1080
- def _func(self, sym: Symbol): pass
1176
+ def _func(self, sym: Symbol):
1177
+ pass
1081
1178
 
1082
1179
  return _func(self, self._to_symbol(existing_prompt))
1083
1180
 
@@ -1093,16 +1190,19 @@ class MetadataTracker(Expression):
1093
1190
  def __str__(self, value=None):
1094
1191
  value = value or self.metadata
1095
1192
  if isinstance(value, dict):
1096
- return '{\n\t' + ', \n\t'.join(f'"{k}": {self.__str__(v)}' for k,v in value.items()) + '\n}'
1097
- elif isinstance(value, list):
1098
- return '[' + ', '.join(self.__str__(item) for item in value) + ']'
1099
- elif isinstance(value, str):
1193
+ return (
1194
+ "{\n\t"
1195
+ + ", \n\t".join(f'"{k}": {self.__str__(v)}' for k, v in value.items())
1196
+ + "\n}"
1197
+ )
1198
+ if isinstance(value, list):
1199
+ return "[" + ", ".join(self.__str__(item) for item in value) + "]"
1200
+ if isinstance(value, str):
1100
1201
  return f'"{value}"'
1101
- else:
1102
- return f"\n\t {value}"
1202
+ return f"\n\t {value}"
1103
1203
 
1104
- def __new__(cls, *args, **kwargs):
1105
- cls._lock = getattr(cls, '_lock', Lock())
1204
+ def __new__(cls, *_args, **_kwargs):
1205
+ cls._lock = getattr(cls, "_lock", Lock())
1106
1206
  with cls._lock:
1107
1207
  instance = super().__new__(cls)
1108
1208
  instance._metadata = {}
@@ -1122,25 +1222,26 @@ class MetadataTracker(Expression):
1122
1222
 
1123
1223
  def _trace_calls(self, frame, event, arg):
1124
1224
  if not self._trace:
1125
- return
1225
+ return None
1126
1226
 
1127
- if event == 'return' and frame.f_code.co_name == 'forward':
1128
- # Check if this is an engine forward call
1129
- if ('self' in frame.f_locals
1130
- and
1131
- isinstance(frame.f_locals['self'], Engine)):
1132
- _, metadata = arg # arg contains return value on 'return' event
1133
- engine_name = frame.f_locals['self'].__class__.__name__
1134
- model_name = frame.f_locals['self'].model
1135
- self._metadata[(self._metadata_id, engine_name, model_name)] = metadata
1136
- self._metadata_id += 1
1227
+ if (
1228
+ event == "return"
1229
+ and frame.f_code.co_name == "forward"
1230
+ and "self" in frame.f_locals
1231
+ and isinstance(frame.f_locals["self"], Engine)
1232
+ ):
1233
+ _, metadata = arg # arg contains return value on 'return' event
1234
+ engine_name = frame.f_locals["self"].__class__.__name__
1235
+ model_name = frame.f_locals["self"].model
1236
+ self._metadata[(self._metadata_id, engine_name, model_name)] = metadata
1237
+ self._metadata_id += 1
1137
1238
 
1138
1239
  return self._trace_calls
1139
1240
 
1140
1241
  def _accumulate_completion_token_details(self):
1141
1242
  """Parses the return object and accumulates completion token details per token type"""
1142
1243
  if not self._metadata:
1143
- CustomUserWarning("No metadata available to generate usage details.")
1244
+ UserMessage("No metadata available to generate usage details.")
1144
1245
  return {}
1145
1246
 
1146
1247
  token_details = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
@@ -1151,46 +1252,149 @@ class MetadataTracker(Expression):
1151
1252
  try:
1152
1253
  if engine_name == "GroqEngine":
1153
1254
  usage = metadata["raw_output"].usage
1154
- token_details[(engine_name, model_name)]["usage"]["completion_tokens"] += usage.completion_tokens
1155
- token_details[(engine_name, model_name)]["usage"]["prompt_tokens"] += usage.prompt_tokens
1156
- token_details[(engine_name, model_name)]["usage"]["total_tokens"] += usage.total_tokens
1255
+ token_details[(engine_name, model_name)]["usage"]["completion_tokens"] += (
1256
+ usage.completion_tokens
1257
+ )
1258
+ token_details[(engine_name, model_name)]["usage"]["prompt_tokens"] += (
1259
+ usage.prompt_tokens
1260
+ )
1261
+ token_details[(engine_name, model_name)]["usage"]["total_tokens"] += (
1262
+ usage.total_tokens
1263
+ )
1157
1264
  token_details[(engine_name, model_name)]["usage"]["total_calls"] += 1
1158
1265
  #!: Backward compatibility for components like `RuntimeInfo`
1159
- token_details[(engine_name, model_name)]["prompt_breakdown"]["cached_tokens"] += 0 # Assignment not allowed with defualtdict
1160
- token_details[(engine_name, model_name)]["completion_breakdown"]["reasoning_tokens"] += 0
1266
+ token_details[(engine_name, model_name)]["prompt_breakdown"][
1267
+ "cached_tokens"
1268
+ ] += 0 # Assignment not allowed with defualtdict
1269
+ token_details[(engine_name, model_name)]["completion_breakdown"][
1270
+ "reasoning_tokens"
1271
+ ] += 0
1272
+ elif engine_name == "ParallelEngine":
1273
+ token_details[(engine_name, None)]["usage"]["total_calls"] += 1
1274
+ # There are no model-specific tokens for this engine
1275
+ token_details[(engine_name, None)]["usage"]["completion_tokens"] += 0
1276
+ token_details[(engine_name, None)]["usage"]["prompt_tokens"] += 0
1277
+ token_details[(engine_name, None)]["usage"]["total_tokens"] += 0
1278
+ #!: Backward compatibility for components like `RuntimeInfo`
1279
+ token_details[(engine_name, None)]["prompt_breakdown"]["cached_tokens"] += (
1280
+ 0 # Assignment not allowed with defualtdict
1281
+ )
1282
+ token_details[(engine_name, None)]["completion_breakdown"][
1283
+ "reasoning_tokens"
1284
+ ] += 0
1161
1285
  elif engine_name in ("GPTXChatEngine", "GPTXReasoningEngine"):
1162
1286
  usage = metadata["raw_output"].usage
1163
- token_details[(engine_name, model_name)]["usage"]["completion_tokens"] += usage.completion_tokens
1164
- token_details[(engine_name, model_name)]["usage"]["prompt_tokens"] += usage.prompt_tokens
1165
- token_details[(engine_name, model_name)]["usage"]["total_tokens"] += usage.total_tokens
1287
+ token_details[(engine_name, model_name)]["usage"]["completion_tokens"] += (
1288
+ usage.completion_tokens
1289
+ )
1290
+ token_details[(engine_name, model_name)]["usage"]["prompt_tokens"] += (
1291
+ usage.prompt_tokens
1292
+ )
1293
+ token_details[(engine_name, model_name)]["usage"]["total_tokens"] += (
1294
+ usage.total_tokens
1295
+ )
1166
1296
  token_details[(engine_name, model_name)]["usage"]["total_calls"] += 1
1167
- token_details[(engine_name, model_name)]["completion_breakdown"]["accepted_prediction_tokens"] += usage.completion_tokens_details.accepted_prediction_tokens
1168
- token_details[(engine_name, model_name)]["completion_breakdown"]["rejected_prediction_tokens"] += usage.completion_tokens_details.rejected_prediction_tokens
1169
- token_details[(engine_name, model_name)]["completion_breakdown"]["audio_tokens"] += usage.completion_tokens_details.audio_tokens
1170
- token_details[(engine_name, model_name)]["completion_breakdown"]["reasoning_tokens"] += usage.completion_tokens_details.reasoning_tokens
1171
- token_details[(engine_name, model_name)]["prompt_breakdown"]["audio_tokens"] += usage.prompt_tokens_details.audio_tokens
1172
- token_details[(engine_name, model_name)]["prompt_breakdown"]["cached_tokens"] += usage.prompt_tokens_details.cached_tokens
1297
+ token_details[(engine_name, model_name)]["completion_breakdown"][
1298
+ "accepted_prediction_tokens"
1299
+ ] += usage.completion_tokens_details.accepted_prediction_tokens
1300
+ token_details[(engine_name, model_name)]["completion_breakdown"][
1301
+ "rejected_prediction_tokens"
1302
+ ] += usage.completion_tokens_details.rejected_prediction_tokens
1303
+ token_details[(engine_name, model_name)]["completion_breakdown"][
1304
+ "audio_tokens"
1305
+ ] += usage.completion_tokens_details.audio_tokens
1306
+ token_details[(engine_name, model_name)]["completion_breakdown"][
1307
+ "reasoning_tokens"
1308
+ ] += usage.completion_tokens_details.reasoning_tokens
1309
+ token_details[(engine_name, model_name)]["prompt_breakdown"][
1310
+ "audio_tokens"
1311
+ ] += usage.prompt_tokens_details.audio_tokens
1312
+ token_details[(engine_name, model_name)]["prompt_breakdown"][
1313
+ "cached_tokens"
1314
+ ] += usage.prompt_tokens_details.cached_tokens
1173
1315
  elif engine_name == "GPTXSearchEngine":
1174
1316
  usage = metadata["raw_output"].usage
1175
- token_details[(engine_name, model_name)]["usage"]["prompt_tokens"] += usage.input_tokens
1176
- token_details[(engine_name, model_name)]["usage"]["completion_tokens"] += usage.output_tokens
1177
- token_details[(engine_name, model_name)]["usage"]["total_tokens"] += usage.total_tokens
1317
+ token_details[(engine_name, model_name)]["usage"]["prompt_tokens"] += (
1318
+ usage.input_tokens
1319
+ )
1320
+ token_details[(engine_name, model_name)]["usage"]["completion_tokens"] += (
1321
+ usage.output_tokens
1322
+ )
1323
+ token_details[(engine_name, model_name)]["usage"]["total_tokens"] += (
1324
+ usage.total_tokens
1325
+ )
1178
1326
  token_details[(engine_name, model_name)]["usage"]["total_calls"] += 1
1179
- token_details[(engine_name, model_name)]["prompt_breakdown"]["cached_tokens"] += usage.input_tokens_details.cached_tokens
1180
- token_details[(engine_name, model_name)]["completion_breakdown"]["reasoning_tokens"] += usage.output_tokens_details.reasoning_tokens
1327
+ token_details[(engine_name, model_name)]["prompt_breakdown"][
1328
+ "cached_tokens"
1329
+ ] += usage.input_tokens_details.cached_tokens
1330
+ token_details[(engine_name, model_name)]["completion_breakdown"][
1331
+ "reasoning_tokens"
1332
+ ] += usage.output_tokens_details.reasoning_tokens
1181
1333
  else:
1182
1334
  logger.warning(f"Tracking {engine_name} is not supported.")
1183
1335
  continue
1184
1336
  except Exception as e:
1185
- CustomUserWarning(f"Failed to parse metadata for {engine_name}: {e}", raise_with=AttributeError)
1337
+ UserMessage(
1338
+ f"Failed to parse metadata for {engine_name}: {e}", raise_with=AttributeError
1339
+ )
1186
1340
 
1187
1341
  # Convert to normal dict
1188
1342
  return {**token_details}
1189
1343
 
1344
+ def _can_accumulate_engine(self, engine_name: str) -> bool:
1345
+ supported_engines = ("GPTXChatEngine", "GPTXReasoningEngine", "GPTXSearchEngine")
1346
+ return engine_name in supported_engines
1347
+
1348
+ def _accumulate_time_field(self, accumulated: dict, metadata: dict) -> None:
1349
+ if "time" in metadata and "time" in accumulated:
1350
+ accumulated["time"] += metadata["time"]
1351
+
1352
+ def _accumulate_usage_fields(self, accumulated: dict, metadata: dict) -> None:
1353
+ if "raw_output" not in metadata or "raw_output" not in accumulated:
1354
+ return
1355
+
1356
+ metadata_raw_output = metadata["raw_output"]
1357
+ accumulated_raw_output = accumulated["raw_output"]
1358
+ if not hasattr(metadata_raw_output, "usage") or not hasattr(
1359
+ accumulated_raw_output, "usage"
1360
+ ):
1361
+ return
1362
+
1363
+ current_usage = metadata_raw_output.usage
1364
+ accumulated_usage = accumulated_raw_output.usage
1365
+
1366
+ for attr in ["completion_tokens", "prompt_tokens", "total_tokens"]:
1367
+ if hasattr(current_usage, attr) and hasattr(accumulated_usage, attr):
1368
+ setattr(
1369
+ accumulated_usage,
1370
+ attr,
1371
+ getattr(accumulated_usage, attr) + getattr(current_usage, attr),
1372
+ )
1373
+
1374
+ for detail_attr in ["completion_tokens_details", "prompt_tokens_details"]:
1375
+ if not hasattr(current_usage, detail_attr) or not hasattr(
1376
+ accumulated_usage, detail_attr
1377
+ ):
1378
+ continue
1379
+
1380
+ current_details = getattr(current_usage, detail_attr)
1381
+ accumulated_details = getattr(accumulated_usage, detail_attr)
1382
+
1383
+ for attr in dir(current_details):
1384
+ if attr.startswith("_") or not hasattr(accumulated_details, attr):
1385
+ continue
1386
+
1387
+ current_val = getattr(current_details, attr)
1388
+ accumulated_val = getattr(accumulated_details, attr)
1389
+ if isinstance(current_val, (int, float)) and isinstance(
1390
+ accumulated_val, (int, float)
1391
+ ):
1392
+ setattr(accumulated_details, attr, accumulated_val + current_val)
1393
+
1190
1394
  def _accumulate_metadata(self):
1191
1395
  """Accumulates metadata across all tracked engine calls."""
1192
1396
  if not self._metadata:
1193
- CustomUserWarning("No metadata available to generate usage details.")
1397
+ UserMessage("No metadata available to generate usage details.")
1194
1398
  return {}
1195
1399
 
1196
1400
  # Use first entry as base
@@ -1199,39 +1403,14 @@ class MetadataTracker(Expression):
1199
1403
 
1200
1404
  # Skipz first entry
1201
1405
  for (_, engine_name), metadata in list(self._metadata.items())[1:]:
1202
- if engine_name not in ("GPTXChatEngine", "GPTXReasoningEngine", "GPTXSearchEngine"):
1203
- logger.warning(f"Metadata accumulation for {engine_name} is not supported. Try `.usage` instead for now.")
1406
+ if not self._can_accumulate_engine(engine_name):
1407
+ logger.warning(
1408
+ f"Metadata accumulation for {engine_name} is not supported. Try `.usage` instead for now."
1409
+ )
1204
1410
  continue
1205
1411
 
1206
- # Accumulate time if it exists
1207
- if 'time' in metadata and 'time' in accumulated:
1208
- accumulated['time'] += metadata['time']
1209
-
1210
- # Handle usage stats accumulation
1211
- if 'raw_output' in metadata and 'raw_output' in accumulated:
1212
- if hasattr(metadata['raw_output'], 'usage') and hasattr(accumulated['raw_output'], 'usage'):
1213
- current_usage = metadata['raw_output'].usage
1214
- accumulated_usage = accumulated['raw_output'].usage
1215
-
1216
- # Accumulate token counts
1217
- for attr in ['completion_tokens', 'prompt_tokens', 'total_tokens']:
1218
- if hasattr(current_usage, attr) and hasattr(accumulated_usage, attr):
1219
- setattr(accumulated_usage, attr,
1220
- getattr(accumulated_usage, attr) + getattr(current_usage, attr))
1221
-
1222
- # Handle nested token details if they exist
1223
- for detail_attr in ['completion_tokens_details', 'prompt_tokens_details']:
1224
- if hasattr(current_usage, detail_attr) and hasattr(accumulated_usage, detail_attr):
1225
- current_details = getattr(current_usage, detail_attr)
1226
- accumulated_details = getattr(accumulated_usage, detail_attr)
1227
-
1228
- # Accumulate all numeric attributes in the details
1229
- for attr in dir(current_details):
1230
- if not attr.startswith('_') and hasattr(accumulated_details, attr):
1231
- current_val = getattr(current_details, attr)
1232
- accumulated_val = getattr(accumulated_details, attr)
1233
- if isinstance(current_val, (int, float)) and isinstance(accumulated_val, (int, float)):
1234
- setattr(accumulated_details, attr, accumulated_val + current_val)
1412
+ self._accumulate_time_field(accumulated, metadata)
1413
+ self._accumulate_usage_fields(accumulated, metadata)
1235
1414
 
1236
1415
  return accumulated
1237
1416
 
@@ -1250,7 +1429,8 @@ class MetadataTracker(Expression):
1250
1429
 
1251
1430
  class DynamicEngine(Expression):
1252
1431
  """Context manager for dynamically switching neurosymbolic engine models."""
1253
- def __init__(self, model: str, api_key: str, debug: bool = False, **kwargs):
1432
+
1433
+ def __init__(self, model: str, api_key: str, _debug: bool = False, **_kwargs):
1254
1434
  super().__init__()
1255
1435
  self.model = model
1256
1436
  self.api_key = api_key
@@ -1259,8 +1439,8 @@ class DynamicEngine(Expression):
1259
1439
  self.engine_instance = None
1260
1440
  self._ctx_token = None
1261
1441
 
1262
- def __new__(cls, *args, **kwargs):
1263
- cls._lock = getattr(cls, '_lock', Lock())
1442
+ def __new__(cls, *_args, **_kwargs):
1443
+ cls._lock = getattr(cls, "_lock", Lock())
1264
1444
  with cls._lock:
1265
1445
  instance = super().__new__(cls)
1266
1446
  instance._metadata = {}
@@ -1293,11 +1473,178 @@ class DynamicEngine(Expression):
1293
1473
 
1294
1474
  def _create_engine_instance(self):
1295
1475
  """Create an engine instance based on the model name."""
1296
- from .backend.engines.neurosymbolic import ENGINE_MAPPING
1476
+ # Deferred to avoid components <-> neurosymbolic engine circular imports.
1477
+ from .backend.engines.neurosymbolic import ENGINE_MAPPING # noqa
1478
+
1297
1479
  try:
1298
1480
  engine_class = ENGINE_MAPPING.get(self.model)
1299
1481
  if engine_class is None:
1300
- raise ValueError(f"Unsupported model '{self.model}'")
1482
+ UserMessage(f"Unsupported model '{self.model}'", raise_with=ValueError)
1301
1483
  return engine_class(api_key=self.api_key, model=self.model)
1302
1484
  except Exception as e:
1303
- raise ValueError(f"Failed to create engine for model '{self.model}': {str(e)}")
1485
+ UserMessage(
1486
+ f"Failed to create engine for model '{self.model}': {e!s}", raise_with=ValueError
1487
+ )
1488
+
1489
+
1490
+ # Chonkie chunker imports - lazy loaded
1491
+ _CHONKIE_MODULES = None
1492
+ _CHUNKER_MAPPING = None
1493
+ _CHONKIE_AVAILABLE = None
1494
+
1495
+
1496
+ def _lazy_import_chonkie():
1497
+ """Lazily import chonkie modules when needed."""
1498
+ global _CHONKIE_MODULES, _CHUNKER_MAPPING, _CHONKIE_AVAILABLE
1499
+
1500
+ if _CHONKIE_MODULES is not None:
1501
+ return _CHONKIE_MODULES
1502
+
1503
+ try:
1504
+ from chonkie import ( # noqa
1505
+ CodeChunker,
1506
+ LateChunker,
1507
+ NeuralChunker,
1508
+ RecursiveChunker,
1509
+ SemanticChunker,
1510
+ SentenceChunker,
1511
+ SlumberChunker,
1512
+ TableChunker,
1513
+ TokenChunker,
1514
+ )
1515
+ from chonkie.embeddings.base import BaseEmbeddings # noqa
1516
+ from tokenizers import Tokenizer # noqa
1517
+
1518
+ _CHONKIE_MODULES = {
1519
+ "CodeChunker": CodeChunker,
1520
+ "LateChunker": LateChunker,
1521
+ "NeuralChunker": NeuralChunker,
1522
+ "RecursiveChunker": RecursiveChunker,
1523
+ "SemanticChunker": SemanticChunker,
1524
+ "SentenceChunker": SentenceChunker,
1525
+ "SlumberChunker": SlumberChunker,
1526
+ "TableChunker": TableChunker,
1527
+ "TokenChunker": TokenChunker,
1528
+ "BaseEmbeddings": BaseEmbeddings,
1529
+ "Tokenizer": Tokenizer,
1530
+ }
1531
+ _CHUNKER_MAPPING = {
1532
+ "TokenChunker": TokenChunker,
1533
+ "SentenceChunker": SentenceChunker,
1534
+ "RecursiveChunker": RecursiveChunker,
1535
+ "SemanticChunker": SemanticChunker,
1536
+ "CodeChunker": CodeChunker,
1537
+ "LateChunker": LateChunker,
1538
+ "NeuralChunker": NeuralChunker,
1539
+ "SlumberChunker": SlumberChunker,
1540
+ "TableChunker": TableChunker,
1541
+ }
1542
+ _CHONKIE_AVAILABLE = True
1543
+ except ImportError:
1544
+ _CHONKIE_MODULES = {}
1545
+ _CHUNKER_MAPPING = {}
1546
+ _CHONKIE_AVAILABLE = False
1547
+
1548
+ return _CHONKIE_MODULES
1549
+
1550
+
1551
+ def _get_chunker_mapping():
1552
+ """Get the chunker mapping, lazily importing chonkie if needed."""
1553
+ if _CHUNKER_MAPPING is None:
1554
+ _lazy_import_chonkie()
1555
+ return _CHUNKER_MAPPING or {}
1556
+
1557
+
1558
+ def _is_chonkie_available():
1559
+ """Check if chonkie is available, lazily importing if needed."""
1560
+ if _CHONKIE_AVAILABLE is None:
1561
+ _lazy_import_chonkie()
1562
+ return _CHONKIE_AVAILABLE or False
1563
+
1564
+
1565
+ @beartype
1566
+ class ChonkieChunker(Expression):
1567
+ def __init__(
1568
+ self,
1569
+ tokenizer_name: str | None = "gpt2",
1570
+ embedding_model_name: str | None = "minishlab/potion-base-8M",
1571
+ **symai_kwargs,
1572
+ ):
1573
+ super().__init__(**symai_kwargs)
1574
+ self.tokenizer_name = tokenizer_name
1575
+ self.embedding_model_name = embedding_model_name
1576
+
1577
+ def forward(
1578
+ self, data: Symbol, chunker_name: str | None = "RecursiveChunker", **chunker_kwargs
1579
+ ) -> Symbol:
1580
+ if not _is_chonkie_available():
1581
+ UserMessage(
1582
+ "chonkie library is not installed. Please install it with `pip install chonkie tokenizers`.",
1583
+ raise_with=ImportError,
1584
+ )
1585
+ chunker = self._resolve_chunker(chunker_name, **chunker_kwargs)
1586
+ chunks = [ChonkieChunker.clean_text(chunk.text) for chunk in chunker(data.value)]
1587
+ return self._to_symbol(chunks)
1588
+
1589
+ def _resolve_chunker(self, chunker_name: str, **chunker_kwargs):
1590
+ """Resolve and instantiate a chunker by name."""
1591
+ chunker_mapping = _get_chunker_mapping()
1592
+
1593
+ if chunker_name not in chunker_mapping:
1594
+ msg = (
1595
+ f"Chunker {chunker_name} not found. Available chunkers: {list(chunker_mapping.keys())}. "
1596
+ f"See docs (https://docs.chonkie.ai/getting-started/introduction) for more info."
1597
+ )
1598
+ raise ValueError(msg)
1599
+
1600
+ chunker_class = chunker_mapping[chunker_name]
1601
+ chonkie_modules = _lazy_import_chonkie()
1602
+ Tokenizer = chonkie_modules.get("Tokenizer")
1603
+
1604
+ # Tokenizer-based chunkers (use tokenizer_name)
1605
+ if chunker_name in ["TokenChunker", "SentenceChunker", "RecursiveChunker"]:
1606
+ if Tokenizer is None:
1607
+ UserMessage(
1608
+ "Tokenizers library is not installed. Please install it with `pip install tokenizers`.",
1609
+ raise_with=ImportError,
1610
+ )
1611
+ tokenizer = Tokenizer.from_pretrained(self.tokenizer_name)
1612
+ return chunker_class(tokenizer, **chunker_kwargs)
1613
+
1614
+ # Embedding-based chunkers (use embedding_model_name)
1615
+ if chunker_name in ["SemanticChunker", "LateChunker"]:
1616
+ return chunker_class(embedding_model=self.embedding_model_name, **chunker_kwargs)
1617
+
1618
+ # CodeChunker and TableChunker use tokenizer (can use string or Tokenizer object)
1619
+ if chunker_name in ["CodeChunker", "TableChunker"]:
1620
+ # These can accept tokenizer as string (default 'character') or Tokenizer object
1621
+ # If tokenizer not provided in kwargs, use tokenizer_name
1622
+ if "tokenizer" not in chunker_kwargs:
1623
+ chunker_kwargs["tokenizer"] = self.tokenizer_name
1624
+ return chunker_class(**chunker_kwargs)
1625
+
1626
+ # SlumberChunker uses tokenizer (can use string or Tokenizer object)
1627
+ if chunker_name == "SlumberChunker":
1628
+ # SlumberChunker can accept tokenizer as string or Tokenizer object
1629
+ # If tokenizer not provided in kwargs, use tokenizer_name
1630
+ if "tokenizer" not in chunker_kwargs:
1631
+ chunker_kwargs["tokenizer"] = self.tokenizer_name
1632
+ return chunker_class(**chunker_kwargs)
1633
+
1634
+ # NeuralChunker uses model parameter (defaults provided by chonkie)
1635
+ if chunker_name == "NeuralChunker":
1636
+ return chunker_class(**chunker_kwargs)
1637
+
1638
+ msg = (
1639
+ f"Chunker {chunker_name} not properly configured. "
1640
+ f"Available chunkers: {list(chunker_mapping.keys())}."
1641
+ )
1642
+ raise ValueError(msg)
1643
+
1644
+ @staticmethod
1645
+ def clean_text(text: str) -> str:
1646
+ """Cleans text by removing problematic characters."""
1647
+ text = text.replace("\x00", "") # Remove null bytes (\x00)
1648
+ return text.encode("utf-8", errors="ignore").decode(
1649
+ "utf-8"
1650
+ ) # Replace invalid UTF-8 sequences