litestar-vite 0.1.22__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of litestar-vite might be problematic. Click here for more details.

@@ -0,0 +1,345 @@
1
+ from __future__ import annotations
2
+
3
+ import itertools
4
+ from collections import defaultdict
5
+ from functools import lru_cache
6
+ from mimetypes import guess_type
7
+ from pathlib import PurePath
8
+ from textwrap import dedent
9
+ from typing import TYPE_CHECKING, Any, Dict, Iterable, Mapping, TypeVar, cast
10
+ from urllib.parse import quote
11
+
12
+ from litestar import Litestar, MediaType, Request, Response
13
+ from litestar.datastructures.cookie import Cookie
14
+ from litestar.exceptions import ImproperlyConfiguredException
15
+ from litestar.response import Redirect
16
+ from litestar.response.base import ASGIResponse
17
+ from litestar.serialization import get_serializer
18
+ from litestar.status_codes import HTTP_200_OK, HTTP_303_SEE_OTHER, HTTP_307_TEMPORARY_REDIRECT, HTTP_409_CONFLICT
19
+ from litestar.utils.deprecation import warn_deprecation
20
+ from litestar.utils.empty import value_or_default
21
+ from litestar.utils.helpers import get_enum_string_value
22
+ from litestar.utils.scope.state import ScopeState
23
+ from markupsafe import Markup
24
+
25
+ from litestar_vite.inertia._utils import get_headers
26
+ from litestar_vite.inertia.types import InertiaHeaderType, PageProps
27
+ from litestar_vite.plugin import VitePlugin
28
+
29
+ if TYPE_CHECKING:
30
+ from litestar.app import Litestar
31
+ from litestar.background_tasks import BackgroundTask, BackgroundTasks
32
+ from litestar.connection import ASGIConnection
33
+ from litestar.connection.base import AuthT, StateT, UserT
34
+ from litestar.types import ResponseCookies, ResponseHeaders, TypeEncodersMap
35
+
36
+ from litestar_vite.inertia.routes import Routes
37
+
38
+ from .plugin import InertiaPlugin
39
+
40
+ T = TypeVar("T")
41
+
42
+
43
+ def share(
44
+ connection: ASGIConnection[Any, Any, Any, Any],
45
+ key: str,
46
+ value: Any,
47
+ ) -> None:
48
+ connection.session.setdefault("_shared", {}).update({key: value})
49
+
50
+
51
+ def error(
52
+ connection: ASGIConnection[Any, Any, Any, Any],
53
+ key: str,
54
+ message: str,
55
+ ) -> None:
56
+ connection.session.setdefault("_errors", {}).update({key: message})
57
+
58
+
59
+ def get_shared_props(request: ASGIConnection[Any, Any, Any, Any]) -> Dict[str, Any]: # noqa: UP006
60
+ """Return shared session props for a request
61
+
62
+
63
+ Be sure to call this before `self.create_template_context` if you would like to include the `flash` message details.
64
+ """
65
+ error_bag = request.headers.get("X-Inertia-Error-Bag", None)
66
+ errors = request.session.pop("_errors", {})
67
+ props = request.session.pop("_shared", {})
68
+ flash: dict[str, list[str]] = defaultdict(list)
69
+ for message in request.session.pop("_messages", []):
70
+ flash[message["category"]].append(message["message"])
71
+
72
+ inertia_plugin = cast("InertiaPlugin", request.app.plugins.get("InertiaPlugin"))
73
+ props.update(inertia_plugin.config.extra_page_props)
74
+ props["csrf_token"] = value_or_default(ScopeState.from_scope(request.scope).csrf_token, "")
75
+ props["flash"] = flash
76
+ props["errors"] = {error_bag: errors} if error_bag is not None else errors
77
+ return props
78
+
79
+
80
+ def js_routes_script(js_routes: Routes) -> Markup:
81
+ @lru_cache
82
+ def _markup_safe_json_dumps(js_routes: str) -> Markup:
83
+ js = js_routes.replace("<", "\\u003c").replace(">", "\\u003e").replace("&", "\\u0026").replace("'", "\\u0027")
84
+ return Markup(js)
85
+
86
+ return Markup(
87
+ dedent(f"""
88
+ <script type="module">
89
+ globalThis.routes = JSON.parse('{_markup_safe_json_dumps(js_routes.formatted_routes)}')
90
+ </script>
91
+ """),
92
+ )
93
+
94
+
95
+ class InertiaResponse(Response[T]):
96
+ """Inertia Response"""
97
+
98
+ def __init__(
99
+ self,
100
+ content: T,
101
+ *,
102
+ template_name: str | None = None,
103
+ template_str: str | None = None,
104
+ background: BackgroundTask | BackgroundTasks | None = None,
105
+ context: dict[str, Any] | None = None,
106
+ cookies: ResponseCookies | None = None,
107
+ encoding: str = "utf-8",
108
+ headers: ResponseHeaders | None = None,
109
+ media_type: MediaType | str | None = None,
110
+ status_code: int = HTTP_200_OK,
111
+ type_encoders: TypeEncodersMap | None = None,
112
+ ) -> None:
113
+ """Handle the rendering of a given template into a bytes string.
114
+
115
+ Args:
116
+ content: A value for the response body that will be rendered into bytes string.
117
+ template_name: Path-like name for the template to be rendered, e.g. ``index.html``.
118
+ template_str: A string representing the template, e.g. ``tmpl = "Hello <strong>World</strong>"``.
119
+ background: A :class:`BackgroundTask <.background_tasks.BackgroundTask>` instance or
120
+ :class:`BackgroundTasks <.background_tasks.BackgroundTasks>` to execute after the response is finished.
121
+ Defaults to ``None``.
122
+ context: A dictionary of key/value pairs to be passed to the temple engine's render method.
123
+ cookies: A list of :class:`Cookie <.datastructures.Cookie>` instances to be set under the response
124
+ ``Set-Cookie`` header.
125
+ encoding: Content encoding
126
+ headers: A string keyed dictionary of response headers. Header keys are insensitive.
127
+ media_type: A string or member of the :class:`MediaType <.enums.MediaType>` enum. If not set, try to infer
128
+ the media type based on the template name. If this fails, fall back to ``text/plain``.
129
+ status_code: A value for the response HTTP status code.
130
+ type_encoders: A mapping of types to callables that transform them into types supported for serialization.
131
+ """
132
+ if template_name and template_str:
133
+ msg = "Either template_name or template_str must be provided, not both."
134
+ raise ValueError(msg)
135
+ self.content = content
136
+ self.background = background
137
+ self.cookies: list[Cookie] = (
138
+ [Cookie(key=key, value=value) for key, value in cookies.items()]
139
+ if isinstance(cookies, Mapping)
140
+ else list(cookies or [])
141
+ )
142
+ self.encoding = encoding
143
+ self.headers: dict[str, Any] = (
144
+ dict(headers) if isinstance(headers, Mapping) else {h.name: h.value for h in headers or {}}
145
+ )
146
+ self.media_type = media_type
147
+ self.status_code = status_code
148
+ self.response_type_encoders = {**(self.type_encoders or {}), **(type_encoders or {})}
149
+ self.context = context or {}
150
+ self.template_name = template_name
151
+ self.template_str = template_str
152
+
153
+ def create_template_context(
154
+ self,
155
+ request: Request[UserT, AuthT, StateT],
156
+ page_props: PageProps[T],
157
+ type_encoders: TypeEncodersMap | None = None,
158
+ ) -> dict[str, Any]:
159
+ """Create a context object for the template.
160
+
161
+ Args:
162
+ request: A :class:`Request <.connection.Request>` instance.
163
+ page_props: A formatted object to return the inertia configuration.
164
+ type_encoders: A mapping of types to callables that transform them into types supported for serialization.
165
+
166
+ Returns:
167
+ A dictionary holding the template context
168
+ """
169
+ csrf_token = value_or_default(ScopeState.from_scope(request.scope).csrf_token, "")
170
+ inertia_props = self.render(page_props, MediaType.JSON, get_serializer(type_encoders)).decode()
171
+ return {
172
+ **self.context,
173
+ "inertia": inertia_props,
174
+ "js_routes": js_routes_script(request.app.state.js_routes),
175
+ "request": request,
176
+ "csrf_input": f'<input type="hidden" name="_csrf_token" value="{csrf_token}" />',
177
+ }
178
+
179
+ def to_asgi_response(
180
+ self,
181
+ app: Litestar | None,
182
+ request: Request[UserT, AuthT, StateT],
183
+ *,
184
+ background: BackgroundTask | BackgroundTasks | None = None,
185
+ cookies: Iterable[Cookie] | None = None,
186
+ encoded_headers: Iterable[tuple[bytes, bytes]] | None = None,
187
+ headers: dict[str, str] | None = None,
188
+ is_head_response: bool = False,
189
+ media_type: MediaType | str | None = None,
190
+ status_code: int | None = None,
191
+ type_encoders: TypeEncodersMap | None = None,
192
+ ) -> ASGIResponse:
193
+ if app is not None:
194
+ warn_deprecation(
195
+ version="2.1",
196
+ deprecated_name="app",
197
+ kind="parameter",
198
+ removal_in="3.0.0",
199
+ alternative="request.app",
200
+ )
201
+ inertia_enabled = getattr(request, "inertia_enabled", False) or getattr(request, "is_inertia", False)
202
+ is_inertia = getattr(request, "is_inertia", False)
203
+
204
+ headers = {**headers, **self.headers} if headers is not None else self.headers
205
+ cookies = self.cookies if cookies is None else itertools.chain(self.cookies, cookies)
206
+ type_encoders = (
207
+ {**type_encoders, **(self.response_type_encoders or {})} if type_encoders else self.response_type_encoders
208
+ )
209
+ if not inertia_enabled:
210
+ media_type = get_enum_string_value(self.media_type or media_type or MediaType.JSON)
211
+ return ASGIResponse(
212
+ background=self.background or background,
213
+ body=self.render(self.content, media_type, get_serializer(type_encoders)),
214
+ cookies=cookies,
215
+ encoded_headers=encoded_headers,
216
+ encoding=self.encoding,
217
+ headers=headers,
218
+ is_head_response=is_head_response,
219
+ media_type=media_type,
220
+ status_code=self.status_code or status_code,
221
+ )
222
+ vite_plugin = request.app.plugins.get(VitePlugin)
223
+ template_engine = vite_plugin.template_config.to_engine()
224
+ headers.update(
225
+ {"Vary": "Accept", **get_headers(InertiaHeaderType(enabled=True))},
226
+ )
227
+ shared_props = get_shared_props(request)
228
+ page_props = PageProps[T](
229
+ component=request.inertia.route_component, # type: ignore[attr-defined] # pyright: ignore[reportUnknownArgumentType,reportUnknownMemberType,reportAttributeAccessIssue]
230
+ props={"content": self.content, **shared_props}, # pyright: ignore[reportArgumentType]
231
+ version=template_engine.asset_loader.version_id,
232
+ url=request.url.path,
233
+ )
234
+ if is_inertia:
235
+ media_type = get_enum_string_value(self.media_type or media_type or MediaType.JSON)
236
+ body = self.render(page_props, media_type, get_serializer(type_encoders))
237
+ return ASGIResponse(
238
+ background=self.background or background,
239
+ body=body,
240
+ cookies=cookies,
241
+ encoded_headers=encoded_headers,
242
+ encoding=self.encoding,
243
+ headers=headers,
244
+ is_head_response=is_head_response,
245
+ media_type=media_type,
246
+ status_code=self.status_code or status_code,
247
+ )
248
+
249
+ if not template_engine:
250
+ msg = "Template engine is not configured"
251
+ raise ImproperlyConfiguredException(msg)
252
+ # it should default to HTML at this point unless the user specified something
253
+ media_type = media_type or MediaType.HTML
254
+ if not media_type:
255
+ if self.template_name:
256
+ suffixes = PurePath(self.template_name).suffixes
257
+ for suffix in suffixes:
258
+ if _type := guess_type(f"name{suffix}")[0]:
259
+ media_type = _type
260
+ break
261
+ else:
262
+ media_type = MediaType.TEXT
263
+ else:
264
+ media_type = MediaType.HTML
265
+ context = self.create_template_context(request, page_props, type_encoders) # pyright: ignore[reportUnknownMemberType]
266
+ if self.template_str is not None:
267
+ body = template_engine.render_string(self.template_str, context).encode(self.encoding)
268
+ else:
269
+ inertia_plugin = cast("InertiaPlugin", request.app.plugins.get("InertiaPlugin"))
270
+ template_name = self.template_name or inertia_plugin.config.root_template
271
+ # cast to str b/c we know that either template_name cannot be None if template_str is None
272
+ template = template_engine.get_template(template_name)
273
+ body = template.render(**context).encode(self.encoding)
274
+
275
+ return ASGIResponse(
276
+ background=self.background or background,
277
+ body=body,
278
+ cookies=cookies,
279
+ encoded_headers=encoded_headers,
280
+ encoding=self.encoding,
281
+ headers=headers,
282
+ is_head_response=is_head_response,
283
+ media_type=media_type,
284
+ status_code=self.status_code or status_code,
285
+ )
286
+
287
+
288
+ class InertiaExternalRedirect(Response[Any]):
289
+ """Client side redirect."""
290
+
291
+ def __init__(
292
+ self,
293
+ request: Request[Any, Any, Any],
294
+ redirect_to: str,
295
+ **kwargs: Any,
296
+ ) -> None:
297
+ """Initialize external redirect, Set status code to 409 (required by Inertia),
298
+ and pass redirect url.
299
+ """
300
+ super().__init__(
301
+ content=kwargs.get("content", ""),
302
+ status_code=HTTP_409_CONFLICT,
303
+ headers={"X-Inertia": "true", "X-Inertia-Location": quote(redirect_to, safe="/#%[]=:;$&()+,!?*@'~")},
304
+ cookies=request.cookies,
305
+ **kwargs,
306
+ )
307
+
308
+
309
+ class InertiaRedirect(Redirect):
310
+ """Client side redirect."""
311
+
312
+ def __init__(
313
+ self,
314
+ request: Request[Any, Any, Any],
315
+ redirect_to: str,
316
+ **kwargs: Any,
317
+ ) -> None:
318
+ """Initialize external redirect, Set status code to 409 (required by Inertia),
319
+ and pass redirect url.
320
+ """
321
+ super().__init__(
322
+ path=redirect_to,
323
+ status_code=HTTP_307_TEMPORARY_REDIRECT if request.method == "GET" else HTTP_303_SEE_OTHER,
324
+ cookies=request.cookies,
325
+ **kwargs,
326
+ )
327
+
328
+
329
+ class InertiaBack(Redirect):
330
+ """Client side redirect."""
331
+
332
+ def __init__(
333
+ self,
334
+ request: Request[Any, Any, Any],
335
+ **kwargs: Any,
336
+ ) -> None:
337
+ """Initialize external redirect, Set status code to 409 (required by Inertia),
338
+ and pass redirect url.
339
+ """
340
+ super().__init__(
341
+ path=request.headers["Referer"],
342
+ status_code=HTTP_307_TEMPORARY_REDIRECT if request.method == "GET" else HTTP_303_SEE_OTHER,
343
+ cookies=request.cookies,
344
+ **kwargs,
345
+ )
@@ -0,0 +1,54 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from functools import cached_property
5
+ from typing import TYPE_CHECKING
6
+
7
+ from litestar.app import DEFAULT_OPENAPI_CONFIG
8
+ from litestar.cli._utils import (
9
+ remove_default_schema_routes,
10
+ remove_routes_with_patterns,
11
+ )
12
+ from litestar.routes import ASGIRoute, WebSocketRoute
13
+ from litestar.serialization import encode_json
14
+
15
+ if TYPE_CHECKING:
16
+ from litestar import Litestar
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class Routes:
21
+ routes: dict[str, str]
22
+
23
+ @cached_property
24
+ def formatted_routes(self) -> str:
25
+ return encode_json(self.routes).decode(encoding="utf-8")
26
+
27
+
28
+ EXCLUDED_METHODS = {"HEAD", "OPTIONS", "TRACE"}
29
+
30
+
31
+ def generate_js_routes(
32
+ app: Litestar,
33
+ exclude: tuple[str, ...] | None = None,
34
+ schema: bool = False,
35
+ ) -> Routes:
36
+ sorted_routes = sorted(app.routes, key=lambda r: r.path)
37
+ if not schema:
38
+ openapi_config = app.openapi_config or DEFAULT_OPENAPI_CONFIG
39
+ sorted_routes = remove_default_schema_routes(sorted_routes, openapi_config)
40
+ if exclude is not None:
41
+ sorted_routes = remove_routes_with_patterns(sorted_routes, exclude)
42
+ route_list: dict[str, str] = {}
43
+ for route in sorted_routes:
44
+ if isinstance(route, (ASGIRoute, WebSocketRoute)):
45
+ route_name = route.route_handler.name or route.route_handler.handler_name
46
+ if len(route.methods.difference(EXCLUDED_METHODS)) > 0:
47
+ route_list[route_name] = route.path
48
+ else:
49
+ for handler in route.route_handlers:
50
+ route_name = handler.name or handler.handler_name
51
+ if handler.http_methods.isdisjoint(EXCLUDED_METHODS):
52
+ route_list[route_name] = route.path
53
+
54
+ return Routes(routes=route_list)
@@ -0,0 +1,39 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Generic, TypedDict, TypeVar
5
+
6
+ __all__ = (
7
+ "InertiaHeaderType",
8
+ "PageProps",
9
+ )
10
+
11
+
12
+ T = TypeVar("T")
13
+
14
+
15
+ @dataclass
16
+ class PageProps(Generic[T]):
17
+ """Inertia Page Props Type."""
18
+
19
+ component: str
20
+ url: str
21
+ version: str
22
+ props: dict[str, Any]
23
+
24
+
25
+ @dataclass
26
+ class InertiaProps(Generic[T]):
27
+ """Inertia Props Type."""
28
+
29
+ page: PageProps[T]
30
+
31
+
32
+ class InertiaHeaderType(TypedDict, total=False):
33
+ """Type for inertia_headers parameter in get_headers()."""
34
+
35
+ enabled: bool | None
36
+ version: str | None
37
+ location: str | None
38
+ partial_data: str | None
39
+ partial_component: str | None
litestar_vite/loader.py CHANGED
@@ -1,17 +1,15 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import json
4
+ from functools import cached_property
4
5
  from pathlib import Path
