rediskit 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rediskit/__init__.py ADDED
@@ -0,0 +1,19 @@
1
+ """
2
+ rediskit - Redis-backed performance and concurrency primitives for Python applications.
3
+
4
+ Provides caching, distributed coordination, and data protection using Redis.
5
+ """
6
+
7
+ from rediskit.encrypter import Encrypter
8
+ from src.rediskit.memoize import RedisMemoize
9
+
10
+ __all__ = [
11
+ "RedisMemoize",
12
+ "InitRedisConnectionPool",
13
+ "InitAsyncRedisConnectionPool",
14
+ "GetRedisConnection",
15
+ "GetAsyncRedisConnection",
16
+ "GetRedisMutexLock",
17
+ "GetAsyncRedisMutexLock",
18
+ "Encrypter",
19
+ ]
rediskit/config.py ADDED
@@ -0,0 +1,23 @@
1
+ import os
2
+
3
+ from DockerBuildSystem import TerminalTools
4
+
5
+ from src.rediskit.utils import base64JsonToDict
6
+
7
+ TerminalTools.LoadDefaultEnvironmentVariablesFile("private.env")
8
+ TerminalTools.LoadDefaultEnvironmentVariablesFile(".env")
9
+
10
+ # Redis Settings
11
+ REDISKIT_REDIS_HOST = os.environ.get("REDISKIT_REDIS_HOST", "localhost")
12
+ REDISKIT_REDIS_PORT = int(os.environ.get("REDISKIT_REDIS_PORT", "6379"))
13
+ REDISKIT_REDIS_PASSWORD = os.environ.get("REDISKIT_REDIS_PASSWORD", "")
14
+ REDISKIT_REDIS_TOP_NODE = os.environ.get("REDISKIT_REDIS_TOP_NODE", "redis_kit_node")
15
+ REDISKIT_REDIS_SCAN_COUNT = int(os.environ.get("REDISKIT_REDIS_SCAN_COUNT", "10000"))
16
+ REDISKIT_REDIS_SKIP_CACHING = os.environ.get("REDISKIT_REDIS_SKIP_CACHING", "false").upper() == "TRUE"
17
+
18
+ # Lock Settings
19
+ REDISKIT_LOCK_SETTINGS_REDIS_NAMESPACE = os.environ.get("REDISKIT_LOCK_SETTINGS_REDIS_NAMESPACE", f"{REDISKIT_REDIS_TOP_NODE}:LOCK")
20
+ REDISKIT_LOCK_ASYNC_SETTINGS_REDIS_NAMESPACE = os.environ.get("REDISKIT_LOCK_ASYNC_SETTINGS_REDIS_NAMESPACE", f"{REDISKIT_REDIS_TOP_NODE}:LOCK_ASYNC")
21
+ REDISKIT_LOCK_CACHE_REDIS_MUTEX = os.environ.get("REDISKIT_LOCK_CACHE_REDIS_MUTEX", "REDISKIT_LOCK_CACHE_REDIS_MUTEX")
22
+
23
+ REDISKIT_ENCRYPTION_SECRET = base64JsonToDict(os.environ.get("REDISKIT_ENCRYPTION_SECRET", ""))
rediskit/encrypter.py ADDED
@@ -0,0 +1,168 @@
1
+ import base64
2
+ import gzip
3
+ import json
4
+ import re
5
+
6
+ import zstd
7
+ from nacl import encoding, secret
8
+ from nacl.utils import random
9
+
10
+ from rediskit import config
11
+
12
+
13
+ class Encrypter:
14
+ VERSION_PREFIX = "__enc_v"
15
+
16
+ # EncryptedCompressedBase64Box
17
+ def __init__(self, keyHexDict: dict[str, str] = config.REDISKIT_ENCRYPTION_SECRET) -> None:
18
+ """
19
+ keysBase64 shall have the following format {"__enc_v1": "32-byte key"...,"__enc_vn": ...}
20
+ """
21
+ self.encryptionKeys = keyHexDict
22
+ self.latestVersion = list(self.encryptionKeys.keys())[-1]
23
+
24
+ def _getSecretBox(self, version: str) -> secret.SecretBox:
25
+ hexKey = self.encryptionKeys.get(version)
26
+ if not hexKey:
27
+ raise ValueError(f"Encryption key for version {version} not found")
28
+ # Ensure the key is in the proper format (32 bytes after decoding from hex)
29
+ return secret.SecretBox(hexKey.encode(), encoder=encoding.HexEncoder)
30
+
31
+ def encrypt[T: str | bytes | None](self, data: T, raiseIfEncrypted: bool = True, useZstd: bool = True) -> T:
32
+ if data is None:
33
+ return None # type: ignore # not able to check this properly
34
+ elif isinstance(data, str):
35
+ dataToEncrypt: bytes = data.encode()
36
+ isText = True
37
+ elif isinstance(data, bytes):
38
+ dataToEncrypt = data
39
+ isText = False
40
+ else:
41
+ raise ValueError("data expected to be bytes or str")
42
+
43
+ if self.isEncrypted(dataToEncrypt, raiseIfEncrypted):
44
+ return data
45
+
46
+ compressedData: bytes = zstd.compress(dataToEncrypt) if useZstd else gzip.compress(dataToEncrypt)
47
+ tagBytes = b"zstd" if useZstd else b"gzip"
48
+
49
+ cipherText = self._getSecretBox(self.latestVersion).encrypt(compressedData, encoder=encoding.Base64Encoder)
50
+ token = self.latestVersion.encode() + b"|" + tagBytes + b":" + cipherText
51
+
52
+ return token.decode() if isText else token # type: ignore # not able to check this properly
53
+
54
+ def decrypt[T: str | bytes | None](self, data: T) -> T:
55
+ if data is None:
56
+ return None # type: ignore # not able to check this properly
57
+ elif isinstance(data, str):
58
+ dataToDecrypt: bytes = data.encode()
59
+ isText = True
60
+ elif isinstance(data, bytes):
61
+ dataToDecrypt = data
62
+ isText = False
63
+ else:
64
+ raise ValueError("data expected to be bytes or str")
65
+
66
+ try:
67
+ if dataToDecrypt.startswith(self.VERSION_PREFIX.encode()):
68
+ if b"|" in dataToDecrypt:
69
+ # Format: __enc_vX|compression:ciphertext
70
+ pre, rest = dataToDecrypt.split(b"|", 1)
71
+ version = pre
72
+ compressionTag, ciphertext = rest.split(b":", 1)
73
+ elif b":" in dataToDecrypt:
74
+ # Format: __enc_vX:ciphertext
75
+ version, ciphertext = dataToDecrypt.split(b":", 1)
76
+ compressionTag = b"gzip"
77
+ else:
78
+ raise ValueError("Invalid encrypted data format")
79
+ else:
80
+ # No version: legacy, fallback to v1 + gzip
81
+ version = f"{self.VERSION_PREFIX}1".encode()
82
+ compressionTag = b"gzip"
83
+ ciphertext = dataToDecrypt
84
+ except ValueError:
85
+ raise ValueError("Invalid encrypted data format; expected version prefix.")
86
+
87
+ secret_box = self._getSecretBox(version.decode())
88
+ compressed_data = secret_box.decrypt(ciphertext, encoder=encoding.Base64Encoder)
89
+
90
+ if compressionTag == b"zstd":
91
+ deCompressed = zstd.decompress(compressed_data)
92
+ elif compressionTag == b"gzip":
93
+ deCompressed = gzip.decompress(compressed_data)
94
+ else:
95
+ raise ValueError(f"Unknown compression '{compressionTag.decode()}' in encrypted data.")
96
+
97
+ return deCompressed.decode() if isText else deCompressed # type: ignore # not able to check this properly
98
+
99
+ @staticmethod
100
+ def getEncryptionKeyVersionNumber(versionKey: str) -> int:
101
+ match = re.match(r"^__enc_v(\d+)$", versionKey)
102
+ if not match:
103
+ raise ValueError(f"Invalid version string: {versionKey}")
104
+ return int(match.group(1))
105
+
106
+ @staticmethod
107
+ def isEncrypted(data: str | bytes | None, raiseIfEncrypted: bool = False) -> bool:
108
+ # NB! can be false positive if a secret starts with "__enc_v"..., this risk is just accepted for now
109
+ if data is None:
110
+ return False
111
+ dataBytes = data.encode() if isinstance(data, str) else data
112
+ if not dataBytes.startswith(Encrypter.VERSION_PREFIX.encode()) or b":" not in dataBytes:
113
+ return False
114
+ if raiseIfEncrypted:
115
+ raise Exception("The data is already encrypted")
116
+ return True
117
+
118
+ @staticmethod
119
+ def generateNewHexKey() -> str:
120
+ """
121
+ Generate a new 32-byte key and return it as a hex-encoded string.
122
+
123
+ >>> key = Encrypter.generateNewHexKey()
124
+ >>> isinstance(key, str)
125
+ True
126
+ >>> len(key) == 64
127
+ True
128
+ """
129
+ newKey = random(secret.SecretBox.KEY_SIZE) # Generates 32 random bytes.
130
+ hexKey = encoding.HexEncoder.encode(newKey).decode("utf-8")
131
+ return hexKey
132
+
133
+ @staticmethod
134
+ def encodeKeysDictToBase64(keys: dict) -> str:
135
+ """
136
+ Convert the keys dictionary to a JSON string and encode it in base64.
137
+
138
+ >>> Encrypter.encodeKeysDictToBase64({"v1": "abcdef"})
139
+ 'eyJ2MSI6ICJhYmNkZWYifQ=='
140
+ """
141
+ json_str = json.dumps(keys)
142
+ base64_str = base64.b64encode(json_str.encode("utf-8")).decode("utf-8")
143
+ return base64_str
144
+
145
+ @staticmethod
146
+ def decodeKeysFromBase64(encodedKeys: str) -> dict:
147
+ """
148
+ Decode a base64-encoded JSON string into a dictionary of keys.
149
+
150
+ >>> Encrypter.decodeKeysFromBase64("eyJ2MSI6ICJhYmNkZWYifQ==")
151
+ {'v1': 'abcdef'}
152
+ """
153
+ try:
154
+ # Decode the base64 string into a JSON string.
155
+ json_str = base64.b64decode(encodedKeys).decode("utf-8")
156
+ # Convert the JSON string into a dictionary.
157
+ keys_dict = json.loads(json_str)
158
+ return keys_dict
159
+ except Exception as e:
160
+ raise ValueError("Failed to decode keys from base64 string.") from e
161
+
162
+ @staticmethod
163
+ def appendNewEncryptKey(currentKeys: str):
164
+ decodedkeys = Encrypter.decodeKeysFromBase64(currentKeys)
165
+ latestKey = list(decodedkeys.keys())[-1]
166
+ newVersion = f"{Encrypter.VERSION_PREFIX}{Encrypter.getEncryptionKeyVersionNumber(latestKey) + 1}"
167
+ decodedkeys.update({newVersion: Encrypter.generateNewHexKey()})
168
+ return {"newKeysDecoded": decodedkeys, "EncryptedKeys": Encrypter.encodeKeysDictToBase64(decodedkeys)}
rediskit/memoize.py ADDED
@@ -0,0 +1,258 @@
1
+ import base64
2
+ import functools
3
+ import inspect
4
+ import json
5
+ import logging
6
+ import pickle
7
+ from typing import Any, Awaitable, Callable, Literal, overload
8
+
9
+ import zstd
10
+ from redis import Redis
11
+
12
+ from rediskit import config, redisClient
13
+ from rediskit.encrypter import Encrypter
14
+ from rediskit.redisClient import HGetCacheFromRedis, HSetCacheToRedis
15
+ from rediskit.redisLock import GetAsyncRedisMutexLock, GetRedisMutexLock
16
+
17
+ log = logging.getLogger(__name__)
18
+ CacheTypeOptions = Literal["zipPickled", "zipJson"]
19
+ RedisStorageOptions = Literal["string", "hash"]
20
+
21
+
22
+ def splitHashKey(key: str) -> tuple[str, str]:
23
+ *parts, field = key.split(":")
24
+ if not parts:
25
+ raise ValueError("Cannot use a single-part key with hash storage.")
26
+ return ":".join(parts), field
27
+
28
+
29
+ def compressAndSign(data: Any, serializeFn: Callable[[Any], bytes], enableEncryption: bool = False) -> str:
30
+ serializedData = serializeFn(data)
31
+ if enableEncryption:
32
+ compressedData = Encrypter().encrypt(serializedData)
33
+ else:
34
+ compressedData = zstd.compress(serializedData)
35
+
36
+ return base64.b64encode(compressedData).decode("utf-8")
37
+
38
+
39
+ def verifyAndDecompress(payload: bytes, deserializeFn: Callable[[bytes], Any], enableEncryption: bool = False) -> Any:
40
+ if enableEncryption:
41
+ serializedData = Encrypter().decrypt(payload)
42
+ else:
43
+ serializedData = zstd.decompress(payload)
44
+ return deserializeFn(serializedData)
45
+
46
+
47
+ def deserializeData(
48
+ data: Any,
49
+ cacheType: CacheTypeOptions,
50
+ enableEncryption: bool = False,
51
+ ) -> bytes:
52
+ if cacheType == "zipPickled":
53
+ cachedData = verifyAndDecompress(base64.b64decode(data), lambda b: pickle.loads(b), enableEncryption)
54
+ elif cacheType == "zipJson":
55
+ cachedData = verifyAndDecompress(base64.b64decode(data), lambda b: json.loads(b.decode("utf-8")), enableEncryption)
56
+ else:
57
+ raise ValueError("Unknown cacheType specified.")
58
+
59
+ return cachedData
60
+
61
+
62
+ def serializeData(
63
+ data: Any,
64
+ cacheType: CacheTypeOptions,
65
+ enableEncryption: bool = False,
66
+ ) -> str:
67
+ if cacheType == "zipPickled":
68
+ payload = compressAndSign(data, lambda d: pickle.dumps(d), enableEncryption)
69
+ elif cacheType == "zipJson":
70
+ payload = compressAndSign(data, lambda d: json.dumps(d).encode("utf-8"), enableEncryption)
71
+ else:
72
+ raise ValueError("Unknown cacheType specified.")
73
+ return payload
74
+
75
+
76
+ def computeValue[T](param: T | Callable[..., T], *args, **kwargs) -> T:
77
+ if callable(param):
78
+ sig = inspect.signature(param)
79
+ accepts_kwargs = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values())
80
+
81
+ if accepts_kwargs:
82
+ # Pass all kwargs directly
83
+ value = param(*args, **kwargs)
84
+ else:
85
+ # Filter only matching kwargs
86
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters}
87
+ bound = sig.bind(*args, **filtered_kwargs)
88
+ bound.apply_defaults()
89
+ value = param(*bound.args, **bound.kwargs)
90
+ return value
91
+ else:
92
+ return param
93
+
94
+
95
+ def maybeDataInCache(
96
+ tenantId: str | None,
97
+ computedMemoizeKey: str,
98
+ computedTtl: int | None,
99
+ cacheType: CacheTypeOptions,
100
+ resetTtlUponRead: bool,
101
+ byPassCachedData: bool,
102
+ enableEncryption: bool,
103
+ storageType: RedisStorageOptions = "string",
104
+ connection: Redis | None = None,
105
+ ) -> Any:
106
+ if byPassCachedData:
107
+ log.info(f"Cache bypassed for tenantId: {tenantId}, key {computedMemoizeKey}")
108
+ return None
109
+
110
+ cachedData = None
111
+ if storageType == "string":
112
+ cached = redisClient.LoadBlobFromRedis(
113
+ tenantId, match=computedMemoizeKey, setTtlOnRead=computedTtl if resetTtlUponRead and computedTtl is not None else None, connection=connection
114
+ )
115
+ if cached:
116
+ log.info(f"Cache hit tenantId: {tenantId}, key: {computedMemoizeKey}")
117
+ cachedData = cached
118
+ elif storageType == "hash":
119
+ hashKey, field = splitHashKey(computedMemoizeKey)
120
+ cachedDict = HGetCacheFromRedis(
121
+ tenantId, hashKey, field, setTtlOnRead=computedTtl if resetTtlUponRead and computedTtl is not None else None, connection=connection
122
+ )
123
+ if cachedDict and field in cachedDict and cachedDict[field] is not None:
124
+ log.info(f"HASH cache hit tenantId: {tenantId}, key: {hashKey}, field: {field}")
125
+ cachedData = cachedDict[field]
126
+ else:
127
+ raise ValueError(f"Unknown storageType: {storageType}")
128
+
129
+ if cachedData:
130
+ return deserializeData(cachedData, cacheType, enableEncryption)
131
+ else:
132
+ log.info(f"No cache found tenantId: {tenantId}, key: {computedMemoizeKey}")
133
+ return None
134
+
135
+
136
+ def dumpData(
137
+ data: Any,
138
+ tenantId: str,
139
+ computedMemoizeKey: str,
140
+ cacheType: CacheTypeOptions,
141
+ computedTtl: int | None,
142
+ enableEncryption: bool,
143
+ storageType: RedisStorageOptions = "string",
144
+ connection: Redis | None = None,
145
+ ) -> None:
146
+ payload = serializeData(data, cacheType, enableEncryption)
147
+ if storageType == "string":
148
+ redisClient.DumpBlobToRedis(tenantId, computedMemoizeKey, payload=payload, ttl=computedTtl, connection=connection)
149
+ elif storageType == "hash":
150
+ hashKey, field = splitHashKey(computedMemoizeKey)
151
+ HSetCacheToRedis(tenantId, hashKey, fields={field: payload}, ttl=computedTtl, connection=connection)
152
+ else:
153
+ raise ValueError(f"Unknown storageType: {storageType}")
154
+
155
+
156
+ def RedisMemoize[T](
157
+ memoizeKey: Callable[..., str] | str,
158
+ ttl: Callable[..., int] | int | None = None,
159
+ bypassCache: Callable[..., bool] | bool = False,
160
+ cacheType: CacheTypeOptions = "zipJson",
161
+ resetTtlUponRead: bool = True,
162
+ enableEncryption: bool = False,
163
+ storageType: RedisStorageOptions = "string",
164
+ connection: Redis | None = None,
165
+ ) -> Callable[[Callable[..., T]], Callable[..., T]]:
166
+ """Caches the result of any function in Redis using either pickle or JSON.
167
+
168
+ The decorated function must have 'tenantId' as an arg or kwarg.
169
+
170
+ Params:
171
+ -------
172
+ - memoizeKey: Callable computing a memoize key based on wrapped funcs args and kwargs, callable shall define the logic to compute the correct memoize key.
173
+ - ttl: Time To Live, either fixed value, or callable consuming args+kwargs to return a ttl. Default None, if None no ttl is set.
174
+ - bypassCache: Don't get data from cache, run wrapped func and update cache. run new values.
175
+ - cacheType: "zipPickled" Uses pickle for arbitrary Python objects, "zipJson" Uses JSON for data that is JSON serializable.
176
+ - resetTtlUponRead: Set the ttl to the initial value upon reading the value from redis cache
177
+ - connection: Custom Redis connection to use instead of the default connection pool
178
+ """
179
+
180
+ def computeMemoizeKey(*args, **kwargs) -> str:
181
+ if not (isinstance(memoizeKey, str) or callable(memoizeKey)):
182
+ raise ValueError(f"Expected memoizeKey to be Callable or a str. got {type(memoizeKey)}")
183
+ return computeValue(memoizeKey, *args, **kwargs)
184
+
185
+ def computeTtl(*args, **kwargs) -> int | None:
186
+ if ttl is None:
187
+ return None
188
+ if not (isinstance(ttl, int) or callable(ttl)):
189
+ raise ValueError(f"Expected ttl to be Callable or an int. got {type(ttl)}")
190
+ return computeValue(ttl, *args, **kwargs)
191
+
192
+ def computeByPassCache(*args, **kwargs) -> bool:
193
+ if not (isinstance(bypassCache, bool) or callable(bypassCache)):
194
+ raise ValueError(f"Expected bypassCache to be Callable or an int. got {type(bypassCache)}")
195
+ return computeValue(bypassCache, *args, **kwargs)
196
+
197
+ def computeTenantId(wrappedFunc: Callable[..., Any], *args, **kwargs) -> str | None:
198
+ boundArgs = inspect.signature(wrappedFunc).bind(*args, **kwargs)
199
+ boundArgs.apply_defaults()
200
+ tenantId = boundArgs.arguments.get("tenantId") or boundArgs.kwargs.get("tenantId")
201
+ # if tenantId is None:
202
+ # raise ValueError("tenantId not provided in either args or kwargs")
203
+ return tenantId
204
+
205
+ def getLockName(tenantId: str, computedMemoizeKey: str) -> str:
206
+ lockName = f"{config.REDISKIT_LOCK_CACHE_REDIS_MUTEX}:{tenantId}:{computedMemoizeKey}"
207
+ return lockName
208
+
209
+ def getParams(func, *args, **kwargs) -> tuple[str, int | None, str, str, bool]:
210
+ computedMemoizeKey = computeMemoizeKey(*args, **kwargs)
211
+ computedTtl = computeTtl(*args, **kwargs)
212
+ tenantId = computeTenantId(func, *args, **kwargs)
213
+ if tenantId is None:
214
+ raise ValueError("tenantId cannot be None")
215
+ lockName = getLockName(tenantId, computedMemoizeKey)
216
+ byPassCachedData = computeByPassCache(*args, **kwargs)
217
+
218
+ return computedMemoizeKey, computedTtl, tenantId, lockName, byPassCachedData
219
+
220
+ def decorator(func: Callable[..., T]) -> Callable[..., T]:
221
+ isAsyncFunc = inspect.iscoroutinefunction(func)
222
+ if isAsyncFunc:
223
+
224
+ @functools.wraps(func)
225
+ async def async_wrapper(*args, **kwargs) -> T:
226
+ computedMemoizeKey, computedTtl, tenantId, lockName, byPassCachedData = getParams(func, *args, **kwargs)
227
+ async with await GetAsyncRedisMutexLock(lockName, expire=60):
228
+ inCache = maybeDataInCache(
229
+ tenantId, computedMemoizeKey, computedTtl, cacheType, resetTtlUponRead, byPassCachedData, enableEncryption, storageType, connection
230
+ )
231
+ if inCache is not None:
232
+ return inCache
233
+ result = await func(*args, **kwargs)
234
+ if result is not None:
235
+ dumpData(result, tenantId, computedMemoizeKey, cacheType, computedTtl, enableEncryption, storageType, connection)
236
+ return result
237
+
238
+ return async_wrapper
239
+
240
+ else:
241
+
242
+ @functools.wraps(func)
243
+ def wrapper(*args, **kwargs) -> T:
244
+ computedMemoizeKey, computedTtl, tenantId, lockName, byPassCachedData = getParams(func, *args, **kwargs)
245
+ with GetRedisMutexLock(lockName, auto_renewal=True, expire=60):
246
+ inCache = maybeDataInCache(
247
+ tenantId, computedMemoizeKey, computedTtl, cacheType, resetTtlUponRead, byPassCachedData, enableEncryption, storageType, connection
248
+ )
249
+ if inCache is not None:
250
+ return inCache
251
+ result = func(*args, **kwargs)
252
+ if result is not None:
253
+ dumpData(result, tenantId, computedMemoizeKey, cacheType, computedTtl, enableEncryption, storageType, connection)
254
+ return result
255
+
256
+ return wrapper
257
+
258
+ return decorator