djhtmx 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,427 @@
1
+ import datetime
2
+ import enum
3
+ import inspect
4
+ import operator
5
+ import types
6
+ from collections import defaultdict
7
+ from collections.abc import Callable, Iterable, Sequence
8
+ from dataclasses import dataclass
9
+ from datetime import date
10
+ from functools import cache
11
+ from inspect import Parameter, _ParameterKind
12
+ from typing import (
13
+ Annotated,
14
+ Any,
15
+ Generic,
16
+ TypedDict,
17
+ TypeVar,
18
+ Union,
19
+ get_args,
20
+ get_origin,
21
+ get_type_hints,
22
+ is_typeddict,
23
+ )
24
+ from uuid import UUID
25
+
26
+ from django.apps import apps
27
+ from django.db import models
28
+ from django.db.models import Prefetch
29
+ from django.utils.datastructures import MultiValueDict
30
+ from pydantic import BeforeValidator, PlainSerializer, TypeAdapter
31
+
32
+ M = TypeVar("M", bound=models.Model)
33
+
34
+
35
+ @dataclass(slots=True)
36
+ class ModelRelatedField:
37
+ name: str
38
+ relation_name: str
39
+ related_model_name: str
40
+
41
+
42
+ MODEL_RELATED_FIELDS: dict[type[models.Model], tuple[ModelRelatedField, ...]] = {}
43
+
44
+
45
+ @dataclass(slots=True, unsafe_hash=True)
46
+ class ModelConfig:
47
+ """Annotation to configure fetching the models/querysets in pydantic models.
48
+
49
+ For this configuration to take place the pydantic model has to call `annotate_model`. HTMX
50
+ components require no extra steps.
51
+
52
+ """
53
+
54
+ lazy: bool = False
55
+ """If set to True, annotations of models.Model will return a _LazyModelProxy instead of the
56
+ actual model instance.
57
+
58
+ """
59
+
60
+ select_related: list[str] | tuple[str, ...] | None = None
61
+ """The arguments to `model.objects.select_related(*select_related)`."""
62
+
63
+ prefetch_related: list[str | Prefetch] | tuple[str | Prefetch, ...] | None = None
64
+ """The arguments to `model.objects.prefetch_related(*prefetch_related)`."""
65
+
66
+
67
+ _DEFAULT_MODEL_CONFIG = ModelConfig()
68
+
69
+
70
+ @dataclass(slots=True, init=False)
71
+ class _LazyModelProxy(Generic[M]): # noqa
72
+ """Deferred proxy for a Django model instance; only fetches from the database on access."""
73
+
74
+ __model: type[M]
75
+ __instance: M | None
76
+ __pk: Any | None
77
+ __select_related: Sequence[str] | None
78
+ __prefetch_related: Sequence[str | Prefetch] | None
79
+
80
+ def __init__(
81
+ self,
82
+ model: type[M],
83
+ value: Any,
84
+ model_annotation: ModelConfig | None = None,
85
+ ):
86
+ self.__model = model
87
+ if value is None or isinstance(value, model):
88
+ self.__instance = value
89
+ self.__pk = getattr(value, "pk", None)
90
+ else:
91
+ self.__instance = None
92
+ self.__pk = value
93
+ if model_annotation:
94
+ self.__select_related = model_annotation.select_related
95
+ self.__prefetch_related = model_annotation.prefetch_related
96
+ else:
97
+ self.__select_related = None
98
+ self.__prefetch_related = None
99
+
100
+ def __getattr__(self, name: str) -> Any:
101
+ if name == "pk":
102
+ return self.__pk
103
+ if self.__instance is None:
104
+ self.__ensure_instance()
105
+ return getattr(self.__instance, name)
106
+
107
+ def __ensure_instance(self):
108
+ if not self.__instance:
109
+ manager = self.__model.objects
110
+ if select_related := self.__select_related:
111
+ manager = manager.select_related(*select_related)
112
+ if prefetch_related := self.__prefetch_related:
113
+ manager = manager.prefetch_related(*prefetch_related)
114
+ self.__instance = manager.get(pk=self.__pk)
115
+ return self.__instance
116
+
117
+ def __repr__(self) -> str:
118
+ return f"<_LazyModelProxy model={self.__model}, pk={self.__pk}, instance={self.__instance}>"
119
+
120
+
121
+ @dataclass(slots=True)
122
+ class _ModelBeforeValidator(Generic[M]): # noqa
123
+ model: type[M]
124
+ model_config: ModelConfig
125
+
126
+ def __call__(self, value):
127
+ if self.model_config.lazy:
128
+ return self._get_lazy_proxy(value)
129
+ else:
130
+ return self._get_instance(value)
131
+
132
+ def _get_lazy_proxy(self, value):
133
+ if isinstance(value, _LazyModelProxy):
134
+ instance = value._LazyModelProxy__instance or value._LazyModelProxy__pk
135
+ return _LazyModelProxy(self.model, instance)
136
+ else:
137
+ return _LazyModelProxy(self.model, value)
138
+
139
+ def _get_instance(self, value):
140
+ if value is None or isinstance(value, self.model):
141
+ return value
142
+ # If a component has a lazy model proxy, and passes it down to another component that
143
+ # doesn't allow lazy proxies, we need to materialize it.
144
+ elif isinstance(value, _LazyModelProxy):
145
+ return value._LazyModelProxy__ensure_instance()
146
+ else:
147
+ manager = self.model.objects
148
+ if select_related := self.model_config.select_related:
149
+ manager = manager.select_related(*select_related)
150
+ if prefetch_related := self.model_config.prefetch_related:
151
+ manager = manager.prefetch_related(*prefetch_related)
152
+ return manager.get(pk=value)
153
+
154
+ @classmethod
155
+ @cache
156
+ def from_modelclass(cls, model: type[M], model_config: ModelConfig):
157
+ return cls(model, model_config=model_config)
158
+
159
+
160
+ @dataclass(slots=True)
161
+ class _ModelPlainSerializer(Generic[M]): # noqa
162
+ model: type[M]
163
+
164
+ def __call__(self, value):
165
+ return value.pk
166
+
167
+ @classmethod
168
+ @cache
169
+ def from_modelclass(cls, model: type[M]):
170
+ return cls(model)
171
+
172
+
173
+ def _Model(model: type[models.Model], model_config: ModelConfig | None = None):
174
+ assert issubclass_safe(model, models.Model)
175
+ model_config = model_config or _DEFAULT_MODEL_CONFIG
176
+ return Annotated[
177
+ model if not model_config.lazy else _LazyModelProxy[model],
178
+ BeforeValidator(_ModelBeforeValidator.from_modelclass(model, model_config)),
179
+ PlainSerializer(
180
+ func=_ModelPlainSerializer.from_modelclass(model),
181
+ return_type=guess_pk_type(model),
182
+ ),
183
+ ]
184
+
185
+
186
+ def _QuerySet(qs: type[models.QuerySet]):
187
+ [model] = [m for m in apps.get_models() if isinstance(m.objects.all(), qs)]
188
+ return Annotated[
189
+ qs,
190
+ BeforeValidator(lambda v: (v if isinstance(v, qs) else model.objects.filter(pk__in=v))),
191
+ PlainSerializer(
192
+ func=lambda v: (
193
+ [instance.pk for instance in v]
194
+ if v._result_cache
195
+ else list(v.values_list("pk", flat=True))
196
+ ),
197
+ return_type=guess_pk_type(model),
198
+ ),
199
+ ]
200
+
201
+
202
+ def annotate_model(annotation, *, model_config: ModelConfig | None = None):
203
+ if issubclass_safe(annotation, models.Model):
204
+ return _Model(annotation, model_config)
205
+ elif issubclass_safe(annotation, models.QuerySet):
206
+ return _QuerySet(annotation)
207
+ elif is_typeddict(annotation):
208
+ return TypedDict(
209
+ annotation.__name__, # type: ignore
210
+ {
211
+ k: annotate_model(v) # type: ignore
212
+ for k, v in get_type_hints(annotation).items()
213
+ },
214
+ )
215
+ elif type_ := get_origin(annotation):
216
+ if type_ is types.UnionType or type_ is Union:
217
+ type_ = Union
218
+ match get_args(annotation):
219
+ case ():
220
+ return type_
221
+ case (param,):
222
+ return type_[annotate_model(param)] # type: ignore
223
+ case params:
224
+ model_annotation = next(
225
+ (p for p in params if isinstance(p, ModelConfig)),
226
+ None,
227
+ )
228
+ return type_[*(annotate_model(p, model_config=model_annotation) for p in params)] # type: ignore
229
+ else:
230
+ return annotation
231
+
232
+
233
+ def guess_pk_type(model: type[models.Model]):
234
+ match model._meta.pk:
235
+ case models.UUIDField():
236
+ return UUID
237
+ case models.IntegerField():
238
+ return int
239
+ case _:
240
+ return str
241
+
242
+
243
+ def isinstance_safe(o, types):
244
+ try:
245
+ return isinstance(o, types)
246
+ except TypeError:
247
+ return False
248
+
249
+
250
+ def issubclass_safe(o, types):
251
+ try:
252
+ return issubclass(o, types)
253
+ except TypeError:
254
+ return False
255
+
256
+
257
+ # for state of old components
258
+
259
+
260
+ def get_function_parameters(
261
+ function: Callable,
262
+ exclude_kinds: tuple[_ParameterKind, ...] = (),
263
+ ) -> frozenset[str]:
264
+ return frozenset(
265
+ param.name
266
+ for param in inspect.signature(function).parameters.values()
267
+ if param.name != "self" and param.kind not in exclude_kinds
268
+ )
269
+
270
+
271
+ @cache
272
+ def get_related_fields(model):
273
+ related_fields = MODEL_RELATED_FIELDS.get(model)
274
+ if related_fields is None:
275
+ fields = []
276
+ for field in model._meta.get_fields():
277
+ if (
278
+ isinstance(field, models.ForeignKey)
279
+ and (relation_name := field.related_query_name())
280
+ and relation_name != "+"
281
+ ):
282
+ rel_meta = field.related_model._meta # type: ignore
283
+ fields.append(
284
+ ModelRelatedField(
285
+ name=field.attname,
286
+ relation_name=relation_name,
287
+ related_model_name=(f"{rel_meta.app_label}.{rel_meta.model_name}"),
288
+ )
289
+ )
290
+ related_fields = MODEL_RELATED_FIELDS[model] = tuple(fields)
291
+ return related_fields
292
+
293
+
294
+ # filtering
295
+
296
+
297
+ def filter_parameters(f: Callable, kwargs: dict[str, Any]):
298
+ has_kwargs = any(
299
+ param.kind == Parameter.VAR_KEYWORD for param in inspect.signature(f).parameters.values()
300
+ )
301
+ if has_kwargs:
302
+ return kwargs
303
+ else:
304
+ return {
305
+ param: value
306
+ for param, value in kwargs.items()
307
+ if param in inspect.signature(f).parameters
308
+ }
309
+
310
+
311
+ # Decoder for client requests
312
+
313
+
314
+ def parse_request_data(data: MultiValueDict[str, Any] | dict[str, Any]):
315
+ if not isinstance(data, MultiValueDict):
316
+ data = MultiValueDict({
317
+ key: value if isinstance(value, list) else [value] for key, value in data.items()
318
+ })
319
+ return _parse_obj(_extract_data(data))
320
+
321
+
322
+ def _extract_data(data: MultiValueDict[str, Any]):
323
+ for key in set(data):
324
+ if key.endswith("[]"):
325
+ value = data.getlist(key)
326
+ key = key.removesuffix("[]")
327
+ else:
328
+ value = data.get(key)
329
+ yield key.split("."), value
330
+
331
+
332
+ def _parse_obj(data: Iterable[tuple[list[str], Any]], output=None) -> dict[str, Any] | Any:
333
+ output = output or {}
334
+ arrays = defaultdict(lambda: defaultdict(dict)) # field -> index -> value
335
+ for key, value in data:
336
+ fragment, *tail = key
337
+ if "[" in fragment:
338
+ field_name = fragment[: fragment.index("[")]
339
+ index = int(fragment[fragment.index("[") + 1 : -1])
340
+ arrays[field_name][index] = (
341
+ _parse_obj([(tail, value)], arrays[field_name][index]) if tail else value
342
+ )
343
+ else:
344
+ output[fragment] = _parse_obj([(tail, value)]) if tail else value
345
+
346
+ for field, items in arrays.items():
347
+ output[field] = [v for _, v in sorted(items.items(), key=operator.itemgetter(0))]
348
+ return output
349
+
350
+
351
+ def get_event_handler_event_types(f: Callable[..., Any]) -> set[type]:
352
+ "Extract the types of the annotations of parameter 'event'."
353
+ event = get_type_hints(f)["event"]
354
+ origin = get_origin(event)
355
+ if origin is types.UnionType or origin is Union:
356
+ return {
357
+ arg for arg in get_args(event) if isinstance(arg, type) and arg is not types.NoneType
358
+ }
359
+ elif isinstance(event, type):
360
+ return {event}
361
+ else:
362
+ return set()
363
+
364
+
365
+ def get_annotation_adapter(annotation):
366
+ """Return a TypeAdapter for the annotation."""
367
+ if annotation is bool:
368
+ return infallible_bool_adapter
369
+
370
+ return TypeAdapter(annotation, config={"arbitrary_types_allowed": True})
371
+
372
+
373
+ # Infallible adapter for boolean values. 't' is True, everything else is
374
+ # False.
375
+ infallible_bool_adapter = TypeAdapter(
376
+ Annotated[
377
+ bool,
378
+ BeforeValidator(lambda v: v == "t"),
379
+ PlainSerializer(lambda v: "t" if v else "f"),
380
+ ]
381
+ )
382
+
383
+
384
+ def is_basic_type(ann):
385
+ """Returns True if the annotation is a simple type.
386
+
387
+ Simple types are:
388
+
389
+ - Simple Python types: ints, floats, strings, UUIDs, dates and datetimes, bools,
390
+ and the value None.
391
+
392
+ - Instances of a Django model (which will use the PK as a proxy)
393
+
394
+ - Instances of IntEnum or StrEnum.
395
+
396
+ - Instances of dict, tuple, list or set
397
+
398
+ """
399
+ return (
400
+ ann in _SIMPLE_TYPES
401
+ # __origin__ -> model in 'Annotated[model, BeforeValidator(...), PlainSerializer(...)]'
402
+ or issubclass_safe(getattr(ann, "__origin__", None), models.Model)
403
+ or issubclass_safe(ann, (enum.IntEnum, enum.StrEnum))
404
+ or is_collection_annotation(ann)
405
+ )
406
+
407
+
408
+ def is_union_of_basic(ann):
409
+ """Returns True Union of simple types (as is_simple_annotation)"""
410
+ type_ = get_origin(ann)
411
+ if type_ is types.UnionType or type_ is Union:
412
+ return all(is_basic_type(arg) for arg in get_args(ann))
413
+ return False
414
+
415
+
416
+ def is_simple_annotation(ann):
417
+ "Return True if the annotation is either simple or a Union of simple"
418
+ return is_basic_type(ann) or is_union_of_basic(ann)
419
+
420
+
421
+ def is_collection_annotation(ann):
422
+ return issubclass_safe(ann, _COLLECTION_TYPES)
423
+
424
+
425
+ Unset = object()
426
+ _SIMPLE_TYPES = (int, str, float, UUID, types.NoneType, date, datetime, bool)
427
+ _COLLECTION_TYPES = (dict, tuple, list, set)
djhtmx/json.py ADDED
@@ -0,0 +1,56 @@
1
+ import dataclasses
2
+ import enum
3
+ import json
4
+ from collections.abc import Generator
5
+
6
+ import orjson
7
+ from django.core.serializers import deserialize, serialize
8
+ from django.core.serializers.base import DeserializedObject
9
+ from django.core.serializers.json import DjangoJSONEncoder
10
+ from django.db import models
11
+ from pydantic import BaseModel
12
+
13
+ loads = orjson.loads
14
+
15
+
16
+ def dumps(obj):
17
+ return orjson.dumps(obj, default).decode()
18
+
19
+
20
+ def encode(instance: models.Model) -> str:
21
+ return serialize("json", [instance], cls=HtmxEncoder)
22
+
23
+
24
+ def decode(instance: str) -> models.Model:
25
+ obj: DeserializedObject = next(iter(deserialize("json", instance)))
26
+ obj.object.save = obj.save # type: ignore
27
+ return obj.object
28
+
29
+
30
+ class HtmxEncoder(json.JSONEncoder):
31
+ def default(self, o):
32
+ return default(o)
33
+
34
+
35
+ def default(o):
36
+ try:
37
+ return DjangoJSONEncoder().default(o)
38
+ except TypeError:
39
+ if hasattr(o, "__json__"):
40
+ return o.__json__()
41
+
42
+ if isinstance(o, models.Model):
43
+ return o.pk
44
+
45
+ if isinstance(o, Generator | set | frozenset):
46
+ return list(o)
47
+
48
+ if BaseModel and isinstance(o, BaseModel):
49
+ return o.model_dump()
50
+
51
+ if dataclasses.is_dataclass(o):
52
+ return dataclasses.asdict(o) # type: ignore
53
+
54
+ if isinstance(o, enum.Enum):
55
+ return o.value
56
+ raise
@@ -0,0 +1,71 @@
1
+ import sys
2
+ from collections import defaultdict
3
+
4
+ import djclick as click
5
+ from xotl.tools.future.itertools import delete_duplicates
6
+ from xotl.tools.objects import get_branch_subclasses as get_final_subclasses
7
+
8
+ from djhtmx.component import REGISTRY, HtmxComponent
9
+
10
+
11
+ def bold(msg):
12
+ return click.style(str(msg), bold=True)
13
+
14
+
15
+ def yellow(msg):
16
+ return click.style(str(msg), fg="yellow")
17
+
18
+
19
+ @click.group()
20
+ def htmx():
21
+ pass
22
+
23
+
24
+ @htmx.command("check-missing") # type: ignore
25
+ @click.argument("fname", type=click.File())
26
+ def check_missing(fname):
27
+ r"""Check if there are any missing HTMX components.
28
+
29
+ Expected usage:
30
+
31
+ find -type f -name '*.html' | while read f; do grep -P '{% htmx .(\w+)' -o $f \
32
+ | awk '{print $3}' | cut -b2-; done | sort -u \
33
+ | python manage.py htmx check-missing -
34
+
35
+ """
36
+ names = {n.strip() for n in fname.readlines()}
37
+ known = set(REGISTRY)
38
+ missing = list(names - known)
39
+ if missing:
40
+ missing.sort()
41
+ for n in missing:
42
+ click.echo(
43
+ f"Missing component detected {bold(yellow(n))}",
44
+ file=sys.stderr,
45
+ )
46
+ sys.exit(1)
47
+
48
+
49
+ @htmx.command("check-shadowing") # type: ignore
50
+ def check_shadowing():
51
+ "Checks if there are components that might shadow one another."
52
+ clashes = defaultdict(list)
53
+ for cls in get_final_subclasses(
54
+ HtmxComponent, # type: ignore
55
+ without_duplicates=True,
56
+ ):
57
+ name = cls.__name__
58
+ registered = REGISTRY.get(name)
59
+ if registered is not cls and registered is not None:
60
+ clashes[name].append(cls)
61
+ clashes[name].append(registered)
62
+
63
+ if clashes:
64
+ for name, shadows in clashes.items():
65
+ shadows = delete_duplicates(shadows)
66
+ if shadows:
67
+ click.echo(f"HtmxComponent {bold(name)} might be shadowed by:")
68
+ for shadow in shadows:
69
+ click.echo(f" - {bold(shadow.__module__)}.{bold(shadow.__name__)}")
70
+
71
+ sys.exit(1)
djhtmx/middleware.py ADDED
@@ -0,0 +1,36 @@
1
+ import asyncio
2
+ from collections.abc import Awaitable, Callable
3
+
4
+ from asgiref.sync import sync_to_async
5
+ from django.http import HttpRequest, HttpResponse
6
+
7
+
8
+ def middleware(
9
+ get_response: Callable[[HttpRequest], HttpResponse]
10
+ | Callable[[HttpRequest], Awaitable[HttpResponse]],
11
+ ):
12
+ """
13
+ Middleware function that wraps get_response and ensures the HTMX repository
14
+ is flushed and removed from the request after handling each request. It can handle
15
+ both sync and async get_response automatically.
16
+ """
17
+
18
+ if asyncio.iscoroutinefunction(get_response):
19
+ # Async version
20
+ async def middleware(request: HttpRequest) -> HttpResponse: # type: ignore
21
+ response = await get_response(request)
22
+ if repo := getattr(request, "htmx_repo", None):
23
+ await sync_to_async(repo.session.flush)()
24
+ delattr(request, "htmx_repo")
25
+ return response
26
+
27
+ else:
28
+ # Sync version
29
+ def middleware(request: HttpRequest) -> HttpResponse:
30
+ response = get_response(request)
31
+ if repo := getattr(request, "htmx_repo", None):
32
+ repo.session.flush()
33
+ delattr(request, "htmx_repo")
34
+ return response # type: ignore
35
+
36
+ return middleware