flowent 0.2.2__tar.gz → 0.2.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. {flowent-0.2.2 → flowent-0.2.3}/PKG-INFO +1 -1
  2. {flowent-0.2.2 → flowent-0.2.3}/pyproject.toml +1 -1
  3. {flowent-0.2.2 → flowent-0.2.3}/src/flowent/agent.py +1 -0
  4. {flowent-0.2.2 → flowent-0.2.3}/src/flowent/context.py +2 -0
  5. {flowent-0.2.2 → flowent-0.2.3}/src/flowent/llm.py +9 -4
  6. {flowent-0.2.2 → flowent-0.2.3}/src/flowent/main.py +227 -16
  7. {flowent-0.2.2 → flowent-0.2.3}/src/flowent/permissions.py +5 -2
  8. flowent-0.2.3/src/flowent/shell.py +94 -0
  9. flowent-0.2.2/src/flowent/static/assets/index-Bz76A4EJ.js → flowent-0.2.3/src/flowent/static/assets/index-D7t9qNrC.js +8 -8
  10. {flowent-0.2.2 → flowent-0.2.3}/src/flowent/static/index.html +1 -1
  11. {flowent-0.2.2 → flowent-0.2.3}/src/flowent/tools.py +5 -2
  12. {flowent-0.2.2 → flowent-0.2.3}/README.md +0 -0
  13. {flowent-0.2.2 → flowent-0.2.3}/src/flowent/__init__.py +0 -0
  14. {flowent-0.2.2 → flowent-0.2.3}/src/flowent/_version.py +0 -0
  15. {flowent-0.2.2 → flowent-0.2.3}/src/flowent/approval.py +0 -0
  16. {flowent-0.2.2 → flowent-0.2.3}/src/flowent/channels.py +0 -0
  17. {flowent-0.2.2 → flowent-0.2.3}/src/flowent/cli.py +0 -0
  18. {flowent-0.2.2 → flowent-0.2.3}/src/flowent/compact.py +0 -0
  19. {flowent-0.2.2 → flowent-0.2.3}/src/flowent/logging.py +0 -0
  20. {flowent-0.2.2 → flowent-0.2.3}/src/flowent/mcp.py +0 -0
  21. {flowent-0.2.2 → flowent-0.2.3}/src/flowent/mcp_import.py +0 -0
  22. {flowent-0.2.2 → flowent-0.2.3}/src/flowent/patch.py +0 -0
  23. {flowent-0.2.2 → flowent-0.2.3}/src/flowent/paths.py +0 -0
  24. {flowent-0.2.2 → flowent-0.2.3}/src/flowent/sandbox.py +0 -0
  25. {flowent-0.2.2 → flowent-0.2.3}/src/flowent/skills.py +0 -0
  26. {flowent-0.2.2 → flowent-0.2.3}/src/flowent/static/assets/geist-cyrillic-wght-normal-CHSlOQsW.woff2 +0 -0
  27. {flowent-0.2.2 → flowent-0.2.3}/src/flowent/static/assets/geist-latin-ext-wght-normal-DMtmJ5ZE.woff2 +0 -0
  28. {flowent-0.2.2 → flowent-0.2.3}/src/flowent/static/assets/geist-latin-wght-normal-Dm3htQBi.woff2 +0 -0
  29. {flowent-0.2.2 → flowent-0.2.3}/src/flowent/static/assets/index-DufpDl8x.css +0 -0
  30. {flowent-0.2.2 → flowent-0.2.3}/src/flowent/static/flowent.png +0 -0
  31. {flowent-0.2.2 → flowent-0.2.3}/src/flowent/storage.py +0 -0
  32. {flowent-0.2.2 → flowent-0.2.3}/src/flowent/usage.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flowent
3
- Version: 0.2.2
3
+ Version: 0.2.3
4
4
  Summary: A workflow orchestration platform for multi-agent collaboration
5
5
  Keywords: agent,agents,ai,ai-agents,assistant,automation,code-generation,llm,mcp,orchestration,sandbox,web-application,workflow
