ragnarbot-ai 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragnarbot/__init__.py +6 -0
- ragnarbot/__main__.py +8 -0
- ragnarbot/agent/__init__.py +8 -0
- ragnarbot/agent/context.py +223 -0
- ragnarbot/agent/loop.py +365 -0
- ragnarbot/agent/memory.py +109 -0
- ragnarbot/agent/skills.py +228 -0
- ragnarbot/agent/subagent.py +241 -0
- ragnarbot/agent/tools/__init__.py +6 -0
- ragnarbot/agent/tools/base.py +102 -0
- ragnarbot/agent/tools/cron.py +114 -0
- ragnarbot/agent/tools/filesystem.py +191 -0
- ragnarbot/agent/tools/message.py +86 -0
- ragnarbot/agent/tools/registry.py +73 -0
- ragnarbot/agent/tools/shell.py +141 -0
- ragnarbot/agent/tools/spawn.py +65 -0
- ragnarbot/agent/tools/web.py +163 -0
- ragnarbot/bus/__init__.py +6 -0
- ragnarbot/bus/events.py +37 -0
- ragnarbot/bus/queue.py +81 -0
- ragnarbot/channels/__init__.py +6 -0
- ragnarbot/channels/base.py +121 -0
- ragnarbot/channels/manager.py +129 -0
- ragnarbot/channels/telegram.py +302 -0
- ragnarbot/cli/__init__.py +1 -0
- ragnarbot/cli/commands.py +568 -0
- ragnarbot/config/__init__.py +6 -0
- ragnarbot/config/loader.py +95 -0
- ragnarbot/config/schema.py +114 -0
- ragnarbot/cron/__init__.py +6 -0
- ragnarbot/cron/service.py +346 -0
- ragnarbot/cron/types.py +59 -0
- ragnarbot/heartbeat/__init__.py +5 -0
- ragnarbot/heartbeat/service.py +130 -0
- ragnarbot/providers/__init__.py +6 -0
- ragnarbot/providers/base.py +69 -0
- ragnarbot/providers/litellm_provider.py +135 -0
- ragnarbot/providers/transcription.py +67 -0
- ragnarbot/session/__init__.py +5 -0
- ragnarbot/session/manager.py +202 -0
- ragnarbot/skills/README.md +24 -0
- ragnarbot/skills/cron/SKILL.md +40 -0
- ragnarbot/skills/github/SKILL.md +48 -0
- ragnarbot/skills/skill-creator/SKILL.md +371 -0
- ragnarbot/skills/summarize/SKILL.md +67 -0
- ragnarbot/skills/tmux/SKILL.md +121 -0
- ragnarbot/skills/tmux/scripts/find-sessions.sh +112 -0
- ragnarbot/skills/tmux/scripts/wait-for-text.sh +83 -0
- ragnarbot/skills/weather/SKILL.md +49 -0
- ragnarbot/utils/__init__.py +5 -0
- ragnarbot/utils/helpers.py +91 -0
- ragnarbot_ai-0.1.0.dist-info/METADATA +28 -0
- ragnarbot_ai-0.1.0.dist-info/RECORD +56 -0
- ragnarbot_ai-0.1.0.dist-info/WHEEL +4 -0
- ragnarbot_ai-0.1.0.dist-info/entry_points.txt +2 -0
- ragnarbot_ai-0.1.0.dist-info/licenses/LICENSE +22 -0
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
"""Web tools: web_search and web_fetch."""
|
|
2
|
+
|
|
3
|
+
import html
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import re
|
|
7
|
+
from typing import Any
|
|
8
|
+
from urllib.parse import urlparse
|
|
9
|
+
|
|
10
|
+
import httpx
|
|
11
|
+
|
|
12
|
+
from ragnarbot.agent.tools.base import Tool
|
|
13
|
+
|
|
14
|
+
# Shared constants
|
|
15
|
+
USER_AGENT = "Mozilla/5.0 (Macintosh; Intel Mac OS X 14_7_2) AppleWebKit/537.36"
|
|
16
|
+
MAX_REDIRECTS = 5 # Limit redirects to prevent DoS attacks
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _strip_tags(text: str) -> str:
|
|
20
|
+
"""Remove HTML tags and decode entities."""
|
|
21
|
+
text = re.sub(r'<script[\s\S]*?</script>', '', text, flags=re.I)
|
|
22
|
+
text = re.sub(r'<style[\s\S]*?</style>', '', text, flags=re.I)
|
|
23
|
+
text = re.sub(r'<[^>]+>', '', text)
|
|
24
|
+
return html.unescape(text).strip()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _normalize(text: str) -> str:
|
|
28
|
+
"""Normalize whitespace."""
|
|
29
|
+
text = re.sub(r'[ \t]+', ' ', text)
|
|
30
|
+
return re.sub(r'\n{3,}', '\n\n', text).strip()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _validate_url(url: str) -> tuple[bool, str]:
|
|
34
|
+
"""Validate URL: must be http(s) with valid domain."""
|
|
35
|
+
try:
|
|
36
|
+
p = urlparse(url)
|
|
37
|
+
if p.scheme not in ('http', 'https'):
|
|
38
|
+
return False, f"Only http/https allowed, got '{p.scheme or 'none'}'"
|
|
39
|
+
if not p.netloc:
|
|
40
|
+
return False, "Missing domain"
|
|
41
|
+
return True, ""
|
|
42
|
+
except Exception as e:
|
|
43
|
+
return False, str(e)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class WebSearchTool(Tool):
|
|
47
|
+
"""Search the web using Brave Search API."""
|
|
48
|
+
|
|
49
|
+
name = "web_search"
|
|
50
|
+
description = "Search the web. Returns titles, URLs, and snippets."
|
|
51
|
+
parameters = {
|
|
52
|
+
"type": "object",
|
|
53
|
+
"properties": {
|
|
54
|
+
"query": {"type": "string", "description": "Search query"},
|
|
55
|
+
"count": {"type": "integer", "description": "Results (1-10)", "minimum": 1, "maximum": 10}
|
|
56
|
+
},
|
|
57
|
+
"required": ["query"]
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
def __init__(self, api_key: str | None = None, max_results: int = 5):
|
|
61
|
+
self.api_key = api_key or os.environ.get("BRAVE_API_KEY", "")
|
|
62
|
+
self.max_results = max_results
|
|
63
|
+
|
|
64
|
+
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
|
|
65
|
+
if not self.api_key:
|
|
66
|
+
return "Error: BRAVE_API_KEY not configured"
|
|
67
|
+
|
|
68
|
+
try:
|
|
69
|
+
n = min(max(count or self.max_results, 1), 10)
|
|
70
|
+
async with httpx.AsyncClient() as client:
|
|
71
|
+
r = await client.get(
|
|
72
|
+
"https://api.search.brave.com/res/v1/web/search",
|
|
73
|
+
params={"q": query, "count": n},
|
|
74
|
+
headers={"Accept": "application/json", "X-Subscription-Token": self.api_key},
|
|
75
|
+
timeout=10.0
|
|
76
|
+
)
|
|
77
|
+
r.raise_for_status()
|
|
78
|
+
|
|
79
|
+
results = r.json().get("web", {}).get("results", [])
|
|
80
|
+
if not results:
|
|
81
|
+
return f"No results for: {query}"
|
|
82
|
+
|
|
83
|
+
lines = [f"Results for: {query}\n"]
|
|
84
|
+
for i, item in enumerate(results[:n], 1):
|
|
85
|
+
lines.append(f"{i}. {item.get('title', '')}\n {item.get('url', '')}")
|
|
86
|
+
if desc := item.get("description"):
|
|
87
|
+
lines.append(f" {desc}")
|
|
88
|
+
return "\n".join(lines)
|
|
89
|
+
except Exception as e:
|
|
90
|
+
return f"Error: {e}"
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class WebFetchTool(Tool):
|
|
94
|
+
"""Fetch and extract content from a URL using Readability."""
|
|
95
|
+
|
|
96
|
+
name = "web_fetch"
|
|
97
|
+
description = "Fetch URL and extract readable content (HTML → markdown/text)."
|
|
98
|
+
parameters = {
|
|
99
|
+
"type": "object",
|
|
100
|
+
"properties": {
|
|
101
|
+
"url": {"type": "string", "description": "URL to fetch"},
|
|
102
|
+
"extractMode": {"type": "string", "enum": ["markdown", "text"], "default": "markdown"},
|
|
103
|
+
"maxChars": {"type": "integer", "minimum": 100}
|
|
104
|
+
},
|
|
105
|
+
"required": ["url"]
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
def __init__(self, max_chars: int = 50000):
|
|
109
|
+
self.max_chars = max_chars
|
|
110
|
+
|
|
111
|
+
async def execute(self, url: str, extractMode: str = "markdown", maxChars: int | None = None, **kwargs: Any) -> str:
|
|
112
|
+
from readability import Document
|
|
113
|
+
|
|
114
|
+
max_chars = maxChars or self.max_chars
|
|
115
|
+
|
|
116
|
+
# Validate URL before fetching
|
|
117
|
+
is_valid, error_msg = _validate_url(url)
|
|
118
|
+
if not is_valid:
|
|
119
|
+
return json.dumps({"error": f"URL validation failed: {error_msg}", "url": url})
|
|
120
|
+
|
|
121
|
+
try:
|
|
122
|
+
async with httpx.AsyncClient(
|
|
123
|
+
follow_redirects=True,
|
|
124
|
+
max_redirects=MAX_REDIRECTS,
|
|
125
|
+
timeout=30.0
|
|
126
|
+
) as client:
|
|
127
|
+
r = await client.get(url, headers={"User-Agent": USER_AGENT})
|
|
128
|
+
r.raise_for_status()
|
|
129
|
+
|
|
130
|
+
ctype = r.headers.get("content-type", "")
|
|
131
|
+
|
|
132
|
+
# JSON
|
|
133
|
+
if "application/json" in ctype:
|
|
134
|
+
text, extractor = json.dumps(r.json(), indent=2), "json"
|
|
135
|
+
# HTML
|
|
136
|
+
elif "text/html" in ctype or r.text[:256].lower().startswith(("<!doctype", "<html")):
|
|
137
|
+
doc = Document(r.text)
|
|
138
|
+
content = self._to_markdown(doc.summary()) if extractMode == "markdown" else _strip_tags(doc.summary())
|
|
139
|
+
text = f"# {doc.title()}\n\n{content}" if doc.title() else content
|
|
140
|
+
extractor = "readability"
|
|
141
|
+
else:
|
|
142
|
+
text, extractor = r.text, "raw"
|
|
143
|
+
|
|
144
|
+
truncated = len(text) > max_chars
|
|
145
|
+
if truncated:
|
|
146
|
+
text = text[:max_chars]
|
|
147
|
+
|
|
148
|
+
return json.dumps({"url": url, "finalUrl": str(r.url), "status": r.status_code,
|
|
149
|
+
"extractor": extractor, "truncated": truncated, "length": len(text), "text": text})
|
|
150
|
+
except Exception as e:
|
|
151
|
+
return json.dumps({"error": str(e), "url": url})
|
|
152
|
+
|
|
153
|
+
def _to_markdown(self, html: str) -> str:
|
|
154
|
+
"""Convert HTML to markdown."""
|
|
155
|
+
# Convert links, headings, lists before stripping tags
|
|
156
|
+
text = re.sub(r'<a\s+[^>]*href=["\']([^"\']+)["\'][^>]*>([\s\S]*?)</a>',
|
|
157
|
+
lambda m: f'[{_strip_tags(m[2])}]({m[1]})', html, flags=re.I)
|
|
158
|
+
text = re.sub(r'<h([1-6])[^>]*>([\s\S]*?)</h\1>',
|
|
159
|
+
lambda m: f'\n{"#" * int(m[1])} {_strip_tags(m[2])}\n', text, flags=re.I)
|
|
160
|
+
text = re.sub(r'<li[^>]*>([\s\S]*?)</li>', lambda m: f'\n- {_strip_tags(m[1])}', text, flags=re.I)
|
|
161
|
+
text = re.sub(r'</(p|div|section|article)>', '\n\n', text, flags=re.I)
|
|
162
|
+
text = re.sub(r'<(br|hr)\s*/?>', '\n', text, flags=re.I)
|
|
163
|
+
return _normalize(_strip_tags(text))
|
ragnarbot/bus/events.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""Event types for the message bus."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class InboundMessage:
|
|
10
|
+
"""Message received from a chat channel."""
|
|
11
|
+
|
|
12
|
+
channel: str # e.g. "telegram"
|
|
13
|
+
sender_id: str # User identifier
|
|
14
|
+
chat_id: str # Chat/channel identifier
|
|
15
|
+
content: str # Message text
|
|
16
|
+
timestamp: datetime = field(default_factory=datetime.now)
|
|
17
|
+
media: list[str] = field(default_factory=list) # Media URLs
|
|
18
|
+
metadata: dict[str, Any] = field(default_factory=dict) # Channel-specific data
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def session_key(self) -> str:
|
|
22
|
+
"""Unique key for session identification."""
|
|
23
|
+
return f"{self.channel}:{self.chat_id}"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class OutboundMessage:
|
|
28
|
+
"""Message to send to a chat channel."""
|
|
29
|
+
|
|
30
|
+
channel: str
|
|
31
|
+
chat_id: str
|
|
32
|
+
content: str
|
|
33
|
+
reply_to: str | None = None
|
|
34
|
+
media: list[str] = field(default_factory=list)
|
|
35
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
36
|
+
|
|
37
|
+
|
ragnarbot/bus/queue.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
"""Async message queue for decoupled channel-agent communication."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from typing import Callable, Awaitable
|
|
5
|
+
|
|
6
|
+
from loguru import logger
|
|
7
|
+
|
|
8
|
+
from ragnarbot.bus.events import InboundMessage, OutboundMessage
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class MessageBus:
|
|
12
|
+
"""
|
|
13
|
+
Async message bus that decouples chat channels from the agent core.
|
|
14
|
+
|
|
15
|
+
Channels push messages to the inbound queue, and the agent processes
|
|
16
|
+
them and pushes responses to the outbound queue.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self):
|
|
20
|
+
self.inbound: asyncio.Queue[InboundMessage] = asyncio.Queue()
|
|
21
|
+
self.outbound: asyncio.Queue[OutboundMessage] = asyncio.Queue()
|
|
22
|
+
self._outbound_subscribers: dict[str, list[Callable[[OutboundMessage], Awaitable[None]]]] = {}
|
|
23
|
+
self._running = False
|
|
24
|
+
|
|
25
|
+
async def publish_inbound(self, msg: InboundMessage) -> None:
|
|
26
|
+
"""Publish a message from a channel to the agent."""
|
|
27
|
+
await self.inbound.put(msg)
|
|
28
|
+
|
|
29
|
+
async def consume_inbound(self) -> InboundMessage:
|
|
30
|
+
"""Consume the next inbound message (blocks until available)."""
|
|
31
|
+
return await self.inbound.get()
|
|
32
|
+
|
|
33
|
+
async def publish_outbound(self, msg: OutboundMessage) -> None:
|
|
34
|
+
"""Publish a response from the agent to channels."""
|
|
35
|
+
await self.outbound.put(msg)
|
|
36
|
+
|
|
37
|
+
async def consume_outbound(self) -> OutboundMessage:
|
|
38
|
+
"""Consume the next outbound message (blocks until available)."""
|
|
39
|
+
return await self.outbound.get()
|
|
40
|
+
|
|
41
|
+
def subscribe_outbound(
|
|
42
|
+
self,
|
|
43
|
+
channel: str,
|
|
44
|
+
callback: Callable[[OutboundMessage], Awaitable[None]]
|
|
45
|
+
) -> None:
|
|
46
|
+
"""Subscribe to outbound messages for a specific channel."""
|
|
47
|
+
if channel not in self._outbound_subscribers:
|
|
48
|
+
self._outbound_subscribers[channel] = []
|
|
49
|
+
self._outbound_subscribers[channel].append(callback)
|
|
50
|
+
|
|
51
|
+
async def dispatch_outbound(self) -> None:
|
|
52
|
+
"""
|
|
53
|
+
Dispatch outbound messages to subscribed channels.
|
|
54
|
+
Run this as a background task.
|
|
55
|
+
"""
|
|
56
|
+
self._running = True
|
|
57
|
+
while self._running:
|
|
58
|
+
try:
|
|
59
|
+
msg = await asyncio.wait_for(self.outbound.get(), timeout=1.0)
|
|
60
|
+
subscribers = self._outbound_subscribers.get(msg.channel, [])
|
|
61
|
+
for callback in subscribers:
|
|
62
|
+
try:
|
|
63
|
+
await callback(msg)
|
|
64
|
+
except Exception as e:
|
|
65
|
+
logger.error(f"Error dispatching to {msg.channel}: {e}")
|
|
66
|
+
except asyncio.TimeoutError:
|
|
67
|
+
continue
|
|
68
|
+
|
|
69
|
+
def stop(self) -> None:
|
|
70
|
+
"""Stop the dispatcher loop."""
|
|
71
|
+
self._running = False
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def inbound_size(self) -> int:
|
|
75
|
+
"""Number of pending inbound messages."""
|
|
76
|
+
return self.inbound.qsize()
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def outbound_size(self) -> int:
|
|
80
|
+
"""Number of pending outbound messages."""
|
|
81
|
+
return self.outbound.qsize()
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
"""Base channel interface for chat platforms."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from ragnarbot.bus.events import InboundMessage, OutboundMessage
|
|
7
|
+
from ragnarbot.bus.queue import MessageBus
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BaseChannel(ABC):
|
|
11
|
+
"""
|
|
12
|
+
Abstract base class for chat channel implementations.
|
|
13
|
+
|
|
14
|
+
Each channel should implement this interface
|
|
15
|
+
to integrate with the ragnarbot message bus.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
name: str = "base"
|
|
19
|
+
|
|
20
|
+
def __init__(self, config: Any, bus: MessageBus):
|
|
21
|
+
"""
|
|
22
|
+
Initialize the channel.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
config: Channel-specific configuration.
|
|
26
|
+
bus: The message bus for communication.
|
|
27
|
+
"""
|
|
28
|
+
self.config = config
|
|
29
|
+
self.bus = bus
|
|
30
|
+
self._running = False
|
|
31
|
+
|
|
32
|
+
@abstractmethod
|
|
33
|
+
async def start(self) -> None:
|
|
34
|
+
"""
|
|
35
|
+
Start the channel and begin listening for messages.
|
|
36
|
+
|
|
37
|
+
This should be a long-running async task that:
|
|
38
|
+
1. Connects to the chat platform
|
|
39
|
+
2. Listens for incoming messages
|
|
40
|
+
3. Forwards messages to the bus via _handle_message()
|
|
41
|
+
"""
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
async def stop(self) -> None:
|
|
46
|
+
"""Stop the channel and clean up resources."""
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
async def send(self, msg: OutboundMessage) -> None:
|
|
51
|
+
"""
|
|
52
|
+
Send a message through this channel.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
msg: The message to send.
|
|
56
|
+
"""
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
def is_allowed(self, sender_id: str) -> bool:
|
|
60
|
+
"""
|
|
61
|
+
Check if a sender is allowed to use this bot.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
sender_id: The sender's identifier.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
True if allowed, False otherwise.
|
|
68
|
+
"""
|
|
69
|
+
allow_list = getattr(self.config, "allow_from", [])
|
|
70
|
+
|
|
71
|
+
# If no allow list, allow everyone
|
|
72
|
+
if not allow_list:
|
|
73
|
+
return True
|
|
74
|
+
|
|
75
|
+
sender_str = str(sender_id)
|
|
76
|
+
if sender_str in allow_list:
|
|
77
|
+
return True
|
|
78
|
+
if "|" in sender_str:
|
|
79
|
+
for part in sender_str.split("|"):
|
|
80
|
+
if part and part in allow_list:
|
|
81
|
+
return True
|
|
82
|
+
return False
|
|
83
|
+
|
|
84
|
+
async def _handle_message(
|
|
85
|
+
self,
|
|
86
|
+
sender_id: str,
|
|
87
|
+
chat_id: str,
|
|
88
|
+
content: str,
|
|
89
|
+
media: list[str] | None = None,
|
|
90
|
+
metadata: dict[str, Any] | None = None
|
|
91
|
+
) -> None:
|
|
92
|
+
"""
|
|
93
|
+
Handle an incoming message from the chat platform.
|
|
94
|
+
|
|
95
|
+
This method checks permissions and forwards to the bus.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
sender_id: The sender's identifier.
|
|
99
|
+
chat_id: The chat/channel identifier.
|
|
100
|
+
content: Message text content.
|
|
101
|
+
media: Optional list of media URLs.
|
|
102
|
+
metadata: Optional channel-specific metadata.
|
|
103
|
+
"""
|
|
104
|
+
if not self.is_allowed(sender_id):
|
|
105
|
+
return
|
|
106
|
+
|
|
107
|
+
msg = InboundMessage(
|
|
108
|
+
channel=self.name,
|
|
109
|
+
sender_id=str(sender_id),
|
|
110
|
+
chat_id=str(chat_id),
|
|
111
|
+
content=content,
|
|
112
|
+
media=media or [],
|
|
113
|
+
metadata=metadata or {}
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
await self.bus.publish_inbound(msg)
|
|
117
|
+
|
|
118
|
+
@property
|
|
119
|
+
def is_running(self) -> bool:
|
|
120
|
+
"""Check if the channel is running."""
|
|
121
|
+
return self._running
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
"""Channel manager for coordinating chat channels."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from loguru import logger
|
|
7
|
+
|
|
8
|
+
from ragnarbot.bus.events import OutboundMessage
|
|
9
|
+
from ragnarbot.bus.queue import MessageBus
|
|
10
|
+
from ragnarbot.channels.base import BaseChannel
|
|
11
|
+
from ragnarbot.config.schema import Config
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ChannelManager:
|
|
15
|
+
"""
|
|
16
|
+
Manages chat channels and coordinates message routing.
|
|
17
|
+
|
|
18
|
+
Responsibilities:
|
|
19
|
+
- Initialize enabled channels (Telegram)
|
|
20
|
+
- Start/stop channels
|
|
21
|
+
- Route outbound messages
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, config: Config, bus: MessageBus):
|
|
25
|
+
self.config = config
|
|
26
|
+
self.bus = bus
|
|
27
|
+
self.channels: dict[str, BaseChannel] = {}
|
|
28
|
+
self._dispatch_task: asyncio.Task | None = None
|
|
29
|
+
|
|
30
|
+
self._init_channels()
|
|
31
|
+
|
|
32
|
+
def _init_channels(self) -> None:
|
|
33
|
+
"""Initialize channels based on config."""
|
|
34
|
+
|
|
35
|
+
# Telegram channel
|
|
36
|
+
if self.config.channels.telegram.enabled:
|
|
37
|
+
try:
|
|
38
|
+
from ragnarbot.channels.telegram import TelegramChannel
|
|
39
|
+
self.channels["telegram"] = TelegramChannel(
|
|
40
|
+
self.config.channels.telegram,
|
|
41
|
+
self.bus,
|
|
42
|
+
groq_api_key=self.config.transcription.api_key,
|
|
43
|
+
)
|
|
44
|
+
logger.info("Telegram channel enabled")
|
|
45
|
+
except ImportError as e:
|
|
46
|
+
logger.warning(f"Telegram channel not available: {e}")
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
async def start_all(self) -> None:
|
|
50
|
+
"""Start all channels and the outbound dispatcher."""
|
|
51
|
+
if not self.channels:
|
|
52
|
+
logger.warning("No channels enabled")
|
|
53
|
+
return
|
|
54
|
+
|
|
55
|
+
# Start outbound dispatcher
|
|
56
|
+
self._dispatch_task = asyncio.create_task(self._dispatch_outbound())
|
|
57
|
+
|
|
58
|
+
# Start channels
|
|
59
|
+
tasks = []
|
|
60
|
+
for name, channel in self.channels.items():
|
|
61
|
+
logger.info(f"Starting {name} channel...")
|
|
62
|
+
tasks.append(asyncio.create_task(channel.start()))
|
|
63
|
+
|
|
64
|
+
# Wait for all to complete (they should run forever)
|
|
65
|
+
await asyncio.gather(*tasks, return_exceptions=True)
|
|
66
|
+
|
|
67
|
+
async def stop_all(self) -> None:
|
|
68
|
+
"""Stop all channels and the dispatcher."""
|
|
69
|
+
logger.info("Stopping all channels...")
|
|
70
|
+
|
|
71
|
+
# Stop dispatcher
|
|
72
|
+
if self._dispatch_task:
|
|
73
|
+
self._dispatch_task.cancel()
|
|
74
|
+
try:
|
|
75
|
+
await self._dispatch_task
|
|
76
|
+
except asyncio.CancelledError:
|
|
77
|
+
pass
|
|
78
|
+
|
|
79
|
+
# Stop all channels
|
|
80
|
+
for name, channel in self.channels.items():
|
|
81
|
+
try:
|
|
82
|
+
await channel.stop()
|
|
83
|
+
logger.info(f"Stopped {name} channel")
|
|
84
|
+
except Exception as e:
|
|
85
|
+
logger.error(f"Error stopping {name}: {e}")
|
|
86
|
+
|
|
87
|
+
async def _dispatch_outbound(self) -> None:
|
|
88
|
+
"""Dispatch outbound messages to the appropriate channel."""
|
|
89
|
+
logger.info("Outbound dispatcher started")
|
|
90
|
+
|
|
91
|
+
while True:
|
|
92
|
+
try:
|
|
93
|
+
msg = await asyncio.wait_for(
|
|
94
|
+
self.bus.consume_outbound(),
|
|
95
|
+
timeout=1.0
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
channel = self.channels.get(msg.channel)
|
|
99
|
+
if channel:
|
|
100
|
+
try:
|
|
101
|
+
await channel.send(msg)
|
|
102
|
+
except Exception as e:
|
|
103
|
+
logger.error(f"Error sending to {msg.channel}: {e}")
|
|
104
|
+
else:
|
|
105
|
+
logger.warning(f"Unknown channel: {msg.channel}")
|
|
106
|
+
|
|
107
|
+
except asyncio.TimeoutError:
|
|
108
|
+
continue
|
|
109
|
+
except asyncio.CancelledError:
|
|
110
|
+
break
|
|
111
|
+
|
|
112
|
+
def get_channel(self, name: str) -> BaseChannel | None:
|
|
113
|
+
"""Get a channel by name."""
|
|
114
|
+
return self.channels.get(name)
|
|
115
|
+
|
|
116
|
+
def get_status(self) -> dict[str, Any]:
|
|
117
|
+
"""Get status of all channels."""
|
|
118
|
+
return {
|
|
119
|
+
name: {
|
|
120
|
+
"enabled": True,
|
|
121
|
+
"running": channel.is_running
|
|
122
|
+
}
|
|
123
|
+
for name, channel in self.channels.items()
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
@property
|
|
127
|
+
def enabled_channels(self) -> list[str]:
|
|
128
|
+
"""Get list of enabled channel names."""
|
|
129
|
+
return list(self.channels.keys())
|