fast-cache-middleware 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,28 @@
1
+ from fastapi import FastAPI, routing
2
+
3
+ from .depends import CacheConfig
4
+
5
+
6
+ def set_cache_age_in_openapi_schema(app: FastAPI) -> None:
7
+ openapi_schema = app.openapi()
8
+
9
+ for route in app.routes:
10
+ if isinstance(route, routing.APIRoute):
11
+ path = route.path
12
+ methods = route.methods
13
+
14
+ for dependency in route.dependencies:
15
+ dep = dependency.dependency
16
+ if isinstance(dep, CacheConfig):
17
+ max_age = dep.max_age
18
+
19
+ for method in methods:
20
+ method = method.lower()
21
+ try:
22
+ operation = openapi_schema["paths"][path][method]
23
+ operation.setdefault("x-cache-age", max_age)
24
+ except KeyError:
25
+ continue
26
+
27
+ app.openapi_schema = openapi_schema
28
+ return None
@@ -1,12 +1,13 @@
1
1
  import http
2
2
  import logging
3
- import typing as tp
3
+ import re
4
4
  from hashlib import blake2b
5
+ from typing import Optional
5
6
 
6
7
  from starlette.requests import Request
7
8
  from starlette.responses import Response
8
9
 
9
- from .depends import CacheConfig, CacheDropConfig
10
+ from .schemas import CacheConfiguration
10
11
  from .storages import BaseStorage
11
12
 
12
13
  logger = logging.getLogger(__name__)
@@ -136,7 +137,7 @@ class Controller:
136
137
  return True
137
138
 
138
139
  async def generate_cache_key(
139
- self, request: Request, cache_config: CacheConfig
140
+ self, request: Request, cache_configuration: CacheConfiguration
140
141
  ) -> str:
141
142
  """Generates cache key for request.
142
143
 
@@ -148,8 +149,8 @@ class Controller:
148
149
  str: Cache key
149
150
  """
150
151
  # Use custom key generation function if available
151
- if cache_config.key_func:
152
- return cache_config.key_func(request)
152
+ if cache_configuration.key_func:
153
+ return cache_configuration.key_func(request)
153
154
 
154
155
  # Use standard function
155
156
  return generate_key(request)
@@ -160,7 +161,7 @@ class Controller:
160
161
  request: Request,
161
162
  response: Response,
162
163
  storage: BaseStorage,
163
- ttl: tp.Optional[int] = None,
164
+ ttl: Optional[int] = None,
164
165
  ) -> None:
165
166
  """Saves response to cache.
166
167
 
@@ -180,7 +181,7 @@ class Controller:
180
181
 
181
182
  async def get_cached_response(
182
183
  self, cache_key: str, storage: BaseStorage
183
- ) -> tp.Optional[Response]:
184
+ ) -> Optional[Response]:
184
185
  """Gets cached response if it exists and is valid.
185
186
 
186
187
  Args:
@@ -198,13 +199,13 @@ class Controller:
198
199
 
199
200
  async def invalidate_cache(
200
201
  self,
201
- cache_drop_config: CacheDropConfig,
202
+ invalidate_paths: list[re.Pattern],
202
203
  storage: BaseStorage,
203
204
  ) -> None:
204
205
  """Invalidates cache by configuration.
205
206
 
206
207
  Args:
207
- cache_drop_config: Cache invalidation configuration
208
+ invalidate_paths: List of regex patterns for cache invalidation
208
209
  storage: Cache storage
209
210
 
210
211
  TODO: Comments on improvements:
@@ -226,6 +227,6 @@ class Controller:
226
227
  5. Add tag support for grouping related caches
227
228
  and their joint invalidation
228
229
  """
229
- for path in cache_drop_config.paths:
230
+ for path in invalidate_paths:
230
231
  await storage.remove(path)
231
232
  logger.info("Invalidated cache for pattern: %s", path.pattern)
@@ -1,5 +1,5 @@
1
1
  import re