5
- from typing import TYPE_CHECKING, Any, ClassVar, TypeVar
6
+ from textwrap import dedent
7
+ from typing import TYPE_CHECKING, Any, ClassVar
6
8
  from urllib.parse import urljoin
7
9
 
8
- from litestar.template import TemplateEngineProtocol
9
-
10
10
  if TYPE_CHECKING:
11
11
  from litestar_vite.config import ViteConfig
12
12
 
13
- T = TypeVar("T", bound=TemplateEngineProtocol)
14
-
15
13
 
16
14
  class ViteAssetLoader:
17
15
  """Vite manifest loader.
@@ -24,6 +22,7 @@ class ViteAssetLoader:
24
22
  def __init__(self, config: ViteConfig) -> None:
25
23
  self._config = config
26
24
  self._manifest: dict[str, Any] = {}
25
+ self._manifest_content: str = ""
27
26
  self._vite_base_path: str | None = None
28
27
 
29
28
  @classmethod
@@ -34,6 +33,12 @@ class ViteAssetLoader:
34
33
  cls._instance.parse_manifest()
35
34
  return cls._instance
36
35
 
36
+ @cached_property
37
+ def version_id(self) -> str:
38
+ if self._manifest_content != "":
39
+ return str(hash(self.manifest_content))
40
+ return "1.0"
41
+
37
42
  def parse_manifest(self) -> None:
38
43
  """Read and parse the Vite manifest file.
