rediskit 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rediskit/__init__.py +19 -0
- rediskit/config.py +23 -0
- rediskit/encrypter.py +168 -0
- rediskit/memoize.py +258 -0
- rediskit/redisClient.py +270 -0
- rediskit/redisLock.py +18 -0
- rediskit/utils.py +167 -0
- rediskit-0.0.1.dist-info/METADATA +235 -0
- rediskit-0.0.1.dist-info/RECORD +12 -0
- rediskit-0.0.1.dist-info/WHEEL +5 -0
- rediskit-0.0.1.dist-info/licenses/LICENSE +201 -0
- rediskit-0.0.1.dist-info/top_level.txt +1 -0
rediskit/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""
|
|
2
|
+
rediskit - Redis-backed performance and concurrency primitives for Python applications.
|
|
3
|
+
|
|
4
|
+
Provides caching, distributed coordination, and data protection using Redis.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from rediskit.encrypter import Encrypter
|
|
8
|
+
from src.rediskit.memoize import RedisMemoize
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"RedisMemoize",
|
|
12
|
+
"InitRedisConnectionPool",
|
|
13
|
+
"InitAsyncRedisConnectionPool",
|
|
14
|
+
"GetRedisConnection",
|
|
15
|
+
"GetAsyncRedisConnection",
|
|
16
|
+
"GetRedisMutexLock",
|
|
17
|
+
"GetAsyncRedisMutexLock",
|
|
18
|
+
"Encrypter",
|
|
19
|
+
]
|
rediskit/config.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from DockerBuildSystem import TerminalTools
|
|
4
|
+
|
|
5
|
+
from src.rediskit.utils import base64JsonToDict
|
|
6
|
+
|
|
7
|
+
TerminalTools.LoadDefaultEnvironmentVariablesFile("private.env")
|
|
8
|
+
TerminalTools.LoadDefaultEnvironmentVariablesFile(".env")
|
|
9
|
+
|
|
10
|
+
# Redis Settings
|
|
11
|
+
REDISKIT_REDIS_HOST = os.environ.get("REDISKIT_REDIS_HOST", "localhost")
|
|
12
|
+
REDISKIT_REDIS_PORT = int(os.environ.get("REDISKIT_REDIS_PORT", "6379"))
|
|
13
|
+
REDISKIT_REDIS_PASSWORD = os.environ.get("REDISKIT_REDIS_PASSWORD", "")
|
|
14
|
+
REDISKIT_REDIS_TOP_NODE = os.environ.get("REDISKIT_REDIS_TOP_NODE", "redis_kit_node")
|
|
15
|
+
REDISKIT_REDIS_SCAN_COUNT = int(os.environ.get("REDISKIT_REDIS_SCAN_COUNT", "10000"))
|
|
16
|
+
REDISKIT_REDIS_SKIP_CACHING = os.environ.get("REDISKIT_REDIS_SKIP_CACHING", "false").upper() == "TRUE"
|
|
17
|
+
|
|
18
|
+
# Lock Settings
|
|
19
|
+
REDISKIT_LOCK_SETTINGS_REDIS_NAMESPACE = os.environ.get("REDISKIT_LOCK_SETTINGS_REDIS_NAMESPACE", f"{REDISKIT_REDIS_TOP_NODE}:LOCK")
|
|
20
|
+
REDISKIT_LOCK_ASYNC_SETTINGS_REDIS_NAMESPACE = os.environ.get("REDISKIT_LOCK_ASYNC_SETTINGS_REDIS_NAMESPACE", f"{REDISKIT_REDIS_TOP_NODE}:LOCK_ASYNC")
|
|
21
|
+
REDISKIT_LOCK_CACHE_REDIS_MUTEX = os.environ.get("REDISKIT_LOCK_CACHE_REDIS_MUTEX", "REDISKIT_LOCK_CACHE_REDIS_MUTEX")
|
|
22
|
+
|
|
23
|
+
REDISKIT_ENCRYPTION_SECRET = base64JsonToDict(os.environ.get("REDISKIT_ENCRYPTION_SECRET", ""))
|
rediskit/encrypter.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import gzip
|
|
3
|
+
import json
|
|
4
|
+
import re
|
|
5
|
+
|
|
6
|
+
import zstd
|
|
7
|
+
from nacl import encoding, secret
|
|
8
|
+
from nacl.utils import random
|
|
9
|
+
|
|
10
|
+
from rediskit import config
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Encrypter:
|
|
14
|
+
VERSION_PREFIX = "__enc_v"
|
|
15
|
+
|
|
16
|
+
# EncryptedCompressedBase64Box
|
|
17
|
+
def __init__(self, keyHexDict: dict[str, str] = config.REDISKIT_ENCRYPTION_SECRET) -> None:
|
|
18
|
+
"""
|
|
19
|
+
keysBase64 shall have the following format {"__enc_v1": "32-byte key"...,"__enc_vn": ...}
|
|
20
|
+
"""
|
|
21
|
+
self.encryptionKeys = keyHexDict
|
|
22
|
+
self.latestVersion = list(self.encryptionKeys.keys())[-1]
|
|
23
|
+
|
|
24
|
+
def _getSecretBox(self, version: str) -> secret.SecretBox:
|
|
25
|
+
hexKey = self.encryptionKeys.get(version)
|
|
26
|
+
if not hexKey:
|
|
27
|
+
raise ValueError(f"Encryption key for version {version} not found")
|
|
28
|
+
# Ensure the key is in the proper format (32 bytes after decoding from hex)
|
|
29
|
+
return secret.SecretBox(hexKey.encode(), encoder=encoding.HexEncoder)
|
|
30
|
+
|
|
31
|
+
def encrypt[T: str | bytes | None](self, data: T, raiseIfEncrypted: bool = True, useZstd: bool = True) -> T:
|
|
32
|
+
if data is None:
|
|
33
|
+
return None # type: ignore # not able to check this properly
|
|
34
|
+
elif isinstance(data, str):
|
|
35
|
+
dataToEncrypt: bytes = data.encode()
|
|
36
|
+
isText = True
|
|
37
|
+
elif isinstance(data, bytes):
|
|
38
|
+
dataToEncrypt = data
|
|
39
|
+
isText = False
|
|
40
|
+
else:
|
|
41
|
+
raise ValueError("data expected to be bytes or str")
|
|
42
|
+
|
|
43
|
+
if self.isEncrypted(dataToEncrypt, raiseIfEncrypted):
|
|
44
|
+
return data
|
|
45
|
+
|
|
46
|
+
compressedData: bytes = zstd.compress(dataToEncrypt) if useZstd else gzip.compress(dataToEncrypt)
|
|
47
|
+
tagBytes = b"zstd" if useZstd else b"gzip"
|
|
48
|
+
|
|
49
|
+
cipherText = self._getSecretBox(self.latestVersion).encrypt(compressedData, encoder=encoding.Base64Encoder)
|
|
50
|
+
token = self.latestVersion.encode() + b"|" + tagBytes + b":" + cipherText
|
|
51
|
+
|
|
52
|
+
return token.decode() if isText else token # type: ignore # not able to check this properly
|
|
53
|
+
|
|
54
|
+
def decrypt[T: str | bytes | None](self, data: T) -> T:
|
|
55
|
+
if data is None:
|
|
56
|
+
return None # type: ignore # not able to check this properly
|
|
57
|
+
elif isinstance(data, str):
|
|
58
|
+
dataToDecrypt: bytes = data.encode()
|
|
59
|
+
isText = True
|
|
60
|
+
elif isinstance(data, bytes):
|
|
61
|
+
dataToDecrypt = data
|
|
62
|
+
isText = False
|
|
63
|
+
else:
|
|
64
|
+
raise ValueError("data expected to be bytes or str")
|
|
65
|
+
|
|
66
|
+
try:
|
|
67
|
+
if dataToDecrypt.startswith(self.VERSION_PREFIX.encode()):
|
|
68
|
+
if b"|" in dataToDecrypt:
|
|
69
|
+
# Format: __enc_vX|compression:ciphertext
|
|
70
|
+
pre, rest = dataToDecrypt.split(b"|", 1)
|
|
71
|
+
version = pre
|
|
72
|
+
compressionTag, ciphertext = rest.split(b":", 1)
|
|
73
|
+
elif b":" in dataToDecrypt:
|
|
74
|
+
# Format: __enc_vX:ciphertext
|
|
75
|
+
version, ciphertext = dataToDecrypt.split(b":", 1)
|
|
76
|
+
compressionTag = b"gzip"
|
|
77
|
+
else:
|
|
78
|
+
raise ValueError("Invalid encrypted data format")
|
|
79
|
+
else:
|
|
80
|
+
# No version: legacy, fallback to v1 + gzip
|
|
81
|
+
version = f"{self.VERSION_PREFIX}1".encode()
|
|
82
|
+
compressionTag = b"gzip"
|
|
83
|
+
ciphertext = dataToDecrypt
|
|
84
|
+
except ValueError:
|
|
85
|
+
raise ValueError("Invalid encrypted data format; expected version prefix.")
|
|
86
|
+
|
|
87
|
+
secret_box = self._getSecretBox(version.decode())
|
|
88
|
+
compressed_data = secret_box.decrypt(ciphertext, encoder=encoding.Base64Encoder)
|
|
89
|
+
|
|
90
|
+
if compressionTag == b"zstd":
|
|
91
|
+
deCompressed = zstd.decompress(compressed_data)
|
|
92
|
+
elif compressionTag == b"gzip":
|
|
93
|
+
deCompressed = gzip.decompress(compressed_data)
|
|
94
|
+
else:
|
|
95
|
+
raise ValueError(f"Unknown compression '{compressionTag.decode()}' in encrypted data.")
|
|
96
|
+
|
|
97
|
+
return deCompressed.decode() if isText else deCompressed # type: ignore # not able to check this properly
|
|
98
|
+
|
|
99
|
+
@staticmethod
|
|
100
|
+
def getEncryptionKeyVersionNumber(versionKey: str) -> int:
|
|
101
|
+
match = re.match(r"^__enc_v(\d+)$", versionKey)
|
|
102
|
+
if not match:
|
|
103
|
+
raise ValueError(f"Invalid version string: {versionKey}")
|
|
104
|
+
return int(match.group(1))
|
|
105
|
+
|
|
106
|
+
@staticmethod
|
|
107
|
+
def isEncrypted(data: str | bytes | None, raiseIfEncrypted: bool = False) -> bool:
|
|
108
|
+
# NB! can be false positive if a secret starts with "__enc_v"..., this risk is just accepted for now
|
|
109
|
+
if data is None:
|
|
110
|
+
return False
|
|
111
|
+
dataBytes = data.encode() if isinstance(data, str) else data
|
|
112
|
+
if not dataBytes.startswith(Encrypter.VERSION_PREFIX.encode()) or b":" not in dataBytes:
|
|
113
|
+
return False
|
|
114
|
+
if raiseIfEncrypted:
|
|
115
|
+
raise Exception("The data is already encrypted")
|
|
116
|
+
return True
|
|
117
|
+
|
|
118
|
+
@staticmethod
|
|
119
|
+
def generateNewHexKey() -> str:
|
|
120
|
+
"""
|
|
121
|
+
Generate a new 32-byte key and return it as a hex-encoded string.
|
|
122
|
+
|
|
123
|
+
>>> key = Encrypter.generateNewHexKey()
|
|
124
|
+
>>> isinstance(key, str)
|
|
125
|
+
True
|
|
126
|
+
>>> len(key) == 64
|
|
127
|
+
True
|
|
128
|
+
"""
|
|
129
|
+
newKey = random(secret.SecretBox.KEY_SIZE) # Generates 32 random bytes.
|
|
130
|
+
hexKey = encoding.HexEncoder.encode(newKey).decode("utf-8")
|
|
131
|
+
return hexKey
|
|
132
|
+
|
|
133
|
+
@staticmethod
|
|
134
|
+
def encodeKeysDictToBase64(keys: dict) -> str:
|
|
135
|
+
"""
|
|
136
|
+
Convert the keys dictionary to a JSON string and encode it in base64.
|
|
137
|
+
|
|
138
|
+
>>> Encrypter.encodeKeysDictToBase64({"v1": "abcdef"})
|
|
139
|
+
'eyJ2MSI6ICJhYmNkZWYifQ=='
|
|
140
|
+
"""
|
|
141
|
+
json_str = json.dumps(keys)
|
|
142
|
+
base64_str = base64.b64encode(json_str.encode("utf-8")).decode("utf-8")
|
|
143
|
+
return base64_str
|
|
144
|
+
|
|
145
|
+
@staticmethod
|
|
146
|
+
def decodeKeysFromBase64(encodedKeys: str) -> dict:
|
|
147
|
+
"""
|
|
148
|
+
Decode a base64-encoded JSON string into a dictionary of keys.
|
|
149
|
+
|
|
150
|
+
>>> Encrypter.decodeKeysFromBase64("eyJ2MSI6ICJhYmNkZWYifQ==")
|
|
151
|
+
{'v1': 'abcdef'}
|
|
152
|
+
"""
|
|
153
|
+
try:
|
|
154
|
+
# Decode the base64 string into a JSON string.
|
|
155
|
+
json_str = base64.b64decode(encodedKeys).decode("utf-8")
|
|
156
|
+
# Convert the JSON string into a dictionary.
|
|
157
|
+
keys_dict = json.loads(json_str)
|
|
158
|
+
return keys_dict
|
|
159
|
+
except Exception as e:
|
|
160
|
+
raise ValueError("Failed to decode keys from base64 string.") from e
|
|
161
|
+
|
|
162
|
+
@staticmethod
|
|
163
|
+
def appendNewEncryptKey(currentKeys: str):
|
|
164
|
+
decodedkeys = Encrypter.decodeKeysFromBase64(currentKeys)
|
|
165
|
+
latestKey = list(decodedkeys.keys())[-1]
|
|
166
|
+
newVersion = f"{Encrypter.VERSION_PREFIX}{Encrypter.getEncryptionKeyVersionNumber(latestKey) + 1}"
|
|
167
|
+
decodedkeys.update({newVersion: Encrypter.generateNewHexKey()})
|
|
168
|
+
return {"newKeysDecoded": decodedkeys, "EncryptedKeys": Encrypter.encodeKeysDictToBase64(decodedkeys)}
|
rediskit/memoize.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import functools
|
|
3
|
+
import inspect
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import pickle
|
|
7
|
+
from typing import Any, Awaitable, Callable, Literal, overload
|
|
8
|
+
|
|
9
|
+
import zstd
|
|
10
|
+
from redis import Redis
|
|
11
|
+
|
|
12
|
+
from rediskit import config, redisClient
|
|
13
|
+
from rediskit.encrypter import Encrypter
|
|
14
|
+
from rediskit.redisClient import HGetCacheFromRedis, HSetCacheToRedis
|
|
15
|
+
from rediskit.redisLock import GetAsyncRedisMutexLock, GetRedisMutexLock
|
|
16
|
+
|
|
17
|
+
log = logging.getLogger(__name__)
|
|
18
|
+
CacheTypeOptions = Literal["zipPickled", "zipJson"]
|
|
19
|
+
RedisStorageOptions = Literal["string", "hash"]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def splitHashKey(key: str) -> tuple[str, str]:
|
|
23
|
+
*parts, field = key.split(":")
|
|
24
|
+
if not parts:
|
|
25
|
+
raise ValueError("Cannot use a single-part key with hash storage.")
|
|
26
|
+
return ":".join(parts), field
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def compressAndSign(data: Any, serializeFn: Callable[[Any], bytes], enableEncryption: bool = False) -> str:
|
|
30
|
+
serializedData = serializeFn(data)
|
|
31
|
+
if enableEncryption:
|
|
32
|
+
compressedData = Encrypter().encrypt(serializedData)
|
|
33
|
+
else:
|
|
34
|
+
compressedData = zstd.compress(serializedData)
|
|
35
|
+
|
|
36
|
+
return base64.b64encode(compressedData).decode("utf-8")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def verifyAndDecompress(payload: bytes, deserializeFn: Callable[[bytes], Any], enableEncryption: bool = False) -> Any:
|
|
40
|
+
if enableEncryption:
|
|
41
|
+
serializedData = Encrypter().decrypt(payload)
|
|
42
|
+
else:
|
|
43
|
+
serializedData = zstd.decompress(payload)
|
|
44
|
+
return deserializeFn(serializedData)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def deserializeData(
|
|
48
|
+
data: Any,
|
|
49
|
+
cacheType: CacheTypeOptions,
|
|
50
|
+
enableEncryption: bool = False,
|
|
51
|
+
) -> bytes:
|
|
52
|
+
if cacheType == "zipPickled":
|
|
53
|
+
cachedData = verifyAndDecompress(base64.b64decode(data), lambda b: pickle.loads(b), enableEncryption)
|
|
54
|
+
elif cacheType == "zipJson":
|
|
55
|
+
cachedData = verifyAndDecompress(base64.b64decode(data), lambda b: json.loads(b.decode("utf-8")), enableEncryption)
|
|
56
|
+
else:
|
|
57
|
+
raise ValueError("Unknown cacheType specified.")
|
|
58
|
+
|
|
59
|
+
return cachedData
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def serializeData(
|
|
63
|
+
data: Any,
|
|
64
|
+
cacheType: CacheTypeOptions,
|
|
65
|
+
enableEncryption: bool = False,
|
|
66
|
+
) -> str:
|
|
67
|
+
if cacheType == "zipPickled":
|
|
68
|
+
payload = compressAndSign(data, lambda d: pickle.dumps(d), enableEncryption)
|
|
69
|
+
elif cacheType == "zipJson":
|
|
70
|
+
payload = compressAndSign(data, lambda d: json.dumps(d).encode("utf-8"), enableEncryption)
|
|
71
|
+
else:
|
|
72
|
+
raise ValueError("Unknown cacheType specified.")
|
|
73
|
+
return payload
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def computeValue[T](param: T | Callable[..., T], *args, **kwargs) -> T:
|
|
77
|
+
if callable(param):
|
|
78
|
+
sig = inspect.signature(param)
|
|
79
|
+
accepts_kwargs = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values())
|
|
80
|
+
|
|
81
|
+
if accepts_kwargs:
|
|
82
|
+
# Pass all kwargs directly
|
|
83
|
+
value = param(*args, **kwargs)
|
|
84
|
+
else:
|
|
85
|
+
# Filter only matching kwargs
|
|
86
|
+
filtered_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters}
|
|
87
|
+
bound = sig.bind(*args, **filtered_kwargs)
|
|
88
|
+
bound.apply_defaults()
|
|
89
|
+
value = param(*bound.args, **bound.kwargs)
|
|
90
|
+
return value
|
|
91
|
+
else:
|
|
92
|
+
return param
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def maybeDataInCache(
|
|
96
|
+
tenantId: str | None,
|
|
97
|
+
computedMemoizeKey: str,
|
|
98
|
+
computedTtl: int | None,
|
|
99
|
+
cacheType: CacheTypeOptions,
|
|
100
|
+
resetTtlUponRead: bool,
|
|
101
|
+
byPassCachedData: bool,
|
|
102
|
+
enableEncryption: bool,
|
|
103
|
+
storageType: RedisStorageOptions = "string",
|
|
104
|
+
connection: Redis | None = None,
|
|
105
|
+
) -> Any:
|
|
106
|
+
if byPassCachedData:
|
|
107
|
+
log.info(f"Cache bypassed for tenantId: {tenantId}, key {computedMemoizeKey}")
|
|
108
|
+
return None
|
|
109
|
+
|
|
110
|
+
cachedData = None
|
|
111
|
+
if storageType == "string":
|
|
112
|
+
cached = redisClient.LoadBlobFromRedis(
|
|
113
|
+
tenantId, match=computedMemoizeKey, setTtlOnRead=computedTtl if resetTtlUponRead and computedTtl is not None else None, connection=connection
|
|
114
|
+
)
|
|
115
|
+
if cached:
|
|
116
|
+
log.info(f"Cache hit tenantId: {tenantId}, key: {computedMemoizeKey}")
|
|
117
|
+
cachedData = cached
|
|
118
|
+
elif storageType == "hash":
|
|
119
|
+
hashKey, field = splitHashKey(computedMemoizeKey)
|
|
120
|
+
cachedDict = HGetCacheFromRedis(
|
|
121
|
+
tenantId, hashKey, field, setTtlOnRead=computedTtl if resetTtlUponRead and computedTtl is not None else None, connection=connection
|
|
122
|
+
)
|
|
123
|
+
if cachedDict and field in cachedDict and cachedDict[field] is not None:
|
|
124
|
+
log.info(f"HASH cache hit tenantId: {tenantId}, key: {hashKey}, field: {field}")
|
|
125
|
+
cachedData = cachedDict[field]
|
|
126
|
+
else:
|
|
127
|
+
raise ValueError(f"Unknown storageType: {storageType}")
|
|
128
|
+
|
|
129
|
+
if cachedData:
|
|
130
|
+
return deserializeData(cachedData, cacheType, enableEncryption)
|
|
131
|
+
else:
|
|
132
|
+
log.info(f"No cache found tenantId: {tenantId}, key: {computedMemoizeKey}")
|
|
133
|
+
return None
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def dumpData(
|
|
137
|
+
data: Any,
|
|
138
|
+
tenantId: str,
|
|
139
|
+
computedMemoizeKey: str,
|
|
140
|
+
cacheType: CacheTypeOptions,
|
|
141
|
+
computedTtl: int | None,
|
|
142
|
+
enableEncryption: bool,
|
|
143
|
+
storageType: RedisStorageOptions = "string",
|
|
144
|
+
connection: Redis | None = None,
|
|
145
|
+
) -> None:
|
|
146
|
+
payload = serializeData(data, cacheType, enableEncryption)
|
|
147
|
+
if storageType == "string":
|
|
148
|
+
redisClient.DumpBlobToRedis(tenantId, computedMemoizeKey, payload=payload, ttl=computedTtl, connection=connection)
|
|
149
|
+
elif storageType == "hash":
|
|
150
|
+
hashKey, field = splitHashKey(computedMemoizeKey)
|
|
151
|
+
HSetCacheToRedis(tenantId, hashKey, fields={field: payload}, ttl=computedTtl, connection=connection)
|
|
152
|
+
else:
|
|
153
|
+
raise ValueError(f"Unknown storageType: {storageType}")
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def RedisMemoize[T](
|
|
157
|
+
memoizeKey: Callable[..., str] | str,
|
|
158
|
+
ttl: Callable[..., int] | int | None = None,
|
|
159
|
+
bypassCache: Callable[..., bool] | bool = False,
|
|
160
|
+
cacheType: CacheTypeOptions = "zipJson",
|
|
161
|
+
resetTtlUponRead: bool = True,
|
|
162
|
+
enableEncryption: bool = False,
|
|
163
|
+
storageType: RedisStorageOptions = "string",
|
|
164
|
+
connection: Redis | None = None,
|
|
165
|
+
) -> Callable[[Callable[..., T]], Callable[..., T]]:
|
|
166
|
+
"""Caches the result of any function in Redis using either pickle or JSON.
|
|
167
|
+
|
|
168
|
+
The decorated function must have 'tenantId' as an arg or kwarg.
|
|
169
|
+
|
|
170
|
+
Params:
|
|
171
|
+
-------
|
|
172
|
+
- memoizeKey: Callable computing a memoize key based on wrapped funcs args and kwargs, callable shall define the logic to compute the correct memoize key.
|
|
173
|
+
- ttl: Time To Live, either fixed value, or callable consuming args+kwargs to return a ttl. Default None, if None no ttl is set.
|
|
174
|
+
- bypassCache: Don't get data from cache, run wrapped func and update cache. run new values.
|
|
175
|
+
- cacheType: "zipPickled" Uses pickle for arbitrary Python objects, "zipJson" Uses JSON for data that is JSON serializable.
|
|
176
|
+
- resetTtlUponRead: Set the ttl to the initial value upon reading the value from redis cache
|
|
177
|
+
- connection: Custom Redis connection to use instead of the default connection pool
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
def computeMemoizeKey(*args, **kwargs) -> str:
|
|
181
|
+
if not (isinstance(memoizeKey, str) or callable(memoizeKey)):
|
|
182
|
+
raise ValueError(f"Expected memoizeKey to be Callable or a str. got {type(memoizeKey)}")
|
|
183
|
+
return computeValue(memoizeKey, *args, **kwargs)
|
|
184
|
+
|
|
185
|
+
def computeTtl(*args, **kwargs) -> int | None:
|
|
186
|
+
if ttl is None:
|
|
187
|
+
return None
|
|
188
|
+
if not (isinstance(ttl, int) or callable(ttl)):
|
|
189
|
+
raise ValueError(f"Expected ttl to be Callable or an int. got {type(ttl)}")
|
|
190
|
+
return computeValue(ttl, *args, **kwargs)
|
|
191
|
+
|
|
192
|
+
def computeByPassCache(*args, **kwargs) -> bool:
|
|
193
|
+
if not (isinstance(bypassCache, bool) or callable(bypassCache)):
|
|
194
|
+
raise ValueError(f"Expected bypassCache to be Callable or an int. got {type(bypassCache)}")
|
|
195
|
+
return computeValue(bypassCache, *args, **kwargs)
|
|
196
|
+
|
|
197
|
+
def computeTenantId(wrappedFunc: Callable[..., Any], *args, **kwargs) -> str | None:
|
|
198
|
+
boundArgs = inspect.signature(wrappedFunc).bind(*args, **kwargs)
|
|
199
|
+
boundArgs.apply_defaults()
|
|
200
|
+
tenantId = boundArgs.arguments.get("tenantId") or boundArgs.kwargs.get("tenantId")
|
|
201
|
+
# if tenantId is None:
|
|
202
|
+
# raise ValueError("tenantId not provided in either args or kwargs")
|
|
203
|
+
return tenantId
|
|
204
|
+
|
|
205
|
+
def getLockName(tenantId: str, computedMemoizeKey: str) -> str:
|
|
206
|
+
lockName = f"{config.REDISKIT_LOCK_CACHE_REDIS_MUTEX}:{tenantId}:{computedMemoizeKey}"
|
|
207
|
+
return lockName
|
|
208
|
+
|
|
209
|
+
def getParams(func, *args, **kwargs) -> tuple[str, int | None, str, str, bool]:
|
|
210
|
+
computedMemoizeKey = computeMemoizeKey(*args, **kwargs)
|
|
211
|
+
computedTtl = computeTtl(*args, **kwargs)
|
|
212
|
+
tenantId = computeTenantId(func, *args, **kwargs)
|
|
213
|
+
if tenantId is None:
|
|
214
|
+
raise ValueError("tenantId cannot be None")
|
|
215
|
+
lockName = getLockName(tenantId, computedMemoizeKey)
|
|
216
|
+
byPassCachedData = computeByPassCache(*args, **kwargs)
|
|
217
|
+
|
|
218
|
+
return computedMemoizeKey, computedTtl, tenantId, lockName, byPassCachedData
|
|
219
|
+
|
|
220
|
+
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
|
221
|
+
isAsyncFunc = inspect.iscoroutinefunction(func)
|
|
222
|
+
if isAsyncFunc:
|
|
223
|
+
|
|
224
|
+
@functools.wraps(func)
|
|
225
|
+
async def async_wrapper(*args, **kwargs) -> T:
|
|
226
|
+
computedMemoizeKey, computedTtl, tenantId, lockName, byPassCachedData = getParams(func, *args, **kwargs)
|
|
227
|
+
async with await GetAsyncRedisMutexLock(lockName, expire=60):
|
|
228
|
+
inCache = maybeDataInCache(
|
|
229
|
+
tenantId, computedMemoizeKey, computedTtl, cacheType, resetTtlUponRead, byPassCachedData, enableEncryption, storageType, connection
|
|
230
|
+
)
|
|
231
|
+
if inCache is not None:
|
|
232
|
+
return inCache
|
|
233
|
+
result = await func(*args, **kwargs)
|
|
234
|
+
if result is not None:
|
|
235
|
+
dumpData(result, tenantId, computedMemoizeKey, cacheType, computedTtl, enableEncryption, storageType, connection)
|
|
236
|
+
return result
|
|
237
|
+
|
|
238
|
+
return async_wrapper
|
|
239
|
+
|
|
240
|
+
else:
|
|
241
|
+
|
|
242
|
+
@functools.wraps(func)
|
|
243
|
+
def wrapper(*args, **kwargs) -> T:
|
|
244
|
+
computedMemoizeKey, computedTtl, tenantId, lockName, byPassCachedData = getParams(func, *args, **kwargs)
|
|
245
|
+
with GetRedisMutexLock(lockName, auto_renewal=True, expire=60):
|
|
246
|
+
inCache = maybeDataInCache(
|
|
247
|
+
tenantId, computedMemoizeKey, computedTtl, cacheType, resetTtlUponRead, byPassCachedData, enableEncryption, storageType, connection
|
|
248
|
+
)
|
|
249
|
+
if inCache is not None:
|
|
250
|
+
return inCache
|
|
251
|
+
result = func(*args, **kwargs)
|
|
252
|
+
if result is not None:
|
|
253
|
+
dumpData(result, tenantId, computedMemoizeKey, cacheType, computedTtl, enableEncryption, storageType, connection)
|
|
254
|
+
return result
|
|
255
|
+
|
|
256
|
+
return wrapper
|
|
257
|
+
|
|
258
|
+
return decorator
|