djhtmx 1.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
djhtmx/repo.py ADDED
@@ -0,0 +1,585 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import random
5
+ from collections import defaultdict
6
+ from collections.abc import AsyncIterable, Generator, Iterable
7
+ from dataclasses import dataclass
8
+ from dataclasses import field as Field
9
+ from typing import Any
10
+
11
+ from django.core.signing import Signer
12
+ from django.http import HttpRequest, QueryDict
13
+ from django.utils.html import format_html
14
+ from django.utils.safestring import SafeString, mark_safe
15
+ from pydantic import ValidationError
16
+ from uuid6 import uuid7
17
+
18
+ from djhtmx.global_events import HtmxUnhandledError
19
+ from djhtmx.tracing import tracing_span
20
+
21
+ from . import json
22
+ from .command_queue import CommandQueue
23
+ from .commands import PushURL, ReplaceURL, SendHtml
24
+ from .component import (
25
+ LISTENERS,
26
+ REGISTRY,
27
+ BuildAndRender,
28
+ Command,
29
+ Destroy,
30
+ DispatchDOMEvent,
31
+ Emit,
32
+ Execute,
33
+ Focus,
34
+ HtmxComponent,
35
+ Open,
36
+ Redirect,
37
+ Render,
38
+ Signal,
39
+ SkipRender,
40
+ _get_query_patchers,
41
+ )
42
+ from .introspection import filter_parameters
43
+ from .settings import (
44
+ KEY_SIZE_ERROR_THRESHOLD,
45
+ KEY_SIZE_SAMPLE_PROB,
46
+ KEY_SIZE_WARN_THRESHOLD,
47
+ LOGIN_URL,
48
+ SESSION_TTL,
49
+ conn,
50
+ )
51
+ from .utils import db, get_params
52
+
53
+ signer = Signer()
54
+
55
+ logger = logging.getLogger(__name__)
56
+
57
+
58
+ ProcessedCommand = (
59
+ Destroy | Redirect | Open | Focus | DispatchDOMEvent | SendHtml | PushURL | ReplaceURL
60
+ )
61
+
62
+
63
+ class Repository:
64
+ """An in-memory (cheap) mapping of component IDs to its states.
65
+
66
+ When an HTMX request comes, all the state from all the components are
67
+ placed in a registry. This way we can instantiate components if/when
68
+ needed.
69
+
70
+ For instance, if a component is subscribed to an event and the event fires
71
+ during the request, that component is rendered.
72
+
73
+ """
74
+
75
+ @staticmethod
76
+ def new_session_id():
77
+ return f"djhtmx:{uuid7().hex}"
78
+
79
+ @classmethod
80
+ def from_request(
81
+ cls,
82
+ request: HttpRequest,
83
+ ) -> Repository:
84
+ """Get or build the Repository from the request.
85
+
86
+ If the request has already a Repository attached, return it without
87
+ further processing.
88
+
89
+ Otherwise, build the repository from the request's POST and attach it
90
+ to the request.
91
+
92
+ """
93
+ from django.contrib.auth.models import AnonymousUser
94
+
95
+ if (result := getattr(request, "htmx_repo", None)) is None:
96
+ if (signed_session := request.META.get("HTTP_HX_SESSION")) and not bool(
97
+ request.META.get("HTTP_HX_BOOSTED")
98
+ ):
99
+ session_id = signer.unsign(signed_session)
100
+ else:
101
+ session_id = cls.new_session_id()
102
+
103
+ session = Session(session_id)
104
+
105
+ result = cls(
106
+ user=getattr(request, "user", AnonymousUser()),
107
+ session=session,
108
+ params=get_params(request),
109
+ )
110
+ request.htmx_repo = result # type: ignore
111
+ return result
112
+
113
+ @classmethod
114
+ def from_websocket(cls, user):
115
+ return cls(
116
+ user=user,
117
+ session=Session(cls.new_session_id()), # TODO: take the session from the websocket url
118
+ params=get_params(None),
119
+ )
120
+
121
+ @staticmethod
122
+ def load_states_by_id(states: list[str]) -> dict[str, dict[str, Any]]:
123
+ return {
124
+ state["id"]: state for state in [json.loads(signer.unsign(state)) for state in states]
125
+ }
126
+
127
+ @staticmethod
128
+ def load_subscriptions(
129
+ states_by_id: dict[str, dict[str, Any]], subscriptions: dict[str, str]
130
+ ) -> dict[str, set[str]]:
131
+ subscriptions_to_ids: dict[str, set[str]] = defaultdict(set)
132
+ for component_id, component_subscriptions in subscriptions.items():
133
+ # Register query string subscriptions
134
+ component_name = states_by_id[component_id]["hx_name"]
135
+ for patcher in _get_query_patchers(component_name):
136
+ subscriptions_to_ids[patcher.signal_name].add(component_id)
137
+
138
+ # Register other subscriptions
139
+ for subscription in component_subscriptions.split(","):
140
+ subscriptions_to_ids[subscription].add(component_id)
141
+ return subscriptions_to_ids
142
+
143
+ def __init__(
144
+ self,
145
+ user,
146
+ session: Session,
147
+ params: QueryDict,
148
+ ):
149
+ self.user = user
150
+ self.session = session
151
+ self.session_signed_id = signer.sign(session.id)
152
+ self.params = params
153
+
154
+ # Component life cycle & management
155
+
156
+ def unregister_component(self, component_id: str):
157
+ # delete component state
158
+ self.session.unregister_component(component_id)
159
+
160
+ async def adispatch_event( # pragma: no cover
161
+ self,
162
+ component_id: str,
163
+ event_handler: str,
164
+ event_data: dict[str, Any],
165
+ ) -> AsyncIterable[ProcessedCommand]:
166
+ commands = CommandQueue([Execute(component_id, event_handler, event_data)])
167
+ # Command loop
168
+ try:
169
+ while commands:
170
+ processed_commands = self._run_command(commands)
171
+ while command := await db(next)(processed_commands, None):
172
+ yield command
173
+ except ValidationError as e:
174
+ # This is here to detect validation errors derived from an invalid User
175
+ # Meaning that the user type is not the right one so a login redirect has to happen
176
+ if any(
177
+ e
178
+ for error in e.errors()
179
+ if error["type"] == "is_instance_of" and error["loc"] == ("user",)
180
+ ):
181
+ yield Redirect(LOGIN_URL)
182
+ else:
183
+ raise
184
+
185
+ def dispatch_event(
186
+ self,
187
+ component_id: str,
188
+ event_handler: str,
189
+ event_data: dict[str, Any],
190
+ ) -> Iterable[ProcessedCommand]:
191
+ commands = CommandQueue([Execute(component_id, event_handler, event_data)])
192
+
193
+ # Command loop
194
+ try:
195
+ while commands:
196
+ yield from self._run_command(commands)
197
+ except ValidationError as e:
198
+ # This is here to detect validation errors derived from an invalid User
199
+ # Meaning that the user type is not the right one so a login redirect has to happen
200
+ if any(
201
+ e
202
+ for error in e.errors()
203
+ if error["type"] == "is_instance_of" and error["loc"] == ("user",)
204
+ ):
205
+ yield Redirect(LOGIN_URL)
206
+ else:
207
+ raise
208
+
209
+ def _run_command(self, commands: CommandQueue) -> Generator[ProcessedCommand, None, None]:
210
+ command = commands.pop()
211
+ logger.debug("COMMAND: %s", command)
212
+ commands_to_append: list[Command] = []
213
+ match command:
214
+ case Execute(component_id, event_handler, event_data):
215
+ commands.processing_component_id = component_id
216
+ match self.get_component_by_id(component_id):
217
+ case Destroy() as command:
218
+ yield command
219
+ case HtmxComponent() as component:
220
+ handler = getattr(component, event_handler)
221
+ handler_kwargs = filter_parameters(handler, event_data)
222
+ try:
223
+ emited_commands = handler(**handler_kwargs)
224
+ except Exception as error:
225
+ annotations = getattr(handler, "_htmx_annotations_", None)
226
+ logger.exception(
227
+ "HTMX unhandled exception in component %s",
228
+ component.__class__.__name__,
229
+ )
230
+ emited_commands = [
231
+ Emit(HtmxUnhandledError(error, handler_annotations=annotations))
232
+ ]
233
+ yield from self._process_emited_commands(
234
+ component, emited_commands, commands, during_execute=True
235
+ )
236
+
237
+ case SkipRender(component):
238
+ commands.processing_component_id = component.id
239
+ self.session.store(component)
240
+
241
+ case BuildAndRender(component_type, state, oob, parent_id):
242
+ commands.processing_component_id = state.get("id", "")
243
+ component = self.build(component_type.__name__, state)
244
+
245
+ # Automatically track parent-child relationship if parent_id is specified
246
+ child_id = component.id
247
+ self.session.register_child(parent_id, child_id)
248
+
249
+ commands_to_append.append(Render(component, oob=oob))
250
+
251
+ case Render(component, template, oob, lazy, context):
252
+ commands.processing_component_id = component.id
253
+ html = self.render_html(
254
+ component, oob=oob, template=template, lazy=lazy, context=context
255
+ )
256
+ yield SendHtml(html, debug_trace=f"{component.hx_name}({component.id})")
257
+
258
+ case Destroy(component_id) as command:
259
+ commands.processing_component_id = component_id
260
+ self.unregister_component(component_id)
261
+ yield command
262
+
263
+ case Emit(event):
264
+ for component in self.get_components_by_names(*LISTENERS[type(event)]):
265
+ commands.processing_component_id = component.id
266
+ logger.debug("< AWAKED: %s id=%s", component.hx_name, component.id)
267
+ try:
268
+ emited_commands = component._handle_event(event) # type: ignore
269
+ except Exception as error:
270
+ logger.exception(
271
+ "HTMX unhandled error in the event handler of %s",
272
+ component.__class__.__name__,
273
+ )
274
+ # Don't enter a spiral of death with HtmxUnhandledError
275
+ if not isinstance(event, HtmxUnhandledError):
276
+ emited_commands = [Emit(HtmxUnhandledError(error))]
277
+ else:
278
+ raise
279
+ yield from self._process_emited_commands(
280
+ component, emited_commands, commands, during_execute=False
281
+ )
282
+
283
+ case Signal(signals):
284
+ commands.processing_component_id = ""
285
+ for component_or_destroy in self.get_components_subscribed_to(signals):
286
+ match component_or_destroy:
287
+ case Destroy() as command:
288
+ yield command
289
+ case component:
290
+ logger.debug("< AWAKED: %s id=%s", component.hx_name, component.id)
291
+ commands_to_append.append(Render(component))
292
+
293
+ case (
294
+ Open()
295
+ | ReplaceURL()
296
+ | PushURL()
297
+ | Redirect()
298
+ | Focus()
299
+ | DispatchDOMEvent() as command
300
+ ):
301
+ commands.processing_component_id = ""
302
+ yield command
303
+
304
+ commands.extend(commands_to_append)
305
+ self.session.flush()
306
+
307
+ def _process_emited_commands(
308
+ self,
309
+ component: HtmxComponent,
310
+ emmited_commands: Iterable[Command] | None,
311
+ commands: CommandQueue,
312
+ during_execute: bool,
313
+ ) -> Iterable[ProcessedCommand]:
314
+ component_was_rendered = False
315
+ commands_to_add: list[Command] = []
316
+ for command in emmited_commands or []:
317
+ component_was_rendered = component_was_rendered or (
318
+ isinstance(command, SkipRender | Render) and command.component.id == component.id
319
+ )
320
+ if (
321
+ component_was_rendered
322
+ and during_execute
323
+ and isinstance(command, Render)
324
+ and command.lazy is None
325
+ ):
326
+ # make partial updates not lazy during_execute
327
+ command.lazy = False
328
+ commands_to_add.append(command)
329
+
330
+ if not component_was_rendered:
331
+ commands_to_add.append(
332
+ Render(component, lazy=False if during_execute else component.lazy)
333
+ )
334
+
335
+ if signals := self.update_params_from(component):
336
+ yield ReplaceURL.from_params(self.params)
337
+ commands_to_add.append(Signal({(signal, component.id) for signal in signals}))
338
+
339
+ commands.extend(commands_to_add)
340
+ self.session.store(component)
341
+
342
+ def get_components_subscribed_to(
343
+ self, signals: set[tuple[str, str]]
344
+ ) -> Iterable[HtmxComponent | Destroy]:
345
+ return (
346
+ self.get_component_by_id(c_id)
347
+ for c_id in sorted(self.session.get_component_ids_subscribed_to(signals))
348
+ )
349
+
350
+ def update_params_from(self, component: HtmxComponent) -> set[str]:
351
+ """Updates self.params based on the state of the component
352
+
353
+ Return the set of signals that should be triggered as the result of
354
+ the update.
355
+
356
+ """
357
+ updated_params: set[str] = set()
358
+ if patchers := _get_query_patchers(component.hx_name):
359
+ for patcher in patchers:
360
+ updated_params.update(
361
+ patcher.get_updates_for_params(
362
+ getattr(component, patcher.field_name, None),
363
+ self.params,
364
+ )
365
+ )
366
+ return updated_params
367
+
368
+ def get_component_by_id(self, component_id: str) -> Destroy | HtmxComponent:
369
+ """Return (possibly build) the component by its ID.
370
+
371
+ If the component was already built, get it unchanged, otherwise build
372
+ it from the request's payload and return it.
373
+
374
+ If the `component_id` cannot be found, raise a KeyError.
375
+
376
+ """
377
+ if state := self.session.get_state(component_id):
378
+ return self.build(state["hx_name"], state, retrieve_state=False)
379
+ else:
380
+ logger.error(
381
+ "Component with id %s not found in session %s", component_id, self.session.id
382
+ )
383
+ return Destroy(component_id)
384
+
385
+ def build(
386
+ self,
387
+ component_name: str,
388
+ state: dict[str, Any],
389
+ retrieve_state: bool = True,
390
+ parent_id: str | None = None,
391
+ ):
392
+ """Build (or update) a component's state."""
393
+ from django.contrib.auth.models import AnonymousUser
394
+
395
+ with tracing_span("Repository.build", component_name=component_name):
396
+ # Retrieve state from storage
397
+ if retrieve_state and (component_id := state.get("id")):
398
+ state = (self.session.get_state(component_id) or {}) | state
399
+
400
+ # Patch it with whatever is the the GET params if needed
401
+ for patcher in _get_query_patchers(component_name):
402
+ state |= patcher.get_update_for_state(self.params)
403
+
404
+ # Inject component name and user
405
+ kwargs = state | {
406
+ "hx_name": component_name,
407
+ "user": None if isinstance(self.user, AnonymousUser) else self.user,
408
+ }
409
+ component = REGISTRY[component_name](**kwargs)
410
+
411
+ # Automatically track parent-child relationship if parent_id is specified
412
+ self.session.register_child(parent_id, component.id)
413
+
414
+ return component
415
+
416
+ def get_components_by_names(self, *names: str) -> Iterable[HtmxComponent]:
417
+ # go over awaken components
418
+ for name in names:
419
+ for state in self.session.get_all_states():
420
+ if state["hx_name"] == name:
421
+ yield self.build(name, {"id": state["id"]})
422
+
423
+ def render_html(
424
+ self,
425
+ component: HtmxComponent,
426
+ oob: str | None = None,
427
+ template: str | None = None,
428
+ lazy: bool | None = None,
429
+ context: dict[str, Any] | None = None,
430
+ ) -> SafeString:
431
+ lazy = component.lazy if lazy is None else lazy
432
+ with tracing_span(
433
+ "Repository.render_html",
434
+ component_name=component.hx_name,
435
+ oob=str(oob),
436
+ template=str(template),
437
+ lazy=str(lazy),
438
+ ):
439
+ self.session.store(component)
440
+
441
+ final_context = {
442
+ "htmx_repo": self,
443
+ "hx_oob": oob == "true",
444
+ "this": component,
445
+ }
446
+
447
+ if lazy:
448
+ template = template or component._template_name_lazy
449
+ final_context |= {"hx_lazy": True} | component._get_lazy_context() | (context or {})
450
+ else:
451
+ final_context |= component._get_context() if context is None else context
452
+
453
+ html = mark_safe(component._get_template(template)(final_context).strip())
454
+
455
+ # if performing some kind of append, the component has to be wrapped
456
+ if oob and oob != "true":
457
+ html = mark_safe(
458
+ "".join([
459
+ format_html('<div hx-swap-oob="{oob}">', oob=oob),
460
+ html,
461
+ "</div>",
462
+ ])
463
+ )
464
+ return html
465
+
466
+
467
+ @dataclass(slots=True)
468
+ class Session:
469
+ id: str
470
+
471
+ read: bool = False
472
+ is_dirty: bool = False
473
+
474
+ # dict[component_id -> state]
475
+ states: dict[str, str] = Field(default_factory=dict)
476
+
477
+ # dict[component_id -> set[signals]]
478
+ subscriptions: defaultdict[str, set[str]] = Field(default_factory=lambda: defaultdict(set))
479
+
480
+ # dict[parent_id -> set[child_ids]]
481
+ children: defaultdict[str, set[str]] = Field(default_factory=lambda: defaultdict(set))
482
+
483
+ # set[component_id]
484
+ unregistered: set[str] = Field(default_factory=set)
485
+
486
+ def store(self, component: HtmxComponent):
487
+ state = component.model_dump_json()
488
+ if self.states.get(component.id) != state:
489
+ self.states[component.id] = state
490
+ self.is_dirty = True
491
+
492
+ subscriptions = component._get_all_subscriptions()
493
+ if self.subscriptions[component.id] != subscriptions:
494
+ self.subscriptions[component.id] = subscriptions
495
+ self.is_dirty = True
496
+
497
+ def unregister_component(self, component_id: str):
498
+ # Recursively unregister all children first
499
+ if child_ids := self.children.get(component_id):
500
+ for child_id in child_ids.copy(): # Copy to avoid modification during iteration
501
+ self.unregister_component(child_id)
502
+
503
+ # Remove from parent's children list
504
+ for child_ids in self.children.values():
505
+ if component_id in child_ids:
506
+ child_ids.remove(component_id)
507
+ break
508
+
509
+ # Remove this component's children mapping
510
+ self.children.pop(component_id, None)
511
+
512
+ # Remove component state and subscriptions
513
+ self.states.pop(component_id, None)
514
+ self.subscriptions.pop(component_id, None)
515
+ self.unregistered.add(component_id)
516
+ self.is_dirty = True
517
+
518
+ def register_child(self, parent_id: str | None, child_id: str):
519
+ """Register a parent-child relationship between components."""
520
+ if parent_id and parent_id != child_id and child_id not in self.children[parent_id]:
521
+ self.children[parent_id].add(child_id)
522
+ self.is_dirty = True
523
+
524
+ def get_state(self, component_id: str) -> dict[str, Any] | None:
525
+ self._ensure_read()
526
+ if state := self.states.get(component_id):
527
+ return json.loads(state)
528
+
529
+ def get_component_ids_subscribed_to(self, signals: set[tuple[str, str]]) -> Iterable[str]:
530
+ self._ensure_read()
531
+ for component_id, subscribed_to in self.subscriptions.items():
532
+ # here we ignore signals emitted by the component it self
533
+ if subscribed_to.intersection(signal for signal, cid in signals if cid != component_id):
534
+ yield component_id
535
+
536
+ def get_all_states(self) -> Iterable[dict[str, Any]]:
537
+ self._ensure_read()
538
+ return [json.loads(state) for state in self.states.values()]
539
+
540
+ def _ensure_read(self):
541
+ if not self.read:
542
+ for component_id, state in conn.hgetall(f"{self.id}:states").items(): # type: ignore
543
+ component_id = component_id.decode()
544
+ if component_id == "__subs__":
545
+ # dict[component_id -> list[signals]]
546
+ for component_id, signals in json.loads(state).items():
547
+ self.subscriptions[component_id] = set(signals)
548
+ elif component_id == "__children__":
549
+ # dict[parent_id -> list[child_ids]]
550
+ for parent_id, child_ids in json.loads(state).items():
551
+ self.children[parent_id] = set(child_ids)
552
+ else:
553
+ self.states[component_id] = state.decode()
554
+ self.read = True
555
+
556
+ def flush(self, ttl: int = SESSION_TTL):
557
+ if self.is_dirty:
558
+ key = f"{self.id}:states"
559
+ if self.unregistered:
560
+ conn.hdel(key, *self.unregistered)
561
+ self.unregistered.clear()
562
+ if self.states:
563
+ conn.hset(key, mapping=self.states)
564
+ conn.hset(key, "__subs__", json.dumps(self.subscriptions))
565
+ conn.hset(key, "__children__", json.dumps(self.children))
566
+ conn.expire(key, ttl)
567
+ # The command MEMORY USAGE is considered slow:
568
+ # https://redis.io/docs/latest/commands/memory-usage/
569
+ #
570
+ # So we perform a trivial sampling with some prob to test the memory usage of the state.
571
+ probe = random.random() <= KEY_SIZE_SAMPLE_PROB
572
+ if probe and isinstance(usage := conn.memory_usage(key), int):
573
+ if KEY_SIZE_ERROR_THRESHOLD and usage > KEY_SIZE_ERROR_THRESHOLD:
574
+ logger.error(
575
+ "HTMX session's size (%s) exceeded the size threshold %s",
576
+ usage,
577
+ KEY_SIZE_ERROR_THRESHOLD,
578
+ )
579
+ elif KEY_SIZE_WARN_THRESHOLD and usage > KEY_SIZE_WARN_THRESHOLD:
580
+ logger.warning(
581
+ "HTMX session's size (%s) exceeded the size threshold %s",
582
+ usage,
583
+ KEY_SIZE_WARN_THRESHOLD,
584
+ )
585
+ self.is_dirty = False
djhtmx/settings.py ADDED
@@ -0,0 +1,49 @@
1
+ from datetime import timedelta
2
+
3
+ import redis
4
+ from django.conf import settings
5
+
6
+ VERSION = "2.0.4"
7
+ DEBUG = settings.DEBUG
8
+ CSRF_HEADER_NAME = settings.CSRF_HEADER_NAME[5:].replace("_", "-")
9
+ LOGIN_URL = settings.LOGIN_URL
10
+
11
+ SCRIPT_URLS = [
12
+ f"htmx/{VERSION}/htmx{'' if DEBUG else '.min'}.js",
13
+ "htmx/django.js",
14
+ ]
15
+
16
+ DEFAULT_LAZY_TEMPLATE = getattr(settings, "DJHTMX_DEFAULT_LAZY_TEMPLATE", "htmx/lazy.html")
17
+ conn = redis.from_url(getattr(settings, "DJHTMX_REDIS_URL", "redis://localhost/0"))
18
+ SESSION_TTL = getattr(settings, "DJHTMX_SESSION_TTL", 3600)
19
+ if isinstance(SESSION_TTL, timedelta):
20
+ SESSION_TTL = int(SESSION_TTL.total_seconds())
21
+
22
+
23
+ ENABLE_SENTRY_TRACING = getattr(settings, "DJHTMX_ENABLE_SENTRY_TRACING", True)
24
+ ENABLE_LOGFIRE_TRACING = getattr(settings, "DJHTMX_ENABLE_LOGFIRE_TRACING", False)
25
+
26
+
27
+ STRICT_EVENT_HANDLER_CONSISTENCY_CHECK = getattr(
28
+ settings,
29
+ "DJHTMX_STRICT_EVENT_HANDLER_CONSISTENCY_CHECK",
30
+ False,
31
+ )
32
+
33
+ KEY_SIZE_ERROR_THRESHOLD = getattr(
34
+ settings,
35
+ "DJHTMX_KEY_SIZE_ERROR_THRESHOLD",
36
+ 0,
37
+ )
38
+ KEY_SIZE_WARN_THRESHOLD = getattr(
39
+ settings,
40
+ "DJHTMX_KEY_SIZE_WARN_THRESHOLD",
41
+ 50 * 1024, # 50kb
42
+ )
43
+ KEY_SIZE_SAMPLE_PROB = getattr(
44
+ settings,
45
+ "DJHTMX_KEY_SIZE_SAMPLE_PROB",
46
+ 0.1,
47
+ )
48
+
49
+ STRICT_PUBLIC_BASE = getattr(settings, "DJHTMX_STRICT_PUBLIC_BASE", False)