39
44
 
@@ -76,11 +81,11 @@ class ViteAssetLoader:
76
81
  try:
77
82
  if manifest_path.exists():
78
83
  with manifest_path.open() as manifest_file:
79
- manifest_content = manifest_file.read()
80
- self._manifest = json.loads(manifest_content)
84
+ self.manifest_content = manifest_file.read()
85
+ self._manifest = json.loads(self.manifest_content)
81
86
  else:
82
87
  self._manifest = {}
83
- except Exception as exc: # noqa: BLE001
88
+ except Exception as exc:
84
89
  msg = "There was an issue reading the Vite manifest file at %s. Did you forget to build your assets?"
85
90
  raise RuntimeError(
86
91
  msg,
@@ -111,7 +116,7 @@ class ViteAssetLoader:
111
116
  str: The script tag or an empty string.
112
117
  """
113
118
  if self._config.is_react and self._config.hot_reload and self._config.dev_mode:
114
- return f"""
119
+ return dedent(f"""
115
120
  <script type="module">
116
121
  import RefreshRuntime from '{self._vite_server_url()}@react-refresh'
117
122
  RefreshRuntime.injectIntoGlobalHook(window)
@@ -119,7 +124,7 @@ class ViteAssetLoader:
119
124
  window.$RefreshSig$ = () => (type) => type
120
125
  window.__vite_plugin_react_preamble_installed__=true
121
126
  </script>
122
- """
127
+ """)
123
128
  return ""
124
129
 
125
130
  def generate_asset_tags(self, path: str | list[str], scripts_attrs: dict[str, str] | None = None) -> str:
@@ -133,7 +138,9 @@ class ViteAssetLoader:
133
138
  if self._config.hot_reload and self._config.dev_mode:
134
139
  return "".join(
135
140
  [
136
- self._script_tag(
141
+ self._style_tag(self._vite_server_url(p))
142
+ if p.endswith(".css")
143
+ else self._script_tag(
137
144
  self._vite_server_url(p),
138
145
  {"type": "module", "async": "", "defer": ""},
139
146
  )
@@ -142,7 +149,7 @@ class ViteAssetLoader:
142
149
  )
143
150
 
144
151
  if any(p for p in path if p not in self._manifest):
145
- msg = "Cannot find %s in Vite manifest at %s. Did you forget to build your assets?"
152
+ msg = "Cannot find %s in Vite manifest at %s. Did you forget to build your assets after an update?"
146
153
  raise RuntimeError(
147
154
  msg,
148
155
  path,
@@ -150,30 +157,33 @@ class ViteAssetLoader:
150
157
  )
151
158
 
152
159
  tags: list[str] = []
153
- for p in path:
154
- manifest_entry: dict = self._manifest[p]
160
+ manifest_entry: dict[str, Any] = {}
161
+ manifest_entry.update({p: self._manifest[p] for p in path})
155
162
  if not scripts_attrs:
156
163
  scripts_attrs = {"type": "module", "async": "", "defer": ""}
157
-
158
- # Add dependent CSS
159
- if "css" in manifest_entry:
160
- tags.extend(
161
- self._style_tag(urljoin(self._config.asset_url, css_path)) for css_path in manifest_entry.get("css", {})
162
- )
163
- # Add dependent "vendor"
164
- if "imports" in manifest_entry:
165
- tags.extend(
166
- self.generate_asset_tags(vendor_path, scripts_attrs=scripts_attrs)
167
- for vendor_path in manifest_entry.get("imports", {})
168
- )
169
- # Add the script by itself
170
- tags.append(
171
- self._script_tag(
172
- urljoin(self._config.asset_url, manifest_entry["file"]),
173
- attrs=scripts_attrs,
174
- ),
175
- )
176
-
164
+ for manifest in manifest_entry.values():
165
+ if "css" in manifest:
166
+ tags.extend(
167
+ self._style_tag(urljoin(self._config.asset_url, css_path)) for css_path in manifest.get("css", {})
168
+ )
169
+ # Add dependent "vendor"
170
+ if "imports" in manifest:
171
+ tags.extend(
172
+ self.generate_asset_tags(vendor_path, scripts_attrs=scripts_attrs)
173
+ for vendor_path in manifest.get("imports", {})
174
+ )
175
+ # Add the script by itself
176
+ if manifest.get("file").endswith(".css"):
177
+ tags.append(
178
+ self._style_tag(urljoin(self._config.asset_url, manifest["file"])),
179
+ )
180
+ else:
181
+ tags.append(
182
+ self._script_tag(
183
+ urljoin(self._config.asset_url, manifest["file"]),
184
+ attrs=scripts_attrs,
185
+ ),
186
+ )
177
187
  return "".join(tags)
178
188
 
179
189
  def _vite_server_url(self, path: str | None = None) -> str:
@@ -193,7 +203,9 @@ class ViteAssetLoader:
193
203
 
194
204
  def _script_tag(self, src: str, attrs: dict[str, str] | None = None) -> str:
195
205
  """Generate an HTML script tag."""
196
- attrs_str = " ".join([f'{key}="{value}"' for key, value in attrs.items()]) if attrs is not None else ""
206
+ if attrs is None:
207
+ attrs = {}
208
+ attrs_str = " ".join([f'{key}="{value}"' for key, value in attrs.items()])
197
209
  return f'<script {attrs_str} src="{src}"></script>'
198
210
 
199
211
  def _style_tag(self, href: str) -> str:
litestar_vite/plugin.py CHANGED
@@ -6,7 +6,9 @@ from pathlib import Path
6
6
  from typing import TYPE_CHECKING, Iterator, cast
7
7
 
8
8
  from litestar.plugins import CLIPlugin, InitPluginProtocol
9
- from litestar.static_files import create_static_files_router
9
+ from litestar.static_files import (
10
+ create_static_files_router, # pyright: ignore[reportUnknownVariableType]
11
+ )
10
12
 
11
13
  from litestar_vite.config import ViteConfig
12
14
 
@@ -15,6 +17,9 @@ if TYPE_CHECKING:
15
17
  from litestar import Litestar
16
18
  from litestar.config.app import AppConfig
17
19
 
20
+ from litestar_vite.config import ViteTemplateConfig
21
+ from litestar_vite.template_engine import ViteTemplateEngine
22
+
18
23
 
19
24
  def set_environment(config: ViteConfig) -> None:
20
25
  """Configure environment for easier integration"""
@@ -23,6 +28,7 @@ def set_environment(config: ViteConfig) -> None:
23
28
  os.environ.setdefault("VITE_PORT", str(config.port))
24
29
  os.environ.setdefault("VITE_HOST", config.host)
25
30
  os.environ.setdefault("VITE_PROTOCOL", config.protocol)
31
+ os.environ.setdefault("APP_URL", "http://localhost:8000")
26
32
  if config.dev_mode:
27
33
  os.environ.setdefault("VITE_DEV_MODE", str(config.dev_mode))
28
34
 
@@ -46,11 +52,21 @@ class VitePlugin(InitPluginProtocol, CLIPlugin):
46
52
  def config(self) -> ViteConfig:
47
53
  return self._config
48
54
 
55
+ @property
56
+ def template_config(self) -> ViteTemplateConfig[ViteTemplateEngine]:
57
+ from litestar_vite.config import ViteTemplateConfig
58
+ from litestar_vite.template_engine import ViteTemplateEngine
59
+
60
+ return ViteTemplateConfig[ViteTemplateEngine](
61
+ engine=ViteTemplateEngine,
62
+ config=self._config,
63
+ directory=self._config.template_dir,
64
+ )
65
+
49
66
  def on_cli_init(self, cli: Group) -> None:
50
67
  from litestar_vite.cli import vite_group
51
68
 
52
69
  cli.add_command(vite_group)
53
- return super().on_cli_init(cli)
54
70
 
55
71
  def on_app_init(self, app_config: AppConfig) -> AppConfig:
56
72
  """Configure application for use with Vite.
@@ -59,15 +75,8 @@ class VitePlugin(InitPluginProtocol, CLIPlugin):
59
75
  app_config: The :class:`AppConfig <.config.app.AppConfig>` instance.
60
76
  """
61
77
 
62
- from litestar_vite.config import ViteTemplateConfig
63
- from litestar_vite.template_engine import ViteTemplateEngine
64
-
65
78
  if self._config.template_dir is not None:
66
- app_config.template_config = ViteTemplateConfig[ViteTemplateEngine]( # type: ignore[assignment]
67
- engine=ViteTemplateEngine,
68
- config=self._config,
69
- directory=self._config.template_dir,
70
- )
79
+ app_config.template_config = self.template_config
71
80
 
72
81
  if self._config.set_static_folders:
73
82
  static_dirs = [Path(self._config.bundle_dir), Path(self._config.resource_dir)]
@@ -115,7 +124,7 @@ class VitePlugin(InitPluginProtocol, CLIPlugin):
115
124
  yield
116
125
  finally:
117
126
  if vite_thread.is_alive():
118
- vite_thread.join()
127
+ vite_thread.join(timeout=5)
119
128
  console.print("[yellow]Vite process stopped.[/]")
120
129
 
121
130
  else: