symbolicai 0.20.2__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- symai/__init__.py +96 -64
- symai/backend/base.py +93 -80
- symai/backend/engines/drawing/engine_bfl.py +12 -11
- symai/backend/engines/drawing/engine_gpt_image.py +108 -87
- symai/backend/engines/embedding/engine_llama_cpp.py +25 -28
- symai/backend/engines/embedding/engine_openai.py +3 -5
- symai/backend/engines/execute/engine_python.py +6 -5
- symai/backend/engines/files/engine_io.py +74 -67
- symai/backend/engines/imagecaptioning/engine_blip2.py +3 -3
- symai/backend/engines/imagecaptioning/engine_llavacpp_client.py +54 -38
- symai/backend/engines/index/engine_pinecone.py +23 -24
- symai/backend/engines/index/engine_vectordb.py +16 -14
- symai/backend/engines/lean/engine_lean4.py +38 -34
- symai/backend/engines/neurosymbolic/__init__.py +41 -13
- symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_chat.py +262 -182
- symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_reasoning.py +263 -191
- symai/backend/engines/neurosymbolic/engine_deepseekX_reasoning.py +53 -49
- symai/backend/engines/neurosymbolic/engine_google_geminiX_reasoning.py +212 -211
- symai/backend/engines/neurosymbolic/engine_groq.py +87 -63
- symai/backend/engines/neurosymbolic/engine_huggingface.py +21 -24
- symai/backend/engines/neurosymbolic/engine_llama_cpp.py +117 -48
- symai/backend/engines/neurosymbolic/engine_openai_gptX_chat.py +256 -229
- symai/backend/engines/neurosymbolic/engine_openai_gptX_reasoning.py +270 -150
- symai/backend/engines/ocr/engine_apilayer.py +6 -8
- symai/backend/engines/output/engine_stdout.py +1 -4
- symai/backend/engines/search/engine_openai.py +7 -7
- symai/backend/engines/search/engine_perplexity.py +5 -5
- symai/backend/engines/search/engine_serpapi.py +12 -14
- symai/backend/engines/speech_to_text/engine_local_whisper.py +20 -27
- symai/backend/engines/symbolic/engine_wolframalpha.py +3 -3
- symai/backend/engines/text_to_speech/engine_openai.py +5 -7
- symai/backend/engines/text_vision/engine_clip.py +7 -11
- symai/backend/engines/userinput/engine_console.py +3 -3
- symai/backend/engines/webscraping/engine_requests.py +81 -48
- symai/backend/mixin/__init__.py +13 -0
- symai/backend/mixin/anthropic.py +4 -2
- symai/backend/mixin/deepseek.py +2 -0
- symai/backend/mixin/google.py +2 -0
- symai/backend/mixin/openai.py +11 -3
- symai/backend/settings.py +83 -16
- symai/chat.py +101 -78
- symai/collect/__init__.py +7 -1
- symai/collect/dynamic.py +77 -69
- symai/collect/pipeline.py +35 -27
- symai/collect/stats.py +75 -63
- symai/components.py +198 -169
- symai/constraints.py +15 -12
- symai/core.py +698 -359
- symai/core_ext.py +32 -34
- symai/endpoints/api.py +80 -73
- symai/extended/.DS_Store +0 -0
- symai/extended/__init__.py +46 -12
- symai/extended/api_builder.py +11 -8
- symai/extended/arxiv_pdf_parser.py +13 -12
- symai/extended/bibtex_parser.py +2 -3
- symai/extended/conversation.py +101 -90
- symai/extended/document.py +17 -10
- symai/extended/file_merger.py +18 -13
- symai/extended/graph.py +18 -13
- symai/extended/html_style_template.py +2 -4
- symai/extended/interfaces/blip_2.py +1 -2
- symai/extended/interfaces/clip.py +1 -2
- symai/extended/interfaces/console.py +7 -1
- symai/extended/interfaces/dall_e.py +1 -1
- symai/extended/interfaces/flux.py +1 -1
- symai/extended/interfaces/gpt_image.py +1 -1
- symai/extended/interfaces/input.py +1 -1
- symai/extended/interfaces/llava.py +0 -1
- symai/extended/interfaces/naive_vectordb.py +7 -8
- symai/extended/interfaces/naive_webscraping.py +1 -1
- symai/extended/interfaces/ocr.py +1 -1
- symai/extended/interfaces/pinecone.py +6 -5
- symai/extended/interfaces/serpapi.py +1 -1
- symai/extended/interfaces/terminal.py +2 -3
- symai/extended/interfaces/tts.py +1 -1
- symai/extended/interfaces/whisper.py +1 -1
- symai/extended/interfaces/wolframalpha.py +1 -1
- symai/extended/metrics/__init__.py +11 -1
- symai/extended/metrics/similarity.py +11 -13
- symai/extended/os_command.py +17 -16
- symai/extended/packages/__init__.py +29 -3
- symai/extended/packages/symdev.py +19 -16
- symai/extended/packages/sympkg.py +12 -9
- symai/extended/packages/symrun.py +21 -19
- symai/extended/repo_cloner.py +11 -10
- symai/extended/seo_query_optimizer.py +1 -2
- symai/extended/solver.py +20 -23
- symai/extended/summarizer.py +4 -3
- symai/extended/taypan_interpreter.py +10 -12
- symai/extended/vectordb.py +99 -82
- symai/formatter/__init__.py +9 -1
- symai/formatter/formatter.py +12 -16
- symai/formatter/regex.py +62 -63
- symai/functional.py +176 -122
- symai/imports.py +136 -127
- symai/interfaces.py +56 -27
- symai/memory.py +14 -13
- symai/misc/console.py +49 -39
- symai/misc/loader.py +5 -3
- symai/models/__init__.py +17 -1
- symai/models/base.py +269 -181
- symai/models/errors.py +0 -1
- symai/ops/__init__.py +32 -22
- symai/ops/measures.py +11 -15
- symai/ops/primitives.py +348 -228
- symai/post_processors.py +32 -28
- symai/pre_processors.py +39 -41
- symai/processor.py +6 -4
- symai/prompts.py +59 -45
- symai/server/huggingface_server.py +23 -20
- symai/server/llama_cpp_server.py +7 -5
- symai/shell.py +3 -4
- symai/shellsv.py +499 -375
- symai/strategy.py +517 -287
- symai/symbol.py +111 -116
- symai/utils.py +42 -36
- {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/METADATA +4 -2
- symbolicai-1.0.0.dist-info/RECORD +163 -0
- symbolicai-0.20.2.dist-info/RECORD +0 -162
- {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/WHEEL +0 -0
- {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/entry_points.txt +0 -0
- {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {symbolicai-0.20.2.dist-info → symbolicai-1.0.0.dist-info}/top_level.txt +0 -0
symai/collect/pipeline.py
CHANGED
|
@@ -1,14 +1,21 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import json
|
|
2
4
|
import logging
|
|
3
5
|
from datetime import datetime
|
|
4
|
-
from typing import
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
5
7
|
|
|
6
8
|
from bson.objectid import ObjectId
|
|
7
|
-
from pymongo.collection import Collection
|
|
8
|
-
from pymongo.database import Database
|
|
9
9
|
from pymongo.mongo_client import MongoClient
|
|
10
10
|
|
|
11
11
|
from ..backend.settings import SYMAI_CONFIG
|
|
12
|
+
from ..utils import UserMessage
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from pymongo.collection import Collection
|
|
16
|
+
from pymongo.database import Database
|
|
17
|
+
else:
|
|
18
|
+
Collection = Database = Any
|
|
12
19
|
|
|
13
20
|
logger = logging.getLogger(__name__)
|
|
14
21
|
|
|
@@ -24,18 +31,17 @@ def rec_serialize(obj):
|
|
|
24
31
|
if isinstance(obj, (int, float, bool)):
|
|
25
32
|
# For simple types, return the string representation directly.
|
|
26
33
|
return obj
|
|
27
|
-
|
|
34
|
+
if isinstance(obj, dict):
|
|
28
35
|
# For dictionaries, serialize each value. Keep keys as strings.
|
|
29
36
|
return {str(key): rec_serialize(value) for key, value in obj.items()}
|
|
30
|
-
|
|
37
|
+
if isinstance(obj, (list, tuple, set)):
|
|
31
38
|
# For lists, tuples, and sets, serialize each element.
|
|
32
39
|
return [rec_serialize(elem) for elem in obj]
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
return str(obj)
|
|
40
|
+
# Attempt JSON serialization first, then fall back to str(...)
|
|
41
|
+
try:
|
|
42
|
+
return json.dumps(obj)
|
|
43
|
+
except TypeError:
|
|
44
|
+
return str(obj)
|
|
39
45
|
|
|
40
46
|
|
|
41
47
|
class CollectionRepository:
|
|
@@ -44,15 +50,15 @@ class CollectionRepository:
|
|
|
44
50
|
self.uri: str = SYMAI_CONFIG["COLLECTION_URI"]
|
|
45
51
|
self.db_name: str = SYMAI_CONFIG["COLLECTION_DB"]
|
|
46
52
|
self.collection_name: str = SYMAI_CONFIG["COLLECTION_STORAGE"]
|
|
47
|
-
self.client:
|
|
48
|
-
self.db:
|
|
49
|
-
self.collection:
|
|
53
|
+
self.client: MongoClient | None = None
|
|
54
|
+
self.db: Database | None = None
|
|
55
|
+
self.collection: Collection | None = None
|
|
50
56
|
|
|
51
|
-
def __enter__(self) ->
|
|
57
|
+
def __enter__(self) -> CollectionRepository:
|
|
52
58
|
self.connect()
|
|
53
59
|
return self
|
|
54
60
|
|
|
55
|
-
def __exit__(self, exc_type:
|
|
61
|
+
def __exit__(self, exc_type: type | None, exc_val: Exception | None, exc_tb: Any | None) -> None:
|
|
56
62
|
self.close()
|
|
57
63
|
|
|
58
64
|
def ping(self) -> bool:
|
|
@@ -63,10 +69,12 @@ class CollectionRepository:
|
|
|
63
69
|
self.client.admin.command('ping')
|
|
64
70
|
return True
|
|
65
71
|
except Exception as e:
|
|
66
|
-
|
|
72
|
+
UserMessage(f"Connection failed: {e}")
|
|
67
73
|
return False
|
|
68
74
|
|
|
69
|
-
def add(self, forward: Any, engine: Any, metadata:
|
|
75
|
+
def add(self, forward: Any, engine: Any, metadata: dict[str, Any] | None = None) -> Any:
|
|
76
|
+
if metadata is None:
|
|
77
|
+
metadata = {}
|
|
70
78
|
if not self.support_community:
|
|
71
79
|
return None
|
|
72
80
|
record = {
|
|
@@ -78,22 +86,22 @@ class CollectionRepository:
|
|
|
78
86
|
}
|
|
79
87
|
try: # assure that adding a record does never cause a system error
|
|
80
88
|
return self.collection.insert_one(record).inserted_id if self.collection else None
|
|
81
|
-
except Exception
|
|
89
|
+
except Exception:
|
|
82
90
|
return None
|
|
83
91
|
|
|
84
|
-
def get(self, record_id: str) ->
|
|
92
|
+
def get(self, record_id: str) -> dict[str, Any] | None:
|
|
85
93
|
if not self.support_community:
|
|
86
94
|
return None
|
|
87
95
|
return self.collection.find_one({'_id': ObjectId(record_id)}) if self.collection else None
|
|
88
96
|
|
|
89
97
|
def update(self,
|
|
90
98
|
record_id: str,
|
|
91
|
-
forward:
|
|
92
|
-
engine:
|
|
93
|
-
metadata:
|
|
99
|
+
forward: Any | None = None,
|
|
100
|
+
engine: str | None = None,
|
|
101
|
+
metadata: dict[str, Any] | None = None) -> Any:
|
|
94
102
|
if not self.support_community:
|
|
95
103
|
return None
|
|
96
|
-
updates:
|
|
104
|
+
updates: dict[str, Any] = {'updated_at': datetime.now()}
|
|
97
105
|
if forward is not None:
|
|
98
106
|
updates['forward'] = forward
|
|
99
107
|
if engine is not None:
|
|
@@ -108,14 +116,14 @@ class CollectionRepository:
|
|
|
108
116
|
return None
|
|
109
117
|
return self.collection.delete_one({'_id': ObjectId(record_id)}) if self.collection else None
|
|
110
118
|
|
|
111
|
-
def list(self, filters:
|
|
119
|
+
def list(self, filters: dict[str, Any] | None = None, limit: int = 0) -> list[dict[str, Any]]:
|
|
112
120
|
if not self.support_community:
|
|
113
121
|
return None
|
|
114
122
|
if filters is None:
|
|
115
123
|
filters = {}
|
|
116
124
|
return list(self.collection.find(filters).limit(limit)) if self.collection else []
|
|
117
125
|
|
|
118
|
-
def count(self, filters:
|
|
126
|
+
def count(self, filters: dict[str, Any] | None = None) -> int:
|
|
119
127
|
if not self.support_community:
|
|
120
128
|
return None
|
|
121
129
|
if filters is None:
|
|
@@ -133,7 +141,7 @@ class CollectionRepository:
|
|
|
133
141
|
self.client = False
|
|
134
142
|
self.db = None
|
|
135
143
|
self.collection = None
|
|
136
|
-
|
|
144
|
+
UserMessage(f"[WARN] MongoClient: Connection failed: {e}")
|
|
137
145
|
|
|
138
146
|
def close(self) -> None:
|
|
139
147
|
if self.client is not None:
|
symai/collect/stats.py
CHANGED
|
@@ -1,13 +1,16 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import re
|
|
3
|
+
from collections.abc import Callable
|
|
3
4
|
from json import JSONEncoder
|
|
4
|
-
from
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Union
|
|
5
7
|
|
|
6
8
|
import numpy as np
|
|
7
9
|
import torch
|
|
8
10
|
|
|
9
11
|
from ..ops.primitives import OperatorPrimitives
|
|
10
12
|
from ..symbol import Symbol
|
|
13
|
+
from ..utils import UserMessage
|
|
11
14
|
|
|
12
15
|
SPECIAL_CONSTANT = '__aggregate_'
|
|
13
16
|
EXCLUDE_LIST = ['_ipython_canary_method_should_not_exist_', '__custom_documentations__']
|
|
@@ -24,7 +27,7 @@ class AggregatorJSONEncoder(JSONEncoder):
|
|
|
24
27
|
if isinstance(obj, np.ndarray):
|
|
25
28
|
return obj.tolist()
|
|
26
29
|
# drop active from state
|
|
27
|
-
|
|
30
|
+
if isinstance(obj, Aggregator):
|
|
28
31
|
state = obj.__dict__.copy()
|
|
29
32
|
state.pop('_raise_error', None)
|
|
30
33
|
state.pop('_active', None)
|
|
@@ -32,8 +35,8 @@ class AggregatorJSONEncoder(JSONEncoder):
|
|
|
32
35
|
state.pop('_map', None)
|
|
33
36
|
# drop everything that starts with SPECIAL_CONSTANT
|
|
34
37
|
for key in list(state.keys()):
|
|
35
|
-
if not key.startswith(SPECIAL_CONSTANT) and key != '_value' or \
|
|
36
|
-
key == '_value' and obj._value == [] or \
|
|
38
|
+
if (not key.startswith(SPECIAL_CONSTANT) and key != '_value') or \
|
|
39
|
+
(key == '_value' and obj._value == []) or \
|
|
37
40
|
key.replace(SPECIAL_CONSTANT, '') in EXCLUDE_LIST:
|
|
38
41
|
state.pop(key, None)
|
|
39
42
|
return state
|
|
@@ -42,8 +45,8 @@ class AggregatorJSONEncoder(JSONEncoder):
|
|
|
42
45
|
|
|
43
46
|
class Aggregator(Symbol):
|
|
44
47
|
def __init__(self,
|
|
45
|
-
value:
|
|
46
|
-
path:
|
|
48
|
+
value: Union["Aggregator", Symbol] | None = None,
|
|
49
|
+
path: str | None = None,
|
|
47
50
|
active: bool = True,
|
|
48
51
|
raise_error: bool = False,
|
|
49
52
|
*args, **kwargs):
|
|
@@ -61,7 +64,7 @@ class Aggregator(Symbol):
|
|
|
61
64
|
elif not isinstance(self._value, (list, tuple)):
|
|
62
65
|
self._value = [self._value]
|
|
63
66
|
elif value is not None:
|
|
64
|
-
|
|
67
|
+
UserMessage(f'Aggregator object must be of type Aggregator or Symbol! Got: {type(value)}', raise_with=Exception)
|
|
65
68
|
else:
|
|
66
69
|
self._value = []
|
|
67
70
|
self._raise_error = raise_error
|
|
@@ -71,11 +74,13 @@ class Aggregator(Symbol):
|
|
|
71
74
|
self._path = path
|
|
72
75
|
|
|
73
76
|
def __new__(cls, *args,
|
|
74
|
-
mixin:
|
|
75
|
-
primitives:
|
|
76
|
-
callables:
|
|
77
|
+
mixin: bool | None = None,
|
|
78
|
+
primitives: list[type] | None = None, # only inherit arithmetic primitives
|
|
79
|
+
callables: list[tuple[str, Callable]] | None = None,
|
|
77
80
|
semantic: bool = False,
|
|
78
81
|
**kwargs) -> "Symbol":
|
|
82
|
+
if primitives is None:
|
|
83
|
+
primitives = [OperatorPrimitives]
|
|
79
84
|
return super().__new__(cls, *args,
|
|
80
85
|
mixin=mixin,
|
|
81
86
|
primitives=primitives,
|
|
@@ -97,8 +102,8 @@ class Aggregator(Symbol):
|
|
|
97
102
|
# add also a property with the same name but without the SPECIAL_CONSTANT prefix as a shortcut
|
|
98
103
|
self.__dict__[name] = self.__dict__[f'{SPECIAL_CONSTANT}{name}']
|
|
99
104
|
return self.__dict__.get(name)
|
|
100
|
-
|
|
101
|
-
|
|
105
|
+
if not self._active and name not in self.__dict__:
|
|
106
|
+
UserMessage(f'Aggregator object is frozen! No attribute {name} found!', raise_with=Exception)
|
|
102
107
|
return self.__dict__.get(name)
|
|
103
108
|
|
|
104
109
|
def __setattr__(self, name, value):
|
|
@@ -139,19 +144,23 @@ class Aggregator(Symbol):
|
|
|
139
144
|
def _set_values(obj, dictionary, parent, strict: bool = True):
|
|
140
145
|
# recursively reconstruct the object
|
|
141
146
|
for key, value in dictionary.items():
|
|
142
|
-
|
|
147
|
+
attr_key = key
|
|
148
|
+
attr_value = value
|
|
149
|
+
if isinstance(attr_value, dict):
|
|
143
150
|
if parent is not None:
|
|
144
|
-
obj._path =
|
|
145
|
-
|
|
146
|
-
if
|
|
147
|
-
|
|
148
|
-
if
|
|
151
|
+
obj._path = attr_key
|
|
152
|
+
attr_value = Aggregator._reconstruct(attr_value, parent=parent, strict=strict)
|
|
153
|
+
if attr_key.startswith(SPECIAL_CONSTANT):
|
|
154
|
+
attr_key = attr_key.replace(SPECIAL_CONSTANT, '')
|
|
155
|
+
if attr_key == '_value':
|
|
149
156
|
try:
|
|
150
|
-
|
|
157
|
+
attr_value = np.asarray(attr_value, dtype=np.float32)
|
|
151
158
|
except Exception as e:
|
|
152
159
|
if strict:
|
|
153
|
-
|
|
154
|
-
|
|
160
|
+
msg = f'Could not set value of Aggregator object: {obj.path}! ERROR: {e}'
|
|
161
|
+
UserMessage(msg)
|
|
162
|
+
raise Exception(msg) from e
|
|
163
|
+
obj.__setattr__(attr_key, attr_value)
|
|
155
164
|
|
|
156
165
|
@staticmethod
|
|
157
166
|
def _reconstruct(dictionary, parent = None, strict: bool = True):
|
|
@@ -186,8 +195,7 @@ class Aggregator(Symbol):
|
|
|
186
195
|
if obj._path is not None:
|
|
187
196
|
path = obj._path.replace(SPECIAL_CONSTANT, '') + '.' + path
|
|
188
197
|
obj = obj._parent
|
|
189
|
-
|
|
190
|
-
return path
|
|
198
|
+
return path[:-1] # remove last dot
|
|
191
199
|
|
|
192
200
|
def __or__(self, other: Any) -> Any:
|
|
193
201
|
self.add(other)
|
|
@@ -211,8 +219,7 @@ class Aggregator(Symbol):
|
|
|
211
219
|
@property
|
|
212
220
|
def value(self):
|
|
213
221
|
if self.map is not None:
|
|
214
|
-
|
|
215
|
-
return res
|
|
222
|
+
return np.asarray(self.map(np.asarray(self._value, dtype=np.float32)))
|
|
216
223
|
return np.asarray(self._value, dtype=np.float32)
|
|
217
224
|
|
|
218
225
|
@property
|
|
@@ -232,19 +239,18 @@ class Aggregator(Symbol):
|
|
|
232
239
|
def shape(self):
|
|
233
240
|
if len(self.entries) > 0:
|
|
234
241
|
return np.asarray(self.entries).shape
|
|
235
|
-
|
|
236
|
-
return ()
|
|
242
|
+
return ()
|
|
237
243
|
|
|
238
244
|
def serialize(self):
|
|
239
245
|
return json.dumps(self, cls=AggregatorJSONEncoder)
|
|
240
246
|
|
|
241
247
|
def save(self, path: str):
|
|
242
|
-
with open(
|
|
248
|
+
with Path(path).open('w') as f:
|
|
243
249
|
json.dump(self, f, cls=AggregatorJSONEncoder)
|
|
244
250
|
|
|
245
251
|
@staticmethod
|
|
246
252
|
def load(path: str, strict: bool = True):
|
|
247
|
-
with
|
|
253
|
+
with Path(path).open() as f:
|
|
248
254
|
json_ = json.load(f)
|
|
249
255
|
return Aggregator._reconstruct(json_, strict=strict)
|
|
250
256
|
|
|
@@ -253,44 +259,50 @@ class Aggregator(Symbol):
|
|
|
253
259
|
|
|
254
260
|
def add(self, entries):
|
|
255
261
|
# Add entries to the aggregator
|
|
256
|
-
if not self.active:
|
|
257
|
-
|
|
258
|
-
raise Exception('Aggregator object is frozen!')
|
|
262
|
+
if not self.active and self._finalized:
|
|
263
|
+
UserMessage('Aggregator object is frozen!', raise_with=Exception)
|
|
259
264
|
return
|
|
260
265
|
try:
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
if type(entries) == torch.Tensor:
|
|
264
|
-
entries = entries.detach().cpu().numpy().astype(np.float32)
|
|
265
|
-
elif type(entries) in [tuple, list]:
|
|
266
|
-
entries = np.asarray(entries, dtype=np.float32)
|
|
267
|
-
elif type(entries) in [int, float]:
|
|
268
|
-
entries = entries
|
|
269
|
-
elif type(entries) == bool:
|
|
270
|
-
entries = int(entries)
|
|
271
|
-
elif type(entries) == str:
|
|
272
|
-
entries = Symbol(entries).embedding.astype(np.float32)
|
|
273
|
-
elif isinstance(entries, Symbol):
|
|
274
|
-
# use this to avoid recursion on map setter
|
|
275
|
-
self.add(entries._value)
|
|
276
|
-
return
|
|
277
|
-
elif isinstance(entries, Aggregator):
|
|
278
|
-
self.add(entries.get())
|
|
266
|
+
processed_entries = self._prepare_entries(entries)
|
|
267
|
+
if processed_entries is None:
|
|
279
268
|
return
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
entries = entries.squeeze()
|
|
283
|
-
|
|
284
|
-
self.entries.append(entries)
|
|
269
|
+
processed_entries = self._squeeze_entries(processed_entries)
|
|
270
|
+
self.entries.append(processed_entries)
|
|
285
271
|
except Exception as e:
|
|
272
|
+
msg = f'Could not add entries to Aggregator object! Please verify type or original error: {e}'
|
|
286
273
|
if self._raise_error:
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
274
|
+
UserMessage(msg)
|
|
275
|
+
raise Exception(msg) from e
|
|
276
|
+
UserMessage(msg)
|
|
277
|
+
|
|
278
|
+
def _prepare_entries(self, entries):
|
|
279
|
+
valid_types = (tuple, list, np.float32, np.float64, np.ndarray, torch.Tensor, int, float, bool, str, Symbol)
|
|
280
|
+
assert isinstance(entries, valid_types), f'Entries must be a tuple, list, numpy array, torch tensor, integer, float, boolean, string, or Symbol! Got: {type(entries)}'
|
|
281
|
+
if isinstance(entries, torch.Tensor):
|
|
282
|
+
return entries.detach().cpu().numpy().astype(np.float32)
|
|
283
|
+
if isinstance(entries, (tuple, list)):
|
|
284
|
+
return np.asarray(entries, dtype=np.float32)
|
|
285
|
+
if isinstance(entries, bool):
|
|
286
|
+
return int(entries)
|
|
287
|
+
if isinstance(entries, str):
|
|
288
|
+
return Symbol(entries).embedding.astype(np.float32)
|
|
289
|
+
if isinstance(entries, Symbol):
|
|
290
|
+
# Use this to avoid recursion on map setter
|
|
291
|
+
self.add(entries._value)
|
|
292
|
+
return None
|
|
293
|
+
if isinstance(entries, Aggregator):
|
|
294
|
+
self.add(entries.get())
|
|
295
|
+
return None
|
|
296
|
+
return entries
|
|
297
|
+
|
|
298
|
+
def _squeeze_entries(self, entries):
|
|
299
|
+
if isinstance(entries, (np.ndarray, np.float32)):
|
|
300
|
+
return entries.squeeze()
|
|
301
|
+
return entries
|
|
290
302
|
|
|
291
303
|
def keys(self):
|
|
292
304
|
# Get all key names of items that have the SPECIAL_CONSTANT prefix
|
|
293
|
-
return [key.replace(SPECIAL_CONSTANT, '') for key in self.__dict__
|
|
305
|
+
return [key.replace(SPECIAL_CONSTANT, '') for key in self.__dict__ if not key.startswith('_') and \
|
|
294
306
|
key.replace(SPECIAL_CONSTANT, '') not in EXCLUDE_LIST]
|
|
295
307
|
|
|
296
308
|
@property
|
|
@@ -301,7 +313,7 @@ class Aggregator(Symbol):
|
|
|
301
313
|
@active.setter
|
|
302
314
|
def active(self, value):
|
|
303
315
|
# Set the active status of the aggregator
|
|
304
|
-
assert isinstance(value, bool), 'Active status must be a boolean! Got: {
|
|
316
|
+
assert isinstance(value, bool), f'Active status must be a boolean! Got: {type(value)}'
|
|
305
317
|
self._active = value
|
|
306
318
|
|
|
307
319
|
@property
|
|
@@ -312,7 +324,7 @@ class Aggregator(Symbol):
|
|
|
312
324
|
@finalized.setter
|
|
313
325
|
def finalized(self, value):
|
|
314
326
|
# Set the finalized status of the aggregator
|
|
315
|
-
assert isinstance(value, bool), 'Finalized status must be a boolean! Got: {
|
|
327
|
+
assert isinstance(value, bool), f'Finalized status must be a boolean! Got: {type(value)}'
|
|
316
328
|
self._finalized = value
|
|
317
329
|
|
|
318
330
|
def finalize(self):
|
|
@@ -323,7 +335,7 @@ class Aggregator(Symbol):
|
|
|
323
335
|
if name == 'map':
|
|
324
336
|
self.__setattr__(name, value)
|
|
325
337
|
else:
|
|
326
|
-
|
|
338
|
+
UserMessage('Aggregator object is frozen!', raise_with=Exception)
|
|
327
339
|
self.__setattr__ = raise_exception
|
|
328
340
|
def get_attribute(*args, **kwargs):
|
|
329
341
|
return self.__dict__.get(*args, **kwargs)
|
|
@@ -342,7 +354,7 @@ class Aggregator(Symbol):
|
|
|
342
354
|
def clear(self):
|
|
343
355
|
# Clear the entries of the aggregator
|
|
344
356
|
if self._finalized:
|
|
345
|
-
|
|
357
|
+
UserMessage('Aggregator object is frozen!', raise_with=Exception)
|
|
346
358
|
self._value = []
|
|
347
359
|
|
|
348
360
|
def sum(self, axis=0):
|