2
- import typing as tp
2
+ from typing import Callable, Optional
3
3
 
4
4
  from fastapi import params
5
5
  from starlette.requests import Request
@@ -15,23 +15,7 @@ class BaseCacheConfigDepends(params.Depends):
15
15
  use_cache: bool = True
16
16
 
17
17
  def __call__(self, request: Request) -> None:
18
- """Saves configuration in ASGI scope extensions.
19
-
20
- Args:
21
- request: HTTP request
22
- """
23
- # Use standard ASGI extensions mechanism
24
- if "extensions" not in request.scope:
25
- request.scope["extensions"] = {}
26
-
27
- if "fast_cache" not in request.scope["extensions"]:
28
- request.scope["extensions"]["fast_cache"] = {}
29
-
30
- request.scope["extensions"]["fast_cache"]["config"] = self
31
-
32
- @property
33
- def dependency(self) -> params.Depends:
34
- return self
18
+ pass
35
19
 
36
20
 
37
21
  class CacheConfig(BaseCacheConfigDepends):
@@ -45,11 +29,13 @@ class CacheConfig(BaseCacheConfigDepends):
45
29
  def __init__(
46
30
  self,
47
31
  max_age: int = 5 * 60,
48
- key_func: tp.Optional[tp.Callable[[Request], str]] = None,
32
+ key_func: Optional[Callable[[Request], str]] = None,
49
33
  ) -> None:
50
34
  self.max_age = max_age
51
35
  self.key_func = key_func
52
36
 
37
+ self.dependency = self
38
+
53
39
 
54
40
  class CacheDropConfig(BaseCacheConfigDepends):
