ptn 0.3.2__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. porterminal/_version.py +2 -2
  2. porterminal/application/services/management_service.py +28 -52
  3. porterminal/application/services/session_service.py +3 -11
  4. porterminal/application/services/terminal_service.py +97 -56
  5. porterminal/cli/args.py +39 -31
  6. porterminal/cli/display.py +18 -16
  7. porterminal/cli/script_discovery.py +112 -0
  8. porterminal/composition.py +2 -7
  9. porterminal/config.py +4 -2
  10. porterminal/domain/__init__.py +0 -9
  11. porterminal/domain/entities/output_buffer.py +56 -1
  12. porterminal/domain/entities/tab.py +11 -10
  13. porterminal/domain/services/__init__.py +0 -2
  14. porterminal/domain/values/__init__.py +0 -4
  15. porterminal/domain/values/environment_rules.py +3 -0
  16. porterminal/infrastructure/cloudflared.py +13 -11
  17. porterminal/infrastructure/config/shell_detector.py +86 -20
  18. porterminal/infrastructure/repositories/in_memory_session.py +1 -4
  19. porterminal/infrastructure/repositories/in_memory_tab.py +2 -10
  20. porterminal/pty/env.py +16 -78
  21. porterminal/pty/manager.py +6 -4
  22. porterminal/static/assets/app-DlWNJWFE.js +87 -0
  23. porterminal/static/assets/app-xPAM7YhQ.css +1 -0
  24. porterminal/static/index.html +2 -2
  25. porterminal/updater.py +13 -5
  26. {ptn-0.3.2.dist-info → ptn-0.4.2.dist-info}/METADATA +54 -16
  27. {ptn-0.3.2.dist-info → ptn-0.4.2.dist-info}/RECORD +30 -35
  28. porterminal/static/assets/app-BkHv5qu0.css +0 -32
  29. porterminal/static/assets/app-CaIGfw7i.js +0 -72
  30. porterminal/static/assets/app-D9ELFbEO.js +0 -72
  31. porterminal/static/assets/app-DF3nl_io.js +0 -72
  32. porterminal/static/assets/app-DQePboVd.css +0 -32
  33. porterminal/static/assets/app-DoBiVkTD.js +0 -72
  34. porterminal/static/assets/app-azbHOsRw.css +0 -32
  35. porterminal/static/assets/app-nMNFwMa6.css +0 -32
  36. {ptn-0.3.2.dist-info → ptn-0.4.2.dist-info}/WHEEL +0 -0
  37. {ptn-0.3.2.dist-info → ptn-0.4.2.dist-info}/entry_points.txt +0 -0
  38. {ptn-0.3.2.dist-info → ptn-0.4.2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,112 @@
1
+ """Auto-discover project scripts for config initialization."""
2
+
3
+ import json
4
+ import re
5
+ import tomllib
6
+ from pathlib import Path
7
+
8
+ # Pattern for safe script names (alphanumeric, hyphens, underscores only)
9
+ _SAFE_NAME = re.compile(r"^[a-zA-Z0-9_-]+$")
10
+
11
+
12
+ def _is_safe_name(name: str) -> bool:
13
+ """Check if script name contains only safe characters."""
14
+ return bool(_SAFE_NAME.match(name)) and len(name) <= 50
15
+
16
+
17
+ def discover_scripts(cwd: Path | None = None) -> list[dict]:
18
+ """Discover project scripts in current directory.
19
+
20
+ Returns list of button configs: [{"label": "build", "send": "npm run build\\r", "row": 2}]
21
+ Only includes scripts explicitly defined in project files.
22
+ """
23
+ base = cwd or Path.cwd()
24
+ buttons = []
25
+
26
+ # Check each project type (only those with explicit scripts)
27
+ buttons.extend(_discover_npm_scripts(base))
28
+ buttons.extend(_discover_python_scripts(base))
29
+ buttons.extend(_discover_makefile_targets(base))
30
+
31
+ # Dedupe by label, keep first occurrence
32
+ unique: dict[str, dict] = {}
33
+ for btn in buttons:
34
+ unique.setdefault(btn["label"], btn)
35
+ return list(unique.values())
36
+
37
+
38
+ def _discover_npm_scripts(base: Path) -> list[dict]:
39
+ """Extract scripts from package.json."""
40
+ pkg_file = base / "package.json"
41
+ if not pkg_file.exists():
42
+ return []
43
+
44
+ try:
45
+ data = json.loads(pkg_file.read_text(encoding="utf-8"))
46
+ scripts = data.get("scripts", {})
47
+
48
+ # Common useful scripts to include (if defined)
49
+ priority = ["build", "dev", "start", "test", "lint", "format", "watch"]
50
+
51
+ buttons = []
52
+ for name in priority:
53
+ if name in scripts:
54
+ buttons.append({"label": name, "send": f"npm run {name}\r", "row": 2})
55
+
56
+ return buttons[:6] # Limit to 6 buttons
57
+ except Exception:
58
+ return []
59
+
60
+
61
+ def _discover_python_scripts(base: Path) -> list[dict]:
62
+ """Extract scripts from pyproject.toml."""
63
+ toml_file = base / "pyproject.toml"
64
+ if not toml_file.exists():
65
+ return []
66
+
67
+ try:
68
+ data = tomllib.loads(toml_file.read_text(encoding="utf-8"))
69
+ buttons = []
70
+
71
+ # Check [project.scripts] (PEP 621)
72
+ project_scripts = data.get("project", {}).get("scripts", {})
73
+ for name in list(project_scripts.keys())[:4]:
74
+ if _is_safe_name(name):
75
+ buttons.append({"label": name, "send": f"{name}\r", "row": 2})
76
+
77
+ # Check [tool.poetry.scripts]
78
+ poetry_scripts = data.get("tool", {}).get("poetry", {}).get("scripts", {})
79
+ for name in list(poetry_scripts.keys())[:4]:
80
+ if _is_safe_name(name) and not any(b["label"] == name for b in buttons):
81
+ buttons.append({"label": name, "send": f"{name}\r", "row": 2})
82
+
83
+ return buttons[:6]
84
+ except Exception:
85
+ return []
86
+
87
+
88
+ def _discover_makefile_targets(base: Path) -> list[dict]:
89
+ """Extract targets from Makefile."""
90
+ makefile = base / "Makefile"
91
+ if not makefile.exists():
92
+ return []
93
+
94
+ try:
95
+ content = makefile.read_text(encoding="utf-8")
96
+ # Match target definitions: "target:" at start of line
97
+ # Regex excludes targets starting with . (internal targets like .PHONY)
98
+ pattern = r"^([a-zA-Z_][a-zA-Z0-9_-]*)\s*:"
99
+ targets = re.findall(pattern, content, re.MULTILINE)
100
+
101
+ # Priority order for common targets (use set for O(1) lookup)
102
+ priority = ["build", "test", "run", "clean", "install", "dev", "lint", "all"]
103
+ priority_set = set(priority)
104
+ target_set = set(targets)
105
+
106
+ # Priority targets first, then remaining targets
107
+ ordered = [t for t in priority if t in target_set]
108
+ ordered.extend(t for t in targets if t not in priority_set)
109
+
110
+ return [{"label": name, "send": f"make {name}\r", "row": 2} for name in ordered[:6]]
111
+ except Exception:
112
+ return []
@@ -14,8 +14,6 @@ from porterminal.application.services import (
14
14
  from porterminal.config import find_config_file
15
15
  from porterminal.container import Container
16
16
  from porterminal.domain import (
17
- EnvironmentRules,
18
- EnvironmentSanitizer,
19
17
  PTYPort,
20
18
  SessionLimitChecker,
21
19
  ShellCommand,
@@ -29,7 +27,7 @@ from porterminal.infrastructure.repositories import InMemorySessionRepository, I
29
27
 
30
28
  def create_pty_factory(
31
29
  cwd: str | None = None,
32
- ) -> Callable[[ShellCommand, TerminalDimensions, dict[str, str], str | None], PTYPort]:
30
+ ) -> Callable[[ShellCommand, TerminalDimensions, str | None], PTYPort]:
33
31
  """Create a PTY factory function.
34
32
 
35
33
  This bridges the domain PTYPort interface with the existing
@@ -40,7 +38,6 @@ def create_pty_factory(
40
38
  def factory(
41
39
  shell: ShellCommand,
42
40
  dimensions: TerminalDimensions,
43
- environment: dict[str, str],
44
41
  working_directory: str | None = None,
45
42
  ) -> PTYPort:
46
43
  # Use provided cwd or factory default
@@ -60,6 +57,7 @@ def create_pty_factory(
60
57
  )
61
58
 
62
59
  # Create manager (which implements PTY operations)
60
+ # Environment sanitization is handled internally by SecurePTYManager
63
61
  manager = SecurePTYManager(
64
62
  backend=backend,
65
63
  shell_config=legacy_shell,
@@ -68,8 +66,6 @@ def create_pty_factory(
68
66
  cwd=effective_cwd,
69
67
  )
70
68
 
71
- # Spawn with environment (manager handles sanitization internally,
72
- # but we pass our sanitized env to be safe)
73
69
  manager.spawn()
74
70
 
75
71
  return PTYManagerAdapter(manager, dimensions)
@@ -173,7 +169,6 @@ def create_container(
173
169
  repository=session_repository,
174
170
  pty_factory=pty_factory,
175
171
  limit_checker=SessionLimitChecker(),
176
- environment_sanitizer=EnvironmentSanitizer(EnvironmentRules()),
177
172
  working_directory=cwd,
178
173
  )
179
174
 
porterminal/config.py CHANGED
@@ -7,6 +7,7 @@ from pathlib import Path
7
7
  import yaml
8
8
  from pydantic import BaseModel, Field, field_validator
9
9
 
10
+ from porterminal.domain.values import MAX_COLS, MAX_ROWS, MIN_COLS, MIN_ROWS
10
11
  from porterminal.infrastructure.config import ShellDetector
11
12
 
12
13
 
@@ -43,8 +44,8 @@ class TerminalConfig(BaseModel):
43
44
  """Terminal configuration."""
44
45
 
45
46
  default_shell: str = ""
46
- cols: int = Field(default=120, ge=40, le=500)
47
- rows: int = Field(default=30, ge=10, le=200)
47
+ cols: int = Field(default=120, ge=MIN_COLS, le=MAX_COLS)
48
+ rows: int = Field(default=30, ge=MIN_ROWS, le=MAX_ROWS)
48
49
  shells: list[ShellConfig] = Field(default_factory=list)
49
50
 
50
51
  def get_shell(self, shell_id: str) -> ShellConfig | None:
@@ -60,6 +61,7 @@ class ButtonConfig(BaseModel):
60
61
 
61
62
  label: str
62
63
  send: str | list[str | int] = "" # string or list of strings/ints (ints = wait ms)
64
+ row: int = Field(default=1, ge=1, le=10) # toolbar row (1-10)
63
65
 
64
66
 
65
67
  class CloudflareConfig(BaseModel):
@@ -1,6 +1,5 @@
1
1
  """Pure domain layer - no infrastructure dependencies."""
2
2
 
3
- # Value Objects
4
3
  # Entities
5
4
  from .entities import (
6
5
  CLEAR_SCREEN_SEQUENCE,
@@ -23,7 +22,6 @@ from .ports import (
23
22
  # Services
24
23
  from .services import (
25
24
  Clock,
26
- EnvironmentSanitizer,
27
25
  SessionLimitChecker,
28
26
  SessionLimitConfig,
29
27
  SessionLimitResult,
@@ -33,13 +31,10 @@ from .services import (
33
31
  TokenBucketRateLimiter,
34
32
  )
35
33
  from .values import (
36
- DEFAULT_BLOCKED_VARS,
37
- DEFAULT_SAFE_VARS,
38
34
  MAX_COLS,
39
35
  MAX_ROWS,
40
36
  MIN_COLS,
41
37
  MIN_ROWS,
42
- EnvironmentRules,
43
38
  RateLimitConfig,
44
39
  SessionId,
45
40
  ShellCommand,
@@ -60,9 +55,6 @@ __all__ = [
60
55
  "TabId",
61
56
  "ShellCommand",
62
57
  "RateLimitConfig",
63
- "EnvironmentRules",
64
- "DEFAULT_SAFE_VARS",
65
- "DEFAULT_BLOCKED_VARS",
66
58
  # Entities
67
59
  "Session",
68
60
  "MAX_SESSIONS_PER_USER",
@@ -75,7 +67,6 @@ __all__ = [
75
67
  # Services
76
68
  "TokenBucketRateLimiter",
77
69
  "Clock",
78
- "EnvironmentSanitizer",
79
70
  "SessionLimitChecker",
80
71
  "SessionLimitConfig",
81
72
  "SessionLimitResult",
@@ -9,6 +9,11 @@ OUTPUT_BUFFER_MAX_BYTES = 1_000_000 # 1MB
9
9
  # Terminal escape sequence for clear screen (ED2)
10
10
  CLEAR_SCREEN_SEQUENCE = b"\x1b[2J"
11
11
 
12
+ # Alternate screen buffer sequences (DEC Private Mode)
13
+ # Used by vim, htop, less, tmux, etc.
14
+ ALT_SCREEN_ENTER = (b"\x1b[?47h", b"\x1b[?1047h", b"\x1b[?1049h")
15
+ ALT_SCREEN_EXIT = (b"\x1b[?47l", b"\x1b[?1047l", b"\x1b[?1049l")
16
+
12
17
 
13
18
  @dataclass
14
19
  class OutputBuffer:
@@ -16,12 +21,21 @@ class OutputBuffer:
16
21
 
17
22
  Pure domain logic for buffering terminal output.
18
23
  No async, no WebSocket - just data management.
24
+
25
+ Handles alternate screen buffer (used by vim, htop, less, etc.):
26
+ - On alt-screen enter: snapshots normal buffer, clears for alt content
27
+ - On alt-screen exit: restores normal buffer, discards alt content
19
28
  """
20
29
 
21
30
  max_bytes: int = OUTPUT_BUFFER_MAX_BYTES
22
31
  _buffer: deque[bytes] = field(default_factory=deque)
23
32
  _size: int = 0
24
33
 
34
+ # Alt-screen state
35
+ _in_alt_screen: bool = False
36
+ _normal_snapshot: deque[bytes] | None = None
37
+ _normal_snapshot_size: int = 0
38
+
25
39
  @property
26
40
  def size(self) -> int:
27
41
  """Current buffer size in bytes."""
@@ -32,12 +46,53 @@ class OutputBuffer:
32
46
  """Check if buffer is empty."""
33
47
  return self._size == 0
34
48
 
49
+ @property
50
+ def in_alt_screen(self) -> bool:
51
+ """Check if currently in alternate screen mode."""
52
+ return self._in_alt_screen
53
+
54
+ def _enter_alt_screen(self) -> None:
55
+ """Handle alt-screen entry: snapshot normal buffer."""
56
+ if self._in_alt_screen:
57
+ return # Already in alt-screen, ignore nested
58
+ self._in_alt_screen = True
59
+ self._normal_snapshot = self._buffer.copy()
60
+ self._normal_snapshot_size = self._size
61
+ self._clear_buffer()
62
+
63
+ def _exit_alt_screen(self) -> None:
64
+ """Handle alt-screen exit: restore normal buffer."""
65
+ if not self._in_alt_screen:
66
+ return # Not in alt-screen, ignore
67
+ self._in_alt_screen = False
68
+ if self._normal_snapshot is not None:
69
+ self._buffer = self._normal_snapshot
70
+ self._size = self._normal_snapshot_size
71
+ self._normal_snapshot = None
72
+ self._normal_snapshot_size = 0
73
+
74
+ def _clear_buffer(self) -> None:
75
+ """Clear the buffer contents only."""
76
+ self._buffer.clear()
77
+ self._size = 0
78
+
35
79
  def add(self, data: bytes) -> None:
36
80
  """Add data to the buffer.
37
81
 
38
- Handles clear screen detection and size limits.
82
+ Handles alt-screen transitions, clear screen detection, and size limits.
39
83
  When clear screen is detected, only keep content AFTER the last clear sequence.
40
84
  """
85
+ # Check alt-screen transitions FIRST
86
+ for pattern in ALT_SCREEN_EXIT:
87
+ if pattern in data:
88
+ self._exit_alt_screen()
89
+ break
90
+ else:
91
+ for pattern in ALT_SCREEN_ENTER:
92
+ if pattern in data:
93
+ self._enter_alt_screen()
94
+ return # Don't buffer alt-screen enter data
95
+
41
96
  # Check for clear screen sequence
42
97
  if CLEAR_SCREEN_SEQUENCE in data:
43
98
  # Clear old buffer
@@ -13,6 +13,15 @@ TAB_NAME_MIN_LENGTH = 1
13
13
  TAB_NAME_MAX_LENGTH = 50
14
14
 
15
15
 
16
+ def _validate_tab_name(name: str) -> None:
17
+ """Validate tab name length."""
18
+ if not (TAB_NAME_MIN_LENGTH <= len(name) <= TAB_NAME_MAX_LENGTH):
19
+ raise ValueError(
20
+ f"Tab name must be {TAB_NAME_MIN_LENGTH}-{TAB_NAME_MAX_LENGTH} "
21
+ f"characters, got {len(name)}"
22
+ )
23
+
24
+
16
25
  @dataclass
17
26
  class Tab:
18
27
  """Terminal tab entity.
@@ -35,11 +44,7 @@ class Tab:
35
44
  last_accessed: datetime
36
45
 
37
46
  def __post_init__(self) -> None:
38
- if not (TAB_NAME_MIN_LENGTH <= len(self.name) <= TAB_NAME_MAX_LENGTH):
39
- raise ValueError(
40
- f"Tab name must be {TAB_NAME_MIN_LENGTH}-{TAB_NAME_MAX_LENGTH} "
41
- f"characters, got {len(self.name)}"
42
- )
47
+ _validate_tab_name(self.name)
43
48
 
44
49
  @property
45
50
  def tab_id(self) -> str:
@@ -52,11 +57,7 @@ class Tab:
52
57
 
53
58
  def rename(self, new_name: str) -> None:
54
59
  """Rename the tab with validation."""
55
- if not (TAB_NAME_MIN_LENGTH <= len(new_name) <= TAB_NAME_MAX_LENGTH):
56
- raise ValueError(
57
- f"Tab name must be {TAB_NAME_MIN_LENGTH}-{TAB_NAME_MAX_LENGTH} "
58
- f"characters, got {len(new_name)}"
59
- )
60
+ _validate_tab_name(new_name)
60
61
  self.name = new_name
61
62
 
62
63
  def to_dict(self) -> dict:
@@ -1,6 +1,5 @@
1
1
  """Domain services - pure business logic operations."""
2
2
 
3
- from .environment_sanitizer import EnvironmentSanitizer
4
3
  from .rate_limiter import Clock, TokenBucketRateLimiter
5
4
  from .session_limits import SessionLimitChecker, SessionLimitConfig, SessionLimitResult
6
5
  from .tab_limits import TabLimitChecker, TabLimitConfig, TabLimitResult
@@ -8,7 +7,6 @@ from .tab_limits import TabLimitChecker, TabLimitConfig, TabLimitResult
8
7
  __all__ = [
9
8
  "TokenBucketRateLimiter",
10
9
  "Clock",
11
- "EnvironmentSanitizer",
12
10
  "SessionLimitChecker",
13
11
  "SessionLimitConfig",
14
12
  "SessionLimitResult",
@@ -1,6 +1,5 @@
1
1
  """Domain value objects - immutable data structures."""
2
2
 
3
- from .environment_rules import DEFAULT_BLOCKED_VARS, DEFAULT_SAFE_VARS, EnvironmentRules
4
3
  from .rate_limit_config import RateLimitConfig
5
4
  from .session_id import SessionId
6
5
  from .shell_command import ShellCommand
@@ -19,7 +18,4 @@ __all__ = [
19
18
  "TabId",
20
19
  "ShellCommand",
21
20
  "RateLimitConfig",
22
- "EnvironmentRules",
23
- "DEFAULT_SAFE_VARS",
24
- "DEFAULT_BLOCKED_VARS",
25
21
  ]
@@ -22,12 +22,15 @@ DEFAULT_SAFE_VARS: frozenset[str] = frozenset(
22
22
  "HOMEPATH",
23
23
  "LOCALAPPDATA",
24
24
  "APPDATA",
25
+ "PROGRAMDATA",
25
26
  "PROGRAMFILES",
26
27
  "PROGRAMFILES(X86)",
27
28
  "COMMONPROGRAMFILES",
28
29
  # System info
29
30
  "COMPUTERNAME",
30
31
  "USERNAME",
32
+ "USER",
33
+ "LOGNAME",
31
34
  "USERDOMAIN",
32
35
  "OS",
33
36
  "PROCESSOR_ARCHITECTURE",
@@ -17,6 +17,13 @@ console = Console()
17
17
  class CloudflaredInstaller:
18
18
  """Platform-specific cloudflared installer."""
19
19
 
20
+ @staticmethod
21
+ def _add_to_path(path: str | Path) -> None:
22
+ """Add directory to PATH for current process."""
23
+ path_str = str(path)
24
+ os.environ["PATH"] = path_str + os.pathsep + os.environ.get("PATH", "")
25
+ console.print(f"[dim]Added to PATH: {path_str}[/dim]")
26
+
20
27
  @staticmethod
21
28
  def is_installed() -> bool:
22
29
  """Check if cloudflared is installed."""
@@ -91,8 +98,7 @@ class CloudflaredInstaller:
91
98
  # Try to find and add to PATH for current session
92
99
  install_path = CloudflaredInstaller._find_cloudflared_windows()
93
100
  if install_path:
94
- os.environ["PATH"] = install_path + os.pathsep + os.environ.get("PATH", "")
95
- console.print(f"[dim]Added to PATH: {install_path}[/dim]")
101
+ CloudflaredInstaller._add_to_path(install_path)
96
102
  # Return True regardless - winget succeeded, may just need shell restart
97
103
  return True
98
104
  except (subprocess.TimeoutExpired, OSError) as e:
@@ -119,7 +125,7 @@ class CloudflaredInstaller:
119
125
 
120
126
  exe_path = install_dir / "cloudflared.exe"
121
127
  if exe_path.exists():
122
- os.environ["PATH"] = str(install_dir) + os.pathsep + os.environ.get("PATH", "")
128
+ CloudflaredInstaller._add_to_path(install_dir)
123
129
  console.print(f"[green]✓ Installed to {install_dir}[/green]")
124
130
  return True
125
131
 
@@ -203,10 +209,7 @@ class CloudflaredInstaller:
203
209
  # Try to find and add to PATH
204
210
  install_path = CloudflaredInstaller._find_cloudflared_unix()
205
211
  if install_path:
206
- os.environ["PATH"] = (
207
- install_path + os.pathsep + os.environ.get("PATH", "")
208
- )
209
- console.print(f"[dim]Added to PATH: {install_path}[/dim]")
212
+ CloudflaredInstaller._add_to_path(install_path)
210
213
  # Return True regardless - package manager succeeded
211
214
  return True
212
215
  except (subprocess.TimeoutExpired, OSError) as e:
@@ -226,7 +229,7 @@ class CloudflaredInstaller:
226
229
  os.chmod(bin_path, 0o755)
227
230
 
228
231
  # Add to PATH for this session
229
- os.environ["PATH"] = str(install_dir) + os.pathsep + os.environ.get("PATH", "")
232
+ CloudflaredInstaller._add_to_path(install_dir)
230
233
  console.print(f"[green]✓ Installed to {bin_path}[/green]")
231
234
  return True
232
235
 
@@ -257,8 +260,7 @@ class CloudflaredInstaller:
257
260
  # Try to find and add to PATH
258
261
  install_path = CloudflaredInstaller._find_cloudflared_unix()
259
262
  if install_path:
260
- os.environ["PATH"] = install_path + os.pathsep + os.environ.get("PATH", "")
261
- console.print(f"[dim]Added to PATH: {install_path}[/dim]")
263
+ CloudflaredInstaller._add_to_path(install_path)
262
264
  # Return True regardless - Homebrew succeeded
263
265
  return True
264
266
  except (subprocess.TimeoutExpired, OSError) as e:
@@ -289,7 +291,7 @@ class CloudflaredInstaller:
289
291
  bin_path = install_dir / "cloudflared"
290
292
  if bin_path.exists():
291
293
  os.chmod(bin_path, 0o755)
292
- os.environ["PATH"] = str(install_dir) + os.pathsep + os.environ.get("PATH", "")
294
+ CloudflaredInstaller._add_to_path(install_dir)
293
295
  console.print(f"[green]✓ Installed to {bin_path}[/green]")
294
296
  return True
295
297
 
@@ -21,6 +21,9 @@ class ShellDetector:
21
21
  def detect_shells(self) -> list[ShellCommand]:
22
22
  """Auto-detect available shells.
23
23
 
24
+ Detects platform-specific shells and includes the user's $SHELL
25
+ if it's not already in the list (supports any shell).
26
+
24
27
  Returns:
25
28
  List of detected shell configurations.
26
29
  """
@@ -39,6 +42,12 @@ class ShellDetector:
39
42
  )
40
43
  )
41
44
 
45
+ # Include user's $SHELL if not already detected (supports unknown shells)
46
+ user_shell = self._create_shell_from_env()
47
+ if user_shell and not any(s.id == user_shell.id for s in shells):
48
+ # Insert at beginning so user's preferred shell is first
49
+ shells.insert(0, user_shell)
50
+
42
51
  return shells
43
52
 
44
53
  def get_default_shell_id(self) -> str:
@@ -149,25 +158,20 @@ class ShellDetector:
149
158
 
150
159
  try:
151
160
  result = subprocess.run(
152
- [
153
- str(vswhere),
154
- "-all",
155
- "-prerelease",
156
- "-property",
157
- "installationPath",
158
- ],
161
+ [str(vswhere), "-all", "-prerelease", "-format", "json"],
159
162
  capture_output=True,
160
163
  text=True,
161
164
  timeout=5,
162
165
  )
163
- vs_paths = [p.strip() for p in result.stdout.strip().split("\n") if p.strip()]
164
- except (subprocess.TimeoutExpired, OSError) as e:
166
+ vs_installs = json.loads(result.stdout) if result.stdout.strip() else []
167
+ except (subprocess.TimeoutExpired, OSError, json.JSONDecodeError) as e:
165
168
  logger.warning("Failed to run vswhere: %s", e)
166
169
  return []
167
170
 
168
171
  shells = []
169
- for vs_path in vs_paths:
170
- vs_path = Path(vs_path)
172
+ for vs_info in vs_installs:
173
+ vs_path = Path(vs_info.get("installationPath", ""))
174
+ instance_id = vs_info.get("instanceId", "")
171
175
  # Extract VS version and edition from path
172
176
  # e.g., "C:\Program Files\Microsoft Visual Studio\2022\Community"
173
177
  edition = vs_path.name # Community, Professional, Enterprise
@@ -188,18 +192,32 @@ class ShellDetector:
188
192
  )
189
193
  )
190
194
 
191
- # Developer PowerShell
192
- launch_ps = vs_path / "Common7" / "Tools" / "Launch-VsDevShell.ps1"
193
- if launch_ps.exists():
195
+ # Developer PowerShell - find DevShell.dll (location varies by VS version)
196
+ devshell_dll = None
197
+ for dll_path in [
198
+ vs_path / "Common7" / "Tools" / "Microsoft.VisualStudio.DevShell.dll",
199
+ vs_path
200
+ / "Common7"
201
+ / "Tools"
202
+ / "vsdevshell"
203
+ / "Microsoft.VisualStudio.DevShell.dll",
204
+ ]:
205
+ if dll_path.exists():
206
+ devshell_dll = dll_path
207
+ break
208
+
209
+ if devshell_dll and instance_id:
194
210
  name = f"Dev PS {year}"
195
211
  shell_id = f"devps-{year}-{edition.lower()}"
196
- # powershell.exe -NoExit -Command "& 'path\to\Launch-VsDevShell.ps1'"
212
+ # Use forward slashes to avoid backslash escape issues (PowerShell accepts both)
213
+ dll_str = str(devshell_dll).replace("\\", "/")
214
+ cmd = f"Import-Module '{dll_str}'; Enter-VsDevShell {instance_id} -SkipAutomaticLocation"
197
215
  shells.append(
198
216
  (
199
217
  name,
200
218
  shell_id,
201
219
  "powershell.exe",
202
- ["-NoExit", "-Command", f"& '{launch_ps}'"],
220
+ ["-NoExit", "-Command", cmd],
203
221
  )
204
222
  )
205
223
 
@@ -405,16 +423,23 @@ class ShellDetector:
405
423
  """Get shell ID from user's $SHELL environment variable.
406
424
 
407
425
  Returns:
408
- Shell ID if $SHELL is set and matches a known shell, None otherwise.
426
+ Shell ID if $SHELL is set and valid, None otherwise.
427
+ For unknown shells, returns the executable name as the ID.
409
428
  """
410
429
  shell_path = os.environ.get("SHELL", "")
411
430
  if not shell_path:
412
431
  return None
413
432
 
433
+ path = Path(shell_path)
434
+
435
+ # Validate shell exists
436
+ if not path.exists() and not shutil.which(shell_path):
437
+ return None
438
+
414
439
  # Extract shell name from path (e.g., /usr/bin/fish -> fish)
415
- shell_name = Path(shell_path).name.lower()
440
+ shell_name = path.name.lower()
416
441
 
417
- # Map common shell names to IDs
442
+ # Map common shell names to canonical IDs (for consistency)
418
443
  shell_map = {
419
444
  "bash": "bash",
420
445
  "zsh": "zsh",
@@ -422,4 +447,45 @@ class ShellDetector:
422
447
  "sh": "sh",
423
448
  }
424
449
 
425
- return shell_map.get(shell_name)
450
+ # Return known ID or use executable name for unknown shells
451
+ return shell_map.get(shell_name, shell_name)
452
+
453
+ def _create_shell_from_env(self) -> ShellCommand | None:
454
+ """Create a ShellCommand from user's $SHELL environment variable.
455
+
456
+ Returns:
457
+ ShellCommand if $SHELL is set and valid, None otherwise.
458
+ """
459
+ shell_path = os.environ.get("SHELL", "")
460
+ if not shell_path:
461
+ return None
462
+
463
+ path = Path(shell_path)
464
+
465
+ # Validate shell exists
466
+ if not path.exists() and not shutil.which(shell_path):
467
+ return None
468
+
469
+ shell_name = path.name.lower()
470
+
471
+ # Known shells with their display names and args
472
+ known_shells = {
473
+ "bash": ("Bash", ["--login"]),
474
+ "zsh": ("Zsh", ["--login"]),
475
+ "fish": ("Fish", []),
476
+ "sh": ("Sh", []),
477
+ }
478
+
479
+ if shell_name in known_shells:
480
+ display_name, args = known_shells[shell_name]
481
+ else:
482
+ # Unknown shell - use capitalized name, no special args
483
+ display_name = shell_name.capitalize()
484
+ args = []
485
+
486
+ return ShellCommand(
487
+ id=shell_name,
488
+ name=display_name,
489
+ command=shell_path,
490
+ args=tuple(args),
491
+ )
@@ -38,10 +38,7 @@ class InMemorySessionRepository(SessionRepository[PTYHandle]):
38
38
  user_id = str(session.user_id)
39
39
 
40
40
  self._sessions[session_id] = session
41
-
42
- if user_id not in self._user_sessions:
43
- self._user_sessions[user_id] = set()
44
- self._user_sessions[user_id].add(session_id)
41
+ self._user_sessions.setdefault(user_id, set()).add(session_id)
45
42
 
46
43
  def remove(self, session_id: SessionId) -> Session[PTYHandle] | None:
47
44
  """Remove and return a session."""