bear-utils 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bear_utils/__init__.py +51 -0
- bear_utils/__main__.py +14 -0
- bear_utils/_internal/__init__.py +0 -0
- bear_utils/_internal/_version.py +1 -0
- bear_utils/_internal/cli.py +119 -0
- bear_utils/_internal/debug.py +174 -0
- bear_utils/ai/__init__.py +30 -0
- bear_utils/ai/ai_helpers/__init__.py +136 -0
- bear_utils/ai/ai_helpers/_common.py +19 -0
- bear_utils/ai/ai_helpers/_config.py +24 -0
- bear_utils/ai/ai_helpers/_parsers.py +194 -0
- bear_utils/ai/ai_helpers/_types.py +15 -0
- bear_utils/cache/__init__.py +131 -0
- bear_utils/cli/__init__.py +22 -0
- bear_utils/cli/_args.py +12 -0
- bear_utils/cli/_get_version.py +207 -0
- bear_utils/cli/commands.py +105 -0
- bear_utils/cli/prompt_helpers.py +186 -0
- bear_utils/cli/shell/__init__.py +1 -0
- bear_utils/cli/shell/_base_command.py +81 -0
- bear_utils/cli/shell/_base_shell.py +430 -0
- bear_utils/cli/shell/_common.py +19 -0
- bear_utils/cli/typer_bridge.py +90 -0
- bear_utils/config/__init__.py +13 -0
- bear_utils/config/config_manager.py +229 -0
- bear_utils/config/dir_manager.py +69 -0
- bear_utils/config/settings_manager.py +179 -0
- bear_utils/constants/__init__.py +90 -0
- bear_utils/constants/_exceptions.py +8 -0
- bear_utils/constants/_exit_code.py +60 -0
- bear_utils/constants/_http_status_code.py +37 -0
- bear_utils/constants/_lazy_typing.py +15 -0
- bear_utils/constants/_meta.py +196 -0
- bear_utils/constants/date_related.py +25 -0
- bear_utils/constants/time_related.py +24 -0
- bear_utils/database/__init__.py +8 -0
- bear_utils/database/_db_manager.py +98 -0
- bear_utils/events/__init__.py +18 -0
- bear_utils/events/events_class.py +52 -0
- bear_utils/events/events_module.py +74 -0
- bear_utils/extras/__init__.py +28 -0
- bear_utils/extras/_async_helpers.py +67 -0
- bear_utils/extras/_tools.py +185 -0
- bear_utils/extras/_zapper.py +399 -0
- bear_utils/extras/platform_utils.py +57 -0
- bear_utils/extras/responses/__init__.py +5 -0
- bear_utils/extras/responses/function_response.py +451 -0
- bear_utils/extras/wrappers/__init__.py +1 -0
- bear_utils/extras/wrappers/add_methods.py +100 -0
- bear_utils/extras/wrappers/string_io.py +46 -0
- bear_utils/files/__init__.py +6 -0
- bear_utils/files/file_handlers/__init__.py +5 -0
- bear_utils/files/file_handlers/_base_file_handler.py +107 -0
- bear_utils/files/file_handlers/file_handler_factory.py +280 -0
- bear_utils/files/file_handlers/json_file_handler.py +71 -0
- bear_utils/files/file_handlers/log_file_handler.py +40 -0
- bear_utils/files/file_handlers/toml_file_handler.py +76 -0
- bear_utils/files/file_handlers/txt_file_handler.py +76 -0
- bear_utils/files/file_handlers/yaml_file_handler.py +64 -0
- bear_utils/files/ignore_parser.py +293 -0
- bear_utils/graphics/__init__.py +6 -0
- bear_utils/graphics/bear_gradient.py +145 -0
- bear_utils/graphics/font/__init__.py +13 -0
- bear_utils/graphics/font/_raw_block_letters.py +463 -0
- bear_utils/graphics/font/_theme.py +31 -0
- bear_utils/graphics/font/_utils.py +220 -0
- bear_utils/graphics/font/block_font.py +192 -0
- bear_utils/graphics/font/glitch_font.py +63 -0
- bear_utils/graphics/image_helpers.py +45 -0
- bear_utils/gui/__init__.py +8 -0
- bear_utils/gui/gui_tools/__init__.py +10 -0
- bear_utils/gui/gui_tools/_settings.py +36 -0
- bear_utils/gui/gui_tools/_types.py +12 -0
- bear_utils/gui/gui_tools/qt_app.py +150 -0
- bear_utils/gui/gui_tools/qt_color_picker.py +130 -0
- bear_utils/gui/gui_tools/qt_file_handler.py +130 -0
- bear_utils/gui/gui_tools/qt_input_dialog.py +303 -0
- bear_utils/logger_manager/__init__.py +109 -0
- bear_utils/logger_manager/_common.py +63 -0
- bear_utils/logger_manager/_console_junk.py +135 -0
- bear_utils/logger_manager/_log_level.py +50 -0
- bear_utils/logger_manager/_styles.py +95 -0
- bear_utils/logger_manager/logger_protocol.py +42 -0
- bear_utils/logger_manager/loggers/__init__.py +1 -0
- bear_utils/logger_manager/loggers/_console.py +223 -0
- bear_utils/logger_manager/loggers/_level_sin.py +61 -0
- bear_utils/logger_manager/loggers/_logger.py +19 -0
- bear_utils/logger_manager/loggers/base_logger.py +244 -0
- bear_utils/logger_manager/loggers/base_logger.pyi +51 -0
- bear_utils/logger_manager/loggers/basic_logger/__init__.py +5 -0
- bear_utils/logger_manager/loggers/basic_logger/logger.py +80 -0
- bear_utils/logger_manager/loggers/basic_logger/logger.pyi +19 -0
- bear_utils/logger_manager/loggers/buffer_logger.py +57 -0
- bear_utils/logger_manager/loggers/console_logger.py +278 -0
- bear_utils/logger_manager/loggers/console_logger.pyi +50 -0
- bear_utils/logger_manager/loggers/fastapi_logger.py +333 -0
- bear_utils/logger_manager/loggers/file_logger.py +151 -0
- bear_utils/logger_manager/loggers/simple_logger.py +98 -0
- bear_utils/logger_manager/loggers/sub_logger.py +105 -0
- bear_utils/logger_manager/loggers/sub_logger.pyi +23 -0
- bear_utils/monitoring/__init__.py +13 -0
- bear_utils/monitoring/_common.py +28 -0
- bear_utils/monitoring/host_monitor.py +346 -0
- bear_utils/time/__init__.py +59 -0
- bear_utils-0.0.1.dist-info/METADATA +305 -0
- bear_utils-0.0.1.dist-info/RECORD +107 -0
- bear_utils-0.0.1.dist-info/WHEEL +4 -0
@@ -0,0 +1,186 @@
|
|
1
|
+
"""Prompt Helpers Module for user input handling."""
|
2
|
+
|
3
|
+
from typing import Any, overload
|
4
|
+
|
5
|
+
from prompt_toolkit import prompt
|
6
|
+
from prompt_toolkit.completion import WordCompleter
|
7
|
+
from prompt_toolkit.validation import ValidationError, Validator
|
8
|
+
|
9
|
+
from bear_utils.constants._exceptions import UserCancelledError
|
10
|
+
from bear_utils.constants._lazy_typing import OptBool, OptFloat, OptInt, OptStr
|
11
|
+
from bear_utils.logger_manager import get_console
|
12
|
+
|
13
|
+
|
14
|
+
def _parse_exit(value: str) -> bool:
|
15
|
+
"""Parse a string into a boolean indicating if the user wants to exit."""
|
16
|
+
lower_value: str = value.lower().strip()
|
17
|
+
return lower_value in ("exit", "quit", "q")
|
18
|
+
|
19
|
+
|
20
|
+
def _parse_bool(value: str) -> bool:
|
21
|
+
"""Parse a string into a boolean value."""
|
22
|
+
lower_value: str = value.lower().strip()
|
23
|
+
if lower_value in ("true", "t", "yes", "y", "1"):
|
24
|
+
return True
|
25
|
+
if lower_value in ("false", "f", "no", "n", "0"):
|
26
|
+
return False
|
27
|
+
raise ValueError(f"Cannot convert '{value}' to boolean")
|
28
|
+
|
29
|
+
|
30
|
+
def _convert_value(value: str, target_type: type) -> str | int | float | bool:
|
31
|
+
"""Convert a string value to the target type."""
|
32
|
+
if target_type is str:
|
33
|
+
return value
|
34
|
+
if target_type is int:
|
35
|
+
return int(value)
|
36
|
+
if target_type is float:
|
37
|
+
return float(value)
|
38
|
+
if target_type is bool:
|
39
|
+
return _parse_bool(value)
|
40
|
+
raise ValueError(f"Unsupported type: {target_type}")
|
41
|
+
|
42
|
+
|
43
|
+
@overload
|
44
|
+
def ask_question(question: str, expected_type: type[bool], default: OptBool = None) -> bool: ...
|
45
|
+
|
46
|
+
|
47
|
+
@overload
|
48
|
+
def ask_question(question: str, expected_type: type[int], default: OptInt = None) -> int: ...
|
49
|
+
|
50
|
+
|
51
|
+
@overload
|
52
|
+
def ask_question(question: str, expected_type: type[float], default: OptFloat = None) -> float: ...
|
53
|
+
|
54
|
+
|
55
|
+
@overload
|
56
|
+
def ask_question(question: str, expected_type: type[str], default: OptStr = None) -> str: ...
|
57
|
+
|
58
|
+
|
59
|
+
def ask_question(question: str, expected_type: type, default: Any = None) -> Any:
|
60
|
+
"""Ask a question and return the answer, ensuring the entered type is correct.
|
61
|
+
|
62
|
+
This function will keep asking until it gets a valid response or the user cancels with Ctrl+C.
|
63
|
+
If the user cancels, a UserCancelledError is raised.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
question: The prompt question to display
|
67
|
+
expected_type: The expected type class (int, float, str, bool)
|
68
|
+
default: Default value if no input is provided
|
69
|
+
|
70
|
+
Returns:
|
71
|
+
The user's response in the expected type
|
72
|
+
|
73
|
+
Raises:
|
74
|
+
UserCancelledError: If the user cancels input with Ctrl+C
|
75
|
+
ValueError: If an unsupported type is specified
|
76
|
+
"""
|
77
|
+
console, _ = get_console("prompt_helpers.py")
|
78
|
+
|
79
|
+
try:
|
80
|
+
while True:
|
81
|
+
console.print(question)
|
82
|
+
response = prompt("> ").strip()
|
83
|
+
|
84
|
+
if not response:
|
85
|
+
if default is not None:
|
86
|
+
return default
|
87
|
+
console.error("Input required. Please enter a value.")
|
88
|
+
continue
|
89
|
+
try:
|
90
|
+
result: str | int | float | bool = _convert_value(response, expected_type)
|
91
|
+
console.verbose(f"{expected_type.__name__} detected")
|
92
|
+
return result
|
93
|
+
except ValueError as e:
|
94
|
+
console.error(f"Invalid input: {e}. Please enter a valid {expected_type.__name__}.")
|
95
|
+
|
96
|
+
except KeyboardInterrupt:
|
97
|
+
raise UserCancelledError("User cancelled input") from None
|
98
|
+
|
99
|
+
|
100
|
+
def ask_yes_no(question: str, default: bool | None = None) -> bool | None:
|
101
|
+
"""Ask a yes or no question and return the answer.
|
102
|
+
|
103
|
+
Args:
|
104
|
+
question: The prompt question to display
|
105
|
+
default: Default value if no input is provided
|
106
|
+
|
107
|
+
Returns:
|
108
|
+
True for yes, False for no, or None if user exits
|
109
|
+
"""
|
110
|
+
console, _ = get_console("prompt_helpers.py")
|
111
|
+
|
112
|
+
while True:
|
113
|
+
try:
|
114
|
+
response: str = prompt(f"{question}\n> ").strip().lower()
|
115
|
+
if not response:
|
116
|
+
if default is not None:
|
117
|
+
return default
|
118
|
+
console.print("Please enter 'yes', 'no', or 'exit'.")
|
119
|
+
continue
|
120
|
+
if _parse_exit(response):
|
121
|
+
return None
|
122
|
+
try:
|
123
|
+
return _parse_bool(response)
|
124
|
+
except ValueError:
|
125
|
+
console.print("Invalid input. Please enter 'yes', 'no', or 'exit'.", style="red")
|
126
|
+
except KeyboardInterrupt:
|
127
|
+
console.print("KeyboardInterrupt: Exiting the prompt.", style="yellow")
|
128
|
+
return None
|
129
|
+
|
130
|
+
|
131
|
+
def restricted_prompt(
|
132
|
+
question: str,
|
133
|
+
valid_options: list[str],
|
134
|
+
exit_command: str = "exit",
|
135
|
+
case_sensitive: bool = False,
|
136
|
+
) -> str | None:
|
137
|
+
"""Continuously prompt the user until they provide a valid response or exit.
|
138
|
+
|
139
|
+
Args:
|
140
|
+
question: The prompt question to display
|
141
|
+
valid_options: List of valid responses
|
142
|
+
exit_command: Command to exit the prompt (default: "exit")
|
143
|
+
case_sensitive: Whether options are case-sensitive (default: False)
|
144
|
+
|
145
|
+
Returns:
|
146
|
+
The user's response or None if they chose to exit
|
147
|
+
"""
|
148
|
+
console, _ = get_console("prompt_helpers.py")
|
149
|
+
completer_options: list[str] = [*valid_options, exit_command]
|
150
|
+
completer = WordCompleter(completer_options)
|
151
|
+
|
152
|
+
comparison_options: list[str] = valid_options if case_sensitive else [opt.lower() for opt in valid_options]
|
153
|
+
comparison_exit: str = exit_command if case_sensitive else exit_command.lower()
|
154
|
+
|
155
|
+
class OptionValidator(Validator):
|
156
|
+
def validate(self, document: Any) -> None:
|
157
|
+
"""Validate the user's input against the valid options."""
|
158
|
+
text: Any = document.text if case_sensitive else document.text.lower()
|
159
|
+
if text and text != comparison_exit and text not in comparison_options:
|
160
|
+
raise ValidationError(
|
161
|
+
message=f"Invalid option. Choose from: {', '.join(valid_options)} or '{exit_command}'",
|
162
|
+
cursor_position=len(document.text),
|
163
|
+
)
|
164
|
+
|
165
|
+
try:
|
166
|
+
while True:
|
167
|
+
response: str = prompt(
|
168
|
+
f"{question}\n> ",
|
169
|
+
completer=completer,
|
170
|
+
validator=OptionValidator(),
|
171
|
+
complete_while_typing=True,
|
172
|
+
).strip()
|
173
|
+
comparison_response: str = response if case_sensitive else response.lower()
|
174
|
+
if not response:
|
175
|
+
console.print("Please enter a valid option or 'exit'.", style="red")
|
176
|
+
continue
|
177
|
+
if comparison_response == comparison_exit:
|
178
|
+
return None
|
179
|
+
if comparison_response in comparison_options:
|
180
|
+
if not case_sensitive:
|
181
|
+
idx: int = comparison_options.index(comparison_response)
|
182
|
+
return valid_options[idx]
|
183
|
+
return response
|
184
|
+
except KeyboardInterrupt:
|
185
|
+
console.print("KeyboardInterrupt: Exiting the prompt.", style="yellow")
|
186
|
+
return None
|
@@ -0,0 +1 @@
|
|
1
|
+
"""Shell utilities for bear_utils CLI."""
|
@@ -0,0 +1,81 @@
|
|
1
|
+
from subprocess import CompletedProcess
|
2
|
+
from typing import Any, ClassVar, Self
|
3
|
+
|
4
|
+
|
5
|
+
class BaseShellCommand[T: str]:
|
6
|
+
"""Base class for typed shell commands compatible with session systems"""
|
7
|
+
|
8
|
+
command_name: ClassVar[str] = ""
|
9
|
+
|
10
|
+
def __init__(self, *args, **kwargs) -> None:
|
11
|
+
self.sub_command: str = kwargs.get("sub_command", "")
|
12
|
+
self.args = args
|
13
|
+
self.kwargs: dict[str, Any] = kwargs
|
14
|
+
self.suffix = kwargs.get("suffix", "")
|
15
|
+
self.result: CompletedProcess[str] | None = None
|
16
|
+
|
17
|
+
def __str__(self) -> str:
|
18
|
+
"""String representation of the command"""
|
19
|
+
return self.cmd
|
20
|
+
|
21
|
+
def value(self, v: str) -> Self:
|
22
|
+
"""Add value to the export command"""
|
23
|
+
self.suffix: str = v
|
24
|
+
return self
|
25
|
+
|
26
|
+
@classmethod
|
27
|
+
def adhoc(cls, name: T, *args, **kwargs) -> "BaseShellCommand[T]":
|
28
|
+
"""Create an ad-hoc command class for a specific command
|
29
|
+
|
30
|
+
Args:
|
31
|
+
name (str): The name of the command to create
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
BaseShellCommand: An instance of the ad-hoc command class.
|
35
|
+
"""
|
36
|
+
return type(
|
37
|
+
f"AdHoc{name.title()}Command",
|
38
|
+
(cls,),
|
39
|
+
{"command_name": name},
|
40
|
+
)(*args, **kwargs)
|
41
|
+
|
42
|
+
@classmethod
|
43
|
+
def sub(cls, s: str, *args, **kwargs) -> Self:
|
44
|
+
"""Set a sub-command for the shell command"""
|
45
|
+
return cls(s, *args, **kwargs)
|
46
|
+
|
47
|
+
@property
|
48
|
+
def cmd(self) -> str:
|
49
|
+
"""Return the full command as a string"""
|
50
|
+
cmd_parts: list[str] = [self.command_name, self.sub_command, *self.args]
|
51
|
+
cmd_parts: list[str] = [part for part in cmd_parts if part]
|
52
|
+
joined: str = " ".join(cmd_parts).strip()
|
53
|
+
if self.suffix:
|
54
|
+
return f"{joined} {self.suffix}"
|
55
|
+
return joined
|
56
|
+
|
57
|
+
def do(self, **kwargs) -> Self:
|
58
|
+
"""Run the command using subprocess"""
|
59
|
+
from ._base_shell import shell_session # noqa: PLC0415
|
60
|
+
|
61
|
+
with shell_session(**kwargs) as session:
|
62
|
+
result: CompletedProcess[str] = session.add(self.cmd).run()
|
63
|
+
if result is not None:
|
64
|
+
self.result = result
|
65
|
+
return self
|
66
|
+
|
67
|
+
def get_result(self) -> CompletedProcess[str]:
|
68
|
+
"""Get the result of the command execution"""
|
69
|
+
if self.result is None:
|
70
|
+
self.do()
|
71
|
+
if self.result is None:
|
72
|
+
raise RuntimeError("Command execution failed for some reason.")
|
73
|
+
return self.result
|
74
|
+
|
75
|
+
def get(self) -> str:
|
76
|
+
"""Get the result of the command execution"""
|
77
|
+
if self.result is None:
|
78
|
+
self.do()
|
79
|
+
if self.result is None:
|
80
|
+
raise RuntimeError("Command execution failed for some reason.")
|
81
|
+
return str(self.result.stdout).strip()
|
@@ -0,0 +1,430 @@
|
|
1
|
+
import asyncio
|
2
|
+
from asyncio.streams import StreamReader
|
3
|
+
from asyncio.subprocess import Process
|
4
|
+
from collections import deque
|
5
|
+
from collections.abc import AsyncGenerator, Callable, Generator
|
6
|
+
from contextlib import asynccontextmanager, contextmanager
|
7
|
+
from io import StringIO
|
8
|
+
from logging import INFO
|
9
|
+
import os
|
10
|
+
from pathlib import Path
|
11
|
+
import shlex
|
12
|
+
import subprocess
|
13
|
+
from subprocess import CompletedProcess
|
14
|
+
from typing import Self, override
|
15
|
+
|
16
|
+
from bear_utils.constants import ExitCode
|
17
|
+
from bear_utils.logger_manager import VERBOSE, BaseLogger, SubConsoleLogger
|
18
|
+
from bear_utils.logger_manager.logger_protocol import LoggerProtocol
|
19
|
+
|
20
|
+
from ._base_command import BaseShellCommand
|
21
|
+
from ._common import DEFAULT_SHELL
|
22
|
+
|
23
|
+
|
24
|
+
class FancyCompletedProcess(CompletedProcess[str]):
|
25
|
+
def __init__(self, args: list[str], returncode: int, stdout: str | None = None, stderr: str | None = None) -> None:
|
26
|
+
"""Initialize with custom attributes for better readability"""
|
27
|
+
super().__init__(args=args, returncode=returncode, stdout=stdout, stderr=stderr)
|
28
|
+
|
29
|
+
def __repr__(self) -> str:
|
30
|
+
"""Custom representation for better readability"""
|
31
|
+
args: list[str] = [
|
32
|
+
f"args={self.args!r}",
|
33
|
+
f"returncode={self.returncode!r}",
|
34
|
+
f"exit_message={self.exit_message!r}",
|
35
|
+
f"stdout={self.stdout!r}" if self.stdout is not None else "",
|
36
|
+
f"stderr={self.stderr!r}" if self.stderr is not None else "",
|
37
|
+
]
|
38
|
+
return f"{type(self).__name__}({', '.join(filter(None, args))})"
|
39
|
+
|
40
|
+
@property
|
41
|
+
def exit_message(self) -> str:
|
42
|
+
"""Get a human-readable message for the exit code"""
|
43
|
+
return ExitCode.from_int(self.returncode).text
|
44
|
+
|
45
|
+
|
46
|
+
class CommandList(deque[CompletedProcess[str]]):
|
47
|
+
"""A list to hold previous commands with their timestamps and results"""
|
48
|
+
|
49
|
+
def __init__(self, maxlen: int = 10, *args, **kwargs) -> None:
|
50
|
+
super().__init__(maxlen=maxlen, *args, **kwargs) # noqa: B026
|
51
|
+
|
52
|
+
def add(self, command: CompletedProcess[str]) -> None:
|
53
|
+
"""Add a command to the list"""
|
54
|
+
self.append(command)
|
55
|
+
|
56
|
+
def get(self, index: int) -> CompletedProcess[str] | None:
|
57
|
+
"""Get a command by index"""
|
58
|
+
return self[index] if 0 <= index < len(self) else None
|
59
|
+
|
60
|
+
def get_most_recent(self) -> CompletedProcess[str] | None:
|
61
|
+
"""Get the most recent command"""
|
62
|
+
return self[-1] if self else None
|
63
|
+
|
64
|
+
|
65
|
+
class SimpleShellSession:
|
66
|
+
"""Simple shell session using subprocess with command chaining"""
|
67
|
+
|
68
|
+
def __init__(
|
69
|
+
self,
|
70
|
+
env: dict | None = None,
|
71
|
+
cwd: Path | str | None = None,
|
72
|
+
shell: str = DEFAULT_SHELL,
|
73
|
+
logger: LoggerProtocol | BaseLogger | None = None,
|
74
|
+
verbose: bool = False,
|
75
|
+
use_shell: bool = True,
|
76
|
+
) -> None:
|
77
|
+
self.shell: str = shell
|
78
|
+
self.cwd: Path = Path.cwd() if cwd is None else Path(cwd)
|
79
|
+
self.env: dict[str, str] = os.environ.copy() if env is None else env
|
80
|
+
self.cmd_buffer: StringIO = StringIO()
|
81
|
+
self.previous_commands: CommandList = CommandList()
|
82
|
+
self.result: CompletedProcess[str] | None = None
|
83
|
+
self.verbose: bool = verbose
|
84
|
+
self.use_shell: bool = use_shell
|
85
|
+
self.logger: LoggerProtocol | BaseLogger | SubConsoleLogger[BaseLogger] = self.set_logger(logger)
|
86
|
+
|
87
|
+
def set_logger(
|
88
|
+
self, passed_logger: LoggerProtocol | BaseLogger | None = None
|
89
|
+
) -> LoggerProtocol | BaseLogger | SubConsoleLogger[BaseLogger]:
|
90
|
+
"""Set the logger for the session, defaulting to a base logger if none is provided"""
|
91
|
+
if passed_logger is not None:
|
92
|
+
return passed_logger
|
93
|
+
|
94
|
+
if BaseLogger.has_instance():
|
95
|
+
logger = BaseLogger.get_instance().get_sub_logger(namespace="shell_session")
|
96
|
+
else:
|
97
|
+
temp: BaseLogger = BaseLogger.get_instance(init=True)
|
98
|
+
logger: SubConsoleLogger[BaseLogger] = temp.get_sub_logger(namespace="shell_session")
|
99
|
+
if self.verbose:
|
100
|
+
logger.set_sub_level(VERBOSE)
|
101
|
+
else:
|
102
|
+
logger.set_sub_level(INFO)
|
103
|
+
return logger
|
104
|
+
|
105
|
+
def add_to_env(self, env: dict[str, str], key: str | None = None, value: str | None = None) -> Self:
|
106
|
+
"""Populate the environment for the session"""
|
107
|
+
_env = {}
|
108
|
+
if isinstance(env, str) and key is not None and value is not None:
|
109
|
+
_env[key] = value
|
110
|
+
elif isinstance(env, dict):
|
111
|
+
for k, v in env.items():
|
112
|
+
_env[k] = v
|
113
|
+
self.env.update(_env)
|
114
|
+
return self
|
115
|
+
|
116
|
+
def add(self, c: str | BaseShellCommand) -> Self:
|
117
|
+
"""Add a command to the current session, return self for chaining"""
|
118
|
+
self.cmd_buffer.write(str(c))
|
119
|
+
return self
|
120
|
+
|
121
|
+
def amp(self, c: str | BaseShellCommand) -> Self:
|
122
|
+
"""Combine a command with the current session: &&, return self for chaining"""
|
123
|
+
if self.empty_history:
|
124
|
+
raise ValueError("No command to combine with")
|
125
|
+
self.cmd_buffer.write(" && ")
|
126
|
+
self.cmd_buffer.write(str(c))
|
127
|
+
return self
|
128
|
+
|
129
|
+
def piped(self, c: str | BaseShellCommand) -> Self:
|
130
|
+
"""Combine a command with the current session: |, return self for chaining"""
|
131
|
+
if self.empty_history:
|
132
|
+
raise ValueError("No command to pipe from")
|
133
|
+
self.cmd_buffer.write(" | ")
|
134
|
+
self.cmd_buffer.write(str(c))
|
135
|
+
return self
|
136
|
+
|
137
|
+
def _run(self, command: str) -> CompletedProcess[str]:
|
138
|
+
"""Internal method to run the accumulated command"""
|
139
|
+
self.logger.debug(f"Executing: {command}")
|
140
|
+
self.next_cmd()
|
141
|
+
|
142
|
+
if self.use_shell:
|
143
|
+
self.result = subprocess.run(
|
144
|
+
check=False,
|
145
|
+
args=command,
|
146
|
+
shell=True,
|
147
|
+
cwd=self.cwd,
|
148
|
+
env=self.env,
|
149
|
+
capture_output=True,
|
150
|
+
text=True,
|
151
|
+
)
|
152
|
+
else:
|
153
|
+
command_args: list[str] = shlex.split(command)
|
154
|
+
self.result = subprocess.run(
|
155
|
+
check=False,
|
156
|
+
args=command_args,
|
157
|
+
shell=False,
|
158
|
+
cwd=self.cwd,
|
159
|
+
env=self.env,
|
160
|
+
capture_output=True,
|
161
|
+
text=True,
|
162
|
+
)
|
163
|
+
|
164
|
+
if self.result.returncode != 0:
|
165
|
+
self.logger.error(f"Command failed with return code {self.result.returncode} {self.result.stderr.strip()}")
|
166
|
+
|
167
|
+
self.reset_buffer()
|
168
|
+
return self.result
|
169
|
+
|
170
|
+
def run(self, cmd: str | BaseShellCommand | None = None, *args) -> CompletedProcess[str]:
|
171
|
+
"""Run the accumulated command history"""
|
172
|
+
if self.empty_history and cmd is None:
|
173
|
+
raise ValueError("No commands to run")
|
174
|
+
|
175
|
+
if self.has_history and cmd is not None:
|
176
|
+
raise ValueError(
|
177
|
+
"If you want to add a command to a chain, use `amp` instead of `run`, `run` is for executing the full command history"
|
178
|
+
)
|
179
|
+
|
180
|
+
if self.has_history and cmd is None:
|
181
|
+
result: CompletedProcess[str] = self._run(self.cmd)
|
182
|
+
elif self.empty_history and cmd is not None:
|
183
|
+
self.cmd_buffer.write(f"{cmd} ")
|
184
|
+
if args:
|
185
|
+
self.cmd_buffer.write(" ".join(map(str, args)))
|
186
|
+
result: CompletedProcess[str] = self._run(self.cmd)
|
187
|
+
else:
|
188
|
+
raise ValueError("Unexpected state")
|
189
|
+
self.reset_buffer()
|
190
|
+
return result
|
191
|
+
|
192
|
+
@property
|
193
|
+
def empty_history(self) -> bool:
|
194
|
+
"""Check if the command history is empty"""
|
195
|
+
return not self.cmd_buffer.getvalue()
|
196
|
+
|
197
|
+
@property
|
198
|
+
def has_history(self) -> bool:
|
199
|
+
"""Check if there is any command in the history"""
|
200
|
+
return not self.empty_history
|
201
|
+
|
202
|
+
@property
|
203
|
+
def cmd(self) -> str:
|
204
|
+
"""Return the combined command as a string"""
|
205
|
+
if not self.cmd_buffer:
|
206
|
+
raise ValueError("No commands have been run yet")
|
207
|
+
|
208
|
+
if self.use_shell:
|
209
|
+
# Original behavior: wrap in shell command
|
210
|
+
full_command: str = f"{self.shell} -c {shlex.quote(self.cmd_buffer.getvalue())}"
|
211
|
+
else:
|
212
|
+
# New behavior: return raw command for argument parsing
|
213
|
+
full_command: str = self.cmd_buffer.getvalue().strip()
|
214
|
+
|
215
|
+
return full_command
|
216
|
+
|
217
|
+
@property
|
218
|
+
def returncode(self) -> bool:
|
219
|
+
"""Return the last command's return code"""
|
220
|
+
if self.result is None:
|
221
|
+
raise ValueError("No command has been run yet")
|
222
|
+
return self.result.returncode == 0
|
223
|
+
|
224
|
+
@property
|
225
|
+
def stdout(self) -> str:
|
226
|
+
"""Return the standard output of the last command"""
|
227
|
+
if self.result is None:
|
228
|
+
raise ValueError("No command has been run yet")
|
229
|
+
return self.result.stdout.strip() if self.result.stdout is not None else "None"
|
230
|
+
|
231
|
+
@property
|
232
|
+
def stderr(self) -> str:
|
233
|
+
"""Return the standard error of the last command"""
|
234
|
+
if self.result is None:
|
235
|
+
raise ValueError("No command has been run yet")
|
236
|
+
return self.result.stderr.strip() if self.result.stderr is not None else "None"
|
237
|
+
|
238
|
+
@property
|
239
|
+
def pretty_result(self) -> str:
|
240
|
+
"""Return a formatted string of the command result"""
|
241
|
+
if self.result is None:
|
242
|
+
raise ValueError("No command has been run yet")
|
243
|
+
return (
|
244
|
+
f"Command: {self.result.args}\n"
|
245
|
+
f"Return Code: {self.result.returncode}\n"
|
246
|
+
f"Standard Output: {self.result.stdout.strip()}\n"
|
247
|
+
f"Standard Error: {self.result.stderr.strip()}\n"
|
248
|
+
)
|
249
|
+
|
250
|
+
def reset_buffer(self) -> None:
|
251
|
+
"""Reset the command buffer"""
|
252
|
+
self.cmd_buffer.seek(0)
|
253
|
+
self.cmd_buffer.truncate(0)
|
254
|
+
|
255
|
+
def reset(self) -> None:
|
256
|
+
"""Reset the session state"""
|
257
|
+
self.previous_commands.clear()
|
258
|
+
self.result = None
|
259
|
+
|
260
|
+
def next_cmd(self) -> None:
|
261
|
+
"""Store the current command in the history before running a new one"""
|
262
|
+
if self.result is not None:
|
263
|
+
self.previous_commands.add(command=self.result)
|
264
|
+
self.result = None
|
265
|
+
|
266
|
+
def get_cmd(self, index: int | None = None) -> CompletedProcess[str] | None:
|
267
|
+
"""Get a previous command by index or the most recent one if index is None"""
|
268
|
+
if index is None:
|
269
|
+
return self.previous_commands.get_most_recent()
|
270
|
+
return self.previous_commands.get(index)
|
271
|
+
|
272
|
+
def __enter__(self) -> Self:
|
273
|
+
"""Enter the context manager"""
|
274
|
+
return self
|
275
|
+
|
276
|
+
def __exit__(self, exc_type: object, exc_value: object, traceback: object) -> None:
|
277
|
+
"""Exit the context manager"""
|
278
|
+
self.reset()
|
279
|
+
self.reset_buffer()
|
280
|
+
|
281
|
+
|
282
|
+
class AsyncShellSession(SimpleShellSession):
|
283
|
+
"""Shell session using Popen for more control over the subprocess"""
|
284
|
+
|
285
|
+
def __init__(
|
286
|
+
self,
|
287
|
+
env: dict[str, str] | None = None,
|
288
|
+
cwd: str | None = None,
|
289
|
+
shell: str = DEFAULT_SHELL,
|
290
|
+
logger: LoggerProtocol | BaseLogger | None = None,
|
291
|
+
verbose: bool = False,
|
292
|
+
use_shell: bool = True,
|
293
|
+
) -> None:
|
294
|
+
super().__init__(env=env, cwd=cwd, shell=shell, logger=logger, verbose=verbose, use_shell=use_shell)
|
295
|
+
self.process: Process | None = None
|
296
|
+
self._callbacks: list[Callable[[CompletedProcess], None]] = []
|
297
|
+
|
298
|
+
@override
|
299
|
+
async def _run(self, command: str, **kwargs) -> Process: # type: ignore[override]
|
300
|
+
"""Run the command using Popen for better control"""
|
301
|
+
self.logger.debug(f"Executing: {command}")
|
302
|
+
self.next_cmd()
|
303
|
+
|
304
|
+
if self.use_shell:
|
305
|
+
self.process = await asyncio.create_subprocess_shell(
|
306
|
+
command,
|
307
|
+
stdout=asyncio.subprocess.PIPE,
|
308
|
+
stderr=asyncio.subprocess.PIPE,
|
309
|
+
cwd=self.cwd,
|
310
|
+
env=self.env,
|
311
|
+
**kwargs,
|
312
|
+
)
|
313
|
+
else:
|
314
|
+
command_args: list[str] = shlex.split(command)
|
315
|
+
self.process = await asyncio.create_subprocess_exec(
|
316
|
+
*command_args,
|
317
|
+
stdout=asyncio.subprocess.PIPE,
|
318
|
+
stderr=asyncio.subprocess.PIPE,
|
319
|
+
cwd=self.cwd,
|
320
|
+
env=self.env,
|
321
|
+
**kwargs,
|
322
|
+
)
|
323
|
+
return self.process
|
324
|
+
|
325
|
+
@override
|
326
|
+
async def run(self, cmd: str | BaseShellCommand | None = None, *args, **kwargs) -> Process: # type: ignore[override]
|
327
|
+
"""Async version of run that returns Process for streaming"""
|
328
|
+
if self.empty_history and cmd is None:
|
329
|
+
raise ValueError("No commands to run")
|
330
|
+
|
331
|
+
if self.has_history and cmd is not None:
|
332
|
+
raise ValueError("Use `amp` to chain commands, not `run`")
|
333
|
+
if self.has_history and cmd is None:
|
334
|
+
command: str = self.cmd
|
335
|
+
elif self.empty_history and cmd is not None:
|
336
|
+
self.cmd_buffer.write(f"{cmd}")
|
337
|
+
if args:
|
338
|
+
self.cmd_buffer.write(" ".join(map(str, args)))
|
339
|
+
command: str = self.cmd
|
340
|
+
else:
|
341
|
+
raise ValueError("Unexpected state")
|
342
|
+
process: Process = await self._run(command, **kwargs)
|
343
|
+
return process
|
344
|
+
|
345
|
+
async def communicate(self, stdin: str = "") -> CompletedProcess[str]:
|
346
|
+
"""Communicate with the process, sending input and waiting for completion"""
|
347
|
+
if self.process is None:
|
348
|
+
raise ValueError("No process has been started yet")
|
349
|
+
bytes_stdin: bytes = stdin.encode("utf-8") if isinstance(stdin, str) else stdin
|
350
|
+
|
351
|
+
stdout, stderr = await self.process.communicate(input=bytes_stdin)
|
352
|
+
return_code: int = await self.process.wait()
|
353
|
+
|
354
|
+
self.result = FancyCompletedProcess(
|
355
|
+
args=self.cmd, # type: ignore FIXME: should be a list[str] not str?
|
356
|
+
returncode=return_code,
|
357
|
+
stdout=stdout.decode() if stdout else "",
|
358
|
+
stderr=stderr.decode() if stderr else "",
|
359
|
+
)
|
360
|
+
if return_code != 0:
|
361
|
+
self.logger.error(f"Command failed with return code {return_code} {stderr.strip()}")
|
362
|
+
for callback in self._callbacks:
|
363
|
+
callback(self.result)
|
364
|
+
await self.after_process()
|
365
|
+
return self.result
|
366
|
+
|
367
|
+
@staticmethod
|
368
|
+
async def read_stream(stream: StreamReader) -> AsyncGenerator[str]:
|
369
|
+
while True:
|
370
|
+
try:
|
371
|
+
line: bytes = await stream.readline()
|
372
|
+
if not line: # EOF
|
373
|
+
break
|
374
|
+
yield line.decode("utf-8").rstrip("\n")
|
375
|
+
except Exception:
|
376
|
+
break
|
377
|
+
|
378
|
+
async def stream_stdout(self) -> AsyncGenerator[str]:
|
379
|
+
"""Stream output line by line as it comes"""
|
380
|
+
if self.process is None:
|
381
|
+
raise ValueError("No process has been started yet")
|
382
|
+
if not self.process.stdout:
|
383
|
+
raise ValueError("Process has no stdout")
|
384
|
+
|
385
|
+
async for line in self.read_stream(self.process.stdout):
|
386
|
+
yield line
|
387
|
+
|
388
|
+
async def stream_stderr(self) -> AsyncGenerator[str]:
|
389
|
+
"""Stream error output line by line as it comes"""
|
390
|
+
if self.process is None:
|
391
|
+
raise ValueError("No process has been started yet")
|
392
|
+
if not self.process.stderr:
|
393
|
+
raise ValueError("Process has no stderr")
|
394
|
+
async for line in self.read_stream(self.process.stderr):
|
395
|
+
yield line
|
396
|
+
|
397
|
+
async def after_process(self) -> None:
|
398
|
+
"""Run after process completion, can be overridden for custom behavior"""
|
399
|
+
self.process = None
|
400
|
+
self._callbacks.clear()
|
401
|
+
self.reset_buffer()
|
402
|
+
|
403
|
+
def on_completion(self, callback: Callable[[CompletedProcess[str]], None]) -> None:
|
404
|
+
"""Add callback for when process completes"""
|
405
|
+
self._callbacks.append(callback)
|
406
|
+
|
407
|
+
@property
|
408
|
+
def is_running(self) -> bool:
|
409
|
+
"""Check if process is still running"""
|
410
|
+
return self.process is not None and self.process.returncode is None
|
411
|
+
|
412
|
+
|
413
|
+
@contextmanager
|
414
|
+
def shell_session(shell: str = DEFAULT_SHELL, **kwargs) -> Generator[SimpleShellSession]:
|
415
|
+
"""Context manager for simple shell sessions"""
|
416
|
+
session = SimpleShellSession(shell=shell, **kwargs)
|
417
|
+
try:
|
418
|
+
yield session
|
419
|
+
finally:
|
420
|
+
pass
|
421
|
+
|
422
|
+
|
423
|
+
@asynccontextmanager
|
424
|
+
async def async_shell_session(shell: str = DEFAULT_SHELL, **kwargs) -> AsyncGenerator[AsyncShellSession]:
|
425
|
+
"""Asynchronous context manager for shell sessions"""
|
426
|
+
session = AsyncShellSession(shell=shell, **kwargs)
|
427
|
+
try:
|
428
|
+
yield session
|
429
|
+
finally:
|
430
|
+
pass
|