rapyer 1.1.4__tar.gz → 1.1.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {rapyer-1.1.4 → rapyer-1.1.5}/PKG-INFO +6 -1
  2. {rapyer-1.1.4 → rapyer-1.1.5}/pyproject.toml +9 -6
  3. {rapyer-1.1.4 → rapyer-1.1.5}/rapyer/base.py +133 -19
  4. {rapyer-1.1.4 → rapyer-1.1.5}/rapyer/config.py +6 -1
  5. rapyer-1.1.5/rapyer/errors/__init__.py +17 -0
  6. {rapyer-1.1.4 → rapyer-1.1.5}/rapyer/errors/base.py +12 -0
  7. rapyer-1.1.5/rapyer/fields/__init__.py +5 -0
  8. rapyer-1.1.5/rapyer/fields/safe_load.py +27 -0
  9. {rapyer-1.1.4 → rapyer-1.1.5}/rapyer/init.py +10 -1
  10. rapyer-1.1.5/rapyer/scripts.py +86 -0
  11. {rapyer-1.1.4 → rapyer-1.1.5}/rapyer/types/base.py +20 -0
  12. {rapyer-1.1.4 → rapyer-1.1.5}/rapyer/types/convert.py +10 -2
  13. {rapyer-1.1.4 → rapyer-1.1.5}/rapyer/types/dct.py +13 -2
  14. {rapyer-1.1.4 → rapyer-1.1.5}/rapyer/types/lst.py +35 -2
  15. {rapyer-1.1.4 → rapyer-1.1.5}/rapyer/utils/fields.py +25 -2
  16. rapyer-1.1.4/rapyer/errors/__init__.py +0 -8
  17. rapyer-1.1.4/rapyer/fields/__init__.py +0 -4
  18. {rapyer-1.1.4 → rapyer-1.1.5}/README.md +0 -0
  19. {rapyer-1.1.4 → rapyer-1.1.5}/rapyer/__init__.py +0 -0
  20. {rapyer-1.1.4 → rapyer-1.1.5}/rapyer/context.py +0 -0
  21. {rapyer-1.1.4 → rapyer-1.1.5}/rapyer/fields/expression.py +0 -0
  22. {rapyer-1.1.4 → rapyer-1.1.5}/rapyer/fields/index.py +0 -0
  23. {rapyer-1.1.4 → rapyer-1.1.5}/rapyer/fields/key.py +0 -0
  24. {rapyer-1.1.4 → rapyer-1.1.5}/rapyer/links.py +0 -0
  25. {rapyer-1.1.4 → rapyer-1.1.5}/rapyer/types/__init__.py +0 -0
  26. {rapyer-1.1.4 → rapyer-1.1.5}/rapyer/types/byte.py +0 -0
  27. {rapyer-1.1.4 → rapyer-1.1.5}/rapyer/types/datetime.py +0 -0
  28. {rapyer-1.1.4 → rapyer-1.1.5}/rapyer/types/float.py +0 -0
  29. {rapyer-1.1.4 → rapyer-1.1.5}/rapyer/types/init.py +0 -0
  30. {rapyer-1.1.4 → rapyer-1.1.5}/rapyer/types/integer.py +0 -0
  31. {rapyer-1.1.4 → rapyer-1.1.5}/rapyer/types/string.py +0 -0
  32. {rapyer-1.1.4 → rapyer-1.1.5}/rapyer/typing_support.py +0 -0
  33. {rapyer-1.1.4 → rapyer-1.1.5}/rapyer/utils/__init__.py +0 -0
  34. {rapyer-1.1.4 → rapyer-1.1.5}/rapyer/utils/annotation.py +0 -0
  35. {rapyer-1.1.4 → rapyer-1.1.5}/rapyer/utils/pythonic.py +0 -0
  36. {rapyer-1.1.4 → rapyer-1.1.5}/rapyer/utils/redis.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rapyer
3
- Version: 1.1.4
3
+ Version: 1.1.5
4
4
  Summary: Pydantic models with Redis as the backend
5
5
  License: MIT
6
6
  Keywords: redis,redis-json,pydantic,pydantic-v2,orm,database,async,nosql,cache,key-value,data-modeling,python,backend,storage,serialization,validation
@@ -23,7 +23,12 @@ Classifier: Topic :: Database :: Database Engines/Servers
23
23
  Classifier: Topic :: Software Development :: Libraries :: Python Modules
24
24
  Classifier: Typing :: Typed
25
25
  Classifier: Operating System :: OS Independent
26
+ Provides-Extra: test
27
+ Requires-Dist: fakeredis[json,lua] (>=2.20.0) ; extra == "test"
26
28
  Requires-Dist: pydantic (>=2.11.0,<2.13.0)
