zwarm 3.0.1__py3-none-any.whl → 3.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
zwarm/cli/pilot.py CHANGED
@@ -11,16 +11,133 @@ import copy
11
11
  import json
12
12
  import shlex
13
13
  import sys
14
- from dataclasses import dataclass, field
14
+ import threading
15
+ import time
16
+ from dataclasses import dataclass
15
17
  from pathlib import Path
16
18
  from typing import Any, Callable, Dict, List, Optional
17
19
  from uuid import uuid4
18
20
 
19
21
  from rich.console import Console
20
22
 
23
+ from zwarm.core.checkpoints import CheckpointManager
24
+ from zwarm.core.costs import estimate_session_cost, format_cost, get_pricing
25
+
21
26
  console = Console()
22
27
 
23
28
 
29
+ class ChoogingSpinner:
30
+ """
31
+ A spinner that displays "Chooching" while waiting, adding an 'o' every second.
32
+
33
+ Chooching → Choooching → Chooooching → ...
34
+ """
35
+
36
+ def __init__(self, base_word: str = "Chooching"):
37
+ self._stop_event = threading.Event()
38
+ self._thread: Optional[threading.Thread] = None
39
+ self._base = base_word
40
+ # Find where to insert extra 'o's (after "Ch" and before "ching")
41
+ # "Chooching" -> insert after index 2
42
+ self._prefix = "Ch"
43
+ self._suffix = "ching"
44
+ self._min_o = 2 # Start with "oo"
45
+
46
+ def _spin(self):
47
+ o_count = self._min_o
48
+ while not self._stop_event.is_set():
49
+ word = f"{self._prefix}{'o' * o_count}{self._suffix}"
50
+ # Write with carriage return to overwrite, dim styling
51
+ sys.stdout.write(f"\r\033[2m{word}\033[0m")
52
+ sys.stdout.flush()
53
+ o_count += 1
54
+ # Wait 1 second, but check for stop every 100ms
55
+ for _ in range(10):
56
+ if self._stop_event.is_set():
57
+ break
58
+ time.sleep(0.1)
59
+
60
+ def start(self):
61
+ """Start the spinner in a background thread."""
62
+ self._stop_event.clear()
63
+ self._thread = threading.Thread(target=self._spin, daemon=True)
64
+ self._thread.start()
65
+
66
+ def stop(self):
67
+ """Stop the spinner and clear the line."""
68
+ self._stop_event.set()
69
+ if self._thread:
70
+ self._thread.join(timeout=0.5)
71
+ # Clear the line
72
+ sys.stdout.write("\r\033[K")
73
+ sys.stdout.flush()
74
+
75
+ def __enter__(self):
76
+ self.start()
77
+ return self
78
+
79
+ def __exit__(self, *args):
80
+ self.stop()
81
+
82
+
83
+ # Context window sizes for different models (in tokens)
84
+ MODEL_CONTEXT_WINDOWS = {
85
+ "gpt-5.1-codex": 200_000,
86
+ "gpt-5.1-codex-mini": 200_000,
87
+ "gpt-5.1-codex-max": 400_000,
88
+ "gpt-5": 200_000,
89
+ "gpt-5-mini": 200_000,
90
+ "claude-sonnet-4": 200_000,
91
+ "claude-opus-4": 200_000,
92
+ # Fallback
93
+ "default": 128_000,
94
+ }
95
+
96
+
97
+ def get_context_window(model: str) -> int:
98
+ """Get context window size for a model."""
99
+ model_lower = model.lower()
100
+ for prefix, size in MODEL_CONTEXT_WINDOWS.items():
101
+ if model_lower.startswith(prefix):
102
+ return size
103
+ return MODEL_CONTEXT_WINDOWS["default"]
104
+
105
+
106
+ def render_context_bar(used: int, total: int, width: int = 30) -> str:
107
+ """
108
+ Render a visual context window usage bar.
109
+
110
+ Args:
111
+ used: Tokens used
112
+ total: Total context window
113
+ width: Bar width in characters
114
+
115
+ Returns:
116
+ Colored bar string like: [████████░░░░░░░░░░░░] 40%
117
+ """
118
+ if total <= 0:
119
+ return "[dim]?[/]"
120
+
121
+ pct = min(used / total, 1.0)
122
+ filled = int(pct * width)
123
+ empty = width - filled
124
+
125
+ # Color based on usage
126
+ if pct < 0.5:
127
+ color = "green"
128
+ elif pct < 0.75:
129
+ color = "yellow"
130
+ elif pct < 0.9:
131
+ color = "red"
132
+ else:
133
+ color = "red bold"
134
+
135
+ bar = f"[{color}]{'█' * filled}[/][dim]{'░' * empty}[/]"
136
+ pct_str = f"{pct * 100:.0f}%"
137
+
138
+ return f"{bar} {pct_str}"
139
+
140
+
24
141
  # =============================================================================
25
142
  # Build Pilot Orchestrator
26
143
  # =============================================================================
@@ -386,105 +503,6 @@ class EventRenderer:
386
503
  return self._show_reasoning
387
504
 
388
505
 
389
- # =============================================================================
390
- # Pilot Session State (checkpoints and time travel)
391
- # =============================================================================
392
-
393
-
394
- @dataclass
395
- class Checkpoint:
396
- """A snapshot of orchestrator state at a specific turn."""
397
- turn_id: int
398
- instruction: str # User instruction that led to this state
399
- messages: List[Dict[str, Any]]
400
- sessions_snapshot: Dict[str, Any] # Serialized session state
401
- step_count: int
402
-
403
-
404
- @dataclass
405
- class PilotSessionState:
406
- """
407
- Manages checkpoints and time travel for the pilot REPL.
408
-
409
- Each turn (user instruction + orchestrator response) creates a checkpoint
410
- that can be returned to later.
411
- """
412
-
413
- checkpoints: List[Checkpoint] = field(default_factory=list)
414
- current_index: int = -1 # Index into checkpoints, -1 = root
415
- next_turn_id: int = 1
416
-
417
- def record_turn(
418
- self,
419
- instruction: str,
420
- messages: List[Dict[str, Any]],
421
- sessions_snapshot: Dict[str, Any],
422
- step_count: int,
423
- ) -> Checkpoint:
424
- """Record a new checkpoint after a turn."""
425
- checkpoint = Checkpoint(
426
- turn_id=self.next_turn_id,
427
- instruction=instruction,
428
- messages=copy.deepcopy(messages),
429
- sessions_snapshot=copy.deepcopy(sessions_snapshot),
430
- step_count=step_count,
431
- )
432
-
433
- # If we're not at the end, we're branching - truncate future
434
- if self.current_index < len(self.checkpoints) - 1:
435
- self.checkpoints = self.checkpoints[:self.current_index + 1]
436
-
437
- self.checkpoints.append(checkpoint)
438
- self.current_index = len(self.checkpoints) - 1
439
- self.next_turn_id += 1
440
-
441
- return checkpoint
442
-
443
- def goto_turn(self, turn_id: int) -> Optional[Checkpoint]:
444
- """Jump to a specific turn. Returns the checkpoint or None if not found."""
445
- if turn_id == 0:
446
- # Root state - before any turns
447
- self.current_index = -1
448
- return None
449
-
450
- for i, cp in enumerate(self.checkpoints):
451
- if cp.turn_id == turn_id:
452
- self.current_index = i
453
- return cp
454
-
455
- return None # Not found
456
-
457
- def current_checkpoint(self) -> Optional[Checkpoint]:
458
- """Get the current checkpoint, or None if at root."""
459
- if self.current_index < 0 or self.current_index >= len(self.checkpoints):
460
- return None
461
- return self.checkpoints[self.current_index]
462
-
463
- def turn_label(self, turn_id: int) -> str:
464
- """Format turn ID as T1, T2, etc."""
465
- return f"T{turn_id}" if turn_id > 0 else "root"
466
-
467
- def history_entries(
468
- self,
469
- *,
470
- show_all: bool = False,
471
- limit: Optional[int] = None
472
- ) -> List[Dict[str, Any]]:
473
- """Get history entries for display."""
474
- entries = []
475
- for i, cp in enumerate(self.checkpoints):
476
- is_current = i == self.current_index
477
- entries.append({
478
- "checkpoint": cp,
479
- "is_current": is_current,
480
- })
481
-
482
- if not show_all and limit:
483
- entries = entries[-limit:]
484
-
485
- return entries
486
-
487
-
488
506
  # =============================================================================