6
6
  Author: ImFeH2
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "flowent"
3
- version = "0.2.2"
3
+ version = "0.2.3"
4
4
  description = "A workflow orchestration platform for multi-agent collaboration"
5
5
  readme = "README.md"
6
6
  authors = [
@@ -240,6 +240,7 @@ async def run_agent_stream(
240
240
  round_number,
241
241
  tool_calls,
242
242
  )
243
+ yield AgentStreamEvent(event="output_done", data={"index": round_number})
243
244
  if not tool_calls:
244
245
  if not final_content and not final_thinking:
245
246
  raise RuntimeError(EMPTY_MODEL_RESPONSE_ERROR)
@@ -4,6 +4,7 @@ import os
4
4
  from pathlib import Path
5
5
 
6
6
  from flowent.llm import ChatMessage
7
+ from flowent.shell import shell_invocation_description
7
8
  from flowent.tools import tool_specs
8
9
 
9
10
  DEFAULT_PROJECT_INSTRUCTIONS_MAX_BYTES = 32768
@@ -108,6 +109,7 @@ def environment_context_message(cwd: Path) -> ChatMessage:
108
109
  content=(
109
110
  "<environment_context>\n"
110
111
  f" <cwd>{cwd.resolve(strict=False)}</cwd>\n"
112
+ f" <shell>{shell_invocation_description()}</shell>\n"
111
113
  " <filesystem>workspace-write</filesystem>\n"
112
114
  " <network>enabled</network>\n"
113
115
  " <tools>\n"
@@ -108,6 +108,7 @@ MODEL_PREFIXES: dict[ProviderFormat, str] = {
108
108
  ProviderFormat.ANTHROPIC: "anthropic",
109
109
  ProviderFormat.GEMINI: "gemini",
110
110
  }
111
+ OPENAI_RESPONSES_MODEL_PREFIX = "responses/"
111
112
  _litellm_stream_error_patch_installed = False
112
113
 
113
114
  PROVIDER_API_VERSIONS: dict[ProviderFormat, str] = {
@@ -121,7 +122,10 @@ VERSION_PATH_SEGMENT = re.compile(r"^v\d+(?:[a-z]+)?$", re.IGNORECASE)
121
122
 
122
123
 
123
124
  def provider_model_name(connection: ProviderConnection) -> str:
124
- return f"{MODEL_PREFIXES[connection.provider]}/{connection.model}"
125
+ model = normalize_provider_model_name(connection.provider, connection.model)
126
+ if connection.provider == ProviderFormat.OPENAI_RESPONSES:
127
+ model = f"{OPENAI_RESPONSES_MODEL_PREFIX}{model}"
128
+ return f"{MODEL_PREFIXES[connection.provider]}/{model}"
125
129
 
126
130
 
127
131
  def provider_litellm_name(provider: ProviderFormat) -> str:
@@ -164,9 +168,10 @@ def normalize_provider_base_url(
164
168
 
165
169
  def normalize_provider_model_name(provider: ProviderFormat, model: str) -> str:
166
170
  prefix = f"{provider_litellm_name(provider)}/"
167
- if model.startswith(prefix):
168
- return model.removeprefix(prefix)
169
- return model
171
+ normalized_model = model.removeprefix(prefix) if model.startswith(prefix) else model
172
+ if provider == ProviderFormat.OPENAI_RESPONSES:
173
+ return normalized_model.removeprefix(OPENAI_RESPONSES_MODEL_PREFIX)
174
+ return normalized_model
170
175
 
171
176
 
172
177
  def stream_failure_message(chunk: Any) -> str:
@@ -1,12 +1,14 @@
1
1
  import asyncio
2
+ import copy
2
3
  import json
3
4
  import logging
4
5
  import os
5
- from collections.abc import AsyncIterator, Mapping, Sequence
6
+ import time
7
+ from collections.abc import AsyncIterator, Awaitable, Mapping, Sequence
6
8
  from contextlib import asynccontextmanager, suppress
7
9
  from dataclasses import dataclass, field
8
10
  from pathlib import Path
9
- from typing import Literal
11
+ from typing import Any, Literal
10
12
  from uuid import uuid4
11
13
 
12
14
  from fastapi import FastAPI, HTTPException, Query
@@ -89,6 +91,7 @@ DEFAULT_AUTO_COMPACT_CONTEXT_WINDOW_RATIO = 0.95
89
91
  AUTO_COMPACT_RETAINED_MESSAGE_TOKEN_BUDGET = 20_000
90
92
  APPROVAL_TRANSCRIPT_MESSAGE_LIMIT = 12
91
93
  APPROVAL_TRANSCRIPT_TEXT_LIMIT = 2_000
94
+ WORKSPACE_PROGRESS_FLUSH_INTERVAL_SECONDS = 0.5
92
95
 
93
96
 
94
97
  class ProviderModelsRequest(BaseModel):
@@ -180,6 +183,7 @@ class WritablePathListResponse(BaseModel):
180
183
  @dataclass
181
184
  class WorkspaceRun:
182
185
  condition: asyncio.Condition
186
+ active_output: Literal["text", "thinking"] | None = None
183
187
  discard_on_cancel: bool = False
184
188
  events: list[tuple[int, str, dict[str, object]]] = field(default_factory=list)
185
189
  generation: int = 0
@@ -200,8 +204,13 @@ def stream_event(
200
204
  return f"{id_line}event: {event}\ndata: {json.dumps(data)}\n\n"
201
205
 
202
206
 
203
- def stream_message_data(message: StoredMessage) -> dict[str, object]:
204
- return {**message.model_dump(), "status": message.status}
207
+ def stream_message_data(
208
+ message: StoredMessage, active_output: Literal["text", "thinking"] | None = None
209
+ ) -> dict[str, object]:
210
+ data = {**message.model_dump(), "status": message.status}
211
+ if active_output is not None:
212
+ data["active_output"] = active_output
213
+ return data
205
214
 
206
215
 
207
216
  def append_or_replace_message(
@@ -216,13 +225,131 @@ def append_or_replace_message(
216
225
  def run_snapshot_data_at(
217
226
  run: WorkspaceRun, event_index: int
218
227
  ) -> dict[str, object] | None:
219
- for current_event_index, event, data in reversed(run.events):
220
- if current_event_index > event_index or event != "snapshot":
228
+ snapshot_event_index = 0
229
+ snapshot: dict[str, object] | None = None
230
+ for current_event_index, event, data in run.events:
231
+ if current_event_index > event_index:
232
+ break
233
+ if event != "snapshot":
234
+ if event == "start" and snapshot is None:
235
+ assistant_id = data.get("id")
236
+ if isinstance(assistant_id, str):
237
+ snapshot_event_index = current_event_index
238
+ snapshot = {
239
+ "author": "assistant",
240
+ "content": "",
241
+ "groups": [],
242
+ "id": assistant_id,
243
+ "status": "running",
244
+ "tools": [],
245
+ }
221
246
  continue
222
247
  message = data.get("message")
223
248
  if isinstance(message, dict):
224
- return message
225
- return None
249
+ snapshot_event_index = current_event_index
250
+ snapshot = copy.deepcopy(message)
251
+ if snapshot is None:
252
+ return None
253
+ for current_event_index, event, data in run.events:
254
+ if current_event_index <= snapshot_event_index:
255
+ continue
256
+ if current_event_index > event_index:
257
+ break
258
+ apply_stream_event_to_snapshot(snapshot, event, data)
259
+ return snapshot
260
+
261
+
262
+ def apply_stream_event_to_snapshot(
263
+ snapshot: dict[str, object], event: str, data: dict[str, object]
264
+ ) -> None:
265
+ if event == "output_start":
266
+ snapshot.pop("active_output", None)
267
+ index = data.get("index")
268
+ if isinstance(index, int):
269
+ append_snapshot_group(snapshot, index)
270
+ if event == "delta":
271
+ append_snapshot_text(snapshot, str(data.get("content") or ""))
272
+ if event == "thinking_delta":
273
+ append_snapshot_thinking(snapshot, str(data.get("content") or ""))
274
+ if event == "output_done":
275
+ snapshot.pop("active_output", None)
276
+
277
+
278
+ def snapshot_groups(snapshot: dict[str, object]) -> list[dict[str, object]]:
279
+ groups = snapshot.get("groups")
280
+ if not isinstance(groups, list):
281
+ groups = []
282
+ snapshot["groups"] = groups
283
+ return groups
284
+
285
+
286
+ def append_snapshot_group(
287
+ snapshot: dict[str, object], index: int | None = None
288
+ ) -> None:
289
+ groups = snapshot_groups(snapshot)
290
+ assistant_id = str(snapshot.get("id") or "assistant")
291
+ group_index = index if index is not None else len(groups) + 1
292
+ group_id = f"{assistant_id}-group-{group_index}"
293
+ if groups and groups[-1].get("id") == group_id:
294
+ return
295
+ groups.append({"id": group_id, "items": []})
296
+
297
+
298
+ def append_snapshot_text(snapshot: dict[str, object], content: str) -> None:
299
+ if not content:
300
+ return
301
+ snapshot["active_output"] = "text"
302
+ snapshot["content"] = f"{snapshot.get('content') or ''}{content}"
303
+ append_snapshot_item_content(snapshot, content, "text")
304
+
305
+
306
+ def append_snapshot_thinking(snapshot: dict[str, object], content: str) -> None:
307
+ if not content:
308
+ return
309
+ snapshot["active_output"] = "thinking"
310
+ snapshot["thinking"] = f"{snapshot.get('thinking') or ''}{content}"
311
+ append_snapshot_item_content(snapshot, content, "thinking")
312
+
313
+
314
+ def append_snapshot_item_content(
315
+ snapshot: dict[str, object], content: str, item_type: Literal["text", "thinking"]
316
+ ) -> None:
317
+ groups = snapshot_groups(snapshot)
318
+ if not groups:
319
+ append_snapshot_group(snapshot)
320
+ group = groups[-1]
321
+ items = group.get("items")
322
+ if not isinstance(items, list):
323
+ items = []
324
+ group["items"] = items
325
+ item = next(
326
+ (
327
+ current
328
+ for current in reversed(items)
329
+ if isinstance(current, dict) and current.get("type") == item_type
330
+ ),
331
+ None,
332
+ )
333
+ if item is None:
334
+ assistant_id = str(snapshot.get("id") or "assistant")
335
+ snapshot_item_count = 0
336
+ for current_group in groups:
337
+ current_items = current_group.get("items")
338
+ if not isinstance(current_items, list):
339
+ continue
340
+ snapshot_item_count += sum(
341
+ 1
342
+ for current_item in current_items
343
+ if isinstance(current_item, dict)
344
+ and current_item.get("type") == item_type
345
+ )
346
+ item = {
347
+ "content": "",
348
+ "id": f"{assistant_id}-{item_type}-{snapshot_item_count + 1}",
349
+ "type": item_type,
350
+ }
351
+ items.append(item)
352
+ item["content"] = f"{item.get('content') or ''}{content}"
226
353
 
227
354
 
228
355
  USER_VISIBLE_RUN_ERROR_TITLE = "Request failed"
@@ -1007,6 +1134,58 @@ def create_app(
1007
1134
  telegram_transport=telegram_transport,
1008
1135
  )
1009
1136
 
1137
+ async def gather_shutdown_tasks(
1138
+ label: str, tasks: Sequence[asyncio.Task[Any]]
1139
+ ) -> None:
1140
+ if not tasks:
1141
+ return
1142
+ results = await asyncio.gather(*tasks, return_exceptions=True)
1143
+ for result in results:
1144
+ if result is None or isinstance(result, asyncio.CancelledError):
1145
+ continue
1146
+ if isinstance(result, BaseException):
1147
+ logger.error(
1148
+ "%s cleanup task failed",
1149
+ label,
1150
+ exc_info=(type(result), result, result.__traceback__),
1151
+ )
1152
+
1153
+ async def stop_workspace_runs_for_shutdown() -> None:
1154
+ tasks: list[asyncio.Task[None]] = []
1155
+ for run in workspace_runs.values():
1156
+ if run.task is None or run.task.done():
1157
+ continue
1158
+ run.task.cancel()
1159
+ tasks.append(run.task)
1160
+ await gather_shutdown_tasks("Workspace run", tasks)
1161
+
1162
+ async def stop_workspace_compact_for_shutdown() -> None:
1163
+ nonlocal active_compact_task
1164
+ if active_compact_task is None:
1165
+ store.save_is_compacting(False)
1166
+ return
1167
+ task = active_compact_task.task
1168
+ active_compact_task = None
1169
+ if not task.done():
1170
+ task.cancel()
1171
+ await gather_shutdown_tasks("Workspace compact", [task])
1172
+ store.save_is_compacting(False)
1173
+
1174
+ async def run_shutdown_step(label: str, cleanup: Awaitable[object]) -> None:
1175
+ try:
1176
+ await cleanup
1177
+ except Exception:
1178
+ logger.exception("%s cleanup failed during shutdown", label)
1179
+
1180
+ async def graceful_shutdown() -> None:
1181
+ await run_shutdown_step("Workspace run", stop_workspace_runs_for_shutdown())
1182
+ await run_shutdown_step(
1183
+ "Workspace compact", stop_workspace_compact_for_shutdown()
1184
+ )
1185
+ if telegram_bot_manager is not None:
1186
+ await run_shutdown_step("Telegram", telegram_bot_manager.stop_all())
1187
+ await run_shutdown_step("MCP", mcp_manager.stop_all())
1188
+
1010
1189
  @asynccontextmanager
1011
1190
  async def lifespan(app: FastAPI) -> AsyncIterator[None]:
1012
1191
  app.state.mcp_manager = mcp_manager
@@ -1017,9 +1196,7 @@ def create_app(
1017
1196
  try:
1018
1197
  yield
1019
1198
  finally:
1020
- if telegram_bot_manager is not None:
1021
- await telegram_bot_manager.stop_all()
1022
- await mcp_manager.stop_all()
1199
+ await graceful_shutdown()
1023
1200
 
1024
1201
  app = FastAPI(title="Flowent", lifespan=lifespan)
1025
1202
  app.state.mcp_manager = mcp_manager
@@ -1204,7 +1381,7 @@ def create_app(
1204
1381
  await append_run_event(
1205
1382
  run,
1206
1383
  "snapshot",
1207
- {"message": stream_message_data(message)},
1384
+ {"message": stream_message_data(message, run.active_output)},
1208
1385
  )
1209
1386
 
1210
1387
  def active_workspace_run() -> WorkspaceRun | None:
@@ -1258,11 +1435,14 @@ def create_app(
1258
1435
  status="running",
1259
1436
  )
1260
1437
  assistant_output = AssistantOutputBuilder(assistant_message.id)
1438
+ last_progress_flush_at = 0.0
1261
1439
 
1262
1440
  def is_current_generation() -> bool:
1263
1441
  return run.generation == workspace_generation
1264
1442
 
1265
- def persist_assistant(status: str = "running") -> StoredMessage | None:
1443
+ def update_assistant_message(
1444
+ status: str = "running", *, persist: bool
1445
+ ) -> StoredMessage | None:
1266
1446
  nonlocal next_messages, assistant_message
1267
1447
  if not is_current_generation() or run.discard_on_cancel:
1268
1448
  return None
@@ -1279,9 +1459,33 @@ def create_app(
1279
1459
  next_messages = append_or_replace_message(
1280
1460
  next_messages, assistant_message
1281
1461
  )
1282
- store.upsert_message(assistant_message)
1462
+ if persist:
1463
+ store.upsert_message(assistant_message)
1283
1464
  return assistant_message
1284
1465
 
1466
+ def persist_assistant(status: str = "running") -> StoredMessage | None:
1467
+ nonlocal last_progress_flush_at
1468
+ message = update_assistant_message(status, persist=True)
1469
+ if status == "running" and message is not None:
1470
+ last_progress_flush_at = time.monotonic()
1471
+ return message
1472
+
1473
+ def refresh_assistant(status: str = "running") -> StoredMessage | None:
1474
+ return update_assistant_message(status, persist=False)
1475
+
1476
+ def persist_assistant_progress() -> StoredMessage | None:
1477
+ nonlocal last_progress_flush_at
1478
+ now = time.monotonic()
1479
+ if (
1480
+ last_progress_flush_at > 0
1481
+ and now - last_progress_flush_at
1482
+ < WORKSPACE_PROGRESS_FLUSH_INTERVAL_SECONDS
1483
+ ):
1484
+ refresh_assistant()
1485
+ return None
1486
+ last_progress_flush_at = now
1487
+ return update_assistant_message("running", persist=True)
1488
+
1285
1489
  try:
1286
1490
  current_tool_id: str | None = None
1287
1491
  turn_usage_info: TokenUsageInfo | None = None
@@ -1445,11 +1649,15 @@ def create_app(
1445
1649
  if event.event == "output_start":
1446
1650
  index = event.data.get("index")
1447
1651
  if isinstance(index, int):
1652
+ run.active_output = None
1448
1653
  assistant_output.start_group(index)
1449
1654
  snapshot_after_event = persist_assistant()
1655
+ if event.event == "output_done":
1656
+ run.active_output = None
1450
1657
  if event.event == "tool_start":
1451
1658
  tool = event.data.get("tool")
1452
1659
  if isinstance(tool, dict) and isinstance(tool.get("id"), str):
1660
+ run.active_output = None
1453
1661
  current_tool_id = tool["id"]
1454
1662
  assistant_output.start_tool(
1455
1663
  StoredToolItem.model_validate(tool)
@@ -1467,15 +1675,17 @@ def create_app(
1467
1675
  assistant_output.update_tool(tool_id, event.data)
1468
1676
  snapshot_after_event = persist_assistant()
1469
1677
  if event.event == "delta":
1678
+ run.active_output = "text"
1470
1679
  assistant_output.append_text(
1471
1680
  str(event.data.get("content") or "")
1472
1681
  )
1473
- snapshot_after_event = persist_assistant()
1682
+ snapshot_after_event = persist_assistant_progress()
1474
1683
  if event.event == "thinking_delta":
1684
+ run.active_output = "thinking"
1475
1685
  assistant_output.append_thinking(
1476
1686
  str(event.data.get("content") or "")
1477
1687
  )
1478
- snapshot_after_event = persist_assistant()
1688
+ snapshot_after_event = persist_assistant_progress()
1479
1689
  if event.event == "usage":
1480
1690
  usage_data = event.data.get("usage")
1481
1691
  if isinstance(usage_data, dict):
@@ -1503,6 +1713,7 @@ def create_app(
1503
1713
  if event.event == "done":
1504
1714
  message = event.data.get("message")
1505
1715
  if isinstance(message, dict):
1716
+ run.active_output = None
1506
1717
  assistant_output.apply_done_message(message)
1507
1718
  response_usage_info = store.read_usage_info()
1508
1719
  final_usage_info = turn_usage_info
@@ -12,6 +12,7 @@ from flowent.approval import (
12
12
  )
13
13
  from flowent.patch import affected_paths
14
14
  from flowent.sandbox import SandboxError, SandboxRunner, path_is_within
15
+ from flowent.shell import shell_invocation
15
16
  from flowent.tools import (
16
17
  ToolContext,
17
18
  ToolResult,
@@ -290,10 +291,11 @@ async def shell_command_with_writable_paths(
290
291
  ) -> ToolResult:
291
292
  command = str(arguments["command"])
292
293
  timeout_seconds = number_argument(arguments, "timeout_seconds", 30)
294
+ invocation = shell_invocation(command)
293
295
  result = await SandboxRunner(
294
296
  cwd=context.cwd,
295
297
  writable_roots=writable_paths,
296
- ).run_async(["/bin/sh", "-c", command], timeout_seconds=timeout_seconds)
298
+ ).run_async(invocation.args, env=invocation.env, timeout_seconds=timeout_seconds)
297
299
  ok = result.exit_code == 0
298
300
  content = result.stdout or result.stderr
299
301
  return ToolResult(
@@ -355,8 +357,9 @@ async def shell_command_without_sandbox(
355
357
  ) -> ToolResult:
356
358
  command = str(arguments["command"])
357
359
  timeout_seconds = number_argument(arguments, "timeout_seconds", 30)
360
+ invocation = shell_invocation(command)
358
361
  result = await SandboxRunner(cwd=context.cwd).run_unsandboxed_async(
359
- ["/bin/sh", "-c", command], timeout_seconds=timeout_seconds
362
+ invocation.args, env=invocation.env, timeout_seconds=timeout_seconds
360
363
  )
361
364
  ok = result.exit_code == 0
362
365
  content = result.stdout or result.stderr
@@ -0,0 +1,94 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import shutil
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class ShellInvocation:
11
+ args: list[str]
12
+ env: dict[str, str]
13
+ shell: str
14
+
15
+
16
+ def executable_path(path: Path) -> str | None:
17
+ if path.is_file() and os.access(path, os.X_OK):
18
+ return str(path.resolve(strict=False))
19
+ return None
20
+
21
+
22
+ def executable_command_path(command: str) -> str | None:
23
+ resolved = shutil.which(command)
24
+ if resolved is None:
25
+ return None
26
+ return executable_path(Path(resolved))
27
+
28
+
29
+ def shell_path(raw_shell: str) -> str | None:
30
+ raw_shell = raw_shell.strip()
31
+ if not raw_shell:
32
+ return None
33
+ expanded = Path(raw_shell).expanduser()
34
+ if expanded.is_absolute():
35
+ return executable_path(expanded)
36
+ return executable_command_path(raw_shell)
37
+
38
+
39
+ def user_default_shell() -> str | None:
40
+ try:
41
+ import pwd
42
+ except ImportError:
43
+ return None
44
+
45
+ try:
46
+ shell = pwd.getpwuid(os.getuid()).pw_shell
47
+ except (AttributeError, KeyError, OSError):
48
+ return None
49
+ return shell_path(shell)
50
+
51
+
52
+ def environment_shell() -> str | None:
53
+ return shell_path(os.environ.get("SHELL", ""))
54
+
55
+
56
+ FALLBACK_SHELL_PATHS = {
57
+ "bash": [Path("/bin/bash"), Path("/usr/bin/bash")],
58
+ "sh": [Path("/bin/sh"), Path("/usr/bin/sh")],
59
+ }
60
+
61
+
62
+ def fallback_shell(command: str) -> str | None:
63
+ shell = executable_command_path(command)
64
+ if shell is not None:
65
+ return shell
66
+ for fallback in FALLBACK_SHELL_PATHS.get(command, []):
67
+ shell = executable_path(fallback)
68
+ if shell is not None:
69
+ return shell
70
+ return None
71
+
72
+
73
+ def default_shell() -> str:
74
+ for shell in [user_default_shell(), environment_shell()]:
75
+ if shell is not None:
76
+ return shell
77
+ for command in ["bash", "sh"]:
78
+ shell = fallback_shell(command)
79
+ if shell is not None:
80
+ return shell
81
+ return "sh"
82
+
83
+
84
+ def shell_invocation(command: str) -> ShellInvocation:
85
+ shell = default_shell()
86
+ return ShellInvocation(
87
+ args=[shell, "-c", command],
88
+ env={"SHELL": shell},
89
+ shell=shell,
90
+ )
91
+
92
+
93
+ def shell_invocation_description() -> str:
94
+ return f"{default_shell()} -c"