lionagi 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (32) hide show
  1. lionagi/__init__.py +2 -1
  2. lionagi/core/generic/graph.py +10 -3
  3. lionagi/core/generic/node.py +5 -1
  4. lionagi/core/report/base.py +1 -0
  5. lionagi/core/report/form.py +1 -1
  6. lionagi/core/session/directive_mixin.py +1 -1
  7. lionagi/core/unit/template/plan.py +1 -1
  8. lionagi/core/work/work.py +4 -2
  9. lionagi/core/work/work_edge.py +96 -0
  10. lionagi/core/work/work_function.py +36 -4
  11. lionagi/core/work/work_function_node.py +44 -0
  12. lionagi/core/work/work_queue.py +50 -26
  13. lionagi/core/work/work_task.py +155 -0
  14. lionagi/core/work/worker.py +225 -37
  15. lionagi/core/work/worker_engine.py +179 -0
  16. lionagi/core/work/worklog.py +9 -11
  17. lionagi/tests/test_core/generic/test_structure.py +193 -0
  18. lionagi/tests/test_core/graph/__init__.py +0 -0
  19. lionagi/tests/test_core/graph/test_graph.py +70 -0
  20. lionagi/tests/test_core/graph/test_tree.py +75 -0
  21. lionagi/tests/test_core/mail/__init__.py +0 -0
  22. lionagi/tests/test_core/mail/test_mail.py +62 -0
  23. lionagi/tests/test_core/test_structure/__init__.py +0 -0
  24. lionagi/tests/test_core/test_structure/test_base_structure.py +196 -0
  25. lionagi/tests/test_core/test_structure/test_graph.py +54 -0
  26. lionagi/tests/test_core/test_structure/test_tree.py +48 -0
  27. lionagi/version.py +1 -1
  28. {lionagi-0.2.0.dist-info → lionagi-0.2.2.dist-info}/METADATA +5 -4
  29. {lionagi-0.2.0.dist-info → lionagi-0.2.2.dist-info}/RECORD +32 -18
  30. {lionagi-0.2.0.dist-info → lionagi-0.2.2.dist-info}/LICENSE +0 -0
  31. {lionagi-0.2.0.dist-info → lionagi-0.2.2.dist-info}/WHEEL +0 -0
  32. {lionagi-0.2.0.dist-info → lionagi-0.2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,193 @@
1
+ import unittest
2
+ from lionagi.core.generic.structure import *
3
+
4
+ class TestCondition(Condition):
5
+ def check(self, node: Node) -> bool:
6
+ return True
7
+
8
+ class TestBaseStructure(unittest.TestCase):
9
+ def setUp(self):
10
+ self.structure = BaseStructure(id_="test_structure")
11
+ self.node1 = Node(id_="node1", content="Node 1 content")
12
+ self.node2 = Node(id_="node2", content="Node 2 content")
13
+ self.node3 = Node(id_="node3", content="Node 3 content")
14
+
15
+ def test_internal_edges(self):
16
+ self.structure.add_node([self.node1, self.node2])
17
+ self.structure.relate_nodes(self.node1, self.node2)
18
+ self.assertEqual(len(self.structure.internal_edges), 1)
19
+
20
+ def test_is_empty(self):
21
+ self.assertTrue(self.structure.is_empty)
22
+ self.structure.add_node(self.node1)
23
+ self.assertFalse(self.structure.is_empty)
24
+
25
+ def test_add_node_single(self):
26
+ self.structure.add_node(self.node1)
27
+ self.assertIn(self.node1.id_, self.structure.internal_nodes)
28
+
29
+ def test_add_node_list(self):
30
+ self.structure.add_node([self.node1, self.node2])
31
+ self.assertIn(self.node1.id_, self.structure.internal_nodes)
32
+ self.assertIn(self.node2.id_, self.structure.internal_nodes)
33
+
34
+ def test_add_node_dict(self):
35
+ self.structure.add_node({self.node1.id_: self.node1, self.node2.id_: self.node2})
36
+ self.assertIn(self.node1.id_, self.structure.internal_nodes)
37
+ self.assertIn(self.node2.id_, self.structure.internal_nodes)
38
+
39
+ def test_add_node_duplicate(self):
40
+ self.structure.add_node(self.node1)
41
+ with self.assertRaises(ValueError):
42
+ self.structure.add_node(self.node1)
43
+
44
+ def test_get_node_str(self):
45
+ self.structure.add_node(self.node1)
46
+ self.assertEqual(self.structure.get_node(self.node1.id_), self.node1)
47
+
48
+ def test_get_node_node(self):
49
+ self.structure.add_node(self.node1)
50
+ self.assertEqual(self.structure.get_node(self.node1), self.node1)
51
+
52
+ def test_get_node_list(self):
53
+ self.structure.add_node([self.node1, self.node2])
54
+ self.assertEqual(self.structure.get_node([self.node1.id_, self.node2.id_]), [self.node1, self.node2])
55
+
56
+ def test_get_node_dict(self):
57
+ self.structure.add_node([self.node1, self.node2])
58
+ self.assertEqual(self.structure.get_node({self.node1.id_: self.node1, self.node2.id_: self.node2}), [self.node1, self.node2])
59
+
60
+ def test_get_node_not_found(self):
61
+ with self.assertRaises(KeyError):
62
+ self.structure.get_node("nonexistent_node")
63
+
64
+ def test_get_node_default(self):
65
+ self.assertIsNone(self.structure.get_node("nonexistent_node", default=None))
66
+
67
+ def test_remove_node_node(self):
68
+ self.structure.add_node(self.node1)
69
+ self.structure.remove_node(self.node1)
70
+ self.assertNotIn(self.node1.id_, self.structure.internal_nodes)
71
+
72
+ def test_remove_node_str(self):
73
+ self.structure.add_node(self.node1)
74
+ self.structure.remove_node(self.node1.id_)
75
+ self.assertNotIn(self.node1.id_, self.structure.internal_nodes)
76
+
77
+ def test_remove_node_list(self):
78
+ self.structure.add_node([self.node1, self.node2])
79
+ self.structure.remove_node([self.node1, self.node2])
80
+ self.assertNotIn(self.node1.id_, self.structure.internal_nodes)
81
+ self.assertNotIn(self.node2.id_, self.structure.internal_nodes)
82
+
83
+ def test_remove_node_dict(self):
84
+ self.structure.add_node([self.node1, self.node2])
85
+ self.structure.remove_node({self.node1.id_: self.node1, self.node2.id_: self.node2})
86
+ self.assertNotIn(self.node1.id_, self.structure.internal_nodes)
87
+ self.assertNotIn(self.node2.id_, self.structure.internal_nodes)
88
+
89
+ def test_pop_node_node(self):
90
+ self.structure.add_node(self.node1)
91
+ popped_node = self.structure.pop_node(self.node1)
92
+ self.assertEqual(popped_node, self.node1)
93
+ self.assertNotIn(self.node1.id_, self.structure.internal_nodes)
94
+
95
+ def test_pop_node_str(self):
96
+ self.structure.add_node(self.node1)
97
+ popped_node = self.structure.pop_node(self.node1.id_)
98
+ self.assertEqual(popped_node, self.node1)
99
+ self.assertNotIn(self.node1.id_, self.structure.internal_nodes)
100
+
101
+ def test_pop_node_list(self):
102
+ self.structure.add_node([self.node1, self.node2])
103
+ popped_nodes = self.structure.pop_node([self.node1, self.node2])
104
+ self.assertEqual(popped_nodes, [self.node1, self.node2])
105
+ self.assertNotIn(self.node1.id_, self.structure.internal_nodes)
106
+ self.assertNotIn(self.node2.id_, self.structure.internal_nodes)
107
+
108
+ def test_pop_node_dict(self):
109
+ self.structure.add_node([self.node1, self.node2])
110
+ popped_nodes = self.structure.pop_node({self.node1.id_: self.node1, self.node2.id_: self.node2})
111
+ self.assertEqual(popped_nodes, [self.node1, self.node2])
112
+ self.assertNotIn(self.node1.id_, self.structure.internal_nodes)
113
+ self.assertNotIn(self.node2.id_, self.structure.internal_nodes)
114
+
115
+ def test_pop_node_default(self):
116
+ self.assertIsNone(self.structure.pop_node("nonexistent_node", default=None))
117
+
118
+ def test_remove_edge_edge(self):
119
+ self.structure.add_node([self.node1, self.node2])
120
+ self.structure.relate_nodes(self.node1, self.node2)
121
+ edge = list(self.node1.edges.values())[0]
122
+ self.structure.remove_edge(edge)
123
+ self.assertNotIn(edge.id_, self.structure.internal_edges)
124
+
125
+ def test_remove_edge_str(self):
126
+ self.structure.add_node([self.node1, self.node2])
127
+ self.structure.relate_nodes(self.node1, self.node2)
128
+ edge = list(self.node1.edges.values())[0]
129
+ self.structure.remove_edge(edge.id_)
130
+ self.assertNotIn(edge.id_, self.structure.internal_edges)
131
+
132
+ def test_remove_edge_list(self):
133
+ self.structure.add_node([self.node1, self.node2, self.node3])
134
+ self.structure.relate_nodes(self.node1, self.node2)
135
+ self.structure.relate_nodes(self.node2, self.node3)
136
+ edges = list(self.node2.edges.values())
137
+
138
+ self.structure.remove_edge(edges)
139
+ self.assertNotIn(edges[0].id_, self.structure.internal_edges)
140
+ self.assertNotIn(edges[1].id_, self.structure.internal_edges)
141
+
142
+ def test_remove_edge_dict(self):
143
+ self.structure.add_node([self.node1, self.node2, self.node3])
144
+ self.structure.relate_nodes(self.node1, self.node2)
145
+ self.structure.relate_nodes(self.node2, self.node3)
146
+
147
+ edge_dict = self.node2.edges
148
+ edge_list = list(edge_dict.values())
149
+
150
+ self.structure.remove_edge(edge_dict)
151
+ self.assertNotIn(edge_list[0].id_, self.structure.internal_edges)
152
+ self.assertNotIn(edge_list[1].id_, self.structure.internal_edges)
153
+
154
+ def test_remove_edge_not_found(self):
155
+ with self.assertRaises(ValueError):
156
+ self.structure.remove_edge("nonexistent_edge")
157
+
158
+ def test_clear(self):
159
+ self.structure.add_node([self.node1, self.node2])
160
+ self.structure.clear()
161
+ self.assertTrue(self.structure.is_empty)
162
+
163
+ def test_get_node_edges_head(self):
164
+ self.structure.add_node([self.node1, self.node2])
165
+ self.structure.relate_nodes(self.node1, self.node2)
166
+ edges = self.structure.get_node_edges(self.node1, node_as="head")
167
+ self.assertEqual(len(edges), 1)
168
+
169
+ def test_get_node_edges_tail(self):
170
+ self.structure.add_node([self.node1, self.node2])
171
+ self.structure.relate_nodes(self.node1, self.node2)
172
+ edges = self.structure.get_node_edges(self.node2, node_as="tail")
173
+ self.assertEqual(len(edges), 1)
174
+
175
+ def test_get_node_edges_label(self):
176
+ self.structure.add_node([self.node1, self.node2])
177
+ self.structure.relate_nodes(self.node1, self.node2, label="test_label")
178
+ edges = self.structure.get_node_edges(self.node1, node_as="head", label="test_label")
179
+ self.assertEqual(len(edges), 1)
180
+
181
+ def test_add_edge(self):
182
+ self.structure.relate_nodes(self.node1, self.node2)
183
+ self.assertIn(self.node1.id_, self.structure.internal_nodes)
184
+ self.assertIn(self.node2.id_, self.structure.internal_nodes)
185
+ self.assertEqual(len(self.structure.internal_edges), 1)
186
+
187
+ def test_add_edge_with_label(self):
188
+ self.structure.relate_nodes(self.node1, self.node2, label="test_label")
189
+ edge = list(self.structure.internal_edges.values())[0]
190
+ self.assertEqual(edge.label, "test_label")
191
+
192
+ if __name__ == "__main__":
193
+ unittest.main()
File without changes
@@ -0,0 +1,70 @@
1
+ import unittest
2
+ from unittest.mock import MagicMock, patch
3
+ from lionagi.core.graph.graph import Graph
4
+ from lionagi.core.generic.node import Node
5
+ from lionagi.core.generic.edge import Edge
6
+
7
+ class TestGraph(unittest.TestCase):
8
+
9
+ def setUp(self):
10
+ self.graph = Graph()
11
+ self.node1 = Node(id_="node1", content="Node 1 content")
12
+ self.node2 = Node(id_="node2", content="Node 2 content")
13
+ self.node3 = Node(id_="node3", content="Node 3 content")
14
+ self.graph.add_node(self.node1)
15
+ self.graph.add_node(self.node2)
16
+ self.graph.add_node(self.node3)
17
+
18
+ def test_graph_heads(self):
19
+ self.graph.relate_nodes(self.node1, self.node2)
20
+ self.assertEqual(["node1"], self.graph.graph_heads)
21
+
22
+
23
+ def test_acyclic(self):
24
+ self.graph.relate_nodes(self.node1, self.node2)
25
+ self.assertTrue(self.graph.acyclic)
26
+
27
+ self.graph.relate_nodes(self.node2, self.node1) # Creating a cycle
28
+ self.assertFalse(self.graph.acyclic)
29
+
30
+ @patch("lionagi.libs.SysUtil.check_import")
31
+ def test_to_networkx_success(self, mock_check_import):
32
+ mock_check_import.return_value = None
33
+ with patch("networkx.DiGraph") as mock_digraph:
34
+ mock_graph = MagicMock()
35
+ mock_digraph.return_value = mock_graph
36
+ result = self.graph.to_networkx()
37
+ self.assertEqual(result, mock_graph)
38
+
39
+ @patch("lionagi.libs.SysUtil.check_import")
40
+ def test_to_networkx_empty_graph(self, mock_check_import):
41
+ mock_check_import.return_value = None
42
+ with patch("networkx.DiGraph") as mock_digraph:
43
+ mock_graph = MagicMock()
44
+ mock_digraph.return_value = mock_graph
45
+
46
+ self.graph.internal_nodes = {}
47
+ result = self.graph.to_networkx()
48
+
49
+ self.assertEqual(result, mock_graph)
50
+ mock_check_import.assert_called_once_with("networkx")
51
+ mock_digraph.assert_called_once()
52
+ mock_graph.add_node.assert_not_called()
53
+ mock_graph.add_edge.assert_not_called()
54
+
55
+ def test_add_node(self):
56
+ new_node = Node(id_="node4", content="Node 4 content")
57
+ self.graph.add_node(new_node)
58
+ self.assertIn("node4", self.graph.internal_nodes)
59
+
60
+ def test_remove_node(self):
61
+ self.graph.remove_node(self.node1)
62
+ self.assertNotIn("node1", self.graph.internal_nodes)
63
+
64
+ def test_clear(self):
65
+ self.graph.clear()
66
+ self.assertEqual(len(self.graph.internal_nodes), 0)
67
+ self.assertTrue(self.graph.is_empty)
68
+
69
+ if __name__ == "__main__":
70
+ unittest.main()
@@ -0,0 +1,75 @@
1
+ import unittest
2
+ from lionagi.core.graph.tree import TreeNode, Tree
3
+
4
+ class TestTreeNode(unittest.TestCase):
5
+ def setUp(self):
6
+ self.parent_node = TreeNode(id_="parent", content="Parent Node")
7
+ self.child_node1 = TreeNode(id_="child1", content="Child Node 1")
8
+ self.child_node2 = TreeNode(id_="child2", content="Child Node 2")
9
+
10
+ def test_relate_child(self):
11
+ self.parent_node.relate_child(self.child_node1)
12
+ self.assertIn("child1", self.parent_node.children)
13
+ self.assertEqual(self.child_node1.parent, self.parent_node)
14
+
15
+ def test_relate_children(self):
16
+ self.parent_node.relate_child([self.child_node1, self.child_node2])
17
+ self.assertIn("child1", self.parent_node.children)
18
+ self.assertIn("child2", self.parent_node.children)
19
+ self.assertEqual(self.child_node1.parent, self.parent_node)
20
+ self.assertEqual(self.child_node2.parent, self.parent_node)
21
+
22
+ def test_relate_parent(self):
23
+ self.child_node1.relate_parent(self.parent_node)
24
+ self.assertIn("child1", self.parent_node.children)
25
+ self.assertEqual(self.child_node1.parent, self.parent_node)
26
+
27
+ def test_unrelate_child(self):
28
+ self.parent_node.relate_child(self.child_node1)
29
+ self.parent_node.unrelate_child(self.child_node1)
30
+ self.assertNotIn("child1", self.parent_node.children)
31
+ self.assertIsNone(self.child_node1.parent)
32
+
33
+ def test_unrelate_parent(self):
34
+ self.child_node1.relate_parent(self.parent_node)
35
+ self.child_node1.unrelate_parent()
36
+ self.assertNotIn("child1", self.parent_node.children)
37
+ self.assertIsNone(self.child_node1.parent)
38
+
39
+ class TestTree(unittest.TestCase):
40
+ def setUp(self):
41
+ self.tree = Tree()
42
+ self.root = TreeNode(id_="root", content="Root Node")
43
+ self.child_node1 = TreeNode(id_="child1", content="Child Node 1")
44
+ self.child_node2 = TreeNode(id_="child2", content="Child Node 2")
45
+ self.grandchild_node = TreeNode(id_="grandchild", content="Grandchild Node")
46
+
47
+ def test_add_node(self):
48
+ self.tree.add_node(self.root)
49
+ self.assertIn("root", self.tree.internal_nodes)
50
+
51
+ def test_relate_parent_child(self):
52
+ self.tree.relate_parent_child(self.root, [self.child_node1, self.child_node2])
53
+ self.assertIn("child1", self.root.children)
54
+ self.assertIn("child2", self.root.children)
55
+ self.assertEqual(self.tree.root, self.root)
56
+
57
+ def test_tree_structure(self):
58
+ # Build the tree
59
+ self.tree.relate_parent_child(self.root, [self.child_node1, self.child_node2])
60
+ self.tree.relate_parent_child(self.child_node1, self.grandchild_node)
61
+
62
+ # Check the tree structure
63
+ self.assertIn("grandchild", self.child_node1.children)
64
+ self.assertEqual(self.grandchild_node.parent, self.child_node1)
65
+ self.assertEqual(self.child_node1.parent, self.root)
66
+
67
+ def test_acyclic(self):
68
+ # Build the tree
69
+ self.tree.relate_parent_child(self.root, self.child_node1)
70
+ self.tree.relate_parent_child(self.child_node1, self.child_node2)
71
+ # Trees should always be acyclic
72
+ self.assertTrue(self.tree.acyclic)
73
+
74
+ if __name__ == "__main__":
75
+ unittest.main()
File without changes
@@ -0,0 +1,62 @@
1
+ from lionagi.core.generic.mail import *
2
+
3
+ import unittest
4
+
5
+
6
+ class TestMail(unittest.TestCase):
7
+ def setUp(self):
8
+ self.mail = Mail(
9
+ sender="node1",
10
+ recipient="node2",
11
+ category=MailPackageCategory.MESSAGES,
12
+ package="Hello, World!"
13
+ )
14
+
15
+ def test_mail_initialization(self):
16
+ """Test initialization of Mail objects."""
17
+ self.assertIsInstance(self.mail, BaseComponent)
18
+ self.assertEqual(self.mail.sender, "node1")
19
+ self.assertEqual(self.mail.recipient, "node2")
20
+ self.assertEqual(self.mail.category, MailPackageCategory.MESSAGES)
21
+ self.assertEqual(self.mail.package, "Hello, World!")
22
+
23
+ def test_mail_str(self):
24
+ """Test the string representation of Mail."""
25
+ expected_str = "Mail from node1 to node2 with category messages and package Hello, World!"
26
+ self.assertEqual(str(self.mail), expected_str)
27
+
28
+ class TestMailBox(unittest.TestCase):
29
+ def setUp(self):
30
+ self.mailbox = MailBox()
31
+ self.mail1 = Mail(
32
+ sender="node1",
33
+ recipient="node3",
34
+ category="model",
35
+ package={"model": "Random Forest"}
36
+ )
37
+ self.mail2 = Mail(
38
+ sender="node2",
39
+ recipient="node3",
40
+ category=MailPackageCategory.SERVICE,
41
+ package={"service": "Prediction"}
42
+ )
43
+
44
+ def test_adding_mails(self):
45
+ """Test adding mails to MailBox."""
46
+ self.mailbox.pending_ins["node1"] = self.mail1
47
+ self.mailbox.pending_outs["node3"] = self.mail2
48
+
49
+ self.assertIn("node1", self.mailbox.pending_ins)
50
+ self.assertIn("node3", self.mailbox.pending_outs)
51
+ self.assertEqual(self.mailbox.pending_ins["node1"], self.mail1)
52
+ self.assertEqual(self.mailbox.pending_outs["node3"], self.mail2)
53
+
54
+ def test_mailbox_str(self):
55
+ """Test the string representation of MailBox."""
56
+ self.mailbox.pending_ins["node1"] = self.mail1
57
+ self.mailbox.pending_outs["node3"] = self.mail2
58
+ expected_str = "MailBox with 1 pending incoming mails and 1 pending outgoing mails."
59
+ self.assertEqual(str(self.mailbox), expected_str)
60
+
61
+ if __name__ == "__main__":
62
+ unittest.main()
File without changes
@@ -0,0 +1,196 @@
1
+ import unittest
2
+ from unittest.mock import MagicMock, patch
3
+ from lionagi.core.tool.structure import *
4
+
5
+
6
+ class TestBaseStructure(unittest.TestCase):
7
+ def setUp(self):
8
+ self.structure = BaseStructure()
9
+ self.node1 = BaseNode(id_="node1")
10
+ self.node2 = BaseNode(id_="node2")
11
+ self.node3 = BaseNode(id_="node3")
12
+ self.edge1 = Edge(id_="edge1", source_node_id="node1", target_node_id="node2")
13
+ self.edge2 = Edge(id_="edge2", source_node_id="node2", target_node_id="node3")
14
+
15
+ def test_node_edges_property(self):
16
+ self.node1.in_relations = {"edge1": self.edge1}
17
+ self.node1.out_relations = {"edge2": self.edge2}
18
+ self.structure.structure_nodes = {"node1": self.node1}
19
+ expected_result = {
20
+ "node1": {"in": {"edge1": self.edge1}, "out": {"edge2": self.edge2}}
21
+ }
22
+ self.assertEqual(self.structure.node_edges, expected_result)
23
+
24
+ def test_get_node_edges_with_node(self):
25
+ self.node1.out_relations = {"edge1": self.edge1}
26
+ self.node1.in_relations = {"edge2": self.edge2}
27
+ self.structure.structure_nodes = {"node1": self.node1}
28
+ self.assertEqual(
29
+ self.structure.get_node_edges(self.node1, direction="out"), [self.edge1]
30
+ )
31
+ self.assertEqual(
32
+ self.structure.get_node_edges(self.node1, direction="in"), [self.edge2]
33
+ )
34
+ self.assertEqual(
35
+ self.structure.get_node_edges(self.node1, direction="all"),
36
+ [self.edge2, self.edge1],
37
+ )
38
+
39
+ def test_get_node_edges_without_node(self):
40
+ self.structure.structure_edges = {"edge1": self.edge1, "edge2": self.edge2}
41
+ self.assertEqual(self.structure.get_node_edges(), [self.edge1, self.edge2])
42
+
43
+ def test_get_node_edges_node_not_found(self):
44
+ invalid_node = BaseNode(id_="invalid_node")
45
+ try:
46
+ self.structure.get_node_edges(invalid_node)
47
+ except KeyError as e:
48
+ self.assertEqual(str(e), f"node {invalid_node.id_} is not found")
49
+
50
+ def test_has_structure_edge_with_edge_object(self):
51
+ self.structure.structure_edges = {"edge1": self.edge1}
52
+ self.assertTrue(self.structure.has_structure_edge(self.edge1))
53
+
54
+ def test_has_structure_edge_with_edge_id(self):
55
+ self.structure.structure_edges = {"edge1": self.edge1}
56
+ self.assertTrue(self.structure.has_structure_edge("edge1"))
57
+
58
+ def test_get_structure_edge_with_edge_object(self):
59
+ self.structure.structure_edges = {"edge1": self.edge1}
60
+ self.assertEqual(self.structure.get_structure_edge(self.edge1), self.edge1)
61
+
62
+ def test_get_structure_edge_with_edge_id(self):
63
+ self.structure.structure_edges = {"edge1": self.edge1}
64
+ self.assertEqual(self.structure.get_structure_edge("edge1"), self.edge1)
65
+
66
+ def test_add_structure_edge_success(self):
67
+ self.structure.structure_nodes = {"node1": self.node1, "node2": self.node2}
68
+ self.structure.add_structure_edge(self.edge1)
69
+ self.assertEqual(self.structure.structure_edges, {"edge1": self.edge1})
70
+ self.assertEqual(self.node1.out_relations, {"node2": self.edge1})
71
+ self.assertEqual(self.node2.in_relations, {"node1": self.edge1})
72
+
73
+ def test_add_structure_edge_failure(self):
74
+ with self.assertRaises(ValueError):
75
+ self.structure.add_structure_edge(self.edge1)
76
+
77
+ def test_remove_structure_edge_with_edge_object(self):
78
+ self.structure.structure_edges = {"edge1": self.edge1}
79
+ self.structure.structure_nodes = {"node1": self.node1, "node2": self.node2}
80
+ self.node1.out_relations = {"node2": self.edge1}
81
+ self.node2.in_relations = {"node1": self.edge1}
82
+ self.structure.remove_structure_edge(self.edge1)
83
+ self.assertEqual(self.structure.structure_edges, {})
84
+ self.assertEqual(self.node1.out_relations, {})
85
+ self.assertEqual(self.node2.in_relations, {})
86
+
87
+ def test_remove_structure_edge_with_edge_id(self):
88
+ self.structure.structure_edges = {"edge1": self.edge1}
89
+ self.structure.structure_nodes = {"node1": self.node1, "node2": self.node2}
90
+ self.node1.out_relations = {"node2": self.edge1}
91
+ self.node2.in_relations = {"node1": self.edge1}
92
+ self.structure.remove_structure_edge("edge1")
93
+ self.assertEqual(self.structure.structure_edges, {})
94
+ self.assertEqual(self.node1.out_relations, {})
95
+ self.assertEqual(self.node2.in_relations, {})
96
+
97
+ def test_remove_structure_node_success(self):
98
+ self.structure.structure_nodes = {"node1": self.node1, "node2": self.node2}
99
+ self.structure.structure_edges = {"edge1": self.edge1}
100
+ self.node1.out_relations = {"node2": self.edge1}
101
+ self.node2.in_relations = {"node1": self.edge1}
102
+ self.structure.remove_structure_node(self.node1)
103
+ self.assertEqual(self.structure.structure_nodes, {"node2": self.node2})
104
+ self.assertEqual(self.structure.structure_edges, {})
105
+
106
+ def test_remove_structure_node_failure(self):
107
+ with self.assertRaises(ValueError):
108
+ self.structure.remove_structure_node(self.node1)
109
+
110
+ def test_add_structure_node_with_base_node(self):
111
+ self.structure.add_structure_node(self.node1)
112
+ self.assertEqual(self.structure.structure_nodes, {"node1": self.node1})
113
+
114
+ def test_add_structure_node_with_list(self):
115
+ self.structure.add_structure_node([self.node1, self.node2])
116
+ self.assertEqual(
117
+ self.structure.structure_nodes, {"node1": self.node1, "node2": self.node2}
118
+ )
119
+
120
+ def test_add_structure_node_with_dict(self):
121
+ self.structure.add_structure_node({"node1": self.node1, "node2": self.node2})
122
+ self.assertEqual(
123
+ self.structure.structure_nodes, {"node1": self.node1, "node2": self.node2}
124
+ )
125
+
126
+ def test_add_structure_node_unsupported_type(self):
127
+ with self.assertRaises(NotImplementedError):
128
+ self.structure.add_structure_node(1)
129
+
130
+ def test_get_structure_node_with_node_id(self):
131
+ self.structure.structure_nodes = {"node1": self.node1}
132
+ self.assertEqual(self.structure.get_structure_node("node1"), self.node1)
133
+
134
+ def test_get_structure_node_with_base_node(self):
135
+ self.structure.structure_nodes = {"node1": self.node1}
136
+ self.assertEqual(self.structure.get_structure_node(self.node1), self.node1)
137
+
138
+ def test_get_structure_node_with_list(self):
139
+ self.structure.structure_nodes = {"node1": self.node1, "node2": self.node2}
140
+ self.assertEqual(
141
+ self.structure.get_structure_node(["node1", "node2"]),
142
+ [self.node1, self.node2],
143
+ )
144
+
145
+ def test_pop_structure_node_with_node_id(self):
146
+ self.structure.structure_nodes = {"node1": self.node1}
147
+ self.assertEqual(self.structure.pop_structure_node("node1"), self.node1)
148
+ self.assertEqual(self.structure.structure_nodes, {})
149
+
150
+ def test_pop_structure_node_with_base_node(self):
151
+ self.structure.structure_nodes = {"node1": self.node1}
152
+ self.assertEqual(self.structure.pop_structure_node(self.node1), self.node1)
153
+ self.assertEqual(self.structure.structure_nodes, {})
154
+
155
+ def test_pop_structure_node_with_list(self):
156
+ self.structure.structure_nodes = {"node1": self.node1, "node2": self.node2}
157
+ self.assertEqual(
158
+ self.structure.pop_structure_node(["node1", "node2"]),
159
+ [self.node1, self.node2],
160
+ )
161
+ self.assertEqual(self.structure.structure_nodes, {})
162
+
163
+ def test_pop_structure_node_unsupported_type(self):
164
+ with self.assertRaises(NotImplementedError):
165
+ self.structure.pop_structure_node(1)
166
+
167
+ def test_has_structure_node_with_node_id(self):
168
+ self.structure.structure_nodes = {"node1": self.node1}
169
+ self.assertTrue(self.structure.has_structure_node("node1"))
170
+ self.assertFalse(self.structure.has_structure_node("node2"))
171
+
172
+ def test_has_structure_node_with_base_node(self):
173
+ self.structure.structure_nodes = {"node1": self.node1}
174
+ self.assertTrue(self.structure.has_structure_node(self.node1))
175
+ self.assertFalse(self.structure.has_structure_node(self.node2))
176
+
177
+ def test_has_structure_node_with_list(self):
178
+ self.structure.structure_nodes = {"node1": self.node1, "node2": self.node2}
179
+ self.assertTrue(self.structure.has_structure_node(["node1", "node2"]))
180
+ self.assertFalse(self.structure.has_structure_node(["node1", "node3"]))
181
+
182
+ def test_is_empty_property(self):
183
+ self.assertTrue(self.structure.is_empty)
184
+ self.structure.structure_nodes = {"node1": self.node1}
185
+ self.assertFalse(self.structure.is_empty)
186
+
187
+ def test_clear_method(self):
188
+ self.structure.structure_nodes = {"node1": self.node1}
189
+ self.structure.structure_edges = {"edge1": self.edge1}
190
+ self.structure.clear()
191
+ self.assertEqual(self.structure.structure_nodes, {})
192
+ self.assertEqual(self.structure.structure_edges, {})
193
+
194
+
195
+ if __name__ == "__main__":
196
+ unittest.main()
@@ -0,0 +1,54 @@
1
+ import unittest
2
+ from unittest.mock import MagicMock, patch
3
+ from lionagi.core.generic import BaseNode, Edge
4
+ from lionagi.new.schema.todo.graph import Graph
5
+
6
+
7
+ class TestGraph(unittest.TestCase):
8
+ def setUp(self):
9
+ self.graph = Graph()
10
+ self.node1 = BaseNode(id_="node1", content="Node 1")
11
+ self.node2 = BaseNode(id_="node2", content="Node 2")
12
+ self.edge1 = Edge(
13
+ id_="edge1", source_node_id="node1", target_node_id="node2", label="Edge 1"
14
+ )
15
+ self.graph.structure_nodes = {"node1": self.node1, "node2": self.node2}
16
+ self.graph.structure_edges = {"edge1": self.edge1}
17
+
18
+ @patch("lionagi.libs.SysUtil.check_import")
19
+ def test_to_networkx_success(self, mock_check_import):
20
+ mock_check_import.return_value = None
21
+ with patch("networkx.DiGraph") as mock_digraph:
22
+ mock_graph = MagicMock()
23
+ mock_digraph.return_value = mock_graph
24
+
25
+ result = self.graph.to_networkx()
26
+
27
+ self.assertEqual(result, mock_graph)
28
+
29
+ @patch("lionagi.libs.SysUtil.check_import")
30
+ def test_to_networkx_empty_graph(self, mock_check_import):
31
+ mock_check_import.return_value = None
32
+ with patch("networkx.DiGraph") as mock_digraph:
33
+ mock_graph = MagicMock()
34
+ mock_digraph.return_value = mock_graph
35
+
36
+ self.graph.structure_nodes = {}
37
+ self.graph.structure_edges = {}
38
+ result = self.graph.to_networkx()
39
+
40
+ self.assertEqual(result, mock_graph)
41
+ mock_check_import.assert_called_once_with("networkx")
42
+ mock_digraph.assert_called_once()
43
+ mock_graph.add_node.assert_not_called()
44
+ mock_graph.add_edge.assert_not_called()
45
+
46
+ @patch("lionagi.libs.SysUtil.check_import", side_effect=ImportError)
47
+ def test_to_networkx_import_error(self, mock_check_import):
48
+ with self.assertRaises(ImportError):
49
+ self.graph.to_networkx()
50
+ mock_check_import.assert_called_once_with("networkx")
51
+
52
+
53
+ if __name__ == "__main__":
54
+ unittest.main()