djhtmx 1.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,439 @@
1
+ import datetime
2
+ import enum
3
+ import inspect
4
+ import operator
5
+ import types
6
+ from collections import defaultdict
7
+ from collections.abc import Callable, Iterable, Sequence
8
+ from dataclasses import dataclass
9
+ from datetime import date
10
+ from functools import cache
11
+ from inspect import Parameter, _ParameterKind
12
+ from typing import (
13
+ Annotated,
14
+ Any,
15
+ Generic,
16
+ Literal,
17
+ TypedDict,
18
+ TypeVar,
19
+ Union,
20
+ get_args,
21
+ get_origin,
22
+ get_type_hints,
23
+ is_typeddict,
24
+ )
25
+ from uuid import UUID
26
+
27
+ from django.apps import apps
28
+ from django.db import models
29
+ from django.db.models import Prefetch
30
+ from django.utils.datastructures import MultiValueDict
31
+ from pydantic import BeforeValidator, PlainSerializer, TypeAdapter
32
+
33
+ M = TypeVar("M", bound=models.Model)
34
+
35
+
36
+ @dataclass(slots=True)
37
+ class ModelRelatedField:
38
+ name: str
39
+ relation_name: str
40
+ related_model_name: str
41
+
42
+
43
+ MODEL_RELATED_FIELDS: dict[type[models.Model], tuple[ModelRelatedField, ...]] = {}
44
+
45
+
46
+ @dataclass(slots=True, unsafe_hash=True)
47
+ class ModelConfig:
48
+ """Annotation to configure fetching the models/querysets in pydantic models.
49
+
50
+ For this configuration to take place the pydantic model has to call `annotate_model`. HTMX
51
+ components require no extra steps.
52
+
53
+ """
54
+
55
+ lazy: bool = False
56
+ """If set to True, annotations of models.Model will return a _LazyModelProxy instead of the
57
+ actual model instance.
58
+
59
+ """
60
+
61
+ select_related: list[str] | tuple[str, ...] | None = None
62
+ """The arguments to `model.objects.select_related(*select_related)`."""
63
+
64
+ prefetch_related: list[str | Prefetch] | tuple[str | Prefetch, ...] | None = None
65
+ """The arguments to `model.objects.prefetch_related(*prefetch_related)`."""
66
+
67
+
68
+ _DEFAULT_MODEL_CONFIG = ModelConfig()
69
+
70
+
71
+ @dataclass(slots=True, init=False)
72
+ class _LazyModelProxy(Generic[M]): # noqa
73
+ """Deferred proxy for a Django model instance; only fetches from the database on access."""
74
+
75
+ __model: type[M]
76
+ __instance: M | None
77
+ __pk: Any | None
78
+ __select_related: Sequence[str] | None
79
+ __prefetch_related: Sequence[str | Prefetch] | None
80
+
81
+ def __init__(
82
+ self,
83
+ model: type[M],
84
+ value: Any,
85
+ model_annotation: ModelConfig | None = None,
86
+ ):
87
+ self.__model = model
88
+ if value is None or isinstance(value, model):
89
+ self.__instance = value
90
+ self.__pk = getattr(value, "pk", None)
91
+ else:
92
+ self.__instance = None
93
+ self.__pk = value
94
+ if model_annotation:
95
+ self.__select_related = model_annotation.select_related
96
+ self.__prefetch_related = model_annotation.prefetch_related
97
+ else:
98
+ self.__select_related = None
99
+ self.__prefetch_related = None
100
+
101
+ def __getattr__(self, name: str) -> Any:
102
+ if name == "pk":
103
+ return self.__pk
104
+ if self.__instance is None:
105
+ self.__ensure_instance()
106
+ return getattr(self.__instance, name)
107
+
108
+ def __ensure_instance(self):
109
+ if not self.__instance:
110
+ manager = self.__model.objects
111
+ if select_related := self.__select_related:
112
+ manager = manager.select_related(*select_related)
113
+ if prefetch_related := self.__prefetch_related:
114
+ manager = manager.prefetch_related(*prefetch_related)
115
+ self.__instance = manager.get(pk=self.__pk)
116
+ return self.__instance
117
+
118
+ def __repr__(self) -> str:
119
+ return f"<_LazyModelProxy model={self.__model}, pk={self.__pk}, instance={self.__instance}>"
120
+
121
+
122
+ @dataclass(slots=True)
123
+ class _ModelBeforeValidator(Generic[M]): # noqa
124
+ model: type[M]
125
+ model_config: ModelConfig
126
+
127
+ def __call__(self, value):
128
+ if self.model_config.lazy:
129
+ return self._get_lazy_proxy(value)
130
+ else:
131
+ return self._get_instance(value)
132
+
133
+ def _get_lazy_proxy(self, value):
134
+ if isinstance(value, _LazyModelProxy):
135
+ instance = value._LazyModelProxy__instance or value._LazyModelProxy__pk
136
+ return _LazyModelProxy(self.model, instance)
137
+ else:
138
+ return _LazyModelProxy(self.model, value)
139
+
140
+ def _get_instance(self, value):
141
+ if value is None or isinstance(value, self.model):
142
+ return value
143
+ # If a component has a lazy model proxy, and passes it down to another component that
144
+ # doesn't allow lazy proxies, we need to materialize it.
145
+ elif isinstance(value, _LazyModelProxy):
146
+ return value._LazyModelProxy__ensure_instance()
147
+ else:
148
+ manager = self.model.objects
149
+ if select_related := self.model_config.select_related:
150
+ manager = manager.select_related(*select_related)
151
+ if prefetch_related := self.model_config.prefetch_related:
152
+ manager = manager.prefetch_related(*prefetch_related)
153
+ return manager.get(pk=value)
154
+
155
+ @classmethod
156
+ @cache
157
+ def from_modelclass(cls, model: type[M], model_config: ModelConfig):
158
+ return cls(model, model_config=model_config)
159
+
160
+
161
+ @dataclass(slots=True)
162
+ class _ModelPlainSerializer(Generic[M]): # noqa
163
+ model: type[M]
164
+
165
+ def __call__(self, value):
166
+ return value.pk
167
+
168
+ @classmethod
169
+ @cache
170
+ def from_modelclass(cls, model: type[M]):
171
+ return cls(model)
172
+
173
+
174
+ def _Model(model: type[models.Model], model_config: ModelConfig | None = None):
175
+ assert issubclass_safe(model, models.Model)
176
+ model_config = model_config or _DEFAULT_MODEL_CONFIG
177
+ return Annotated[
178
+ model if not model_config.lazy else _LazyModelProxy[model],
179
+ BeforeValidator(_ModelBeforeValidator.from_modelclass(model, model_config)),
180
+ PlainSerializer(
181
+ func=_ModelPlainSerializer.from_modelclass(model),
182
+ return_type=guess_pk_type(model),
183
+ ),
184
+ ]
185
+
186
+
187
+ def _QuerySet(qs: type[models.QuerySet]):
188
+ [model] = [m for m in apps.get_models() if isinstance(m.objects.all(), qs)]
189
+ return Annotated[
190
+ qs,
191
+ BeforeValidator(lambda v: (v if isinstance(v, qs) else model.objects.filter(pk__in=v))),
192
+ PlainSerializer(
193
+ func=lambda v: (
194
+ [instance.pk for instance in v]
195
+ if v._result_cache
196
+ else list(v.values_list("pk", flat=True))
197
+ ),
198
+ return_type=guess_pk_type(model),
199
+ ),
200
+ ]
201
+
202
+
203
+ def annotate_model(annotation, *, model_config: ModelConfig | None = None):
204
+ if issubclass_safe(annotation, models.Model):
205
+ return _Model(annotation, model_config)
206
+ elif issubclass_safe(annotation, models.QuerySet):
207
+ return _QuerySet(annotation)
208
+ elif is_typeddict(annotation):
209
+ return TypedDict(
210
+ annotation.__name__, # type: ignore
211
+ {
212
+ k: annotate_model(v) # type: ignore
213
+ for k, v in get_type_hints(annotation).items()
214
+ },
215
+ )
216
+ elif type_ := get_origin(annotation):
217
+ if type_ is types.UnionType or type_ is Union:
218
+ type_ = Union
219
+ match get_args(annotation):
220
+ case ():
221
+ return type_
222
+ case (param,):
223
+ return type_[annotate_model(param)] # type: ignore
224
+ case params:
225
+ model_annotation = next(
226
+ (p for p in params if isinstance(p, ModelConfig)),
227
+ None,
228
+ )
229
+ return type_[*(annotate_model(p, model_config=model_annotation) for p in params)] # type: ignore
230
+ else:
231
+ return annotation
232
+
233
+
234
+ def guess_pk_type(model: type[models.Model]):
235
+ match model._meta.pk:
236
+ case models.UUIDField():
237
+ return UUID
238
+ case models.IntegerField():
239
+ return int
240
+ case _:
241
+ return str
242
+
243
+
244
+ def isinstance_safe(o, types):
245
+ try:
246
+ return isinstance(o, types)
247
+ except TypeError:
248
+ return False
249
+
250
+
251
+ def issubclass_safe(o, types):
252
+ try:
253
+ return issubclass(o, types)
254
+ except TypeError:
255
+ return False
256
+
257
+
258
+ # for state of old components
259
+
260
+
261
+ def get_function_parameters(
262
+ function: Callable,
263
+ exclude_kinds: tuple[_ParameterKind, ...] = (),
264
+ ) -> frozenset[str]:
265
+ return frozenset(
266
+ param.name
267
+ for param in inspect.signature(function).parameters.values()
268
+ if param.name != "self" and param.kind not in exclude_kinds
269
+ )
270
+
271
+
272
+ @cache
273
+ def get_related_fields(model):
274
+ related_fields = MODEL_RELATED_FIELDS.get(model)
275
+ if related_fields is None:
276
+ fields = []
277
+ for field in model._meta.get_fields():
278
+ if (
279
+ isinstance(field, models.ForeignKey)
280
+ and (relation_name := field.related_query_name())
281
+ and relation_name != "+"
282
+ ):
283
+ rel_meta = field.related_model._meta # type: ignore
284
+ fields.append(
285
+ ModelRelatedField(
286
+ name=field.attname,
287
+ relation_name=relation_name,
288
+ related_model_name=(f"{rel_meta.app_label}.{rel_meta.model_name}"),
289
+ )
290
+ )
291
+ related_fields = MODEL_RELATED_FIELDS[model] = tuple(fields)
292
+ return related_fields
293
+
294
+
295
+ # filtering
296
+
297
+
298
+ def filter_parameters(f: Callable, kwargs: dict[str, Any]):
299
+ has_kwargs = any(
300
+ param.kind == Parameter.VAR_KEYWORD for param in inspect.signature(f).parameters.values()
301
+ )
302
+ if has_kwargs:
303
+ return kwargs
304
+ else:
305
+ return {
306
+ param: value
307
+ for param, value in kwargs.items()
308
+ if param in inspect.signature(f).parameters
309
+ }
310
+
311
+
312
+ # Decoder for client requests
313
+
314
+
315
+ def parse_request_data(data: MultiValueDict[str, Any] | dict[str, Any]):
316
+ if not isinstance(data, MultiValueDict):
317
+ data = MultiValueDict({
318
+ key: value if isinstance(value, list) else [value] for key, value in data.items()
319
+ })
320
+ return _parse_obj(_extract_data(data))
321
+
322
+
323
+ def _extract_data(data: MultiValueDict[str, Any]):
324
+ for key in set(data):
325
+ if key.endswith("[]"):
326
+ value = data.getlist(key)
327
+ key = key.removesuffix("[]")
328
+ else:
329
+ value = data.get(key)
330
+ yield key.split("."), value
331
+
332
+
333
+ def _parse_obj(data: Iterable[tuple[list[str], Any]], output=None) -> dict[str, Any] | Any:
334
+ output = output or {}
335
+ arrays = defaultdict(lambda: defaultdict(dict)) # field -> index -> value
336
+ for key, value in data:
337
+ fragment, *tail = key
338
+ if "[" in fragment:
339
+ field_name = fragment[: fragment.index("[")]
340
+ index = int(fragment[fragment.index("[") + 1 : -1])
341
+ arrays[field_name][index] = (
342
+ _parse_obj([(tail, value)], arrays[field_name][index]) if tail else value
343
+ )
344
+ else:
345
+ output[fragment] = _parse_obj([(tail, value)]) if tail else value
346
+
347
+ for field, items in arrays.items():
348
+ output[field] = [v for _, v in sorted(items.items(), key=operator.itemgetter(0))]
349
+ return output
350
+
351
+
352
+ def get_event_handler_event_types(f: Callable[..., Any]) -> set[type]:
353
+ "Extract the types of the annotations of parameter 'event'."
354
+ event = get_type_hints(f)["event"]
355
+ origin = get_origin(event)
356
+ if origin is types.UnionType or origin is Union:
357
+ return {
358
+ arg for arg in get_args(event) if isinstance(arg, type) and arg is not types.NoneType
359
+ }
360
+ elif isinstance(event, type):
361
+ return {event}
362
+ else:
363
+ return set()
364
+
365
+
366
+ def get_annotation_adapter(annotation):
367
+ """Return a TypeAdapter for the annotation."""
368
+ if annotation is bool:
369
+ return infallible_bool_adapter
370
+
371
+ return TypeAdapter(annotation, config={"arbitrary_types_allowed": True})
372
+
373
+
374
+ # Infallible adapter for boolean values. 't' is True, everything else is
375
+ # False.
376
+ infallible_bool_adapter = TypeAdapter(
377
+ Annotated[
378
+ bool,
379
+ BeforeValidator(lambda v: v == "t"),
380
+ PlainSerializer(lambda v: "t" if v else "f"),
381
+ ]
382
+ )
383
+
384
+
385
+ def is_literal_annotation(ann):
386
+ """Returns True if the annotation is a Literal type with simple values."""
387
+ return get_origin(ann) is Literal and all(type(arg) in _SIMPLE_TYPES for arg in get_args(ann))
388
+
389
+
390
+ def is_basic_type(ann):
391
+ """Returns True if the annotation is a simple type.
392
+
393
+ Simple types are:
394
+
395
+ - Simple Python types: ints, floats, strings, UUIDs, dates and datetimes, bools,
396
+ and the value None.
397
+
398
+ - Instances of a Django model (which will use the PK as a proxy)
399
+
400
+ - Instances of IntEnum or StrEnum.
401
+
402
+ - Instances of dict, tuple, list or set
403
+
404
+ - Literal types with simple values
405
+
406
+ """
407
+ return (
408
+ ann in _SIMPLE_TYPES
409
+ # __origin__ -> model in 'Annotated[model, BeforeValidator(...), PlainSerializer(...)]'
410
+ or issubclass_safe(getattr(ann, "__origin__", None), models.Model)
411
+ or issubclass_safe(ann, (enum.IntEnum, enum.StrEnum))
412
+ or is_collection_annotation(ann)
413
+ or is_literal_annotation(ann)
414
+ )
415
+
416
+
417
+ def is_union_of_basic(ann):
418
+ """Returns True Union of simple types (as is_simple_annotation)"""
419
+ type_ = get_origin(ann)
420
+ if type_ is types.UnionType or type_ is Union:
421
+ return all(is_basic_type(arg) for arg in get_args(ann))
422
+ return False
423
+
424
+
425
+ def is_simple_annotation(ann):
426
+ "Return True if the annotation is either simple or a Union of simple"
427
+ return is_basic_type(ann) or is_union_of_basic(ann)
428
+
429
+
430
+ def is_collection_annotation(ann):
431
+ if isinstance(ann, types.GenericAlias):
432
+ return issubclass_safe(ann.__origin__, _COLLECTION_TYPES)
433
+ else:
434
+ return issubclass_safe(ann, _COLLECTION_TYPES)
435
+
436
+
437
+ Unset = object()
438
+ _SIMPLE_TYPES = (int, str, float, UUID, types.NoneType, date, datetime, bool)
439
+ _COLLECTION_TYPES = (dict, tuple, list, set, defaultdict)
djhtmx/json.py ADDED
@@ -0,0 +1,56 @@
1
+ import dataclasses
2
+ import enum
3
+ import json
4
+ from collections.abc import Generator
5
+
6
+ import orjson
7
+ from django.core.serializers import deserialize, serialize
8
+ from django.core.serializers.base import DeserializedObject
9
+ from django.core.serializers.json import DjangoJSONEncoder
10
+ from django.db import models
11
+ from pydantic import BaseModel
12
+
13
+ loads = orjson.loads
14
+
15
+
16
+ def dumps(obj):
17
+ return orjson.dumps(obj, default).decode()
18
+
19
+
20
+ def encode(instance: models.Model) -> str:
21
+ return serialize("json", [instance], cls=HtmxEncoder)
22
+
23
+
24
+ def decode(instance: str) -> models.Model:
25
+ obj: DeserializedObject = next(iter(deserialize("json", instance)))
26
+ obj.object.save = obj.save # type: ignore
27
+ return obj.object
28
+
29
+
30
+ class HtmxEncoder(json.JSONEncoder):
31
+ def default(self, o):
32
+ return default(o)
33
+
34
+
35
+ def default(o):
36
+ try:
37
+ return DjangoJSONEncoder().default(o)
38
+ except TypeError:
39
+ if hasattr(o, "__json__"):
40
+ return o.__json__()
41
+
42
+ if isinstance(o, models.Model):
43
+ return o.pk
44
+
45
+ if isinstance(o, Generator | set | frozenset):
46
+ return list(o)
47
+
48
+ if BaseModel and isinstance(o, BaseModel):
49
+ return o.model_dump()
50
+
51
+ if dataclasses.is_dataclass(o):
52
+ return dataclasses.asdict(o) # type: ignore
53
+
54
+ if isinstance(o, enum.Enum):
55
+ return o.value
56
+ raise
@@ -0,0 +1,123 @@
1
+ import sys
2
+ from collections import defaultdict
3
+
4
+ import djclick as click
5
+ from xotl.tools.future.itertools import delete_duplicates
6
+ from xotl.tools.objects import get_branch_subclasses as get_final_subclasses
7
+
8
+ from djhtmx.component import REGISTRY, HtmxComponent
9
+
10
+
11
+ def bold(msg):
12
+ return click.style(str(msg), bold=True)
13
+
14
+
15
+ def yellow(msg):
16
+ return click.style(str(msg), fg="yellow")
17
+
18
+
19
+ @click.group()
20
+ def htmx():
21
+ pass
22
+
23
+
24
+ @htmx.command("check-missing") # type: ignore
25
+ @click.argument("fname", type=click.File())
26
+ def check_missing(fname):
27
+ r"""Check if there are any missing HTMX components.
28
+
29
+ Expected usage:
30
+
31
+ find -type f -name '*.html' | while read f; do grep -P '{% htmx .(\w+)' -o $f \
32
+ | awk '{print $3}' | cut -b2-; done | sort -u \
33
+ | python manage.py htmx check-missing -
34
+
35
+ """
36
+ names = {n.strip() for n in fname.readlines()}
37
+ known = set(REGISTRY)
38
+ missing = list(names - known)
39
+ if missing:
40
+ missing.sort()
41
+ for n in missing:
42
+ click.echo(
43
+ f"Missing component detected {bold(yellow(n))}",
44
+ file=sys.stderr,
45
+ )
46
+ sys.exit(1)
47
+
48
+
49
+ @htmx.command("check-unused") # type: ignore
50
+ @click.argument("fname", type=click.File())
51
+ def check_unused(fname):
52
+ r"""Check if there are any unused HTMX components.
53
+
54
+ Expected usage:
55
+
56
+ find -type f -name '*.html' | while read f; do grep -P '{% htmx .(\\w+)' -o $f \
57
+ | awk '{print $3}' | cut -b2-; done | sort -u \
58
+ | python manage.py htmx check-unused -
59
+
60
+ """
61
+ names = {n.strip() for n in fname.readlines()}
62
+ known = set(REGISTRY)
63
+ unused = list(known - names)
64
+ if unused:
65
+ unused.sort()
66
+ for n in unused:
67
+ click.echo(
68
+ f"Unused component detected {bold(yellow(n))}",
69
+ file=sys.stderr,
70
+ )
71
+ sys.exit(1)
72
+
73
+
74
+ @htmx.command("check-unused-non-public") # type: ignore
75
+ def check_unused_non_public():
76
+ """Check if there are any unused non-public HTMX components.
77
+
78
+ Non-public components that are final subclasses (have no subclasses themselves)
79
+ are considered unused since they can't be instantiated from templates and serve
80
+ no purpose as base classes.
81
+ """
82
+ final_subclasses = set(
83
+ get_final_subclasses(
84
+ HtmxComponent, # type: ignore
85
+ without_duplicates=True,
86
+ )
87
+ )
88
+ registered = set(REGISTRY.values())
89
+ unused_non_public = list(final_subclasses - registered)
90
+
91
+ if unused_non_public:
92
+ unused_non_public.sort(key=lambda cls: cls.__name__)
93
+ for cls in unused_non_public:
94
+ click.echo(
95
+ f"Unused non-public component detected {bold(yellow(cls.__name__))}",
96
+ file=sys.stderr,
97
+ )
98
+ sys.exit(1)
99
+
100
+
101
+ @htmx.command("check-shadowing") # type: ignore
102
+ def check_shadowing():
103
+ "Checks if there are components that might shadow one another."
104
+ clashes = defaultdict(list)
105
+ for cls in get_final_subclasses(
106
+ HtmxComponent, # type: ignore
107
+ without_duplicates=True,
108
+ ):
109
+ name = cls.__name__
110
+ registered = REGISTRY.get(name)
111
+ if registered is not cls and registered is not None:
112
+ clashes[name].append(cls)
113
+ clashes[name].append(registered)
114
+
115
+ if clashes:
116
+ for name, shadows in clashes.items():
117
+ shadows = delete_duplicates(shadows)
118
+ if shadows:
119
+ click.echo(f"HtmxComponent {bold(name)} might be shadowed by:")
120
+ for shadow in shadows:
121
+ click.echo(f" - {bold(shadow.__module__)}.{bold(shadow.__name__)}")
122
+
123
+ sys.exit(1)
djhtmx/middleware.py ADDED
@@ -0,0 +1,36 @@
1
+ import asyncio
2
+ from collections.abc import Awaitable, Callable
3
+
4
+ from asgiref.sync import sync_to_async
5
+ from django.http import HttpRequest, HttpResponse
6
+
7
+
8
+ def middleware(
9
+ get_response: Callable[[HttpRequest], HttpResponse]
10
+ | Callable[[HttpRequest], Awaitable[HttpResponse]],
11
+ ):
12
+ """
13
+ Middleware function that wraps get_response and ensures the HTMX repository
14
+ is flushed and removed from the request after handling each request. It can handle
15
+ both sync and async get_response automatically.
16
+ """
17
+
18
+ if asyncio.iscoroutinefunction(get_response):
19
+ # Async version
20
+ async def middleware(request: HttpRequest) -> HttpResponse: # type: ignore
21
+ response = await get_response(request)
22
+ if repo := getattr(request, "htmx_repo", None):
23
+ await sync_to_async(repo.session.flush)()
24
+ delattr(request, "htmx_repo")
25
+ return response
26
+
27
+ else:
28
+ # Sync version
29
+ def middleware(request: HttpRequest) -> HttpResponse:
30
+ response = get_response(request)
31
+ if repo := getattr(request, "htmx_repo", None):
32
+ repo.session.flush()
33
+ delattr(request, "htmx_repo")
34
+ return response # type: ignore
35
+
36
+ return middleware