django-ninja-aio-crud 2.17.0__py3-none-any.whl → 2.18.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- django_ninja_aio_crud-2.18.1.dist-info/METADATA +431 -0
- {django_ninja_aio_crud-2.17.0.dist-info → django_ninja_aio_crud-2.18.1.dist-info}/RECORD +10 -10
- ninja_aio/__init__.py +1 -1
- ninja_aio/models/serializers.py +715 -99
- ninja_aio/models/utils.py +235 -35
- ninja_aio/types.py +38 -0
- ninja_aio/views/api.py +108 -9
- ninja_aio/views/mixins.py +25 -9
- django_ninja_aio_crud-2.17.0.dist-info/METADATA +0 -379
- {django_ninja_aio_crud-2.17.0.dist-info → django_ninja_aio_crud-2.18.1.dist-info}/WHEEL +0 -0
- {django_ninja_aio_crud-2.17.0.dist-info → django_ninja_aio_crud-2.18.1.dist-info}/licenses/LICENSE +0 -0
ninja_aio/models/utils.py
CHANGED
|
@@ -94,6 +94,9 @@ class ModelUtil:
|
|
|
94
94
|
- Stateless wrapper; safe per-request instantiation.
|
|
95
95
|
"""
|
|
96
96
|
|
|
97
|
+
# Performance: Class-level cache for relation discovery (model structure is static)
|
|
98
|
+
_relation_cache: dict[tuple[type, str, str], list[str]] = {}
|
|
99
|
+
|
|
97
100
|
def __init__(
|
|
98
101
|
self, model: type["ModelSerializer"] | models.Model, serializer_class=None
|
|
99
102
|
):
|
|
@@ -192,6 +195,47 @@ class ModelUtil:
|
|
|
192
195
|
"""
|
|
193
196
|
return [field.name for field in self.model._meta.get_fields()]
|
|
194
197
|
|
|
198
|
+
def get_valid_input_fields(
|
|
199
|
+
self, is_serializer: bool, serializer: "ModelSerializer | None" = None
|
|
200
|
+
) -> set[str]:
|
|
201
|
+
"""
|
|
202
|
+
Get allowlist of valid field names for input validation.
|
|
203
|
+
|
|
204
|
+
Security: Prevents field injection by returning only fields that should
|
|
205
|
+
be accepted from user input.
|
|
206
|
+
|
|
207
|
+
Parameters
|
|
208
|
+
----------
|
|
209
|
+
is_serializer : bool
|
|
210
|
+
Whether using a ModelSerializer
|
|
211
|
+
serializer : ModelSerializer, optional
|
|
212
|
+
Serializer instance if applicable
|
|
213
|
+
|
|
214
|
+
Returns
|
|
215
|
+
-------
|
|
216
|
+
set[str]
|
|
217
|
+
Set of valid field names that can be accepted in input payloads
|
|
218
|
+
"""
|
|
219
|
+
valid_fields = set(self.model_fields)
|
|
220
|
+
|
|
221
|
+
# If using a serializer, also include custom fields
|
|
222
|
+
if is_serializer and serializer:
|
|
223
|
+
# Get all custom fields defined in the serializer
|
|
224
|
+
try:
|
|
225
|
+
# Custom fields are those that are not model fields but are defined
|
|
226
|
+
# in the serializer configuration
|
|
227
|
+
for schema_type in ['create', 'update', 'read', 'detail']:
|
|
228
|
+
try:
|
|
229
|
+
schema_fields = serializer.get_fields(schema_type)
|
|
230
|
+
if schema_fields:
|
|
231
|
+
valid_fields.update(schema_fields)
|
|
232
|
+
except (AttributeError, TypeError):
|
|
233
|
+
continue
|
|
234
|
+
except (AttributeError, TypeError):
|
|
235
|
+
pass
|
|
236
|
+
|
|
237
|
+
return valid_fields
|
|
238
|
+
|
|
195
239
|
@property
|
|
196
240
|
def model_name(self) -> str:
|
|
197
241
|
"""
|
|
@@ -518,6 +562,9 @@ class ModelUtil:
|
|
|
518
562
|
"""
|
|
519
563
|
Discover reverse relation names for safe prefetching.
|
|
520
564
|
|
|
565
|
+
Performance: Results are cached per (model, serializer_class, is_for) tuple
|
|
566
|
+
since model structure is static.
|
|
567
|
+
|
|
521
568
|
Parameters
|
|
522
569
|
----------
|
|
523
570
|
is_for : Literal["read", "detail"]
|
|
@@ -528,9 +575,17 @@ class ModelUtil:
|
|
|
528
575
|
list[str]
|
|
529
576
|
Relation attribute names.
|
|
530
577
|
"""
|
|
578
|
+
# Check cache first (performance optimization)
|
|
579
|
+
cache_key = (self.model, str(self.serializer_class), is_for)
|
|
580
|
+
if cache_key in self._relation_cache:
|
|
581
|
+
return self._relation_cache[cache_key].copy()
|
|
582
|
+
|
|
531
583
|
reverse_rels = self._get_read_optimizations(is_for).prefetch_related.copy()
|
|
532
584
|
if reverse_rels:
|
|
585
|
+
# Cache and return
|
|
586
|
+
self._relation_cache[cache_key] = reverse_rels
|
|
533
587
|
return reverse_rels
|
|
588
|
+
|
|
534
589
|
serializable_fields = self._get_serializable_field_names(is_for)
|
|
535
590
|
for f in serializable_fields:
|
|
536
591
|
field_obj = getattr(self.model, f)
|
|
@@ -542,6 +597,9 @@ class ModelUtil:
|
|
|
542
597
|
continue
|
|
543
598
|
if isinstance(field_obj, ReverseOneToOneDescriptor):
|
|
544
599
|
reverse_rels.append(field_obj.related.name)
|
|
600
|
+
|
|
601
|
+
# Cache the result
|
|
602
|
+
self._relation_cache[cache_key] = reverse_rels
|
|
545
603
|
return reverse_rels
|
|
546
604
|
|
|
547
605
|
def get_select_relateds(
|
|
@@ -550,6 +608,9 @@ class ModelUtil:
|
|
|
550
608
|
"""
|
|
551
609
|
Discover forward relation names for safe select_related.
|
|
552
610
|
|
|
611
|
+
Performance: Results are cached per (model, serializer_class, is_for) tuple
|
|
612
|
+
since model structure is static.
|
|
613
|
+
|
|
553
614
|
Parameters
|
|
554
615
|
----------
|
|
555
616
|
is_for : Literal["read", "detail"]
|
|
@@ -560,9 +621,17 @@ class ModelUtil:
|
|
|
560
621
|
list[str]
|
|
561
622
|
Relation attribute names.
|
|
562
623
|
"""
|
|
624
|
+
# Check cache first (performance optimization)
|
|
625
|
+
cache_key = (self.model, str(self.serializer_class) + "_select", is_for)
|
|
626
|
+
if cache_key in self._relation_cache:
|
|
627
|
+
return self._relation_cache[cache_key].copy()
|
|
628
|
+
|
|
563
629
|
select_rels = self._get_read_optimizations(is_for).select_related.copy()
|
|
564
630
|
if select_rels:
|
|
631
|
+
# Cache and return
|
|
632
|
+
self._relation_cache[cache_key] = select_rels
|
|
565
633
|
return select_rels
|
|
634
|
+
|
|
566
635
|
serializable_fields = self._get_serializable_field_names(is_for)
|
|
567
636
|
for f in serializable_fields:
|
|
568
637
|
field_obj = getattr(self.model, f)
|
|
@@ -571,12 +640,17 @@ class ModelUtil:
|
|
|
571
640
|
continue
|
|
572
641
|
if isinstance(field_obj, ForwardManyToOneDescriptor):
|
|
573
642
|
select_rels.append(f)
|
|
643
|
+
|
|
644
|
+
# Cache the result
|
|
645
|
+
self._relation_cache[cache_key] = select_rels
|
|
574
646
|
return select_rels
|
|
575
647
|
|
|
576
|
-
async def _get_field(self, k: str):
|
|
648
|
+
async def _get_field(self, k: str) -> models.Field:
|
|
649
|
+
"""Get Django field object for a given field name."""
|
|
577
650
|
return (await agetattr(self.model, k)).field
|
|
578
651
|
|
|
579
|
-
def _decode_binary(self, payload: dict, k: str, v: Any, field_obj: models.Field):
|
|
652
|
+
def _decode_binary(self, payload: dict, k: str, v: Any, field_obj: models.Field) -> None:
|
|
653
|
+
"""Decode base64-encoded binary field values in place."""
|
|
580
654
|
if not isinstance(field_obj, models.BinaryField):
|
|
581
655
|
return
|
|
582
656
|
try:
|
|
@@ -591,7 +665,8 @@ class ModelUtil:
|
|
|
591
665
|
k: str,
|
|
592
666
|
v: Any,
|
|
593
667
|
field_obj: models.Field,
|
|
594
|
-
):
|
|
668
|
+
) -> None:
|
|
669
|
+
"""Resolve foreign key ID to model instance in place."""
|
|
595
670
|
if not isinstance(field_obj, models.ForeignKey):
|
|
596
671
|
return
|
|
597
672
|
rel_util = ModelUtil(field_obj.related_model)
|
|
@@ -600,10 +675,11 @@ class ModelUtil:
|
|
|
600
675
|
|
|
601
676
|
async def _bump_object_from_schema(
|
|
602
677
|
self, obj: type["ModelSerializer"] | models.Model, schema: Schema
|
|
603
|
-
):
|
|
678
|
+
) -> dict:
|
|
679
|
+
"""Convert model instance to dict using Pydantic schema."""
|
|
604
680
|
return (await sync_to_async(schema.from_orm)(obj)).model_dump()
|
|
605
681
|
|
|
606
|
-
def _validate_read_params(self, request: HttpRequest, query_data: QuerySchema):
|
|
682
|
+
def _validate_read_params(self, request: HttpRequest, query_data: QuerySchema) -> None:
|
|
607
683
|
"""Validate required parameters for read operations."""
|
|
608
684
|
if request is None:
|
|
609
685
|
raise SerializeError(
|
|
@@ -667,12 +743,152 @@ class ModelUtil:
|
|
|
667
743
|
obj = await self.get_object(request, query_data=query_data, is_for=is_for)
|
|
668
744
|
return await self._bump_object_from_schema(obj, obj_schema)
|
|
669
745
|
|
|
746
|
+
def _validate_input_fields(
|
|
747
|
+
self, payload: dict, is_serializer: bool, serializer
|
|
748
|
+
) -> None:
|
|
749
|
+
"""
|
|
750
|
+
Validate non-custom payload keys against model fields.
|
|
751
|
+
|
|
752
|
+
Parameters
|
|
753
|
+
----------
|
|
754
|
+
payload : dict
|
|
755
|
+
Input payload to validate.
|
|
756
|
+
is_serializer : bool
|
|
757
|
+
Whether using a ModelSerializer.
|
|
758
|
+
serializer : ModelSerializer | Serializer
|
|
759
|
+
Serializer instance if applicable.
|
|
760
|
+
|
|
761
|
+
Raises
|
|
762
|
+
------
|
|
763
|
+
SerializeError
|
|
764
|
+
If invalid field names are found in payload.
|
|
765
|
+
"""
|
|
766
|
+
invalid_fields = []
|
|
767
|
+
for key in payload.keys():
|
|
768
|
+
# Skip custom fields - they're validated by Pydantic schema
|
|
769
|
+
if is_serializer and serializer.is_custom(key):
|
|
770
|
+
continue
|
|
771
|
+
# Validate non-custom fields exist on the model
|
|
772
|
+
if key not in self.model_fields:
|
|
773
|
+
invalid_fields.append(key)
|
|
774
|
+
|
|
775
|
+
if invalid_fields:
|
|
776
|
+
raise SerializeError(
|
|
777
|
+
{
|
|
778
|
+
"detail": f"Invalid field names in payload: {', '.join(sorted(invalid_fields))}",
|
|
779
|
+
"invalid_fields": sorted(invalid_fields),
|
|
780
|
+
},
|
|
781
|
+
400,
|
|
782
|
+
)
|
|
783
|
+
|
|
784
|
+
def _collect_custom_and_optional_fields(
|
|
785
|
+
self, payload: dict, is_serializer: bool, serializer
|
|
786
|
+
) -> tuple[dict[str, Any], list[str]]:
|
|
787
|
+
"""
|
|
788
|
+
Collect custom and optional fields from payload.
|
|
789
|
+
|
|
790
|
+
Parameters
|
|
791
|
+
----------
|
|
792
|
+
payload : dict
|
|
793
|
+
Input payload.
|
|
794
|
+
is_serializer : bool
|
|
795
|
+
Whether using a ModelSerializer.
|
|
796
|
+
serializer : ModelSerializer | Serializer
|
|
797
|
+
Serializer instance if applicable.
|
|
798
|
+
|
|
799
|
+
Returns
|
|
800
|
+
-------
|
|
801
|
+
tuple[dict[str, Any], list[str]]
|
|
802
|
+
(custom_fields_dict, optional_field_names)
|
|
803
|
+
"""
|
|
804
|
+
customs: dict[str, Any] = {}
|
|
805
|
+
optionals: list[str] = []
|
|
806
|
+
|
|
807
|
+
if not is_serializer:
|
|
808
|
+
return customs, optionals
|
|
809
|
+
|
|
810
|
+
customs = {
|
|
811
|
+
k: v
|
|
812
|
+
for k, v in payload.items()
|
|
813
|
+
if serializer.is_custom(k) and k not in self.model_fields
|
|
814
|
+
}
|
|
815
|
+
optionals = [
|
|
816
|
+
k for k, v in payload.items() if serializer.is_optional(k) and v is None
|
|
817
|
+
]
|
|
818
|
+
|
|
819
|
+
return customs, optionals
|
|
820
|
+
|
|
821
|
+
def _determine_skip_keys(
|
|
822
|
+
self, payload: dict, is_serializer: bool, serializer
|
|
823
|
+
) -> set[str]:
|
|
824
|
+
"""
|
|
825
|
+
Determine which keys to skip during model field processing.
|
|
826
|
+
|
|
827
|
+
Parameters
|
|
828
|
+
----------
|
|
829
|
+
payload : dict
|
|
830
|
+
Input payload.
|
|
831
|
+
is_serializer : bool
|
|
832
|
+
Whether using a ModelSerializer.
|
|
833
|
+
serializer : ModelSerializer | Serializer
|
|
834
|
+
Serializer instance if applicable.
|
|
835
|
+
|
|
836
|
+
Returns
|
|
837
|
+
-------
|
|
838
|
+
set[str]
|
|
839
|
+
Set of keys to skip.
|
|
840
|
+
"""
|
|
841
|
+
if not is_serializer:
|
|
842
|
+
return set()
|
|
843
|
+
|
|
844
|
+
skip_keys = {
|
|
845
|
+
k
|
|
846
|
+
for k, v in payload.items()
|
|
847
|
+
if (serializer.is_custom(k) and k not in self.model_fields)
|
|
848
|
+
or (serializer.is_optional(k) and v is None)
|
|
849
|
+
}
|
|
850
|
+
return skip_keys
|
|
851
|
+
|
|
852
|
+
async def _process_payload_fields(
|
|
853
|
+
self, request: HttpRequest, payload: dict, fields_to_process: list[tuple[str, Any]]
|
|
854
|
+
) -> None:
|
|
855
|
+
"""
|
|
856
|
+
Process payload fields: decode binary and resolve foreign keys.
|
|
857
|
+
|
|
858
|
+
Parameters
|
|
859
|
+
----------
|
|
860
|
+
request : HttpRequest
|
|
861
|
+
HTTP request object.
|
|
862
|
+
payload : dict
|
|
863
|
+
Payload dict to modify in place.
|
|
864
|
+
fields_to_process : list[tuple[str, Any]]
|
|
865
|
+
List of (field_name, field_value) tuples to process.
|
|
866
|
+
"""
|
|
867
|
+
if not fields_to_process:
|
|
868
|
+
return
|
|
869
|
+
|
|
870
|
+
# Fetch all field objects in parallel
|
|
871
|
+
field_tasks = [self._get_field(k) for k, _ in fields_to_process]
|
|
872
|
+
field_objs = await asyncio.gather(*field_tasks)
|
|
873
|
+
|
|
874
|
+
# Decode binary fields (synchronous, must be sequential)
|
|
875
|
+
for (k, v), field_obj in zip(fields_to_process, field_objs):
|
|
876
|
+
self._decode_binary(payload, k, v, field_obj)
|
|
877
|
+
|
|
878
|
+
# Resolve all FK fields in parallel
|
|
879
|
+
fk_tasks = [
|
|
880
|
+
self._resolve_fk(request, payload, k, v, field_obj)
|
|
881
|
+
for (k, v), field_obj in zip(fields_to_process, field_objs)
|
|
882
|
+
]
|
|
883
|
+
await asyncio.gather(*fk_tasks)
|
|
884
|
+
|
|
670
885
|
async def parse_input_data(self, request: HttpRequest, data: Schema):
|
|
671
886
|
"""
|
|
672
887
|
Transform inbound schema data to a model-ready payload.
|
|
673
888
|
|
|
674
889
|
Steps
|
|
675
890
|
-----
|
|
891
|
+
- Validate fields against allowlist (security).
|
|
676
892
|
- Strip custom fields (retain separately).
|
|
677
893
|
- Drop optional fields with None (ModelSerializer only).
|
|
678
894
|
- Decode BinaryField base64 values.
|
|
@@ -692,7 +908,7 @@ class ModelUtil:
|
|
|
692
908
|
Raises
|
|
693
909
|
------
|
|
694
910
|
SerializeError
|
|
695
|
-
On base64 decoding failure.
|
|
911
|
+
On base64 decoding failure or invalid field names.
|
|
696
912
|
"""
|
|
697
913
|
payload = data.model_dump(mode="json")
|
|
698
914
|
|
|
@@ -701,36 +917,20 @@ class ModelUtil:
|
|
|
701
917
|
)
|
|
702
918
|
serializer = self.serializer if self.with_serializer else self.model
|
|
703
919
|
|
|
704
|
-
#
|
|
705
|
-
|
|
706
|
-
optionals: list[str] = []
|
|
707
|
-
if is_serializer:
|
|
708
|
-
customs = {
|
|
709
|
-
k: v
|
|
710
|
-
for k, v in payload.items()
|
|
711
|
-
if serializer.is_custom(k) and k not in self.model_fields
|
|
712
|
-
}
|
|
713
|
-
optionals = [
|
|
714
|
-
k for k, v in payload.items() if serializer.is_optional(k) and v is None
|
|
715
|
-
]
|
|
920
|
+
# Security: Validate non-custom payload keys against model fields
|
|
921
|
+
self._validate_input_fields(payload, is_serializer, serializer)
|
|
716
922
|
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
for k, v in payload.items():
|
|
729
|
-
if k in skip_keys:
|
|
730
|
-
continue
|
|
731
|
-
field_obj = await self._get_field(k)
|
|
732
|
-
self._decode_binary(payload, k, v, field_obj)
|
|
733
|
-
await self._resolve_fk(request, payload, k, v, field_obj)
|
|
923
|
+
# Collect custom and optional fields
|
|
924
|
+
customs, optionals = self._collect_custom_and_optional_fields(
|
|
925
|
+
payload, is_serializer, serializer
|
|
926
|
+
)
|
|
927
|
+
|
|
928
|
+
# Determine which keys to skip during model field processing
|
|
929
|
+
skip_keys = self._determine_skip_keys(payload, is_serializer, serializer)
|
|
930
|
+
|
|
931
|
+
# Process payload fields - gather field objects in parallel for better performance
|
|
932
|
+
fields_to_process = [(k, v) for k, v in payload.items() if k not in skip_keys]
|
|
933
|
+
await self._process_payload_fields(request, payload, fields_to_process)
|
|
734
934
|
|
|
735
935
|
# Preserve original exclusion semantics (customs if present else optionals)
|
|
736
936
|
exclude_keys = customs.keys() or optionals
|
ninja_aio/types.py
CHANGED
|
@@ -10,6 +10,44 @@ SCHEMA_TYPES = Literal["In", "Out", "Detail", "Patch", "Related"]
|
|
|
10
10
|
VIEW_TYPES = Literal["list", "retrieve", "create", "update", "delete", "all"]
|
|
11
11
|
JwtKeys: TypeAlias = jwk.RSAKey | jwk.ECKey | jwk.OctKey
|
|
12
12
|
|
|
13
|
+
# Django ORM field lookup suffixes for QuerySet filtering
|
|
14
|
+
# See: https://docs.djangoproject.com/en/stable/ref/models/querysets/#field-lookups
|
|
15
|
+
DjangoLookup = Literal[
|
|
16
|
+
"exact",
|
|
17
|
+
"iexact",
|
|
18
|
+
"contains",
|
|
19
|
+
"icontains",
|
|
20
|
+
"in",
|
|
21
|
+
"gt",
|
|
22
|
+
"gte",
|
|
23
|
+
"lt",
|
|
24
|
+
"lte",
|
|
25
|
+
"startswith",
|
|
26
|
+
"istartswith",
|
|
27
|
+
"endswith",
|
|
28
|
+
"iendswith",
|
|
29
|
+
"range",
|
|
30
|
+
"date",
|
|
31
|
+
"year",
|
|
32
|
+
"iso_year",
|
|
33
|
+
"month",
|
|
34
|
+
"day",
|
|
35
|
+
"week",
|
|
36
|
+
"week_day",
|
|
37
|
+
"iso_week_day",
|
|
38
|
+
"quarter",
|
|
39
|
+
"time",
|
|
40
|
+
"hour",
|
|
41
|
+
"minute",
|
|
42
|
+
"second",
|
|
43
|
+
"isnull",
|
|
44
|
+
"regex",
|
|
45
|
+
"iregex",
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
# Set of valid Django lookup suffixes for runtime validation
|
|
49
|
+
VALID_DJANGO_LOOKUPS: set[str] = set(DjangoLookup.__args__)
|
|
50
|
+
|
|
13
51
|
|
|
14
52
|
class SerializerMeta(type):
|
|
15
53
|
"""Metaclass for serializers - extend with custom behavior as needed."""
|
ninja_aio/views/api.py
CHANGED
|
@@ -6,6 +6,7 @@ from ninja.pagination import paginate, AsyncPaginationBase, PageNumberPagination
|
|
|
6
6
|
from django.http import HttpRequest
|
|
7
7
|
from django.db.models import Model, QuerySet
|
|
8
8
|
from django.conf import settings
|
|
9
|
+
from django.core.exceptions import FieldDoesNotExist
|
|
9
10
|
from pydantic import create_model
|
|
10
11
|
|
|
11
12
|
from ninja_aio.schemas.helpers import ModelQuerySetSchema, QuerySchema, DecoratorsSchema
|
|
@@ -16,7 +17,7 @@ from ninja_aio.schemas import (
|
|
|
16
17
|
M2MRelationSchema,
|
|
17
18
|
)
|
|
18
19
|
from ninja_aio.helpers.api import ManyToManyAPI
|
|
19
|
-
from ninja_aio.types import ModelSerializerMeta, VIEW_TYPES
|
|
20
|
+
from ninja_aio.types import ModelSerializerMeta, VIEW_TYPES, VALID_DJANGO_LOOKUPS
|
|
20
21
|
from ninja_aio.decorators import unique_view, decorate_view, aatomic
|
|
21
22
|
from ninja_aio.models import serializers
|
|
22
23
|
|
|
@@ -31,7 +32,7 @@ class API:
|
|
|
31
32
|
auth: list | None = NOT_SET
|
|
32
33
|
router: Router = None
|
|
33
34
|
|
|
34
|
-
def views(self):
|
|
35
|
+
def views(self) -> None:
|
|
35
36
|
"""
|
|
36
37
|
Override this method to add your custom views. For example:
|
|
37
38
|
@self.router.get(some_path, response=some_schema)
|
|
@@ -65,13 +66,15 @@ class API:
|
|
|
65
66
|
"""
|
|
66
67
|
pass
|
|
67
68
|
|
|
68
|
-
def _add_views(self):
|
|
69
|
+
def _add_views(self) -> Router:
|
|
70
|
+
"""Register views decorated with @api_register."""
|
|
69
71
|
for name in dir(self.__class__):
|
|
70
72
|
method = getattr(self.__class__, name)
|
|
71
73
|
if hasattr(method, "_api_register"):
|
|
72
74
|
method._api_register(self)
|
|
75
|
+
return self.router
|
|
73
76
|
|
|
74
|
-
def add_views_to_route(self):
|
|
77
|
+
def add_views_to_route(self) -> Router:
|
|
75
78
|
return self.api.add_router(f"{self.api_route_path}", self._add_views())
|
|
76
79
|
|
|
77
80
|
|
|
@@ -322,6 +325,98 @@ class APIViewSet(API):
|
|
|
322
325
|
filter
|
|
323
326
|
)
|
|
324
327
|
|
|
328
|
+
def _is_lookup_suffix(self, part: str) -> bool:
|
|
329
|
+
"""
|
|
330
|
+
Check if a part is a valid Django lookup suffix.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
part: The part to check
|
|
334
|
+
|
|
335
|
+
Returns:
|
|
336
|
+
bool: True if the part is a valid lookup suffix
|
|
337
|
+
"""
|
|
338
|
+
return part in VALID_DJANGO_LOOKUPS
|
|
339
|
+
|
|
340
|
+
def _get_related_model(self, field):
|
|
341
|
+
"""
|
|
342
|
+
Extract the related model from a field if it exists.
|
|
343
|
+
|
|
344
|
+
Args:
|
|
345
|
+
field: The Django field object
|
|
346
|
+
|
|
347
|
+
Returns:
|
|
348
|
+
Model class or None
|
|
349
|
+
"""
|
|
350
|
+
if hasattr(field, 'related_model') and field.related_model:
|
|
351
|
+
return field.related_model
|
|
352
|
+
if hasattr(field, 'remote_field') and field.remote_field and hasattr(field.remote_field, 'model'):
|
|
353
|
+
return field.remote_field.model
|
|
354
|
+
return None
|
|
355
|
+
|
|
356
|
+
def _validate_non_relation_field(self, parts: list[str], i: int) -> bool:
|
|
357
|
+
"""
|
|
358
|
+
Validate a non-relation field that appears before the end of the path.
|
|
359
|
+
|
|
360
|
+
Args:
|
|
361
|
+
parts: List of field path parts
|
|
362
|
+
i: Current index in parts
|
|
363
|
+
|
|
364
|
+
Returns:
|
|
365
|
+
bool: True if valid, False otherwise
|
|
366
|
+
"""
|
|
367
|
+
if i >= len(parts) - 1:
|
|
368
|
+
return True
|
|
369
|
+
next_part = parts[i + 1]
|
|
370
|
+
return self._is_lookup_suffix(next_part)
|
|
371
|
+
|
|
372
|
+
def _validate_filter_field(self, field_path: str) -> bool:
|
|
373
|
+
"""
|
|
374
|
+
Validate that a filter field path corresponds to valid model fields.
|
|
375
|
+
|
|
376
|
+
Security: Prevents field injection attacks by ensuring only valid model
|
|
377
|
+
fields can be used in filters.
|
|
378
|
+
|
|
379
|
+
Args:
|
|
380
|
+
field_path: The field path to validate (e.g., "name", "author__name")
|
|
381
|
+
|
|
382
|
+
Returns:
|
|
383
|
+
bool: True if the field path is valid, False otherwise
|
|
384
|
+
|
|
385
|
+
Examples:
|
|
386
|
+
"name" -> validates against direct model field
|
|
387
|
+
"author__name" -> validates author is a relation, then name on related model
|
|
388
|
+
"created_at__gte" -> validates created_at field, lookup suffix is allowed
|
|
389
|
+
"""
|
|
390
|
+
if not field_path:
|
|
391
|
+
return False
|
|
392
|
+
|
|
393
|
+
parts = field_path.split('__')
|
|
394
|
+
current_model = self.model
|
|
395
|
+
|
|
396
|
+
# Iterate through the path, validating each part
|
|
397
|
+
for i, part in enumerate(parts):
|
|
398
|
+
# Check if this is the last part and might be a lookup suffix
|
|
399
|
+
is_last_part = i == len(parts) - 1
|
|
400
|
+
if is_last_part and self._is_lookup_suffix(part):
|
|
401
|
+
return True
|
|
402
|
+
|
|
403
|
+
try:
|
|
404
|
+
field = current_model._meta.get_field(part)
|
|
405
|
+
except (FieldDoesNotExist, AttributeError):
|
|
406
|
+
# Field doesn't exist on this model
|
|
407
|
+
return False
|
|
408
|
+
|
|
409
|
+
# If this is a relation field and not the last part, traverse to related model
|
|
410
|
+
related_model = self._get_related_model(field)
|
|
411
|
+
if related_model and not is_last_part:
|
|
412
|
+
current_model = related_model
|
|
413
|
+
elif not is_last_part:
|
|
414
|
+
# Non-relation field in the middle - must be followed by a lookup suffix
|
|
415
|
+
if not self._validate_non_relation_field(parts, i):
|
|
416
|
+
return False
|
|
417
|
+
|
|
418
|
+
return True
|
|
419
|
+
|
|
325
420
|
def _auth_view(self, view_type: str):
|
|
326
421
|
"""
|
|
327
422
|
Resolve auth for a specific HTTP verb; falls back to self.auth if NOT_SET.
|
|
@@ -329,16 +424,20 @@ class APIViewSet(API):
|
|
|
329
424
|
auth = getattr(self, f"{view_type}_auth", None)
|
|
330
425
|
return auth if auth is not NOT_SET else self.auth
|
|
331
426
|
|
|
332
|
-
def get_view_auth(self):
|
|
427
|
+
def get_view_auth(self) -> list | None:
|
|
428
|
+
"""Get authentication configuration for GET endpoints."""
|
|
333
429
|
return self._auth_view("get")
|
|
334
430
|
|
|
335
|
-
def post_view_auth(self):
|
|
431
|
+
def post_view_auth(self) -> list | None:
|
|
432
|
+
"""Get authentication configuration for POST endpoints."""
|
|
336
433
|
return self._auth_view("post")
|
|
337
434
|
|
|
338
|
-
def patch_view_auth(self):
|
|
435
|
+
def patch_view_auth(self) -> list | None:
|
|
436
|
+
"""Get authentication configuration for PATCH endpoints."""
|
|
339
437
|
return self._auth_view("patch")
|
|
340
438
|
|
|
341
|
-
def delete_view_auth(self):
|
|
439
|
+
def delete_view_auth(self) -> list | None:
|
|
440
|
+
"""Get authentication configuration for DELETE endpoints."""
|
|
342
441
|
return self._auth_view("delete")
|
|
343
442
|
|
|
344
443
|
def _generate_schema(self, fields: dict, name: str) -> Schema:
|
|
@@ -347,7 +446,7 @@ class APIViewSet(API):
|
|
|
347
446
|
"""
|
|
348
447
|
return create_model(f"{self.model_util.model_name}{name}", **fields)
|
|
349
448
|
|
|
350
|
-
def _generate_path_schema(self):
|
|
449
|
+
def _generate_path_schema(self) -> Schema:
|
|
351
450
|
"""
|
|
352
451
|
Schema containing only the primary key field for path resolution.
|
|
353
452
|
"""
|
ninja_aio/views/mixins.py
CHANGED
|
@@ -56,7 +56,9 @@ class IcontainsFilterViewSetMixin(APIViewSet):
|
|
|
56
56
|
**{
|
|
57
57
|
f"{key}__icontains": value
|
|
58
58
|
for key, value in filters.items()
|
|
59
|
-
if isinstance(value, str)
|
|
59
|
+
if isinstance(value, str)
|
|
60
|
+
and not self._is_special_filter(key)
|
|
61
|
+
and self._validate_filter_field(key)
|
|
60
62
|
}
|
|
61
63
|
)
|
|
62
64
|
|
|
@@ -98,7 +100,9 @@ class BooleanFilterViewSetMixin(APIViewSet):
|
|
|
98
100
|
**{
|
|
99
101
|
key: value
|
|
100
102
|
for key, value in filters.items()
|
|
101
|
-
if isinstance(value, bool)
|
|
103
|
+
if isinstance(value, bool)
|
|
104
|
+
and not self._is_special_filter(key)
|
|
105
|
+
and self._validate_filter_field(key)
|
|
102
106
|
}
|
|
103
107
|
)
|
|
104
108
|
|
|
@@ -140,7 +144,9 @@ class NumericFilterViewSetMixin(APIViewSet):
|
|
|
140
144
|
**{
|
|
141
145
|
key: value
|
|
142
146
|
for key, value in filters.items()
|
|
143
|
-
if isinstance(value, (int, float))
|
|
147
|
+
if isinstance(value, (int, float))
|
|
148
|
+
and not self._is_special_filter(key)
|
|
149
|
+
and self._validate_filter_field(key)
|
|
144
150
|
}
|
|
145
151
|
)
|
|
146
152
|
|
|
@@ -182,7 +188,9 @@ class DateFilterViewSetMixin(APIViewSet):
|
|
|
182
188
|
**{
|
|
183
189
|
f"{key}{self._compare_attr}": value
|
|
184
190
|
for key, value in filters.items()
|
|
185
|
-
if hasattr(value, "isoformat")
|
|
191
|
+
if hasattr(value, "isoformat")
|
|
192
|
+
and not self._is_special_filter(key)
|
|
193
|
+
and self._validate_filter_field(key)
|
|
186
194
|
}
|
|
187
195
|
)
|
|
188
196
|
|
|
@@ -341,7 +349,8 @@ class RelationFilterViewSetMixin(APIViewSet):
|
|
|
341
349
|
rel_filters = {}
|
|
342
350
|
for rel_filter in self.relations_filters:
|
|
343
351
|
value = filters.get(rel_filter.query_param)
|
|
344
|
-
|
|
352
|
+
# Validate the configured query_filter path for security
|
|
353
|
+
if value is not None and self._validate_filter_field(rel_filter.query_filter):
|
|
345
354
|
rel_filters[rel_filter.query_filter] = value
|
|
346
355
|
return base_qs.filter(**rel_filters) if rel_filters else base_qs
|
|
347
356
|
|
|
@@ -406,8 +415,15 @@ class MatchCaseFilterViewSetMixin(APIViewSet):
|
|
|
406
415
|
filter_match.cases.true if value else filter_match.cases.false
|
|
407
416
|
)
|
|
408
417
|
lookup = case_filter.query_filter
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
418
|
+
# Validate all filter fields in the lookup dictionary for security
|
|
419
|
+
validated_lookup = {
|
|
420
|
+
k: v
|
|
421
|
+
for k, v in lookup.items()
|
|
422
|
+
if self._validate_filter_field(k)
|
|
423
|
+
}
|
|
424
|
+
if validated_lookup:
|
|
425
|
+
if case_filter.include:
|
|
426
|
+
base_qs = base_qs.filter(**validated_lookup)
|
|
427
|
+
else:
|
|
428
|
+
base_qs = base_qs.exclude(**validated_lookup)
|
|
413
429
|
return base_qs
|