kash-shell 0.3.34__py3-none-any.whl → 0.3.35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -22,9 +22,6 @@ class KashEnv(EnvEnum):
22
22
  KASH_SYSTEM_CACHE_DIR = "KASH_SYSTEM_CACHE_DIR"
23
23
  """The directory for system cache (caches separate from workspace caches)."""
24
24
 
25
- KASH_MCP_WS = "KASH_MCP_WS"
26
- """The directory for the workspace for MCP servers."""
27
-
28
25
  KASH_SHOW_TRACEBACK = "KASH_SHOW_TRACEBACK"
29
26
  """Whether to show tracebacks on actions and commands in the shell."""
30
27
 
kash/config/logger.py CHANGED
@@ -281,6 +281,8 @@ def _do_logging_setup(log_settings: LogSettings):
281
281
  def prefix(line: str, emoji: str = "", warn_emoji: str = "") -> str:
282
282
  prefix = task_stack_prefix_str()
283
283
  emojis = f"{warn_emoji}{emoji}".strip()
284
+ if emojis:
285
+ emojis += " "
284
286
  return "".join(filter(None, [prefix, emojis, line]))
285
287
 
286
288
 
@@ -19,7 +19,16 @@ class SuppressedWarningsStreamHandler(logging.StreamHandler):
19
19
  def basic_file_handler(path: Path, level: LogLevel | LogLevelStr) -> logging.FileHandler:
20
20
  handler = logging.FileHandler(path)
21
21
  handler.setLevel(LogLevel.parse(level).value)
22
- handler.setFormatter(Formatter("%(asctime)s %(levelname).1s %(name)s - %(message)s"))
22
+
23
+ class ThreadIdFormatter(Formatter):
24
+ def format(self, record):
25
+ # Add shortened thread ID as an attribute
26
+ record.thread_short = str(record.thread)[-5:]
27
+ return super().format(record)
28
+
29
+ handler.setFormatter(
30
+ ThreadIdFormatter("%(asctime)s %(levelname).1s [T%(thread_short)s] %(name)s - %(message)s")
31
+ )
23
32
  return handler
24
33
 
25
34
 
kash/config/settings.py CHANGED
@@ -166,9 +166,6 @@ class Settings:
166
166
  system_cache_dir: Path
167
167
  """Default global and system cache directory (for global media, content, etc)."""
168
168
 
169
- mcp_ws_dir: Path | None
170
- """The directory for the MCP workspace, if set."""
171
-
172
169
  local_server_log_path: Path
173
170
  """The path to the local server log."""
174
171
 
@@ -245,14 +242,6 @@ def _get_system_cache_dir() -> Path:
245
242
  return KashEnv.KASH_SYSTEM_CACHE_DIR.read_path(default=_get_ws_root_dir() / "cache")
246
243
 
247
244
 
248
- def _get_mcp_ws_dir() -> Path | None:
249
- mcp_dir = KashEnv.KASH_MCP_WS.read_str(default=None)
250
- if mcp_dir:
251
- return Path(mcp_dir).expanduser().resolve()
252
- else:
253
- return None
254
-
255
-
256
245
  @cache
257
246
  def _get_local_server_log_path() -> Path:
258
247
  return resolve_and_create_dirs(get_system_logs_dir() / f"{LOCAL_SERVER_LOG_NAME}.log")
@@ -266,7 +255,6 @@ def _read_settings():
266
255
  system_config_dir=_get_system_config_dir(),
267
256
  system_logs_dir=get_system_logs_dir(),
268
257
  system_cache_dir=_get_system_cache_dir(),
269
- mcp_ws_dir=_get_mcp_ws_dir(),
270
258
  local_server_log_path=_get_local_server_log_path(),
271
259
  # These default to the global but can be overridden by workspace settings.
272
260
  media_cache_dir=_get_system_cache_dir() / MEDIA_CACHE_NAME,
