lionagi 0.0.112__py3-none-any.whl → 0.0.113__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- lionagi/__init__.py +3 -3
- lionagi/bridge/__init__.py +7 -0
- lionagi/bridge/langchain.py +131 -0
- lionagi/bridge/llama_index.py +157 -0
- lionagi/configs/__init__.py +7 -0
- lionagi/configs/oai_configs.py +49 -0
- lionagi/configs/openrouter_config.py +49 -0
- lionagi/core/__init__.py +8 -2
- lionagi/core/instruction_sets.py +1 -3
- lionagi/core/messages.py +2 -2
- lionagi/core/sessions.py +174 -27
- lionagi/datastore/__init__.py +1 -0
- lionagi/loader/__init__.py +9 -4
- lionagi/loader/chunker.py +157 -0
- lionagi/loader/reader.py +124 -0
- lionagi/objs/__init__.py +7 -0
- lionagi/objs/messenger.py +163 -0
- lionagi/objs/tool_registry.py +247 -0
- lionagi/schema/__init__.py +11 -0
- lionagi/schema/base_schema.py +239 -0
- lionagi/schema/base_tool.py +9 -0
- lionagi/schema/data_logger.py +94 -0
- lionagi/services/__init__.py +14 -0
- lionagi/{service_/oai.py → services/base_api_service.py} +49 -82
- lionagi/{endpoint/base_endpoint.py → services/chatcompletion.py} +19 -22
- lionagi/services/oai.py +34 -0
- lionagi/services/openrouter.py +32 -0
- lionagi/{service_/service_utils.py → services/service_objs.py} +0 -1
- lionagi/structure/__init__.py +7 -0
- lionagi/structure/relationship.py +128 -0
- lionagi/structure/structure.py +160 -0
- lionagi/tests/test_flatten_util.py +426 -0
- lionagi/tools/__init__.py +0 -5
- lionagi/tools/coder.py +1 -0
- lionagi/tools/scorer.py +1 -0
- lionagi/tools/validator.py +1 -0
- lionagi/utils/__init__.py +46 -20
- lionagi/utils/api_util.py +86 -0
- lionagi/utils/call_util.py +347 -0
- lionagi/utils/flat_util.py +540 -0
- lionagi/utils/io_util.py +102 -0
- lionagi/utils/load_utils.py +190 -0
- lionagi/utils/sys_util.py +191 -0
- lionagi/utils/tool_util.py +92 -0
- lionagi/utils/type_util.py +81 -0
- lionagi/version.py +1 -1
- {lionagi-0.0.112.dist-info → lionagi-0.0.113.dist-info}/METADATA +37 -13
- lionagi-0.0.113.dist-info/RECORD +84 -0
- lionagi/endpoint/chat_completion.py +0 -20
- lionagi/endpoint/endpoint_utils.py +0 -0
- lionagi/llm_configs.py +0 -21
- lionagi/loader/load_utils.py +0 -161
- lionagi/schema.py +0 -275
- lionagi/service_/__init__.py +0 -6
- lionagi/service_/base_service.py +0 -48
- lionagi/service_/openrouter.py +0 -1
- lionagi/services.py +0 -1
- lionagi/tools/tool_utils.py +0 -75
- lionagi/utils/sys_utils.py +0 -799
- lionagi-0.0.112.dist-info/RECORD +0 -67
- /lionagi/{core/responses.py → datastore/chroma.py} +0 -0
- /lionagi/{endpoint/assistants.py → datastore/deeplake.py} +0 -0
- /lionagi/{endpoint/audio.py → datastore/elasticsearch.py} +0 -0
- /lionagi/{endpoint/embeddings.py → datastore/lantern.py} +0 -0
- /lionagi/{endpoint/files.py → datastore/pinecone.py} +0 -0
- /lionagi/{endpoint/fine_tuning.py → datastore/postgres.py} +0 -0
- /lionagi/{endpoint/images.py → datastore/qdrant.py} +0 -0
- /lionagi/{endpoint/messages.py → schema/base_condition.py} +0 -0
- /lionagi/{service_ → services}/anthropic.py +0 -0
- /lionagi/{service_ → services}/anyscale.py +0 -0
- /lionagi/{service_ → services}/azure.py +0 -0
- /lionagi/{service_ → services}/bedrock.py +0 -0
- /lionagi/{service_ → services}/everlyai.py +0 -0
- /lionagi/{service_ → services}/gemini.py +0 -0
- /lionagi/{service_ → services}/gpt4all.py +0 -0
- /lionagi/{service_ → services}/huggingface.py +0 -0
- /lionagi/{service_ → services}/litellm.py +0 -0
- /lionagi/{service_ → services}/localai.py +0 -0
- /lionagi/{service_ → services}/mistralai.py +0 -0
- /lionagi/{service_ → services}/ollama.py +0 -0
- /lionagi/{service_ → services}/openllm.py +0 -0
- /lionagi/{service_ → services}/perplexity.py +0 -0
- /lionagi/{service_ → services}/predibase.py +0 -0
- /lionagi/{service_ → services}/rungpt.py +0 -0
- /lionagi/{service_ → services}/vllm.py +0 -0
- /lionagi/{service_ → services}/xinference.py +0 -0
- /lionagi/{endpoint → tests}/__init__.py +0 -0
- /lionagi/{endpoint/models.py → tools/planner.py} +0 -0
- /lionagi/{endpoint/moderations.py → tools/prompter.py} +0 -0
- /lionagi/{endpoint/runs.py → tools/sandbox.py} +0 -0
- /lionagi/{endpoint/threads.py → tools/summarizer.py} +0 -0
- {lionagi-0.0.112.dist-info → lionagi-0.0.113.dist-info}/LICENSE +0 -0
- {lionagi-0.0.112.dist-info → lionagi-0.0.113.dist-info}/WHEEL +0 -0
- {lionagi-0.0.112.dist-info → lionagi-0.0.113.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,94 @@
|
|
1
|
+
from collections import deque
|
2
|
+
from typing import List, Optional
|
3
|
+
from ..utils.sys_util import create_path
|
4
|
+
from ..utils.io_util import to_csv
|
5
|
+
|
6
|
+
|
7
|
+
class DataLogger:
|
8
|
+
"""
|
9
|
+
A class for logging data entries and exporting them as CSV files.
|
10
|
+
|
11
|
+
This class provides functionality to log data entries in a deque and
|
12
|
+
supports exporting the logged data to a CSV file. The DataLogger can
|
13
|
+
be configured to use a specific directory for saving files.
|
14
|
+
|
15
|
+
Attributes:
|
16
|
+
dir (Optional[str]):
|
17
|
+
The default directory where CSV files will be saved.
|
18
|
+
log (deque):
|
19
|
+
A deque object that stores the logged data entries.
|
20
|
+
|
21
|
+
Methods:
|
22
|
+
__call__:
|
23
|
+
Adds an entry to the log.
|
24
|
+
to_csv:
|
25
|
+
Exports the logged data to a CSV file and clears the log.
|
26
|
+
set_dir:
|
27
|
+
Sets the default directory for saving CSV files.
|
28
|
+
"""
|
29
|
+
|
30
|
+
def __init__(self, dir= None, log: list = None) -> None:
|
31
|
+
"""
|
32
|
+
Initializes the DataLogger with an optional directory and initial log.
|
33
|
+
|
34
|
+
Parameters:
|
35
|
+
dir (Optional[str]):
|
36
|
+
The directory where CSV files will be saved. Defaults to None.
|
37
|
+
log (Optional[List]):
|
38
|
+
An initial list of log entries. Defaults to an empty list.
|
39
|
+
"""
|
40
|
+
self.dir = dir
|
41
|
+
self.log = deque(log) if log else deque()
|
42
|
+
|
43
|
+
def __call__(self, entry):
|
44
|
+
"""
|
45
|
+
Adds a new entry to the log.
|
46
|
+
|
47
|
+
Parameters:
|
48
|
+
entry:
|
49
|
+
The data entry to be added to the log.
|
50
|
+
"""
|
51
|
+
self.log.append(entry)
|
52
|
+
|
53
|
+
def to_csv(self, filename: str, dir: Optional[str] = None, verbose: bool = True,
|
54
|
+
timestamp: bool = True, dir_exist_ok: bool = True, file_exist_ok: bool = False) -> None:
|
55
|
+
"""
|
56
|
+
Exports the logged data to a CSV file and optionally clears the log.
|
57
|
+
|
58
|
+
Parameters:
|
59
|
+
filename (str):
|
60
|
+
The name of the CSV file.
|
61
|
+
dir (Optional[str]):
|
62
|
+
The directory to save the file. Defaults to the instance's dir attribute.
|
63
|
+
verbose (bool):
|
64
|
+
If True, prints a message upon completion. Defaults to True.
|
65
|
+
timestamp (bool):
|
66
|
+
If True, appends a timestamp to the filename. Defaults to True.
|
67
|
+
dir_exist_ok (bool):
|
68
|
+
If True, will not raise an error if the directory already exists. Defaults to True.
|
69
|
+
file_exist_ok (bool):
|
70
|
+
If True, overwrites the file if it exists. Defaults to False.
|
71
|
+
|
72
|
+
Side Effects:
|
73
|
+
Clears the log after saving the CSV file.
|
74
|
+
Prints a message indicating the save location and number of logs saved if verbose is True.
|
75
|
+
"""
|
76
|
+
dir = dir or self.dir
|
77
|
+
filepath = create_path(
|
78
|
+
dir=dir, filename=filename, timestamp=timestamp, dir_exist_ok=dir_exist_ok)
|
79
|
+
to_csv(list(self.log), filepath, file_exist_ok=file_exist_ok)
|
80
|
+
n_logs = len(list(self.log))
|
81
|
+
self.log = deque()
|
82
|
+
if verbose:
|
83
|
+
print(f"{n_logs} logs saved to {filepath}")
|
84
|
+
|
85
|
+
def set_dir(self, dir: str) -> None:
|
86
|
+
"""
|
87
|
+
Sets the default directory for saving CSV files.
|
88
|
+
|
89
|
+
Parameters:
|
90
|
+
dir (str):
|
91
|
+
The directory to be set as the default for saving files.
|
92
|
+
"""
|
93
|
+
self.dir = dir
|
94
|
+
|
@@ -0,0 +1,14 @@
|
|
1
|
+
from .chatcompletion import ChatCompletion
|
2
|
+
from .base_api_service import BaseAPIService, BaseAPIRateLimiter
|
3
|
+
from .oai import OpenAIService
|
4
|
+
from .openrouter import OpenRouterService
|
5
|
+
|
6
|
+
|
7
|
+
|
8
|
+
__all__ = [
|
9
|
+
"BaseAPIService",
|
10
|
+
"OpenAIService",
|
11
|
+
"OpenRouterService",
|
12
|
+
"ChatCompletion",
|
13
|
+
"BaseAPIRateLimiter"
|
14
|
+
]
|
@@ -1,55 +1,18 @@
|
|
1
|
-
import
|
2
|
-
import dotenv
|
1
|
+
import re
|
3
2
|
import asyncio
|
4
|
-
import
|
3
|
+
import os
|
5
4
|
import tiktoken
|
5
|
+
import logging
|
6
6
|
import aiohttp
|
7
|
-
from typing import
|
8
|
-
|
9
|
-
from .base_service import BaseAPIService
|
10
|
-
|
11
|
-
dotenv.load_dotenv()
|
12
|
-
|
13
|
-
from .base_service import BaseAPIService
|
14
|
-
from .service_utils import StatusTracker, AsyncQueue, RateLimiter
|
15
|
-
|
7
|
+
from typing import Generator, NoReturn, Dict, Any, Optional
|
8
|
+
from .service_objs import BaseService, RateLimiter, StatusTracker, AsyncQueue
|
16
9
|
|
17
|
-
|
18
|
-
class OpenAIRateLimiter(RateLimiter):
|
19
|
-
"""
|
20
|
-
A specialized RateLimiter for managing requests to the OpenAI API.
|
21
|
-
|
22
|
-
Extends the generic RateLimiter to enforce specific rate-limiting rules and limits
|
23
|
-
as required by the OpenAI API. This includes maximum requests and tokens per minute
|
24
|
-
and replenishing these limits at regular intervals.
|
25
|
-
|
26
|
-
Attributes:
|
27
|
-
max_requests_per_minute (int):
|
28
|
-
Maximum number of requests allowed per minute.
|
29
|
-
max_tokens_per_minute (int):
|
30
|
-
Maximum number of tokens allowed per minute.
|
31
|
-
|
32
|
-
Methods:
|
33
|
-
rate_limit_replenisher:
|
34
|
-
Coroutine to replenish rate limits over time.
|
35
|
-
calculate_num_token:
|
36
|
-
Calculates the required tokens for a request.
|
37
|
-
"""
|
10
|
+
class BaseAPIRateLimiter(RateLimiter):
|
38
11
|
|
39
12
|
def __init__(
|
40
13
|
self, max_requests_per_minute: int, max_tokens_per_minute: int
|
41
14
|
) -> None:
|
42
|
-
"""
|
43
|
-
Initializes the rate limiter with specific limits for OpenAI API.
|
44
|
-
|
45
|
-
Parameters:
|
46
|
-
max_requests_per_minute (int): The maximum number of requests allowed per minute.
|
47
|
-
|
48
|
-
max_tokens_per_minute (int): The maximum number of tokens that can accumulate per minute.
|
49
|
-
"""
|
50
15
|
super().__init__(max_requests_per_minute, max_tokens_per_minute)
|
51
|
-
if not os.getenv('env_readthedocs'):
|
52
|
-
self.rate_limit_replenisher_task = asyncio.create_task(self.rate_limit_replenisher())
|
53
16
|
|
54
17
|
@classmethod
|
55
18
|
async def create(
|
@@ -160,48 +123,53 @@ class OpenAIRateLimiter(RateLimiter):
|
|
160
123
|
raise NotImplementedError(
|
161
124
|
f'API endpoint "{api_endpoint}" not implemented in this script'
|
162
125
|
)
|
163
|
-
|
164
|
-
|
165
|
-
class OpenAIService(BaseAPIService):
|
166
|
-
"""
|
167
|
-
Service class for interacting with the OpenAI API.
|
168
126
|
|
169
|
-
This class provides methods for calling OpenAI's API endpoints, handling the responses,
|
170
|
-
and managing rate limits and asynchronous tasks associated with API calls.
|
171
127
|
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
128
|
+
class BaseAPIService(BaseService):
|
129
|
+
|
130
|
+
def __init__(self, api_key: str = None,
|
131
|
+
status_tracker = None,
|
132
|
+
queue = None, endpoint=None, schema=None,
|
133
|
+
ratelimiter=None, max_requests_per_minute=None, max_tokens_per_minute=None) -> None:
|
134
|
+
self.api_key = api_key
|
135
|
+
self.status_tracker = status_tracker or StatusTracker()
|
136
|
+
self.queue = queue or AsyncQueue()
|
137
|
+
self.endpoint=endpoint
|
138
|
+
self.schema = schema
|
139
|
+
self.rate_limiter = ratelimiter(max_requests_per_minute, max_tokens_per_minute)
|
140
|
+
|
141
|
+
@staticmethod
|
142
|
+
def api_methods(http_session, method="post"):
|
143
|
+
if method not in ["post", "delete", "head", "options", "patch"]:
|
144
|
+
raise ValueError("Invalid request, method must be in ['post', 'delete', 'head', 'options', 'patch']")
|
145
|
+
elif method == "post":
|
146
|
+
return http_session.post
|
147
|
+
elif method == "delete":
|
148
|
+
return http_session.delete
|
149
|
+
elif method == "head":
|
150
|
+
return http_session.head
|
151
|
+
elif method == "options":
|
152
|
+
return http_session.options
|
153
|
+
elif method == "patch":
|
154
|
+
return http_session.patch
|
155
|
+
|
156
|
+
@staticmethod
|
157
|
+
def api_endpoint_from_url(request_url: str) -> str:
|
158
|
+
match = re.search(r"^https://[^/]+/v\d+/(.+)$", request_url)
|
159
|
+
if match:
|
160
|
+
return match.group(1)
|
161
|
+
else:
|
162
|
+
return ""
|
178
163
|
|
179
|
-
|
164
|
+
@staticmethod
|
165
|
+
def task_id_generator_function() -> Generator[int, None, None]:
|
166
|
+
task_id = 0
|
167
|
+
while True:
|
168
|
+
yield task_id
|
169
|
+
task_id += 1
|
180
170
|
|
181
|
-
def __init__(
|
182
|
-
self,
|
183
|
-
api_key: str = None,
|
184
|
-
token_encoding_name: str = "cl100k_base",
|
185
|
-
max_attempts: int = 3,
|
186
|
-
max_requests_per_minute: int = 500,
|
187
|
-
max_tokens_per_minute: int = 150_000,
|
188
|
-
ratelimiter = OpenAIRateLimiter ,
|
189
|
-
status_tracker = None,
|
190
|
-
queue = None,
|
191
|
-
):
|
192
|
-
super().__init__(
|
193
|
-
api_key = api_key or os.getenv("OPENAI_API_KEY"),
|
194
|
-
status_tracker = status_tracker or StatusTracker(),
|
195
|
-
queue = queue or AsyncQueue(),
|
196
|
-
ratelimiter=ratelimiter,
|
197
|
-
max_requests_per_minute=max_requests_per_minute,
|
198
|
-
max_tokens_per_minute=max_tokens_per_minute),
|
199
|
-
self.token_encoding_name=token_encoding_name
|
200
|
-
self.max_attempts = max_attempts
|
201
|
-
|
202
|
-
|
203
171
|
async def _call_api(self, http_session, endpoint_, method="post", payload: Dict[str, any] =None) -> Optional[Dict[str, any]]:
|
204
|
-
endpoint_ = self.api_endpoint_from_url(
|
172
|
+
endpoint_ = self.api_endpoint_from_url("https://api.openai.com/v1/"+endpoint_)
|
205
173
|
|
206
174
|
while True:
|
207
175
|
if self.rate_limiter.available_request_capacity < 1 or self.rate_limiter.available_token_capacity < 10: # Minimum token count
|
@@ -244,7 +212,7 @@ class OpenAIService(BaseAPIService):
|
|
244
212
|
else:
|
245
213
|
await asyncio.sleep(1)
|
246
214
|
|
247
|
-
async def
|
215
|
+
async def _serve(self, payload, endpoint_="chat/completions", method="post"):
|
248
216
|
|
249
217
|
async def call_api():
|
250
218
|
async with aiohttp.ClientSession() as http_session:
|
@@ -256,5 +224,4 @@ class OpenAIService(BaseAPIService):
|
|
256
224
|
except Exception as e:
|
257
225
|
self.status_tracker.num_tasks_failed += 1
|
258
226
|
raise e
|
259
|
-
|
260
227
|
|
@@ -1,5 +1,6 @@
|
|
1
1
|
import abc
|
2
2
|
|
3
|
+
|
3
4
|
class BaseEndpoint(abc.ABC):
|
4
5
|
endpoint: str = abc.abstractproperty()
|
5
6
|
|
@@ -16,17 +17,6 @@ class BaseEndpoint(abc.ABC):
|
|
16
17
|
"""
|
17
18
|
pass
|
18
19
|
|
19
|
-
# @abc.abstractmethod
|
20
|
-
# async def call_api(self, session, **kwargs):
|
21
|
-
# """
|
22
|
-
# Make a call to the API endpoint and process the response.
|
23
|
-
|
24
|
-
# Parameters:
|
25
|
-
# session: The aiohttp client session.
|
26
|
-
# **kwargs: Additional keyword arguments for configuration.
|
27
|
-
# """
|
28
|
-
# pass
|
29
|
-
|
30
20
|
@abc.abstractmethod
|
31
21
|
def process_response(self, response):
|
32
22
|
"""
|
@@ -38,14 +28,21 @@ class BaseEndpoint(abc.ABC):
|
|
38
28
|
pass
|
39
29
|
|
40
30
|
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
31
|
+
class ChatCompletion(BaseEndpoint):
|
32
|
+
endpoint: str = "chat/completion"
|
33
|
+
|
34
|
+
@classmethod
|
35
|
+
def create_payload(scls, messages, llmconfig, schema, **kwargs):
|
36
|
+
config = {**llmconfig, **kwargs}
|
37
|
+
payload = {"messages": messages}
|
38
|
+
for key in schema['required']:
|
39
|
+
payload.update({key: config[key]})
|
40
|
+
|
41
|
+
for key in schema['optional']:
|
42
|
+
if bool(config[key]) is True and str(config[key]).lower() != "none":
|
43
|
+
payload.update({key: config[key]})
|
44
|
+
return payload
|
45
|
+
|
46
|
+
def process_response(self, session, payload, completion):
|
47
|
+
...
|
48
|
+
|
lionagi/services/oai.py
ADDED
@@ -0,0 +1,34 @@
|
|
1
|
+
from os import getenv
|
2
|
+
import dotenv
|
3
|
+
from .base_api_service import BaseAPIService, BaseAPIRateLimiter
|
4
|
+
|
5
|
+
dotenv.load_dotenv()
|
6
|
+
|
7
|
+
class OpenAIService(BaseAPIService):
|
8
|
+
|
9
|
+
base_url = "https://api.openai.com/v1/"
|
10
|
+
|
11
|
+
def __init__(
|
12
|
+
self,
|
13
|
+
api_key: str = None,
|
14
|
+
token_encoding_name: str = "cl100k_base",
|
15
|
+
max_attempts: int = 3,
|
16
|
+
max_requests_per_minute: int = 500,
|
17
|
+
max_tokens_per_minute: int = 150_000,
|
18
|
+
ratelimiter = BaseAPIRateLimiter ,
|
19
|
+
status_tracker = None,
|
20
|
+
queue = None,
|
21
|
+
):
|
22
|
+
super().__init__(
|
23
|
+
api_key = api_key or getenv("OPENAI_API_KEY"),
|
24
|
+
status_tracker = status_tracker,
|
25
|
+
queue = queue,
|
26
|
+
ratelimiter=ratelimiter,
|
27
|
+
max_requests_per_minute=max_requests_per_minute,
|
28
|
+
max_tokens_per_minute=max_tokens_per_minute),
|
29
|
+
self.token_encoding_name=token_encoding_name
|
30
|
+
self.max_attempts = max_attempts
|
31
|
+
|
32
|
+
async def serve(self, payload, endpoint_="chat/completions", method="post"):
|
33
|
+
return await self._serve(payload=payload, endpoint_=endpoint_, method=method)
|
34
|
+
|
@@ -0,0 +1,32 @@
|
|
1
|
+
from os import getenv
|
2
|
+
from .base_api_service import BaseAPIService, BaseAPIRateLimiter
|
3
|
+
|
4
|
+
class OpenRouterService(BaseAPIService):
|
5
|
+
_key_scheme = "OPENROUTER_API_KEY"
|
6
|
+
|
7
|
+
base_url = "https://openrouter.ai/api/v1/"
|
8
|
+
|
9
|
+
def __init__(
|
10
|
+
self,
|
11
|
+
api_key: str = None,
|
12
|
+
token_encoding_name: str = "cl100k_base",
|
13
|
+
max_attempts: int = 3,
|
14
|
+
max_requests_per_minute: int = 500,
|
15
|
+
max_tokens_per_minute: int = 150_000,
|
16
|
+
ratelimiter = BaseAPIRateLimiter ,
|
17
|
+
status_tracker = None,
|
18
|
+
queue = None,
|
19
|
+
):
|
20
|
+
super().__init__(
|
21
|
+
api_key = api_key or getenv(self._key_scheme),
|
22
|
+
status_tracker = status_tracker,
|
23
|
+
queue = queue,
|
24
|
+
ratelimiter=ratelimiter,
|
25
|
+
max_requests_per_minute=max_requests_per_minute,
|
26
|
+
max_tokens_per_minute=max_tokens_per_minute),
|
27
|
+
self.token_encoding_name=token_encoding_name
|
28
|
+
self.max_attempts = max_attempts
|
29
|
+
|
30
|
+
async def serve(self, payload, endpoint_="chat/completions"):
|
31
|
+
return await self._serve(payload=payload, endpoint_=endpoint_)
|
32
|
+
|
@@ -0,0 +1,128 @@
|
|
1
|
+
from pydantic import Field
|
2
|
+
from typing import Dict, Optional, Any
|
3
|
+
from ..schema.base_schema import BaseNode
|
4
|
+
|
5
|
+
|
6
|
+
class Relationship(BaseNode):
|
7
|
+
"""
|
8
|
+
Relationship class represents a relationship between two nodes in a graph.
|
9
|
+
|
10
|
+
Inherits from BaseNode and adds functionality to manage conditions and relationships
|
11
|
+
between source and target nodes.
|
12
|
+
|
13
|
+
Attributes:
|
14
|
+
source_node_id (str): The identifier of the source node.
|
15
|
+
target_node_id (str): The identifier of the target node.
|
16
|
+
condition (Dict[str, Any]): A dictionary representing conditions for the relationship.
|
17
|
+
"""
|
18
|
+
|
19
|
+
source_node_id: str
|
20
|
+
target_node_id: str
|
21
|
+
condition: dict = Field(default={})
|
22
|
+
|
23
|
+
def add_condition(self, condition: Dict[str, Any]) -> None:
|
24
|
+
"""
|
25
|
+
Adds a condition to the relationship.
|
26
|
+
|
27
|
+
Parameters:
|
28
|
+
condition (Dict[str, Any]): The condition to be added.
|
29
|
+
"""
|
30
|
+
self.condition.update(condition)
|
31
|
+
|
32
|
+
def remove_condition(self, condition_key: str) -> Any:
|
33
|
+
"""
|
34
|
+
Removes a condition from the relationship.
|
35
|
+
|
36
|
+
Parameters:
|
37
|
+
condition_key (str): The key of the condition to be removed.
|
38
|
+
|
39
|
+
Returns:
|
40
|
+
Any: The value of the removed condition.
|
41
|
+
|
42
|
+
Raises:
|
43
|
+
KeyError: If the condition key is not found.
|
44
|
+
"""
|
45
|
+
if condition_key not in self.condition.keys():
|
46
|
+
raise KeyError(f'condition {condition_key} is not found')
|
47
|
+
return self.condition.pop(condition_key)
|
48
|
+
|
49
|
+
def condition_exists(self, condition_key: str) -> bool:
|
50
|
+
"""
|
51
|
+
Checks if a condition exists in the relationship.
|
52
|
+
|
53
|
+
Parameters:
|
54
|
+
condition_key (str): The key of the condition to check.
|
55
|
+
|
56
|
+
Returns:
|
57
|
+
bool: True if the condition exists, False otherwise.
|
58
|
+
"""
|
59
|
+
if condition_key in self.condition.keys():
|
60
|
+
return True
|
61
|
+
else:
|
62
|
+
return False
|
63
|
+
|
64
|
+
def get_condition(self, condition_key: Optional[str] = None) -> Any:
|
65
|
+
"""
|
66
|
+
Retrieves a specific condition or all conditions of the relationship.
|
67
|
+
|
68
|
+
Parameters:
|
69
|
+
condition_key (Optional[str]): The key of the specific condition. Defaults to None.
|
70
|
+
|
71
|
+
Returns:
|
72
|
+
Any: The requested condition or all conditions if no key is provided.
|
73
|
+
|
74
|
+
Raises:
|
75
|
+
ValueError: If the specified condition key does not exist.
|
76
|
+
"""
|
77
|
+
if condition_key is None:
|
78
|
+
return self.condition
|
79
|
+
if self.condition_exists(condition_key=condition_key):
|
80
|
+
return self.condition[condition_key]
|
81
|
+
else:
|
82
|
+
raise ValueError(f"Condition {condition_key} does not exist")
|
83
|
+
|
84
|
+
def _source_existed(self, obj: Dict[str, Any]) -> bool:
|
85
|
+
"""
|
86
|
+
Checks if the source node exists in a given object.
|
87
|
+
|
88
|
+
Parameters:
|
89
|
+
obj (Dict[str, Any]): The object to check.
|
90
|
+
|
91
|
+
Returns:
|
92
|
+
bool: True if the source node exists, False otherwise.
|
93
|
+
"""
|
94
|
+
return self.source_node_id in obj.keys()
|
95
|
+
|
96
|
+
def _target_existed(self, obj: Dict[str, Any]) -> bool:
|
97
|
+
"""
|
98
|
+
Checks if the target node exists in a given object.
|
99
|
+
|
100
|
+
Parameters:
|
101
|
+
obj (Dict[str, Any]): The object to check.
|
102
|
+
|
103
|
+
Returns:
|
104
|
+
bool: True if the target node exists, False otherwise.
|
105
|
+
"""
|
106
|
+
return self.target_node_id in obj.keys()
|
107
|
+
|
108
|
+
def _is_in(self, obj: Dict[str, Any]) -> bool:
|
109
|
+
"""
|
110
|
+
Validates the existence of both source and target nodes in a given object.
|
111
|
+
|
112
|
+
Parameters:
|
113
|
+
obj (Dict[str, Any]): The object to check.
|
114
|
+
|
115
|
+
Returns:
|
116
|
+
bool: True if both nodes exist.
|
117
|
+
|
118
|
+
Raises:
|
119
|
+
ValueError: If either the source or target node does not exist.
|
120
|
+
"""
|
121
|
+
if self._source_existed(obj) and self._target_existed(obj):
|
122
|
+
return True
|
123
|
+
|
124
|
+
elif self._source_existed(obj):
|
125
|
+
raise ValueError(f"Target node {self.source_node_id} does not exist")
|
126
|
+
else :
|
127
|
+
raise ValueError(f"Source node {self.target_node_id} does not exist")
|
128
|
+
|