lionagi 0.0.114__py3-none-any.whl → 0.0.116__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. lionagi/__init__.py +7 -4
  2. lionagi/bridge/__init__.py +19 -4
  3. lionagi/bridge/langchain.py +23 -3
  4. lionagi/bridge/llama_index.py +5 -3
  5. lionagi/configs/__init__.py +1 -1
  6. lionagi/configs/oai_configs.py +88 -1
  7. lionagi/core/__init__.py +6 -9
  8. lionagi/core/conversations/__init__.py +5 -0
  9. lionagi/core/conversations/conversation.py +107 -0
  10. lionagi/core/flows/__init__.py +8 -0
  11. lionagi/core/flows/flow.py +8 -0
  12. lionagi/core/flows/flow_util.py +62 -0
  13. lionagi/core/instruction_set/__init__.py +5 -0
  14. lionagi/core/instruction_set/instruction_sets.py +7 -0
  15. lionagi/core/sessions/__init__.py +5 -0
  16. lionagi/core/sessions/sessions.py +187 -0
  17. lionagi/endpoints/__init__.py +5 -0
  18. lionagi/endpoints/assistants.py +0 -0
  19. lionagi/endpoints/audio.py +17 -0
  20. lionagi/endpoints/chatcompletion.py +54 -0
  21. lionagi/endpoints/embeddings.py +0 -0
  22. lionagi/endpoints/finetune.py +0 -0
  23. lionagi/endpoints/image.py +0 -0
  24. lionagi/endpoints/moderation.py +0 -0
  25. lionagi/endpoints/vision.py +0 -0
  26. lionagi/{loader → loaders}/__init__.py +7 -1
  27. lionagi/{loader → loaders}/chunker.py +6 -12
  28. lionagi/{utils/load_utils.py → loaders/load_util.py} +47 -6
  29. lionagi/{loader → loaders}/reader.py +4 -12
  30. lionagi/messages/__init__.py +11 -0
  31. lionagi/messages/instruction.py +15 -0
  32. lionagi/messages/message.py +110 -0
  33. lionagi/messages/response.py +33 -0
  34. lionagi/messages/system.py +12 -0
  35. lionagi/objs/__init__.py +10 -6
  36. lionagi/objs/abc_objs.py +39 -0
  37. lionagi/objs/async_queue.py +135 -0
  38. lionagi/objs/messenger.py +70 -148
  39. lionagi/objs/status_tracker.py +37 -0
  40. lionagi/objs/{tool_registry.py → tool_manager.py} +8 -6
  41. lionagi/schema/__init__.py +3 -3
  42. lionagi/schema/base_node.py +251 -0
  43. lionagi/schema/base_tool.py +8 -3
  44. lionagi/schema/data_logger.py +2 -3
  45. lionagi/schema/data_node.py +37 -0
  46. lionagi/services/__init__.py +1 -4
  47. lionagi/services/base_api_service.py +15 -5
  48. lionagi/services/oai.py +2 -2
  49. lionagi/services/openrouter.py +2 -3
  50. lionagi/structures/graph.py +96 -0
  51. lionagi/{structure → structures}/relationship.py +10 -2
  52. lionagi/structures/structure.py +102 -0
  53. lionagi/tests/test_api_util.py +46 -0
  54. lionagi/tests/test_call_util.py +115 -0
  55. lionagi/tests/test_convert_util.py +202 -0
  56. lionagi/tests/test_encrypt_util.py +33 -0
  57. lionagi/tests/{test_flatten_util.py → test_flat_util.py} +1 -1
  58. lionagi/tests/test_io_util.py +0 -0
  59. lionagi/tests/test_sys_util.py +0 -0
  60. lionagi/tools/__init__.py +5 -0
  61. lionagi/tools/tool_util.py +7 -0
  62. lionagi/utils/__init__.py +55 -35
  63. lionagi/utils/api_util.py +19 -17
  64. lionagi/utils/call_util.py +2 -1
  65. lionagi/utils/convert_util.py +229 -0
  66. lionagi/utils/encrypt_util.py +16 -0
  67. lionagi/utils/flat_util.py +38 -0
  68. lionagi/utils/io_util.py +2 -2
  69. lionagi/utils/sys_util.py +45 -10
  70. lionagi/version.py +1 -1
  71. {lionagi-0.0.114.dist-info → lionagi-0.0.116.dist-info}/METADATA +2 -2
  72. lionagi-0.0.116.dist-info/RECORD +110 -0
  73. lionagi/core/conversations.py +0 -108
  74. lionagi/core/flows.py +0 -1
  75. lionagi/core/instruction_sets.py +0 -1
  76. lionagi/core/messages.py +0 -166
  77. lionagi/core/sessions.py +0 -297
  78. lionagi/schema/base_schema.py +0 -252
  79. lionagi/services/chatcompletion.py +0 -48
  80. lionagi/services/service_objs.py +0 -282
  81. lionagi/structure/structure.py +0 -160
  82. lionagi/tools/coder.py +0 -1
  83. lionagi/tools/sandbox.py +0 -1
  84. lionagi/utils/tool_util.py +0 -92
  85. lionagi/utils/type_util.py +0 -81
  86. lionagi-0.0.114.dist-info/RECORD +0 -84
  87. /lionagi/configs/{openrouter_config.py → openrouter_configs.py} +0 -0
  88. /lionagi/{datastore → datastores}/__init__.py +0 -0
  89. /lionagi/{datastore → datastores}/chroma.py +0 -0
  90. /lionagi/{datastore → datastores}/deeplake.py +0 -0
  91. /lionagi/{datastore → datastores}/elasticsearch.py +0 -0
  92. /lionagi/{datastore → datastores}/lantern.py +0 -0
  93. /lionagi/{datastore → datastores}/pinecone.py +0 -0
  94. /lionagi/{datastore → datastores}/postgres.py +0 -0
  95. /lionagi/{datastore → datastores}/qdrant.py +0 -0
  96. /lionagi/{structure → structures}/__init__.py +0 -0
  97. {lionagi-0.0.114.dist-info → lionagi-0.0.116.dist-info}/LICENSE +0 -0
  98. {lionagi-0.0.114.dist-info → lionagi-0.0.116.dist-info}/WHEEL +0 -0
  99. {lionagi-0.0.114.dist-info → lionagi-0.0.116.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,9 @@ import tiktoken
5
5
  import logging
6
6
  import aiohttp
7
7
  from typing import Generator, NoReturn, Dict, Any, Optional
8
- from .service_objs import BaseService, RateLimiter, StatusTracker, AsyncQueue
8
+ from ..objs.status_tracker import StatusTracker
9
+ from ..objs.abc_objs import BaseService, RateLimiter
10
+ from ..objs.async_queue import AsyncQueue
9
11
 
10
12
  class BaseAPIRateLimiter(RateLimiter):
11
13
 
@@ -20,7 +22,7 @@ class BaseAPIRateLimiter(RateLimiter):
20
22
  ) -> None:
21
23
  self = cls(max_requests_per_minute, max_tokens_per_minute)
22
24
  if not os.getenv("env_readthedocs"):
23
- self.rate_limit_replenisher_task = await asyncio.create_task(
25
+ self.rate_limit_replenisher_task = asyncio.create_task(
24
26
  self.rate_limit_replenisher()
25
27
  )
26
28
  return self
@@ -43,6 +45,7 @@ class BaseAPIRateLimiter(RateLimiter):
43
45
  self.available_request_capacity = self.max_requests_per_minute
44
46
  self.available_token_capacity = self.max_tokens_per_minute
45
47
 
48
+ # credit to OpenAI for the following function
46
49
  def calculate_num_token(
47
50
  self,
48
51
  payload: Dict[str, Any] = None,
@@ -130,13 +133,16 @@ class BaseAPIService(BaseService):
130
133
  def __init__(self, api_key: str = None,
131
134
  status_tracker = None,
132
135
  queue = None, endpoint=None, schema=None,
133
- ratelimiter=None, max_requests_per_minute=None, max_tokens_per_minute=None) -> None:
136
+ ratelimiter=BaseAPIRateLimiter, max_requests_per_minute=None, max_tokens_per_minute=None) -> None:
134
137
  self.api_key = api_key
135
138
  self.status_tracker = status_tracker or StatusTracker()
136
139
  self.queue = queue or AsyncQueue()
137
140
  self.endpoint=endpoint
138
141
  self.schema = schema
139
- self.rate_limiter = ratelimiter(max_requests_per_minute, max_tokens_per_minute)
142
+ self.max_requests_per_minute = max_requests_per_minute
143
+ self.max_tokens_per_minute = max_tokens_per_minute
144
+ self.rate_limiter_class = ratelimiter
145
+ self.rate_limiter = None
140
146
 
141
147
  @staticmethod
142
148
  def api_methods(http_session, method="post"):
@@ -168,6 +174,10 @@ class BaseAPIService(BaseService):
168
174
  yield task_id
169
175
  task_id += 1
170
176
 
177
+ async def _init(self):
178
+ if self.rate_limiter is None:
179
+ self.rate_limiter = await self.rate_limiter_class.create(self.max_requests_per_minute, self.max_tokens_per_minute)
180
+
171
181
  async def _call_api(self, http_session, endpoint_, method="post", payload: Dict[str, any] =None) -> Optional[Dict[str, any]]:
172
182
  endpoint_ = self.api_endpoint_from_url("https://api.openai.com/v1/"+endpoint_)
173
183
 
@@ -224,4 +234,4 @@ class BaseAPIService(BaseService):
224
234
  except Exception as e:
225
235
  self.status_tracker.num_tasks_failed += 1
226
236
  raise e
227
-
237
+
lionagi/services/oai.py CHANGED
@@ -15,7 +15,7 @@ class OpenAIService(BaseAPIService):
15
15
  max_attempts: int = 3,
16
16
  max_requests_per_minute: int = 500,
17
17
  max_tokens_per_minute: int = 150_000,
18
- ratelimiter = BaseAPIRateLimiter ,
18
+ ratelimiter = BaseAPIRateLimiter,
19
19
  status_tracker = None,
20
20
  queue = None,
21
21
  ):
