genxai-framework 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__init__.py +3 -0
- cli/commands/__init__.py +6 -0
- cli/commands/approval.py +85 -0
- cli/commands/audit.py +127 -0
- cli/commands/metrics.py +25 -0
- cli/commands/tool.py +389 -0
- cli/main.py +32 -0
- genxai/__init__.py +81 -0
- genxai/api/__init__.py +5 -0
- genxai/api/app.py +21 -0
- genxai/config/__init__.py +5 -0
- genxai/config/settings.py +37 -0
- genxai/connectors/__init__.py +19 -0
- genxai/connectors/base.py +122 -0
- genxai/connectors/kafka.py +92 -0
- genxai/connectors/postgres_cdc.py +95 -0
- genxai/connectors/registry.py +44 -0
- genxai/connectors/sqs.py +94 -0
- genxai/connectors/webhook.py +73 -0
- genxai/core/__init__.py +37 -0
- genxai/core/agent/__init__.py +32 -0
- genxai/core/agent/base.py +206 -0
- genxai/core/agent/config_io.py +59 -0
- genxai/core/agent/registry.py +98 -0
- genxai/core/agent/runtime.py +970 -0
- genxai/core/communication/__init__.py +6 -0
- genxai/core/communication/collaboration.py +44 -0
- genxai/core/communication/message_bus.py +192 -0
- genxai/core/communication/protocols.py +35 -0
- genxai/core/execution/__init__.py +22 -0
- genxai/core/execution/metadata.py +181 -0
- genxai/core/execution/queue.py +201 -0
- genxai/core/graph/__init__.py +30 -0
- genxai/core/graph/checkpoints.py +77 -0
- genxai/core/graph/edges.py +131 -0
- genxai/core/graph/engine.py +813 -0
- genxai/core/graph/executor.py +516 -0
- genxai/core/graph/nodes.py +161 -0
- genxai/core/graph/trigger_runner.py +40 -0
- genxai/core/memory/__init__.py +19 -0
- genxai/core/memory/base.py +72 -0
- genxai/core/memory/embedding.py +327 -0
- genxai/core/memory/episodic.py +448 -0
- genxai/core/memory/long_term.py +467 -0
- genxai/core/memory/manager.py +543 -0
- genxai/core/memory/persistence.py +297 -0
- genxai/core/memory/procedural.py +461 -0
- genxai/core/memory/semantic.py +526 -0
- genxai/core/memory/shared.py +62 -0
- genxai/core/memory/short_term.py +303 -0
- genxai/core/memory/vector_store.py +508 -0
- genxai/core/memory/working.py +211 -0
- genxai/core/state/__init__.py +6 -0
- genxai/core/state/manager.py +293 -0
- genxai/core/state/schema.py +115 -0
- genxai/llm/__init__.py +14 -0
- genxai/llm/base.py +150 -0
- genxai/llm/factory.py +329 -0
- genxai/llm/providers/__init__.py +1 -0
- genxai/llm/providers/anthropic.py +249 -0
- genxai/llm/providers/cohere.py +274 -0
- genxai/llm/providers/google.py +334 -0
- genxai/llm/providers/ollama.py +147 -0
- genxai/llm/providers/openai.py +257 -0
- genxai/llm/routing.py +83 -0
- genxai/observability/__init__.py +6 -0
- genxai/observability/logging.py +327 -0
- genxai/observability/metrics.py +494 -0
- genxai/observability/tracing.py +372 -0
- genxai/performance/__init__.py +39 -0
- genxai/performance/cache.py +256 -0
- genxai/performance/pooling.py +289 -0
- genxai/security/audit.py +304 -0
- genxai/security/auth.py +315 -0
- genxai/security/cost_control.py +528 -0
- genxai/security/default_policies.py +44 -0
- genxai/security/jwt.py +142 -0
- genxai/security/oauth.py +226 -0
- genxai/security/pii.py +366 -0
- genxai/security/policy_engine.py +82 -0
- genxai/security/rate_limit.py +341 -0
- genxai/security/rbac.py +247 -0
- genxai/security/validation.py +218 -0
- genxai/tools/__init__.py +21 -0
- genxai/tools/base.py +383 -0
- genxai/tools/builtin/__init__.py +131 -0
- genxai/tools/builtin/communication/__init__.py +15 -0
- genxai/tools/builtin/communication/email_sender.py +159 -0
- genxai/tools/builtin/communication/notification_manager.py +167 -0
- genxai/tools/builtin/communication/slack_notifier.py +118 -0
- genxai/tools/builtin/communication/sms_sender.py +118 -0
- genxai/tools/builtin/communication/webhook_caller.py +136 -0
- genxai/tools/builtin/computation/__init__.py +15 -0
- genxai/tools/builtin/computation/calculator.py +101 -0
- genxai/tools/builtin/computation/code_executor.py +183 -0
- genxai/tools/builtin/computation/data_validator.py +259 -0
- genxai/tools/builtin/computation/hash_generator.py +129 -0
- genxai/tools/builtin/computation/regex_matcher.py +201 -0
- genxai/tools/builtin/data/__init__.py +15 -0
- genxai/tools/builtin/data/csv_processor.py +213 -0
- genxai/tools/builtin/data/data_transformer.py +299 -0
- genxai/tools/builtin/data/json_processor.py +233 -0
- genxai/tools/builtin/data/text_analyzer.py +288 -0
- genxai/tools/builtin/data/xml_processor.py +175 -0
- genxai/tools/builtin/database/__init__.py +15 -0
- genxai/tools/builtin/database/database_inspector.py +157 -0
- genxai/tools/builtin/database/mongodb_query.py +196 -0
- genxai/tools/builtin/database/redis_cache.py +167 -0
- genxai/tools/builtin/database/sql_query.py +145 -0
- genxai/tools/builtin/database/vector_search.py +163 -0
- genxai/tools/builtin/file/__init__.py +17 -0
- genxai/tools/builtin/file/directory_scanner.py +214 -0
- genxai/tools/builtin/file/file_compressor.py +237 -0
- genxai/tools/builtin/file/file_reader.py +102 -0
- genxai/tools/builtin/file/file_writer.py +122 -0
- genxai/tools/builtin/file/image_processor.py +186 -0
- genxai/tools/builtin/file/pdf_parser.py +144 -0
- genxai/tools/builtin/test/__init__.py +15 -0
- genxai/tools/builtin/test/async_simulator.py +62 -0
- genxai/tools/builtin/test/data_transformer.py +99 -0
- genxai/tools/builtin/test/error_generator.py +82 -0
- genxai/tools/builtin/test/simple_math.py +94 -0
- genxai/tools/builtin/test/string_processor.py +72 -0
- genxai/tools/builtin/web/__init__.py +15 -0
- genxai/tools/builtin/web/api_caller.py +161 -0
- genxai/tools/builtin/web/html_parser.py +330 -0
- genxai/tools/builtin/web/http_client.py +187 -0
- genxai/tools/builtin/web/url_validator.py +162 -0
- genxai/tools/builtin/web/web_scraper.py +170 -0
- genxai/tools/custom/my_test_tool_2.py +9 -0
- genxai/tools/dynamic.py +105 -0
- genxai/tools/mcp_server.py +167 -0
- genxai/tools/persistence/__init__.py +6 -0
- genxai/tools/persistence/models.py +55 -0
- genxai/tools/persistence/service.py +322 -0
- genxai/tools/registry.py +227 -0
- genxai/tools/security/__init__.py +11 -0
- genxai/tools/security/limits.py +214 -0
- genxai/tools/security/policy.py +20 -0
- genxai/tools/security/sandbox.py +248 -0
- genxai/tools/templates.py +435 -0
- genxai/triggers/__init__.py +19 -0
- genxai/triggers/base.py +104 -0
- genxai/triggers/file_watcher.py +75 -0
- genxai/triggers/queue.py +68 -0
- genxai/triggers/registry.py +82 -0
- genxai/triggers/schedule.py +66 -0
- genxai/triggers/webhook.py +68 -0
- genxai/utils/__init__.py +1 -0
- genxai/utils/tokens.py +295 -0
- genxai_framework-0.1.0.dist-info/METADATA +495 -0
- genxai_framework-0.1.0.dist-info/RECORD +156 -0
- genxai_framework-0.1.0.dist-info/WHEEL +5 -0
- genxai_framework-0.1.0.dist-info/entry_points.txt +2 -0
- genxai_framework-0.1.0.dist-info/licenses/LICENSE +21 -0
- genxai_framework-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
"""Collaboration protocol implementations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Any, Dict, List
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class VotingResult:
|
|
11
|
+
winner: Any
|
|
12
|
+
counts: Dict[Any, int]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class VotingProtocol:
|
|
16
|
+
"""Simple majority voting protocol."""
|
|
17
|
+
|
|
18
|
+
async def run(self, inputs: List[Any], metadata: Dict[str, Any]) -> VotingResult:
|
|
19
|
+
counts: Dict[Any, int] = {}
|
|
20
|
+
for value in inputs:
|
|
21
|
+
counts[value] = counts.get(value, 0) + 1
|
|
22
|
+
winner = max(counts, key=counts.get)
|
|
23
|
+
return VotingResult(winner=winner, counts=counts)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class NegotiationProtocol:
|
|
27
|
+
"""Simple negotiation protocol that returns consensus if all equal."""
|
|
28
|
+
|
|
29
|
+
async def run(self, inputs: List[Any], metadata: Dict[str, Any]) -> Any:
|
|
30
|
+
if not inputs:
|
|
31
|
+
return None
|
|
32
|
+
first = inputs[0]
|
|
33
|
+
if all(value == first for value in inputs):
|
|
34
|
+
return first
|
|
35
|
+
return metadata.get("fallback")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class AuctionProtocol:
|
|
39
|
+
"""Simple auction protocol selecting max bid from inputs."""
|
|
40
|
+
|
|
41
|
+
async def run(self, inputs: List[Any], metadata: Dict[str, Any]) -> Any:
|
|
42
|
+
if not inputs:
|
|
43
|
+
return None
|
|
44
|
+
return max(inputs)
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
"""Message bus for agent-to-agent communication."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, List, Optional, Callable
|
|
4
|
+
from pydantic import BaseModel, Field, ConfigDict
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from collections import defaultdict
|
|
7
|
+
import asyncio
|
|
8
|
+
import logging
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Message(BaseModel):
|
|
14
|
+
"""Message for agent communication."""
|
|
15
|
+
|
|
16
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
17
|
+
|
|
18
|
+
id: str
|
|
19
|
+
sender: str
|
|
20
|
+
recipient: Optional[str] = None # None for broadcast
|
|
21
|
+
content: Any
|
|
22
|
+
message_type: str = "default"
|
|
23
|
+
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
24
|
+
timestamp: datetime = Field(default_factory=datetime.now)
|
|
25
|
+
reply_to: Optional[str] = None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class MessageBus:
|
|
30
|
+
"""Central message bus for agent communication."""
|
|
31
|
+
|
|
32
|
+
def __init__(self) -> None:
|
|
33
|
+
"""Initialize message bus."""
|
|
34
|
+
self._subscribers: Dict[str, List[Callable]] = defaultdict(list)
|
|
35
|
+
self._message_history: List[Message] = []
|
|
36
|
+
self._message_count = 0
|
|
37
|
+
|
|
38
|
+
async def send(self, message: Message) -> None:
|
|
39
|
+
"""Send a message to a specific recipient.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
message: Message to send
|
|
43
|
+
"""
|
|
44
|
+
self._message_count += 1
|
|
45
|
+
message.id = f"msg_{self._message_count}"
|
|
46
|
+
self._message_history.append(message)
|
|
47
|
+
|
|
48
|
+
logger.info(f"Message sent: {message.sender} -> {message.recipient}")
|
|
49
|
+
|
|
50
|
+
# Deliver to recipient's subscribers
|
|
51
|
+
if message.recipient and message.recipient in self._subscribers:
|
|
52
|
+
for callback in self._subscribers[message.recipient]:
|
|
53
|
+
try:
|
|
54
|
+
await callback(message)
|
|
55
|
+
except Exception as e:
|
|
56
|
+
logger.error(f"Error delivering message to {message.recipient}: {e}")
|
|
57
|
+
|
|
58
|
+
async def broadcast(self, message: Message, group: Optional[str] = None) -> None:
|
|
59
|
+
"""Broadcast a message to all subscribers or a group.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
message: Message to broadcast
|
|
63
|
+
group: Optional group name to broadcast to
|
|
64
|
+
"""
|
|
65
|
+
self._message_count += 1
|
|
66
|
+
message.id = f"msg_{self._message_count}"
|
|
67
|
+
message.recipient = None # Broadcast has no specific recipient
|
|
68
|
+
self._message_history.append(message)
|
|
69
|
+
|
|
70
|
+
logger.info(f"Message broadcast from {message.sender} to group: {group or 'all'}")
|
|
71
|
+
|
|
72
|
+
# Deliver to all subscribers (or group subscribers)
|
|
73
|
+
for agent_id, callbacks in self._subscribers.items():
|
|
74
|
+
if group and not agent_id.startswith(f"{group}_"):
|
|
75
|
+
continue
|
|
76
|
+
|
|
77
|
+
for callback in callbacks:
|
|
78
|
+
try:
|
|
79
|
+
await callback(message)
|
|
80
|
+
except Exception as e:
|
|
81
|
+
logger.error(f"Error broadcasting to {agent_id}: {e}")
|
|
82
|
+
|
|
83
|
+
async def request_reply(
|
|
84
|
+
self, message: Message, timeout: float = 30.0
|
|
85
|
+
) -> Optional[Message]:
|
|
86
|
+
"""Send a message and wait for reply.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
message: Message to send
|
|
90
|
+
timeout: Timeout in seconds
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
Reply message or None if timeout
|
|
94
|
+
"""
|
|
95
|
+
# Send message
|
|
96
|
+
await self.send(message)
|
|
97
|
+
|
|
98
|
+
# Wait for reply
|
|
99
|
+
reply_event = asyncio.Event()
|
|
100
|
+
reply_message: Optional[Message] = None
|
|
101
|
+
|
|
102
|
+
async def reply_handler(msg: Message) -> None:
|
|
103
|
+
nonlocal reply_message
|
|
104
|
+
if msg.reply_to == message.id:
|
|
105
|
+
reply_message = msg
|
|
106
|
+
reply_event.set()
|
|
107
|
+
|
|
108
|
+
# Subscribe to replies
|
|
109
|
+
if message.sender:
|
|
110
|
+
self.subscribe(message.sender, reply_handler)
|
|
111
|
+
|
|
112
|
+
try:
|
|
113
|
+
await asyncio.wait_for(reply_event.wait(), timeout=timeout)
|
|
114
|
+
return reply_message
|
|
115
|
+
except asyncio.TimeoutError:
|
|
116
|
+
logger.warning(f"Request-reply timeout for message {message.id}")
|
|
117
|
+
return None
|
|
118
|
+
finally:
|
|
119
|
+
if message.sender:
|
|
120
|
+
self.unsubscribe(message.sender, reply_handler)
|
|
121
|
+
|
|
122
|
+
def subscribe(self, agent_id: str, callback: Callable) -> None:
|
|
123
|
+
"""Subscribe an agent to receive messages.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
agent_id: Agent identifier
|
|
127
|
+
callback: Async callback function to handle messages
|
|
128
|
+
"""
|
|
129
|
+
self._subscribers[agent_id].append(callback)
|
|
130
|
+
logger.debug(f"Agent {agent_id} subscribed to message bus")
|
|
131
|
+
|
|
132
|
+
def unsubscribe(self, agent_id: str, callback: Optional[Callable] = None) -> None:
|
|
133
|
+
"""Unsubscribe an agent from messages.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
agent_id: Agent identifier
|
|
137
|
+
callback: Specific callback to remove (None to remove all)
|
|
138
|
+
"""
|
|
139
|
+
if agent_id in self._subscribers:
|
|
140
|
+
if callback:
|
|
141
|
+
self._subscribers[agent_id].remove(callback)
|
|
142
|
+
else:
|
|
143
|
+
del self._subscribers[agent_id]
|
|
144
|
+
logger.debug(f"Agent {agent_id} unsubscribed from message bus")
|
|
145
|
+
|
|
146
|
+
def get_history(
|
|
147
|
+
self, agent_id: Optional[str] = None, limit: Optional[int] = None
|
|
148
|
+
) -> List[Message]:
|
|
149
|
+
"""Get message history.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
agent_id: Filter by agent (sender or recipient)
|
|
153
|
+
limit: Maximum number of messages to return
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
List of messages
|
|
157
|
+
"""
|
|
158
|
+
messages = self._message_history
|
|
159
|
+
|
|
160
|
+
if agent_id:
|
|
161
|
+
messages = [
|
|
162
|
+
m
|
|
163
|
+
for m in messages
|
|
164
|
+
if m.sender == agent_id or m.recipient == agent_id
|
|
165
|
+
]
|
|
166
|
+
|
|
167
|
+
if limit:
|
|
168
|
+
messages = messages[-limit:]
|
|
169
|
+
|
|
170
|
+
return messages
|
|
171
|
+
|
|
172
|
+
def clear_history(self) -> None:
|
|
173
|
+
"""Clear message history."""
|
|
174
|
+
self._message_history.clear()
|
|
175
|
+
logger.info("Message history cleared")
|
|
176
|
+
|
|
177
|
+
def get_stats(self) -> Dict[str, Any]:
|
|
178
|
+
"""Get message bus statistics.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
Statistics dictionary
|
|
182
|
+
"""
|
|
183
|
+
return {
|
|
184
|
+
"total_messages": self._message_count,
|
|
185
|
+
"history_size": len(self._message_history),
|
|
186
|
+
"subscribers": len(self._subscribers),
|
|
187
|
+
"subscriber_list": list(self._subscribers.keys()),
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
def __repr__(self) -> str:
|
|
191
|
+
"""String representation."""
|
|
192
|
+
return f"MessageBus(messages={self._message_count}, subscribers={len(self._subscribers)})"
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""Communication protocols for agent interaction."""
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import Protocol as TypingProtocol, Any, Dict, List
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class CommunicationProtocol(str, Enum):
|
|
8
|
+
"""Communication protocols for agents."""
|
|
9
|
+
|
|
10
|
+
POINT_TO_POINT = "point_to_point"
|
|
11
|
+
BROADCAST = "broadcast"
|
|
12
|
+
REQUEST_REPLY = "request_reply"
|
|
13
|
+
PUB_SUB = "pub_sub"
|
|
14
|
+
NEGOTIATION = "negotiation"
|
|
15
|
+
VOTING = "voting"
|
|
16
|
+
AUCTION = "auction"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class MessageHandler(TypingProtocol):
|
|
20
|
+
"""Protocol for message handlers."""
|
|
21
|
+
|
|
22
|
+
async def handle_message(self, message: Any) -> None:
|
|
23
|
+
"""Handle incoming message.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
message: Message to handle
|
|
27
|
+
"""
|
|
28
|
+
...
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class CollaborationProtocol(TypingProtocol):
|
|
32
|
+
"""Protocol for collaboration strategies."""
|
|
33
|
+
|
|
34
|
+
async def run(self, inputs: List[Any], metadata: Dict[str, Any]) -> Any:
|
|
35
|
+
...
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""Distributed execution primitives for GenXAI."""
|
|
2
|
+
|
|
3
|
+
from genxai.core.execution.queue import (
|
|
4
|
+
QueueBackend,
|
|
5
|
+
QueueTask,
|
|
6
|
+
InMemoryQueueBackend,
|
|
7
|
+
WorkerQueueEngine,
|
|
8
|
+
RedisQueueBackend,
|
|
9
|
+
RQQueueBackend,
|
|
10
|
+
)
|
|
11
|
+
from genxai.core.execution.metadata import ExecutionRecord, ExecutionStore
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"QueueBackend",
|
|
15
|
+
"QueueTask",
|
|
16
|
+
"InMemoryQueueBackend",
|
|
17
|
+
"WorkerQueueEngine",
|
|
18
|
+
"RedisQueueBackend",
|
|
19
|
+
"RQQueueBackend",
|
|
20
|
+
"ExecutionRecord",
|
|
21
|
+
"ExecutionStore",
|
|
22
|
+
]
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
"""Execution metadata store for workflow runs."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Dict, Optional
|
|
9
|
+
import json
|
|
10
|
+
import uuid
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class ExecutionRecord:
|
|
15
|
+
"""Represents a workflow execution record."""
|
|
16
|
+
|
|
17
|
+
run_id: str
|
|
18
|
+
workflow: str
|
|
19
|
+
status: str
|
|
20
|
+
started_at: str
|
|
21
|
+
completed_at: Optional[str] = None
|
|
22
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
23
|
+
error: Optional[str] = None
|
|
24
|
+
result: Optional[Dict[str, Any]] = None
|
|
25
|
+
|
|
26
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
27
|
+
return {
|
|
28
|
+
"run_id": self.run_id,
|
|
29
|
+
"workflow": self.workflow,
|
|
30
|
+
"status": self.status,
|
|
31
|
+
"started_at": self.started_at,
|
|
32
|
+
"completed_at": self.completed_at,
|
|
33
|
+
"metadata": self.metadata,
|
|
34
|
+
"error": self.error,
|
|
35
|
+
"result": self.result,
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class ExecutionStore:
|
|
40
|
+
"""Execution store with JSON or SQL persistence support."""
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
persistence_path: Optional[Path] = None,
|
|
45
|
+
sql_url: Optional[str] = None,
|
|
46
|
+
) -> None:
|
|
47
|
+
self._records: Dict[str, ExecutionRecord] = {}
|
|
48
|
+
self._persistence_path = persistence_path
|
|
49
|
+
self._sql_url = sql_url
|
|
50
|
+
self._engine = None
|
|
51
|
+
self._table = None
|
|
52
|
+
|
|
53
|
+
if sql_url:
|
|
54
|
+
try:
|
|
55
|
+
import sqlalchemy as sa
|
|
56
|
+
except Exception as exc:
|
|
57
|
+
raise ImportError(
|
|
58
|
+
"sqlalchemy is required for SQL persistence. Install with: pip install sqlalchemy"
|
|
59
|
+
) from exc
|
|
60
|
+
self._engine = sa.create_engine(sql_url)
|
|
61
|
+
metadata = sa.MetaData()
|
|
62
|
+
self._table = sa.Table(
|
|
63
|
+
"genxai_executions",
|
|
64
|
+
metadata,
|
|
65
|
+
sa.Column("run_id", sa.String, primary_key=True),
|
|
66
|
+
sa.Column("workflow", sa.String, nullable=False),
|
|
67
|
+
sa.Column("status", sa.String, nullable=False),
|
|
68
|
+
sa.Column("started_at", sa.String, nullable=False),
|
|
69
|
+
sa.Column("completed_at", sa.String),
|
|
70
|
+
sa.Column("metadata", sa.JSON, nullable=True),
|
|
71
|
+
sa.Column("error", sa.Text, nullable=True),
|
|
72
|
+
sa.Column("result", sa.JSON, nullable=True),
|
|
73
|
+
)
|
|
74
|
+
metadata.create_all(self._engine)
|
|
75
|
+
|
|
76
|
+
def generate_run_id(self) -> str:
|
|
77
|
+
return str(uuid.uuid4())
|
|
78
|
+
|
|
79
|
+
def create(
|
|
80
|
+
self,
|
|
81
|
+
run_id: str,
|
|
82
|
+
workflow: str,
|
|
83
|
+
status: str,
|
|
84
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
85
|
+
) -> ExecutionRecord:
|
|
86
|
+
if run_id in self._records:
|
|
87
|
+
return self._records[run_id]
|
|
88
|
+
|
|
89
|
+
record = ExecutionRecord(
|
|
90
|
+
run_id=run_id,
|
|
91
|
+
workflow=workflow,
|
|
92
|
+
status=status,
|
|
93
|
+
started_at=datetime.now().isoformat(),
|
|
94
|
+
metadata=metadata or {},
|
|
95
|
+
)
|
|
96
|
+
self._records[run_id] = record
|
|
97
|
+
self._persist(record)
|
|
98
|
+
return record
|
|
99
|
+
|
|
100
|
+
def update(
|
|
101
|
+
self,
|
|
102
|
+
run_id: str,
|
|
103
|
+
status: Optional[str] = None,
|
|
104
|
+
error: Optional[str] = None,
|
|
105
|
+
result: Optional[Dict[str, Any]] = None,
|
|
106
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
107
|
+
completed: bool = False,
|
|
108
|
+
) -> ExecutionRecord:
|
|
109
|
+
record = self._records.get(run_id)
|
|
110
|
+
if record is None:
|
|
111
|
+
record = self.create(run_id, workflow="unknown", status="unknown")
|
|
112
|
+
if status is not None:
|
|
113
|
+
record.status = status
|
|
114
|
+
if error is not None:
|
|
115
|
+
record.error = error
|
|
116
|
+
if result is not None:
|
|
117
|
+
record.result = result
|
|
118
|
+
if metadata:
|
|
119
|
+
record.metadata.update(metadata)
|
|
120
|
+
if completed:
|
|
121
|
+
record.completed_at = datetime.now().isoformat()
|
|
122
|
+
self._persist(record)
|
|
123
|
+
return record
|
|
124
|
+
|
|
125
|
+
def get(self, run_id: str) -> Optional[ExecutionRecord]:
|
|
126
|
+
record = self._records.get(run_id)
|
|
127
|
+
if record or not self._engine or not self._table:
|
|
128
|
+
return record
|
|
129
|
+
|
|
130
|
+
import sqlalchemy as sa
|
|
131
|
+
|
|
132
|
+
with self._engine.begin() as conn:
|
|
133
|
+
stmt = sa.select(self._table).where(self._table.c.run_id == run_id)
|
|
134
|
+
row = conn.execute(stmt).mappings().first()
|
|
135
|
+
if not row:
|
|
136
|
+
return None
|
|
137
|
+
record = ExecutionRecord(
|
|
138
|
+
run_id=row["run_id"],
|
|
139
|
+
workflow=row["workflow"],
|
|
140
|
+
status=row["status"],
|
|
141
|
+
started_at=row["started_at"],
|
|
142
|
+
completed_at=row["completed_at"],
|
|
143
|
+
metadata=row["metadata"] or {},
|
|
144
|
+
error=row["error"],
|
|
145
|
+
result=row["result"],
|
|
146
|
+
)
|
|
147
|
+
self._records[run_id] = record
|
|
148
|
+
return record
|
|
149
|
+
|
|
150
|
+
def _persist(self, record: ExecutionRecord) -> None:
|
|
151
|
+
if self._engine is not None and self._table is not None:
|
|
152
|
+
import sqlalchemy as sa
|
|
153
|
+
|
|
154
|
+
payload = record.to_dict()
|
|
155
|
+
with self._engine.begin() as conn:
|
|
156
|
+
stmt = sa.select(self._table.c.run_id).where(
|
|
157
|
+
self._table.c.run_id == record.run_id
|
|
158
|
+
)
|
|
159
|
+
exists = conn.execute(stmt).first()
|
|
160
|
+
if exists:
|
|
161
|
+
conn.execute(
|
|
162
|
+
self._table.update()
|
|
163
|
+
.where(self._table.c.run_id == record.run_id)
|
|
164
|
+
.values(**payload)
|
|
165
|
+
)
|
|
166
|
+
else:
|
|
167
|
+
conn.execute(self._table.insert().values(**payload))
|
|
168
|
+
|
|
169
|
+
if not self._persistence_path:
|
|
170
|
+
return
|
|
171
|
+
self._persistence_path.mkdir(parents=True, exist_ok=True)
|
|
172
|
+
path = self._persistence_path / f"execution_{record.run_id}.json"
|
|
173
|
+
path.write_text(json.dumps(record.to_dict(), indent=2, default=str))
|
|
174
|
+
|
|
175
|
+
def close(self) -> None:
|
|
176
|
+
"""Dispose of SQL resources if enabled."""
|
|
177
|
+
if self._engine is not None:
|
|
178
|
+
self._engine.dispose()
|
|
179
|
+
|
|
180
|
+
def __del__(self) -> None:
|
|
181
|
+
self.close()
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
"""Async worker queue engine for distributed execution."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from typing import Any, Awaitable, Callable, Optional, Protocol
|
|
8
|
+
import uuid
|
|
9
|
+
import logging
|
|
10
|
+
import json
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class QueueTask:
|
|
17
|
+
"""Represents a unit of work for the worker queue."""
|
|
18
|
+
|
|
19
|
+
task_id: str
|
|
20
|
+
payload: dict[str, Any]
|
|
21
|
+
handler: Optional[Callable[[dict[str, Any]], Awaitable[Any]]]
|
|
22
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class QueueBackend(Protocol):
|
|
26
|
+
"""Protocol for queue backends."""
|
|
27
|
+
|
|
28
|
+
async def put(self, task: QueueTask) -> None:
|
|
29
|
+
...
|
|
30
|
+
|
|
31
|
+
async def get(self) -> QueueTask:
|
|
32
|
+
...
|
|
33
|
+
|
|
34
|
+
def qsize(self) -> int:
|
|
35
|
+
...
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class InMemoryQueueBackend:
|
|
39
|
+
"""In-memory asyncio queue backend."""
|
|
40
|
+
|
|
41
|
+
def __init__(self) -> None:
|
|
42
|
+
self._queue: asyncio.Queue[QueueTask] = asyncio.Queue()
|
|
43
|
+
|
|
44
|
+
async def put(self, task: QueueTask) -> None:
|
|
45
|
+
await self._queue.put(task)
|
|
46
|
+
|
|
47
|
+
async def get(self) -> QueueTask:
|
|
48
|
+
return await self._queue.get()
|
|
49
|
+
|
|
50
|
+
def qsize(self) -> int:
|
|
51
|
+
return self._queue.qsize()
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class WorkerQueueEngine:
|
|
55
|
+
"""Simple async worker engine for processing queued tasks."""
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
backend: Optional[QueueBackend] = None,
|
|
60
|
+
worker_count: int = 2,
|
|
61
|
+
max_retries: int = 3,
|
|
62
|
+
backoff_seconds: float = 0.5,
|
|
63
|
+
handler_registry: Optional[dict[str, Callable[[dict[str, Any]], Awaitable[Any]]]] = None,
|
|
64
|
+
) -> None:
|
|
65
|
+
self._backend = backend or InMemoryQueueBackend()
|
|
66
|
+
self._worker_count = worker_count
|
|
67
|
+
self._max_retries = max_retries
|
|
68
|
+
self._backoff_seconds = backoff_seconds
|
|
69
|
+
self._workers: list[asyncio.Task[None]] = []
|
|
70
|
+
self._running = False
|
|
71
|
+
self._handler_registry = handler_registry or {}
|
|
72
|
+
|
|
73
|
+
def register_handler(
|
|
74
|
+
self,
|
|
75
|
+
name: str,
|
|
76
|
+
handler: Callable[[dict[str, Any]], Awaitable[Any]],
|
|
77
|
+
) -> None:
|
|
78
|
+
"""Register a handler by name for distributed queue backends."""
|
|
79
|
+
self._handler_registry[name] = handler
|
|
80
|
+
|
|
81
|
+
async def start(self) -> None:
|
|
82
|
+
if self._running:
|
|
83
|
+
return
|
|
84
|
+
self._running = True
|
|
85
|
+
for idx in range(self._worker_count):
|
|
86
|
+
worker = asyncio.create_task(self._worker_loop(idx))
|
|
87
|
+
self._workers.append(worker)
|
|
88
|
+
|
|
89
|
+
async def stop(self) -> None:
|
|
90
|
+
if not self._running:
|
|
91
|
+
return
|
|
92
|
+
self._running = False
|
|
93
|
+
for worker in self._workers:
|
|
94
|
+
worker.cancel()
|
|
95
|
+
await asyncio.gather(*self._workers, return_exceptions=True)
|
|
96
|
+
self._workers.clear()
|
|
97
|
+
|
|
98
|
+
async def enqueue(
|
|
99
|
+
self,
|
|
100
|
+
payload: dict[str, Any],
|
|
101
|
+
handler: Optional[Callable[[dict[str, Any]], Awaitable[Any]]] = None,
|
|
102
|
+
metadata: Optional[dict[str, Any]] = None,
|
|
103
|
+
run_id: Optional[str] = None,
|
|
104
|
+
handler_name: Optional[str] = None,
|
|
105
|
+
) -> str:
|
|
106
|
+
task_id = run_id or str(uuid.uuid4())
|
|
107
|
+
if handler is None and handler_name:
|
|
108
|
+
handler = self._handler_registry.get(handler_name)
|
|
109
|
+
if handler is None:
|
|
110
|
+
raise ValueError("Handler must be provided or registered via handler_name")
|
|
111
|
+
task = QueueTask(
|
|
112
|
+
task_id=task_id,
|
|
113
|
+
payload=payload,
|
|
114
|
+
handler=handler,
|
|
115
|
+
metadata={**(metadata or {}), "handler_name": handler_name},
|
|
116
|
+
)
|
|
117
|
+
await self._backend.put(task)
|
|
118
|
+
return task_id
|
|
119
|
+
|
|
120
|
+
async def _worker_loop(self, worker_id: int) -> None:
|
|
121
|
+
while self._running:
|
|
122
|
+
try:
|
|
123
|
+
task = await self._backend.get()
|
|
124
|
+
await self._execute_with_retry(task)
|
|
125
|
+
logger.debug(
|
|
126
|
+
"Worker %s processed task %s", worker_id, task.task_id
|
|
127
|
+
)
|
|
128
|
+
except asyncio.CancelledError:
|
|
129
|
+
break
|
|
130
|
+
except Exception as exc:
|
|
131
|
+
logger.error("Worker %s failed: %s", worker_id, exc)
|
|
132
|
+
|
|
133
|
+
async def _execute_with_retry(self, task: QueueTask) -> None:
|
|
134
|
+
handler = task.handler
|
|
135
|
+
if handler is None:
|
|
136
|
+
handler_name = task.metadata.get("handler_name")
|
|
137
|
+
handler = self._handler_registry.get(handler_name) if handler_name else None
|
|
138
|
+
if handler is None:
|
|
139
|
+
raise ValueError(f"No handler registered for task {task.task_id}")
|
|
140
|
+
attempts = 0
|
|
141
|
+
while True:
|
|
142
|
+
try:
|
|
143
|
+
await handler(task.payload)
|
|
144
|
+
return
|
|
145
|
+
except Exception as exc:
|
|
146
|
+
attempts += 1
|
|
147
|
+
if attempts > self._max_retries:
|
|
148
|
+
raise exc
|
|
149
|
+
await asyncio.sleep(self._backoff_seconds * attempts)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class RedisQueueBackend:
|
|
153
|
+
"""Redis-backed queue backend for distributed execution.
|
|
154
|
+
|
|
155
|
+
Stores serialized QueueTask payloads in a Redis list and uses BLPOP to
|
|
156
|
+
retrieve work items.
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
def __init__(self, url: str, queue_name: str = "genxai:queue") -> None:
|
|
160
|
+
try:
|
|
161
|
+
import redis.asyncio as redis # type: ignore
|
|
162
|
+
except Exception as exc:
|
|
163
|
+
raise ImportError(
|
|
164
|
+
"redis package is required for RedisQueueBackend. Install with: pip install redis"
|
|
165
|
+
) from exc
|
|
166
|
+
|
|
167
|
+
self._redis = redis.from_url(url)
|
|
168
|
+
self._queue_name = queue_name
|
|
169
|
+
|
|
170
|
+
async def put(self, task: QueueTask) -> None:
|
|
171
|
+
payload = {
|
|
172
|
+
"task_id": task.task_id,
|
|
173
|
+
"payload": task.payload,
|
|
174
|
+
"metadata": task.metadata,
|
|
175
|
+
}
|
|
176
|
+
await self._redis.rpush(self._queue_name, json.dumps(payload))
|
|
177
|
+
|
|
178
|
+
async def get(self) -> QueueTask:
|
|
179
|
+
_, raw = await self._redis.blpop(self._queue_name)
|
|
180
|
+
data = json.loads(raw)
|
|
181
|
+
return QueueTask(
|
|
182
|
+
task_id=data["task_id"],
|
|
183
|
+
payload=data["payload"],
|
|
184
|
+
handler=None,
|
|
185
|
+
metadata=data.get("metadata", {}),
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
def qsize(self) -> int:
|
|
189
|
+
return int(self._redis.llen(self._queue_name))
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class RQQueueBackend:
|
|
193
|
+
"""Placeholder backend for Redis/RQ integration.
|
|
194
|
+
|
|
195
|
+
This is a stub that documents the interface needed for an RQ backend.
|
|
196
|
+
"""
|
|
197
|
+
|
|
198
|
+
def __init__(self) -> None:
|
|
199
|
+
raise NotImplementedError(
|
|
200
|
+
"RQQueueBackend is a stub. Implement with Redis + rq when ready."
|
|
201
|
+
)
|