lionagi 0.0.112__py3-none-any.whl → 0.0.113__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lionagi/__init__.py +3 -3
- lionagi/bridge/__init__.py +7 -0
- lionagi/bridge/langchain.py +131 -0
- lionagi/bridge/llama_index.py +157 -0
- lionagi/configs/__init__.py +7 -0
- lionagi/configs/oai_configs.py +49 -0
- lionagi/configs/openrouter_config.py +49 -0
- lionagi/core/__init__.py +8 -2
- lionagi/core/instruction_sets.py +1 -3
- lionagi/core/messages.py +2 -2
- lionagi/core/sessions.py +174 -27
- lionagi/datastore/__init__.py +1 -0
- lionagi/loader/__init__.py +9 -4
- lionagi/loader/chunker.py +157 -0
- lionagi/loader/reader.py +124 -0
- lionagi/objs/__init__.py +7 -0
- lionagi/objs/messenger.py +163 -0
- lionagi/objs/tool_registry.py +247 -0
- lionagi/schema/__init__.py +11 -0
- lionagi/schema/base_schema.py +239 -0
- lionagi/schema/base_tool.py +9 -0
- lionagi/schema/data_logger.py +94 -0
- lionagi/services/__init__.py +14 -0
- lionagi/{service_/oai.py → services/base_api_service.py} +49 -82
- lionagi/{endpoint/base_endpoint.py → services/chatcompletion.py} +19 -22
- lionagi/services/oai.py +34 -0
- lionagi/services/openrouter.py +32 -0
- lionagi/{service_/service_utils.py → services/service_objs.py} +0 -1
- lionagi/structure/__init__.py +7 -0
- lionagi/structure/relationship.py +128 -0
- lionagi/structure/structure.py +160 -0
- lionagi/tests/test_flatten_util.py +426 -0
- lionagi/tools/__init__.py +0 -5
- lionagi/tools/coder.py +1 -0
- lionagi/tools/scorer.py +1 -0
- lionagi/tools/validator.py +1 -0
- lionagi/utils/__init__.py +46 -20
- lionagi/utils/api_util.py +86 -0
- lionagi/utils/call_util.py +347 -0
- lionagi/utils/flat_util.py +540 -0
- lionagi/utils/io_util.py +102 -0
- lionagi/utils/load_utils.py +190 -0
- lionagi/utils/sys_util.py +191 -0
- lionagi/utils/tool_util.py +92 -0
- lionagi/utils/type_util.py +81 -0
- lionagi/version.py +1 -1
- {lionagi-0.0.112.dist-info → lionagi-0.0.113.dist-info}/METADATA +37 -13
- lionagi-0.0.113.dist-info/RECORD +84 -0
- lionagi/endpoint/chat_completion.py +0 -20
- lionagi/endpoint/endpoint_utils.py +0 -0
- lionagi/llm_configs.py +0 -21
- lionagi/loader/load_utils.py +0 -161
- lionagi/schema.py +0 -275
- lionagi/service_/__init__.py +0 -6
- lionagi/service_/base_service.py +0 -48
- lionagi/service_/openrouter.py +0 -1
- lionagi/services.py +0 -1
- lionagi/tools/tool_utils.py +0 -75
- lionagi/utils/sys_utils.py +0 -799
- lionagi-0.0.112.dist-info/RECORD +0 -67
- /lionagi/{core/responses.py → datastore/chroma.py} +0 -0
- /lionagi/{endpoint/assistants.py → datastore/deeplake.py} +0 -0
- /lionagi/{endpoint/audio.py → datastore/elasticsearch.py} +0 -0
- /lionagi/{endpoint/embeddings.py → datastore/lantern.py} +0 -0
- /lionagi/{endpoint/files.py → datastore/pinecone.py} +0 -0
- /lionagi/{endpoint/fine_tuning.py → datastore/postgres.py} +0 -0
- /lionagi/{endpoint/images.py → datastore/qdrant.py} +0 -0
- /lionagi/{endpoint/messages.py → schema/base_condition.py} +0 -0
- /lionagi/{service_ → services}/anthropic.py +0 -0
- /lionagi/{service_ → services}/anyscale.py +0 -0
- /lionagi/{service_ → services}/azure.py +0 -0
- /lionagi/{service_ → services}/bedrock.py +0 -0
- /lionagi/{service_ → services}/everlyai.py +0 -0
- /lionagi/{service_ → services}/gemini.py +0 -0
- /lionagi/{service_ → services}/gpt4all.py +0 -0
- /lionagi/{service_ → services}/huggingface.py +0 -0
- /lionagi/{service_ → services}/litellm.py +0 -0
- /lionagi/{service_ → services}/localai.py +0 -0
- /lionagi/{service_ → services}/mistralai.py +0 -0
- /lionagi/{service_ → services}/ollama.py +0 -0
- /lionagi/{service_ → services}/openllm.py +0 -0
- /lionagi/{service_ → services}/perplexity.py +0 -0
- /lionagi/{service_ → services}/predibase.py +0 -0
- /lionagi/{service_ → services}/rungpt.py +0 -0
- /lionagi/{service_ → services}/vllm.py +0 -0
- /lionagi/{service_ → services}/xinference.py +0 -0
- /lionagi/{endpoint → tests}/__init__.py +0 -0
- /lionagi/{endpoint/models.py → tools/planner.py} +0 -0
- /lionagi/{endpoint/moderations.py → tools/prompter.py} +0 -0
- /lionagi/{endpoint/runs.py → tools/sandbox.py} +0 -0
- /lionagi/{endpoint/threads.py → tools/summarizer.py} +0 -0
- {lionagi-0.0.112.dist-info → lionagi-0.0.113.dist-info}/LICENSE +0 -0
- {lionagi-0.0.112.dist-info → lionagi-0.0.113.dist-info}/WHEEL +0 -0
- {lionagi-0.0.112.dist-info → lionagi-0.0.113.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,163 @@
|
|
1
|
+
# from typing import Optional, Any, Union, Dict, Tuple
|
2
|
+
|
3
|
+
# from lionagi.utils.call_util import lcall
|
4
|
+
# from lionagi.schema.data_logger import DataLogger
|
5
|
+
# from lionagi.core.messages import Message, Response, Instruction, System
|
6
|
+
|
7
|
+
|
8
|
+
# class Messenger:
|
9
|
+
# """
|
10
|
+
# Messenger handles the creation, logging, and exporting of messages.
|
11
|
+
|
12
|
+
# This class is responsible for creating various types of messages (system, instruction, response),
|
13
|
+
# logging them, and optionally exporting the log to a CSV file.
|
14
|
+
|
15
|
+
# Attributes:
|
16
|
+
# _logger (DataLogger): An instance of DataLogger to manage message logging.
|
17
|
+
|
18
|
+
# Methods:
|
19
|
+
# set_dir: Sets the directory for the DataLogger to save CSV files.
|
20
|
+
|
21
|
+
# set_log: Sets the log object for the DataLogger.
|
22
|
+
|
23
|
+
# log_message: Logs a message in JSON format.
|
24
|
+
|
25
|
+
# to_csv: Exports logged messages to a CSV file.
|
26
|
+
|
27
|
+
# _create_message: Internal method to create a specific type of message.
|
28
|
+
|
29
|
+
# create_message: Public interface to create messages, log them, and optionally return them in different formats.
|
30
|
+
# """
|
31
|
+
|
32
|
+
# def __init__(self) -> None:
|
33
|
+
# """
|
34
|
+
# Initializes the Messenger with a DataLogger instance.
|
35
|
+
# """
|
36
|
+
# self._logger = DataLogger()
|
37
|
+
|
38
|
+
# def set_dir(self, dir: str) -> None:
|
39
|
+
# """
|
40
|
+
# Sets the directory where the DataLogger will save CSV files.
|
41
|
+
|
42
|
+
# Parameters:
|
43
|
+
# dir (str): The directory path to set for the DataLogger.
|
44
|
+
# """
|
45
|
+
# self._logger.dir = dir
|
46
|
+
|
47
|
+
# def set_log(self, log) -> None:
|
48
|
+
# """
|
49
|
+
# Sets the log object for the DataLogger.
|
50
|
+
|
51
|
+
# Parameters:
|
52
|
+
# log: The log object to be used by the DataLogger.
|
53
|
+
# """
|
54
|
+
# self._logger.log = log
|
55
|
+
|
56
|
+
# def log_message(self, msg: Message) -> None:
|
57
|
+
# """
|
58
|
+
# Logs a message in JSON format using the DataLogger.
|
59
|
+
|
60
|
+
# Parameters:
|
61
|
+
# msg (Message): The message object to be logged.
|
62
|
+
# """
|
63
|
+
# self._logger(msg.to_json())
|
64
|
+
|
65
|
+
# def to_csv(self, **kwargs) -> None:
|
66
|
+
# """
|
67
|
+
# Exports the logged messages to a CSV file.
|
68
|
+
|
69
|
+
# Parameters:
|
70
|
+
# **kwargs: Additional keyword arguments to be passed to the DataLogger's to_csv method.
|
71
|
+
# """
|
72
|
+
# self._logger.to_csv(**kwargs)
|
73
|
+
|
74
|
+
# def _create_message(self,
|
75
|
+
# system: Optional[Any] = None,
|
76
|
+
# instruction: Optional[Any] = None,
|
77
|
+
# context: Optional[Any] = None,
|
78
|
+
# response: Optional[Any] = None,
|
79
|
+
# name: Optional[str] = None) -> Message:
|
80
|
+
# """
|
81
|
+
# Creates a specific type of message based on the provided parameters.
|
82
|
+
|
83
|
+
# Parameters:
|
84
|
+
# system (Optional[Any]): System message content.
|
85
|
+
|
86
|
+
# instruction (Optional[Any]): Instruction message content.
|
87
|
+
|
88
|
+
# context (Optional[Any]): Context for the instruction message.
|
89
|
+
|
90
|
+
# response (Optional[Any]): Response message content.
|
91
|
+
|
92
|
+
# name (Optional[str]): Name associated with the message.
|
93
|
+
|
94
|
+
# Returns:
|
95
|
+
# Message: The created message object of type Response, Instruction, or System.
|
96
|
+
|
97
|
+
# Raises:
|
98
|
+
# ValueError: If more than one or none of the message content parameters (system, instruction, response) are provided.
|
99
|
+
# """
|
100
|
+
|
101
|
+
# if sum(lcall([system, instruction, response], bool)) != 1:
|
102
|
+
# raise ValueError("Error: Message must have one and only one role.")
|
103
|
+
|
104
|
+
# else:
|
105
|
+
# msg = 0
|
106
|
+
|
107
|
+
# if response:
|
108
|
+
# msg = Response()
|
109
|
+
# msg.create_message(response=response,
|
110
|
+
# name=name,)
|
111
|
+
# elif instruction:
|
112
|
+
# msg = Instruction()
|
113
|
+
# msg.create_message(instruction=instruction,
|
114
|
+
# context=context,
|
115
|
+
# name=name,)
|
116
|
+
# elif system:
|
117
|
+
# msg = System()
|
118
|
+
# msg.create_message(system=system,
|
119
|
+
# name=name,)
|
120
|
+
# return msg
|
121
|
+
|
122
|
+
# def create_message(self,
|
123
|
+
# system: Optional[Any] = None,
|
124
|
+
# instruction: Optional[Any] = None,
|
125
|
+
# context: Optional[Any] = None,
|
126
|
+
# response: Optional[Any] = None,
|
127
|
+
# name: Optional[str] = None,
|
128
|
+
# obj: bool = False,
|
129
|
+
# log_: bool = True) -> Union[Message, Tuple[Message, Dict]]:
|
130
|
+
# """
|
131
|
+
# Creates and optionally logs a message, returning it in different formats based on parameters.
|
132
|
+
|
133
|
+
# Parameters:
|
134
|
+
# system (Optional[Any]): System message content.
|
135
|
+
|
136
|
+
# instruction (Optional[Any]): Instruction message content.
|
137
|
+
|
138
|
+
# context (Optional[Any]): Context for the instruction message.
|
139
|
+
|
140
|
+
# response (Optional[Any]): Response message content.
|
141
|
+
|
142
|
+
# name (Optional[str]): Name associated with the message.
|
143
|
+
|
144
|
+
# obj (bool): If True, returns the Message object and its dictionary representation. Defaults to False.
|
145
|
+
|
146
|
+
# log_ (bool): If True, logs the created message. Defaults to True.
|
147
|
+
|
148
|
+
# Returns:
|
149
|
+
# Union[Message, Tuple[Message, Dict]]: The created message in the specified format.
|
150
|
+
# """
|
151
|
+
|
152
|
+
# msg = self._create_message(system=system,
|
153
|
+
# instruction=instruction,
|
154
|
+
# context=context,
|
155
|
+
# response=response,
|
156
|
+
# name=name)
|
157
|
+
# if log_:
|
158
|
+
# self.log_message(msg)
|
159
|
+
# if obj:
|
160
|
+
# return (msg, msg._to_message())
|
161
|
+
# else:
|
162
|
+
# return msg._to_message()
|
163
|
+
|
@@ -0,0 +1,247 @@
|
|
1
|
+
import json
|
2
|
+
import asyncio
|
3
|
+
from typing import Dict
|
4
|
+
from ..utils import lcall, str_to_num
|
5
|
+
from ..schema import BaseNode
|
6
|
+
|
7
|
+
|
8
|
+
class ToolManager(BaseNode):
|
9
|
+
registry : Dict = {}
|
10
|
+
|
11
|
+
def _name_existed(self, name: str):
|
12
|
+
return True if name in self.registry.keys() else False
|
13
|
+
|
14
|
+
def _register_tool(self, tool, name: str=None, update=False, new=False, prefix=None, postfix=None):
|
15
|
+
|
16
|
+
if self._name_existed(name):
|
17
|
+
if update and new:
|
18
|
+
raise ValueError(f"Cannot both update and create new registry for existing function {name} at the same time")
|
19
|
+
|
20
|
+
if len(name) > len(tool.func.__name__):
|
21
|
+
if new and not postfix:
|
22
|
+
try:
|
23
|
+
idx = str_to_num(name[-3:], int)
|
24
|
+
if idx > 0:
|
25
|
+
postfix = idx + 1
|
26
|
+
except:
|
27
|
+
pass
|
28
|
+
|
29
|
+
name = f"{prefix or ''}{name}{postfix}" if new else tool.func.__name__
|
30
|
+
self.registry.update({name:tool})
|
31
|
+
|
32
|
+
async def invoke(self, func_call):
|
33
|
+
name, kwargs = func_call
|
34
|
+
if self._name_existed(name):
|
35
|
+
tool = self.registry[name]
|
36
|
+
func = tool.func
|
37
|
+
parser = tool.parser
|
38
|
+
try:
|
39
|
+
if asyncio.iscoroutinefunction(func):
|
40
|
+
return await parser(func(**kwargs)) if parser else func(**kwargs)
|
41
|
+
else:
|
42
|
+
return parser(func(**kwargs)) if parser else func(**kwargs)
|
43
|
+
except Exception as e:
|
44
|
+
raise ValueError(f"Error when invoking function {name} with arguments {kwargs} with error message {e}")
|
45
|
+
else:
|
46
|
+
raise ValueError(f"Function {name} is not registered.")
|
47
|
+
|
48
|
+
@staticmethod
|
49
|
+
def _get_function_call(response):
|
50
|
+
"""
|
51
|
+
Extract function name and arguments from a response JSON.
|
52
|
+
|
53
|
+
Parameters:
|
54
|
+
response (dict): The JSON response containing function information.
|
55
|
+
|
56
|
+
Returns:
|
57
|
+
Tuple[str, dict]: The function name and its arguments.
|
58
|
+
"""
|
59
|
+
try:
|
60
|
+
# out = json.loads(response)
|
61
|
+
func = response['function'][5:]
|
62
|
+
args = json.loads(response['arguments'])
|
63
|
+
return (func, args)
|
64
|
+
except:
|
65
|
+
try:
|
66
|
+
# out = json.loads(response)
|
67
|
+
# out = out['tool_uses'][0]
|
68
|
+
func = response['recipient_name'].split('.')[-1]
|
69
|
+
args = response['parameters']
|
70
|
+
return (func, args)
|
71
|
+
except:
|
72
|
+
raise ValueError('response is not a valid function call')
|
73
|
+
|
74
|
+
def register_tools(self, tools, update=False, new=False, prefix=None, postfix=None ):
|
75
|
+
lcall(tools, self._register_tool, update=update, new=new, prefix=prefix, postfix=postfix)
|
76
|
+
|
77
|
+
|
78
|
+
|
79
|
+
# import asyncio
|
80
|
+
# import json
|
81
|
+
# from typing import Dict, Any, Optional, List, Tuple
|
82
|
+
|
83
|
+
# from lionagi.schema.base_tool import Tool
|
84
|
+
# from lionagi.utils.type_util import to_list
|
85
|
+
# from lionagi.utils.tool_util import func_to_tool
|
86
|
+
|
87
|
+
# class ToolRegistry:
|
88
|
+
# """
|
89
|
+
# ToolManager manages the registration and invocation of tools.
|
90
|
+
|
91
|
+
# This class provides functionalities to register tools, check for their existence,
|
92
|
+
# and invoke them dynamically with specified arguments.
|
93
|
+
|
94
|
+
# Attributes:
|
95
|
+
# registry (Dict[str, BaseTool]): A dictionary to store registered tools by name.
|
96
|
+
|
97
|
+
# Methods:
|
98
|
+
# _name_exists: Checks if a tool name already exists in the registry.
|
99
|
+
|
100
|
+
# _register_tool: Registers a tool in the registry.
|
101
|
+
|
102
|
+
# invoke: Dynamically invokes a registered tool with given arguments.
|
103
|
+
|
104
|
+
# register_tools: Registers multiple tools in the registry.
|
105
|
+
# """
|
106
|
+
|
107
|
+
# def __init__(self):
|
108
|
+
# """
|
109
|
+
# Initializes the ToolManager with an empty registry.
|
110
|
+
# """
|
111
|
+
# self.registry: Dict[str, Tool] = {}
|
112
|
+
|
113
|
+
# def _name_exists(self, name: str) -> bool:
|
114
|
+
# """
|
115
|
+
# Checks if a tool name already exists in the registry.
|
116
|
+
|
117
|
+
# Parameters:
|
118
|
+
# name (str): The name of the tool to check.
|
119
|
+
|
120
|
+
# Returns:
|
121
|
+
# bool: True if the name exists in the registry, False otherwise.
|
122
|
+
# """
|
123
|
+
# return name in self.registry
|
124
|
+
|
125
|
+
# def _register_tool(self, tool: Tool, name: Optional[str] = None, update: bool = False, new: bool = False, prefix: Optional[str] = None, postfix: Optional[int] = None):
|
126
|
+
# """
|
127
|
+
# Registers a tool in the registry.
|
128
|
+
|
129
|
+
# Parameters:
|
130
|
+
# tool (BaseTool): The tool to be registered.
|
131
|
+
|
132
|
+
# name (Optional[str]): The name to register the tool with. Defaults to the tool's function name.
|
133
|
+
|
134
|
+
# update (bool): If True, updates the existing tool. Defaults to False.
|
135
|
+
|
136
|
+
# new (bool): If True, creates a new registry entry. Defaults to False.
|
137
|
+
|
138
|
+
# prefix (Optional[str]): A prefix for the tool name.
|
139
|
+
|
140
|
+
# postfix (Optional[int]): A postfix for the tool name.
|
141
|
+
|
142
|
+
# Raises:
|
143
|
+
# ValueError: If both update and new are True for an existing function.
|
144
|
+
# """
|
145
|
+
# name = name or tool.func.__name__
|
146
|
+
# original_name = name
|
147
|
+
|
148
|
+
# if self._name_exists(name):
|
149
|
+
# if update and new:
|
150
|
+
# raise ValueError("Cannot both update and create new registry for existing function.")
|
151
|
+
# if new:
|
152
|
+
# idx = 1
|
153
|
+
# while self._name_exists(f"{prefix or ''}{name}{postfix or ''}{idx}"):
|
154
|
+
# idx += 1
|
155
|
+
# name = f"{prefix or ''}{name}{postfix or ''}{idx}"
|
156
|
+
# else:
|
157
|
+
# self.registry.pop(original_name, None)
|
158
|
+
|
159
|
+
# self.registry[name] = tool
|
160
|
+
|
161
|
+
# async def invoke(self, name_kwargs: Tuple) -> Any:
|
162
|
+
# """
|
163
|
+
# Dynamically invokes a registered tool with given arguments.
|
164
|
+
|
165
|
+
# Parameters:
|
166
|
+
# name (str): The name of the tool to invoke.
|
167
|
+
|
168
|
+
# kwargs (Dict[str, Any]): A dictionary of keyword arguments to pass to the tool.
|
169
|
+
|
170
|
+
# Returns:
|
171
|
+
# Any: The result of the tool invocation.
|
172
|
+
|
173
|
+
# Raises:
|
174
|
+
# ValueError: If the tool is not registered or if an error occurs during invocation.
|
175
|
+
# """
|
176
|
+
# name, kwargs = name_kwargs
|
177
|
+
# if not self._name_exists(name):
|
178
|
+
# raise ValueError(f"Function {name} is not registered.")
|
179
|
+
|
180
|
+
# tool = self.registry[name]
|
181
|
+
# func = tool.func
|
182
|
+
# parser = tool.parser
|
183
|
+
|
184
|
+
# try:
|
185
|
+
# result = await func(**kwargs) if asyncio.iscoroutinefunction(func) else func(**kwargs)
|
186
|
+
# return await parser(result) if parser and asyncio.iscoroutinefunction(parser) else parser(result) if parser else result
|
187
|
+
# except Exception as e:
|
188
|
+
# raise ValueError(f"Error invoking function {name}: {str(e)}")
|
189
|
+
|
190
|
+
# def register_tools(self, tools: List[Tool], update: bool = False, new: bool = False,
|
191
|
+
# prefix: Optional[str] = None, postfix: Optional[int] = None):
|
192
|
+
# """
|
193
|
+
# Registers multiple tools in the registry.
|
194
|
+
|
195
|
+
# Parameters:
|
196
|
+
# tools (List[BaseTool]): A list of tools to register.
|
197
|
+
|
198
|
+
# update (bool): If True, updates existing tools. Defaults to False.
|
199
|
+
|
200
|
+
# new (bool): If True, creates new registry entries. Defaults to False.
|
201
|
+
|
202
|
+
# prefix (Optional[str]): A prefix for the tool names.
|
203
|
+
|
204
|
+
# postfix (Optional[int]): A postfix for the tool names.
|
205
|
+
# """
|
206
|
+
# for tool in tools:
|
207
|
+
# self._register_tool(tool, update=update, new=new, prefix=prefix, postfix=postfix)
|
208
|
+
|
209
|
+
# def _register_func(self, func_, parser=None, **kwargs):
|
210
|
+
# # kwargs for _register_tool
|
211
|
+
|
212
|
+
# tool = func_to_tool(func_=func_, parser=parser)
|
213
|
+
# self._register_tool(tool=tool, **kwargs)
|
214
|
+
|
215
|
+
# def register_funcs(self, funcs, parsers=None, **kwargs):
|
216
|
+
# funcs, parsers = to_list(funcs), to_list(parsers)
|
217
|
+
# if parsers is not None and len(parsers) != len(funcs):
|
218
|
+
# raise ValueError("The number of funcs and tools should be the same")
|
219
|
+
# parsers = parsers or [None for _ in range(len(funcs))]
|
220
|
+
|
221
|
+
# for i, func in enumerate(funcs):
|
222
|
+
# self._register_func(func_=func, parser=parsers[i], **kwargs)
|
223
|
+
|
224
|
+
# @staticmethod
|
225
|
+
# def get_function_call(response):
|
226
|
+
# """
|
227
|
+
# Extract function name and arguments from a response JSON.
|
228
|
+
|
229
|
+
# Parameters:
|
230
|
+
# response (dict): The JSON response containing function information.
|
231
|
+
|
232
|
+
# Returns:
|
233
|
+
# Tuple[str, dict]: The function name and its arguments.
|
234
|
+
# """
|
235
|
+
# try:
|
236
|
+
# # out = json.loads(response)
|
237
|
+
# func = response['function'][5:]
|
238
|
+
# args = json.loads(response['arguments'])
|
239
|
+
# return (func, args)
|
240
|
+
# except:
|
241
|
+
# try:
|
242
|
+
# func = response['recipient_name'].split('.')[-1]
|
243
|
+
# args = response['parameters']
|
244
|
+
# return (func, args)
|
245
|
+
# except:
|
246
|
+
# raise ValueError('response is not a valid function call')
|
247
|
+
|
@@ -0,0 +1,239 @@
|
|
1
|
+
import json
|
2
|
+
from typing import Any, Dict, Optional, TypeVar, Type, List, Callable, Union
|
3
|
+
from pydantic import BaseModel, Field, AliasChoices
|
4
|
+
|
5
|
+
from ..utils.sys_util import create_id
|
6
|
+
|
7
|
+
T = TypeVar('T', bound='BaseNode')
|
8
|
+
|
9
|
+
|
10
|
+
class BaseNode(BaseModel):
|
11
|
+
"""
|
12
|
+
BaseNode: A foundational building block for representing a node in a graph-like structure.
|
13
|
+
|
14
|
+
Attributes:
|
15
|
+
id_ (str): Unique identifier for the node, aliased as 'node_id'.
|
16
|
+
content (Optional[Any]): Content or value the node represents.
|
17
|
+
metadata (Dict[str, Any]): A dictionary of metadata related to the node.
|
18
|
+
label (Optional[str]): A label for categorizing or identifying the node.
|
19
|
+
related_nodes (List[str]): A list of identifiers for nodes related to this node.
|
20
|
+
"""
|
21
|
+
id_: str = Field(default_factory=lambda: str(create_id()), alias="node_id")
|
22
|
+
content: Union[str, Dict[str, Any], None, Any] = Field(default=None,
|
23
|
+
validation_alias=AliasChoices('text', 'page_content', 'chunk_content'))
|
24
|
+
metadata: Dict[str, Any] = Field(default_factory=dict)
|
25
|
+
label: Optional[str] = None
|
26
|
+
related_nodes: List[str] = Field(default_factory=list)
|
27
|
+
|
28
|
+
class Config:
|
29
|
+
extra = 'allow'
|
30
|
+
populate_by_name = True
|
31
|
+
validate_assignment = True
|
32
|
+
str_strip_whitespace = True
|
33
|
+
|
34
|
+
def to_json(self) -> str:
|
35
|
+
"""Converts the node instance into JSON string representation."""
|
36
|
+
return self.model_dump_json(by_alias=True)
|
37
|
+
|
38
|
+
@classmethod
|
39
|
+
def from_json(cls: Type[T], json_str: str, **kwargs) -> T:
|
40
|
+
"""
|
41
|
+
Creates a node instance from a JSON string.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
json_str (str): The JSON string representing a node.
|
45
|
+
**kwargs: Additional keyword arguments to pass to json.loads.
|
46
|
+
|
47
|
+
Returns:
|
48
|
+
An instance of BaseNode.
|
49
|
+
|
50
|
+
Raises:
|
51
|
+
ValueError: If the provided string is not valid JSON.
|
52
|
+
"""
|
53
|
+
try:
|
54
|
+
data = json.loads(json_str, **kwargs)
|
55
|
+
return cls(**data)
|
56
|
+
except json.JSONDecodeError as e:
|
57
|
+
raise ValueError("Invalid JSON string provided for deserialization.") from e
|
58
|
+
|
59
|
+
def to_dict(self) -> Dict[str, Any]:
|
60
|
+
"""Converts the node instance into a dictionary representation."""
|
61
|
+
return self.model_dump(by_alias=True)
|
62
|
+
|
63
|
+
@classmethod
|
64
|
+
def from_dict(cls, data: Dict[str, Any]) -> T:
|
65
|
+
"""Creates a node instance from a dictionary."""
|
66
|
+
return cls(**data)
|
67
|
+
|
68
|
+
def copy(self, deep: bool = True, n: int = 1) -> T:
|
69
|
+
"""
|
70
|
+
Creates a copy of the node instance.
|
71
|
+
|
72
|
+
Args:
|
73
|
+
deep (bool): Whether to make a deep copy.
|
74
|
+
n (int): Number of copies to create.
|
75
|
+
|
76
|
+
Returns:
|
77
|
+
A copy or list of copies of the BaseNode instance.
|
78
|
+
"""
|
79
|
+
copies = [self.copy(deep=deep) for _ in range(n)]
|
80
|
+
return copies[0] if n == 1 else copies
|
81
|
+
|
82
|
+
def merge_metadata(self, other_metadata: Dict[str, Any], overwrite: bool = True) -> None:
|
83
|
+
"""
|
84
|
+
Merges another metadata dictionary into the node's metadata.
|
85
|
+
|
86
|
+
Args:
|
87
|
+
other_metadata (Dict[str, Any]): The metadata to merge in.
|
88
|
+
overwrite (bool): Whether to overwrite existing keys in the metadata.
|
89
|
+
"""
|
90
|
+
if not overwrite:
|
91
|
+
other_metadata = {k: v for k, v in other_metadata.items() if k not in self.metadata}
|
92
|
+
self.metadata.update(other_metadata)
|
93
|
+
|
94
|
+
def set_meta(self, metadata_: Dict[str, Any]) -> None:
|
95
|
+
self.metadata = metadata_
|
96
|
+
|
97
|
+
def get_meta(self) -> Dict[str, Any]:
|
98
|
+
return self.metadata
|
99
|
+
|
100
|
+
def set_content(self, content: Optional[Any]) -> None:
|
101
|
+
self.content = content
|
102
|
+
|
103
|
+
def get_content(self) -> Optional[Any]:
|
104
|
+
return self.content
|
105
|
+
|
106
|
+
def set_id(self, id_: str) -> None:
|
107
|
+
self.id_ = id_
|
108
|
+
|
109
|
+
def get_id(self) -> str:
|
110
|
+
return self.id_
|
111
|
+
|
112
|
+
def update_meta(self, **kwargs) -> None:
|
113
|
+
self.metadata.update(kwargs)
|
114
|
+
|
115
|
+
def add_related_node(self, node_id: str) -> None:
|
116
|
+
if node_id not in self.related_nodes:
|
117
|
+
self.related_nodes.append(node_id)
|
118
|
+
|
119
|
+
def remove_related_node(self, node_id: str) -> None:
|
120
|
+
self.related_nodes = [id_ for id_ in self.related_nodes if id_ != node_id]
|
121
|
+
|
122
|
+
def __eq__(self, other: object) -> bool:
|
123
|
+
if not isinstance(other, T):
|
124
|
+
return NotImplemented
|
125
|
+
return self.model_dump() == other.model_dump()
|
126
|
+
|
127
|
+
# def __str__(self) -> str:
|
128
|
+
# """Returns a simple string representation of the BaseNode."""
|
129
|
+
# return f"BaseNode(id={self.id_}, label={self.label})"
|
130
|
+
|
131
|
+
# def __repr__(self) -> str:
|
132
|
+
# """Returns a detailed string representation of the BaseNode."""
|
133
|
+
# return f"BaseNode(id={self.id_}, content={self.content}, metadata={self.metadata}, label={self.label})"
|
134
|
+
|
135
|
+
# Utility Methods
|
136
|
+
def is_empty(self) -> bool:
|
137
|
+
return not self.content and not self.metadata
|
138
|
+
|
139
|
+
def has_label(self, label: str) -> bool:
|
140
|
+
return self.label == label
|
141
|
+
|
142
|
+
def is_metadata_key_present(self, key: str) -> bool:
|
143
|
+
return key in self.metadata
|
144
|
+
|
145
|
+
|
146
|
+
class DataNode(BaseNode):
|
147
|
+
|
148
|
+
def to_llama_index(self, **kwargs):
|
149
|
+
# to llama_index textnode
|
150
|
+
from lionagi.bridge.llama_index import to_llama_index_textnode
|
151
|
+
return to_llama_index_textnode(self, **kwargs)
|
152
|
+
|
153
|
+
def to_langchain(self, **kwargs):
|
154
|
+
# to langchain document
|
155
|
+
from lionagi.bridge.langchain import to_langchain_document
|
156
|
+
return to_langchain_document(self, **kwargs)
|
157
|
+
|
158
|
+
@classmethod
|
159
|
+
def from_llama_index(cls, llama_node: Any, **kwargs):
|
160
|
+
llama_dict = llama_node.to_dict(**kwargs)
|
161
|
+
return cls.from_dict(llama_dict)
|
162
|
+
|
163
|
+
@classmethod
|
164
|
+
def from_langchain(cls, lc_doc: Any):
|
165
|
+
info_json = lc_doc.to_json()
|
166
|
+
info_node = {'lc_id': info_json['id']}
|
167
|
+
info_node = {**info_node, **info_json['kwargs']}
|
168
|
+
return cls(**info_node)
|
169
|
+
|
170
|
+
# def __repr__(self) -> str:
|
171
|
+
# return f"DataNode(id={self.id_}, content={self.content}, metadata={self.metadata}, label={self.label})"
|
172
|
+
|
173
|
+
# def __str__(self) -> str:
|
174
|
+
# return f"DataNode(id={self.id_}, label={self.label})"
|
175
|
+
|
176
|
+
|
177
|
+
class File(DataNode):
|
178
|
+
|
179
|
+
...
|
180
|
+
|
181
|
+
|
182
|
+
class Chunk(DataNode):
|
183
|
+
|
184
|
+
...
|
185
|
+
|
186
|
+
|
187
|
+
class Message(BaseNode):
|
188
|
+
"""
|
189
|
+
Message: A specialized type of BaseNode for handling messages.
|
190
|
+
|
191
|
+
This class represents a message node, extending the BaseNode with additional
|
192
|
+
attributes specific to messages, such as role and name, and provides methods
|
193
|
+
for message-specific operations.
|
194
|
+
|
195
|
+
Attributes:
|
196
|
+
role (Optional[str]): The role of the message, e.g., 'sender', 'receiver'.
|
197
|
+
name (Optional[str]): The name associated with the message, e.g., a user name or system name.
|
198
|
+
"""
|
199
|
+
|
200
|
+
role: Optional[str] = None
|
201
|
+
name: Optional[str] = None
|
202
|
+
|
203
|
+
def _to_message(self):
|
204
|
+
"""
|
205
|
+
Converts the message node to a dictionary representation suitable for messaging purposes.
|
206
|
+
|
207
|
+
The method serializes the content attribute to a JSON string if it is a dictionary.
|
208
|
+
Otherwise, it keeps the content as is.
|
209
|
+
|
210
|
+
Returns:
|
211
|
+
A dictionary representing the message with 'role' and 'content' keys.
|
212
|
+
"""
|
213
|
+
out = {
|
214
|
+
"role": self.role,
|
215
|
+
"content": json.dumps(self.content) if isinstance(self.content, dict) else self.content
|
216
|
+
}
|
217
|
+
return out
|
218
|
+
|
219
|
+
def _create_role_message(self, role_: str,
|
220
|
+
content: Any,
|
221
|
+
content_key: str,
|
222
|
+
name: Optional[str] = None
|
223
|
+
) -> None:
|
224
|
+
"""
|
225
|
+
Creates a message with a specific role, content, and an optional name.
|
226
|
+
|
227
|
+
This method sets up the message node with the specified role, content, and name. The content
|
228
|
+
is stored in a dictionary under the provided content_key.
|
229
|
+
|
230
|
+
Args:
|
231
|
+
role_ (str): The role of the message.
|
232
|
+
content (Any): The content of the message.
|
233
|
+
content_key (str): The key under which the content will be stored.
|
234
|
+
name (Optional[str]): The name associated with the message. Defaults to the role if not provided.
|
235
|
+
"""
|
236
|
+
self.role = role_
|
237
|
+
self.content = {content_key: content}
|
238
|
+
self.name = name or role_
|
239
|
+
|