lionagi 0.0.114__py3-none-any.whl → 0.0.116__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lionagi/__init__.py +7 -4
- lionagi/bridge/__init__.py +19 -4
- lionagi/bridge/langchain.py +23 -3
- lionagi/bridge/llama_index.py +5 -3
- lionagi/configs/__init__.py +1 -1
- lionagi/configs/oai_configs.py +88 -1
- lionagi/core/__init__.py +6 -9
- lionagi/core/conversations/__init__.py +5 -0
- lionagi/core/conversations/conversation.py +107 -0
- lionagi/core/flows/__init__.py +8 -0
- lionagi/core/flows/flow.py +8 -0
- lionagi/core/flows/flow_util.py +62 -0
- lionagi/core/instruction_set/__init__.py +5 -0
- lionagi/core/instruction_set/instruction_sets.py +7 -0
- lionagi/core/sessions/__init__.py +5 -0
- lionagi/core/sessions/sessions.py +187 -0
- lionagi/endpoints/__init__.py +5 -0
- lionagi/endpoints/assistants.py +0 -0
- lionagi/endpoints/audio.py +17 -0
- lionagi/endpoints/chatcompletion.py +54 -0
- lionagi/endpoints/embeddings.py +0 -0
- lionagi/endpoints/finetune.py +0 -0
- lionagi/endpoints/image.py +0 -0
- lionagi/endpoints/moderation.py +0 -0
- lionagi/endpoints/vision.py +0 -0
- lionagi/{loader → loaders}/__init__.py +7 -1
- lionagi/{loader → loaders}/chunker.py +6 -12
- lionagi/{utils/load_utils.py → loaders/load_util.py} +47 -6
- lionagi/{loader → loaders}/reader.py +4 -12
- lionagi/messages/__init__.py +11 -0
- lionagi/messages/instruction.py +15 -0
- lionagi/messages/message.py +110 -0
- lionagi/messages/response.py +33 -0
- lionagi/messages/system.py +12 -0
- lionagi/objs/__init__.py +10 -6
- lionagi/objs/abc_objs.py +39 -0
- lionagi/objs/async_queue.py +135 -0
- lionagi/objs/messenger.py +70 -148
- lionagi/objs/status_tracker.py +37 -0
- lionagi/objs/{tool_registry.py → tool_manager.py} +8 -6
- lionagi/schema/__init__.py +3 -3
- lionagi/schema/base_node.py +251 -0
- lionagi/schema/base_tool.py +8 -3
- lionagi/schema/data_logger.py +2 -3
- lionagi/schema/data_node.py +37 -0
- lionagi/services/__init__.py +1 -4
- lionagi/services/base_api_service.py +15 -5
- lionagi/services/oai.py +2 -2
- lionagi/services/openrouter.py +2 -3
- lionagi/structures/graph.py +96 -0
- lionagi/{structure → structures}/relationship.py +10 -2
- lionagi/structures/structure.py +102 -0
- lionagi/tests/test_api_util.py +46 -0
- lionagi/tests/test_call_util.py +115 -0
- lionagi/tests/test_convert_util.py +202 -0
- lionagi/tests/test_encrypt_util.py +33 -0
- lionagi/tests/{test_flatten_util.py → test_flat_util.py} +1 -1
- lionagi/tests/test_io_util.py +0 -0
- lionagi/tests/test_sys_util.py +0 -0
- lionagi/tools/__init__.py +5 -0
- lionagi/tools/tool_util.py +7 -0
- lionagi/utils/__init__.py +55 -35
- lionagi/utils/api_util.py +19 -17
- lionagi/utils/call_util.py +2 -1
- lionagi/utils/convert_util.py +229 -0
- lionagi/utils/encrypt_util.py +16 -0
- lionagi/utils/flat_util.py +38 -0
- lionagi/utils/io_util.py +2 -2
- lionagi/utils/sys_util.py +45 -10
- lionagi/version.py +1 -1
- {lionagi-0.0.114.dist-info → lionagi-0.0.116.dist-info}/METADATA +2 -2
- lionagi-0.0.116.dist-info/RECORD +110 -0
- lionagi/core/conversations.py +0 -108
- lionagi/core/flows.py +0 -1
- lionagi/core/instruction_sets.py +0 -1
- lionagi/core/messages.py +0 -166
- lionagi/core/sessions.py +0 -297
- lionagi/schema/base_schema.py +0 -252
- lionagi/services/chatcompletion.py +0 -48
- lionagi/services/service_objs.py +0 -282
- lionagi/structure/structure.py +0 -160
- lionagi/tools/coder.py +0 -1
- lionagi/tools/sandbox.py +0 -1
- lionagi/utils/tool_util.py +0 -92
- lionagi/utils/type_util.py +0 -81
- lionagi-0.0.114.dist-info/RECORD +0 -84
- /lionagi/configs/{openrouter_config.py → openrouter_configs.py} +0 -0
- /lionagi/{datastore → datastores}/__init__.py +0 -0
- /lionagi/{datastore → datastores}/chroma.py +0 -0
- /lionagi/{datastore → datastores}/deeplake.py +0 -0
- /lionagi/{datastore → datastores}/elasticsearch.py +0 -0
- /lionagi/{datastore → datastores}/lantern.py +0 -0
- /lionagi/{datastore → datastores}/pinecone.py +0 -0
- /lionagi/{datastore → datastores}/postgres.py +0 -0
- /lionagi/{datastore → datastores}/qdrant.py +0 -0
- /lionagi/{structure → structures}/__init__.py +0 -0
- {lionagi-0.0.114.dist-info → lionagi-0.0.116.dist-info}/LICENSE +0 -0
- {lionagi-0.0.114.dist-info → lionagi-0.0.116.dist-info}/WHEEL +0 -0
- {lionagi-0.0.114.dist-info → lionagi-0.0.116.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,9 @@ import tiktoken
|
|
5
5
|
import logging
|
6
6
|
import aiohttp
|
7
7
|
from typing import Generator, NoReturn, Dict, Any, Optional
|
8
|
-
from .
|
8
|
+
from ..objs.status_tracker import StatusTracker
|
9
|
+
from ..objs.abc_objs import BaseService, RateLimiter
|
10
|
+
from ..objs.async_queue import AsyncQueue
|
9
11
|
|
10
12
|
class BaseAPIRateLimiter(RateLimiter):
|
11
13
|
|
@@ -20,7 +22,7 @@ class BaseAPIRateLimiter(RateLimiter):
|
|
20
22
|
) -> None:
|
21
23
|
self = cls(max_requests_per_minute, max_tokens_per_minute)
|
22
24
|
if not os.getenv("env_readthedocs"):
|
23
|
-
self.rate_limit_replenisher_task =
|
25
|
+
self.rate_limit_replenisher_task = asyncio.create_task(
|
24
26
|
self.rate_limit_replenisher()
|
25
27
|
)
|
26
28
|
return self
|
@@ -43,6 +45,7 @@ class BaseAPIRateLimiter(RateLimiter):
|
|
43
45
|
self.available_request_capacity = self.max_requests_per_minute
|
44
46
|
self.available_token_capacity = self.max_tokens_per_minute
|
45
47
|
|
48
|
+
# credit to OpenAI for the following function
|
46
49
|
def calculate_num_token(
|
47
50
|
self,
|
48
51
|
payload: Dict[str, Any] = None,
|
@@ -130,13 +133,16 @@ class BaseAPIService(BaseService):
|
|
130
133
|
def __init__(self, api_key: str = None,
|
131
134
|
status_tracker = None,
|
132
135
|
queue = None, endpoint=None, schema=None,
|
133
|
-
ratelimiter=
|
136
|
+
ratelimiter=BaseAPIRateLimiter, max_requests_per_minute=None, max_tokens_per_minute=None) -> None:
|
134
137
|
self.api_key = api_key
|
135
138
|
self.status_tracker = status_tracker or StatusTracker()
|
136
139
|
self.queue = queue or AsyncQueue()
|
137
140
|
self.endpoint=endpoint
|
138
141
|
self.schema = schema
|
139
|
-
self.
|
142
|
+
self.max_requests_per_minute = max_requests_per_minute
|
143
|
+
self.max_tokens_per_minute = max_tokens_per_minute
|
144
|
+
self.rate_limiter_class = ratelimiter
|
145
|
+
self.rate_limiter = None
|
140
146
|
|
141
147
|
@staticmethod
|
142
148
|
def api_methods(http_session, method="post"):
|
@@ -168,6 +174,10 @@ class BaseAPIService(BaseService):
|
|
168
174
|
yield task_id
|
169
175
|
task_id += 1
|
170
176
|
|
177
|
+
async def _init(self):
|
178
|
+
if self.rate_limiter is None:
|
179
|
+
self.rate_limiter = await self.rate_limiter_class.create(self.max_requests_per_minute, self.max_tokens_per_minute)
|
180
|
+
|
171
181
|
async def _call_api(self, http_session, endpoint_, method="post", payload: Dict[str, any] =None) -> Optional[Dict[str, any]]:
|
172
182
|
endpoint_ = self.api_endpoint_from_url("https://api.openai.com/v1/"+endpoint_)
|
173
183
|
|
@@ -224,4 +234,4 @@ class BaseAPIService(BaseService):
|
|
224
234
|
except Exception as e:
|
225
235
|
self.status_tracker.num_tasks_failed += 1
|
226
236
|
raise e
|
227
|
-
|
237
|
+
|
lionagi/services/oai.py
CHANGED
@@ -15,7 +15,7 @@ class OpenAIService(BaseAPIService):
|
|
15
15
|
max_attempts: int = 3,
|
16
16
|
max_requests_per_minute: int = 500,
|
17
17
|
max_tokens_per_minute: int = 150_000,
|
18
|
-
ratelimiter = BaseAPIRateLimiter
|
18
|
+
ratelimiter = BaseAPIRateLimiter,
|
19
19
|
status_tracker = None,
|
20
20
|
queue = None,
|
21
21
|
):
|
@@ -31,4 +31,4 @@ class OpenAIService(BaseAPIService):
|
|
31
31
|
|
32
32
|
async def serve(self, payload, endpoint_="chat/completions", method="post"):
|
33
33
|
return await self._serve(payload=payload, endpoint_=endpoint_, method=method)
|
34
|
-
|
34
|
+
|
lionagi/services/openrouter.py
CHANGED
@@ -2,8 +2,7 @@ from os import getenv
|
|
2
2
|
from .base_api_service import BaseAPIService, BaseAPIRateLimiter
|
3
3
|
|
4
4
|
class OpenRouterService(BaseAPIService):
|
5
|
-
|
6
|
-
|
5
|
+
key_scheme = "OPENROUTER_API_KEY"
|
7
6
|
base_url = "https://openrouter.ai/api/v1/"
|
8
7
|
|
9
8
|
def __init__(
|
@@ -18,7 +17,7 @@ class OpenRouterService(BaseAPIService):
|
|
18
17
|
queue = None,
|
19
18
|
):
|
20
19
|
super().__init__(
|
21
|
-
api_key = api_key or getenv(self.
|
20
|
+
api_key = api_key or getenv(self.key_scheme),
|
22
21
|
status_tracker = status_tracker,
|
23
22
|
queue = queue,
|
24
23
|
ratelimiter=ratelimiter,
|
@@ -0,0 +1,96 @@
|
|
1
|
+
from pydantic import Field
|
2
|
+
|
3
|
+
from lionagi.schema.base_node import BaseNode
|
4
|
+
from .relationship import Relationship
|
5
|
+
from lionagi.utils.call_util import lcall
|
6
|
+
|
7
|
+
|
8
|
+
class Graph(BaseNode):
|
9
|
+
nodes: dict = Field(default={})
|
10
|
+
relationships: dict = Field(default={})
|
11
|
+
node_relationships: dict = Field(default={})
|
12
|
+
|
13
|
+
def add_node(self, node: BaseNode):
|
14
|
+
self.nodes[node.id_] = node
|
15
|
+
self.node_relationships[node.id_] = {'in': {}, 'out': {}}
|
16
|
+
|
17
|
+
def add_relationship(self, relationships: Relationship):
|
18
|
+
if relationships.source_node_id not in self.node_relationships.keys():
|
19
|
+
raise KeyError(f'node {relationships.source_node_id} is not found.')
|
20
|
+
if relationships.target_node_id not in self.node_relationships.keys():
|
21
|
+
raise KeyError(f'node {relationships.target_node_id} is not found.')
|
22
|
+
|
23
|
+
self.relationships[relationships.id_] = relationships
|
24
|
+
self.node_relationships[relationships.source_node_id]['out'][relationships.id_] = relationships.target_node_id
|
25
|
+
self.node_relationships[relationships.target_node_id]['in'][relationships.id_] = relationships.source_node_id
|
26
|
+
|
27
|
+
def get_node_relationships(self, node: BaseNode = None, out_edge=True):
|
28
|
+
if node is None:
|
29
|
+
return list(self.relationships.values())
|
30
|
+
|
31
|
+
if node.id_ not in self.nodes.keys():
|
32
|
+
raise KeyError(f'node {node.id_} is not found')
|
33
|
+
|
34
|
+
if out_edge:
|
35
|
+
relationship_ids = list(self.node_relationships[node.id_]['out'].keys())
|
36
|
+
relationships = lcall(relationship_ids, lambda i: self.relationships[i])
|
37
|
+
return relationships
|
38
|
+
else:
|
39
|
+
relationship_ids = list(self.node_relationships[node.id_]['in'].keys())
|
40
|
+
relationships = lcall(relationship_ids, lambda i: self.relationships[i])
|
41
|
+
return relationships
|
42
|
+
|
43
|
+
def remove_node(self, node: BaseNode):
|
44
|
+
if node.id_ not in self.nodes.keys():
|
45
|
+
raise KeyError(f'node {node.id_} is not found')
|
46
|
+
|
47
|
+
out_relationship = self.node_relationships[node.id_]['out']
|
48
|
+
for relationship_id, node_id in out_relationship.items():
|
49
|
+
self.node_relationships[node_id]['in'].pop(relationship_id)
|
50
|
+
self.relationships.pop(relationship_id)
|
51
|
+
|
52
|
+
in_relationship = self.node_relationships[node.id_]['in']
|
53
|
+
for relationship_id, node_id in in_relationship.items():
|
54
|
+
self.node_relationships[node_id]['out'].pop(relationship_id)
|
55
|
+
self.relationships.pop(relationship_id)
|
56
|
+
|
57
|
+
self.node_relationships.pop(node.id_)
|
58
|
+
return self.nodes.pop(node.id_)
|
59
|
+
|
60
|
+
def remove_relationship(self, relationship: Relationship):
|
61
|
+
if relationship.id_ not in self.relationships.keys():
|
62
|
+
raise KeyError(f'relationship {relationship.id_} is not found')
|
63
|
+
|
64
|
+
self.node_relationships[relationship.source_node_id]['out'].pop(relationship.id_)
|
65
|
+
self.node_relationships[relationship.target_node_id]['in'].pop(relationship.id_)
|
66
|
+
|
67
|
+
return self.relationships.pop(relationship.id_)
|
68
|
+
|
69
|
+
def node_exists(self, node: BaseNode):
|
70
|
+
if node.id_ in self.nodes.keys():
|
71
|
+
return True
|
72
|
+
else:
|
73
|
+
return False
|
74
|
+
|
75
|
+
def relationship_exists(self, relationship: Relationship):
|
76
|
+
if relationship.id_ in self.relationships.keys():
|
77
|
+
return True
|
78
|
+
else:
|
79
|
+
return False
|
80
|
+
|
81
|
+
def to_networkx(self, **kwargs):
|
82
|
+
import networkx as nx
|
83
|
+
g = nx.DiGraph(**kwargs)
|
84
|
+
for node_id, node in self.nodes.items():
|
85
|
+
node_info = node.to_dict()
|
86
|
+
node_info.pop('node_id')
|
87
|
+
g.add_node(node_id, **node_info)
|
88
|
+
|
89
|
+
for relationship_id, relationship in self.relationships.items():
|
90
|
+
relationship_info = relationship.to_dict()
|
91
|
+
relationship_info.pop('node_id')
|
92
|
+
source_node_id = relationship_info.pop('source_node_id')
|
93
|
+
target_node_id = relationship_info.pop('target_node_id')
|
94
|
+
g.add_edge(source_node_id, target_node_id, **relationship_info)
|
95
|
+
|
96
|
+
return g
|
@@ -1,6 +1,6 @@
|
|
1
1
|
from pydantic import Field
|
2
2
|
from typing import Dict, Optional, Any
|
3
|
-
from ..schema.
|
3
|
+
from ..schema.base_node import BaseNode
|
4
4
|
|
5
5
|
|
6
6
|
class Relationship(BaseNode):
|
@@ -125,4 +125,12 @@ class Relationship(BaseNode):
|
|
125
125
|
raise ValueError(f"Target node {self.source_node_id} does not exist")
|
126
126
|
else :
|
127
127
|
raise ValueError(f"Source node {self.target_node_id} does not exist")
|
128
|
-
|
128
|
+
|
129
|
+
def __str__(self) -> str:
|
130
|
+
"""Returns a simple string representation of the Relationship."""
|
131
|
+
return f"Relationship (id_={self.id_}, from={self.source_node_id}, to={self.target_node_id}, label={self.label})"
|
132
|
+
|
133
|
+
def __repr__(self) -> str:
|
134
|
+
"""Returns a detailed string representation of the Relationship."""
|
135
|
+
return f"Relationship(id_={self.id_}, from={self.source_node_id}, to={self.target_node_id}, content={self.content}, " \
|
136
|
+
f"metadata={self.metadata}, label={self.label})"
|
@@ -0,0 +1,102 @@
|
|
1
|
+
from typing import TypeVar
|
2
|
+
from .graph import Graph
|
3
|
+
from ..schema import BaseNode
|
4
|
+
from .relationship import Relationship
|
5
|
+
|
6
|
+
T = TypeVar('T', bound='BaseNode')
|
7
|
+
R = TypeVar('R', bound='Relationship')
|
8
|
+
|
9
|
+
|
10
|
+
class Structure(BaseNode):
|
11
|
+
"""
|
12
|
+
Represents the structure of a graph consisting of nodes and relationships.
|
13
|
+
"""
|
14
|
+
graph: Graph
|
15
|
+
|
16
|
+
def add_node(self, node: T) -> None:
|
17
|
+
"""
|
18
|
+
Adds a node to the structure.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
node (T): The node instance to be added.
|
22
|
+
"""
|
23
|
+
self.graph.add_node(node)
|
24
|
+
|
25
|
+
def add_relationship(self, relationship: R) -> None:
|
26
|
+
"""
|
27
|
+
Adds a relationship to the structure.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
relationship (R): The relationship instance to be added.
|
31
|
+
"""
|
32
|
+
self.graph.add_relationship(relationship)
|
33
|
+
|
34
|
+
# type can be dict or list
|
35
|
+
# @staticmethod
|
36
|
+
# def _typed_return(type: Type[Union[Dict, List]],
|
37
|
+
# obj: Optional[Dict[str, Any]] = None
|
38
|
+
# ) -> Union[Dict[str, Any], List[Any]]:
|
39
|
+
# """
|
40
|
+
# Returns the object in the specified type format.
|
41
|
+
#
|
42
|
+
# Args:
|
43
|
+
# type (Type[Union[Dict, List]]): The type to return the object as (dict or list).
|
44
|
+
#
|
45
|
+
# obj (Optional[Dict[str, Any]]): The object to be converted.
|
46
|
+
#
|
47
|
+
# Returns:
|
48
|
+
# Union[Dict[str, Any], List[Any]]: The object in the specified type format.
|
49
|
+
# """
|
50
|
+
# if type is list:
|
51
|
+
# return list(obj.values())
|
52
|
+
# return obj
|
53
|
+
|
54
|
+
def get_relationships(self) -> list[R]:
|
55
|
+
return self.graph.get_node_relationships()
|
56
|
+
|
57
|
+
def get_node_relationships(self, node: T, out_edge=True) -> R:
|
58
|
+
return self.graph.get_node_relationships(node, out_edge)
|
59
|
+
|
60
|
+
def node_exist(self, node: T) -> bool:
|
61
|
+
"""
|
62
|
+
Checks if a node exists in the structure.
|
63
|
+
|
64
|
+
Args:
|
65
|
+
node (T): The node instance or node ID to check for existence.
|
66
|
+
|
67
|
+
Returns:
|
68
|
+
bool: True if the node exists, False otherwise.
|
69
|
+
"""
|
70
|
+
|
71
|
+
return self.graph.node_exist(node)
|
72
|
+
|
73
|
+
def relationship_exist(self, relationship: R) -> bool:
|
74
|
+
"""
|
75
|
+
Checks if a relationship exists in the structure.
|
76
|
+
|
77
|
+
Args:
|
78
|
+
relationship (R): The relationship instance to check for existence.
|
79
|
+
|
80
|
+
Returns:
|
81
|
+
bool: True if the relationship exists, False otherwise.
|
82
|
+
"""
|
83
|
+
return self.graph.relationship_exists(relationship)
|
84
|
+
|
85
|
+
def remove_node(self, node: T) -> T:
|
86
|
+
"""
|
87
|
+
Removes a node and its associated relationships from the structure.
|
88
|
+
|
89
|
+
Args:
|
90
|
+
node (T): The node instance or node ID to be removed.
|
91
|
+
"""
|
92
|
+
return self.graph.remove_node(node)
|
93
|
+
|
94
|
+
def remove_relationship(self, relationship: R) -> R:
|
95
|
+
"""
|
96
|
+
Removes a relationship from the structure.
|
97
|
+
|
98
|
+
Args:
|
99
|
+
relationship (R): The relationship instance to be removed.
|
100
|
+
"""
|
101
|
+
return self.graph.remove_relationship(relationship)
|
102
|
+
|
@@ -0,0 +1,46 @@
|
|
1
|
+
import unittest
|
2
|
+
from unittest.mock import MagicMock
|
3
|
+
|
4
|
+
# Assuming the Python module with the above functions is named 'api_utils'
|
5
|
+
from lionagi.utils.api_util import *
|
6
|
+
|
7
|
+
class TestApiUtils(unittest.TestCase):
|
8
|
+
|
9
|
+
def test_api_method_valid(self):
|
10
|
+
session = MagicMock()
|
11
|
+
methods = ['post', 'delete', 'head', 'options', 'patch']
|
12
|
+
for method in methods:
|
13
|
+
with self.subTest(method=method):
|
14
|
+
func = api_method(session, method)
|
15
|
+
self.assertTrue(callable(func))
|
16
|
+
|
17
|
+
def test_api_method_invalid(self):
|
18
|
+
session = MagicMock()
|
19
|
+
with self.assertRaises(ValueError):
|
20
|
+
api_method(session, 'get')
|
21
|
+
|
22
|
+
def test_api_error(self):
|
23
|
+
response_with_error = {'error': 'Something went wrong'}
|
24
|
+
response_without_error = {'result': 'Success'}
|
25
|
+
|
26
|
+
self.assertTrue(api_error(response_with_error))
|
27
|
+
self.assertFalse(api_error(response_without_error))
|
28
|
+
|
29
|
+
def test_api_rate_limit_error(self):
|
30
|
+
response_with_rate_limit_error = {'error': {'message': 'Rate limit exceeded'}}
|
31
|
+
response_without_rate_limit_error = {'error': {'message': 'Another error'}}
|
32
|
+
|
33
|
+
self.assertTrue(api_rate_limit_error(response_with_rate_limit_error))
|
34
|
+
self.assertFalse(api_rate_limit_error(response_without_rate_limit_error))
|
35
|
+
|
36
|
+
def test_api_endpoint_from_url(self):
|
37
|
+
url_with_endpoint = "https://api.example.com/v1/users"
|
38
|
+
url_without_endpoint = "https://api.example.com/users"
|
39
|
+
url_invalid = "Just a string"
|
40
|
+
|
41
|
+
self.assertEqual(api_endpoint_from_url(url_with_endpoint), 'users')
|
42
|
+
self.assertEqual(api_endpoint_from_url(url_without_endpoint), '')
|
43
|
+
self.assertEqual(api_endpoint_from_url(url_invalid), '')
|
44
|
+
|
45
|
+
if __name__ == '__main__':
|
46
|
+
unittest.main()
|
@@ -0,0 +1,115 @@
|
|
1
|
+
# import asyncio
|
2
|
+
# import unittest
|
3
|
+
# from unittest.mock import patch, MagicMock
|
4
|
+
# from lionagi.utils.sys_util import create_copy
|
5
|
+
# from lionagi.utils.flat_util import to_list
|
6
|
+
# from typing import Callable
|
7
|
+
|
8
|
+
# # Assuming the Python module with the above functions is named 'call_utils'
|
9
|
+
# from lionagi.utils.call_util import *
|
10
|
+
|
11
|
+
# class TestCallUtils(unittest.TestCase):
|
12
|
+
|
13
|
+
# def setUp(self):
|
14
|
+
# self.sample_input = [1, 2, 3]
|
15
|
+
# self.sample_func = MagicMock(return_value=42)
|
16
|
+
|
17
|
+
# def test_hcall(self):
|
18
|
+
# with patch('time.sleep', return_value=None):
|
19
|
+
# result = hcall(self.sample_input, self.sample_func)
|
20
|
+
# self.sample_func.assert_called_once_with(self.sample_input)
|
21
|
+
# self.assertEqual(result, 42)
|
22
|
+
|
23
|
+
# def test_ahcall(self):
|
24
|
+
# async def async_test():
|
25
|
+
# self.sample_func.reset_mock()
|
26
|
+
# with patch('asyncio.sleep', return_value=None):
|
27
|
+
# result = await ahcall(self.sample_input, self.sample_func)
|
28
|
+
# self.sample_func.assert_called_once_with(self.sample_input)
|
29
|
+
# self.assertEqual(result, 42)
|
30
|
+
|
31
|
+
# asyncio.run(async_test())
|
32
|
+
|
33
|
+
# def test_lcall(self):
|
34
|
+
# expected_result = [42, 42, 42]
|
35
|
+
# result = lcall(self.sample_input, self.sample_func)
|
36
|
+
# self.assertEqual(result, expected_result)
|
37
|
+
# calls = [unittest.mock.call(item) for item in self.sample_input]
|
38
|
+
# self.sample_func.assert_has_calls(calls, any_order=True)
|
39
|
+
|
40
|
+
# def test_alcall(self):
|
41
|
+
# async def async_test():
|
42
|
+
# expected_result = [42, 42, 42]
|
43
|
+
# self.sample_func.reset_mock()
|
44
|
+
# result = await alcall(self.sample_input, self.sample_func)
|
45
|
+
# self.assertEqual(result, expected_result)
|
46
|
+
# calls = [unittest.mock.call(item) for item in self.sample_input]
|
47
|
+
# self.sample_func.assert_has_calls(calls, any_order=True)
|
48
|
+
|
49
|
+
# asyncio.run(async_test())
|
50
|
+
|
51
|
+
|
52
|
+
# class TestCallUtils2(unittest.TestCase):
|
53
|
+
|
54
|
+
# def setUp(self):
|
55
|
+
# self.sample_input = [1, 2, 3]
|
56
|
+
# self.sample_func = MagicMock(side_effect=lambda x: x + 1)
|
57
|
+
# self.sample_async_func = MagicMock(side_effect=lambda x: x + 1)
|
58
|
+
|
59
|
+
# def test_mcall_single_function(self):
|
60
|
+
# result = mcall(self.sample_input, self.sample_func)
|
61
|
+
# self.sample_func.assert_has_calls([unittest.mock.call(i) for i in self.sample_input], any_order=True)
|
62
|
+
# self.assertEqual(result, [2, 3, 4])
|
63
|
+
|
64
|
+
# def test_mcall_multiple_functions(self):
|
65
|
+
# # Define multiple functions
|
66
|
+
# funcs = [
|
67
|
+
# MagicMock(side_effect=lambda x: x + 1),
|
68
|
+
# MagicMock(side_effect=lambda x: x * 2),
|
69
|
+
# MagicMock(side_effect=lambda x: x - 1)
|
70
|
+
# ]
|
71
|
+
# result = mcall(self.sample_input, funcs)
|
72
|
+
# for i, func in enumerate(funcs):
|
73
|
+
# func.assert_called_once_with(self.sample_input[i])
|
74
|
+
# self.assertEqual(result, [2, 4, 2])
|
75
|
+
|
76
|
+
# def test_amcall_single_function(self):
|
77
|
+
# async def async_test():
|
78
|
+
# result = await amcall(self.sample_input, self.sample_async_func)
|
79
|
+
# self.sample_async_func.assert_has_calls([unittest.mock.call(i) for i in self.sample_input], any_order=True)
|
80
|
+
# self.assertEqual(result, [2, 3, 4])
|
81
|
+
|
82
|
+
# asyncio.run(async_test())
|
83
|
+
|
84
|
+
# def test_amcall_multiple_functions(self):
|
85
|
+
# # Define multiple asynchronous functions
|
86
|
+
# async_funcs = [
|
87
|
+
# MagicMock(side_effect=lambda x: x + 1),
|
88
|
+
# MagicMock(side_effect=lambda x: x * 2),
|
89
|
+
# MagicMock(side_effect=lambda x: x - 1)
|
90
|
+
# ]
|
91
|
+
|
92
|
+
# async def async_test():
|
93
|
+
# result = await amcall(self.sample_input, async_funcs)
|
94
|
+
# for i, func in enumerate(async_funcs):
|
95
|
+
# func.assert_called_once_with(self.sample_input[i])
|
96
|
+
# self.assertEqual(result, [2, 4, 2])
|
97
|
+
|
98
|
+
# asyncio.run(async_test())
|
99
|
+
|
100
|
+
# def test_ecall(self):
|
101
|
+
# funcs = [MagicMock(side_effect=lambda x: x * x)]
|
102
|
+
# result = ecall(self.sample_input, funcs)
|
103
|
+
# self.assertEqual(result, [[1, 4, 9]])
|
104
|
+
|
105
|
+
# def test_aecall(self):
|
106
|
+
# async_funcs = [MagicMock(side_effect=lambda x: x * x)]
|
107
|
+
|
108
|
+
# async def async_test():
|
109
|
+
# result = await aecall(self.sample_input, async_funcs)
|
110
|
+
# self.assertEqual(result, [[1, 4, 9]])
|
111
|
+
|
112
|
+
# asyncio.run(async_test())
|
113
|
+
|
114
|
+
# if __name__ == '__main__':
|
115
|
+
# unittest.main()
|
@@ -0,0 +1,202 @@
|
|
1
|
+
import unittest
|
2
|
+
from lionagi.utils.convert_util import *
|
3
|
+
|
4
|
+
# Test cases for the function
|
5
|
+
class TestStrToNum(unittest.TestCase):
|
6
|
+
def test_valid_int(self):
|
7
|
+
self.assertEqual(str_to_num("123"), 123)
|
8
|
+
self.assertEqual(str_to_num("-123"), -123)
|
9
|
+
|
10
|
+
def test_valid_float(self):
|
11
|
+
self.assertEqual(str_to_num("123.45", num_type=float), 123.45)
|
12
|
+
self.assertEqual(str_to_num("-123.45", num_type=float), -123.45)
|
13
|
+
|
14
|
+
def test_precision(self):
|
15
|
+
self.assertEqual(str_to_num("123.456", num_type=float, precision=1), 123.5)
|
16
|
+
self.assertEqual(str_to_num("123.444", num_type=float, precision=2), 123.44)
|
17
|
+
|
18
|
+
def test_bounds(self):
|
19
|
+
self.assertEqual(str_to_num("10", lower_bound=5, upper_bound=15), 10)
|
20
|
+
with self.assertRaises(ValueError):
|
21
|
+
str_to_num("20", upper_bound=15)
|
22
|
+
with self.assertRaises(ValueError):
|
23
|
+
str_to_num("2", lower_bound=5)
|
24
|
+
|
25
|
+
def test_invalid_input(self):
|
26
|
+
with self.assertRaises(ValueError):
|
27
|
+
str_to_num("abc")
|
28
|
+
with self.assertRaises(ValueError):
|
29
|
+
str_to_num("123abc", num_type=str)
|
30
|
+
|
31
|
+
def test_no_numeric_value(self):
|
32
|
+
with self.assertRaises(ValueError):
|
33
|
+
str_to_num("No numbers here")
|
34
|
+
|
35
|
+
|
36
|
+
|
37
|
+
# Functions to be tested
|
38
|
+
def dict_to_xml(data: Dict[str, Any], root_tag: str = 'node') -> str:
|
39
|
+
root = ET.Element(root_tag)
|
40
|
+
_build_xml(root, data)
|
41
|
+
return ET.tostring(root, encoding='unicode')
|
42
|
+
|
43
|
+
def _build_xml(element: ET.Element, data: Any):
|
44
|
+
if isinstance(data, dict):
|
45
|
+
for key, value in data.items():
|
46
|
+
sub_element = ET.SubElement(element, key)
|
47
|
+
_build_xml(sub_element, value)
|
48
|
+
elif isinstance(data, list):
|
49
|
+
for item in data:
|
50
|
+
item_element = ET.SubElement(element, 'item')
|
51
|
+
_build_xml(item_element, item)
|
52
|
+
else:
|
53
|
+
element.text = str(data)
|
54
|
+
|
55
|
+
def xml_to_dict(element: ET.Element) -> Dict[str, Any]:
|
56
|
+
dict_data = {}
|
57
|
+
for child in element:
|
58
|
+
if list(child):
|
59
|
+
dict_data[child.tag] = xml_to_dict(child)
|
60
|
+
else:
|
61
|
+
dict_data[child.tag] = child.text
|
62
|
+
return dict_data
|
63
|
+
|
64
|
+
# Test cases for the functions
|
65
|
+
class TestDictXMLConversion(unittest.TestCase):
|
66
|
+
def setUp(self):
|
67
|
+
self.data = {
|
68
|
+
'name': 'John',
|
69
|
+
'age': 30,
|
70
|
+
'children': [
|
71
|
+
{'name': 'Alice', 'age': 5},
|
72
|
+
{'name': 'Bob', 'age': 7}
|
73
|
+
]
|
74
|
+
}
|
75
|
+
self.root_tag = 'person'
|
76
|
+
self.xml = dict_to_xml(self.data, self.root_tag)
|
77
|
+
self.xml_element = ET.fromstring(self.xml)
|
78
|
+
|
79
|
+
def test_dict_to_xml(self):
|
80
|
+
self.assertIn('<name>John</name>', self.xml)
|
81
|
+
self.assertIn('<age>30</age>', self.xml)
|
82
|
+
self.assertIn('<children>', self.xml)
|
83
|
+
self.assertIn('<item>', self.xml)
|
84
|
+
|
85
|
+
# def test_xml_to_dict(self):
|
86
|
+
# data_from_xml = xml_to_dict(self.xml_element)
|
87
|
+
# self.assertEqual(data_from_xml, self.data)
|
88
|
+
|
89
|
+
# def test_xml_to_dict_to_xml(self):
|
90
|
+
# data_from_xml = xml_to_dict(self.xml_element)
|
91
|
+
# xml_from_dict = dict_to_xml(data_from_xml, self.root_tag)
|
92
|
+
# self.assertEqual(xml_from_dict, self.xml)
|
93
|
+
|
94
|
+
# def test_invalid_input(self):
|
95
|
+
# with self.assertRaises(TypeError):
|
96
|
+
# dict_to_xml("not a dict", self.root_tag)
|
97
|
+
|
98
|
+
|
99
|
+
|
100
|
+
class TestDocstringExtraction(unittest.TestCase):
|
101
|
+
def test_google_style_extraction(self):
|
102
|
+
def sample_func(arg1, arg2):
|
103
|
+
"""
|
104
|
+
This is a sample function.
|
105
|
+
|
106
|
+
Args:
|
107
|
+
arg1 (int): The first argument.
|
108
|
+
arg2 (str): The second argument.
|
109
|
+
|
110
|
+
Returns:
|
111
|
+
bool: The truth value.
|
112
|
+
"""
|
113
|
+
return True
|
114
|
+
|
115
|
+
description, params = extract_docstring_details_google(sample_func)
|
116
|
+
self.assertEqual(description, "This is a sample function.")
|
117
|
+
self.assertEqual(params, {
|
118
|
+
'arg1': 'The first argument.',
|
119
|
+
'arg2': 'The second argument.'
|
120
|
+
})
|
121
|
+
|
122
|
+
def test_rest_style_extraction(self):
|
123
|
+
def sample_func(arg1, arg2):
|
124
|
+
"""
|
125
|
+
This is a sample function.
|
126
|
+
|
127
|
+
:param int arg1: The first argument.
|
128
|
+
:param str arg2: The second argument.
|
129
|
+
:return: The truth value.
|
130
|
+
:rtype: bool
|
131
|
+
"""
|
132
|
+
return True
|
133
|
+
|
134
|
+
description, params = extract_docstring_details_rest(sample_func)
|
135
|
+
self.assertEqual(description, "This is a sample function.")
|
136
|
+
self.assertEqual(params, {
|
137
|
+
'arg1': 'The first argument.',
|
138
|
+
'arg2': 'The second argument.'
|
139
|
+
})
|
140
|
+
|
141
|
+
def test_extract_docstring_details_with_invalid_style(self):
|
142
|
+
def sample_func(arg1, arg2):
|
143
|
+
return True
|
144
|
+
|
145
|
+
with self.assertRaises(ValueError):
|
146
|
+
extract_docstring_details(sample_func, style='unsupported')
|
147
|
+
|
148
|
+
class TestPythonToJsonTypeConversion(unittest.TestCase):
|
149
|
+
def test_python_to_json_type_conversion(self):
|
150
|
+
self.assertEqual(python_to_json_type('str'), 'string')
|
151
|
+
self.assertEqual(python_to_json_type('int'), 'number')
|
152
|
+
self.assertEqual(python_to_json_type('float'), 'number')
|
153
|
+
self.assertEqual(python_to_json_type('list'), 'array')
|
154
|
+
self.assertEqual(python_to_json_type('tuple'), 'array')
|
155
|
+
self.assertEqual(python_to_json_type('bool'), 'boolean')
|
156
|
+
self.assertEqual(python_to_json_type('dict'), 'object')
|
157
|
+
self.assertEqual(python_to_json_type('nonexistent'), 'object')
|
158
|
+
|
159
|
+
class TestFunctionToSchema(unittest.TestCase):
|
160
|
+
def test_func_to_schema(self):
|
161
|
+
def sample_func(arg1: int, arg2: str) -> bool:
|
162
|
+
"""
|
163
|
+
This is a sample function.
|
164
|
+
|
165
|
+
Args:
|
166
|
+
arg1 (int): The first argument.
|
167
|
+
arg2 (str): The second argument.
|
168
|
+
|
169
|
+
Returns:
|
170
|
+
bool: The truth value.
|
171
|
+
"""
|
172
|
+
return True
|
173
|
+
|
174
|
+
schema = func_to_schema(sample_func, style='google')
|
175
|
+
expected_schema = {
|
176
|
+
"type": "function",
|
177
|
+
"function": {
|
178
|
+
"name": "sample_func",
|
179
|
+
"description": "This is a sample function.",
|
180
|
+
"parameters": {
|
181
|
+
"type": "object",
|
182
|
+
"properties": {
|
183
|
+
"arg1": {
|
184
|
+
"type": "number",
|
185
|
+
"description": "The first argument."
|
186
|
+
},
|
187
|
+
"arg2": {
|
188
|
+
"type": "string",
|
189
|
+
"description": "The second argument."
|
190
|
+
}
|
191
|
+
},
|
192
|
+
"required": ["arg1", "arg2"]
|
193
|
+
}
|
194
|
+
}
|
195
|
+
}
|
196
|
+
self.assertEqual(schema, expected_schema)
|
197
|
+
|
198
|
+
if __name__ == '__main__':
|
199
|
+
unittest.main()
|
200
|
+
|
201
|
+
|
202
|
+
# --------------------------------------------------
|