55
41
  """Cache invalidation configuration for route.
@@ -64,3 +50,5 @@ class CacheDropConfig(BaseCacheConfigDepends):
64
50
  self.paths: list[re.Pattern] = [
65
51
  p if isinstance(p, re.Pattern) else re.compile(f"^{p}") for p in paths
66
52
  ]
53
+
54
+ self.dependency = self
@@ -1,23 +1,142 @@
1
1
  import copy
2
- import inspect
3
2
  import logging
4
3
  import typing as tp
5
- import cachetools
6
4
 
7
5
  from fastapi import FastAPI, routing
8
6
  from starlette.requests import Request
9
7
  from starlette.responses import Response
10
- from starlette.routing import Mount, get_route_path
8
+ from starlette.routing import Match, Mount
11
9
  from starlette.types import ASGIApp, Receive, Scope, Send
12
10
 
11
+ from ._helpers import set_cache_age_in_openapi_schema
13
12
  from .controller import Controller
14
13
  from .depends import BaseCacheConfigDepends, CacheConfig, CacheDropConfig
15
- from .schemas import RouteInfo
14
+ from .schemas import CacheConfiguration, RouteInfo
16
15
  from .storages import BaseStorage, InMemoryStorage
17
16
 
18
17
  logger = logging.getLogger(__name__)
19
18
 
20
19
 
20
+ class BaseMiddleware:
21
+ def __init__(
22
+ self,
23
+ app: ASGIApp,
24
+ ) -> None:
25
+ self.app = app
26
+
27
+ self.executors_map = {
28
+ "lifespan": self.on_lifespan,
29
+ "http": self.on_http,
30
+ }
31
+
32
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
33
+ scope_type = scope["type"]
34
+ try:
35
+ is_request_processed = await self.executors_map[scope_type](
36
+ scope, receive, send
37
+ )
38
+ except KeyError:
39
+ logger.debug("Not supported scope type: %s", scope_type)
40
+ is_request_processed = False
41
+
42
+ if not is_request_processed:
43
+ await self.app(scope, receive, send)
44
+
45
+ async def on_lifespan(
46
+ self, scope: Scope, receive: Receive, send: Send
47
+ ) -> bool | None:
48
+ pass
49
+
50
+ async def on_http(self, scope: Scope, receive: Receive, send: Send) -> bool | None:
51
+ pass
52
+
53
+
54
+ class BaseSendWrapper:
55
+ def __init__(self, app: ASGIApp, scope: Scope, receive: Receive, send: Send):
56
+ self.app = app
57
+ self.scope = scope
58
+ self.receive = receive
59
+ self.send = send
60
+
61
+ self._response_status: int = 200
62
+ self._response_headers: dict[str, str] = dict()
63
+ self._response_body: bytes = b""
64
+
65
+ self.executors_map = {
66
+ "http.response.start": self.on_response_start,
67
+ "http.response.body": self.on_response_body,
68
+ }
69
+
70
+ async def __call__(self) -> None:
71
+ return await self.app(self.scope, self.receive, self._message_processor)
72
+
73
+ async def _message_processor(self, message: tp.MutableMapping[str, tp.Any]) -> None:
74
+ try:
75
+ executor = self.executors_map[message["type"]]
76
+ except KeyError:
77
+ logger.error("Not found executor for %s message type", message["type"])
78
+ else:
79
+ await executor(message)
80
+
81
+ await self.send(message)
82
+
83
+ async def on_response_start(self, message: tp.MutableMapping[str, tp.Any]) -> None:
84
+ self._response_status = message["status"]
85
+ self._response_headers = {
86
+ k.decode(): v.decode() for k, v in message.get("headers", [])
87
+ }
88
+
89
+ async def on_response_body(self, message: tp.MutableMapping[str, tp.Any]) -> None:
90
+ self._response_body += message.get("body", b"")
91
+
92
+ # this is the last chunk
93
+ if not message.get("more_body", False):
94
+ response = Response(
95
+ content=self._response_body,
96
+ status_code=self._response_status,
97
+ headers=self._response_headers,
98
+ )
99
+ await self.on_response_ready(response)
100
+
101
+ async def on_response_ready(self, response: Response) -> None:
102
+ pass
103
+
104
+
105
+ class CacheSendWrapper(BaseSendWrapper):
106
+ def __init__(
107
+ self,
108
+ controller: Controller,
109
+ storage: BaseStorage,
110
+ request: Request,
111
+ cache_key: str,
112
+ ttl: int,
113
+ app: ASGIApp,
114
+ scope: Scope,
115
+ receive: Receive,
116
+ send: Send,
117
+ ) -> None:
118
+ super().__init__(app, scope, receive, send)
119
+
120
+ self.controller = controller
121
+ self.storage = storage
122
+ self.request = request
123
+ self.cache_key = cache_key
124
+ self.ttl = ttl
125
+
126
+ async def on_response_start(self, message: tp.MutableMapping[str, tp.Any]) -> None:
127
+ message.get("headers", []).append(("X-Cache-Status".encode(), "MISS".encode()))
128
+ return await super().on_response_start(message)
129
+
130
+ async def on_response_ready(self, response: Response) -> None:
131
+ await self.controller.cache_response(
132
+ cache_key=self.cache_key,
133
+ request=self.request,
134
+ response=response,
135
+ storage=self.storage,
136
+ ttl=self.ttl,
137
+ )
138
+
139
+
21
140
  def get_app_routes(app: FastAPI) -> tp.List[routing.APIRoute]:
22
141
  """Gets all routes from FastAPI application.
23
142
 
@@ -68,57 +187,7 @@ def get_routes(router: routing.APIRouter) -> list[routing.APIRoute]:
68
187
  return routes
69
188
 
70
189
 