29
+ Requires-Dist: pytest (>=8.4.2) ; extra == "test"
30
+ Requires-Dist: pytest-asyncio (>=0.25.0) ; extra == "test"
31
+ Requires-Dist: pytest-cov (>=6.0.0) ; extra == "test"
27
32
  Requires-Dist: redis[async] (>=6.0.0,<7.1.0)
28
33
  Project-URL: Bug Tracker, https://github.com/imaginary-cherry/rapyer/issues
29
34
  Project-URL: Changelog, https://github.com/imaginary-cherry/rapyer/releases
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [project]
6
6
  name = "rapyer"
7
- version = "1.1.4"
7
+ version = "1.1.5"
8
8
  description = "Pydantic models with Redis as the backend"
9
9
  authors = [{name = "YedidyaHKfir", email = "yedidyakfir@gmail.com"}]
10
10
  readme = "README.md"
@@ -51,6 +51,14 @@ dependencies = [
51
51
  "pydantic>=2.11.0, <2.13.0",
52
52
  ]
53
53
 
54
+ [project.optional-dependencies]
55
+ test = [
56
+ "pytest>=8.4.2",
57
+ "pytest-asyncio>=0.25.0",
58
+ "pytest-cov>=6.0.0",
59
+ "fakeredis[lua,json]>=2.20.0",
60
+ ]
61
+
54
62
  [project.urls]
55
63
  Homepage = "https://imaginary-cherry.github.io/rapyer/"
56
64
  Documentation = "https://imaginary-cherry.github.io/rapyer/"
@@ -67,11 +75,6 @@ packages = [{include = "rapyer"}]
67
75
  black = "^25.9.0"
68
76
  mypy = "^1.0.0"
69
77
 
70
- [tool.poetry.group.tests.dependencies]
71
- pytest = "^8.4.2"
72
- pytest-asyncio = "^0.25.0"
73
- pytest-cov = "^6.0.0"
74
-
75
78
  [tool.coverage.run]
76
79
  source = ["rapyer"]
77
80
  omit = ["*/tests/*", "*/test_*"]
@@ -2,6 +2,7 @@ import asyncio
2
2
  import base64
3
3
  import contextlib
4
4
  import functools
5
+ import logging
5
6
  import pickle
6
7
  import uuid
7
8
  from contextlib import AbstractAsyncContextManager
