kash-shell 0.3.20__py3-none-any.whl → 0.3.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kash/config/colors.py CHANGED
@@ -139,14 +139,15 @@ web_light_translucent = SimpleNamespace(
139
139
  bg_header=hsl_to_hex("hsla(188, 42%, 70%, 0.2)"),
140
140
  bg_alt=hsl_to_hex("hsla(39, 24%, 90%, 0.3)"),
141
141
  bg_alt_solid=hsl_to_hex("hsla(39, 24%, 97%, 1)"),
142
- bg_selected=hsl_to_hex("hsla(188, 44%, 94%, 0.95)"),
142
+ bg_meta_solid=hsl_to_hex("hsla(39, 24%, 94%, 1)"),
143
+ bg_selected=hsl_to_hex("hsla(188, 21%, 94%, 0.9)"),
143
144
  text=hsl_to_hex("hsl(188, 39%, 11%)"),
144
145
  code=hsl_to_hex("hsl(44, 38%, 23%)"),
145
146
  border=hsl_to_hex("hsl(188, 8%, 50%)"),
146
147
  border_hint=hsl_to_hex("hsla(188, 8%, 72%, 0.3)"),
147
148
  border_accent=hsl_to_hex("hsla(305, 18%, 65%, 0.85)"),
148
149
  hover=hsl_to_hex("hsl(188, 12%, 84%)"),
149
- hover_bg=hsl_to_hex("hsla(188, 44%, 94%, 1)"),
150
+ hover_bg=hsl_to_hex("hsla(188, 18%, 97%, 1)"),
150
151
  hint=hsl_to_hex("hsl(188, 11%, 65%)"),
151
152
  hint_strong=hsl_to_hex("hsl(188, 11%, 46%)"),
152
153
  hint_gentle=hsl_to_hex("hsla(188, 11%, 65%, 0.2)"),
@@ -165,14 +166,15 @@ web_light_translucent = SimpleNamespace(
165
166
  web_dark_translucent = SimpleNamespace(
166
167
  primary=hsl_to_hex("hsl(188, 40%, 62%)"),
167
168
  primary_light=hsl_to_hex("hsl(188, 50%, 72%)"),
168
- secondary=hsl_to_hex("hsl(188, 12%, 65%)"),
169
- tertiary=hsl_to_hex("hsl(188, 7%, 40%)"),
169
+ secondary=hsl_to_hex("hsl(188, 12%, 70%)"),
170
+ tertiary=hsl_to_hex("hsl(188, 7%, 45%)"),
170
171
  bg=hsl_to_hex("hsla(220, 14%, 7%, 0.95)"),
171
172
  bg_solid=hsl_to_hex("hsl(220, 14%, 7%)"),
172
173
  bg_header=hsl_to_hex("hsla(188, 42%, 20%, 0.3)"),
173
174
  bg_alt=hsl_to_hex("hsla(220, 14%, 12%, 0.5)"),
174
- bg_alt_solid=hsl_to_hex("hsl(220, 14%, 12%)"),
175
- bg_selected=hsl_to_hex("hsla(188, 12%, 50%, 0.95)"),
175
+ bg_alt_solid=hsl_to_hex("hsl(220, 15%, 16%)"),
176
+ bg_meta_solid=hsl_to_hex("hsl(220, 14%, 25%)"),
177
+ bg_selected=hsl_to_hex("hsla(188, 13%, 33%, 0.95)"),
176
178
  text=hsl_to_hex("hsl(188, 10%, 90%)"),
177
179
  code=hsl_to_hex("hsl(44, 38%, 72%)"),
178
180
  border=hsl_to_hex("hsl(188, 8%, 25%)"),
@@ -277,6 +277,8 @@ EMOJI_HELP = "?"
277
277
 
278
278
  EMOJI_ACTION = "⛭"
279
279
 
280
+ EMOJI_TASK = "⚒"
281
+
280
282
  EMOJI_COMMAND = "⧁" # More ideas: ⦿⧁⧀⦿⦾⟐⦊⟡
281
283
 
282
284
  EMOJI_SHELL = "⦊"
@@ -165,48 +165,37 @@ works on readable text such as Markdown.
165
165
  This catches errors and allows you to find actions that might apply to a given selected
166
166
  set of items using `suggest_actions`.
167
167
 
168
- ### Programmatic Use
169
-
170
- Since commands and actions are really just Python functions.
171
-
172
- ### Useful Features
173
-
174
- Kash makes a few kinds of messy text manipulations easier:
175
-
176
- - Reusable LLM actions: A common kind of action is to invoke an LLM (like GPT-4o or o1)
177
- on a text item, with a given system and user prompt template.
178
- New LLM actions can be added with a few lines of Python by subclassing an action base
179
- class, typically `Action`, `CachedItemAction` (for any action that doesn't need to be
180
- rerun if it has the same single output), `CachedLLMAction` (if it also is performing
181
- an LLM-based transform), or `ChunkedLLMAction` (if it will be processing a document
182
- broken into <div class="chunk"> elements).
183
-
184
- - Sliding window transformations: LLMs can have trouble processing large inputs, not
185
- just because of context window and because they may make more mistakes when making
186
- lots of changes at once.
187
- Kash supports running actions in a sliding window across the document, then stitching
188
- the results back together when done.
189
-
190
- - Checking and enforcing changes: LLMs do not reliably do what they are asked to do.
191
- So a key part of making them useful is to save outputs at each step of the way and
192
- have a way to review their outputs or provide guardrails on what they can do with
193
- content.
194
-
195
- - Fine-grainded diffs with word tokens: Documents can be represented at the word level,
196
- using “word tokens” to represent words and normalized whitespace (word, sentence, and
197
- paragraph breaks, but not line breaks).
198
- This allows diffs of similar documents regardless of formatting.
199
- For example, it is possible to ask an LLM only to add paragraph breaks, then drop any
200
- other changes it makes to other words.
201
- You can use this intelligent matching of words to “backfill” specific content from one
202
- doc into an edited document, such as pulling timestamps from a full transcript back
203
- into an edited transcript or summary.
204
-
205
- - Paragraph and sentence operations: A lot of operations within actions should be done
206
- in chunks at the paragraph or sentence level.
207
- Kash offers simple tools to subdivide documents into paragraphs and sentences and
208
- these can be used together with sliding windows to process large documents.
209
-
210
- In addition, there are built-in kash commands that are part of the kash tool itself.
211
- These allow you to list items in the workspace, see or change the current selection,
212
- archive items, view logs, etc.
168
+ ### Programmatic Usage
169
+
170
+ Kash can be used entirely programmatically, so that actions are called just like
171
+ functions from Python, but the additional functionality of the items model, saving files
172
+ to a workspace, and so on, are all automatic.
173
+
174
+ This means you can use Kash to build your own CLI apps much more quickly.
175
+
176
+ For an example of this, see [textpress](https://github.com/jlevy/textpress), which wraps
177
+ quite a few kash actions to allow clean publishing of docx or PDF files on
178
+ [textpress.md](https://textpress.md/).
179
+
180
+ ### Utilities and Supporting Libraries
181
+
182
+ Kash includes a number of utility libraries to help with common tasks, either in the
183
+ base `kash-shell` package or or smaller dependencies:
184
+
185
+ - See [frontmatter-format](https://github.com/jlevy/frontmatter-format) for the spec and
186
+ implementation we use of frontmatter YAML format.
187
+
188
+ - See
189
+ [utils/file_utils](https://github.com/jlevy/kash/tree/main/src/kash/utils/file_utils)
190
+ for file format detection, conversion, filename handling, etc.
191
+
192
+ - See [chopdiff](https://github.com/jlevy/chopdiff) for a simple text doc data model
193
+ that includes sentences and paragraphs and fairly advanced diffing, filtered diffing,
194
+ and windowed transformations of text via LLM calls.
195
+
196
+ - See [clideps](https://github.com/jlevy/clideps) for utilities for helping with dot-env
197
+ files, API key setup, and dependency checks.
198
+
199
+ - See [utils/common](https://github.com/jlevy/kash/tree/main/src/kash/utils/common) the
200
+ rest of [utils/](https://github.com/jlevy/kash/tree/main/src/kash/utils) for a variety
201
+ of other general utilities.
@@ -105,7 +105,10 @@ def kash_action_class(cls: type[A]) -> type[A]:
105
105
 
106
106
 
107
107
  def _register_dynamic_action(
108
- action_cls: type[A], action_name: str, action_description: str, source_path: Path | None
108
+ action_cls: type[A],
109
+ action_name: str,
110
+ action_description: str,
111
+ source_path: Path | None,
109
112
  ) -> type[A]:
110
113
  # Set class fields for name and description for convenience.
111
114
  action_cls.name = action_name
@@ -206,6 +209,7 @@ def kash_action(
206
209
  run_per_item: bool | None = None,
207
210
  uses_selection: bool = True,
208
211
  interactive_input: bool = False,
212
+ live_output: bool = False,
209
213
  mcp_tool: bool = False,
210
214
  title_template: TitleTemplate = TitleTemplate("{title}"),
211
215
  llm_options: LLMOptions = LLMOptions(),
@@ -235,13 +239,17 @@ def kash_action(
235
239
  def decorator(orig_func: AF) -> AF:
236
240
  if hasattr(orig_func, "__action_class__"):
237
241
  log.warning(
238
- "Function `%s` is already decorated with `@kash_action`", orig_func.__name__
242
+ "Function `%s` is already decorated with `@kash_action`",
243
+ orig_func.__name__,
239
244
  )
240
245
  return orig_func
241
246
 
242
247
  # Inspect and sanity check the formal params.
243
248
  func_params = inspect_function_params(orig_func)
244
- if len(func_params) == 0 or func_params[0].effective_type not in (ActionInput, Item):
249
+ if len(func_params) == 0 or func_params[0].effective_type not in (
250
+ ActionInput,
251
+ Item,
252
+ ):
245
253
  raise InvalidDefinition(
246
254
  f"Decorator `@kash_action` requires exactly one positional parameter, "
247
255
  f"`input` of type `ActionInput` or `Item` on function `{orig_func.__name__}` but "
@@ -311,6 +319,7 @@ def kash_action(
311
319
  self.uses_selection = uses_selection
312
320
  self.output_type = output_type
313
321
  self.interactive_input = interactive_input
322
+ self.live_output = live_output
314
323
  self.mcp_tool = mcp_tool
315
324
  self.title_template = title_template
316
325
  self.llm_options = llm_options
@@ -332,8 +341,14 @@ def kash_action(
332
341
  kw_args[fp.name] = self.get_param(fp.name)
333
342
 
334
343
  if self.params:
335
- log.info("Action function param declarations:\n%s", fmt_lines(self.params))
336
- log.info("Action function param values:\n%s", self.param_value_summary_str())
344
+ log.info(
345
+ "Action function param declarations:\n%s",
346
+ fmt_lines(self.params),
347
+ )
348
+ log.info(
349
+ "Action function param values:\n%s",
350
+ self.param_value_summary_str(),
351
+ )
337
352
  else:
338
353
  log.info("Action function has no declared params")
339
354
 
@@ -68,7 +68,7 @@ def llm_transform_str(options: LLMOptions, input_str: str, check_no_results: boo
68
68
  diff_filter=options.diff_filter,
69
69
  ).reassemble()
70
70
  else:
71
- log.message(
71
+ log.info(
72
72
  "Running simple LLM transform action %s with model %s",
73
73
  options.op_name,
74
74
  options.model.litellm_name,
@@ -56,7 +56,7 @@ class ShellCallableAction:
56
56
 
57
57
  log.info("Action shell args: %s", shell_args)
58
58
  explicit_values = RawParamValues(shell_args.options)
59
- if not action.interactive_input:
59
+ if not action.interactive_input and not action.live_output:
60
60
  with get_console().status(f"Running action {action.name}…", spinner=SPINNER):
61
61
  result = run_action_with_shell_context(
62
62
  action_cls,
@@ -173,7 +173,7 @@ def llm_template_completion(
173
173
  )
174
174
 
175
175
  if check_no_results and is_no_results(result.content):
176
- log.message("No results for LLM transform, will ignore: %r", result.content)
176
+ log.info("No results for LLM transform, will ignore: %r", result.content)
177
177
  result.content = ""
178
178
 
179
179
  return result
@@ -270,6 +270,12 @@ class Action(ABC):
270
270
  Does this action ask for input interactively?
271
271
  """
272
272
 
273
+ live_output: bool = False
274
+ """
275
+ Does this action have live output (e.g., progress bars, spinners)?
276
+ If True, the shell should not show its own status spinner.
277
+ """
278
+
273
279
  mcp_tool: bool = False
274
280
  """
275
281
  If True, this action is published as an MCP tool.
@@ -28,6 +28,7 @@ from kash.config.text_styles import (
28
28
  STYLE_HINT,
29
29
  )
30
30
  from kash.shell.output.kmarkdown import KMarkdown
31
+ from kash.utils.rich_custom.multitask_status import MultiTaskStatus, StatusSettings
31
32
  from kash.utils.rich_custom.rich_indent import Indent
32
33
  from kash.utils.rich_custom.rich_markdown_fork import Markdown
33
34
 
@@ -80,6 +81,20 @@ def console_pager(use_pager: bool = True):
80
81
  PrintHooks.after_pager()
81
82
 
82
83
 
84
+ def multitask_status(
85
+ settings: StatusSettings | None = None, *, auto_summary: bool = True
86
+ ) -> MultiTaskStatus:
87
+ """
88
+ Create a `MultiTaskStatus` context manager for displaying multiple task progress
89
+ using the global shell console.
90
+ """
91
+ return MultiTaskStatus(
92
+ console=get_console(),
93
+ settings=settings,
94
+ auto_summary=auto_summary,
95
+ )
96
+
97
+
83
98
  null_style = rich.style.Style.null()
84
99
 
85
100
 
@@ -0,0 +1,305 @@
1
+ from __future__ import annotations
2
+
3
+ import random
4
+ from collections.abc import Callable
5
+ from dataclasses import dataclass
6
+
7
+
8
+ class RetryException(RuntimeError):
9
+ """
10
+ Base exception class for retry-related errors.
11
+ """
12
+
13
+
14
+ class RetryExhaustedException(RetryException):
15
+ """
16
+ Retries exhausted (this is not retriable).
17
+ """
18
+
19
+ def __init__(self, original_exception: Exception, max_retries: int, total_time: float):
20
+ self.original_exception = original_exception
21
+ self.max_retries = max_retries
22
+ self.total_time = total_time
23
+
24
+ super().__init__(
25
+ f"Max retries ({max_retries}) exhausted after {total_time:.1f}s. "
26
+ f"Final error: {type(original_exception).__name__}: {original_exception}"
27
+ )
28
+
29
+
30
+ def default_is_retriable(exception: Exception) -> bool:
31
+ """
32
+ Default retriable exception checker for common rate limit patterns.
33
+
34
+ Args:
35
+ exception: The exception to check
36
+
37
+ Returns:
38
+ True if the exception should be retried with backoff
39
+ """
40
+ # Check for LiteLLM specific exceptions first, as a soft dependency.
41
+ try:
42
+ import litellm.exceptions
43
+
44
+ # Check for specific LiteLLM exception types
45
+ if isinstance(
46
+ exception,
47
+ (
48
+ litellm.exceptions.RateLimitError,
49
+ litellm.exceptions.APIError,
50
+ ),
51
+ ):
52
+ return True
53
+ except ImportError:
54
+ # LiteLLM not available, fall back to string-based detection
55
+ pass
56
+
57
+ # Fallback to string-based detection for general patterns
58
+ exception_str = str(exception).lower()
59
+ rate_limit_indicators = [
60
+ "rate limit",
61
+ "too many requests",
62
+ "try again later",
63
+ "429",
64
+ "quota exceeded",
65
+ "throttled",
66
+ "rate_limit_error",
67
+ "ratelimiterror",
68
+ ]
69
+
70
+ return any(indicator in exception_str for indicator in rate_limit_indicators)
71
+
72
+
73
+ @dataclass(frozen=True)
74
+ class RetrySettings:
75
+ """
76
+ Retry behavior when handling concurrent requests.
77
+ """
78
+
79
+ max_task_retries: int
80
+ """Maximum retries per individual task (0 = no retries)"""
81
+
82
+ max_total_retries: int | None = None
83
+ """Maximum retries across all tasks combined (None = no global limit)"""
84
+
85
+ initial_backoff: float = 1.0
86
+ """Base backoff time in seconds"""
87
+
88
+ max_backoff: float = 128.0
89
+ """Maximum backoff time in seconds"""
90
+
91
+ backoff_factor: float = 2.0
92
+ """Exponential backoff multiplier"""
93
+
94
+ is_retriable: Callable[[Exception], bool] = default_is_retriable
95
+ """Function to determine if an exception should be retried"""
96
+
97
+
98
+ DEFAULT_RETRIES = RetrySettings(
99
+ max_task_retries=10,
100
+ max_total_retries=100,
101
+ initial_backoff=1.0,
102
+ max_backoff=128.0,
103
+ backoff_factor=2.0,
104
+ is_retriable=default_is_retriable,
105
+ )
106
+ """Reasonable default retry settings with both per-task and global limits."""
107
+
108
+
109
+ NO_RETRIES = RetrySettings(
110
+ max_task_retries=0,
111
+ max_total_retries=0,
112
+ initial_backoff=0.0,
113
+ max_backoff=0.0,
114
+ backoff_factor=1.0,
115
+ is_retriable=lambda _: False,
116
+ )
117
+ """Disable retries completely."""
118
+
119
+
120
+ def extract_retry_after(exception: Exception) -> float | None:
121
+ """
122
+ Try to extract retry-after time from exception headers or message.
123
+
124
+ Args:
125
+ exception: The exception to extract retry-after from
126
+
127
+ Returns:
128
+ Retry-after time in seconds, or None if not found
129
+ """
130
+ # Check if exception has response headers
131
+ response = getattr(exception, "response", None)
132
+ if response:
133
+ headers = getattr(response, "headers", None)
134
+ if headers and "retry-after" in headers:
135
+ try:
136
+ return float(headers["retry-after"])
137
+ except (ValueError, TypeError):
138
+ pass
139
+
140
+ # Check for retry_after attribute
141
+ retry_after = getattr(exception, "retry_after", None)
142
+ if retry_after is not None:
143
+ try:
144
+ return float(retry_after)
145
+ except (ValueError, TypeError):
146
+ pass
147
+
148
+ return None
149
+
150
+
151
+ def calculate_backoff(
152
+ attempt: int,
153
+ exception: Exception,
154
+ *,
155
+ initial_backoff: float,
156
+ max_backoff: float,
157
+ backoff_factor: float,
158
+ ) -> float:
159
+ """
160
+ Calculate backoff time using exponential backoff with jitter.
161
+
162
+ Args:
163
+ attempt: Current attempt number (0-based)
164
+ exception: The exception that triggered the backoff
165
+ initial_backoff: Base backoff time in seconds
166
+ max_backoff: Maximum backoff time in seconds
167
+ backoff_factor: Exponential backoff multiplier
168
+
169
+ Returns:
170
+ Backoff time in seconds
171
+ """
172
+ # Try to extract retry-after header if available
173
+ retry_after = extract_retry_after(exception)
174
+ if retry_after is not None:
175
+ return min(retry_after, max_backoff)
176
+
177
+ # Exponential backoff: initial_backoff * (backoff_factor ^ attempt)
178
+ exponential_backoff = initial_backoff * (backoff_factor**attempt)
179
+
180
+ # Add significant jitter (±50% randomization) to prevent thundering herd
181
+ jitter_factor = 1 + (random.random() - 0.5) * 1.0
182
+ backoff_with_jitter = exponential_backoff * jitter_factor
183
+ # Add a small random base delay (0 to 50% of initial_backoff) to further spread out retries
184
+ base_delay = random.random() * (initial_backoff * 0.5)
185
+ total_backoff = backoff_with_jitter + base_delay
186
+
187
+ return min(total_backoff, max_backoff)
188
+
189
+
190
+ ## Tests
191
+
192
+
193
+ def test_default_is_retriable():
194
+ """Test string-based rate limit detection."""
195
+ # Positive cases
196
+ assert default_is_retriable(Exception("Rate limit exceeded"))
197
+ assert default_is_retriable(Exception("Too many requests"))
198
+ assert default_is_retriable(Exception("HTTP 429 error"))
199
+ assert default_is_retriable(Exception("Quota exceeded"))
200
+ assert default_is_retriable(Exception("throttled"))
201
+ assert default_is_retriable(Exception("RateLimitError"))
202
+
203
+ # Negative cases
204
+ assert not default_is_retriable(Exception("Authentication failed"))
205
+ assert not default_is_retriable(Exception("Invalid API key"))
206
+ assert not default_is_retriable(Exception("Network error"))
207
+
208
+
209
+ def test_default_is_retriable_litellm():
210
+ """Test LiteLLM exception detection if available."""
211
+ try:
212
+ import litellm.exceptions
213
+
214
+ # Test retriable LiteLLM exceptions
215
+ rate_error = litellm.exceptions.RateLimitError(
216
+ message="Rate limit", model="test", llm_provider="test"
217
+ )
218
+ api_error = litellm.exceptions.APIError(
219
+ message="API error", model="test", llm_provider="test", status_code=500
220
+ )
221
+ assert default_is_retriable(rate_error)
222
+ assert default_is_retriable(api_error)
223
+
224
+ # Test non-retriable exception
225
+ auth_error = litellm.exceptions.AuthenticationError(
226
+ message="Auth failed", model="test", llm_provider="test"
227
+ )
228
+ assert not default_is_retriable(auth_error)
229
+
230
+ except ImportError:
231
+ # LiteLLM not available, skip
232
+ pass
233
+
234
+
235
+ def test_extract_retry_after():
236
+ """Test retry-after header extraction."""
237
+
238
+ class MockResponse:
239
+ def __init__(self, headers):
240
+ self.headers = headers
241
+
242
+ class MockException(Exception):
243
+ def __init__(self, response=None, retry_after=None):
244
+ self.response = response
245
+ if retry_after is not None:
246
+ self.retry_after = retry_after
247
+ super().__init__()
248
+
249
+ # Test response header
250
+ response = MockResponse({"retry-after": "30"})
251
+ assert extract_retry_after(MockException(response=response)) == 30.0
252
+
253
+ # Test retry_after attribute
254
+ assert extract_retry_after(MockException(retry_after=45.0)) == 45.0
255
+
256
+ # Test no retry info
257
+ assert extract_retry_after(MockException()) is None
258
+
259
+ # Test invalid values
260
+ invalid_response = MockResponse({"retry-after": "invalid"})
261
+ assert extract_retry_after(MockException(response=invalid_response)) is None
262
+
263
+
264
+ def test_calculate_backoff():
265
+ """Test backoff calculation."""
266
+
267
+ class MockException(Exception):
268
+ def __init__(self, retry_after=None):
269
+ self.retry_after = retry_after
270
+ super().__init__()
271
+
272
+ # Test with retry_after header
273
+ exception = MockException(retry_after=30.0)
274
+ assert (
275
+ calculate_backoff(
276
+ attempt=1,
277
+ exception=exception,
278
+ initial_backoff=1.0,
279
+ max_backoff=60.0,
280
+ backoff_factor=2.0,
281
+ )
282
+ == 30.0
283
+ )
284
+
285
+ # Test exponential backoff with increased jitter and base delay
286
+ exception = MockException()
287
+ backoff = calculate_backoff(
288
+ attempt=1,
289
+ exception=exception,
290
+ initial_backoff=1.0,
291
+ max_backoff=60.0,
292
+ backoff_factor=2.0,
293
+ )
294
+ # base factor * (±50% jitter) + (0-50% of initial_backoff) = range calculation
295
+ assert 1.0 <= backoff <= 3.5
296
+
297
+ # Test max_backoff cap
298
+ high_backoff = calculate_backoff(
299
+ attempt=10,
300
+ exception=exception,
301
+ initial_backoff=1.0,
302
+ max_backoff=5.0,
303
+ backoff_factor=2.0,
304
+ )
305
+ assert high_backoff <= 5.0
@@ -0,0 +1,84 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Any
5
+ from urllib.parse import urlencode
6
+
7
+ import requests
8
+ from pyrate_limiter import Duration, Limiter, Rate
9
+ from pyrate_limiter.buckets import InMemoryBucket
10
+ from typing_extensions import override
11
+
12
+ from kash.config.logger import get_logger
13
+ from kash.web_content.file_cache_utils import cache_file
14
+ from kash.web_content.local_file_cache import Loadable
15
+
16
+ log = get_logger(__name__)
17
+
18
+
19
+ class CachingSession(requests.Session):
20
+ """
21
+ A `requests.Session` that adds local file caching and optional rate limiting (if
22
+ `limit` and `limit_interval_secs` are provided). A bit of a hack but enables
23
+ hot patching libraries that use `requests` without other code changes.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ *,
29
+ limit: int | None = None,
30
+ limit_interval_secs: int | None = None,
31
+ max_wait_secs: int = 60 * 5,
32
+ ):
33
+ super().__init__()
34
+ self._limiter: Limiter | None = None
35
+ if limit and limit_interval_secs:
36
+ rate = Rate(limit, Duration.SECOND * limit_interval_secs)
37
+ bucket = InMemoryBucket([rate])
38
+ # Explicitly set raise_when_fail=False and max_delay to enable blocking.
39
+ self._limiter = Limiter(
40
+ bucket, raise_when_fail=False, max_delay=Duration.SECOND * max_wait_secs
41
+ )
42
+ log.info(
43
+ "CachingSession: rate limiting requests with limit=%d, interval=%d, max_wait=%d",
44
+ limit,
45
+ limit_interval_secs,
46
+ max_wait_secs,
47
+ )
48
+
49
+ @override
50
+ def get(self, url: str | bytes, **kwargs: Any) -> Any:
51
+ params = kwargs.get("params")
52
+ # We need a unique key for the cache, so we use the URL and params.
53
+ url_str = url.decode() if isinstance(url, bytes) else str(url)
54
+ query_string = urlencode(params or {})
55
+ url_key = f"{url_str}?{query_string}" if query_string else url_str
56
+
57
+ def save(path: Path):
58
+ if self._limiter:
59
+ acquired = self._limiter.try_acquire("caching_session_get")
60
+ if not acquired:
61
+ # Generally shouldn't happen.
62
+ raise RuntimeError("Rate limiter failed to acquire after maximum delay")
63
+
64
+ response = super(CachingSession, self).get(url, **kwargs)
65
+ response.raise_for_status()
66
+ content = response.content
67
+ with open(path, "wb") as f:
68
+ f.write(content)
69
+
70
+ cache_result = cache_file(Loadable(url_key, save))
71
+
72
+ if not cache_result.was_cached:
73
+ log.debug("Cache miss, fetched: %s", url_key)
74
+ else:
75
+ log.debug("Cache hit: %s", url_key)
76
+
77
+ # A simple hack to make sure response.json() works (e.g. when using wikipediaapi needs).
78
+ # TODO: Wrap more carefully to ensure other methods work.
79
+ response = requests.Response()
80
+ response.status_code = 200
81
+ response.encoding = "utf-8"
82
+ response._content = cache_result.content.path.read_bytes()
83
+ response.url = url_key
84
+ return response