kash-shell 0.3.33__py3-none-any.whl → 0.3.35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -74,26 +74,108 @@ def import_recursive(
74
74
  return tallies
75
75
 
76
76
 
77
+ def _import_modules_from_package(
78
+ package: types.ModuleType,
79
+ package_name: str,
80
+ max_depth: int = 1,
81
+ include_private: bool = True,
82
+ current_depth: int = 0,
83
+ imported_modules: dict[str, types.ModuleType] | None = None,
84
+ ) -> dict[str, types.ModuleType]:
85
+ """
86
+ Internal helper to recursively import modules from a package.
87
+
88
+ Args:
89
+ package: The package module to import from
90
+ package_name: The fully qualified name of the package
91
+ max_depth: Maximum recursion depth (1 = direct children only)
92
+ include_private: Whether to import private modules (starting with _)
93
+ current_depth: Current recursion depth (internal use)
94
+ imported_modules: Dictionary to accumulate imported modules
95
+
96
+ Returns:
97
+ Dictionary mapping module names to their imported module objects
98
+ """
99
+ if imported_modules is None:
100
+ imported_modules = {}
101
+
102
+ if current_depth >= max_depth:
103
+ return imported_modules
104
+
105
+ # Get the module's __path__ if it's a package
106
+ if not hasattr(package, "__path__"):
107
+ return imported_modules
108
+
109
+ try:
110
+ for _finder, module_name, ispkg in pkgutil.iter_modules(
111
+ package.__path__, f"{package_name}."
112
+ ):
113
+ # Skip private modules unless requested
114
+ if not include_private and module_name.split(".")[-1].startswith("_"):
115
+ continue
116
+
117
+ # Skip test modules - they often have special import requirements
118
+ # and aren't needed for warming the import cache
119
+ module_parts = module_name.split(".")
120
+ if any(
121
+ part in ("tests", "test", "testing", "_test", "_tests") for part in module_parts
122
+ ):
123
+ continue
124
+
125
+ # Skip already imported modules
126
+ if module_name in imported_modules:
127
+ continue
128
+
129
+ try:
130
+ module = importlib.import_module(module_name)
131
+ imported_modules[module_name] = module
132
+
133
+ # Recursively import submodules if it's a package
134
+ if ispkg and current_depth + 1 < max_depth:
135
+ _import_modules_from_package(
136
+ module,
137
+ module_name,
138
+ max_depth=max_depth,
139
+ include_private=include_private,
140
+ current_depth=current_depth + 1,
141
+ imported_modules=imported_modules,
142
+ )
143
+
144
+ except Exception as e:
145
+ # Handle various import failures gracefully
146
+ # This includes ImportError, pytest.Skipped, and other exceptions
147
+ error_type = type(e).__name__
148
+ if error_type not in ("ImportError", "AttributeError", "TypeError"):
149
+ log.debug(f" Skipped {module_name}: {error_type}: {e}")
150
+ # Don't log common/expected import errors to reduce noise
151
+
152
+ except Exception as e:
153
+ log.warning(f"Error iterating modules in {package_name}: {e}")
154
+
155
+ return imported_modules
156
+
157
+
77
158
  def import_namespace_modules(namespace: str) -> dict[str, types.ModuleType]:
78
159
  """
79
160
  Find and import all modules or packages within a namespace package.
80
161
  Returns a dictionary mapping module names to their imported module objects.
81
162
  """
82
- importlib.import_module(namespace) # Propagate import errors
163
+ # Import the main module first
164
+ main_module = importlib.import_module(namespace) # Propagate import errors
83
165
 
84
166
  # Get the package to access its __path__
85
- package = sys.modules.get(namespace)
86
- if not package or not hasattr(package, "__path__"):
167
+ if not hasattr(main_module, "__path__"):
87
168
  raise ImportError(f"`{namespace}` is not a package or namespace package")
88
169
 
89
- log.info(f"Discovering modules in `{namespace}` namespace, searching: {package.__path__}")
170
+ log.info(f"Discovering modules in `{namespace}` namespace, searching: {main_module.__path__}")
171
+
172
+ # Use the common helper with depth=1 (no recursion) and include_private=True
173
+ modules = _import_modules_from_package(
174
+ main_module, namespace, max_depth=1, include_private=True
175
+ )
90
176
 
91
- # Iterate through all modules in the namespace package
92
- modules = {}
93
- for _finder, module_name, _ispkg in pkgutil.iter_modules(package.__path__, f"{namespace}."):
94
- module = importlib.import_module(module_name) # Propagate import errors
95
- log.info(f"Imported module: {module_name} from {module.__file__}")
96
- modules[module_name] = module
177
+ # Add the main module itself
178
+ modules[namespace] = main_module
97
179
 
98
180
  log.info(f"Imported {len(modules)} modules from namespace `{namespace}`")
99
181
  return modules
@@ -106,8 +188,13 @@ def recursive_reload(
106
188
  Recursively reload all modules in the given package that match the filter function.
107
189
  Returns a list of module names that were reloaded.
108
190
 
109
- :param filter_func: A function that takes a module name and returns True if the
110
- module should be reloaded.
191
+ Args:
192
+ package: The package to reload.
193
+ filter_func: A function that takes a module name and returns True if the
194
+ module should be reloaded.
195
+
196
+ Returns:
197
+ List of module names that were reloaded.
111
198
  """
112
199
  package_name = package.__name__
113
200
  modules = {
@@ -124,3 +211,40 @@ def recursive_reload(
124
211
  importlib.reload(modules[name])
125
212
 
126
213
  return module_names
214
+
215
+
216
+ def warm_import_library(
217
+ library_name: str, max_depth: int = 3, include_private: bool = False
218
+ ) -> dict[str, types.ModuleType]:
219
+ """
220
+ Recursively import all submodules of a library to warm the import cache.
221
+ This is useful for servers where you want to pay the import cost upfront
222
+ rather than during request handling.
223
+
224
+ Args:
225
+ library_name: Name of the library to import (e.g., 'litellm', 'openai')
226
+ max_depth: Maximum depth to recurse into submodules
227
+ include_private: Whether to import private modules (starting with _)
228
+
229
+ Returns:
230
+ Dictionary mapping module names to their imported module objects
231
+ """
232
+ try:
233
+ # Import the main module first
234
+ main_module = importlib.import_module(library_name)
235
+
236
+ # Use the common helper for recursive imports
237
+ imported_modules = _import_modules_from_package(
238
+ main_module, library_name, max_depth=max_depth, include_private=include_private
239
+ )
240
+
241
+ # Add the main module itself
242
+ imported_modules[library_name] = main_module
243
+
244
+ except ImportError as e:
245
+ log.warning(f"Could not import {library_name}: {e}")
246
+ return {}
247
+
248
+ log.info(f"Warmed {len(imported_modules)} modules from {library_name}")
249
+
250
+ return imported_modules
@@ -1,13 +1,19 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import os
3
4
  import shutil
4
5
  import subprocess
6
+ from logging import getLogger
5
7
  from pathlib import Path
6
8
 
9
+ from dotenv import find_dotenv, load_dotenv
7
10
  from sidematter_format.sidematter_format import Sidematter
11
+ from strif import abbrev_str
8
12
 
9
13
  from kash.utils.common.url import Url, is_s3_url, parse_s3_url
10
14
 
15
+ log = getLogger(__name__)
16
+
11
17
 
12
18
  def check_aws_cli() -> None:
13
19
  """
@@ -19,6 +25,54 @@ def check_aws_cli() -> None:
19
25
  )
20
26
 
21
27
 
28
+ def run_aws_command(cmd: list[str]) -> subprocess.CompletedProcess[str]:
29
+ """
30
+ Run an AWS CLI command and capture output.
31
+ Raises a RuntimeError with stdout/stderr on failure.
32
+ """
33
+ result = subprocess.run(
34
+ cmd,
35
+ capture_output=True,
36
+ text=True,
37
+ env=os.environ,
38
+ )
39
+
40
+ if result.returncode != 0:
41
+ # Build a detailed error message
42
+ error_parts = [f"AWS command failed with exit code {result.returncode}"]
43
+ error_parts.append(f"Command: {' '.join(cmd)}")
44
+
45
+ if result.stdout:
46
+ error_parts.append(f"stdout: {result.stdout}")
47
+ if result.stderr:
48
+ error_parts.append(f"stderr: {result.stderr}")
49
+
50
+ raise RuntimeError("\n".join(error_parts))
51
+
52
+ return result
53
+
54
+
55
+ def reload_aws_env_vars() -> None:
56
+ """
57
+ Fresh reload of AWS env vars from .env.local.
58
+ """
59
+
60
+ def aws_creds() -> set[tuple[str, str]]:
61
+ return {(k, abbrev_str(v, 5)) for k, v in os.environ.items() if k.startswith("AWS_")}
62
+
63
+ if len(aws_creds()) == 0:
64
+ dotenv_path = find_dotenv(".env.local", usecwd=True) or find_dotenv(".env", usecwd=True)
65
+ load_dotenv(dotenv_path, override=True)
66
+ if len(aws_creds()) > 0:
67
+ log.info(
68
+ "Loaded %s, found AWS credentials: %s",
69
+ dotenv_path,
70
+ aws_creds(),
71
+ )
72
+ else:
73
+ log.warning("No AWS credentials found in env or .env files")
74
+
75
+
22
76
  def get_s3_parent_folder(url: Url) -> Url | None:
23
77
  """
24
78
  Get the parent folder of an S3 URL, or None if not an S3 URL.
@@ -47,6 +101,7 @@ def s3_sync_to_folder(
47
101
  - For a single file: the file URL (and sidematter file/dir URLs if included).
48
102
  - For a directory: the destination parent prefix URL (non-recursive reporting).
49
103
  """
104
+ reload_aws_env_vars()
50
105
 
51
106
  src_path = Path(src_path)
52
107
  if not src_path.exists():
@@ -71,7 +126,7 @@ def s3_sync_to_folder(
71
126
  for p in sync_paths:
72
127
  if p.is_file():
73
128
  # Use sync with include/exclude to leverage default short-circuiting
74
- subprocess.run(
129
+ run_aws_command(
75
130
  [
76
131
  "aws",
77
132
  "s3",
@@ -82,27 +137,54 @@ def s3_sync_to_folder(
82
137
  "*",
83
138
  "--include",
84
139
  p.name,
85
- ],
86
- check=True,
140
+ ]
87
141
  )
88
142
  targets.append(Url(dest_prefix + p.name))
89
143
  elif p.is_dir():
90
144
  dest_dir = dest_prefix + p.name + "/"
91
- subprocess.run(["aws", "s3", "sync", str(p), dest_dir], check=True)
145
+ run_aws_command(["aws", "s3", "sync", str(p), dest_dir])
92
146
  targets.append(Url(dest_dir))
93
147
 
94
148
  return targets
95
149
  else:
96
150
  # Directory mode: sync whole directory.
97
- subprocess.run(
151
+ run_aws_command(
98
152
  [
99
153
  "aws",
100
154
  "s3",
101
155
  "sync",
102
156
  str(src_path),
103
157
  dest_prefix,
104
- ],
105
- check=True,
158
+ ]
106
159
  )
107
160
  targets.append(Url(dest_prefix))
108
161
  return targets
162
+
163
+
164
+ def s3_download_file(s3_url: Url, target_path: str | Path) -> None:
165
+ """
166
+ Download a file from S3 to a local path using the AWS CLI.
167
+
168
+ Args:
169
+ s3_url: The S3 URL to download from (s3://bucket/path/to/file)
170
+ target_path: The local path to save the file to
171
+ """
172
+ reload_aws_env_vars()
173
+
174
+ if not is_s3_url(s3_url):
175
+ raise ValueError(f"Source must be an s3:// URL: {s3_url}")
176
+
177
+ check_aws_cli()
178
+
179
+ target_path = Path(target_path)
180
+
181
+ # Use aws s3 cp to download the file
182
+ run_aws_command(
183
+ [
184
+ "aws",
185
+ "s3",
186
+ "cp",
187
+ str(s3_url),
188
+ str(target_path),
189
+ ]
190
+ )
@@ -72,6 +72,8 @@ RUNNING_SYMBOL = ""
72
72
  DEFAULT_LABEL_WIDTH = 40
73
73
  DEFAULT_PROGRESS_WIDTH = 20
74
74
 
75
+ MAX_DISPLAY_TASKS = 20
76
+
75
77
 
76
78
  # Calculate spinner width to maintain column alignment
77
79
  def _get_spinner_width(spinner_name: str) -> int:
@@ -101,6 +103,9 @@ class StatusSettings:
101
103
  transient: bool = True
102
104
  refresh_per_second: float = 10
103
105
  styles: StatusStyles = DEFAULT_STYLES
106
+ # Maximum number of tasks to keep visible in the live display.
107
+ # Older completed/skipped/failed tasks beyond this cap will be removed from the live view.
108
+ max_display_tasks: int = MAX_DISPLAY_TASKS
104
109
 
105
110
 
106
111
  class SpinnerStatusColumn(ProgressColumn):
@@ -298,6 +303,10 @@ class MultiTaskStatus(AbstractAsyncContextManager):
298
303
  self._task_info: dict[int, TaskInfo] = {}
299
304
  self._next_id: int = 1
300
305
  self._rich_task_ids: dict[int, TaskID] = {} # Map our IDs to Rich Progress IDs
306
+ # Track order of tasks added to the Progress so we can prune oldest completed ones
307
+ self._displayed_task_order: list[int] = []
308
+ # Track tasks pruned from the live display so we don't re-add them later
309
+ self._pruned_task_ids: set[int] = set()
301
310
 
302
311
  # Unified live integration
303
312
  self._unified_live: Any | None = None # Reference to the global unified live
@@ -442,6 +451,10 @@ class MultiTaskStatus(AbstractAsyncContextManager):
442
451
  progress_display=None,
443
452
  )
444
453
  self._rich_task_ids[task_id] = rich_task_id
454
+ self._displayed_task_order.append(task_id)
455
+
456
+ # Prune if too many tasks are visible (prefer removing completed ones)
457
+ self._prune_completed_tasks_if_needed()
445
458
 
446
459
  async def set_progress_display(self, task_id: int, display: RenderableType) -> None:
447
460
  """
@@ -536,18 +549,31 @@ class MultiTaskStatus(AbstractAsyncContextManager):
536
549
 
537
550
  # Complete the progress bar and stop spinner
538
551
  if rich_task_id is not None:
539
- total = self._progress.tasks[rich_task_id].total or 1
552
+ # Safely find the Task by id; Progress.tasks is a list, not a dict
553
+ task_obj = next((t for t in self._progress.tasks if t.id == rich_task_id), None)
554
+ if task_obj is not None and task_obj.total is not None:
555
+ total = task_obj.total
556
+ else:
557
+ total = task_info.steps_total or 1
540
558
  self._progress.update(rich_task_id, completed=total, task_info=task_info)
541
559
  else:
542
- # Task was never started, but we still need to add it to show completion
543
- rich_task_id = self._progress.add_task(
544
- "",
545
- total=task_info.steps_total,
546
- label=task_info.label,
547
- completed=task_info.steps_total,
548
- task_info=task_info,
549
- )
550
- self._rich_task_ids[task_id] = rich_task_id
560
+ # If this task was pruned from the live display, skip re-adding it
561
+ if task_id in self._pruned_task_ids:
562
+ pass
563
+ else:
564
+ # Task was never started; add a completed row so it appears once
565
+ rich_task_id = self._progress.add_task(
566
+ "",
567
+ total=task_info.steps_total,
568
+ label=task_info.label,
569
+ completed=task_info.steps_total,
570
+ task_info=task_info,
571
+ )
572
+ self._rich_task_ids[task_id] = rich_task_id
573
+ self._displayed_task_order.append(task_id)
574
+
575
+ # After finishing, prune completed tasks to respect max visible cap
576
+ self._prune_completed_tasks_if_needed()
551
577
 
552
578
  def get_task_info(self, task_id: int) -> TaskInfo | None:
553
579
  """Get additional task information."""
@@ -567,6 +593,54 @@ class MultiTaskStatus(AbstractAsyncContextManager):
567
593
  """Get console instance for additional output above progress."""
568
594
  return self._progress.console
569
595
 
596
+ def _prune_completed_tasks_if_needed(self) -> None:
597
+ """
598
+ Ensure at most `max_display_tasks` tasks are visible by removing the oldest
599
+ completed/skipped/failed tasks first. Running or waiting tasks are never
600
+ removed by this method.
601
+ Note: This method assumes it's called under self._lock.
602
+ """
603
+ max_visible = self.settings.max_display_tasks
604
+
605
+ # Nothing to prune or unlimited
606
+ if max_visible <= 0:
607
+ return
608
+
609
+ # Count visible tasks (those with a Rich task id present)
610
+ visible_task_ids = [tid for tid in self._displayed_task_order if tid in self._rich_task_ids]
611
+ excess = len(visible_task_ids) - max_visible
612
+ if excess <= 0:
613
+ return
614
+
615
+ # Build list of terminal tasks that can be pruned (oldest first)
616
+ terminal_tasks = []
617
+ for tid in self._displayed_task_order:
618
+ if tid not in self._rich_task_ids:
619
+ continue
620
+ info = self._task_info.get(tid)
621
+ if info and info.state in (
622
+ TaskState.COMPLETED,
623
+ TaskState.FAILED,
624
+ TaskState.SKIPPED,
625
+ ):
626
+ terminal_tasks.append(tid)
627
+
628
+ # Remove the oldest terminal tasks up to the excess count
629
+ tasks_to_remove = terminal_tasks[:excess]
630
+
631
+ for tid in tasks_to_remove:
632
+ rich_tid = self._rich_task_ids.pop(tid, None)
633
+ if rich_tid is not None:
634
+ # Remove from Rich progress display
635
+ self._progress.remove_task(rich_tid)
636
+ # Mark as pruned so we don't re-add on finish
637
+ self._pruned_task_ids.add(tid)
638
+
639
+ # Efficiently rebuild the displayed task order without the removed tasks
640
+ self._displayed_task_order = [
641
+ tid for tid in self._displayed_task_order if tid not in tasks_to_remove
642
+ ]
643
+
570
644
 
571
645
  ## Tests
572
646
 
@@ -1,48 +1,19 @@
1
1
  from __future__ import annotations
2
2
 
3
- import re
4
3
  from dataclasses import dataclass, field
5
4
  from typing import Any
6
5
 
7
- from flowmark import flowmark_markdown, line_wrap_by_sentence
8
6
  from marko import Markdown
7
+ from marko.block import Document
9
8
  from marko.ext import footnote
10
9
 
11
- from kash.utils.text_handling.markdown_utils import comprehensive_transform_tree
12
-
13
-
14
- def _normalize_footnotes_in_markdown(content: str) -> str:
15
- """
16
- Ensure blank lines between consecutive footnote definitions.
17
-
18
- Marko has a bug where consecutive footnotes without blank lines are parsed
19
- as a single footnote. This adds blank lines where needed.
20
- """
21
- lines = content.split("\n")
22
- result = []
23
- i = 0
24
-
25
- while i < len(lines):
26
- line = lines[i]
27
- result.append(line)
28
-
29
- # Check if this is a footnote definition
30
- if re.match(r"^\[\^[^\]]+\]:", line):
31
- # Look ahead to see if the next non-empty line is also a footnote
32
- j = i + 1
33
- while j < len(lines) and not lines[j].strip():
34
- result.append(lines[j])
35
- j += 1
36
-
37
- if j < len(lines) and re.match(r"^\[\^[^\]]+\]:", lines[j]):
38
- # Next non-empty line is also a footnote, add blank line
39
- result.append("")
40
-
41
- i = j
42
- else:
43
- i += 1
44
-
45
- return "\n".join(result)
10
+ from kash.utils.text_handling.markdown_utils import (
11
+ MARKDOWN as DEFAULT_MARKDOWN,
12
+ )
13
+ from kash.utils.text_handling.markdown_utils import (
14
+ comprehensive_transform_tree,
15
+ normalize_footnotes_in_markdown,
16
+ )
46
17
 
47
18
 
48
19
  @dataclass
@@ -81,15 +52,17 @@ class MarkdownFootnotes:
81
52
  MarkdownFootnotes instance with all footnotes indexed by ID
82
53
  """
83
54
  if markdown_parser is None:
84
- markdown_parser = flowmark_markdown(line_wrap_by_sentence(is_markdown=True))
55
+ markdown_parser = DEFAULT_MARKDOWN
85
56
 
86
57
  # Normalize to work around marko bug with consecutive footnotes
87
- normalized_content = _normalize_footnotes_in_markdown(content)
58
+ normalized_content = normalize_footnotes_in_markdown(content)
88
59
  document = markdown_parser.parse(normalized_content)
89
60
  return MarkdownFootnotes.from_document(document, markdown_parser)
90
61
 
91
62
  @staticmethod
92
- def from_document(document: Any, markdown_parser: Markdown | None = None) -> MarkdownFootnotes:
63
+ def from_document(
64
+ document: Document, markdown_parser: Markdown | None = None
65
+ ) -> MarkdownFootnotes:
93
66
  """
94
67
  Extract all footnotes from a parsed markdown document.
95
68
 
@@ -102,7 +75,7 @@ class MarkdownFootnotes:
102
75
  MarkdownFootnotes instance with all footnotes indexed by ID
103
76
  """
104
77
  if markdown_parser is None:
105
- markdown_parser = flowmark_markdown(line_wrap_by_sentence(is_markdown=True))
78
+ markdown_parser = DEFAULT_MARKDOWN
106
79
 
107
80
  footnotes_dict: dict[str, FootnoteInfo] = {}
108
81
 
@@ -206,9 +179,9 @@ def extract_footnote_references(content: str, markdown_parser: Markdown | None =
206
179
  List of unique footnote IDs that are referenced (with the ^)
207
180
  """
208
181
  if markdown_parser is None:
209
- markdown_parser = flowmark_markdown(line_wrap_by_sentence(is_markdown=True))
182
+ markdown_parser = DEFAULT_MARKDOWN
210
183
 
211
- normalized_content = _normalize_footnotes_in_markdown(content)
184
+ normalized_content = normalize_footnotes_in_markdown(content)
212
185
  document = markdown_parser.parse(normalized_content)
213
186
  references: list[str] = []
214
187
  seen: set[str] = set()