zwarm 2.3.5__py3-none-any.whl → 3.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
zwarm/core/__init__.py CHANGED
@@ -0,0 +1,20 @@
1
+ """Core primitives for zwarm."""
2
+
3
+ from .checkpoints import Checkpoint, CheckpointManager
4
+ from .costs import (
5
+ estimate_cost,
6
+ estimate_session_cost,
7
+ format_cost,
8
+ get_pricing,
9
+ ModelPricing,
10
+ )
11
+
12
+ __all__ = [
13
+ "Checkpoint",
14
+ "CheckpointManager",
15
+ "estimate_cost",
16
+ "estimate_session_cost",
17
+ "format_cost",
18
+ "get_pricing",
19
+ "ModelPricing",
20
+ ]
@@ -0,0 +1,216 @@
1
+ """
2
+ Checkpoint primitives for state management.
3
+
4
+ Provides time-travel capability by recording snapshots of state at key points.
5
+ Used by pilot for turn-by-turn checkpointing, and potentially by other
6
+ interfaces that need state restoration.
7
+
8
+ Topology reminder:
9
+ orchestrator → pilot → interactive → CodexSessionManager
10
+
11
+ These primitives sit at the core layer, usable by any interface above.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import copy
17
+ from dataclasses import dataclass, field
18
+ from datetime import datetime
19
+ from typing import Any
20
+
21
+
22
+ @dataclass
23
+ class Checkpoint:
24
+ """
25
+ A snapshot of state at a specific point in time.
26
+
27
+ Attributes:
28
+ checkpoint_id: Unique identifier (e.g., turn number)
29
+ label: Human-readable label (e.g., "T1", "T2")
30
+ description: What action led to this state
31
+ state: The actual state snapshot (deep-copied)
32
+ timestamp: When checkpoint was created
33
+ metadata: Optional extra data
34
+ """
35
+ checkpoint_id: int
36
+ label: str
37
+ description: str
38
+ state: dict[str, Any]
39
+ timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
40
+ metadata: dict[str, Any] = field(default_factory=dict)
41
+
42
+
43
+ @dataclass
44
+ class CheckpointManager:
45
+ """
46
+ Manages checkpoints and time travel.
47
+
48
+ Maintains a list of checkpoints and a current position. Supports:
49
+ - Recording new checkpoints
50
+ - Jumping to any previous checkpoint
51
+ - Branching (going back and continuing creates new timeline)
52
+ - History inspection
53
+
54
+ Usage:
55
+ mgr = CheckpointManager()
56
+
57
+ # Record state after each action
58
+ mgr.record(description="Added auth", state={"messages": [...], ...})
59
+ mgr.record(description="Fixed bug", state={"messages": [...], ...})
60
+
61
+ # Jump back
62
+ cp = mgr.goto(1) # Go to first checkpoint
63
+ restored_state = cp.state
64
+
65
+ # Continue from there (branches off)
66
+ mgr.record(description="Different path", state={...})
67
+ """
68
+
69
+ checkpoints: list[Checkpoint] = field(default_factory=list)
70
+ current_index: int = -1 # -1 = root (before any checkpoints)
71
+ next_id: int = 1
72
+ label_prefix: str = "T" # Labels will be T1, T2, etc.
73
+
74
+ def record(
75
+ self,
76
+ description: str,
77
+ state: dict[str, Any],
78
+ metadata: dict[str, Any] | None = None,
79
+ ) -> Checkpoint:
80
+ """
81
+ Record a new checkpoint.
82
+
83
+ If not at the end of history (i.e., we've gone back), this creates
84
+ a branch - future checkpoints are discarded.
85
+
86
+ Args:
87
+ description: What action led to this state
88
+ state: State to snapshot (will be deep-copied)
89
+ metadata: Optional extra data
90
+
91
+ Returns:
92
+ The created checkpoint
93
+ """
94
+ checkpoint = Checkpoint(
95
+ checkpoint_id=self.next_id,
96
+ label=f"{self.label_prefix}{self.next_id}",
97
+ description=description,
98
+ state=copy.deepcopy(state),
99
+ metadata=metadata or {},
100
+ )
101
+
102
+ # If we're not at the end, we're branching - truncate future
103
+ if self.current_index < len(self.checkpoints) - 1:
104
+ self.checkpoints = self.checkpoints[:self.current_index + 1]
105
+
106
+ self.checkpoints.append(checkpoint)
107
+ self.current_index = len(self.checkpoints) - 1
108
+ self.next_id += 1
109
+
110
+ return checkpoint
111
+
112
+ def goto(self, checkpoint_id: int) -> Checkpoint | None:
113
+ """
114
+ Jump to a specific checkpoint.
115
+
116
+ Args:
117
+ checkpoint_id: The checkpoint ID to jump to (0 = root)
118
+
119
+ Returns:
120
+ The checkpoint, or None if not found (or root)
121
+ """
122
+ if checkpoint_id == 0:
123
+ # Root state - before any checkpoints
124
+ self.current_index = -1
125
+ return None
126
+
127
+ for i, cp in enumerate(self.checkpoints):
128
+ if cp.checkpoint_id == checkpoint_id:
129
+ self.current_index = i
130
+ return cp
131
+
132
+ return None # Not found
133
+
134
+ def goto_label(self, label: str) -> Checkpoint | None:
135
+ """
136
+ Jump to a checkpoint by label (e.g., "T1", "root").
137
+
138
+ Args:
139
+ label: The label to find
140
+
141
+ Returns:
142
+ The checkpoint, or None if not found
143
+ """
144
+ if label.lower() == "root":
145
+ self.current_index = -1
146
+ return None
147
+
148
+ for i, cp in enumerate(self.checkpoints):
149
+ if cp.label == label:
150
+ self.current_index = i
151
+ return cp
152
+
153
+ return None
154
+
155
+ def current(self) -> Checkpoint | None:
156
+ """Get the current checkpoint, or None if at root."""
157
+ if self.current_index < 0 or self.current_index >= len(self.checkpoints):
158
+ return None
159
+ return self.checkpoints[self.current_index]
160
+
161
+ def current_state(self) -> dict[str, Any] | None:
162
+ """Get the current state, or None if at root."""
163
+ cp = self.current()
164
+ return copy.deepcopy(cp.state) if cp else None
165
+
166
+ def history(
167
+ self,
168
+ limit: int | None = None,
169
+ include_state: bool = False,
170
+ ) -> list[dict[str, Any]]:
171
+ """
172
+ Get history entries for display.
173
+
174
+ Args:
175
+ limit: Max entries to return (most recent)
176
+ include_state: Whether to include full state in entries
177
+
178
+ Returns:
179
+ List of history entries with checkpoint info
180
+ """
181
+ entries = []
182
+ for i, cp in enumerate(self.checkpoints):
183
+ entry = {
184
+ "checkpoint_id": cp.checkpoint_id,
185
+ "label": cp.label,
186
+ "description": cp.description,
187
+ "timestamp": cp.timestamp,
188
+ "is_current": i == self.current_index,
189
+ "metadata": cp.metadata,
190
+ }
191
+ if include_state:
192
+ entry["state"] = cp.state
193
+ entries.append(entry)
194
+
195
+ if limit:
196
+ entries = entries[-limit:]
197
+
198
+ return entries
199
+
200
+ def label_for(self, checkpoint_id: int) -> str:
201
+ """Get label for a checkpoint ID."""
202
+ if checkpoint_id == 0:
203
+ return "root"
204
+ return f"{self.label_prefix}{checkpoint_id}"
205
+
206
+ def __len__(self) -> int:
207
+ """Number of checkpoints."""
208
+ return len(self.checkpoints)
209
+
210
+ def is_at_root(self) -> bool:
211
+ """Whether we're at root (before any checkpoints)."""
212
+ return self.current_index < 0
213
+
214
+ def is_at_end(self) -> bool:
215
+ """Whether we're at the most recent checkpoint."""
216
+ return self.current_index == len(self.checkpoints) - 1
zwarm/core/config.py CHANGED
@@ -37,6 +37,7 @@ class ExecutorConfig:
37
37
  sandbox: str = "workspace-write" # read-only | workspace-write | danger-full-access
38
38
  timeout: int = 3600
39
39
  reasoning_effort: str | None = "high" # low | medium | high (default to high for compatibility)
40
+ # Note: web_search is always enabled via .codex/config.toml (set up by `zwarm init`)
40
41
 
41
42
 
42
43
  @dataclass
@@ -59,8 +60,8 @@ class OrchestratorConfig:
59
60
  prompt: str | None = None # path to prompt yaml
60
61
  tools: list[str] = field(default_factory=lambda: ["delegate", "converse", "check_session", "end_session", "bash"])
61
62
  max_steps: int = 50
63
+ max_steps_per_turn: int = 60 # Max tool-call steps before returning to user (pilot mode)
62
64
  parallel_delegations: int = 4
63
- sync_first: bool = True # prefer sync mode by default
64
65
  compaction: CompactionConfig = field(default_factory=CompactionConfig)
65
66
 
66
67
  # Directory restrictions for agent delegations
@@ -172,8 +173,8 @@ class ZwarmConfig:
172
173
  "prompt": self.orchestrator.prompt,
173
174
  "tools": self.orchestrator.tools,
174
175
  "max_steps": self.orchestrator.max_steps,
176
+ "max_steps_per_turn": self.orchestrator.max_steps_per_turn,
175
177
  "parallel_delegations": self.orchestrator.parallel_delegations,
176
- "sync_first": self.orchestrator.sync_first,
177
178
  "compaction": {
178
179
  "enabled": self.orchestrator.compaction.enabled,
179
180
  "max_tokens": self.orchestrator.compaction.max_tokens,
@@ -195,15 +196,16 @@ class ZwarmConfig:
195
196
  }
196
197
 
197
198
 
198
- def load_env(path: Path | None = None) -> None:
199
+ def load_env(path: Path | None = None, base_dir: Path | None = None) -> None:
199
200
  """Load .env file if it exists."""
200
201
  if path is None:
201
- path = Path.cwd() / ".env"
202
+ base = base_dir or Path.cwd()
203
+ path = base / ".env"
202
204
  if path.exists():
203
205
  load_dotenv(path)
204
206
 
205
207
 
206
- def load_toml_config(path: Path | None = None) -> dict[str, Any]:
208
+ def load_toml_config(path: Path | None = None, base_dir: Path | None = None) -> dict[str, Any]:
207
209
  """
208
210
  Load config.toml file.
209
211
 
@@ -211,11 +213,16 @@ def load_toml_config(path: Path | None = None) -> dict[str, Any]:
211
213
  1. Explicit path (if provided)
212
214
  2. .zwarm/config.toml (new standard location)
213
215
  3. config.toml (legacy location for backwards compat)
216
+
217
+ Args:
218
+ path: Explicit path to config.toml
219
+ base_dir: Base directory to search in (defaults to cwd)
214
220
  """
215
221
  if path is None:
222
+ base = base_dir or Path.cwd()
216
223
  # Try new location first
217
- new_path = Path.cwd() / ".zwarm" / "config.toml"
218
- legacy_path = Path.cwd() / "config.toml"
224
+ new_path = base / ".zwarm" / "config.toml"
225
+ legacy_path = base / "config.toml"
219
226
  if new_path.exists():
220
227
  path = new_path
221
228
  elif legacy_path.exists():
@@ -306,6 +313,7 @@ def load_config(
306
313
  toml_path: Path | None = None,
307
314
  env_path: Path | None = None,
308
315
  overrides: list[str] | None = None,
316
+ working_dir: Path | None = None,
309
317
  ) -> ZwarmConfig:
310
318
  """
311
319
  Load configuration with full precedence chain:
@@ -314,15 +322,24 @@ def load_config(
314
322
  3. YAML config file (if provided)
315
323
  4. CLI overrides (--set key=value)
316
324
  5. Environment variables (for secrets)
325
+
326
+ Args:
327
+ config_path: Path to YAML config file
328
+ toml_path: Explicit path to config.toml
329
+ env_path: Explicit path to .env file
330
+ overrides: CLI overrides (--set key=value)
331
+ working_dir: Working directory to search for config files (defaults to cwd).
332
+ This is important when using --working-dir flag to ensure
333
+ config is loaded from the project directory, not invoke directory.
317
334
  """
318
335
  # Load .env first (for secrets)
319
- load_env(env_path)
336
+ load_env(env_path, base_dir=working_dir)
320
337
 
321
338
  # Start with defaults
322
339
  config_dict: dict[str, Any] = {}
323
340
 
324
341
  # Layer in config.toml
325
- toml_config = load_toml_config(toml_path)
342
+ toml_config = load_toml_config(toml_path, base_dir=working_dir)
326
343
  if toml_config:
327
344
  config_dict = deep_merge(config_dict, toml_config)
328
345
 
zwarm/core/costs.py ADDED
@@ -0,0 +1,71 @@
1
+ """
2
+ Token cost estimation for LLM models.
3
+
4
+ This module re-exports from the centralized model registry.
5
+ For adding new models, edit: zwarm/core/registry.py
6
+
7
+ Backwards-compatible API preserved for existing code.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ # Re-export everything from registry for backwards compatibility
13
+ from zwarm.core.registry import (
14
+ ModelInfo,
15
+ MODELS,
16
+ resolve_model,
17
+ get_adapter_for_model,
18
+ get_default_model,
19
+ list_models,
20
+ list_adapters,
21
+ get_models_help_text,
22
+ get_models_table_data,
23
+ estimate_cost,
24
+ format_cost,
25
+ estimate_session_cost,
26
+ )
27
+
28
+ # Backwards compatibility alias
29
+ ModelPricing = ModelInfo
30
+
31
+ # Legacy aliases for backwards compatibility
32
+ MODEL_PRICING = {m.canonical: m for m in MODELS}
33
+ MODEL_ALIASES = {}
34
+ for m in MODELS:
35
+ for alias in m.aliases:
36
+ MODEL_ALIASES[alias] = m.canonical
37
+
38
+
39
+ def get_pricing(model: str) -> ModelInfo | None:
40
+ """
41
+ Get pricing for a model.
42
+
43
+ Args:
44
+ model: Model name or alias
45
+
46
+ Returns:
47
+ ModelInfo or None if unknown
48
+ """
49
+ return resolve_model(model)
50
+
51
+
52
+ __all__ = [
53
+ # New API
54
+ "ModelInfo",
55
+ "MODELS",
56
+ "resolve_model",
57
+ "get_adapter_for_model",
58
+ "get_default_model",
59
+ "list_models",
60
+ "list_adapters",
61
+ "get_models_help_text",
62
+ "get_models_table_data",
63
+ "estimate_cost",
64
+ "format_cost",
65
+ "estimate_session_cost",
66
+ # Legacy API
67
+ "MODEL_PRICING",
68
+ "MODEL_ALIASES",
69
+ "ModelPricing",
70
+ "get_pricing",
71
+ ]