71
- async def send_with_callbacks(
72
- app: ASGIApp,
73
- scope: Scope,
74
- receive: Receive,
75
- send: Send,
76
- on_response_ready: tp.Callable[[Response], tp.Awaitable[None]] | None = None,
77
- ) -> None:
78
- response_holder: tp.Dict[str, tp.Any] = {}
79
-
80
- async def response_builder(message: tp.Dict[str, tp.Any]) -> None:
81
- """Wrapper for intercepting and saving response."""
82
- if message["type"] == "http.response.start":
83
- response_holder["status"] = message["status"]
84
-
85
- message.get("headers", []).append(
86
- ("X-Cache-Status".encode(), "MISS".encode())
87
- )
88
- response_holder["headers"] = [
89
- (k.decode(), v.decode()) for k, v in message.get("headers", [])
90
- ]
91
-
92
- response_holder["body"] = b""
93
- elif message["type"] == "http.response.body":
94
- body = message.get("body", b"")
95
- response_holder["body"] += body
96
-
97
- # If this is the last chunk, cache the response
98
- if not message.get("more_body", False):
99
- response = Response(
100
- content=response_holder["body"],
101
- status_code=response_holder["status"],
102
- headers=dict(response_holder["headers"]),
103
- )
104
-
105
- # Call callback with ready response
106
- if on_response_ready:
107
- await on_response_ready(response)
108
-
109
- # Pass event further
110
- await send(message)
111
-
112
- await app(scope, receive, response_builder)
113
-
114
-
115
- def _build_scope_hash_key(scope: Scope) -> str:
116
- path = get_route_path(scope)
117
- method = scope["method"].upper()
118
- return f"{path}/{method}"
119
-
120
-
121
- class FastCacheMiddleware:
190
+ class FastCacheMiddleware(BaseMiddleware):
122
191
  """Middleware for caching responses in ASGI applications.
123
192
 
124
193
  Route resolution approach:
@@ -145,12 +214,79 @@ class FastCacheMiddleware:
145
214
  storage: tp.Optional[BaseStorage] = None,
146
215
  controller: tp.Optional[Controller] = None,
147
216
  ) -> None:
148
- self.app = app
217
+ super().__init__(app)
218
+
149
219
  self.storage = storage or InMemoryStorage()
150
220
  self.controller = controller or Controller()
221
+ self._openapi_initialized = False
151
222
 
152
223
  self._routes_info: list[RouteInfo] = []
153
224
 
225
+ current_app: tp.Any = app
226
+ while current_app := getattr(current_app, "app", None):
227
+ if isinstance(current_app, routing.APIRouter):
228
+ _routes = get_routes(current_app)
229
+ self._routes_info = self._extract_routes_info(_routes)
230
+ break
231
+
232
+ async def on_lifespan(self, scope: Scope, _: Receive, __: Send) -> bool | None:
233
+ app_routes = get_app_routes(scope["app"])
234
+ set_cache_age_in_openapi_schema(scope["app"])
235
+ self._routes_info = self._extract_routes_info(app_routes)
236
+ return None
237
+
238
+ async def on_http(self, scope: Scope, receive: Receive, send: Send) -> bool | None:
239
+ request = Request(scope, receive)
240
+
241
+ if not self._openapi_initialized:
242
+ set_cache_age_in_openapi_schema(scope["app"])
243
+ self._openapi_initialized = True
244
+
245
+ # Find matching route
246
+ route_info = self._find_matching_route(request, self._routes_info)
247
+ if not route_info:
248
+ return None
249
+
250
+ cache_configuration = route_info.cache_config
251
+
252
+ # Handle invalidation if specified
253
+ if cache_configuration.invalidate_paths:
254
+ await self.controller.invalidate_cache(
255
+ cache_configuration.invalidate_paths, storage=self.storage
256
+ )
257
+
258
+ if not cache_configuration.max_age:
259
+ return None
260
+
261
+ if not await self.controller.is_cachable_request(request):
262
+ return None
263
+
264
+ cache_key = await self.controller.generate_cache_key(
265
+ request, cache_configuration=cache_configuration
266
+ )
267
+
268
+ cached_response = await self.controller.get_cached_response(
269
+ cache_key, self.storage
270
+ )
271
+ if cached_response is not None:
272
+ logger.debug("Returning cached response for key: %s", cache_key)
273
+ await cached_response(scope, receive, send)
274
+ return True
275
+
276
+ # Cache not found - execute request and cache result
277
+ await CacheSendWrapper(
278
+ app=self.app,
279
+ scope=scope,
280
+ receive=receive,
281
+ send=send,
282
+ controller=self.controller,
283
+ storage=self.storage,
284
+ request=request,
285
+ cache_key=cache_key,
286
+ ttl=cache_configuration.max_age,
287
+ )()
288
+ return True
289
+
154
290
  def _extract_routes_info(self, routes: list[routing.APIRoute]) -> list[RouteInfo]:
155
291
  """Recursively extracts route information and their dependencies.
156
292
 
@@ -165,10 +301,17 @@ class FastCacheMiddleware:
165
301
  ) = self._extract_cache_configs_from_route(route)
166
302
 
167
303
  if cache_config or cache_drop_config:
304
+ cache_configuration = CacheConfiguration(
305
+ max_age=cache_config.max_age if cache_config else None,
306
+ key_func=cache_config.key_func if cache_config else None,
307
+ invalidate_paths=(
308
+ cache_drop_config.paths if cache_drop_config else None
309
+ ),
310
+ )
311
+
168
312
  route_info = RouteInfo(
169
313
  route=route,
170
- cache_config=cache_config,
171
- cache_drop_config=cache_drop_config,
314
+ cache_config=cache_configuration,
172
315
  )
173
316
  routes_info.append(route_info)
174
317
 
@@ -205,10 +348,6 @@ class FastCacheMiddleware:
205
348
 
206
349
  return cache_config, cache_drop_config
207
350
 
208
- @cachetools.cached(
209
- cache=cachetools.LRUCache(maxsize=10**3),
210
- key=lambda _, request, __: _build_scope_hash_key(request.scope),
211
- )
212
351
  def _find_matching_route(
213
352
  self, request: Request, routes_info: list[RouteInfo]
214
353
  ) -> tp.Optional[RouteInfo]:
@@ -224,83 +363,7 @@ class FastCacheMiddleware:
224
363
  if request.method not in route_info.methods:
225
364
  continue
226
365
  match_mode, _ = route_info.route.matches(request.scope)
227
- if match_mode == routing.Match.FULL:
366
+ if match_mode == Match.FULL:
228
367
  return route_info
229
368
 
230
- return
231
-
232
- async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
233
- if scope["type"] != "http":
234
- await self.app(scope, receive, send)
235
- return
236
-
237
- if not self._routes_info:
238
- app_routes = get_app_routes(scope["app"])
239
- self._routes_info = self._extract_routes_info(app_routes)
240
-
241
- request = Request(scope, receive)
242
-
243
- # Find matching route
244
- route_info = self._find_matching_route(request, self._routes_info)
245
- if not route_info:
246
- await self.app(scope, receive, send)
247
- return
248
-
249
- # Handle invalidation if specified
250
- if cc := route_info.cache_drop_config:
251
- await self.controller.invalidate_cache(cc, storage=self.storage)
252
-
253
- # Handle caching if config exists
254
- if route_info.cache_config:
255
- await self._handle_cache_request(route_info, request, scope, receive, send)
256
- return
257
-
258
- # Execute original request
259
- await self.app(scope, receive, send)
260
-
261
- async def _handle_cache_request(
262
- self,
263
- route_info: RouteInfo,
264
- request: Request,
265
- scope: Scope,
266
- receive: Receive,
267
- send: Send,
268
- ) -> None:
269
- """Handles request with caching.
270
-
271
- Args:
272
- route_info: Route information
273
- request: HTTP request
274
- scope: ASGI scope
275
- receive: ASGI receive callable
276
- send: ASGI send callable
277
- """
278
- cache_config = route_info.cache_config
279
- if not cache_config:
280
- await self.app(scope, receive, send)
281
- return
282
-
283
- if not await self.controller.is_cachable_request(request):
284
- await self.app(scope, receive, send)
285
- return
286
-
287
- cache_key = await self.controller.generate_cache_key(request, cache_config)
288
-
289
- cached_response = await self.controller.get_cached_response(
290
- cache_key, self.storage
291
- )
292
- if cached_response is not None:
293
- logger.debug("Returning cached response for key: %s", cache_key)
294
- await cached_response(scope, receive, send)
295
- return
296
-
297
- # Cache not found - execute request and cache result
298
- await send_with_callbacks(
299
- self.app,
300
- scope,
301
- receive,
302
- send,
303
- lambda response: self.controller.cache_response(
304
- cache_key, request, response, self.storage, cache_config.max_age
305
- ),
306
- )
369
+ return None
@@ -1,21 +1,78 @@
1
- import typing as tp
1
+ import re
2
+ from typing import Any, Callable
2
3
 
