coredis 5.5.0__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. 22fe76227e35f92ab5c3__mypyc.cpython-313-darwin.so +0 -0
  2. coredis/__init__.py +42 -0
  3. coredis/_enum.py +42 -0
  4. coredis/_json.py +11 -0
  5. coredis/_packer.cpython-313-darwin.so +0 -0
  6. coredis/_packer.py +71 -0
  7. coredis/_protocols.py +50 -0
  8. coredis/_py_311_typing.py +20 -0
  9. coredis/_py_312_typing.py +17 -0
  10. coredis/_sidecar.py +114 -0
  11. coredis/_utils.cpython-313-darwin.so +0 -0
  12. coredis/_utils.py +440 -0
  13. coredis/_version.py +34 -0
  14. coredis/_version.pyi +1 -0
  15. coredis/cache.py +801 -0
  16. coredis/client/__init__.py +6 -0
  17. coredis/client/basic.py +1240 -0
  18. coredis/client/cluster.py +1265 -0
  19. coredis/commands/__init__.py +64 -0
  20. coredis/commands/_key_spec.py +517 -0
  21. coredis/commands/_utils.py +108 -0
  22. coredis/commands/_validators.py +159 -0
  23. coredis/commands/_wrappers.py +175 -0
  24. coredis/commands/bitfield.py +110 -0
  25. coredis/commands/constants.py +662 -0
  26. coredis/commands/core.py +8484 -0
  27. coredis/commands/function.py +408 -0
  28. coredis/commands/monitor.py +168 -0
  29. coredis/commands/pubsub.py +905 -0
  30. coredis/commands/request.py +108 -0
  31. coredis/commands/script.py +296 -0
  32. coredis/commands/sentinel.py +246 -0
  33. coredis/config.py +50 -0
  34. coredis/connection.py +906 -0
  35. coredis/constants.cpython-313-darwin.so +0 -0
  36. coredis/constants.py +37 -0
  37. coredis/credentials.py +45 -0
  38. coredis/exceptions.py +360 -0
  39. coredis/experimental/__init__.py +1 -0
  40. coredis/globals.py +23 -0
  41. coredis/modules/__init__.py +121 -0
  42. coredis/modules/autocomplete.py +138 -0
  43. coredis/modules/base.py +262 -0
  44. coredis/modules/filters.py +1319 -0
  45. coredis/modules/graph.py +362 -0
  46. coredis/modules/json.py +691 -0
  47. coredis/modules/response/__init__.py +0 -0
  48. coredis/modules/response/_callbacks/__init__.py +0 -0
  49. coredis/modules/response/_callbacks/autocomplete.py +42 -0
  50. coredis/modules/response/_callbacks/graph.py +237 -0
  51. coredis/modules/response/_callbacks/json.py +21 -0
  52. coredis/modules/response/_callbacks/search.py +221 -0
  53. coredis/modules/response/_callbacks/timeseries.py +158 -0
  54. coredis/modules/response/types.py +179 -0
  55. coredis/modules/search.py +1089 -0
  56. coredis/modules/timeseries.py +1139 -0
  57. coredis/parser.cpython-313-darwin.so +0 -0
  58. coredis/parser.py +344 -0
  59. coredis/pipeline.py +1225 -0
  60. coredis/pool/__init__.py +11 -0
  61. coredis/pool/basic.py +453 -0
  62. coredis/pool/cluster.py +517 -0
  63. coredis/pool/nodemanager.py +340 -0
  64. coredis/py.typed +0 -0
  65. coredis/recipes/__init__.py +0 -0
  66. coredis/recipes/credentials/__init__.py +5 -0
  67. coredis/recipes/credentials/iam_provider.py +63 -0
  68. coredis/recipes/locks/__init__.py +5 -0
  69. coredis/recipes/locks/extend.lua +17 -0
  70. coredis/recipes/locks/lua_lock.py +281 -0
  71. coredis/recipes/locks/release.lua +10 -0
  72. coredis/response/__init__.py +5 -0
  73. coredis/response/_callbacks/__init__.py +538 -0
  74. coredis/response/_callbacks/acl.py +32 -0
  75. coredis/response/_callbacks/cluster.py +183 -0
  76. coredis/response/_callbacks/command.py +86 -0
  77. coredis/response/_callbacks/connection.py +31 -0
  78. coredis/response/_callbacks/geo.py +58 -0
  79. coredis/response/_callbacks/hash.py +85 -0
  80. coredis/response/_callbacks/keys.py +59 -0
  81. coredis/response/_callbacks/module.py +33 -0
  82. coredis/response/_callbacks/script.py +85 -0
  83. coredis/response/_callbacks/sentinel.py +179 -0
  84. coredis/response/_callbacks/server.py +241 -0
  85. coredis/response/_callbacks/sets.py +44 -0
  86. coredis/response/_callbacks/sorted_set.py +204 -0
  87. coredis/response/_callbacks/streams.py +185 -0
  88. coredis/response/_callbacks/strings.py +70 -0
  89. coredis/response/_callbacks/vector_sets.py +159 -0
  90. coredis/response/_utils.py +33 -0
  91. coredis/response/types.py +416 -0
  92. coredis/retry.py +233 -0
  93. coredis/sentinel.py +477 -0
  94. coredis/stream.py +369 -0
  95. coredis/tokens.py +2286 -0
  96. coredis/typing.py +593 -0
  97. coredis-5.5.0.dist-info/METADATA +211 -0
  98. coredis-5.5.0.dist-info/RECORD +100 -0
  99. coredis-5.5.0.dist-info/WHEEL +6 -0
  100. coredis-5.5.0.dist-info/licenses/LICENSE +23 -0
@@ -0,0 +1,416 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ import datetime
5
+ import re
6
+ import shlex
7
+ from collections.abc import Set
8
+ from re import Pattern
9
+
10
+ from coredis.typing import (
11
+ ClassVar,
12
+ Literal,
13
+ Mapping,
14
+ NamedTuple,
15
+ OrderedDict,
16
+ RedisValueT,
17
+ StringT,
18
+ TypedDict,
19
+ )
20
+
21
+ #: Response from `CLIENT INFO <https://redis.io/commands/client-info>`__
22
+ #:
23
+ #: - ``id``: a unique 64-bit client ID
24
+ #: - ``addr``: address/port of the client
25
+ #: - ``laddr``: address/port of local address client connected to (bind address)
26
+ #: - ``fd``: file descriptor corresponding to the socket
27
+ #: - ``name``: the name set by the client with CLIENT SETNAME
28
+ #: - ``age``: total duration of the connection in seconds
29
+ #: - ``idle``: idle time of the connection in seconds
30
+ #: - ``flags``: client flags
31
+ #: - ``db``: current database ID
32
+ #: - ``sub``: number of channel subscriptions
33
+ #: - ``psub``: number of pattern matching subscriptions
34
+ #: - ``multi``: number of commands in a MULTI/EXEC context
35
+ #: - ``qbuf``: query buffer length (0 means no query pending)
36
+ #: - ``qbuf-free``: free space of the query buffer (0 means the buffer is full)
37
+ #: - ``argv-mem``: incomplete arguments for the next command (already extracted from query buffer)
38
+ #: - ``multi-mem``: memory is used up by buffered multi commands. Added in Redis 7.0
39
+ #: - ``obl``: output buffer length
40
+ #: - ``oll``: output list length (replies are queued in this list when the buffer is full)
41
+ #: - ``omem``: output buffer memory usage
42
+ #: - ``tot-mem``: total memory consumed by this client in its various buffers
43
+ #: - ``events``: file descriptor events
44
+ #: - ``cmd``: last command played
45
+ #: - ``user``: the authenticated username of the client
46
+ #: - ``redir``: client id of current client tracking redirection
47
+ #: - ``resp``: client RESP protocol version. Added in Redis 7.0
48
+ ClientInfo = TypedDict(
49
+ "ClientInfo",
50
+ {
51
+ "id": int,
52
+ "addr": str,
53
+ "laddr": str,
54
+ "fd": int,
55
+ "name": str,
56
+ "age": int,
57
+ "idle": int,
58
+ "flags": str,
59
+ "db": int,
60
+ "sub": int,
61
+ "psub": int,
62
+ "multi": int,
63
+ "qbuf": int,
64
+ "qbuf-free": int,
65
+ "argv-mem": int,
66
+ "multi-mem": int,
67
+ "obl": int,
68
+ "oll": int,
69
+ "omem": int,
70
+ "tot-mem": int,
71
+ "events": str,
72
+ "cmd": str,
73
+ "user": str,
74
+ "redir": int,
75
+ "resp": str,
76
+ },
77
+ )
78
+
79
+ #: Script/Function flags
80
+ #: See: `<https://redis.io/topics/lua-api#a-namescriptflagsa-script-flags>`__
81
+ ScriptFlag = Literal[
82
+ "no-writes",
83
+ "allow-oom",
84
+ "allow-stale",
85
+ "no-cluster",
86
+ b"no-writes",
87
+ b"allow-oom",
88
+ b"allow-stale",
89
+ b"no-cluster",
90
+ ]
91
+
92
+
93
+ class FunctionDefinition(TypedDict):
94
+ """
95
+ Function definition as returned by `FUNCTION LIST <https://redis.io/commands/function-list>`__
96
+ """
97
+
98
+ #: the name of the function
99
+ name: StringT
100
+ #: the description of the function
101
+ description: StringT
102
+ #: function flags
103
+ flags: set[ScriptFlag]
104
+
105
+
106
+ class LibraryDefinition(TypedDict):
107
+ """
108
+ Library definition as returned by `FUNCTION LIST <https://redis.io/commands/function-list>`__
109
+ """
110
+
111
+ #: the name of the library
112
+ name: StringT
113
+ #: the engine used by the library
114
+ engine: Literal["LUA"]
115
+ #: Mapping of function names to functions in the library
116
+ functions: dict[StringT, FunctionDefinition]
117
+ #: The library's source code
118
+ library_code: StringT | None
119
+
120
+
121
+ class ScoredMember(NamedTuple):
122
+ """
123
+ Member of a sorted set
124
+ """
125
+
126
+ #: The sorted set member name
127
+ member: StringT
128
+ #: Score of the member
129
+ score: float
130
+
131
+
132
+ class GeoCoordinates(NamedTuple):
133
+ """
134
+ A longitude/latitude pair identifying a location
135
+ """
136
+
137
+ #: Longitude
138
+ longitude: float
139
+ #: Latitude
140
+ latitude: float
141
+
142
+
143
+ ScoredMembers = tuple[ScoredMember, ...]
144
+
145
+
146
+ class GeoSearchResult(NamedTuple):
147
+ """
148
+ Structure of a geo query
149
+ """
150
+
151
+ #: Place name
152
+ name: StringT
153
+ #: Distance
154
+ distance: float | None
155
+ #: GeoHash
156
+ geohash: int | None
157
+ #: Lat/Lon
158
+ coordinates: GeoCoordinates | None
159
+
160
+
161
+ #: Definition of a redis command
162
+ #: See: `<https://redis.io/topics/key-specs>`__
163
+ #:
164
+ #: - ``name``: This is the command's name in lowercase
165
+ #: - ``arity``: Arity is the number of arguments a command expects.
166
+ #: - ``flags``: See `<https://redis.io/commands/command#flags>`__
167
+ #: - ``first-key``: This value identifies the position of the command's first key name argumen
168
+ #: - ``last-key``: This value identifies the position of the command's last key name argument
169
+ #: - ``step``: This value is the step, or increment, between the first key and last key values
170
+ #: where the keys are.
171
+ #: - ``acl-categories``: This is an array of simple strings that are the ACL categories to which
172
+ #: the command belongs
173
+ #: - ``tips``: Helpful information about the command
174
+ #: - ``key-specification``: This is an array consisting of the command's key specifications
175
+ #: - ``sub-commands``: This is an array containing all of the command's subcommands, if any
176
+ Command = TypedDict(
177
+ "Command",
178
+ {
179
+ "name": str,
180
+ "arity": int,
181
+ "flags": Set[str],
182
+ "first-key": int,
183
+ "last-key": int,
184
+ "step": int,
185
+ "acl-categories": Set[str] | None,
186
+ "tips": Set[str] | None,
187
+ "key-specifications": Set[Mapping[str, int | str | Mapping]] | None, # type: ignore
188
+ "sub-commands": tuple[str, ...] | None,
189
+ },
190
+ )
191
+
192
+
193
+ class RoleInfo(NamedTuple):
194
+ """
195
+ Redis instance role information
196
+ """
197
+
198
+ #:
199
+ role: str
200
+ #:
201
+ offset: int | None = None
202
+ #:
203
+ status: str | None = None
204
+ #:
205
+ slaves: tuple[dict[str, str | int], ...] | None = None
206
+ #:
207
+ masters: tuple[str, ...] | None = None
208
+
209
+
210
+ class StreamEntry(NamedTuple):
211
+ """
212
+ Structure representing an entry in a redis stream
213
+ """
214
+
215
+ identifier: StringT
216
+ field_values: OrderedDict[StringT, StringT]
217
+
218
+
219
+ #: Details of a stream
220
+ #: See: `<https://redis.io/commands/xinfo-stream>`__
221
+ StreamInfo = TypedDict(
222
+ "StreamInfo",
223
+ {
224
+ "first-entry": StreamEntry | None,
225
+ "last-entry": StreamEntry | None,
226
+ "length": int,
227
+ "radix-tree-keys": int,
228
+ "radix-tree-nodes": int,
229
+ "groups": int | list[dict], # type: ignore
230
+ "last-generated-id": str,
231
+ "max-deleted-entry-id": str,
232
+ "recorded-first-entry-id": str,
233
+ "entries-added": int,
234
+ "entries-read": int,
235
+ "entries": tuple[StreamEntry, ...] | None,
236
+ },
237
+ )
238
+
239
+
240
+ class StreamPending(NamedTuple):
241
+ """
242
+ Summary response from
243
+ `XPENDING <https://redis.io/commands/xpending#summary-form-of-xpending>`__
244
+ """
245
+
246
+ pending: int
247
+ minimum_identifier: StringT
248
+ maximum_identifier: StringT
249
+ consumers: OrderedDict[StringT, int]
250
+
251
+
252
+ class StreamPendingExt(NamedTuple):
253
+ """
254
+ Extended form response from
255
+ `XPENDING <https://redis.io/commands/xpending#extended-form-of-xpending>`__
256
+ """
257
+
258
+ identifier: StringT
259
+ consumer: StringT
260
+ idle: int
261
+ delivered: int
262
+
263
+
264
+ #: Response from `SLOWLOG GET <https://redis.io/commands/slowlog-get>`__
265
+ class SlowLogInfo(NamedTuple):
266
+ #: A unique progressive identifier for every slow log entry.
267
+ id: int
268
+ #: The unix timestamp at which the logged command was processed.
269
+ start_time: int
270
+ #: The amount of time needed for its execution, in microseconds.
271
+ duration: int
272
+ #: The array composing the arguments of the command.
273
+ command: tuple[StringT, ...]
274
+ #: Client IP address and port
275
+ client_addr: tuple[StringT, int]
276
+ #: Client name
277
+ client_name: str
278
+
279
+
280
+ class LCSMatch(NamedTuple):
281
+ """
282
+ An instance of an LCS match
283
+ """
284
+
285
+ #: Start/end offset of the first string
286
+ first: tuple[int, int]
287
+ #: Start/end offset of the second string
288
+ second: tuple[int, int]
289
+ #: Length of the match
290
+ length: int | None
291
+
292
+
293
+ class LCSResult(NamedTuple):
294
+ """
295
+ Results from `LCS <https://redis.io/commands/lcs>`__
296
+ """
297
+
298
+ #: matches
299
+ matches: tuple[LCSMatch, ...]
300
+ #: Length of longest match
301
+ length: int
302
+
303
+
304
+ @dataclasses.dataclass
305
+ class MonitorResult:
306
+ """
307
+ Details of issued commands received by the client when
308
+ listening with the `MONITOR <https://redis.io/commands/monitor>`__
309
+ command
310
+ """
311
+
312
+ #: Time command was received
313
+ time: datetime.datetime
314
+ #: db number
315
+ db: int
316
+ #: (host, port) or path if the server is listening on a unix domain socket
317
+ client_addr: tuple[str, int] | str | None
318
+ #: The type of the client that send the command
319
+ client_type: Literal["tcp", "unix", "lua"]
320
+ #: The name of the command
321
+ command: str
322
+ #: Arguments passed to the command
323
+ args: tuple[str, ...] | None
324
+
325
+ EXPR: ClassVar[Pattern[str]] = re.compile(r"\[(\d+) (.*?)\] (.*)$")
326
+
327
+ @classmethod
328
+ def parse_response_string(cls, response: str) -> MonitorResult:
329
+ command_time, command_data = response.split(" ", 1)
330
+ match = cls.EXPR.match(command_data)
331
+ assert match
332
+ db_id, client_info, command = match.groups()
333
+ command = shlex.split(command)
334
+ client_addr = None
335
+ client_type: Literal["tcp", "unix", "lua"]
336
+ if client_info == "lua":
337
+ client_type = "lua"
338
+ elif client_info.startswith("unix"):
339
+ client_type = "unix"
340
+ client_addr = client_info[5:]
341
+ else:
342
+ host, port = client_info.rsplit(":", 1)
343
+ client_addr = (host, int(port))
344
+ client_type = "tcp"
345
+ return cls(
346
+ time=datetime.datetime.fromtimestamp(float(command_time)),
347
+ db=int(db_id),
348
+ client_addr=client_addr,
349
+ client_type=client_type,
350
+ command=command[0],
351
+ args=tuple(command[1:]),
352
+ )
353
+
354
+
355
+ class ClusterNode(TypedDict):
356
+ host: str
357
+ port: int
358
+ node_id: str | None
359
+ server_type: Literal["master", "slave"] | None
360
+
361
+
362
+ class ClusterNodeDetail(TypedDict):
363
+ id: str
364
+ flags: tuple[str, ...]
365
+ host: str
366
+ port: int
367
+ master: str | None
368
+ ping_sent: int
369
+ pong_recv: int
370
+ link_state: str
371
+ slots: list[int]
372
+ migrations: list[dict[str, RedisValueT]]
373
+
374
+
375
+ class PubSubMessage(TypedDict):
376
+ #: One of the following:
377
+ #:
378
+ #: subscribe
379
+ #: Server response when a client subscribes to a channel(s)
380
+ #: unsubscribe
381
+ #: Server response when a client unsubscribes from a channel(s)
382
+ #: psubscribe
383
+ #: Server response when a client subscribes to a pattern(s)
384
+ #: punsubscribe
385
+ #: Server response when a client unsubscribes from a pattern(s)
386
+ #: ssubscribe
387
+ #: Server response when a client subscribes to a shard channel(s)
388
+ #: sunsubscribe
389
+ #: Server response when a client unsubscribes from a shard channel(s)
390
+ #: message
391
+ #: A message received from subscribing to a channel
392
+ #: pmessage
393
+ #: A message received from subscribing to a pattern
394
+ type: str
395
+ #: The channel subscribed to or unsubscribed from or the channel a message was published to
396
+ channel: StringT
397
+ #: The pattern that was subscribed to or unsubscribed from or to which a received message was
398
+ #: routed to
399
+ pattern: StringT | None
400
+ #: - If ``type`` is one of ``{message, pmessage}`` this is the actual published message
401
+ #: - If ``type`` is one of
402
+ #: ``{subscribe, psubscribe, ssubscribe, unsubscribe, punsubscribe, sunsubscribe}``
403
+ #: this will be an :class:`int` corresponding to the number of channels and patterns that the
404
+ #: connection is currently subscribed to.
405
+ data: int | StringT
406
+
407
+
408
+ class VectorData(TypedDict):
409
+ #: The quantization type as a string (``fp32``, ``bin`` or ``q8``)
410
+ quantization: str
411
+ #: Raw bytes representation of the vector
412
+ blob: bytes
413
+ #: The L2 norm of the vector before normalization
414
+ l2_norm: float
415
+ #: If the vector is quantized as q8, the quantization range
416
+ quantization_range: float
coredis/retry.py ADDED
@@ -0,0 +1,233 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import logging
5
+ from abc import ABC, abstractmethod
6
+ from functools import wraps
7
+ from typing import Any
8
+
9
+ from coredis.typing import Awaitable, Callable, P, R
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class RetryPolicy(ABC):
15
+ """
16
+ Abstract retry policy
17
+ """
18
+
19
+ def __init__(self, retries: int, retryable_exceptions: tuple[type[BaseException], ...]) -> None:
20
+ """
21
+ :param retries: number of times to retry if a :paramref:`retryable_exception`
22
+ is encountered.
23
+ :param retryable_exceptions: The exceptions to trigger a retry for
24
+ """
25
+ self.retryable_exceptions = retryable_exceptions
26
+ self.retries = retries
27
+
28
+ @abstractmethod
29
+ async def delay(self, attempt_number: int) -> None:
30
+ pass
31
+
32
+ async def call_with_retries(
33
+ self,
34
+ func: Callable[..., Awaitable[R]],
35
+ before_hook: Callable[..., Awaitable[Any]] | None = None,
36
+ failure_hook: Callable[..., Awaitable[Any]]
37
+ | dict[type[BaseException], Callable[..., Awaitable[None]]]
38
+ | None = None,
39
+ ) -> R:
40
+ """
41
+ :param func: a function that should return the coroutine that will be
42
+ awaited when retrying if :paramref:`RetryPolicy.retryable_exceptions` is encountered.
43
+ :param before_hook: if provided will be called on every attempt.
44
+ :param failure_hook: if provided and is a callable it will be
45
+ called everytime a retryable exception is encountered. If it is a mapping
46
+ of exception types to callables, the first exception type that is a parent
47
+ of any encountered exception will be called.
48
+ """
49
+ last_error: BaseException | None = None
50
+ for attempt in range(self.retries + 1):
51
+ try:
52
+ await self.delay(attempt)
53
+ if before_hook:
54
+ await before_hook()
55
+ return await func()
56
+ except self.retryable_exceptions as e:
57
+ logger.info(f"Retry attempt {attempt + 1} due to error: {e}")
58
+ if failure_hook:
59
+ try:
60
+ if isinstance(failure_hook, dict):
61
+ for exc_type, hook in failure_hook.items():
62
+ if isinstance(e, exc_type):
63
+ await hook(e)
64
+ break
65
+ else:
66
+ await failure_hook(e)
67
+ except: # noqa
68
+ pass
69
+ last_error = e
70
+ if last_error:
71
+ raise last_error
72
+ assert False
73
+
74
+ def will_retry(self, exc: BaseException) -> bool:
75
+ return isinstance(exc, self.retryable_exceptions)
76
+
77
+ def __repr__(self) -> str:
78
+ return (
79
+ f"{self.__class__.__name__}<"
80
+ f"retries={self.retries}, "
81
+ f"retryable_exceptions={','.join(e.__name__ for e in self.retryable_exceptions)}"
82
+ ">"
83
+ )
84
+
85
+
86
+ class NoRetryPolicy(RetryPolicy):
87
+ def __init__(self) -> None:
88
+ super().__init__(1, ())
89
+
90
+ async def delay(self, attempt_number: int) -> None:
91
+ pass
92
+
93
+
94
+ class ConstantRetryPolicy(RetryPolicy):
95
+ """
96
+ Retry policy that pauses :paramref:`ConstantRetryPolicy.delay`
97
+ seconds between :paramref:`ConstantRetryPolicy.retries`
98
+ if any of :paramref:`ConstantRetryPolicy.retryable_exceptions` are
99
+ encountered.
100
+ """
101
+
102
+ def __init__(
103
+ self,
104
+ retryable_exceptions: tuple[type[BaseException], ...],
105
+ retries: int,
106
+ delay: float,
107
+ ) -> None:
108
+ self.__delay = delay
109
+ super().__init__(retries, retryable_exceptions)
110
+
111
+ async def delay(self, attempt_number: int) -> None:
112
+ if attempt_number > 0:
113
+ await asyncio.sleep(self.__delay)
114
+
115
+
116
+ class ExponentialBackoffRetryPolicy(RetryPolicy):
117
+ """
118
+ Retry policy that exponentially backs off before retrying up to
119
+ :paramref:`ExponentialBackoffRetryPolicy.retries` if any
120
+ of :paramref:`ExponentialBackoffRetryPolicy.retryable_exceptions` are
121
+ encountered. :paramref:`ExponentialBckoffRetryPolicy.initial_delay`
122
+ is used as the initial value for calculating the exponential backoff.
123
+
124
+ """
125
+
126
+ def __init__(
127
+ self,
128
+ retryable_exceptions: tuple[type[BaseException], ...],
129
+ retries: int,
130
+ initial_delay: float,
131
+ ) -> None:
132
+ self.__initial_delay = initial_delay
133
+ super().__init__(retries, retryable_exceptions)
134
+
135
+ async def delay(self, attempt_number: int) -> None:
136
+ if attempt_number > 0:
137
+ await asyncio.sleep(pow(2, attempt_number) * self.__initial_delay)
138
+
139
+
140
+ class CompositeRetryPolicy(RetryPolicy):
141
+ """
142
+ Convenience class to combine multiple retry policies
143
+ """
144
+
145
+ def __init__(self, *retry_policies: RetryPolicy):
146
+ self._retry_policies = set(retry_policies)
147
+
148
+ def __repr__(self) -> str:
149
+ return f"{self.__class__.__name__}<{','.join(str(p) for p in self._retry_policies)}>"
150
+
151
+ def add_retry_policy(self, policy: RetryPolicy) -> None:
152
+ """
153
+ Add to the retry policies that this instance was created with
154
+ """
155
+ self._retry_policies.add(policy)
156
+
157
+ async def delay(self, attempt_number: int) -> None:
158
+ raise NotImplementedError()
159
+
160
+ async def call_with_retries(
161
+ self,
162
+ func: Callable[..., Awaitable[R]],
163
+ before_hook: Callable[..., Awaitable[Any]] | None = None,
164
+ failure_hook: None
165
+ | (
166
+ Callable[..., Awaitable[Any]] | dict[type[BaseException], Callable[..., Awaitable[Any]]]
167
+ ) = None,
168
+ ) -> R:
169
+ """
170
+ Calls :paramref:`func` repeatedly according to the retry policies that
171
+ this class was instantiated with (:paramref:`CompositeRetryPolicy.retry_policies`).
172
+
173
+ :param func: a function that should return the coroutine that will be
174
+ awaited when retrying if :paramref:`RetryPolicy.retryable_exceptions` is encountered.
175
+ :param before_hook: if provided will be called before every attempt.
176
+ :param failure_hook: if provided and is a callable it will be
177
+ called everytime a retryable exception is encountered. If it is a mapping
178
+ of exception types to callables, the first exception type that is a parent
179
+ of any encountered exception will be called.
180
+ """
181
+ attempts = {policy: 0 for policy in self._retry_policies}
182
+ while True:
183
+ try:
184
+ if before_hook:
185
+ await before_hook()
186
+ return await func()
187
+ except Exception as e:
188
+ will_retry = False
189
+ attempt = 0
190
+ for policy in attempts:
191
+ if policy.will_retry(e) and attempts[policy] < policy.retries:
192
+ attempt = attempts[policy]
193
+ attempts[policy] += 1
194
+ await policy.delay(attempts[policy])
195
+ will_retry = True
196
+ break
197
+
198
+ if failure_hook:
199
+ if isinstance(failure_hook, dict):
200
+ for exc_type in failure_hook:
201
+ if isinstance(e, exc_type):
202
+ await failure_hook[exc_type](e)
203
+ break
204
+ else:
205
+ await failure_hook(e)
206
+
207
+ if will_retry:
208
+ logger.info(f"Retry attempt {attempt} due to error: {e}")
209
+ continue
210
+
211
+ raise e
212
+
213
+
214
+ def retryable(
215
+ policy: RetryPolicy,
216
+ failure_hook: Callable[..., Awaitable[Any]] | None = None,
217
+ ) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]:
218
+ """
219
+ Decorator to be used to apply a retry policy to a coroutine
220
+ """
221
+
222
+ def inner(
223
+ func: Callable[P, Awaitable[R]],
224
+ ) -> Callable[P, Awaitable[R]]:
225
+ @wraps(func)
226
+ async def _inner(*args: P.args, **kwargs: P.kwargs) -> R:
227
+ return await policy.call_with_retries(
228
+ lambda: func(*args, **kwargs), failure_hook=failure_hook
229
+ )
230
+
231
+ return _inner
232
+
233
+ return inner