lionagi 0.0.114__py3-none-any.whl → 0.0.116__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (99) hide show
  1. lionagi/__init__.py +7 -4
  2. lionagi/bridge/__init__.py +19 -4
  3. lionagi/bridge/langchain.py +23 -3
  4. lionagi/bridge/llama_index.py +5 -3
  5. lionagi/configs/__init__.py +1 -1
  6. lionagi/configs/oai_configs.py +88 -1
  7. lionagi/core/__init__.py +6 -9
  8. lionagi/core/conversations/__init__.py +5 -0
  9. lionagi/core/conversations/conversation.py +107 -0
  10. lionagi/core/flows/__init__.py +8 -0
  11. lionagi/core/flows/flow.py +8 -0
  12. lionagi/core/flows/flow_util.py +62 -0
  13. lionagi/core/instruction_set/__init__.py +5 -0
  14. lionagi/core/instruction_set/instruction_sets.py +7 -0
  15. lionagi/core/sessions/__init__.py +5 -0
  16. lionagi/core/sessions/sessions.py +187 -0
  17. lionagi/endpoints/__init__.py +5 -0
  18. lionagi/endpoints/assistants.py +0 -0
  19. lionagi/endpoints/audio.py +17 -0
  20. lionagi/endpoints/chatcompletion.py +54 -0
  21. lionagi/endpoints/embeddings.py +0 -0
  22. lionagi/endpoints/finetune.py +0 -0
  23. lionagi/endpoints/image.py +0 -0
  24. lionagi/endpoints/moderation.py +0 -0
  25. lionagi/endpoints/vision.py +0 -0
  26. lionagi/{loader → loaders}/__init__.py +7 -1
  27. lionagi/{loader → loaders}/chunker.py +6 -12
  28. lionagi/{utils/load_utils.py → loaders/load_util.py} +47 -6
  29. lionagi/{loader → loaders}/reader.py +4 -12
  30. lionagi/messages/__init__.py +11 -0
  31. lionagi/messages/instruction.py +15 -0
  32. lionagi/messages/message.py +110 -0
  33. lionagi/messages/response.py +33 -0
  34. lionagi/messages/system.py +12 -0
  35. lionagi/objs/__init__.py +10 -6
  36. lionagi/objs/abc_objs.py +39 -0
  37. lionagi/objs/async_queue.py +135 -0
  38. lionagi/objs/messenger.py +70 -148
  39. lionagi/objs/status_tracker.py +37 -0
  40. lionagi/objs/{tool_registry.py → tool_manager.py} +8 -6
  41. lionagi/schema/__init__.py +3 -3
  42. lionagi/schema/base_node.py +251 -0
  43. lionagi/schema/base_tool.py +8 -3
  44. lionagi/schema/data_logger.py +2 -3
  45. lionagi/schema/data_node.py +37 -0
  46. lionagi/services/__init__.py +1 -4
  47. lionagi/services/base_api_service.py +15 -5
  48. lionagi/services/oai.py +2 -2
  49. lionagi/services/openrouter.py +2 -3
  50. lionagi/structures/graph.py +96 -0
  51. lionagi/{structure → structures}/relationship.py +10 -2
  52. lionagi/structures/structure.py +102 -0
  53. lionagi/tests/test_api_util.py +46 -0
  54. lionagi/tests/test_call_util.py +115 -0
  55. lionagi/tests/test_convert_util.py +202 -0
  56. lionagi/tests/test_encrypt_util.py +33 -0
  57. lionagi/tests/{test_flatten_util.py → test_flat_util.py} +1 -1
  58. lionagi/tests/test_io_util.py +0 -0
  59. lionagi/tests/test_sys_util.py +0 -0
  60. lionagi/tools/__init__.py +5 -0
  61. lionagi/tools/tool_util.py +7 -0
  62. lionagi/utils/__init__.py +55 -35
  63. lionagi/utils/api_util.py +19 -17
  64. lionagi/utils/call_util.py +2 -1
  65. lionagi/utils/convert_util.py +229 -0
  66. lionagi/utils/encrypt_util.py +16 -0
  67. lionagi/utils/flat_util.py +38 -0
  68. lionagi/utils/io_util.py +2 -2
  69. lionagi/utils/sys_util.py +45 -10
  70. lionagi/version.py +1 -1
  71. {lionagi-0.0.114.dist-info → lionagi-0.0.116.dist-info}/METADATA +2 -2
  72. lionagi-0.0.116.dist-info/RECORD +110 -0
  73. lionagi/core/conversations.py +0 -108
  74. lionagi/core/flows.py +0 -1
  75. lionagi/core/instruction_sets.py +0 -1
  76. lionagi/core/messages.py +0 -166
  77. lionagi/core/sessions.py +0 -297
  78. lionagi/schema/base_schema.py +0 -252
  79. lionagi/services/chatcompletion.py +0 -48
  80. lionagi/services/service_objs.py +0 -282
  81. lionagi/structure/structure.py +0 -160
  82. lionagi/tools/coder.py +0 -1
  83. lionagi/tools/sandbox.py +0 -1
  84. lionagi/utils/tool_util.py +0 -92
  85. lionagi/utils/type_util.py +0 -81
  86. lionagi-0.0.114.dist-info/RECORD +0 -84
  87. /lionagi/configs/{openrouter_config.py → openrouter_configs.py} +0 -0
  88. /lionagi/{datastore → datastores}/__init__.py +0 -0
  89. /lionagi/{datastore → datastores}/chroma.py +0 -0
  90. /lionagi/{datastore → datastores}/deeplake.py +0 -0
  91. /lionagi/{datastore → datastores}/elasticsearch.py +0 -0
  92. /lionagi/{datastore → datastores}/lantern.py +0 -0
  93. /lionagi/{datastore → datastores}/pinecone.py +0 -0
  94. /lionagi/{datastore → datastores}/postgres.py +0 -0
  95. /lionagi/{datastore → datastores}/qdrant.py +0 -0
  96. /lionagi/{structure → structures}/__init__.py +0 -0
  97. {lionagi-0.0.114.dist-info → lionagi-0.0.116.dist-info}/LICENSE +0 -0
  98. {lionagi-0.0.114.dist-info → lionagi-0.0.116.dist-info}/WHEEL +0 -0
  99. {lionagi-0.0.114.dist-info → lionagi-0.0.116.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,9 @@ import tiktoken
5
5
  import logging
6
6
  import aiohttp
7
7
  from typing import Generator, NoReturn, Dict, Any, Optional
8
- from .service_objs import BaseService, RateLimiter, StatusTracker, AsyncQueue
8
+ from ..objs.status_tracker import StatusTracker
9
+ from ..objs.abc_objs import BaseService, RateLimiter
10
+ from ..objs.async_queue import AsyncQueue
9
11
 
10
12
  class BaseAPIRateLimiter(RateLimiter):
11
13
 
@@ -20,7 +22,7 @@ class BaseAPIRateLimiter(RateLimiter):
20
22
  ) -> None:
21
23
  self = cls(max_requests_per_minute, max_tokens_per_minute)
22
24
  if not os.getenv("env_readthedocs"):
23
- self.rate_limit_replenisher_task = await asyncio.create_task(
25
+ self.rate_limit_replenisher_task = asyncio.create_task(
24
26
  self.rate_limit_replenisher()
25
27
  )
26
28
  return self
@@ -43,6 +45,7 @@ class BaseAPIRateLimiter(RateLimiter):
43
45
  self.available_request_capacity = self.max_requests_per_minute
44
46
  self.available_token_capacity = self.max_tokens_per_minute
45
47
 
48
+ # credit to OpenAI for the following function
46
49
  def calculate_num_token(
47
50
  self,
48
51
  payload: Dict[str, Any] = None,
@@ -130,13 +133,16 @@ class BaseAPIService(BaseService):
130
133
  def __init__(self, api_key: str = None,
131
134
  status_tracker = None,
132
135
  queue = None, endpoint=None, schema=None,
133
- ratelimiter=None, max_requests_per_minute=None, max_tokens_per_minute=None) -> None:
136
+ ratelimiter=BaseAPIRateLimiter, max_requests_per_minute=None, max_tokens_per_minute=None) -> None:
134
137
  self.api_key = api_key
135
138
  self.status_tracker = status_tracker or StatusTracker()
136
139
  self.queue = queue or AsyncQueue()
137
140
  self.endpoint=endpoint
138
141
  self.schema = schema
139
- self.rate_limiter = ratelimiter(max_requests_per_minute, max_tokens_per_minute)
142
+ self.max_requests_per_minute = max_requests_per_minute
143
+ self.max_tokens_per_minute = max_tokens_per_minute
144
+ self.rate_limiter_class = ratelimiter
145
+ self.rate_limiter = None
140
146
 
141
147
  @staticmethod
142
148
  def api_methods(http_session, method="post"):
@@ -168,6 +174,10 @@ class BaseAPIService(BaseService):
168
174
  yield task_id
169
175
  task_id += 1
170
176
 
177
+ async def _init(self):
178
+ if self.rate_limiter is None:
179
+ self.rate_limiter = await self.rate_limiter_class.create(self.max_requests_per_minute, self.max_tokens_per_minute)
180
+
171
181
  async def _call_api(self, http_session, endpoint_, method="post", payload: Dict[str, any] =None) -> Optional[Dict[str, any]]:
172
182
  endpoint_ = self.api_endpoint_from_url("https://api.openai.com/v1/"+endpoint_)
173
183
 
@@ -224,4 +234,4 @@ class BaseAPIService(BaseService):
224
234
  except Exception as e:
225
235
  self.status_tracker.num_tasks_failed += 1
226
236
  raise e
227
-
237
+
lionagi/services/oai.py CHANGED
@@ -15,7 +15,7 @@ class OpenAIService(BaseAPIService):
15
15
  max_attempts: int = 3,
16
16
  max_requests_per_minute: int = 500,
17
17
  max_tokens_per_minute: int = 150_000,
18
- ratelimiter = BaseAPIRateLimiter ,
18
+ ratelimiter = BaseAPIRateLimiter,
19
19
  status_tracker = None,
20
20
  queue = None,
21
21
  ):
@@ -31,4 +31,4 @@ class OpenAIService(BaseAPIService):
31
31
 
32
32
  async def serve(self, payload, endpoint_="chat/completions", method="post"):
33
33
  return await self._serve(payload=payload, endpoint_=endpoint_, method=method)
34
-
34
+
@@ -2,8 +2,7 @@ from os import getenv
2
2
  from .base_api_service import BaseAPIService, BaseAPIRateLimiter
3
3
 
4
4
  class OpenRouterService(BaseAPIService):
5
- _key_scheme = "OPENROUTER_API_KEY"
6
-
5
+ key_scheme = "OPENROUTER_API_KEY"
7
6
  base_url = "https://openrouter.ai/api/v1/"
8
7
 
9
8
  def __init__(
@@ -18,7 +17,7 @@ class OpenRouterService(BaseAPIService):
18
17
  queue = None,
19
18
  ):
20
19
  super().__init__(
21
- api_key = api_key or getenv(self._key_scheme),
20
+ api_key = api_key or getenv(self.key_scheme),
22
21
  status_tracker = status_tracker,
23
22
  queue = queue,
24
23
  ratelimiter=ratelimiter,
@@ -0,0 +1,96 @@
1
+ from pydantic import Field
2
+
3
+ from lionagi.schema.base_node import BaseNode
4
+ from .relationship import Relationship
5
+ from lionagi.utils.call_util import lcall
6
+
7
+
8
+ class Graph(BaseNode):
9
+ nodes: dict = Field(default={})
10
+ relationships: dict = Field(default={})
11
+ node_relationships: dict = Field(default={})
12
+
13
+ def add_node(self, node: BaseNode):
14
+ self.nodes[node.id_] = node
15
+ self.node_relationships[node.id_] = {'in': {}, 'out': {}}
16
+
17
+ def add_relationship(self, relationships: Relationship):
18
+ if relationships.source_node_id not in self.node_relationships.keys():
19
+ raise KeyError(f'node {relationships.source_node_id} is not found.')
20
+ if relationships.target_node_id not in self.node_relationships.keys():
21
+ raise KeyError(f'node {relationships.target_node_id} is not found.')
22
+
23
+ self.relationships[relationships.id_] = relationships
24
+ self.node_relationships[relationships.source_node_id]['out'][relationships.id_] = relationships.target_node_id
25
+ self.node_relationships[relationships.target_node_id]['in'][relationships.id_] = relationships.source_node_id
26
+
27
+ def get_node_relationships(self, node: BaseNode = None, out_edge=True):
28
+ if node is None:
29
+ return list(self.relationships.values())
30
+
31
+ if node.id_ not in self.nodes.keys():
32
+ raise KeyError(f'node {node.id_} is not found')
33
+
34
+ if out_edge:
35
+ relationship_ids = list(self.node_relationships[node.id_]['out'].keys())
36
+ relationships = lcall(relationship_ids, lambda i: self.relationships[i])
37
+ return relationships
38
+ else:
39
+ relationship_ids = list(self.node_relationships[node.id_]['in'].keys())
40
+ relationships = lcall(relationship_ids, lambda i: self.relationships[i])
41
+ return relationships
42
+
43
+ def remove_node(self, node: BaseNode):
44
+ if node.id_ not in self.nodes.keys():
45
+ raise KeyError(f'node {node.id_} is not found')
46
+
47
+ out_relationship = self.node_relationships[node.id_]['out']
48
+ for relationship_id, node_id in out_relationship.items():
49
+ self.node_relationships[node_id]['in'].pop(relationship_id)
50
+ self.relationships.pop(relationship_id)
51
+
52
+ in_relationship = self.node_relationships[node.id_]['in']
53
+ for relationship_id, node_id in in_relationship.items():
54
+ self.node_relationships[node_id]['out'].pop(relationship_id)
55
+ self.relationships.pop(relationship_id)
56
+
57
+ self.node_relationships.pop(node.id_)
58
+ return self.nodes.pop(node.id_)
59
+
60
+ def remove_relationship(self, relationship: Relationship):
61
+ if relationship.id_ not in self.relationships.keys():
62
+ raise KeyError(f'relationship {relationship.id_} is not found')
63
+
64
+ self.node_relationships[relationship.source_node_id]['out'].pop(relationship.id_)
65
+ self.node_relationships[relationship.target_node_id]['in'].pop(relationship.id_)
66
+
67
+ return self.relationships.pop(relationship.id_)
68
+
69
+ def node_exists(self, node: BaseNode):
70
+ if node.id_ in self.nodes.keys():
71
+ return True
72
+ else:
73
+ return False
74
+
75
+ def relationship_exists(self, relationship: Relationship):
76
+ if relationship.id_ in self.relationships.keys():
77
+ return True
78
+ else:
79
+ return False
80
+
81
+ def to_networkx(self, **kwargs):
82
+ import networkx as nx
83
+ g = nx.DiGraph(**kwargs)
84
+ for node_id, node in self.nodes.items():
85
+ node_info = node.to_dict()
86
+ node_info.pop('node_id')
87
+ g.add_node(node_id, **node_info)
88
+
89
+ for relationship_id, relationship in self.relationships.items():
90
+ relationship_info = relationship.to_dict()
91
+ relationship_info.pop('node_id')
92
+ source_node_id = relationship_info.pop('source_node_id')
93
+ target_node_id = relationship_info.pop('target_node_id')
94
+ g.add_edge(source_node_id, target_node_id, **relationship_info)
95
+
96
+ return g
@@ -1,6 +1,6 @@
1
1
  from pydantic import Field
2
2
  from typing import Dict, Optional, Any
3
- from ..schema.base_schema import BaseNode
3
+ from ..schema.base_node import BaseNode
4
4
 
5
5
 
6
6
  class Relationship(BaseNode):
@@ -125,4 +125,12 @@ class Relationship(BaseNode):
125
125
  raise ValueError(f"Target node {self.source_node_id} does not exist")
126
126
  else :
127
127
  raise ValueError(f"Source node {self.target_node_id} does not exist")
128
-
128
+
129
+ def __str__(self) -> str:
130
+ """Returns a simple string representation of the Relationship."""
131
+ return f"Relationship (id_={self.id_}, from={self.source_node_id}, to={self.target_node_id}, label={self.label})"
132
+
133
+ def __repr__(self) -> str:
134
+ """Returns a detailed string representation of the Relationship."""
135
+ return f"Relationship(id_={self.id_}, from={self.source_node_id}, to={self.target_node_id}, content={self.content}, " \
136
+ f"metadata={self.metadata}, label={self.label})"
@@ -0,0 +1,102 @@
1
+ from typing import TypeVar
2
+ from .graph import Graph
3
+ from ..schema import BaseNode
4
+ from .relationship import Relationship
5
+
6
+ T = TypeVar('T', bound='BaseNode')
7
+ R = TypeVar('R', bound='Relationship')
8
+
9
+
10
+ class Structure(BaseNode):
11
+ """
12
+ Represents the structure of a graph consisting of nodes and relationships.
13
+ """
14
+ graph: Graph
15
+
16
+ def add_node(self, node: T) -> None:
17
+ """
18
+ Adds a node to the structure.
19
+
20
+ Args:
21
+ node (T): The node instance to be added.
22
+ """
23
+ self.graph.add_node(node)
24
+
25
+ def add_relationship(self, relationship: R) -> None:
26
+ """
27
+ Adds a relationship to the structure.
28
+
29
+ Args:
30
+ relationship (R): The relationship instance to be added.
31
+ """
32
+ self.graph.add_relationship(relationship)
33
+
34
+ # type can be dict or list
35
+ # @staticmethod
36
+ # def _typed_return(type: Type[Union[Dict, List]],
37
+ # obj: Optional[Dict[str, Any]] = None
38
+ # ) -> Union[Dict[str, Any], List[Any]]:
39
+ # """
40
+ # Returns the object in the specified type format.
41
+ #
42
+ # Args:
43
+ # type (Type[Union[Dict, List]]): The type to return the object as (dict or list).
44
+ #
45
+ # obj (Optional[Dict[str, Any]]): The object to be converted.
46
+ #
47
+ # Returns:
48
+ # Union[Dict[str, Any], List[Any]]: The object in the specified type format.
49
+ # """
50
+ # if type is list:
51
+ # return list(obj.values())
52
+ # return obj
53
+
54
+ def get_relationships(self) -> list[R]:
55
+ return self.graph.get_node_relationships()
56
+
57
+ def get_node_relationships(self, node: T, out_edge=True) -> R:
58
+ return self.graph.get_node_relationships(node, out_edge)
59
+
60
+ def node_exist(self, node: T) -> bool:
61
+ """
62
+ Checks if a node exists in the structure.
63
+
64
+ Args:
65
+ node (T): The node instance or node ID to check for existence.
66
+
67
+ Returns:
68
+ bool: True if the node exists, False otherwise.
69
+ """
70
+
71
+ return self.graph.node_exist(node)
72
+
73
+ def relationship_exist(self, relationship: R) -> bool:
74
+ """
75
+ Checks if a relationship exists in the structure.
76
+
77
+ Args:
78
+ relationship (R): The relationship instance to check for existence.
79
+
80
+ Returns:
81
+ bool: True if the relationship exists, False otherwise.
82
+ """
83
+ return self.graph.relationship_exists(relationship)
84
+
85
+ def remove_node(self, node: T) -> T:
86
+ """
87
+ Removes a node and its associated relationships from the structure.
88
+
89
+ Args:
90
+ node (T): The node instance or node ID to be removed.
91
+ """
92
+ return self.graph.remove_node(node)
93
+
94
+ def remove_relationship(self, relationship: R) -> R:
95
+ """
96
+ Removes a relationship from the structure.
97
+
98
+ Args:
99
+ relationship (R): The relationship instance to be removed.
100
+ """
101
+ return self.graph.remove_relationship(relationship)
102
+
@@ -0,0 +1,46 @@
1
+ import unittest
2
+ from unittest.mock import MagicMock
3
+
4
+ # Assuming the Python module with the above functions is named 'api_utils'
5
+ from lionagi.utils.api_util import *
6
+
7
+ class TestApiUtils(unittest.TestCase):
8
+
9
+ def test_api_method_valid(self):
10
+ session = MagicMock()
11
+ methods = ['post', 'delete', 'head', 'options', 'patch']
12
+ for method in methods:
13
+ with self.subTest(method=method):
14
+ func = api_method(session, method)
15
+ self.assertTrue(callable(func))
16
+
17
+ def test_api_method_invalid(self):
18
+ session = MagicMock()
19
+ with self.assertRaises(ValueError):
20
+ api_method(session, 'get')
21
+
22
+ def test_api_error(self):
23
+ response_with_error = {'error': 'Something went wrong'}
24
+ response_without_error = {'result': 'Success'}
25
+
26
+ self.assertTrue(api_error(response_with_error))
27
+ self.assertFalse(api_error(response_without_error))
28
+
29
+ def test_api_rate_limit_error(self):
30
+ response_with_rate_limit_error = {'error': {'message': 'Rate limit exceeded'}}
31
+ response_without_rate_limit_error = {'error': {'message': 'Another error'}}
32
+
33
+ self.assertTrue(api_rate_limit_error(response_with_rate_limit_error))
34
+ self.assertFalse(api_rate_limit_error(response_without_rate_limit_error))
35
+
36
+ def test_api_endpoint_from_url(self):
37
+ url_with_endpoint = "https://api.example.com/v1/users"
38
+ url_without_endpoint = "https://api.example.com/users"
39
+ url_invalid = "Just a string"
40
+
41
+ self.assertEqual(api_endpoint_from_url(url_with_endpoint), 'users')
42
+ self.assertEqual(api_endpoint_from_url(url_without_endpoint), '')
43
+ self.assertEqual(api_endpoint_from_url(url_invalid), '')
44
+
45
+ if __name__ == '__main__':
46
+ unittest.main()
@@ -0,0 +1,115 @@
1
+ # import asyncio
2
+ # import unittest
3
+ # from unittest.mock import patch, MagicMock
4
+ # from lionagi.utils.sys_util import create_copy
5
+ # from lionagi.utils.flat_util import to_list
6
+ # from typing import Callable
7
+
8
+ # # Assuming the Python module with the above functions is named 'call_utils'
9
+ # from lionagi.utils.call_util import *
10
+
11
+ # class TestCallUtils(unittest.TestCase):
12
+
13
+ # def setUp(self):
14
+ # self.sample_input = [1, 2, 3]
15
+ # self.sample_func = MagicMock(return_value=42)
16
+
17
+ # def test_hcall(self):
18
+ # with patch('time.sleep', return_value=None):
19
+ # result = hcall(self.sample_input, self.sample_func)
20
+ # self.sample_func.assert_called_once_with(self.sample_input)
21
+ # self.assertEqual(result, 42)
22
+
23
+ # def test_ahcall(self):
24
+ # async def async_test():
25
+ # self.sample_func.reset_mock()
26
+ # with patch('asyncio.sleep', return_value=None):
27
+ # result = await ahcall(self.sample_input, self.sample_func)
28
+ # self.sample_func.assert_called_once_with(self.sample_input)
29
+ # self.assertEqual(result, 42)
30
+
31
+ # asyncio.run(async_test())
32
+
33
+ # def test_lcall(self):
34
+ # expected_result = [42, 42, 42]
35
+ # result = lcall(self.sample_input, self.sample_func)
36
+ # self.assertEqual(result, expected_result)
37
+ # calls = [unittest.mock.call(item) for item in self.sample_input]
38
+ # self.sample_func.assert_has_calls(calls, any_order=True)
39
+
40
+ # def test_alcall(self):
41
+ # async def async_test():
42
+ # expected_result = [42, 42, 42]
43
+ # self.sample_func.reset_mock()
44
+ # result = await alcall(self.sample_input, self.sample_func)
45
+ # self.assertEqual(result, expected_result)
46
+ # calls = [unittest.mock.call(item) for item in self.sample_input]
47
+ # self.sample_func.assert_has_calls(calls, any_order=True)
48
+
49
+ # asyncio.run(async_test())
50
+
51
+
52
+ # class TestCallUtils2(unittest.TestCase):
53
+
54
+ # def setUp(self):
55
+ # self.sample_input = [1, 2, 3]
56
+ # self.sample_func = MagicMock(side_effect=lambda x: x + 1)
57
+ # self.sample_async_func = MagicMock(side_effect=lambda x: x + 1)
58
+
59
+ # def test_mcall_single_function(self):
60
+ # result = mcall(self.sample_input, self.sample_func)
61
+ # self.sample_func.assert_has_calls([unittest.mock.call(i) for i in self.sample_input], any_order=True)
62
+ # self.assertEqual(result, [2, 3, 4])
63
+
64
+ # def test_mcall_multiple_functions(self):
65
+ # # Define multiple functions
66
+ # funcs = [
67
+ # MagicMock(side_effect=lambda x: x + 1),
68
+ # MagicMock(side_effect=lambda x: x * 2),
69
+ # MagicMock(side_effect=lambda x: x - 1)
70
+ # ]
71
+ # result = mcall(self.sample_input, funcs)
72
+ # for i, func in enumerate(funcs):
73
+ # func.assert_called_once_with(self.sample_input[i])
74
+ # self.assertEqual(result, [2, 4, 2])
75
+
76
+ # def test_amcall_single_function(self):
77
+ # async def async_test():
78
+ # result = await amcall(self.sample_input, self.sample_async_func)
79
+ # self.sample_async_func.assert_has_calls([unittest.mock.call(i) for i in self.sample_input], any_order=True)
80
+ # self.assertEqual(result, [2, 3, 4])
81
+
82
+ # asyncio.run(async_test())
83
+
84
+ # def test_amcall_multiple_functions(self):
85
+ # # Define multiple asynchronous functions
86
+ # async_funcs = [
87
+ # MagicMock(side_effect=lambda x: x + 1),
88
+ # MagicMock(side_effect=lambda x: x * 2),
89
+ # MagicMock(side_effect=lambda x: x - 1)
90
+ # ]
91
+
92
+ # async def async_test():
93
+ # result = await amcall(self.sample_input, async_funcs)
94
+ # for i, func in enumerate(async_funcs):
95
+ # func.assert_called_once_with(self.sample_input[i])
96
+ # self.assertEqual(result, [2, 4, 2])
97
+
98
+ # asyncio.run(async_test())
99
+
100
+ # def test_ecall(self):
101
+ # funcs = [MagicMock(side_effect=lambda x: x * x)]
102
+ # result = ecall(self.sample_input, funcs)
103
+ # self.assertEqual(result, [[1, 4, 9]])
104
+
105
+ # def test_aecall(self):
106
+ # async_funcs = [MagicMock(side_effect=lambda x: x * x)]
107
+
108
+ # async def async_test():
109
+ # result = await aecall(self.sample_input, async_funcs)
110
+ # self.assertEqual(result, [[1, 4, 9]])
111
+
112
+ # asyncio.run(async_test())
113
+
114
+ # if __name__ == '__main__':
115
+ # unittest.main()
@@ -0,0 +1,202 @@
1
+ import unittest
2
+ from lionagi.utils.convert_util import *
3
+
4
+ # Test cases for the function
5
+ class TestStrToNum(unittest.TestCase):
6
+ def test_valid_int(self):
7
+ self.assertEqual(str_to_num("123"), 123)
8
+ self.assertEqual(str_to_num("-123"), -123)
9
+
10
+ def test_valid_float(self):
11
+ self.assertEqual(str_to_num("123.45", num_type=float), 123.45)
12
+ self.assertEqual(str_to_num("-123.45", num_type=float), -123.45)
13
+
14
+ def test_precision(self):
15
+ self.assertEqual(str_to_num("123.456", num_type=float, precision=1), 123.5)
16
+ self.assertEqual(str_to_num("123.444", num_type=float, precision=2), 123.44)
17
+
18
+ def test_bounds(self):
19
+ self.assertEqual(str_to_num("10", lower_bound=5, upper_bound=15), 10)
20
+ with self.assertRaises(ValueError):
21
+ str_to_num("20", upper_bound=15)
22
+ with self.assertRaises(ValueError):
23
+ str_to_num("2", lower_bound=5)
24
+
25
+ def test_invalid_input(self):
26
+ with self.assertRaises(ValueError):
27
+ str_to_num("abc")
28
+ with self.assertRaises(ValueError):
29
+ str_to_num("123abc", num_type=str)
30
+
31
+ def test_no_numeric_value(self):
32
+ with self.assertRaises(ValueError):
33
+ str_to_num("No numbers here")
34
+
35
+
36
+
37
+ # Functions to be tested
38
+ def dict_to_xml(data: Dict[str, Any], root_tag: str = 'node') -> str:
39
+ root = ET.Element(root_tag)
40
+ _build_xml(root, data)
41
+ return ET.tostring(root, encoding='unicode')
42
+
43
+ def _build_xml(element: ET.Element, data: Any):
44
+ if isinstance(data, dict):
45
+ for key, value in data.items():
46
+ sub_element = ET.SubElement(element, key)
47
+ _build_xml(sub_element, value)
48
+ elif isinstance(data, list):
49
+ for item in data:
50
+ item_element = ET.SubElement(element, 'item')
51
+ _build_xml(item_element, item)
52
+ else:
53
+ element.text = str(data)
54
+
55
+ def xml_to_dict(element: ET.Element) -> Dict[str, Any]:
56
+ dict_data = {}
57
+ for child in element:
58
+ if list(child):
59
+ dict_data[child.tag] = xml_to_dict(child)
60
+ else:
61
+ dict_data[child.tag] = child.text
62
+ return dict_data
63
+
64
+ # Test cases for the functions
65
+ class TestDictXMLConversion(unittest.TestCase):
66
+ def setUp(self):
67
+ self.data = {
68
+ 'name': 'John',
69
+ 'age': 30,
70
+ 'children': [
71
+ {'name': 'Alice', 'age': 5},
72
+ {'name': 'Bob', 'age': 7}
73
+ ]
74
+ }
75
+ self.root_tag = 'person'
76
+ self.xml = dict_to_xml(self.data, self.root_tag)
77
+ self.xml_element = ET.fromstring(self.xml)
78
+
79
+ def test_dict_to_xml(self):
80
+ self.assertIn('<name>John</name>', self.xml)
81
+ self.assertIn('<age>30</age>', self.xml)
82
+ self.assertIn('<children>', self.xml)
83
+ self.assertIn('<item>', self.xml)
84
+
85
+ # def test_xml_to_dict(self):
86
+ # data_from_xml = xml_to_dict(self.xml_element)
87
+ # self.assertEqual(data_from_xml, self.data)
88
+
89
+ # def test_xml_to_dict_to_xml(self):
90
+ # data_from_xml = xml_to_dict(self.xml_element)
91
+ # xml_from_dict = dict_to_xml(data_from_xml, self.root_tag)
92
+ # self.assertEqual(xml_from_dict, self.xml)
93
+
94
+ # def test_invalid_input(self):
95
+ # with self.assertRaises(TypeError):
96
+ # dict_to_xml("not a dict", self.root_tag)
97
+
98
+
99
+
100
+ class TestDocstringExtraction(unittest.TestCase):
101
+ def test_google_style_extraction(self):
102
+ def sample_func(arg1, arg2):
103
+ """
104
+ This is a sample function.
105
+
106
+ Args:
107
+ arg1 (int): The first argument.
108
+ arg2 (str): The second argument.
109
+
110
+ Returns:
111
+ bool: The truth value.
112
+ """
113
+ return True
114
+
115
+ description, params = extract_docstring_details_google(sample_func)
116
+ self.assertEqual(description, "This is a sample function.")
117
+ self.assertEqual(params, {
118
+ 'arg1': 'The first argument.',
119
+ 'arg2': 'The second argument.'
120
+ })
121
+
122
+ def test_rest_style_extraction(self):
123
+ def sample_func(arg1, arg2):
124
+ """
125
+ This is a sample function.
126
+
127
+ :param int arg1: The first argument.
128
+ :param str arg2: The second argument.
129
+ :return: The truth value.
130
+ :rtype: bool
131
+ """
132
+ return True
133
+
134
+ description, params = extract_docstring_details_rest(sample_func)
135
+ self.assertEqual(description, "This is a sample function.")
136
+ self.assertEqual(params, {
137
+ 'arg1': 'The first argument.',
138
+ 'arg2': 'The second argument.'
139
+ })
140
+
141
+ def test_extract_docstring_details_with_invalid_style(self):
142
+ def sample_func(arg1, arg2):
143
+ return True
144
+
145
+ with self.assertRaises(ValueError):
146
+ extract_docstring_details(sample_func, style='unsupported')
147
+
148
+ class TestPythonToJsonTypeConversion(unittest.TestCase):
149
+ def test_python_to_json_type_conversion(self):
150
+ self.assertEqual(python_to_json_type('str'), 'string')
151
+ self.assertEqual(python_to_json_type('int'), 'number')
152
+ self.assertEqual(python_to_json_type('float'), 'number')
153
+ self.assertEqual(python_to_json_type('list'), 'array')
154
+ self.assertEqual(python_to_json_type('tuple'), 'array')
155
+ self.assertEqual(python_to_json_type('bool'), 'boolean')
156
+ self.assertEqual(python_to_json_type('dict'), 'object')
157
+ self.assertEqual(python_to_json_type('nonexistent'), 'object')
158
+
159
+ class TestFunctionToSchema(unittest.TestCase):
160
+ def test_func_to_schema(self):
161
+ def sample_func(arg1: int, arg2: str) -> bool:
162
+ """
163
+ This is a sample function.
164
+
165
+ Args:
166
+ arg1 (int): The first argument.
167
+ arg2 (str): The second argument.
168
+
169
+ Returns:
170
+ bool: The truth value.
171
+ """
172
+ return True
173
+
174
+ schema = func_to_schema(sample_func, style='google')
175
+ expected_schema = {
176
+ "type": "function",
177
+ "function": {
178
+ "name": "sample_func",
179
+ "description": "This is a sample function.",
180
+ "parameters": {
181
+ "type": "object",
182
+ "properties": {
183
+ "arg1": {
184
+ "type": "number",
185
+ "description": "The first argument."
186
+ },
187
+ "arg2": {
188
+ "type": "string",
189
+ "description": "The second argument."
190
+ }
191
+ },
192
+ "required": ["arg1", "arg2"]
193
+ }
194
+ }
195
+ }
196
+ self.assertEqual(schema, expected_schema)
197
+
198
+ if __name__ == '__main__':
199
+ unittest.main()
200
+
201
+
202
+ # --------------------------------------------------