code-puppy 0.0.325__py3-none-any.whl → 0.0.341__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- code_puppy/agents/base_agent.py +110 -124
- code_puppy/claude_cache_client.py +208 -2
- code_puppy/cli_runner.py +152 -32
- code_puppy/command_line/add_model_menu.py +4 -0
- code_puppy/command_line/autosave_menu.py +23 -24
- code_puppy/command_line/clipboard.py +527 -0
- code_puppy/command_line/colors_menu.py +5 -0
- code_puppy/command_line/config_commands.py +24 -1
- code_puppy/command_line/core_commands.py +85 -0
- code_puppy/command_line/diff_menu.py +5 -0
- code_puppy/command_line/mcp/custom_server_form.py +4 -0
- code_puppy/command_line/mcp/install_menu.py +5 -1
- code_puppy/command_line/model_settings_menu.py +5 -0
- code_puppy/command_line/motd.py +13 -7
- code_puppy/command_line/onboarding_slides.py +180 -0
- code_puppy/command_line/onboarding_wizard.py +340 -0
- code_puppy/command_line/prompt_toolkit_completion.py +118 -0
- code_puppy/config.py +3 -2
- code_puppy/http_utils.py +201 -279
- code_puppy/keymap.py +10 -8
- code_puppy/mcp_/managed_server.py +7 -11
- code_puppy/messaging/messages.py +3 -0
- code_puppy/messaging/rich_renderer.py +114 -22
- code_puppy/model_factory.py +102 -15
- code_puppy/models.json +2 -2
- code_puppy/plugins/antigravity_oauth/__init__.py +10 -0
- code_puppy/plugins/antigravity_oauth/accounts.py +406 -0
- code_puppy/plugins/antigravity_oauth/antigravity_model.py +668 -0
- code_puppy/plugins/antigravity_oauth/config.py +42 -0
- code_puppy/plugins/antigravity_oauth/constants.py +136 -0
- code_puppy/plugins/antigravity_oauth/oauth.py +478 -0
- code_puppy/plugins/antigravity_oauth/register_callbacks.py +406 -0
- code_puppy/plugins/antigravity_oauth/storage.py +271 -0
- code_puppy/plugins/antigravity_oauth/test_plugin.py +319 -0
- code_puppy/plugins/antigravity_oauth/token.py +167 -0
- code_puppy/plugins/antigravity_oauth/transport.py +664 -0
- code_puppy/plugins/antigravity_oauth/utils.py +169 -0
- code_puppy/plugins/chatgpt_oauth/register_callbacks.py +2 -0
- code_puppy/plugins/claude_code_oauth/register_callbacks.py +2 -0
- code_puppy/plugins/claude_code_oauth/utils.py +126 -7
- code_puppy/reopenable_async_client.py +8 -8
- code_puppy/terminal_utils.py +295 -3
- code_puppy/tools/command_runner.py +43 -54
- code_puppy/tools/common.py +3 -9
- code_puppy/uvx_detection.py +242 -0
- {code_puppy-0.0.325.data → code_puppy-0.0.341.data}/data/code_puppy/models.json +2 -2
- {code_puppy-0.0.325.dist-info → code_puppy-0.0.341.dist-info}/METADATA +26 -49
- {code_puppy-0.0.325.dist-info → code_puppy-0.0.341.dist-info}/RECORD +52 -36
- {code_puppy-0.0.325.data → code_puppy-0.0.341.data}/data/code_puppy/models_dev_api.json +0 -0
- {code_puppy-0.0.325.dist-info → code_puppy-0.0.341.dist-info}/WHEEL +0 -0
- {code_puppy-0.0.325.dist-info → code_puppy-0.0.341.dist-info}/entry_points.txt +0 -0
- {code_puppy-0.0.325.dist-info → code_puppy-0.0.341.dist-info}/licenses/LICENSE +0 -0
code_puppy/agents/base_agent.py
CHANGED
|
@@ -4,7 +4,6 @@ import asyncio
|
|
|
4
4
|
import json
|
|
5
5
|
import math
|
|
6
6
|
import signal
|
|
7
|
-
import sys
|
|
8
7
|
import threading
|
|
9
8
|
import uuid
|
|
10
9
|
from abc import ABC, abstractmethod
|
|
@@ -914,6 +913,11 @@ class BaseAgent(ABC):
|
|
|
914
913
|
"""
|
|
915
914
|
Truncate message history to manage token usage.
|
|
916
915
|
|
|
916
|
+
Protects:
|
|
917
|
+
- The first message (system prompt) - always kept
|
|
918
|
+
- The second message if it contains a ThinkingPart (extended thinking context)
|
|
919
|
+
- The most recent messages up to protected_tokens
|
|
920
|
+
|
|
917
921
|
Args:
|
|
918
922
|
messages: List of messages to truncate
|
|
919
923
|
protected_tokens: Number of tokens to protect
|
|
@@ -925,12 +929,30 @@ class BaseAgent(ABC):
|
|
|
925
929
|
|
|
926
930
|
emit_info("Truncating message history to manage token usage")
|
|
927
931
|
result = [messages[0]] # Always keep the first message (system prompt)
|
|
932
|
+
|
|
933
|
+
# Check if second message exists and contains a ThinkingPart
|
|
934
|
+
# If so, protect it (extended thinking context shouldn't be lost)
|
|
935
|
+
skip_second = False
|
|
936
|
+
if len(messages) > 1:
|
|
937
|
+
second_msg = messages[1]
|
|
938
|
+
has_thinking = any(
|
|
939
|
+
isinstance(part, ThinkingPart) for part in second_msg.parts
|
|
940
|
+
)
|
|
941
|
+
if has_thinking:
|
|
942
|
+
result.append(second_msg)
|
|
943
|
+
skip_second = True
|
|
944
|
+
|
|
928
945
|
num_tokens = 0
|
|
929
946
|
stack = queue.LifoQueue()
|
|
930
947
|
|
|
948
|
+
# Determine which messages to consider for the recent-tokens window
|
|
949
|
+
# Skip first message (already added), and skip second if it has thinking
|
|
950
|
+
start_idx = 2 if skip_second else 1
|
|
951
|
+
messages_to_scan = messages[start_idx:]
|
|
952
|
+
|
|
931
953
|
# Put messages in reverse order (most recent first) into the stack
|
|
932
954
|
# but break when we exceed protected_tokens
|
|
933
|
-
for
|
|
955
|
+
for msg in reversed(messages_to_scan):
|
|
934
956
|
num_tokens += self.estimate_tokens_for_message(msg)
|
|
935
957
|
if num_tokens > protected_tokens:
|
|
936
958
|
break
|
|
@@ -1354,7 +1376,6 @@ class BaseAgent(ABC):
|
|
|
1354
1376
|
ToolCallPartDelta,
|
|
1355
1377
|
)
|
|
1356
1378
|
from rich.console import Console
|
|
1357
|
-
from rich.markdown import Markdown
|
|
1358
1379
|
from rich.markup import escape
|
|
1359
1380
|
|
|
1360
1381
|
from code_puppy.messaging.spinner import pause_all_spinners
|
|
@@ -1376,22 +1397,28 @@ class BaseAgent(ABC):
|
|
|
1376
1397
|
text_parts: set[int] = set() # Track which parts are text
|
|
1377
1398
|
tool_parts: set[int] = set() # Track which parts are tool calls
|
|
1378
1399
|
banner_printed: set[int] = set() # Track if banner was already printed
|
|
1379
|
-
text_buffer: dict[int, list[str]] = {} # Buffer text for final markdown render
|
|
1380
1400
|
token_count: dict[int, int] = {} # Track token count per text/tool part
|
|
1381
1401
|
did_stream_anything = False # Track if we streamed any content
|
|
1382
1402
|
|
|
1403
|
+
# Termflow streaming state for text parts
|
|
1404
|
+
from termflow import Parser as TermflowParser
|
|
1405
|
+
from termflow import Renderer as TermflowRenderer
|
|
1406
|
+
|
|
1407
|
+
termflow_parsers: dict[int, TermflowParser] = {}
|
|
1408
|
+
termflow_renderers: dict[int, TermflowRenderer] = {}
|
|
1409
|
+
termflow_line_buffers: dict[int, str] = {} # Buffer incomplete lines
|
|
1410
|
+
|
|
1383
1411
|
def _print_thinking_banner() -> None:
|
|
1384
1412
|
"""Print the THINKING banner with spinner pause and line clear."""
|
|
1385
1413
|
nonlocal did_stream_anything
|
|
1386
|
-
import sys
|
|
1387
1414
|
import time
|
|
1388
1415
|
|
|
1389
1416
|
from code_puppy.config import get_banner_color
|
|
1390
1417
|
|
|
1391
1418
|
pause_all_spinners()
|
|
1392
1419
|
time.sleep(0.1) # Delay to let spinner fully clear
|
|
1393
|
-
|
|
1394
|
-
|
|
1420
|
+
# Clear line and print newline before banner
|
|
1421
|
+
console.print(" " * 50, end="\r")
|
|
1395
1422
|
console.print() # Newline before banner
|
|
1396
1423
|
# Bold banner with configurable color and lightning bolt
|
|
1397
1424
|
thinking_color = get_banner_color("thinking")
|
|
@@ -1401,21 +1428,19 @@ class BaseAgent(ABC):
|
|
|
1401
1428
|
),
|
|
1402
1429
|
end="",
|
|
1403
1430
|
)
|
|
1404
|
-
sys.stdout.flush()
|
|
1405
1431
|
did_stream_anything = True
|
|
1406
1432
|
|
|
1407
1433
|
def _print_response_banner() -> None:
|
|
1408
1434
|
"""Print the AGENT RESPONSE banner with spinner pause and line clear."""
|
|
1409
1435
|
nonlocal did_stream_anything
|
|
1410
|
-
import sys
|
|
1411
1436
|
import time
|
|
1412
1437
|
|
|
1413
1438
|
from code_puppy.config import get_banner_color
|
|
1414
1439
|
|
|
1415
1440
|
pause_all_spinners()
|
|
1416
1441
|
time.sleep(0.1) # Delay to let spinner fully clear
|
|
1417
|
-
|
|
1418
|
-
|
|
1442
|
+
# Clear line and print newline before banner
|
|
1443
|
+
console.print(" " * 50, end="\r")
|
|
1419
1444
|
console.print() # Newline before banner
|
|
1420
1445
|
response_color = get_banner_color("agent_response")
|
|
1421
1446
|
console.print(
|
|
@@ -1423,7 +1448,6 @@ class BaseAgent(ABC):
|
|
|
1423
1448
|
f"[bold white on {response_color}] AGENT RESPONSE [/bold white on {response_color}]"
|
|
1424
1449
|
)
|
|
1425
1450
|
)
|
|
1426
|
-
sys.stdout.flush()
|
|
1427
1451
|
did_stream_anything = True
|
|
1428
1452
|
|
|
1429
1453
|
async for event in events:
|
|
@@ -1442,13 +1466,17 @@ class BaseAgent(ABC):
|
|
|
1442
1466
|
elif isinstance(part, TextPart):
|
|
1443
1467
|
streaming_parts.add(event.index)
|
|
1444
1468
|
text_parts.add(event.index)
|
|
1445
|
-
|
|
1446
|
-
|
|
1447
|
-
|
|
1469
|
+
# Initialize termflow streaming for this text part
|
|
1470
|
+
termflow_parsers[event.index] = TermflowParser()
|
|
1471
|
+
termflow_renderers[event.index] = TermflowRenderer(
|
|
1472
|
+
output=console.file, width=console.width
|
|
1473
|
+
)
|
|
1474
|
+
termflow_line_buffers[event.index] = ""
|
|
1475
|
+
# Handle initial content if present
|
|
1448
1476
|
if part.content and part.content.strip():
|
|
1449
|
-
|
|
1450
|
-
|
|
1451
|
-
|
|
1477
|
+
_print_response_banner()
|
|
1478
|
+
banner_printed.add(event.index)
|
|
1479
|
+
termflow_line_buffers[event.index] = part.content
|
|
1452
1480
|
elif isinstance(part, ToolCallPart):
|
|
1453
1481
|
streaming_parts.add(event.index)
|
|
1454
1482
|
tool_parts.add(event.index)
|
|
@@ -1464,26 +1492,29 @@ class BaseAgent(ABC):
|
|
|
1464
1492
|
delta = event.delta
|
|
1465
1493
|
if isinstance(delta, (TextPartDelta, ThinkingPartDelta)):
|
|
1466
1494
|
if delta.content_delta:
|
|
1467
|
-
# For text parts,
|
|
1495
|
+
# For text parts, stream markdown with termflow
|
|
1468
1496
|
if event.index in text_parts:
|
|
1469
|
-
import sys
|
|
1470
|
-
|
|
1471
1497
|
# Print banner on first content
|
|
1472
1498
|
if event.index not in banner_printed:
|
|
1473
1499
|
_print_response_banner()
|
|
1474
1500
|
banner_printed.add(event.index)
|
|
1475
|
-
|
|
1476
|
-
|
|
1477
|
-
|
|
1478
|
-
|
|
1479
|
-
len(delta.content_delta) // 3
|
|
1480
|
-
)
|
|
1481
|
-
# Update token counter in place (single line)
|
|
1482
|
-
count = token_count[event.index]
|
|
1483
|
-
sys.stdout.write(
|
|
1484
|
-
f"\r\x1b[K ⏳ Receiving... {count} tokens"
|
|
1501
|
+
|
|
1502
|
+
# Add content to line buffer
|
|
1503
|
+
termflow_line_buffers[event.index] += (
|
|
1504
|
+
delta.content_delta
|
|
1485
1505
|
)
|
|
1486
|
-
|
|
1506
|
+
|
|
1507
|
+
# Process complete lines
|
|
1508
|
+
parser = termflow_parsers[event.index]
|
|
1509
|
+
renderer = termflow_renderers[event.index]
|
|
1510
|
+
buffer = termflow_line_buffers[event.index]
|
|
1511
|
+
|
|
1512
|
+
while "\n" in buffer:
|
|
1513
|
+
line, buffer = buffer.split("\n", 1)
|
|
1514
|
+
events_to_render = parser.parse_line(line)
|
|
1515
|
+
renderer.render_all(events_to_render)
|
|
1516
|
+
|
|
1517
|
+
termflow_line_buffers[event.index] = buffer
|
|
1487
1518
|
else:
|
|
1488
1519
|
# For thinking parts, stream immediately (dim)
|
|
1489
1520
|
if event.index not in banner_printed:
|
|
@@ -1492,48 +1523,51 @@ class BaseAgent(ABC):
|
|
|
1492
1523
|
escaped = escape(delta.content_delta)
|
|
1493
1524
|
console.print(f"[dim]{escaped}[/dim]", end="")
|
|
1494
1525
|
elif isinstance(delta, ToolCallPartDelta):
|
|
1495
|
-
|
|
1496
|
-
|
|
1497
|
-
# For tool calls, show token counter (use string repr for estimation)
|
|
1498
|
-
token_count[event.index] += len(str(delta)) // 3
|
|
1526
|
+
# For tool calls, count chunks received
|
|
1527
|
+
token_count[event.index] += 1
|
|
1499
1528
|
# Get tool name if available
|
|
1500
1529
|
tool_name = getattr(delta, "tool_name_delta", "")
|
|
1501
1530
|
count = token_count[event.index]
|
|
1502
1531
|
# Display with tool wrench icon and tool name
|
|
1503
1532
|
if tool_name:
|
|
1504
|
-
|
|
1505
|
-
f"
|
|
1533
|
+
console.print(
|
|
1534
|
+
f" 🔧 Calling {tool_name}... {count} chunks ",
|
|
1535
|
+
end="\r",
|
|
1506
1536
|
)
|
|
1507
1537
|
else:
|
|
1508
|
-
|
|
1509
|
-
f"
|
|
1538
|
+
console.print(
|
|
1539
|
+
f" 🔧 Calling tool... {count} chunks ",
|
|
1540
|
+
end="\r",
|
|
1510
1541
|
)
|
|
1511
|
-
sys.stdout.flush()
|
|
1512
1542
|
|
|
1513
1543
|
# PartEndEvent - finish the streaming with a newline
|
|
1514
1544
|
elif isinstance(event, PartEndEvent):
|
|
1515
1545
|
if event.index in streaming_parts:
|
|
1516
|
-
|
|
1517
|
-
|
|
1518
|
-
# For text parts, clear counter line and render markdown
|
|
1546
|
+
# For text parts, finalize termflow rendering
|
|
1519
1547
|
if event.index in text_parts:
|
|
1520
|
-
#
|
|
1521
|
-
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
|
|
1525
|
-
|
|
1526
|
-
|
|
1527
|
-
|
|
1528
|
-
|
|
1529
|
-
|
|
1530
|
-
|
|
1531
|
-
|
|
1532
|
-
|
|
1548
|
+
# Render any remaining buffered content
|
|
1549
|
+
if event.index in termflow_parsers:
|
|
1550
|
+
parser = termflow_parsers[event.index]
|
|
1551
|
+
renderer = termflow_renderers[event.index]
|
|
1552
|
+
remaining = termflow_line_buffers.get(event.index, "")
|
|
1553
|
+
|
|
1554
|
+
# Parse and render any remaining partial line
|
|
1555
|
+
if remaining.strip():
|
|
1556
|
+
events_to_render = parser.parse_line(remaining)
|
|
1557
|
+
renderer.render_all(events_to_render)
|
|
1558
|
+
|
|
1559
|
+
# Finalize the parser to close any open blocks
|
|
1560
|
+
final_events = parser.finalize()
|
|
1561
|
+
renderer.render_all(final_events)
|
|
1562
|
+
|
|
1563
|
+
# Clean up termflow state
|
|
1564
|
+
del termflow_parsers[event.index]
|
|
1565
|
+
del termflow_renderers[event.index]
|
|
1566
|
+
del termflow_line_buffers[event.index]
|
|
1567
|
+
# For tool parts, clear the chunk counter line
|
|
1533
1568
|
elif event.index in tool_parts:
|
|
1534
|
-
# Clear the
|
|
1535
|
-
|
|
1536
|
-
sys.stdout.flush()
|
|
1569
|
+
# Clear the chunk counter line by printing spaces and returning
|
|
1570
|
+
console.print(" " * 50, end="\r")
|
|
1537
1571
|
# For thinking parts, just print newline
|
|
1538
1572
|
elif event.index in banner_printed:
|
|
1539
1573
|
console.print() # Final newline after streaming
|
|
@@ -1952,74 +1986,35 @@ class BaseAgent(ABC):
|
|
|
1952
1986
|
def graceful_sigint_handler(_sig, _frame):
|
|
1953
1987
|
# When using keyboard-based cancel, SIGINT should be a no-op
|
|
1954
1988
|
# (just show a hint to user about the configured cancel key)
|
|
1955
|
-
|
|
1956
|
-
|
|
1989
|
+
# Also reset terminal to prevent bricking on Windows+uvx
|
|
1957
1990
|
from code_puppy.keymap import get_cancel_agent_display_name
|
|
1991
|
+
from code_puppy.terminal_utils import reset_windows_terminal_full
|
|
1992
|
+
|
|
1993
|
+
# Reset terminal state first to prevent bricking
|
|
1994
|
+
reset_windows_terminal_full()
|
|
1958
1995
|
|
|
1959
1996
|
cancel_key = get_cancel_agent_display_name()
|
|
1960
|
-
|
|
1961
|
-
# On Windows, we use keyboard listener, so SIGINT might still fire
|
|
1962
|
-
# but we handle cancellation via the key listener
|
|
1963
|
-
pass # Silent on Windows - the key listener handles it
|
|
1964
|
-
else:
|
|
1965
|
-
emit_info(f"Use {cancel_key} to cancel the agent task.")
|
|
1997
|
+
emit_info(f"Use {cancel_key} to cancel the agent task.")
|
|
1966
1998
|
|
|
1967
1999
|
original_handler = None
|
|
1968
2000
|
key_listener_stop_event = None
|
|
1969
2001
|
_key_listener_thread = None
|
|
1970
|
-
_windows_ctrl_handler = None # Store reference to prevent garbage collection
|
|
1971
2002
|
|
|
1972
2003
|
try:
|
|
1973
|
-
if
|
|
1974
|
-
#
|
|
1975
|
-
import ctypes
|
|
1976
|
-
|
|
1977
|
-
# Define the handler function type
|
|
1978
|
-
HANDLER_ROUTINE = ctypes.WINFUNCTYPE(ctypes.c_bool, ctypes.c_ulong)
|
|
1979
|
-
|
|
1980
|
-
def windows_ctrl_handler(ctrl_type):
|
|
1981
|
-
"""Handle Windows console control events."""
|
|
1982
|
-
CTRL_C_EVENT = 0
|
|
1983
|
-
CTRL_BREAK_EVENT = 1
|
|
1984
|
-
|
|
1985
|
-
if ctrl_type in (CTRL_C_EVENT, CTRL_BREAK_EVENT):
|
|
1986
|
-
# Check if we're awaiting user input
|
|
1987
|
-
if is_awaiting_user_input():
|
|
1988
|
-
return False # Let default handler run
|
|
1989
|
-
|
|
1990
|
-
# Schedule agent cancellation
|
|
1991
|
-
schedule_agent_cancel()
|
|
1992
|
-
return True # We handled it, don't terminate
|
|
1993
|
-
|
|
1994
|
-
return False # Let other handlers process it
|
|
1995
|
-
|
|
1996
|
-
# Create the callback - must keep reference alive!
|
|
1997
|
-
_windows_ctrl_handler = HANDLER_ROUTINE(windows_ctrl_handler)
|
|
1998
|
-
|
|
1999
|
-
# Register the handler
|
|
2000
|
-
kernel32 = ctypes.windll.kernel32
|
|
2001
|
-
if not kernel32.SetConsoleCtrlHandler(_windows_ctrl_handler, True):
|
|
2002
|
-
emit_warning("Failed to set Windows Ctrl+C handler")
|
|
2003
|
-
|
|
2004
|
-
# Also spawn keyboard listener for Ctrl+X (shell cancel) and other keys
|
|
2005
|
-
key_listener_stop_event = threading.Event()
|
|
2006
|
-
_key_listener_thread = self._spawn_ctrl_x_key_listener(
|
|
2007
|
-
key_listener_stop_event,
|
|
2008
|
-
on_escape=lambda: None, # Ctrl+X handled by command_runner
|
|
2009
|
-
on_cancel_agent=None, # Ctrl+C handled by SetConsoleCtrlHandler above
|
|
2010
|
-
)
|
|
2011
|
-
elif cancel_agent_uses_signal():
|
|
2012
|
-
# Unix with Ctrl+C: Use SIGINT-based cancellation
|
|
2004
|
+
if cancel_agent_uses_signal():
|
|
2005
|
+
# Use SIGINT-based cancellation (default Ctrl+C behavior)
|
|
2013
2006
|
original_handler = signal.signal(
|
|
2014
2007
|
signal.SIGINT, keyboard_interrupt_handler
|
|
2015
2008
|
)
|
|
2016
2009
|
else:
|
|
2017
|
-
#
|
|
2010
|
+
# Use keyboard listener for agent cancellation
|
|
2011
|
+
# Set a graceful SIGINT handler that shows a hint
|
|
2018
2012
|
original_handler = signal.signal(signal.SIGINT, graceful_sigint_handler)
|
|
2013
|
+
# Spawn keyboard listener with the cancel agent callback
|
|
2019
2014
|
key_listener_stop_event = threading.Event()
|
|
2020
2015
|
_key_listener_thread = self._spawn_ctrl_x_key_listener(
|
|
2021
2016
|
key_listener_stop_event,
|
|
2022
|
-
on_escape=lambda: None,
|
|
2017
|
+
on_escape=lambda: None, # Ctrl+X handled by command_runner
|
|
2023
2018
|
on_cancel_agent=schedule_agent_cancel,
|
|
2024
2019
|
)
|
|
2025
2020
|
|
|
@@ -2044,17 +2039,8 @@ class BaseAgent(ABC):
|
|
|
2044
2039
|
# Stop keyboard listener if it was started
|
|
2045
2040
|
if key_listener_stop_event is not None:
|
|
2046
2041
|
key_listener_stop_event.set()
|
|
2047
|
-
|
|
2048
|
-
|
|
2049
|
-
|
|
2050
|
-
|
|
2051
|
-
import ctypes
|
|
2052
|
-
|
|
2053
|
-
kernel32 = ctypes.windll.kernel32
|
|
2054
|
-
kernel32.SetConsoleCtrlHandler(_windows_ctrl_handler, False)
|
|
2055
|
-
except Exception:
|
|
2056
|
-
pass # Best effort cleanup
|
|
2057
|
-
|
|
2058
|
-
# Restore original signal handler (Unix)
|
|
2059
|
-
if original_handler is not None:
|
|
2042
|
+
# Restore original signal handler
|
|
2043
|
+
if (
|
|
2044
|
+
original_handler is not None
|
|
2045
|
+
): # Explicit None check - SIG_DFL can be 0/falsy!
|
|
2060
2046
|
signal.signal(signal.SIGINT, original_handler)
|
|
@@ -9,11 +9,19 @@ serialization, avoiding httpx/Pydantic internals.
|
|
|
9
9
|
|
|
10
10
|
from __future__ import annotations
|
|
11
11
|
|
|
12
|
+
import base64
|
|
12
13
|
import json
|
|
13
|
-
|
|
14
|
+
import logging
|
|
15
|
+
import time
|
|
16
|
+
from typing import Any, Callable, MutableMapping
|
|
14
17
|
|
|
15
18
|
import httpx
|
|
16
19
|
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
# Refresh token if it's older than 1 hour (3600 seconds)
|
|
23
|
+
TOKEN_MAX_AGE_SECONDS = 3600
|
|
24
|
+
|
|
17
25
|
try:
|
|
18
26
|
from anthropic import AsyncAnthropic
|
|
19
27
|
except ImportError: # pragma: no cover - optional dep
|
|
@@ -21,9 +29,108 @@ except ImportError: # pragma: no cover - optional dep
|
|
|
21
29
|
|
|
22
30
|
|
|
23
31
|
class ClaudeCacheAsyncClient(httpx.AsyncClient):
|
|
32
|
+
def _get_jwt_age_seconds(self, token: str | None) -> float | None:
|
|
33
|
+
"""Decode a JWT and return its age in seconds.
|
|
34
|
+
|
|
35
|
+
Returns None if the token can't be decoded or has no timestamp claims.
|
|
36
|
+
Uses 'iat' (issued at) if available, otherwise calculates from 'exp'.
|
|
37
|
+
"""
|
|
38
|
+
if not token:
|
|
39
|
+
return None
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
# JWT format: header.payload.signature
|
|
43
|
+
# We only need the payload (second part)
|
|
44
|
+
parts = token.split(".")
|
|
45
|
+
if len(parts) != 3:
|
|
46
|
+
return None
|
|
47
|
+
|
|
48
|
+
# Decode the payload (base64url encoded)
|
|
49
|
+
payload_b64 = parts[1]
|
|
50
|
+
# Add padding if needed (base64url doesn't require padding)
|
|
51
|
+
padding = 4 - len(payload_b64) % 4
|
|
52
|
+
if padding != 4:
|
|
53
|
+
payload_b64 += "=" * padding
|
|
54
|
+
|
|
55
|
+
payload_bytes = base64.urlsafe_b64decode(payload_b64)
|
|
56
|
+
payload = json.loads(payload_bytes.decode("utf-8"))
|
|
57
|
+
|
|
58
|
+
now = time.time()
|
|
59
|
+
|
|
60
|
+
# Prefer 'iat' (issued at) claim if available
|
|
61
|
+
if "iat" in payload:
|
|
62
|
+
iat = float(payload["iat"])
|
|
63
|
+
age = now - iat
|
|
64
|
+
return age
|
|
65
|
+
|
|
66
|
+
# Fall back to calculating from 'exp' claim
|
|
67
|
+
# Assume tokens are typically valid for 1 hour
|
|
68
|
+
if "exp" in payload:
|
|
69
|
+
exp = float(payload["exp"])
|
|
70
|
+
# If exp is in the future, calculate how long until expiry
|
|
71
|
+
# and assume the token was issued 1 hour before expiry
|
|
72
|
+
time_until_exp = exp - now
|
|
73
|
+
# If token has less than 1 hour left, it's "old"
|
|
74
|
+
age = TOKEN_MAX_AGE_SECONDS - time_until_exp
|
|
75
|
+
return max(0, age)
|
|
76
|
+
|
|
77
|
+
return None
|
|
78
|
+
except Exception as exc:
|
|
79
|
+
logger.debug("Failed to decode JWT age: %s", exc)
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
def _extract_bearer_token(self, request: httpx.Request) -> str | None:
|
|
83
|
+
"""Extract the bearer token from request headers."""
|
|
84
|
+
auth_header = request.headers.get("Authorization") or request.headers.get(
|
|
85
|
+
"authorization"
|
|
86
|
+
)
|
|
87
|
+
if auth_header and auth_header.lower().startswith("bearer "):
|
|
88
|
+
return auth_header[7:] # Strip "Bearer " prefix
|
|
89
|
+
return None
|
|
90
|
+
|
|
91
|
+
def _should_refresh_token(self, request: httpx.Request) -> bool:
|
|
92
|
+
"""Check if the token in the request is older than 1 hour."""
|
|
93
|
+
token = self._extract_bearer_token(request)
|
|
94
|
+
if not token:
|
|
95
|
+
return False
|
|
96
|
+
|
|
97
|
+
age = self._get_jwt_age_seconds(token)
|
|
98
|
+
if age is None:
|
|
99
|
+
return False
|
|
100
|
+
|
|
101
|
+
should_refresh = age >= TOKEN_MAX_AGE_SECONDS
|
|
102
|
+
if should_refresh:
|
|
103
|
+
logger.info(
|
|
104
|
+
"JWT token is %.1f seconds old (>= %d), will refresh proactively",
|
|
105
|
+
age,
|
|
106
|
+
TOKEN_MAX_AGE_SECONDS,
|
|
107
|
+
)
|
|
108
|
+
return should_refresh
|
|
109
|
+
|
|
24
110
|
async def send(
|
|
25
111
|
self, request: httpx.Request, *args: Any, **kwargs: Any
|
|
26
112
|
) -> httpx.Response: # type: ignore[override]
|
|
113
|
+
# Proactive token refresh: check JWT age before every request
|
|
114
|
+
if not request.extensions.get("claude_oauth_refresh_attempted"):
|
|
115
|
+
try:
|
|
116
|
+
if self._should_refresh_token(request):
|
|
117
|
+
refreshed_token = self._refresh_claude_oauth_token()
|
|
118
|
+
if refreshed_token:
|
|
119
|
+
logger.info("Proactively refreshed token before request")
|
|
120
|
+
# Rebuild request with new token
|
|
121
|
+
headers = dict(request.headers)
|
|
122
|
+
self._update_auth_headers(headers, refreshed_token)
|
|
123
|
+
body_bytes = self._extract_body_bytes(request)
|
|
124
|
+
request = self.build_request(
|
|
125
|
+
method=request.method,
|
|
126
|
+
url=request.url,
|
|
127
|
+
headers=headers,
|
|
128
|
+
content=body_bytes,
|
|
129
|
+
)
|
|
130
|
+
request.extensions["claude_oauth_refresh_attempted"] = True
|
|
131
|
+
except Exception as exc:
|
|
132
|
+
logger.debug("Error during proactive token refresh check: %s", exc)
|
|
133
|
+
|
|
27
134
|
try:
|
|
28
135
|
if request.url.path.endswith("/v1/messages"):
|
|
29
136
|
body_bytes = self._extract_body_bytes(request)
|
|
@@ -56,7 +163,47 @@ class ClaudeCacheAsyncClient(httpx.AsyncClient):
|
|
|
56
163
|
except Exception:
|
|
57
164
|
# Swallow wrapper errors; do not break real calls.
|
|
58
165
|
pass
|
|
59
|
-
|
|
166
|
+
response = await super().send(request, *args, **kwargs)
|
|
167
|
+
try:
|
|
168
|
+
# Check for both 401 and 400 - Anthropic/Cloudflare may return 400 for auth errors
|
|
169
|
+
# Also check if it's a Cloudflare HTML error response
|
|
170
|
+
if response.status_code in (400, 401) and not request.extensions.get(
|
|
171
|
+
"claude_oauth_refresh_attempted"
|
|
172
|
+
):
|
|
173
|
+
# Determine if this is an auth error (including Cloudflare HTML errors)
|
|
174
|
+
is_auth_error = response.status_code == 401
|
|
175
|
+
|
|
176
|
+
if response.status_code == 400:
|
|
177
|
+
# Check if this is a Cloudflare HTML error
|
|
178
|
+
is_auth_error = self._is_cloudflare_html_error(response)
|
|
179
|
+
if is_auth_error:
|
|
180
|
+
logger.info(
|
|
181
|
+
"Detected Cloudflare 400 error (likely auth-related), attempting token refresh"
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
if is_auth_error:
|
|
185
|
+
refreshed_token = self._refresh_claude_oauth_token()
|
|
186
|
+
if refreshed_token:
|
|
187
|
+
logger.info("Token refreshed successfully, retrying request")
|
|
188
|
+
await response.aclose()
|
|
189
|
+
body_bytes = self._extract_body_bytes(request)
|
|
190
|
+
headers = dict(request.headers)
|
|
191
|
+
self._update_auth_headers(headers, refreshed_token)
|
|
192
|
+
retry_request = self.build_request(
|
|
193
|
+
method=request.method,
|
|
194
|
+
url=request.url,
|
|
195
|
+
headers=headers,
|
|
196
|
+
content=body_bytes,
|
|
197
|
+
)
|
|
198
|
+
retry_request.extensions["claude_oauth_refresh_attempted"] = (
|
|
199
|
+
True
|
|
200
|
+
)
|
|
201
|
+
return await super().send(retry_request, *args, **kwargs)
|
|
202
|
+
else:
|
|
203
|
+
logger.warning("Token refresh failed, returning original error")
|
|
204
|
+
except Exception as exc:
|
|
205
|
+
logger.debug("Error during token refresh attempt: %s", exc)
|
|
206
|
+
return response
|
|
60
207
|
|
|
61
208
|
@staticmethod
|
|
62
209
|
def _extract_body_bytes(request: httpx.Request) -> bytes | None:
|
|
@@ -78,6 +225,65 @@ class ClaudeCacheAsyncClient(httpx.AsyncClient):
|
|
|
78
225
|
|
|
79
226
|
return None
|
|
80
227
|
|
|
228
|
+
@staticmethod
|
|
229
|
+
def _update_auth_headers(
|
|
230
|
+
headers: MutableMapping[str, str], access_token: str
|
|
231
|
+
) -> None:
|
|
232
|
+
bearer_value = f"Bearer {access_token}"
|
|
233
|
+
if "Authorization" in headers or "authorization" in headers:
|
|
234
|
+
headers["Authorization"] = bearer_value
|
|
235
|
+
elif "x-api-key" in headers or "X-API-Key" in headers:
|
|
236
|
+
headers["x-api-key"] = access_token
|
|
237
|
+
else:
|
|
238
|
+
headers["Authorization"] = bearer_value
|
|
239
|
+
|
|
240
|
+
@staticmethod
|
|
241
|
+
def _is_cloudflare_html_error(response: httpx.Response) -> bool:
|
|
242
|
+
"""Check if this is a Cloudflare HTML error response.
|
|
243
|
+
|
|
244
|
+
Cloudflare often returns HTML error pages with status 400 when
|
|
245
|
+
there are authentication issues.
|
|
246
|
+
"""
|
|
247
|
+
# Check content type
|
|
248
|
+
content_type = response.headers.get("content-type", "")
|
|
249
|
+
if "text/html" not in content_type.lower():
|
|
250
|
+
return False
|
|
251
|
+
|
|
252
|
+
# Check if body contains Cloudflare markers
|
|
253
|
+
try:
|
|
254
|
+
# Read response body if not already consumed
|
|
255
|
+
if hasattr(response, "_content") and response._content:
|
|
256
|
+
body = response._content.decode("utf-8", errors="ignore")
|
|
257
|
+
else:
|
|
258
|
+
# Try to read the text (this might be already consumed)
|
|
259
|
+
try:
|
|
260
|
+
body = response.text
|
|
261
|
+
except Exception:
|
|
262
|
+
return False
|
|
263
|
+
|
|
264
|
+
# Look for Cloudflare and 400 Bad Request markers
|
|
265
|
+
body_lower = body.lower()
|
|
266
|
+
return "cloudflare" in body_lower and "400 bad request" in body_lower
|
|
267
|
+
except Exception as exc:
|
|
268
|
+
logger.debug("Error checking for Cloudflare error: %s", exc)
|
|
269
|
+
return False
|
|
270
|
+
|
|
271
|
+
def _refresh_claude_oauth_token(self) -> str | None:
|
|
272
|
+
try:
|
|
273
|
+
from code_puppy.plugins.claude_code_oauth.utils import refresh_access_token
|
|
274
|
+
|
|
275
|
+
logger.info("Attempting to refresh Claude Code OAuth token...")
|
|
276
|
+
refreshed_token = refresh_access_token(force=True)
|
|
277
|
+
if refreshed_token:
|
|
278
|
+
self._update_auth_headers(self.headers, refreshed_token)
|
|
279
|
+
logger.info("Successfully refreshed Claude Code OAuth token")
|
|
280
|
+
else:
|
|
281
|
+
logger.warning("Token refresh returned None")
|
|
282
|
+
return refreshed_token
|
|
283
|
+
except Exception as exc:
|
|
284
|
+
logger.error("Exception during token refresh: %s", exc)
|
|
285
|
+
return None
|
|
286
|
+
|
|
81
287
|
@staticmethod
|
|
82
288
|
def _inject_cache_control(body: bytes) -> bytes | None:
|
|
83
289
|
try:
|