symbolicai 0.21.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- symai/__init__.py +269 -173
- symai/backend/base.py +123 -110
- symai/backend/engines/drawing/engine_bfl.py +45 -44
- symai/backend/engines/drawing/engine_gpt_image.py +112 -97
- symai/backend/engines/embedding/engine_llama_cpp.py +63 -52
- symai/backend/engines/embedding/engine_openai.py +25 -21
- symai/backend/engines/execute/engine_python.py +19 -18
- symai/backend/engines/files/engine_io.py +104 -95
- symai/backend/engines/imagecaptioning/engine_blip2.py +28 -24
- symai/backend/engines/imagecaptioning/engine_llavacpp_client.py +102 -79
- symai/backend/engines/index/engine_pinecone.py +124 -97
- symai/backend/engines/index/engine_qdrant.py +1011 -0
- symai/backend/engines/index/engine_vectordb.py +84 -56
- symai/backend/engines/lean/engine_lean4.py +96 -52
- symai/backend/engines/neurosymbolic/__init__.py +41 -13
- symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_chat.py +330 -248
- symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_reasoning.py +329 -264
- symai/backend/engines/neurosymbolic/engine_cerebras.py +328 -0
- symai/backend/engines/neurosymbolic/engine_deepseekX_reasoning.py +118 -88
- symai/backend/engines/neurosymbolic/engine_google_geminiX_reasoning.py +344 -299
- symai/backend/engines/neurosymbolic/engine_groq.py +173 -115
- symai/backend/engines/neurosymbolic/engine_huggingface.py +114 -84
- symai/backend/engines/neurosymbolic/engine_llama_cpp.py +144 -118
- symai/backend/engines/neurosymbolic/engine_openai_gptX_chat.py +415 -307
- symai/backend/engines/neurosymbolic/engine_openai_gptX_reasoning.py +394 -231
- symai/backend/engines/ocr/engine_apilayer.py +23 -27
- symai/backend/engines/output/engine_stdout.py +10 -13
- symai/backend/engines/{webscraping → scrape}/engine_requests.py +101 -54
- symai/backend/engines/search/engine_openai.py +100 -88
- symai/backend/engines/search/engine_parallel.py +665 -0
- symai/backend/engines/search/engine_perplexity.py +44 -45
- symai/backend/engines/search/engine_serpapi.py +37 -34
- symai/backend/engines/speech_to_text/engine_local_whisper.py +54 -51
- symai/backend/engines/symbolic/engine_wolframalpha.py +15 -9
- symai/backend/engines/text_to_speech/engine_openai.py +20 -26
- symai/backend/engines/text_vision/engine_clip.py +39 -37
- symai/backend/engines/userinput/engine_console.py +5 -6
- symai/backend/mixin/__init__.py +13 -0
- symai/backend/mixin/anthropic.py +48 -38
- symai/backend/mixin/deepseek.py +6 -5
- symai/backend/mixin/google.py +7 -4
- symai/backend/mixin/groq.py +2 -4
- symai/backend/mixin/openai.py +140 -110
- symai/backend/settings.py +87 -20
- symai/chat.py +216 -123
- symai/collect/__init__.py +7 -1
- symai/collect/dynamic.py +80 -70
- symai/collect/pipeline.py +67 -51
- symai/collect/stats.py +161 -109
- symai/components.py +707 -360
- symai/constraints.py +24 -12
- symai/core.py +1857 -1233
- symai/core_ext.py +83 -80
- symai/endpoints/api.py +166 -104
- symai/extended/.DS_Store +0 -0
- symai/extended/__init__.py +46 -12
- symai/extended/api_builder.py +29 -21
- symai/extended/arxiv_pdf_parser.py +23 -14
- symai/extended/bibtex_parser.py +9 -6
- symai/extended/conversation.py +156 -126
- symai/extended/document.py +50 -30
- symai/extended/file_merger.py +57 -14
- symai/extended/graph.py +51 -32
- symai/extended/html_style_template.py +18 -14
- symai/extended/interfaces/blip_2.py +2 -3
- symai/extended/interfaces/clip.py +4 -3
- symai/extended/interfaces/console.py +9 -1
- symai/extended/interfaces/dall_e.py +4 -2
- symai/extended/interfaces/file.py +2 -0
- symai/extended/interfaces/flux.py +4 -2
- symai/extended/interfaces/gpt_image.py +16 -7
- symai/extended/interfaces/input.py +2 -1
- symai/extended/interfaces/llava.py +1 -2
- symai/extended/interfaces/{naive_webscraping.py → naive_scrape.py} +4 -3
- symai/extended/interfaces/naive_vectordb.py +9 -10
- symai/extended/interfaces/ocr.py +5 -3
- symai/extended/interfaces/openai_search.py +2 -0
- symai/extended/interfaces/parallel.py +30 -0
- symai/extended/interfaces/perplexity.py +2 -0
- symai/extended/interfaces/pinecone.py +12 -9
- symai/extended/interfaces/python.py +2 -0
- symai/extended/interfaces/serpapi.py +3 -1
- symai/extended/interfaces/terminal.py +2 -4
- symai/extended/interfaces/tts.py +3 -2
- symai/extended/interfaces/whisper.py +3 -2
- symai/extended/interfaces/wolframalpha.py +2 -1
- symai/extended/metrics/__init__.py +11 -1
- symai/extended/metrics/similarity.py +14 -13
- symai/extended/os_command.py +39 -29
- symai/extended/packages/__init__.py +29 -3
- symai/extended/packages/symdev.py +51 -43
- symai/extended/packages/sympkg.py +41 -35
- symai/extended/packages/symrun.py +63 -50
- symai/extended/repo_cloner.py +14 -12
- symai/extended/seo_query_optimizer.py +15 -13
- symai/extended/solver.py +116 -91
- symai/extended/summarizer.py +12 -10
- symai/extended/taypan_interpreter.py +17 -18
- symai/extended/vectordb.py +122 -92
- symai/formatter/__init__.py +9 -1
- symai/formatter/formatter.py +51 -47
- symai/formatter/regex.py +70 -69
- symai/functional.py +325 -176
- symai/imports.py +190 -147
- symai/interfaces.py +57 -28
- symai/memory.py +45 -35
- symai/menu/screen.py +28 -19
- symai/misc/console.py +66 -56
- symai/misc/loader.py +8 -5
- symai/models/__init__.py +17 -1
- symai/models/base.py +395 -236
- symai/models/errors.py +1 -2
- symai/ops/__init__.py +32 -22
- symai/ops/measures.py +24 -25
- symai/ops/primitives.py +1149 -731
- symai/post_processors.py +58 -50
- symai/pre_processors.py +86 -82
- symai/processor.py +21 -13
- symai/prompts.py +764 -685
- symai/server/huggingface_server.py +135 -49
- symai/server/llama_cpp_server.py +21 -11
- symai/server/qdrant_server.py +206 -0
- symai/shell.py +100 -42
- symai/shellsv.py +700 -492
- symai/strategy.py +630 -346
- symai/symbol.py +368 -322
- symai/utils.py +100 -78
- {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/METADATA +22 -10
- symbolicai-1.1.0.dist-info/RECORD +168 -0
- symbolicai-0.21.0.dist-info/RECORD +0 -162
- {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/WHEEL +0 -0
- {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/entry_points.txt +0 -0
- {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/licenses/LICENSE +0 -0
- {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/top_level.txt +0 -0
symai/components.py
CHANGED
|
@@ -1,23 +1,23 @@
|
|
|
1
1
|
import copy
|
|
2
2
|
import inspect
|
|
3
3
|
import json
|
|
4
|
-
import os
|
|
5
4
|
import re
|
|
6
5
|
import sys
|
|
7
|
-
from abc import abstractmethod
|
|
8
6
|
from collections import defaultdict
|
|
7
|
+
from collections.abc import Callable, Iterator
|
|
9
8
|
from pathlib import Path
|
|
10
9
|
from random import sample
|
|
11
10
|
from string import ascii_lowercase, ascii_uppercase
|
|
12
11
|
from threading import Lock
|
|
13
|
-
from
|
|
14
|
-
|
|
12
|
+
from typing import TYPE_CHECKING, Union
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from typing import Any
|
|
15
16
|
|
|
16
17
|
import numpy as np
|
|
17
|
-
from
|
|
18
|
+
from beartype import beartype
|
|
18
19
|
from box import Box
|
|
19
20
|
from loguru import logger
|
|
20
|
-
from pydantic import BaseModel, ValidationError
|
|
21
21
|
from pyvis.network import Network
|
|
22
22
|
from tqdm import tqdm
|
|
23
23
|
|
|
@@ -25,46 +25,64 @@ from . import core, core_ext
|
|
|
25
25
|
from .backend.base import Engine
|
|
26
26
|
from .backend.settings import HOME_PATH
|
|
27
27
|
from .constraints import DictFormatConstraint
|
|
28
|
+
from .context import CURRENT_ENGINE_VAR
|
|
28
29
|
from .formatter import ParagraphFormatter
|
|
29
|
-
from .post_processors import (
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
30
|
+
from .post_processors import (
|
|
31
|
+
CodeExtractPostProcessor,
|
|
32
|
+
JsonTruncateMarkdownPostProcessor,
|
|
33
|
+
JsonTruncatePostProcessor,
|
|
34
|
+
PostProcessor,
|
|
35
|
+
StripPostProcessor,
|
|
36
|
+
)
|
|
33
37
|
from .pre_processors import JsonPreProcessor, PreProcessor
|
|
34
38
|
from .processor import ProcessorPipeline
|
|
35
39
|
from .prompts import JsonPromptTemplate, Prompt
|
|
36
40
|
from .symbol import Expression, Metadata, Symbol
|
|
37
|
-
from .utils import
|
|
41
|
+
from .utils import UserMessage
|
|
42
|
+
|
|
43
|
+
if TYPE_CHECKING:
|
|
44
|
+
from .backend.engines.index.engine_vectordb import VectorDBResult
|
|
45
|
+
|
|
46
|
+
_DEFAULT_PARAGRAPH_FORMATTER = ParagraphFormatter()
|
|
38
47
|
|
|
39
48
|
|
|
40
49
|
class GraphViz(Expression):
|
|
41
|
-
def __init__(
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
notebook=True,
|
|
53
|
+
cdn_resources="remote",
|
|
54
|
+
bgcolor="#222222",
|
|
55
|
+
font_color="white",
|
|
56
|
+
height="750px",
|
|
57
|
+
width="100%",
|
|
58
|
+
select_menu=True,
|
|
59
|
+
filter_menu=True,
|
|
60
|
+
**kwargs,
|
|
61
|
+
):
|
|
51
62
|
super().__init__(**kwargs)
|
|
52
|
-
self.net
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
63
|
+
self.net = Network(
|
|
64
|
+
notebook=notebook,
|
|
65
|
+
cdn_resources=cdn_resources,
|
|
66
|
+
bgcolor=bgcolor,
|
|
67
|
+
font_color=font_color,
|
|
68
|
+
height=height,
|
|
69
|
+
width=width,
|
|
70
|
+
select_menu=select_menu,
|
|
71
|
+
filter_menu=filter_menu,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def forward(self, sym: Symbol, file_path: str, **_kwargs):
|
|
62
75
|
nodes = [str(n) if n.value else n.__repr__(simplified=True) for n in sym.nodes]
|
|
63
|
-
edges = [
|
|
64
|
-
|
|
76
|
+
edges = [
|
|
77
|
+
(
|
|
78
|
+
str(e[0]) if e[0].value else e[0].__repr__(simplified=True),
|
|
79
|
+
str(e[1]) if e[1].value else e[1].__repr__(simplified=True),
|
|
80
|
+
)
|
|
81
|
+
for e in sym.edges
|
|
82
|
+
]
|
|
65
83
|
self.net.add_nodes(nodes)
|
|
66
84
|
self.net.add_edges(edges)
|
|
67
|
-
file_path = file_path if file_path.endswith(
|
|
85
|
+
file_path = file_path if file_path.endswith(".html") else file_path + ".html"
|
|
68
86
|
return self.net.show(file_path)
|
|
69
87
|
|
|
70
88
|
|
|
@@ -73,21 +91,21 @@ class TrackerTraceable(Expression):
|
|
|
73
91
|
|
|
74
92
|
|
|
75
93
|
class Any(Expression):
|
|
76
|
-
def __init__(self, *expr:
|
|
94
|
+
def __init__(self, *expr: list[Expression], **kwargs):
|
|
77
95
|
super().__init__(**kwargs)
|
|
78
|
-
self.expr:
|
|
96
|
+
self.expr: list[Expression] = expr
|
|
79
97
|
|
|
80
98
|
def forward(self, *args, **kwargs) -> Symbol:
|
|
81
|
-
return self.sym_return_type(any(
|
|
99
|
+
return self.sym_return_type(any(e() for e in self.expr(*args, **kwargs)))
|
|
82
100
|
|
|
83
101
|
|
|
84
102
|
class All(Expression):
|
|
85
|
-
def __init__(self, *expr:
|
|
103
|
+
def __init__(self, *expr: list[Expression], **kwargs):
|
|
86
104
|
super().__init__(**kwargs)
|
|
87
|
-
self.expr:
|
|
105
|
+
self.expr: list[Expression] = expr
|
|
88
106
|
|
|
89
107
|
def forward(self, *args, **kwargs) -> Symbol:
|
|
90
|
-
return self.sym_return_type(all(
|
|
108
|
+
return self.sym_return_type(all(e() for e in self.expr(*args, **kwargs)))
|
|
91
109
|
|
|
92
110
|
|
|
93
111
|
class Try(Expression):
|
|
@@ -104,12 +122,14 @@ class Try(Expression):
|
|
|
104
122
|
class Lambda(Expression):
|
|
105
123
|
def __init__(self, callable: Callable, **kwargs):
|
|
106
124
|
super().__init__(**kwargs)
|
|
125
|
+
|
|
107
126
|
def _callable(*args, **kwargs):
|
|
108
127
|
kw = {
|
|
109
|
-
|
|
110
|
-
|
|
128
|
+
"args": args,
|
|
129
|
+
"kwargs": kwargs,
|
|
111
130
|
}
|
|
112
131
|
return callable(kw)
|
|
132
|
+
|
|
113
133
|
self.callable: Callable = _callable
|
|
114
134
|
|
|
115
135
|
def forward(self, *args, **kwargs) -> Symbol:
|
|
@@ -117,14 +137,14 @@ class Lambda(Expression):
|
|
|
117
137
|
|
|
118
138
|
|
|
119
139
|
class Choice(Expression):
|
|
120
|
-
def __init__(self, cases:
|
|
140
|
+
def __init__(self, cases: list[str], default: str | None = None, **kwargs):
|
|
121
141
|
super().__init__(**kwargs)
|
|
122
|
-
self.cases:
|
|
123
|
-
self.default:
|
|
142
|
+
self.cases: list[str] = cases
|
|
143
|
+
self.default: str | None = default
|
|
124
144
|
|
|
125
145
|
def forward(self, sym: Symbol, *args, **kwargs) -> Symbol:
|
|
126
146
|
sym = self._to_symbol(sym)
|
|
127
|
-
return sym.choice(cases=self.cases, default=self.default,
|
|
147
|
+
return sym.choice(*args, cases=self.cases, default=self.default, **kwargs)
|
|
128
148
|
|
|
129
149
|
|
|
130
150
|
class Output(Expression):
|
|
@@ -135,15 +155,15 @@ class Output(Expression):
|
|
|
135
155
|
self.verbose: bool = verbose
|
|
136
156
|
|
|
137
157
|
def forward(self, *args, **kwargs) -> Expression:
|
|
138
|
-
kwargs[
|
|
139
|
-
kwargs[
|
|
140
|
-
return self.output(expr=self.expr,
|
|
158
|
+
kwargs["verbose"] = self.verbose
|
|
159
|
+
kwargs["handler"] = self.handler
|
|
160
|
+
return self.output(*args, expr=self.expr, **kwargs)
|
|
141
161
|
|
|
142
162
|
|
|
143
163
|
class Sequence(TrackerTraceable):
|
|
144
|
-
def __init__(self, *expressions:
|
|
164
|
+
def __init__(self, *expressions: list[Expression], **kwargs):
|
|
145
165
|
super().__init__(**kwargs)
|
|
146
|
-
self.expressions:
|
|
166
|
+
self.expressions: list[Expression] = expressions
|
|
147
167
|
|
|
148
168
|
def forward(self, *args, **kwargs) -> Symbol:
|
|
149
169
|
sym = self.expressions[0](*args, **kwargs)
|
|
@@ -159,34 +179,36 @@ class Sequence(TrackerTraceable):
|
|
|
159
179
|
|
|
160
180
|
|
|
161
181
|
class Parallel(Expression):
|
|
162
|
-
def __init__(self, *expr:
|
|
182
|
+
def __init__(self, *expr: list[Expression | Callable], sequential: bool = False, **kwargs):
|
|
163
183
|
super().__init__(**kwargs)
|
|
164
|
-
self.sequential: bool
|
|
165
|
-
self.expr:
|
|
166
|
-
self.results:
|
|
184
|
+
self.sequential: bool = sequential
|
|
185
|
+
self.expr: list[Expression] = expr
|
|
186
|
+
self.results: list[Symbol] = []
|
|
167
187
|
|
|
168
188
|
def forward(self, *args, **kwargs) -> Symbol:
|
|
169
189
|
# run in sequence
|
|
170
190
|
if self.sequential:
|
|
171
191
|
return [e(*args, **kwargs) for e in self.expr]
|
|
192
|
+
|
|
172
193
|
# run in parallel
|
|
173
194
|
@core_ext.parallel(self.expr)
|
|
174
195
|
def _func(e, *args, **kwargs):
|
|
175
196
|
return e(*args, **kwargs)
|
|
197
|
+
|
|
176
198
|
self.results = _func(*args, **kwargs)
|
|
177
199
|
# final result of the parallel execution
|
|
178
200
|
return self._to_symbol(self.results)
|
|
179
201
|
|
|
180
202
|
|
|
181
|
-
|
|
203
|
+
# @TODO: BinPacker(format="...") -> ensure that data packages form a "bin" that's consistent (e.g. never break a sentence in the middle)
|
|
182
204
|
class Stream(Expression):
|
|
183
|
-
def __init__(self, expr:
|
|
205
|
+
def __init__(self, expr: Expression | None = None, retrieval: str | None = None, **kwargs):
|
|
184
206
|
super().__init__(**kwargs)
|
|
185
|
-
self.char_token_ratio:
|
|
186
|
-
self.expr:
|
|
187
|
-
self.retrieval:
|
|
188
|
-
self._trace:
|
|
189
|
-
self._previous_frame
|
|
207
|
+
self.char_token_ratio: float = 0.6
|
|
208
|
+
self.expr: Expression | None = expr
|
|
209
|
+
self.retrieval: str | None = retrieval
|
|
210
|
+
self._trace: bool = False
|
|
211
|
+
self._previous_frame = None
|
|
190
212
|
|
|
191
213
|
def forward(self, sym: Symbol, **kwargs) -> Iterator:
|
|
192
214
|
sym = self._to_symbol(sym)
|
|
@@ -194,30 +216,31 @@ class Stream(Expression):
|
|
|
194
216
|
if self._trace:
|
|
195
217
|
local_vars = self._previous_frame.f_locals
|
|
196
218
|
vals = []
|
|
197
|
-
for
|
|
219
|
+
for _key, var in local_vars.items():
|
|
198
220
|
if isinstance(var, TrackerTraceable):
|
|
199
221
|
vals.append(var)
|
|
200
222
|
|
|
201
223
|
if len(vals) == 1:
|
|
202
224
|
self.expr = vals[0]
|
|
203
225
|
else:
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
226
|
+
UserMessage(
|
|
227
|
+
"This component does either not inherit from TrackerTraceable or has an invalid number of component "
|
|
228
|
+
f"declarations: {len(vals)}! Only one component that inherits from TrackerTraceable is allowed in the "
|
|
229
|
+
"with stream clause.",
|
|
230
|
+
raise_with=ValueError,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
res = sym.stream(expr=self.expr, char_token_ratio=self.char_token_ratio, **kwargs)
|
|
210
234
|
if self.retrieval is not None:
|
|
211
235
|
res = list(res)
|
|
212
|
-
if self.retrieval ==
|
|
236
|
+
if self.retrieval == "all":
|
|
213
237
|
return res
|
|
214
|
-
if self.retrieval ==
|
|
238
|
+
if self.retrieval == "longest":
|
|
215
239
|
res = sorted(res, key=lambda x: len(x), reverse=True)
|
|
216
240
|
return res[0]
|
|
217
|
-
if self.retrieval ==
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
raise ValueError(f"Invalid retrieval method: {self.retrieval}")
|
|
241
|
+
if self.retrieval == "contains":
|
|
242
|
+
return [r for r in res if self.expr in r]
|
|
243
|
+
UserMessage(f"Invalid retrieval method: {self.retrieval}", raise_with=ValueError)
|
|
221
244
|
|
|
222
245
|
return res
|
|
223
246
|
|
|
@@ -231,10 +254,12 @@ class Stream(Expression):
|
|
|
231
254
|
|
|
232
255
|
|
|
233
256
|
class Trace(Expression):
|
|
234
|
-
def __init__(self, expr:
|
|
257
|
+
def __init__(self, expr: Expression | None = None, engines=None, **kwargs):
|
|
258
|
+
if engines is None:
|
|
259
|
+
engines = ["all"]
|
|
235
260
|
super().__init__(**kwargs)
|
|
236
261
|
self.expr: Expression = expr
|
|
237
|
-
self.engines:
|
|
262
|
+
self.engines: list[str] = engines
|
|
238
263
|
|
|
239
264
|
def forward(self, *args, **kwargs) -> Expression:
|
|
240
265
|
Expression.command(verbose=True, engines=self.engines)
|
|
@@ -252,23 +277,26 @@ class Trace(Expression):
|
|
|
252
277
|
Expression.command(verbose=False, engines=self.engines)
|
|
253
278
|
if self.expr is not None:
|
|
254
279
|
return self.expr.__exit__(type, value, traceback)
|
|
280
|
+
return None
|
|
255
281
|
|
|
256
282
|
|
|
257
283
|
class Analyze(Expression):
|
|
258
|
-
def __init__(self, exception: Exception, query:
|
|
284
|
+
def __init__(self, exception: Exception, query: str | None = None, **kwargs):
|
|
259
285
|
super().__init__(**kwargs)
|
|
260
286
|
self.exception: Expression = exception
|
|
261
|
-
self.query:
|
|
287
|
+
self.query: str | None = query
|
|
262
288
|
|
|
263
289
|
def forward(self, sym: Symbol, *args, **kwargs) -> Symbol:
|
|
264
|
-
return sym.analyze(exception=self.exception, query=self.query,
|
|
290
|
+
return sym.analyze(*args, exception=self.exception, query=self.query, **kwargs)
|
|
265
291
|
|
|
266
292
|
|
|
267
293
|
class Log(Expression):
|
|
268
|
-
def __init__(self, expr:
|
|
294
|
+
def __init__(self, expr: Expression | None = None, engines=None, **kwargs):
|
|
295
|
+
if engines is None:
|
|
296
|
+
engines = ["all"]
|
|
269
297
|
super().__init__(**kwargs)
|
|
270
298
|
self.expr: Expression = expr
|
|
271
|
-
self.engines:
|
|
299
|
+
self.engines: list[str] = engines
|
|
272
300
|
|
|
273
301
|
def forward(self, *args, **kwargs) -> Expression:
|
|
274
302
|
Expression.command(logging=True, engines=self.engines)
|
|
@@ -286,10 +314,16 @@ class Log(Expression):
|
|
|
286
314
|
Expression.command(logging=False, engines=self.engines)
|
|
287
315
|
if self.expr is not None:
|
|
288
316
|
return self.expr.__exit__(type, value, traceback)
|
|
317
|
+
return None
|
|
289
318
|
|
|
290
319
|
|
|
291
320
|
class Template(Expression):
|
|
292
|
-
def __init__(
|
|
321
|
+
def __init__(
|
|
322
|
+
self,
|
|
323
|
+
template: str = "<html><body>{{placeholder}}</body></html>",
|
|
324
|
+
placeholder: str = "{{placeholder}}",
|
|
325
|
+
**kwargs,
|
|
326
|
+
):
|
|
293
327
|
super().__init__(**kwargs)
|
|
294
328
|
self.placeholder = placeholder
|
|
295
329
|
self.template_ = template
|
|
@@ -319,22 +353,26 @@ class RuntimeExpression(Expression):
|
|
|
319
353
|
code = self._to_symbol(code)
|
|
320
354
|
# declare the runtime expression from the code
|
|
321
355
|
expr = self.runner(code)
|
|
356
|
+
|
|
322
357
|
def _func(sym):
|
|
323
358
|
# execute nested expression
|
|
324
|
-
return expr[
|
|
359
|
+
return expr["locals"]["_output_"](sym)
|
|
360
|
+
|
|
325
361
|
return _func
|
|
326
362
|
|
|
327
363
|
|
|
328
364
|
class Metric(Expression):
|
|
329
365
|
def __init__(self, normalize: bool = False, eps: float = 1e-8, **kwargs):
|
|
330
366
|
super().__init__(**kwargs)
|
|
331
|
-
self.normalize
|
|
332
|
-
self.eps
|
|
367
|
+
self.normalize = normalize
|
|
368
|
+
self.eps = eps
|
|
333
369
|
|
|
334
|
-
def forward(self, sym: Symbol, **
|
|
370
|
+
def forward(self, sym: Symbol, **_kwargs) -> Symbol:
|
|
335
371
|
sym = self._to_symbol(sym)
|
|
336
|
-
assert sym.value_type
|
|
337
|
-
|
|
372
|
+
assert sym.value_type is np.ndarray or sym.value_type is list, (
|
|
373
|
+
"Metric can only be applied to numpy arrays or lists."
|
|
374
|
+
)
|
|
375
|
+
if sym.value_type is list:
|
|
338
376
|
sym._value = np.array(sym.value)
|
|
339
377
|
# compute normalization between 0 and 1
|
|
340
378
|
if self.normalize:
|
|
@@ -343,17 +381,19 @@ class Metric(Expression):
|
|
|
343
381
|
elif len(sym.value.shape) == 2:
|
|
344
382
|
pass
|
|
345
383
|
else:
|
|
346
|
-
|
|
384
|
+
UserMessage(f"Invalid shape: {sym.value.shape}", raise_with=ValueError)
|
|
347
385
|
# normalize between 0 and 1 and sum to 1
|
|
348
386
|
sym._value = np.exp(sym.value) / (np.exp(sym.value).sum() + self.eps)
|
|
349
387
|
return sym
|
|
350
388
|
|
|
351
389
|
|
|
352
390
|
class Style(Expression):
|
|
353
|
-
def __init__(self, description: str, libraries:
|
|
391
|
+
def __init__(self, description: str, libraries: list[str] | None = None, **kwargs):
|
|
392
|
+
if libraries is None:
|
|
393
|
+
libraries = []
|
|
354
394
|
super().__init__(**kwargs)
|
|
355
395
|
self.description: str = description
|
|
356
|
-
self.libraries:
|
|
396
|
+
self.libraries: list[str] = libraries
|
|
357
397
|
|
|
358
398
|
def forward(self, sym: Symbol, **kwargs) -> Symbol:
|
|
359
399
|
sym = self._to_symbol(sym)
|
|
@@ -365,7 +405,7 @@ class Query(TrackerTraceable):
|
|
|
365
405
|
super().__init__(**kwargs)
|
|
366
406
|
self.prompt: str = prompt
|
|
367
407
|
|
|
368
|
-
def forward(self, sym: Symbol, context: Symbol = None, *
|
|
408
|
+
def forward(self, sym: Symbol, context: Symbol = None, *_args, **kwargs) -> Symbol:
|
|
369
409
|
sym = self._to_symbol(sym)
|
|
370
410
|
return sym.query(prompt=self.prompt, context=context, **kwargs)
|
|
371
411
|
|
|
@@ -397,16 +437,16 @@ _output_ = _func()
|
|
|
397
437
|
|
|
398
438
|
def forward(self, sym: Symbol, enclosure: bool = False, **kwargs) -> Symbol:
|
|
399
439
|
if enclosure or self.enclosure:
|
|
400
|
-
lines = str(sym).split(
|
|
401
|
-
lines = [
|
|
402
|
-
sym =
|
|
403
|
-
sym = self.template.replace(
|
|
440
|
+
lines = str(sym).split("\n")
|
|
441
|
+
lines = [" " + line for line in lines]
|
|
442
|
+
sym = "\n".join(lines)
|
|
443
|
+
sym = self.template.replace("{sym}", str(sym))
|
|
404
444
|
sym = self._to_symbol(sym)
|
|
405
445
|
return sym.execute(**kwargs)
|
|
406
446
|
|
|
407
447
|
|
|
408
448
|
class Convert(Expression):
|
|
409
|
-
def __init__(self, format: str =
|
|
449
|
+
def __init__(self, format: str = "Python", **kwargs):
|
|
410
450
|
super().__init__(**kwargs)
|
|
411
451
|
self.format = format
|
|
412
452
|
|
|
@@ -440,13 +480,13 @@ class Map(Expression):
|
|
|
440
480
|
|
|
441
481
|
|
|
442
482
|
class Translate(Expression):
|
|
443
|
-
def __init__(self, language: str =
|
|
483
|
+
def __init__(self, language: str = "English", **kwargs):
|
|
444
484
|
super().__init__(**kwargs)
|
|
445
485
|
self.language = language
|
|
446
486
|
|
|
447
487
|
def forward(self, sym: Symbol, **kwargs) -> Symbol:
|
|
448
488
|
sym = self._to_symbol(sym)
|
|
449
|
-
if sym.isinstanceof(f
|
|
489
|
+
if sym.isinstanceof(f"{self.language} text"):
|
|
450
490
|
return sym
|
|
451
491
|
return sym.translate(language=self.language, **kwargs)
|
|
452
492
|
|
|
@@ -474,11 +514,11 @@ class ExcludeFilter(Expression):
|
|
|
474
514
|
class FileWriter(Expression):
|
|
475
515
|
def __init__(self, path: str, **kwargs):
|
|
476
516
|
super().__init__(**kwargs)
|
|
477
|
-
self.path = path
|
|
517
|
+
self.path = Path(path)
|
|
478
518
|
|
|
479
|
-
def forward(self, sym: Symbol, **
|
|
519
|
+
def forward(self, sym: Symbol, **_kwargs) -> Symbol:
|
|
480
520
|
sym = self._to_symbol(sym)
|
|
481
|
-
with
|
|
521
|
+
with self.path.open("w") as f:
|
|
482
522
|
f.write(str(sym))
|
|
483
523
|
|
|
484
524
|
|
|
@@ -486,20 +526,18 @@ class FileReader(Expression):
|
|
|
486
526
|
@staticmethod
|
|
487
527
|
def exists(path: str) -> bool:
|
|
488
528
|
# remove slicing if any
|
|
489
|
-
_tmp
|
|
490
|
-
_splits
|
|
491
|
-
if
|
|
529
|
+
_tmp = path
|
|
530
|
+
_splits = _tmp.split("[")
|
|
531
|
+
if "[" in _tmp:
|
|
492
532
|
_tmp = _splits[0]
|
|
493
|
-
assert len(_splits) == 1 or len(_splits) == 2,
|
|
494
|
-
_tmp
|
|
533
|
+
assert len(_splits) == 1 or len(_splits) == 2, "Invalid file link format."
|
|
534
|
+
_tmp = Path(_tmp)
|
|
495
535
|
# check if file exists and is a file
|
|
496
|
-
|
|
497
|
-
return True
|
|
498
|
-
return False
|
|
536
|
+
return _tmp.is_file()
|
|
499
537
|
|
|
500
538
|
@staticmethod
|
|
501
|
-
def get_files(folder_path: str, max_depth: int = 1) ->
|
|
502
|
-
accepted_formats = [
|
|
539
|
+
def get_files(folder_path: str, max_depth: int = 1) -> list[str]:
|
|
540
|
+
accepted_formats = [".pdf", ".md", ".txt"]
|
|
503
541
|
|
|
504
542
|
folder = Path(folder_path)
|
|
505
543
|
files = []
|
|
@@ -512,10 +550,35 @@ class FileReader(Expression):
|
|
|
512
550
|
return files
|
|
513
551
|
|
|
514
552
|
@staticmethod
|
|
515
|
-
def extract_files(cmds: str) ->
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
553
|
+
def extract_files(cmds: str) -> list[str] | None:
|
|
554
|
+
"""
|
|
555
|
+
Extract file paths from a command string, handling various quoting styles.
|
|
556
|
+
|
|
557
|
+
This method is used by the Qdrant RAG implementation when processing document paths.
|
|
558
|
+
It uses regex to parse file paths that may be quoted in different ways.
|
|
559
|
+
|
|
560
|
+
Regex patterns used:
|
|
561
|
+
1. Main pattern: Matches file paths in four formats:
|
|
562
|
+
- Double-quoted: "path/to/file" (handles escaped characters)
|
|
563
|
+
- Single-quoted: 'path/to/file' (handles escaped characters)
|
|
564
|
+
- Backtick-quoted: `path/to/file` (handles escaped characters)
|
|
565
|
+
- Non-quoted: path/to/file (handles escaped spaces)
|
|
566
|
+
|
|
567
|
+
2. Escape removal pattern: r"\\(.)" -> r"\1"
|
|
568
|
+
- Removes backslash escape sequences from quoted paths
|
|
569
|
+
- Example: "path\\/to\\/file" -> "path/to/file"
|
|
570
|
+
- Used for double quotes, single quotes, and backticks
|
|
571
|
+
"""
|
|
572
|
+
# Regex pattern to match file paths in various quoting styles
|
|
573
|
+
# Pattern breakdown:
|
|
574
|
+
# - (?:"((?:\\.|[^"\\])*)") : Matches double-quoted paths, capturing content while handling escapes
|
|
575
|
+
# - '((?:\\.|[^'\\])*)' : Matches single-quoted paths, capturing content while handling escapes
|
|
576
|
+
# - `((?:\\.|[^`\\])*)` : Matches backtick-quoted paths, capturing content while handling escapes
|
|
577
|
+
# - ((?:\\ |[^ ])+) : Matches non-quoted paths, allowing escaped spaces
|
|
578
|
+
pattern = (
|
|
579
|
+
r"""(?:"((?:\\.|[^"\\])*)"|'((?:\\.|[^'\\])*)'|`((?:\\.|[^`\\])*)`|((?:\\ |[^ ])+))"""
|
|
580
|
+
)
|
|
581
|
+
# Use regex to find all file path matches in the command string
|
|
519
582
|
matches = re.findall(pattern, cmds)
|
|
520
583
|
# Process the matches to handle quoted paths and normal paths
|
|
521
584
|
files = []
|
|
@@ -523,23 +586,27 @@ class FileReader(Expression):
|
|
|
523
586
|
# Each match will have 4 groups due to the pattern; only one will be non-empty
|
|
524
587
|
quoted_double, quoted_single, quoted_backtick, non_quoted = match
|
|
525
588
|
if quoted_double:
|
|
526
|
-
# Remove backslashes used for escaping inside double quotes
|
|
527
|
-
|
|
589
|
+
# Regex substitution: Remove backslashes used for escaping inside double quotes
|
|
590
|
+
# Pattern r"\\(.)" matches a backslash followed by any character and replaces with just the character
|
|
591
|
+
# Example: "path\\/to\\/file" -> "path/to/file"
|
|
592
|
+
path = re.sub(r"\\(.)", r"\1", quoted_double)
|
|
528
593
|
file = FileReader.expand_user_path(path)
|
|
529
594
|
files.append(file)
|
|
530
595
|
elif quoted_single:
|
|
531
|
-
# Remove backslashes used for escaping inside single quotes
|
|
532
|
-
|
|
596
|
+
# Regex substitution: Remove backslashes used for escaping inside single quotes
|
|
597
|
+
# Same pattern as above, applied to single-quoted paths
|
|
598
|
+
path = re.sub(r"\\(.)", r"\1", quoted_single)
|
|
533
599
|
file = FileReader.expand_user_path(path)
|
|
534
600
|
files.append(file)
|
|
535
601
|
elif quoted_backtick:
|
|
536
|
-
# Remove backslashes used for escaping inside backticks
|
|
537
|
-
|
|
602
|
+
# Regex substitution: Remove backslashes used for escaping inside backticks
|
|
603
|
+
# Same pattern as above, applied to backtick-quoted paths
|
|
604
|
+
path = re.sub(r"\\(.)", r"\1", quoted_backtick)
|
|
538
605
|
file = FileReader.expand_user_path(path)
|
|
539
606
|
files.append(file)
|
|
540
607
|
elif non_quoted:
|
|
541
|
-
# Replace escaped spaces with actual spaces
|
|
542
|
-
path = non_quoted.replace(
|
|
608
|
+
# Replace escaped spaces with actual spaces (no regex needed here, simple string replace)
|
|
609
|
+
path = non_quoted.replace("\\ ", " ")
|
|
543
610
|
file = FileReader.expand_user_path(path)
|
|
544
611
|
files.append(file)
|
|
545
612
|
# Filter out any files that do not exist
|
|
@@ -551,31 +618,34 @@ class FileReader(Expression):
|
|
|
551
618
|
return Path(path).expanduser().resolve().as_posix()
|
|
552
619
|
|
|
553
620
|
@staticmethod
|
|
554
|
-
def integrity_check(files:
|
|
621
|
+
def integrity_check(files: list[str]) -> list[str]:
|
|
555
622
|
not_skipped = []
|
|
556
623
|
for file in tqdm(files):
|
|
557
624
|
if FileReader.exists(file):
|
|
558
625
|
not_skipped.append(file)
|
|
559
626
|
else:
|
|
560
|
-
|
|
627
|
+
UserMessage(f"Skipping file: {file}")
|
|
561
628
|
return not_skipped
|
|
562
629
|
|
|
563
|
-
def forward(self, files:
|
|
630
|
+
def forward(self, files: str | list[str], **kwargs) -> Expression:
|
|
564
631
|
if isinstance(files, str):
|
|
565
632
|
# Convert to list for uniform processing; more easily downstream
|
|
566
633
|
files = [files]
|
|
567
|
-
if kwargs.get(
|
|
634
|
+
if kwargs.get("run_integrity_check"):
|
|
568
635
|
files = self.integrity_check(files)
|
|
569
636
|
return self.sym_return_type([self.open(f, **kwargs).value for f in files])
|
|
570
637
|
|
|
638
|
+
|
|
571
639
|
class FileQuery(Expression):
|
|
572
640
|
def __init__(self, path: str, filter: str, **kwargs):
|
|
573
641
|
super().__init__(**kwargs)
|
|
574
642
|
self.path = path
|
|
575
643
|
file_open = FileReader()
|
|
576
|
-
self.query_stream = Stream(
|
|
577
|
-
|
|
578
|
-
|
|
644
|
+
self.query_stream = Stream(
|
|
645
|
+
Sequence(
|
|
646
|
+
IncludeFilter(filter),
|
|
647
|
+
)
|
|
648
|
+
)
|
|
579
649
|
self.file = file_open(path)
|
|
580
650
|
|
|
581
651
|
def forward(self, sym: Symbol, **kwargs) -> Symbol:
|
|
@@ -585,40 +655,45 @@ class FileQuery(Expression):
|
|
|
585
655
|
|
|
586
656
|
|
|
587
657
|
class Function(TrackerTraceable):
|
|
588
|
-
def __init__(
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
658
|
+
def __init__(
|
|
659
|
+
self,
|
|
660
|
+
prompt: str = "",
|
|
661
|
+
examples: str | None = [],
|
|
662
|
+
pre_processors: list[PreProcessor] | None = None,
|
|
663
|
+
post_processors: list[PostProcessor] | None = None,
|
|
664
|
+
default: object | None = None,
|
|
665
|
+
constraints: list[Callable] | None = None,
|
|
666
|
+
return_type: type | None = str,
|
|
667
|
+
sym_return_type: type | None = Symbol,
|
|
668
|
+
origin_type: type | None = Expression,
|
|
669
|
+
*args,
|
|
670
|
+
**kwargs,
|
|
671
|
+
):
|
|
672
|
+
if constraints is None:
|
|
673
|
+
constraints = []
|
|
598
674
|
super().__init__(**kwargs)
|
|
599
|
-
chars
|
|
600
|
-
self.name
|
|
601
|
-
self.args
|
|
675
|
+
chars = ascii_lowercase + ascii_uppercase
|
|
676
|
+
self.name = "func_" + "".join(sample(chars, 15))
|
|
677
|
+
self.args = args
|
|
602
678
|
self.kwargs = kwargs
|
|
603
|
-
self._promptTemplate
|
|
604
|
-
self._promptFormatArgs
|
|
679
|
+
self._promptTemplate = prompt
|
|
680
|
+
self._promptFormatArgs = []
|
|
605
681
|
self._promptFormatKwargs = {}
|
|
606
|
-
self.examples
|
|
607
|
-
self.pre_processors
|
|
682
|
+
self.examples = Prompt(examples)
|
|
683
|
+
self.pre_processors = pre_processors
|
|
608
684
|
self.post_processors = post_processors
|
|
609
|
-
self.constraints
|
|
610
|
-
self.default
|
|
611
|
-
self.return_type
|
|
685
|
+
self.constraints = constraints
|
|
686
|
+
self.default = default
|
|
687
|
+
self.return_type = return_type
|
|
612
688
|
self.sym_return_type = sym_return_type
|
|
613
|
-
self.origin_type
|
|
689
|
+
self.origin_type = origin_type
|
|
614
690
|
|
|
615
691
|
@property
|
|
616
692
|
def prompt(self):
|
|
617
693
|
# return a copy of the prompt template
|
|
618
694
|
if len(self._promptFormatArgs) == 0 and len(self._promptFormatKwargs) == 0:
|
|
619
695
|
return self._promptTemplate
|
|
620
|
-
return f"{self._promptTemplate}".format(*self._promptFormatArgs,
|
|
621
|
-
**self._promptFormatKwargs)
|
|
696
|
+
return f"{self._promptTemplate}".format(*self._promptFormatArgs, **self._promptFormatKwargs)
|
|
622
697
|
|
|
623
698
|
def format(self, *args, **kwargs):
|
|
624
699
|
self._promptFormatArgs = args
|
|
@@ -626,27 +701,36 @@ class Function(TrackerTraceable):
|
|
|
626
701
|
|
|
627
702
|
def forward(self, *args, **kwargs) -> Expression:
|
|
628
703
|
# special case for few shot function prompt definition override
|
|
629
|
-
if
|
|
630
|
-
self.prompt = kwargs[
|
|
631
|
-
del kwargs[
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
704
|
+
if "fn" in kwargs:
|
|
705
|
+
self.prompt = kwargs["fn"]
|
|
706
|
+
del kwargs["fn"]
|
|
707
|
+
|
|
708
|
+
@core.few_shot(
|
|
709
|
+
*self.args,
|
|
710
|
+
prompt=self.prompt,
|
|
711
|
+
examples=self.examples,
|
|
712
|
+
pre_processors=self.pre_processors,
|
|
713
|
+
post_processors=self.post_processors,
|
|
714
|
+
constraints=self.constraints,
|
|
715
|
+
default=self.default,
|
|
716
|
+
**self.kwargs,
|
|
717
|
+
)
|
|
639
718
|
def _func(_, *args, **kwargs) -> self.return_type:
|
|
640
719
|
pass
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
720
|
+
|
|
721
|
+
_type = type(
|
|
722
|
+
self.name,
|
|
723
|
+
(self.origin_type,),
|
|
724
|
+
{
|
|
725
|
+
# constructor
|
|
726
|
+
"forward": _func,
|
|
727
|
+
"sym_return_type": self.sym_return_type,
|
|
728
|
+
"static_context": self.static_context,
|
|
729
|
+
"dynamic_context": self.dynamic_context,
|
|
730
|
+
"__class__": self.__class__,
|
|
731
|
+
"__module__": self.__module__,
|
|
732
|
+
},
|
|
733
|
+
)
|
|
650
734
|
obj = _type()
|
|
651
735
|
|
|
652
736
|
return self._to_symbol(obj(*args, **kwargs))
|
|
@@ -657,19 +741,19 @@ class PrepareData(Function):
|
|
|
657
741
|
def __call__(self, argument):
|
|
658
742
|
assert argument.prop.context is not None
|
|
659
743
|
instruct = argument.prop.prompt
|
|
660
|
-
context
|
|
661
|
-
return """{
|
|
662
|
-
'context': '
|
|
663
|
-
'instruction': '
|
|
744
|
+
context = argument.prop.context
|
|
745
|
+
return f"""{{
|
|
746
|
+
'context': '{context}',
|
|
747
|
+
'instruction': '{instruct}',
|
|
664
748
|
'result': 'TODO: Replace this with the expected result.'
|
|
665
|
-
}"""
|
|
749
|
+
}}"""
|
|
666
750
|
|
|
667
751
|
def __init__(self, *args, **kwargs):
|
|
668
752
|
super().__init__(*args, **kwargs)
|
|
669
|
-
self.pre_processors
|
|
670
|
-
self.constraints
|
|
753
|
+
self.pre_processors = [self.PrepareDataPreProcessor()]
|
|
754
|
+
self.constraints = [DictFormatConstraint({"result": "<the data>"})]
|
|
671
755
|
self.post_processors = [JsonTruncateMarkdownPostProcessor()]
|
|
672
|
-
self.return_type
|
|
756
|
+
self.return_type = dict # constraint to cast the result to a dict
|
|
673
757
|
|
|
674
758
|
@property
|
|
675
759
|
def static_context(self):
|
|
@@ -704,10 +788,10 @@ Your goal is to prepare the data for the next task instruction. The data should
|
|
|
704
788
|
|
|
705
789
|
class ExpressionBuilder(Function):
|
|
706
790
|
def __init__(self, **kwargs):
|
|
707
|
-
super().__init__(
|
|
791
|
+
super().__init__("Generate the code following the instructions:", **kwargs)
|
|
708
792
|
self.processors = ProcessorPipeline([StripPostProcessor(), CodeExtractPostProcessor()])
|
|
709
793
|
|
|
710
|
-
def forward(self, instruct, *
|
|
794
|
+
def forward(self, instruct, *_args, **_kwargs):
|
|
711
795
|
result = super().forward(instruct)
|
|
712
796
|
return self.processors(str(result), None)
|
|
713
797
|
|
|
@@ -755,10 +839,12 @@ Always produce the entire code to be executed in the same Python process. All ta
|
|
|
755
839
|
class JsonParser(Expression):
|
|
756
840
|
def __init__(self, query: str, json_: dict, **kwargs):
|
|
757
841
|
super().__init__(**kwargs)
|
|
758
|
-
func = Function(
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
842
|
+
func = Function(
|
|
843
|
+
prompt=JsonPromptTemplate(query, json_),
|
|
844
|
+
constraints=[DictFormatConstraint(json_)],
|
|
845
|
+
pre_processors=[JsonPreProcessor()],
|
|
846
|
+
post_processors=[JsonTruncatePostProcessor()],
|
|
847
|
+
)
|
|
762
848
|
self.fn = Try(func, retries=1)
|
|
763
849
|
|
|
764
850
|
def forward(self, sym: Symbol, **kwargs) -> Symbol:
|
|
@@ -768,21 +854,27 @@ class JsonParser(Expression):
|
|
|
768
854
|
|
|
769
855
|
|
|
770
856
|
class SimilarityClassification(Expression):
|
|
771
|
-
def __init__(
|
|
857
|
+
def __init__(
|
|
858
|
+
self, classes: list[str], metric: str = "cosine", in_memory: bool = False, **kwargs
|
|
859
|
+
):
|
|
772
860
|
super().__init__(**kwargs)
|
|
773
|
-
self.classes
|
|
774
|
-
self.metric
|
|
861
|
+
self.classes = classes
|
|
862
|
+
self.metric = metric
|
|
775
863
|
self.in_memory = in_memory
|
|
776
864
|
|
|
777
865
|
if self.in_memory:
|
|
778
|
-
|
|
866
|
+
UserMessage(
|
|
867
|
+
f"Caching mode is enabled! It is your responsability to empty the .cache folder if you did changes to the classes. The cache is located at {HOME_PATH}/cache"
|
|
868
|
+
)
|
|
779
869
|
|
|
780
870
|
def forward(self, x: Symbol) -> Symbol:
|
|
781
|
-
x
|
|
782
|
-
usr_embed
|
|
783
|
-
embeddings
|
|
871
|
+
x = self._to_symbol(x)
|
|
872
|
+
usr_embed = x.embed()
|
|
873
|
+
embeddings = self._dynamic_cache()
|
|
784
874
|
similarities = [usr_embed.similarity(emb, metric=self.metric) for emb in embeddings]
|
|
785
|
-
similarities = sorted(
|
|
875
|
+
similarities = sorted(
|
|
876
|
+
zip(self.classes, similarities, strict=False), key=lambda x: x[1], reverse=True
|
|
877
|
+
)
|
|
786
878
|
|
|
787
879
|
return Symbol(similarities[0][0])
|
|
788
880
|
|
|
@@ -790,9 +882,7 @@ class SimilarityClassification(Expression):
|
|
|
790
882
|
@core_ext.cache(in_memory=self.in_memory)
|
|
791
883
|
def embed_classes(self):
|
|
792
884
|
opts = map(Symbol, self.classes)
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
return embeddings
|
|
885
|
+
return [opt.embed() for opt in opts]
|
|
796
886
|
|
|
797
887
|
return embed_classes(self)
|
|
798
888
|
|
|
@@ -803,11 +893,7 @@ class InContextClassification(Expression):
|
|
|
803
893
|
self.blueprint = blueprint
|
|
804
894
|
|
|
805
895
|
def forward(self, x: Symbol, **kwargs) -> Symbol:
|
|
806
|
-
@core.few_shot(
|
|
807
|
-
prompt=x,
|
|
808
|
-
examples=self.blueprint,
|
|
809
|
-
**kwargs
|
|
810
|
-
)
|
|
896
|
+
@core.few_shot(prompt=x, examples=self.blueprint, **kwargs)
|
|
811
897
|
def _func(_):
|
|
812
898
|
pass
|
|
813
899
|
|
|
@@ -815,43 +901,38 @@ class InContextClassification(Expression):
|
|
|
815
901
|
|
|
816
902
|
|
|
817
903
|
class Indexer(Expression):
|
|
818
|
-
DEFAULT =
|
|
904
|
+
DEFAULT = "dataindex"
|
|
819
905
|
|
|
820
906
|
@staticmethod
|
|
821
907
|
def replace_special_chars(index: str):
|
|
822
908
|
# replace special characters that are not for path
|
|
823
|
-
|
|
824
|
-
index = index.replace('-', '')
|
|
825
|
-
index = index.replace('_', '')
|
|
826
|
-
index = index.replace(' ', '')
|
|
827
|
-
index = index.lower()
|
|
828
|
-
return index
|
|
909
|
+
return str(index).replace("-", "").replace("_", "").replace(" ", "").lower()
|
|
829
910
|
|
|
830
911
|
def __init__(
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
912
|
+
self,
|
|
913
|
+
index_name: str = DEFAULT,
|
|
914
|
+
top_k: int = 8,
|
|
915
|
+
batch_size: int = 20,
|
|
916
|
+
formatter: Callable = _DEFAULT_PARAGRAPH_FORMATTER,
|
|
917
|
+
auto_add=False,
|
|
918
|
+
raw_result: bool = False,
|
|
919
|
+
new_dim: int = 1536,
|
|
920
|
+
**kwargs,
|
|
921
|
+
):
|
|
841
922
|
super().__init__(**kwargs)
|
|
842
923
|
index_name = Indexer.replace_special_chars(index_name)
|
|
843
924
|
self.index_name = index_name
|
|
844
|
-
self.elements
|
|
925
|
+
self.elements = []
|
|
845
926
|
self.batch_size = batch_size
|
|
846
|
-
self.top_k
|
|
847
|
-
self.retrieval
|
|
848
|
-
self.formatter
|
|
927
|
+
self.top_k = top_k
|
|
928
|
+
self.retrieval = None
|
|
929
|
+
self.formatter = formatter
|
|
849
930
|
self.raw_result = raw_result
|
|
850
|
-
self.new_dim
|
|
931
|
+
self.new_dim = new_dim
|
|
851
932
|
self.sym_return_type = Expression
|
|
852
933
|
|
|
853
934
|
# append index name to indices.txt in home directory .symai folder (default)
|
|
854
|
-
self.path = HOME_PATH /
|
|
935
|
+
self.path = HOME_PATH / "indices.txt"
|
|
855
936
|
if not self.path.exists():
|
|
856
937
|
self.path.parent.mkdir(parents=True, exist_ok=True)
|
|
857
938
|
self.path.touch()
|
|
@@ -861,52 +942,63 @@ class Indexer(Expression):
|
|
|
861
942
|
def register(self):
|
|
862
943
|
# check if index already exists in indices.txt and append if not
|
|
863
944
|
change = False
|
|
864
|
-
with
|
|
865
|
-
indices = f.read().split(
|
|
945
|
+
with self.path.open() as f:
|
|
946
|
+
indices = f.read().split("\n")
|
|
866
947
|
# filter out empty strings
|
|
867
948
|
indices = [i for i in indices if i]
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
949
|
+
if self.index_name not in indices:
|
|
950
|
+
indices.append(self.index_name)
|
|
951
|
+
change = True
|
|
871
952
|
if change:
|
|
872
|
-
with
|
|
873
|
-
f.write(
|
|
953
|
+
with self.path.open("w") as f:
|
|
954
|
+
f.write("\n".join(indices))
|
|
874
955
|
|
|
875
956
|
def exists(self) -> bool:
|
|
876
957
|
# check if index exists in home directory .symai folder (default) indices.txt
|
|
877
|
-
path = HOME_PATH /
|
|
958
|
+
path = HOME_PATH / "indices.txt"
|
|
878
959
|
if not path.exists():
|
|
879
960
|
return False
|
|
880
|
-
with open(
|
|
881
|
-
indices = f.read().split(
|
|
961
|
+
with path.open() as f:
|
|
962
|
+
indices = f.read().split("\n")
|
|
882
963
|
if self.index_name in indices:
|
|
883
964
|
return True
|
|
965
|
+
return False
|
|
884
966
|
|
|
885
967
|
def forward(
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
968
|
+
self,
|
|
969
|
+
data: Symbol | None = None,
|
|
970
|
+
_raw_result: bool = False,
|
|
971
|
+
) -> Symbol:
|
|
890
972
|
that = self
|
|
891
973
|
if data is not None:
|
|
892
974
|
data = self._to_symbol(data)
|
|
893
975
|
self.elements = self.formatter(data).value
|
|
894
976
|
# run over the elments in batches
|
|
895
977
|
for i in tqdm(range(0, len(self.elements), self.batch_size)):
|
|
896
|
-
val = Symbol(self.elements[i:i+self.batch_size]).zip(new_dim=self.new_dim)
|
|
978
|
+
val = Symbol(self.elements[i : i + self.batch_size]).zip(new_dim=self.new_dim)
|
|
897
979
|
that.add(val, index_name=that.index_name, index_dims=that.new_dim)
|
|
898
980
|
# we save the index
|
|
899
981
|
that.config(None, save=True, index_name=that.index_name, index_dims=that.new_dim)
|
|
900
982
|
|
|
901
|
-
def _func(query, *
|
|
902
|
-
raw_result = kwargs.get(
|
|
983
|
+
def _func(query, *_args, **kwargs) -> Union[Symbol, "VectorDBResult"]:
|
|
984
|
+
raw_result = kwargs.get("raw_result") or that.raw_result
|
|
903
985
|
query_emb = Symbol(query).embed(new_dim=that.new_dim).value
|
|
904
|
-
res = that.get(
|
|
986
|
+
res = that.get(
|
|
987
|
+
query_emb,
|
|
988
|
+
index_name=that.index_name,
|
|
989
|
+
index_top_k=that.top_k,
|
|
990
|
+
ori_query=query,
|
|
991
|
+
index_dims=that.new_dim,
|
|
992
|
+
**kwargs,
|
|
993
|
+
)
|
|
905
994
|
that.retrieval = res
|
|
906
995
|
if raw_result:
|
|
907
996
|
return res
|
|
908
|
-
|
|
909
|
-
|
|
997
|
+
return Symbol(res).query(
|
|
998
|
+
prompt="From the retrieved data, select the most relevant information.",
|
|
999
|
+
context=query,
|
|
1000
|
+
)
|
|
1001
|
+
|
|
910
1002
|
return _func
|
|
911
1003
|
|
|
912
1004
|
|
|
@@ -917,8 +1009,8 @@ class PrimitiveDisabler(Expression):
|
|
|
917
1009
|
self._original_primitives = defaultdict(list)
|
|
918
1010
|
|
|
919
1011
|
def __enter__(self):
|
|
920
|
-
#
|
|
921
|
-
from .symbol import Symbol
|
|
1012
|
+
# Import Symbol lazily so components does not clash with symbol during load.
|
|
1013
|
+
from .symbol import Symbol # noqa
|
|
922
1014
|
|
|
923
1015
|
frame = inspect.currentframe()
|
|
924
1016
|
f_locals = frame.f_back.f_locals
|
|
@@ -934,7 +1026,7 @@ class PrimitiveDisabler(Expression):
|
|
|
934
1026
|
for func in self._primitives:
|
|
935
1027
|
if hasattr(sym, func):
|
|
936
1028
|
self._original_primitives[sym_name].append((func, getattr(sym, func)))
|
|
937
|
-
setattr(sym, func, lambda *
|
|
1029
|
+
setattr(sym, func, lambda *_args, **_kwargs: None)
|
|
938
1030
|
|
|
939
1031
|
def _enable_primitives(self):
|
|
940
1032
|
for sym_name, sym in self._symbols.items():
|
|
@@ -945,7 +1037,7 @@ class PrimitiveDisabler(Expression):
|
|
|
945
1037
|
for sym in self._symbols.values():
|
|
946
1038
|
for primitive in sym._primitives:
|
|
947
1039
|
for method, _ in inspect.getmembers(primitive, predicate=inspect.isfunction):
|
|
948
|
-
if method in self._primitives or method.startswith(
|
|
1040
|
+
if method in self._primitives or method.startswith("_"):
|
|
949
1041
|
continue
|
|
950
1042
|
self._primitives.add(method)
|
|
951
1043
|
|
|
@@ -968,7 +1060,7 @@ class FunctionWithUsage(Function):
|
|
|
968
1060
|
|
|
969
1061
|
def print_verbose(self, msg):
|
|
970
1062
|
if self.verbose:
|
|
971
|
-
|
|
1063
|
+
UserMessage(msg)
|
|
972
1064
|
|
|
973
1065
|
def _format_usage(self, prompt_tokens, completion_tokens, total_tokens):
|
|
974
1066
|
return Box(
|
|
@@ -990,9 +1082,7 @@ class FunctionWithUsage(Function):
|
|
|
990
1082
|
self.total_tokens += usage.total_tokens
|
|
991
1083
|
|
|
992
1084
|
def get_usage(self):
|
|
993
|
-
return self._format_usage(
|
|
994
|
-
self.prompt_tokens, self.completion_tokens, self.total_tokens
|
|
995
|
-
)
|
|
1085
|
+
return self._format_usage(self.prompt_tokens, self.completion_tokens, self.total_tokens)
|
|
996
1086
|
|
|
997
1087
|
def forward(self, *args, **kwargs):
|
|
998
1088
|
if "return_metadata" not in kwargs:
|
|
@@ -1003,9 +1093,7 @@ class FunctionWithUsage(Function):
|
|
|
1003
1093
|
raw_output = metadata.get("raw_output")
|
|
1004
1094
|
if hasattr(raw_output, "usage"):
|
|
1005
1095
|
usage = raw_output.usage
|
|
1006
|
-
prompt_tokens = (
|
|
1007
|
-
usage.prompt_tokens if hasattr(usage, "prompt_tokens") else 0
|
|
1008
|
-
)
|
|
1096
|
+
prompt_tokens = usage.prompt_tokens if hasattr(usage, "prompt_tokens") else 0
|
|
1009
1097
|
completion_tokens = (
|
|
1010
1098
|
usage.completion_tokens if hasattr(usage, "completion_tokens") else 0
|
|
1011
1099
|
)
|
|
@@ -1020,28 +1108,29 @@ class FunctionWithUsage(Function):
|
|
|
1020
1108
|
self.completion_tokens += completion_tokens
|
|
1021
1109
|
self.total_tokens += total_tokens
|
|
1022
1110
|
else:
|
|
1023
|
-
if self.missing_usage_exception and
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
|
|
1111
|
+
if self.missing_usage_exception and "preview" not in kwargs:
|
|
1112
|
+
UserMessage(
|
|
1113
|
+
"Missing usage in metadata of neursymbolic engine", raise_with=Exception
|
|
1114
|
+
)
|
|
1115
|
+
prompt_tokens = 0
|
|
1116
|
+
completion_tokens = 0
|
|
1117
|
+
total_tokens = 0
|
|
1029
1118
|
|
|
1030
1119
|
return res, self._format_usage(prompt_tokens, completion_tokens, total_tokens)
|
|
1031
1120
|
|
|
1032
1121
|
|
|
1033
1122
|
class SelfPrompt(Expression):
|
|
1034
|
-
_default_retry_tries
|
|
1035
|
-
_default_retry_delay
|
|
1123
|
+
_default_retry_tries = 20
|
|
1124
|
+
_default_retry_delay = 0.5
|
|
1036
1125
|
_default_retry_max_delay = -1
|
|
1037
|
-
_default_retry_backoff
|
|
1038
|
-
_default_retry_jitter
|
|
1039
|
-
_default_retry_graceful
|
|
1126
|
+
_default_retry_backoff = 1
|
|
1127
|
+
_default_retry_jitter = 0
|
|
1128
|
+
_default_retry_graceful = True
|
|
1040
1129
|
|
|
1041
1130
|
def __init__(self, *args, **kwargs):
|
|
1042
1131
|
super().__init__(*args, **kwargs)
|
|
1043
1132
|
|
|
1044
|
-
def forward(self, existing_prompt:
|
|
1133
|
+
def forward(self, existing_prompt: dict[str, str], **kwargs) -> dict[str, str]:
|
|
1045
1134
|
"""
|
|
1046
1135
|
Generate new system and user prompts based on the existing prompt.
|
|
1047
1136
|
|
|
@@ -1050,14 +1139,21 @@ class SelfPrompt(Expression):
|
|
|
1050
1139
|
:return: A dictionary containing the new prompts in the same format:
|
|
1051
1140
|
{'user': '...', 'system': '...'}
|
|
1052
1141
|
"""
|
|
1053
|
-
tries
|
|
1054
|
-
delay
|
|
1055
|
-
max_delay = kwargs.get(
|
|
1056
|
-
backoff
|
|
1057
|
-
jitter
|
|
1058
|
-
graceful
|
|
1059
|
-
|
|
1060
|
-
@core_ext.retry(
|
|
1142
|
+
tries = kwargs.get("tries", self._default_retry_tries)
|
|
1143
|
+
delay = kwargs.get("delay", self._default_retry_delay)
|
|
1144
|
+
max_delay = kwargs.get("max_delay", self._default_retry_max_delay)
|
|
1145
|
+
backoff = kwargs.get("backoff", self._default_retry_backoff)
|
|
1146
|
+
jitter = kwargs.get("jitter", self._default_retry_jitter)
|
|
1147
|
+
graceful = kwargs.get("graceful", self._default_retry_graceful)
|
|
1148
|
+
|
|
1149
|
+
@core_ext.retry(
|
|
1150
|
+
tries=tries,
|
|
1151
|
+
delay=delay,
|
|
1152
|
+
max_delay=max_delay,
|
|
1153
|
+
backoff=backoff,
|
|
1154
|
+
jitter=jitter,
|
|
1155
|
+
graceful=graceful,
|
|
1156
|
+
)
|
|
1061
1157
|
@core.zero_shot(
|
|
1062
1158
|
prompt=(
|
|
1063
1159
|
"Based on the following prompt, generate a new system (or developer) prompt and a new user prompt. "
|
|
@@ -1066,18 +1162,19 @@ class SelfPrompt(Expression):
|
|
|
1066
1162
|
"The new user prompt should contain the user's requirements. "
|
|
1067
1163
|
"Check if the input contains a 'system' or 'developer' key and use the same key in your output. "
|
|
1068
1164
|
"Only output the new prompts in JSON format as shown:\n\n"
|
|
1069
|
-
|
|
1165
|
+
'{"system": "<new system prompt>", "user": "<new user prompt>"}\n\n'
|
|
1070
1166
|
"OR\n\n"
|
|
1071
|
-
|
|
1167
|
+
'{"developer": "<new developer prompt>", "user": "<new user prompt>"}\n\n'
|
|
1072
1168
|
"Maintain the same key structure as in the input prompt. Do not include any additional text."
|
|
1073
1169
|
),
|
|
1074
1170
|
response_format={"type": "json_object"},
|
|
1075
1171
|
post_processors=[
|
|
1076
1172
|
lambda res, _: json.loads(res),
|
|
1077
1173
|
],
|
|
1078
|
-
**kwargs
|
|
1174
|
+
**kwargs,
|
|
1079
1175
|
)
|
|
1080
|
-
def _func(self, sym: Symbol):
|
|
1176
|
+
def _func(self, sym: Symbol):
|
|
1177
|
+
pass
|
|
1081
1178
|
|
|
1082
1179
|
return _func(self, self._to_symbol(existing_prompt))
|
|
1083
1180
|
|
|
@@ -1093,16 +1190,19 @@ class MetadataTracker(Expression):
|
|
|
1093
1190
|
def __str__(self, value=None):
|
|
1094
1191
|
value = value or self.metadata
|
|
1095
1192
|
if isinstance(value, dict):
|
|
1096
|
-
return
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
|
|
1193
|
+
return (
|
|
1194
|
+
"{\n\t"
|
|
1195
|
+
+ ", \n\t".join(f'"{k}": {self.__str__(v)}' for k, v in value.items())
|
|
1196
|
+
+ "\n}"
|
|
1197
|
+
)
|
|
1198
|
+
if isinstance(value, list):
|
|
1199
|
+
return "[" + ", ".join(self.__str__(item) for item in value) + "]"
|
|
1200
|
+
if isinstance(value, str):
|
|
1100
1201
|
return f'"{value}"'
|
|
1101
|
-
|
|
1102
|
-
return f"\n\t {value}"
|
|
1202
|
+
return f"\n\t {value}"
|
|
1103
1203
|
|
|
1104
|
-
def __new__(cls, *
|
|
1105
|
-
cls._lock = getattr(cls,
|
|
1204
|
+
def __new__(cls, *_args, **_kwargs):
|
|
1205
|
+
cls._lock = getattr(cls, "_lock", Lock())
|
|
1106
1206
|
with cls._lock:
|
|
1107
1207
|
instance = super().__new__(cls)
|
|
1108
1208
|
instance._metadata = {}
|
|
@@ -1122,25 +1222,26 @@ class MetadataTracker(Expression):
|
|
|
1122
1222
|
|
|
1123
1223
|
def _trace_calls(self, frame, event, arg):
|
|
1124
1224
|
if not self._trace:
|
|
1125
|
-
return
|
|
1225
|
+
return None
|
|
1126
1226
|
|
|
1127
|
-
if
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1227
|
+
if (
|
|
1228
|
+
event == "return"
|
|
1229
|
+
and frame.f_code.co_name == "forward"
|
|
1230
|
+
and "self" in frame.f_locals
|
|
1231
|
+
and isinstance(frame.f_locals["self"], Engine)
|
|
1232
|
+
):
|
|
1233
|
+
_, metadata = arg # arg contains return value on 'return' event
|
|
1234
|
+
engine_name = frame.f_locals["self"].__class__.__name__
|
|
1235
|
+
model_name = frame.f_locals["self"].model
|
|
1236
|
+
self._metadata[(self._metadata_id, engine_name, model_name)] = metadata
|
|
1237
|
+
self._metadata_id += 1
|
|
1137
1238
|
|
|
1138
1239
|
return self._trace_calls
|
|
1139
1240
|
|
|
1140
1241
|
def _accumulate_completion_token_details(self):
|
|
1141
1242
|
"""Parses the return object and accumulates completion token details per token type"""
|
|
1142
1243
|
if not self._metadata:
|
|
1143
|
-
|
|
1244
|
+
UserMessage("No metadata available to generate usage details.")
|
|
1144
1245
|
return {}
|
|
1145
1246
|
|
|
1146
1247
|
token_details = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
|
|
@@ -1151,46 +1252,149 @@ class MetadataTracker(Expression):
|
|
|
1151
1252
|
try:
|
|
1152
1253
|
if engine_name == "GroqEngine":
|
|
1153
1254
|
usage = metadata["raw_output"].usage
|
|
1154
|
-
token_details[(engine_name, model_name)]["usage"]["completion_tokens"] +=
|
|
1155
|
-
|
|
1156
|
-
|
|
1255
|
+
token_details[(engine_name, model_name)]["usage"]["completion_tokens"] += (
|
|
1256
|
+
usage.completion_tokens
|
|
1257
|
+
)
|
|
1258
|
+
token_details[(engine_name, model_name)]["usage"]["prompt_tokens"] += (
|
|
1259
|
+
usage.prompt_tokens
|
|
1260
|
+
)
|
|
1261
|
+
token_details[(engine_name, model_name)]["usage"]["total_tokens"] += (
|
|
1262
|
+
usage.total_tokens
|
|
1263
|
+
)
|
|
1157
1264
|
token_details[(engine_name, model_name)]["usage"]["total_calls"] += 1
|
|
1158
1265
|
#!: Backward compatibility for components like `RuntimeInfo`
|
|
1159
|
-
token_details[(engine_name, model_name)]["prompt_breakdown"][
|
|
1160
|
-
|
|
1266
|
+
token_details[(engine_name, model_name)]["prompt_breakdown"][
|
|
1267
|
+
"cached_tokens"
|
|
1268
|
+
] += 0 # Assignment not allowed with defualtdict
|
|
1269
|
+
token_details[(engine_name, model_name)]["completion_breakdown"][
|
|
1270
|
+
"reasoning_tokens"
|
|
1271
|
+
] += 0
|
|
1272
|
+
elif engine_name == "ParallelEngine":
|
|
1273
|
+
token_details[(engine_name, None)]["usage"]["total_calls"] += 1
|
|
1274
|
+
# There are no model-specific tokens for this engine
|
|
1275
|
+
token_details[(engine_name, None)]["usage"]["completion_tokens"] += 0
|
|
1276
|
+
token_details[(engine_name, None)]["usage"]["prompt_tokens"] += 0
|
|
1277
|
+
token_details[(engine_name, None)]["usage"]["total_tokens"] += 0
|
|
1278
|
+
#!: Backward compatibility for components like `RuntimeInfo`
|
|
1279
|
+
token_details[(engine_name, None)]["prompt_breakdown"]["cached_tokens"] += (
|
|
1280
|
+
0 # Assignment not allowed with defualtdict
|
|
1281
|
+
)
|
|
1282
|
+
token_details[(engine_name, None)]["completion_breakdown"][
|
|
1283
|
+
"reasoning_tokens"
|
|
1284
|
+
] += 0
|
|
1161
1285
|
elif engine_name in ("GPTXChatEngine", "GPTXReasoningEngine"):
|
|
1162
1286
|
usage = metadata["raw_output"].usage
|
|
1163
|
-
token_details[(engine_name, model_name)]["usage"]["completion_tokens"] +=
|
|
1164
|
-
|
|
1165
|
-
|
|
1287
|
+
token_details[(engine_name, model_name)]["usage"]["completion_tokens"] += (
|
|
1288
|
+
usage.completion_tokens
|
|
1289
|
+
)
|
|
1290
|
+
token_details[(engine_name, model_name)]["usage"]["prompt_tokens"] += (
|
|
1291
|
+
usage.prompt_tokens
|
|
1292
|
+
)
|
|
1293
|
+
token_details[(engine_name, model_name)]["usage"]["total_tokens"] += (
|
|
1294
|
+
usage.total_tokens
|
|
1295
|
+
)
|
|
1166
1296
|
token_details[(engine_name, model_name)]["usage"]["total_calls"] += 1
|
|
1167
|
-
token_details[(engine_name, model_name)]["completion_breakdown"][
|
|
1168
|
-
|
|
1169
|
-
|
|
1170
|
-
token_details[(engine_name, model_name)]["completion_breakdown"][
|
|
1171
|
-
|
|
1172
|
-
|
|
1297
|
+
token_details[(engine_name, model_name)]["completion_breakdown"][
|
|
1298
|
+
"accepted_prediction_tokens"
|
|
1299
|
+
] += usage.completion_tokens_details.accepted_prediction_tokens
|
|
1300
|
+
token_details[(engine_name, model_name)]["completion_breakdown"][
|
|
1301
|
+
"rejected_prediction_tokens"
|
|
1302
|
+
] += usage.completion_tokens_details.rejected_prediction_tokens
|
|
1303
|
+
token_details[(engine_name, model_name)]["completion_breakdown"][
|
|
1304
|
+
"audio_tokens"
|
|
1305
|
+
] += usage.completion_tokens_details.audio_tokens
|
|
1306
|
+
token_details[(engine_name, model_name)]["completion_breakdown"][
|
|
1307
|
+
"reasoning_tokens"
|
|
1308
|
+
] += usage.completion_tokens_details.reasoning_tokens
|
|
1309
|
+
token_details[(engine_name, model_name)]["prompt_breakdown"][
|
|
1310
|
+
"audio_tokens"
|
|
1311
|
+
] += usage.prompt_tokens_details.audio_tokens
|
|
1312
|
+
token_details[(engine_name, model_name)]["prompt_breakdown"][
|
|
1313
|
+
"cached_tokens"
|
|
1314
|
+
] += usage.prompt_tokens_details.cached_tokens
|
|
1173
1315
|
elif engine_name == "GPTXSearchEngine":
|
|
1174
1316
|
usage = metadata["raw_output"].usage
|
|
1175
|
-
token_details[(engine_name, model_name)]["usage"]["prompt_tokens"] +=
|
|
1176
|
-
|
|
1177
|
-
|
|
1317
|
+
token_details[(engine_name, model_name)]["usage"]["prompt_tokens"] += (
|
|
1318
|
+
usage.input_tokens
|
|
1319
|
+
)
|
|
1320
|
+
token_details[(engine_name, model_name)]["usage"]["completion_tokens"] += (
|
|
1321
|
+
usage.output_tokens
|
|
1322
|
+
)
|
|
1323
|
+
token_details[(engine_name, model_name)]["usage"]["total_tokens"] += (
|
|
1324
|
+
usage.total_tokens
|
|
1325
|
+
)
|
|
1178
1326
|
token_details[(engine_name, model_name)]["usage"]["total_calls"] += 1
|
|
1179
|
-
token_details[(engine_name, model_name)]["prompt_breakdown"][
|
|
1180
|
-
|
|
1327
|
+
token_details[(engine_name, model_name)]["prompt_breakdown"][
|
|
1328
|
+
"cached_tokens"
|
|
1329
|
+
] += usage.input_tokens_details.cached_tokens
|
|
1330
|
+
token_details[(engine_name, model_name)]["completion_breakdown"][
|
|
1331
|
+
"reasoning_tokens"
|
|
1332
|
+
] += usage.output_tokens_details.reasoning_tokens
|
|
1181
1333
|
else:
|
|
1182
1334
|
logger.warning(f"Tracking {engine_name} is not supported.")
|
|
1183
1335
|
continue
|
|
1184
1336
|
except Exception as e:
|
|
1185
|
-
|
|
1337
|
+
UserMessage(
|
|
1338
|
+
f"Failed to parse metadata for {engine_name}: {e}", raise_with=AttributeError
|
|
1339
|
+
)
|
|
1186
1340
|
|
|
1187
1341
|
# Convert to normal dict
|
|
1188
1342
|
return {**token_details}
|
|
1189
1343
|
|
|
1344
|
+
def _can_accumulate_engine(self, engine_name: str) -> bool:
|
|
1345
|
+
supported_engines = ("GPTXChatEngine", "GPTXReasoningEngine", "GPTXSearchEngine")
|
|
1346
|
+
return engine_name in supported_engines
|
|
1347
|
+
|
|
1348
|
+
def _accumulate_time_field(self, accumulated: dict, metadata: dict) -> None:
|
|
1349
|
+
if "time" in metadata and "time" in accumulated:
|
|
1350
|
+
accumulated["time"] += metadata["time"]
|
|
1351
|
+
|
|
1352
|
+
def _accumulate_usage_fields(self, accumulated: dict, metadata: dict) -> None:
|
|
1353
|
+
if "raw_output" not in metadata or "raw_output" not in accumulated:
|
|
1354
|
+
return
|
|
1355
|
+
|
|
1356
|
+
metadata_raw_output = metadata["raw_output"]
|
|
1357
|
+
accumulated_raw_output = accumulated["raw_output"]
|
|
1358
|
+
if not hasattr(metadata_raw_output, "usage") or not hasattr(
|
|
1359
|
+
accumulated_raw_output, "usage"
|
|
1360
|
+
):
|
|
1361
|
+
return
|
|
1362
|
+
|
|
1363
|
+
current_usage = metadata_raw_output.usage
|
|
1364
|
+
accumulated_usage = accumulated_raw_output.usage
|
|
1365
|
+
|
|
1366
|
+
for attr in ["completion_tokens", "prompt_tokens", "total_tokens"]:
|
|
1367
|
+
if hasattr(current_usage, attr) and hasattr(accumulated_usage, attr):
|
|
1368
|
+
setattr(
|
|
1369
|
+
accumulated_usage,
|
|
1370
|
+
attr,
|
|
1371
|
+
getattr(accumulated_usage, attr) + getattr(current_usage, attr),
|
|
1372
|
+
)
|
|
1373
|
+
|
|
1374
|
+
for detail_attr in ["completion_tokens_details", "prompt_tokens_details"]:
|
|
1375
|
+
if not hasattr(current_usage, detail_attr) or not hasattr(
|
|
1376
|
+
accumulated_usage, detail_attr
|
|
1377
|
+
):
|
|
1378
|
+
continue
|
|
1379
|
+
|
|
1380
|
+
current_details = getattr(current_usage, detail_attr)
|
|
1381
|
+
accumulated_details = getattr(accumulated_usage, detail_attr)
|
|
1382
|
+
|
|
1383
|
+
for attr in dir(current_details):
|
|
1384
|
+
if attr.startswith("_") or not hasattr(accumulated_details, attr):
|
|
1385
|
+
continue
|
|
1386
|
+
|
|
1387
|
+
current_val = getattr(current_details, attr)
|
|
1388
|
+
accumulated_val = getattr(accumulated_details, attr)
|
|
1389
|
+
if isinstance(current_val, (int, float)) and isinstance(
|
|
1390
|
+
accumulated_val, (int, float)
|
|
1391
|
+
):
|
|
1392
|
+
setattr(accumulated_details, attr, accumulated_val + current_val)
|
|
1393
|
+
|
|
1190
1394
|
def _accumulate_metadata(self):
|
|
1191
1395
|
"""Accumulates metadata across all tracked engine calls."""
|
|
1192
1396
|
if not self._metadata:
|
|
1193
|
-
|
|
1397
|
+
UserMessage("No metadata available to generate usage details.")
|
|
1194
1398
|
return {}
|
|
1195
1399
|
|
|
1196
1400
|
# Use first entry as base
|
|
@@ -1199,39 +1403,14 @@ class MetadataTracker(Expression):
|
|
|
1199
1403
|
|
|
1200
1404
|
# Skipz first entry
|
|
1201
1405
|
for (_, engine_name), metadata in list(self._metadata.items())[1:]:
|
|
1202
|
-
if
|
|
1203
|
-
logger.warning(
|
|
1406
|
+
if not self._can_accumulate_engine(engine_name):
|
|
1407
|
+
logger.warning(
|
|
1408
|
+
f"Metadata accumulation for {engine_name} is not supported. Try `.usage` instead for now."
|
|
1409
|
+
)
|
|
1204
1410
|
continue
|
|
1205
1411
|
|
|
1206
|
-
|
|
1207
|
-
|
|
1208
|
-
accumulated['time'] += metadata['time']
|
|
1209
|
-
|
|
1210
|
-
# Handle usage stats accumulation
|
|
1211
|
-
if 'raw_output' in metadata and 'raw_output' in accumulated:
|
|
1212
|
-
if hasattr(metadata['raw_output'], 'usage') and hasattr(accumulated['raw_output'], 'usage'):
|
|
1213
|
-
current_usage = metadata['raw_output'].usage
|
|
1214
|
-
accumulated_usage = accumulated['raw_output'].usage
|
|
1215
|
-
|
|
1216
|
-
# Accumulate token counts
|
|
1217
|
-
for attr in ['completion_tokens', 'prompt_tokens', 'total_tokens']:
|
|
1218
|
-
if hasattr(current_usage, attr) and hasattr(accumulated_usage, attr):
|
|
1219
|
-
setattr(accumulated_usage, attr,
|
|
1220
|
-
getattr(accumulated_usage, attr) + getattr(current_usage, attr))
|
|
1221
|
-
|
|
1222
|
-
# Handle nested token details if they exist
|
|
1223
|
-
for detail_attr in ['completion_tokens_details', 'prompt_tokens_details']:
|
|
1224
|
-
if hasattr(current_usage, detail_attr) and hasattr(accumulated_usage, detail_attr):
|
|
1225
|
-
current_details = getattr(current_usage, detail_attr)
|
|
1226
|
-
accumulated_details = getattr(accumulated_usage, detail_attr)
|
|
1227
|
-
|
|
1228
|
-
# Accumulate all numeric attributes in the details
|
|
1229
|
-
for attr in dir(current_details):
|
|
1230
|
-
if not attr.startswith('_') and hasattr(accumulated_details, attr):
|
|
1231
|
-
current_val = getattr(current_details, attr)
|
|
1232
|
-
accumulated_val = getattr(accumulated_details, attr)
|
|
1233
|
-
if isinstance(current_val, (int, float)) and isinstance(accumulated_val, (int, float)):
|
|
1234
|
-
setattr(accumulated_details, attr, accumulated_val + current_val)
|
|
1412
|
+
self._accumulate_time_field(accumulated, metadata)
|
|
1413
|
+
self._accumulate_usage_fields(accumulated, metadata)
|
|
1235
1414
|
|
|
1236
1415
|
return accumulated
|
|
1237
1416
|
|
|
@@ -1250,7 +1429,8 @@ class MetadataTracker(Expression):
|
|
|
1250
1429
|
|
|
1251
1430
|
class DynamicEngine(Expression):
|
|
1252
1431
|
"""Context manager for dynamically switching neurosymbolic engine models."""
|
|
1253
|
-
|
|
1432
|
+
|
|
1433
|
+
def __init__(self, model: str, api_key: str, _debug: bool = False, **_kwargs):
|
|
1254
1434
|
super().__init__()
|
|
1255
1435
|
self.model = model
|
|
1256
1436
|
self.api_key = api_key
|
|
@@ -1259,8 +1439,8 @@ class DynamicEngine(Expression):
|
|
|
1259
1439
|
self.engine_instance = None
|
|
1260
1440
|
self._ctx_token = None
|
|
1261
1441
|
|
|
1262
|
-
def __new__(cls, *
|
|
1263
|
-
cls._lock = getattr(cls,
|
|
1442
|
+
def __new__(cls, *_args, **_kwargs):
|
|
1443
|
+
cls._lock = getattr(cls, "_lock", Lock())
|
|
1264
1444
|
with cls._lock:
|
|
1265
1445
|
instance = super().__new__(cls)
|
|
1266
1446
|
instance._metadata = {}
|
|
@@ -1293,11 +1473,178 @@ class DynamicEngine(Expression):
|
|
|
1293
1473
|
|
|
1294
1474
|
def _create_engine_instance(self):
|
|
1295
1475
|
"""Create an engine instance based on the model name."""
|
|
1296
|
-
|
|
1476
|
+
# Deferred to avoid components <-> neurosymbolic engine circular imports.
|
|
1477
|
+
from .backend.engines.neurosymbolic import ENGINE_MAPPING # noqa
|
|
1478
|
+
|
|
1297
1479
|
try:
|
|
1298
1480
|
engine_class = ENGINE_MAPPING.get(self.model)
|
|
1299
1481
|
if engine_class is None:
|
|
1300
|
-
|
|
1482
|
+
UserMessage(f"Unsupported model '{self.model}'", raise_with=ValueError)
|
|
1301
1483
|
return engine_class(api_key=self.api_key, model=self.model)
|
|
1302
1484
|
except Exception as e:
|
|
1303
|
-
|
|
1485
|
+
UserMessage(
|
|
1486
|
+
f"Failed to create engine for model '{self.model}': {e!s}", raise_with=ValueError
|
|
1487
|
+
)
|
|
1488
|
+
|
|
1489
|
+
|
|
1490
|
+
# Chonkie chunker imports - lazy loaded
|
|
1491
|
+
_CHONKIE_MODULES = None
|
|
1492
|
+
_CHUNKER_MAPPING = None
|
|
1493
|
+
_CHONKIE_AVAILABLE = None
|
|
1494
|
+
|
|
1495
|
+
|
|
1496
|
+
def _lazy_import_chonkie():
|
|
1497
|
+
"""Lazily import chonkie modules when needed."""
|
|
1498
|
+
global _CHONKIE_MODULES, _CHUNKER_MAPPING, _CHONKIE_AVAILABLE
|
|
1499
|
+
|
|
1500
|
+
if _CHONKIE_MODULES is not None:
|
|
1501
|
+
return _CHONKIE_MODULES
|
|
1502
|
+
|
|
1503
|
+
try:
|
|
1504
|
+
from chonkie import ( # noqa
|
|
1505
|
+
CodeChunker,
|
|
1506
|
+
LateChunker,
|
|
1507
|
+
NeuralChunker,
|
|
1508
|
+
RecursiveChunker,
|
|
1509
|
+
SemanticChunker,
|
|
1510
|
+
SentenceChunker,
|
|
1511
|
+
SlumberChunker,
|
|
1512
|
+
TableChunker,
|
|
1513
|
+
TokenChunker,
|
|
1514
|
+
)
|
|
1515
|
+
from chonkie.embeddings.base import BaseEmbeddings # noqa
|
|
1516
|
+
from tokenizers import Tokenizer # noqa
|
|
1517
|
+
|
|
1518
|
+
_CHONKIE_MODULES = {
|
|
1519
|
+
"CodeChunker": CodeChunker,
|
|
1520
|
+
"LateChunker": LateChunker,
|
|
1521
|
+
"NeuralChunker": NeuralChunker,
|
|
1522
|
+
"RecursiveChunker": RecursiveChunker,
|
|
1523
|
+
"SemanticChunker": SemanticChunker,
|
|
1524
|
+
"SentenceChunker": SentenceChunker,
|
|
1525
|
+
"SlumberChunker": SlumberChunker,
|
|
1526
|
+
"TableChunker": TableChunker,
|
|
1527
|
+
"TokenChunker": TokenChunker,
|
|
1528
|
+
"BaseEmbeddings": BaseEmbeddings,
|
|
1529
|
+
"Tokenizer": Tokenizer,
|
|
1530
|
+
}
|
|
1531
|
+
_CHUNKER_MAPPING = {
|
|
1532
|
+
"TokenChunker": TokenChunker,
|
|
1533
|
+
"SentenceChunker": SentenceChunker,
|
|
1534
|
+
"RecursiveChunker": RecursiveChunker,
|
|
1535
|
+
"SemanticChunker": SemanticChunker,
|
|
1536
|
+
"CodeChunker": CodeChunker,
|
|
1537
|
+
"LateChunker": LateChunker,
|
|
1538
|
+
"NeuralChunker": NeuralChunker,
|
|
1539
|
+
"SlumberChunker": SlumberChunker,
|
|
1540
|
+
"TableChunker": TableChunker,
|
|
1541
|
+
}
|
|
1542
|
+
_CHONKIE_AVAILABLE = True
|
|
1543
|
+
except ImportError:
|
|
1544
|
+
_CHONKIE_MODULES = {}
|
|
1545
|
+
_CHUNKER_MAPPING = {}
|
|
1546
|
+
_CHONKIE_AVAILABLE = False
|
|
1547
|
+
|
|
1548
|
+
return _CHONKIE_MODULES
|
|
1549
|
+
|
|
1550
|
+
|
|
1551
|
+
def _get_chunker_mapping():
|
|
1552
|
+
"""Get the chunker mapping, lazily importing chonkie if needed."""
|
|
1553
|
+
if _CHUNKER_MAPPING is None:
|
|
1554
|
+
_lazy_import_chonkie()
|
|
1555
|
+
return _CHUNKER_MAPPING or {}
|
|
1556
|
+
|
|
1557
|
+
|
|
1558
|
+
def _is_chonkie_available():
|
|
1559
|
+
"""Check if chonkie is available, lazily importing if needed."""
|
|
1560
|
+
if _CHONKIE_AVAILABLE is None:
|
|
1561
|
+
_lazy_import_chonkie()
|
|
1562
|
+
return _CHONKIE_AVAILABLE or False
|
|
1563
|
+
|
|
1564
|
+
|
|
1565
|
+
@beartype
|
|
1566
|
+
class ChonkieChunker(Expression):
|
|
1567
|
+
def __init__(
|
|
1568
|
+
self,
|
|
1569
|
+
tokenizer_name: str | None = "gpt2",
|
|
1570
|
+
embedding_model_name: str | None = "minishlab/potion-base-8M",
|
|
1571
|
+
**symai_kwargs,
|
|
1572
|
+
):
|
|
1573
|
+
super().__init__(**symai_kwargs)
|
|
1574
|
+
self.tokenizer_name = tokenizer_name
|
|
1575
|
+
self.embedding_model_name = embedding_model_name
|
|
1576
|
+
|
|
1577
|
+
def forward(
|
|
1578
|
+
self, data: Symbol, chunker_name: str | None = "RecursiveChunker", **chunker_kwargs
|
|
1579
|
+
) -> Symbol:
|
|
1580
|
+
if not _is_chonkie_available():
|
|
1581
|
+
UserMessage(
|
|
1582
|
+
"chonkie library is not installed. Please install it with `pip install chonkie tokenizers`.",
|
|
1583
|
+
raise_with=ImportError,
|
|
1584
|
+
)
|
|
1585
|
+
chunker = self._resolve_chunker(chunker_name, **chunker_kwargs)
|
|
1586
|
+
chunks = [ChonkieChunker.clean_text(chunk.text) for chunk in chunker(data.value)]
|
|
1587
|
+
return self._to_symbol(chunks)
|
|
1588
|
+
|
|
1589
|
+
def _resolve_chunker(self, chunker_name: str, **chunker_kwargs):
|
|
1590
|
+
"""Resolve and instantiate a chunker by name."""
|
|
1591
|
+
chunker_mapping = _get_chunker_mapping()
|
|
1592
|
+
|
|
1593
|
+
if chunker_name not in chunker_mapping:
|
|
1594
|
+
msg = (
|
|
1595
|
+
f"Chunker {chunker_name} not found. Available chunkers: {list(chunker_mapping.keys())}. "
|
|
1596
|
+
f"See docs (https://docs.chonkie.ai/getting-started/introduction) for more info."
|
|
1597
|
+
)
|
|
1598
|
+
raise ValueError(msg)
|
|
1599
|
+
|
|
1600
|
+
chunker_class = chunker_mapping[chunker_name]
|
|
1601
|
+
chonkie_modules = _lazy_import_chonkie()
|
|
1602
|
+
Tokenizer = chonkie_modules.get("Tokenizer")
|
|
1603
|
+
|
|
1604
|
+
# Tokenizer-based chunkers (use tokenizer_name)
|
|
1605
|
+
if chunker_name in ["TokenChunker", "SentenceChunker", "RecursiveChunker"]:
|
|
1606
|
+
if Tokenizer is None:
|
|
1607
|
+
UserMessage(
|
|
1608
|
+
"Tokenizers library is not installed. Please install it with `pip install tokenizers`.",
|
|
1609
|
+
raise_with=ImportError,
|
|
1610
|
+
)
|
|
1611
|
+
tokenizer = Tokenizer.from_pretrained(self.tokenizer_name)
|
|
1612
|
+
return chunker_class(tokenizer, **chunker_kwargs)
|
|
1613
|
+
|
|
1614
|
+
# Embedding-based chunkers (use embedding_model_name)
|
|
1615
|
+
if chunker_name in ["SemanticChunker", "LateChunker"]:
|
|
1616
|
+
return chunker_class(embedding_model=self.embedding_model_name, **chunker_kwargs)
|
|
1617
|
+
|
|
1618
|
+
# CodeChunker and TableChunker use tokenizer (can use string or Tokenizer object)
|
|
1619
|
+
if chunker_name in ["CodeChunker", "TableChunker"]:
|
|
1620
|
+
# These can accept tokenizer as string (default 'character') or Tokenizer object
|
|
1621
|
+
# If tokenizer not provided in kwargs, use tokenizer_name
|
|
1622
|
+
if "tokenizer" not in chunker_kwargs:
|
|
1623
|
+
chunker_kwargs["tokenizer"] = self.tokenizer_name
|
|
1624
|
+
return chunker_class(**chunker_kwargs)
|
|
1625
|
+
|
|
1626
|
+
# SlumberChunker uses tokenizer (can use string or Tokenizer object)
|
|
1627
|
+
if chunker_name == "SlumberChunker":
|
|
1628
|
+
# SlumberChunker can accept tokenizer as string or Tokenizer object
|
|
1629
|
+
# If tokenizer not provided in kwargs, use tokenizer_name
|
|
1630
|
+
if "tokenizer" not in chunker_kwargs:
|
|
1631
|
+
chunker_kwargs["tokenizer"] = self.tokenizer_name
|
|
1632
|
+
return chunker_class(**chunker_kwargs)
|
|
1633
|
+
|
|
1634
|
+
# NeuralChunker uses model parameter (defaults provided by chonkie)
|
|
1635
|
+
if chunker_name == "NeuralChunker":
|
|
1636
|
+
return chunker_class(**chunker_kwargs)
|
|
1637
|
+
|
|
1638
|
+
msg = (
|
|
1639
|
+
f"Chunker {chunker_name} not properly configured. "
|
|
1640
|
+
f"Available chunkers: {list(chunker_mapping.keys())}."
|
|
1641
|
+
)
|
|
1642
|
+
raise ValueError(msg)
|
|
1643
|
+
|
|
1644
|
+
@staticmethod
|
|
1645
|
+
def clean_text(text: str) -> str:
|
|
1646
|
+
"""Cleans text by removing problematic characters."""
|
|
1647
|
+
text = text.replace("\x00", "") # Remove null bytes (\x00)
|
|
1648
|
+
return text.encode("utf-8", errors="ignore").decode(
|
|
1649
|
+
"utf-8"
|
|
1650
|
+
) # Replace invalid UTF-8 sequences
|