kash-shell 0.3.30__py3-none-any.whl → 0.3.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,128 @@
1
+ from __future__ import annotations
2
+
3
+ import threading
4
+ from collections.abc import Callable
5
+
6
+ from prettyfmt import fmt_lines, fmt_path
7
+
8
+ from kash.config.logger import get_logger
9
+ from kash.file_storage.store_filenames import join_suffix, parse_item_filename
10
+ from kash.model.items_model import Item, ItemId
11
+ from kash.model.paths_model import StorePath
12
+ from kash.utils.common.uniquifier import Uniquifier
13
+ from kash.utils.errors import InvalidFilename, SkippableError
14
+
15
+ log = get_logger(__name__)
16
+
17
+
18
+ class ItemIdIndex:
19
+ """
20
+ Index of item identities and historical filenames within a workspace.
21
+
22
+ - Tracks a mapping of `ItemId -> StorePath` for quick lookups
23
+ - Tracks historical slugs via `Uniquifier` to generate unique names consistently
24
+
25
+ TODO: Should add a file system watcher to make this always consistent with disk state.
26
+ """
27
+
28
+ def __init__(self) -> None:
29
+ self._lock = threading.RLock()
30
+ self.uniquifier = Uniquifier()
31
+ self.id_map: dict[ItemId, StorePath] = {}
32
+
33
+ def reset(self) -> None:
34
+ """
35
+ Clear all index state.
36
+ """
37
+ with self._lock:
38
+ log.info("ItemIdIndex: reset")
39
+ self.uniquifier = Uniquifier()
40
+ self.id_map.clear()
41
+
42
+ def __len__(self) -> int:
43
+ """
44
+ Number of unique names tracked.
45
+ """
46
+ with self._lock:
47
+ return len(self.uniquifier)
48
+
49
+ def uniquify_slug(self, slug: str, full_suffix: str) -> tuple[str, list[str]]:
50
+ """
51
+ Return a unique slug and historic slugs for the given suffix.
52
+ """
53
+ with self._lock:
54
+ # This updates internal history as a side-effect. Log for consistency.
55
+ log.info("ItemIdIndex: uniquify slug '%s' with suffix '%s'", slug, full_suffix)
56
+ return self.uniquifier.uniquify_historic(slug, full_suffix)
57
+
58
+ def index_item(
59
+ self, store_path: StorePath, load_item: Callable[[StorePath], Item]
60
+ ) -> StorePath | None:
61
+ """
62
+ Update the index with an item at `store_path`.
63
+ Returns store path of any duplicate item with the same id, otherwise None.
64
+ """
65
+ name, item_type, _format, file_ext = parse_item_filename(store_path)
66
+ if not file_ext:
67
+ log.debug(
68
+ "Skipping file with unrecognized name or extension: %s",
69
+ fmt_path(store_path),
70
+ )
71
+ return None
72
+
73
+ with self._lock:
74
+ full_suffix = join_suffix(item_type.name, file_ext.name) if item_type else file_ext.name
75
+ # Track unique name history
76
+ self.uniquifier.add(name, full_suffix)
77
+
78
+ log.info("ItemIdIndex: indexing %s", fmt_path(store_path))
79
+
80
+ # Load item outside the lock to avoid holding it during potentially slow I/O
81
+ try:
82
+ item = load_item(store_path)
83
+ except (ValueError, SkippableError) as e:
84
+ log.warning(
85
+ "ItemIdIndex: could not index file, skipping: %s: %s",
86
+ fmt_path(store_path),
87
+ e,
88
+ )
89
+ return None
90
+
91
+ dup_path: StorePath | None = None
92
+ with self._lock:
93
+ item_id = item.item_id()
94
+ if item_id:
95
+ old_path = self.id_map.get(item_id)
96
+ if old_path and old_path != store_path:
97
+ dup_path = old_path
98
+ log.info(
99
+ "ItemIdIndex: duplicate id detected %s:\n%s",
100
+ item_id,
101
+ fmt_lines([old_path, store_path]),
102
+ )
103
+ self.id_map[item_id] = store_path
104
+ log.info("ItemIdIndex: set id %s -> %s", item_id, fmt_path(store_path))
105
+
106
+ return dup_path
107
+
108
+ def unindex_item(self, store_path: StorePath, load_item: Callable[[StorePath], Item]) -> None:
109
+ """
110
+ Remove an item from the id index.
111
+ """
112
+ try:
113
+ # Load item outside the lock to avoid holding it during potentially slow I/O
114
+ item = load_item(store_path)
115
+ item_id = item.item_id()
116
+ if item_id:
117
+ with self._lock:
118
+ try:
119
+ self.id_map.pop(item_id, None)
120
+ log.info("ItemIdIndex: removed id %s for %s", item_id, fmt_path(store_path))
121
+ except KeyError:
122
+ pass
123
+ except (FileNotFoundError, InvalidFilename):
124
+ pass
125
+
126
+ def find_store_path_by_id(self, item_id: ItemId) -> StorePath | None:
127
+ with self._lock:
128
+ return self.id_map.get(item_id)
@@ -6,8 +6,10 @@ from dataclasses import dataclass
6
6
 
