symbolicai 0.20.2__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- symai/__init__.py +96 -64
- symai/backend/base.py +93 -80
- symai/backend/engines/drawing/engine_bfl.py +12 -11
- symai/backend/engines/drawing/engine_gpt_image.py +108 -87
- symai/backend/engines/embedding/engine_llama_cpp.py +25 -28
- symai/backend/engines/embedding/engine_openai.py +3 -5
- symai/backend/engines/execute/engine_python.py +6 -5
- symai/backend/engines/files/engine_io.py +74 -67
- symai/backend/engines/imagecaptioning/engine_blip2.py +3 -3
- symai/backend/engines/imagecaptioning/engine_llavacpp_client.py +54 -38
- symai/backend/engines/index/engine_pinecone.py +23 -24
- symai/backend/engines/index/engine_vectordb.py +16 -14
- symai/backend/engines/lean/engine_lean4.py +38 -34
- symai/backend/engines/neurosymbolic/__init__.py +41 -13
- symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_chat.py +262 -182
- symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_reasoning.py +263 -191
- symai/backend/engines/neurosymbolic/engine_deepseekX_reasoning.py +53 -49
- symai/backend/engines/neurosymbolic/engine_google_geminiX_reasoning.py +212 -211
- symai/backend/engines/neurosymbolic/engine_groq.py +87 -63
- symai/backend/engines/neurosymbolic/engine_huggingface.py +21 -24
- symai/backend/engines/neurosymbolic/engine_llama_cpp.py +117 -48
- symai/backend/engines/neurosymbolic/engine_openai_gptX_chat.py +256 -229
- symai/backend/engines/neurosymbolic/engine_openai_gptX_reasoning.py +270 -150
- symai/backend/engines/ocr/engine_apilayer.py +6 -8
- symai/backend/engines/output/engine_stdout.py +1 -4
- symai/backend/engines/search/engine_openai.py +7 -7
- symai/backend/engines/search/engine_perplexity.py +5 -5
- symai/backend/engines/search/engine_serpapi.py +12 -14
- symai/backend/engines/speech_to_text/engine_local_whisper.py +20 -27
- symai/backend/engines/symbolic/engine_wolframalpha.py +3 -3
- symai/backend/engines/text_to_speech/engine_openai.py +5 -7
- symai/backend/engines/text_vision/engine_clip.py +7 -11
- symai/backend/engines/userinput/engine_console.py +3 -3
- symai/backend/engines/webscraping/engine_requests.py +81 -48
- symai/backend/mixin/__init__.py +13 -0
- symai/backend/mixin/anthropic.py +4 -2
- symai/backend/mixin/deepseek.py +2 -0
- symai/backend/mixin/google.py +2 -0
- symai/backend/mixin/openai.py +11 -3
- symai/backend/settings.py +83 -16
- symai/chat.py +101 -78
- symai/collect/__init__.py +7 -1
- symai/collect/dynamic.py +77 -69
- symai/collect/pipeline.py +35 -27
- symai/collect/stats.py +75 -63
- symai/components.py +198 -169
- symai/constraints.py +15 -12
- symai/core.py +698 -359
- symai/core_ext.py +32 -34
- symai/endpoints/api.py +80 -73
- symai/extended/.DS_Store +0 -0
- symai/extended/__init__.py +46 -12
- symai/extended/api_builder.py +11 -8
- symai/extended/arxiv_pdf_parser.py +13 -12
- symai/extended/bibtex_parser.py +2 -3
- symai/extended/conversation.py +101 -90
- symai/extended/document.py +17 -10
- symai/extended/file_merger.py +18 -13
- symai/extended/graph.py +18 -13
- symai/extended/html_style_template.py +2 -4
- symai/extended/interfaces/blip_2.py +1 -2
- symai/extended/interfaces/clip.py +1 -2
- symai/extended/interfaces/console.py +7 -1
- symai/extended/interfaces/dall_e.py +1 -1
- symai/extended/interfaces/flux.py +1 -1
- symai/extended/interfaces/gpt_image.py +1 -1
- symai/extended/interfaces/input.py +1 -1
- symai/extended/interfaces/llava.py +0 -1
- symai/extended/interfaces/naive_vectordb.py +7 -8
- symai/extended/interfaces/naive_webscraping.py +1 -1
- symai/extended/interfaces/ocr.py +1 -1
- symai/extended/interfaces/pinecone.py +6 -5
- symai/extended/interfaces/serpapi.py +1 -1
- symai/extended/interfaces/terminal.py +2 -3
- symai/extended/interfaces/tts.py +1 -1
- symai/extended/interfaces/whisper.py +1 -1
- symai/extended/interfaces/wolframalpha.py +1 -1
- symai/extended/metrics/__init__.py +11 -1
- symai/extended/metrics/similarity.py +11 -13
- symai/extended/os_command.py +17 -16
- symai/extended/packages/__init__.py +29 -3
- symai/extended/packages/symdev.py +19 -16
- symai/extended/packages/sympkg.py +12 -9
- symai/extended/packages/symrun.py +21 -19
- symai/extended/repo_cloner.py +11 -10
- symai/extended/seo_query_optimizer.py +1 -2
- symai/extended/solver.py +20 -23
- symai/extended/summarizer.py +4 -3
- symai/extended/taypan_interpreter.py +10 -12
- symai/extended/vectordb.py +99 -82
- symai/formatter/__init__.py +9 -1
- symai/formatter/formatter.py +12 -16
- symai/formatter/regex.py +62 -63
- symai/functional.py +176 -122
- symai/imports.py +136 -127
- symai/interfaces.py +56 -27
- symai/memory.py +14 -13
- symai/misc/console.py +49 -39
- symai/misc/loader.py +5 -3
- symai/models/__init__.py +17 -1
- symai/models/base.py +269 -181
- symai/models/errors.py +0 -1
- symai/ops/__init__.py +32 -22
- symai/ops/measures.py +11 -15
- symai/ops/primitives.py +348 -228
- symai/post_processors.py +32 -28
- symai/pre_processors.py +39 -41
- symai/processor.py +6 -4
- symai/prompts.py +59 -45
- symai/server/huggingface_server.py +23 -20
- symai/server/llama_cpp_server.py +7 -5
- symai/shell.py +3 -4
- symai/shellsv.py +499 -375
- symai/strategy.py +517 -287
- symai/symbol.py +111 -116
- symai/utils.py +42 -36
- {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/METADATA +4 -2
- symbolicai-1.0.0.dist-info/RECORD +163 -0
- symbolicai-0.20.2.dist-info/RECORD +0 -162
- {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/WHEEL +0 -0
- {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/entry_points.txt +0 -0
- {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/top_level.txt +0 -0
symai/components.py
CHANGED
|
@@ -1,23 +1,19 @@
|
|
|
1
1
|
import copy
|
|
2
2
|
import inspect
|
|
3
3
|
import json
|
|
4
|
-
import os
|
|
5
4
|
import re
|
|
6
5
|
import sys
|
|
7
|
-
from abc import abstractmethod
|
|
8
6
|
from collections import defaultdict
|
|
7
|
+
from collections.abc import Callable, Iterator
|
|
9
8
|
from pathlib import Path
|
|
10
9
|
from random import sample
|
|
11
10
|
from string import ascii_lowercase, ascii_uppercase
|
|
12
11
|
from threading import Lock
|
|
13
|
-
from
|
|
14
|
-
from typing import Callable, Dict, Iterator, List, Optional, Type, Union
|
|
12
|
+
from typing import TYPE_CHECKING, Union
|
|
15
13
|
|
|
16
14
|
import numpy as np
|
|
17
|
-
from attr import dataclass
|
|
18
15
|
from box import Box
|
|
19
16
|
from loguru import logger
|
|
20
|
-
from pydantic import BaseModel, ValidationError
|
|
21
17
|
from pyvis.network import Network
|
|
22
18
|
from tqdm import tqdm
|
|
23
19
|
|
|
@@ -25,16 +21,25 @@ from . import core, core_ext
|
|
|
25
21
|
from .backend.base import Engine
|
|
26
22
|
from .backend.settings import HOME_PATH
|
|
27
23
|
from .constraints import DictFormatConstraint
|
|
24
|
+
from .context import CURRENT_ENGINE_VAR
|
|
28
25
|
from .formatter import ParagraphFormatter
|
|
29
|
-
from .post_processors import (
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
26
|
+
from .post_processors import (
|
|
27
|
+
CodeExtractPostProcessor,
|
|
28
|
+
JsonTruncateMarkdownPostProcessor,
|
|
29
|
+
JsonTruncatePostProcessor,
|
|
30
|
+
PostProcessor,
|
|
31
|
+
StripPostProcessor,
|
|
32
|
+
)
|
|
33
33
|
from .pre_processors import JsonPreProcessor, PreProcessor
|
|
34
34
|
from .processor import ProcessorPipeline
|
|
35
35
|
from .prompts import JsonPromptTemplate, Prompt
|
|
36
36
|
from .symbol import Expression, Metadata, Symbol
|
|
37
|
-
from .utils import
|
|
37
|
+
from .utils import UserMessage
|
|
38
|
+
|
|
39
|
+
if TYPE_CHECKING:
|
|
40
|
+
from .backend.engines.index.engine_vectordb import VectorDBResult
|
|
41
|
+
|
|
42
|
+
_DEFAULT_PARAGRAPH_FORMATTER = ParagraphFormatter()
|
|
38
43
|
|
|
39
44
|
|
|
40
45
|
class GraphViz(Expression):
|
|
@@ -58,7 +63,7 @@ class GraphViz(Expression):
|
|
|
58
63
|
select_menu=select_menu,
|
|
59
64
|
filter_menu=filter_menu)
|
|
60
65
|
|
|
61
|
-
def forward(self, sym: Symbol, file_path: str, **
|
|
66
|
+
def forward(self, sym: Symbol, file_path: str, **_kwargs):
|
|
62
67
|
nodes = [str(n) if n.value else n.__repr__(simplified=True) for n in sym.nodes]
|
|
63
68
|
edges = [(str(e[0]) if e[0].value else e[0].__repr__(simplified=True),
|
|
64
69
|
str(e[1]) if e[1].value else e[1].__repr__(simplified=True)) for e in sym.edges]
|
|
@@ -73,21 +78,21 @@ class TrackerTraceable(Expression):
|
|
|
73
78
|
|
|
74
79
|
|
|
75
80
|
class Any(Expression):
|
|
76
|
-
def __init__(self, *expr:
|
|
81
|
+
def __init__(self, *expr: list[Expression], **kwargs):
|
|
77
82
|
super().__init__(**kwargs)
|
|
78
|
-
self.expr:
|
|
83
|
+
self.expr: list[Expression] = expr
|
|
79
84
|
|
|
80
85
|
def forward(self, *args, **kwargs) -> Symbol:
|
|
81
|
-
return self.sym_return_type(any(
|
|
86
|
+
return self.sym_return_type(any(e() for e in self.expr(*args, **kwargs)))
|
|
82
87
|
|
|
83
88
|
|
|
84
89
|
class All(Expression):
|
|
85
|
-
def __init__(self, *expr:
|
|
90
|
+
def __init__(self, *expr: list[Expression], **kwargs):
|
|
86
91
|
super().__init__(**kwargs)
|
|
87
|
-
self.expr:
|
|
92
|
+
self.expr: list[Expression] = expr
|
|
88
93
|
|
|
89
94
|
def forward(self, *args, **kwargs) -> Symbol:
|
|
90
|
-
return self.sym_return_type(all(
|
|
95
|
+
return self.sym_return_type(all(e() for e in self.expr(*args, **kwargs)))
|
|
91
96
|
|
|
92
97
|
|
|
93
98
|
class Try(Expression):
|
|
@@ -117,14 +122,14 @@ class Lambda(Expression):
|
|
|
117
122
|
|
|
118
123
|
|
|
119
124
|
class Choice(Expression):
|
|
120
|
-
def __init__(self, cases:
|
|
125
|
+
def __init__(self, cases: list[str], default: str | None = None, **kwargs):
|
|
121
126
|
super().__init__(**kwargs)
|
|
122
|
-
self.cases:
|
|
123
|
-
self.default:
|
|
127
|
+
self.cases: list[str] = cases
|
|
128
|
+
self.default: str | None = default
|
|
124
129
|
|
|
125
130
|
def forward(self, sym: Symbol, *args, **kwargs) -> Symbol:
|
|
126
131
|
sym = self._to_symbol(sym)
|
|
127
|
-
return sym.choice(cases=self.cases, default=self.default,
|
|
132
|
+
return sym.choice(*args, cases=self.cases, default=self.default, **kwargs)
|
|
128
133
|
|
|
129
134
|
|
|
130
135
|
class Output(Expression):
|
|
@@ -137,13 +142,13 @@ class Output(Expression):
|
|
|
137
142
|
def forward(self, *args, **kwargs) -> Expression:
|
|
138
143
|
kwargs['verbose'] = self.verbose
|
|
139
144
|
kwargs['handler'] = self.handler
|
|
140
|
-
return self.output(expr=self.expr,
|
|
145
|
+
return self.output(*args, expr=self.expr, **kwargs)
|
|
141
146
|
|
|
142
147
|
|
|
143
148
|
class Sequence(TrackerTraceable):
|
|
144
|
-
def __init__(self, *expressions:
|
|
149
|
+
def __init__(self, *expressions: list[Expression], **kwargs):
|
|
145
150
|
super().__init__(**kwargs)
|
|
146
|
-
self.expressions:
|
|
151
|
+
self.expressions: list[Expression] = expressions
|
|
147
152
|
|
|
148
153
|
def forward(self, *args, **kwargs) -> Symbol:
|
|
149
154
|
sym = self.expressions[0](*args, **kwargs)
|
|
@@ -159,11 +164,11 @@ class Sequence(TrackerTraceable):
|
|
|
159
164
|
|
|
160
165
|
|
|
161
166
|
class Parallel(Expression):
|
|
162
|
-
def __init__(self, *expr:
|
|
167
|
+
def __init__(self, *expr: list[Expression | Callable], sequential: bool = False, **kwargs):
|
|
163
168
|
super().__init__(**kwargs)
|
|
164
169
|
self.sequential: bool = sequential
|
|
165
|
-
self.expr:
|
|
166
|
-
self.results:
|
|
170
|
+
self.expr: list[Expression] = expr
|
|
171
|
+
self.results: list[Symbol] = []
|
|
167
172
|
|
|
168
173
|
def forward(self, *args, **kwargs) -> Symbol:
|
|
169
174
|
# run in sequence
|
|
@@ -180,11 +185,11 @@ class Parallel(Expression):
|
|
|
180
185
|
|
|
181
186
|
#@TODO: BinPacker(format="...") -> ensure that data packages form a "bin" that's consistent (e.g. never break a sentence in the middle)
|
|
182
187
|
class Stream(Expression):
|
|
183
|
-
def __init__(self, expr:
|
|
188
|
+
def __init__(self, expr: Expression | None = None, retrieval: str | None = None, **kwargs):
|
|
184
189
|
super().__init__(**kwargs)
|
|
185
190
|
self.char_token_ratio: float = 0.6
|
|
186
|
-
self.expr:
|
|
187
|
-
self.retrieval:
|
|
191
|
+
self.expr: Expression | None = expr
|
|
192
|
+
self.retrieval: str | None = retrieval
|
|
188
193
|
self._trace: bool = False
|
|
189
194
|
self._previous_frame = None
|
|
190
195
|
|
|
@@ -194,19 +199,23 @@ class Stream(Expression):
|
|
|
194
199
|
if self._trace:
|
|
195
200
|
local_vars = self._previous_frame.f_locals
|
|
196
201
|
vals = []
|
|
197
|
-
for
|
|
202
|
+
for _key, var in local_vars.items():
|
|
198
203
|
if isinstance(var, TrackerTraceable):
|
|
199
204
|
vals.append(var)
|
|
200
205
|
|
|
201
206
|
if len(vals) == 1:
|
|
202
207
|
self.expr = vals[0]
|
|
203
208
|
else:
|
|
204
|
-
|
|
209
|
+
UserMessage(
|
|
210
|
+
"This component does either not inherit from TrackerTraceable or has an invalid number of component "
|
|
211
|
+
f"declarations: {len(vals)}! Only one component that inherits from TrackerTraceable is allowed in the "
|
|
212
|
+
"with stream clause.",
|
|
213
|
+
raise_with=ValueError,
|
|
214
|
+
)
|
|
205
215
|
|
|
206
216
|
res = sym.stream(expr=self.expr,
|
|
207
217
|
char_token_ratio=self.char_token_ratio,
|
|
208
218
|
**kwargs)
|
|
209
|
-
|
|
210
219
|
if self.retrieval is not None:
|
|
211
220
|
res = list(res)
|
|
212
221
|
if self.retrieval == 'all':
|
|
@@ -215,9 +224,8 @@ class Stream(Expression):
|
|
|
215
224
|
res = sorted(res, key=lambda x: len(x), reverse=True)
|
|
216
225
|
return res[0]
|
|
217
226
|
if self.retrieval == 'contains':
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
raise ValueError(f"Invalid retrieval method: {self.retrieval}")
|
|
227
|
+
return [r for r in res if self.expr in r]
|
|
228
|
+
UserMessage(f"Invalid retrieval method: {self.retrieval}", raise_with=ValueError)
|
|
221
229
|
|
|
222
230
|
return res
|
|
223
231
|
|
|
@@ -231,10 +239,12 @@ class Stream(Expression):
|
|
|
231
239
|
|
|
232
240
|
|
|
233
241
|
class Trace(Expression):
|
|
234
|
-
def __init__(self, expr:
|
|
242
|
+
def __init__(self, expr: Expression | None = None, engines=None, **kwargs):
|
|
243
|
+
if engines is None:
|
|
244
|
+
engines = ['all']
|
|
235
245
|
super().__init__(**kwargs)
|
|
236
246
|
self.expr: Expression = expr
|
|
237
|
-
self.engines:
|
|
247
|
+
self.engines: list[str] = engines
|
|
238
248
|
|
|
239
249
|
def forward(self, *args, **kwargs) -> Expression:
|
|
240
250
|
Expression.command(verbose=True, engines=self.engines)
|
|
@@ -252,23 +262,26 @@ class Trace(Expression):
|
|
|
252
262
|
Expression.command(verbose=False, engines=self.engines)
|
|
253
263
|
if self.expr is not None:
|
|
254
264
|
return self.expr.__exit__(type, value, traceback)
|
|
265
|
+
return None
|
|
255
266
|
|
|
256
267
|
|
|
257
268
|
class Analyze(Expression):
|
|
258
|
-
def __init__(self, exception: Exception, query:
|
|
269
|
+
def __init__(self, exception: Exception, query: str | None = None, **kwargs):
|
|
259
270
|
super().__init__(**kwargs)
|
|
260
271
|
self.exception: Expression = exception
|
|
261
|
-
self.query:
|
|
272
|
+
self.query: str | None = query
|
|
262
273
|
|
|
263
274
|
def forward(self, sym: Symbol, *args, **kwargs) -> Symbol:
|
|
264
|
-
return sym.analyze(exception=self.exception, query=self.query,
|
|
275
|
+
return sym.analyze(*args, exception=self.exception, query=self.query, **kwargs)
|
|
265
276
|
|
|
266
277
|
|
|
267
278
|
class Log(Expression):
|
|
268
|
-
def __init__(self, expr:
|
|
279
|
+
def __init__(self, expr: Expression | None = None, engines=None, **kwargs):
|
|
280
|
+
if engines is None:
|
|
281
|
+
engines = ['all']
|
|
269
282
|
super().__init__(**kwargs)
|
|
270
283
|
self.expr: Expression = expr
|
|
271
|
-
self.engines:
|
|
284
|
+
self.engines: list[str] = engines
|
|
272
285
|
|
|
273
286
|
def forward(self, *args, **kwargs) -> Expression:
|
|
274
287
|
Expression.command(logging=True, engines=self.engines)
|
|
@@ -286,6 +299,7 @@ class Log(Expression):
|
|
|
286
299
|
Expression.command(logging=False, engines=self.engines)
|
|
287
300
|
if self.expr is not None:
|
|
288
301
|
return self.expr.__exit__(type, value, traceback)
|
|
302
|
+
return None
|
|
289
303
|
|
|
290
304
|
|
|
291
305
|
class Template(Expression):
|
|
@@ -331,10 +345,10 @@ class Metric(Expression):
|
|
|
331
345
|
self.normalize = normalize
|
|
332
346
|
self.eps = eps
|
|
333
347
|
|
|
334
|
-
def forward(self, sym: Symbol, **
|
|
348
|
+
def forward(self, sym: Symbol, **_kwargs) -> Symbol:
|
|
335
349
|
sym = self._to_symbol(sym)
|
|
336
|
-
assert sym.value_type
|
|
337
|
-
if sym.value_type
|
|
350
|
+
assert sym.value_type is np.ndarray or sym.value_type is list, 'Metric can only be applied to numpy arrays or lists.'
|
|
351
|
+
if sym.value_type is list:
|
|
338
352
|
sym._value = np.array(sym.value)
|
|
339
353
|
# compute normalization between 0 and 1
|
|
340
354
|
if self.normalize:
|
|
@@ -343,17 +357,19 @@ class Metric(Expression):
|
|
|
343
357
|
elif len(sym.value.shape) == 2:
|
|
344
358
|
pass
|
|
345
359
|
else:
|
|
346
|
-
|
|
360
|
+
UserMessage(f'Invalid shape: {sym.value.shape}', raise_with=ValueError)
|
|
347
361
|
# normalize between 0 and 1 and sum to 1
|
|
348
362
|
sym._value = np.exp(sym.value) / (np.exp(sym.value).sum() + self.eps)
|
|
349
363
|
return sym
|
|
350
364
|
|
|
351
365
|
|
|
352
366
|
class Style(Expression):
|
|
353
|
-
def __init__(self, description: str, libraries:
|
|
367
|
+
def __init__(self, description: str, libraries: list[str] | None = None, **kwargs):
|
|
368
|
+
if libraries is None:
|
|
369
|
+
libraries = []
|
|
354
370
|
super().__init__(**kwargs)
|
|
355
371
|
self.description: str = description
|
|
356
|
-
self.libraries:
|
|
372
|
+
self.libraries: list[str] = libraries
|
|
357
373
|
|
|
358
374
|
def forward(self, sym: Symbol, **kwargs) -> Symbol:
|
|
359
375
|
sym = self._to_symbol(sym)
|
|
@@ -365,7 +381,7 @@ class Query(TrackerTraceable):
|
|
|
365
381
|
super().__init__(**kwargs)
|
|
366
382
|
self.prompt: str = prompt
|
|
367
383
|
|
|
368
|
-
def forward(self, sym: Symbol, context: Symbol = None, *
|
|
384
|
+
def forward(self, sym: Symbol, context: Symbol = None, *_args, **kwargs) -> Symbol:
|
|
369
385
|
sym = self._to_symbol(sym)
|
|
370
386
|
return sym.query(prompt=self.prompt, context=context, **kwargs)
|
|
371
387
|
|
|
@@ -474,11 +490,11 @@ class ExcludeFilter(Expression):
|
|
|
474
490
|
class FileWriter(Expression):
|
|
475
491
|
def __init__(self, path: str, **kwargs):
|
|
476
492
|
super().__init__(**kwargs)
|
|
477
|
-
self.path = path
|
|
493
|
+
self.path = Path(path)
|
|
478
494
|
|
|
479
|
-
def forward(self, sym: Symbol, **
|
|
495
|
+
def forward(self, sym: Symbol, **_kwargs) -> Symbol:
|
|
480
496
|
sym = self._to_symbol(sym)
|
|
481
|
-
with
|
|
497
|
+
with self.path.open('w') as f:
|
|
482
498
|
f.write(str(sym))
|
|
483
499
|
|
|
484
500
|
|
|
@@ -493,12 +509,10 @@ class FileReader(Expression):
|
|
|
493
509
|
assert len(_splits) == 1 or len(_splits) == 2, 'Invalid file link format.'
|
|
494
510
|
_tmp = Path(_tmp)
|
|
495
511
|
# check if file exists and is a file
|
|
496
|
-
|
|
497
|
-
return True
|
|
498
|
-
return False
|
|
512
|
+
return _tmp.is_file()
|
|
499
513
|
|
|
500
514
|
@staticmethod
|
|
501
|
-
def get_files(folder_path: str, max_depth: int = 1) ->
|
|
515
|
+
def get_files(folder_path: str, max_depth: int = 1) -> list[str]:
|
|
502
516
|
accepted_formats = ['.pdf', '.md', '.txt']
|
|
503
517
|
|
|
504
518
|
folder = Path(folder_path)
|
|
@@ -512,7 +526,7 @@ class FileReader(Expression):
|
|
|
512
526
|
return files
|
|
513
527
|
|
|
514
528
|
@staticmethod
|
|
515
|
-
def extract_files(cmds: str) ->
|
|
529
|
+
def extract_files(cmds: str) -> list[str] | None:
|
|
516
530
|
# Use the updated regular expression to match quoted and non-quoted paths
|
|
517
531
|
pattern = r'''(?:"((?:\\.|[^"\\])*)"|'((?:\\.|[^'\\])*)'|`((?:\\.|[^`\\])*)`|((?:\\ |[^ ])+))'''
|
|
518
532
|
# Use the regular expression to split and handle quoted and non-quoted paths
|
|
@@ -551,16 +565,16 @@ class FileReader(Expression):
|
|
|
551
565
|
return Path(path).expanduser().resolve().as_posix()
|
|
552
566
|
|
|
553
567
|
@staticmethod
|
|
554
|
-
def integrity_check(files:
|
|
568
|
+
def integrity_check(files: list[str]) -> list[str]:
|
|
555
569
|
not_skipped = []
|
|
556
570
|
for file in tqdm(files):
|
|
557
571
|
if FileReader.exists(file):
|
|
558
572
|
not_skipped.append(file)
|
|
559
573
|
else:
|
|
560
|
-
|
|
574
|
+
UserMessage(f'Skipping file: {file}')
|
|
561
575
|
return not_skipped
|
|
562
576
|
|
|
563
|
-
def forward(self, files:
|
|
577
|
+
def forward(self, files: str | list[str], **kwargs) -> Expression:
|
|
564
578
|
if isinstance(files, str):
|
|
565
579
|
# Convert to list for uniform processing; more easily downstream
|
|
566
580
|
files = [files]
|
|
@@ -586,15 +600,17 @@ class FileQuery(Expression):
|
|
|
586
600
|
|
|
587
601
|
class Function(TrackerTraceable):
|
|
588
602
|
def __init__(self, prompt: str = '',
|
|
589
|
-
examples:
|
|
590
|
-
pre_processors:
|
|
591
|
-
post_processors:
|
|
592
|
-
default:
|
|
593
|
-
constraints:
|
|
594
|
-
return_type:
|
|
595
|
-
sym_return_type:
|
|
596
|
-
origin_type:
|
|
603
|
+
examples: str | None = [],
|
|
604
|
+
pre_processors: list[PreProcessor] | None = None,
|
|
605
|
+
post_processors: list[PostProcessor] | None = None,
|
|
606
|
+
default: object | None = None,
|
|
607
|
+
constraints: list[Callable] | None = None,
|
|
608
|
+
return_type: type | None = str,
|
|
609
|
+
sym_return_type: type | None = Symbol,
|
|
610
|
+
origin_type: type | None = Expression,
|
|
597
611
|
*args, **kwargs):
|
|
612
|
+
if constraints is None:
|
|
613
|
+
constraints = []
|
|
598
614
|
super().__init__(**kwargs)
|
|
599
615
|
chars = ascii_lowercase + ascii_uppercase
|
|
600
616
|
self.name = 'func_' + ''.join(sample(chars, 15))
|
|
@@ -629,13 +645,16 @@ class Function(TrackerTraceable):
|
|
|
629
645
|
if 'fn' in kwargs:
|
|
630
646
|
self.prompt = kwargs['fn']
|
|
631
647
|
del kwargs['fn']
|
|
632
|
-
@core.few_shot(
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
648
|
+
@core.few_shot(
|
|
649
|
+
*self.args,
|
|
650
|
+
prompt=self.prompt,
|
|
651
|
+
examples=self.examples,
|
|
652
|
+
pre_processors=self.pre_processors,
|
|
653
|
+
post_processors=self.post_processors,
|
|
654
|
+
constraints=self.constraints,
|
|
655
|
+
default=self.default,
|
|
656
|
+
**self.kwargs
|
|
657
|
+
)
|
|
639
658
|
def _func(_, *args, **kwargs) -> self.return_type:
|
|
640
659
|
pass
|
|
641
660
|
_type = type(self.name, (self.origin_type, ), {
|
|
@@ -658,11 +677,11 @@ class PrepareData(Function):
|
|
|
658
677
|
assert argument.prop.context is not None
|
|
659
678
|
instruct = argument.prop.prompt
|
|
660
679
|
context = argument.prop.context
|
|
661
|
-
return """{
|
|
662
|
-
'context': '
|
|
663
|
-
'instruction': '
|
|
680
|
+
return f"""{{
|
|
681
|
+
'context': '{context}',
|
|
682
|
+
'instruction': '{instruct}',
|
|
664
683
|
'result': 'TODO: Replace this with the expected result.'
|
|
665
|
-
}"""
|
|
684
|
+
}}"""
|
|
666
685
|
|
|
667
686
|
def __init__(self, *args, **kwargs):
|
|
668
687
|
super().__init__(*args, **kwargs)
|
|
@@ -707,7 +726,7 @@ class ExpressionBuilder(Function):
|
|
|
707
726
|
super().__init__('Generate the code following the instructions:', **kwargs)
|
|
708
727
|
self.processors = ProcessorPipeline([StripPostProcessor(), CodeExtractPostProcessor()])
|
|
709
728
|
|
|
710
|
-
def forward(self, instruct, *
|
|
729
|
+
def forward(self, instruct, *_args, **_kwargs):
|
|
711
730
|
result = super().forward(instruct)
|
|
712
731
|
return self.processors(str(result), None)
|
|
713
732
|
|
|
@@ -768,21 +787,21 @@ class JsonParser(Expression):
|
|
|
768
787
|
|
|
769
788
|
|
|
770
789
|
class SimilarityClassification(Expression):
|
|
771
|
-
def __init__(self, classes:
|
|
790
|
+
def __init__(self, classes: list[str], metric: str = 'cosine', in_memory: bool = False, **kwargs):
|
|
772
791
|
super().__init__(**kwargs)
|
|
773
792
|
self.classes = classes
|
|
774
793
|
self.metric = metric
|
|
775
794
|
self.in_memory = in_memory
|
|
776
795
|
|
|
777
796
|
if self.in_memory:
|
|
778
|
-
|
|
797
|
+
UserMessage(f'Caching mode is enabled! It is your responsability to empty the .cache folder if you did changes to the classes. The cache is located at {HOME_PATH}/cache')
|
|
779
798
|
|
|
780
799
|
def forward(self, x: Symbol) -> Symbol:
|
|
781
800
|
x = self._to_symbol(x)
|
|
782
801
|
usr_embed = x.embed()
|
|
783
802
|
embeddings = self._dynamic_cache()
|
|
784
803
|
similarities = [usr_embed.similarity(emb, metric=self.metric) for emb in embeddings]
|
|
785
|
-
similarities = sorted(zip(self.classes, similarities), key=lambda x: x[1], reverse=True)
|
|
804
|
+
similarities = sorted(zip(self.classes, similarities, strict=False), key=lambda x: x[1], reverse=True)
|
|
786
805
|
|
|
787
806
|
return Symbol(similarities[0][0])
|
|
788
807
|
|
|
@@ -790,9 +809,7 @@ class SimilarityClassification(Expression):
|
|
|
790
809
|
@core_ext.cache(in_memory=self.in_memory)
|
|
791
810
|
def embed_classes(self):
|
|
792
811
|
opts = map(Symbol, self.classes)
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
return embeddings
|
|
812
|
+
return [opt.embed() for opt in opts]
|
|
796
813
|
|
|
797
814
|
return embed_classes(self)
|
|
798
815
|
|
|
@@ -820,19 +837,14 @@ class Indexer(Expression):
|
|
|
820
837
|
@staticmethod
|
|
821
838
|
def replace_special_chars(index: str):
|
|
822
839
|
# replace special characters that are not for path
|
|
823
|
-
|
|
824
|
-
index = index.replace('-', '')
|
|
825
|
-
index = index.replace('_', '')
|
|
826
|
-
index = index.replace(' ', '')
|
|
827
|
-
index = index.lower()
|
|
828
|
-
return index
|
|
840
|
+
return str(index).replace('-', '').replace('_', '').replace(' ', '').lower()
|
|
829
841
|
|
|
830
842
|
def __init__(
|
|
831
843
|
self,
|
|
832
844
|
index_name: str = DEFAULT,
|
|
833
845
|
top_k: int = 8,
|
|
834
846
|
batch_size: int = 20,
|
|
835
|
-
formatter: Callable =
|
|
847
|
+
formatter: Callable = _DEFAULT_PARAGRAPH_FORMATTER,
|
|
836
848
|
auto_add=False,
|
|
837
849
|
raw_result: bool = False,
|
|
838
850
|
new_dim: int = 1536,
|
|
@@ -861,15 +873,15 @@ class Indexer(Expression):
|
|
|
861
873
|
def register(self):
|
|
862
874
|
# check if index already exists in indices.txt and append if not
|
|
863
875
|
change = False
|
|
864
|
-
with
|
|
876
|
+
with self.path.open() as f:
|
|
865
877
|
indices = f.read().split('\n')
|
|
866
878
|
# filter out empty strings
|
|
867
879
|
indices = [i for i in indices if i]
|
|
868
|
-
|
|
880
|
+
if self.index_name not in indices:
|
|
869
881
|
indices.append(self.index_name)
|
|
870
882
|
change = True
|
|
871
883
|
if change:
|
|
872
|
-
with
|
|
884
|
+
with self.path.open('w') as f:
|
|
873
885
|
f.write('\n'.join(indices))
|
|
874
886
|
|
|
875
887
|
def exists(self) -> bool:
|
|
@@ -877,15 +889,16 @@ class Indexer(Expression):
|
|
|
877
889
|
path = HOME_PATH / 'indices.txt'
|
|
878
890
|
if not path.exists():
|
|
879
891
|
return False
|
|
880
|
-
with open(
|
|
892
|
+
with path.open() as f:
|
|
881
893
|
indices = f.read().split('\n')
|
|
882
894
|
if self.index_name in indices:
|
|
883
895
|
return True
|
|
896
|
+
return False
|
|
884
897
|
|
|
885
898
|
def forward(
|
|
886
899
|
self,
|
|
887
|
-
data:
|
|
888
|
-
|
|
900
|
+
data: Symbol | None = None,
|
|
901
|
+
_raw_result: bool = False,
|
|
889
902
|
) -> Symbol:
|
|
890
903
|
that = self
|
|
891
904
|
if data is not None:
|
|
@@ -898,15 +911,14 @@ class Indexer(Expression):
|
|
|
898
911
|
# we save the index
|
|
899
912
|
that.config(None, save=True, index_name=that.index_name, index_dims=that.new_dim)
|
|
900
913
|
|
|
901
|
-
def _func(query, *
|
|
914
|
+
def _func(query, *_args, **kwargs) -> Union[Symbol, 'VectorDBResult']:
|
|
902
915
|
raw_result = kwargs.get('raw_result') or that.raw_result
|
|
903
916
|
query_emb = Symbol(query).embed(new_dim=that.new_dim).value
|
|
904
917
|
res = that.get(query_emb, index_name=that.index_name, index_top_k=that.top_k, ori_query=query, index_dims=that.new_dim, **kwargs)
|
|
905
918
|
that.retrieval = res
|
|
906
919
|
if raw_result:
|
|
907
920
|
return res
|
|
908
|
-
|
|
909
|
-
return rsp
|
|
921
|
+
return Symbol(res).query(prompt='From the retrieved data, select the most relevant information.', context=query)
|
|
910
922
|
return _func
|
|
911
923
|
|
|
912
924
|
|
|
@@ -917,8 +929,8 @@ class PrimitiveDisabler(Expression):
|
|
|
917
929
|
self._original_primitives = defaultdict(list)
|
|
918
930
|
|
|
919
931
|
def __enter__(self):
|
|
920
|
-
#
|
|
921
|
-
from .symbol import Symbol
|
|
932
|
+
# Import Symbol lazily so components does not clash with symbol during load.
|
|
933
|
+
from .symbol import Symbol # noqa
|
|
922
934
|
|
|
923
935
|
frame = inspect.currentframe()
|
|
924
936
|
f_locals = frame.f_back.f_locals
|
|
@@ -934,7 +946,7 @@ class PrimitiveDisabler(Expression):
|
|
|
934
946
|
for func in self._primitives:
|
|
935
947
|
if hasattr(sym, func):
|
|
936
948
|
self._original_primitives[sym_name].append((func, getattr(sym, func)))
|
|
937
|
-
setattr(sym, func, lambda *
|
|
949
|
+
setattr(sym, func, lambda *_args, **_kwargs: None)
|
|
938
950
|
|
|
939
951
|
def _enable_primitives(self):
|
|
940
952
|
for sym_name, sym in self._symbols.items():
|
|
@@ -968,7 +980,7 @@ class FunctionWithUsage(Function):
|
|
|
968
980
|
|
|
969
981
|
def print_verbose(self, msg):
|
|
970
982
|
if self.verbose:
|
|
971
|
-
|
|
983
|
+
UserMessage(msg)
|
|
972
984
|
|
|
973
985
|
def _format_usage(self, prompt_tokens, completion_tokens, total_tokens):
|
|
974
986
|
return Box(
|
|
@@ -1020,12 +1032,11 @@ class FunctionWithUsage(Function):
|
|
|
1020
1032
|
self.completion_tokens += completion_tokens
|
|
1021
1033
|
self.total_tokens += total_tokens
|
|
1022
1034
|
else:
|
|
1023
|
-
if self.missing_usage_exception and
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
total_tokens = 0
|
|
1035
|
+
if self.missing_usage_exception and "preview" not in kwargs:
|
|
1036
|
+
UserMessage("Missing usage in metadata of neursymbolic engine", raise_with=Exception)
|
|
1037
|
+
prompt_tokens = 0
|
|
1038
|
+
completion_tokens = 0
|
|
1039
|
+
total_tokens = 0
|
|
1029
1040
|
|
|
1030
1041
|
return res, self._format_usage(prompt_tokens, completion_tokens, total_tokens)
|
|
1031
1042
|
|
|
@@ -1041,7 +1052,7 @@ class SelfPrompt(Expression):
|
|
|
1041
1052
|
def __init__(self, *args, **kwargs):
|
|
1042
1053
|
super().__init__(*args, **kwargs)
|
|
1043
1054
|
|
|
1044
|
-
def forward(self, existing_prompt:
|
|
1055
|
+
def forward(self, existing_prompt: dict[str, str], **kwargs) -> dict[str, str]:
|
|
1045
1056
|
"""
|
|
1046
1057
|
Generate new system and user prompts based on the existing prompt.
|
|
1047
1058
|
|
|
@@ -1094,14 +1105,13 @@ class MetadataTracker(Expression):
|
|
|
1094
1105
|
value = value or self.metadata
|
|
1095
1106
|
if isinstance(value, dict):
|
|
1096
1107
|
return '{\n\t' + ', \n\t'.join(f'"{k}": {self.__str__(v)}' for k,v in value.items()) + '\n}'
|
|
1097
|
-
|
|
1108
|
+
if isinstance(value, list):
|
|
1098
1109
|
return '[' + ', '.join(self.__str__(item) for item in value) + ']'
|
|
1099
|
-
|
|
1110
|
+
if isinstance(value, str):
|
|
1100
1111
|
return f'"{value}"'
|
|
1101
|
-
|
|
1102
|
-
return f"\n\t {value}"
|
|
1112
|
+
return f"\n\t {value}"
|
|
1103
1113
|
|
|
1104
|
-
def __new__(cls, *
|
|
1114
|
+
def __new__(cls, *_args, **_kwargs):
|
|
1105
1115
|
cls._lock = getattr(cls, '_lock', Lock())
|
|
1106
1116
|
with cls._lock:
|
|
1107
1117
|
instance = super().__new__(cls)
|
|
@@ -1122,25 +1132,26 @@ class MetadataTracker(Expression):
|
|
|
1122
1132
|
|
|
1123
1133
|
def _trace_calls(self, frame, event, arg):
|
|
1124
1134
|
if not self._trace:
|
|
1125
|
-
return
|
|
1135
|
+
return None
|
|
1126
1136
|
|
|
1127
|
-
if
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
+
if (
|
|
1138
|
+
event == 'return'
|
|
1139
|
+
and frame.f_code.co_name == 'forward'
|
|
1140
|
+
and 'self' in frame.f_locals
|
|
1141
|
+
and isinstance(frame.f_locals['self'], Engine)
|
|
1142
|
+
):
|
|
1143
|
+
_, metadata = arg # arg contains return value on 'return' event
|
|
1144
|
+
engine_name = frame.f_locals['self'].__class__.__name__
|
|
1145
|
+
model_name = frame.f_locals['self'].model
|
|
1146
|
+
self._metadata[(self._metadata_id, engine_name, model_name)] = metadata
|
|
1147
|
+
self._metadata_id += 1
|
|
1137
1148
|
|
|
1138
1149
|
return self._trace_calls
|
|
1139
1150
|
|
|
1140
1151
|
def _accumulate_completion_token_details(self):
|
|
1141
1152
|
"""Parses the return object and accumulates completion token details per token type"""
|
|
1142
1153
|
if not self._metadata:
|
|
1143
|
-
|
|
1154
|
+
UserMessage("No metadata available to generate usage details.")
|
|
1144
1155
|
return {}
|
|
1145
1156
|
|
|
1146
1157
|
token_details = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
|
|
@@ -1182,15 +1193,59 @@ class MetadataTracker(Expression):
|
|
|
1182
1193
|
logger.warning(f"Tracking {engine_name} is not supported.")
|
|
1183
1194
|
continue
|
|
1184
1195
|
except Exception as e:
|
|
1185
|
-
|
|
1196
|
+
UserMessage(f"Failed to parse metadata for {engine_name}: {e}", raise_with=AttributeError)
|
|
1186
1197
|
|
|
1187
1198
|
# Convert to normal dict
|
|
1188
1199
|
return {**token_details}
|
|
1189
1200
|
|
|
1201
|
+
def _can_accumulate_engine(self, engine_name: str) -> bool:
|
|
1202
|
+
supported_engines = ("GPTXChatEngine", "GPTXReasoningEngine", "GPTXSearchEngine")
|
|
1203
|
+
return engine_name in supported_engines
|
|
1204
|
+
|
|
1205
|
+
def _accumulate_time_field(self, accumulated: dict, metadata: dict) -> None:
|
|
1206
|
+
if 'time' in metadata and 'time' in accumulated:
|
|
1207
|
+
accumulated['time'] += metadata['time']
|
|
1208
|
+
|
|
1209
|
+
def _accumulate_usage_fields(self, accumulated: dict, metadata: dict) -> None:
|
|
1210
|
+
if 'raw_output' not in metadata or 'raw_output' not in accumulated:
|
|
1211
|
+
return
|
|
1212
|
+
|
|
1213
|
+
metadata_raw_output = metadata['raw_output']
|
|
1214
|
+
accumulated_raw_output = accumulated['raw_output']
|
|
1215
|
+
if not hasattr(metadata_raw_output, 'usage') or not hasattr(accumulated_raw_output, 'usage'):
|
|
1216
|
+
return
|
|
1217
|
+
|
|
1218
|
+
current_usage = metadata_raw_output.usage
|
|
1219
|
+
accumulated_usage = accumulated_raw_output.usage
|
|
1220
|
+
|
|
1221
|
+
for attr in ['completion_tokens', 'prompt_tokens', 'total_tokens']:
|
|
1222
|
+
if hasattr(current_usage, attr) and hasattr(accumulated_usage, attr):
|
|
1223
|
+
setattr(
|
|
1224
|
+
accumulated_usage,
|
|
1225
|
+
attr,
|
|
1226
|
+
getattr(accumulated_usage, attr) + getattr(current_usage, attr),
|
|
1227
|
+
)
|
|
1228
|
+
|
|
1229
|
+
for detail_attr in ['completion_tokens_details', 'prompt_tokens_details']:
|
|
1230
|
+
if not hasattr(current_usage, detail_attr) or not hasattr(accumulated_usage, detail_attr):
|
|
1231
|
+
continue
|
|
1232
|
+
|
|
1233
|
+
current_details = getattr(current_usage, detail_attr)
|
|
1234
|
+
accumulated_details = getattr(accumulated_usage, detail_attr)
|
|
1235
|
+
|
|
1236
|
+
for attr in dir(current_details):
|
|
1237
|
+
if attr.startswith('_') or not hasattr(accumulated_details, attr):
|
|
1238
|
+
continue
|
|
1239
|
+
|
|
1240
|
+
current_val = getattr(current_details, attr)
|
|
1241
|
+
accumulated_val = getattr(accumulated_details, attr)
|
|
1242
|
+
if isinstance(current_val, (int, float)) and isinstance(accumulated_val, (int, float)):
|
|
1243
|
+
setattr(accumulated_details, attr, accumulated_val + current_val)
|
|
1244
|
+
|
|
1190
1245
|
def _accumulate_metadata(self):
|
|
1191
1246
|
"""Accumulates metadata across all tracked engine calls."""
|
|
1192
1247
|
if not self._metadata:
|
|
1193
|
-
|
|
1248
|
+
UserMessage("No metadata available to generate usage details.")
|
|
1194
1249
|
return {}
|
|
1195
1250
|
|
|
1196
1251
|
# Use first entry as base
|
|
@@ -1199,39 +1254,12 @@ class MetadataTracker(Expression):
|
|
|
1199
1254
|
|
|
1200
1255
|
# Skipz first entry
|
|
1201
1256
|
for (_, engine_name), metadata in list(self._metadata.items())[1:]:
|
|
1202
|
-
if
|
|
1257
|
+
if not self._can_accumulate_engine(engine_name):
|
|
1203
1258
|
logger.warning(f"Metadata accumulation for {engine_name} is not supported. Try `.usage` instead for now.")
|
|
1204
1259
|
continue
|
|
1205
1260
|
|
|
1206
|
-
|
|
1207
|
-
|
|
1208
|
-
accumulated['time'] += metadata['time']
|
|
1209
|
-
|
|
1210
|
-
# Handle usage stats accumulation
|
|
1211
|
-
if 'raw_output' in metadata and 'raw_output' in accumulated:
|
|
1212
|
-
if hasattr(metadata['raw_output'], 'usage') and hasattr(accumulated['raw_output'], 'usage'):
|
|
1213
|
-
current_usage = metadata['raw_output'].usage
|
|
1214
|
-
accumulated_usage = accumulated['raw_output'].usage
|
|
1215
|
-
|
|
1216
|
-
# Accumulate token counts
|
|
1217
|
-
for attr in ['completion_tokens', 'prompt_tokens', 'total_tokens']:
|
|
1218
|
-
if hasattr(current_usage, attr) and hasattr(accumulated_usage, attr):
|
|
1219
|
-
setattr(accumulated_usage, attr,
|
|
1220
|
-
getattr(accumulated_usage, attr) + getattr(current_usage, attr))
|
|
1221
|
-
|
|
1222
|
-
# Handle nested token details if they exist
|
|
1223
|
-
for detail_attr in ['completion_tokens_details', 'prompt_tokens_details']:
|
|
1224
|
-
if hasattr(current_usage, detail_attr) and hasattr(accumulated_usage, detail_attr):
|
|
1225
|
-
current_details = getattr(current_usage, detail_attr)
|
|
1226
|
-
accumulated_details = getattr(accumulated_usage, detail_attr)
|
|
1227
|
-
|
|
1228
|
-
# Accumulate all numeric attributes in the details
|
|
1229
|
-
for attr in dir(current_details):
|
|
1230
|
-
if not attr.startswith('_') and hasattr(accumulated_details, attr):
|
|
1231
|
-
current_val = getattr(current_details, attr)
|
|
1232
|
-
accumulated_val = getattr(accumulated_details, attr)
|
|
1233
|
-
if isinstance(current_val, (int, float)) and isinstance(accumulated_val, (int, float)):
|
|
1234
|
-
setattr(accumulated_details, attr, accumulated_val + current_val)
|
|
1261
|
+
self._accumulate_time_field(accumulated, metadata)
|
|
1262
|
+
self._accumulate_usage_fields(accumulated, metadata)
|
|
1235
1263
|
|
|
1236
1264
|
return accumulated
|
|
1237
1265
|
|
|
@@ -1250,7 +1278,7 @@ class MetadataTracker(Expression):
|
|
|
1250
1278
|
|
|
1251
1279
|
class DynamicEngine(Expression):
|
|
1252
1280
|
"""Context manager for dynamically switching neurosymbolic engine models."""
|
|
1253
|
-
def __init__(self, model: str, api_key: str,
|
|
1281
|
+
def __init__(self, model: str, api_key: str, _debug: bool = False, **_kwargs):
|
|
1254
1282
|
super().__init__()
|
|
1255
1283
|
self.model = model
|
|
1256
1284
|
self.api_key = api_key
|
|
@@ -1259,7 +1287,7 @@ class DynamicEngine(Expression):
|
|
|
1259
1287
|
self.engine_instance = None
|
|
1260
1288
|
self._ctx_token = None
|
|
1261
1289
|
|
|
1262
|
-
def __new__(cls, *
|
|
1290
|
+
def __new__(cls, *_args, **_kwargs):
|
|
1263
1291
|
cls._lock = getattr(cls, '_lock', Lock())
|
|
1264
1292
|
with cls._lock:
|
|
1265
1293
|
instance = super().__new__(cls)
|
|
@@ -1293,11 +1321,12 @@ class DynamicEngine(Expression):
|
|
|
1293
1321
|
|
|
1294
1322
|
def _create_engine_instance(self):
|
|
1295
1323
|
"""Create an engine instance based on the model name."""
|
|
1296
|
-
|
|
1324
|
+
# Deferred to avoid components <-> neurosymbolic engine circular imports.
|
|
1325
|
+
from .backend.engines.neurosymbolic import ENGINE_MAPPING # noqa
|
|
1297
1326
|
try:
|
|
1298
1327
|
engine_class = ENGINE_MAPPING.get(self.model)
|
|
1299
1328
|
if engine_class is None:
|
|
1300
|
-
|
|
1329
|
+
UserMessage(f"Unsupported model '{self.model}'", raise_with=ValueError)
|
|
1301
1330
|
return engine_class(api_key=self.api_key, model=self.model)
|
|
1302
1331
|
except Exception as e:
|
|
1303
|
-
|
|
1332
|
+
UserMessage(f"Failed to create engine for model '{self.model}': {e!s}", raise_with=ValueError)
|