symbolicai 1.5.0__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. symai/__init__.py +21 -71
  2. symai/backend/base.py +0 -26
  3. symai/backend/engines/drawing/engine_gemini_image.py +101 -0
  4. symai/backend/engines/embedding/engine_openai.py +11 -8
  5. symai/backend/engines/neurosymbolic/__init__.py +8 -0
  6. symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_chat.py +1 -0
  7. symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_reasoning.py +48 -1
  8. symai/backend/engines/neurosymbolic/engine_cerebras.py +1 -0
  9. symai/backend/engines/neurosymbolic/engine_google_geminiX_reasoning.py +14 -1
  10. symai/backend/engines/neurosymbolic/engine_openrouter.py +294 -0
  11. symai/backend/mixin/__init__.py +4 -0
  12. symai/backend/mixin/anthropic.py +37 -16
  13. symai/backend/mixin/openrouter.py +2 -0
  14. symai/components.py +203 -13
  15. symai/extended/interfaces/nanobanana.py +23 -0
  16. symai/interfaces.py +2 -0
  17. symai/ops/primitives.py +0 -18
  18. symai/shellsv.py +2 -7
  19. symai/strategy.py +44 -4
  20. {symbolicai-1.5.0.dist-info → symbolicai-1.7.0.dist-info}/METADATA +3 -10
  21. {symbolicai-1.5.0.dist-info → symbolicai-1.7.0.dist-info}/RECORD +25 -48
  22. {symbolicai-1.5.0.dist-info → symbolicai-1.7.0.dist-info}/WHEEL +1 -1
  23. symai/backend/driver/webclient.py +0 -217
  24. symai/backend/engines/crawler/engine_selenium.py +0 -94
  25. symai/backend/engines/drawing/engine_dall_e.py +0 -131
  26. symai/backend/engines/embedding/engine_plugin_embeddings.py +0 -12
  27. symai/backend/engines/experiments/engine_bard_wrapper.py +0 -131
  28. symai/backend/engines/experiments/engine_gptfinetuner.py +0 -32
  29. symai/backend/engines/experiments/engine_llamacpp_completion.py +0 -142
  30. symai/backend/engines/neurosymbolic/engine_openai_gptX_completion.py +0 -277
  31. symai/collect/__init__.py +0 -8
  32. symai/collect/dynamic.py +0 -117
  33. symai/collect/pipeline.py +0 -156
  34. symai/collect/stats.py +0 -434
  35. symai/extended/crawler.py +0 -21
  36. symai/extended/interfaces/selenium.py +0 -18
  37. symai/extended/interfaces/vectordb.py +0 -21
  38. symai/extended/personas/__init__.py +0 -3
  39. symai/extended/personas/builder.py +0 -105
  40. symai/extended/personas/dialogue.py +0 -126
  41. symai/extended/personas/persona.py +0 -154
  42. symai/extended/personas/research/__init__.py +0 -1
  43. symai/extended/personas/research/yann_lecun.py +0 -62
  44. symai/extended/personas/sales/__init__.py +0 -1
  45. symai/extended/personas/sales/erik_james.py +0 -62
  46. symai/extended/personas/student/__init__.py +0 -1
  47. symai/extended/personas/student/max_tenner.py +0 -51
  48. symai/extended/strategies/__init__.py +0 -1
  49. symai/extended/strategies/cot.py +0 -40
  50. {symbolicai-1.5.0.dist-info → symbolicai-1.7.0.dist-info}/entry_points.txt +0 -0
  51. {symbolicai-1.5.0.dist-info → symbolicai-1.7.0.dist-info}/licenses/LICENSE +0 -0
  52. {symbolicai-1.5.0.dist-info → symbolicai-1.7.0.dist-info}/top_level.txt +0 -0