@@ -31,4 +31,4 @@ class OpenAIService(BaseAPIService):
31
31
 
32
32
  async def serve(self, payload, endpoint_="chat/completions", method="post"):
33
33
  return await self._serve(payload=payload, endpoint_=endpoint_, method=method)
34
-
34
+
@@ -2,8 +2,7 @@ from os import getenv
2
2
  from .base_api_service import BaseAPIService, BaseAPIRateLimiter
3
3
 
4
4
  class OpenRouterService(BaseAPIService):
5
- _key_scheme = "OPENROUTER_API_KEY"
6
-
5
+ key_scheme = "OPENROUTER_API_KEY"
7
6
  base_url = "https://openrouter.ai/api/v1/"
8
7
 
9
8
  def __init__(
@@ -18,7 +17,7 @@ class OpenRouterService(BaseAPIService):
18
17
  queue = None,
19
18
  ):
20
19
  super().__init__(
21
- api_key = api_key or getenv(self._key_scheme),
20
+ api_key = api_key or getenv(self.key_scheme),
22
21
  status_tracker = status_tracker,
23
22
  queue = queue,
24
23
  ratelimiter=ratelimiter,
@@ -0,0 +1,96 @@
1
+ from pydantic import Field
2
+
3
+ from lionagi.schema.base_node import BaseNode
4
+ from .relationship import Relationship
5
+ from lionagi.utils.call_util import lcall
6
+
7
+
8
+ class Graph(BaseNode):
9
+ nodes: dict = Field(default={})
10
+ relationships: dict = Field(default={})
11
+ node_relationships: dict = Field(default={})
12
+
13
+ def add_node(self, node: BaseNode):
14
+ self.nodes[node.id_] = node
15
+ self.node_relationships[node.id_] = {'in': {}, 'out': {}}
16
+
17
+ def add_relationship(self, relationships: Relationship):
18
+ if relationships.source_node_id not in self.node_relationships.keys():
19
+ raise KeyError(f'node {relationships.source_node_id} is not found.')
20
+ if relationships.target_node_id not in self.node_relationships.keys():
21
+ raise KeyError(f'node {relationships.target_node_id} is not found.')
22
+
23
+ self.relationships[relationships.id_] = relationships
24
+ self.node_relationships[relationships.source_node_id]['out'][relationships.id_] = relationships.target_node_id
25
+ self.node_relationships[relationships.target_node_id]['in'][relationships.id_] = relationships.source_node_id
26
+
27
+ def get_node_relationships(self, node: BaseNode = None, out_edge=True):
28
+ if node is None:
29
+ return list(self.relationships.values())
30
+
31
+ if node.id_ not in self.nodes.keys():
32
+ raise KeyError(f'node {node.id_} is not found')
33
+
34
+ if out_edge:
35
+ relationship_ids = list(self.node_relationships[node.id_]['out'].keys())
36
+ relationships = lcall(relationship_ids, lambda i: self.relationships[i])
37
+ return relationships
38
+ else:
39
+ relationship_ids = list(self.node_relationships[node.id_]['in'].keys())
40
+ relationships = lcall(relationship_ids, lambda i: self.relationships[i])
41
+ return relationships
42
+
43
+ def remove_node(self, node: BaseNode):
44
+ if node.id_ not in self.nodes.keys():
45
+ raise KeyError(f'node {node.id_} is not found')
46
+
47
+ out_relationship = self.node_relationships[node.id_]['out']
48
+ for relationship_id, node_id in out_relationship.items():
49
+ self.node_relationships[node_id]['in'].pop(relationship_id)
50
+ self.relationships.pop(relationship_id)
51
+
52
+ in_relationship = self.node_relationships[node.id_]['in']
53
+ for relationship_id, node_id in in_relationship.items():
54
+ self.node_relationships[node_id]['out'].pop(relationship_id)
55
+ self.relationships.pop(relationship_id)
56
+
57
+ self.node_relationships.pop(node.id_)
58
+ return self.nodes.pop(node.id_)
59
+
60
+ def remove_relationship(self, relationship: Relationship):
61
+ if relationship.id_ not in self.relationships.keys():
62
+ raise KeyError(f'relationship {relationship.id_} is not found')
63
+
64
+ self.node_relationships[relationship.source_node_id]['out'].pop(relationship.id_)
65
+ self.node_relationships[relationship.target_node_id]['in'].pop(relationship.id_)
66
+
67
+ return self.relationships.pop(relationship.id_)
68
+
69
+ def node_exists(self, node: BaseNode):
70
+ if node.id_ in self.nodes.keys():
71
+ return True
72
+ else:
73
+ return False
74
+
75
+ def relationship_exists(self, relationship: Relationship):
76
+ if relationship.id_ in self.relationships.keys():
77
+ return True
78
+ else:
79
+ return False
80
+
81
+ def to_networkx(self, **kwargs):
82
+ import networkx as nx
83
+ g = nx.DiGraph(**kwargs)
84
+ for node_id, node in self.nodes.items():
85
+ node_info = node.to_dict()
86
+ node_info.pop('node_id')
87
+ g.add_node(node_id, **node_info)
88
+
89
+ for relationship_id, relationship in self.relationships.items():
90
+ relationship_info = relationship.to_dict()
91
+ relationship_info.pop('node_id')
92
+ source_node_id = relationship_info.pop('source_node_id')
93
+ target_node_id = relationship_info.pop('target_node_id')
94
+ g.add_edge(source_node_id, target_node_id, **relationship_info)
95
+
96
+ return g
@@ -1,6 +1,6 @@
1
1
  from pydantic import Field
2
2
  from typing import Dict, Optional, Any
3
- from ..schema.base_schema import BaseNode
3
+ from ..schema.base_node import BaseNode
4
4
 
5
5
 
6
6
  class Relationship(BaseNode):
@@ -125,4 +125,12 @@ class Relationship(BaseNode):
125
125
  raise ValueError(f"Target node {self.source_node_id} does not exist")
126
126
  else :
127
127
  raise ValueError(f"Source node {self.target_node_id} does not exist")
128
-
128
+
129
+ def __str__(self) -> str:
130
+ """Returns a simple string representation of the Relationship."""
131
+ return f"Relationship (id_={self.id_}, from={self.source_node_id}, to={self.target_node_id}, label={self.label})"
132
+
133
+ def __repr__(self) -> str:
134
+ """Returns a detailed string representation of the Relationship."""
135
+ return f"Relationship(id_={self.id_}, from={self.source_node_id}, to={self.target_node_id}, content={self.content}, " \
136
+ f"metadata={self.metadata}, label={self.label})"
@@ -0,0 +1,102 @@
1
+ from typing import TypeVar
2
+ from .graph import Graph
3
+ from ..schema import BaseNode
4
+ from .relationship import Relationship
5
+
6
+ T = TypeVar('T', bound='BaseNode')
7
+ R = TypeVar('R', bound='Relationship')
8
+
9
+
10
+ class Structure(BaseNode):
11
+ """
12
+ Represents the structure of a graph consisting of nodes and relationships.
13
+ """
14
+ graph: Graph
15
+
16
+ def add_node(self, node: T) -> None:
17
+ """
18
+ Adds a node to the structure.
19
+
20
+ Args:
21
+ node (T): The node instance to be added.
22
+ """
23
+ self.graph.add_node(node)
24
+
25
+ def add_relationship(self, relationship: R) -> None:
26
+ """
27
+ Adds a relationship to the structure.
28
+
29
+ Args:
30
+ relationship (R): The relationship instance to be added.
31
+ """
32
+ self.graph.add_relationship(relationship)
33
+
34
+ # type can be dict or list
35
+ # @staticmethod
36
+ # def _typed_return(type: Type[Union[Dict, List]],
37
+ # obj: Optional[Dict[str, Any]] = None
38
+ # ) -> Union[Dict[str, Any], List[Any]]:
39
+ # """
40
+ # Returns the object in the specified type format.
41
+ #
42
+ # Args:
43
+ # type (Type[Union[Dict, List]]): The type to return the object as (dict or list).
44
+ #
45
+ # obj (Optional[Dict[str, Any]]): The object to be converted.
46
+ #
47
+ # Returns:
48
+ # Union[Dict[str, Any], List[Any]]: The object in the specified type format.
49
+ # """
50
+ # if type is list:
51
+ # return list(obj.values())
52
+ # return obj
53
+
54
+ def get_relationships(self) -> list[R]:
55
+ return self.graph.get_node_relationships()
56
+
57
+ def get_node_relationships(self, node: T, out_edge=True) -> R:
58
+ return self.graph.get_node_relationships(node, out_edge)
59
+
60
+ def node_exist(self, node: T) -> bool:
61
+ """
62
+ Checks if a node exists in the structure.
63
+
64
+ Args:
65
+ node (T): The node instance or node ID to check for existence.
66
+
67
+ Returns:
68
+ bool: True if the node exists, False otherwise.
69
+ """
70
+
71
+ return self.graph.node_exist(node)
72
+
73
+ def relationship_exist(self, relationship: R) -> bool:
74
+ """
75
+ Checks if a relationship exists in the structure.
76
+
77
+ Args:
78
+ relationship (R): The relationship instance to check for existence.
79
+
80
+ Returns:
81
+ bool: True if the relationship exists, False otherwise.
82
+ """
83
+ return self.graph.relationship_exists(relationship)
84
+
85
+ def remove_node(self, node: T) -> T:
86
+ """
87
+ Removes a node and its associated relationships from the structure.
88
+
89
+ Args:
90
+ node (T): The node instance or node ID to be removed.
91
+ """
92
+ return self.graph.remove_node(node)
93
+
94
+ def remove_relationship(self, relationship: R) -> R:
95
+ """
96
+ Removes a relationship from the structure.
97
+
98
+ Args:
99
+ relationship (R): The relationship instance to be removed.
100
+ """
101
+ return self.graph.remove_relationship(relationship)
102
+
@@ -0,0 +1,46 @@
1
+ import unittest
2
+ from unittest.mock import MagicMock
3
+
4
+ # Assuming the Python module with the above functions is named 'api_utils'
5
+ from lionagi.utils.api_util import *
6
+
7
+ class TestApiUtils(unittest.TestCase):
8
+
9
+ def test_api_method_valid(self):
10
+ session = MagicMock()
11
+ methods = ['post', 'delete', 'head', 'options', 'patch']
12
+ for method in methods:
13
+ with self.subTest(method=method):
14
+ func = api_method(session, method)
15
+ self.assertTrue(callable(func))
16
+
17
+ def test_api_method_invalid(self):
18
+ session = MagicMock()
19
+ with self.assertRaises(ValueError):
20
+ api_method(session, 'get')
21
+
22
+ def test_api_error(self):
23
+ response_with_error = {'error': 'Something went wrong'}
24
+ response_without_error = {'result': 'Success'}
25
+
26
+ self.assertTrue(api_error(response_with_error))
27
+ self.assertFalse(api_error(response_without_error))
28
+
29
+ def test_api_rate_limit_error(self):
30
+ response_with_rate_limit_error = {'error': {'message': 'Rate limit exceeded'}}
31
+ response_without_rate_limit_error = {'error': {'message': 'Another error'}}
32
+
33
+ self.assertTrue(api_rate_limit_error(response_with_rate_limit_error))
34
+ self.assertFalse(api_rate_limit_error(response_without_rate_limit_error))
35
+
36
+ def test_api_endpoint_from_url(self):
37
+ url_with_endpoint = "https://api.example.com/v1/users"
38
+ url_without_endpoint = "https://api.example.com/users"
39
+ url_invalid = "Just a string"
40
+
41
+ self.assertEqual(api_endpoint_from_url(url_with_endpoint), 'users')
42
+ self.assertEqual(api_endpoint_from_url(url_without_endpoint), '')
43
+ self.assertEqual(api_endpoint_from_url(url_invalid), '')
44
+
45
+ if __name__ == '__main__':
46
+ unittest.main()
@@ -0,0 +1,115 @@
1
+ # import asyncio
2
+ # import unittest
3
+ # from unittest.mock import patch, MagicMock
4
+ # from lionagi.utils.sys_util import create_copy
5
+ # from lionagi.utils.flat_util import to_list
6
+ # from typing import Callable
7
+
8
+ # # Assuming the Python module with the above functions is named 'call_utils'
9
+ # from lionagi.utils.call_util import *
10
+
11
+ # class TestCallUtils(unittest.TestCase):
12
+
13
+ # def setUp(self):
14
+ # self.sample_input = [1, 2, 3]
15
+ # self.sample_func = MagicMock(return_value=42)
16
+
17
+ # def test_hcall(self):
18
+ # with patch('time.sleep', return_value=None):
19
+ # result = hcall(self.sample_input, self.sample_func)
20
+ # self.sample_func.assert_called_once_with(self.sample_input)
21
+ # self.assertEqual(result, 42)
22
+
23
+ # def test_ahcall(self):
24
+ # async def async_test():
25
+ # self.sample_func.reset_mock()
26
+ # with patch('asyncio.sleep', return_value=None):
27
+ # result = await ahcall(self.sample_input, self.sample_func)
28
+ # self.sample_func.assert_called_once_with(self.sample_input)
29
+ # self.assertEqual(result, 42)
30
+
31
+ # asyncio.run(async_test())
32
+
33
+ # def test_lcall(self):
34
+ # expected_result = [42, 42, 42]
35
+ # result = lcall(self.sample_input, self.sample_func)
36
+ # self.assertEqual(result, expected_result)
37
+ # calls = [unittest.mock.call(item) for item in self.sample_input]
38
+ # self.sample_func.assert_has_calls(calls, any_order=True)
39
+
40
+ # def test_alcall(self):
41
+ # async def async_test():
42
+ # expected_result = [42, 42, 42]
43
+ # self.sample_func.reset_mock()
44
+ # result = await alcall(self.sample_input, self.sample_func)
45
+ # self.assertEqual(result, expected_result)
46
+ # calls = [unittest.mock.call(item) for item in self.sample_input]
47
+ # self.sample_func.assert_has_calls(calls, any_order=True)
48
+
49
+ # asyncio.run(async_test())
50
+
51
+
52
+ # class TestCallUtils2(unittest.TestCase):
53
+
54
+ # def setUp(self):
55
+ # self.sample_input = [1, 2, 3]
56
+ # self.sample_func = MagicMock(side_effect=lambda x: x + 1)
57
+ # self.sample_async_func = MagicMock(side_effect=lambda x: x + 1)
58
+
59
+ # def test_mcall_single_function(self):
60
+ # result = mcall(self.sample_input, self.sample_func)
61
+ # self.sample_func.assert_has_calls([unittest.mock.call(i) for i in self.sample_input], any_order=True)
62
+ # self.assertEqual(result, [2, 3, 4])
63
+
64
+ # def test_mcall_multiple_functions(self):
65
+ # # Define multiple functions
66
+ # funcs = [
67
+ # MagicMock(side_effect=lambda x: x + 1),
68
+ # MagicMock(side_effect=lambda x: x * 2),
69
+ # MagicMock(side_effect=lambda x: x - 1)
70
+ # ]
71
+ # result = mcall(self.sample_input, funcs)
72
+ # for i, func in enumerate(funcs):
73
+ # func.assert_called_once_with(self.sample_input[i])
74
+ # self.assertEqual(result, [2, 4, 2])
75
+
76
+ # def test_amcall_single_function(self):
77
+ # async def async_test():
78
+ # result = await amcall(self.sample_input, self.sample_async_func)
79
+ # self.sample_async_func.assert_has_calls([unittest.mock.call(i) for i in self.sample_input], any_order=True)
80
+ # self.assertEqual(result, [2, 3, 4])
81
+
82
+ # asyncio.run(async_test())
83
+
84
+ # def test_amcall_multiple_functions(self):
85
+ # # Define multiple asynchronous functions
86
+ # async_funcs = [
87
+ # MagicMock(side_effect=lambda x: x + 1),
88
+ # MagicMock(side_effect=lambda x: x * 2),
89
+ # MagicMock(side_effect=lambda x: x - 1)
90
+ # ]
91
+
92
+ # async def async_test():
93
+ # result = await amcall(self.sample_input, async_funcs)
94
+ # for i, func in enumerate(async_funcs):
95
+ # func.assert_called_once_with(self.sample_input[i])
96
+ # self.assertEqual(result, [2, 4, 2])
97
+
98
+ # asyncio.run(async_test())
99
+
100
+ # def test_ecall(self):
101
+ # funcs = [MagicMock(side_effect=lambda x: x * x)]
102
+ # result = ecall(self.sample_input, funcs)
103
+ # self.assertEqual(result, [[1, 4, 9]])
104
+
105
+ # def test_aecall(self):
106
+ # async_funcs = [MagicMock(side_effect=lambda x: x * x)]
107
+
108
+ # async def async_test():
109
+ # result = await aecall(self.sample_input, async_funcs)
110
+ # self.assertEqual(result, [[1, 4, 9]])
111
+
112
+ # asyncio.run(async_test())
113
+
114
+ # if __name__ == '__main__':
115
+ # unittest.main()
@@ -0,0 +1,202 @@
1
+ import unittest
2
+ from lionagi.utils.convert_util import *
3
+
4
+ # Test cases for the function
5
+ class TestStrToNum(unittest.TestCase):
6
+ def test_valid_int(self):
7
+ self.assertEqual(str_to_num("123"), 123)
8
+ self.assertEqual(str_to_num("-123"), -123)
9
+
10
+ def test_valid_float(self):
11
+ self.assertEqual(str_to_num("123.45", num_type=float), 123.45)
12
+ self.assertEqual(str_to_num("-123.45", num_type=float), -123.45)
13
+
14
+ def test_precision(self):
15
+ self.assertEqual(str_to_num("123.456", num_type=float, precision=1), 123.5)
16
+ self.assertEqual(str_to_num("123.444", num_type=float, precision=2), 123.44)
17
+
18
+ def test_bounds(self):
19
+ self.assertEqual(str_to_num("10", lower_bound=5, upper_bound=15), 10)
20
+ with self.assertRaises(ValueError):
21
+ str_to_num("20", upper_bound=15)
22
+ with self.assertRaises(ValueError):
23
+ str_to_num("2", lower_bound=5)
24
+
25
+ def test_invalid_input(self):
26
+ with self.assertRaises(ValueError):
27
+ str_to_num("abc")
28
+ with self.assertRaises(ValueError):
29
+ str_to_num("123abc", num_type=str)
30
+
31
+ def test_no_numeric_value(self):
32
+ with self.assertRaises(ValueError):
33
+ str_to_num("No numbers here")
34
+
35
+
36
+
37
+ # Functions to be tested
38
+ def dict_to_xml(data: Dict[str, Any], root_tag: str = 'node') -> str:
39
+ root = ET.Element(root_tag)
40
+ _build_xml(root, data)
41
+ return ET.tostring(root, encoding='unicode')
42
+
43
+ def _build_xml(element: ET.Element, data: Any):
44
+ if isinstance(data, dict):
45
+ for key, value in data.items():
46
+ sub_element = ET.SubElement(element, key)
47
+ _build_xml(sub_element, value)
48
+ elif isinstance(data, list):
49
+ for item in data:
50
+ item_element = ET.SubElement(element, 'item')
51
+ _build_xml(item_element, item)
52
+ else:
53
+ element.text = str(data)
54
+
55
+ def xml_to_dict(element: ET.Element) -> Dict[str, Any]:
56
+ dict_data = {}
57
+ for child in element:
58
+ if list(child):
59
+ dict_data[child.tag] = xml_to_dict(child)
60
+ else:
61
+ dict_data[child.tag] = child.text
62
+ return dict_data
63
+
64
+ # Test cases for the functions
65
+ class TestDictXMLConversion(unittest.TestCase):
66
+ def setUp(self):
67
+ self.data = {
68
+ 'name': 'John',
69
+ 'age': 30,
70
+ 'children': [
71
+ {'name': 'Alice', 'age': 5},
72
+ {'name': 'Bob', 'age': 7}
73
+ ]
74
+ }
75
+ self.root_tag = 'person'
76
+ self.xml = dict_to_xml(self.data, self.root_tag)
77
+ self.xml_element = ET.fromstring(self.xml)
78
+
79
+ def test_dict_to_xml(self):
80
+ self.assertIn('<name>John</name>', self.xml)
81
+ self.assertIn('<age>30</age>', self.xml)
82
+ self.assertIn('<children>', self.xml)
83
+ self.assertIn('<item>', self.xml)
84
+
85
+ # def test_xml_to_dict(self):
86
+ # data_from_xml = xml_to_dict(self.xml_element)
87
+ # self.assertEqual(data_from_xml, self.data)
88
+
89
+ # def test_xml_to_dict_to_xml(self):
90
+ # data_from_xml = xml_to_dict(self.xml_element)
91
+ # xml_from_dict = dict_to_xml(data_from_xml, self.root_tag)
92
+ # self.assertEqual(xml_from_dict, self.xml)
93
+
94
+ # def test_invalid_input(self):
95
+ # with self.assertRaises(TypeError):
96
+ # dict_to_xml("not a dict", self.root_tag)
97
+
98
+
99
+
100
+ class TestDocstringExtraction(unittest.TestCase):
101
+ def test_google_style_extraction(self):
102
+ def sample_func(arg1, arg2):
103
+ """
104
+ This is a sample function.
105
+
106
+ Args:
107
+ arg1 (int): The first argument.
108
+ arg2 (str): The second argument.
109
+
110
+ Returns:
111
+ bool: The truth value.
112
+ """
113
+ return True
114
+
115
+ description, params = extract_docstring_details_google(sample_func)
116
+ self.assertEqual(description, "This is a sample function.")
117
+ self.assertEqual(params, {
118
+ 'arg1': 'The first argument.',
119
+ 'arg2': 'The second argument.'
120
+ })
121
+
122
+ def test_rest_style_extraction(self):
123
+ def sample_func(arg1, arg2):
124
+ """
125
+ This is a sample function.
126
+
127
+ :param int arg1: The first argument.
128
+ :param str arg2: The second argument.
129
+ :return: The truth value.
130
+ :rtype: bool
131
+ """
132
+ return True
133
+
134
+ description, params = extract_docstring_details_rest(sample_func)
135
+ self.assertEqual(description, "This is a sample function.")
136
+ self.assertEqual(params, {
137
+ 'arg1': 'The first argument.',
138
+ 'arg2': 'The second argument.'
139
+ })
140
+
141
+ def test_extract_docstring_details_with_invalid_style(self):
142
+ def sample_func(arg1, arg2):
143
+ return True
144
+
145
+ with self.assertRaises(ValueError):
146
+ extract_docstring_details(sample_func, style='unsupported')
147
+
148
+ class TestPythonToJsonTypeConversion(unittest.TestCase):
149
+ def test_python_to_json_type_conversion(self):
150
+ self.assertEqual(python_to_json_type('str'), 'string')
151
+ self.assertEqual(python_to_json_type('int'), 'number')
152
+ self.assertEqual(python_to_json_type('float'), 'number')
153
+ self.assertEqual(python_to_json_type('list'), 'array')
154
+ self.assertEqual(python_to_json_type('tuple'), 'array')
155
+ self.assertEqual(python_to_json_type('bool'), 'boolean')
156
+ self.assertEqual(python_to_json_type('dict'), 'object')
157
+ self.assertEqual(python_to_json_type('nonexistent'), 'object')
158
+
159
+ class TestFunctionToSchema(unittest.TestCase):
160
+ def test_func_to_schema(self):
161
+ def sample_func(arg1: int, arg2: str) -> bool:
162
+ """
163
+ This is a sample function.
164
+
165
+ Args:
166
+ arg1 (int): The first argument.
167
+ arg2 (str): The second argument.
168
+
169
+ Returns:
170
+ bool: The truth value.
171
+ """
172
+ return True
173
+
174
+ schema = func_to_schema(sample_func, style='google')
175
+ expected_schema = {
176
+ "type": "function",
177
+ "function": {
178
+ "name": "sample_func",
179
+ "description": "This is a sample function.",
180
+ "parameters": {
181
+ "type": "object",
182
+ "properties": {
183
+ "arg1": {
184
+ "type": "number",
185
+ "description": "The first argument."
186
+ },
187
+ "arg2": {
188
+ "type": "string",
189
+ "description": "The second argument."
190
+ }
191
+ },
192
+ "required": ["arg1", "arg2"]
193
+ }
194
+ }
195
+ }
196
+ self.assertEqual(schema, expected_schema)
197
+
198
+ if __name__ == '__main__':
199
+ unittest.main()
200
+
201
+
202
+ # --------------------------------------------------