zwarm 0.1.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
zwarm/core/compact.py ADDED
@@ -0,0 +1,329 @@
1
+ """
2
+ Message compaction for context window management.
3
+
4
+ Safely prunes old messages while preserving:
5
+ - System prompt and initial user task
6
+ - Tool call/response pairs (never orphaned)
7
+ - Recent conversation context
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import logging
13
+ from dataclasses import dataclass
14
+ from typing import Any
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def _get_attr(obj: Any, key: str, default: Any = None) -> Any:
20
+ """Get attribute from dict or object (handles both Pydantic models and dicts)."""
21
+ if isinstance(obj, dict):
22
+ return obj.get(key, default)
23
+ return getattr(obj, key, default)
24
+
25
+
26
+ @dataclass
27
+ class CompactionResult:
28
+ """Result of a compaction operation."""
29
+
30
+ messages: list[dict[str, Any]]
31
+ removed_count: int
32
+ original_count: int
33
+ preserved_reason: str | None = None
34
+
35
+ @property
36
+ def was_compacted(self) -> bool:
37
+ return self.removed_count > 0
38
+
39
+
40
+ def estimate_tokens(messages: list[Any]) -> int:
41
+ """
42
+ Rough token estimate for messages.
43
+
44
+ Uses ~4 chars per token as a simple heuristic.
45
+ This is intentionally conservative.
46
+ Handles both dict messages and Pydantic model messages.
47
+ """
48
+ total_chars = 0
49
+ for msg in messages:
50
+ content = _get_attr(msg, "content", "")
51
+ if isinstance(content, str):
52
+ total_chars += len(content)
53
+ elif isinstance(content, list):
54
+ # Anthropic-style content blocks
55
+ for block in content:
56
+ if isinstance(block, dict):
57
+ total_chars += len(str(block.get("text", "")))
58
+ total_chars += len(str(block.get("input", "")))
59
+ elif isinstance(block, str):
60
+ total_chars += len(block)
61
+ else:
62
+ # Pydantic model block
63
+ total_chars += len(str(_get_attr(block, "text", "")))
64
+ total_chars += len(str(_get_attr(block, "input", "")))
65
+
66
+ # Tool calls add tokens too
67
+ tool_calls = _get_attr(msg, "tool_calls", []) or []
68
+ for tc in tool_calls:
69
+ func = _get_attr(tc, "function", {}) or {}
70
+ args = _get_attr(func, "arguments", "") if isinstance(func, dict) else getattr(func, "arguments", "")
71
+ total_chars += len(str(args))
72
+
73
+ return total_chars // 4
74
+
75
+
76
+ def find_tool_groups(messages: list[Any]) -> list[tuple[int, int]]:
77
+ """
78
+ Find message index ranges that form tool call groups.
79
+
80
+ A tool call group is:
81
+ - An assistant message with tool_calls
82
+ - All following tool/user response messages until the next assistant message
83
+
84
+ This handles both OpenAI format (role="tool") and Anthropic format
85
+ (role="user" with tool_result content).
86
+ Also handles Pydantic model messages.
87
+
88
+ Returns list of (start_idx, end_idx) tuples (inclusive).
89
+ """
90
+ groups = []
91
+ i = 0
92
+
93
+ while i < len(messages):
94
+ msg = messages[i]
95
+
96
+ # Check for tool calls in assistant message
97
+ has_tool_calls = False
98
+
99
+ # OpenAI format: tool_calls field
100
+ if _get_attr(msg, "role") == "assistant" and _get_attr(msg, "tool_calls"):
101
+ has_tool_calls = True
102
+
103
+ # Anthropic format: content blocks with type="tool_use"
104
+ if _get_attr(msg, "role") == "assistant":
105
+ content = _get_attr(msg, "content", [])
106
+ if isinstance(content, list):
107
+ for block in content:
108
+ block_type = _get_attr(block, "type", None)
109
+ if block_type == "tool_use":
110
+ has_tool_calls = True
111
+ break
112
+
113
+ if has_tool_calls:
114
+ start = i
115
+ j = i + 1
116
+
117
+ # Find all following tool responses
118
+ while j < len(messages):
119
+ next_msg = messages[j]
120
+ role = _get_attr(next_msg, "role", "")
121
+
122
+ # OpenAI format: tool role
123
+ if role == "tool":
124
+ j += 1
125
+ continue
126
+
127
+ # Anthropic format: user message with tool_result
128
+ if role == "user":
129
+ content = _get_attr(next_msg, "content", [])
130
+ if isinstance(content, list):
131
+ has_tool_result = any(
132
+ _get_attr(b, "type", None) == "tool_result"
133
+ for b in content
134
+ )
135
+ if has_tool_result:
136
+ j += 1
137
+ continue
138
+
139
+ # Not a tool response, stop here
140
+ break
141
+
142
+ groups.append((start, j - 1))
143
+ i = j
144
+ else:
145
+ i += 1
146
+
147
+ return groups
148
+
149
+
150
+ def compact_messages(
151
+ messages: list[Any],
152
+ keep_first_n: int = 2,
153
+ keep_last_n: int = 10,
154
+ max_tokens: int | None = None,
155
+ target_token_pct: float = 0.7,
156
+ ) -> CompactionResult:
157
+ """
158
+ Compact message history by removing old messages (LRU-style).
159
+
160
+ Preserves:
161
+ - First N messages (system prompt, user task)
162
+ - Last N messages (recent context)
163
+ - Tool call/response pairs are NEVER split
164
+
165
+ Args:
166
+ messages: The message list to compact
167
+ keep_first_n: Number of messages to always keep at the start
168
+ keep_last_n: Number of messages to always keep at the end
169
+ max_tokens: If set, compact when estimated tokens exceed this
170
+ target_token_pct: Target percentage of max_tokens after compaction
171
+
172
+ Returns:
173
+ CompactionResult with the compacted messages and stats
174
+ """
175
+ original_count = len(messages)
176
+
177
+ # Nothing to compact if we have few messages
178
+ if len(messages) <= keep_first_n + keep_last_n:
179
+ return CompactionResult(
180
+ messages=messages,
181
+ removed_count=0,
182
+ original_count=original_count,
183
+ preserved_reason="Too few messages to compact",
184
+ )
185
+
186
+ # Check if compaction is needed based on tokens
187
+ if max_tokens:
188
+ current_tokens = estimate_tokens(messages)
189
+ if current_tokens < max_tokens:
190
+ return CompactionResult(
191
+ messages=messages,
192
+ removed_count=0,
193
+ original_count=original_count,
194
+ preserved_reason=f"Under token limit ({current_tokens}/{max_tokens})",
195
+ )
196
+
197
+ # Find tool call groups (these must stay together)
198
+ tool_groups = find_tool_groups(messages)
199
+
200
+ # Build a set of "protected" indices (in tool groups)
201
+ protected_indices: set[int] = set()
202
+ for start, end in tool_groups:
203
+ for idx in range(start, end + 1):
204
+ protected_indices.add(idx)
205
+
206
+ # Determine which messages are in the "middle" (candidates for removal)
207
+ # Middle = not in first N, not in last N
208
+ middle_start = keep_first_n
209
+ middle_end = len(messages) - keep_last_n
210
+
211
+ if middle_start >= middle_end:
212
+ return CompactionResult(
213
+ messages=messages,
214
+ removed_count=0,
215
+ original_count=original_count,
216
+ preserved_reason="No middle messages to remove",
217
+ )
218
+
219
+ # Find removable message ranges in the middle
220
+ # We remove from the oldest (lowest index) first
221
+ removable_ranges: list[tuple[int, int]] = []
222
+ i = middle_start
223
+
224
+ while i < middle_end:
225
+ # Check if this index is in a tool group
226
+ in_group = False
227
+ for start, end in tool_groups:
228
+ if start <= i <= end:
229
+ # This message is part of a tool group
230
+ # Check if the ENTIRE group is in the middle
231
+ if start >= middle_start and end < middle_end:
232
+ # Entire group is removable as a unit
233
+ removable_ranges.append((start, end))
234
+ i = end + 1
235
+ in_group = True
236
+ break
237
+ else:
238
+ # Group spans protected region, skip it entirely
239
+ i = end + 1
240
+ in_group = True
241
+ break
242
+
243
+ if not in_group:
244
+ # Single message, can be removed individually
245
+ removable_ranges.append((i, i))
246
+ i += 1
247
+
248
+ # Deduplicate and sort ranges
249
+ removable_ranges = sorted(set(removable_ranges), key=lambda x: x[0])
250
+
251
+ if not removable_ranges:
252
+ return CompactionResult(
253
+ messages=messages,
254
+ removed_count=0,
255
+ original_count=original_count,
256
+ preserved_reason="All middle messages are in protected tool groups",
257
+ )
258
+
259
+ # Determine how many to remove
260
+ # Start by removing the oldest half of removable ranges
261
+ if max_tokens:
262
+ # Token-based: remove until under target
263
+ target_tokens = int(max_tokens * target_token_pct)
264
+ indices_to_remove: set[int] = set()
265
+
266
+ for start, end in removable_ranges:
267
+ for idx in range(start, end + 1):
268
+ indices_to_remove.add(idx)
269
+
270
+ # Check if we've removed enough
271
+ remaining = [m for i, m in enumerate(messages) if i not in indices_to_remove]
272
+ if estimate_tokens(remaining) <= target_tokens:
273
+ break
274
+ else:
275
+ # Count-based: remove oldest half of middle
276
+ total_removable = sum(end - start + 1 for start, end in removable_ranges)
277
+ target_remove = total_removable // 2
278
+
279
+ indices_to_remove = set()
280
+ removed = 0
281
+
282
+ for start, end in removable_ranges:
283
+ if removed >= target_remove:
284
+ break
285
+ for idx in range(start, end + 1):
286
+ indices_to_remove.add(idx)
287
+ removed += 1
288
+
289
+ # Build new message list
290
+ new_messages = [m for i, m in enumerate(messages) if i not in indices_to_remove]
291
+
292
+ # Add a compaction marker so the model knows history was truncated
293
+ if indices_to_remove and len(new_messages) > keep_first_n:
294
+ # Insert marker after the preserved first messages
295
+ marker = {
296
+ "role": "system",
297
+ "content": (
298
+ f"[Context compacted: {len(indices_to_remove)} older messages removed "
299
+ f"to manage context window. Conversation continues below.]"
300
+ ),
301
+ }
302
+ new_messages.insert(keep_first_n, marker)
303
+
304
+ logger.info(
305
+ f"Compacted messages: {original_count} -> {len(new_messages)} "
306
+ f"(removed {len(indices_to_remove)})"
307
+ )
308
+
309
+ return CompactionResult(
310
+ messages=new_messages,
311
+ removed_count=len(indices_to_remove),
312
+ original_count=original_count,
313
+ )
314
+
315
+
316
+ def should_compact(
317
+ messages: list[Any],
318
+ max_tokens: int,
319
+ threshold_pct: float = 0.85,
320
+ ) -> bool:
321
+ """
322
+ Check if messages should be compacted.
323
+
324
+ Returns True if estimated tokens exceed threshold percentage of max.
325
+ Handles both dict messages and Pydantic model messages.
326
+ """
327
+ current = estimate_tokens(messages)
328
+ threshold = int(max_tokens * threshold_pct)
329
+ return current >= threshold
zwarm/core/config.py CHANGED
@@ -38,6 +38,18 @@ class ExecutorConfig:
38
38
  timeout: int = 3600
39
39
 
40
40
 
41
+ @dataclass
42
+ class CompactionConfig:
43
+ """Configuration for context window compaction."""
44
+
45
+ enabled: bool = True
46
+ max_tokens: int = 100000 # Trigger compaction when estimated tokens exceed this
47
+ threshold_pct: float = 0.85 # Compact when at this % of max_tokens
48
+ target_pct: float = 0.7 # Target this % after compaction
49
+ keep_first_n: int = 2 # Always keep first N messages (system + task)
50
+ keep_last_n: int = 10 # Always keep last N messages (recent context)
51
+
52
+
41
53
  @dataclass
42
54
  class OrchestratorConfig:
43
55
  """Configuration for the orchestrator."""
@@ -48,6 +60,7 @@ class OrchestratorConfig:
48
60
  max_steps: int = 50
49
61
  parallel_delegations: int = 4
50
62
  sync_first: bool = True # prefer sync mode by default
63
+ compaction: CompactionConfig = field(default_factory=CompactionConfig)
51
64
 
52
65
 
53
66
  @dataclass
@@ -88,19 +101,40 @@ class ZwarmConfig:
88
101
  orchestrator_data = data.get("orchestrator", {})
89
102
  watchers_data = data.get("watchers", {})
90
103
 
91
- # Parse watchers config
92
- watchers_config = WatchersConfig(
93
- enabled=watchers_data.get("enabled", True),
94
- watchers=[
95
- WatcherConfigItem(**w) if isinstance(w, dict) else w
96
- for w in watchers_data.get("watchers", [])
97
- ] or WatchersConfig().watchers,
98
- )
104
+ # Parse compaction config from orchestrator
105
+ compaction_data = orchestrator_data.pop("compaction", {}) if orchestrator_data else {}
106
+ compaction_config = CompactionConfig(**compaction_data) if compaction_data else CompactionConfig()
107
+
108
+ # Parse watchers config - handle both list shorthand and dict format
109
+ if isinstance(watchers_data, list):
110
+ # Shorthand: watchers: [progress, budget, scope]
111
+ watchers_config = WatchersConfig(
112
+ enabled=True,
113
+ watchers=[
114
+ WatcherConfigItem(name=w) if isinstance(w, str) else WatcherConfigItem(**w)
115
+ for w in watchers_data
116
+ ],
117
+ )
118
+ else:
119
+ # Full format: watchers: {enabled: true, watchers: [...]}
120
+ watchers_config = WatchersConfig(
121
+ enabled=watchers_data.get("enabled", True),
122
+ watchers=[
123
+ WatcherConfigItem(name=w) if isinstance(w, str) else WatcherConfigItem(**w)
124
+ for w in watchers_data.get("watchers", [])
125
+ ] or WatchersConfig().watchers,
126
+ )
127
+
128
+ # Build orchestrator config with nested compaction
129
+ if orchestrator_data:
130
+ orchestrator_config = OrchestratorConfig(**orchestrator_data, compaction=compaction_config)
131
+ else:
132
+ orchestrator_config = OrchestratorConfig(compaction=compaction_config)
99
133
 
100
134
  return cls(
101
135
  weave=WeaveConfig(**weave_data) if weave_data else WeaveConfig(),
102
136
  executor=ExecutorConfig(**executor_data) if executor_data else ExecutorConfig(),
103
- orchestrator=OrchestratorConfig(**orchestrator_data) if orchestrator_data else OrchestratorConfig(),
137
+ orchestrator=orchestrator_config,
104
138
  watchers=watchers_config,
105
139
  state_dir=data.get("state_dir", ".zwarm"),
106
140
  )
@@ -125,6 +159,14 @@ class ZwarmConfig:
125
159
  "max_steps": self.orchestrator.max_steps,
126
160
  "parallel_delegations": self.orchestrator.parallel_delegations,
127
161
  "sync_first": self.orchestrator.sync_first,
162
+ "compaction": {
163
+ "enabled": self.orchestrator.compaction.enabled,
164
+ "max_tokens": self.orchestrator.compaction.max_tokens,
165
+ "threshold_pct": self.orchestrator.compaction.threshold_pct,
166
+ "target_pct": self.orchestrator.compaction.target_pct,
167
+ "keep_first_n": self.orchestrator.compaction.keep_first_n,
168
+ "keep_last_n": self.orchestrator.compaction.keep_last_n,
169
+ },
128
170
  },
129
171
  "watchers": {
130
172
  "enabled": self.watchers.enabled,
zwarm/core/environment.py CHANGED
@@ -4,14 +4,15 @@ OrchestratorEnv: A lean environment for the zwarm orchestrator.
4
4
  Unlike ChatEnv, this environment:
5
5
  - Has no notes/observations (we use StateManager instead)
6
6
  - Has no chat() tool (orchestrator communicates via output_handler)
7
- - Shows active sessions in observe() for context
7
+ - Shows active sessions, step progress, and budget in observe()
8
8
  """
