kimi-cli 0.42__py3-none-any.whl → 0.44__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kimi-cli might be problematic. Click here for more details.

kimi_cli/soul/kimisoul.py CHANGED
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING
6
6
  import kosong
7
7
  import tenacity
8
8
  from kosong import StepResult
9
- from kosong.base.message import Message
9
+ from kosong.base.message import ContentPart, ImageURLPart, Message
10
10
  from kosong.chat_provider import (
11
11
  APIConnectionError,
12
12
  APIStatusError,
@@ -16,7 +16,14 @@ from kosong.chat_provider import (
16
16
  from kosong.tooling import ToolResult
17
17
  from tenacity import RetryCallState, retry_if_exception, stop_after_attempt, wait_exponential_jitter
18
18
 
19
- from kimi_cli.soul import LLMNotSet, MaxStepsReached, Soul, StatusSnapshot, wire_send
19
+ from kimi_cli.soul import (
20
+ LLMNotSet,
21
+ LLMNotSupported,
22
+ MaxStepsReached,
23
+ Soul,
24
+ StatusSnapshot,
25
+ wire_send,
26
+ )
20
27
  from kimi_cli.soul.agent import Agent
21
28
  from kimi_cli.soul.compaction import SimpleCompaction
22
29
  from kimi_cli.soul.context import Context
@@ -53,7 +60,6 @@ class KimiSoul(Soul):
53
60
  agent (Agent): The agent to run.
54
61
  runtime (Runtime): Runtime parameters and states.
55
62
  context (Context): The context of the agent.
56
- loop_control (LoopControl): The control parameters for the agent loop.
57
63
  """
58
64
  self._agent = agent
59
65
  self._runtime = runtime
@@ -85,6 +91,10 @@ class KimiSoul(Soul):
85
91
  def status(self) -> StatusSnapshot:
86
92
  return StatusSnapshot(context_usage=self._context_usage)
87
93
 
94
+ @property
95
+ def context(self) -> Context:
96
+ return self._context
97
+
88
98
  @property
89
99
  def _context_usage(self) -> float:
90
100
  if self._runtime.llm is not None:
@@ -94,10 +104,17 @@ class KimiSoul(Soul):
94
104
  async def _checkpoint(self):
95
105
  await self._context.checkpoint(self._checkpoint_with_user_message)
96
106
 
97
- async def run(self, user_input: str):
107
+ async def run(self, user_input: str | list[ContentPart]):
98
108
  if self._runtime.llm is None:
99
109
  raise LLMNotSet()
100
110
 
111
+ if (
112
+ isinstance(user_input, list)
113
+ and any(isinstance(part, ImageURLPart) for part in user_input)
114
+ and not self._runtime.llm.supports_image_in
115
+ ):
116
+ raise LLMNotSupported(self._runtime.llm, ["image_in"])
117
+
101
118
  await self._checkpoint() # this creates the checkpoint 0 on first run
102
119
  await self._context.append_message(Message(role="user", content=user_input))
103
120
  logger.debug("Appended user message to context")
kimi_cli/soul/message.py CHANGED
@@ -14,7 +14,7 @@ def tool_result_to_messages(tool_result: ToolResult) -> list[Message]:
14
14
  message = tool_result.result.message
15
15
  if isinstance(tool_result.result, ToolRuntimeError):
16
16
  message += "\nThis is an unexpected error and the tool is probably not working."
17
- content: list[ContentPart] = [system(message)]
17
+ content: list[ContentPart] = [system(f"ERROR: {message}")]
18
18
  if tool_result.result.output:
19
19
  content.append(TextPart(text=tool_result.result.output))
20
20
  return [
kimi_cli/soul/runtime.py CHANGED
@@ -1,3 +1,4 @@
1
+ import asyncio
1
2
  import subprocess
2
3
  import sys
3
4
  from datetime import datetime
@@ -59,7 +60,7 @@ def _list_work_dir(work_dir: Path) -> str:
59
60
 
60
61
 
61
62
  class Runtime(NamedTuple):
62
- """Agent globals."""
63
+ """Agent runtime."""
63
64
 
64
65
  config: Config
65
66
  llm: LLM | None
@@ -75,9 +76,10 @@ class Runtime(NamedTuple):
75
76
  session: Session,
76
77
  yolo: bool,
77
78
  ) -> "Runtime":
78
- # FIXME: do these asynchronously
79
- ls_output = _list_work_dir(session.work_dir)
80
- agents_md = load_agents_md(session.work_dir) or ""
79
+ ls_output, agents_md = await asyncio.gather(
80
+ asyncio.to_thread(_list_work_dir, session.work_dir),
81
+ asyncio.to_thread(load_agents_md, session.work_dir),
82
+ )
81
83
 
82
84
  return Runtime(
83
85
  config=config,
@@ -87,7 +89,7 @@ class Runtime(NamedTuple):
87
89
  KIMI_NOW=datetime.now().astimezone().isoformat(),
88
90
  KIMI_WORK_DIR=session.work_dir,
89
91
  KIMI_WORK_DIR_LS=ls_output,
90
- KIMI_AGENTS_MD=agents_md,
92
+ KIMI_AGENTS_MD=agents_md or "",
91
93
  ),
92
94
  denwa_renji=DenwaRenji(),
93
95
  approval=Approval(yolo=yolo),
@@ -68,7 +68,8 @@ class Task(CallableTool2[Params]):
68
68
  self._subagents: dict[str, Agent] = {}
69
69
 
70
70
  try:
71
- self._load_task = asyncio.create_task(self._load_subagents(agent_spec.subagents))
71
+ loop = asyncio.get_running_loop()
72
+ self._load_task = loop.create_task(self._load_subagents(agent_spec.subagents))
72
73
  except RuntimeError:
73
74
  # In case there's no running event loop, e.g., during synchronous tests
74
75
  self._load_task = None
@@ -44,7 +44,7 @@ class SearchWeb(CallableTool2[Params]):
44
44
  if config.services.moonshot_search is not None:
45
45
  self._base_url = config.services.moonshot_search.base_url
46
46
  self._api_key = config.services.moonshot_search.api_key.get_secret_value()
47
- self._custom_headers = config.services.moonshot_search.custom_headers
47
+ self._custom_headers = config.services.moonshot_search.custom_headers or {}
48
48
  else:
49
49
  self._base_url = ""
50
50
  self._api_key = ""
@@ -172,33 +172,29 @@ class ACPAgent:
172
172
  self.run_state.cancel_event.set()
173
173
 
174
174
  async def _stream_events(self, wire: WireUISide):
175
- try:
176
- # expect a StepBegin
177
- assert isinstance(await wire.receive(), StepBegin)
178
-
179
- while True:
180
- msg = await wire.receive()
181
-
182
- if isinstance(msg, TextPart):
183
- await self._send_text(msg.text)
184
- elif isinstance(msg, ContentPart):
185
- logger.warning("Unsupported content part: {part}", part=msg)
186
- await self._send_text(f"[{msg.__class__.__name__}]")
187
- elif isinstance(msg, ToolCall):
188
- await self._send_tool_call(msg)
189
- elif isinstance(msg, ToolCallPart):
190
- await self._send_tool_call_part(msg)
191
- elif isinstance(msg, ToolResult):
192
- await self._send_tool_result(msg)
193
- elif isinstance(msg, ApprovalRequest):
194
- await self._handle_approval_request(msg)
195
- elif isinstance(msg, StatusUpdate):
196
- # TODO: stream status if needed
197
- pass
198
- elif isinstance(msg, StepInterrupted):
199
- break
200
- except asyncio.QueueShutDown:
201
- logger.debug("Event stream loop shutting down")
175
+ assert isinstance(await wire.receive(), StepBegin)
176
+
177
+ while True:
178
+ msg = await wire.receive()
179
+
180
+ if isinstance(msg, TextPart):
181
+ await self._send_text(msg.text)
182
+ elif isinstance(msg, ContentPart):
183
+ logger.warning("Unsupported content part: {part}", part=msg)
184
+ await self._send_text(f"[{msg.__class__.__name__}]")
185
+ elif isinstance(msg, ToolCall):
186
+ await self._send_tool_call(msg)
187
+ elif isinstance(msg, ToolCallPart):
188
+ await self._send_tool_call_part(msg)
189
+ elif isinstance(msg, ToolResult):
190
+ await self._send_tool_result(msg)
191
+ elif isinstance(msg, ApprovalRequest):
192
+ await self._handle_approval_request(msg)
193
+ elif isinstance(msg, StatusUpdate):
194
+ # TODO: stream status if needed
195
+ pass
196
+ elif isinstance(msg, StepInterrupted):
197
+ break
202
198
 
203
199
  async def _send_text(self, text: str):
204
200
  """Send text chunk to client."""
@@ -321,7 +317,7 @@ class ACPAgent:
321
317
  # Create permission request with options
322
318
  permission_request = acp.RequestPermissionRequest(
323
319
  sessionId=self.session_id,
324
- toolCall=acp.schema.ToolCallUpdate(
320
+ toolCall=acp.schema.ToolCall(
325
321
  toolCallId=state.acp_tool_call_id,
326
322
  content=[
327
323
  acp.schema.ContentToolCallContent(
@@ -1,24 +1,22 @@
1
1
  import asyncio
2
2
  import json
3
- import signal
4
3
  import sys
5
4
  from functools import partial
6
5
  from pathlib import Path
7
- from typing import Literal
8
6
 
9
7
  import aiofiles
10
8
  from kosong.base.message import Message
11
9
  from kosong.chat_provider import ChatProviderError
10
+ from rich import print
12
11
 
12
+ from kimi_cli.cli import InputFormat, OutputFormat
13
13
  from kimi_cli.soul import LLMNotSet, MaxStepsReached, RunCancelled, Soul, run_soul
14
14
  from kimi_cli.utils.logging import logger
15
15
  from kimi_cli.utils.message import message_extract_text
16
+ from kimi_cli.utils.signals import install_sigint_handler
16
17
  from kimi_cli.wire import WireUISide
17
18
  from kimi_cli.wire.message import StepInterrupted
18
19
 
19
- InputFormat = Literal["text", "stream-json"]
20
- OutputFormat = Literal["text", "stream-json"]
21
-
22
20
 
23
21
  class PrintApp:
24
22
  """
@@ -51,7 +49,7 @@ class PrintApp:
51
49
  cancel_event.set()
52
50
 
53
51
  loop = asyncio.get_running_loop()
54
- loop.add_signal_handler(signal.SIGINT, _handler)
52
+ remove_sigint = install_sigint_handler(loop, _handler)
55
53
 
56
54
  if command is None and not sys.stdin.isatty() and self.input_format == "text":
57
55
  command = sys.stdin.read().strip()
@@ -98,7 +96,7 @@ class PrintApp:
98
96
  print(f"Unknown error: {e}")
99
97
  raise
100
98
  finally:
101
- loop.remove_signal_handler(signal.SIGINT)
99
+ remove_sigint()
102
100
  return False
103
101
 
104
102
  def _read_next_command(self) -> str | None:
@@ -127,35 +125,29 @@ class PrintApp:
127
125
  logger.warning("Ignoring invalid user message: {json_line}", json_line=json_line)
128
126
 
129
127
  async def _visualize_text(self, wire: WireUISide):
130
- try:
131
- while True:
132
- msg = await wire.receive()
133
- print(msg)
134
- if isinstance(msg, StepInterrupted):
135
- break
136
- except asyncio.QueueShutDown:
137
- logger.debug("Visualization loop shutting down")
128
+ while True:
129
+ msg = await wire.receive()
130
+ print(msg)
131
+ if isinstance(msg, StepInterrupted):
132
+ break
138
133
 
139
134
  async def _visualize_stream_json(self, wire: WireUISide, start_position: int):
140
135
  # TODO: be aware of context compaction
141
136
  # FIXME: this is only a temporary impl, may miss the last lines of the context file
142
137
  if not self.context_file.exists():
143
138
  self.context_file.touch()
144
- try:
145
- async with aiofiles.open(self.context_file, encoding="utf-8") as f:
146
- await f.seek(start_position)
147
- while True:
148
- should_end = False
149
- while (msg := wire.receive_nowait()) is not None:
150
- if isinstance(msg, StepInterrupted):
151
- should_end = True
152
-
153
- line = await f.readline()
154
- if not line:
155
- if should_end:
156
- break
157
- await asyncio.sleep(0.1)
158
- continue
159
- print(line, end="")
160
- except asyncio.QueueShutDown:
161
- logger.debug("Visualization loop shutting down")
139
+ async with aiofiles.open(self.context_file, encoding="utf-8") as f:
140
+ await f.seek(start_position)
141
+ while True:
142
+ should_end = False
143
+ while (msg := wire.receive_nowait()) is not None:
144
+ if isinstance(msg, StepInterrupted):
145
+ should_end = True
146
+
147
+ line = await f.readline()
148
+ if not line:
149
+ if should_end:
150
+ break
151
+ await asyncio.sleep(0.1)
152
+ continue
153
+ print(line, end="")
@@ -1,28 +1,32 @@
1
1
  import asyncio
2
- import signal
3
2
  from collections.abc import Awaitable, Coroutine
3
+ from dataclasses import dataclass
4
+ from enum import Enum
4
5
  from typing import Any
5
6
 
7
+ from kosong.base.message import ContentPart
6
8
  from kosong.chat_provider import APIStatusError, ChatProviderError
7
9
  from rich.console import Group, RenderableType
8
10
  from rich.panel import Panel
9
11
  from rich.table import Table
10
12
  from rich.text import Text
11
13
 
12
- from kimi_cli.soul import LLMNotSet, MaxStepsReached, RunCancelled, Soul, run_soul
14
+ from kimi_cli.soul import LLMNotSet, LLMNotSupported, MaxStepsReached, RunCancelled, Soul, run_soul
13
15
  from kimi_cli.soul.kimisoul import KimiSoul
14
16
  from kimi_cli.ui.shell.console import console
15
17
  from kimi_cli.ui.shell.metacmd import get_meta_command
16
- from kimi_cli.ui.shell.prompt import CustomPromptSession, PromptMode, toast
18
+ from kimi_cli.ui.shell.prompt import CustomPromptSession, PromptMode, ensure_new_line, toast
19
+ from kimi_cli.ui.shell.replay import replay_recent_history
17
20
  from kimi_cli.ui.shell.update import LATEST_VERSION_FILE, UpdateResult, do_update, semver_tuple
18
21
  from kimi_cli.ui.shell.visualize import visualize
19
22
  from kimi_cli.utils.logging import logger
23
+ from kimi_cli.utils.signals import install_sigint_handler
20
24
 
21
25
 
22
26
  class ShellApp:
23
- def __init__(self, soul: Soul, welcome_info: dict[str, str] | None = None):
27
+ def __init__(self, soul: Soul, welcome_info: list["WelcomeInfoItem"] | None = None):
24
28
  self.soul = soul
25
- self.welcome_info = welcome_info or {}
29
+ self._welcome_info = list(welcome_info or [])
26
30
  self._background_tasks: set[asyncio.Task[Any]] = set()
27
31
 
28
32
  async def run(self, command: str | None = None) -> bool:
@@ -33,11 +37,15 @@ class ShellApp:
33
37
 
34
38
  self._start_background_task(self._auto_update())
35
39
 
36
- _print_welcome_info(self.soul.name or "Kimi CLI", self.soul.model, self.welcome_info)
40
+ _print_welcome_info(self.soul.name or "Kimi CLI", self._welcome_info)
41
+
42
+ if isinstance(self.soul, KimiSoul):
43
+ await replay_recent_history(self.soul.context.history)
37
44
 
38
45
  with CustomPromptSession(lambda: self.soul.status) as prompt_session:
39
46
  while True:
40
47
  try:
48
+ ensure_new_line()
41
49
  user_input = await prompt_session.prompt()
42
50
  except KeyboardInterrupt:
43
51
  logger.debug("Exiting by KeyboardInterrupt")
@@ -62,14 +70,13 @@ class ShellApp:
62
70
  await self._run_shell_command(user_input.command)
63
71
  continue
64
72
 
65
- command = user_input.command
66
- if command.startswith("/"):
67
- logger.debug("Running meta command: {command}", command=command)
68
- await self._run_meta_command(command[1:])
73
+ if user_input.command.startswith("/"):
74
+ logger.debug("Running meta command: {command}", command=user_input.command)
75
+ await self._run_meta_command(user_input.command[1:])
69
76
  continue
70
77
 
71
- logger.info("Running agent command: {command}", command=command)
72
- await self._run_soul_command(command)
78
+ logger.info("Running agent command: {command}", command=user_input.content)
79
+ await self._run_soul_command(user_input.content)
73
80
 
74
81
  return True
75
82
 
@@ -79,24 +86,26 @@ class ShellApp:
79
86
  return
80
87
 
81
88
  logger.info("Running shell command: {cmd}", cmd=command)
89
+
90
+ proc: asyncio.subprocess.Process | None = None
91
+
92
+ def _handler():
93
+ logger.debug("SIGINT received.")
94
+ if proc:
95
+ proc.terminate()
96
+
82
97
  loop = asyncio.get_running_loop()
98
+ remove_sigint = install_sigint_handler(loop, _handler)
83
99
  try:
84
100
  # TODO: For the sake of simplicity, we now use `create_subprocess_shell`.
85
101
  # Later we should consider making this behave like a real shell.
86
102
  proc = await asyncio.create_subprocess_shell(command)
87
-
88
- def _handler():
89
- logger.debug("SIGINT received.")
90
- proc.terminate()
91
-
92
- loop.add_signal_handler(signal.SIGINT, _handler)
93
-
94
103
  await proc.wait()
95
104
  except Exception as e:
96
105
  logger.exception("Failed to run shell command:")
97
106
  console.print(f"[red]Failed to run shell command: {e}[/red]")
98
107
  finally:
99
- loop.remove_signal_handler(signal.SIGINT)
108
+ remove_sigint()
100
109
 
101
110
  async def _run_meta_command(self, command_str: str):
102
111
  from kimi_cli.cli import Reload
@@ -137,7 +146,7 @@ class ShellApp:
137
146
  console.print(f"[red]Unknown error: {e}[/red]")
138
147
  raise # re-raise unknown error
139
148
 
140
- async def _run_soul_command(self, command: str) -> bool:
149
+ async def _run_soul_command(self, user_input: str | list[ContentPart]) -> bool:
141
150
  """
142
151
  Run the soul and handle any known exceptions.
143
152
 
@@ -151,13 +160,13 @@ class ShellApp:
151
160
  cancel_event.set()
152
161
 
153
162
  loop = asyncio.get_running_loop()
154
- loop.add_signal_handler(signal.SIGINT, _handler)
163
+ remove_sigint = install_sigint_handler(loop, _handler)
155
164
 
156
165
  try:
157
166
  # Use lambda to pass cancel_event via closure
158
167
  await run_soul(
159
168
  self.soul,
160
- command,
169
+ user_input,
161
170
  lambda wire: visualize(
162
171
  wire, initial_status=self.soul.status, cancel_event=cancel_event
163
172
  ),
@@ -167,6 +176,13 @@ class ShellApp:
167
176
  except LLMNotSet:
168
177
  logger.error("LLM not set")
169
178
  console.print("[red]LLM not set, send /setup to configure[/red]")
179
+ except LLMNotSupported as e:
180
+ logger.error(
181
+ "LLM model '{model_name}' does not support required capabilities: {capabilities}",
182
+ model_name=e.llm.model_name,
183
+ capabilities=", ".join(e.capabilities),
184
+ )
185
+ console.print(f"[red]{e}[/red]")
170
186
  except ChatProviderError as e:
171
187
  logger.exception("LLM provider error:")
172
188
  if isinstance(e, APIStatusError) and e.status_code == 401:
@@ -188,7 +204,7 @@ class ShellApp:
188
204
  console.print(f"[red]Unknown error: {e}[/red]")
189
205
  raise # re-raise unknown error
190
206
  finally:
191
- loop.remove_signal_handler(signal.SIGINT)
207
+ remove_sigint()
192
208
  return False
193
209
 
194
210
  async def _auto_update(self) -> None:
@@ -227,7 +243,19 @@ _LOGO = f"""\
227
243
  """
228
244
 
229
245
 
230
- def _print_welcome_info(name: str, model: str, info_items: dict[str, str]) -> None:
246
+ @dataclass(slots=True)
247
+ class WelcomeInfoItem:
248
+ class Level(Enum):
249
+ INFO = "grey50"
250
+ WARN = "yellow"
251
+ ERROR = "red"
252
+
253
+ name: str
254
+ value: str
255
+ level: Level = Level.INFO
256
+
257
+
258
+ def _print_welcome_info(name: str, info_items: list[WelcomeInfoItem]) -> None:
231
259
  head = Text.from_markup(f"[bold]Welcome to {name}![/bold]")
232
260
  help_text = Text.from_markup("[grey50]Send /help for help information.[/grey50]")
233
261
 
@@ -241,17 +269,8 @@ def _print_welcome_info(name: str, model: str, info_items: dict[str, str]) -> No
241
269
  rows: list[RenderableType] = [table]
242
270
 
243
271
  rows.append(Text("")) # Empty line
244
- rows.extend(
245
- Text.from_markup(f"[grey50]{key}: {value}[/grey50]") for key, value in info_items.items()
246
- )
247
- if model:
248
- rows.append(Text.from_markup(f"[grey50]Model: {model}[/grey50]"))
249
- else:
250
- rows.append(
251
- Text.from_markup(
252
- "[grey50]Model:[/grey50] [yellow]not set, send /setup to configure[/yellow]"
253
- )
254
- )
272
+ for item in info_items:
273
+ rows.append(Text(f"{item.name}: {item.value}", style=item.level.value))
255
274
 
256
275
  if LATEST_VERSION_FILE.exists():
257
276
  from kimi_cli.constant import VERSION as current_version
@@ -1,6 +1,5 @@
1
1
  import asyncio
2
2
  import sys
3
- import termios
4
3
  import threading
5
4
  import time
6
5
  from collections.abc import AsyncGenerator, Callable
@@ -47,6 +46,21 @@ def _listen_for_keyboard_thread(
47
46
  cancel: threading.Event,
48
47
  emit: Callable[[KeyEvent], None],
49
48
  ) -> None:
49
+ if sys.platform == "win32":
50
+ _listen_for_keyboard_windows(cancel, emit)
51
+ else:
52
+ _listen_for_keyboard_unix(cancel, emit)
53
+
54
+
55
+ def _listen_for_keyboard_unix(
56
+ cancel: threading.Event,
57
+ emit: Callable[[KeyEvent], None],
58
+ ) -> None:
59
+ if sys.platform == "win32":
60
+ raise RuntimeError("Unix keyboard listener requires a non-Windows platform")
61
+
62
+ import termios
63
+
50
64
  # make stdin raw and non-blocking
51
65
  fd = sys.stdin.fileno()
52
66
  oldterm = termios.tcgetattr(fd)
@@ -59,9 +73,9 @@ def _listen_for_keyboard_thread(
59
73
  try:
60
74
  while not cancel.is_set():
61
75
  try:
62
- c = sys.stdin.read(1)
76
+ c = sys.stdin.buffer.read(1)
63
77
  except (OSError, ValueError):
64
- c = ""
78
+ c = b""
65
79
 
66
80
  if not c:
67
81
  if cancel.is_set():
@@ -69,15 +83,15 @@ def _listen_for_keyboard_thread(
69
83
  time.sleep(0.01)
70
84
  continue
71
85
 
72
- if c == "\x1b":
86
+ if c == b"\x1b":
73
87
  sequence = c
74
88
  for _ in range(2):
75
89
  if cancel.is_set():
76
90
  break
77
91
  try:
78
- fragment = sys.stdin.read(1)
92
+ fragment = sys.stdin.buffer.read(1)
79
93
  except (OSError, ValueError):
80
- fragment = ""
94
+ fragment = b""
81
95
  if not fragment:
82
96
  break
83
97
  sequence += fragment
@@ -87,22 +101,76 @@ def _listen_for_keyboard_thread(
87
101
  event = _ARROW_KEY_MAP.get(sequence)
88
102
  if event is not None:
89
103
  emit(event)
90
- elif sequence == "\x1b":
104
+ elif sequence == b"\x1b":
91
105
  emit(KeyEvent.ESCAPE)
92
- elif c in ("\r", "\n"):
106
+ elif c in (b"\r", b"\n"):
93
107
  emit(KeyEvent.ENTER)
94
- elif c == "\t":
108
+ elif c == b"\t":
95
109
  emit(KeyEvent.TAB)
96
110
  finally:
97
111
  # restore the terminal settings
98
112
  termios.tcsetattr(fd, termios.TCSAFLUSH, oldterm)
99
113
 
100
114
 
101
- _ARROW_KEY_MAP: dict[str, KeyEvent] = {
102
- "\x1b[A": KeyEvent.UP,
103
- "\x1b[B": KeyEvent.DOWN,
104
- "\x1b[C": KeyEvent.RIGHT,
105
- "\x1b[D": KeyEvent.LEFT,
115
+ def _listen_for_keyboard_windows(
116
+ cancel: threading.Event,
117
+ emit: Callable[[KeyEvent], None],
118
+ ) -> None:
119
+ if sys.platform != "win32":
120
+ raise RuntimeError("Windows keyboard listener requires a Windows platform")
121
+
122
+ import msvcrt
123
+
124
+ while not cancel.is_set():
125
+ if msvcrt.kbhit():
126
+ c = msvcrt.getch()
127
+
128
+ # Handle special keys (arrow keys, etc.)
129
+ if c in (b"\x00", b"\xe0"):
130
+ # Extended key, read the next byte
131
+ extended = msvcrt.getch()
132
+ event = _WINDOWS_KEY_MAP.get(extended)
133
+ if event is not None:
134
+ emit(event)
135
+ elif c == b"\x1b":
136
+ sequence = c
137
+ for _ in range(2):
138
+ if cancel.is_set():
139
+ break
140
+ fragment = msvcrt.getch() if msvcrt.kbhit() else b""
141
+ if not fragment:
142
+ break
143
+ sequence += fragment
144
+ if sequence in _ARROW_KEY_MAP:
145
+ break
146
+
147
+ event = _ARROW_KEY_MAP.get(sequence)
148
+ if event is not None:
149
+ emit(event)
150
+ elif sequence == b"\x1b":
151
+ emit(KeyEvent.ESCAPE)
152
+ elif c in (b"\r", b"\n"):
153
+ emit(KeyEvent.ENTER)
154
+ elif c == b"\t":
155
+ emit(KeyEvent.TAB)
156
+ else:
157
+ if cancel.is_set():
158
+ break
159
+ time.sleep(0.01)
160
+
161
+
162
+ _ARROW_KEY_MAP: dict[bytes, KeyEvent] = {
163
+ b"\x1b[A": KeyEvent.UP,
164
+ b"\x1b[B": KeyEvent.DOWN,
165
+ b"\x1b[C": KeyEvent.RIGHT,
166
+ b"\x1b[D": KeyEvent.LEFT,
167
+ }
168
+
169
+ _WINDOWS_KEY_MAP: dict[bytes, KeyEvent] = {
170
+ b"H": KeyEvent.UP, # Up arrow
171
+ b"P": KeyEvent.DOWN, # Down arrow
172
+ b"M": KeyEvent.RIGHT, # Right arrow
173
+ b"K": KeyEvent.LEFT, # Left arrow
106
174
  }
107
175
 
108
176