kash/config/setup.py CHANGED
@@ -75,6 +75,21 @@ def kash_setup(
75
75
 
76
76
 
77
77
  def _lib_setup():
78
+ import logging
79
+
80
+ log = logging.getLogger(__name__)
81
+
82
+ # Trust store integration, for consistent TLS behavior.
83
+ try:
84
+ import truststore # type: ignore
85
+
86
+ truststore.inject_into_ssl()
87
+ log.info("truststore initialized: using system TLS trust store")
88
+ except Exception as exc:
89
+ # If not installed or fails, default TLS trust will be used.
90
+ log.warning("truststore not available at import time: %s", exc)
91
+
92
+ # Handle default YAML representers.
78
93
  from sidematter_format import register_default_yaml_representers
79
94
 
80
95
  register_default_yaml_representers()
@@ -262,7 +262,7 @@ PROMPT_ASSIST = "(assistant) ❯"
262
262
 
263
263
  EMOJI_HINT = "👉"
264
264
 
265
- EMOJI_MSG_INDENT = "⋮ "
265
+ EMOJI_MSG_INDENT = "⋮ "
266
266
 
267
267
  EMOJI_START = "[➤]"
268
268
 
@@ -0,0 +1,60 @@
1
+ from funlog import log_calls
2
+
3
+ from kash.config.logger import get_logger
4
+ from kash.utils.common.import_utils import warm_import_library
5
+
6
+ log = get_logger(__name__)
7
+
8
+
9
+ @log_calls(level="info", show_timing_only=True)
10
+ def warm_slow_imports(include_extras: bool = True):
11
+ """
12
+ Pre-import slow packages to avoid delays when they are first used.
13
+
14
+ Args:
15
+ include_extras: If True, warm import optional libraries like LLM packages,
16
+ scipy, torch, etc. Set to False for minimal/faster startup.
17
+ """
18
+ try:
19
+ # Loading actions also loads any kits that are discovered.
20
+ import kash.actions # noqa: F401
21
+ import kash.local_server # noqa: F401
22
+ import kash.local_server.local_server # noqa: F401
23
+ import kash.mcp.mcp_server_sse # noqa: F401
24
+
25
+ # Core libraries that should usually be present
26
+ for lib_name, max_depth in [("xonsh", 3), ("uvicorn", 3)]:
27
+ try:
28
+ warm_import_library(lib_name, max_depth=max_depth)
29
+ except Exception as e:
30
+ log.debug(f"Could not warm import {lib_name}: {e}")
31
+
32
+ if include_extras:
33
+ # Fully warm import larger libraries (only if they're installed)
34
+ # These are optional dependencies that may not be present
35
+ optional_libraries = [
36
+ ("pydantic", 5),
37
+ ("litellm", 5),
38
+ ("openai", 5),
39
+ ("torch", 3), # torch is huge, limit depth
40
+ ("scipy", 3), # scipy has test modules we want to skip
41
+ ("marker", 4),
42
+ ("pandas", 3),
43
+ ]
44
+
45
+ for lib_name, max_depth in optional_libraries:
46
+ try:
47
+ warm_import_library(lib_name, max_depth=max_depth)
48
+ except Exception as e:
49
+ log.debug(f"Could not warm import {lib_name}: {e}")
50
+
51
+ # Initialize litellm configuration if available
52
+ try:
53
+ from kash.llm_utils.init_litellm import init_litellm
54
+
55
+ init_litellm()
56
+ except ImportError:
57
+ pass # litellm not installed
58
+
59
+ except ImportError as e:
60
+ log.warning(f"Error pre-importing packages: {e}")
@@ -204,7 +204,7 @@ def kash_action(
204
204
  precondition: Precondition = Precondition.always,
205
205
  arg_type: ArgType = ArgType.Locator,
206
206
  expected_args: ArgCount = ONE_ARG,
207
- output_type: ItemType = ItemType.doc,
207
+ output_type: ItemType | None = None,
208
208
  output_format: Format | None = None,
209
209
  expected_outputs: ArgCount = ONE_ARG,
210
210
  params: ParamDeclarations = (),
@@ -349,7 +349,7 @@ def kash_action(
349
349
  fmt_lines(self.params),
350
350
  )
351
351
  log.info(
352
- "Action function param values:\n%s",
352
+ "Action function param values: %s",
353
353
  self.param_value_summary_str(),
354
354
  )
355
355
  else:
kash/exec/action_exec.py CHANGED
@@ -107,7 +107,7 @@ def log_action(action: Action, action_input: ActionInput, operation: Operation):
107
107
  log.message("%s Action: `%s`", EMOJI_START, action.name)
108
108
  log.info("Running: `%s`", operation.command_line(with_options=True))
109
109
  if len(action.param_value_summary()) > 0:
110
- log.message("Parameters:\n%s", action.param_value_summary_str())
110
+ log.message("Parameters: %s", action.param_value_summary_str())
111
111
  log.info("Operation is: %s", operation)
112
112
  log.info("Input items are:\n%s", fmt_lines(action_input.items))
113
113
 
@@ -144,15 +144,17 @@ def fetch_url_item_content(
144
144
  if save_content:
145
145
  assert page_data.saved_content
146
146
  assert page_data.format_info
147
+ if not page_data.format_info.format:
148
+ log.warning("No format detected for content, defaulting to HTML: %s", url)
147
149
  content_item = url_item.new_copy_with(
148
150
  external_path=str(page_data.saved_content),
149
151
  # Use the original filename, not the local cache filename (which has a hash suffix).
150
152
  original_filename=item.get_filename(),
151
- format=page_data.format_info.format,
153
+ format=page_data.format_info.format or Format.html,
152
154
  )
153
155
 
154
156
  if not url_item.title:
155
- log.warning("Failed to fetch page data: title is missing: %s", item.url)
157
+ log.info("Title is missing for url item: %s", item)
156
158
 
157
159
  # Now save the updated URL item and also the content item if we have one.
158
160
  ws.save(url_item, overwrite=overwrite)
kash/mcp/mcp_cli.py CHANGED
@@ -11,8 +11,14 @@ from pathlib import Path
11
11
 
12
12
  from clideps.utils.readable_argparse import ReadableColorFormatter
13
13
 
14
- from kash.config.settings import DEFAULT_MCP_SERVER_PORT, LogLevel, global_settings
14
+ from kash.config.settings import (
15
+ DEFAULT_MCP_SERVER_PORT,
16
+ LogLevel,
17
+ atomic_global_settings,
18
+ global_settings,
19
+ )
15
20
  from kash.config.setup import kash_setup
21
+ from kash.config.warm_slow_imports import warm_slow_imports
16
22
  from kash.shell.version import get_version
17
23
 
18
24
  __version__ = get_version()
@@ -26,8 +32,6 @@ log = logging.getLogger()
26
32
 
27
33
 
28
34
  def build_parser():
29
- from kash.workspaces.workspaces import global_ws_dir
30
-
31
35
  parser = argparse.ArgumentParser(description=__doc__, formatter_class=ReadableColorFormatter)
32
36
  parser.add_argument(
33
37
  "--version",
@@ -36,8 +40,8 @@ def build_parser():
36
40
  )
37
41
  parser.add_argument(
38
42
  "--workspace",
39
- default=global_ws_dir(),
40
- help=f"Set workspace directory. Defaults to kash global workspace directory: {global_ws_dir()}",
43
+ default=global_settings().global_ws_dir,
44
+ help=f"Set workspace directory. Defaults to kash global workspace directory: {global_settings().global_ws_dir}",
41
45
  )
42
46
  parser.add_argument(
43
47
  "--proxy",
@@ -95,6 +99,14 @@ def run_server(args: argparse.Namespace):
95
99
  log.warning("kash MCP CLI started, logging to: %s", MCP_CLI_LOG_PATH)
96
100
  log.warning("Current working directory: %s", Path(".").resolve())
97
101
 
102
+ # Eagerly import so the server is warmed up.
103
+ # This is important to save init time on fresh sandboxes like E2B!
104
+ warm_slow_imports(include_extras=True)
105
+
106
+ if args.workspace and args.workspace != global_settings().global_ws_dir:
107
+ with atomic_global_settings().updates() as settings:
108
+ settings.global_ws_dir = Path(args.workspace).absolute()
109
+
98
110
  ws: Workspace = get_ws(name_or_path=Path(args.workspace), auto_init=True)
99
111
  os.chdir(ws.base_dir)
100
112
  log.warning("Running in workspace: %s", ws.base_dir)
@@ -3,7 +3,9 @@ from __future__ import annotations
3
3
  import asyncio
4
4
  import pprint
5
5
  from dataclasses import dataclass
6
+ from pathlib import Path
6
7
 
8
+ from clideps.env_vars.dotenv_utils import load_dotenv_paths
7
9
  from funlog import log_calls
8
10
  from mcp.server.lowlevel import Server
9
11
  from mcp.server.lowlevel.server import StructuredContent, UnstructuredContent
@@ -237,10 +239,10 @@ def run_mcp_tool(
237
239
  """
238
240
  try:
239
241
  with captured_output() as capture:
240
- # XXX For now, unless the user has overridden the MCP workspace, we use the
241
- # current workspace, which could be changed by the user by changing working
242
- # directories. Maybe confusing?
243
- explicit_mcp_ws = global_settings().mcp_ws_dir
242
+ dotenv_paths = load_dotenv_paths(True, True, Path("."))
243
+ log.warning("Loaded .env files: %s", dotenv_paths)
244
+ # Use the global workspace default
245
+ explicit_mcp_ws = global_settings().global_ws_dir
244
246
 
245
247
  with kash_runtime(
246
248
  workspace_dir=explicit_mcp_ws,
@@ -244,11 +244,12 @@ class Action(ABC):
244
244
  be ONE_ARG.
245
245
  """
246
246
 
247
- output_type: ItemType = ItemType.doc
247
+ output_type: ItemType | None = None
248
248
  """
249
249
  The type of the output item(s). If an action returns multiple output types,
250
250
  this will be the output type of the first output.
251
251
  This is mainly used for preassembly for the cache check if an output already exists.
252
+ None means to use the input type.
252
253
  """
253
254
 
254
255
  output_format: Format | None = None
@@ -451,7 +452,7 @@ class Action(ABC):
451
452
  return changed_params
452
453
 
453
454
  def param_value_summary_str(self) -> str:
454
- return fmt_lines(
455
+ return ", ".join(
455
456
  [format_key_value(name, value) for name, value in self.param_value_summary().items()]
456
457
  )
457
458
 
@@ -560,7 +561,14 @@ class Action(ABC):
560
561
  # Using first input to determine the output title.
561
562
  primary_input = context.action_input.items[0]
562
563
  # In this case we only expect one output, of the type specified by the action.
563
- primary_output = primary_input.derived_copy(context, 0, type=context.action.output_type)
564
+ output_type = context.action.output_type or primary_input.type
565
+ if not output_type:
566
+ log.warning(
567
+ "No output type specified for action `%s`, using `doc` for preassembly",
568
+ self.name,
569
+ )
570
+ output_type = ItemType.doc
571
+ primary_output = primary_input.derived_copy(context, 0, type=output_type)
564
572
  log.info("Preassembled output: source %s, %s", primary_output.source, primary_output)
565
573
  return ActionResult([primary_output])
566
574
  else:
kash/model/items_model.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import re
3
4
  from collections.abc import Sequence
4
5
  from copy import deepcopy
5
6
  from dataclasses import asdict, field, is_dataclass
@@ -636,8 +637,8 @@ class Item:
636
637
  pull_body_heading: bool = False,
637
638
  ) -> str:
638
639
  """
639
- Get or infer a title for this item, falling back to the filename, URL, description, or
640
- finally body text. Optionally, include the last operation as a parenthetical at the end
640
+ Get or infer a title for this item, falling back to the URL, description or
641
+ body text. Optionally, include the last operation as a parenthetical at the end
641
642
  of the title. Will use "Untitled" if all else fails.
642
643
  """
643
644
  # First special case: if we are pulling the title from the body header, check
@@ -651,12 +652,9 @@ class Item:
651
652
  if not self.title and self.url:
652
653
  return abbrev_str(self.url, max_len)
653
654
 
654
- filename_stem = self.filename_stem()
655
-
656
- # Use the title or the path if possible, falling back to description or even body text.
655
+ # Use semantic sources for titles. The original filename is preserved separately.
657
656
  base_title = (
658
657
  self.title
659
- or filename_stem
660
658
  or self.description
661
659
  or (not self.is_binary and self.abbrev_body(max_len))
662
660
  or UNTITLED
@@ -666,7 +664,11 @@ class Item:
666
664
  # indicating the last operation, if there was one. This makes filename slugs
667
665
  # more readable.
668
666
  suffix = ""
669
- if add_ops_suffix and self.type.allows_op_suffix:
667
+ if (
668
+ add_ops_suffix
669
+ and self.type.allows_op_suffix
670
+ and not re.search(r"step\d+", base_title) # Just in case, never add suffix twice.
671
+ ):
670
672
  last_op = self.history and self.history[-1].action_name
671
673
  if last_op:
672
674
  step_num = len(self.history) + 1 if self.history else 1
@@ -894,18 +896,19 @@ class Item:
894
896
  if action_context:
895
897
  # Default the output item type and format to the action's declared output_type
896
898
  # and format if not explicitly set.
897
- if "type" not in updates:
899
+ if "type" not in updates and action_context.action.output_type:
898
900
  updates["type"] = action_context.action.output_type
899
901
  # If we were not given a format override, we leave the output type the same.
900
902
  elif action_context.action.output_format:
901
903
  # Check an overridden format and then our own format.
902
- new_output_format = updates.get("format", self.format)
904
+ new_output_format = updates.get("format")
903
905
  if new_output_format and action_context.action.output_format != new_output_format:
904
906
  log.warning(
905
- "Output item format `%s` does not match declared output format `%s` for action `%s`",
907
+ "Output item format `%s` does not match declared output format `%s` for action `%s` on item: %s",
906
908
  new_output_format,
907
909
  action_context.action.output_format,
908
910
  action_context.action.name,
911
+ self,
909
912
  )
910
913
 
911
914
  new_item = self.new_copy_with(update_timestamp=True, **updates)
@@ -927,7 +930,9 @@ class Item:
927
930
 
928
931
  # Fall back to action title template if we have it and title wasn't explicitly set.
929
932
  if "title" not in updates:
930
- prev_title = self.title or (Path(self.store_path).stem if self.store_path else UNTITLED)
933
+ # Avoid using filenames as titles when deriving. Prefer existing semantic title
934
+ # or derive from body heading/URL.
935
+ prev_title = self.title or self.pick_title(pull_body_heading=True)
931
936
 
932
937
  if action:
933
938
  new_item.title = action.format_title(prev_title)
kash/shell/shell_main.py CHANGED
@@ -70,20 +70,9 @@ def build_parser() -> argparse.ArgumentParser:
70
70
 
71
71
 
72
72
  def _import_packages():
73
- try:
74
- # Slowest packages:
75
- import uvicorn.protocols # noqa: F401
76
- import uvicorn.protocols.http.h11_impl # noqa: F401
77
- import uvicorn.protocols.websockets.websockets_impl # noqa: F401
78
- import xonsh.completers.init # noqa: F401
79
- import xonsh.pyghooks # noqa: F401
80
-
81
- import kash.actions # noqa: F401
82
- import kash.local_server # noqa: F401
83
- import kash.local_server.local_server # noqa: F401
84
- import kash.mcp.mcp_server_sse # noqa: F401
85
- except ImportError as e:
86
- log.warning(f"Error pre-importing packages: {e}")
73
+ from kash.config.warm_slow_imports import warm_slow_imports
74
+
75
+ warm_slow_imports(include_extras=False)
87
76
 
88
77
  imports_done_event.set()
89
78
 
@@ -74,26 +74,108 @@ def import_recursive(
74
74
  return tallies
75
75
 
76
76
 
77
+ def _import_modules_from_package(
78
+ package: types.ModuleType,
79
+ package_name: str,
80
+ max_depth: int = 1,
81
+ include_private: bool = True,
82
+ current_depth: int = 0,
83
+ imported_modules: dict[str, types.ModuleType] | None = None,
84
+ ) -> dict[str, types.ModuleType]:
85
+ """
86
+ Internal helper to recursively import modules from a package.
87
+
88
+ Args:
89
+ package: The package module to import from
90
+ package_name: The fully qualified name of the package
91
+ max_depth: Maximum recursion depth (1 = direct children only)
92
+ include_private: Whether to import private modules (starting with _)
93
+ current_depth: Current recursion depth (internal use)
94
+ imported_modules: Dictionary to accumulate imported modules
95
+
96
+ Returns:
97
+ Dictionary mapping module names to their imported module objects
98
+ """
99
+ if imported_modules is None:
100
+ imported_modules = {}
101
+
102
+ if current_depth >= max_depth:
103
+ return imported_modules
104
+
105
+ # Get the module's __path__ if it's a package
106
+ if not hasattr(package, "__path__"):
107
+ return imported_modules
108
+
109
+ try:
110
+ for _finder, module_name, ispkg in pkgutil.iter_modules(
111
+ package.__path__, f"{package_name}."
112
+ ):
113
+ # Skip private modules unless requested
114
+ if not include_private and module_name.split(".")[-1].startswith("_"):
115
+ continue
116
+
117
+ # Skip test modules - they often have special import requirements
118
+ # and aren't needed for warming the import cache
119
+ module_parts = module_name.split(".")
120
+ if any(
121
+ part in ("tests", "test", "testing", "_test", "_tests") for part in module_parts
122
+ ):
123
+ continue
124
+
125
+ # Skip already imported modules
126
+ if module_name in imported_modules:
127
+ continue
128
+
129
+ try:
130
+ module = importlib.import_module(module_name)
131
+ imported_modules[module_name] = module
132
+
133
+ # Recursively import submodules if it's a package
134
+ if ispkg and current_depth + 1 < max_depth:
135
+ _import_modules_from_package(
136
+ module,
137
+ module_name,
138
+ max_depth=max_depth,
139
+ include_private=include_private,
140
+ current_depth=current_depth + 1,
141
+ imported_modules=imported_modules,
142
+ )
143
+
144
+ except Exception as e:
145
+ # Handle various import failures gracefully
146
+ # This includes ImportError, pytest.Skipped, and other exceptions
147
+ error_type = type(e).__name__
148
+ if error_type not in ("ImportError", "AttributeError", "TypeError"):
149
+ log.debug(f" Skipped {module_name}: {error_type}: {e}")
150
+ # Don't log common/expected import errors to reduce noise
151
+
152
+ except Exception as e:
153
+ log.warning(f"Error iterating modules in {package_name}: {e}")
154
+
155
+ return imported_modules
156
+
157
+
77
158
  def import_namespace_modules(namespace: str) -> dict[str, types.ModuleType]:
78
159
  """
79
160
  Find and import all modules or packages within a namespace package.
80
161
  Returns a dictionary mapping module names to their imported module objects.
81
162
  """
82
- importlib.import_module(namespace) # Propagate import errors
163
+ # Import the main module first
164
+ main_module = importlib.import_module(namespace) # Propagate import errors
83
165
 
84
166
  # Get the package to access its __path__
85
- package = sys.modules.get(namespace)
86
- if not package or not hasattr(package, "__path__"):
167
+ if not hasattr(main_module, "__path__"):
87
168
  raise ImportError(f"`{namespace}` is not a package or namespace package")
88
169
 
89
- log.info(f"Discovering modules in `{namespace}` namespace, searching: {package.__path__}")
170
+ log.info(f"Discovering modules in `{namespace}` namespace, searching: {main_module.__path__}")
171
+
172
+ # Use the common helper with depth=1 (no recursion) and include_private=True
173
+ modules = _import_modules_from_package(
174
+ main_module, namespace, max_depth=1, include_private=True
175
+ )
90
176
 
91
- # Iterate through all modules in the namespace package
92
- modules = {}
93
- for _finder, module_name, _ispkg in pkgutil.iter_modules(package.__path__, f"{namespace}."):
94
- module = importlib.import_module(module_name) # Propagate import errors
95
- log.info(f"Imported module: {module_name} from {module.__file__}")
96
- modules[module_name] = module
177
+ # Add the main module itself
178
+ modules[namespace] = main_module
97
179
 
98
180
  log.info(f"Imported {len(modules)} modules from namespace `{namespace}`")
99
181
  return modules
@@ -106,8 +188,13 @@ def recursive_reload(
106
188
  Recursively reload all modules in the given package that match the filter function.
107
189
  Returns a list of module names that were reloaded.
108
190
 
109
- :param filter_func: A function that takes a module name and returns True if the
110
- module should be reloaded.
191
+ Args:
192
+ package: The package to reload.
193
+ filter_func: A function that takes a module name and returns True if the
194
+ module should be reloaded.
195
+
196
+ Returns:
197
+ List of module names that were reloaded.
111
198
  """
112
199
  package_name = package.__name__
113
200
  modules = {
@@ -124,3 +211,40 @@ def recursive_reload(
124
211
  importlib.reload(modules[name])
125
212
 
126
213
  return module_names
214
+
215
+
216
+ def warm_import_library(
217
+ library_name: str, max_depth: int = 3, include_private: bool = False
218
+ ) -> dict[str, types.ModuleType]:
219
+ """
220
+ Recursively import all submodules of a library to warm the import cache.
221
+ This is useful for servers where you want to pay the import cost upfront
222
+ rather than during request handling.
223
+
224
+ Args:
225
+ library_name: Name of the library to import (e.g., 'litellm', 'openai')
226
+ max_depth: Maximum depth to recurse into submodules
227
+ include_private: Whether to import private modules (starting with _)
228
+
229
+ Returns:
230
+ Dictionary mapping module names to their imported module objects
231
+ """
232
+ try:
233
+ # Import the main module first
234
+ main_module = importlib.import_module(library_name)
235
+
236
+ # Use the common helper for recursive imports
237
+ imported_modules = _import_modules_from_package(
238
+ main_module, library_name, max_depth=max_depth, include_private=include_private
239
+ )
240
+
241
+ # Add the main module itself
242
+ imported_modules[library_name] = main_module
243
+
244
+ except ImportError as e:
245
+ log.warning(f"Could not import {library_name}: {e}")
246
+ return {}
247
+
248
+ log.info(f"Warmed {len(imported_modules)} modules from {library_name}")
249
+
250
+ return imported_modules