lionagi 0.0.112__py3-none-any.whl → 0.0.113__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- lionagi/__init__.py +3 -3
- lionagi/bridge/__init__.py +7 -0
- lionagi/bridge/langchain.py +131 -0
- lionagi/bridge/llama_index.py +157 -0
- lionagi/configs/__init__.py +7 -0
- lionagi/configs/oai_configs.py +49 -0
- lionagi/configs/openrouter_config.py +49 -0
- lionagi/core/__init__.py +8 -2
- lionagi/core/instruction_sets.py +1 -3
- lionagi/core/messages.py +2 -2
- lionagi/core/sessions.py +174 -27
- lionagi/datastore/__init__.py +1 -0
- lionagi/loader/__init__.py +9 -4
- lionagi/loader/chunker.py +157 -0
- lionagi/loader/reader.py +124 -0
- lionagi/objs/__init__.py +7 -0
- lionagi/objs/messenger.py +163 -0
- lionagi/objs/tool_registry.py +247 -0
- lionagi/schema/__init__.py +11 -0
- lionagi/schema/base_schema.py +239 -0
- lionagi/schema/base_tool.py +9 -0
- lionagi/schema/data_logger.py +94 -0
- lionagi/services/__init__.py +14 -0
- lionagi/{service_/oai.py → services/base_api_service.py} +49 -82
- lionagi/{endpoint/base_endpoint.py → services/chatcompletion.py} +19 -22
- lionagi/services/oai.py +34 -0
- lionagi/services/openrouter.py +32 -0
- lionagi/{service_/service_utils.py → services/service_objs.py} +0 -1
- lionagi/structure/__init__.py +7 -0
- lionagi/structure/relationship.py +128 -0
- lionagi/structure/structure.py +160 -0
- lionagi/tests/test_flatten_util.py +426 -0
- lionagi/tools/__init__.py +0 -5
- lionagi/tools/coder.py +1 -0
- lionagi/tools/scorer.py +1 -0
- lionagi/tools/validator.py +1 -0
- lionagi/utils/__init__.py +46 -20
- lionagi/utils/api_util.py +86 -0
- lionagi/utils/call_util.py +347 -0
- lionagi/utils/flat_util.py +540 -0
- lionagi/utils/io_util.py +102 -0
- lionagi/utils/load_utils.py +190 -0
- lionagi/utils/sys_util.py +191 -0
- lionagi/utils/tool_util.py +92 -0
- lionagi/utils/type_util.py +81 -0
- lionagi/version.py +1 -1
- {lionagi-0.0.112.dist-info → lionagi-0.0.113.dist-info}/METADATA +37 -13
- lionagi-0.0.113.dist-info/RECORD +84 -0
- lionagi/endpoint/chat_completion.py +0 -20
- lionagi/endpoint/endpoint_utils.py +0 -0
- lionagi/llm_configs.py +0 -21
- lionagi/loader/load_utils.py +0 -161
- lionagi/schema.py +0 -275
- lionagi/service_/__init__.py +0 -6
- lionagi/service_/base_service.py +0 -48
- lionagi/service_/openrouter.py +0 -1
- lionagi/services.py +0 -1
- lionagi/tools/tool_utils.py +0 -75
- lionagi/utils/sys_utils.py +0 -799
- lionagi-0.0.112.dist-info/RECORD +0 -67
- /lionagi/{core/responses.py → datastore/chroma.py} +0 -0
- /lionagi/{endpoint/assistants.py → datastore/deeplake.py} +0 -0
- /lionagi/{endpoint/audio.py → datastore/elasticsearch.py} +0 -0
- /lionagi/{endpoint/embeddings.py → datastore/lantern.py} +0 -0
- /lionagi/{endpoint/files.py → datastore/pinecone.py} +0 -0
- /lionagi/{endpoint/fine_tuning.py → datastore/postgres.py} +0 -0
- /lionagi/{endpoint/images.py → datastore/qdrant.py} +0 -0
- /lionagi/{endpoint/messages.py → schema/base_condition.py} +0 -0
- /lionagi/{service_ → services}/anthropic.py +0 -0
- /lionagi/{service_ → services}/anyscale.py +0 -0
- /lionagi/{service_ → services}/azure.py +0 -0
- /lionagi/{service_ → services}/bedrock.py +0 -0
- /lionagi/{service_ → services}/everlyai.py +0 -0
- /lionagi/{service_ → services}/gemini.py +0 -0
- /lionagi/{service_ → services}/gpt4all.py +0 -0
- /lionagi/{service_ → services}/huggingface.py +0 -0
- /lionagi/{service_ → services}/litellm.py +0 -0
- /lionagi/{service_ → services}/localai.py +0 -0
- /lionagi/{service_ → services}/mistralai.py +0 -0
- /lionagi/{service_ → services}/ollama.py +0 -0
- /lionagi/{service_ → services}/openllm.py +0 -0
- /lionagi/{service_ → services}/perplexity.py +0 -0
- /lionagi/{service_ → services}/predibase.py +0 -0
- /lionagi/{service_ → services}/rungpt.py +0 -0
- /lionagi/{service_ → services}/vllm.py +0 -0
- /lionagi/{service_ → services}/xinference.py +0 -0
- /lionagi/{endpoint → tests}/__init__.py +0 -0
- /lionagi/{endpoint/models.py → tools/planner.py} +0 -0
- /lionagi/{endpoint/moderations.py → tools/prompter.py} +0 -0
- /lionagi/{endpoint/runs.py → tools/sandbox.py} +0 -0
- /lionagi/{endpoint/threads.py → tools/summarizer.py} +0 -0
- {lionagi-0.0.112.dist-info → lionagi-0.0.113.dist-info}/LICENSE +0 -0
- {lionagi-0.0.112.dist-info → lionagi-0.0.113.dist-info}/WHEEL +0 -0
- {lionagi-0.0.112.dist-info → lionagi-0.0.113.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,163 @@
|
|
1
|
+
# from typing import Optional, Any, Union, Dict, Tuple
|
2
|
+
|
3
|
+
# from lionagi.utils.call_util import lcall
|
4
|
+
# from lionagi.schema.data_logger import DataLogger
|
5
|
+
# from lionagi.core.messages import Message, Response, Instruction, System
|
6
|
+
|
7
|
+
|
8
|
+
# class Messenger:
|
9
|
+
# """
|
10
|
+
# Messenger handles the creation, logging, and exporting of messages.
|
11
|
+
|
12
|
+
# This class is responsible for creating various types of messages (system, instruction, response),
|
13
|
+
# logging them, and optionally exporting the log to a CSV file.
|
14
|
+
|
15
|
+
# Attributes:
|
16
|
+
# _logger (DataLogger): An instance of DataLogger to manage message logging.
|
17
|
+
|
18
|
+
# Methods:
|
19
|
+
# set_dir: Sets the directory for the DataLogger to save CSV files.
|
20
|
+
|
21
|
+
# set_log: Sets the log object for the DataLogger.
|
22
|
+
|
23
|
+
# log_message: Logs a message in JSON format.
|
24
|
+
|
25
|
+
# to_csv: Exports logged messages to a CSV file.
|
26
|
+
|
27
|
+
# _create_message: Internal method to create a specific type of message.
|
28
|
+
|
29
|
+
# create_message: Public interface to create messages, log them, and optionally return them in different formats.
|
30
|
+
# """
|
31
|
+
|
32
|
+
# def __init__(self) -> None:
|
33
|
+
# """
|
34
|
+
# Initializes the Messenger with a DataLogger instance.
|
35
|
+
# """
|
36
|
+
# self._logger = DataLogger()
|
37
|
+
|
38
|
+
# def set_dir(self, dir: str) -> None:
|
39
|
+
# """
|
40
|
+
# Sets the directory where the DataLogger will save CSV files.
|
41
|
+
|
42
|
+
# Parameters:
|
43
|
+
# dir (str): The directory path to set for the DataLogger.
|
44
|
+
# """
|
45
|
+
# self._logger.dir = dir
|
46
|
+
|
47
|
+
# def set_log(self, log) -> None:
|
48
|
+
# """
|
49
|
+
# Sets the log object for the DataLogger.
|
50
|
+
|
51
|
+
# Parameters:
|
52
|
+
# log: The log object to be used by the DataLogger.
|
53
|
+
# """
|
54
|
+
# self._logger.log = log
|
55
|
+
|
56
|
+
# def log_message(self, msg: Message) -> None:
|
57
|
+
# """
|
58
|
+
# Logs a message in JSON format using the DataLogger.
|
59
|
+
|
60
|
+
# Parameters:
|
61
|
+
# msg (Message): The message object to be logged.
|
62
|
+
# """
|
63
|
+
# self._logger(msg.to_json())
|
64
|
+
|
65
|
+
# def to_csv(self, **kwargs) -> None:
|
66
|
+
# """
|
67
|
+
# Exports the logged messages to a CSV file.
|
68
|
+
|
69
|
+
# Parameters:
|
70
|
+
# **kwargs: Additional keyword arguments to be passed to the DataLogger's to_csv method.
|
71
|
+
# """
|
72
|
+
# self._logger.to_csv(**kwargs)
|
73
|
+
|
74
|
+
# def _create_message(self,
|
75
|
+
# system: Optional[Any] = None,
|
76
|
+
# instruction: Optional[Any] = None,
|
77
|
+
# context: Optional[Any] = None,
|
78
|
+
# response: Optional[Any] = None,
|
79
|
+
# name: Optional[str] = None) -> Message:
|
80
|
+
# """
|
81
|
+
# Creates a specific type of message based on the provided parameters.
|
82
|
+
|
83
|
+
# Parameters:
|
84
|
+
# system (Optional[Any]): System message content.
|
85
|
+
|
86
|
+
# instruction (Optional[Any]): Instruction message content.
|
87
|
+
|
88
|
+
# context (Optional[Any]): Context for the instruction message.
|
89
|
+
|
90
|
+
# response (Optional[Any]): Response message content.
|
91
|
+
|
92
|
+
# name (Optional[str]): Name associated with the message.
|
93
|
+
|
94
|
+
# Returns:
|
95
|
+
# Message: The created message object of type Response, Instruction, or System.
|
96
|
+
|
97
|
+
# Raises:
|
98
|
+
# ValueError: If more than one or none of the message content parameters (system, instruction, response) are provided.
|
99
|
+
# """
|
100
|
+
|
101
|
+
# if sum(lcall([system, instruction, response], bool)) != 1:
|
102
|
+
# raise ValueError("Error: Message must have one and only one role.")
|
103
|
+
|
104
|
+
# else:
|
105
|
+
# msg = 0
|
106
|
+
|
107
|
+
# if response:
|
108
|
+
# msg = Response()
|
109
|
+
# msg.create_message(response=response,
|
110
|
+
# name=name,)
|
111
|
+
# elif instruction:
|
112
|
+
# msg = Instruction()
|
113
|
+
# msg.create_message(instruction=instruction,
|
114
|
+
# context=context,
|
115
|
+
# name=name,)
|
116
|
+
# elif system:
|
117
|
+
# msg = System()
|
118
|
+
# msg.create_message(system=system,
|
119
|
+
# name=name,)
|
120
|
+
# return msg
|
121
|
+
|
122
|
+
# def create_message(self,
|
123
|
+
# system: Optional[Any] = None,
|
124
|
+
# instruction: Optional[Any] = None,
|
125
|
+
# context: Optional[Any] = None,
|
126
|
+
# response: Optional[Any] = None,
|
127
|
+
# name: Optional[str] = None,
|
128
|
+
# obj: bool = False,
|
129
|
+
# log_: bool = True) -> Union[Message, Tuple[Message, Dict]]:
|
130
|
+
# """
|
131
|
+
# Creates and optionally logs a message, returning it in different formats based on parameters.
|
132
|
+
|
133
|
+
# Parameters:
|
134
|
+
# system (Optional[Any]): System message content.
|
135
|
+
|
136
|
+
# instruction (Optional[Any]): Instruction message content.
|
137
|
+
|
138
|
+
# context (Optional[Any]): Context for the instruction message.
|
139
|
+
|
140
|
+
# response (Optional[Any]): Response message content.
|
141
|
+
|
142
|
+
# name (Optional[str]): Name associated with the message.
|
143
|
+
|
144
|
+
# obj (bool): If True, returns the Message object and its dictionary representation. Defaults to False.
|
145
|
+
|
146
|
+
# log_ (bool): If True, logs the created message. Defaults to True.
|
147
|
+
|
148
|
+
# Returns:
|
149
|
+
# Union[Message, Tuple[Message, Dict]]: The created message in the specified format.
|
150
|
+
# """
|
151
|
+
|
152
|
+
# msg = self._create_message(system=system,
|
153
|
+
# instruction=instruction,
|
154
|
+
# context=context,
|
155
|
+
# response=response,
|
156
|
+
# name=name)
|
157
|
+
# if log_:
|
158
|
+
# self.log_message(msg)
|
159
|
+
# if obj:
|
160
|
+
# return (msg, msg._to_message())
|
161
|
+
# else:
|
162
|
+
# return msg._to_message()
|
163
|
+
|
@@ -0,0 +1,247 @@
|
|
1
|
+
import json
|
2
|
+
import asyncio
|
3
|
+
from typing import Dict
|
4
|
+
from ..utils import lcall, str_to_num
|
5
|
+
from ..schema import BaseNode
|
6
|
+
|
7
|
+
|
8
|
+
class ToolManager(BaseNode):
|
9
|
+
registry : Dict = {}
|
10
|
+
|
11
|
+
def _name_existed(self, name: str):
|
12
|
+
return True if name in self.registry.keys() else False
|
13
|
+
|
14
|
+
def _register_tool(self, tool, name: str=None, update=False, new=False, prefix=None, postfix=None):
|
15
|
+
|
16
|
+
if self._name_existed(name):
|
17
|
+
if update and new:
|
18
|
+
raise ValueError(f"Cannot both update and create new registry for existing function {name} at the same time")
|
19
|
+
|
20
|
+
if len(name) > len(tool.func.__name__):
|
21
|
+
if new and not postfix:
|
22
|
+
try:
|
23
|
+
idx = str_to_num(name[-3:], int)
|
24
|
+
if idx > 0:
|
25
|
+
postfix = idx + 1
|
26
|
+
except:
|
27
|
+
pass
|
28
|
+
|
29
|
+
name = f"{prefix or ''}{name}{postfix}" if new else tool.func.__name__
|
30
|
+
self.registry.update({name:tool})
|
31
|
+
|
32
|
+
async def invoke(self, func_call):
|
33
|
+
name, kwargs = func_call
|
34
|
+
if self._name_existed(name):
|
35
|
+
tool = self.registry[name]
|
36
|
+
func = tool.func
|
37
|
+
parser = tool.parser
|
38
|
+
try:
|
39
|
+
if asyncio.iscoroutinefunction(func):
|
40
|
+
return await parser(func(**kwargs)) if parser else func(**kwargs)
|
41
|
+
else:
|
42
|
+
return parser(func(**kwargs)) if parser else func(**kwargs)
|
43
|
+
except Exception as e:
|
44
|
+
raise ValueError(f"Error when invoking function {name} with arguments {kwargs} with error message {e}")
|
45
|
+
else:
|
46
|
+
raise ValueError(f"Function {name} is not registered.")
|
47
|
+
|
48
|
+
@staticmethod
|
49
|
+
def _get_function_call(response):
|
50
|
+
"""
|
51
|
+
Extract function name and arguments from a response JSON.
|
52
|
+
|
53
|
+
Parameters:
|
54
|
+
response (dict): The JSON response containing function information.
|
55
|
+
|
56
|
+
Returns:
|
57
|
+
Tuple[str, dict]: The function name and its arguments.
|
58
|
+
"""
|
59
|
+
try:
|
60
|
+
# out = json.loads(response)
|
61
|
+
func = response['function'][5:]
|
62
|
+
args = json.loads(response['arguments'])
|
63
|
+
return (func, args)
|
64
|
+
except:
|
65
|
+
try:
|
66
|
+
# out = json.loads(response)
|
67
|
+
# out = out['tool_uses'][0]
|
68
|
+
func = response['recipient_name'].split('.')[-1]
|
69
|
+
args = response['parameters']
|
70
|
+
return (func, args)
|
71
|
+
except:
|
72
|
+
raise ValueError('response is not a valid function call')
|
73
|
+
|
74
|
+
def register_tools(self, tools, update=False, new=False, prefix=None, postfix=None ):
|
75
|
+
lcall(tools, self._register_tool, update=update, new=new, prefix=prefix, postfix=postfix)
|
76
|
+
|
77
|
+
|
78
|
+
|
79
|
+
# import asyncio
|
80
|
+
# import json
|
81
|
+
# from typing import Dict, Any, Optional, List, Tuple
|
82
|
+
|
83
|
+
# from lionagi.schema.base_tool import Tool
|
84
|
+
# from lionagi.utils.type_util import to_list
|
85
|
+
# from lionagi.utils.tool_util import func_to_tool
|
86
|
+
|
87
|
+
# class ToolRegistry:
|
88
|
+
# """
|
89
|
+
# ToolManager manages the registration and invocation of tools.
|
90
|
+
|
91
|
+
# This class provides functionalities to register tools, check for their existence,
|
92
|
+
# and invoke them dynamically with specified arguments.
|
93
|
+
|
94
|
+
# Attributes:
|
95
|
+
# registry (Dict[str, BaseTool]): A dictionary to store registered tools by name.
|
96
|
+
|
97
|
+
# Methods:
|
98
|
+
# _name_exists: Checks if a tool name already exists in the registry.
|
99
|
+
|
100
|
+
# _register_tool: Registers a tool in the registry.
|
101
|
+
|
102
|
+
# invoke: Dynamically invokes a registered tool with given arguments.
|
103
|
+
|
104
|
+
# register_tools: Registers multiple tools in the registry.
|
105
|
+
# """
|
106
|
+
|
107
|
+
# def __init__(self):
|
108
|
+
# """
|
109
|
+
# Initializes the ToolManager with an empty registry.
|
110
|
+
# """
|
111
|
+
# self.registry: Dict[str, Tool] = {}
|
112
|
+
|
113
|
+
# def _name_exists(self, name: str) -> bool:
|
114
|
+
# """
|
115
|
+
# Checks if a tool name already exists in the registry.
|
116
|
+
|
117
|
+
# Parameters:
|
118
|
+
# name (str): The name of the tool to check.
|
119
|
+
|
120
|
+
# Returns:
|
121
|
+
# bool: True if the name exists in the registry, False otherwise.
|
122
|
+
# """
|
123
|
+
# return name in self.registry
|
124
|
+
|
125
|
+
# def _register_tool(self, tool: Tool, name: Optional[str] = None, update: bool = False, new: bool = False, prefix: Optional[str] = None, postfix: Optional[int] = None):
|
126
|
+
# """
|
127
|
+
# Registers a tool in the registry.
|
128
|
+
|
129
|
+
# Parameters:
|
130
|
+
# tool (BaseTool): The tool to be registered.
|
131
|
+
|
132
|
+
# name (Optional[str]): The name to register the tool with. Defaults to the tool's function name.
|
133
|
+
|
134
|
+
# update (bool): If True, updates the existing tool. Defaults to False.
|
135
|
+
|
136
|
+
# new (bool): If True, creates a new registry entry. Defaults to False.
|
137
|
+
|
138
|
+
# prefix (Optional[str]): A prefix for the tool name.
|
139
|
+
|
140
|
+
# postfix (Optional[int]): A postfix for the tool name.
|
141
|
+
|
142
|
+
# Raises:
|
143
|
+
# ValueError: If both update and new are True for an existing function.
|
144
|
+
# """
|
145
|
+
# name = name or tool.func.__name__
|
146
|
+
# original_name = name
|
147
|
+
|
148
|
+
# if self._name_exists(name):
|
149
|
+
# if update and new:
|
150
|
+
# raise ValueError("Cannot both update and create new registry for existing function.")
|
151
|
+
# if new:
|
152
|
+
# idx = 1
|
153
|
+
# while self._name_exists(f"{prefix or ''}{name}{postfix or ''}{idx}"):
|
154
|
+
# idx += 1
|
155
|
+
# name = f"{prefix or ''}{name}{postfix or ''}{idx}"
|
156
|
+
# else:
|
157
|
+
# self.registry.pop(original_name, None)
|
158
|
+
|
159
|
+
# self.registry[name] = tool
|
160
|
+
|
161
|
+
# async def invoke(self, name_kwargs: Tuple) -> Any:
|
162
|
+
# """
|
163
|
+
# Dynamically invokes a registered tool with given arguments.
|
164
|
+
|
165
|
+
# Parameters:
|
166
|
+
# name (str): The name of the tool to invoke.
|
167
|
+
|
168
|
+
# kwargs (Dict[str, Any]): A dictionary of keyword arguments to pass to the tool.
|
169
|
+
|
170
|
+
# Returns:
|
171
|
+
# Any: The result of the tool invocation.
|
172
|
+
|
173
|
+
# Raises:
|
174
|
+
# ValueError: If the tool is not registered or if an error occurs during invocation.
|
175
|
+
# """
|
176
|
+
# name, kwargs = name_kwargs
|
177
|
+
# if not self._name_exists(name):
|
178
|
+
# raise ValueError(f"Function {name} is not registered.")
|
179
|
+
|
180
|
+
# tool = self.registry[name]
|
181
|
+
# func = tool.func
|
182
|
+
# parser = tool.parser
|
183
|
+
|
184
|
+
# try:
|
185
|
+
# result = await func(**kwargs) if asyncio.iscoroutinefunction(func) else func(**kwargs)
|
186
|
+
# return await parser(result) if parser and asyncio.iscoroutinefunction(parser) else parser(result) if parser else result
|
187
|
+
# except Exception as e:
|
188
|
+
# raise ValueError(f"Error invoking function {name}: {str(e)}")
|
189
|
+
|
190
|
+
# def register_tools(self, tools: List[Tool], update: bool = False, new: bool = False,
|
191
|
+
# prefix: Optional[str] = None, postfix: Optional[int] = None):
|
192
|
+
# """
|
193
|
+
# Registers multiple tools in the registry.
|
194
|
+
|
195
|
+
# Parameters:
|
196
|
+
# tools (List[BaseTool]): A list of tools to register.
|
197
|
+
|
198
|
+
# update (bool): If True, updates existing tools. Defaults to False.
|
199
|
+
|
200
|
+
# new (bool): If True, creates new registry entries. Defaults to False.
|
201
|
+
|
202
|
+
# prefix (Optional[str]): A prefix for the tool names.
|
203
|
+
|
204
|
+
# postfix (Optional[int]): A postfix for the tool names.
|
205
|
+
# """
|
206
|
+
# for tool in tools:
|
207
|
+
# self._register_tool(tool, update=update, new=new, prefix=prefix, postfix=postfix)
|
208
|
+
|
209
|
+
# def _register_func(self, func_, parser=None, **kwargs):
|
210
|
+
# # kwargs for _register_tool
|
211
|
+
|
212
|
+
# tool = func_to_tool(func_=func_, parser=parser)
|
213
|
+
# self._register_tool(tool=tool, **kwargs)
|
214
|
+
|
215
|
+
# def register_funcs(self, funcs, parsers=None, **kwargs):
|
216
|
+
# funcs, parsers = to_list(funcs), to_list(parsers)
|
217
|
+
# if parsers is not None and len(parsers) != len(funcs):
|
218
|
+
# raise ValueError("The number of funcs and tools should be the same")
|
219
|
+
# parsers = parsers or [None for _ in range(len(funcs))]
|
220
|
+
|
221
|
+
# for i, func in enumerate(funcs):
|
222
|
+
# self._register_func(func_=func, parser=parsers[i], **kwargs)
|
223
|
+
|
224
|
+
# @staticmethod
|
225
|
+
# def get_function_call(response):
|
226
|
+
# """
|
227
|
+
# Extract function name and arguments from a response JSON.
|
228
|
+
|
229
|
+
# Parameters:
|
230
|
+
# response (dict): The JSON response containing function information.
|
231
|
+
|
232
|
+
# Returns:
|
233
|
+
# Tuple[str, dict]: The function name and its arguments.
|
234
|
+
# """
|
235
|
+
# try:
|
236
|
+
# # out = json.loads(response)
|
237
|
+
# func = response['function'][5:]
|
238
|
+
# args = json.loads(response['arguments'])
|
239
|
+
# return (func, args)
|
240
|
+
# except:
|
241
|
+
# try:
|
242
|
+
# func = response['recipient_name'].split('.')[-1]
|
243
|
+
# args = response['parameters']
|
244
|
+
# return (func, args)
|
245
|
+
# except:
|
246
|
+
# raise ValueError('response is not a valid function call')
|
247
|
+
|
@@ -0,0 +1,239 @@
|
|
1
|
+
import json
|
2
|
+
from typing import Any, Dict, Optional, TypeVar, Type, List, Callable, Union
|
3
|
+
from pydantic import BaseModel, Field, AliasChoices
|
4
|
+
|
5
|
+
from ..utils.sys_util import create_id
|
6
|
+
|
7
|
+
T = TypeVar('T', bound='BaseNode')
|
8
|
+
|
9
|
+
|
10
|
+
class BaseNode(BaseModel):
|
11
|
+
"""
|
12
|
+
BaseNode: A foundational building block for representing a node in a graph-like structure.
|
13
|
+
|
14
|
+
Attributes:
|
15
|
+
id_ (str): Unique identifier for the node, aliased as 'node_id'.
|
16
|
+
content (Optional[Any]): Content or value the node represents.
|
17
|
+
metadata (Dict[str, Any]): A dictionary of metadata related to the node.
|
18
|
+
label (Optional[str]): A label for categorizing or identifying the node.
|
19
|
+
related_nodes (List[str]): A list of identifiers for nodes related to this node.
|
20
|
+
"""
|
21
|
+
id_: str = Field(default_factory=lambda: str(create_id()), alias="node_id")
|
22
|
+
content: Union[str, Dict[str, Any], None, Any] = Field(default=None,
|
23
|
+
validation_alias=AliasChoices('text', 'page_content', 'chunk_content'))
|
24
|
+
metadata: Dict[str, Any] = Field(default_factory=dict)
|
25
|
+
label: Optional[str] = None
|
26
|
+
related_nodes: List[str] = Field(default_factory=list)
|
27
|
+
|
28
|
+
class Config:
|
29
|
+
extra = 'allow'
|
30
|
+
populate_by_name = True
|
31
|
+
validate_assignment = True
|
32
|
+
str_strip_whitespace = True
|
33
|
+
|
34
|
+
def to_json(self) -> str:
|
35
|
+
"""Converts the node instance into JSON string representation."""
|
36
|
+
return self.model_dump_json(by_alias=True)
|
37
|
+
|
38
|
+
@classmethod
|
39
|
+
def from_json(cls: Type[T], json_str: str, **kwargs) -> T:
|
40
|
+
"""
|
41
|
+
Creates a node instance from a JSON string.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
json_str (str): The JSON string representing a node.
|
45
|
+
**kwargs: Additional keyword arguments to pass to json.loads.
|
46
|
+
|
47
|
+
Returns:
|
48
|
+
An instance of BaseNode.
|
49
|
+
|
50
|
+
Raises:
|
51
|
+
ValueError: If the provided string is not valid JSON.
|
52
|
+
"""
|
53
|
+
try:
|
54
|
+
data = json.loads(json_str, **kwargs)
|
55
|
+
return cls(**data)
|
56
|
+
except json.JSONDecodeError as e:
|
57
|
+
raise ValueError("Invalid JSON string provided for deserialization.") from e
|
58
|
+
|
59
|
+
def to_dict(self) -> Dict[str, Any]:
|
60
|
+
"""Converts the node instance into a dictionary representation."""
|
61
|
+
return self.model_dump(by_alias=True)
|
62
|
+
|
63
|
+
@classmethod
|
64
|
+
def from_dict(cls, data: Dict[str, Any]) -> T:
|
65
|
+
"""Creates a node instance from a dictionary."""
|
66
|
+
return cls(**data)
|
67
|
+
|
68
|
+
def copy(self, deep: bool = True, n: int = 1) -> T:
|
69
|
+
"""
|
70
|
+
Creates a copy of the node instance.
|
71
|
+
|
72
|
+
Args:
|
73
|
+
deep (bool): Whether to make a deep copy.
|
74
|
+
n (int): Number of copies to create.
|
75
|
+
|
76
|
+
Returns:
|
77
|
+
A copy or list of copies of the BaseNode instance.
|
78
|
+
"""
|
79
|
+
copies = [self.copy(deep=deep) for _ in range(n)]
|
80
|
+
return copies[0] if n == 1 else copies
|
81
|
+
|
82
|
+
def merge_metadata(self, other_metadata: Dict[str, Any], overwrite: bool = True) -> None:
|
83
|
+
"""
|
84
|
+
Merges another metadata dictionary into the node's metadata.
|
85
|
+
|
86
|
+
Args:
|
87
|
+
other_metadata (Dict[str, Any]): The metadata to merge in.
|
88
|
+
overwrite (bool): Whether to overwrite existing keys in the metadata.
|
89
|
+
"""
|
90
|
+
if not overwrite:
|
91
|
+
other_metadata = {k: v for k, v in other_metadata.items() if k not in self.metadata}
|
92
|
+
self.metadata.update(other_metadata)
|
93
|
+
|
94
|
+
def set_meta(self, metadata_: Dict[str, Any]) -> None:
|
95
|
+
self.metadata = metadata_
|
96
|
+
|
97
|
+
def get_meta(self) -> Dict[str, Any]:
|
98
|
+
return self.metadata
|
99
|
+
|
100
|
+
def set_content(self, content: Optional[Any]) -> None:
|
101
|
+
self.content = content
|
102
|
+
|
103
|
+
def get_content(self) -> Optional[Any]:
|
104
|
+
return self.content
|
105
|
+
|
106
|
+
def set_id(self, id_: str) -> None:
|
107
|
+
self.id_ = id_
|
108
|
+
|
109
|
+
def get_id(self) -> str:
|
110
|
+
return self.id_
|
111
|
+
|
112
|
+
def update_meta(self, **kwargs) -> None:
|
113
|
+
self.metadata.update(kwargs)
|
114
|
+
|
115
|
+
def add_related_node(self, node_id: str) -> None:
|
116
|
+
if node_id not in self.related_nodes:
|
117
|
+
self.related_nodes.append(node_id)
|
118
|
+
|
119
|
+
def remove_related_node(self, node_id: str) -> None:
|
120
|
+
self.related_nodes = [id_ for id_ in self.related_nodes if id_ != node_id]
|
121
|
+
|
122
|
+
def __eq__(self, other: object) -> bool:
|
123
|
+
if not isinstance(other, T):
|
124
|
+
return NotImplemented
|
125
|
+
return self.model_dump() == other.model_dump()
|
126
|
+
|
127
|
+
# def __str__(self) -> str:
|
128
|
+
# """Returns a simple string representation of the BaseNode."""
|
129
|
+
# return f"BaseNode(id={self.id_}, label={self.label})"
|
130
|
+
|
131
|
+
# def __repr__(self) -> str:
|
132
|
+
# """Returns a detailed string representation of the BaseNode."""
|
133
|
+
# return f"BaseNode(id={self.id_}, content={self.content}, metadata={self.metadata}, label={self.label})"
|
134
|
+
|
135
|
+
# Utility Methods
|
136
|
+
def is_empty(self) -> bool:
|
137
|
+
return not self.content and not self.metadata
|
138
|
+
|
139
|
+
def has_label(self, label: str) -> bool:
|
140
|
+
return self.label == label
|
141
|
+
|
142
|
+
def is_metadata_key_present(self, key: str) -> bool:
|
143
|
+
return key in self.metadata
|
144
|
+
|
145
|
+
|
146
|
+
class DataNode(BaseNode):
|
147
|
+
|
148
|
+
def to_llama_index(self, **kwargs):
|
149
|
+
# to llama_index textnode
|
150
|
+
from lionagi.bridge.llama_index import to_llama_index_textnode
|
151
|
+
return to_llama_index_textnode(self, **kwargs)
|
152
|
+
|
153
|
+
def to_langchain(self, **kwargs):
|
154
|
+
# to langchain document
|
155
|
+
from lionagi.bridge.langchain import to_langchain_document
|
156
|
+
return to_langchain_document(self, **kwargs)
|
157
|
+
|
158
|
+
@classmethod
|
159
|
+
def from_llama_index(cls, llama_node: Any, **kwargs):
|
160
|
+
llama_dict = llama_node.to_dict(**kwargs)
|
161
|
+
return cls.from_dict(llama_dict)
|
162
|
+
|
163
|
+
@classmethod
|
164
|
+
def from_langchain(cls, lc_doc: Any):
|
165
|
+
info_json = lc_doc.to_json()
|
166
|
+
info_node = {'lc_id': info_json['id']}
|
167
|
+
info_node = {**info_node, **info_json['kwargs']}
|
168
|
+
return cls(**info_node)
|
169
|
+
|
170
|
+
# def __repr__(self) -> str:
|
171
|
+
# return f"DataNode(id={self.id_}, content={self.content}, metadata={self.metadata}, label={self.label})"
|
172
|
+
|
173
|
+
# def __str__(self) -> str:
|
174
|
+
# return f"DataNode(id={self.id_}, label={self.label})"
|
175
|
+
|
176
|
+
|
177
|
+
class File(DataNode):
|
178
|
+
|
179
|
+
...
|
180
|
+
|
181
|
+
|
182
|
+
class Chunk(DataNode):
|
183
|
+
|
184
|
+
...
|
185
|
+
|
186
|
+
|
187
|
+
class Message(BaseNode):
|
188
|
+
"""
|
189
|
+
Message: A specialized type of BaseNode for handling messages.
|
190
|
+
|
191
|
+
This class represents a message node, extending the BaseNode with additional
|
192
|
+
attributes specific to messages, such as role and name, and provides methods
|
193
|
+
for message-specific operations.
|
194
|
+
|
195
|
+
Attributes:
|
196
|
+
role (Optional[str]): The role of the message, e.g., 'sender', 'receiver'.
|
197
|
+
name (Optional[str]): The name associated with the message, e.g., a user name or system name.
|
198
|
+
"""
|
199
|
+
|
200
|
+
role: Optional[str] = None
|
201
|
+
name: Optional[str] = None
|
202
|
+
|
203
|
+
def _to_message(self):
|
204
|
+
"""
|
205
|
+
Converts the message node to a dictionary representation suitable for messaging purposes.
|
206
|
+
|
207
|
+
The method serializes the content attribute to a JSON string if it is a dictionary.
|
208
|
+
Otherwise, it keeps the content as is.
|
209
|
+
|
210
|
+
Returns:
|
211
|
+
A dictionary representing the message with 'role' and 'content' keys.
|
212
|
+
"""
|
213
|
+
out = {
|
214
|
+
"role": self.role,
|
215
|
+
"content": json.dumps(self.content) if isinstance(self.content, dict) else self.content
|
216
|
+
}
|
217
|
+
return out
|
218
|
+
|
219
|
+
def _create_role_message(self, role_: str,
|
220
|
+
content: Any,
|
221
|
+
content_key: str,
|
222
|
+
name: Optional[str] = None
|
223
|
+
) -> None:
|
224
|
+
"""
|
225
|
+
Creates a message with a specific role, content, and an optional name.
|
226
|
+
|
227
|
+
This method sets up the message node with the specified role, content, and name. The content
|
228
|
+
is stored in a dictionary under the provided content_key.
|
229
|
+
|
230
|
+
Args:
|
231
|
+
role_ (str): The role of the message.
|
232
|
+
content (Any): The content of the message.
|
233
|
+
content_key (str): The key under which the content will be stored.
|
234
|
+
name (Optional[str]): The name associated with the message. Defaults to the role if not provided.
|
235
|
+
"""
|
236
|
+
self.role = role_
|
237
|
+
self.content = {content_key: content}
|
238
|
+
self.name = name or role_
|
239
|
+
|