7
7
  from funlog import log_calls
8
8
  from mcp.server.lowlevel import Server
9
+ from mcp.server.lowlevel.server import StructuredContent, UnstructuredContent
9
10
  from mcp.types import Prompt, Resource, TextContent, Tool
10
11
  from prettyfmt import fmt_path
12
+ from pydantic import BaseModel
11
13
  from strif import AtomicVar
12
14
 
13
15
  from kash.config.capture_output import CapturedOutput, captured_output
@@ -20,6 +22,7 @@ from kash.model.actions_model import Action, ActionResult
20
22
  from kash.model.exec_model import ExecContext
21
23
  from kash.model.params_model import TypedParamValues
22
24
  from kash.model.paths_model import StorePath
25
+ from kash.utils.common.url import Url
23
26
 
24
27
  log = get_logger(__name__)
25
28
 
@@ -109,6 +112,22 @@ def get_published_tools() -> list[Tool]:
109
112
  return []
110
113
 
111
114
 
115
+ class StructuredActionResult(BaseModel):
116
+ """
117
+ Error from an MCP tool call.
118
+ """
119
+
120
+ s3_paths: list[Url] | None = None
121
+ """If the tool created an S3 item, the S3 paths of the created items."""
122
+
123
+ error: str | None = None
124
+ """If the tool had an error, the error message."""
125
+
126
+ # TODO: Include other metadata.
127
+ # metadata: dict[str, Any] | None = None
128
+ # """Metadata about the action result."""
129
+
130
+
112
131
  @dataclass(frozen=True)
113
132
  class ToolResult:
114
133
  """
@@ -119,6 +138,7 @@ class ToolResult:
119
138
  captured_output: CapturedOutput
120
139
  action_result: ActionResult
121
140
  result_store_paths: list[StorePath]
141
+ result_s3_paths: list[Url]
122
142
  error: Exception | None = None
123
143
 
124
144
  @property
@@ -168,12 +188,13 @@ class ToolResult:
168
188
  # TODO: Add more info on how to find the logs.
169
189
  return "Check kash logs for details."
170
190
 
171
- def formatted_for_client(self) -> list[TextContent]:
191
+ def as_mcp_content(self) -> tuple[UnstructuredContent, StructuredContent]:
172
192
  """
173
- Convert the tool result to content for the client LLM.
193
+ Convert the tool result to content for the MCP client.
174
194
  """
195
+ structured = StructuredActionResult()
175
196
  if self.error:
176
- return [
197
+ unstructured = [
177
198
  TextContent(
178
199
  text=f"The tool `{self.action.name}` had an error: {self.error}.\n\n"
179
200
  + self.check_logs_message,
@@ -194,7 +215,7 @@ class ToolResult:
194
215
  if not chat_result:
195
216
  chat_result = "No result. Check kash logs for details."
196
217
 
197
- return [
218
+ unstructured = [
198
219
  TextContent(
199
220
  text=f"{self.output_summary}\n\n"
200
221
  f"{self.output_content}\n\n"
@@ -202,10 +223,15 @@ class ToolResult:
202
223
  type="text",
203
224
  ),
204
225
  ]
226
+ structured = StructuredActionResult(s3_paths=self.result_s3_paths)
227
+
228
+ return unstructured, structured.model_dump()
205
229
 
206
230
 
207
231
  @log_calls(level="info")
208
- def run_mcp_tool(action_name: str, arguments: dict) -> list[TextContent]:
232
+ def run_mcp_tool(
233
+ action_name: str, arguments: dict
234
+ ) -> tuple[UnstructuredContent, StructuredContent]:
209
235
  """
210
236
  Run the action as a tool.
211
237
  """
@@ -222,6 +248,7 @@ def run_mcp_tool(action_name: str, arguments: dict) -> list[TextContent]:
222
248
  refetch=False, # Using the file caches.
223
249
  # Keeping all transient files for now, but maybe make transient?
224
250
  override_state=None,
251
+ sync_to_s3=True, # Enable S3 syncing for MCP tools.
225
252
  ) as exec_settings:
226
253
  action_cls = look_up_action_class(action_name)
227
254
 
@@ -237,9 +264,9 @@ def run_mcp_tool(action_name: str, arguments: dict) -> list[TextContent]:
237
264
  context = ExecContext(action=action, settings=exec_settings)
238
265
  action_input = prepare_action_input(*input_items)
239
266
 
240
- result, result_store_paths, _archived_store_paths = run_action_with_caching(
241
- context, action_input
242
- )
267
+ result_with_paths = run_action_with_caching(context, action_input)
268
+ result = result_with_paths.result
269
+ result_store_paths = result_with_paths.result_paths
243
270
 
244
271
  # Return final result, formatted for the LLM to understand.
245
272
  return ToolResult(
@@ -247,8 +274,9 @@ def run_mcp_tool(action_name: str, arguments: dict) -> list[TextContent]:
247
274
  captured_output=capture.output,
248
275
  action_result=result,
249
276
  result_store_paths=result_store_paths,
277
+ result_s3_paths=result_with_paths.s3_paths,
250
278
  error=None,
251
- ).formatted_for_client()
279
+ ).as_mcp_content()
252
280
 
253
281
  except Exception as e:
254
282
  log.exception("Error running mcp tool")
@@ -258,7 +286,7 @@ def run_mcp_tool(action_name: str, arguments: dict) -> list[TextContent]:
258
286
  + "Check kash logs for details.",
259
287
  type="text",
260
288
  )
261
- ]
289
+ ], StructuredActionResult(error=str(e)).model_dump()
262
290
 
263
291
 
264
292
  def create_base_server() -> Server:
@@ -288,7 +316,9 @@ def create_base_server() -> Server:
288
316
  return []
289
317
 
290
318
  @app.call_tool()
291
- async def handle_tool(name: str, arguments: dict) -> list[TextContent]:
319
+ async def handle_tool(
320
+ name: str, arguments: dict
321
+ ) -> tuple[UnstructuredContent, StructuredContent]:
292
322
  try:
293
323
  if name not in _mcp_published_actions.copy():
294
324
  log.error(f"Unknown tool requested: {name}")
@@ -303,6 +333,6 @@ def create_base_server() -> Server:
303
333
  text=f"Error executing tool {name}: {e}",
304
334
  type="text",
305
335
  )
306
- ]
336
+ ], StructuredActionResult(error=str(e)).model_dump()
307
337
 
308
338
  return app
@@ -246,7 +246,17 @@ class Action(ABC):
246
246
 
247
247
  output_type: ItemType = ItemType.doc
248
248
  """
249
- The type of the output item(s), which for now are all assumed to be of the same type.
249
+ The type of the output item(s). If an action returns multiple output types,
250
+ this will be the output type of the first output.
251
+ This is mainly used for preassembly for the cache check if an output already exists.
252
+ """
253
+
254
+ output_format: Format | None = None
255
+ """
256
+ The format of the output item(s). The default is to assume it is the same
257
+ format as the input. If an action returns multiple output formats,
258
+ this will be the format of the first output.
259
+ This is mainly used for preassembly for the cache check if an output already exists.
250
260
  """
251
261
 
252
262
  expected_outputs: ArgCount = ONE_ARG
@@ -540,7 +550,7 @@ class Action(ABC):
540
550
  """
541
551
  can_preassemble = self.cacheable and self.expected_outputs == ONE_ARG
542
552
  log.info(
543
- "Preassemble check for `%s` is %s (%s with cacheable=%s)",
553
+ "Preassemble check for `%s`: can_preassemble=%s (expected_outputs=%s, cacheable=%s)",
544
554
  self.name,
545
555
  can_preassemble,
546
556
  self.expected_outputs,
@@ -549,9 +559,10 @@ class Action(ABC):
549
559
  if can_preassemble:
550
560
  # Using first input to determine the output title.
551
561
  primary_input = context.action_input.items[0]
552
- # In this case we only expect one output.
553
- item = primary_input.derived_copy(context, 0)
554
- return ActionResult([item])
562
+ # In this case we only expect one output, of the type specified by the action.
563
+ primary_output = primary_input.derived_copy(context, 0, type=context.action.output_type)
564
+ log.info("Preassembled output: source %s, %s", primary_output.source, primary_output)
565
+ return ActionResult([primary_output])
555
566
  else:
556
567
  # Caching disabled.
557
568
  return None
kash/model/exec_model.py CHANGED
@@ -43,6 +43,9 @@ class RuntimeSettings:
43
43
  no_format: bool = False
44
44
  """If True, will not normalize the output item's body text formatting (for Markdown)."""
45
45
 
46
+ sync_to_s3: bool = True
47
+ """If True, will sync output items to S3 if input was from S3."""
48
+
46
49
  @property
47
50
  def workspace(self) -> FileStore:
48
51
  from kash.workspaces.workspaces import get_ws
kash/model/items_model.py CHANGED
@@ -203,6 +203,15 @@ class ItemId:
203
203
  # If we got here, the item has no identity.
204
204
  item_id = None
205
205
 
206
+ log.debug(
207
+ "item_id is %s for type=%s, format=%s, url=%s, title=%s, source=%s",
208
+ item_id,
209
+ item.type,
210
+ item.format,
211
+ item.url,
212
+ item.title,
213
+ item.source,
214
+ )
206
215
  return item_id
207
216
 
208
217
 
@@ -835,7 +844,9 @@ class Item:
835
844
  the type and the body.
836
845
 
837
846
  Same as `new_copy_with` but also updates the `derived_from` relation. If we also
838
- have an action context, then use the `title_template` to derive a new title.
847
+ have an action context, then use that to fill some fields, in particular `title_template`
848
+ to derive a new title and `output_type` and `output_format` to set the output type
849
+ and format
839
850
  """
840
851
 
841
852
  # Get derived_from relation if possible.
@@ -869,20 +880,38 @@ class Item:
869
880
  if "external_path" not in updates:
870
881
  updates["external_path"] = None
871
882
 
883
+ action_context = action_context or self.context
884
+
885
+ if action_context:
886
+ # Default the output item type and format to the action's declared output_type
887
+ # and format if not explicitly set.
888
+ if "type" not in updates:
889
+ updates["type"] = action_context.action.output_type
890
+ # If we were not given a format override, we leave the output type the same.
891
+ elif action_context.action.output_format:
892
+ # Check an overridden format and then our own format.
893
+ new_output_format = updates.get("format", self.format)
894
+ if new_output_format and action_context.action.output_format != new_output_format:
895
+ log.warning(
896
+ "Output item format `%s` does not match declared output format `%s` for action `%s`",
897
+ new_output_format,
898
+ action_context.action.output_format,
899
+ action_context.action.name,
900
+ )
901
+
872
902
  new_item = self.new_copy_with(update_timestamp=True, **updates)
873
903
  if derived_from:
874
904
  new_item.update_relations(derived_from=derived_from)
875
905
 
876
- action_context = action_context or self.context
877
-
878
906
  # Record the history.
879
907
  if action_context:
880
- self.source = Source(
881
- operation=action_context.operation,
882
- output_num=output_num,
883
- cacheable=action_context.action.cacheable,
908
+ new_item.update_source(
909
+ Source(
910
+ operation=action_context.operation,
911
+ output_num=output_num,
912
+ cacheable=action_context.action.cacheable,
913
+ )
884
914
  )
885
- self.add_to_history(self.source.operation.summary())
886
915
  action = action_context.action
887
916
  else:
888
917
  action = None
@@ -911,9 +940,10 @@ class Item:
911
940
  setattr(self.relations, key, list(value))
912
941
  return self.relations
913
942
 
914
- def update_history(self, source: Source) -> None:
943
+ def update_source(self, source: Source) -> None:
915
944
  """
916
- Update the history of the item with the given operation.
945
+ Update the source and the history of the item to indicate it was created
946
+ by the given operation. For convenience, this is idempotent.
917
947
  """
918
948
  self.source = source
919
949
  self.add_to_history(source.operation.summary())
@@ -945,6 +975,9 @@ class Item:
945
975
  return metadata_matches and body_matches
946
976
 
947
977
  def add_to_history(self, operation_summary: OperationSummary):
978
+ """
979
+ For convenience, this is idempotent.
980
+ """
948
981
  if not self.history:
949
982
  self.history = []
950
983
  # Don't add duplicates to the history.
@@ -542,6 +542,8 @@ async def gather_limited_sync(
542
542
  # Mark as failed
543
543
  if status and task_id is not None:
544
544
  await status.finish(task_id, TaskState.FAILED, str(e))
545
+
546
+ log.warning("Task failed: %s: %s", label, e, exc_info=True)
545
547
  raise
546
548
 
547
549
  return await _gather_with_interrupt_handling(
@@ -0,0 +1,74 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Callable, Iterable, Sequence
4
+ from typing import Any, TypeVar
5
+
6
+ from kash.config.logger import get_logger
7
+ from kash.config.settings import global_settings
8
+ from kash.shell.output.shell_output import multitask_status
9
+ from kash.utils.api_utils.api_retries import RetrySettings
10
+ from kash.utils.api_utils.gather_limited import FuncTask, Limit, gather_limited_sync
11
+
12
+ T = TypeVar("T")
13
+
14
+ log = get_logger(name=__name__)
15
+
16
+
17
+ def _default_labeler(total: int) -> Callable[[int, Any], str]:
18
+ def labeler(i: int, _spec: Any) -> str: # pyright: ignore[reportUnusedParameter]
19
+ return f"Task {i + 1}/{total}"
20
+
21
+ return labeler
22
+
23
+
24
+ async def multitask_gather(
25
+ tasks: Iterable[FuncTask[T]] | Sequence[FuncTask[T]],
26
+ *,
27
+ labeler: Callable[[int, Any], str] | None = None,
28
+ limit: Limit | None = None,
29
+ bucket_limits: dict[str, Limit] | None = None,
30
+ retry_settings: RetrySettings | None = None,
31
+ show_progress: bool = True,
32
+ ) -> list[T]:
33
+ """
34
+ Run many `FuncTask`s concurrently with shared progress UI and rate limits.
35
+
36
+ This wraps the standard pattern of creating a status context, providing a labeler,
37
+ and calling `gather_limited_sync` with common options.
38
+
39
+ - `labeler` can be omitted; a simple "Task X/Y" label will be used.
40
+ - If `limit` is not provided, defaults are taken from `global_settings()`.
41
+ - If `show_progress` is False, tasks are run without the status context.
42
+ - By default, exceptions are returned as results rather than raised (return_exceptions=True).
43
+ """
44
+
45
+ # Normalize tasks to a list for length and stable iteration
46
+ task_list: list[FuncTask[T]] = list(tasks)
47
+
48
+ # Provide a default labeler if none is supplied
49
+ effective_labeler: Callable[[int, Any], str] = (
50
+ labeler if labeler is not None else _default_labeler(len(task_list))
51
+ )
52
+
53
+ # Provide sensible default rate limits if none are supplied
54
+ effective_limit: Limit = (
55
+ limit
56
+ if limit is not None
57
+ else Limit(
58
+ rps=global_settings().limit_rps,
59
+ concurrency=global_settings().limit_concurrency,
60
+ )
61
+ )
62
+
63
+ if not show_progress:
64
+ log.warning("Running %d tasks (progress disabled)…", len(task_list))
65
+
66
+ async with multitask_status(enabled=show_progress) as status:
67
+ return await gather_limited_sync(
68
+ *task_list,
69
+ limit=effective_limit,
70
+ bucket_limits=bucket_limits,
71
+ status=status,
72
+ labeler=effective_labeler,
73
+ retry_settings=retry_settings,
74
+ )
@@ -0,0 +1,108 @@
1
+ from __future__ import annotations
2
+
3
+ import shutil
4
+ import subprocess
5
+ from pathlib import Path
6
+
7
+ from sidematter_format.sidematter_format import Sidematter
8
+
9
+ from kash.utils.common.url import Url, is_s3_url, parse_s3_url
10
+
11
+
12
+ def check_aws_cli() -> None:
13
+ """
14
+ Check if the AWS CLI is installed and available.
15
+ """
16
+ if shutil.which("aws") is None:
17
+ raise RuntimeError(
18
+ "AWS CLI not found in PATH. Please install 'awscli' and ensure 'aws' is available."
19
+ )
20
+
21
+
22
+ def get_s3_parent_folder(url: Url) -> Url | None:
23
+ """
24
+ Get the parent folder of an S3 URL, or None if not an S3 URL.
25
+ """
26
+ if is_s3_url(url):
27
+ s3_bucket, s3_key = parse_s3_url(url)
28
+ s3_parent_folder = Path(s3_key).parent
29
+
30
+ return Url(f"s3://{s3_bucket}/{s3_parent_folder}")
31
+
32
+ else:
33
+ return None
34
+
35
+
36
+ def s3_sync_to_folder(
37
+ src_path: str | Path,
38
+ s3_dest_parent: Url,
39
+ *,
40
+ include_sidematter: bool = False,
41
+ ) -> list[Url]:
42
+ """
43
+ Sync a local file or directory to an S3 "parent" folder using the AWS CLI.
44
+ Set `include_sidematter` to include sidematter files alongside the source files.
45
+
46
+ Returns a list of S3 URLs that were the top-level sync targets:
47
+ - For a single file: the file URL (and sidematter file/dir URLs if included).
48
+ - For a directory: the destination parent prefix URL (non-recursive reporting).
49
+ """
50
+
51
+ src_path = Path(src_path)
52
+ if not src_path.exists():
53
+ raise ValueError(f"Source path does not exist: {src_path}")
54
+ if not is_s3_url(s3_dest_parent):
55
+ raise ValueError(f"Destination must be an s3:// URL: {s3_dest_parent}")
56
+
57
+ check_aws_cli()
58
+
59
+ dest_prefix = str(s3_dest_parent).rstrip("/") + "/"
60
+ targets: list[Url] = []
61
+
62
+ if src_path.is_file():
63
+ # Build the list of paths to sync using Sidematter's resolved path_list if requested.
64
+ sync_paths: list[Path]
65
+ if include_sidematter:
66
+ resolved = Sidematter(src_path).resolve(parse_meta=False, use_frontmatter=False)
67
+ sync_paths = resolved.path_list
68
+ else:
69
+ sync_paths = [src_path]
70
+
71
+ for p in sync_paths:
72
+ if p.is_file():
73
+ # Use sync with include/exclude to leverage default short-circuiting
74
+ subprocess.run(
75
+ [
76
+ "aws",
77
+ "s3",
78
+ "sync",
79
+ str(p.parent),
80
+ dest_prefix,
81
+ "--exclude",
82
+ "*",
83
+ "--include",
84
+ p.name,
85
+ ],
86
+ check=True,
87
+ )
88
+ targets.append(Url(dest_prefix + p.name))
89
+ elif p.is_dir():
90
+ dest_dir = dest_prefix + p.name + "/"
91
+ subprocess.run(["aws", "s3", "sync", str(p), dest_dir], check=True)
92
+ targets.append(Url(dest_dir))
93
+
94
+ return targets
95
+ else:
96
+ # Directory mode: sync whole directory.
97
+ subprocess.run(
98
+ [
99
+ "aws",
100
+ "s3",
101
+ "sync",
102
+ str(src_path),
103
+ dest_prefix,
104
+ ],
105
+ check=True,
106
+ )
107
+ targets.append(Url(dest_prefix))
108
+ return targets
kash/utils/common/url.py CHANGED
@@ -26,6 +26,7 @@ A string that may not be resolved to a URL or path.
26
26
 
27
27
  HTTP_ONLY = ["http", "https"]
28
28
  HTTP_OR_FILE = HTTP_ONLY + ["file"]
29
+ HTTP_OR_FILE_OR_S3 = HTTP_OR_FILE + ["s3"]
29
30
 
30
31
 
31
32
  def check_if_url(
@@ -36,7 +37,8 @@ def check_if_url(
36
37
  the `urlparse.ParseResult`.
37
38
 
38
39
  Also returns false for Paths, so that it's easy to use local paths and URLs
39
- (`Locator`s) interchangeably. Can provide `HTTP_ONLY` or `HTTP_OR_FILE` to
40
+ (`Locator`s) interchangeably. Can provide `HTTP_ONLY` or `HTTP_OR_FILE`
41
+ or `HTTP_OR_FILE_OR_S3` to restrict to only certain schemes.
40
42
  restrict to only certain schemes.
41
43
  """
42
44
  if isinstance(text, Path):
@@ -69,6 +71,13 @@ def is_file_url(url: str | Url) -> bool:
69
71
  return url.startswith("file://")
70
72
 
71
73
 
74
+ def is_s3_url(url: str | Url) -> bool:
75
+ """
76
+ Is URL an S3 URL?
77
+ """
78
+ return url.startswith("s3://")
79
+
80
+
72
81
  def parse_http_url(url: str | Url) -> ParseResult:
73
82
  """
74
83
  Parse an http/https URL and return the parsed result, raising ValueError if
@@ -118,7 +127,7 @@ def as_file_url(path: str | Path) -> Url:
118
127
 
119
128
  def normalize_url(
120
129
  url: Url,
121
- check_schemes: list[str] | None = HTTP_OR_FILE,
130
+ check_schemes: list[str] | None = HTTP_OR_FILE_OR_S3,
122
131
  drop_fragment: bool = True,
123
132
  resolve_local_paths: bool = True,
124
133
  ) -> Url:
@@ -238,7 +247,10 @@ def test_normalize_url():
238
247
  normalize_url(url=Url("/not/a/URL"))
239
248
  raise AssertionError()
240
249
  except ValueError as e:
241
- assert str(e) == "Scheme '' not in allowed schemes: ['http', 'https', 'file']: /not/a/URL"
250
+ assert (
251
+ str(e)
252
+ == "Scheme '' not in allowed schemes: ['http', 'https', 'file', 's3']: /not/a/URL"
253
+ )
242
254
 
243
255
  try:
244
256
  normalize_url(Url("ftp://example.com"))
@@ -246,7 +258,7 @@ def test_normalize_url():
246
258
  except ValueError as e:
247
259
  assert (
248
260
  str(e)
249
- == "Scheme 'ftp' not in allowed schemes: ['http', 'https', 'file']: ftp://example.com"
261
+ == "Scheme 'ftp' not in allowed schemes: ['http', 'https', 'file', 's3']: ftp://example.com"
250
262
  )
251
263
 
252
264