9
9
 
10
10
  from __future__ import annotations
11
11
 
12
12
  from pathlib import Path
13
- from typing import Any, Callable, TYPE_CHECKING
13
+ from typing import TYPE_CHECKING, Any, Callable
14
14
 
15
+ from pydantic import PrivateAttr
15
16
  from wbal.environment import Environment
16
17
 
17
18
  if TYPE_CHECKING:
@@ -26,6 +27,8 @@ class OrchestratorEnv(Environment):
26
27
  - Task context
27
28
  - Working directory info
28
29
  - Active session visibility
30
+ - Step progress tracking
31
+ - Budget/resource monitoring
29
32
  - Output handler for messages
30
33
  """
31
34
 
@@ -34,50 +37,118 @@ class OrchestratorEnv(Environment):
34
37
  output_handler: Callable[[str], None] = lambda x: print(x)
35
38
 
36
39
  # Session tracking (set by orchestrator)
37
- _sessions: dict[str, "ConversationSession"] | None = None
40
+ _sessions: dict[str, "ConversationSession"] | None = PrivateAttr(default=None)
41
+
42
+ # Progress tracking (updated by orchestrator each step)
43
+ _step_count: int = PrivateAttr(default=0)
44
+ _max_steps: int = PrivateAttr(default=50)
45
+ _total_tokens: int = PrivateAttr(default=0)
46
+ _executor_tokens: int = PrivateAttr(default=0) # Executor token usage
47
+
48
+ # Budget config (set from config)
49
+ _budget_max_sessions: int | None = PrivateAttr(default=None)
38
50
 
39
51
  def set_sessions(self, sessions: dict[str, "ConversationSession"]) -> None:
40
52
  """Set the sessions dict for observe() visibility."""
41
53
  self._sessions = sessions
42
54
 
55
+ def update_progress(
56
+ self,
57
+ step_count: int,
58
+ max_steps: int,
59
+ total_tokens: int = 0,
60
+ executor_tokens: int = 0,
61
+ ) -> None:
62
+ """Update progress tracking (called by orchestrator each step)."""
63
+ self._step_count = step_count
64
+ self._max_steps = max_steps
65
+ self._total_tokens = total_tokens
66
+ self._executor_tokens = executor_tokens
67
+
68
+ def set_budget(self, max_sessions: int | None = None) -> None:
69
+ """Set budget limits from config."""
70
+ self._budget_max_sessions = max_sessions
71
+
43
72
  def observe(self) -> str:
44
73
  """
