symbolicai 1.5.0__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- symai/__init__.py +21 -71
- symai/backend/base.py +0 -26
- symai/backend/engines/drawing/engine_gemini_image.py +101 -0
- symai/backend/engines/embedding/engine_openai.py +11 -8
- symai/backend/engines/neurosymbolic/__init__.py +8 -0
- symai/backend/engines/neurosymbolic/engine_google_geminiX_reasoning.py +14 -1
- symai/backend/engines/neurosymbolic/engine_openrouter.py +294 -0
- symai/backend/mixin/__init__.py +4 -0
- symai/backend/mixin/openrouter.py +2 -0
- symai/components.py +203 -13
- symai/extended/interfaces/nanobanana.py +23 -0
- symai/interfaces.py +2 -0
- symai/ops/primitives.py +0 -18
- symai/shellsv.py +2 -7
- {symbolicai-1.5.0.dist-info → symbolicai-1.6.0.dist-info}/METADATA +2 -9
- {symbolicai-1.5.0.dist-info → symbolicai-1.6.0.dist-info}/RECORD +20 -43
- {symbolicai-1.5.0.dist-info → symbolicai-1.6.0.dist-info}/WHEEL +1 -1
- symai/backend/driver/webclient.py +0 -217
- symai/backend/engines/crawler/engine_selenium.py +0 -94
- symai/backend/engines/drawing/engine_dall_e.py +0 -131
- symai/backend/engines/embedding/engine_plugin_embeddings.py +0 -12
- symai/backend/engines/experiments/engine_bard_wrapper.py +0 -131
- symai/backend/engines/experiments/engine_gptfinetuner.py +0 -32
- symai/backend/engines/experiments/engine_llamacpp_completion.py +0 -142
- symai/backend/engines/neurosymbolic/engine_openai_gptX_completion.py +0 -277
- symai/collect/__init__.py +0 -8
- symai/collect/dynamic.py +0 -117
- symai/collect/pipeline.py +0 -156
- symai/collect/stats.py +0 -434
- symai/extended/crawler.py +0 -21
- symai/extended/interfaces/selenium.py +0 -18
- symai/extended/interfaces/vectordb.py +0 -21
- symai/extended/personas/__init__.py +0 -3
- symai/extended/personas/builder.py +0 -105
- symai/extended/personas/dialogue.py +0 -126
- symai/extended/personas/persona.py +0 -154
- symai/extended/personas/research/__init__.py +0 -1
- symai/extended/personas/research/yann_lecun.py +0 -62
- symai/extended/personas/sales/__init__.py +0 -1
- symai/extended/personas/sales/erik_james.py +0 -62
- symai/extended/personas/student/__init__.py +0 -1
- symai/extended/personas/student/max_tenner.py +0 -51
- symai/extended/strategies/__init__.py +0 -1
- symai/extended/strategies/cot.py +0 -40
- {symbolicai-1.5.0.dist-info → symbolicai-1.6.0.dist-info}/entry_points.txt +0 -0
- {symbolicai-1.5.0.dist-info → symbolicai-1.6.0.dist-info}/licenses/LICENSE +0 -0
- {symbolicai-1.5.0.dist-info → symbolicai-1.6.0.dist-info}/top_level.txt +0 -0
symai/collect/dynamic.py
DELETED
|
@@ -1,117 +0,0 @@
|
|
|
1
|
-
import ast
|
|
2
|
-
import re
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
class DynamicClass:
|
|
6
|
-
def __init__(self, **kwargs):
|
|
7
|
-
self.__dict__.update(kwargs)
|
|
8
|
-
|
|
9
|
-
def __repr__(self):
|
|
10
|
-
return str(self.__dict__)
|
|
11
|
-
|
|
12
|
-
@staticmethod
|
|
13
|
-
def from_string(s):
|
|
14
|
-
return create_object_from_string(s)
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
def create_dynamic_class(class_name, **kwargs):
|
|
18
|
-
return type(class_name, (DynamicClass,), kwargs)()
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
def parse_custom_class_instances(s):
|
|
22
|
-
pattern = r"(\w+)\((.*?)\)"
|
|
23
|
-
if not isinstance(s, str):
|
|
24
|
-
return s
|
|
25
|
-
matches = re.finditer(pattern, s)
|
|
26
|
-
|
|
27
|
-
for match in matches:
|
|
28
|
-
class_name = match.group(1)
|
|
29
|
-
class_args = match.group(2)
|
|
30
|
-
try:
|
|
31
|
-
parsed_args = ast.literal_eval(f"{{{class_args}}}")
|
|
32
|
-
except (ValueError, SyntaxError):
|
|
33
|
-
parsed_args = create_object_from_string(class_args)
|
|
34
|
-
class_instance = create_dynamic_class(class_name, **parsed_args)
|
|
35
|
-
s = s.replace(match.group(0), repr(class_instance))
|
|
36
|
-
|
|
37
|
-
return s
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
def _strip_quotes(text):
|
|
41
|
-
if not isinstance(text, str):
|
|
42
|
-
return text
|
|
43
|
-
if text.startswith("'") and text.endswith("'"):
|
|
44
|
-
return text.strip("'")
|
|
45
|
-
if text.startswith('"') and text.endswith('"'):
|
|
46
|
-
return text.strip('"')
|
|
47
|
-
return text
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
def _extract_content(str_class):
|
|
51
|
-
return str_class.split("ChatCompletionMessage(content=")[-1].split(", role=")[0][1:-1]
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
def _parse_value(value):
|
|
55
|
-
try:
|
|
56
|
-
value = parse_custom_class_instances(value)
|
|
57
|
-
if not isinstance(value, str):
|
|
58
|
-
return value
|
|
59
|
-
if value.startswith("["):
|
|
60
|
-
inner_values = value[1:-1]
|
|
61
|
-
values = inner_values.split(",")
|
|
62
|
-
return [_parse_value(v.strip()) for v in values]
|
|
63
|
-
if value.startswith("{"):
|
|
64
|
-
inner_values = value[1:-1]
|
|
65
|
-
values = inner_values.split(",")
|
|
66
|
-
return {
|
|
67
|
-
k.strip(): _parse_value(v.strip()) for k, v in [v.split(":", 1) for v in values]
|
|
68
|
-
}
|
|
69
|
-
result = ast.literal_eval(value)
|
|
70
|
-
if isinstance(result, dict):
|
|
71
|
-
return {k: _parse_value(v) for k, v in result.items()}
|
|
72
|
-
if isinstance(result, (list, tuple, set)):
|
|
73
|
-
return [_parse_value(v) for v in result]
|
|
74
|
-
return result
|
|
75
|
-
except (ValueError, SyntaxError):
|
|
76
|
-
return value
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
def _process_list_value(raw_value):
|
|
80
|
-
parsed_value = _parse_value(raw_value)
|
|
81
|
-
dir(parsed_value)
|
|
82
|
-
if hasattr(parsed_value, "__dict__"):
|
|
83
|
-
for key in parsed_value.__dict__:
|
|
84
|
-
value = getattr(parsed_value, key)
|
|
85
|
-
if isinstance(value, str):
|
|
86
|
-
parsed_value[key.strip("'")] = value.strip("'")
|
|
87
|
-
return parsed_value
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
def _process_dict_value(raw_value):
|
|
91
|
-
parsed_value = _parse_value(raw_value)
|
|
92
|
-
new_value = {}
|
|
93
|
-
for key, value in parsed_value.items():
|
|
94
|
-
stripped_value = value.strip("'") if isinstance(value, str) else value
|
|
95
|
-
new_value[key.strip("'")] = stripped_value
|
|
96
|
-
return new_value
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
def _collect_attributes(str_class):
|
|
100
|
-
attr_pattern = r"(\w+)=(\[.*?\]|\{.*?\}|'.*?'|None|\w+)"
|
|
101
|
-
attributes = re.findall(attr_pattern, str_class)
|
|
102
|
-
updated_attributes = [("content", _extract_content(str_class))]
|
|
103
|
-
for key, raw_value in attributes:
|
|
104
|
-
attr_key = _strip_quotes(key)
|
|
105
|
-
attr_value = _strip_quotes(raw_value)
|
|
106
|
-
if attr_value.startswith("[") and attr_value.endswith("]"):
|
|
107
|
-
attr_value = _process_list_value(attr_value)
|
|
108
|
-
elif attr_value.startswith("{") and attr_value.endswith("}"):
|
|
109
|
-
attr_value = _process_dict_value(attr_value)
|
|
110
|
-
updated_attributes.append((attr_key, attr_value))
|
|
111
|
-
return updated_attributes
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
# TODO: fix to properly parse nested lists and dicts
|
|
115
|
-
def create_object_from_string(str_class):
|
|
116
|
-
updated_attributes = _collect_attributes(str_class)
|
|
117
|
-
return DynamicClass(**{key: _parse_value(value) for key, value in updated_attributes})
|
symai/collect/pipeline.py
DELETED
|
@@ -1,156 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
import json
|
|
4
|
-
import logging
|
|
5
|
-
from datetime import datetime
|
|
6
|
-
from typing import TYPE_CHECKING, Any
|
|
7
|
-
|
|
8
|
-
from bson.objectid import ObjectId
|
|
9
|
-
from pymongo.mongo_client import MongoClient
|
|
10
|
-
|
|
11
|
-
from ..backend.settings import SYMAI_CONFIG
|
|
12
|
-
from ..utils import UserMessage
|
|
13
|
-
|
|
14
|
-
if TYPE_CHECKING:
|
|
15
|
-
from pymongo.collection import Collection
|
|
16
|
-
from pymongo.database import Database
|
|
17
|
-
else:
|
|
18
|
-
Collection = Database = Any
|
|
19
|
-
|
|
20
|
-
logger = logging.getLogger(__name__)
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def rec_serialize(obj):
|
|
24
|
-
"""
|
|
25
|
-
Recursively serialize a given object into a string representation, handling
|
|
26
|
-
nested structures like lists and dictionaries.
|
|
27
|
-
|
|
28
|
-
:param obj: The object to be serialized.
|
|
29
|
-
:return: A string representation of the serialized object.
|
|
30
|
-
"""
|
|
31
|
-
if isinstance(obj, (int, float, bool)):
|
|
32
|
-
# For simple types, return the string representation directly.
|
|
33
|
-
return obj
|
|
34
|
-
if isinstance(obj, dict):
|
|
35
|
-
# For dictionaries, serialize each value. Keep keys as strings.
|
|
36
|
-
return {str(key): rec_serialize(value) for key, value in obj.items()}
|
|
37
|
-
if isinstance(obj, (list, tuple, set)):
|
|
38
|
-
# For lists, tuples, and sets, serialize each element.
|
|
39
|
-
return [rec_serialize(elem) for elem in obj]
|
|
40
|
-
# Attempt JSON serialization first, then fall back to str(...)
|
|
41
|
-
try:
|
|
42
|
-
return json.dumps(obj)
|
|
43
|
-
except TypeError:
|
|
44
|
-
return str(obj)
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
class CollectionRepository:
|
|
48
|
-
def __init__(self) -> None:
|
|
49
|
-
self.support_community: bool = SYMAI_CONFIG["SUPPORT_COMMUNITY"]
|
|
50
|
-
self.uri: str = SYMAI_CONFIG["COLLECTION_URI"]
|
|
51
|
-
self.db_name: str = SYMAI_CONFIG["COLLECTION_DB"]
|
|
52
|
-
self.collection_name: str = SYMAI_CONFIG["COLLECTION_STORAGE"]
|
|
53
|
-
self.client: MongoClient | None = None
|
|
54
|
-
self.db: Database | None = None
|
|
55
|
-
self.collection: Collection | None = None
|
|
56
|
-
|
|
57
|
-
def __enter__(self) -> CollectionRepository:
|
|
58
|
-
self.connect()
|
|
59
|
-
return self
|
|
60
|
-
|
|
61
|
-
def __exit__(
|
|
62
|
-
self, exc_type: type | None, exc_val: Exception | None, exc_tb: Any | None
|
|
63
|
-
) -> None:
|
|
64
|
-
self.close()
|
|
65
|
-
|
|
66
|
-
def ping(self) -> bool:
|
|
67
|
-
if not self.support_community:
|
|
68
|
-
return False
|
|
69
|
-
# Send a ping to confirm a successful connection
|
|
70
|
-
try:
|
|
71
|
-
self.client.admin.command("ping")
|
|
72
|
-
return True
|
|
73
|
-
except Exception as e:
|
|
74
|
-
UserMessage(f"Connection failed: {e}")
|
|
75
|
-
return False
|
|
76
|
-
|
|
77
|
-
def add(self, forward: Any, engine: Any, metadata: dict[str, Any] | None = None) -> Any:
|
|
78
|
-
if metadata is None:
|
|
79
|
-
metadata = {}
|
|
80
|
-
if not self.support_community:
|
|
81
|
-
return None
|
|
82
|
-
record = {
|
|
83
|
-
"forward": forward,
|
|
84
|
-
"engine": engine,
|
|
85
|
-
"metadata": metadata,
|
|
86
|
-
"created_at": datetime.now(),
|
|
87
|
-
"updated_at": datetime.now(),
|
|
88
|
-
}
|
|
89
|
-
try: # assure that adding a record does never cause a system error
|
|
90
|
-
return self.collection.insert_one(record).inserted_id if self.collection else None
|
|
91
|
-
except Exception:
|
|
92
|
-
return None
|
|
93
|
-
|
|
94
|
-
def get(self, record_id: str) -> dict[str, Any] | None:
|
|
95
|
-
if not self.support_community:
|
|
96
|
-
return None
|
|
97
|
-
return self.collection.find_one({"_id": ObjectId(record_id)}) if self.collection else None
|
|
98
|
-
|
|
99
|
-
def update(
|
|
100
|
-
self,
|
|
101
|
-
record_id: str,
|
|
102
|
-
forward: Any | None = None,
|
|
103
|
-
engine: str | None = None,
|
|
104
|
-
metadata: dict[str, Any] | None = None,
|
|
105
|
-
) -> Any:
|
|
106
|
-
if not self.support_community:
|
|
107
|
-
return None
|
|
108
|
-
updates: dict[str, Any] = {"updated_at": datetime.now()}
|
|
109
|
-
if forward is not None:
|
|
110
|
-
updates["forward"] = forward
|
|
111
|
-
if engine is not None:
|
|
112
|
-
updates["engine"] = engine
|
|
113
|
-
if metadata is not None:
|
|
114
|
-
updates["metadata"] = metadata
|
|
115
|
-
|
|
116
|
-
return (
|
|
117
|
-
self.collection.update_one({"_id": ObjectId(record_id)}, {"$set": updates})
|
|
118
|
-
if self.collection
|
|
119
|
-
else None
|
|
120
|
-
)
|
|
121
|
-
|
|
122
|
-
def delete(self, record_id: str) -> Any:
|
|
123
|
-
if not self.support_community:
|
|
124
|
-
return None
|
|
125
|
-
return self.collection.delete_one({"_id": ObjectId(record_id)}) if self.collection else None
|
|
126
|
-
|
|
127
|
-
def list(self, filters: dict[str, Any] | None = None, limit: int = 0) -> list[dict[str, Any]]:
|
|
128
|
-
if not self.support_community:
|
|
129
|
-
return None
|
|
130
|
-
if filters is None:
|
|
131
|
-
filters = {}
|
|
132
|
-
return list(self.collection.find(filters).limit(limit)) if self.collection else []
|
|
133
|
-
|
|
134
|
-
def count(self, filters: dict[str, Any] | None = None) -> int:
|
|
135
|
-
if not self.support_community:
|
|
136
|
-
return None
|
|
137
|
-
if filters is None:
|
|
138
|
-
filters = {}
|
|
139
|
-
return self.collection.count_documents(filters) if self.collection else 0
|
|
140
|
-
|
|
141
|
-
def connect(self) -> None:
|
|
142
|
-
try:
|
|
143
|
-
if self.client is None and self.support_community:
|
|
144
|
-
self.client = MongoClient(self.uri)
|
|
145
|
-
self.db = self.client[self.db_name]
|
|
146
|
-
self.collection = self.db[self.collection_name]
|
|
147
|
-
except Exception as e:
|
|
148
|
-
# disable retries
|
|
149
|
-
self.client = False
|
|
150
|
-
self.db = None
|
|
151
|
-
self.collection = None
|
|
152
|
-
UserMessage(f"[WARN] MongoClient: Connection failed: {e}")
|
|
153
|
-
|
|
154
|
-
def close(self) -> None:
|
|
155
|
-
if self.client is not None:
|
|
156
|
-
self.client.close()
|
symai/collect/stats.py
DELETED
|
@@ -1,434 +0,0 @@
|
|
|
1
|
-
import json
|
|
2
|
-
import re
|
|
3
|
-
from collections.abc import Callable
|
|
4
|
-
from json import JSONEncoder
|
|
5
|
-
from pathlib import Path
|
|
6
|
-
from typing import Any, Union
|
|
7
|
-
|
|
8
|
-
import numpy as np
|
|
9
|
-
import torch
|
|
10
|
-
|
|
11
|
-
from ..ops.primitives import OperatorPrimitives
|
|
12
|
-
from ..symbol import Symbol
|
|
13
|
-
from ..utils import UserMessage
|
|
14
|
-
|
|
15
|
-
SPECIAL_CONSTANT = "__aggregate_"
|
|
16
|
-
EXCLUDE_LIST = ["_ipython_canary_method_should_not_exist_", "__custom_documentations__"]
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def _normalize_name(name: str) -> str:
|
|
20
|
-
# Replace any character that is not a letter or a number with an underscore
|
|
21
|
-
normalized_name = re.sub(r"[^a-zA-Z0-9]", "_", name)
|
|
22
|
-
return normalized_name.lower()
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
class AggregatorJSONEncoder(JSONEncoder):
|
|
26
|
-
def default(self, obj):
|
|
27
|
-
if isinstance(obj, np.ndarray):
|
|
28
|
-
return obj.tolist()
|
|
29
|
-
# drop active from state
|
|
30
|
-
if isinstance(obj, Aggregator):
|
|
31
|
-
state = obj.__dict__.copy()
|
|
32
|
-
state.pop("_raise_error", None)
|
|
33
|
-
state.pop("_active", None)
|
|
34
|
-
state.pop("_finalized", None)
|
|
35
|
-
state.pop("_map", None)
|
|
36
|
-
# drop everything that starts with SPECIAL_CONSTANT
|
|
37
|
-
for key in list(state.keys()):
|
|
38
|
-
if (
|
|
39
|
-
(not key.startswith(SPECIAL_CONSTANT) and key != "_value")
|
|
40
|
-
or (key == "_value" and obj._value == [])
|
|
41
|
-
or key.replace(SPECIAL_CONSTANT, "") in EXCLUDE_LIST
|
|
42
|
-
):
|
|
43
|
-
state.pop(key, None)
|
|
44
|
-
return state
|
|
45
|
-
return obj.__dict__
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
class Aggregator(Symbol):
|
|
49
|
-
def __init__(
|
|
50
|
-
self,
|
|
51
|
-
value: Union["Aggregator", Symbol] | None = None,
|
|
52
|
-
path: str | None = None,
|
|
53
|
-
active: bool = True,
|
|
54
|
-
raise_error: bool = False,
|
|
55
|
-
*args,
|
|
56
|
-
**kwargs,
|
|
57
|
-
):
|
|
58
|
-
super().__init__(*args, **kwargs)
|
|
59
|
-
# disable nesy engine to avoid side effects
|
|
60
|
-
self.__disable_nesy_engine__ = True
|
|
61
|
-
if value is not None and isinstance(value, Symbol):
|
|
62
|
-
# use this to avoid recursion on map setter
|
|
63
|
-
self._value = value._value
|
|
64
|
-
if isinstance(self._value, np.ndarray):
|
|
65
|
-
self._value = self._value.tolist()
|
|
66
|
-
elif isinstance(self._value, torch.Tensor):
|
|
67
|
-
self._value = self._value.detach().cpu().numpy().tolist()
|
|
68
|
-
elif not isinstance(self._value, (list, tuple)):
|
|
69
|
-
self._value = [self._value]
|
|
70
|
-
elif value is not None:
|
|
71
|
-
UserMessage(
|
|
72
|
-
f"Aggregator object must be of type Aggregator or Symbol! Got: {type(value)}",
|
|
73
|
-
raise_with=Exception,
|
|
74
|
-
)
|
|
75
|
-
else:
|
|
76
|
-
self._value = []
|
|
77
|
-
self._raise_error = raise_error
|
|
78
|
-
self._active = active
|
|
79
|
-
self._finalized = False
|
|
80
|
-
self._map = None
|
|
81
|
-
self._path = path
|
|
82
|
-
|
|
83
|
-
def __new__(
|
|
84
|
-
cls,
|
|
85
|
-
*args,
|
|
86
|
-
mixin: bool | None = None,
|
|
87
|
-
primitives: list[type] | None = None, # only inherit arithmetic primitives
|
|
88
|
-
callables: list[tuple[str, Callable]] | None = None,
|
|
89
|
-
semantic: bool = False,
|
|
90
|
-
**kwargs,
|
|
91
|
-
) -> "Symbol":
|
|
92
|
-
if primitives is None:
|
|
93
|
-
primitives = [OperatorPrimitives]
|
|
94
|
-
return super().__new__(
|
|
95
|
-
cls,
|
|
96
|
-
*args,
|
|
97
|
-
mixin=mixin,
|
|
98
|
-
primitives=primitives,
|
|
99
|
-
callables=callables,
|
|
100
|
-
semantic=semantic,
|
|
101
|
-
**kwargs,
|
|
102
|
-
)
|
|
103
|
-
|
|
104
|
-
def __getattr__(self, name):
|
|
105
|
-
# replace name special characters and spaces with underscores
|
|
106
|
-
name = _normalize_name(name)
|
|
107
|
-
# Dynamically create new aggregator instance if it does not exist
|
|
108
|
-
if self._active and name not in self.__dict__:
|
|
109
|
-
aggregator = Aggregator(path=name)
|
|
110
|
-
aggregator._parent = self
|
|
111
|
-
self._children.append(aggregator)
|
|
112
|
-
# create a new aggregate aggregator
|
|
113
|
-
# named {SPECIAL_CONSTANT}{name} for automatic aggregation
|
|
114
|
-
self.__dict__[f"{SPECIAL_CONSTANT}{name}"] = aggregator
|
|
115
|
-
# add also a property with the same name but without the SPECIAL_CONSTANT prefix as a shortcut
|
|
116
|
-
self.__dict__[name] = self.__dict__[f"{SPECIAL_CONSTANT}{name}"]
|
|
117
|
-
return self.__dict__.get(name)
|
|
118
|
-
if not self._active and name not in self.__dict__:
|
|
119
|
-
UserMessage(
|
|
120
|
-
f"Aggregator object is frozen! No attribute {name} found!", raise_with=Exception
|
|
121
|
-
)
|
|
122
|
-
return self.__dict__.get(name)
|
|
123
|
-
|
|
124
|
-
def __setattr__(self, name, value):
|
|
125
|
-
# replace name special characters and spaces with underscores
|
|
126
|
-
name = _normalize_name(name)
|
|
127
|
-
return super().__setattr__(name, value)
|
|
128
|
-
|
|
129
|
-
def __delattr__(self, name):
|
|
130
|
-
# replace name special characters and spaces with underscores
|
|
131
|
-
name = _normalize_name(name)
|
|
132
|
-
return super().__delattr__(name)
|
|
133
|
-
|
|
134
|
-
def __getitem__(self, name):
|
|
135
|
-
# replace name special characters and spaces with underscores
|
|
136
|
-
name = _normalize_name(name)
|
|
137
|
-
return self.__getattr__(name)
|
|
138
|
-
|
|
139
|
-
def __setitem__(self, name, value):
|
|
140
|
-
# replace name special characters and spaces with underscores
|
|
141
|
-
name = _normalize_name(name)
|
|
142
|
-
return self.__setattr__(name, value)
|
|
143
|
-
|
|
144
|
-
def __delitem__(self, name):
|
|
145
|
-
# replace name special characters and spaces with underscores
|
|
146
|
-
name = _normalize_name(name)
|
|
147
|
-
return self.__delattr__(name)
|
|
148
|
-
|
|
149
|
-
def __setstate__(self, state):
|
|
150
|
-
# replace name special characters and spaces with underscores
|
|
151
|
-
# drop active from state
|
|
152
|
-
state.pop("_raise_error", None)
|
|
153
|
-
state.pop("_active", None)
|
|
154
|
-
state.pop("_finalized", None)
|
|
155
|
-
state.pop("_map", None)
|
|
156
|
-
return super().__setstate__(state)
|
|
157
|
-
|
|
158
|
-
@staticmethod
|
|
159
|
-
def _set_values(obj, dictionary, parent, strict: bool = True):
|
|
160
|
-
# recursively reconstruct the object
|
|
161
|
-
for key, value in dictionary.items():
|
|
162
|
-
attr_key = key
|
|
163
|
-
attr_value = value
|
|
164
|
-
if isinstance(attr_value, dict):
|
|
165
|
-
if parent is not None:
|
|
166
|
-
obj._path = attr_key
|
|
167
|
-
attr_value = Aggregator._reconstruct(attr_value, parent=parent, strict=strict)
|
|
168
|
-
if attr_key.startswith(SPECIAL_CONSTANT):
|
|
169
|
-
attr_key = attr_key.replace(SPECIAL_CONSTANT, "")
|
|
170
|
-
if attr_key == "_value":
|
|
171
|
-
try:
|
|
172
|
-
attr_value = np.asarray(attr_value, dtype=np.float32)
|
|
173
|
-
except Exception as e:
|
|
174
|
-
if strict:
|
|
175
|
-
msg = f"Could not set value of Aggregator object: {obj.path}! ERROR: {e}"
|
|
176
|
-
UserMessage(msg)
|
|
177
|
-
raise Exception(msg) from e
|
|
178
|
-
obj.__setattr__(attr_key, attr_value)
|
|
179
|
-
|
|
180
|
-
@staticmethod
|
|
181
|
-
def _reconstruct(dictionary, parent=None, strict: bool = True):
|
|
182
|
-
obj = Aggregator()
|
|
183
|
-
obj._parent = parent
|
|
184
|
-
if parent is not None:
|
|
185
|
-
parent._children.append(obj)
|
|
186
|
-
Aggregator._set_values(obj, dictionary, parent=obj, strict=strict)
|
|
187
|
-
return obj
|
|
188
|
-
|
|
189
|
-
def __str__(self) -> str:
|
|
190
|
-
"""
|
|
191
|
-
Get the string representation of the Symbol object.
|
|
192
|
-
|
|
193
|
-
Returns:
|
|
194
|
-
str: The string representation of the Symbol object.
|
|
195
|
-
"""
|
|
196
|
-
return str(self.entries)
|
|
197
|
-
|
|
198
|
-
def _to_symbol(self, other) -> Symbol:
|
|
199
|
-
sym = super()._to_symbol(other)
|
|
200
|
-
res = Aggregator(sym)
|
|
201
|
-
res._parent = self
|
|
202
|
-
self._children.append(res)
|
|
203
|
-
return
|
|
204
|
-
|
|
205
|
-
@property
|
|
206
|
-
def path(self) -> str:
|
|
207
|
-
path = ""
|
|
208
|
-
obj = self
|
|
209
|
-
while obj is not None:
|
|
210
|
-
if obj._path is not None:
|
|
211
|
-
path = obj._path.replace(SPECIAL_CONSTANT, "") + "." + path
|
|
212
|
-
obj = obj._parent
|
|
213
|
-
return path[:-1] # remove last dot
|
|
214
|
-
|
|
215
|
-
def __or__(self, other: Any) -> Any:
|
|
216
|
-
self.add(other)
|
|
217
|
-
return other
|
|
218
|
-
|
|
219
|
-
def __ror__(self, other: Any) -> Any:
|
|
220
|
-
self.add(other)
|
|
221
|
-
return other
|
|
222
|
-
|
|
223
|
-
def __ior__(self, other: Any) -> Any:
|
|
224
|
-
self.add(other)
|
|
225
|
-
return other
|
|
226
|
-
|
|
227
|
-
def __len__(self) -> int:
|
|
228
|
-
return len(self._value)
|
|
229
|
-
|
|
230
|
-
@property
|
|
231
|
-
def entries(self):
|
|
232
|
-
return self._value
|
|
233
|
-
|
|
234
|
-
@property
|
|
235
|
-
def value(self):
|
|
236
|
-
if self.map is not None:
|
|
237
|
-
return np.asarray(self.map(np.asarray(self._value, dtype=np.float32)))
|
|
238
|
-
return np.asarray(self._value, dtype=np.float32)
|
|
239
|
-
|
|
240
|
-
@property
|
|
241
|
-
def map(self):
|
|
242
|
-
return self._map if not self.empty() else None
|
|
243
|
-
|
|
244
|
-
@map.setter
|
|
245
|
-
def map(self, value):
|
|
246
|
-
self._set_map_recursively(value)
|
|
247
|
-
|
|
248
|
-
def _set_map_recursively(self, map):
|
|
249
|
-
self._map = map
|
|
250
|
-
for key, value in self.__dict__.items():
|
|
251
|
-
if isinstance(value, Aggregator) and (
|
|
252
|
-
not key.startswith("_") or key.startswith(SPECIAL_CONSTANT)
|
|
253
|
-
):
|
|
254
|
-
value.map = map
|
|
255
|
-
|
|
256
|
-
def shape(self):
|
|
257
|
-
if len(self.entries) > 0:
|
|
258
|
-
return np.asarray(self.entries).shape
|
|
259
|
-
return ()
|
|
260
|
-
|
|
261
|
-
def serialize(self):
|
|
262
|
-
return json.dumps(self, cls=AggregatorJSONEncoder)
|
|
263
|
-
|
|
264
|
-
def save(self, path: str):
|
|
265
|
-
with Path(path).open("w") as f:
|
|
266
|
-
json.dump(self, f, cls=AggregatorJSONEncoder)
|
|
267
|
-
|
|
268
|
-
@staticmethod
|
|
269
|
-
def load(path: str, strict: bool = True):
|
|
270
|
-
with Path(path).open() as f:
|
|
271
|
-
json_ = json.load(f)
|
|
272
|
-
return Aggregator._reconstruct(json_, strict=strict)
|
|
273
|
-
|
|
274
|
-
def empty(self) -> bool:
|
|
275
|
-
return len(self) == 0
|
|
276
|
-
|
|
277
|
-
def add(self, entries):
|
|
278
|
-
# Add entries to the aggregator
|
|
279
|
-
if not self.active and self._finalized:
|
|
280
|
-
UserMessage("Aggregator object is frozen!", raise_with=Exception)
|
|
281
|
-
return
|
|
282
|
-
try:
|
|
283
|
-
processed_entries = self._prepare_entries(entries)
|
|
284
|
-
if processed_entries is None:
|
|
285
|
-
return
|
|
286
|
-
processed_entries = self._squeeze_entries(processed_entries)
|
|
287
|
-
self.entries.append(processed_entries)
|
|
288
|
-
except Exception as e:
|
|
289
|
-
msg = f"Could not add entries to Aggregator object! Please verify type or original error: {e}"
|
|
290
|
-
if self._raise_error:
|
|
291
|
-
UserMessage(msg)
|
|
292
|
-
raise Exception(msg) from e
|
|
293
|
-
UserMessage(msg)
|
|
294
|
-
|
|
295
|
-
def _prepare_entries(self, entries):
|
|
296
|
-
valid_types = (
|
|
297
|
-
tuple,
|
|
298
|
-
list,
|
|
299
|
-
np.float32,
|
|
300
|
-
np.float64,
|
|
301
|
-
np.ndarray,
|
|
302
|
-
torch.Tensor,
|
|
303
|
-
int,
|
|
304
|
-
float,
|
|
305
|
-
bool,
|
|
306
|
-
str,
|
|
307
|
-
Symbol,
|
|
308
|
-
)
|
|
309
|
-
assert isinstance(entries, valid_types), (
|
|
310
|
-
f"Entries must be a tuple, list, numpy array, torch tensor, integer, float, boolean, string, or Symbol! Got: {type(entries)}"
|
|
311
|
-
)
|
|
312
|
-
if isinstance(entries, torch.Tensor):
|
|
313
|
-
return entries.detach().cpu().numpy().astype(np.float32)
|
|
314
|
-
if isinstance(entries, (tuple, list)):
|
|
315
|
-
return np.asarray(entries, dtype=np.float32)
|
|
316
|
-
if isinstance(entries, bool):
|
|
317
|
-
return int(entries)
|
|
318
|
-
if isinstance(entries, str):
|
|
319
|
-
return Symbol(entries).embedding.astype(np.float32)
|
|
320
|
-
if isinstance(entries, Symbol):
|
|
321
|
-
# Use this to avoid recursion on map setter
|
|
322
|
-
self.add(entries._value)
|
|
323
|
-
return None
|
|
324
|
-
if isinstance(entries, Aggregator):
|
|
325
|
-
self.add(entries.get())
|
|
326
|
-
return None
|
|
327
|
-
return entries
|
|
328
|
-
|
|
329
|
-
def _squeeze_entries(self, entries):
|
|
330
|
-
if isinstance(entries, (np.ndarray, np.float32)):
|
|
331
|
-
return entries.squeeze()
|
|
332
|
-
return entries
|
|
333
|
-
|
|
334
|
-
def keys(self):
|
|
335
|
-
# Get all key names of items that have the SPECIAL_CONSTANT prefix
|
|
336
|
-
return [
|
|
337
|
-
key.replace(SPECIAL_CONSTANT, "")
|
|
338
|
-
for key in self.__dict__
|
|
339
|
-
if not key.startswith("_") and key.replace(SPECIAL_CONSTANT, "") not in EXCLUDE_LIST
|
|
340
|
-
]
|
|
341
|
-
|
|
342
|
-
@property
|
|
343
|
-
def active(self):
|
|
344
|
-
# Get the active status of the aggregator
|
|
345
|
-
return self._active
|
|
346
|
-
|
|
347
|
-
@active.setter
|
|
348
|
-
def active(self, value):
|
|
349
|
-
# Set the active status of the aggregator
|
|
350
|
-
assert isinstance(value, bool), f"Active status must be a boolean! Got: {type(value)}"
|
|
351
|
-
self._active = value
|
|
352
|
-
|
|
353
|
-
@property
|
|
354
|
-
def finalized(self):
|
|
355
|
-
# Get the finalized status of the aggregator
|
|
356
|
-
return self._finalized
|
|
357
|
-
|
|
358
|
-
@finalized.setter
|
|
359
|
-
def finalized(self, value):
|
|
360
|
-
# Set the finalized status of the aggregator
|
|
361
|
-
assert isinstance(value, bool), f"Finalized status must be a boolean! Got: {type(value)}"
|
|
362
|
-
self._finalized = value
|
|
363
|
-
|
|
364
|
-
def finalize(self):
|
|
365
|
-
# Finalizes the dynamic creation of the aggregators and freezes the object to prevent further changes
|
|
366
|
-
self._active = False
|
|
367
|
-
self._finalized = True
|
|
368
|
-
|
|
369
|
-
def raise_exception(name, value):
|
|
370
|
-
if name == "map":
|
|
371
|
-
self.__setattr__(name, value)
|
|
372
|
-
else:
|
|
373
|
-
UserMessage("Aggregator object is frozen!", raise_with=Exception)
|
|
374
|
-
|
|
375
|
-
self.__setattr__ = raise_exception
|
|
376
|
-
|
|
377
|
-
def get_attribute(*args, **kwargs):
|
|
378
|
-
return self.__dict__.get(*args, **kwargs)
|
|
379
|
-
|
|
380
|
-
self.__getattr__ = get_attribute
|
|
381
|
-
# Do the same recursively for all properties of type Aggregator
|
|
382
|
-
for key, value in self.__dict__.items():
|
|
383
|
-
if isinstance(value, Aggregator) and (
|
|
384
|
-
not key.startswith("_") or key.startswith(SPECIAL_CONSTANT)
|
|
385
|
-
):
|
|
386
|
-
value.finalize()
|
|
387
|
-
|
|
388
|
-
def get(self, *args, **kwargs):
|
|
389
|
-
if self._map is not None:
|
|
390
|
-
return self._map(self.entries, *args, **kwargs)
|
|
391
|
-
# Get the entries of the aggregator
|
|
392
|
-
return self.entries
|
|
393
|
-
|
|
394
|
-
def clear(self):
|
|
395
|
-
# Clear the entries of the aggregator
|
|
396
|
-
if self._finalized:
|
|
397
|
-
UserMessage("Aggregator object is frozen!", raise_with=Exception)
|
|
398
|
-
self._value = []
|
|
399
|
-
|
|
400
|
-
def sum(self, axis=0):
|
|
401
|
-
# Get the sum of the entries of the aggregator
|
|
402
|
-
return np.sum(self.entries, axis=axis)
|
|
403
|
-
|
|
404
|
-
def mean(self, axis=0):
|
|
405
|
-
# Get the mean of the entries of the aggregator
|
|
406
|
-
return np.mean(self.entries, axis=axis)
|
|
407
|
-
|
|
408
|
-
def median(self, axis=0):
|
|
409
|
-
# Get the median of the entries of the aggregator
|
|
410
|
-
return np.median(self.entries, axis=axis)
|
|
411
|
-
|
|
412
|
-
def var(self, axis=0):
|
|
413
|
-
# Get the variance of the entries of the aggregator
|
|
414
|
-
return np.var(self.entries, axis=axis)
|
|
415
|
-
|
|
416
|
-
def cov(self, rowvar=False):
|
|
417
|
-
# Get the covariance of the entries of the aggregator
|
|
418
|
-
return np.cov(self.entries, rowvar=rowvar)
|
|
419
|
-
|
|
420
|
-
def moment(self, moment=2, axis=0):
|
|
421
|
-
# Get the moment of the entries of the aggregator
|
|
422
|
-
return np.mean(np.power(self.entries, moment), axis=axis)
|
|
423
|
-
|
|
424
|
-
def std(self, axis=0):
|
|
425
|
-
# Get the standard deviation of the entries of the aggregator
|
|
426
|
-
return np.std(self.entries, axis=axis)
|
|
427
|
-
|
|
428
|
-
def min(self, axis=0):
|
|
429
|
-
# Get the minimum of the entries of the aggregator
|
|
430
|
-
return np.min(self.entries, axis=axis)
|
|
431
|
-
|
|
432
|
-
def max(self, axis=0):
|
|
433
|
-
# Get the maximum of the entries of the aggregator
|
|
434
|
-
return np.max(self.entries, axis=axis)
|