489
507
  # Command Parsing
490
508
  # =============================================================================
@@ -655,8 +673,20 @@ def run_until_response(
655
673
  @weave.op(name="pilot_turn")
656
674
  def _run_turn():
657
675
  all_results = []
676
+ spinner = ChoogingSpinner()
677
+
658
678
  for step in range(max_steps):
659
- results, had_message = execute_step_with_events(orchestrator, renderer)
679
+ # Show spinner only for the first step (initial LLM call after user message)
680
+ # Subsequent steps have visible tool activity so no spinner needed
681
+ if step == 0:
682
+ spinner.start()
683
+
684
+ try:
685
+ results, had_message = execute_step_with_events(orchestrator, renderer)
686
+ finally:
687
+ if step == 0:
688
+ spinner.stop()
689
+
660
690
  all_results.extend(results)
661
691
 
662
692
  # Stop if agent produced a message
@@ -687,13 +717,20 @@ def print_help(renderer: EventRenderer) -> None:
687
717
  "",
688
718
  "Commands:",
689
719
  " :help Show this help",
720
+ " :status Show pilot status (tokens, cost, context)",
690
721
  " :history [N|all] Show turn checkpoints",
691
722
  " :goto <turn|root> Jump to a prior turn (e.g., :goto T1)",
692
- " :state Show current orchestrator state",
693
- " :sessions Show active sessions",
723
+ " :sessions Show executor sessions",
694
724
  " :reasoning [on|off] Toggle reasoning display",
695
725
  " :quit / :exit Exit the pilot",
696
726
  "",
727
+ "Multiline input:",
728
+ ' Start with """ and end with """ to enter multiple lines.',
729
+ ' Example: """',
730
+ " paste your",
731
+ " content here",
732
+ ' """',
733
+ "",
697
734
  ]
698
735
  for line in lines:
699
736
  renderer.status(line)
@@ -710,6 +747,8 @@ def get_sessions_snapshot(orchestrator: Any) -> Dict[str, Any]:
710
747
  "status": s.status.value,
711
748
  "task": s.task[:100] if s.task else "",
712
749
  "turns": s.turn,
750
+ "tokens": s.token_usage.get("total_tokens", 0),
751
+ "model": s.model,
713
752
  }
714
753
  for s in sessions
715
754
  ]
