squidbot 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
squidbot/session.py ADDED
@@ -0,0 +1,609 @@
1
+ """
2
+ Session Management - Channel-agnostic session handling with JSONL transcripts.
3
+
4
+ Supports multiple channels: Telegram, WhatsApp, Discord, TCP, etc.
5
+ Uses JSONL format for session transcripts (one JSON object per line).
6
+ """
7
+
8
+ import json
9
+ import logging
10
+ import time
11
+ import uuid
12
+ from dataclasses import dataclass, field
13
+ from datetime import datetime
14
+ from enum import Enum
15
+ from pathlib import Path
16
+ from typing import Any, Iterator
17
+
18
+ from .config import DATA_DIR
19
+ from .lanes import CommandLane
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # Session format version
24
+ SESSION_VERSION = 1
25
+
26
+
27
+ class ChannelType(str, Enum):
28
+ """Supported messaging channels."""
29
+
30
+ TELEGRAM = "telegram"
31
+ WHATSAPP = "whatsapp"
32
+ DISCORD = "discord"
33
+ SLACK = "slack"
34
+ TCP = "tcp" # Local TCP client
35
+ WEB = "web" # Web interface
36
+ API = "api" # Direct API access
37
+
38
+ def __str__(self) -> str:
39
+ return self.value
40
+
41
+ @property
42
+ def supports_media(self) -> bool:
43
+ """Whether this channel supports media messages."""
44
+ return self in (
45
+ ChannelType.TELEGRAM,
46
+ ChannelType.WHATSAPP,
47
+ ChannelType.DISCORD,
48
+ ChannelType.SLACK,
49
+ ChannelType.WEB,
50
+ )
51
+
52
+ @property
53
+ def supports_reactions(self) -> bool:
54
+ """Whether this channel supports message reactions."""
55
+ return self in (
56
+ ChannelType.TELEGRAM,
57
+ ChannelType.DISCORD,
58
+ ChannelType.SLACK,
59
+ )
60
+
61
+ @property
62
+ def max_message_length(self) -> int:
63
+ """Maximum message length for this channel."""
64
+ limits = {
65
+ ChannelType.TELEGRAM: 4096,
66
+ ChannelType.WHATSAPP: 4096,
67
+ ChannelType.DISCORD: 2000,
68
+ ChannelType.SLACK: 40000,
69
+ ChannelType.TCP: 0, # No limit
70
+ ChannelType.WEB: 0, # No limit
71
+ ChannelType.API: 0, # No limit
72
+ }
73
+ return limits.get(self, 4096)
74
+
75
+
76
+ @dataclass
77
+ class DeliveryContext:
78
+ """Context for message delivery routing."""
79
+
80
+ channel: ChannelType
81
+ recipient_id: str # Chat ID, user ID, etc.
82
+ account_id: str | None = None # Bot account if multiple
83
+ thread_id: str | None = None # Thread/topic ID if applicable
84
+ guild_id: str | None = None # Discord guild/server ID
85
+ metadata: dict[str, Any] = field(default_factory=dict)
86
+
87
+ def to_dict(self) -> dict[str, Any]:
88
+ return {
89
+ "channel": str(self.channel),
90
+ "recipient_id": self.recipient_id,
91
+ "account_id": self.account_id,
92
+ "thread_id": self.thread_id,
93
+ "guild_id": self.guild_id,
94
+ "metadata": self.metadata,
95
+ }
96
+
97
+ @classmethod
98
+ def from_dict(cls, data: dict[str, Any]) -> "DeliveryContext":
99
+ return cls(
100
+ channel=ChannelType(data["channel"]),
101
+ recipient_id=data["recipient_id"],
102
+ account_id=data.get("account_id"),
103
+ thread_id=data.get("thread_id"),
104
+ guild_id=data.get("guild_id"),
105
+ metadata=data.get("metadata", {}),
106
+ )
107
+
108
+
109
+ @dataclass
110
+ class SessionEntry:
111
+ """Session metadata stored in sessions.json index."""
112
+
113
+ session_id: str
114
+ session_key: str
115
+ channel: ChannelType
116
+ recipient_id: str
117
+ transcript_file: str # Path to .jsonl transcript
118
+ created_at: float
119
+ updated_at: float
120
+ last_lane: CommandLane = CommandLane.MAIN
121
+ delivery_context: DeliveryContext | None = None
122
+ metadata: dict[str, Any] = field(default_factory=dict)
123
+ message_count: int = 0
124
+ display_name: str | None = None
125
+
126
+ def to_dict(self) -> dict[str, Any]:
127
+ return {
128
+ "session_id": self.session_id,
129
+ "session_key": self.session_key,
130
+ "channel": str(self.channel),
131
+ "recipient_id": self.recipient_id,
132
+ "transcript_file": self.transcript_file,
133
+ "created_at": self.created_at,
134
+ "updated_at": self.updated_at,
135
+ "last_lane": str(self.last_lane),
136
+ "delivery_context": (
137
+ self.delivery_context.to_dict() if self.delivery_context else None
138
+ ),
139
+ "metadata": self.metadata,
140
+ "message_count": self.message_count,
141
+ "display_name": self.display_name,
142
+ }
143
+
144
+ @classmethod
145
+ def from_dict(cls, data: dict[str, Any]) -> "SessionEntry":
146
+ delivery_ctx = data.get("delivery_context")
147
+ return cls(
148
+ session_id=data["session_id"],
149
+ session_key=data["session_key"],
150
+ channel=ChannelType(data["channel"]),
151
+ recipient_id=data["recipient_id"],
152
+ transcript_file=data["transcript_file"],
153
+ created_at=data.get("created_at", time.time()),
154
+ updated_at=data.get("updated_at", time.time()),
155
+ last_lane=CommandLane(data.get("last_lane", "main")),
156
+ delivery_context=(
157
+ DeliveryContext.from_dict(delivery_ctx) if delivery_ctx else None
158
+ ),
159
+ metadata=data.get("metadata", {}),
160
+ message_count=data.get("message_count", 0),
161
+ display_name=data.get("display_name"),
162
+ )
163
+
164
+
165
+ @dataclass
166
+ class TranscriptMessage:
167
+ """A single message in a session transcript."""
168
+
169
+ type: str # "message", "tool_call", "tool_result", "system", etc.
170
+ role: str # "user", "assistant", "system"
171
+ content: str
172
+ timestamp: str
173
+ metadata: dict[str, Any] = field(default_factory=dict)
174
+
175
+ def to_dict(self) -> dict[str, Any]:
176
+ return {
177
+ "type": self.type,
178
+ "role": self.role,
179
+ "content": self.content,
180
+ "timestamp": self.timestamp,
181
+ "metadata": self.metadata,
182
+ }
183
+
184
+ @classmethod
185
+ def from_dict(cls, data: dict[str, Any]) -> "TranscriptMessage":
186
+ return cls(
187
+ type=data.get("type", "message"),
188
+ role=data["role"],
189
+ content=data["content"],
190
+ timestamp=data.get("timestamp", datetime.now().isoformat()),
191
+ metadata=data.get("metadata", {}),
192
+ )
193
+
194
+
195
+ class SessionTranscript:
196
+ """JSONL-based session transcript."""
197
+
198
+ def __init__(self, file_path: Path, session_id: str | None = None):
199
+ self.file_path = file_path
200
+ self.session_id = session_id or str(uuid.uuid4())
201
+ self._ensure_header()
202
+
203
+ def _ensure_header(self) -> None:
204
+ """Ensure transcript has a header line."""
205
+ if not self.file_path.exists():
206
+ self.file_path.parent.mkdir(parents=True, exist_ok=True)
207
+ header = {
208
+ "type": "session",
209
+ "version": SESSION_VERSION,
210
+ "id": self.session_id,
211
+ "timestamp": datetime.now().isoformat(),
212
+ }
213
+ self._append_line(header)
214
+
215
+ def _append_line(self, data: dict) -> None:
216
+ """Append a JSON line to the transcript."""
217
+ with open(self.file_path, "a", encoding="utf-8") as f:
218
+ f.write(json.dumps(data, ensure_ascii=False) + "\n")
219
+
220
+ def append_message(
221
+ self,
222
+ role: str,
223
+ content: str,
224
+ msg_type: str = "message",
225
+ metadata: dict | None = None,
226
+ ) -> None:
227
+ """Append a message to the transcript."""
228
+ msg = TranscriptMessage(
229
+ type=msg_type,
230
+ role=role,
231
+ content=content,
232
+ timestamp=datetime.now().isoformat(),
233
+ metadata=metadata or {},
234
+ )
235
+ self._append_line(msg.to_dict())
236
+
237
+ def append_user_message(self, content: str, metadata: dict | None = None) -> None:
238
+ """Append a user message."""
239
+ self.append_message("user", content, "message", metadata)
240
+
241
+ def append_assistant_message(
242
+ self, content: str, metadata: dict | None = None
243
+ ) -> None:
244
+ """Append an assistant message."""
245
+ self.append_message("assistant", content, "message", metadata)
246
+
247
+ def append_tool_call(
248
+ self, tool_name: str, tool_input: dict, metadata: dict | None = None
249
+ ) -> None:
250
+ """Append a tool call."""
251
+ self._append_line(
252
+ {
253
+ "type": "tool_call",
254
+ "tool": tool_name,
255
+ "input": tool_input,
256
+ "timestamp": datetime.now().isoformat(),
257
+ "metadata": metadata or {},
258
+ }
259
+ )
260
+
261
+ def append_tool_result(
262
+ self, tool_name: str, result: str, metadata: dict | None = None
263
+ ) -> None:
264
+ """Append a tool result."""
265
+ self._append_line(
266
+ {
267
+ "type": "tool_result",
268
+ "tool": tool_name,
269
+ "result": result,
270
+ "timestamp": datetime.now().isoformat(),
271
+ "metadata": metadata or {},
272
+ }
273
+ )
274
+
275
+ def read_messages(self) -> Iterator[dict]:
276
+ """Read all messages from transcript (excluding header)."""
277
+ if not self.file_path.exists():
278
+ return
279
+
280
+ with open(self.file_path, "r", encoding="utf-8") as f:
281
+ for line in f:
282
+ line = line.strip()
283
+ if not line:
284
+ continue
285
+ try:
286
+ data = json.loads(line)
287
+ if data.get("type") != "session": # Skip header
288
+ yield data
289
+ except json.JSONDecodeError:
290
+ logger.warning(f"Invalid JSON line in transcript: {line[:50]}...")
291
+
292
+ def get_history(self, limit: int | None = None) -> list[dict]:
293
+ """Get conversation history as list of message dicts."""
294
+ messages = []
295
+ for msg in self.read_messages():
296
+ if msg.get("type") == "message":
297
+ messages.append(
298
+ {
299
+ "role": msg["role"],
300
+ "content": msg["content"],
301
+ }
302
+ )
303
+ if limit:
304
+ messages = messages[-limit:]
305
+ return messages
306
+
307
+ def get_full_history(self) -> list[dict]:
308
+ """Get full transcript including tool calls."""
309
+ return list(self.read_messages())
310
+
311
+ def count_messages(self) -> int:
312
+ """Count messages in transcript."""
313
+ count = 0
314
+ for msg in self.read_messages():
315
+ if msg.get("type") == "message":
316
+ count += 1
317
+ return count
318
+
319
+ def clear(self) -> None:
320
+ """Clear transcript and write new header."""
321
+ if self.file_path.exists():
322
+ self.file_path.unlink()
323
+ self._ensure_header()
324
+
325
+
326
+ @dataclass
327
+ class Session:
328
+ """Session representing a conversation context."""
329
+
330
+ entry: SessionEntry
331
+ transcript: SessionTranscript
332
+
333
+ @property
334
+ def session_key(self) -> str:
335
+ return self.entry.session_key
336
+
337
+ @property
338
+ def session_id(self) -> str:
339
+ return self.entry.session_id
340
+
341
+ @property
342
+ def channel(self) -> ChannelType:
343
+ return self.entry.channel
344
+
345
+ @property
346
+ def recipient_id(self) -> str:
347
+ return self.entry.recipient_id
348
+
349
+ @property
350
+ def history(self) -> list[dict]:
351
+ """Get conversation history."""
352
+ return self.transcript.get_history()
353
+
354
+ @history.setter
355
+ def history(self, value: list[dict]) -> None:
356
+ """Set history by clearing and rewriting transcript."""
357
+ self.transcript.clear()
358
+ for msg in value:
359
+ self.transcript.append_message(
360
+ role=msg["role"],
361
+ content=msg["content"],
362
+ )
363
+ self.entry.message_count = len(value)
364
+
365
+ @property
366
+ def delivery_context(self) -> DeliveryContext | None:
367
+ return self.entry.delivery_context
368
+
369
+ @delivery_context.setter
370
+ def delivery_context(self, value: DeliveryContext | None) -> None:
371
+ self.entry.delivery_context = value
372
+
373
+ @property
374
+ def last_lane(self) -> CommandLane:
375
+ return self.entry.last_lane
376
+
377
+ @last_lane.setter
378
+ def last_lane(self, value: CommandLane) -> None:
379
+ self.entry.last_lane = value
380
+
381
+ @classmethod
382
+ def create_key(cls, channel: ChannelType, recipient_id: str) -> str:
383
+ """Generate a unique session key."""
384
+ return f"{channel}:{recipient_id}"
385
+
386
+ def touch(self) -> None:
387
+ """Update the session timestamp."""
388
+ self.entry.updated_at = time.time()
389
+
390
+ def add_message(self, role: str, content: str) -> None:
391
+ """Add a message to transcript."""
392
+ self.transcript.append_message(role, content)
393
+ self.entry.message_count += 1
394
+ self.touch()
395
+
396
+ def clear_history(self) -> None:
397
+ """Clear conversation history."""
398
+ self.transcript.clear()
399
+ self.entry.message_count = 0
400
+ self.touch()
401
+
402
+ def to_dict(self) -> dict[str, Any]:
403
+ """For backwards compatibility."""
404
+ return {
405
+ **self.entry.to_dict(),
406
+ "history": self.history,
407
+ }
408
+
409
+
410
+ class SessionManager:
411
+ """Manages sessions across all channels using JSONL transcripts."""
412
+
413
+ def __init__(self, store_path: Path | None = None):
414
+ self.store_path = store_path or DATA_DIR / "sessions"
415
+ self.store_path.mkdir(parents=True, exist_ok=True)
416
+ self._index_file = self.store_path / "sessions.json"
417
+ self._sessions: dict[str, Session] = {}
418
+ self._entries: dict[str, SessionEntry] = {}
419
+ self._load_index()
420
+
421
+ def _load_index(self) -> None:
422
+ """Load session index from sessions.json."""
423
+ if not self._index_file.exists():
424
+ return
425
+
426
+ try:
427
+ data = json.loads(self._index_file.read_text())
428
+ for key, entry_data in data.items():
429
+ try:
430
+ entry = SessionEntry.from_dict(entry_data)
431
+ self._entries[key] = entry
432
+ except Exception as e:
433
+ logger.warning(f"Failed to load session entry {key}: {e}")
434
+ except Exception as e:
435
+ logger.error(f"Failed to load session index: {e}")
436
+
437
+ def _save_index(self) -> None:
438
+ """Save session index to sessions.json."""
439
+ try:
440
+ data = {key: entry.to_dict() for key, entry in self._entries.items()}
441
+ self._index_file.write_text(json.dumps(data, indent=2, ensure_ascii=False))
442
+ except Exception as e:
443
+ logger.error(f"Failed to save session index: {e}")
444
+
445
+ def _transcript_path(self, session_id: str, thread_id: str | None = None) -> Path:
446
+ """Get transcript file path."""
447
+ if thread_id:
448
+ safe_thread = str(thread_id).replace("/", "_").replace(":", "_")
449
+ filename = f"{session_id}-thread-{safe_thread}.jsonl"
450
+ else:
451
+ filename = f"{session_id}.jsonl"
452
+ return self.store_path / filename
453
+
454
+ def _get_or_create_session(
455
+ self,
456
+ channel: ChannelType,
457
+ recipient_id: str,
458
+ create_if_missing: bool = True,
459
+ ) -> Session | None:
460
+ """Get or create a session."""
461
+ session_key = Session.create_key(channel, recipient_id)
462
+
463
+ # Check cache first
464
+ if session_key in self._sessions:
465
+ return self._sessions[session_key]
466
+
467
+ # Check index
468
+ if session_key in self._entries:
469
+ entry = self._entries[session_key]
470
+ transcript_path = Path(entry.transcript_file)
471
+ if not transcript_path.is_absolute():
472
+ transcript_path = self.store_path / transcript_path
473
+ transcript = SessionTranscript(transcript_path, entry.session_id)
474
+ session = Session(entry=entry, transcript=transcript)
475
+ self._sessions[session_key] = session
476
+ return session
477
+
478
+ # Create new
479
+ if create_if_missing:
480
+ session_id = str(uuid.uuid4())
481
+ transcript_path = self._transcript_path(session_id)
482
+ transcript = SessionTranscript(transcript_path, session_id)
483
+
484
+ entry = SessionEntry(
485
+ session_id=session_id,
486
+ session_key=session_key,
487
+ channel=channel,
488
+ recipient_id=recipient_id,
489
+ transcript_file=str(transcript_path.relative_to(self.store_path)),
490
+ created_at=time.time(),
491
+ updated_at=time.time(),
492
+ )
493
+
494
+ session = Session(entry=entry, transcript=transcript)
495
+ self._sessions[session_key] = session
496
+ self._entries[session_key] = entry
497
+ self._save_index()
498
+ return session
499
+
500
+ return None
501
+
502
+ def get(
503
+ self,
504
+ channel: ChannelType,
505
+ recipient_id: str,
506
+ create_if_missing: bool = True,
507
+ ) -> Session | None:
508
+ """Get or create a session."""
509
+ return self._get_or_create_session(channel, recipient_id, create_if_missing)
510
+
511
+ def get_by_key(self, session_key: str) -> Session | None:
512
+ """Get a session by its key."""
513
+ if session_key in self._sessions:
514
+ return self._sessions[session_key]
515
+
516
+ if session_key in self._entries:
517
+ entry = self._entries[session_key]
518
+ return self.get(entry.channel, entry.recipient_id, create_if_missing=False)
519
+
520
+ return None
521
+
522
+ def update(self, session: Session) -> None:
523
+ """Update a session."""
524
+ session.touch()
525
+ self._sessions[session.session_key] = session
526
+ self._entries[session.session_key] = session.entry
527
+ self._save_index()
528
+
529
+ def delete(self, session_key: str) -> bool:
530
+ """Delete a session."""
531
+ if session_key not in self._entries:
532
+ return False
533
+
534
+ entry = self._entries[session_key]
535
+
536
+ # Delete transcript file
537
+ transcript_path = self.store_path / entry.transcript_file
538
+ if transcript_path.exists():
539
+ transcript_path.unlink()
540
+
541
+ # Remove from caches
542
+ self._sessions.pop(session_key, None)
543
+ del self._entries[session_key]
544
+ self._save_index()
545
+
546
+ return True
547
+
548
+ def list_sessions(
549
+ self,
550
+ channel: ChannelType | None = None,
551
+ ) -> list[Session]:
552
+ """List all sessions, optionally filtered by channel."""
553
+ sessions = []
554
+ for key, entry in self._entries.items():
555
+ if channel and entry.channel != channel:
556
+ continue
557
+ session = self.get_by_key(key)
558
+ if session:
559
+ sessions.append(session)
560
+ return sorted(sessions, key=lambda s: s.entry.updated_at, reverse=True)
561
+
562
+ def get_active_delivery_contexts(
563
+ self,
564
+ channel: ChannelType | None = None,
565
+ ) -> list[DeliveryContext]:
566
+ """Get all active delivery contexts for broadcasting."""
567
+ contexts = []
568
+ for session in self.list_sessions(channel):
569
+ if session.delivery_context:
570
+ contexts.append(session.delivery_context)
571
+ else:
572
+ contexts.append(
573
+ DeliveryContext(
574
+ channel=session.channel,
575
+ recipient_id=session.recipient_id,
576
+ )
577
+ )
578
+ return contexts
579
+
580
+
581
+ # Global session manager instance
582
+ _session_manager: SessionManager | None = None
583
+
584
+
585
+ def get_session_manager() -> SessionManager:
586
+ """Get the global session manager instance."""
587
+ global _session_manager
588
+ if _session_manager is None:
589
+ _session_manager = SessionManager()
590
+ return _session_manager
591
+
592
+
593
+ def record_inbound_session(
594
+ channel: ChannelType,
595
+ recipient_id: str,
596
+ lane: CommandLane = CommandLane.MAIN,
597
+ delivery_context: DeliveryContext | None = None,
598
+ ) -> Session:
599
+ """Record an inbound message session."""
600
+ manager = get_session_manager()
601
+ session = manager.get(channel, recipient_id, create_if_missing=True)
602
+
603
+ if session:
604
+ session.last_lane = lane
605
+ if delivery_context:
606
+ session.delivery_context = delivery_context
607
+ manager.update(session)
608
+
609
+ return session