45
74
  Return observable state for the orchestrator.
46
75
 
47
76
  Shows:
48
- - Current task
49
- - Working directory
77
+ - Progress (steps, tokens)
78
+ - Session summary
50
79
  - Active sessions with their status
80
+ - Working directory
81
+
82
+ Note: Task is NOT included here as it's already in the user message.
51
83
  """
52
84
  parts = []
53
85
 
54
- # Task
55
- if self.task:
56
- parts.append(f"## Current Task\n{self.task}")
57
-
58
- # Working directory
59
- parts.append(f"## Working Directory\n{self.working_dir.absolute()}")
60
-
61
- # Active sessions
62
- if self._sessions:
63
- session_lines = []
64
- for sid, session in self._sessions.items():
65
- status_icon = {
66
- "active": "[ACTIVE]",
67
- "completed": "[DONE]",
68
- "failed": "[FAILED]",
69
- }.get(session.status.value, "[?]")
70
-
71
- mode_icon = "sync" if session.mode.value == "sync" else "async"
72
- task_preview = session.task_description[:60] + "..." if len(session.task_description) > 60 else session.task_description
73
-
74
- session_lines.append(
75
- f" - {sid[:8]}... {status_icon} ({mode_icon}, {session.adapter}) {task_preview}"
76
- )
77
-
78
- if session_lines:
79
- parts.append("## Sessions\n" + "\n".join(session_lines))
80
- else:
81
- parts.append("## Sessions\n (none)")
86
+ # Progress bar and stats
87
+ progress_pct = (
88
+ (self._step_count / self._max_steps * 100) if self._max_steps > 0 else 0
89
+ )
90
+ bar_len = 20
91
+ filled = (
92
+ int(bar_len * self._step_count / self._max_steps)
93
+ if self._max_steps > 0
94
+ else 0
95
+ )
96
+ bar = "█" * filled + "░" * (bar_len - filled)
97
+
98
+ progress_lines = [
99
+ f"Steps: [{bar}] {self._step_count}/{self._max_steps} ({progress_pct:.0f}%)",
100
+ ]
101
+ if self._total_tokens > 0 or self._executor_tokens > 0:
102
+ token_parts = []
103
+ if self._total_tokens > 0:
104
+ token_parts.append(f"orchestrator: ~{self._total_tokens:,}")
105
+ if self._executor_tokens > 0:
106
+ token_parts.append(f"executors: ~{self._executor_tokens:,}")
107
+ progress_lines.append(f"Tokens: {', '.join(token_parts)}")
108
+
109
+ parts.append("## Progress\n" + "\n".join(progress_lines))
110
+
111
+ # Session summary
112
+ if self._sessions is not None:
113
+ active = sum(
114
+ 1 for s in self._sessions.values() if s.status.value == "active"
115
+ )
116
+ completed = sum(
117
+ 1 for s in self._sessions.values() if s.status.value == "completed"
118
+ )
119
+ failed = sum(
120
+ 1 for s in self._sessions.values() if s.status.value == "failed"
121
+ )
122
+ total = len(self._sessions)
123
+
124
+ summary = f"Sessions: {active} active, {completed} done, {failed} failed ({total} total)"
125
+ if self._budget_max_sessions:
126
+ summary += f" [limit: {self._budget_max_sessions}]"
127
+
128
+ parts.append(f"## Resources\n{summary}")
129
+
130
+ # Active sessions detail
131
+ active_sessions = [
132
+ (sid, s)
133
+ for sid, s in self._sessions.items()
134
+ if s.status.value == "active"
135
+ ]
136
+ if active_sessions:
137
+ session_lines = []
138
+ for sid, session in active_sessions:
139
+ mode_tag = "sync" if session.mode.value == "sync" else "async"
140
+ turns = len([m for m in session.messages if m.role == "user"])
141
+ task_preview = (
142
+ session.task_description[:50] + "..."
143
+ if len(session.task_description) > 50
144
+ else session.task_description
145
+ )
146
+ session_lines.append(
147
+ f"\n • {sid[:8]} ({session.adapter}, {mode_tag}, {turns} turns): {task_preview}"
148
+ )
149
+ parts.append("## Active Sessions\n" + "\n".join(session_lines))
150
+
151
+ # Working directory (less prominent)
152
+ parts.append(f"## Context\nWorking dir: {self.working_dir.absolute()}")
82
153
 
83
154
  return "\n\n".join(parts)
zwarm/core/models.py CHANGED
@@ -92,6 +92,20 @@ class ConversationSession:
92
92
  model: str | None = None
93
93
  exit_message: str | None = None
94
94
 
95
+ # Token usage tracking for cost calculation
96
+ token_usage: dict[str, int] = field(default_factory=lambda: {
97
+ "input_tokens": 0,
98
+ "output_tokens": 0,
99
+ "total_tokens": 0,
100
+ })
101
+
102
+ def add_usage(self, usage: dict[str, int]) -> None:
103
+ """Add token usage from an interaction."""
104
+ if not usage:
105
+ return
106
+ for key in self.token_usage:
107
+ self.token_usage[key] += usage.get(key, 0)
108
+
95
109
  def add_message(self, role: Literal["user", "assistant", "system"], content: str) -> Message:
96
110
  """Add a message to the conversation."""
97
111
  msg = Message(role=role, content=content)
@@ -125,6 +139,7 @@ class ConversationSession:
125
139
  "task_description": self.task_description,
126
140
  "model": self.model,
127
141
  "exit_message": self.exit_message,
142
+ "token_usage": self.token_usage,
128
143
  }
129
144
 
130
145
  @classmethod
@@ -143,6 +158,7 @@ class ConversationSession:
143
158
  task_description=data.get("task_description", ""),
144
159
  model=data.get("model"),
145
160
  exit_message=data.get("exit_message"),
161
+ token_usage=data.get("token_usage", {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}),
146
162
  )
147
163
 
148
164