pygeai-orchestration 0.1.0b2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pygeai_orchestration/__init__.py +99 -0
- pygeai_orchestration/cli/__init__.py +7 -0
- pygeai_orchestration/cli/__main__.py +11 -0
- pygeai_orchestration/cli/commands/__init__.py +13 -0
- pygeai_orchestration/cli/commands/base.py +192 -0
- pygeai_orchestration/cli/error_handler.py +123 -0
- pygeai_orchestration/cli/formatters.py +419 -0
- pygeai_orchestration/cli/geai_orch.py +270 -0
- pygeai_orchestration/cli/interactive.py +265 -0
- pygeai_orchestration/cli/texts/help.py +169 -0
- pygeai_orchestration/core/__init__.py +130 -0
- pygeai_orchestration/core/base/__init__.py +23 -0
- pygeai_orchestration/core/base/agent.py +121 -0
- pygeai_orchestration/core/base/geai_agent.py +144 -0
- pygeai_orchestration/core/base/geai_orchestrator.py +77 -0
- pygeai_orchestration/core/base/orchestrator.py +142 -0
- pygeai_orchestration/core/base/pattern.py +161 -0
- pygeai_orchestration/core/base/tool.py +149 -0
- pygeai_orchestration/core/common/__init__.py +18 -0
- pygeai_orchestration/core/common/context.py +140 -0
- pygeai_orchestration/core/common/memory.py +176 -0
- pygeai_orchestration/core/common/message.py +50 -0
- pygeai_orchestration/core/common/state.py +181 -0
- pygeai_orchestration/core/composition.py +190 -0
- pygeai_orchestration/core/config.py +356 -0
- pygeai_orchestration/core/exceptions.py +400 -0
- pygeai_orchestration/core/handlers.py +380 -0
- pygeai_orchestration/core/utils/__init__.py +37 -0
- pygeai_orchestration/core/utils/cache.py +138 -0
- pygeai_orchestration/core/utils/config.py +94 -0
- pygeai_orchestration/core/utils/logging.py +57 -0
- pygeai_orchestration/core/utils/metrics.py +184 -0
- pygeai_orchestration/core/utils/validators.py +140 -0
- pygeai_orchestration/dev/__init__.py +15 -0
- pygeai_orchestration/dev/debug.py +288 -0
- pygeai_orchestration/dev/templates.py +321 -0
- pygeai_orchestration/dev/testing.py +301 -0
- pygeai_orchestration/patterns/__init__.py +15 -0
- pygeai_orchestration/patterns/multi_agent.py +237 -0
- pygeai_orchestration/patterns/planning.py +219 -0
- pygeai_orchestration/patterns/react.py +221 -0
- pygeai_orchestration/patterns/reflection.py +134 -0
- pygeai_orchestration/patterns/tool_use.py +170 -0
- pygeai_orchestration/tests/__init__.py +1 -0
- pygeai_orchestration/tests/test_base_classes.py +187 -0
- pygeai_orchestration/tests/test_cache.py +184 -0
- pygeai_orchestration/tests/test_cli_formatters.py +232 -0
- pygeai_orchestration/tests/test_common.py +214 -0
- pygeai_orchestration/tests/test_composition.py +265 -0
- pygeai_orchestration/tests/test_config.py +301 -0
- pygeai_orchestration/tests/test_dev_utils.py +337 -0
- pygeai_orchestration/tests/test_exceptions.py +327 -0
- pygeai_orchestration/tests/test_handlers.py +307 -0
- pygeai_orchestration/tests/test_metrics.py +171 -0
- pygeai_orchestration/tests/test_patterns.py +165 -0
- pygeai_orchestration-0.1.0b2.dist-info/METADATA +290 -0
- pygeai_orchestration-0.1.0b2.dist-info/RECORD +61 -0
- pygeai_orchestration-0.1.0b2.dist-info/WHEEL +5 -0
- pygeai_orchestration-0.1.0b2.dist-info/entry_points.txt +2 -0
- pygeai_orchestration-0.1.0b2.dist-info/licenses/LICENSE +8 -0
- pygeai_orchestration-0.1.0b2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
"""Tests for CLI formatters."""
|
|
2
|
+
|
|
3
|
+
import unittest
|
|
4
|
+
from io import StringIO
|
|
5
|
+
from unittest.mock import patch
|
|
6
|
+
|
|
7
|
+
from pygeai_orchestration.cli.formatters import (
|
|
8
|
+
Color,
|
|
9
|
+
OutputFormatter,
|
|
10
|
+
ProgressBar,
|
|
11
|
+
Spinner,
|
|
12
|
+
Symbol,
|
|
13
|
+
format_error_details,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TestColor(unittest.TestCase):
|
|
18
|
+
def test_color_values(self):
|
|
19
|
+
self.assertIsInstance(Color.RED, str)
|
|
20
|
+
self.assertIsInstance(Color.GREEN, str)
|
|
21
|
+
self.assertIsInstance(Color.RESET, str)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class TestSymbol(unittest.TestCase):
|
|
25
|
+
def test_symbol_values(self):
|
|
26
|
+
self.assertIsInstance(Symbol.SUCCESS, str)
|
|
27
|
+
self.assertIsInstance(Symbol.ERROR, str)
|
|
28
|
+
self.assertIsInstance(Symbol.WARNING, str)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class TestOutputFormatter(unittest.TestCase):
|
|
32
|
+
def setUp(self):
|
|
33
|
+
self.formatter_color = OutputFormatter(use_color=True)
|
|
34
|
+
self.formatter_no_color = OutputFormatter(use_color=False)
|
|
35
|
+
|
|
36
|
+
def test_colorize_with_color(self):
|
|
37
|
+
result = self.formatter_color.colorize("test", Color.RED)
|
|
38
|
+
self.assertIn("test", result)
|
|
39
|
+
|
|
40
|
+
def test_colorize_no_color(self):
|
|
41
|
+
result = self.formatter_no_color.colorize("test", Color.RED)
|
|
42
|
+
self.assertEqual(result, "test")
|
|
43
|
+
|
|
44
|
+
def test_colorize_bold(self):
|
|
45
|
+
result = self.formatter_color.colorize("test", Color.RED, bold=True)
|
|
46
|
+
self.assertIn("test", result)
|
|
47
|
+
|
|
48
|
+
def test_success(self):
|
|
49
|
+
result = self.formatter_no_color.success("Great!")
|
|
50
|
+
self.assertIn(Symbol.SUCCESS, result)
|
|
51
|
+
self.assertIn("Great!", result)
|
|
52
|
+
|
|
53
|
+
def test_error(self):
|
|
54
|
+
result = self.formatter_no_color.error("Oops!")
|
|
55
|
+
self.assertIn(Symbol.ERROR, result)
|
|
56
|
+
self.assertIn("Oops!", result)
|
|
57
|
+
|
|
58
|
+
def test_warning(self):
|
|
59
|
+
result = self.formatter_no_color.warning("Careful!")
|
|
60
|
+
self.assertIn(Symbol.WARNING, result)
|
|
61
|
+
self.assertIn("Careful!", result)
|
|
62
|
+
|
|
63
|
+
def test_info(self):
|
|
64
|
+
result = self.formatter_no_color.info("FYI")
|
|
65
|
+
self.assertIn(Symbol.INFO, result)
|
|
66
|
+
self.assertIn("FYI", result)
|
|
67
|
+
|
|
68
|
+
def test_heading_level_1(self):
|
|
69
|
+
result = self.formatter_no_color.heading("Title", level=1)
|
|
70
|
+
self.assertEqual(result, "Title")
|
|
71
|
+
|
|
72
|
+
def test_heading_level_2(self):
|
|
73
|
+
result = self.formatter_no_color.heading("Subtitle", level=2)
|
|
74
|
+
self.assertEqual(result, "Subtitle")
|
|
75
|
+
|
|
76
|
+
def test_heading_level_3(self):
|
|
77
|
+
result = self.formatter_no_color.heading("Minor", level=3)
|
|
78
|
+
self.assertEqual(result, "Minor")
|
|
79
|
+
|
|
80
|
+
def test_dim(self):
|
|
81
|
+
result = self.formatter_no_color.dim("faded")
|
|
82
|
+
self.assertEqual(result, "faded")
|
|
83
|
+
|
|
84
|
+
def test_bold(self):
|
|
85
|
+
result = self.formatter_no_color.bold("strong")
|
|
86
|
+
self.assertEqual(result, "strong")
|
|
87
|
+
|
|
88
|
+
def test_key_value(self):
|
|
89
|
+
result = self.formatter_no_color.key_value("name", "value")
|
|
90
|
+
self.assertIn("name", result)
|
|
91
|
+
self.assertIn("value", result)
|
|
92
|
+
|
|
93
|
+
def test_key_value_indent(self):
|
|
94
|
+
result = self.formatter_no_color.key_value("name", "value", indent=2)
|
|
95
|
+
self.assertIn(" ", result)
|
|
96
|
+
self.assertIn("name", result)
|
|
97
|
+
|
|
98
|
+
def test_bullet_list(self):
|
|
99
|
+
items = ["item1", "item2", "item3"]
|
|
100
|
+
result = self.formatter_no_color.bullet_list(items)
|
|
101
|
+
self.assertIn("item1", result)
|
|
102
|
+
self.assertIn("item2", result)
|
|
103
|
+
self.assertIn("item3", result)
|
|
104
|
+
self.assertIn(Symbol.BULLET, result)
|
|
105
|
+
|
|
106
|
+
def test_bullet_list_indent(self):
|
|
107
|
+
items = ["item1"]
|
|
108
|
+
result = self.formatter_no_color.bullet_list(items, indent=1)
|
|
109
|
+
self.assertIn(" ", result)
|
|
110
|
+
|
|
111
|
+
def test_section(self):
|
|
112
|
+
result = self.formatter_no_color.section("Section", "content")
|
|
113
|
+
self.assertIn("Section", result)
|
|
114
|
+
self.assertIn("content", result)
|
|
115
|
+
|
|
116
|
+
def test_table_empty(self):
|
|
117
|
+
result = self.formatter_no_color.table(["H1", "H2"], [])
|
|
118
|
+
self.assertEqual(result, "")
|
|
119
|
+
|
|
120
|
+
def test_table_with_rows(self):
|
|
121
|
+
headers = ["Name", "Age"]
|
|
122
|
+
rows = [["Alice", "30"], ["Bob", "25"]]
|
|
123
|
+
result = self.formatter_no_color.table(headers, rows)
|
|
124
|
+
self.assertIn("Name", result)
|
|
125
|
+
self.assertIn("Age", result)
|
|
126
|
+
self.assertIn("Alice", result)
|
|
127
|
+
self.assertIn("Bob", result)
|
|
128
|
+
|
|
129
|
+
def test_json_tree_simple(self):
|
|
130
|
+
data = {"key": "value"}
|
|
131
|
+
result = self.formatter_no_color.json_tree(data)
|
|
132
|
+
self.assertIn("key", result)
|
|
133
|
+
self.assertIn("value", result)
|
|
134
|
+
|
|
135
|
+
def test_json_tree_nested(self):
|
|
136
|
+
data = {"outer": {"inner": "value"}}
|
|
137
|
+
result = self.formatter_no_color.json_tree(data)
|
|
138
|
+
self.assertIn("outer", result)
|
|
139
|
+
self.assertIn("inner", result)
|
|
140
|
+
self.assertIn("value", result)
|
|
141
|
+
|
|
142
|
+
def test_json_tree_list(self):
|
|
143
|
+
data = {"items": ["a", "b"]}
|
|
144
|
+
result = self.formatter_no_color.json_tree(data)
|
|
145
|
+
self.assertIn("items", result)
|
|
146
|
+
self.assertIn("a", result)
|
|
147
|
+
self.assertIn("b", result)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class TestProgressBar(unittest.TestCase):
|
|
151
|
+
@patch("sys.stdout", new_callable=StringIO)
|
|
152
|
+
def test_progress_bar_creation(self, mock_stdout):
|
|
153
|
+
bar = ProgressBar(total=10, width=20)
|
|
154
|
+
self.assertEqual(bar.total, 10)
|
|
155
|
+
self.assertEqual(bar.width, 20)
|
|
156
|
+
self.assertEqual(bar.current, 0)
|
|
157
|
+
|
|
158
|
+
@patch("sys.stdout.isatty", return_value=False)
|
|
159
|
+
def test_progress_bar_update_no_tty(self, mock_isatty):
|
|
160
|
+
bar = ProgressBar(total=10)
|
|
161
|
+
bar.update(5)
|
|
162
|
+
self.assertEqual(bar.current, 5)
|
|
163
|
+
|
|
164
|
+
@patch("sys.stdout.isatty", return_value=False)
|
|
165
|
+
def test_progress_bar_finish(self, mock_isatty):
|
|
166
|
+
bar = ProgressBar(total=10)
|
|
167
|
+
bar.finish()
|
|
168
|
+
self.assertEqual(bar.current, 10)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
class TestSpinner(unittest.TestCase):
|
|
172
|
+
@patch("sys.stdout.isatty", return_value=False)
|
|
173
|
+
def test_spinner_creation(self, mock_isatty):
|
|
174
|
+
spinner = Spinner("Loading...")
|
|
175
|
+
self.assertEqual(spinner.message, "Loading...")
|
|
176
|
+
self.assertFalse(spinner.running)
|
|
177
|
+
|
|
178
|
+
@patch("sys.stdout.isatty", return_value=False)
|
|
179
|
+
def test_spinner_start(self, mock_isatty):
|
|
180
|
+
spinner = Spinner("Loading...")
|
|
181
|
+
spinner.start()
|
|
182
|
+
self.assertTrue(spinner.running)
|
|
183
|
+
|
|
184
|
+
@patch("sys.stdout.isatty", return_value=False)
|
|
185
|
+
def test_spinner_update(self, mock_isatty):
|
|
186
|
+
spinner = Spinner("Loading...")
|
|
187
|
+
spinner.update("New message")
|
|
188
|
+
self.assertEqual(spinner.message, "New message")
|
|
189
|
+
|
|
190
|
+
@patch("sys.stdout.isatty", return_value=False)
|
|
191
|
+
def test_spinner_stop(self, mock_isatty):
|
|
192
|
+
spinner = Spinner("Loading...")
|
|
193
|
+
spinner.start()
|
|
194
|
+
spinner.stop()
|
|
195
|
+
self.assertFalse(spinner.running)
|
|
196
|
+
|
|
197
|
+
@patch("sys.stdout.isatty", return_value=False)
|
|
198
|
+
def test_spinner_stop_with_message(self, mock_isatty):
|
|
199
|
+
spinner = Spinner("Loading...")
|
|
200
|
+
spinner.start()
|
|
201
|
+
spinner.stop("Done!")
|
|
202
|
+
self.assertFalse(spinner.running)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class TestFormatErrorDetails(unittest.TestCase):
|
|
206
|
+
def test_format_simple_error(self):
|
|
207
|
+
error = ValueError("Something went wrong")
|
|
208
|
+
result = format_error_details(error)
|
|
209
|
+
self.assertIn("ValueError", result)
|
|
210
|
+
self.assertIn("Something went wrong", result)
|
|
211
|
+
|
|
212
|
+
def test_format_error_with_attributes(self):
|
|
213
|
+
class CustomError(Exception):
|
|
214
|
+
def __init__(self, message: str, code: int):
|
|
215
|
+
super().__init__(message)
|
|
216
|
+
self.code = code
|
|
217
|
+
|
|
218
|
+
error = CustomError("Error message", 42)
|
|
219
|
+
result = format_error_details(error)
|
|
220
|
+
self.assertIn("CustomError", result)
|
|
221
|
+
self.assertIn("Error message", result)
|
|
222
|
+
|
|
223
|
+
def test_format_error_no_color(self):
|
|
224
|
+
formatter = OutputFormatter(use_color=False)
|
|
225
|
+
error = RuntimeError("Test error")
|
|
226
|
+
result = format_error_details(error, formatter=formatter)
|
|
227
|
+
self.assertIn("RuntimeError", result)
|
|
228
|
+
self.assertIn("Test error", result)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
if __name__ == "__main__":
|
|
232
|
+
unittest.main()
|
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
import unittest
|
|
2
|
+
from pygeai_orchestration import (
|
|
3
|
+
Message,
|
|
4
|
+
MessageRole,
|
|
5
|
+
Conversation,
|
|
6
|
+
Context,
|
|
7
|
+
ContextManager,
|
|
8
|
+
State,
|
|
9
|
+
StateStatus,
|
|
10
|
+
StateManager,
|
|
11
|
+
Memory,
|
|
12
|
+
MemoryStore
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TestMessage(unittest.TestCase):
|
|
17
|
+
|
|
18
|
+
def test_creation(self):
|
|
19
|
+
msg = Message(role=MessageRole.USER, content="Hello")
|
|
20
|
+
|
|
21
|
+
self.assertEqual(msg.role, MessageRole.USER)
|
|
22
|
+
self.assertEqual(msg.content, "Hello")
|
|
23
|
+
self.assertIsNotNone(msg.timestamp)
|
|
24
|
+
|
|
25
|
+
def test_with_metadata(self):
|
|
26
|
+
msg = Message(
|
|
27
|
+
role=MessageRole.ASSISTANT,
|
|
28
|
+
content="Response",
|
|
29
|
+
metadata={"key": "value"}
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
self.assertEqual(msg.metadata["key"], "value")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class TestConversation(unittest.TestCase):
|
|
36
|
+
|
|
37
|
+
def test_creation(self):
|
|
38
|
+
conv = Conversation(id="conv-1")
|
|
39
|
+
|
|
40
|
+
self.assertEqual(conv.id, "conv-1")
|
|
41
|
+
self.assertEqual(len(conv.messages), 0)
|
|
42
|
+
|
|
43
|
+
def test_add_message(self):
|
|
44
|
+
conv = Conversation(id="conv-1")
|
|
45
|
+
msg = Message(role=MessageRole.USER, content="Hello")
|
|
46
|
+
|
|
47
|
+
conv.add_message(msg)
|
|
48
|
+
self.assertEqual(len(conv.messages), 1)
|
|
49
|
+
|
|
50
|
+
def test_get_messages_by_role(self):
|
|
51
|
+
conv = Conversation(id="conv-1")
|
|
52
|
+
conv.add_message(Message(role=MessageRole.USER, content="Q1"))
|
|
53
|
+
conv.add_message(Message(role=MessageRole.ASSISTANT, content="A1"))
|
|
54
|
+
conv.add_message(Message(role=MessageRole.USER, content="Q2"))
|
|
55
|
+
|
|
56
|
+
user_msgs = conv.get_messages(MessageRole.USER)
|
|
57
|
+
self.assertEqual(len(user_msgs), 2)
|
|
58
|
+
|
|
59
|
+
def test_to_dict_list(self):
|
|
60
|
+
conv = Conversation(id="conv-1")
|
|
61
|
+
conv.add_message(Message(role=MessageRole.USER, content="Hello"))
|
|
62
|
+
|
|
63
|
+
dict_list = conv.to_dict_list()
|
|
64
|
+
self.assertEqual(len(dict_list), 1)
|
|
65
|
+
self.assertEqual(dict_list[0]["role"], "user")
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class TestContext(unittest.TestCase):
|
|
69
|
+
|
|
70
|
+
def test_creation(self):
|
|
71
|
+
ctx = Context()
|
|
72
|
+
self.assertEqual(len(ctx.data), 0)
|
|
73
|
+
|
|
74
|
+
def test_set_get(self):
|
|
75
|
+
ctx = Context()
|
|
76
|
+
ctx.set("key", "value")
|
|
77
|
+
|
|
78
|
+
self.assertTrue(ctx.has("key"))
|
|
79
|
+
self.assertEqual(ctx.get("key"), "value")
|
|
80
|
+
|
|
81
|
+
def test_remove(self):
|
|
82
|
+
ctx = Context()
|
|
83
|
+
ctx.set("key", "value")
|
|
84
|
+
ctx.remove("key")
|
|
85
|
+
|
|
86
|
+
self.assertFalse(ctx.has("key"))
|
|
87
|
+
|
|
88
|
+
def test_update(self):
|
|
89
|
+
ctx = Context()
|
|
90
|
+
ctx.update({"a": 1, "b": 2})
|
|
91
|
+
|
|
92
|
+
self.assertEqual(ctx.get("a"), 1)
|
|
93
|
+
self.assertEqual(ctx.get("b"), 2)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class TestContextManager(unittest.TestCase):
|
|
97
|
+
|
|
98
|
+
def test_create_context(self):
|
|
99
|
+
manager = ContextManager()
|
|
100
|
+
ctx = manager.create_context("ctx-1")
|
|
101
|
+
|
|
102
|
+
self.assertIsNotNone(ctx)
|
|
103
|
+
self.assertIn("ctx-1", manager.list_contexts())
|
|
104
|
+
|
|
105
|
+
def test_get_context(self):
|
|
106
|
+
manager = ContextManager()
|
|
107
|
+
manager.create_context("ctx-1", {"test": "data"})
|
|
108
|
+
|
|
109
|
+
ctx = manager.get_context("ctx-1")
|
|
110
|
+
self.assertEqual(ctx.get("test"), "data")
|
|
111
|
+
|
|
112
|
+
def test_delete_context(self):
|
|
113
|
+
manager = ContextManager()
|
|
114
|
+
manager.create_context("ctx-1")
|
|
115
|
+
manager.delete_context("ctx-1")
|
|
116
|
+
|
|
117
|
+
self.assertNotIn("ctx-1", manager.list_contexts())
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class TestState(unittest.TestCase):
|
|
121
|
+
|
|
122
|
+
def test_creation(self):
|
|
123
|
+
state = State()
|
|
124
|
+
self.assertEqual(state.status, StateStatus.INITIALIZED)
|
|
125
|
+
|
|
126
|
+
def test_update_status(self):
|
|
127
|
+
state = State()
|
|
128
|
+
state.update_status(StateStatus.RUNNING)
|
|
129
|
+
|
|
130
|
+
self.assertEqual(state.status, StateStatus.RUNNING)
|
|
131
|
+
|
|
132
|
+
def test_checkpoint(self):
|
|
133
|
+
state = State()
|
|
134
|
+
state.set("key", "value")
|
|
135
|
+
state.create_checkpoint()
|
|
136
|
+
|
|
137
|
+
state.set("key", "new_value")
|
|
138
|
+
self.assertEqual(state.get("key"), "new_value")
|
|
139
|
+
|
|
140
|
+
state.restore_checkpoint()
|
|
141
|
+
self.assertEqual(state.get("key"), "value")
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class TestStateManager(unittest.TestCase):
|
|
145
|
+
|
|
146
|
+
def test_create_state(self):
|
|
147
|
+
manager = StateManager()
|
|
148
|
+
state = manager.create_state("state-1")
|
|
149
|
+
|
|
150
|
+
self.assertIsNotNone(state)
|
|
151
|
+
self.assertIn("state-1", manager.list_states())
|
|
152
|
+
|
|
153
|
+
def test_update_state(self):
|
|
154
|
+
manager = StateManager()
|
|
155
|
+
manager.create_state("state-1")
|
|
156
|
+
|
|
157
|
+
success = manager.update_state("state-1", StateStatus.COMPLETED, {"result": "done"})
|
|
158
|
+
self.assertTrue(success)
|
|
159
|
+
|
|
160
|
+
state = manager.get_state("state-1")
|
|
161
|
+
self.assertEqual(state.status, StateStatus.COMPLETED)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class TestMemory(unittest.TestCase):
|
|
165
|
+
|
|
166
|
+
def test_creation(self):
|
|
167
|
+
mem = Memory()
|
|
168
|
+
self.assertEqual(mem.size(), 0)
|
|
169
|
+
|
|
170
|
+
def test_store_retrieve(self):
|
|
171
|
+
mem = Memory()
|
|
172
|
+
mem.store("fact1", "The sky is blue")
|
|
173
|
+
|
|
174
|
+
self.assertTrue(mem.has("fact1"))
|
|
175
|
+
self.assertEqual(mem.retrieve("fact1"), "The sky is blue")
|
|
176
|
+
|
|
177
|
+
def test_max_size(self):
|
|
178
|
+
mem = Memory(max_size=2)
|
|
179
|
+
mem.store("k1", "v1")
|
|
180
|
+
mem.store("k2", "v2")
|
|
181
|
+
mem.store("k3", "v3")
|
|
182
|
+
|
|
183
|
+
self.assertEqual(mem.size(), 2)
|
|
184
|
+
self.assertFalse(mem.has("k1"))
|
|
185
|
+
|
|
186
|
+
def test_get_recent(self):
|
|
187
|
+
mem = Memory()
|
|
188
|
+
mem.store("k1", "v1")
|
|
189
|
+
mem.store("k2", "v2")
|
|
190
|
+
mem.store("k3", "v3")
|
|
191
|
+
|
|
192
|
+
recent = mem.get_recent(limit=2)
|
|
193
|
+
self.assertEqual(len(recent), 2)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class TestMemoryStore(unittest.TestCase):
|
|
197
|
+
|
|
198
|
+
def test_create_memory(self):
|
|
199
|
+
store = MemoryStore()
|
|
200
|
+
mem = store.create_memory("mem-1")
|
|
201
|
+
|
|
202
|
+
self.assertIsNotNone(mem)
|
|
203
|
+
self.assertIn("mem-1", store.list_memories())
|
|
204
|
+
|
|
205
|
+
def test_get_memory(self):
|
|
206
|
+
store = MemoryStore()
|
|
207
|
+
store.create_memory("mem-1")
|
|
208
|
+
|
|
209
|
+
mem = store.get_memory("mem-1")
|
|
210
|
+
self.assertIsNotNone(mem)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
if __name__ == '__main__':
|
|
214
|
+
unittest.main()
|
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
import unittest
|
|
2
|
+
|
|
3
|
+
from pygeai_orchestration.core.base.pattern import BasePattern, PatternConfig, PatternResult, PatternType
|
|
4
|
+
from pygeai_orchestration.core.composition import (
|
|
5
|
+
CompositionConfig,
|
|
6
|
+
CompositionMode,
|
|
7
|
+
PatternComposer,
|
|
8
|
+
PatternPipeline,
|
|
9
|
+
)
|
|
10
|
+
from pygeai_orchestration.core.exceptions import PatternExecutionError
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class MockPattern(BasePattern):
|
|
14
|
+
def __init__(self, name: str, output: str, success: bool = True):
|
|
15
|
+
config = PatternConfig(name=name, pattern_type=PatternType.REFLECTION)
|
|
16
|
+
super().__init__(config)
|
|
17
|
+
self.output = output
|
|
18
|
+
self.success = success
|
|
19
|
+
self.executed = False
|
|
20
|
+
|
|
21
|
+
async def execute(self, task: str, **kwargs) -> PatternResult:
|
|
22
|
+
self.executed = True
|
|
23
|
+
if not self.success:
|
|
24
|
+
raise Exception(f"Pattern {self.config.name} failed")
|
|
25
|
+
return PatternResult(
|
|
26
|
+
success=True,
|
|
27
|
+
result=f"{self.output}: {task}",
|
|
28
|
+
metadata={"pattern_name": self.config.name}
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
async def _execute_impl(self, task: str, context=None) -> PatternResult:
|
|
32
|
+
return await self.execute(task)
|
|
33
|
+
|
|
34
|
+
async def step(self, task: str, context=None):
|
|
35
|
+
return await self.execute(task)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class TestCompositionConfig(unittest.TestCase):
|
|
39
|
+
def test_default_config(self):
|
|
40
|
+
config = CompositionConfig()
|
|
41
|
+
self.assertEqual(config.mode, CompositionMode.SEQUENTIAL)
|
|
42
|
+
self.assertTrue(config.stop_on_error)
|
|
43
|
+
self.assertTrue(config.pass_output)
|
|
44
|
+
|
|
45
|
+
def test_custom_config(self):
|
|
46
|
+
config = CompositionConfig(
|
|
47
|
+
mode=CompositionMode.PARALLEL,
|
|
48
|
+
stop_on_error=False,
|
|
49
|
+
pass_output=False
|
|
50
|
+
)
|
|
51
|
+
self.assertEqual(config.mode, CompositionMode.PARALLEL)
|
|
52
|
+
self.assertFalse(config.stop_on_error)
|
|
53
|
+
self.assertFalse(config.pass_output)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class TestPatternPipeline(unittest.IsolatedAsyncioTestCase):
|
|
57
|
+
def setUp(self):
|
|
58
|
+
self.pattern1 = MockPattern("pattern1", "Result1")
|
|
59
|
+
self.pattern2 = MockPattern("pattern2", "Result2")
|
|
60
|
+
self.pattern3 = MockPattern("pattern3", "Result3")
|
|
61
|
+
|
|
62
|
+
async def test_pipeline_creation(self):
|
|
63
|
+
pipeline = PatternPipeline()
|
|
64
|
+
self.assertEqual(len(pipeline), 0)
|
|
65
|
+
|
|
66
|
+
async def test_add_pattern(self):
|
|
67
|
+
pipeline = PatternPipeline()
|
|
68
|
+
pipeline.add_pattern(self.pattern1)
|
|
69
|
+
self.assertEqual(len(pipeline), 1)
|
|
70
|
+
|
|
71
|
+
async def test_add_patterns(self):
|
|
72
|
+
pipeline = PatternPipeline()
|
|
73
|
+
pipeline.add_patterns(self.pattern1, self.pattern2, self.pattern3)
|
|
74
|
+
self.assertEqual(len(pipeline), 3)
|
|
75
|
+
|
|
76
|
+
async def test_add_invalid_pattern(self):
|
|
77
|
+
pipeline = PatternPipeline()
|
|
78
|
+
with self.assertRaises(TypeError):
|
|
79
|
+
pipeline.add_pattern("not a pattern")
|
|
80
|
+
|
|
81
|
+
async def test_fluent_interface(self):
|
|
82
|
+
pipeline = PatternPipeline()
|
|
83
|
+
result = pipeline.add_pattern(self.pattern1).add_pattern(self.pattern2)
|
|
84
|
+
self.assertIs(result, pipeline)
|
|
85
|
+
self.assertEqual(len(pipeline), 2)
|
|
86
|
+
|
|
87
|
+
async def test_sequential_execution(self):
|
|
88
|
+
pipeline = PatternPipeline()
|
|
89
|
+
pipeline.add_patterns(self.pattern1, self.pattern2, self.pattern3)
|
|
90
|
+
|
|
91
|
+
result = await pipeline.execute("input")
|
|
92
|
+
|
|
93
|
+
self.assertTrue(result.success)
|
|
94
|
+
self.assertTrue(self.pattern1.executed)
|
|
95
|
+
self.assertTrue(self.pattern2.executed)
|
|
96
|
+
self.assertTrue(self.pattern3.executed)
|
|
97
|
+
|
|
98
|
+
async def test_sequential_output_passing(self):
|
|
99
|
+
pipeline = PatternPipeline(CompositionConfig(pass_output=True))
|
|
100
|
+
pipeline.add_patterns(self.pattern1, self.pattern2)
|
|
101
|
+
|
|
102
|
+
result = await pipeline.execute("input")
|
|
103
|
+
|
|
104
|
+
self.assertIn("Result2", result.result)
|
|
105
|
+
self.assertIn("Result1: input", result.result)
|
|
106
|
+
|
|
107
|
+
async def test_sequential_no_output_passing(self):
|
|
108
|
+
pipeline = PatternPipeline(CompositionConfig(pass_output=False))
|
|
109
|
+
pipeline.add_patterns(self.pattern1, self.pattern2)
|
|
110
|
+
|
|
111
|
+
result = await pipeline.execute("input")
|
|
112
|
+
|
|
113
|
+
self.assertIn("Result2: input", result.result)
|
|
114
|
+
|
|
115
|
+
async def test_stop_on_error(self):
|
|
116
|
+
failing_pattern = MockPattern("failing", "fail", success=False)
|
|
117
|
+
pipeline = PatternPipeline(CompositionConfig(stop_on_error=True))
|
|
118
|
+
pipeline.add_patterns(self.pattern1, failing_pattern, self.pattern3)
|
|
119
|
+
|
|
120
|
+
with self.assertRaises(PatternExecutionError):
|
|
121
|
+
await pipeline.execute("input")
|
|
122
|
+
|
|
123
|
+
self.assertTrue(self.pattern1.executed)
|
|
124
|
+
self.assertFalse(self.pattern3.executed)
|
|
125
|
+
|
|
126
|
+
async def test_continue_on_error(self):
|
|
127
|
+
failing_pattern = MockPattern("failing", "fail", success=False)
|
|
128
|
+
pipeline = PatternPipeline(CompositionConfig(stop_on_error=False))
|
|
129
|
+
pipeline.add_patterns(self.pattern1, failing_pattern, self.pattern3)
|
|
130
|
+
|
|
131
|
+
await pipeline.execute("input")
|
|
132
|
+
|
|
133
|
+
self.assertTrue(self.pattern1.executed)
|
|
134
|
+
self.assertTrue(self.pattern3.executed)
|
|
135
|
+
|
|
136
|
+
async def test_empty_pipeline(self):
|
|
137
|
+
pipeline = PatternPipeline()
|
|
138
|
+
|
|
139
|
+
with self.assertRaises(PatternExecutionError) as cm:
|
|
140
|
+
await pipeline.execute("input")
|
|
141
|
+
|
|
142
|
+
self.assertIn("No patterns", str(cm.exception))
|
|
143
|
+
|
|
144
|
+
async def test_get_results(self):
|
|
145
|
+
pipeline = PatternPipeline()
|
|
146
|
+
pipeline.add_patterns(self.pattern1, self.pattern2)
|
|
147
|
+
|
|
148
|
+
await pipeline.execute("input")
|
|
149
|
+
results = pipeline.get_results()
|
|
150
|
+
|
|
151
|
+
self.assertEqual(len(results), 2)
|
|
152
|
+
self.assertTrue(all(isinstance(r, PatternResult) for r in results))
|
|
153
|
+
|
|
154
|
+
async def test_clear_pipeline(self):
|
|
155
|
+
pipeline = PatternPipeline()
|
|
156
|
+
pipeline.add_patterns(self.pattern1, self.pattern2)
|
|
157
|
+
await pipeline.execute("input")
|
|
158
|
+
|
|
159
|
+
pipeline.clear()
|
|
160
|
+
|
|
161
|
+
self.assertEqual(len(pipeline), 0)
|
|
162
|
+
self.assertEqual(len(pipeline.get_results()), 0)
|
|
163
|
+
|
|
164
|
+
async def test_pipeline_metadata(self):
|
|
165
|
+
pipeline = PatternPipeline()
|
|
166
|
+
pipeline.add_patterns(self.pattern1, self.pattern2, self.pattern3)
|
|
167
|
+
|
|
168
|
+
result = await pipeline.execute("input")
|
|
169
|
+
|
|
170
|
+
self.assertIn("pipeline_results", result.metadata)
|
|
171
|
+
self.assertEqual(result.metadata["total_patterns"], 3)
|
|
172
|
+
self.assertEqual(result.metadata["successful_patterns"], 3)
|
|
173
|
+
|
|
174
|
+
async def test_parallel_execution(self):
|
|
175
|
+
pipeline = PatternPipeline(CompositionConfig(mode=CompositionMode.PARALLEL))
|
|
176
|
+
pipeline.add_patterns(self.pattern1, self.pattern2, self.pattern3)
|
|
177
|
+
|
|
178
|
+
result = await pipeline.execute("input")
|
|
179
|
+
|
|
180
|
+
self.assertTrue(result.success)
|
|
181
|
+
self.assertTrue(self.pattern1.executed)
|
|
182
|
+
self.assertTrue(self.pattern2.executed)
|
|
183
|
+
self.assertTrue(self.pattern3.executed)
|
|
184
|
+
|
|
185
|
+
async def test_parallel_combined_output(self):
|
|
186
|
+
pipeline = PatternPipeline(CompositionConfig(mode=CompositionMode.PARALLEL))
|
|
187
|
+
pipeline.add_patterns(self.pattern1, self.pattern2)
|
|
188
|
+
|
|
189
|
+
result = await pipeline.execute("input")
|
|
190
|
+
|
|
191
|
+
self.assertIn("Result1", result.result)
|
|
192
|
+
self.assertIn("Result2", result.result)
|
|
193
|
+
|
|
194
|
+
async def test_parallel_with_failure(self):
|
|
195
|
+
failing_pattern = MockPattern("failing", "fail", success=False)
|
|
196
|
+
pipeline = PatternPipeline(CompositionConfig(mode=CompositionMode.PARALLEL))
|
|
197
|
+
pipeline.add_patterns(self.pattern1, failing_pattern, self.pattern2)
|
|
198
|
+
|
|
199
|
+
result = await pipeline.execute("input")
|
|
200
|
+
|
|
201
|
+
self.assertTrue(result.success)
|
|
202
|
+
self.assertIn("Result1", result.result)
|
|
203
|
+
self.assertIn("Result2", result.result)
|
|
204
|
+
|
|
205
|
+
async def test_parallel_all_fail(self):
|
|
206
|
+
failing1 = MockPattern("fail1", "fail", success=False)
|
|
207
|
+
failing2 = MockPattern("fail2", "fail", success=False)
|
|
208
|
+
pipeline = PatternPipeline(CompositionConfig(mode=CompositionMode.PARALLEL))
|
|
209
|
+
pipeline.add_patterns(failing1, failing2)
|
|
210
|
+
|
|
211
|
+
with self.assertRaises(PatternExecutionError) as cm:
|
|
212
|
+
await pipeline.execute("input")
|
|
213
|
+
|
|
214
|
+
self.assertIn("All patterns failed", str(cm.exception))
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
class TestPatternComposer(unittest.IsolatedAsyncioTestCase):
|
|
218
|
+
def setUp(self):
|
|
219
|
+
self.pattern1 = MockPattern("pattern1", "Result1")
|
|
220
|
+
self.pattern2 = MockPattern("pattern2", "Result2")
|
|
221
|
+
|
|
222
|
+
async def test_sequential_composer(self):
|
|
223
|
+
pipeline = PatternComposer.sequential(self.pattern1, self.pattern2)
|
|
224
|
+
|
|
225
|
+
self.assertIsInstance(pipeline, PatternPipeline)
|
|
226
|
+
self.assertEqual(len(pipeline), 2)
|
|
227
|
+
self.assertEqual(pipeline.config.mode, CompositionMode.SEQUENTIAL)
|
|
228
|
+
|
|
229
|
+
async def test_parallel_composer(self):
|
|
230
|
+
pipeline = PatternComposer.parallel(self.pattern1, self.pattern2)
|
|
231
|
+
|
|
232
|
+
self.assertIsInstance(pipeline, PatternPipeline)
|
|
233
|
+
self.assertEqual(len(pipeline), 2)
|
|
234
|
+
self.assertEqual(pipeline.config.mode, CompositionMode.PARALLEL)
|
|
235
|
+
|
|
236
|
+
async def test_custom_composer(self):
|
|
237
|
+
pipeline = PatternComposer.custom(
|
|
238
|
+
mode=CompositionMode.SEQUENTIAL,
|
|
239
|
+
stop_on_error=False,
|
|
240
|
+
pass_output=False
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
self.assertIsInstance(pipeline, PatternPipeline)
|
|
244
|
+
self.assertFalse(pipeline.config.stop_on_error)
|
|
245
|
+
self.assertFalse(pipeline.config.pass_output)
|
|
246
|
+
|
|
247
|
+
async def test_sequential_execution_via_composer(self):
|
|
248
|
+
pipeline = PatternComposer.sequential(self.pattern1, self.pattern2)
|
|
249
|
+
result = await pipeline.execute("input")
|
|
250
|
+
|
|
251
|
+
self.assertTrue(result.success)
|
|
252
|
+
self.assertTrue(self.pattern1.executed)
|
|
253
|
+
self.assertTrue(self.pattern2.executed)
|
|
254
|
+
|
|
255
|
+
async def test_parallel_execution_via_composer(self):
|
|
256
|
+
pipeline = PatternComposer.parallel(self.pattern1, self.pattern2)
|
|
257
|
+
result = await pipeline.execute("input")
|
|
258
|
+
|
|
259
|
+
self.assertTrue(result.success)
|
|
260
|
+
self.assertTrue(self.pattern1.executed)
|
|
261
|
+
self.assertTrue(self.pattern2.executed)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
if __name__ == "__main__":
|
|
265
|
+
unittest.main()
|