@@ -747,7 +786,7 @@ def _run_pilot_repl(
747
786
  The actual REPL implementation.
748
787
  """
749
788
  renderer = EventRenderer(show_reasoning=True)
750
- state = PilotSessionState()
789
+ state = CheckpointManager()
751
790
 
752
791
  # Silence the default output_handler - we render events directly in execute_step_with_events
753
792
  # (Otherwise messages would be rendered twice)
@@ -776,22 +815,28 @@ def _run_pilot_repl(
776
815
  results = run_until_response(orchestrator, renderer)
777
816
 
778
817
  # Record checkpoint
779
- state.record_turn(
780
- instruction=initial_task,
781
- messages=orchestrator.messages,
782
- sessions_snapshot=get_sessions_snapshot(orchestrator),
783
- step_count=orchestrator._step_count,
818
+ state.record(
819
+ description=initial_task,
820
+ state={
821
+ "messages": orchestrator.messages,
822
+ "sessions_snapshot": get_sessions_snapshot(orchestrator),
823
+ "step_count": orchestrator._step_count,
824
+ },
825
+ metadata={
826
+ "step_count": orchestrator._step_count,
827
+ "message_count": len(orchestrator.messages),
828
+ },
784
829
  )
785
830
 
786
- cp = state.current_checkpoint()
831
+ cp = state.current()
787
832
  if cp:
788
833
  renderer.status("")
789
834
  renderer.status(
790
- f"[{state.turn_label(cp.turn_id)}] "
791
- f"step={cp.step_count} "
792
- f"messages={len(cp.messages)}"
835
+ f"[{cp.label}] "
836
+ f"step={cp.state['step_count']} "
837
+ f"messages={len(cp.state['messages'])}"
793
838
  )
794
- renderer.status(f":goto {state.turn_label(cp.turn_id)} to return here")
839
+ renderer.status(f":goto {cp.label} to return here")
795
840
 
796
841
  # Main REPL loop
797
842
  while True:
@@ -808,6 +853,38 @@ def _run_pilot_repl(
808
853
  if not user_input:
809
854
  continue
810
855
 
856
+ # Multiline input: if starts with """, collect until closing """
857
+ if user_input.startswith('"""'):
858
+ # Check if closing """ is on the same line (e.g., """hello""")
859
+ rest = user_input[3:]
860
+ if '"""' in rest:
861
+ # Single line with both opening and closing
862
+ user_input = rest[: rest.index('"""')]
863
+ else:
864
+ # Multiline mode - collect until we see """
865
+ lines = [rest] if rest else []
866
+ try:
867
+ while True:
868
+ line = input("... ")
869
+ if '"""' in line:
870
+ # Found closing quotes
871
+ idx = line.index('"""')
872
+ if idx > 0:
873
+ lines.append(line[:idx])
874
+ break
875
+ lines.append(line)
876
+ except EOFError:
877
+ renderer.error("Multiline input interrupted (EOF)")
878
+ continue
879
+ except KeyboardInterrupt:
880
+ sys.stdout.write("\n")
881
+ renderer.status("(Multiline cancelled)")
882
+ continue
883
+ user_input = "\n".join(lines)
884
+
885
+ if not user_input:
886
+ continue
887
+
811
888
  # Parse command
812
889
  cmd_parts = parse_command(user_input)
813
890
  if cmd_parts:
@@ -827,28 +904,29 @@ def _run_pilot_repl(
827
904
  # :history
828
905
  if cmd == "history":
829
906
  limit = None
830
- show_all = False
831
907
  if args:
832
908
  token = args[0].lower()
833
909
  if token == "all":
834
- show_all = True
910
+ limit = None # Show all
835
911
  elif token.isdigit():
836
912
  limit = int(token)
913
+ else:
914
+ limit = 10
837
915
 
838
- entries = state.history_entries(show_all=show_all, limit=limit or 10)
916
+ entries = state.history(limit=limit)
839
917
  if not entries:
840
918
  renderer.status("No checkpoints yet.")
841
919
  else:
842
920
  renderer.status("")
843
921
  for entry in entries:
844
- cp = entry["checkpoint"]
845
922
  marker = "*" if entry["is_current"] else " "
846
- instruction_preview = cp.instruction[:60] + "..." if len(cp.instruction) > 60 else cp.instruction
923
+ desc = entry["description"]
924
+ desc_preview = desc[:60] + "..." if len(desc) > 60 else desc
847
925
  renderer.status(
848
- f"{marker}[{state.turn_label(cp.turn_id)}] "
849
- f"step={cp.step_count} "
850
- f"msgs={len(cp.messages)} "
851
- f"| {instruction_preview}"
926
+ f"{marker}[{entry['label']}] "
927
+ f"step={entry['metadata'].get('step_count', '?')} "
928
+ f"msgs={entry['metadata'].get('message_count', '?')} "
929
+ f"| {desc_preview}"
852
930
  )
853
931
  renderer.status("")
854
932
  continue
@@ -862,7 +940,7 @@ def _run_pilot_repl(
862
940
  token = args[0]
863
941
  if token.lower() == "root":
864
942
  # Go to root (before any turns)
865
- state.goto_turn(0)
943
+ state.goto(0)
866
944
  # Reset orchestrator to initial state
867
945
  if hasattr(orchestrator, "messages"):
868
946
  # Keep only system messages
@@ -885,31 +963,89 @@ def _run_pilot_repl(
885
963
  renderer.error(f"Invalid turn: {token}")
886
964
  continue
887
965
 
888
- cp = state.goto_turn(turn_id)
966
+ cp = state.goto(turn_id)
889
967
  if cp is None:
890
968
  renderer.error(f"Turn T{turn_id} not found.")
891
969
  continue
892
970
 
893
971
  # Restore orchestrator state
894
- orchestrator.messages = copy.deepcopy(cp.messages)
895
- orchestrator._step_count = cp.step_count
896
- renderer.status(f"Switched to {state.turn_label(turn_id)}.")
897
- renderer.status(f" instruction: {cp.instruction[:60]}...")
898
- renderer.status(f" messages: {len(cp.messages)}")
972
+ orchestrator.messages = copy.deepcopy(cp.state["messages"])
973
+ orchestrator._step_count = cp.state["step_count"]
974
+ renderer.status(f"Switched to {cp.label}.")
975
+ renderer.status(f" instruction: {cp.description[:60]}...")
976
+ renderer.status(f" messages: {len(cp.state['messages'])}")
899
977
  continue
900
978
 
901
- # :state
902
- if cmd == "state":
979
+ # :state / :status
980
+ if cmd in ("state", "status"):
903
981
  renderer.status("")
904
- renderer.status(f"Step count: {orchestrator._step_count}")
905
- renderer.status(f"Messages: {len(orchestrator.messages)}")
906
- if hasattr(orchestrator, "_total_tokens"):
907
- renderer.status(f"Total tokens: {orchestrator._total_tokens}")
908
- cp = state.current_checkpoint()
909
- if cp:
910
- renderer.status(f"Current turn: {state.turn_label(cp.turn_id)}")
911
- else:
912
- renderer.status("Current turn: root")
982
+ renderer.status("[bold]Pilot Status[/]")
983
+ renderer.status("")
984
+
985
+ # Basic stats
986
+ step_count = getattr(orchestrator, "_step_count", 0)
987
+ msg_count = len(orchestrator.messages)
988
+ total_tokens = getattr(orchestrator, "_total_tokens", 0)
989
+
990
+ renderer.status(f" Steps: {step_count}")
991
+ renderer.status(f" Messages: {msg_count}")
992
+
993
+ # Checkpoint
994
+ cp = state.current()
995
+ turn_label = cp.label if cp else "root"
996
+ renderer.status(f" Turn: {turn_label}")
997
+
998
+ # Token usage and context
999
+ renderer.status("")
1000
+ renderer.status("[bold]Token Usage[/]")
1001
+ renderer.status("")
1002
+
1003
+ # Get model from orchestrator if available
1004
+ model = "gpt-5.1-codex" # Default
1005
+ if hasattr(orchestrator, "lm") and hasattr(orchestrator.lm, "model"):
1006
+ model = orchestrator.lm.model
1007
+ elif hasattr(orchestrator, "config"):
1008
+ model = getattr(orchestrator.config, "model", model)
1009
+
1010
+ context_window = get_context_window(model)
1011
+ context_bar = render_context_bar(total_tokens, context_window)
1012
+
1013
+ renderer.status(f" Model: {model}")
1014
+ renderer.status(f" Tokens: {total_tokens:,} / {context_window:,}")
1015
+ renderer.status(f" Context: {context_bar}")
1016
+
1017
+ # Cost estimate for orchestrator
1018
+ pricing = get_pricing(model)
1019
+ if pricing and total_tokens > 0:
1020
+ # Estimate assuming 30% input, 70% output (typical for agentic)
1021
+ est_input = int(total_tokens * 0.3)
1022
+ est_output = total_tokens - est_input
1023
+ cost = pricing.estimate_cost(est_input, est_output)
1024
+ renderer.status(f" Est Cost: [green]{format_cost(cost)}[/] (pilot LLM)")
1025
+
1026
+ # Executor sessions summary
1027
+ snapshot = get_sessions_snapshot(orchestrator)
1028
+ sessions = snapshot.get("sessions", [])
1029
+ if sessions:
1030
+ renderer.status("")
1031
+ renderer.status("[bold]Executor Sessions[/]")
1032
+ renderer.status("")
1033
+
1034
+ exec_tokens = 0
1035
+ exec_cost = 0.0
1036
+ running = 0
1037
+ completed = 0
1038
+
1039
+ for s in sessions:
1040
+ exec_tokens += s.get("tokens", 0)
1041
+ if s.get("status") == "running":
1042
+ running += 1
1043
+ elif s.get("status") == "completed":
1044
+ completed += 1
1045
+
1046
+ renderer.status(f" Sessions: {len(sessions)} ({running} running, {completed} done)")
1047
+ renderer.status(f" Tokens: {exec_tokens:,}")
1048
+
913
1049
  renderer.status("")
914
1050
  continue
915
1051
 
@@ -975,23 +1111,29 @@ def _run_pilot_repl(
975
1111
  continue
976
1112
 
977
1113
  # Record checkpoint
978
- state.record_turn(
979
- instruction=user_input,
980
- messages=orchestrator.messages,
981
- sessions_snapshot=get_sessions_snapshot(orchestrator),
982
- step_count=orchestrator._step_count,
1114
+ state.record(
1115
+ description=user_input,
1116
+ state={
1117
+ "messages": orchestrator.messages,
1118
+ "sessions_snapshot": get_sessions_snapshot(orchestrator),
1119
+ "step_count": orchestrator._step_count,
1120
+ },
1121
+ metadata={
1122
+ "step_count": orchestrator._step_count,
1123
+ "message_count": len(orchestrator.messages),
1124
+ },
983
1125
  )
984
1126
 
985
1127
  # Show turn info
986
- cp = state.current_checkpoint()
1128
+ cp = state.current()
987
1129
  if cp:
988
1130
  renderer.status("")
989
1131
  renderer.status(
990
- f"[{state.turn_label(cp.turn_id)}] "
991
- f"step={cp.step_count} "
992
- f"messages={len(cp.messages)}"
1132
+ f"[{cp.label}] "
1133
+ f"step={cp.state['step_count']} "
1134
+ f"messages={len(cp.state['messages'])}"
993
1135
  )
994
- renderer.status(f":goto {state.turn_label(cp.turn_id)} to return here, :history for timeline")
1136
+ renderer.status(f":goto {cp.label} to return here, :history for timeline")
995
1137
 
996
1138
  # Check stop condition
997
1139
  if hasattr(orchestrator, "stopCondition") and orchestrator.stopCondition:
zwarm/core/__init__.py CHANGED
@@ -0,0 +1,20 @@
1
+ """Core primitives for zwarm."""
2
+
3
+ from .checkpoints import Checkpoint, CheckpointManager
4
+ from .costs import (
5
+ estimate_cost,
6
+ estimate_session_cost,
7
+ format_cost,
8
+ get_pricing,
9
+ ModelPricing,
10
+ )
11
+
12
+ __all__ = [
13
+ "Checkpoint",
14
+ "CheckpointManager",
15
+ "estimate_cost",
16
+ "estimate_session_cost",
17
+ "format_cost",
18
+ "get_pricing",
19
+ "ModelPricing",
20
+ ]