kash-shell 0.3.23__py3-none-any.whl → 0.3.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. kash/actions/core/combine_docs.py +52 -0
  2. kash/actions/core/concat_docs.py +47 -0
  3. kash/commands/workspace/workspace_commands.py +2 -2
  4. kash/config/logger.py +3 -2
  5. kash/config/settings.py +8 -0
  6. kash/docs/markdown/topics/a2_installation.md +2 -2
  7. kash/embeddings/embeddings.py +4 -6
  8. kash/embeddings/text_similarity.py +2 -5
  9. kash/exec/action_exec.py +1 -1
  10. kash/exec/fetch_url_items.py +36 -8
  11. kash/help/help_embeddings.py +3 -0
  12. kash/llm_utils/llm_completion.py +1 -1
  13. kash/llm_utils/llm_features.py +1 -1
  14. kash/llm_utils/llms.py +5 -7
  15. kash/mcp/mcp_cli.py +2 -2
  16. kash/model/params_model.py +1 -1
  17. kash/utils/api_utils/api_retries.py +84 -76
  18. kash/utils/api_utils/gather_limited.py +227 -89
  19. kash/utils/api_utils/http_utils.py +46 -0
  20. kash/utils/api_utils/progress_protocol.py +49 -56
  21. kash/utils/rich_custom/multitask_status.py +70 -21
  22. kash/utils/text_handling/doc_normalization.py +2 -0
  23. kash/utils/text_handling/markdown_utils.py +14 -3
  24. kash/web_content/web_extract.py +12 -8
  25. kash/web_content/web_fetch.py +289 -60
  26. kash/web_content/web_page_model.py +5 -0
  27. kash/web_gen/templates/base_styles.css.jinja +8 -1
  28. {kash_shell-0.3.23.dist-info → kash_shell-0.3.25.dist-info}/METADATA +6 -4
  29. {kash_shell-0.3.23.dist-info → kash_shell-0.3.25.dist-info}/RECORD +32 -29
  30. {kash_shell-0.3.23.dist-info → kash_shell-0.3.25.dist-info}/WHEEL +0 -0
  31. {kash_shell-0.3.23.dist-info → kash_shell-0.3.25.dist-info}/entry_points.txt +0 -0
  32. {kash_shell-0.3.23.dist-info → kash_shell-0.3.25.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,52 @@
1
+ from chopdiff.html.html_in_md import div_wrapper
2
+
3
+ from kash.config.logger import get_logger
4
+ from kash.exec import kash_action
5
+ from kash.model import ONE_OR_MORE_ARGS, ActionInput, ActionResult, Param
6
+ from kash.utils.errors import InvalidInput
7
+
8
+ log = get_logger(__name__)
9
+
10
+
11
+ @kash_action(
12
+ expected_args=ONE_OR_MORE_ARGS,
13
+ params=(
14
+ Param(
15
+ "class_name",
16
+ "CSS class name to use for wrapping each document in a div.",
17
+ type=str,
18
+ default_value="doc",
19
+ ),
20
+ ),
21
+ )
22
+ def combine_docs(input: ActionInput, class_name: str = "page") -> ActionResult:
23
+ """
24
+ Combine multiple text items into a single document, wrapping each piece
25
+ in a div with the specified CSS class.
26
+ """
27
+ items = input.items
28
+
29
+ if not items:
30
+ raise InvalidInput("No items provided for combination")
31
+
32
+ # Create wrapper function
33
+ wrapper = div_wrapper(class_name=class_name)
34
+
35
+ # Collect and wrap all bodies
36
+ wrapped_bodies = []
37
+ for item in items:
38
+ if not item.body:
39
+ raise InvalidInput(f"Item has no body: {item.store_path}")
40
+ wrapped_bodies.append(wrapper(item.body))
41
+
42
+ # Concatenate with double newlines
43
+ combined_body = "\n\n".join(wrapped_bodies)
44
+
45
+ # Create title
46
+ count = len(items)
47
+ title = f"Combined ({count} doc{'s' if count != 1 else ''})"
48
+
49
+ # Create result item based on first item
50
+ result_item = items[0].derived_copy(body=combined_body, title=title, original_filename=None)
51
+
52
+ return ActionResult([result_item])
@@ -0,0 +1,47 @@
1
+ from kash.config.logger import get_logger
2
+ from kash.exec import kash_action
3
+ from kash.model import ONE_OR_MORE_ARGS, ActionInput, ActionResult, Param
4
+ from kash.utils.errors import InvalidInput
5
+
6
+ log = get_logger(__name__)
7
+
8
+
9
+ @kash_action(
10
+ expected_args=ONE_OR_MORE_ARGS,
11
+ params=(
12
+ Param(
13
+ "separator",
14
+ "String to use between concatenated items.",
15
+ type=str,
16
+ default_value="\n\n",
17
+ ),
18
+ ),
19
+ )
20
+ def concat_docs(input: ActionInput, separator: str = "\n\n") -> ActionResult:
21
+ """
22
+ Concatenate multiple text items into a single document with the specified
23
+ separator between each piece.
24
+ """
25
+ items = input.items
26
+
27
+ if not items:
28
+ raise InvalidInput("No items provided for concatenation")
29
+
30
+ # Collect all bodies
31
+ bodies = []
32
+ for item in items:
33
+ if not item.body:
34
+ raise InvalidInput(f"Item has no body: {item.store_path}")
35
+ bodies.append(item.body)
36
+
37
+ # Concatenate with the specified separator
38
+ concat_body = separator.join(bodies)
39
+
40
+ # Create title
41
+ count = len(items)
42
+ title = f"Concat ({count} doc{'s' if count != 1 else ''})"
43
+
44
+ # Create result item based on first item
45
+ result_item = items[0].derived_copy(body=concat_body, title=title, original_filename=None)
46
+
47
+ return ActionResult([result_item])
@@ -474,8 +474,8 @@ def fetch_url(*files_or_urls: str, refetch: bool = False) -> ShellResult:
474
474
  store_paths = []
475
475
  for locator in locators:
476
476
  try:
477
- fetched_item = fetch_url_item(locator, refetch=refetch)
478
- store_paths.append(fetched_item.store_path)
477
+ fetch_result = fetch_url_item(locator, refetch=refetch)
478
+ store_paths.append(fetch_result.item.store_path)
479
479
  except InvalidInput as e:
480
480
  log.warning(
481
481
  "Not a URL or URL resource, will not fetch metadata: %s: %s", fmt_loc(locator), e
kash/config/logger.py CHANGED
@@ -254,9 +254,10 @@ def _do_logging_setup(log_settings: LogSettings):
254
254
  _console_handler = basic_stderr_handler(log_settings.log_console_level)
255
255
 
256
256
  # Manually adjust logging for a few packages, removing previous verbose default handlers.
257
-
257
+ # Set root logger to most permissive level so handlers can do the filtering
258
+ root_level = min(log_settings.log_console_level.value, log_settings.log_file_level.value)
258
259
  log_levels = {
259
- None: INFO,
260
+ None: root_level,
260
261
  "LiteLLM": INFO,
261
262
  "LiteLLM Router": INFO,
262
263
  "LiteLLM Proxy": INFO,
kash/config/settings.py CHANGED
@@ -210,6 +210,12 @@ class Settings:
210
210
  use_nerd_icons: bool
211
211
  """If true, use Nerd Icons in file listings. Requires a compatible font."""
212
212
 
213
+ limit_rps: float
214
+ """Default rate limit for API calls."""
215
+
216
+ limit_concurrency: int
217
+ """Default concurrency limit for API calls."""
218
+
213
219
 
214
220
  ws_root_dir = Path("~/Kash").expanduser()
215
221
 
@@ -276,6 +282,8 @@ def _read_settings():
276
282
  mcp_server_port=DEFAULT_MCP_SERVER_PORT,
277
283
  use_kerm_codes=False,
278
284
  use_nerd_icons=True,
285
+ limit_rps=5.0,
286
+ limit_concurrency=10,
279
287
  )
280
288
 
281
289
 
@@ -124,7 +124,7 @@ These are for `kash-media` but you can use a `kash-shell` for a more basic setup
124
124
 
125
125
  You can use kash from your MCP client (such as Anthropic Desktop or Cursor).
126
126
 
127
- You do this by running the the `kash_mcp` binary to make kash actions available as MCP
127
+ You do this by running the the `kash-mcp` binary to make kash actions available as MCP
128
128
  tools.
129
129
 
130
130
  For Claude Desktop, my config looks like this:
@@ -133,7 +133,7 @@ For Claude Desktop, my config looks like this:
133
133
  {
134
134
  "mcpServers": {
135
135
  "kash": {
136
- "command": "/Users/levy/.local/bin/kash_mcp",
136
+ "command": "/Users/levy/.local/bin/kash-mcp",
137
137
  "args": ["--proxy"]
138
138
  }
139
139
  }
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import ast
4
4
  from collections.abc import Iterable
5
5
  from pathlib import Path
6
- from typing import TYPE_CHECKING, TypeAlias, cast
6
+ from typing import TYPE_CHECKING, TypeAlias
7
7
 
8
8
  from pydantic.dataclasses import dataclass
9
9
  from strif import abbrev_list
@@ -65,12 +65,11 @@ class Embeddings:
65
65
  @classmethod
66
66
  def embed(cls, keyvals: list[KeyVal], model=DEFAULT_EMBEDDING_MODEL) -> Embeddings:
67
67
  from litellm import embedding
68
- from litellm.types.utils import EmbeddingResponse
69
68
 
70
69
  init_litellm()
71
70
 
72
71
  data = {}
73
- log.message(
72
+ log.info(
74
73
  "Embedding %d texts (model %s, batch size %s)…",
75
74
  len(keyvals),
76
75
  model.litellm_name,
@@ -82,9 +81,8 @@ class Embeddings:
82
81
  keys = [kv[0] for kv in batch]
83
82
  texts = [kv[1] for kv in batch]
84
83
 
85
- response: EmbeddingResponse = cast(
86
- EmbeddingResponse, embedding(model=model.litellm_name, input=texts)
87
- )
84
+ response = embedding(model=model.litellm_name, input=texts)
85
+
88
86
  if not response.data:
89
87
  raise ValueError("No embedding response data")
90
88
 
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import TYPE_CHECKING, cast
3
+ from typing import TYPE_CHECKING
4
4
 
5
5
  from funlog import log_calls
6
6
 
@@ -24,12 +24,9 @@ def cosine_relatedness(x: ArrayLike, y: ArrayLike) -> float:
24
24
  def embed_query(model: EmbeddingModel, query: str) -> EmbeddingResponse:
25
25
  import litellm
26
26
  from litellm import embedding
27
- from litellm.types.utils import EmbeddingResponse
28
27
 
29
28
  try:
30
- response: EmbeddingResponse = cast(
31
- EmbeddingResponse, embedding(model=model.litellm_name, input=[query])
32
- )
29
+ response = embedding(model=model.litellm_name, input=[query])
33
30
  except litellm.exceptions.APIError as e:
34
31
  log.info("API error embedding query: %s", e)
35
32
  raise ApiResultError(str(e))
kash/exec/action_exec.py CHANGED
@@ -55,7 +55,7 @@ def prepare_action_input(*input_args: CommandArg, refetch: bool = False) -> Acti
55
55
  if input_items:
56
56
  log.message("Assembling metadata for input items:\n%s", fmt_lines(input_items))
57
57
  input_items = [
58
- fetch_url_item_content(item, refetch=refetch) if is_url_resource(item) else item
58
+ fetch_url_item_content(item, refetch=refetch).item if is_url_resource(item) else item
59
59
  for item in input_items
60
60
  ]
61
61
 
@@ -1,23 +1,42 @@
1
+ from dataclasses import dataclass
2
+
1
3
  from kash.config.logger import get_logger
2
4
  from kash.exec.preconditions import is_url_resource
3
- from kash.media_base.media_services import get_media_metadata
4
5
  from kash.model.items_model import Item, ItemType
5
6
  from kash.model.paths_model import StorePath
6
7
  from kash.utils.common.format_utils import fmt_loc
7
8
  from kash.utils.common.url import Url, is_url
8
9
  from kash.utils.common.url_slice import add_slice_to_url, parse_url_slice
9
10
  from kash.utils.errors import InvalidInput
11
+ from kash.web_content.web_page_model import WebPageData
10
12
 
11
13
  log = get_logger(__name__)
12
14
 
13
15
 
16
+ @dataclass(frozen=True)
17
+ class FetchItemResult:
18
+ """
19
+ Result of fetching a URL item.
20
+ """
21
+
22
+ item: Item
23
+
24
+ was_cached: bool
25
+ """Whether this item was already present in cache (or if we skipped the fetch
26
+ because we already had the data)."""
27
+
28
+ page_data: WebPageData | None = None
29
+ """If the item was fetched from a URL via the web content cache,
30
+ this will hold additional metadata whether it was cached."""
31
+
32
+
14
33
  def fetch_url_item(
15
34
  locator: Url | StorePath,
16
35
  *,
17
36
  save_content: bool = True,
18
37
  refetch: bool = False,
19
38
  cache: bool = True,
20
- ) -> Item:
39
+ ) -> FetchItemResult:
21
40
  from kash.workspaces import current_ws
22
41
 
23
42
  ws = current_ws()
@@ -37,7 +56,7 @@ def fetch_url_item(
37
56
 
38
57
  def fetch_url_item_content(
39
58
  item: Item, *, save_content: bool = True, refetch: bool = False, cache: bool = True
40
- ) -> Item:
59
+ ) -> FetchItemResult:
41
60
  """
42
61
  Fetch content and metadata for a URL using a media service if we
43
62
  recognize the URL as a known media service. Otherwise, fetch and extract the
@@ -51,6 +70,7 @@ def fetch_url_item_content(
51
70
  The content item is returned if content was saved. Otherwise, the updated
52
71
  URL item is returned.
53
72
  """
73
+ from kash.media_base.media_services import get_media_metadata
54
74
  from kash.web_content.canon_url import canonicalize_url
55
75
  from kash.web_content.web_extract import fetch_page_content
56
76
  from kash.workspaces import current_ws
@@ -61,7 +81,7 @@ def fetch_url_item_content(
61
81
  "Already have title, description, and body, will not fetch: %s",
62
82
  item.fmt_loc(),
63
83
  )
64
- return item
84
+ return FetchItemResult(item, was_cached=True)
65
85
 
66
86
  if not item.url:
67
87
  raise InvalidInput(f"No URL for item: {item.fmt_loc()}")
@@ -74,6 +94,8 @@ def fetch_url_item_content(
74
94
  media_metadata = get_media_metadata(url)
75
95
  url_item: Item | None = None
76
96
  content_item: Item | None = None
97
+ page_data: WebPageData | None = None
98
+
77
99
  if media_metadata:
78
100
  url_item = Item.from_media_metadata(media_metadata)
79
101
  # Preserve and canonicalize any slice suffix on the URL.
@@ -101,7 +123,6 @@ def fetch_url_item_content(
101
123
  original_filename=item.get_filename(),
102
124
  format=page_data.format_info.format,
103
125
  )
104
- ws.save(content_item)
105
126
 
106
127
  if not url_item.title:
107
128
  log.warning("Failed to fetch page data: title is missing: %s", item.url)
@@ -112,8 +133,15 @@ def fetch_url_item_content(
112
133
  if content_item:
113
134
  ws.save(content_item)
114
135
  assert content_item.store_path
115
- log.info("Saved content item: %s", content_item.fmt_loc())
136
+ log.info(
137
+ "Saved both URL and content item: %s, %s",
138
+ url_item.fmt_loc(),
139
+ content_item.fmt_loc(),
140
+ )
116
141
  else:
117
- log.info("Saved URL item: %s", url_item.fmt_loc())
142
+ log.info("Saved URL item (no content): %s", url_item.fmt_loc())
118
143
 
119
- return content_item or url_item
144
+ was_cached = bool(
145
+ not page_data or (page_data.cache_result and page_data.cache_result.was_cached)
146
+ )
147
+ return FetchItemResult(content_item or url_item, was_cached=was_cached, page_data=page_data)
@@ -3,6 +3,8 @@ from __future__ import annotations
3
3
  from dataclasses import dataclass, field
4
4
  from pathlib import Path
5
5
 
6
+ from typing_extensions import override
7
+
6
8
  from kash.config.logger import get_logger
7
9
  from kash.embeddings.embeddings import Embeddings
8
10
  from kash.embeddings.text_similarity import rank_by_relatedness
@@ -19,6 +21,7 @@ class DocKey:
19
21
  doc_type: HelpDocType
20
22
  index: int
21
23
 
24
+ @override
22
25
  def __str__(self) -> str:
23
26
  return f"{self.doc_type.value}:{self.index}"
24
27
 
@@ -107,7 +107,7 @@ def llm_completion(
107
107
 
108
108
  total_input_len = sum(len(m["content"]) for m in messages)
109
109
  speed = len(content) / elapsed
110
- log.message(
110
+ log.info(
111
111
  f"{EMOJI_TIMING} LLM completion from {model.litellm_name} in {format_duration(elapsed)}: "
112
112
  f"input {total_input_len} chars in {len(messages)} messages, output {len(content)} chars "
113
113
  f"({speed:.0f} char/s)"
@@ -68,5 +68,5 @@ preferred_llms: list[LLMName] = [
68
68
  LLM.claude_4_opus,
69
69
  LLM.claude_3_7_sonnet,
70
70
  LLM.claude_3_5_haiku,
71
- LLM.gemini_2_5_pro_preview_05_06,
71
+ LLM.gemini_2_5_pro,
72
72
  ]
kash/llm_utils/llms.py CHANGED
@@ -15,6 +15,7 @@ class LLM(LLMName, Enum):
15
15
  # https://platform.openai.com/docs/models
16
16
  o4_mini = LLMName("o4-mini")
17
17
  o3 = LLMName("o3")
18
+ o3_pro = LLMName("o3-pro")
18
19
  o3_mini = LLMName("o3-mini")
19
20
  o1 = LLMName("o1")
20
21
  o1_mini = LLMName("o1-mini")
@@ -35,13 +36,9 @@ class LLM(LLMName, Enum):
35
36
  claude_3_5_haiku = LLMName("claude-3-5-haiku-latest")
36
37
 
37
38
  # https://ai.google.dev/gemini-api/docs/models
38
- gemini_2_5_pro_preview_06_05 = LLMName("gemini/gemini-2.5-pro-preview-06-05")
39
- gemini_2_5_pro_preview_05_06 = LLMName("gemini/gemini-2.5-pro-preview-05-06")
40
- gemini_2_5_pro_preview_03_25 = LLMName("gemini/gemini-2.5-pro-preview-03-25")
41
- gemini_2_5_flash_preview = LLMName("gemini-2.5-flash-preview-05-20")
42
- gemini_2_0_flash = LLMName("gemini/gemini-2_0-flash")
43
- gemini_2_0_flash_lite = LLMName("gemini/gemini-2.0-flash-lite")
44
- gemini_2_0_pro_exp_02_05 = LLMName("gemini/gemini-2.0-pro-exp-02-05")
39
+ gemini_2_5_pro = LLMName("gemini/gemini-2.5-pro")
40
+ gemini_2_5_flash = LLMName("gemini/gemini-2.5-flash")
41
+ gemini_2_5_flash_lite = LLMName("gemini-2.5-flash-lite-preview-06-17")
45
42
 
46
43
  # https://docs.x.ai/docs/models
47
44
  xai_grok_3 = LLMName("xai/grok-3")
@@ -56,6 +53,7 @@ class LLM(LLMName, Enum):
56
53
  deepseek_reasoner = LLMName("deepseek/deepseek-reasoner")
57
54
 
58
55
  # https://console.groq.com/docs/models
56
+ groq_gemma2_9b_it = LLMName("groq/gemma2-9b-it")
59
57
  groq_llama_3_1_8b_instant = LLMName("groq/llama-3.1-8b-instant")
60
58
  groq_llama_3_3_70b_versatile = LLMName("groq/llama-3.3-70b-versatile")
61
59
  groq_deepseek_r1_distill_llama_70b = LLMName("groq/deepseek-r1-distill-llama-70b")
kash/mcp/mcp_cli.py CHANGED
@@ -114,10 +114,10 @@ def main():
114
114
  args = build_parser().parse_args()
115
115
 
116
116
  if args.list_tools or args.tool_help:
117
- kash_setup(rich_logging=True, level=LogLevel.warning)
117
+ kash_setup(rich_logging=True, log_level=LogLevel.warning)
118
118
  show_tool_info(args.tool_help)
119
119
  else:
120
- kash_setup(rich_logging=False, level=LogLevel.info)
120
+ kash_setup(rich_logging=False, log_level=LogLevel.info)
121
121
  run_server(args)
122
122
 
123
123
 
@@ -209,7 +209,7 @@ A list of parameter declarations, possibly with default values.
209
209
  DEFAULT_CAREFUL_LLM = LLM.o3
210
210
  DEFAULT_STRUCTURED_LLM = LLM.gpt_4o
211
211
  DEFAULT_STANDARD_LLM = LLM.claude_4_sonnet
212
- DEFAULT_FAST_LLM = LLM.o1_mini
212
+ DEFAULT_FAST_LLM = LLM.gpt_4o
213
213
 
214
214
 
215
215
  # Parameters set globally such as in the workspace.
@@ -5,6 +5,8 @@ from collections.abc import Callable
5
5
  from dataclasses import dataclass
6
6
  from enum import Enum
7
7
 
8
+ from kash.utils.api_utils.http_utils import extract_http_status_code
9
+
8
10
 
9
11
  class HTTPRetryBehavior(Enum):
10
12
  """HTTP status code retry behavior classification."""
@@ -62,51 +64,6 @@ class RetryExhaustedException(RetryException):
62
64
  )
63
65
 
64
66
 
65
- def extract_http_status_code(exception: Exception) -> int | None:
66
- """
67
- Extract HTTP status code from various exception types.
68
-
69
- Args:
70
- exception: The exception to extract status code from
71
-
72
- Returns:
73
- HTTP status code or None if not found
74
- """
75
- # Check for httpx.HTTPStatusError and requests.HTTPError
76
- if hasattr(exception, "response"):
77
- response = getattr(exception, "response", None)
78
- if response and hasattr(response, "status_code"):
79
- return getattr(response, "status_code", None)
80
-
81
- # Check for aiohttp errors
82
- if hasattr(exception, "status"):
83
- return getattr(exception, "status", None)
84
-
85
- # Parse from exception message as fallback
86
- exception_str = str(exception)
87
-
88
- # Try to find status code patterns in the message
89
- import re
90
-
91
- # Pattern for "403 Forbidden", "HTTP 429", etc.
92
- status_patterns = [
93
- r"\b(\d{3})\s+(?:Forbidden|Unauthorized|Not Found|Too Many Requests|Internal Server Error|Bad Gateway|Service Unavailable|Gateway Timeout)\b",
94
- r"\bHTTP\s+(\d{3})\b",
95
- r"\b(\d{3})\s+error\b",
96
- r"status\s*(?:code)?:\s*(\d{3})\b",
97
- ]
98
-
99
- for pattern in status_patterns:
100
- match = re.search(pattern, exception_str, re.IGNORECASE)
101
- if match:
102
- try:
103
- return int(match.group(1))
104
- except (ValueError, IndexError):
105
- continue
106
-
107
- return None
108
-
109
-
110
67
  def default_is_retriable(exception: Exception) -> bool:
111
68
  """
112
69
  Default retriable exception checker with HTTP status code awareness.
@@ -204,22 +161,22 @@ def default_is_retriable(exception: Exception) -> bool:
204
161
 
205
162
  def is_http_status_retriable(
206
163
  status_code: int,
207
- retry_map: dict[int, HTTPRetryBehavior] | None = None,
164
+ retry_policy: dict[int, HTTPRetryBehavior] | None = None,
208
165
  ) -> bool:
209
166
  """
210
167
  Determine if an HTTP status code should be retried.
211
168
 
212
169
  Args:
213
170
  status_code: HTTP status code
214
- retry_map: Custom retry behavior map (uses default if None)
171
+ retry_policy: Custom retry behavior policy (uses default if None)
215
172
 
216
173
  Returns:
217
174
  True if the status code should be retried
218
175
  """
219
- if retry_map is None:
220
- retry_map = DEFAULT_HTTP_RETRY_MAP
176
+ if retry_policy is None:
177
+ retry_policy = DEFAULT_HTTP_RETRY_MAP
221
178
 
222
- behavior = retry_map.get(status_code)
179
+ behavior = retry_policy.get(status_code)
223
180
 
224
181
  if behavior == HTTPRetryBehavior.FULL:
225
182
  return True
@@ -265,36 +222,46 @@ class RetrySettings:
265
222
  """Exponential backoff multiplier"""
266
223
 
267
224
  is_retriable: Callable[[Exception], bool] = default_is_retriable
268
- """Function to determine if an exception should be retried"""
269
-
270
- http_retry_map: dict[int, HTTPRetryBehavior] | None = None
271
- """Custom HTTP status code retry behavior (None = use defaults)"""
225
+ """Function to determine if non-HTTP exceptions should be retried (network errors, timeouts, etc.)"""
226
+
227
+ http_retry_policy: dict[int, HTTPRetryBehavior] | None = None
228
+ """Custom HTTP status code retry behavior policy (None = use defaults)"""
229
+
230
+ def should_retry(self, exception: Exception) -> bool:
231
+ """
232
+ Determine if an exception should be retried.
233
+
234
+ First checks for HTTP status codes and uses http_retry_policy if present.
235
+ For non-HTTP exceptions, uses the is_retriable function to determine
236
+ if other exception types (network errors, timeouts, etc.) should be retried.
237
+ """
238
+ # First check if this is an HTTP exception with a status code
239
+ status_code = extract_http_status_code(exception)
240
+ if status_code:
241
+ retry_policy = (
242
+ self.http_retry_policy
243
+ if self.http_retry_policy is not None
244
+ else DEFAULT_HTTP_RETRY_MAP
245
+ )
246
+ return is_http_status_retriable(status_code, retry_policy)
247
+
248
+ # Not an HTTP error - use is_retriable for other exception types
249
+ # (network errors, timeouts, connection issues, etc.)
250
+ return self.is_retriable(exception)
272
251
 
273
252
 
274
253
  DEFAULT_RETRIES = RetrySettings(
275
- max_task_retries=10,
276
- max_total_retries=100,
254
+ max_task_retries=15,
255
+ max_total_retries=1000,
277
256
  initial_backoff=1.0,
278
- max_backoff=128.0,
279
- backoff_factor=2.0,
257
+ max_backoff=60.0,
258
+ backoff_factor=1.5,
280
259
  is_retriable=default_is_retriable,
281
260
  )
282
261
  """Reasonable default retry settings with both per-task and global limits."""
283
262
 
284
-
285
- # Preset configurations for different use cases
286
- AGGRESSIVE_RETRIES = RetrySettings(
287
- max_task_retries=15,
288
- max_total_retries=200,
289
- initial_backoff=0.5,
290
- max_backoff=64.0,
291
- backoff_factor=1.8,
292
- )
293
- """Aggressive retry settings - retry more often with shorter initial backoff."""
294
-
295
-
296
- # Conservative retry settings use a custom retry map that excludes conservative retries
297
- _CONSERVATIVE_HTTP_RETRY_MAP = {
263
+ # Conservative retry settings use a custom retry policy that excludes conservative retries
264
+ _CONSERVATIVE_HTTP_RETRY_POLICY = {
298
265
  # Fully retriable: server errors and explicit rate limiting
299
266
  429: HTTPRetryBehavior.FULL,
300
267
  500: HTTPRetryBehavior.FULL,
@@ -319,7 +286,7 @@ CONSERVATIVE_RETRIES = RetrySettings(
319
286
  initial_backoff=2.0,
320
287
  max_backoff=60.0,
321
288
  backoff_factor=2.5,
322
- http_retry_map=_CONSERVATIVE_HTTP_RETRY_MAP,
289
+ http_retry_policy=_CONSERVATIVE_HTTP_RETRY_POLICY,
323
290
  )
324
291
  """Conservative retry settings - fewer retries, longer backoff, no conservative HTTP retries."""
325
292
 
@@ -455,9 +422,9 @@ def test_is_http_status_retriable():
455
422
  assert is_http_status_retriable(403) # Forbidden
456
423
  assert is_http_status_retriable(408) # Request Timeout
457
424
 
458
- # Conservative retriable with custom conservative map (disabled)
459
- assert not is_http_status_retriable(403, _CONSERVATIVE_HTTP_RETRY_MAP)
460
- assert not is_http_status_retriable(408, _CONSERVATIVE_HTTP_RETRY_MAP)
425
+ # Conservative retriable with custom conservative policy (disabled)
426
+ assert not is_http_status_retriable(403, _CONSERVATIVE_HTTP_RETRY_POLICY)
427
+ assert not is_http_status_retriable(408, _CONSERVATIVE_HTTP_RETRY_POLICY)
461
428
 
462
429
  # Never retriable
463
430
  assert not is_http_status_retriable(400) # Bad Request
@@ -629,3 +596,44 @@ def test_calculate_backoff():
629
596
  backoff_factor=2.0,
630
597
  )
631
598
  assert high_backoff <= 5.0
599
+
600
+
601
+ def test_retry_settings_should_retry():
602
+ """Test RetrySettings.should_retry method with custom HTTP maps."""
603
+
604
+ class MockHTTPXResponse:
605
+ def __init__(self, status_code):
606
+ self.status_code = status_code
607
+
608
+ class MockHTTPXException(Exception):
609
+ def __init__(self, status_code):
610
+ self.response = MockHTTPXResponse(status_code)
611
+ super().__init__(f"HTTP {status_code} error")
612
+
613
+ # Test with default settings (conservative retries enabled)
614
+ default_settings = RetrySettings(max_task_retries=3)
615
+ assert default_settings.should_retry(MockHTTPXException(429)) # Rate limit - retriable
616
+ assert default_settings.should_retry(MockHTTPXException(500)) # Server error - retriable
617
+ assert default_settings.should_retry(
618
+ MockHTTPXException(403)
619
+ ) # Conservative - retriable by default
620
+ assert not default_settings.should_retry(MockHTTPXException(404)) # Not found - not retriable
621
+
622
+ # Test with conservative settings (conservative retries disabled)
623
+ conservative_settings = CONSERVATIVE_RETRIES
624
+ assert conservative_settings.should_retry(
625
+ MockHTTPXException(429)
626
+ ) # Rate limit - still retriable
627
+ assert conservative_settings.should_retry(
628
+ MockHTTPXException(500)
629
+ ) # Server error - still retriable
630
+ assert not conservative_settings.should_retry(
631
+ MockHTTPXException(403)
632
+ ) # Conservative - now not retriable
633
+ assert not conservative_settings.should_retry(
634
+ MockHTTPXException(404)
635
+ ) # Not found - still not retriable
636
+
637
+ # Test with non-HTTP exception
638
+ assert default_settings.should_retry(Exception("Network error"))
639
+ assert not default_settings.should_retry(Exception("Authentication failed"))