zwarm 2.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zwarm/__init__.py +38 -0
- zwarm/adapters/__init__.py +21 -0
- zwarm/adapters/base.py +109 -0
- zwarm/adapters/claude_code.py +357 -0
- zwarm/adapters/codex_mcp.py +1262 -0
- zwarm/adapters/registry.py +69 -0
- zwarm/adapters/test_codex_mcp.py +274 -0
- zwarm/adapters/test_registry.py +68 -0
- zwarm/cli/__init__.py +0 -0
- zwarm/cli/main.py +2503 -0
- zwarm/core/__init__.py +0 -0
- zwarm/core/compact.py +329 -0
- zwarm/core/config.py +344 -0
- zwarm/core/environment.py +173 -0
- zwarm/core/models.py +315 -0
- zwarm/core/state.py +355 -0
- zwarm/core/test_compact.py +312 -0
- zwarm/core/test_config.py +160 -0
- zwarm/core/test_models.py +265 -0
- zwarm/orchestrator.py +683 -0
- zwarm/prompts/__init__.py +10 -0
- zwarm/prompts/orchestrator.py +230 -0
- zwarm/sessions/__init__.py +26 -0
- zwarm/sessions/manager.py +792 -0
- zwarm/test_orchestrator_watchers.py +23 -0
- zwarm/tools/__init__.py +17 -0
- zwarm/tools/delegation.py +784 -0
- zwarm/watchers/__init__.py +31 -0
- zwarm/watchers/base.py +131 -0
- zwarm/watchers/builtin.py +518 -0
- zwarm/watchers/llm_watcher.py +319 -0
- zwarm/watchers/manager.py +181 -0
- zwarm/watchers/registry.py +57 -0
- zwarm/watchers/test_watchers.py +237 -0
- zwarm-2.3.5.dist-info/METADATA +309 -0
- zwarm-2.3.5.dist-info/RECORD +38 -0
- zwarm-2.3.5.dist-info/WHEEL +4 -0
- zwarm-2.3.5.dist-info/entry_points.txt +2 -0
zwarm/core/state.py
ADDED
|
@@ -0,0 +1,355 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Flat-file state management for zwarm.
|
|
3
|
+
|
|
4
|
+
State structure (with instance isolation):
|
|
5
|
+
.zwarm/
|
|
6
|
+
├── instances.json # Registry of all instances
|
|
7
|
+
└── instances/
|
|
8
|
+
└── <instance-id>/
|
|
9
|
+
├── state.json # Current state (sessions, tasks)
|
|
10
|
+
├── events.jsonl # Append-only event log
|
|
11
|
+
├── sessions/
|
|
12
|
+
│ └── <session-id>/
|
|
13
|
+
│ ├── messages.json
|
|
14
|
+
│ └── output.log
|
|
15
|
+
└── orchestrator/
|
|
16
|
+
└── messages.json # Orchestrator's message history (for resume)
|
|
17
|
+
|
|
18
|
+
Legacy structure (single instance, for backwards compat):
|
|
19
|
+
.zwarm/
|
|
20
|
+
├── state.json
|
|
21
|
+
├── events.jsonl
|
|
22
|
+
└── ...
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
from __future__ import annotations
|
|
26
|
+
|
|
27
|
+
import json
|
|
28
|
+
from datetime import datetime
|
|
29
|
+
from pathlib import Path
|
|
30
|
+
from typing import Any
|
|
31
|
+
from uuid import uuid4
|
|
32
|
+
|
|
33
|
+
from .models import ConversationSession, Event, Task
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# --- Instance Registry ---
|
|
37
|
+
|
|
38
|
+
def get_instances_registry_path(base_dir: Path | str = ".zwarm") -> Path:
|
|
39
|
+
"""Get path to the instances registry file."""
|
|
40
|
+
return Path(base_dir) / "instances.json"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def list_instances(base_dir: Path | str = ".zwarm") -> list[dict[str, Any]]:
|
|
44
|
+
"""List all registered instances."""
|
|
45
|
+
registry_path = get_instances_registry_path(base_dir)
|
|
46
|
+
if not registry_path.exists():
|
|
47
|
+
return []
|
|
48
|
+
try:
|
|
49
|
+
return json.loads(registry_path.read_text()).get("instances", [])
|
|
50
|
+
except (json.JSONDecodeError, KeyError):
|
|
51
|
+
return []
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def register_instance(
|
|
55
|
+
instance_id: str,
|
|
56
|
+
name: str | None = None,
|
|
57
|
+
task: str | None = None,
|
|
58
|
+
base_dir: Path | str = ".zwarm",
|
|
59
|
+
) -> None:
|
|
60
|
+
"""Register an instance in the global registry."""
|
|
61
|
+
base = Path(base_dir)
|
|
62
|
+
base.mkdir(parents=True, exist_ok=True)
|
|
63
|
+
|
|
64
|
+
registry_path = get_instances_registry_path(base_dir)
|
|
65
|
+
|
|
66
|
+
# Load existing registry
|
|
67
|
+
if registry_path.exists():
|
|
68
|
+
try:
|
|
69
|
+
registry = json.loads(registry_path.read_text())
|
|
70
|
+
except json.JSONDecodeError:
|
|
71
|
+
registry = {"instances": []}
|
|
72
|
+
else:
|
|
73
|
+
registry = {"instances": []}
|
|
74
|
+
|
|
75
|
+
# Check if instance already registered
|
|
76
|
+
existing_ids = {inst["id"] for inst in registry["instances"]}
|
|
77
|
+
if instance_id in existing_ids:
|
|
78
|
+
# Update existing entry
|
|
79
|
+
for inst in registry["instances"]:
|
|
80
|
+
if inst["id"] == instance_id:
|
|
81
|
+
inst["updated_at"] = datetime.now().isoformat()
|
|
82
|
+
inst["status"] = "active"
|
|
83
|
+
if name:
|
|
84
|
+
inst["name"] = name
|
|
85
|
+
if task:
|
|
86
|
+
inst["task"] = task[:100] # Truncate
|
|
87
|
+
break
|
|
88
|
+
else:
|
|
89
|
+
# Add new entry
|
|
90
|
+
registry["instances"].append({
|
|
91
|
+
"id": instance_id,
|
|
92
|
+
"name": name or instance_id[:8],
|
|
93
|
+
"task": (task[:100] if task else None),
|
|
94
|
+
"created_at": datetime.now().isoformat(),
|
|
95
|
+
"updated_at": datetime.now().isoformat(),
|
|
96
|
+
"status": "active",
|
|
97
|
+
})
|
|
98
|
+
|
|
99
|
+
registry_path.write_text(json.dumps(registry, indent=2))
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def update_instance_status(
|
|
103
|
+
instance_id: str,
|
|
104
|
+
status: str,
|
|
105
|
+
base_dir: Path | str = ".zwarm",
|
|
106
|
+
) -> None:
|
|
107
|
+
"""Update an instance's status in the registry."""
|
|
108
|
+
registry_path = get_instances_registry_path(base_dir)
|
|
109
|
+
if not registry_path.exists():
|
|
110
|
+
return
|
|
111
|
+
|
|
112
|
+
try:
|
|
113
|
+
registry = json.loads(registry_path.read_text())
|
|
114
|
+
except json.JSONDecodeError:
|
|
115
|
+
return
|
|
116
|
+
|
|
117
|
+
for inst in registry.get("instances", []):
|
|
118
|
+
if inst["id"] == instance_id:
|
|
119
|
+
inst["status"] = status
|
|
120
|
+
inst["updated_at"] = datetime.now().isoformat()
|
|
121
|
+
break
|
|
122
|
+
|
|
123
|
+
registry_path.write_text(json.dumps(registry, indent=2))
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def get_instance_state_dir(
|
|
127
|
+
instance_id: str | None = None,
|
|
128
|
+
base_dir: Path | str = ".zwarm",
|
|
129
|
+
) -> Path:
|
|
130
|
+
"""
|
|
131
|
+
Get the state directory for an instance.
|
|
132
|
+
|
|
133
|
+
If instance_id is None, returns the legacy path for backwards compat.
|
|
134
|
+
"""
|
|
135
|
+
base = Path(base_dir)
|
|
136
|
+
if instance_id is None:
|
|
137
|
+
return base # Legacy: .zwarm/
|
|
138
|
+
return base / "instances" / instance_id
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _json_serializer(obj: Any) -> Any:
|
|
142
|
+
"""Custom JSON serializer for non-standard types."""
|
|
143
|
+
# Handle pydantic models
|
|
144
|
+
if hasattr(obj, "model_dump"):
|
|
145
|
+
return obj.model_dump()
|
|
146
|
+
# Handle objects with __dict__
|
|
147
|
+
if hasattr(obj, "__dict__"):
|
|
148
|
+
return {k: v for k, v in obj.__dict__.items() if not k.startswith("_")}
|
|
149
|
+
# Handle datetime
|
|
150
|
+
if hasattr(obj, "isoformat"):
|
|
151
|
+
return obj.isoformat()
|
|
152
|
+
# Fallback to string representation
|
|
153
|
+
return str(obj)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class StateManager:
|
|
157
|
+
"""
|
|
158
|
+
Manages flat-file state for zwarm.
|
|
159
|
+
|
|
160
|
+
All state is stored as JSON files in a directory.
|
|
161
|
+
With instance isolation: .zwarm/instances/<instance-id>/
|
|
162
|
+
Legacy (no instance): .zwarm/
|
|
163
|
+
|
|
164
|
+
This enables:
|
|
165
|
+
- Git-backed history
|
|
166
|
+
- Easy debugging (just read the files)
|
|
167
|
+
- Resume from previous state
|
|
168
|
+
- Multiple concurrent orchestrators (with instance isolation)
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
def __init__(
|
|
172
|
+
self,
|
|
173
|
+
state_dir: Path | str = ".zwarm",
|
|
174
|
+
instance_id: str | None = None,
|
|
175
|
+
):
|
|
176
|
+
self.base_dir = Path(state_dir)
|
|
177
|
+
self.instance_id = instance_id
|
|
178
|
+
|
|
179
|
+
# Resolve actual state directory
|
|
180
|
+
if instance_id:
|
|
181
|
+
self.state_dir = get_instance_state_dir(instance_id, self.base_dir)
|
|
182
|
+
else:
|
|
183
|
+
self.state_dir = self.base_dir
|
|
184
|
+
|
|
185
|
+
self._sessions: dict[str, ConversationSession] = {}
|
|
186
|
+
self._tasks: dict[str, Task] = {}
|
|
187
|
+
self._orchestrator_messages: list[dict[str, Any]] = []
|
|
188
|
+
|
|
189
|
+
def init(self) -> None:
|
|
190
|
+
"""Initialize state directory structure."""
|
|
191
|
+
self.state_dir.mkdir(parents=True, exist_ok=True)
|
|
192
|
+
(self.state_dir / "sessions").mkdir(exist_ok=True)
|
|
193
|
+
(self.state_dir / "orchestrator").mkdir(exist_ok=True)
|
|
194
|
+
|
|
195
|
+
# Touch events.jsonl
|
|
196
|
+
events_file = self.state_dir / "events.jsonl"
|
|
197
|
+
if not events_file.exists():
|
|
198
|
+
events_file.touch()
|
|
199
|
+
|
|
200
|
+
# --- Sessions ---
|
|
201
|
+
|
|
202
|
+
def add_session(self, session: ConversationSession) -> None:
|
|
203
|
+
"""Add a session and persist it."""
|
|
204
|
+
self._sessions[session.id] = session
|
|
205
|
+
self._save_session(session)
|
|
206
|
+
self._save_state()
|
|
207
|
+
|
|
208
|
+
def get_session(self, session_id: str) -> ConversationSession | None:
|
|
209
|
+
"""Get a session by ID."""
|
|
210
|
+
return self._sessions.get(session_id)
|
|
211
|
+
|
|
212
|
+
def update_session(self, session: ConversationSession) -> None:
|
|
213
|
+
"""Update a session and persist it."""
|
|
214
|
+
self._sessions[session.id] = session
|
|
215
|
+
self._save_session(session)
|
|
216
|
+
self._save_state()
|
|
217
|
+
|
|
218
|
+
def list_sessions(self, status: str | None = None) -> list[ConversationSession]:
|
|
219
|
+
"""List sessions, optionally filtered by status."""
|
|
220
|
+
sessions = list(self._sessions.values())
|
|
221
|
+
if status:
|
|
222
|
+
sessions = [s for s in sessions if s.status.value == status]
|
|
223
|
+
return sessions
|
|
224
|
+
|
|
225
|
+
def _save_session(self, session: ConversationSession) -> None:
|
|
226
|
+
"""Save session to its own directory."""
|
|
227
|
+
session_dir = self.state_dir / "sessions" / session.id
|
|
228
|
+
session_dir.mkdir(parents=True, exist_ok=True)
|
|
229
|
+
|
|
230
|
+
# Save messages
|
|
231
|
+
messages_file = session_dir / "messages.json"
|
|
232
|
+
messages_file.write_text(json.dumps([m.to_dict() for m in session.messages], indent=2))
|
|
233
|
+
|
|
234
|
+
# --- Tasks ---
|
|
235
|
+
|
|
236
|
+
def add_task(self, task: Task) -> None:
|
|
237
|
+
"""Add a task and persist it."""
|
|
238
|
+
self._tasks[task.id] = task
|
|
239
|
+
self._save_state()
|
|
240
|
+
|
|
241
|
+
def get_task(self, task_id: str) -> Task | None:
|
|
242
|
+
"""Get a task by ID."""
|
|
243
|
+
return self._tasks.get(task_id)
|
|
244
|
+
|
|
245
|
+
def update_task(self, task: Task) -> None:
|
|
246
|
+
"""Update a task and persist it."""
|
|
247
|
+
self._tasks[task.id] = task
|
|
248
|
+
self._save_state()
|
|
249
|
+
|
|
250
|
+
def list_tasks(self, status: str | None = None) -> list[Task]:
|
|
251
|
+
"""List tasks, optionally filtered by status."""
|
|
252
|
+
tasks = list(self._tasks.values())
|
|
253
|
+
if status:
|
|
254
|
+
tasks = [t for t in tasks if t.status.value == status]
|
|
255
|
+
return tasks
|
|
256
|
+
|
|
257
|
+
# --- Events ---
|
|
258
|
+
|
|
259
|
+
def log_event(self, event: Event) -> None:
|
|
260
|
+
"""Append an event to the log."""
|
|
261
|
+
events_file = self.state_dir / "events.jsonl"
|
|
262
|
+
with open(events_file, "a") as f:
|
|
263
|
+
f.write(json.dumps(event.to_dict()) + "\n")
|
|
264
|
+
|
|
265
|
+
def get_events(
|
|
266
|
+
self,
|
|
267
|
+
session_id: str | None = None,
|
|
268
|
+
task_id: str | None = None,
|
|
269
|
+
kind: str | None = None,
|
|
270
|
+
limit: int | None = None,
|
|
271
|
+
) -> list[Event]:
|
|
272
|
+
"""Read events from the log, optionally filtered."""
|
|
273
|
+
events_file = self.state_dir / "events.jsonl"
|
|
274
|
+
if not events_file.exists():
|
|
275
|
+
return []
|
|
276
|
+
|
|
277
|
+
events = []
|
|
278
|
+
with open(events_file) as f:
|
|
279
|
+
for line in f:
|
|
280
|
+
line = line.strip()
|
|
281
|
+
if not line:
|
|
282
|
+
continue
|
|
283
|
+
event = Event.from_dict(json.loads(line))
|
|
284
|
+
if session_id and event.session_id != session_id:
|
|
285
|
+
continue
|
|
286
|
+
if task_id and event.task_id != task_id:
|
|
287
|
+
continue
|
|
288
|
+
if kind and event.kind != kind:
|
|
289
|
+
continue
|
|
290
|
+
events.append(event)
|
|
291
|
+
|
|
292
|
+
# Most recent first
|
|
293
|
+
events.reverse()
|
|
294
|
+
if limit:
|
|
295
|
+
events = events[:limit]
|
|
296
|
+
return events
|
|
297
|
+
|
|
298
|
+
# --- Orchestrator State ---
|
|
299
|
+
|
|
300
|
+
def save_orchestrator_messages(self, messages: list[dict[str, Any]]) -> None:
|
|
301
|
+
"""Save orchestrator's message history for resume."""
|
|
302
|
+
self._orchestrator_messages = messages
|
|
303
|
+
messages_file = self.state_dir / "orchestrator" / "messages.json"
|
|
304
|
+
# Use custom encoder to handle non-serializable types
|
|
305
|
+
messages_file.write_text(json.dumps(messages, indent=2, default=_json_serializer))
|
|
306
|
+
|
|
307
|
+
def load_orchestrator_messages(self) -> list[dict[str, Any]]:
|
|
308
|
+
"""Load orchestrator's message history for resume."""
|
|
309
|
+
messages_file = self.state_dir / "orchestrator" / "messages.json"
|
|
310
|
+
if not messages_file.exists():
|
|
311
|
+
return []
|
|
312
|
+
return json.loads(messages_file.read_text())
|
|
313
|
+
|
|
314
|
+
# --- State Persistence ---
|
|
315
|
+
|
|
316
|
+
def _save_state(self) -> None:
|
|
317
|
+
"""Save current state to state.json."""
|
|
318
|
+
state = {
|
|
319
|
+
"updated_at": datetime.now().isoformat(),
|
|
320
|
+
"sessions": {sid: s.to_dict() for sid, s in self._sessions.items()},
|
|
321
|
+
"tasks": {tid: t.to_dict() for tid, t in self._tasks.items()},
|
|
322
|
+
}
|
|
323
|
+
state_file = self.state_dir / "state.json"
|
|
324
|
+
state_file.write_text(json.dumps(state, indent=2))
|
|
325
|
+
|
|
326
|
+
def load(self) -> None:
|
|
327
|
+
"""Load state from state.json."""
|
|
328
|
+
state_file = self.state_dir / "state.json"
|
|
329
|
+
if not state_file.exists():
|
|
330
|
+
return
|
|
331
|
+
|
|
332
|
+
state = json.loads(state_file.read_text())
|
|
333
|
+
|
|
334
|
+
# Load sessions
|
|
335
|
+
for sid, sdata in state.get("sessions", {}).items():
|
|
336
|
+
self._sessions[sid] = ConversationSession.from_dict(sdata)
|
|
337
|
+
|
|
338
|
+
# Load tasks
|
|
339
|
+
for tid, tdata in state.get("tasks", {}).items():
|
|
340
|
+
self._tasks[tid] = Task.from_dict(tdata)
|
|
341
|
+
|
|
342
|
+
def clear(self) -> None:
|
|
343
|
+
"""Clear all state (for testing)."""
|
|
344
|
+
self._sessions.clear()
|
|
345
|
+
self._tasks.clear()
|
|
346
|
+
self._orchestrator_messages.clear()
|
|
347
|
+
|
|
348
|
+
# Clear files
|
|
349
|
+
state_file = self.state_dir / "state.json"
|
|
350
|
+
if state_file.exists():
|
|
351
|
+
state_file.unlink()
|
|
352
|
+
|
|
353
|
+
events_file = self.state_dir / "events.jsonl"
|
|
354
|
+
if events_file.exists():
|
|
355
|
+
events_file.write_text("")
|
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
"""Tests for the compact module."""
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from zwarm.core.compact import (
|
|
6
|
+
compact_messages,
|
|
7
|
+
estimate_tokens,
|
|
8
|
+
find_tool_groups,
|
|
9
|
+
should_compact,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TestEstimateTokens:
|
|
14
|
+
def test_simple_messages(self):
|
|
15
|
+
"""Estimate tokens for simple text messages."""
|
|
16
|
+
messages = [
|
|
17
|
+
{"role": "user", "content": "Hello world"}, # 11 chars
|
|
18
|
+
{"role": "assistant", "content": "Hi there!"}, # 9 chars
|
|
19
|
+
]
|
|
20
|
+
# ~20 chars / 4 = ~5 tokens
|
|
21
|
+
tokens = estimate_tokens(messages)
|
|
22
|
+
assert tokens == 5
|
|
23
|
+
|
|
24
|
+
def test_empty_messages(self):
|
|
25
|
+
"""Empty messages return 0 tokens."""
|
|
26
|
+
assert estimate_tokens([]) == 0
|
|
27
|
+
|
|
28
|
+
def test_messages_with_tool_calls(self):
|
|
29
|
+
"""Tool calls add to token count."""
|
|
30
|
+
messages = [
|
|
31
|
+
{
|
|
32
|
+
"role": "assistant",
|
|
33
|
+
"content": "Let me check",
|
|
34
|
+
"tool_calls": [
|
|
35
|
+
{"function": {"name": "read", "arguments": '{"path": "/foo/bar"}'}}
|
|
36
|
+
],
|
|
37
|
+
}
|
|
38
|
+
]
|
|
39
|
+
tokens = estimate_tokens(messages)
|
|
40
|
+
assert tokens > 0
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class TestFindToolGroups:
|
|
44
|
+
def test_no_tool_calls(self):
|
|
45
|
+
"""No tool groups in simple conversation."""
|
|
46
|
+
messages = [
|
|
47
|
+
{"role": "system", "content": "You are helpful"},
|
|
48
|
+
{"role": "user", "content": "Hello"},
|
|
49
|
+
{"role": "assistant", "content": "Hi!"},
|
|
50
|
+
]
|
|
51
|
+
groups = find_tool_groups(messages)
|
|
52
|
+
assert groups == []
|
|
53
|
+
|
|
54
|
+
def test_openai_format_tool_call(self):
|
|
55
|
+
"""Detect OpenAI-style tool call groups."""
|
|
56
|
+
messages = [
|
|
57
|
+
{"role": "system", "content": "System"},
|
|
58
|
+
{"role": "user", "content": "Read file"},
|
|
59
|
+
{
|
|
60
|
+
"role": "assistant",
|
|
61
|
+
"content": "Reading...",
|
|
62
|
+
"tool_calls": [{"id": "tc1", "function": {"name": "read"}}],
|
|
63
|
+
},
|
|
64
|
+
{"role": "tool", "tool_call_id": "tc1", "content": "file contents"},
|
|
65
|
+
{"role": "assistant", "content": "Here's the file"},
|
|
66
|
+
]
|
|
67
|
+
groups = find_tool_groups(messages)
|
|
68
|
+
assert groups == [(2, 3)] # Assistant with tool_calls + tool response
|
|
69
|
+
|
|
70
|
+
def test_multiple_tool_responses(self):
|
|
71
|
+
"""Group includes all consecutive tool responses."""
|
|
72
|
+
messages = [
|
|
73
|
+
{"role": "user", "content": "Do things"},
|
|
74
|
+
{
|
|
75
|
+
"role": "assistant",
|
|
76
|
+
"tool_calls": [
|
|
77
|
+
{"id": "tc1", "function": {"name": "a"}},
|
|
78
|
+
{"id": "tc2", "function": {"name": "b"}},
|
|
79
|
+
],
|
|
80
|
+
},
|
|
81
|
+
{"role": "tool", "tool_call_id": "tc1", "content": "result1"},
|
|
82
|
+
{"role": "tool", "tool_call_id": "tc2", "content": "result2"},
|
|
83
|
+
{"role": "assistant", "content": "Done"},
|
|
84
|
+
]
|
|
85
|
+
groups = find_tool_groups(messages)
|
|
86
|
+
assert groups == [(1, 3)] # Indices 1, 2, 3 form one group
|
|
87
|
+
|
|
88
|
+
def test_anthropic_format_tool_use(self):
|
|
89
|
+
"""Detect Anthropic-style tool_use content blocks."""
|
|
90
|
+
messages = [
|
|
91
|
+
{"role": "user", "content": "Read file"},
|
|
92
|
+
{
|
|
93
|
+
"role": "assistant",
|
|
94
|
+
"content": [
|
|
95
|
+
{"type": "text", "text": "Reading..."},
|
|
96
|
+
{"type": "tool_use", "id": "tu1", "name": "read", "input": {}},
|
|
97
|
+
],
|
|
98
|
+
},
|
|
99
|
+
{
|
|
100
|
+
"role": "user",
|
|
101
|
+
"content": [
|
|
102
|
+
{"type": "tool_result", "tool_use_id": "tu1", "content": "data"},
|
|
103
|
+
],
|
|
104
|
+
},
|
|
105
|
+
{"role": "assistant", "content": "Got it"},
|
|
106
|
+
]
|
|
107
|
+
groups = find_tool_groups(messages)
|
|
108
|
+
assert groups == [(1, 2)] # Assistant with tool_use + user with tool_result
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class TestCompactMessages:
|
|
112
|
+
def test_no_compaction_needed_few_messages(self):
|
|
113
|
+
"""Don't compact if we have fewer messages than keep thresholds."""
|
|
114
|
+
messages = [
|
|
115
|
+
{"role": "system", "content": "System"},
|
|
116
|
+
{"role": "user", "content": "Task"},
|
|
117
|
+
{"role": "assistant", "content": "Response"},
|
|
118
|
+
]
|
|
119
|
+
result = compact_messages(messages, keep_first_n=2, keep_last_n=2)
|
|
120
|
+
assert not result.was_compacted
|
|
121
|
+
assert result.messages == messages
|
|
122
|
+
assert "Too few" in result.preserved_reason
|
|
123
|
+
|
|
124
|
+
def test_compacts_middle_messages(self):
|
|
125
|
+
"""Remove messages from the middle, keeping first and last."""
|
|
126
|
+
messages = [
|
|
127
|
+
{"role": "system", "content": "System"},
|
|
128
|
+
{"role": "user", "content": "Task"},
|
|
129
|
+
{"role": "assistant", "content": "Step 1"},
|
|
130
|
+
{"role": "user", "content": "Continue"},
|
|
131
|
+
{"role": "assistant", "content": "Step 2"},
|
|
132
|
+
{"role": "user", "content": "More"},
|
|
133
|
+
{"role": "assistant", "content": "Step 3"},
|
|
134
|
+
{"role": "user", "content": "Final"},
|
|
135
|
+
{"role": "assistant", "content": "Done"},
|
|
136
|
+
]
|
|
137
|
+
result = compact_messages(messages, keep_first_n=2, keep_last_n=2)
|
|
138
|
+
|
|
139
|
+
assert result.was_compacted
|
|
140
|
+
assert result.removed_count > 0
|
|
141
|
+
# First 2 and last 2 should be preserved
|
|
142
|
+
assert result.messages[0]["content"] == "System"
|
|
143
|
+
assert result.messages[1]["content"] == "Task"
|
|
144
|
+
assert result.messages[-1]["content"] == "Done"
|
|
145
|
+
assert result.messages[-2]["content"] == "Final"
|
|
146
|
+
|
|
147
|
+
def test_preserves_tool_call_pairs(self):
|
|
148
|
+
"""Never split tool call from its response."""
|
|
149
|
+
messages = [
|
|
150
|
+
{"role": "system", "content": "System"},
|
|
151
|
+
{"role": "user", "content": "Task"},
|
|
152
|
+
{"role": "assistant", "content": "Old message 1"},
|
|
153
|
+
{"role": "assistant", "content": "Old message 2"},
|
|
154
|
+
{
|
|
155
|
+
"role": "assistant",
|
|
156
|
+
"content": "Calling tool",
|
|
157
|
+
"tool_calls": [{"id": "tc1", "function": {"name": "test"}}],
|
|
158
|
+
},
|
|
159
|
+
{"role": "tool", "tool_call_id": "tc1", "content": "Tool result"},
|
|
160
|
+
{"role": "assistant", "content": "Recent 1"},
|
|
161
|
+
{"role": "user", "content": "Recent 2"},
|
|
162
|
+
]
|
|
163
|
+
result = compact_messages(messages, keep_first_n=2, keep_last_n=2)
|
|
164
|
+
|
|
165
|
+
# The tool call pair should either both be kept or both removed
|
|
166
|
+
has_tool_call = any(m.get("tool_calls") for m in result.messages)
|
|
167
|
+
has_tool_response = any(m.get("role") == "tool" for m in result.messages)
|
|
168
|
+
|
|
169
|
+
# They should match - either both present or both absent
|
|
170
|
+
assert has_tool_call == has_tool_response
|
|
171
|
+
|
|
172
|
+
def test_adds_compaction_marker(self):
|
|
173
|
+
"""Add a marker message when compaction occurs."""
|
|
174
|
+
messages = [
|
|
175
|
+
{"role": "system", "content": "System"},
|
|
176
|
+
{"role": "user", "content": "Task"},
|
|
177
|
+
] + [{"role": "assistant", "content": f"Msg {i}"} for i in range(20)]
|
|
178
|
+
|
|
179
|
+
result = compact_messages(messages, keep_first_n=2, keep_last_n=3)
|
|
180
|
+
|
|
181
|
+
if result.was_compacted:
|
|
182
|
+
# Should have a system message about compaction
|
|
183
|
+
marker_msgs = [
|
|
184
|
+
m for m in result.messages
|
|
185
|
+
if m.get("role") == "system" and "compacted" in m.get("content", "").lower()
|
|
186
|
+
]
|
|
187
|
+
assert len(marker_msgs) == 1
|
|
188
|
+
|
|
189
|
+
def test_token_based_compaction(self):
|
|
190
|
+
"""Compact based on token threshold."""
|
|
191
|
+
# Create messages that exceed token limit
|
|
192
|
+
messages = [
|
|
193
|
+
{"role": "system", "content": "System prompt " * 100},
|
|
194
|
+
{"role": "user", "content": "Task " * 100},
|
|
195
|
+
] + [
|
|
196
|
+
{"role": "assistant", "content": f"Response {i} " * 50}
|
|
197
|
+
for i in range(10)
|
|
198
|
+
]
|
|
199
|
+
|
|
200
|
+
# Should not compact if under limit
|
|
201
|
+
result_under = compact_messages(messages, max_tokens=100000)
|
|
202
|
+
# Might or might not compact depending on estimate
|
|
203
|
+
|
|
204
|
+
# Should compact if over limit
|
|
205
|
+
result_over = compact_messages(messages, max_tokens=100, target_token_pct=0.5)
|
|
206
|
+
# With such a low limit, should definitely try to compact
|
|
207
|
+
assert result_over.original_count == len(messages)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class TestShouldCompact:
|
|
211
|
+
def test_under_threshold(self):
|
|
212
|
+
"""Don't compact when under threshold."""
|
|
213
|
+
messages = [{"role": "user", "content": "Hello"}]
|
|
214
|
+
assert not should_compact(messages, max_tokens=1000, threshold_pct=0.85)
|
|
215
|
+
|
|
216
|
+
def test_over_threshold(self):
|
|
217
|
+
"""Compact when over threshold."""
|
|
218
|
+
messages = [{"role": "user", "content": "x" * 4000}] # ~1000 tokens
|
|
219
|
+
assert should_compact(messages, max_tokens=500, threshold_pct=0.85)
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
class TestEdgeCases:
|
|
223
|
+
def test_all_tool_calls(self):
|
|
224
|
+
"""Handle conversation that's mostly tool calls."""
|
|
225
|
+
messages = [
|
|
226
|
+
{"role": "system", "content": "System"},
|
|
227
|
+
{"role": "user", "content": "Task"},
|
|
228
|
+
]
|
|
229
|
+
# Add many tool call pairs
|
|
230
|
+
for i in range(5):
|
|
231
|
+
messages.append({
|
|
232
|
+
"role": "assistant",
|
|
233
|
+
"tool_calls": [{"id": f"tc{i}", "function": {"name": "test"}}],
|
|
234
|
+
})
|
|
235
|
+
messages.append({"role": "tool", "tool_call_id": f"tc{i}", "content": f"result{i}"})
|
|
236
|
+
|
|
237
|
+
messages.append({"role": "assistant", "content": "Final"})
|
|
238
|
+
|
|
239
|
+
result = compact_messages(messages, keep_first_n=2, keep_last_n=1)
|
|
240
|
+
|
|
241
|
+
# Should still produce valid output
|
|
242
|
+
assert len(result.messages) > 0
|
|
243
|
+
|
|
244
|
+
# Check no orphaned tool calls
|
|
245
|
+
for i, msg in enumerate(result.messages):
|
|
246
|
+
if msg.get("tool_calls"):
|
|
247
|
+
# Next message should be a tool response
|
|
248
|
+
if i + 1 < len(result.messages):
|
|
249
|
+
# Either next is tool response, or this is at the end
|
|
250
|
+
pass # Structural validity checked by not raising
|
|
251
|
+
|
|
252
|
+
def test_empty_messages(self):
|
|
253
|
+
"""Handle empty message list."""
|
|
254
|
+
result = compact_messages([])
|
|
255
|
+
assert result.messages == []
|
|
256
|
+
assert not result.was_compacted
|
|
257
|
+
|
|
258
|
+
def test_only_system_and_user(self):
|
|
259
|
+
"""Handle minimal conversation."""
|
|
260
|
+
messages = [
|
|
261
|
+
{"role": "system", "content": "System"},
|
|
262
|
+
{"role": "user", "content": "Hello"},
|
|
263
|
+
]
|
|
264
|
+
result = compact_messages(messages, keep_first_n=2, keep_last_n=2)
|
|
265
|
+
assert not result.was_compacted
|
|
266
|
+
assert result.messages == messages
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
class TestPydanticModelMessages:
|
|
270
|
+
"""Test handling of Pydantic model messages (not just dicts)."""
|
|
271
|
+
|
|
272
|
+
def test_estimate_tokens_with_objects(self):
|
|
273
|
+
"""estimate_tokens should handle objects with attributes."""
|
|
274
|
+
class MockMessage:
|
|
275
|
+
def __init__(self, role, content):
|
|
276
|
+
self.role = role
|
|
277
|
+
self.content = content
|
|
278
|
+
|
|
279
|
+
messages = [
|
|
280
|
+
MockMessage("user", "Hello world"),
|
|
281
|
+
MockMessage("assistant", "Hi there!"),
|
|
282
|
+
]
|
|
283
|
+
tokens = estimate_tokens(messages)
|
|
284
|
+
assert tokens > 0
|
|
285
|
+
|
|
286
|
+
def test_should_compact_with_objects(self):
|
|
287
|
+
"""should_compact should handle objects with attributes."""
|
|
288
|
+
class MockMessage:
|
|
289
|
+
def __init__(self, role, content):
|
|
290
|
+
self.role = role
|
|
291
|
+
self.content = content
|
|
292
|
+
|
|
293
|
+
messages = [MockMessage("user", "x" * 4000)]
|
|
294
|
+
# Should not crash
|
|
295
|
+
result = should_compact(messages, max_tokens=500, threshold_pct=0.85)
|
|
296
|
+
assert result is True
|
|
297
|
+
|
|
298
|
+
def test_find_tool_groups_with_objects(self):
|
|
299
|
+
"""find_tool_groups should handle objects with attributes."""
|
|
300
|
+
class MockMessage:
|
|
301
|
+
def __init__(self, role, content=None, tool_calls=None):
|
|
302
|
+
self.role = role
|
|
303
|
+
self.content = content
|
|
304
|
+
self.tool_calls = tool_calls
|
|
305
|
+
|
|
306
|
+
messages = [
|
|
307
|
+
MockMessage("user", "Task"),
|
|
308
|
+
MockMessage("assistant", "Done"),
|
|
309
|
+
]
|
|
310
|
+
# Should not crash
|
|
311
|
+
groups = find_tool_groups(messages)
|
|
312
|
+
assert groups == []
|