symbolicai 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- symai/__init__.py +198 -134
- symai/backend/base.py +51 -51
- symai/backend/engines/drawing/engine_bfl.py +33 -33
- symai/backend/engines/drawing/engine_gpt_image.py +4 -10
- symai/backend/engines/embedding/engine_llama_cpp.py +50 -35
- symai/backend/engines/embedding/engine_openai.py +22 -16
- symai/backend/engines/execute/engine_python.py +16 -16
- symai/backend/engines/files/engine_io.py +51 -49
- symai/backend/engines/imagecaptioning/engine_blip2.py +27 -23
- symai/backend/engines/imagecaptioning/engine_llavacpp_client.py +53 -46
- symai/backend/engines/index/engine_pinecone.py +116 -88
- symai/backend/engines/index/engine_qdrant.py +1011 -0
- symai/backend/engines/index/engine_vectordb.py +78 -52
- symai/backend/engines/lean/engine_lean4.py +65 -25
- symai/backend/engines/neurosymbolic/__init__.py +28 -28
- symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_chat.py +137 -135
- symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_reasoning.py +145 -152
- symai/backend/engines/neurosymbolic/engine_cerebras.py +328 -0
- symai/backend/engines/neurosymbolic/engine_deepseekX_reasoning.py +75 -49
- symai/backend/engines/neurosymbolic/engine_google_geminiX_reasoning.py +199 -155
- symai/backend/engines/neurosymbolic/engine_groq.py +106 -72
- symai/backend/engines/neurosymbolic/engine_huggingface.py +100 -67
- symai/backend/engines/neurosymbolic/engine_llama_cpp.py +121 -93
- symai/backend/engines/neurosymbolic/engine_openai_gptX_chat.py +213 -132
- symai/backend/engines/neurosymbolic/engine_openai_gptX_reasoning.py +180 -137
- symai/backend/engines/ocr/engine_apilayer.py +18 -20
- symai/backend/engines/output/engine_stdout.py +9 -9
- symai/backend/engines/{webscraping → scrape}/engine_requests.py +25 -11
- symai/backend/engines/search/engine_openai.py +95 -83
- symai/backend/engines/search/engine_parallel.py +665 -0
- symai/backend/engines/search/engine_perplexity.py +40 -41
- symai/backend/engines/search/engine_serpapi.py +33 -28
- symai/backend/engines/speech_to_text/engine_local_whisper.py +37 -27
- symai/backend/engines/symbolic/engine_wolframalpha.py +14 -8
- symai/backend/engines/text_to_speech/engine_openai.py +15 -19
- symai/backend/engines/text_vision/engine_clip.py +34 -28
- symai/backend/engines/userinput/engine_console.py +3 -4
- symai/backend/mixin/anthropic.py +48 -40
- symai/backend/mixin/deepseek.py +4 -5
- symai/backend/mixin/google.py +5 -4
- symai/backend/mixin/groq.py +2 -4
- symai/backend/mixin/openai.py +132 -110
- symai/backend/settings.py +14 -14
- symai/chat.py +164 -94
- symai/collect/dynamic.py +13 -11
- symai/collect/pipeline.py +39 -31
- symai/collect/stats.py +109 -69
- symai/components.py +556 -238
- symai/constraints.py +14 -5
- symai/core.py +1495 -1210
- symai/core_ext.py +55 -50
- symai/endpoints/api.py +113 -58
- symai/extended/api_builder.py +22 -17
- symai/extended/arxiv_pdf_parser.py +13 -5
- symai/extended/bibtex_parser.py +8 -4
- symai/extended/conversation.py +88 -69
- symai/extended/document.py +40 -27
- symai/extended/file_merger.py +45 -7
- symai/extended/graph.py +38 -24
- symai/extended/html_style_template.py +17 -11
- symai/extended/interfaces/blip_2.py +1 -1
- symai/extended/interfaces/clip.py +4 -2
- symai/extended/interfaces/console.py +5 -3
- symai/extended/interfaces/dall_e.py +3 -1
- symai/extended/interfaces/file.py +2 -0
- symai/extended/interfaces/flux.py +3 -1
- symai/extended/interfaces/gpt_image.py +15 -6
- symai/extended/interfaces/input.py +2 -1
- symai/extended/interfaces/llava.py +1 -1
- symai/extended/interfaces/{naive_webscraping.py → naive_scrape.py} +3 -2
- symai/extended/interfaces/naive_vectordb.py +2 -2
- symai/extended/interfaces/ocr.py +4 -2
- symai/extended/interfaces/openai_search.py +2 -0
- symai/extended/interfaces/parallel.py +30 -0
- symai/extended/interfaces/perplexity.py +2 -0
- symai/extended/interfaces/pinecone.py +6 -4
- symai/extended/interfaces/python.py +2 -0
- symai/extended/interfaces/serpapi.py +2 -0
- symai/extended/interfaces/terminal.py +0 -1
- symai/extended/interfaces/tts.py +2 -1
- symai/extended/interfaces/whisper.py +2 -1
- symai/extended/interfaces/wolframalpha.py +1 -0
- symai/extended/metrics/__init__.py +1 -1
- symai/extended/metrics/similarity.py +5 -2
- symai/extended/os_command.py +31 -22
- symai/extended/packages/symdev.py +39 -34
- symai/extended/packages/sympkg.py +30 -27
- symai/extended/packages/symrun.py +46 -35
- symai/extended/repo_cloner.py +10 -9
- symai/extended/seo_query_optimizer.py +15 -12
- symai/extended/solver.py +104 -76
- symai/extended/summarizer.py +8 -7
- symai/extended/taypan_interpreter.py +10 -9
- symai/extended/vectordb.py +28 -15
- symai/formatter/formatter.py +39 -31
- symai/formatter/regex.py +46 -44
- symai/functional.py +184 -86
- symai/imports.py +85 -51
- symai/interfaces.py +1 -1
- symai/memory.py +33 -24
- symai/menu/screen.py +28 -19
- symai/misc/console.py +27 -27
- symai/misc/loader.py +4 -3
- symai/models/base.py +147 -76
- symai/models/errors.py +1 -1
- symai/ops/__init__.py +1 -1
- symai/ops/measures.py +17 -14
- symai/ops/primitives.py +933 -635
- symai/post_processors.py +28 -24
- symai/pre_processors.py +58 -52
- symai/processor.py +15 -9
- symai/prompts.py +714 -649
- symai/server/huggingface_server.py +115 -32
- symai/server/llama_cpp_server.py +14 -6
- symai/server/qdrant_server.py +206 -0
- symai/shell.py +98 -39
- symai/shellsv.py +307 -223
- symai/strategy.py +135 -81
- symai/symbol.py +276 -225
- symai/utils.py +62 -46
- {symbolicai-1.0.0.dist-info → symbolicai-1.1.0.dist-info}/METADATA +19 -9
- symbolicai-1.1.0.dist-info/RECORD +168 -0
- symbolicai-1.0.0.dist-info/RECORD +0 -163
- {symbolicai-1.0.0.dist-info → symbolicai-1.1.0.dist-info}/WHEEL +0 -0
- {symbolicai-1.0.0.dist-info → symbolicai-1.1.0.dist-info}/entry_points.txt +0 -0
- {symbolicai-1.0.0.dist-info → symbolicai-1.1.0.dist-info}/licenses/LICENSE +0 -0
- {symbolicai-1.0.0.dist-info → symbolicai-1.1.0.dist-info}/top_level.txt +0 -0
symai/components.py
CHANGED
|
@@ -11,7 +11,11 @@ from string import ascii_lowercase, ascii_uppercase
|
|
|
11
11
|
from threading import Lock
|
|
12
12
|
from typing import TYPE_CHECKING, Union
|
|
13
13
|
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
14
17
|
import numpy as np
|
|
18
|
+
from beartype import beartype
|
|
15
19
|
from box import Box
|
|
16
20
|
from loguru import logger
|
|
17
21
|
from pyvis.network import Network
|
|
@@ -43,33 +47,42 @@ _DEFAULT_PARAGRAPH_FORMATTER = ParagraphFormatter()
|
|
|
43
47
|
|
|
44
48
|
|
|
45
49
|
class GraphViz(Expression):
|
|
46
|
-
def __init__(
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
notebook=True,
|
|
53
|
+
cdn_resources="remote",
|
|
54
|
+
bgcolor="#222222",
|
|
55
|
+
font_color="white",
|
|
56
|
+
height="750px",
|
|
57
|
+
width="100%",
|
|
58
|
+
select_menu=True,
|
|
59
|
+
filter_menu=True,
|
|
60
|
+
**kwargs,
|
|
61
|
+
):
|
|
56
62
|
super().__init__(**kwargs)
|
|
57
|
-
self.net
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
63
|
+
self.net = Network(
|
|
64
|
+
notebook=notebook,
|
|
65
|
+
cdn_resources=cdn_resources,
|
|
66
|
+
bgcolor=bgcolor,
|
|
67
|
+
font_color=font_color,
|
|
68
|
+
height=height,
|
|
69
|
+
width=width,
|
|
70
|
+
select_menu=select_menu,
|
|
71
|
+
filter_menu=filter_menu,
|
|
72
|
+
)
|
|
65
73
|
|
|
66
74
|
def forward(self, sym: Symbol, file_path: str, **_kwargs):
|
|
67
75
|
nodes = [str(n) if n.value else n.__repr__(simplified=True) for n in sym.nodes]
|
|
68
|
-
edges = [
|
|
69
|
-
|
|
76
|
+
edges = [
|
|
77
|
+
(
|
|
78
|
+
str(e[0]) if e[0].value else e[0].__repr__(simplified=True),
|
|
79
|
+
str(e[1]) if e[1].value else e[1].__repr__(simplified=True),
|
|
80
|
+
)
|
|
81
|
+
for e in sym.edges
|
|
82
|
+
]
|
|
70
83
|
self.net.add_nodes(nodes)
|
|
71
84
|
self.net.add_edges(edges)
|
|
72
|
-
file_path = file_path if file_path.endswith(
|
|
85
|
+
file_path = file_path if file_path.endswith(".html") else file_path + ".html"
|
|
73
86
|
return self.net.show(file_path)
|
|
74
87
|
|
|
75
88
|
|
|
@@ -109,12 +122,14 @@ class Try(Expression):
|
|
|
109
122
|
class Lambda(Expression):
|
|
110
123
|
def __init__(self, callable: Callable, **kwargs):
|
|
111
124
|
super().__init__(**kwargs)
|
|
125
|
+
|
|
112
126
|
def _callable(*args, **kwargs):
|
|
113
127
|
kw = {
|
|
114
|
-
|
|
115
|
-
|
|
128
|
+
"args": args,
|
|
129
|
+
"kwargs": kwargs,
|
|
116
130
|
}
|
|
117
131
|
return callable(kw)
|
|
132
|
+
|
|
118
133
|
self.callable: Callable = _callable
|
|
119
134
|
|
|
120
135
|
def forward(self, *args, **kwargs) -> Symbol:
|
|
@@ -140,8 +155,8 @@ class Output(Expression):
|
|
|
140
155
|
self.verbose: bool = verbose
|
|
141
156
|
|
|
142
157
|
def forward(self, *args, **kwargs) -> Expression:
|
|
143
|
-
kwargs[
|
|
144
|
-
kwargs[
|
|
158
|
+
kwargs["verbose"] = self.verbose
|
|
159
|
+
kwargs["handler"] = self.handler
|
|
145
160
|
return self.output(*args, expr=self.expr, **kwargs)
|
|
146
161
|
|
|
147
162
|
|
|
@@ -166,32 +181,34 @@ class Sequence(TrackerTraceable):
|
|
|
166
181
|
class Parallel(Expression):
|
|
167
182
|
def __init__(self, *expr: list[Expression | Callable], sequential: bool = False, **kwargs):
|
|
168
183
|
super().__init__(**kwargs)
|
|
169
|
-
self.sequential: bool
|
|
184
|
+
self.sequential: bool = sequential
|
|
170
185
|
self.expr: list[Expression] = expr
|
|
171
|
-
self.results: list[Symbol]
|
|
186
|
+
self.results: list[Symbol] = []
|
|
172
187
|
|
|
173
188
|
def forward(self, *args, **kwargs) -> Symbol:
|
|
174
189
|
# run in sequence
|
|
175
190
|
if self.sequential:
|
|
176
191
|
return [e(*args, **kwargs) for e in self.expr]
|
|
192
|
+
|
|
177
193
|
# run in parallel
|
|
178
194
|
@core_ext.parallel(self.expr)
|
|
179
195
|
def _func(e, *args, **kwargs):
|
|
180
196
|
return e(*args, **kwargs)
|
|
197
|
+
|
|
181
198
|
self.results = _func(*args, **kwargs)
|
|
182
199
|
# final result of the parallel execution
|
|
183
200
|
return self._to_symbol(self.results)
|
|
184
201
|
|
|
185
202
|
|
|
186
|
-
|
|
203
|
+
# @TODO: BinPacker(format="...") -> ensure that data packages form a "bin" that's consistent (e.g. never break a sentence in the middle)
|
|
187
204
|
class Stream(Expression):
|
|
188
205
|
def __init__(self, expr: Expression | None = None, retrieval: str | None = None, **kwargs):
|
|
189
206
|
super().__init__(**kwargs)
|
|
190
|
-
self.char_token_ratio:
|
|
207
|
+
self.char_token_ratio: float = 0.6
|
|
191
208
|
self.expr: Expression | None = expr
|
|
192
|
-
self.retrieval:
|
|
193
|
-
self._trace:
|
|
194
|
-
self._previous_frame
|
|
209
|
+
self.retrieval: str | None = retrieval
|
|
210
|
+
self._trace: bool = False
|
|
211
|
+
self._previous_frame = None
|
|
195
212
|
|
|
196
213
|
def forward(self, sym: Symbol, **kwargs) -> Iterator:
|
|
197
214
|
sym = self._to_symbol(sym)
|
|
@@ -213,17 +230,15 @@ class Stream(Expression):
|
|
|
213
230
|
raise_with=ValueError,
|
|
214
231
|
)
|
|
215
232
|
|
|
216
|
-
res = sym.stream(expr=self.expr,
|
|
217
|
-
char_token_ratio=self.char_token_ratio,
|
|
218
|
-
**kwargs)
|
|
233
|
+
res = sym.stream(expr=self.expr, char_token_ratio=self.char_token_ratio, **kwargs)
|
|
219
234
|
if self.retrieval is not None:
|
|
220
235
|
res = list(res)
|
|
221
|
-
if self.retrieval ==
|
|
236
|
+
if self.retrieval == "all":
|
|
222
237
|
return res
|
|
223
|
-
if self.retrieval ==
|
|
238
|
+
if self.retrieval == "longest":
|
|
224
239
|
res = sorted(res, key=lambda x: len(x), reverse=True)
|
|
225
240
|
return res[0]
|
|
226
|
-
if self.retrieval ==
|
|
241
|
+
if self.retrieval == "contains":
|
|
227
242
|
return [r for r in res if self.expr in r]
|
|
228
243
|
UserMessage(f"Invalid retrieval method: {self.retrieval}", raise_with=ValueError)
|
|
229
244
|
|
|
@@ -241,7 +256,7 @@ class Stream(Expression):
|
|
|
241
256
|
class Trace(Expression):
|
|
242
257
|
def __init__(self, expr: Expression | None = None, engines=None, **kwargs):
|
|
243
258
|
if engines is None:
|
|
244
|
-
engines = [
|
|
259
|
+
engines = ["all"]
|
|
245
260
|
super().__init__(**kwargs)
|
|
246
261
|
self.expr: Expression = expr
|
|
247
262
|
self.engines: list[str] = engines
|
|
@@ -278,7 +293,7 @@ class Analyze(Expression):
|
|
|
278
293
|
class Log(Expression):
|
|
279
294
|
def __init__(self, expr: Expression | None = None, engines=None, **kwargs):
|
|
280
295
|
if engines is None:
|
|
281
|
-
engines = [
|
|
296
|
+
engines = ["all"]
|
|
282
297
|
super().__init__(**kwargs)
|
|
283
298
|
self.expr: Expression = expr
|
|
284
299
|
self.engines: list[str] = engines
|
|
@@ -303,7 +318,12 @@ class Log(Expression):
|
|
|
303
318
|
|
|
304
319
|
|
|
305
320
|
class Template(Expression):
|
|
306
|
-
def __init__(
|
|
321
|
+
def __init__(
|
|
322
|
+
self,
|
|
323
|
+
template: str = "<html><body>{{placeholder}}</body></html>",
|
|
324
|
+
placeholder: str = "{{placeholder}}",
|
|
325
|
+
**kwargs,
|
|
326
|
+
):
|
|
307
327
|
super().__init__(**kwargs)
|
|
308
328
|
self.placeholder = placeholder
|
|
309
329
|
self.template_ = template
|
|
@@ -333,21 +353,25 @@ class RuntimeExpression(Expression):
|
|
|
333
353
|
code = self._to_symbol(code)
|
|
334
354
|
# declare the runtime expression from the code
|
|
335
355
|
expr = self.runner(code)
|
|
356
|
+
|
|
336
357
|
def _func(sym):
|
|
337
358
|
# execute nested expression
|
|
338
|
-
return expr[
|
|
359
|
+
return expr["locals"]["_output_"](sym)
|
|
360
|
+
|
|
339
361
|
return _func
|
|
340
362
|
|
|
341
363
|
|
|
342
364
|
class Metric(Expression):
|
|
343
365
|
def __init__(self, normalize: bool = False, eps: float = 1e-8, **kwargs):
|
|
344
366
|
super().__init__(**kwargs)
|
|
345
|
-
self.normalize
|
|
346
|
-
self.eps
|
|
367
|
+
self.normalize = normalize
|
|
368
|
+
self.eps = eps
|
|
347
369
|
|
|
348
370
|
def forward(self, sym: Symbol, **_kwargs) -> Symbol:
|
|
349
371
|
sym = self._to_symbol(sym)
|
|
350
|
-
assert sym.value_type is np.ndarray or sym.value_type is list,
|
|
372
|
+
assert sym.value_type is np.ndarray or sym.value_type is list, (
|
|
373
|
+
"Metric can only be applied to numpy arrays or lists."
|
|
374
|
+
)
|
|
351
375
|
if sym.value_type is list:
|
|
352
376
|
sym._value = np.array(sym.value)
|
|
353
377
|
# compute normalization between 0 and 1
|
|
@@ -357,7 +381,7 @@ class Metric(Expression):
|
|
|
357
381
|
elif len(sym.value.shape) == 2:
|
|
358
382
|
pass
|
|
359
383
|
else:
|
|
360
|
-
UserMessage(f
|
|
384
|
+
UserMessage(f"Invalid shape: {sym.value.shape}", raise_with=ValueError)
|
|
361
385
|
# normalize between 0 and 1 and sum to 1
|
|
362
386
|
sym._value = np.exp(sym.value) / (np.exp(sym.value).sum() + self.eps)
|
|
363
387
|
return sym
|
|
@@ -413,16 +437,16 @@ _output_ = _func()
|
|
|
413
437
|
|
|
414
438
|
def forward(self, sym: Symbol, enclosure: bool = False, **kwargs) -> Symbol:
|
|
415
439
|
if enclosure or self.enclosure:
|
|
416
|
-
lines = str(sym).split(
|
|
417
|
-
lines = [
|
|
418
|
-
sym =
|
|
419
|
-
sym = self.template.replace(
|
|
440
|
+
lines = str(sym).split("\n")
|
|
441
|
+
lines = [" " + line for line in lines]
|
|
442
|
+
sym = "\n".join(lines)
|
|
443
|
+
sym = self.template.replace("{sym}", str(sym))
|
|
420
444
|
sym = self._to_symbol(sym)
|
|
421
445
|
return sym.execute(**kwargs)
|
|
422
446
|
|
|
423
447
|
|
|
424
448
|
class Convert(Expression):
|
|
425
|
-
def __init__(self, format: str =
|
|
449
|
+
def __init__(self, format: str = "Python", **kwargs):
|
|
426
450
|
super().__init__(**kwargs)
|
|
427
451
|
self.format = format
|
|
428
452
|
|
|
@@ -456,13 +480,13 @@ class Map(Expression):
|
|
|
456
480
|
|
|
457
481
|
|
|
458
482
|
class Translate(Expression):
|
|
459
|
-
def __init__(self, language: str =
|
|
483
|
+
def __init__(self, language: str = "English", **kwargs):
|
|
460
484
|
super().__init__(**kwargs)
|
|
461
485
|
self.language = language
|
|
462
486
|
|
|
463
487
|
def forward(self, sym: Symbol, **kwargs) -> Symbol:
|
|
464
488
|
sym = self._to_symbol(sym)
|
|
465
|
-
if sym.isinstanceof(f
|
|
489
|
+
if sym.isinstanceof(f"{self.language} text"):
|
|
466
490
|
return sym
|
|
467
491
|
return sym.translate(language=self.language, **kwargs)
|
|
468
492
|
|
|
@@ -494,7 +518,7 @@ class FileWriter(Expression):
|
|
|
494
518
|
|
|
495
519
|
def forward(self, sym: Symbol, **_kwargs) -> Symbol:
|
|
496
520
|
sym = self._to_symbol(sym)
|
|
497
|
-
with self.path.open(
|
|
521
|
+
with self.path.open("w") as f:
|
|
498
522
|
f.write(str(sym))
|
|
499
523
|
|
|
500
524
|
|
|
@@ -502,18 +526,18 @@ class FileReader(Expression):
|
|
|
502
526
|
@staticmethod
|
|
503
527
|
def exists(path: str) -> bool:
|
|
504
528
|
# remove slicing if any
|
|
505
|
-
_tmp
|
|
506
|
-
_splits
|
|
507
|
-
if
|
|
529
|
+
_tmp = path
|
|
530
|
+
_splits = _tmp.split("[")
|
|
531
|
+
if "[" in _tmp:
|
|
508
532
|
_tmp = _splits[0]
|
|
509
|
-
assert len(_splits) == 1 or len(_splits) == 2,
|
|
510
|
-
_tmp
|
|
533
|
+
assert len(_splits) == 1 or len(_splits) == 2, "Invalid file link format."
|
|
534
|
+
_tmp = Path(_tmp)
|
|
511
535
|
# check if file exists and is a file
|
|
512
536
|
return _tmp.is_file()
|
|
513
537
|
|
|
514
538
|
@staticmethod
|
|
515
539
|
def get_files(folder_path: str, max_depth: int = 1) -> list[str]:
|
|
516
|
-
accepted_formats = [
|
|
540
|
+
accepted_formats = [".pdf", ".md", ".txt"]
|
|
517
541
|
|
|
518
542
|
folder = Path(folder_path)
|
|
519
543
|
files = []
|
|
@@ -527,9 +551,34 @@ class FileReader(Expression):
|
|
|
527
551
|
|
|
528
552
|
@staticmethod
|
|
529
553
|
def extract_files(cmds: str) -> list[str] | None:
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
554
|
+
"""
|
|
555
|
+
Extract file paths from a command string, handling various quoting styles.
|
|
556
|
+
|
|
557
|
+
This method is used by the Qdrant RAG implementation when processing document paths.
|
|
558
|
+
It uses regex to parse file paths that may be quoted in different ways.
|
|
559
|
+
|
|
560
|
+
Regex patterns used:
|
|
561
|
+
1. Main pattern: Matches file paths in four formats:
|
|
562
|
+
- Double-quoted: "path/to/file" (handles escaped characters)
|
|
563
|
+
- Single-quoted: 'path/to/file' (handles escaped characters)
|
|
564
|
+
- Backtick-quoted: `path/to/file` (handles escaped characters)
|
|
565
|
+
- Non-quoted: path/to/file (handles escaped spaces)
|
|
566
|
+
|
|
567
|
+
2. Escape removal pattern: r"\\(.)" -> r"\1"
|
|
568
|
+
- Removes backslash escape sequences from quoted paths
|
|
569
|
+
- Example: "path\\/to\\/file" -> "path/to/file"
|
|
570
|
+
- Used for double quotes, single quotes, and backticks
|
|
571
|
+
"""
|
|
572
|
+
# Regex pattern to match file paths in various quoting styles
|
|
573
|
+
# Pattern breakdown:
|
|
574
|
+
# - (?:"((?:\\.|[^"\\])*)") : Matches double-quoted paths, capturing content while handling escapes
|
|
575
|
+
# - '((?:\\.|[^'\\])*)' : Matches single-quoted paths, capturing content while handling escapes
|
|
576
|
+
# - `((?:\\.|[^`\\])*)` : Matches backtick-quoted paths, capturing content while handling escapes
|
|
577
|
+
# - ((?:\\ |[^ ])+) : Matches non-quoted paths, allowing escaped spaces
|
|
578
|
+
pattern = (
|
|
579
|
+
r"""(?:"((?:\\.|[^"\\])*)"|'((?:\\.|[^'\\])*)'|`((?:\\.|[^`\\])*)`|((?:\\ |[^ ])+))"""
|
|
580
|
+
)
|
|
581
|
+
# Use regex to find all file path matches in the command string
|
|
533
582
|
matches = re.findall(pattern, cmds)
|
|
534
583
|
# Process the matches to handle quoted paths and normal paths
|
|
535
584
|
files = []
|
|
@@ -537,23 +586,27 @@ class FileReader(Expression):
|
|
|
537
586
|
# Each match will have 4 groups due to the pattern; only one will be non-empty
|
|
538
587
|
quoted_double, quoted_single, quoted_backtick, non_quoted = match
|
|
539
588
|
if quoted_double:
|
|
540
|
-
# Remove backslashes used for escaping inside double quotes
|
|
541
|
-
|
|
589
|
+
# Regex substitution: Remove backslashes used for escaping inside double quotes
|
|
590
|
+
# Pattern r"\\(.)" matches a backslash followed by any character and replaces with just the character
|
|
591
|
+
# Example: "path\\/to\\/file" -> "path/to/file"
|
|
592
|
+
path = re.sub(r"\\(.)", r"\1", quoted_double)
|
|
542
593
|
file = FileReader.expand_user_path(path)
|
|
543
594
|
files.append(file)
|
|
544
595
|
elif quoted_single:
|
|
545
|
-
# Remove backslashes used for escaping inside single quotes
|
|
546
|
-
|
|
596
|
+
# Regex substitution: Remove backslashes used for escaping inside single quotes
|
|
597
|
+
# Same pattern as above, applied to single-quoted paths
|
|
598
|
+
path = re.sub(r"\\(.)", r"\1", quoted_single)
|
|
547
599
|
file = FileReader.expand_user_path(path)
|
|
548
600
|
files.append(file)
|
|
549
601
|
elif quoted_backtick:
|
|
550
|
-
# Remove backslashes used for escaping inside backticks
|
|
551
|
-
|
|
602
|
+
# Regex substitution: Remove backslashes used for escaping inside backticks
|
|
603
|
+
# Same pattern as above, applied to backtick-quoted paths
|
|
604
|
+
path = re.sub(r"\\(.)", r"\1", quoted_backtick)
|
|
552
605
|
file = FileReader.expand_user_path(path)
|
|
553
606
|
files.append(file)
|
|
554
607
|
elif non_quoted:
|
|
555
|
-
# Replace escaped spaces with actual spaces
|
|
556
|
-
path = non_quoted.replace(
|
|
608
|
+
# Replace escaped spaces with actual spaces (no regex needed here, simple string replace)
|
|
609
|
+
path = non_quoted.replace("\\ ", " ")
|
|
557
610
|
file = FileReader.expand_user_path(path)
|
|
558
611
|
files.append(file)
|
|
559
612
|
# Filter out any files that do not exist
|
|
@@ -571,25 +624,28 @@ class FileReader(Expression):
|
|
|
571
624
|
if FileReader.exists(file):
|
|
572
625
|
not_skipped.append(file)
|
|
573
626
|
else:
|
|
574
|
-
UserMessage(f
|
|
627
|
+
UserMessage(f"Skipping file: {file}")
|
|
575
628
|
return not_skipped
|
|
576
629
|
|
|
577
630
|
def forward(self, files: str | list[str], **kwargs) -> Expression:
|
|
578
631
|
if isinstance(files, str):
|
|
579
632
|
# Convert to list for uniform processing; more easily downstream
|
|
580
633
|
files = [files]
|
|
581
|
-
if kwargs.get(
|
|
634
|
+
if kwargs.get("run_integrity_check"):
|
|
582
635
|
files = self.integrity_check(files)
|
|
583
636
|
return self.sym_return_type([self.open(f, **kwargs).value for f in files])
|
|
584
637
|
|
|
638
|
+
|
|
585
639
|
class FileQuery(Expression):
|
|
586
640
|
def __init__(self, path: str, filter: str, **kwargs):
|
|
587
641
|
super().__init__(**kwargs)
|
|
588
642
|
self.path = path
|
|
589
643
|
file_open = FileReader()
|
|
590
|
-
self.query_stream = Stream(
|
|
591
|
-
|
|
592
|
-
|
|
644
|
+
self.query_stream = Stream(
|
|
645
|
+
Sequence(
|
|
646
|
+
IncludeFilter(filter),
|
|
647
|
+
)
|
|
648
|
+
)
|
|
593
649
|
self.file = file_open(path)
|
|
594
650
|
|
|
595
651
|
def forward(self, sym: Symbol, **kwargs) -> Symbol:
|
|
@@ -599,42 +655,45 @@ class FileQuery(Expression):
|
|
|
599
655
|
|
|
600
656
|
|
|
601
657
|
class Function(TrackerTraceable):
|
|
602
|
-
def __init__(
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
658
|
+
def __init__(
|
|
659
|
+
self,
|
|
660
|
+
prompt: str = "",
|
|
661
|
+
examples: str | None = [],
|
|
662
|
+
pre_processors: list[PreProcessor] | None = None,
|
|
663
|
+
post_processors: list[PostProcessor] | None = None,
|
|
664
|
+
default: object | None = None,
|
|
665
|
+
constraints: list[Callable] | None = None,
|
|
666
|
+
return_type: type | None = str,
|
|
667
|
+
sym_return_type: type | None = Symbol,
|
|
668
|
+
origin_type: type | None = Expression,
|
|
669
|
+
*args,
|
|
670
|
+
**kwargs,
|
|
671
|
+
):
|
|
612
672
|
if constraints is None:
|
|
613
673
|
constraints = []
|
|
614
674
|
super().__init__(**kwargs)
|
|
615
|
-
chars
|
|
616
|
-
self.name
|
|
617
|
-
self.args
|
|
675
|
+
chars = ascii_lowercase + ascii_uppercase
|
|
676
|
+
self.name = "func_" + "".join(sample(chars, 15))
|
|
677
|
+
self.args = args
|
|
618
678
|
self.kwargs = kwargs
|
|
619
|
-
self._promptTemplate
|
|
620
|
-
self._promptFormatArgs
|
|
679
|
+
self._promptTemplate = prompt
|
|
680
|
+
self._promptFormatArgs = []
|
|
621
681
|
self._promptFormatKwargs = {}
|
|
622
|
-
self.examples
|
|
623
|
-
self.pre_processors
|
|
682
|
+
self.examples = Prompt(examples)
|
|
683
|
+
self.pre_processors = pre_processors
|
|
624
684
|
self.post_processors = post_processors
|
|
625
|
-
self.constraints
|
|
626
|
-
self.default
|
|
627
|
-
self.return_type
|
|
685
|
+
self.constraints = constraints
|
|
686
|
+
self.default = default
|
|
687
|
+
self.return_type = return_type
|
|
628
688
|
self.sym_return_type = sym_return_type
|
|
629
|
-
self.origin_type
|
|
689
|
+
self.origin_type = origin_type
|
|
630
690
|
|
|
631
691
|
@property
|
|
632
692
|
def prompt(self):
|
|
633
693
|
# return a copy of the prompt template
|
|
634
694
|
if len(self._promptFormatArgs) == 0 and len(self._promptFormatKwargs) == 0:
|
|
635
695
|
return self._promptTemplate
|
|
636
|
-
return f"{self._promptTemplate}".format(*self._promptFormatArgs,
|
|
637
|
-
**self._promptFormatKwargs)
|
|
696
|
+
return f"{self._promptTemplate}".format(*self._promptFormatArgs, **self._promptFormatKwargs)
|
|
638
697
|
|
|
639
698
|
def format(self, *args, **kwargs):
|
|
640
699
|
self._promptFormatArgs = args
|
|
@@ -642,9 +701,10 @@ class Function(TrackerTraceable):
|
|
|
642
701
|
|
|
643
702
|
def forward(self, *args, **kwargs) -> Expression:
|
|
644
703
|
# special case for few shot function prompt definition override
|
|
645
|
-
if
|
|
646
|
-
self.prompt = kwargs[
|
|
647
|
-
del kwargs[
|
|
704
|
+
if "fn" in kwargs:
|
|
705
|
+
self.prompt = kwargs["fn"]
|
|
706
|
+
del kwargs["fn"]
|
|
707
|
+
|
|
648
708
|
@core.few_shot(
|
|
649
709
|
*self.args,
|
|
650
710
|
prompt=self.prompt,
|
|
@@ -653,19 +713,24 @@ class Function(TrackerTraceable):
|
|
|
653
713
|
post_processors=self.post_processors,
|
|
654
714
|
constraints=self.constraints,
|
|
655
715
|
default=self.default,
|
|
656
|
-
**self.kwargs
|
|
716
|
+
**self.kwargs,
|
|
657
717
|
)
|
|
658
718
|
def _func(_, *args, **kwargs) -> self.return_type:
|
|
659
719
|
pass
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
720
|
+
|
|
721
|
+
_type = type(
|
|
722
|
+
self.name,
|
|
723
|
+
(self.origin_type,),
|
|
724
|
+
{
|
|
725
|
+
# constructor
|
|
726
|
+
"forward": _func,
|
|
727
|
+
"sym_return_type": self.sym_return_type,
|
|
728
|
+
"static_context": self.static_context,
|
|
729
|
+
"dynamic_context": self.dynamic_context,
|
|
730
|
+
"__class__": self.__class__,
|
|
731
|
+
"__module__": self.__module__,
|
|
732
|
+
},
|
|
733
|
+
)
|
|
669
734
|
obj = _type()
|
|
670
735
|
|
|
671
736
|
return self._to_symbol(obj(*args, **kwargs))
|
|
@@ -676,7 +741,7 @@ class PrepareData(Function):
|
|
|
676
741
|
def __call__(self, argument):
|
|
677
742
|
assert argument.prop.context is not None
|
|
678
743
|
instruct = argument.prop.prompt
|
|
679
|
-
context
|
|
744
|
+
context = argument.prop.context
|
|
680
745
|
return f"""{{
|
|
681
746
|
'context': '{context}',
|
|
682
747
|
'instruction': '{instruct}',
|
|
@@ -685,10 +750,10 @@ class PrepareData(Function):
|
|
|
685
750
|
|
|
686
751
|
def __init__(self, *args, **kwargs):
|
|
687
752
|
super().__init__(*args, **kwargs)
|
|
688
|
-
self.pre_processors
|
|
689
|
-
self.constraints
|
|
753
|
+
self.pre_processors = [self.PrepareDataPreProcessor()]
|
|
754
|
+
self.constraints = [DictFormatConstraint({"result": "<the data>"})]
|
|
690
755
|
self.post_processors = [JsonTruncateMarkdownPostProcessor()]
|
|
691
|
-
self.return_type
|
|
756
|
+
self.return_type = dict # constraint to cast the result to a dict
|
|
692
757
|
|
|
693
758
|
@property
|
|
694
759
|
def static_context(self):
|
|
@@ -723,7 +788,7 @@ Your goal is to prepare the data for the next task instruction. The data should
|
|
|
723
788
|
|
|
724
789
|
class ExpressionBuilder(Function):
|
|
725
790
|
def __init__(self, **kwargs):
|
|
726
|
-
super().__init__(
|
|
791
|
+
super().__init__("Generate the code following the instructions:", **kwargs)
|
|
727
792
|
self.processors = ProcessorPipeline([StripPostProcessor(), CodeExtractPostProcessor()])
|
|
728
793
|
|
|
729
794
|
def forward(self, instruct, *_args, **_kwargs):
|
|
@@ -774,10 +839,12 @@ Always produce the entire code to be executed in the same Python process. All ta
|
|
|
774
839
|
class JsonParser(Expression):
|
|
775
840
|
def __init__(self, query: str, json_: dict, **kwargs):
|
|
776
841
|
super().__init__(**kwargs)
|
|
777
|
-
func = Function(
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
842
|
+
func = Function(
|
|
843
|
+
prompt=JsonPromptTemplate(query, json_),
|
|
844
|
+
constraints=[DictFormatConstraint(json_)],
|
|
845
|
+
pre_processors=[JsonPreProcessor()],
|
|
846
|
+
post_processors=[JsonTruncatePostProcessor()],
|
|
847
|
+
)
|
|
781
848
|
self.fn = Try(func, retries=1)
|
|
782
849
|
|
|
783
850
|
def forward(self, sym: Symbol, **kwargs) -> Symbol:
|
|
@@ -787,21 +854,27 @@ class JsonParser(Expression):
|
|
|
787
854
|
|
|
788
855
|
|
|
789
856
|
class SimilarityClassification(Expression):
|
|
790
|
-
def __init__(
|
|
857
|
+
def __init__(
|
|
858
|
+
self, classes: list[str], metric: str = "cosine", in_memory: bool = False, **kwargs
|
|
859
|
+
):
|
|
791
860
|
super().__init__(**kwargs)
|
|
792
|
-
self.classes
|
|
793
|
-
self.metric
|
|
861
|
+
self.classes = classes
|
|
862
|
+
self.metric = metric
|
|
794
863
|
self.in_memory = in_memory
|
|
795
864
|
|
|
796
865
|
if self.in_memory:
|
|
797
|
-
UserMessage(
|
|
866
|
+
UserMessage(
|
|
867
|
+
f"Caching mode is enabled! It is your responsability to empty the .cache folder if you did changes to the classes. The cache is located at {HOME_PATH}/cache"
|
|
868
|
+
)
|
|
798
869
|
|
|
799
870
|
def forward(self, x: Symbol) -> Symbol:
|
|
800
|
-
x
|
|
801
|
-
usr_embed
|
|
802
|
-
embeddings
|
|
871
|
+
x = self._to_symbol(x)
|
|
872
|
+
usr_embed = x.embed()
|
|
873
|
+
embeddings = self._dynamic_cache()
|
|
803
874
|
similarities = [usr_embed.similarity(emb, metric=self.metric) for emb in embeddings]
|
|
804
|
-
similarities = sorted(
|
|
875
|
+
similarities = sorted(
|
|
876
|
+
zip(self.classes, similarities, strict=False), key=lambda x: x[1], reverse=True
|
|
877
|
+
)
|
|
805
878
|
|
|
806
879
|
return Symbol(similarities[0][0])
|
|
807
880
|
|
|
@@ -820,11 +893,7 @@ class InContextClassification(Expression):
|
|
|
820
893
|
self.blueprint = blueprint
|
|
821
894
|
|
|
822
895
|
def forward(self, x: Symbol, **kwargs) -> Symbol:
|
|
823
|
-
@core.few_shot(
|
|
824
|
-
prompt=x,
|
|
825
|
-
examples=self.blueprint,
|
|
826
|
-
**kwargs
|
|
827
|
-
)
|
|
896
|
+
@core.few_shot(prompt=x, examples=self.blueprint, **kwargs)
|
|
828
897
|
def _func(_):
|
|
829
898
|
pass
|
|
830
899
|
|
|
@@ -832,38 +901,38 @@ class InContextClassification(Expression):
|
|
|
832
901
|
|
|
833
902
|
|
|
834
903
|
class Indexer(Expression):
|
|
835
|
-
DEFAULT =
|
|
904
|
+
DEFAULT = "dataindex"
|
|
836
905
|
|
|
837
906
|
@staticmethod
|
|
838
907
|
def replace_special_chars(index: str):
|
|
839
908
|
# replace special characters that are not for path
|
|
840
|
-
return str(index).replace(
|
|
909
|
+
return str(index).replace("-", "").replace("_", "").replace(" ", "").lower()
|
|
841
910
|
|
|
842
911
|
def __init__(
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
912
|
+
self,
|
|
913
|
+
index_name: str = DEFAULT,
|
|
914
|
+
top_k: int = 8,
|
|
915
|
+
batch_size: int = 20,
|
|
916
|
+
formatter: Callable = _DEFAULT_PARAGRAPH_FORMATTER,
|
|
917
|
+
auto_add=False,
|
|
918
|
+
raw_result: bool = False,
|
|
919
|
+
new_dim: int = 1536,
|
|
920
|
+
**kwargs,
|
|
921
|
+
):
|
|
853
922
|
super().__init__(**kwargs)
|
|
854
923
|
index_name = Indexer.replace_special_chars(index_name)
|
|
855
924
|
self.index_name = index_name
|
|
856
|
-
self.elements
|
|
925
|
+
self.elements = []
|
|
857
926
|
self.batch_size = batch_size
|
|
858
|
-
self.top_k
|
|
859
|
-
self.retrieval
|
|
860
|
-
self.formatter
|
|
927
|
+
self.top_k = top_k
|
|
928
|
+
self.retrieval = None
|
|
929
|
+
self.formatter = formatter
|
|
861
930
|
self.raw_result = raw_result
|
|
862
|
-
self.new_dim
|
|
931
|
+
self.new_dim = new_dim
|
|
863
932
|
self.sym_return_type = Expression
|
|
864
933
|
|
|
865
934
|
# append index name to indices.txt in home directory .symai folder (default)
|
|
866
|
-
self.path = HOME_PATH /
|
|
935
|
+
self.path = HOME_PATH / "indices.txt"
|
|
867
936
|
if not self.path.exists():
|
|
868
937
|
self.path.parent.mkdir(parents=True, exist_ok=True)
|
|
869
938
|
self.path.touch()
|
|
@@ -874,51 +943,62 @@ class Indexer(Expression):
|
|
|
874
943
|
# check if index already exists in indices.txt and append if not
|
|
875
944
|
change = False
|
|
876
945
|
with self.path.open() as f:
|
|
877
|
-
indices = f.read().split(
|
|
946
|
+
indices = f.read().split("\n")
|
|
878
947
|
# filter out empty strings
|
|
879
948
|
indices = [i for i in indices if i]
|
|
880
949
|
if self.index_name not in indices:
|
|
881
|
-
|
|
882
|
-
|
|
950
|
+
indices.append(self.index_name)
|
|
951
|
+
change = True
|
|
883
952
|
if change:
|
|
884
|
-
with self.path.open(
|
|
885
|
-
f.write(
|
|
953
|
+
with self.path.open("w") as f:
|
|
954
|
+
f.write("\n".join(indices))
|
|
886
955
|
|
|
887
956
|
def exists(self) -> bool:
|
|
888
957
|
# check if index exists in home directory .symai folder (default) indices.txt
|
|
889
|
-
path = HOME_PATH /
|
|
958
|
+
path = HOME_PATH / "indices.txt"
|
|
890
959
|
if not path.exists():
|
|
891
960
|
return False
|
|
892
961
|
with path.open() as f:
|
|
893
|
-
indices = f.read().split(
|
|
962
|
+
indices = f.read().split("\n")
|
|
894
963
|
if self.index_name in indices:
|
|
895
964
|
return True
|
|
896
965
|
return False
|
|
897
966
|
|
|
898
967
|
def forward(
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
968
|
+
self,
|
|
969
|
+
data: Symbol | None = None,
|
|
970
|
+
_raw_result: bool = False,
|
|
971
|
+
) -> Symbol:
|
|
903
972
|
that = self
|
|
904
973
|
if data is not None:
|
|
905
974
|
data = self._to_symbol(data)
|
|
906
975
|
self.elements = self.formatter(data).value
|
|
907
976
|
# run over the elments in batches
|
|
908
977
|
for i in tqdm(range(0, len(self.elements), self.batch_size)):
|
|
909
|
-
val = Symbol(self.elements[i:i+self.batch_size]).zip(new_dim=self.new_dim)
|
|
978
|
+
val = Symbol(self.elements[i : i + self.batch_size]).zip(new_dim=self.new_dim)
|
|
910
979
|
that.add(val, index_name=that.index_name, index_dims=that.new_dim)
|
|
911
980
|
# we save the index
|
|
912
981
|
that.config(None, save=True, index_name=that.index_name, index_dims=that.new_dim)
|
|
913
982
|
|
|
914
|
-
def _func(query, *_args, **kwargs) -> Union[Symbol,
|
|
915
|
-
raw_result = kwargs.get(
|
|
983
|
+
def _func(query, *_args, **kwargs) -> Union[Symbol, "VectorDBResult"]:
|
|
984
|
+
raw_result = kwargs.get("raw_result") or that.raw_result
|
|
916
985
|
query_emb = Symbol(query).embed(new_dim=that.new_dim).value
|
|
917
|
-
res = that.get(
|
|
986
|
+
res = that.get(
|
|
987
|
+
query_emb,
|
|
988
|
+
index_name=that.index_name,
|
|
989
|
+
index_top_k=that.top_k,
|
|
990
|
+
ori_query=query,
|
|
991
|
+
index_dims=that.new_dim,
|
|
992
|
+
**kwargs,
|
|
993
|
+
)
|
|
918
994
|
that.retrieval = res
|
|
919
995
|
if raw_result:
|
|
920
996
|
return res
|
|
921
|
-
return Symbol(res).query(
|
|
997
|
+
return Symbol(res).query(
|
|
998
|
+
prompt="From the retrieved data, select the most relevant information.",
|
|
999
|
+
context=query,
|
|
1000
|
+
)
|
|
1001
|
+
|
|
922
1002
|
return _func
|
|
923
1003
|
|
|
924
1004
|
|
|
@@ -930,7 +1010,7 @@ class PrimitiveDisabler(Expression):
|
|
|
930
1010
|
|
|
931
1011
|
def __enter__(self):
|
|
932
1012
|
# Import Symbol lazily so components does not clash with symbol during load.
|
|
933
|
-
from .symbol import Symbol
|
|
1013
|
+
from .symbol import Symbol # noqa
|
|
934
1014
|
|
|
935
1015
|
frame = inspect.currentframe()
|
|
936
1016
|
f_locals = frame.f_back.f_locals
|
|
@@ -957,7 +1037,7 @@ class PrimitiveDisabler(Expression):
|
|
|
957
1037
|
for sym in self._symbols.values():
|
|
958
1038
|
for primitive in sym._primitives:
|
|
959
1039
|
for method, _ in inspect.getmembers(primitive, predicate=inspect.isfunction):
|
|
960
|
-
if method in self._primitives or method.startswith(
|
|
1040
|
+
if method in self._primitives or method.startswith("_"):
|
|
961
1041
|
continue
|
|
962
1042
|
self._primitives.add(method)
|
|
963
1043
|
|
|
@@ -1002,9 +1082,7 @@ class FunctionWithUsage(Function):
|
|
|
1002
1082
|
self.total_tokens += usage.total_tokens
|
|
1003
1083
|
|
|
1004
1084
|
def get_usage(self):
|
|
1005
|
-
return self._format_usage(
|
|
1006
|
-
self.prompt_tokens, self.completion_tokens, self.total_tokens
|
|
1007
|
-
)
|
|
1085
|
+
return self._format_usage(self.prompt_tokens, self.completion_tokens, self.total_tokens)
|
|
1008
1086
|
|
|
1009
1087
|
def forward(self, *args, **kwargs):
|
|
1010
1088
|
if "return_metadata" not in kwargs:
|
|
@@ -1015,9 +1093,7 @@ class FunctionWithUsage(Function):
|
|
|
1015
1093
|
raw_output = metadata.get("raw_output")
|
|
1016
1094
|
if hasattr(raw_output, "usage"):
|
|
1017
1095
|
usage = raw_output.usage
|
|
1018
|
-
prompt_tokens = (
|
|
1019
|
-
usage.prompt_tokens if hasattr(usage, "prompt_tokens") else 0
|
|
1020
|
-
)
|
|
1096
|
+
prompt_tokens = usage.prompt_tokens if hasattr(usage, "prompt_tokens") else 0
|
|
1021
1097
|
completion_tokens = (
|
|
1022
1098
|
usage.completion_tokens if hasattr(usage, "completion_tokens") else 0
|
|
1023
1099
|
)
|
|
@@ -1033,7 +1109,9 @@ class FunctionWithUsage(Function):
|
|
|
1033
1109
|
self.total_tokens += total_tokens
|
|
1034
1110
|
else:
|
|
1035
1111
|
if self.missing_usage_exception and "preview" not in kwargs:
|
|
1036
|
-
UserMessage(
|
|
1112
|
+
UserMessage(
|
|
1113
|
+
"Missing usage in metadata of neursymbolic engine", raise_with=Exception
|
|
1114
|
+
)
|
|
1037
1115
|
prompt_tokens = 0
|
|
1038
1116
|
completion_tokens = 0
|
|
1039
1117
|
total_tokens = 0
|
|
@@ -1042,12 +1120,12 @@ class FunctionWithUsage(Function):
|
|
|
1042
1120
|
|
|
1043
1121
|
|
|
1044
1122
|
class SelfPrompt(Expression):
|
|
1045
|
-
_default_retry_tries
|
|
1046
|
-
_default_retry_delay
|
|
1123
|
+
_default_retry_tries = 20
|
|
1124
|
+
_default_retry_delay = 0.5
|
|
1047
1125
|
_default_retry_max_delay = -1
|
|
1048
|
-
_default_retry_backoff
|
|
1049
|
-
_default_retry_jitter
|
|
1050
|
-
_default_retry_graceful
|
|
1126
|
+
_default_retry_backoff = 1
|
|
1127
|
+
_default_retry_jitter = 0
|
|
1128
|
+
_default_retry_graceful = True
|
|
1051
1129
|
|
|
1052
1130
|
def __init__(self, *args, **kwargs):
|
|
1053
1131
|
super().__init__(*args, **kwargs)
|
|
@@ -1061,14 +1139,21 @@ class SelfPrompt(Expression):
|
|
|
1061
1139
|
:return: A dictionary containing the new prompts in the same format:
|
|
1062
1140
|
{'user': '...', 'system': '...'}
|
|
1063
1141
|
"""
|
|
1064
|
-
tries
|
|
1065
|
-
delay
|
|
1066
|
-
max_delay = kwargs.get(
|
|
1067
|
-
backoff
|
|
1068
|
-
jitter
|
|
1069
|
-
graceful
|
|
1070
|
-
|
|
1071
|
-
@core_ext.retry(
|
|
1142
|
+
tries = kwargs.get("tries", self._default_retry_tries)
|
|
1143
|
+
delay = kwargs.get("delay", self._default_retry_delay)
|
|
1144
|
+
max_delay = kwargs.get("max_delay", self._default_retry_max_delay)
|
|
1145
|
+
backoff = kwargs.get("backoff", self._default_retry_backoff)
|
|
1146
|
+
jitter = kwargs.get("jitter", self._default_retry_jitter)
|
|
1147
|
+
graceful = kwargs.get("graceful", self._default_retry_graceful)
|
|
1148
|
+
|
|
1149
|
+
@core_ext.retry(
|
|
1150
|
+
tries=tries,
|
|
1151
|
+
delay=delay,
|
|
1152
|
+
max_delay=max_delay,
|
|
1153
|
+
backoff=backoff,
|
|
1154
|
+
jitter=jitter,
|
|
1155
|
+
graceful=graceful,
|
|
1156
|
+
)
|
|
1072
1157
|
@core.zero_shot(
|
|
1073
1158
|
prompt=(
|
|
1074
1159
|
"Based on the following prompt, generate a new system (or developer) prompt and a new user prompt. "
|
|
@@ -1077,18 +1162,19 @@ class SelfPrompt(Expression):
|
|
|
1077
1162
|
"The new user prompt should contain the user's requirements. "
|
|
1078
1163
|
"Check if the input contains a 'system' or 'developer' key and use the same key in your output. "
|
|
1079
1164
|
"Only output the new prompts in JSON format as shown:\n\n"
|
|
1080
|
-
|
|
1165
|
+
'{"system": "<new system prompt>", "user": "<new user prompt>"}\n\n'
|
|
1081
1166
|
"OR\n\n"
|
|
1082
|
-
|
|
1167
|
+
'{"developer": "<new developer prompt>", "user": "<new user prompt>"}\n\n'
|
|
1083
1168
|
"Maintain the same key structure as in the input prompt. Do not include any additional text."
|
|
1084
1169
|
),
|
|
1085
1170
|
response_format={"type": "json_object"},
|
|
1086
1171
|
post_processors=[
|
|
1087
1172
|
lambda res, _: json.loads(res),
|
|
1088
1173
|
],
|
|
1089
|
-
**kwargs
|
|
1174
|
+
**kwargs,
|
|
1090
1175
|
)
|
|
1091
|
-
def _func(self, sym: Symbol):
|
|
1176
|
+
def _func(self, sym: Symbol):
|
|
1177
|
+
pass
|
|
1092
1178
|
|
|
1093
1179
|
return _func(self, self._to_symbol(existing_prompt))
|
|
1094
1180
|
|
|
@@ -1104,15 +1190,19 @@ class MetadataTracker(Expression):
|
|
|
1104
1190
|
def __str__(self, value=None):
|
|
1105
1191
|
value = value or self.metadata
|
|
1106
1192
|
if isinstance(value, dict):
|
|
1107
|
-
return
|
|
1193
|
+
return (
|
|
1194
|
+
"{\n\t"
|
|
1195
|
+
+ ", \n\t".join(f'"{k}": {self.__str__(v)}' for k, v in value.items())
|
|
1196
|
+
+ "\n}"
|
|
1197
|
+
)
|
|
1108
1198
|
if isinstance(value, list):
|
|
1109
|
-
return
|
|
1199
|
+
return "[" + ", ".join(self.__str__(item) for item in value) + "]"
|
|
1110
1200
|
if isinstance(value, str):
|
|
1111
1201
|
return f'"{value}"'
|
|
1112
1202
|
return f"\n\t {value}"
|
|
1113
1203
|
|
|
1114
1204
|
def __new__(cls, *_args, **_kwargs):
|
|
1115
|
-
cls._lock = getattr(cls,
|
|
1205
|
+
cls._lock = getattr(cls, "_lock", Lock())
|
|
1116
1206
|
with cls._lock:
|
|
1117
1207
|
instance = super().__new__(cls)
|
|
1118
1208
|
instance._metadata = {}
|
|
@@ -1135,14 +1225,14 @@ class MetadataTracker(Expression):
|
|
|
1135
1225
|
return None
|
|
1136
1226
|
|
|
1137
1227
|
if (
|
|
1138
|
-
event ==
|
|
1139
|
-
and frame.f_code.co_name ==
|
|
1140
|
-
and
|
|
1141
|
-
and isinstance(frame.f_locals[
|
|
1228
|
+
event == "return"
|
|
1229
|
+
and frame.f_code.co_name == "forward"
|
|
1230
|
+
and "self" in frame.f_locals
|
|
1231
|
+
and isinstance(frame.f_locals["self"], Engine)
|
|
1142
1232
|
):
|
|
1143
1233
|
_, metadata = arg # arg contains return value on 'return' event
|
|
1144
|
-
engine_name = frame.f_locals[
|
|
1145
|
-
model_name = frame.f_locals[
|
|
1234
|
+
engine_name = frame.f_locals["self"].__class__.__name__
|
|
1235
|
+
model_name = frame.f_locals["self"].model
|
|
1146
1236
|
self._metadata[(self._metadata_id, engine_name, model_name)] = metadata
|
|
1147
1237
|
self._metadata_id += 1
|
|
1148
1238
|
|
|
@@ -1162,38 +1252,91 @@ class MetadataTracker(Expression):
|
|
|
1162
1252
|
try:
|
|
1163
1253
|
if engine_name == "GroqEngine":
|
|
1164
1254
|
usage = metadata["raw_output"].usage
|
|
1165
|
-
token_details[(engine_name, model_name)]["usage"]["completion_tokens"] +=
|
|
1166
|
-
|
|
1167
|
-
|
|
1255
|
+
token_details[(engine_name, model_name)]["usage"]["completion_tokens"] += (
|
|
1256
|
+
usage.completion_tokens
|
|
1257
|
+
)
|
|
1258
|
+
token_details[(engine_name, model_name)]["usage"]["prompt_tokens"] += (
|
|
1259
|
+
usage.prompt_tokens
|
|
1260
|
+
)
|
|
1261
|
+
token_details[(engine_name, model_name)]["usage"]["total_tokens"] += (
|
|
1262
|
+
usage.total_tokens
|
|
1263
|
+
)
|
|
1168
1264
|
token_details[(engine_name, model_name)]["usage"]["total_calls"] += 1
|
|
1169
1265
|
#!: Backward compatibility for components like `RuntimeInfo`
|
|
1170
|
-
token_details[(engine_name, model_name)]["prompt_breakdown"][
|
|
1171
|
-
|
|
1266
|
+
token_details[(engine_name, model_name)]["prompt_breakdown"][
|
|
1267
|
+
"cached_tokens"
|
|
1268
|
+
] += 0 # Assignment not allowed with defualtdict
|
|
1269
|
+
token_details[(engine_name, model_name)]["completion_breakdown"][
|
|
1270
|
+
"reasoning_tokens"
|
|
1271
|
+
] += 0
|
|
1272
|
+
elif engine_name == "ParallelEngine":
|
|
1273
|
+
token_details[(engine_name, None)]["usage"]["total_calls"] += 1
|
|
1274
|
+
# There are no model-specific tokens for this engine
|
|
1275
|
+
token_details[(engine_name, None)]["usage"]["completion_tokens"] += 0
|
|
1276
|
+
token_details[(engine_name, None)]["usage"]["prompt_tokens"] += 0
|
|
1277
|
+
token_details[(engine_name, None)]["usage"]["total_tokens"] += 0
|
|
1278
|
+
#!: Backward compatibility for components like `RuntimeInfo`
|
|
1279
|
+
token_details[(engine_name, None)]["prompt_breakdown"]["cached_tokens"] += (
|
|
1280
|
+
0 # Assignment not allowed with defualtdict
|
|
1281
|
+
)
|
|
1282
|
+
token_details[(engine_name, None)]["completion_breakdown"][
|
|
1283
|
+
"reasoning_tokens"
|
|
1284
|
+
] += 0
|
|
1172
1285
|
elif engine_name in ("GPTXChatEngine", "GPTXReasoningEngine"):
|
|
1173
1286
|
usage = metadata["raw_output"].usage
|
|
1174
|
-
token_details[(engine_name, model_name)]["usage"]["completion_tokens"] +=
|
|
1175
|
-
|
|
1176
|
-
|
|
1287
|
+
token_details[(engine_name, model_name)]["usage"]["completion_tokens"] += (
|
|
1288
|
+
usage.completion_tokens
|
|
1289
|
+
)
|
|
1290
|
+
token_details[(engine_name, model_name)]["usage"]["prompt_tokens"] += (
|
|
1291
|
+
usage.prompt_tokens
|
|
1292
|
+
)
|
|
1293
|
+
token_details[(engine_name, model_name)]["usage"]["total_tokens"] += (
|
|
1294
|
+
usage.total_tokens
|
|
1295
|
+
)
|
|
1177
1296
|
token_details[(engine_name, model_name)]["usage"]["total_calls"] += 1
|
|
1178
|
-
token_details[(engine_name, model_name)]["completion_breakdown"][
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
token_details[(engine_name, model_name)]["completion_breakdown"][
|
|
1182
|
-
|
|
1183
|
-
|
|
1297
|
+
token_details[(engine_name, model_name)]["completion_breakdown"][
|
|
1298
|
+
"accepted_prediction_tokens"
|
|
1299
|
+
] += usage.completion_tokens_details.accepted_prediction_tokens
|
|
1300
|
+
token_details[(engine_name, model_name)]["completion_breakdown"][
|
|
1301
|
+
"rejected_prediction_tokens"
|
|
1302
|
+
] += usage.completion_tokens_details.rejected_prediction_tokens
|
|
1303
|
+
token_details[(engine_name, model_name)]["completion_breakdown"][
|
|
1304
|
+
"audio_tokens"
|
|
1305
|
+
] += usage.completion_tokens_details.audio_tokens
|
|
1306
|
+
token_details[(engine_name, model_name)]["completion_breakdown"][
|
|
1307
|
+
"reasoning_tokens"
|
|
1308
|
+
] += usage.completion_tokens_details.reasoning_tokens
|
|
1309
|
+
token_details[(engine_name, model_name)]["prompt_breakdown"][
|
|
1310
|
+
"audio_tokens"
|
|
1311
|
+
] += usage.prompt_tokens_details.audio_tokens
|
|
1312
|
+
token_details[(engine_name, model_name)]["prompt_breakdown"][
|
|
1313
|
+
"cached_tokens"
|
|
1314
|
+
] += usage.prompt_tokens_details.cached_tokens
|
|
1184
1315
|
elif engine_name == "GPTXSearchEngine":
|
|
1185
1316
|
usage = metadata["raw_output"].usage
|
|
1186
|
-
token_details[(engine_name, model_name)]["usage"]["prompt_tokens"] +=
|
|
1187
|
-
|
|
1188
|
-
|
|
1317
|
+
token_details[(engine_name, model_name)]["usage"]["prompt_tokens"] += (
|
|
1318
|
+
usage.input_tokens
|
|
1319
|
+
)
|
|
1320
|
+
token_details[(engine_name, model_name)]["usage"]["completion_tokens"] += (
|
|
1321
|
+
usage.output_tokens
|
|
1322
|
+
)
|
|
1323
|
+
token_details[(engine_name, model_name)]["usage"]["total_tokens"] += (
|
|
1324
|
+
usage.total_tokens
|
|
1325
|
+
)
|
|
1189
1326
|
token_details[(engine_name, model_name)]["usage"]["total_calls"] += 1
|
|
1190
|
-
token_details[(engine_name, model_name)]["prompt_breakdown"][
|
|
1191
|
-
|
|
1327
|
+
token_details[(engine_name, model_name)]["prompt_breakdown"][
|
|
1328
|
+
"cached_tokens"
|
|
1329
|
+
] += usage.input_tokens_details.cached_tokens
|
|
1330
|
+
token_details[(engine_name, model_name)]["completion_breakdown"][
|
|
1331
|
+
"reasoning_tokens"
|
|
1332
|
+
] += usage.output_tokens_details.reasoning_tokens
|
|
1192
1333
|
else:
|
|
1193
1334
|
logger.warning(f"Tracking {engine_name} is not supported.")
|
|
1194
1335
|
continue
|
|
1195
1336
|
except Exception as e:
|
|
1196
|
-
UserMessage(
|
|
1337
|
+
UserMessage(
|
|
1338
|
+
f"Failed to parse metadata for {engine_name}: {e}", raise_with=AttributeError
|
|
1339
|
+
)
|
|
1197
1340
|
|
|
1198
1341
|
# Convert to normal dict
|
|
1199
1342
|
return {**token_details}
|
|
@@ -1203,22 +1346,24 @@ class MetadataTracker(Expression):
|
|
|
1203
1346
|
return engine_name in supported_engines
|
|
1204
1347
|
|
|
1205
1348
|
def _accumulate_time_field(self, accumulated: dict, metadata: dict) -> None:
|
|
1206
|
-
if
|
|
1207
|
-
accumulated[
|
|
1349
|
+
if "time" in metadata and "time" in accumulated:
|
|
1350
|
+
accumulated["time"] += metadata["time"]
|
|
1208
1351
|
|
|
1209
1352
|
def _accumulate_usage_fields(self, accumulated: dict, metadata: dict) -> None:
|
|
1210
|
-
if
|
|
1353
|
+
if "raw_output" not in metadata or "raw_output" not in accumulated:
|
|
1211
1354
|
return
|
|
1212
1355
|
|
|
1213
|
-
metadata_raw_output = metadata[
|
|
1214
|
-
accumulated_raw_output = accumulated[
|
|
1215
|
-
if not hasattr(metadata_raw_output,
|
|
1356
|
+
metadata_raw_output = metadata["raw_output"]
|
|
1357
|
+
accumulated_raw_output = accumulated["raw_output"]
|
|
1358
|
+
if not hasattr(metadata_raw_output, "usage") or not hasattr(
|
|
1359
|
+
accumulated_raw_output, "usage"
|
|
1360
|
+
):
|
|
1216
1361
|
return
|
|
1217
1362
|
|
|
1218
1363
|
current_usage = metadata_raw_output.usage
|
|
1219
1364
|
accumulated_usage = accumulated_raw_output.usage
|
|
1220
1365
|
|
|
1221
|
-
for attr in [
|
|
1366
|
+
for attr in ["completion_tokens", "prompt_tokens", "total_tokens"]:
|
|
1222
1367
|
if hasattr(current_usage, attr) and hasattr(accumulated_usage, attr):
|
|
1223
1368
|
setattr(
|
|
1224
1369
|
accumulated_usage,
|
|
@@ -1226,20 +1371,24 @@ class MetadataTracker(Expression):
|
|
|
1226
1371
|
getattr(accumulated_usage, attr) + getattr(current_usage, attr),
|
|
1227
1372
|
)
|
|
1228
1373
|
|
|
1229
|
-
for detail_attr in [
|
|
1230
|
-
if not hasattr(current_usage, detail_attr) or not hasattr(
|
|
1374
|
+
for detail_attr in ["completion_tokens_details", "prompt_tokens_details"]:
|
|
1375
|
+
if not hasattr(current_usage, detail_attr) or not hasattr(
|
|
1376
|
+
accumulated_usage, detail_attr
|
|
1377
|
+
):
|
|
1231
1378
|
continue
|
|
1232
1379
|
|
|
1233
1380
|
current_details = getattr(current_usage, detail_attr)
|
|
1234
1381
|
accumulated_details = getattr(accumulated_usage, detail_attr)
|
|
1235
1382
|
|
|
1236
1383
|
for attr in dir(current_details):
|
|
1237
|
-
if attr.startswith(
|
|
1384
|
+
if attr.startswith("_") or not hasattr(accumulated_details, attr):
|
|
1238
1385
|
continue
|
|
1239
1386
|
|
|
1240
1387
|
current_val = getattr(current_details, attr)
|
|
1241
1388
|
accumulated_val = getattr(accumulated_details, attr)
|
|
1242
|
-
if isinstance(current_val, (int, float)) and isinstance(
|
|
1389
|
+
if isinstance(current_val, (int, float)) and isinstance(
|
|
1390
|
+
accumulated_val, (int, float)
|
|
1391
|
+
):
|
|
1243
1392
|
setattr(accumulated_details, attr, accumulated_val + current_val)
|
|
1244
1393
|
|
|
1245
1394
|
def _accumulate_metadata(self):
|
|
@@ -1255,7 +1404,9 @@ class MetadataTracker(Expression):
|
|
|
1255
1404
|
# Skipz first entry
|
|
1256
1405
|
for (_, engine_name), metadata in list(self._metadata.items())[1:]:
|
|
1257
1406
|
if not self._can_accumulate_engine(engine_name):
|
|
1258
|
-
logger.warning(
|
|
1407
|
+
logger.warning(
|
|
1408
|
+
f"Metadata accumulation for {engine_name} is not supported. Try `.usage` instead for now."
|
|
1409
|
+
)
|
|
1259
1410
|
continue
|
|
1260
1411
|
|
|
1261
1412
|
self._accumulate_time_field(accumulated, metadata)
|
|
@@ -1278,6 +1429,7 @@ class MetadataTracker(Expression):
|
|
|
1278
1429
|
|
|
1279
1430
|
class DynamicEngine(Expression):
|
|
1280
1431
|
"""Context manager for dynamically switching neurosymbolic engine models."""
|
|
1432
|
+
|
|
1281
1433
|
def __init__(self, model: str, api_key: str, _debug: bool = False, **_kwargs):
|
|
1282
1434
|
super().__init__()
|
|
1283
1435
|
self.model = model
|
|
@@ -1288,7 +1440,7 @@ class DynamicEngine(Expression):
|
|
|
1288
1440
|
self._ctx_token = None
|
|
1289
1441
|
|
|
1290
1442
|
def __new__(cls, *_args, **_kwargs):
|
|
1291
|
-
cls._lock = getattr(cls,
|
|
1443
|
+
cls._lock = getattr(cls, "_lock", Lock())
|
|
1292
1444
|
with cls._lock:
|
|
1293
1445
|
instance = super().__new__(cls)
|
|
1294
1446
|
instance._metadata = {}
|
|
@@ -1322,11 +1474,177 @@ class DynamicEngine(Expression):
|
|
|
1322
1474
|
def _create_engine_instance(self):
|
|
1323
1475
|
"""Create an engine instance based on the model name."""
|
|
1324
1476
|
# Deferred to avoid components <-> neurosymbolic engine circular imports.
|
|
1325
|
-
from .backend.engines.neurosymbolic import ENGINE_MAPPING
|
|
1477
|
+
from .backend.engines.neurosymbolic import ENGINE_MAPPING # noqa
|
|
1478
|
+
|
|
1326
1479
|
try:
|
|
1327
1480
|
engine_class = ENGINE_MAPPING.get(self.model)
|
|
1328
1481
|
if engine_class is None:
|
|
1329
1482
|
UserMessage(f"Unsupported model '{self.model}'", raise_with=ValueError)
|
|
1330
1483
|
return engine_class(api_key=self.api_key, model=self.model)
|
|
1331
1484
|
except Exception as e:
|
|
1332
|
-
UserMessage(
|
|
1485
|
+
UserMessage(
|
|
1486
|
+
f"Failed to create engine for model '{self.model}': {e!s}", raise_with=ValueError
|
|
1487
|
+
)
|
|
1488
|
+
|
|
1489
|
+
|
|
1490
|
+
# Chonkie chunker imports - lazy loaded
|
|
1491
|
+
_CHONKIE_MODULES = None
|
|
1492
|
+
_CHUNKER_MAPPING = None
|
|
1493
|
+
_CHONKIE_AVAILABLE = None
|
|
1494
|
+
|
|
1495
|
+
|
|
1496
|
+
def _lazy_import_chonkie():
|
|
1497
|
+
"""Lazily import chonkie modules when needed."""
|
|
1498
|
+
global _CHONKIE_MODULES, _CHUNKER_MAPPING, _CHONKIE_AVAILABLE
|
|
1499
|
+
|
|
1500
|
+
if _CHONKIE_MODULES is not None:
|
|
1501
|
+
return _CHONKIE_MODULES
|
|
1502
|
+
|
|
1503
|
+
try:
|
|
1504
|
+
from chonkie import ( # noqa
|
|
1505
|
+
CodeChunker,
|
|
1506
|
+
LateChunker,
|
|
1507
|
+
NeuralChunker,
|
|
1508
|
+
RecursiveChunker,
|
|
1509
|
+
SemanticChunker,
|
|
1510
|
+
SentenceChunker,
|
|
1511
|
+
SlumberChunker,
|
|
1512
|
+
TableChunker,
|
|
1513
|
+
TokenChunker,
|
|
1514
|
+
)
|
|
1515
|
+
from chonkie.embeddings.base import BaseEmbeddings # noqa
|
|
1516
|
+
from tokenizers import Tokenizer # noqa
|
|
1517
|
+
|
|
1518
|
+
_CHONKIE_MODULES = {
|
|
1519
|
+
"CodeChunker": CodeChunker,
|
|
1520
|
+
"LateChunker": LateChunker,
|
|
1521
|
+
"NeuralChunker": NeuralChunker,
|
|
1522
|
+
"RecursiveChunker": RecursiveChunker,
|
|
1523
|
+
"SemanticChunker": SemanticChunker,
|
|
1524
|
+
"SentenceChunker": SentenceChunker,
|
|
1525
|
+
"SlumberChunker": SlumberChunker,
|
|
1526
|
+
"TableChunker": TableChunker,
|
|
1527
|
+
"TokenChunker": TokenChunker,
|
|
1528
|
+
"BaseEmbeddings": BaseEmbeddings,
|
|
1529
|
+
"Tokenizer": Tokenizer,
|
|
1530
|
+
}
|
|
1531
|
+
_CHUNKER_MAPPING = {
|
|
1532
|
+
"TokenChunker": TokenChunker,
|
|
1533
|
+
"SentenceChunker": SentenceChunker,
|
|
1534
|
+
"RecursiveChunker": RecursiveChunker,
|
|
1535
|
+
"SemanticChunker": SemanticChunker,
|
|
1536
|
+
"CodeChunker": CodeChunker,
|
|
1537
|
+
"LateChunker": LateChunker,
|
|
1538
|
+
"NeuralChunker": NeuralChunker,
|
|
1539
|
+
"SlumberChunker": SlumberChunker,
|
|
1540
|
+
"TableChunker": TableChunker,
|
|
1541
|
+
}
|
|
1542
|
+
_CHONKIE_AVAILABLE = True
|
|
1543
|
+
except ImportError:
|
|
1544
|
+
_CHONKIE_MODULES = {}
|
|
1545
|
+
_CHUNKER_MAPPING = {}
|
|
1546
|
+
_CHONKIE_AVAILABLE = False
|
|
1547
|
+
|
|
1548
|
+
return _CHONKIE_MODULES
|
|
1549
|
+
|
|
1550
|
+
|
|
1551
|
+
def _get_chunker_mapping():
|
|
1552
|
+
"""Get the chunker mapping, lazily importing chonkie if needed."""
|
|
1553
|
+
if _CHUNKER_MAPPING is None:
|
|
1554
|
+
_lazy_import_chonkie()
|
|
1555
|
+
return _CHUNKER_MAPPING or {}
|
|
1556
|
+
|
|
1557
|
+
|
|
1558
|
+
def _is_chonkie_available():
|
|
1559
|
+
"""Check if chonkie is available, lazily importing if needed."""
|
|
1560
|
+
if _CHONKIE_AVAILABLE is None:
|
|
1561
|
+
_lazy_import_chonkie()
|
|
1562
|
+
return _CHONKIE_AVAILABLE or False
|
|
1563
|
+
|
|
1564
|
+
|
|
1565
|
+
@beartype
|
|
1566
|
+
class ChonkieChunker(Expression):
|
|
1567
|
+
def __init__(
|
|
1568
|
+
self,
|
|
1569
|
+
tokenizer_name: str | None = "gpt2",
|
|
1570
|
+
embedding_model_name: str | None = "minishlab/potion-base-8M",
|
|
1571
|
+
**symai_kwargs,
|
|
1572
|
+
):
|
|
1573
|
+
super().__init__(**symai_kwargs)
|
|
1574
|
+
self.tokenizer_name = tokenizer_name
|
|
1575
|
+
self.embedding_model_name = embedding_model_name
|
|
1576
|
+
|
|
1577
|
+
def forward(
|
|
1578
|
+
self, data: Symbol, chunker_name: str | None = "RecursiveChunker", **chunker_kwargs
|
|
1579
|
+
) -> Symbol:
|
|
1580
|
+
if not _is_chonkie_available():
|
|
1581
|
+
UserMessage(
|
|
1582
|
+
"chonkie library is not installed. Please install it with `pip install chonkie tokenizers`.",
|
|
1583
|
+
raise_with=ImportError,
|
|
1584
|
+
)
|
|
1585
|
+
chunker = self._resolve_chunker(chunker_name, **chunker_kwargs)
|
|
1586
|
+
chunks = [ChonkieChunker.clean_text(chunk.text) for chunk in chunker(data.value)]
|
|
1587
|
+
return self._to_symbol(chunks)
|
|
1588
|
+
|
|
1589
|
+
def _resolve_chunker(self, chunker_name: str, **chunker_kwargs):
|
|
1590
|
+
"""Resolve and instantiate a chunker by name."""
|
|
1591
|
+
chunker_mapping = _get_chunker_mapping()
|
|
1592
|
+
|
|
1593
|
+
if chunker_name not in chunker_mapping:
|
|
1594
|
+
msg = (
|
|
1595
|
+
f"Chunker {chunker_name} not found. Available chunkers: {list(chunker_mapping.keys())}. "
|
|
1596
|
+
f"See docs (https://docs.chonkie.ai/getting-started/introduction) for more info."
|
|
1597
|
+
)
|
|
1598
|
+
raise ValueError(msg)
|
|
1599
|
+
|
|
1600
|
+
chunker_class = chunker_mapping[chunker_name]
|
|
1601
|
+
chonkie_modules = _lazy_import_chonkie()
|
|
1602
|
+
Tokenizer = chonkie_modules.get("Tokenizer")
|
|
1603
|
+
|
|
1604
|
+
# Tokenizer-based chunkers (use tokenizer_name)
|
|
1605
|
+
if chunker_name in ["TokenChunker", "SentenceChunker", "RecursiveChunker"]:
|
|
1606
|
+
if Tokenizer is None:
|
|
1607
|
+
UserMessage(
|
|
1608
|
+
"Tokenizers library is not installed. Please install it with `pip install tokenizers`.",
|
|
1609
|
+
raise_with=ImportError,
|
|
1610
|
+
)
|
|
1611
|
+
tokenizer = Tokenizer.from_pretrained(self.tokenizer_name)
|
|
1612
|
+
return chunker_class(tokenizer, **chunker_kwargs)
|
|
1613
|
+
|
|
1614
|
+
# Embedding-based chunkers (use embedding_model_name)
|
|
1615
|
+
if chunker_name in ["SemanticChunker", "LateChunker"]:
|
|
1616
|
+
return chunker_class(embedding_model=self.embedding_model_name, **chunker_kwargs)
|
|
1617
|
+
|
|
1618
|
+
# CodeChunker and TableChunker use tokenizer (can use string or Tokenizer object)
|
|
1619
|
+
if chunker_name in ["CodeChunker", "TableChunker"]:
|
|
1620
|
+
# These can accept tokenizer as string (default 'character') or Tokenizer object
|
|
1621
|
+
# If tokenizer not provided in kwargs, use tokenizer_name
|
|
1622
|
+
if "tokenizer" not in chunker_kwargs:
|
|
1623
|
+
chunker_kwargs["tokenizer"] = self.tokenizer_name
|
|
1624
|
+
return chunker_class(**chunker_kwargs)
|
|
1625
|
+
|
|
1626
|
+
# SlumberChunker uses tokenizer (can use string or Tokenizer object)
|
|
1627
|
+
if chunker_name == "SlumberChunker":
|
|
1628
|
+
# SlumberChunker can accept tokenizer as string or Tokenizer object
|
|
1629
|
+
# If tokenizer not provided in kwargs, use tokenizer_name
|
|
1630
|
+
if "tokenizer" not in chunker_kwargs:
|
|
1631
|
+
chunker_kwargs["tokenizer"] = self.tokenizer_name
|
|
1632
|
+
return chunker_class(**chunker_kwargs)
|
|
1633
|
+
|
|
1634
|
+
# NeuralChunker uses model parameter (defaults provided by chonkie)
|
|
1635
|
+
if chunker_name == "NeuralChunker":
|
|
1636
|
+
return chunker_class(**chunker_kwargs)
|
|
1637
|
+
|
|
1638
|
+
msg = (
|
|
1639
|
+
f"Chunker {chunker_name} not properly configured. "
|
|
1640
|
+
f"Available chunkers: {list(chunker_mapping.keys())}."
|
|
1641
|
+
)
|
|
1642
|
+
raise ValueError(msg)
|
|
1643
|
+
|
|
1644
|
+
@staticmethod
|
|
1645
|
+
def clean_text(text: str) -> str:
|
|
1646
|
+
"""Cleans text by removing problematic characters."""
|
|
1647
|
+
text = text.replace("\x00", "") # Remove null bytes (\x00)
|
|
1648
|
+
return text.encode("utf-8", errors="ignore").decode(
|
|
1649
|
+
"utf-8"
|
|
1650
|
+
) # Replace invalid UTF-8 sequences
|