symbolicai 0.21.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- symai/__init__.py +269 -173
- symai/backend/base.py +123 -110
- symai/backend/engines/drawing/engine_bfl.py +45 -44
- symai/backend/engines/drawing/engine_gpt_image.py +112 -97
- symai/backend/engines/embedding/engine_llama_cpp.py +63 -52
- symai/backend/engines/embedding/engine_openai.py +25 -21
- symai/backend/engines/execute/engine_python.py +19 -18
- symai/backend/engines/files/engine_io.py +104 -95
- symai/backend/engines/imagecaptioning/engine_blip2.py +28 -24
- symai/backend/engines/imagecaptioning/engine_llavacpp_client.py +102 -79
- symai/backend/engines/index/engine_pinecone.py +124 -97
- symai/backend/engines/index/engine_qdrant.py +1011 -0
- symai/backend/engines/index/engine_vectordb.py +84 -56
- symai/backend/engines/lean/engine_lean4.py +96 -52
- symai/backend/engines/neurosymbolic/__init__.py +41 -13
- symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_chat.py +330 -248
- symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_reasoning.py +329 -264
- symai/backend/engines/neurosymbolic/engine_cerebras.py +328 -0
- symai/backend/engines/neurosymbolic/engine_deepseekX_reasoning.py +118 -88
- symai/backend/engines/neurosymbolic/engine_google_geminiX_reasoning.py +344 -299
- symai/backend/engines/neurosymbolic/engine_groq.py +173 -115
- symai/backend/engines/neurosymbolic/engine_huggingface.py +114 -84
- symai/backend/engines/neurosymbolic/engine_llama_cpp.py +144 -118
- symai/backend/engines/neurosymbolic/engine_openai_gptX_chat.py +415 -307
- symai/backend/engines/neurosymbolic/engine_openai_gptX_reasoning.py +394 -231
- symai/backend/engines/ocr/engine_apilayer.py +23 -27
- symai/backend/engines/output/engine_stdout.py +10 -13
- symai/backend/engines/{webscraping → scrape}/engine_requests.py +101 -54
- symai/backend/engines/search/engine_openai.py +100 -88
- symai/backend/engines/search/engine_parallel.py +665 -0
- symai/backend/engines/search/engine_perplexity.py +44 -45
- symai/backend/engines/search/engine_serpapi.py +37 -34
- symai/backend/engines/speech_to_text/engine_local_whisper.py +54 -51
- symai/backend/engines/symbolic/engine_wolframalpha.py +15 -9
- symai/backend/engines/text_to_speech/engine_openai.py +20 -26
- symai/backend/engines/text_vision/engine_clip.py +39 -37
- symai/backend/engines/userinput/engine_console.py +5 -6
- symai/backend/mixin/__init__.py +13 -0
- symai/backend/mixin/anthropic.py +48 -38
- symai/backend/mixin/deepseek.py +6 -5
- symai/backend/mixin/google.py +7 -4
- symai/backend/mixin/groq.py +2 -4
- symai/backend/mixin/openai.py +140 -110
- symai/backend/settings.py +87 -20
- symai/chat.py +216 -123
- symai/collect/__init__.py +7 -1
- symai/collect/dynamic.py +80 -70
- symai/collect/pipeline.py +67 -51
- symai/collect/stats.py +161 -109
- symai/components.py +707 -360
- symai/constraints.py +24 -12
- symai/core.py +1857 -1233
- symai/core_ext.py +83 -80
- symai/endpoints/api.py +166 -104
- symai/extended/.DS_Store +0 -0
- symai/extended/__init__.py +46 -12
- symai/extended/api_builder.py +29 -21
- symai/extended/arxiv_pdf_parser.py +23 -14
- symai/extended/bibtex_parser.py +9 -6
- symai/extended/conversation.py +156 -126
- symai/extended/document.py +50 -30
- symai/extended/file_merger.py +57 -14
- symai/extended/graph.py +51 -32
- symai/extended/html_style_template.py +18 -14
- symai/extended/interfaces/blip_2.py +2 -3
- symai/extended/interfaces/clip.py +4 -3
- symai/extended/interfaces/console.py +9 -1
- symai/extended/interfaces/dall_e.py +4 -2
- symai/extended/interfaces/file.py +2 -0
- symai/extended/interfaces/flux.py +4 -2
- symai/extended/interfaces/gpt_image.py +16 -7
- symai/extended/interfaces/input.py +2 -1
- symai/extended/interfaces/llava.py +1 -2
- symai/extended/interfaces/{naive_webscraping.py → naive_scrape.py} +4 -3
- symai/extended/interfaces/naive_vectordb.py +9 -10
- symai/extended/interfaces/ocr.py +5 -3
- symai/extended/interfaces/openai_search.py +2 -0
- symai/extended/interfaces/parallel.py +30 -0
- symai/extended/interfaces/perplexity.py +2 -0
- symai/extended/interfaces/pinecone.py +12 -9
- symai/extended/interfaces/python.py +2 -0
- symai/extended/interfaces/serpapi.py +3 -1
- symai/extended/interfaces/terminal.py +2 -4
- symai/extended/interfaces/tts.py +3 -2
- symai/extended/interfaces/whisper.py +3 -2
- symai/extended/interfaces/wolframalpha.py +2 -1
- symai/extended/metrics/__init__.py +11 -1
- symai/extended/metrics/similarity.py +14 -13
- symai/extended/os_command.py +39 -29
- symai/extended/packages/__init__.py +29 -3
- symai/extended/packages/symdev.py +51 -43
- symai/extended/packages/sympkg.py +41 -35
- symai/extended/packages/symrun.py +63 -50
- symai/extended/repo_cloner.py +14 -12
- symai/extended/seo_query_optimizer.py +15 -13
- symai/extended/solver.py +116 -91
- symai/extended/summarizer.py +12 -10
- symai/extended/taypan_interpreter.py +17 -18
- symai/extended/vectordb.py +122 -92
- symai/formatter/__init__.py +9 -1
- symai/formatter/formatter.py +51 -47
- symai/formatter/regex.py +70 -69
- symai/functional.py +325 -176
- symai/imports.py +190 -147
- symai/interfaces.py +57 -28
- symai/memory.py +45 -35
- symai/menu/screen.py +28 -19
- symai/misc/console.py +66 -56
- symai/misc/loader.py +8 -5
- symai/models/__init__.py +17 -1
- symai/models/base.py +395 -236
- symai/models/errors.py +1 -2
- symai/ops/__init__.py +32 -22
- symai/ops/measures.py +24 -25
- symai/ops/primitives.py +1149 -731
- symai/post_processors.py +58 -50
- symai/pre_processors.py +86 -82
- symai/processor.py +21 -13
- symai/prompts.py +764 -685
- symai/server/huggingface_server.py +135 -49
- symai/server/llama_cpp_server.py +21 -11
- symai/server/qdrant_server.py +206 -0
- symai/shell.py +100 -42
- symai/shellsv.py +700 -492
- symai/strategy.py +630 -346
- symai/symbol.py +368 -322
- symai/utils.py +100 -78
- {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/METADATA +22 -10
- symbolicai-1.1.0.dist-info/RECORD +168 -0
- symbolicai-0.21.0.dist-info/RECORD +0 -162
- {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/WHEEL +0 -0
- {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/entry_points.txt +0 -0
- {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/licenses/LICENSE +0 -0
- {symbolicai-0.21.0.dist-info → symbolicai-1.1.0.dist-info}/top_level.txt +0 -0
symai/collect/stats.py
CHANGED
|
@@ -1,21 +1,24 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import re
|
|
3
|
+
from collections.abc import Callable
|
|
3
4
|
from json import JSONEncoder
|
|
4
|
-
from
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Union
|
|
5
7
|
|
|
6
8
|
import numpy as np
|
|
7
9
|
import torch
|
|
8
10
|
|
|
9
11
|
from ..ops.primitives import OperatorPrimitives
|
|
10
12
|
from ..symbol import Symbol
|
|
13
|
+
from ..utils import UserMessage
|
|
11
14
|
|
|
12
|
-
SPECIAL_CONSTANT =
|
|
13
|
-
EXCLUDE_LIST
|
|
15
|
+
SPECIAL_CONSTANT = "__aggregate_"
|
|
16
|
+
EXCLUDE_LIST = ["_ipython_canary_method_should_not_exist_", "__custom_documentations__"]
|
|
14
17
|
|
|
15
18
|
|
|
16
19
|
def _normalize_name(name: str) -> str:
|
|
17
20
|
# Replace any character that is not a letter or a number with an underscore
|
|
18
|
-
normalized_name = re.sub(r
|
|
21
|
+
normalized_name = re.sub(r"[^a-zA-Z0-9]", "_", name)
|
|
19
22
|
return normalized_name.lower()
|
|
20
23
|
|
|
21
24
|
|
|
@@ -24,31 +27,35 @@ class AggregatorJSONEncoder(JSONEncoder):
|
|
|
24
27
|
if isinstance(obj, np.ndarray):
|
|
25
28
|
return obj.tolist()
|
|
26
29
|
# drop active from state
|
|
27
|
-
|
|
30
|
+
if isinstance(obj, Aggregator):
|
|
28
31
|
state = obj.__dict__.copy()
|
|
29
|
-
state.pop(
|
|
30
|
-
state.pop(
|
|
31
|
-
state.pop(
|
|
32
|
-
state.pop(
|
|
32
|
+
state.pop("_raise_error", None)
|
|
33
|
+
state.pop("_active", None)
|
|
34
|
+
state.pop("_finalized", None)
|
|
35
|
+
state.pop("_map", None)
|
|
33
36
|
# drop everything that starts with SPECIAL_CONSTANT
|
|
34
37
|
for key in list(state.keys()):
|
|
35
|
-
if
|
|
36
|
-
key
|
|
37
|
-
key.
|
|
38
|
+
if (
|
|
39
|
+
(not key.startswith(SPECIAL_CONSTANT) and key != "_value")
|
|
40
|
+
or (key == "_value" and obj._value == [])
|
|
41
|
+
or key.replace(SPECIAL_CONSTANT, "") in EXCLUDE_LIST
|
|
42
|
+
):
|
|
38
43
|
state.pop(key, None)
|
|
39
44
|
return state
|
|
40
45
|
return obj.__dict__
|
|
41
46
|
|
|
42
47
|
|
|
43
48
|
class Aggregator(Symbol):
|
|
44
|
-
def __init__(
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
value: Union["Aggregator", Symbol] | None = None,
|
|
52
|
+
path: str | None = None,
|
|
53
|
+
active: bool = True,
|
|
54
|
+
raise_error: bool = False,
|
|
55
|
+
*args,
|
|
56
|
+
**kwargs,
|
|
57
|
+
):
|
|
58
|
+
super().__init__(*args, **kwargs)
|
|
52
59
|
# disable nesy engine to avoid side effects
|
|
53
60
|
self.__disable_nesy_engine__ = True
|
|
54
61
|
if value is not None and isinstance(value, Symbol):
|
|
@@ -61,27 +68,38 @@ class Aggregator(Symbol):
|
|
|
61
68
|
elif not isinstance(self._value, (list, tuple)):
|
|
62
69
|
self._value = [self._value]
|
|
63
70
|
elif value is not None:
|
|
64
|
-
|
|
71
|
+
UserMessage(
|
|
72
|
+
f"Aggregator object must be of type Aggregator or Symbol! Got: {type(value)}",
|
|
73
|
+
raise_with=Exception,
|
|
74
|
+
)
|
|
65
75
|
else:
|
|
66
76
|
self._value = []
|
|
67
|
-
self._raise_error
|
|
68
|
-
self._active
|
|
77
|
+
self._raise_error = raise_error
|
|
78
|
+
self._active = active
|
|
69
79
|
self._finalized = False
|
|
70
|
-
self._map
|
|
71
|
-
self._path
|
|
72
|
-
|
|
73
|
-
def __new__(
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
80
|
+
self._map = None
|
|
81
|
+
self._path = path
|
|
82
|
+
|
|
83
|
+
def __new__(
|
|
84
|
+
cls,
|
|
85
|
+
*args,
|
|
86
|
+
mixin: bool | None = None,
|
|
87
|
+
primitives: list[type] | None = None, # only inherit arithmetic primitives
|
|
88
|
+
callables: list[tuple[str, Callable]] | None = None,
|
|
89
|
+
semantic: bool = False,
|
|
90
|
+
**kwargs,
|
|
91
|
+
) -> "Symbol":
|
|
92
|
+
if primitives is None:
|
|
93
|
+
primitives = [OperatorPrimitives]
|
|
94
|
+
return super().__new__(
|
|
95
|
+
cls,
|
|
96
|
+
*args,
|
|
97
|
+
mixin=mixin,
|
|
98
|
+
primitives=primitives,
|
|
99
|
+
callables=callables,
|
|
100
|
+
semantic=semantic,
|
|
101
|
+
**kwargs,
|
|
102
|
+
)
|
|
85
103
|
|
|
86
104
|
def __getattr__(self, name):
|
|
87
105
|
# replace name special characters and spaces with underscores
|
|
@@ -93,12 +111,14 @@ class Aggregator(Symbol):
|
|
|
93
111
|
self._children.append(aggregator)
|
|
94
112
|
# create a new aggregate aggregator
|
|
95
113
|
# named {SPECIAL_CONSTANT}{name} for automatic aggregation
|
|
96
|
-
self.__dict__[f
|
|
114
|
+
self.__dict__[f"{SPECIAL_CONSTANT}{name}"] = aggregator
|
|
97
115
|
# add also a property with the same name but without the SPECIAL_CONSTANT prefix as a shortcut
|
|
98
|
-
self.__dict__[name] = self.__dict__[f
|
|
116
|
+
self.__dict__[name] = self.__dict__[f"{SPECIAL_CONSTANT}{name}"]
|
|
99
117
|
return self.__dict__.get(name)
|
|
100
|
-
|
|
101
|
-
|
|
118
|
+
if not self._active and name not in self.__dict__:
|
|
119
|
+
UserMessage(
|
|
120
|
+
f"Aggregator object is frozen! No attribute {name} found!", raise_with=Exception
|
|
121
|
+
)
|
|
102
122
|
return self.__dict__.get(name)
|
|
103
123
|
|
|
104
124
|
def __setattr__(self, name, value):
|
|
@@ -129,32 +149,36 @@ class Aggregator(Symbol):
|
|
|
129
149
|
def __setstate__(self, state):
|
|
130
150
|
# replace name special characters and spaces with underscores
|
|
131
151
|
# drop active from state
|
|
132
|
-
state.pop(
|
|
133
|
-
state.pop(
|
|
134
|
-
state.pop(
|
|
135
|
-
state.pop(
|
|
152
|
+
state.pop("_raise_error", None)
|
|
153
|
+
state.pop("_active", None)
|
|
154
|
+
state.pop("_finalized", None)
|
|
155
|
+
state.pop("_map", None)
|
|
136
156
|
return super().__setstate__(state)
|
|
137
157
|
|
|
138
158
|
@staticmethod
|
|
139
159
|
def _set_values(obj, dictionary, parent, strict: bool = True):
|
|
140
160
|
# recursively reconstruct the object
|
|
141
161
|
for key, value in dictionary.items():
|
|
142
|
-
|
|
162
|
+
attr_key = key
|
|
163
|
+
attr_value = value
|
|
164
|
+
if isinstance(attr_value, dict):
|
|
143
165
|
if parent is not None:
|
|
144
|
-
obj._path =
|
|
145
|
-
|
|
146
|
-
if
|
|
147
|
-
|
|
148
|
-
if
|
|
166
|
+
obj._path = attr_key
|
|
167
|
+
attr_value = Aggregator._reconstruct(attr_value, parent=parent, strict=strict)
|
|
168
|
+
if attr_key.startswith(SPECIAL_CONSTANT):
|
|
169
|
+
attr_key = attr_key.replace(SPECIAL_CONSTANT, "")
|
|
170
|
+
if attr_key == "_value":
|
|
149
171
|
try:
|
|
150
|
-
|
|
172
|
+
attr_value = np.asarray(attr_value, dtype=np.float32)
|
|
151
173
|
except Exception as e:
|
|
152
174
|
if strict:
|
|
153
|
-
|
|
154
|
-
|
|
175
|
+
msg = f"Could not set value of Aggregator object: {obj.path}! ERROR: {e}"
|
|
176
|
+
UserMessage(msg)
|
|
177
|
+
raise Exception(msg) from e
|
|
178
|
+
obj.__setattr__(attr_key, attr_value)
|
|
155
179
|
|
|
156
180
|
@staticmethod
|
|
157
|
-
def _reconstruct(dictionary, parent
|
|
181
|
+
def _reconstruct(dictionary, parent=None, strict: bool = True):
|
|
158
182
|
obj = Aggregator()
|
|
159
183
|
obj._parent = parent
|
|
160
184
|
if parent is not None:
|
|
@@ -163,12 +187,12 @@ class Aggregator(Symbol):
|
|
|
163
187
|
return obj
|
|
164
188
|
|
|
165
189
|
def __str__(self) -> str:
|
|
166
|
-
|
|
190
|
+
"""
|
|
167
191
|
Get the string representation of the Symbol object.
|
|
168
192
|
|
|
169
193
|
Returns:
|
|
170
194
|
str: The string representation of the Symbol object.
|
|
171
|
-
|
|
195
|
+
"""
|
|
172
196
|
return str(self.entries)
|
|
173
197
|
|
|
174
198
|
def _to_symbol(self, other) -> Symbol:
|
|
@@ -180,14 +204,13 @@ class Aggregator(Symbol):
|
|
|
180
204
|
|
|
181
205
|
@property
|
|
182
206
|
def path(self) -> str:
|
|
183
|
-
path =
|
|
184
|
-
obj
|
|
207
|
+
path = ""
|
|
208
|
+
obj = self
|
|
185
209
|
while obj is not None:
|
|
186
210
|
if obj._path is not None:
|
|
187
|
-
path = obj._path.replace(SPECIAL_CONSTANT,
|
|
211
|
+
path = obj._path.replace(SPECIAL_CONSTANT, "") + "." + path
|
|
188
212
|
obj = obj._parent
|
|
189
|
-
|
|
190
|
-
return path
|
|
213
|
+
return path[:-1] # remove last dot
|
|
191
214
|
|
|
192
215
|
def __or__(self, other: Any) -> Any:
|
|
193
216
|
self.add(other)
|
|
@@ -211,8 +234,7 @@ class Aggregator(Symbol):
|
|
|
211
234
|
@property
|
|
212
235
|
def value(self):
|
|
213
236
|
if self.map is not None:
|
|
214
|
-
|
|
215
|
-
return res
|
|
237
|
+
return np.asarray(self.map(np.asarray(self._value, dtype=np.float32)))
|
|
216
238
|
return np.asarray(self._value, dtype=np.float32)
|
|
217
239
|
|
|
218
240
|
@property
|
|
@@ -226,25 +248,26 @@ class Aggregator(Symbol):
|
|
|
226
248
|
def _set_map_recursively(self, map):
|
|
227
249
|
self._map = map
|
|
228
250
|
for key, value in self.__dict__.items():
|
|
229
|
-
if isinstance(value, Aggregator) and (
|
|
251
|
+
if isinstance(value, Aggregator) and (
|
|
252
|
+
not key.startswith("_") or key.startswith(SPECIAL_CONSTANT)
|
|
253
|
+
):
|
|
230
254
|
value.map = map
|
|
231
255
|
|
|
232
256
|
def shape(self):
|
|
233
257
|
if len(self.entries) > 0:
|
|
234
258
|
return np.asarray(self.entries).shape
|
|
235
|
-
|
|
236
|
-
return ()
|
|
259
|
+
return ()
|
|
237
260
|
|
|
238
261
|
def serialize(self):
|
|
239
262
|
return json.dumps(self, cls=AggregatorJSONEncoder)
|
|
240
263
|
|
|
241
264
|
def save(self, path: str):
|
|
242
|
-
with open(
|
|
265
|
+
with Path(path).open("w") as f:
|
|
243
266
|
json.dump(self, f, cls=AggregatorJSONEncoder)
|
|
244
267
|
|
|
245
268
|
@staticmethod
|
|
246
269
|
def load(path: str, strict: bool = True):
|
|
247
|
-
with
|
|
270
|
+
with Path(path).open() as f:
|
|
248
271
|
json_ = json.load(f)
|
|
249
272
|
return Aggregator._reconstruct(json_, strict=strict)
|
|
250
273
|
|
|
@@ -253,45 +276,68 @@ class Aggregator(Symbol):
|
|
|
253
276
|
|
|
254
277
|
def add(self, entries):
|
|
255
278
|
# Add entries to the aggregator
|
|
256
|
-
if not self.active:
|
|
257
|
-
|
|
258
|
-
raise Exception('Aggregator object is frozen!')
|
|
279
|
+
if not self.active and self._finalized:
|
|
280
|
+
UserMessage("Aggregator object is frozen!", raise_with=Exception)
|
|
259
281
|
return
|
|
260
282
|
try:
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
if type(entries) == torch.Tensor:
|
|
264
|
-
entries = entries.detach().cpu().numpy().astype(np.float32)
|
|
265
|
-
elif type(entries) in [tuple, list]:
|
|
266
|
-
entries = np.asarray(entries, dtype=np.float32)
|
|
267
|
-
elif type(entries) in [int, float]:
|
|
268
|
-
entries = entries
|
|
269
|
-
elif type(entries) == bool:
|
|
270
|
-
entries = int(entries)
|
|
271
|
-
elif type(entries) == str:
|
|
272
|
-
entries = Symbol(entries).embedding.astype(np.float32)
|
|
273
|
-
elif isinstance(entries, Symbol):
|
|
274
|
-
# use this to avoid recursion on map setter
|
|
275
|
-
self.add(entries._value)
|
|
276
|
-
return
|
|
277
|
-
elif isinstance(entries, Aggregator):
|
|
278
|
-
self.add(entries.get())
|
|
283
|
+
processed_entries = self._prepare_entries(entries)
|
|
284
|
+
if processed_entries is None:
|
|
279
285
|
return
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
entries = entries.squeeze()
|
|
283
|
-
|
|
284
|
-
self.entries.append(entries)
|
|
286
|
+
processed_entries = self._squeeze_entries(processed_entries)
|
|
287
|
+
self.entries.append(processed_entries)
|
|
285
288
|
except Exception as e:
|
|
289
|
+
msg = f"Could not add entries to Aggregator object! Please verify type or original error: {e}"
|
|
286
290
|
if self._raise_error:
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
291
|
+
UserMessage(msg)
|
|
292
|
+
raise Exception(msg) from e
|
|
293
|
+
UserMessage(msg)
|
|
294
|
+
|
|
295
|
+
def _prepare_entries(self, entries):
|
|
296
|
+
valid_types = (
|
|
297
|
+
tuple,
|
|
298
|
+
list,
|
|
299
|
+
np.float32,
|
|
300
|
+
np.float64,
|
|
301
|
+
np.ndarray,
|
|
302
|
+
torch.Tensor,
|
|
303
|
+
int,
|
|
304
|
+
float,
|
|
305
|
+
bool,
|
|
306
|
+
str,
|
|
307
|
+
Symbol,
|
|
308
|
+
)
|
|
309
|
+
assert isinstance(entries, valid_types), (
|
|
310
|
+
f"Entries must be a tuple, list, numpy array, torch tensor, integer, float, boolean, string, or Symbol! Got: {type(entries)}"
|
|
311
|
+
)
|
|
312
|
+
if isinstance(entries, torch.Tensor):
|
|
313
|
+
return entries.detach().cpu().numpy().astype(np.float32)
|
|
314
|
+
if isinstance(entries, (tuple, list)):
|
|
315
|
+
return np.asarray(entries, dtype=np.float32)
|
|
316
|
+
if isinstance(entries, bool):
|
|
317
|
+
return int(entries)
|
|
318
|
+
if isinstance(entries, str):
|
|
319
|
+
return Symbol(entries).embedding.astype(np.float32)
|
|
320
|
+
if isinstance(entries, Symbol):
|
|
321
|
+
# Use this to avoid recursion on map setter
|
|
322
|
+
self.add(entries._value)
|
|
323
|
+
return None
|
|
324
|
+
if isinstance(entries, Aggregator):
|
|
325
|
+
self.add(entries.get())
|
|
326
|
+
return None
|
|
327
|
+
return entries
|
|
328
|
+
|
|
329
|
+
def _squeeze_entries(self, entries):
|
|
330
|
+
if isinstance(entries, (np.ndarray, np.float32)):
|
|
331
|
+
return entries.squeeze()
|
|
332
|
+
return entries
|
|
290
333
|
|
|
291
334
|
def keys(self):
|
|
292
335
|
# Get all key names of items that have the SPECIAL_CONSTANT prefix
|
|
293
|
-
return [
|
|
294
|
-
|
|
336
|
+
return [
|
|
337
|
+
key.replace(SPECIAL_CONSTANT, "")
|
|
338
|
+
for key in self.__dict__
|
|
339
|
+
if not key.startswith("_") and key.replace(SPECIAL_CONSTANT, "") not in EXCLUDE_LIST
|
|
340
|
+
]
|
|
295
341
|
|
|
296
342
|
@property
|
|
297
343
|
def active(self):
|
|
@@ -301,7 +347,7 @@ class Aggregator(Symbol):
|
|
|
301
347
|
@active.setter
|
|
302
348
|
def active(self, value):
|
|
303
349
|
# Set the active status of the aggregator
|
|
304
|
-
assert isinstance(value, bool),
|
|
350
|
+
assert isinstance(value, bool), f"Active status must be a boolean! Got: {type(value)}"
|
|
305
351
|
self._active = value
|
|
306
352
|
|
|
307
353
|
@property
|
|
@@ -312,25 +358,31 @@ class Aggregator(Symbol):
|
|
|
312
358
|
@finalized.setter
|
|
313
359
|
def finalized(self, value):
|
|
314
360
|
# Set the finalized status of the aggregator
|
|
315
|
-
assert isinstance(value, bool),
|
|
361
|
+
assert isinstance(value, bool), f"Finalized status must be a boolean! Got: {type(value)}"
|
|
316
362
|
self._finalized = value
|
|
317
363
|
|
|
318
364
|
def finalize(self):
|
|
319
365
|
# Finalizes the dynamic creation of the aggregators and freezes the object to prevent further changes
|
|
320
|
-
self._active
|
|
321
|
-
self._finalized
|
|
366
|
+
self._active = False
|
|
367
|
+
self._finalized = True
|
|
368
|
+
|
|
322
369
|
def raise_exception(name, value):
|
|
323
|
-
if name ==
|
|
370
|
+
if name == "map":
|
|
324
371
|
self.__setattr__(name, value)
|
|
325
372
|
else:
|
|
326
|
-
|
|
373
|
+
UserMessage("Aggregator object is frozen!", raise_with=Exception)
|
|
374
|
+
|
|
327
375
|
self.__setattr__ = raise_exception
|
|
376
|
+
|
|
328
377
|
def get_attribute(*args, **kwargs):
|
|
329
378
|
return self.__dict__.get(*args, **kwargs)
|
|
379
|
+
|
|
330
380
|
self.__getattr__ = get_attribute
|
|
331
381
|
# Do the same recursively for all properties of type Aggregator
|
|
332
382
|
for key, value in self.__dict__.items():
|
|
333
|
-
if isinstance(value, Aggregator) and (
|
|
383
|
+
if isinstance(value, Aggregator) and (
|
|
384
|
+
not key.startswith("_") or key.startswith(SPECIAL_CONSTANT)
|
|
385
|
+
):
|
|
334
386
|
value.finalize()
|
|
335
387
|
|
|
336
388
|
def get(self, *args, **kwargs):
|
|
@@ -342,7 +394,7 @@ class Aggregator(Symbol):
|
|
|
342
394
|
def clear(self):
|
|
343
395
|
# Clear the entries of the aggregator
|
|
344
396
|
if self._finalized:
|
|
345
|
-
|
|
397
|
+
UserMessage("Aggregator object is frozen!", raise_with=Exception)
|
|
346
398
|
self._value = []
|
|
347
399
|
|
|
348
400
|
def sum(self, axis=0):
|