lionagi 0.0.114__py3-none-any.whl → 0.0.116__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- lionagi/__init__.py +7 -4
- lionagi/bridge/__init__.py +19 -4
- lionagi/bridge/langchain.py +23 -3
- lionagi/bridge/llama_index.py +5 -3
- lionagi/configs/__init__.py +1 -1
- lionagi/configs/oai_configs.py +88 -1
- lionagi/core/__init__.py +6 -9
- lionagi/core/conversations/__init__.py +5 -0
- lionagi/core/conversations/conversation.py +107 -0
- lionagi/core/flows/__init__.py +8 -0
- lionagi/core/flows/flow.py +8 -0
- lionagi/core/flows/flow_util.py +62 -0
- lionagi/core/instruction_set/__init__.py +5 -0
- lionagi/core/instruction_set/instruction_sets.py +7 -0
- lionagi/core/sessions/__init__.py +5 -0
- lionagi/core/sessions/sessions.py +187 -0
- lionagi/endpoints/__init__.py +5 -0
- lionagi/endpoints/assistants.py +0 -0
- lionagi/endpoints/audio.py +17 -0
- lionagi/endpoints/chatcompletion.py +54 -0
- lionagi/endpoints/embeddings.py +0 -0
- lionagi/endpoints/finetune.py +0 -0
- lionagi/endpoints/image.py +0 -0
- lionagi/endpoints/moderation.py +0 -0
- lionagi/endpoints/vision.py +0 -0
- lionagi/{loader → loaders}/__init__.py +7 -1
- lionagi/{loader → loaders}/chunker.py +6 -12
- lionagi/{utils/load_utils.py → loaders/load_util.py} +47 -6
- lionagi/{loader → loaders}/reader.py +4 -12
- lionagi/messages/__init__.py +11 -0
- lionagi/messages/instruction.py +15 -0
- lionagi/messages/message.py +110 -0
- lionagi/messages/response.py +33 -0
- lionagi/messages/system.py +12 -0
- lionagi/objs/__init__.py +10 -6
- lionagi/objs/abc_objs.py +39 -0
- lionagi/objs/async_queue.py +135 -0
- lionagi/objs/messenger.py +70 -148
- lionagi/objs/status_tracker.py +37 -0
- lionagi/objs/{tool_registry.py → tool_manager.py} +8 -6
- lionagi/schema/__init__.py +3 -3
- lionagi/schema/base_node.py +251 -0
- lionagi/schema/base_tool.py +8 -3
- lionagi/schema/data_logger.py +2 -3
- lionagi/schema/data_node.py +37 -0
- lionagi/services/__init__.py +1 -4
- lionagi/services/base_api_service.py +15 -5
- lionagi/services/oai.py +2 -2
- lionagi/services/openrouter.py +2 -3
- lionagi/structures/graph.py +96 -0
- lionagi/{structure → structures}/relationship.py +10 -2
- lionagi/structures/structure.py +102 -0
- lionagi/tests/test_api_util.py +46 -0
- lionagi/tests/test_call_util.py +115 -0
- lionagi/tests/test_convert_util.py +202 -0
- lionagi/tests/test_encrypt_util.py +33 -0
- lionagi/tests/{test_flatten_util.py → test_flat_util.py} +1 -1
- lionagi/tests/test_io_util.py +0 -0
- lionagi/tests/test_sys_util.py +0 -0
- lionagi/tools/__init__.py +5 -0
- lionagi/tools/tool_util.py +7 -0
- lionagi/utils/__init__.py +55 -35
- lionagi/utils/api_util.py +19 -17
- lionagi/utils/call_util.py +2 -1
- lionagi/utils/convert_util.py +229 -0
- lionagi/utils/encrypt_util.py +16 -0
- lionagi/utils/flat_util.py +38 -0
- lionagi/utils/io_util.py +2 -2
- lionagi/utils/sys_util.py +45 -10
- lionagi/version.py +1 -1
- {lionagi-0.0.114.dist-info → lionagi-0.0.116.dist-info}/METADATA +2 -2
- lionagi-0.0.116.dist-info/RECORD +110 -0
- lionagi/core/conversations.py +0 -108
- lionagi/core/flows.py +0 -1
- lionagi/core/instruction_sets.py +0 -1
- lionagi/core/messages.py +0 -166
- lionagi/core/sessions.py +0 -297
- lionagi/schema/base_schema.py +0 -252
- lionagi/services/chatcompletion.py +0 -48
- lionagi/services/service_objs.py +0 -282
- lionagi/structure/structure.py +0 -160
- lionagi/tools/coder.py +0 -1
- lionagi/tools/sandbox.py +0 -1
- lionagi/utils/tool_util.py +0 -92
- lionagi/utils/type_util.py +0 -81
- lionagi-0.0.114.dist-info/RECORD +0 -84
- /lionagi/configs/{openrouter_config.py → openrouter_configs.py} +0 -0
- /lionagi/{datastore → datastores}/__init__.py +0 -0
- /lionagi/{datastore → datastores}/chroma.py +0 -0
- /lionagi/{datastore → datastores}/deeplake.py +0 -0
- /lionagi/{datastore → datastores}/elasticsearch.py +0 -0
- /lionagi/{datastore → datastores}/lantern.py +0 -0
- /lionagi/{datastore → datastores}/pinecone.py +0 -0
- /lionagi/{datastore → datastores}/postgres.py +0 -0
- /lionagi/{datastore → datastores}/qdrant.py +0 -0
- /lionagi/{structure → structures}/__init__.py +0 -0
- {lionagi-0.0.114.dist-info → lionagi-0.0.116.dist-info}/LICENSE +0 -0
- {lionagi-0.0.114.dist-info → lionagi-0.0.116.dist-info}/WHEEL +0 -0
- {lionagi-0.0.114.dist-info → lionagi-0.0.116.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,9 @@ import tiktoken
|
|
5
5
|
import logging
|
6
6
|
import aiohttp
|
7
7
|
from typing import Generator, NoReturn, Dict, Any, Optional
|
8
|
-
from .
|
8
|
+
from ..objs.status_tracker import StatusTracker
|
9
|
+
from ..objs.abc_objs import BaseService, RateLimiter
|
10
|
+
from ..objs.async_queue import AsyncQueue
|
9
11
|
|
10
12
|
class BaseAPIRateLimiter(RateLimiter):
|
11
13
|
|
@@ -20,7 +22,7 @@ class BaseAPIRateLimiter(RateLimiter):
|
|
20
22
|
) -> None:
|
21
23
|
self = cls(max_requests_per_minute, max_tokens_per_minute)
|
22
24
|
if not os.getenv("env_readthedocs"):
|
23
|
-
self.rate_limit_replenisher_task =
|
25
|
+
self.rate_limit_replenisher_task = asyncio.create_task(
|
24
26
|
self.rate_limit_replenisher()
|
25
27
|
)
|
26
28
|
return self
|
@@ -43,6 +45,7 @@ class BaseAPIRateLimiter(RateLimiter):
|
|
43
45
|
self.available_request_capacity = self.max_requests_per_minute
|
44
46
|
self.available_token_capacity = self.max_tokens_per_minute
|
45
47
|
|
48
|
+
# credit to OpenAI for the following function
|
46
49
|
def calculate_num_token(
|
47
50
|
self,
|
48
51
|
payload: Dict[str, Any] = None,
|
@@ -130,13 +133,16 @@ class BaseAPIService(BaseService):
|
|
130
133
|
def __init__(self, api_key: str = None,
|
131
134
|
status_tracker = None,
|
132
135
|
queue = None, endpoint=None, schema=None,
|
133
|
-
ratelimiter=
|
136
|
+
ratelimiter=BaseAPIRateLimiter, max_requests_per_minute=None, max_tokens_per_minute=None) -> None:
|
134
137
|
self.api_key = api_key
|
135
138
|
self.status_tracker = status_tracker or StatusTracker()
|
136
139
|
self.queue = queue or AsyncQueue()
|
137
140
|
self.endpoint=endpoint
|
138
141
|
self.schema = schema
|
139
|
-
self.
|
142
|
+
self.max_requests_per_minute = max_requests_per_minute
|
143
|
+
self.max_tokens_per_minute = max_tokens_per_minute
|
144
|
+
self.rate_limiter_class = ratelimiter
|
145
|
+
self.rate_limiter = None
|
140
146
|
|
141
147
|
@staticmethod
|
142
148
|
def api_methods(http_session, method="post"):
|
@@ -168,6 +174,10 @@ class BaseAPIService(BaseService):
|
|
168
174
|
yield task_id
|
169
175
|
task_id += 1
|
170
176
|
|
177
|
+
async def _init(self):
|
178
|
+
if self.rate_limiter is None:
|
179
|
+
self.rate_limiter = await self.rate_limiter_class.create(self.max_requests_per_minute, self.max_tokens_per_minute)
|
180
|
+
|
171
181
|
async def _call_api(self, http_session, endpoint_, method="post", payload: Dict[str, any] =None) -> Optional[Dict[str, any]]:
|
172
182
|
endpoint_ = self.api_endpoint_from_url("https://api.openai.com/v1/"+endpoint_)
|
173
183
|
|
@@ -224,4 +234,4 @@ class BaseAPIService(BaseService):
|
|
224
234
|
except Exception as e:
|
225
235
|
self.status_tracker.num_tasks_failed += 1
|
226
236
|
raise e
|
227
|
-
|
237
|
+
|
lionagi/services/oai.py
CHANGED
@@ -15,7 +15,7 @@ class OpenAIService(BaseAPIService):
|
|
15
15
|
max_attempts: int = 3,
|
16
16
|
max_requests_per_minute: int = 500,
|
17
17
|
max_tokens_per_minute: int = 150_000,
|
18
|
-
ratelimiter = BaseAPIRateLimiter
|
18
|
+
ratelimiter = BaseAPIRateLimiter,
|
19
19
|
status_tracker = None,
|
20
20
|
queue = None,
|
21
21
|
):
|
@@ -31,4 +31,4 @@ class OpenAIService(BaseAPIService):
|
|
31
31
|
|
32
32
|
async def serve(self, payload, endpoint_="chat/completions", method="post"):
|
33
33
|
return await self._serve(payload=payload, endpoint_=endpoint_, method=method)
|
34
|
-
|
34
|
+
|
lionagi/services/openrouter.py
CHANGED
@@ -2,8 +2,7 @@ from os import getenv
|
|
2
2
|
from .base_api_service import BaseAPIService, BaseAPIRateLimiter
|
3
3
|
|
4
4
|
class OpenRouterService(BaseAPIService):
|
5
|
-
|
6
|
-
|
5
|
+
key_scheme = "OPENROUTER_API_KEY"
|
7
6
|
base_url = "https://openrouter.ai/api/v1/"
|
8
7
|
|
9
8
|
def __init__(
|
@@ -18,7 +17,7 @@ class OpenRouterService(BaseAPIService):
|
|
18
17
|
queue = None,
|
19
18
|
):
|
20
19
|
super().__init__(
|
21
|
-
api_key = api_key or getenv(self.
|
20
|
+
api_key = api_key or getenv(self.key_scheme),
|
22
21
|
status_tracker = status_tracker,
|
23
22
|
queue = queue,
|
24
23
|
ratelimiter=ratelimiter,
|
@@ -0,0 +1,96 @@
|
|
1
|
+
from pydantic import Field
|
2
|
+
|
3
|
+
from lionagi.schema.base_node import BaseNode
|
4
|
+
from .relationship import Relationship
|
5
|
+
from lionagi.utils.call_util import lcall
|
6
|
+
|
7
|
+
|
8
|
+
class Graph(BaseNode):
|
9
|
+
nodes: dict = Field(default={})
|
10
|
+
relationships: dict = Field(default={})
|
11
|
+
node_relationships: dict = Field(default={})
|
12
|
+
|
13
|
+
def add_node(self, node: BaseNode):
|
14
|
+
self.nodes[node.id_] = node
|
15
|
+
self.node_relationships[node.id_] = {'in': {}, 'out': {}}
|
16
|
+
|
17
|
+
def add_relationship(self, relationships: Relationship):
|
18
|
+
if relationships.source_node_id not in self.node_relationships.keys():
|
19
|
+
raise KeyError(f'node {relationships.source_node_id} is not found.')
|
20
|
+
if relationships.target_node_id not in self.node_relationships.keys():
|
21
|
+
raise KeyError(f'node {relationships.target_node_id} is not found.')
|
22
|
+
|
23
|
+
self.relationships[relationships.id_] = relationships
|
24
|
+
self.node_relationships[relationships.source_node_id]['out'][relationships.id_] = relationships.target_node_id
|
25
|
+
self.node_relationships[relationships.target_node_id]['in'][relationships.id_] = relationships.source_node_id
|
26
|
+
|
27
|
+
def get_node_relationships(self, node: BaseNode = None, out_edge=True):
|
28
|
+
if node is None:
|
29
|
+
return list(self.relationships.values())
|
30
|
+
|
31
|
+
if node.id_ not in self.nodes.keys():
|
32
|
+
raise KeyError(f'node {node.id_} is not found')
|
33
|
+
|
34
|
+
if out_edge:
|
35
|
+
relationship_ids = list(self.node_relationships[node.id_]['out'].keys())
|
36
|
+
relationships = lcall(relationship_ids, lambda i: self.relationships[i])
|
37
|
+
return relationships
|
38
|
+
else:
|
39
|
+
relationship_ids = list(self.node_relationships[node.id_]['in'].keys())
|
40
|
+
relationships = lcall(relationship_ids, lambda i: self.relationships[i])
|
41
|
+
return relationships
|
42
|
+
|
43
|
+
def remove_node(self, node: BaseNode):
|
44
|
+
if node.id_ not in self.nodes.keys():
|
45
|
+
raise KeyError(f'node {node.id_} is not found')
|
46
|
+
|
47
|
+
out_relationship = self.node_relationships[node.id_]['out']
|
48
|
+
for relationship_id, node_id in out_relationship.items():
|
49
|
+
self.node_relationships[node_id]['in'].pop(relationship_id)
|
50
|
+
self.relationships.pop(relationship_id)
|
51
|
+
|
52
|
+
in_relationship = self.node_relationships[node.id_]['in']
|
53
|
+
for relationship_id, node_id in in_relationship.items():
|
54
|
+
self.node_relationships[node_id]['out'].pop(relationship_id)
|
55
|
+
self.relationships.pop(relationship_id)
|
56
|
+
|
57
|
+
self.node_relationships.pop(node.id_)
|
58
|
+
return self.nodes.pop(node.id_)
|
59
|
+
|
60
|
+
def remove_relationship(self, relationship: Relationship):
|
61
|
+
if relationship.id_ not in self.relationships.keys():
|
62
|
+
raise KeyError(f'relationship {relationship.id_} is not found')
|
63
|
+
|
64
|
+
self.node_relationships[relationship.source_node_id]['out'].pop(relationship.id_)
|
65
|
+
self.node_relationships[relationship.target_node_id]['in'].pop(relationship.id_)
|
66
|
+
|
67
|
+
return self.relationships.pop(relationship.id_)
|
68
|
+
|
69
|
+
def node_exists(self, node: BaseNode):
|
70
|
+
if node.id_ in self.nodes.keys():
|
71
|
+
return True
|
72
|
+
else:
|
73
|
+
return False
|
74
|
+
|
75
|
+
def relationship_exists(self, relationship: Relationship):
|
76
|
+
if relationship.id_ in self.relationships.keys():
|
77
|
+
return True
|
78
|
+
else:
|
79
|
+
return False
|
80
|
+
|
81
|
+
def to_networkx(self, **kwargs):
|
82
|
+
import networkx as nx
|
83
|
+
g = nx.DiGraph(**kwargs)
|
84
|
+
for node_id, node in self.nodes.items():
|
85
|
+
node_info = node.to_dict()
|
86
|
+
node_info.pop('node_id')
|
87
|
+
g.add_node(node_id, **node_info)
|
88
|
+
|
89
|
+
for relationship_id, relationship in self.relationships.items():
|
90
|
+
relationship_info = relationship.to_dict()
|
91
|
+
relationship_info.pop('node_id')
|
92
|
+
source_node_id = relationship_info.pop('source_node_id')
|
93
|
+
target_node_id = relationship_info.pop('target_node_id')
|
94
|
+
g.add_edge(source_node_id, target_node_id, **relationship_info)
|
95
|
+
|
96
|
+
return g
|
@@ -1,6 +1,6 @@
|
|
1
1
|
from pydantic import Field
|
2
2
|
from typing import Dict, Optional, Any
|
3
|
-
from ..schema.
|
3
|
+
from ..schema.base_node import BaseNode
|
4
4
|
|
5
5
|
|
6
6
|
class Relationship(BaseNode):
|
@@ -125,4 +125,12 @@ class Relationship(BaseNode):
|
|
125
125
|
raise ValueError(f"Target node {self.source_node_id} does not exist")
|
126
126
|
else :
|
127
127
|
raise ValueError(f"Source node {self.target_node_id} does not exist")
|
128
|
-
|
128
|
+
|
129
|
+
def __str__(self) -> str:
|
130
|
+
"""Returns a simple string representation of the Relationship."""
|
131
|
+
return f"Relationship (id_={self.id_}, from={self.source_node_id}, to={self.target_node_id}, label={self.label})"
|
132
|
+
|
133
|
+
def __repr__(self) -> str:
|
134
|
+
"""Returns a detailed string representation of the Relationship."""
|
135
|
+
return f"Relationship(id_={self.id_}, from={self.source_node_id}, to={self.target_node_id}, content={self.content}, " \
|
136
|
+
f"metadata={self.metadata}, label={self.label})"
|
@@ -0,0 +1,102 @@
|
|
1
|
+
from typing import TypeVar
|
2
|
+
from .graph import Graph
|
3
|
+
from ..schema import BaseNode
|
4
|
+
from .relationship import Relationship
|
5
|
+
|
6
|
+
T = TypeVar('T', bound='BaseNode')
|
7
|
+
R = TypeVar('R', bound='Relationship')
|
8
|
+
|
9
|
+
|
10
|
+
class Structure(BaseNode):
|
11
|
+
"""
|
12
|
+
Represents the structure of a graph consisting of nodes and relationships.
|
13
|
+
"""
|
14
|
+
graph: Graph
|
15
|
+
|
16
|
+
def add_node(self, node: T) -> None:
|
17
|
+
"""
|
18
|
+
Adds a node to the structure.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
node (T): The node instance to be added.
|
22
|
+
"""
|
23
|
+
self.graph.add_node(node)
|
24
|
+
|
25
|
+
def add_relationship(self, relationship: R) -> None:
|
26
|
+
"""
|
27
|
+
Adds a relationship to the structure.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
relationship (R): The relationship instance to be added.
|
31
|
+
"""
|
32
|
+
self.graph.add_relationship(relationship)
|
33
|
+
|
34
|
+
# type can be dict or list
|
35
|
+
# @staticmethod
|
36
|
+
# def _typed_return(type: Type[Union[Dict, List]],
|
37
|
+
# obj: Optional[Dict[str, Any]] = None
|
38
|
+
# ) -> Union[Dict[str, Any], List[Any]]:
|
39
|
+
# """
|
40
|
+
# Returns the object in the specified type format.
|
41
|
+
#
|
42
|
+
# Args:
|
43
|
+
# type (Type[Union[Dict, List]]): The type to return the object as (dict or list).
|
44
|
+
#
|
45
|
+
# obj (Optional[Dict[str, Any]]): The object to be converted.
|
46
|
+
#
|
47
|
+
# Returns:
|
48
|
+
# Union[Dict[str, Any], List[Any]]: The object in the specified type format.
|
49
|
+
# """
|
50
|
+
# if type is list:
|
51
|
+
# return list(obj.values())
|
52
|
+
# return obj
|
53
|
+
|
54
|
+
def get_relationships(self) -> list[R]:
|
55
|
+
return self.graph.get_node_relationships()
|
56
|
+
|
57
|
+
def get_node_relationships(self, node: T, out_edge=True) -> R:
|
58
|
+
return self.graph.get_node_relationships(node, out_edge)
|
59
|
+
|
60
|
+
def node_exist(self, node: T) -> bool:
|
61
|
+
"""
|
62
|
+
Checks if a node exists in the structure.
|
63
|
+
|
64
|
+
Args:
|
65
|
+
node (T): The node instance or node ID to check for existence.
|
66
|
+
|
67
|
+
Returns:
|
68
|
+
bool: True if the node exists, False otherwise.
|
69
|
+
"""
|
70
|
+
|
71
|
+
return self.graph.node_exist(node)
|
72
|
+
|
73
|
+
def relationship_exist(self, relationship: R) -> bool:
|
74
|
+
"""
|
75
|
+
Checks if a relationship exists in the structure.
|
76
|
+
|
77
|
+
Args:
|
78
|
+
relationship (R): The relationship instance to check for existence.
|
79
|
+
|
80
|
+
Returns:
|
81
|
+
bool: True if the relationship exists, False otherwise.
|
82
|
+
"""
|
83
|
+
return self.graph.relationship_exists(relationship)
|
84
|
+
|
85
|
+
def remove_node(self, node: T) -> T:
|
86
|
+
"""
|
87
|
+
Removes a node and its associated relationships from the structure.
|
88
|
+
|
89
|
+
Args:
|
90
|
+
node (T): The node instance or node ID to be removed.
|
91
|
+
"""
|
92
|
+
return self.graph.remove_node(node)
|
93
|
+
|
94
|
+
def remove_relationship(self, relationship: R) -> R:
|
95
|
+
"""
|
96
|
+
Removes a relationship from the structure.
|
97
|
+
|
98
|
+
Args:
|
99
|
+
relationship (R): The relationship instance to be removed.
|
100
|
+
"""
|
101
|
+
return self.graph.remove_relationship(relationship)
|
102
|
+
|
@@ -0,0 +1,46 @@
|
|
1
|
+
import unittest
|
2
|
+
from unittest.mock import MagicMock
|
3
|
+
|
4
|
+
# Assuming the Python module with the above functions is named 'api_utils'
|
5
|
+
from lionagi.utils.api_util import *
|
6
|
+
|
7
|
+
class TestApiUtils(unittest.TestCase):
|
8
|
+
|
9
|
+
def test_api_method_valid(self):
|
10
|
+
session = MagicMock()
|
11
|
+
methods = ['post', 'delete', 'head', 'options', 'patch']
|
12
|
+
for method in methods:
|
13
|
+
with self.subTest(method=method):
|
14
|
+
func = api_method(session, method)
|
15
|
+
self.assertTrue(callable(func))
|
16
|
+
|
17
|
+
def test_api_method_invalid(self):
|
18
|
+
session = MagicMock()
|
19
|
+
with self.assertRaises(ValueError):
|
20
|
+
api_method(session, 'get')
|
21
|
+
|
22
|
+
def test_api_error(self):
|
23
|
+
response_with_error = {'error': 'Something went wrong'}
|
24
|
+
response_without_error = {'result': 'Success'}
|
25
|
+
|
26
|
+
self.assertTrue(api_error(response_with_error))
|
27
|
+
self.assertFalse(api_error(response_without_error))
|
28
|
+
|
29
|
+
def test_api_rate_limit_error(self):
|
30
|
+
response_with_rate_limit_error = {'error': {'message': 'Rate limit exceeded'}}
|
31
|
+
response_without_rate_limit_error = {'error': {'message': 'Another error'}}
|
32
|
+
|
33
|
+
self.assertTrue(api_rate_limit_error(response_with_rate_limit_error))
|
34
|
+
self.assertFalse(api_rate_limit_error(response_without_rate_limit_error))
|
35
|
+
|
36
|
+
def test_api_endpoint_from_url(self):
|
37
|
+
url_with_endpoint = "https://api.example.com/v1/users"
|
38
|
+
url_without_endpoint = "https://api.example.com/users"
|
39
|
+
url_invalid = "Just a string"
|
40
|
+
|
41
|
+
self.assertEqual(api_endpoint_from_url(url_with_endpoint), 'users')
|
42
|
+
self.assertEqual(api_endpoint_from_url(url_without_endpoint), '')
|
43
|
+
self.assertEqual(api_endpoint_from_url(url_invalid), '')
|
44
|
+
|
45
|
+
if __name__ == '__main__':
|
46
|
+
unittest.main()
|
@@ -0,0 +1,115 @@
|
|
1
|
+
# import asyncio
|
2
|
+
# import unittest
|
3
|
+
# from unittest.mock import patch, MagicMock
|
4
|
+
# from lionagi.utils.sys_util import create_copy
|
5
|
+
# from lionagi.utils.flat_util import to_list
|
6
|
+
# from typing import Callable
|
7
|
+
|
8
|
+
# # Assuming the Python module with the above functions is named 'call_utils'
|
9
|
+
# from lionagi.utils.call_util import *
|
10
|
+
|
11
|
+
# class TestCallUtils(unittest.TestCase):
|
12
|
+
|
13
|
+
# def setUp(self):
|
14
|
+
# self.sample_input = [1, 2, 3]
|
15
|
+
# self.sample_func = MagicMock(return_value=42)
|
16
|
+
|
17
|
+
# def test_hcall(self):
|
18
|
+
# with patch('time.sleep', return_value=None):
|
19
|
+
# result = hcall(self.sample_input, self.sample_func)
|
20
|
+
# self.sample_func.assert_called_once_with(self.sample_input)
|
21
|
+
# self.assertEqual(result, 42)
|
22
|
+
|
23
|
+
# def test_ahcall(self):
|
24
|
+
# async def async_test():
|
25
|
+
# self.sample_func.reset_mock()
|
26
|
+
# with patch('asyncio.sleep', return_value=None):
|
27
|
+
# result = await ahcall(self.sample_input, self.sample_func)
|
28
|
+
# self.sample_func.assert_called_once_with(self.sample_input)
|
29
|
+
# self.assertEqual(result, 42)
|
30
|
+
|
31
|
+
# asyncio.run(async_test())
|
32
|
+
|
33
|
+
# def test_lcall(self):
|
34
|
+
# expected_result = [42, 42, 42]
|
35
|
+
# result = lcall(self.sample_input, self.sample_func)
|
36
|
+
# self.assertEqual(result, expected_result)
|
37
|
+
# calls = [unittest.mock.call(item) for item in self.sample_input]
|
38
|
+
# self.sample_func.assert_has_calls(calls, any_order=True)
|
39
|
+
|
40
|
+
# def test_alcall(self):
|
41
|
+
# async def async_test():
|
42
|
+
# expected_result = [42, 42, 42]
|
43
|
+
# self.sample_func.reset_mock()
|
44
|
+
# result = await alcall(self.sample_input, self.sample_func)
|
45
|
+
# self.assertEqual(result, expected_result)
|
46
|
+
# calls = [unittest.mock.call(item) for item in self.sample_input]
|
47
|
+
# self.sample_func.assert_has_calls(calls, any_order=True)
|
48
|
+
|
49
|
+
# asyncio.run(async_test())
|
50
|
+
|
51
|
+
|
52
|
+
# class TestCallUtils2(unittest.TestCase):
|
53
|
+
|
54
|
+
# def setUp(self):
|
55
|
+
# self.sample_input = [1, 2, 3]
|
56
|
+
# self.sample_func = MagicMock(side_effect=lambda x: x + 1)
|
57
|
+
# self.sample_async_func = MagicMock(side_effect=lambda x: x + 1)
|
58
|
+
|
59
|
+
# def test_mcall_single_function(self):
|
60
|
+
# result = mcall(self.sample_input, self.sample_func)
|
61
|
+
# self.sample_func.assert_has_calls([unittest.mock.call(i) for i in self.sample_input], any_order=True)
|
62
|
+
# self.assertEqual(result, [2, 3, 4])
|
63
|
+
|
64
|
+
# def test_mcall_multiple_functions(self):
|
65
|
+
# # Define multiple functions
|
66
|
+
# funcs = [
|
67
|
+
# MagicMock(side_effect=lambda x: x + 1),
|
68
|
+
# MagicMock(side_effect=lambda x: x * 2),
|
69
|
+
# MagicMock(side_effect=lambda x: x - 1)
|
70
|
+
# ]
|
71
|
+
# result = mcall(self.sample_input, funcs)
|
72
|
+
# for i, func in enumerate(funcs):
|
73
|
+
# func.assert_called_once_with(self.sample_input[i])
|
74
|
+
# self.assertEqual(result, [2, 4, 2])
|
75
|
+
|
76
|
+
# def test_amcall_single_function(self):
|
77
|
+
# async def async_test():
|
78
|
+
# result = await amcall(self.sample_input, self.sample_async_func)
|
79
|
+
# self.sample_async_func.assert_has_calls([unittest.mock.call(i) for i in self.sample_input], any_order=True)
|
80
|
+
# self.assertEqual(result, [2, 3, 4])
|
81
|
+
|
82
|
+
# asyncio.run(async_test())
|
83
|
+
|
84
|
+
# def test_amcall_multiple_functions(self):
|
85
|
+
# # Define multiple asynchronous functions
|
86
|
+
# async_funcs = [
|
87
|
+
# MagicMock(side_effect=lambda x: x + 1),
|
88
|
+
# MagicMock(side_effect=lambda x: x * 2),
|
89
|
+
# MagicMock(side_effect=lambda x: x - 1)
|
90
|
+
# ]
|
91
|
+
|
92
|
+
# async def async_test():
|
93
|
+
# result = await amcall(self.sample_input, async_funcs)
|
94
|
+
# for i, func in enumerate(async_funcs):
|
95
|
+
# func.assert_called_once_with(self.sample_input[i])
|
96
|
+
# self.assertEqual(result, [2, 4, 2])
|
97
|
+
|
98
|
+
# asyncio.run(async_test())
|
99
|
+
|
100
|
+
# def test_ecall(self):
|
101
|
+
# funcs = [MagicMock(side_effect=lambda x: x * x)]
|
102
|
+
# result = ecall(self.sample_input, funcs)
|
103
|
+
# self.assertEqual(result, [[1, 4, 9]])
|
104
|
+
|
105
|
+
# def test_aecall(self):
|
106
|
+
# async_funcs = [MagicMock(side_effect=lambda x: x * x)]
|
107
|
+
|
108
|
+
# async def async_test():
|
109
|
+
# result = await aecall(self.sample_input, async_funcs)
|
110
|
+
# self.assertEqual(result, [[1, 4, 9]])
|
111
|
+
|
112
|
+
# asyncio.run(async_test())
|
113
|
+
|
114
|
+
# if __name__ == '__main__':
|
115
|
+
# unittest.main()
|
@@ -0,0 +1,202 @@
|
|
1
|
+
import unittest
|
2
|
+
from lionagi.utils.convert_util import *
|
3
|
+
|
4
|
+
# Test cases for the function
|
5
|
+
class TestStrToNum(unittest.TestCase):
|
6
|
+
def test_valid_int(self):
|
7
|
+
self.assertEqual(str_to_num("123"), 123)
|
8
|
+
self.assertEqual(str_to_num("-123"), -123)
|
9
|
+
|
10
|
+
def test_valid_float(self):
|
11
|
+
self.assertEqual(str_to_num("123.45", num_type=float), 123.45)
|
12
|
+
self.assertEqual(str_to_num("-123.45", num_type=float), -123.45)
|
13
|
+
|
14
|
+
def test_precision(self):
|
15
|
+
self.assertEqual(str_to_num("123.456", num_type=float, precision=1), 123.5)
|
16
|
+
self.assertEqual(str_to_num("123.444", num_type=float, precision=2), 123.44)
|
17
|
+
|
18
|
+
def test_bounds(self):
|
19
|
+
self.assertEqual(str_to_num("10", lower_bound=5, upper_bound=15), 10)
|
20
|
+
with self.assertRaises(ValueError):
|
21
|
+
str_to_num("20", upper_bound=15)
|
22
|
+
with self.assertRaises(ValueError):
|
23
|
+
str_to_num("2", lower_bound=5)
|
24
|
+
|
25
|
+
def test_invalid_input(self):
|
26
|
+
with self.assertRaises(ValueError):
|
27
|
+
str_to_num("abc")
|
28
|
+
with self.assertRaises(ValueError):
|
29
|
+
str_to_num("123abc", num_type=str)
|
30
|
+
|
31
|
+
def test_no_numeric_value(self):
|
32
|
+
with self.assertRaises(ValueError):
|
33
|
+
str_to_num("No numbers here")
|
34
|
+
|
35
|
+
|
36
|
+
|
37
|
+
# Functions to be tested
|
38
|
+
def dict_to_xml(data: Dict[str, Any], root_tag: str = 'node') -> str:
|
39
|
+
root = ET.Element(root_tag)
|
40
|
+
_build_xml(root, data)
|
41
|
+
return ET.tostring(root, encoding='unicode')
|
42
|
+
|
43
|
+
def _build_xml(element: ET.Element, data: Any):
|
44
|
+
if isinstance(data, dict):
|
45
|
+
for key, value in data.items():
|
46
|
+
sub_element = ET.SubElement(element, key)
|
47
|
+
_build_xml(sub_element, value)
|
48
|
+
elif isinstance(data, list):
|
49
|
+
for item in data:
|
50
|
+
item_element = ET.SubElement(element, 'item')
|
51
|
+
_build_xml(item_element, item)
|
52
|
+
else:
|
53
|
+
element.text = str(data)
|
54
|
+
|
55
|
+
def xml_to_dict(element: ET.Element) -> Dict[str, Any]:
|
56
|
+
dict_data = {}
|
57
|
+
for child in element:
|
58
|
+
if list(child):
|
59
|
+
dict_data[child.tag] = xml_to_dict(child)
|
60
|
+
else:
|
61
|
+
dict_data[child.tag] = child.text
|
62
|
+
return dict_data
|
63
|
+
|
64
|
+
# Test cases for the functions
|
65
|
+
class TestDictXMLConversion(unittest.TestCase):
|
66
|
+
def setUp(self):
|
67
|
+
self.data = {
|
68
|
+
'name': 'John',
|
69
|
+
'age': 30,
|
70
|
+
'children': [
|
71
|
+
{'name': 'Alice', 'age': 5},
|
72
|
+
{'name': 'Bob', 'age': 7}
|
73
|
+
]
|
74
|
+
}
|
75
|
+
self.root_tag = 'person'
|
76
|
+
self.xml = dict_to_xml(self.data, self.root_tag)
|
77
|
+
self.xml_element = ET.fromstring(self.xml)
|
78
|
+
|
79
|
+
def test_dict_to_xml(self):
|
80
|
+
self.assertIn('<name>John</name>', self.xml)
|
81
|
+
self.assertIn('<age>30</age>', self.xml)
|
82
|
+
self.assertIn('<children>', self.xml)
|
83
|
+
self.assertIn('<item>', self.xml)
|
84
|
+
|
85
|
+
# def test_xml_to_dict(self):
|
86
|
+
# data_from_xml = xml_to_dict(self.xml_element)
|
87
|
+
# self.assertEqual(data_from_xml, self.data)
|
88
|
+
|
89
|
+
# def test_xml_to_dict_to_xml(self):
|
90
|
+
# data_from_xml = xml_to_dict(self.xml_element)
|
91
|
+
# xml_from_dict = dict_to_xml(data_from_xml, self.root_tag)
|
92
|
+
# self.assertEqual(xml_from_dict, self.xml)
|
93
|
+
|
94
|
+
# def test_invalid_input(self):
|
95
|
+
# with self.assertRaises(TypeError):
|
96
|
+
# dict_to_xml("not a dict", self.root_tag)
|
97
|
+
|
98
|
+
|
99
|
+
|
100
|
+
class TestDocstringExtraction(unittest.TestCase):
|
101
|
+
def test_google_style_extraction(self):
|
102
|
+
def sample_func(arg1, arg2):
|
103
|
+
"""
|
104
|
+
This is a sample function.
|
105
|
+
|
106
|
+
Args:
|
107
|
+
arg1 (int): The first argument.
|
108
|
+
arg2 (str): The second argument.
|
109
|
+
|
110
|
+
Returns:
|
111
|
+
bool: The truth value.
|
112
|
+
"""
|
113
|
+
return True
|
114
|
+
|
115
|
+
description, params = extract_docstring_details_google(sample_func)
|
116
|
+
self.assertEqual(description, "This is a sample function.")
|
117
|
+
self.assertEqual(params, {
|
118
|
+
'arg1': 'The first argument.',
|
119
|
+
'arg2': 'The second argument.'
|
120
|
+
})
|
121
|
+
|
122
|
+
def test_rest_style_extraction(self):
|
123
|
+
def sample_func(arg1, arg2):
|
124
|
+
"""
|
125
|
+
This is a sample function.
|
126
|
+
|
127
|
+
:param int arg1: The first argument.
|
128
|
+
:param str arg2: The second argument.
|
129
|
+
:return: The truth value.
|
130
|
+
:rtype: bool
|
131
|
+
"""
|
132
|
+
return True
|
133
|
+
|
134
|
+
description, params = extract_docstring_details_rest(sample_func)
|
135
|
+
self.assertEqual(description, "This is a sample function.")
|
136
|
+
self.assertEqual(params, {
|
137
|
+
'arg1': 'The first argument.',
|
138
|
+
'arg2': 'The second argument.'
|
139
|
+
})
|
140
|
+
|
141
|
+
def test_extract_docstring_details_with_invalid_style(self):
|
142
|
+
def sample_func(arg1, arg2):
|
143
|
+
return True
|
144
|
+
|
145
|
+
with self.assertRaises(ValueError):
|
146
|
+
extract_docstring_details(sample_func, style='unsupported')
|
147
|
+
|
148
|
+
class TestPythonToJsonTypeConversion(unittest.TestCase):
|
149
|
+
def test_python_to_json_type_conversion(self):
|
150
|
+
self.assertEqual(python_to_json_type('str'), 'string')
|
151
|
+
self.assertEqual(python_to_json_type('int'), 'number')
|
152
|
+
self.assertEqual(python_to_json_type('float'), 'number')
|
153
|
+
self.assertEqual(python_to_json_type('list'), 'array')
|
154
|
+
self.assertEqual(python_to_json_type('tuple'), 'array')
|
155
|
+
self.assertEqual(python_to_json_type('bool'), 'boolean')
|
156
|
+
self.assertEqual(python_to_json_type('dict'), 'object')
|
157
|
+
self.assertEqual(python_to_json_type('nonexistent'), 'object')
|
158
|
+
|
159
|
+
class TestFunctionToSchema(unittest.TestCase):
|
160
|
+
def test_func_to_schema(self):
|
161
|
+
def sample_func(arg1: int, arg2: str) -> bool:
|
162
|
+
"""
|
163
|
+
This is a sample function.
|
164
|
+
|
165
|
+
Args:
|
166
|
+
arg1 (int): The first argument.
|
167
|
+
arg2 (str): The second argument.
|
168
|
+
|
169
|
+
Returns:
|
170
|
+
bool: The truth value.
|
171
|
+
"""
|
172
|
+
return True
|
173
|
+
|
174
|
+
schema = func_to_schema(sample_func, style='google')
|
175
|
+
expected_schema = {
|
176
|
+
"type": "function",
|
177
|
+
"function": {
|
178
|
+
"name": "sample_func",
|
179
|
+
"description": "This is a sample function.",
|
180
|
+
"parameters": {
|
181
|
+
"type": "object",
|
182
|
+
"properties": {
|
183
|
+
"arg1": {
|
184
|
+
"type": "number",
|
185
|
+
"description": "The first argument."
|
186
|
+
},
|
187
|
+
"arg2": {
|
188
|
+
"type": "string",
|
189
|
+
"description": "The second argument."
|
190
|
+
}
|
191
|
+
},
|
192
|
+
"required": ["arg1", "arg2"]
|
193
|
+
}
|
194
|
+
}
|
195
|
+
}
|
196
|
+
self.assertEqual(schema, expected_schema)
|
197
|
+
|
198
|
+
if __name__ == '__main__':
|
199
|
+
unittest.main()
|
200
|
+
|
201
|
+
|
202
|
+
# --------------------------------------------------
|