4
+ from pydantic import (
5
+ BaseModel,
6
+ ConfigDict,
7
+ Field,
8
+ computed_field,
9
+ field_validator,
10
+ model_validator,
11
+ )
12
+ from starlette.requests import Request
3
13
  from starlette.routing import Route
4
14
 
5
- from .depends import BaseCacheConfigDepends
15
+ from .depends import CacheConfig, CacheDropConfig
6
16
 
7
17
 
8
- class RouteInfo:
18
+ class CacheConfiguration(BaseModel):
19
+ model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
20
+
21
+ max_age: int | None = Field(
22
+ default=None,
23
+ description="Cache lifetime in seconds. If None, caching is disabled.",
24
+ )
25
+ key_func: Callable[[Request], str] | None = Field(
26
+ default=None,
27
+ description="Custom cache key generation function. If None, default key generation is used.",
28
+ )
29
+ invalidate_paths: list[re.Pattern] | None = Field(
30
+ default=None,
31
+ description="Paths for cache invalidation (strings or regex patterns). No invalidation if None.",
32
+ )
33
+
34
+ @model_validator(mode="after")
35
+ def one_of_field_is_set(self) -> "CacheConfiguration":
36
+ if (
37
+ self.max_age is None
38
+ and self.key_func is None
39
+ and self.invalidate_paths is None
40
+ ):
41
+ raise ValueError(
42
+ "At least one of max_age, key_func, or invalidate_paths must be set."
43
+ )
44
+ return self
45
+
46
+ @field_validator("invalidate_paths")
47
+ @classmethod
48
+ def compile_paths(cls, item: Any) -> Any:
49
+ if item is None:
50
+ return None
51
+ if isinstance(item, str):
52
+ return re.compile(f"^{item}")
53
+ if isinstance(item, re.Pattern):
54
+ return item
55
+ if isinstance(item, list):
56
+ return [cls.compile_paths(i) for i in item]
57
+ raise ValueError(
58
+ "invalidate_paths must be a string, regex pattern, or list of them."
59
+ )
60
+
61
+
62
+ class RouteInfo(BaseModel):
9
63
  """Route information with cache configuration."""
10
64
 
11
- def __init__(
12
- self,
13
- route: Route,
14
- cache_config: tp.Optional[BaseCacheConfigDepends] = None,
15
- cache_drop_config: tp.Optional[BaseCacheConfigDepends] = None,
16
- ):
17
- self.route = route
18
- self.cache_config = cache_config
19
- self.cache_drop_config = cache_drop_config
20
- self.path: str = getattr(route, "path")
21
- self.methods: tp.Set[str] = getattr(route, "methods", set())
65
+ model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
66
+
67
+ route: Route
68
+ cache_config: CacheConfiguration
69
+
70
+ @computed_field # type: ignore[prop-decorator]
71
+ @property
72
+ def path(self) -> str:
73
+ return getattr(self.route, "path", "")
74
+
75
+ @computed_field # type: ignore[prop-decorator]
76
+ @property
77
+ def methods(self) -> set[str]:
78
+ return getattr(self.route, "methods", set())