@@ -1,277 +0,0 @@
1
- import logging
2
- from copy import deepcopy
3
- from typing import List, Optional
4
-
5
- import openai
6
- import tiktoken
7
-
8
- from ....misc.console import ConsoleStyle
9
- from ....utils import CustomUserWarning
10
- from ...base import Engine
11
- from ...mixin.openai import OpenAIMixin
12
- from ...settings import SYMAI_CONFIG
13
-
14
- logging.getLogger("openai").setLevel(logging.ERROR)
15
- logging.getLogger("requests").setLevel(logging.ERROR)
16
- logging.getLogger("urllib").setLevel(logging.ERROR)
17
- logging.getLogger("httpx").setLevel(logging.ERROR)
18
- logging.getLogger("httpcore").setLevel(logging.ERROR)
19
-
20
-
21
- class InvalidRequestErrorRemedyCompletionStrategy:
22
- def __init__(self, *args, **kwargs):
23
- super().__init__(*args, **kwargs)
24
-
25
- def __call__(self, engine, error, callback, argument):
26
- openai_kwargs = {}
27
- kwargs = argument.kwargs
28
- prompts_ = argument.prop.prepared_input
29
- # send prompt to GPT-X Completion-based
30
- stop = kwargs['stop'] if 'stop' in kwargs else None
31
- model = kwargs['model'] if 'model' in kwargs else None
32
-
33
- msg = str(error)
34
- handle = None
35
- try:
36
- if "This model's maximum context length is" in msg:
37
- handle = 'type1'
38
- max_ = engine.max_context_tokens
39
- usr = msg.split('tokens. ')[1].split(' ')[-1]
40
- overflow_tokens = int(usr) - int(max_)
41
- elif "is less than the minimum" in msg:
42
- handle = 'type2'
43
- overflow_tokens = engine.max_response_tokens
44
- else:
45
- raise Exception(msg) from error
46
- except Exception as e:
47
- raise e from error
48
-
49
- # unify the format to use same remedy strategy for both chat and completion
50
- values = prompts_[0].replace('---------SYSTEM BEHAVIOR--------\n', '').split('\n\n---------USER REQUEST--------\n')
51
- prompts_ = [{'role': 'system', 'content': values[0]}, {'role': 'user', 'content': values[1]}]
52
-
53
- prompts = [p for p in prompts_ if p['role'] == 'user']
54
- system_prompt = [p for p in prompts_ if p['role'] == 'system']
55
-
56
- def compute_required_tokens(prompts: dict) -> int:
57
- # iterate over prompts and compute number of tokens
58
- prompts_ = [role['content'] for role in prompts]
59
- prompt = ''.join(prompts_)
60
- val = len(engine.tokenizer.encode(prompt, disallowed_special=()))
61
- return val
62
-
63
- def compute_remaining_tokens(prompts: list) -> int:
64
- val = compute_required_tokens(prompts)
65
- return int((engine.max_context_tokens - val) * 0.99)
66
-
67
- if handle == 'type1':
68
- truncated_content_ = [p['content'][overflow_tokens:] for p in prompts]
69
- truncated_prompts_ = [{'role': p['role'], 'content': c} for p, c in zip(prompts, truncated_content_)]
70
- CustomUserWarning(f"WARNING: Overflow tokens detected. Reducing prompt size by {overflow_tokens} characters.")
71
- elif handle == 'type2':
72
- user_prompts = [p['content'] for p in prompts]
73
- new_prompt = [*system_prompt]
74
- new_prompt.extend([{'role': p['role'], 'content': c} for p, c in zip(prompts, user_prompts)])
75
- overflow_tokens = compute_required_tokens(new_prompt) - int(engine.max_context_tokens * 0.70)
76
- if overflow_tokens > 0:
77
- CustomUserWarning(f'WARNING: Overflow tokens detected. Reducing prompt size to 70% of model context size ({engine.max_context_tokens}).')
78
- for i, content in enumerate(user_prompts):
79
- token_ids = engine.tokenizer.encode(content)
80
- if overflow_tokens >= len(token_ids):
81
- overflow_tokens -= len(token_ids)
82
- user_prompts[i] = ''
83
- else:
84
- new_content = engine.tokenizer.decode(token_ids[:-overflow_tokens])
85
- user_prompts[i] = new_content
86
- overflow_tokens = 0
87
- break
88
-
89
- new_prompt = [*system_prompt]
90
- new_prompt.extend([{'role': p['role'], 'content': c} for p, c in zip(prompts, user_prompts)])
91
- assert compute_required_tokens(new_prompt) <= engine.max_context_tokens, \
92
- f"Token overflow: prompts exceed {engine.max_context_tokens} tokens after truncation"
93
-
94
- truncated_prompts_ = [{'role': p['role'], 'content': c.strip()} for p, c in zip(prompts, user_prompts) if c.strip()]
95
- else:
96
- raise Exception('Invalid handle case for remedy strategy.') from error
97
-
98
- truncated_prompts_ = [*system_prompt, *truncated_prompts_]
99
-
100
- # convert map to list of strings
101
- max_tokens = kwargs['max_tokens'] if 'max_tokens' in kwargs else compute_remaining_tokens(truncated_prompts_)
102
- temperature = kwargs['temperature'] if 'temperature' in kwargs else 1
103
- frequency_penalty = kwargs['frequency_penalty'] if 'frequency_penalty' in kwargs else 0
104
- presence_penalty = kwargs['presence_penalty'] if 'presence_penalty' in kwargs else 0
105
- top_p = kwargs['top_p'] if 'top_p' in kwargs else 1
106
- suffix = kwargs['template_suffix'] if 'template_suffix' in kwargs else None
107
-
108
- system = truncated_prompts_[0]['content']
109
- user = truncated_prompts_[1]['content']
110
- truncated_prompts_ = [f'---------SYSTEM BEHAVIOR--------\n{system}\n\n---------USER REQUEST--------\n{user}']
111
-
112
- if stop is not None:
113
- openai_kwargs['stop'] = stop
114
-
115
- return callback(model=model,
116
- prompt=truncated_prompts_,
117
- suffix=suffix,
118
- max_tokens=max_tokens,
119
- temperature=temperature,
120
- frequency_penalty=frequency_penalty,
121
- presence_penalty=presence_penalty,
122
- top_p=top_p,
123
- n=1,
124
- **openai_kwargs)
125
-
126
-
127
-
128
- class GPTXCompletionEngine(Engine, OpenAIMixin):
129
- def __init__(self, api_key: Optional[str] = None, model: Optional[str] = None):
130
- super().__init__()
131
- logger = logging.getLogger('openai')
132
- logger.setLevel(logging.WARNING)
133
- self.config = deepcopy(SYMAI_CONFIG)
134
- if self.id() != 'neurosymbolic':
135
- return # do not initialize if not neurosymbolic; avoids conflict with llama.cpp check in EngineRepository.register_from_package
136
- openai.api_key = self.config['NEUROSYMBOLIC_ENGINE_API_KEY'] if api_key is None else api_key
137
- self.model = self.config['NEUROSYMBOLIC_ENGINE_MODEL'] if model is None else model
138
- self.tokenizer = tiktoken.encoding_for_model(self.model)
139
- self.max_context_tokens = self.api_max_context_tokens()
140
- self.max_response_tokens = self.api_max_response_tokens()
141
- self.except_remedy = None
142
-
143
- def id(self) -> str:
144
- if self.config['NEUROSYMBOLIC_ENGINE_MODEL'] and \
145
- (self.config['NEUROSYMBOLIC_ENGINE_MODEL'].startswith('text-') or \
146
- self.config['NEUROSYMBOLIC_ENGINE_MODEL'].startswith('davinci') or \
147
- self.config['NEUROSYMBOLIC_ENGINE_MODEL'].startswith('curie') or \
148
- self.config['NEUROSYMBOLIC_ENGINE_MODEL'].startswith('babbage') or \
149
- self.config['NEUROSYMBOLIC_ENGINE_MODEL'].startswith('ada')):
150
- return 'neurosymbolic'
151
- return super().id() # default to unregistered
152
-
153
- def command(self, *args, **kwargs):
154
- super().command(*args, **kwargs)
155
- if 'NEUROSYMBOLIC_ENGINE_API_KEY' in kwargs:
156
- openai.api_key = kwargs['NEUROSYMBOLIC_ENGINE_API_KEY']
157
- if 'NEUROSYMBOLIC_ENGINE_MODEL' in kwargs:
158
- self.model = kwargs['NEUROSYMBOLIC_ENGINE_MODEL']
159
- if 'except_remedy' in kwargs:
160
- self.except_remedy = kwargs['except_remedy']
161
-
162
- def compute_required_tokens(self, prompts: list) -> int:
163
- # iterate over prompts and compute number of tokens
164
- prompt = prompts[0] # index 0 is correct since we only have one prompt in legacy mode
165
- val = len(self.tokenizer.encode(prompt, disallowed_special=()))
166
- return val
167
-
168
- def compute_remaining_tokens(self, prompts: list) -> int:
169
- val = self.compute_required_tokens(prompts)
170
- return min(self.max_context_tokens - val, self.max_response_tokens)
171
-
172
- def forward(self, argument):
173
- kwargs = argument.kwargs
174
- prompts_ = argument.prop.prepared_input
175
-
176
- # send prompt to GPT-3
177
- max_tokens = kwargs['max_tokens'] if 'max_tokens' in kwargs else self.compute_remaining_tokens(prompts_)
178
- stop = kwargs['stop'] if 'stop' in kwargs else None
179
- model = kwargs['model'] if 'model' in kwargs else self.model
180
- temperature = kwargs['temperature'] if 'temperature' in kwargs else 0.7
181
- frequency_penalty = kwargs['frequency_penalty'] if 'frequency_penalty' in kwargs else 0
182
- presence_penalty = kwargs['presence_penalty'] if 'presence_penalty' in kwargs else 0
183
- top_p = kwargs['top_p'] if 'top_p' in kwargs else 1
184
- except_remedy = kwargs['except_remedy'] if 'except_remedy' in kwargs else self.except_remedy
185
-
186
- try:
187
- res = openai.completions.create(model=model,
188
- prompt=prompts_,
189
- max_tokens=max_tokens,
190
- temperature=temperature,
191
- frequency_penalty=frequency_penalty,
192
- presence_penalty=presence_penalty,
193
- top_p=top_p,
194
- stop=stop,
195
- n=1)
196
- except Exception as e:
197
- callback = openai.completions.create
198
- kwargs['model'] = kwargs['model'] if 'model' in kwargs else self.model
199
- if except_remedy is not None:
200
- res = except_remedy(self, e, callback, argument)
201
- else:
202
- try:
203
- # implicit remedy strategy
204
- except_remedy = InvalidRequestErrorRemedyCompletionStrategy()
205
- res = except_remedy(self, e, callback, argument)
206
- except Exception as e2:
207
- ex = Exception(f'Failed to handle exception: {e}. Also failed implicit remedy strategy after retry: {e2}')
208
- raise ex from e
209
-
210
- metadata = {}
211
- # TODO: remove system behavior and user request from output. consider post-processing
212
- def replace_verbose(rsp):
213
- rsp = rsp.replace('---------SYSTEM BEHAVIOR--------\n', '')
214
- rsp = rsp.replace('\n\n---------USER REQUEST--------\n', '')
215
- return rsp
216
-
217
- rsp = [replace_verbose(r.text) for r in res.choices]
218
- output = rsp if isinstance(prompts_, list) else rsp[0]
219
- return output, metadata
220
-
221
- def prepare(self, argument):
222
- if argument.prop.raw_input:
223
- if not argument.prop.processed_input:
224
- raise ValueError('Need to provide a prompt instruction to the engine if raw_input is enabled.')
225
- value = argument.prop.processed_input
226
- if type(value) is not list:
227
- value = [str(value)]
228
- argument.prop.prepared_input = value
229
- return
230
-
231
- _non_verbose_output = """[META INSTRUCTIONS START]\nYou do not output anything else, like verbose preambles or post explanation, such as "Sure, let me...", "Hope that was helpful...", "Yes, I can help you with that...", etc. Consider well formatted output, e.g. for sentences use punctuation, spaces etc. or for code use indentation, etc. Never add meta instructions information to your output!\n"""
232
-
233
- user: str = ""
234
- system: str = ""
235
-
236
- if argument.prop.suppress_verbose_output:
237
- system += _non_verbose_output
238
- system = f'{system}\n' if system and len(system) > 0 else ''
239
-
240
- ref = argument.prop.instance
241
- static_ctxt, dyn_ctxt = ref.global_context
242
- if len(static_ctxt) > 0:
243
- system += f"[STATIC CONTEXT]\n{static_ctxt}\n\n"
244
-
245
- if len(dyn_ctxt) > 0:
246
- system += f"[DYNAMIC CONTEXT]\n{dyn_ctxt}\n\n"
247
-
248
- payload = argument.prop.payload
249
- if payload is not None:
250
- system += f"[ADDITIONAL CONTEXT]\n{payload}\n\n"
251
-
252
- examples: List[str] = argument.prop.examples
253
- if examples and len(examples) > 0:
254
- system += f"[EXAMPLES]\n{str(examples)}\n\n"
255
-
256
- if argument.prop.prompt is not None and len(argument.prop.prompt) > 0:
257
- val = str(argument.prop.prompt)
258
- system += f"[INSTRUCTION]\n{val}"
259
-
260
- suffix: str = str(argument.prop.processed_input)
261
- if '=>' in suffix:
262
- user += f"[LAST TASK]\n"
263
-
264
- if '[SYSTEM_INSTRUCTION::]: <<<' in suffix and argument.prop.parse_system_instructions:
265
- parts = suffix.split('\n>>>\n')
266
- # first parts are the system instructions
267
- for p in parts[:-1]:
268
- system += f"{p}\n"
269
- # last part is the user input
270
- suffix = parts[-1]
271
- user += f"{suffix}"
272
-
273
- if argument.prop.template_suffix is not None:
274
- user += f"\n[[PLACEHOLDER]]\n{argument.prop.template_suffix}\n\n"
275
- user += f"Only generate content for the placeholder `[[PLACEHOLDER]]` following the instructions and context information. Do NOT write `[[PLACEHOLDER]]` or anything else in your output.\n\n"
276
-
277
- argument.prop.prepared_input = [f'---------SYSTEM BEHAVIOR--------\n{system}\n\n---------USER REQUEST--------\n{user}']
symai/collect/__init__.py DELETED
@@ -1,8 +0,0 @@
1
- from .dynamic import create_object_from_string
2
- from .pipeline import CollectionRepository, rec_serialize
3
-
4
- __all__ = [
5
- "CollectionRepository",
6
- "create_object_from_string",
7
- "rec_serialize",
8
- ]
symai/collect/dynamic.py DELETED
@@ -1,117 +0,0 @@
1
- import ast
2
- import re
3
-
4
-
5
- class DynamicClass:
6
- def __init__(self, **kwargs):
7
- self.__dict__.update(kwargs)
8
-
9
- def __repr__(self):
10
- return str(self.__dict__)
11
-
12
- @staticmethod
13
- def from_string(s):
14
- return create_object_from_string(s)
15
-
16
-
17
- def create_dynamic_class(class_name, **kwargs):
18
- return type(class_name, (DynamicClass,), kwargs)()
19
-
20
-
21
- def parse_custom_class_instances(s):
22
- pattern = r"(\w+)\((.*?)\)"
23
- if not isinstance(s, str):
24
- return s
25
- matches = re.finditer(pattern, s)
26
-
27
- for match in matches:
28
- class_name = match.group(1)
29
- class_args = match.group(2)
30
- try:
31
- parsed_args = ast.literal_eval(f"{{{class_args}}}")
32
- except (ValueError, SyntaxError):
33
- parsed_args = create_object_from_string(class_args)
34
- class_instance = create_dynamic_class(class_name, **parsed_args)
35
- s = s.replace(match.group(0), repr(class_instance))
36
-
37
- return s
38
-
39
-
40
- def _strip_quotes(text):
41
- if not isinstance(text, str):
42
- return text
43
- if text.startswith("'") and text.endswith("'"):
44
- return text.strip("'")
45
- if text.startswith('"') and text.endswith('"'):
46
- return text.strip('"')
47
- return text
48
-
49
-
50
- def _extract_content(str_class):
51
- return str_class.split("ChatCompletionMessage(content=")[-1].split(", role=")[0][1:-1]
52
-
53
-
54
- def _parse_value(value):
55
- try:
56
- value = parse_custom_class_instances(value)
57
- if not isinstance(value, str):
58
- return value
59
- if value.startswith("["):
60
- inner_values = value[1:-1]
61
- values = inner_values.split(",")
62
- return [_parse_value(v.strip()) for v in values]
63
- if value.startswith("{"):
64
- inner_values = value[1:-1]
65
- values = inner_values.split(",")
66
- return {
67
- k.strip(): _parse_value(v.strip()) for k, v in [v.split(":", 1) for v in values]
68
- }
69
- result = ast.literal_eval(value)
70
- if isinstance(result, dict):
71
- return {k: _parse_value(v) for k, v in result.items()}
72
- if isinstance(result, (list, tuple, set)):
73
- return [_parse_value(v) for v in result]
74
- return result
75
- except (ValueError, SyntaxError):
76
- return value
77
-
78
-
79
- def _process_list_value(raw_value):
80
- parsed_value = _parse_value(raw_value)
81
- dir(parsed_value)
82
- if hasattr(parsed_value, "__dict__"):
83
- for key in parsed_value.__dict__:
84
- value = getattr(parsed_value, key)
85
- if isinstance(value, str):
86
- parsed_value[key.strip("'")] = value.strip("'")
87
- return parsed_value
88
-
89
-
90
- def _process_dict_value(raw_value):
91
- parsed_value = _parse_value(raw_value)
92
- new_value = {}
93
- for key, value in parsed_value.items():
94
- stripped_value = value.strip("'") if isinstance(value, str) else value
95
- new_value[key.strip("'")] = stripped_value
96
- return new_value
97
-
98
-
99
- def _collect_attributes(str_class):
100
- attr_pattern = r"(\w+)=(\[.*?\]|\{.*?\}|'.*?'|None|\w+)"
101
- attributes = re.findall(attr_pattern, str_class)
102
- updated_attributes = [("content", _extract_content(str_class))]
103
- for key, raw_value in attributes:
104
- attr_key = _strip_quotes(key)
105
- attr_value = _strip_quotes(raw_value)
106
- if attr_value.startswith("[") and attr_value.endswith("]"):
107
- attr_value = _process_list_value(attr_value)
108
- elif attr_value.startswith("{") and attr_value.endswith("}"):
109
- attr_value = _process_dict_value(attr_value)
110
- updated_attributes.append((attr_key, attr_value))
111
- return updated_attributes
112
-
113
-
114
- # TODO: fix to properly parse nested lists and dicts
115
- def create_object_from_string(str_class):
116
- updated_attributes = _collect_attributes(str_class)
117
- return DynamicClass(**{key: _parse_value(value) for key, value in updated_attributes})
symai/collect/pipeline.py DELETED
@@ -1,156 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import json
4
- import logging
5
- from datetime import datetime
6
- from typing import TYPE_CHECKING, Any
7
-
8
- from bson.objectid import ObjectId
9
- from pymongo.mongo_client import MongoClient
10
-
11
- from ..backend.settings import SYMAI_CONFIG
12
- from ..utils import UserMessage
13
-
14
- if TYPE_CHECKING:
15
- from pymongo.collection import Collection
16
- from pymongo.database import Database
17
- else:
18
- Collection = Database = Any
19
-
20
- logger = logging.getLogger(__name__)
21
-
22
-
23
- def rec_serialize(obj):
24
- """
25
- Recursively serialize a given object into a string representation, handling
26
- nested structures like lists and dictionaries.
27
-
28
- :param obj: The object to be serialized.
29
- :return: A string representation of the serialized object.
30
- """
31
- if isinstance(obj, (int, float, bool)):
32
- # For simple types, return the string representation directly.
33
- return obj
34
- if isinstance(obj, dict):
35
- # For dictionaries, serialize each value. Keep keys as strings.
36
- return {str(key): rec_serialize(value) for key, value in obj.items()}
37
- if isinstance(obj, (list, tuple, set)):
38
- # For lists, tuples, and sets, serialize each element.
39
- return [rec_serialize(elem) for elem in obj]
40
- # Attempt JSON serialization first, then fall back to str(...)
41
- try:
42
- return json.dumps(obj)
43
- except TypeError:
44
- return str(obj)
45
-
46
-
47
- class CollectionRepository:
48
- def __init__(self) -> None:
49
- self.support_community: bool = SYMAI_CONFIG["SUPPORT_COMMUNITY"]
50
- self.uri: str = SYMAI_CONFIG["COLLECTION_URI"]
51
- self.db_name: str = SYMAI_CONFIG["COLLECTION_DB"]
52
- self.collection_name: str = SYMAI_CONFIG["COLLECTION_STORAGE"]
53
- self.client: MongoClient | None = None
54
- self.db: Database | None = None
55
- self.collection: Collection | None = None
56
-
57
- def __enter__(self) -> CollectionRepository:
58
- self.connect()
59
- return self
60
-
61
- def __exit__(
62
- self, exc_type: type | None, exc_val: Exception | None, exc_tb: Any | None
63
- ) -> None:
64
- self.close()
65
-
66
- def ping(self) -> bool:
67
- if not self.support_community:
68
- return False
69
- # Send a ping to confirm a successful connection
70
- try:
71
- self.client.admin.command("ping")
72
- return True
73
- except Exception as e:
74
- UserMessage(f"Connection failed: {e}")
75
- return False
76
-
77
- def add(self, forward: Any, engine: Any, metadata: dict[str, Any] | None = None) -> Any:
78
- if metadata is None:
79
- metadata = {}
80
- if not self.support_community:
81
- return None
82
- record = {
83
- "forward": forward,
84
- "engine": engine,
85
- "metadata": metadata,
86
- "created_at": datetime.now(),
87
- "updated_at": datetime.now(),
88
- }
89
- try: # assure that adding a record does never cause a system error
90
- return self.collection.insert_one(record).inserted_id if self.collection else None
91
- except Exception:
92
- return None
93
-
94
- def get(self, record_id: str) -> dict[str, Any] | None:
95
- if not self.support_community:
96
- return None
97
- return self.collection.find_one({"_id": ObjectId(record_id)}) if self.collection else None
98
-
99
- def update(
100
- self,
101
- record_id: str,
102
- forward: Any | None = None,
103
- engine: str | None = None,
104
- metadata: dict[str, Any] | None = None,
105
- ) -> Any:
106
- if not self.support_community:
107
- return None
108
- updates: dict[str, Any] = {"updated_at": datetime.now()}
109
- if forward is not None:
110
- updates["forward"] = forward
111
- if engine is not None:
112
- updates["engine"] = engine
113
- if metadata is not None:
114
- updates["metadata"] = metadata
115
-
116
- return (
117
- self.collection.update_one({"_id": ObjectId(record_id)}, {"$set": updates})
118
- if self.collection
119
- else None
120
- )
121
-
122
- def delete(self, record_id: str) -> Any:
123
- if not self.support_community:
124
- return None
125
- return self.collection.delete_one({"_id": ObjectId(record_id)}) if self.collection else None
126
-
127
- def list(self, filters: dict[str, Any] | None = None, limit: int = 0) -> list[dict[str, Any]]:
128
- if not self.support_community:
129
- return None
130
- if filters is None:
131
- filters = {}
132
- return list(self.collection.find(filters).limit(limit)) if self.collection else []
133
-
134
- def count(self, filters: dict[str, Any] | None = None) -> int:
135
- if not self.support_community:
136
- return None
137
- if filters is None:
138
- filters = {}
139
- return self.collection.count_documents(filters) if self.collection else 0
140
-
141
- def connect(self) -> None:
142
- try:
143
- if self.client is None and self.support_community:
144
- self.client = MongoClient(self.uri)
145
- self.db = self.client[self.db_name]
146
- self.collection = self.db[self.collection_name]
147
- except Exception as e:
148
- # disable retries
149
- self.client = False
150
- self.db = None
151
- self.collection = None
152
- UserMessage(f"[WARN] MongoClient: Connection failed: {e}")
153
-
154
- def close(self) -> None:
155
- if self.client is not None:
156
- self.client.close()