@@ -18,16 +19,24 @@ from pydantic import (
18
19
  from pydantic_core.core_schema import FieldSerializationInfo, ValidationInfo
19
20
  from redis.commands.search.index_definition import IndexDefinition, IndexType
20
21
  from redis.commands.search.query import Query
22
+ from redis.exceptions import NoScriptError
21
23
  from typing_extensions import deprecated
22
24
 
23
25
  from rapyer.config import RedisConfig
24
26
  from rapyer.context import _context_var, _context_xx_pipe
25
- from rapyer.errors.base import KeyNotFound, UnsupportedIndexedFieldError
27
+ from rapyer.errors.base import (
28
+ KeyNotFound,
29
+ PersistentNoScriptError,
30
+ UnsupportedIndexedFieldError,
31
+ CantSerializeRedisValueError,
32
+ )
26
33
  from rapyer.fields.expression import ExpressionField, AtomicField, Expression
27
34
  from rapyer.fields.index import IndexAnnotation
28
35
  from rapyer.fields.key import KeyAnnotation
36
+ from rapyer.fields.safe_load import SafeLoadAnnotation
29
37
  from rapyer.links import REDIS_SUPPORTED_LINK
30
- from rapyer.types.base import RedisType, REDIS_DUMP_FLAG_NAME
38
+ from rapyer.scripts import handle_noscript_error
39
+ from rapyer.types.base import RedisType, REDIS_DUMP_FLAG_NAME, FAILED_FIELDS_KEY
31
40
  from rapyer.types.convert import RedisConverter
32
41
  from rapyer.typing_support import Self, Unpack
33
42
  from rapyer.utils.annotation import (
@@ -36,7 +45,11 @@ from rapyer.utils.annotation import (
36
45
  field_with_flag,
37
46
  DYNAMIC_CLASS_DOC,
38
47
  )
39
- from rapyer.utils.fields import get_all_pydantic_annotation, is_redis_field
48
+ from rapyer.utils.fields import (
49
+ get_all_pydantic_annotation,
50
+ is_redis_field,
51
+ is_type_json_serializable,
52
+ )
40
53
  from rapyer.utils.pythonic import safe_issubclass
41
54
  from rapyer.utils.redis import (
42
55
  acquire_lock,
@@ -44,26 +57,45 @@ from rapyer.utils.redis import (
44
57
  refresh_ttl_if_needed,
45
58
  )
46
59
 
60
+ logger = logging.getLogger("rapyer")
47
61
 
48
- def make_pickle_field_serializer(field: str):
62
+
63
+ def make_pickle_field_serializer(
64
+ field: str, safe_load: bool = False, can_json: bool = False
65
+ ):
49
66
  @field_serializer(field, when_used="json-unless-none")
50
- def pickle_field_serializer(v, info: FieldSerializationInfo):
67
+ @classmethod
68
+ def pickle_field_serializer(cls, v, info: FieldSerializationInfo):
51
69
  ctx = info.context or {}
52
70
  should_serialize_redis = ctx.get(REDIS_DUMP_FLAG_NAME, False)
53
- if should_serialize_redis:
71
+ # Skip pickling if field CAN be JSON serialized AND user prefers JSON dump
72
+ field_can_be_json = can_json and cls.Meta.prefer_normal_json_dump
73
+ if should_serialize_redis and not field_can_be_json:
54
74
  return base64.b64encode(pickle.dumps(v)).decode("utf-8")
55
75
  return v
56
76
 
57
77
  pickle_field_serializer.__name__ = f"__serialize_{field}"
58
78
 
59
79
  @field_validator(field, mode="before")
60
- def pickle_field_validator(v, info: ValidationInfo):
80
+ @classmethod
81
+ def pickle_field_validator(cls, v, info: ValidationInfo):
61
82
  if v is None:
62
83
  return v
63
84
  ctx = info.context or {}
64
85
  should_serialize_redis = ctx.get(REDIS_DUMP_FLAG_NAME, False)
65
86
  if should_serialize_redis:
66
- return pickle.loads(base64.b64decode(v))
87
+ try:
88
+ field_can_be_json = can_json and cls.Meta.prefer_normal_json_dump
89
+ if should_serialize_redis and not field_can_be_json:
90
+ return pickle.loads(base64.b64decode(v))
91
+ return v
92
+ except Exception as e:
93
+ if safe_load:
94
+ failed_fields = ctx.setdefault(FAILED_FIELDS_KEY, set())
95
+ failed_fields.add(field)
96
+ logger.warning("SafeLoad: Failed to deserialize field '%s'", field)
97
+ return None
98
+ raise CantSerializeRedisValueError() from e
67
99
  return v
68
100
 
69
101
  pickle_field_validator.__name__ = f"__deserialize_{field}"
@@ -71,15 +103,40 @@ def make_pickle_field_serializer(field: str):
71
103
  return pickle_field_serializer, pickle_field_validator
72
104
 
73
105
 
106
+ # TODO: Remove in next major version (2.0) - backward compatibility for pickled data
107
+ # This validator handles loading old pickled data for fields that are now JSON-serializable.
108
+ # In 2.0, remove this function and the validator registration in __init_subclass__.
109
+ def make_backward_compat_validator(field: str):
110
+ @field_validator(field, mode="before")
111
+ def backward_compat_validator(v, info: ValidationInfo):
112
+ ctx = info.context or {}
113
+ should_deserialize_redis = ctx.get(REDIS_DUMP_FLAG_NAME, False)
114
+ if should_deserialize_redis and isinstance(v, str):
115
+ try:
116
+ return pickle.loads(base64.b64decode(v))
117
+ except Exception:
118
+ pass
119
+ return v
120
+
121
+ backward_compat_validator.__name__ = f"__backward_compat_{field}"
122
+ return backward_compat_validator
123
+
124
+
74
125
  class AtomicRedisModel(BaseModel):
75
126
  _pk: str = PrivateAttr(default_factory=lambda: str(uuid.uuid4()))
76
127
  _base_model_link: Self | RedisType = PrivateAttr(default=None)
128
+ _failed_fields: set[str] = PrivateAttr(default_factory=set)
77
129
 
78
130
  Meta: ClassVar[RedisConfig] = RedisConfig()
79
131
  _key_field_name: ClassVar[str | None] = None
132
+ _safe_load_fields: ClassVar[set[str]] = set()
80
133
  _field_name: str = PrivateAttr(default="")
81
134
  model_config = ConfigDict(validate_assignment=True, validate_default=True)
82
135
 
136
+ @property
137
+ def failed_fields(self) -> set[str]:
138
+ return self._failed_fields
139
+
83
140
  @property
84
141
  def pk(self):
85
142
  if self._key_field_name:
@@ -184,11 +241,13 @@ class AtomicRedisModel(BaseModel):
184
241
  self._pk = value.split(":", maxsplit=1)[-1]
185
242
 
186
243
  def __init_subclass__(cls, **kwargs):
187
- # Find a field with KeyAnnotation and save its name
244
+ # Find fields with KeyAnnotation and SafeLoadAnnotation
245
+ cls._safe_load_fields = set()
188
246
  for field_name, annotation in cls.__annotations__.items():
189
247
  if has_annotation(annotation, KeyAnnotation):
190
248
  cls._key_field_name = field_name
191
- break
249
+ if has_annotation(annotation, SafeLoadAnnotation):
250
+ cls._safe_load_fields.add(field_name)
192
251
 
193
252
  # Redefine annotations to use redis types
194
253
  pydantic_annotation = get_all_pydantic_annotation(cls, AtomicRedisModel)
@@ -200,7 +259,13 @@ class AtomicRedisModel(BaseModel):
200
259
  original_annotations.update(new_annotation)
201
260
  new_annotations = {
202
261
  field_name: replace_to_redis_types_in_annotation(
203
- annotation, RedisConverter(cls.Meta.redis_type, f".{field_name}")
262
+ annotation,
263
+ RedisConverter(
264
+ cls.Meta.redis_type,
265
+ f".{field_name}",
266
+ safe_load=field_name in cls._safe_load_fields
267
+ or cls.Meta.safe_load_all,
268
+ ),
204
269
  )
205
270
  for field_name, annotation in original_annotations.items()
206
271
  if is_redis_field(field_name, annotation)
@@ -216,9 +281,22 @@ class AtomicRedisModel(BaseModel):
216
281
  if not is_redis_field(attr_name, attr_type):
217
282
  continue
218
283
  if original_annotations[attr_name] == attr_type:
219
- serializer, validator = make_pickle_field_serializer(attr_name)
220
- setattr(cls, serializer.__name__, serializer)
221
- setattr(cls, validator.__name__, validator)
284
+ default_value = cls.__dict__.get(attr_name, None)
285
+ can_json = is_type_json_serializable(attr_type, default_value)
286
+ should_json_serialize = can_json and cls.Meta.prefer_normal_json_dump
287
+
288
+ if not should_json_serialize:
289
+ is_field_marked_safe = attr_name in cls._safe_load_fields
290
+ is_safe_load = is_field_marked_safe or cls.Meta.safe_load_all
291
+ serializer, validator = make_pickle_field_serializer(
292
+ attr_name, safe_load=is_safe_load, can_json=can_json
293
+ )
294
+ setattr(cls, serializer.__name__, serializer)
295
+ setattr(cls, validator.__name__, validator)
296
+ else:
297
+ # TODO: Remove in 2.0 - backward compatibility for old pickled data
298
+ validator = make_backward_compat_validator(attr_name)
299
+ setattr(cls, validator.__name__, validator)
222
300
  continue
223
301
 
224
302
  # Update the redis model list for initialization
@@ -337,8 +415,10 @@ class AtomicRedisModel(BaseModel):
337
415
  raise KeyNotFound(f"{key} is missing in redis")
338
416
  model_dump = model_dump[0]
339
417
 
340
- instance = cls.model_validate(model_dump, context={REDIS_DUMP_FLAG_NAME: True})
418
+ context = {REDIS_DUMP_FLAG_NAME: True, FAILED_FIELDS_KEY: set()}
419
+ instance = cls.model_validate(model_dump, context=context)
341
420
  instance.key = key
421
+ instance._failed_fields = context.get(FAILED_FIELDS_KEY, set())
342
422
  await refresh_ttl_if_needed(
343
423
  cls.Meta.redis, key, cls.Meta.ttl, cls.Meta.refresh_ttl
344
424
  )
@@ -355,9 +435,11 @@ class AtomicRedisModel(BaseModel):
355
435
  if not model_dump:
356
436
  raise KeyNotFound(f"{self.key} is missing in redis")
357
437
  model_dump = model_dump[0]
358
- instance = self.__class__(**model_dump)
438
+ context = {REDIS_DUMP_FLAG_NAME: True, FAILED_FIELDS_KEY: set()}
439
+ instance = self.__class__.model_validate(model_dump, context=context)
359
440
  instance._pk = self._pk
360
441
  instance._base_model_link = self._base_model_link
442
+ instance._failed_fields = context.get(FAILED_FIELDS_KEY, set())
361
443
  await refresh_ttl_if_needed(
362
444
  self.Meta.redis, self.key, self.Meta.ttl, self.Meta.refresh_ttl
363
445
  )
@@ -402,8 +484,10 @@ class AtomicRedisModel(BaseModel):
402
484
 
403
485
  instances = []
404
486
  for model, key in zip(models, keys):
405
- model = cls.model_validate(model[0], context={REDIS_DUMP_FLAG_NAME: True})
487
+ context = {REDIS_DUMP_FLAG_NAME: True, FAILED_FIELDS_KEY: set()}
488
+ model = cls.model_validate(model[0], context=context)
406
489
  model.key = key
490
+ model._failed_fields = context.get(FAILED_FIELDS_KEY, set())
407
491
  instances.append(model)
408
492
  return instances
409
493
 
@@ -512,7 +596,7 @@ class AtomicRedisModel(BaseModel):
512
596
  async def apipeline(
513
597
  self, ignore_if_deleted: bool = False
514
598
  ) -> AbstractAsyncContextManager[Self]:
515
- async with self.Meta.redis.pipeline() as pipe:
599
+ async with self.Meta.redis.pipeline(transaction=True) as pipe:
516
600
  try:
517
601
  redis_model = await self.__class__.aget(self.key)
518
602
  unset_fields = {
@@ -527,7 +611,37 @@ class AtomicRedisModel(BaseModel):
527
611
  _context_var.set(pipe)
528
612
  _context_xx_pipe.set(ignore_if_deleted)
529
613
  yield redis_model
530
- await pipe.execute()
614
+ commands_backup = list(pipe.command_stack)
615
+ noscript_on_first_attempt = False
616
+ noscript_on_retry = False
617
+
618
+ try:
619
+ await pipe.execute()
620
+ except NoScriptError:
621
+ noscript_on_first_attempt = True
622
+
623
+ if noscript_on_first_attempt:
624
+ await handle_noscript_error(self.Meta.redis)
625
+ evalsha_commands = [
626
+ (args, options)
627
+ for args, options in commands_backup
628
+ if args[0] == "EVALSHA"
629
+ ]
630
+ # Retry execute the pipeline actions
631
+ async with self.Meta.redis.pipeline(transaction=True) as retry_pipe:
632
+ for args, options in evalsha_commands:
633
+ retry_pipe.execute_command(*args, **options)
634
+ try:
635
+ await retry_pipe.execute()
636
+ except NoScriptError:
637
+ noscript_on_retry = True
638
+
639
+ if noscript_on_retry:
640
+ raise PersistentNoScriptError(
641
+ "NOSCRIPT error persisted after re-registering scripts. "
642
+ "This indicates a server-side problem with Redis."
643
+ )
644
+
531
645
  await refresh_ttl_if_needed(
532
646
  self.Meta.redis, self.key, self.Meta.ttl, self.Meta.refresh_ttl
533
647
  )
@@ -22,4 +22,9 @@ class RedisConfig:
22
22
  redis_type: dict[type, type] = dataclasses.field(default_factory=create_all_types)
23
23
  ttl: int | None = None
24
24
  init_with_rapyer: bool = True
25
- refresh_ttl: bool = True # Enable TTL refresh on read/write operations by default
25
+ # Enable TTL refresh on read/write operations by default
26
+ refresh_ttl: bool = True
27
+ # If True, all non-Redis-supported fields are treated as SafeLoad
28
+ safe_load_all: bool = False
29
+ # If True, use JSON serialization for fields that support it instead of pickle
30
+ prefer_normal_json_dump: bool = False
@@ -0,0 +1,17 @@
1
+ from rapyer.errors.base import (
2
+ BadFilterError,
3
+ FindError,
4
+ PersistentNoScriptError,
5
+ RapyerError,
6
+ ScriptsNotInitializedError,
7
+ UnsupportedIndexedFieldError,
8
+ )
9
+
10
+ __all__ = [
11
+ "BadFilterError",
12
+ "FindError",
13
+ "PersistentNoScriptError",
14
+ "RapyerError",
15
+ "ScriptsNotInitializedError",
16
+ "UnsupportedIndexedFieldError",
17
+ ]
@@ -24,3 +24,15 @@ class BadFilterError(FindError):
24
24
 
25
25
  class UnsupportedIndexedFieldError(FindError):
26
26
  pass
27
+
28
+
29
+ class CantSerializeRedisValueError(RapyerError):
30
+ pass
31
+
32
+
33
+ class ScriptsNotInitializedError(RapyerError):
34
+ pass
35
+
36
+
37
+ class PersistentNoScriptError(RapyerError):
38
+ pass
@@ -0,0 +1,5 @@
1
+ from rapyer.fields.index import Index
2
+ from rapyer.fields.key import Key
3
+ from rapyer.fields.safe_load import SafeLoad
4
+
5
+ __all__ = ["Key", "Index", "SafeLoad"]
@@ -0,0 +1,27 @@
1
+ import dataclasses
2
+ from typing import TYPE_CHECKING, Annotated, Any, Generic, TypeAlias, TypeVar
3
+
4
+
5
+ @dataclasses.dataclass(frozen=True)
6
+ class SafeLoadAnnotation:
7
+ pass
8
+
9
+
10
+ T = TypeVar("T")
11
+
12
+
13
+ class _SafeLoadType(Generic[T]):
14
+ def __new__(cls, typ: Any = None):
15
+ if typ is None:
16
+ return SafeLoadAnnotation()
17
+ return Annotated[typ, SafeLoadAnnotation()]
18
+
19
+ def __class_getitem__(cls, item):
20
+ return Annotated[item, SafeLoadAnnotation()]
21
+
22
+
23
+ SafeLoad = _SafeLoadType
24
+
25
+
26
+ if TYPE_CHECKING:
27
+ SafeLoad: TypeAlias = Annotated[T, SafeLoadAnnotation()] # pragma: no cover
@@ -3,19 +3,28 @@ from redis import ResponseError
3
3
  from redis.asyncio.client import Redis
4
4
 
5
5
  from rapyer.base import REDIS_MODELS
6
+ from rapyer.scripts import register_scripts
6
7
 
7
8
 
8
9
  async def init_rapyer(
9
- redis: str | Redis = None, ttl: int = None, override_old_idx: bool = True
10
+ redis: str | Redis = None,
11
+ ttl: int = None,
12
+ override_old_idx: bool = True,
13
+ prefer_normal_json_dump: bool = None,
10
14
  ):
11
15
  if isinstance(redis, str):
12
16
  redis = redis_async.from_url(redis, decode_responses=True, max_connections=20)
13
17
 
18
+ if redis is not None:
19
+ await register_scripts(redis)
20
+
14
21
  for model in REDIS_MODELS:
15
22
  if redis is not None:
16
23
  model.Meta.redis = redis
17
24
  if ttl is not None:
18
25
  model.Meta.ttl = ttl
26
+ if prefer_normal_json_dump is not None:
27
+ model.Meta.prefer_normal_json_dump = prefer_normal_json_dump
19
28
 
20
29
  # Initialize model fields
21
30
  model.init_class()
@@ -0,0 +1,86 @@
1
+ from rapyer.errors import ScriptsNotInitializedError
2
+
3
+ REMOVE_RANGE_SCRIPT_NAME = "remove_range"
4
+
5
+ _REMOVE_RANGE_SCRIPT_TEMPLATE = """
6
+ local key = KEYS[1]
7
+ local path = ARGV[1]
8
+ local start_idx = tonumber(ARGV[2])
9
+ local end_idx = tonumber(ARGV[3])
10
+
11
+ local arr_json = redis.call('JSON.GET', key, path)
12
+ if not arr_json or arr_json == 'null' then
13
+ return nil
14
+ end
15
+
16
+ {extract_array}
17
+ local n = #arr
18
+
19
+ if start_idx < 0 then start_idx = n + start_idx end
20
+ if end_idx < 0 then end_idx = n + end_idx end
21
+ if start_idx < 0 then start_idx = 0 end
22
+ if end_idx < 0 then end_idx = 0 end
23
+ if end_idx > n then end_idx = n end
24
+ if start_idx >= n or start_idx >= end_idx then return true end
25
+
26
+ local new_arr = {{}}
27
+ local j = 1
28
+
29
+ for i = 1, start_idx do
30
+ new_arr[j] = arr[i]
31
+ j = j + 1
32
+ end
33
+
34
+ for i = end_idx + 1, n do
35
+ new_arr[j] = arr[i]
36
+ j = j + 1
37
+ end
38
+
39
+ local encoded = j == 1 and '[]' or cjson.encode(new_arr)
40
+ redis.call('JSON.SET', key, path, encoded)
41
+ return true
42
+ """
43
+
44
+ _EXTRACT_ARRAY_REDIS = "local arr = cjson.decode(arr_json)[1]"
45
+ _EXTRACT_ARRAY_FAKEREDIS = "local arr = cjson.decode(arr_json)"
46
+
47
+ REMOVE_RANGE_SCRIPT = _REMOVE_RANGE_SCRIPT_TEMPLATE.format(
48
+ extract_array=_EXTRACT_ARRAY_REDIS
49
+ )
50
+ REMOVE_RANGE_SCRIPT_FAKEREDIS = _REMOVE_RANGE_SCRIPT_TEMPLATE.format(
51
+ extract_array=_EXTRACT_ARRAY_FAKEREDIS
52
+ )
53
+
54
+ SCRIPTS: dict[str, str] = {
55
+ REMOVE_RANGE_SCRIPT_NAME: REMOVE_RANGE_SCRIPT,
56
+ }
57
+
58
+ SCRIPTS_FAKEREDIS: dict[str, str] = {
59
+ REMOVE_RANGE_SCRIPT_NAME: REMOVE_RANGE_SCRIPT_FAKEREDIS,
60
+ }
61
+
62
+ _REGISTERED_SCRIPT_SHAS: dict[str, str] = {}
63
+
64
+
65
+ def is_fakeredis(client) -> bool:
66
+ return "fakeredis" in type(client).__module__
67
+
68
+
69
+ async def register_scripts(redis_client):
70
+ scripts = SCRIPTS_FAKEREDIS if is_fakeredis(redis_client) else SCRIPTS
71
+ for name, script_text in scripts.items():
72
+ sha = await redis_client.script_load(script_text)
73
+ _REGISTERED_SCRIPT_SHAS[name] = sha
74
+
75
+
76
+ def run_sha(pipeline, script_name: str, keys: int, *args):
77
+ sha = _REGISTERED_SCRIPT_SHAS.get(script_name)
78
+ if sha is None:
79
+ raise ScriptsNotInitializedError(
80
+ f"Script '{script_name}' not loaded. Did you forget to call init_rapyer()?"
81
+ )
82
+ pipeline.evalsha(sha, keys, *args)
83
+
84
+
85
+ async def handle_noscript_error(redis_client):
86
+ await register_scripts(redis_client)
@@ -1,5 +1,6 @@
1
1
  import abc
2
2
  import base64
3
+ import logging
3
4
  import pickle
4
5
  from abc import ABC
5
6
  from typing import get_args, Any, TypeVar, Generic
@@ -11,10 +12,15 @@ from redis.commands.search.field import TextField
11
12
  from typing_extensions import deprecated
12
13
 
13
14
  from rapyer.context import _context_var
15
+ from rapyer.errors.base import CantSerializeRedisValueError
14
16
  from rapyer.typing_support import Self
15
17
  from rapyer.utils.redis import refresh_ttl_if_needed
16
18
 
19
+ logger = logging.getLogger("rapyer")
20
+
17
21
  REDIS_DUMP_FLAG_NAME = "__rapyer_dumped__"
22
+ FAILED_FIELDS_KEY = "__rapyer_failed_fields__"
23
+ SKIP_SENTINEL = object()
18
24
 
19
25
 
20
26
  class RedisType(ABC):
@@ -129,6 +135,8 @@ T = TypeVar("T")
129
135
 
130
136
 
131
137
  class GenericRedisType(RedisType, Generic[T], ABC):
138
+ safe_load: bool = False
139
+
132
140
  def __init__(self, *args, **kwargs):
133
141
  super().__init__(*args, **kwargs)
134
142
  for key, val in self.iterate_items():
@@ -139,6 +147,18 @@ class GenericRedisType(RedisType, Generic[T], ABC):
139
147
  args = get_args(type_)
140
148
  return args[0] if args else Any
141
149
 
150
+ @classmethod
151
+ def try_deserialize_item(cls, item, identifier):
152
+ try:
153
+ return cls.deserialize_unknown(item)
154
+ except Exception as e:
155
+ if cls.safe_load:
156
+ logger.warning(
157
+ "SafeLoad: Failed to deserialize item at '%s'.", identifier
158
+ )
159
+ return SKIP_SENTINEL
160
+ raise CantSerializeRedisValueError() from e
161
+
142
162
  @abc.abstractmethod
143
163
  def iterate_items(self):
144
164
  pass # pragma: no cover
@@ -1,4 +1,4 @@
1
- from typing import Any, get_origin
1
+ from typing import get_origin
2
2
 
3
3
  from pydantic import BaseModel, PrivateAttr, TypeAdapter
4
4
 
@@ -8,9 +8,15 @@ from rapyer.utils.pythonic import safe_issubclass
8
8
 
9
9
 
10
10
  class RedisConverter(TypeConverter):
11
- def __init__(self, supported_types: dict[type, type], field_name: str):
11
+ def __init__(
12
+ self,
13
+ supported_types: dict[type, type],
14
+ field_name: str,
15
+ safe_load: bool = False,
16
+ ):
12
17
  self.supported_types = supported_types
13
18
  self.field_name = field_name
19
+ self.safe_load = safe_load
14
20
 
15
21
  def is_redis_type(self, type_to_check: type) -> bool:
16
22
  origin = get_origin(type_to_check) or type_to_check
@@ -62,6 +68,7 @@ class RedisConverter(TypeConverter):
62
68
  dict(
63
69
  field_name=self.field_name,
64
70
  original_type=original_type,
71
+ safe_load=self.safe_load,
65
72
  __doc__=DYNAMIC_CLASS_DOC,
66
73
  ),
67
74
  )
@@ -86,6 +93,7 @@ class RedisConverter(TypeConverter):
86
93
  dict(
87
94
  field_name=self.field_name,
88
95
  original_type=original_type,
96
+ safe_load=self.safe_load,
89
97
  __doc__=DYNAMIC_CLASS_DOC,
90
98
  ),
91
99
  )
@@ -2,7 +2,12 @@ from typing import TypeVar, Generic, get_args, Any, TypeAlias, TYPE_CHECKING
2
2
 
3
3
  from pydantic_core import core_schema
4
4
 
5
- from rapyer.types.base import GenericRedisType, RedisType, REDIS_DUMP_FLAG_NAME
5
+ from rapyer.types.base import (
6
+ GenericRedisType,
7
+ RedisType,
8
+ REDIS_DUMP_FLAG_NAME,
9
+ SKIP_SENTINEL,
10
+ )
6
11
  from rapyer.utils.redis import refresh_ttl_if_needed
7
12
  from rapyer.utils.redis import update_keys_in_pipeline
8
13
 
@@ -239,9 +244,15 @@ class RedisDict(dict[str, T], GenericRedisType, Generic[T]):
239
244
  def full_deserializer(cls, value: dict, info: core_schema.ValidationInfo):
240
245
  ctx = info.context or {}
241
246
  should_serialize_redis = ctx.get(REDIS_DUMP_FLAG_NAME)
247
+
248
+ if not should_serialize_redis:
249
+ return value
250
+
242
251
  return {
243
- key: cls.deserialize_unknown(item) if should_serialize_redis else item
252
+ key: deserialized
244
253
  for key, item in value.items()
254
+ if (deserialized := cls.try_deserialize_item(item, f"key '{key}'"))
255
+ is not SKIP_SENTINEL
245
256
  }
246
257
 
247
258
  @classmethod
@@ -1,13 +1,22 @@
1
1
  import json
2
+ import logging
2
3
  from typing import TypeVar, TYPE_CHECKING
3
4
 
4
5
  from pydantic_core import core_schema
5
6
  from pydantic_core.core_schema import ValidationInfo, SerializationInfo
6
7
  from typing_extensions import TypeAlias
7
8
 
8
- from rapyer.types.base import GenericRedisType, RedisType, REDIS_DUMP_FLAG_NAME
9
+ from rapyer.scripts import run_sha, REMOVE_RANGE_SCRIPT_NAME
10
+ from rapyer.types.base import (
11
+ GenericRedisType,
12
+ RedisType,
13
+ REDIS_DUMP_FLAG_NAME,
14
+ SKIP_SENTINEL,
15
+ )
9
16
  from rapyer.utils.redis import refresh_ttl_if_needed
10
17
 
18
+ logger = logging.getLogger("rapyer")
19
+
11
20
  T = TypeVar("T")
12
21
 
13
22
 
@@ -66,6 +75,24 @@ class RedisList(list, GenericRedisType[T]):
66
75
  self.pipeline.json().set(self.key, self.json_path, [])
67
76
  return super().clear()
68
77
 
78
+ def remove_range(self, start: int, end: int):
79
+ if self.pipeline:
80
+ run_sha(
81
+ self.pipeline,
82
+ REMOVE_RANGE_SCRIPT_NAME,
83
+ 1,
84
+ self.key,
85
+ self.json_path,
86
+ start,
87
+ end,
88
+ )
89
+ del self[start:end]
90
+ else:
91
+ logger.warning(
92
+ "remove_range() called without a pipeline context. "
93
+ "No changes were made. Use 'async with model.apipeline():' to execute."
94
+ )
95
+
69
96
  async def aappend(self, __object):
70
97
  self.append(__object)
71
98
 
@@ -162,8 +189,14 @@ class RedisList(list, GenericRedisType[T]):
162
189
  ctx = info.context or {}
163
190
  is_redis_data = ctx.get(REDIS_DUMP_FLAG_NAME)
164
191
 
192
+ if not is_redis_data:
193
+ return value
194
+
165
195
  return [
166
- cls.deserialize_unknown(item) if is_redis_data else item for item in value
196
+ deserialized
197
+ for idx, item in enumerate(value)
198
+ if (deserialized := cls.try_deserialize_item(item, f"index {idx}"))
199
+ is not SKIP_SENTINEL
167
200
  ]
168
201
 
169
202
  @classmethod
@@ -1,7 +1,8 @@
1
- from typing import get_origin, ClassVar
1
+ from typing import get_origin, ClassVar, Any
2
2
 
3
- from pydantic import BaseModel
3
+ from pydantic import BaseModel, TypeAdapter
4
4
  from pydantic.fields import FieldInfo
5
+ from pydantic_core import PydanticUndefined
5
6
 
6
7
 
7
8
  def _collect_annotations_recursive(
@@ -62,3 +63,25 @@ def is_redis_field(field_name, field_annotation):
62
63
  or field_name.endswith("_")
63
64
  or get_origin(field_annotation) is ClassVar
64
65
  )
66
+
67
+
68
+ def is_field_default_has_value(field_default):
69
+ return field_default is not PydanticUndefined and field_default is not None
70
+
71
+
72
+ def is_type_json_serializable(typ: type, test_value: Any) -> bool:
73
+ try:
74
+ adapter = TypeAdapter(typ)
75
+ if isinstance(test_value, FieldInfo):
76
+ if is_field_default_has_value(test_value.default):
77
+ test_value = test_value.default
78
+ elif is_field_default_has_value(test_value.default_factory):
79
+ test_value = test_value.default_factory()
80
+ else:
81
+ return False
82
+ if test_value is None:
83
+ return False
84
+ adapter.dump_python(test_value, mode="json")
85
+ return True
86
+ except Exception:
87
+ return False
@@ -1,8 +0,0 @@
1
- from rapyer.errors.base import (
2
- BadFilterError,
3
- FindError,
4
- RapyerError,
5
- UnsupportedIndexedFieldError,
6
- )
7
-
8
- __all__ = ["BadFilterError", "FindError", "RapyerError", "UnsupportedIndexedFieldError"]
@@ -1,4 +0,0 @@
1
- from rapyer.fields.index import Index
2
- from rapyer.fields.key import Key
3
-
4
- __all__ = ["Key", "Index"]
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes