lionagi 0.0.112__py3-none-any.whl → 0.0.113__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lionagi/__init__.py +3 -3
- lionagi/bridge/__init__.py +7 -0
- lionagi/bridge/langchain.py +131 -0
- lionagi/bridge/llama_index.py +157 -0
- lionagi/configs/__init__.py +7 -0
- lionagi/configs/oai_configs.py +49 -0
- lionagi/configs/openrouter_config.py +49 -0
- lionagi/core/__init__.py +8 -2
- lionagi/core/instruction_sets.py +1 -3
- lionagi/core/messages.py +2 -2
- lionagi/core/sessions.py +174 -27
- lionagi/datastore/__init__.py +1 -0
- lionagi/loader/__init__.py +9 -4
- lionagi/loader/chunker.py +157 -0
- lionagi/loader/reader.py +124 -0
- lionagi/objs/__init__.py +7 -0
- lionagi/objs/messenger.py +163 -0
- lionagi/objs/tool_registry.py +247 -0
- lionagi/schema/__init__.py +11 -0
- lionagi/schema/base_schema.py +239 -0
- lionagi/schema/base_tool.py +9 -0
- lionagi/schema/data_logger.py +94 -0
- lionagi/services/__init__.py +14 -0
- lionagi/{service_/oai.py → services/base_api_service.py} +49 -82
- lionagi/{endpoint/base_endpoint.py → services/chatcompletion.py} +19 -22
- lionagi/services/oai.py +34 -0
- lionagi/services/openrouter.py +32 -0
- lionagi/{service_/service_utils.py → services/service_objs.py} +0 -1
- lionagi/structure/__init__.py +7 -0
- lionagi/structure/relationship.py +128 -0
- lionagi/structure/structure.py +160 -0
- lionagi/tests/test_flatten_util.py +426 -0
- lionagi/tools/__init__.py +0 -5
- lionagi/tools/coder.py +1 -0
- lionagi/tools/scorer.py +1 -0
- lionagi/tools/validator.py +1 -0
- lionagi/utils/__init__.py +46 -20
- lionagi/utils/api_util.py +86 -0
- lionagi/utils/call_util.py +347 -0
- lionagi/utils/flat_util.py +540 -0
- lionagi/utils/io_util.py +102 -0
- lionagi/utils/load_utils.py +190 -0
- lionagi/utils/sys_util.py +191 -0
- lionagi/utils/tool_util.py +92 -0
- lionagi/utils/type_util.py +81 -0
- lionagi/version.py +1 -1
- {lionagi-0.0.112.dist-info → lionagi-0.0.113.dist-info}/METADATA +37 -13
- lionagi-0.0.113.dist-info/RECORD +84 -0
- lionagi/endpoint/chat_completion.py +0 -20
- lionagi/endpoint/endpoint_utils.py +0 -0
- lionagi/llm_configs.py +0 -21
- lionagi/loader/load_utils.py +0 -161
- lionagi/schema.py +0 -275
- lionagi/service_/__init__.py +0 -6
- lionagi/service_/base_service.py +0 -48
- lionagi/service_/openrouter.py +0 -1
- lionagi/services.py +0 -1
- lionagi/tools/tool_utils.py +0 -75
- lionagi/utils/sys_utils.py +0 -799
- lionagi-0.0.112.dist-info/RECORD +0 -67
- /lionagi/{core/responses.py → datastore/chroma.py} +0 -0
- /lionagi/{endpoint/assistants.py → datastore/deeplake.py} +0 -0
- /lionagi/{endpoint/audio.py → datastore/elasticsearch.py} +0 -0
- /lionagi/{endpoint/embeddings.py → datastore/lantern.py} +0 -0
- /lionagi/{endpoint/files.py → datastore/pinecone.py} +0 -0
- /lionagi/{endpoint/fine_tuning.py → datastore/postgres.py} +0 -0
- /lionagi/{endpoint/images.py → datastore/qdrant.py} +0 -0
- /lionagi/{endpoint/messages.py → schema/base_condition.py} +0 -0
- /lionagi/{service_ → services}/anthropic.py +0 -0
- /lionagi/{service_ → services}/anyscale.py +0 -0
- /lionagi/{service_ → services}/azure.py +0 -0
- /lionagi/{service_ → services}/bedrock.py +0 -0
- /lionagi/{service_ → services}/everlyai.py +0 -0
- /lionagi/{service_ → services}/gemini.py +0 -0
- /lionagi/{service_ → services}/gpt4all.py +0 -0
- /lionagi/{service_ → services}/huggingface.py +0 -0
- /lionagi/{service_ → services}/litellm.py +0 -0
- /lionagi/{service_ → services}/localai.py +0 -0
- /lionagi/{service_ → services}/mistralai.py +0 -0
- /lionagi/{service_ → services}/ollama.py +0 -0
- /lionagi/{service_ → services}/openllm.py +0 -0
- /lionagi/{service_ → services}/perplexity.py +0 -0
- /lionagi/{service_ → services}/predibase.py +0 -0
- /lionagi/{service_ → services}/rungpt.py +0 -0
- /lionagi/{service_ → services}/vllm.py +0 -0
- /lionagi/{service_ → services}/xinference.py +0 -0
- /lionagi/{endpoint → tests}/__init__.py +0 -0
- /lionagi/{endpoint/models.py → tools/planner.py} +0 -0
- /lionagi/{endpoint/moderations.py → tools/prompter.py} +0 -0
- /lionagi/{endpoint/runs.py → tools/sandbox.py} +0 -0
- /lionagi/{endpoint/threads.py → tools/summarizer.py} +0 -0
- {lionagi-0.0.112.dist-info → lionagi-0.0.113.dist-info}/LICENSE +0 -0
- {lionagi-0.0.112.dist-info → lionagi-0.0.113.dist-info}/WHEEL +0 -0
- {lionagi-0.0.112.dist-info → lionagi-0.0.113.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,94 @@
|
|
1
|
+
from collections import deque
|
2
|
+
from typing import List, Optional
|
3
|
+
from ..utils.sys_util import create_path
|
4
|
+
from ..utils.io_util import to_csv
|
5
|
+
|
6
|
+
|
7
|
+
class DataLogger:
|
8
|
+
"""
|
9
|
+
A class for logging data entries and exporting them as CSV files.
|
10
|
+
|
11
|
+
This class provides functionality to log data entries in a deque and
|
12
|
+
supports exporting the logged data to a CSV file. The DataLogger can
|
13
|
+
be configured to use a specific directory for saving files.
|
14
|
+
|
15
|
+
Attributes:
|
16
|
+
dir (Optional[str]):
|
17
|
+
The default directory where CSV files will be saved.
|
18
|
+
log (deque):
|
19
|
+
A deque object that stores the logged data entries.
|
20
|
+
|
21
|
+
Methods:
|
22
|
+
__call__:
|
23
|
+
Adds an entry to the log.
|
24
|
+
to_csv:
|
25
|
+
Exports the logged data to a CSV file and clears the log.
|
26
|
+
set_dir:
|
27
|
+
Sets the default directory for saving CSV files.
|
28
|
+
"""
|
29
|
+
|
30
|
+
def __init__(self, dir= None, log: list = None) -> None:
|
31
|
+
"""
|
32
|
+
Initializes the DataLogger with an optional directory and initial log.
|
33
|
+
|
34
|
+
Parameters:
|
35
|
+
dir (Optional[str]):
|
36
|
+
The directory where CSV files will be saved. Defaults to None.
|
37
|
+
log (Optional[List]):
|
38
|
+
An initial list of log entries. Defaults to an empty list.
|
39
|
+
"""
|
40
|
+
self.dir = dir
|
41
|
+
self.log = deque(log) if log else deque()
|
42
|
+
|
43
|
+
def __call__(self, entry):
|
44
|
+
"""
|
45
|
+
Adds a new entry to the log.
|
46
|
+
|
47
|
+
Parameters:
|
48
|
+
entry:
|
49
|
+
The data entry to be added to the log.
|
50
|
+
"""
|
51
|
+
self.log.append(entry)
|
52
|
+
|
53
|
+
def to_csv(self, filename: str, dir: Optional[str] = None, verbose: bool = True,
|
54
|
+
timestamp: bool = True, dir_exist_ok: bool = True, file_exist_ok: bool = False) -> None:
|
55
|
+
"""
|
56
|
+
Exports the logged data to a CSV file and optionally clears the log.
|
57
|
+
|
58
|
+
Parameters:
|
59
|
+
filename (str):
|
60
|
+
The name of the CSV file.
|
61
|
+
dir (Optional[str]):
|
62
|
+
The directory to save the file. Defaults to the instance's dir attribute.
|
63
|
+
verbose (bool):
|
64
|
+
If True, prints a message upon completion. Defaults to True.
|
65
|
+
timestamp (bool):
|
66
|
+
If True, appends a timestamp to the filename. Defaults to True.
|
67
|
+
dir_exist_ok (bool):
|
68
|
+
If True, will not raise an error if the directory already exists. Defaults to True.
|
69
|
+
file_exist_ok (bool):
|
70
|
+
If True, overwrites the file if it exists. Defaults to False.
|
71
|
+
|
72
|
+
Side Effects:
|
73
|
+
Clears the log after saving the CSV file.
|
74
|
+
Prints a message indicating the save location and number of logs saved if verbose is True.
|
75
|
+
"""
|
76
|
+
dir = dir or self.dir
|
77
|
+
filepath = create_path(
|
78
|
+
dir=dir, filename=filename, timestamp=timestamp, dir_exist_ok=dir_exist_ok)
|
79
|
+
to_csv(list(self.log), filepath, file_exist_ok=file_exist_ok)
|
80
|
+
n_logs = len(list(self.log))
|
81
|
+
self.log = deque()
|
82
|
+
if verbose:
|
83
|
+
print(f"{n_logs} logs saved to {filepath}")
|
84
|
+
|
85
|
+
def set_dir(self, dir: str) -> None:
|
86
|
+
"""
|
87
|
+
Sets the default directory for saving CSV files.
|
88
|
+
|
89
|
+
Parameters:
|
90
|
+
dir (str):
|
91
|
+
The directory to be set as the default for saving files.
|
92
|
+
"""
|
93
|
+
self.dir = dir
|
94
|
+
|
@@ -0,0 +1,14 @@
|
|
1
|
+
from .chatcompletion import ChatCompletion
|
2
|
+
from .base_api_service import BaseAPIService, BaseAPIRateLimiter
|
3
|
+
from .oai import OpenAIService
|
4
|
+
from .openrouter import OpenRouterService
|
5
|
+
|
6
|
+
|
7
|
+
|
8
|
+
__all__ = [
|
9
|
+
"BaseAPIService",
|
10
|
+
"OpenAIService",
|
11
|
+
"OpenRouterService",
|
12
|
+
"ChatCompletion",
|
13
|
+
"BaseAPIRateLimiter"
|
14
|
+
]
|
@@ -1,55 +1,18 @@
|
|
1
|
-
import
|
2
|
-
import dotenv
|
1
|
+
import re
|
3
2
|
import asyncio
|
4
|
-
import
|
3
|
+
import os
|
5
4
|
import tiktoken
|
5
|
+
import logging
|
6
6
|
import aiohttp
|
7
|
-
from typing import
|
8
|
-
|
9
|
-
from .base_service import BaseAPIService
|
10
|
-
|
11
|
-
dotenv.load_dotenv()
|
12
|
-
|
13
|
-
from .base_service import BaseAPIService
|
14
|
-
from .service_utils import StatusTracker, AsyncQueue, RateLimiter
|
15
|
-
|
7
|
+
from typing import Generator, NoReturn, Dict, Any, Optional
|
8
|
+
from .service_objs import BaseService, RateLimiter, StatusTracker, AsyncQueue
|
16
9
|
|
17
|
-
|
18
|
-
class OpenAIRateLimiter(RateLimiter):
|
19
|
-
"""
|
20
|
-
A specialized RateLimiter for managing requests to the OpenAI API.
|
21
|
-
|
22
|
-
Extends the generic RateLimiter to enforce specific rate-limiting rules and limits
|
23
|
-
as required by the OpenAI API. This includes maximum requests and tokens per minute
|
24
|
-
and replenishing these limits at regular intervals.
|
25
|
-
|
26
|
-
Attributes:
|
27
|
-
max_requests_per_minute (int):
|
28
|
-
Maximum number of requests allowed per minute.
|
29
|
-
max_tokens_per_minute (int):
|
30
|
-
Maximum number of tokens allowed per minute.
|
31
|
-
|
32
|
-
Methods:
|
33
|
-
rate_limit_replenisher:
|
34
|
-
Coroutine to replenish rate limits over time.
|
35
|
-
calculate_num_token:
|
36
|
-
Calculates the required tokens for a request.
|
37
|
-
"""
|
10
|
+
class BaseAPIRateLimiter(RateLimiter):
|
38
11
|
|
39
12
|
def __init__(
|
40
13
|
self, max_requests_per_minute: int, max_tokens_per_minute: int
|
41
14
|
) -> None:
|
42
|
-
"""
|
43
|
-
Initializes the rate limiter with specific limits for OpenAI API.
|
44
|
-
|
45
|
-
Parameters:
|
46
|
-
max_requests_per_minute (int): The maximum number of requests allowed per minute.
|
47
|
-
|
48
|
-
max_tokens_per_minute (int): The maximum number of tokens that can accumulate per minute.
|
49
|
-
"""
|
50
15
|
super().__init__(max_requests_per_minute, max_tokens_per_minute)
|
51
|
-
if not os.getenv('env_readthedocs'):
|
52
|
-
self.rate_limit_replenisher_task = asyncio.create_task(self.rate_limit_replenisher())
|
53
16
|
|
54
17
|
@classmethod
|
55
18
|
async def create(
|
@@ -160,48 +123,53 @@ class OpenAIRateLimiter(RateLimiter):
|
|
160
123
|
raise NotImplementedError(
|
161
124
|
f'API endpoint "{api_endpoint}" not implemented in this script'
|
162
125
|
)
|
163
|
-
|
164
|
-
|
165
|
-
class OpenAIService(BaseAPIService):
|
166
|
-
"""
|
167
|
-
Service class for interacting with the OpenAI API.
|
168
126
|
|
169
|
-
This class provides methods for calling OpenAI's API endpoints, handling the responses,
|
170
|
-
and managing rate limits and asynchronous tasks associated with API calls.
|
171
127
|
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
128
|
+
class BaseAPIService(BaseService):
|
129
|
+
|
130
|
+
def __init__(self, api_key: str = None,
|
131
|
+
status_tracker = None,
|
132
|
+
queue = None, endpoint=None, schema=None,
|
133
|
+
ratelimiter=None, max_requests_per_minute=None, max_tokens_per_minute=None) -> None:
|
134
|
+
self.api_key = api_key
|
135
|
+
self.status_tracker = status_tracker or StatusTracker()
|
136
|
+
self.queue = queue or AsyncQueue()
|
137
|
+
self.endpoint=endpoint
|
138
|
+
self.schema = schema
|
139
|
+
self.rate_limiter = ratelimiter(max_requests_per_minute, max_tokens_per_minute)
|
140
|
+
|
141
|
+
@staticmethod
|
142
|
+
def api_methods(http_session, method="post"):
|
143
|
+
if method not in ["post", "delete", "head", "options", "patch"]:
|
144
|
+
raise ValueError("Invalid request, method must be in ['post', 'delete', 'head', 'options', 'patch']")
|
145
|
+
elif method == "post":
|
146
|
+
return http_session.post
|
147
|
+
elif method == "delete":
|
148
|
+
return http_session.delete
|
149
|
+
elif method == "head":
|
150
|
+
return http_session.head
|
151
|
+
elif method == "options":
|
152
|
+
return http_session.options
|
153
|
+
elif method == "patch":
|
154
|
+
return http_session.patch
|
155
|
+
|
156
|
+
@staticmethod
|
157
|
+
def api_endpoint_from_url(request_url: str) -> str:
|
158
|
+
match = re.search(r"^https://[^/]+/v\d+/(.+)$", request_url)
|
159
|
+
if match:
|
160
|
+
return match.group(1)
|
161
|
+
else:
|
162
|
+
return ""
|
178
163
|
|
179
|
-
|
164
|
+
@staticmethod
|
165
|
+
def task_id_generator_function() -> Generator[int, None, None]:
|
166
|
+
task_id = 0
|
167
|
+
while True:
|
168
|
+
yield task_id
|
169
|
+
task_id += 1
|
180
170
|
|
181
|
-
def __init__(
|
182
|
-
self,
|
183
|
-
api_key: str = None,
|
184
|
-
token_encoding_name: str = "cl100k_base",
|
185
|
-
max_attempts: int = 3,
|
186
|
-
max_requests_per_minute: int = 500,
|
187
|
-
max_tokens_per_minute: int = 150_000,
|
188
|
-
ratelimiter = OpenAIRateLimiter ,
|
189
|
-
status_tracker = None,
|
190
|
-
queue = None,
|
191
|
-
):
|
192
|
-
super().__init__(
|
193
|
-
api_key = api_key or os.getenv("OPENAI_API_KEY"),
|
194
|
-
status_tracker = status_tracker or StatusTracker(),
|
195
|
-
queue = queue or AsyncQueue(),
|
196
|
-
ratelimiter=ratelimiter,
|
197
|
-
max_requests_per_minute=max_requests_per_minute,
|
198
|
-
max_tokens_per_minute=max_tokens_per_minute),
|
199
|
-
self.token_encoding_name=token_encoding_name
|
200
|
-
self.max_attempts = max_attempts
|
201
|
-
|
202
|
-
|
203
171
|
async def _call_api(self, http_session, endpoint_, method="post", payload: Dict[str, any] =None) -> Optional[Dict[str, any]]:
|
204
|
-
endpoint_ = self.api_endpoint_from_url(
|
172
|
+
endpoint_ = self.api_endpoint_from_url("https://api.openai.com/v1/"+endpoint_)
|
205
173
|
|
206
174
|
while True:
|
207
175
|
if self.rate_limiter.available_request_capacity < 1 or self.rate_limiter.available_token_capacity < 10: # Minimum token count
|
@@ -244,7 +212,7 @@ class OpenAIService(BaseAPIService):
|
|
244
212
|
else:
|
245
213
|
await asyncio.sleep(1)
|
246
214
|
|
247
|
-
async def
|
215
|
+
async def _serve(self, payload, endpoint_="chat/completions", method="post"):
|
248
216
|
|
249
217
|
async def call_api():
|
250
218
|
async with aiohttp.ClientSession() as http_session:
|
@@ -256,5 +224,4 @@ class OpenAIService(BaseAPIService):
|
|
256
224
|
except Exception as e:
|
257
225
|
self.status_tracker.num_tasks_failed += 1
|
258
226
|
raise e
|
259
|
-
|
260
227
|
|
@@ -1,5 +1,6 @@
|
|
1
1
|
import abc
|
2
2
|
|
3
|
+
|
3
4
|
class BaseEndpoint(abc.ABC):
|
4
5
|
endpoint: str = abc.abstractproperty()
|
5
6
|
|
@@ -16,17 +17,6 @@ class BaseEndpoint(abc.ABC):
|
|
16
17
|
"""
|
17
18
|
pass
|
18
19
|
|
19
|
-
# @abc.abstractmethod
|
20
|
-
# async def call_api(self, session, **kwargs):
|
21
|
-
# """
|
22
|
-
# Make a call to the API endpoint and process the response.
|
23
|
-
|
24
|
-
# Parameters:
|
25
|
-
# session: The aiohttp client session.
|
26
|
-
# **kwargs: Additional keyword arguments for configuration.
|
27
|
-
# """
|
28
|
-
# pass
|
29
|
-
|
30
20
|
@abc.abstractmethod
|
31
21
|
def process_response(self, response):
|
32
22
|
"""
|
@@ -38,14 +28,21 @@ class BaseEndpoint(abc.ABC):
|
|
38
28
|
pass
|
39
29
|
|
40
30
|
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
31
|
+
class ChatCompletion(BaseEndpoint):
|
32
|
+
endpoint: str = "chat/completion"
|
33
|
+
|
34
|
+
@classmethod
|
35
|
+
def create_payload(scls, messages, llmconfig, schema, **kwargs):
|
36
|
+
config = {**llmconfig, **kwargs}
|
37
|
+
payload = {"messages": messages}
|
38
|
+
for key in schema['required']:
|
39
|
+
payload.update({key: config[key]})
|
40
|
+
|
41
|
+
for key in schema['optional']:
|
42
|
+
if bool(config[key]) is True and str(config[key]).lower() != "none":
|
43
|
+
payload.update({key: config[key]})
|
44
|
+
return payload
|
45
|
+
|
46
|
+
def process_response(self, session, payload, completion):
|
47
|
+
...
|
48
|
+
|
lionagi/services/oai.py
ADDED
@@ -0,0 +1,34 @@
|
|
1
|
+
from os import getenv
|
2
|
+
import dotenv
|
3
|
+
from .base_api_service import BaseAPIService, BaseAPIRateLimiter
|
4
|
+
|
5
|
+
dotenv.load_dotenv()
|
6
|
+
|
7
|
+
class OpenAIService(BaseAPIService):
|
8
|
+
|
9
|
+
base_url = "https://api.openai.com/v1/"
|
10
|
+
|
11
|
+
def __init__(
|
12
|
+
self,
|
13
|
+
api_key: str = None,
|
14
|
+
token_encoding_name: str = "cl100k_base",
|
15
|
+
max_attempts: int = 3,
|
16
|
+
max_requests_per_minute: int = 500,
|
17
|
+
max_tokens_per_minute: int = 150_000,
|
18
|
+
ratelimiter = BaseAPIRateLimiter ,
|
19
|
+
status_tracker = None,
|
20
|
+
queue = None,
|
21
|
+
):
|
22
|
+
super().__init__(
|
23
|
+
api_key = api_key or getenv("OPENAI_API_KEY"),
|
24
|
+
status_tracker = status_tracker,
|
25
|
+
queue = queue,
|
26
|
+
ratelimiter=ratelimiter,
|
27
|
+
max_requests_per_minute=max_requests_per_minute,
|
28
|
+
max_tokens_per_minute=max_tokens_per_minute),
|
29
|
+
self.token_encoding_name=token_encoding_name
|
30
|
+
self.max_attempts = max_attempts
|
31
|
+
|
32
|
+
async def serve(self, payload, endpoint_="chat/completions", method="post"):
|
33
|
+
return await self._serve(payload=payload, endpoint_=endpoint_, method=method)
|
34
|
+
|
@@ -0,0 +1,32 @@
|
|
1
|
+
from os import getenv
|
2
|
+
from .base_api_service import BaseAPIService, BaseAPIRateLimiter
|
3
|
+
|
4
|
+
class OpenRouterService(BaseAPIService):
|
5
|
+
_key_scheme = "OPENROUTER_API_KEY"
|
6
|
+
|
7
|
+
base_url = "https://openrouter.ai/api/v1/"
|
8
|
+
|
9
|
+
def __init__(
|
10
|
+
self,
|
11
|
+
api_key: str = None,
|
12
|
+
token_encoding_name: str = "cl100k_base",
|
13
|
+
max_attempts: int = 3,
|
14
|
+
max_requests_per_minute: int = 500,
|
15
|
+
max_tokens_per_minute: int = 150_000,
|
16
|
+
ratelimiter = BaseAPIRateLimiter ,
|
17
|
+
status_tracker = None,
|
18
|
+
queue = None,
|
19
|
+
):
|
20
|
+
super().__init__(
|
21
|
+
api_key = api_key or getenv(self._key_scheme),
|
22
|
+
status_tracker = status_tracker,
|
23
|
+
queue = queue,
|
24
|
+
ratelimiter=ratelimiter,
|
25
|
+
max_requests_per_minute=max_requests_per_minute,
|
26
|
+
max_tokens_per_minute=max_tokens_per_minute),
|
27
|
+
self.token_encoding_name=token_encoding_name
|
28
|
+
self.max_attempts = max_attempts
|
29
|
+
|
30
|
+
async def serve(self, payload, endpoint_="chat/completions"):
|
31
|
+
return await self._serve(payload=payload, endpoint_=endpoint_)
|
32
|
+
|
@@ -0,0 +1,128 @@
|
|
1
|
+
from pydantic import Field
|
2
|
+
from typing import Dict, Optional, Any
|
3
|
+
from ..schema.base_schema import BaseNode
|
4
|
+
|
5
|
+
|
6
|
+
class Relationship(BaseNode):
|
7
|
+
"""
|
8
|
+
Relationship class represents a relationship between two nodes in a graph.
|
9
|
+
|
10
|
+
Inherits from BaseNode and adds functionality to manage conditions and relationships
|
11
|
+
between source and target nodes.
|
12
|
+
|
13
|
+
Attributes:
|
14
|
+
source_node_id (str): The identifier of the source node.
|
15
|
+
target_node_id (str): The identifier of the target node.
|
16
|
+
condition (Dict[str, Any]): A dictionary representing conditions for the relationship.
|
17
|
+
"""
|
18
|
+
|
19
|
+
source_node_id: str
|
20
|
+
target_node_id: str
|
21
|
+
condition: dict = Field(default={})
|
22
|
+
|
23
|
+
def add_condition(self, condition: Dict[str, Any]) -> None:
|
24
|
+
"""
|
25
|
+
Adds a condition to the relationship.
|
26
|
+
|
27
|
+
Parameters:
|
28
|
+
condition (Dict[str, Any]): The condition to be added.
|
29
|
+
"""
|
30
|
+
self.condition.update(condition)
|
31
|
+
|
32
|
+
def remove_condition(self, condition_key: str) -> Any:
|
33
|
+
"""
|
34
|
+
Removes a condition from the relationship.
|
35
|
+
|
36
|
+
Parameters:
|
37
|
+
condition_key (str): The key of the condition to be removed.
|
38
|
+
|
39
|
+
Returns:
|
40
|
+
Any: The value of the removed condition.
|
41
|
+
|
42
|
+
Raises:
|
43
|
+
KeyError: If the condition key is not found.
|
44
|
+
"""
|
45
|
+
if condition_key not in self.condition.keys():
|
46
|
+
raise KeyError(f'condition {condition_key} is not found')
|
47
|
+
return self.condition.pop(condition_key)
|
48
|
+
|
49
|
+
def condition_exists(self, condition_key: str) -> bool:
|
50
|
+
"""
|
51
|
+
Checks if a condition exists in the relationship.
|
52
|
+
|
53
|
+
Parameters:
|
54
|
+
condition_key (str): The key of the condition to check.
|
55
|
+
|
56
|
+
Returns:
|
57
|
+
bool: True if the condition exists, False otherwise.
|
58
|
+
"""
|
59
|
+
if condition_key in self.condition.keys():
|
60
|
+
return True
|
61
|
+
else:
|
62
|
+
return False
|
63
|
+
|
64
|
+
def get_condition(self, condition_key: Optional[str] = None) -> Any:
|
65
|
+
"""
|
66
|
+
Retrieves a specific condition or all conditions of the relationship.
|
67
|
+
|
68
|
+
Parameters:
|
69
|
+
condition_key (Optional[str]): The key of the specific condition. Defaults to None.
|
70
|
+
|
71
|
+
Returns:
|
72
|
+
Any: The requested condition or all conditions if no key is provided.
|
73
|
+
|
74
|
+
Raises:
|
75
|
+
ValueError: If the specified condition key does not exist.
|
76
|
+
"""
|
77
|
+
if condition_key is None:
|
78
|
+
return self.condition
|
79
|
+
if self.condition_exists(condition_key=condition_key):
|
80
|
+
return self.condition[condition_key]
|
81
|
+
else:
|
82
|
+
raise ValueError(f"Condition {condition_key} does not exist")
|
83
|
+
|
84
|
+
def _source_existed(self, obj: Dict[str, Any]) -> bool:
|
85
|
+
"""
|
86
|
+
Checks if the source node exists in a given object.
|
87
|
+
|
88
|
+
Parameters:
|
89
|
+
obj (Dict[str, Any]): The object to check.
|
90
|
+
|
91
|
+
Returns:
|
92
|
+
bool: True if the source node exists, False otherwise.
|
93
|
+
"""
|
94
|
+
return self.source_node_id in obj.keys()
|
95
|
+
|
96
|
+
def _target_existed(self, obj: Dict[str, Any]) -> bool:
|
97
|
+
"""
|
98
|
+
Checks if the target node exists in a given object.
|
99
|
+
|
100
|
+
Parameters:
|
101
|
+
obj (Dict[str, Any]): The object to check.
|
102
|
+
|
103
|
+
Returns:
|
104
|
+
bool: True if the target node exists, False otherwise.
|
105
|
+
"""
|
106
|
+
return self.target_node_id in obj.keys()
|
107
|
+
|
108
|
+
def _is_in(self, obj: Dict[str, Any]) -> bool:
|
109
|
+
"""
|
110
|
+
Validates the existence of both source and target nodes in a given object.
|
111
|
+
|
112
|
+
Parameters:
|
113
|
+
obj (Dict[str, Any]): The object to check.
|
114
|
+
|
115
|
+
Returns:
|
116
|
+
bool: True if both nodes exist.
|
117
|
+
|
118
|
+
Raises:
|
119
|
+
ValueError: If either the source or target node does not exist.
|
120
|
+
"""
|
121
|
+
if self._source_existed(obj) and self._target_existed(obj):
|
122
|
+
return True
|
123
|
+
|
124
|
+
elif self._source_existed(obj):
|
125
|
+
raise ValueError(f"Target node {self.source_node_id} does not exist")
|
126
|
+
else :
|
127
|
+
raise ValueError(f"Source node {self.target_node_id} does not exist")
|
128
|
+
|