litestar-vite 0.12.1__py3-none-any.whl → 0.13.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of litestar-vite might be problematic. Click here for more details.

@@ -0,0 +1,99 @@
1
+ from enum import Enum
2
+ from typing import TYPE_CHECKING, Any, Callable
3
+
4
+ if TYPE_CHECKING:
5
+ from litestar_vite.inertia.types import InertiaHeaderType
6
+
7
+
8
+ class InertiaHeaders(str, Enum):
9
+ """Enum for Inertia Headers"""
10
+
11
+ ENABLED = "X-Inertia"
12
+ VERSION = "X-Inertia-Version"
13
+ PARTIAL_DATA = "X-Inertia-Partial-Data"
14
+ PARTIAL_COMPONENT = "X-Inertia-Partial-Component"
15
+ LOCATION = "X-Inertia-Location"
16
+ REFERER = "Referer"
17
+
18
+
19
+ def get_enabled_header(enabled: bool = True) -> "dict[str, Any]":
20
+ """True if inertia is enabled.
21
+
22
+ Args:
23
+ enabled: Whether inertia is enabled.
24
+
25
+ Returns:
26
+ The headers for inertia.
27
+ """
28
+
29
+ return {InertiaHeaders.ENABLED.value: "true" if enabled else "false"}
30
+
31
+
32
+ def get_version_header(version: str) -> "dict[str, Any]":
33
+ """Return headers for change swap method response.
34
+
35
+ Args:
36
+ version: The version of the inertia.
37
+
38
+ Returns:
39
+ The headers for inertia.
40
+ """
41
+ return {InertiaHeaders.VERSION.value: version}
42
+
43
+
44
+ def get_partial_data_header(partial: str) -> "dict[str, Any]":
45
+ """Return headers for a partial data response.
46
+
47
+ Args:
48
+ partial: The partial data.
49
+
50
+ Returns:
51
+ The headers for inertia.
52
+ """
53
+ return {InertiaHeaders.PARTIAL_DATA.value: partial}
54
+
55
+
56
+ def get_partial_component_header(partial: str) -> "dict[str, Any]":
57
+ """Return headers for a partial data response.
58
+
59
+ Args:
60
+ partial: The partial data.
61
+
62
+ Returns:
63
+ The headers for inertia.
64
+ """
65
+ return {InertiaHeaders.PARTIAL_COMPONENT.value: partial}
66
+
67
+
68
+ def get_headers(inertia_headers: "InertiaHeaderType") -> "dict[str, Any]":
69
+ """Return headers for Inertia responses.
70
+
71
+ Args:
72
+ inertia_headers: The inertia headers.
73
+
74
+ Raises:
75
+ ValueError: If the inertia headers are None.
76
+
77
+ Returns:
78
+ The headers for inertia.
79
+ """
80
+ if not inertia_headers:
81
+ msg = "Value for inertia_headers cannot be None."
82
+ raise ValueError(msg)
83
+ inertia_headers_dict: "dict[str, Callable[..., dict[str, Any]]]" = {
84
+ "enabled": get_enabled_header,
85
+ "partial_data": get_partial_data_header,
86
+ "partial_component": get_partial_component_header,
87
+ "version": get_version_header,
88
+ }
89
+
90
+ header: "dict[str, Any]" = {}
91
+ response: "dict[str, Any]"
92
+ key: "str"
93
+ value: "Any"
94
+
95
+ for key, value in inertia_headers.items():
96
+ if value is not None:
97
+ response = inertia_headers_dict[key](value)
98
+ header.update(response)
99
+ return header
@@ -0,0 +1,27 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Any, Optional
3
+
4
+ __all__ = ("InertiaConfig",)
5
+
6
+
7
+ @dataclass
8
+ class InertiaConfig:
9
+ """Configuration for InertiaJS support."""
10
+
11
+ root_template: str = "index.html"
12
+ """Name of the root template to use.
13
+
14
+ This must be a path that is found by the Vite Plugin template config
15
+ """
16
+ component_opt_key: str = "component"
17
+ """An identifier to use on routes to get the inertia component to render."""
18
+ exclude_from_js_routes_key: str = "exclude_from_routes"
19
+ """An identifier to use on routes to exclude a route from the generated routes typescript file."""
20
+ redirect_unauthorized_to: "Optional[str]" = None
21
+ """Optionally supply a path where unauthorized requests should redirect."""
22
+ redirect_404: "Optional[str]" = None
23
+ """Optionally supply a path where 404 requests should redirect."""
24
+ extra_static_page_props: "dict[str, Any]" = field(default_factory=dict)
25
+ """A dictionary of values to automatically add in to page props on every response."""
26
+ extra_session_page_props: "set[str]" = field(default_factory=set)
27
+ """A set of session keys for which the value automatically be added (if it exists) to the response."""
@@ -0,0 +1,131 @@
1
+ import re
2
+ from typing import TYPE_CHECKING, Any, cast
3
+
4
+ from litestar import MediaType
5
+ from litestar.connection import Request
6
+ from litestar.connection.base import AuthT, StateT, UserT
7
+ from litestar.exceptions import (
8
+ HTTPException,
9
+ ImproperlyConfiguredException,
10
+ InternalServerException,
11
+ NotAuthorizedException,
12
+ NotFoundException,
13
+ PermissionDeniedException,
14
+ )
15
+ from litestar.exceptions.responses import (
16
+ create_debug_response, # pyright: ignore[reportUnknownVariableType]
17
+ create_exception_response, # pyright: ignore[reportUnknownVariableType]
18
+ )
19
+ from litestar.plugins.flash import flash
20
+ from litestar.repository.exceptions import (
21
+ ConflictError, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue]
22
+ NotFoundError, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue]
23
+ RepositoryError, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue]
24
+ )
25
+ from litestar.response import Response
26
+ from litestar.status_codes import (
27
+ HTTP_400_BAD_REQUEST,
28
+ HTTP_401_UNAUTHORIZED,
29
+ HTTP_404_NOT_FOUND,
30
+ HTTP_405_METHOD_NOT_ALLOWED,
31
+ HTTP_409_CONFLICT,
32
+ HTTP_422_UNPROCESSABLE_ENTITY,
33
+ HTTP_500_INTERNAL_SERVER_ERROR,
34
+ )
35
+
36
+ from litestar_vite.inertia.helpers import error
37
+ from litestar_vite.inertia.response import InertiaBack, InertiaRedirect, InertiaResponse
38
+
39
+ if TYPE_CHECKING:
40
+ from litestar.connection import Request
41
+ from litestar.connection.base import AuthT, StateT, UserT
42
+ from litestar.response import Response
43
+
44
+ from litestar_vite.inertia.plugin import InertiaPlugin
45
+
46
+ FIELD_ERR_RE = re.compile(r"field `(.+)`$")
47
+
48
+
49
+ class _HTTPConflictException(HTTPException):
50
+ """Request conflict with the current state of the target resource."""
51
+
52
+ status_code: int = HTTP_409_CONFLICT
53
+
54
+
55
+ def exception_to_http_response(request: "Request[UserT, AuthT, StateT]", exc: "Exception") -> "Response[Any]":
56
+ """Handler for all exceptions subclassed from HTTPException.
57
+
58
+ Args:
59
+ request: The request object.
60
+ exc: The exception to handle.
61
+
62
+ Returns:
63
+ The response object.
64
+ """
65
+ inertia_enabled = getattr(request, "inertia_enabled", False) or getattr(request, "is_inertia", False)
66
+
67
+ if not inertia_enabled:
68
+ if isinstance(exc, NotFoundError):
69
+ http_exc = NotFoundException
70
+ elif isinstance(exc, (RepositoryError, ConflictError)):
71
+ http_exc = _HTTPConflictException # type: ignore[assignment]
72
+ else:
73
+ http_exc = InternalServerException # type: ignore[assignment]
74
+ if request.app.debug and http_exc not in {PermissionDeniedException, NotFoundError}:
75
+ return cast("Response[Any]", create_debug_response(request, exc))
76
+ return cast("Response[Any]", create_exception_response(request, http_exc(detail=str(exc.__cause__)))) # pyright: ignore[reportUnknownArgumentType]
77
+ return create_inertia_exception_response(request, exc)
78
+
79
+
80
+ def create_inertia_exception_response(request: "Request[UserT, AuthT, StateT]", exc: "Exception") -> "Response[Any]":
81
+ """Create the inertia exception response.
82
+
83
+ Args:
84
+ request: The request object.
85
+ exc: The exception to handle.
86
+
87
+ Returns:
88
+ The response object.
89
+ """
90
+ is_inertia = getattr(request, "is_inertia", False)
91
+ status_code = getattr(exc, "status_code", HTTP_500_INTERNAL_SERVER_ERROR)
92
+ preferred_type = MediaType.HTML if not is_inertia else MediaType.JSON
93
+ detail = getattr(exc, "detail", "") # litestar exceptions
94
+ extras = getattr(exc, "extra", "") # msgspec exceptions
95
+ content: dict[str, Any] = {"status_code": status_code, "message": getattr(exc, "detail", "")}
96
+ inertia_plugin = cast("InertiaPlugin", request.app.plugins.get("InertiaPlugin"))
97
+ if extras:
98
+ content.update({"extra": extras})
99
+ try:
100
+ flash(request, detail, category="error")
101
+ except (AttributeError, ImproperlyConfiguredException):
102
+ msg = "Unable to set `flash` session state. A valid session was not found for this request."
103
+ request.logger.warning(msg)
104
+ if extras and len(extras) >= 1:
105
+ message = extras[0]
106
+ default_field = f"root.{message.get('key')}" if message.get("key", None) is not None else "root" # type: ignore
107
+ error_detail = cast("str", message.get("message", detail)) # type: ignore[union-attr] # pyright: ignore[reportUnknownMemberType]
108
+ match = FIELD_ERR_RE.search(error_detail)
109
+ field = match.group(1) if match else default_field
110
+ if isinstance(message, dict):
111
+ error(request, field, error_detail or detail)
112
+ if status_code in {HTTP_422_UNPROCESSABLE_ENTITY, HTTP_400_BAD_REQUEST}:
113
+ return InertiaBack(request)
114
+ if isinstance(exc, PermissionDeniedException):
115
+ return InertiaBack(request)
116
+ if (status_code == HTTP_401_UNAUTHORIZED or isinstance(exc, NotAuthorizedException)) and (
117
+ inertia_plugin.config.redirect_unauthorized_to is not None
118
+ and request.url.path != inertia_plugin.config.redirect_unauthorized_to
119
+ ):
120
+ return InertiaRedirect(request, redirect_to=inertia_plugin.config.redirect_unauthorized_to)
121
+
122
+ if status_code in {HTTP_404_NOT_FOUND, HTTP_405_METHOD_NOT_ALLOWED} and (
123
+ inertia_plugin.config.redirect_404 is not None and request.url.path != inertia_plugin.config.redirect_404
124
+ ):
125
+ return InertiaRedirect(request, redirect_to=inertia_plugin.config.redirect_404)
126
+
127
+ return InertiaResponse[Any](
128
+ media_type=preferred_type,
129
+ content=content,
130
+ status_code=status_code,
131
+ )
@@ -0,0 +1,325 @@
1
+ import inspect
2
+ from collections import defaultdict
3
+ from collections.abc import Coroutine, Generator, Iterable, Mapping
4
+ from contextlib import contextmanager
5
+ from functools import lru_cache
6
+ from textwrap import dedent
7
+ from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, Union, cast, overload
8
+
9
+ from anyio.from_thread import BlockingPortal, start_blocking_portal
10
+ from litestar.exceptions import ImproperlyConfiguredException
11
+ from litestar.utils.empty import value_or_default
12
+ from litestar.utils.scope.state import ScopeState
13
+ from markupsafe import Markup
14
+ from typing_extensions import ParamSpec, TypeGuard
15
+
16
+ if TYPE_CHECKING:
17
+ from litestar.connection import ASGIConnection
18
+
19
+ from litestar_vite.inertia.plugin import InertiaPlugin
20
+ from litestar_vite.inertia.routes import Routes
21
+
22
+ T = TypeVar("T")
23
+ T_ParamSpec = ParamSpec("T_ParamSpec")
24
+ PropKeyT = TypeVar("PropKeyT", bound=str)
25
+ StaticT = TypeVar("StaticT", bound=object)
26
+
27
+
28
+ @overload
29
+ def lazy(key: str, value_or_callable: "None") -> "StaticProp[str, None]": ...
30
+
31
+
32
+ @overload
33
+ def lazy(key: str, value_or_callable: "T") -> "StaticProp[str, T]": ...
34
+
35
+
36
+ @overload
37
+ def lazy(key: str, value_or_callable: "Callable[..., None]" = ...) -> "DeferredProp[str, None]": ...
38
+
39
+
40
+ @overload
41
+ def lazy(
42
+ key: str, value_or_callable: "Callable[..., Coroutine[Any, Any, None]]" = ...
43
+ ) -> "DeferredProp[str, None]": ...
44
+
45
+
46
+ @overload
47
+ def lazy(
48
+ key: str,
49
+ value_or_callable: "Callable[..., Union[T, Coroutine[Any, Any, T]]]" = ..., # pyright: ignore[reportInvalidTypeVarUse]
50
+ ) -> "DeferredProp[str, T]": ...
51
+
52
+
53
+ def lazy(
54
+ key: str,
55
+ value_or_callable: "Optional[Union[T, Callable[..., Coroutine[Any, Any, None]], Callable[..., T], Callable[..., Union[T, Coroutine[Any, Any, T]]]]]" = None,
56
+ ) -> "Union[StaticProp[str, None], StaticProp[str, T], DeferredProp[str, T], DeferredProp[str, None]]":
57
+ """Wrap an async function to return a DeferredProp.
58
+
59
+ Args:
60
+ key: The key to store the value under.
61
+ value_or_callable: The value or callable to store.
62
+
63
+ Returns:
64
+ The wrapped value or callable.
65
+ """
66
+ if value_or_callable is None:
67
+ return StaticProp[str, None](key=key, value=None)
68
+
69
+ if not callable(value_or_callable):
70
+ return StaticProp[str, T](key=key, value=value_or_callable)
71
+
72
+ return DeferredProp[str, T](key=key, value=cast("Callable[..., T | Coroutine[Any, Any, T]]", value_or_callable))
73
+
74
+
75
+ class StaticProp(Generic[PropKeyT, StaticT]):
76
+ """A wrapper for static property evaluation."""
77
+
78
+ def __init__(self, key: "PropKeyT", value: "StaticT") -> None:
79
+ self._key = key
80
+ self._result = value
81
+
82
+ @property
83
+ def key(self) -> "PropKeyT":
84
+ return self._key
85
+
86
+ def render(self, portal: "Optional[BlockingPortal]" = None) -> "StaticT":
87
+ return self._result
88
+
89
+
90
+ class DeferredProp(Generic[PropKeyT, T]):
91
+ """A wrapper for deferred property evaluation."""
92
+
93
+ def __init__(
94
+ self, key: "PropKeyT", value: "Optional[Callable[..., Optional[Union[T, Coroutine[Any, Any, T]]]]]" = None
95
+ ) -> None:
96
+ self._key = key
97
+ self._value = value
98
+ self._evaluated = False
99
+ self._result: "Optional[T]" = None
100
+
101
+ @property
102
+ def key(self) -> "PropKeyT":
103
+ return self._key
104
+
105
+ @staticmethod
106
+ @contextmanager
107
+ def with_portal(portal: "Optional[BlockingPortal]" = None) -> "Generator[BlockingPortal, None, None]":
108
+ if portal is None:
109
+ with start_blocking_portal() as p:
110
+ yield p
111
+ else:
112
+ yield portal
113
+
114
+ @staticmethod
115
+ def _is_awaitable(
116
+ v: "Callable[..., Union[T, Coroutine[Any, Any, T]]]",
117
+ ) -> "TypeGuard[Coroutine[Any, Any, T]]":
118
+ return inspect.iscoroutinefunction(v)
119
+
120
+ def render(self, portal: "Optional[BlockingPortal]" = None) -> "Union[T, None]":
121
+ if self._evaluated:
122
+ return self._result
123
+ if self._value is None or not callable(self._value):
124
+ self._result = self._value
125
+ self._evaluated = True
126
+ return self._result
127
+ if not self._is_awaitable(cast("Callable[..., T]", self._value)):
128
+ self._result = cast("T", self._value())
129
+ self._evaluated = True
130
+ return self._result
131
+ with self.with_portal(portal) as p:
132
+ self._result = p.call(cast("Callable[..., T]", self._value))
133
+ self._evaluated = True
134
+ return self._result
135
+
136
+
137
+ def is_lazy_prop(value: "Any") -> "TypeGuard[Union[DeferredProp[Any, Any], StaticProp[Any, Any]]]":
138
+ """Check if value is a deferred property.
139
+
140
+ Args:
141
+ value: Any value to check
142
+
143
+ Returns:
144
+ bool: True if value is a deferred property
145
+ """
146
+ return isinstance(value, (DeferredProp, StaticProp))
147
+
148
+
149
+ def should_render(value: "Any", partial_data: "Optional[set[str]]" = None) -> "bool":
150
+ """Check if value should be rendered.
151
+
152
+ Args:
153
+ value: Any value to check
154
+ partial_data: Optional set of keys for partial rendering
155
+
156
+ Returns:
157
+ bool: True if value should be rendered
158
+ """
159
+ partial_data = partial_data or set()
160
+ if is_lazy_prop(value):
161
+ return value.key in partial_data
162
+ return True
163
+
164
+
165
+ def is_or_contains_lazy_prop(value: "Any") -> "bool":
166
+ """Check if value is or contains a deferred property.
167
+
168
+ Args:
169
+ value: Any value to check
170
+
171
+ Returns:
172
+ bool: True if value is or contains a deferred property
173
+ """
174
+ if is_lazy_prop(value):
175
+ return True
176
+ if isinstance(value, str):
177
+ return False
178
+ if isinstance(value, Mapping):
179
+ return any(is_or_contains_lazy_prop(v) for v in cast("Mapping[str, Any]", value).values())
180
+ if isinstance(value, Iterable):
181
+ return any(is_or_contains_lazy_prop(v) for v in cast("Iterable[Any]", value))
182
+ return False
183
+
184
+
185
+ def lazy_render(
186
+ value: "T", partial_data: "Optional[set[str]]" = None, portal: "Optional[BlockingPortal]" = None
187
+ ) -> "T":
188
+ """Filter deferred properties from the value based on partial data.
189
+
190
+ Args:
191
+ value: The value to filter
192
+ partial_data: Keys for partial rendering
193
+ portal: Optional portal to use for async rendering
194
+
195
+ Returns:
196
+ The filtered value
197
+ """
198
+ partial_data = partial_data or set()
199
+ if isinstance(value, str):
200
+ return cast("T", value)
201
+ if isinstance(value, Mapping):
202
+ return cast(
203
+ "T",
204
+ {
205
+ k: lazy_render(v, partial_data, portal)
206
+ for k, v in cast("Mapping[str, Any]", value).items()
207
+ if should_render(v, partial_data)
208
+ },
209
+ )
210
+
211
+ if isinstance(value, (list, tuple)):
212
+ filtered = [
213
+ lazy_render(v, partial_data, portal) for v in cast("Iterable[Any]", value) if should_render(v, partial_data)
214
+ ]
215
+ return cast("T", type(value)(filtered)) # pyright: ignore[reportUnknownArgumentType]
216
+
217
+ if is_lazy_prop(value) and should_render(value, partial_data):
218
+ return cast("T", value.render(portal))
219
+
220
+ return cast("T", value)
221
+
222
+
223
+ def get_shared_props(
224
+ request: "ASGIConnection[Any, Any, Any, Any]",
225
+ partial_data: "Optional[set[str]]" = None,
226
+ ) -> "dict[str, Any]":
227
+ """Return shared session props for a request.
228
+
229
+ Args:
230
+ request: The ASGI connection.
231
+ partial_data: Optional set of keys for partial rendering.
232
+ portal: Optional portal to use for async rendering
233
+ Returns:
234
+ Dict[str, Any]: The shared props.
235
+
236
+ Note:
237
+ Be sure to call this before `self.create_template_context` if you would like to include the `flash` message details.
238
+ """
239
+ props: "dict[str, Any]" = {}
240
+ flash: "dict[str, list[str]]" = defaultdict(list)
241
+ errors: "dict[str, Any]" = {}
242
+ error_bag = request.headers.get("X-Inertia-Error-Bag", None)
243
+
244
+ try:
245
+ errors = request.session.pop("_errors", {})
246
+ shared_props = cast("dict[str,Any]", request.session.pop("_shared", {}))
247
+ inertia_plugin = cast("InertiaPlugin", request.app.plugins.get("InertiaPlugin"))
248
+
249
+ # Handle deferred props
250
+ for key, value in shared_props.items():
251
+ if is_lazy_prop(value) and should_render(value, partial_data):
252
+ props[key] = value.render(inertia_plugin.portal)
253
+ continue
254
+ if should_render(value, partial_data):
255
+ props[key] = value
256
+
257
+ for message in cast("list[dict[str,Any]]", request.session.pop("_messages", [])):
258
+ flash[message["category"]].append(message["message"])
259
+
260
+ props.update(inertia_plugin.config.extra_static_page_props)
261
+ for session_prop in inertia_plugin.config.extra_session_page_props:
262
+ if session_prop not in props and session_prop in request.session:
263
+ props[session_prop] = request.session.get(session_prop)
264
+
265
+ except (AttributeError, ImproperlyConfiguredException):
266
+ msg = "Unable to generate all shared props. A valid session was not found for this request."
267
+ request.logger.warning(msg)
268
+
269
+ props["flash"] = flash
270
+ props["errors"] = {error_bag: errors} if error_bag is not None else errors
271
+ props["csrf_token"] = value_or_default(ScopeState.from_scope(request.scope).csrf_token, "")
272
+ return props
273
+
274
+
275
+ def share(
276
+ connection: "ASGIConnection[Any, Any, Any, Any]",
277
+ key: "str",
278
+ value: "Any",
279
+ ) -> "None":
280
+ """Share a value in the session.
281
+
282
+ Args:
283
+ connection: The ASGI connection.
284
+ key: The key to store the value under.
285
+ value: The value to store.
286
+ """
287
+ try:
288
+ connection.session.setdefault("_shared", {}).update({key: value})
289
+ except (AttributeError, ImproperlyConfiguredException):
290
+ msg = "Unable to set `share` session state. A valid session was not found for this request."
291
+ connection.logger.warning(msg)
292
+
293
+
294
+ def error(
295
+ connection: "ASGIConnection[Any, Any, Any, Any]",
296
+ key: "str",
297
+ message: "str",
298
+ ) -> "None":
299
+ """Set an error message in the session.
300
+
301
+ Args:
302
+ connection: The ASGI connection.
303
+ key: The key to store the error under.
304
+ message: The error message.
305
+ """
306
+ try:
307
+ connection.session.setdefault("_errors", {}).update({key: message})
308
+ except (AttributeError, ImproperlyConfiguredException):
309
+ msg = "Unable to set `error` session state. A valid session was not found for this request."
310
+ connection.logger.warning(msg)
311
+
312
+
313
+ def js_routes_script(js_routes: "Routes") -> "Markup":
314
+ @lru_cache
315
+ def _markup_safe_json_dumps(js_routes: "str") -> "Markup":
316
+ js = js_routes.replace("<", "\\u003c").replace(">", "\\u003e").replace("&", "\\u0026").replace("'", "\\u0027")
317
+ return Markup(js)
318
+
319
+ return Markup(
320
+ dedent(f"""
321
+ <script type="module">
322
+ globalThis.routes = JSON.parse('{_markup_safe_json_dumps(js_routes.formatted_routes)}')
323
+ </script>
324
+ """),
325
+ )
@@ -0,0 +1,49 @@
1
+ from typing import TYPE_CHECKING, Any, Optional
2
+
3
+ from litestar import Request
4
+ from litestar.middleware import AbstractMiddleware
5
+ from litestar.types import Receive, Scope, Send
6
+
7
+ from litestar_vite.inertia.response import InertiaRedirect
8
+ from litestar_vite.plugin import VitePlugin
9
+
10
+ if TYPE_CHECKING:
11
+ from litestar.connection.base import (
12
+ AuthT,
13
+ StateT,
14
+ UserT,
15
+ )
16
+ from litestar.types import ASGIApp, Receive, Scope, Send
17
+
18
+
19
+ def redirect_on_asset_version_mismatch(request: "Request[UserT, AuthT, StateT]") -> "Optional[InertiaRedirect]":
20
+ if getattr(request, "is_inertia", None) is None:
21
+ return None
22
+ inertia_version = request.headers.get("X-Inertia-Version")
23
+ if inertia_version is None:
24
+ return None
25
+
26
+ vite_plugin = request.app.plugins.get(VitePlugin)
27
+ if inertia_version == vite_plugin.asset_loader.version_id:
28
+ return None
29
+ return InertiaRedirect(request, redirect_to=str(request.url))
30
+
31
+
32
+ class InertiaMiddleware(AbstractMiddleware):
33
+ def __init__(self, app: "ASGIApp") -> None:
34
+ super().__init__(app)
35
+ self.app = app
36
+
37
+ async def __call__(
38
+ self,
39
+ scope: "Scope",
40
+ receive: "Receive",
41
+ send: "Send",
42
+ ) -> None:
43
+ request = Request[Any, Any, Any](scope=scope)
44
+ redirect = redirect_on_asset_version_mismatch(request)
45
+ if redirect is not None:
46
+ response = redirect.to_asgi_response(app=None, request=request) # pyright: ignore[reportUnknownMemberType]
47
+ await response(scope, receive, send)
48
+ else:
49
+ await self.app(scope, receive, send)