clue-api 1.0.0.dev7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clue/.gitignore +21 -0
- clue/__init__.py +0 -0
- clue/api/__init__.py +211 -0
- clue/api/base.py +99 -0
- clue/api/v1/__init__.py +82 -0
- clue/api/v1/actions.py +92 -0
- clue/api/v1/auth.py +243 -0
- clue/api/v1/configs.py +83 -0
- clue/api/v1/fetchers.py +94 -0
- clue/api/v1/lookup.py +221 -0
- clue/api/v1/registration.py +109 -0
- clue/api/v1/static.py +94 -0
- clue/app.py +166 -0
- clue/cache/__init__.py +129 -0
- clue/common/__init__.py +0 -0
- clue/common/classification.py +1006 -0
- clue/common/classification.yml +130 -0
- clue/common/dict_utils.py +130 -0
- clue/common/exceptions.py +199 -0
- clue/common/forge.py +152 -0
- clue/common/json_utils.py +10 -0
- clue/common/list_utils.py +11 -0
- clue/common/logging/__init__.py +291 -0
- clue/common/logging/audit.py +157 -0
- clue/common/logging/format.py +42 -0
- clue/common/regex.py +31 -0
- clue/common/str_utils.py +213 -0
- clue/common/swagger.py +139 -0
- clue/common/uid.py +47 -0
- clue/config.py +60 -0
- clue/constants/__init__.py +0 -0
- clue/constants/supported_types.py +38 -0
- clue/cronjobs/__init__.py +30 -0
- clue/cronjobs/plugins.py +32 -0
- clue/error.py +129 -0
- clue/gunicorn_config.py +29 -0
- clue/healthz.py +74 -0
- clue/helper/discover.py +53 -0
- clue/helper/headers.py +30 -0
- clue/helper/oauth.py +128 -0
- clue/models/__init__.py +0 -0
- clue/models/actions.py +243 -0
- clue/models/config.py +456 -0
- clue/models/fetchers.py +136 -0
- clue/models/graph.py +162 -0
- clue/models/model_list.py +52 -0
- clue/models/network.py +430 -0
- clue/models/results/__init__.py +34 -0
- clue/models/results/base.py +10 -0
- clue/models/results/graph.py +26 -0
- clue/models/results/image.py +22 -0
- clue/models/results/status.py +55 -0
- clue/models/results/validation.py +57 -0
- clue/models/selector.py +67 -0
- clue/models/utils.py +52 -0
- clue/models/validators.py +19 -0
- clue/patched.py +8 -0
- clue/plugin/__init__.py +1008 -0
- clue/plugin/helpers/__init__.py +0 -0
- clue/plugin/helpers/central_server.py +27 -0
- clue/plugin/helpers/email_render.py +228 -0
- clue/plugin/helpers/token.py +34 -0
- clue/plugin/helpers/trino.py +103 -0
- clue/plugin/interactive.py +270 -0
- clue/plugin/models.py +19 -0
- clue/plugin/utils.py +78 -0
- clue/remote/__init__.py +0 -0
- clue/remote/datatypes/__init__.py +130 -0
- clue/remote/datatypes/cache.py +62 -0
- clue/remote/datatypes/events.py +118 -0
- clue/remote/datatypes/hash.py +193 -0
- clue/remote/datatypes/queues/__init__.py +0 -0
- clue/remote/datatypes/queues/comms.py +62 -0
- clue/remote/datatypes/set.py +96 -0
- clue/remote/datatypes/user_quota_tracker.py +54 -0
- clue/security/__init__.py +211 -0
- clue/security/obo.py +95 -0
- clue/security/utils.py +34 -0
- clue/services/action_service.py +186 -0
- clue/services/auth_service.py +348 -0
- clue/services/config_service.py +38 -0
- clue/services/fetcher_service.py +203 -0
- clue/services/jwt_service.py +233 -0
- clue/services/lookup_service.py +786 -0
- clue/services/type_service.py +165 -0
- clue/services/user_service.py +152 -0
- clue_api-1.0.0.dev7.dist-info/METADATA +111 -0
- clue_api-1.0.0.dev7.dist-info/RECORD +91 -0
- clue_api-1.0.0.dev7.dist-info/WHEEL +4 -0
- clue_api-1.0.0.dev7.dist-info/entry_points.txt +8 -0
- clue_api-1.0.0.dev7.dist-info/licenses/LICENSE +11 -0
clue/plugin/utils.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import time
|
|
3
|
+
from datetime import datetime, timezone
|
|
4
|
+
|
|
5
|
+
from flask import request
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
from clue.common.logging import get_logger
|
|
9
|
+
|
|
10
|
+
# Default settings
|
|
11
|
+
MAX_LIMIT = int(os.environ.get("MAX_LIMIT", 100))
|
|
12
|
+
MAX_TIMEOUT = float(os.environ.get("MAX_TIMEOUT", 3))
|
|
13
|
+
|
|
14
|
+
logger = get_logger(__file__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Params(BaseModel):
|
|
18
|
+
"A model build to parse arguments provided in the request object into a model"
|
|
19
|
+
|
|
20
|
+
deadline: float
|
|
21
|
+
"The epoch the central server wants to have a response by"
|
|
22
|
+
|
|
23
|
+
max_timeout: float
|
|
24
|
+
"The raw timeout value provided by the user"
|
|
25
|
+
|
|
26
|
+
annotate: bool
|
|
27
|
+
"Should the plugin return annotations about the given selector(s)?"
|
|
28
|
+
|
|
29
|
+
raw: bool
|
|
30
|
+
"Should the plugin return the raw data applicable to the given selector(s)?"
|
|
31
|
+
|
|
32
|
+
limit: int
|
|
33
|
+
"What is the maximum number of query entries that should be returned?"
|
|
34
|
+
|
|
35
|
+
use_cache: bool
|
|
36
|
+
"Does the request want to bypass any cached results?"
|
|
37
|
+
|
|
38
|
+
@classmethod
|
|
39
|
+
def from_request(cls):
|
|
40
|
+
"Create a Params object from flask's request object"
|
|
41
|
+
max_timeout = request.args.get("max_timeout", MAX_TIMEOUT)
|
|
42
|
+
try:
|
|
43
|
+
max_timeout = float(max_timeout)
|
|
44
|
+
except (ValueError, TypeError):
|
|
45
|
+
max_timeout = MAX_TIMEOUT
|
|
46
|
+
|
|
47
|
+
deadline = request.args.get("deadline", time.time() + max_timeout, type=float)
|
|
48
|
+
current_time = datetime.now(timezone.utc)
|
|
49
|
+
if deadline > 0 and current_time.timestamp() > deadline:
|
|
50
|
+
logger.warning(
|
|
51
|
+
"Deadline %s was earlier than the current time, %s",
|
|
52
|
+
str(datetime.fromtimestamp(deadline)),
|
|
53
|
+
str(current_time),
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
raise RuntimeError("Deadline exceeded")
|
|
57
|
+
|
|
58
|
+
logger.debug(
|
|
59
|
+
"Deadline %s hits in %sms",
|
|
60
|
+
str(datetime.fromtimestamp(deadline)),
|
|
61
|
+
round((deadline - current_time.timestamp()) * 1000),
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
annotate = request.args.get("no_annotation", "false").lower() not in ("true", "1", "")
|
|
65
|
+
use_cache = request.args.get("no_cache", "false").lower() not in ("true", "1", "")
|
|
66
|
+
raw = request.args.get("include_raw", "false").lower() in ("true", "1", "")
|
|
67
|
+
limit = request.args.get("limit", 100, type=int)
|
|
68
|
+
if limit > int(MAX_LIMIT):
|
|
69
|
+
limit = int(MAX_LIMIT)
|
|
70
|
+
|
|
71
|
+
return cls(
|
|
72
|
+
deadline=deadline, max_timeout=max_timeout, annotate=annotate, raw=raw, limit=limit, use_cache=use_cache
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
def __str__(self):
|
|
76
|
+
# Make a string representation of the params that can be used for caching purposes
|
|
77
|
+
# Deadline and max_timout are explicitely ignored otherwise it would never hit the cache
|
|
78
|
+
return f"a={self.annotate},r={self.raw},l={self.limit}"
|
clue/remote/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
import time
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
|
|
8
|
+
import redis
|
|
9
|
+
from packaging.version import parse
|
|
10
|
+
|
|
11
|
+
from clue.common.logging import get_logger
|
|
12
|
+
from clue.common.uid import get_random_id
|
|
13
|
+
|
|
14
|
+
logger = get_logger(__file__)
|
|
15
|
+
|
|
16
|
+
# Add a version warning if redis python client is < 2.10.0. Older versions
|
|
17
|
+
# have a connection bug that can manifest with the dispatcher.
|
|
18
|
+
if parse(redis.__version__) <= parse("2.10.0"):
|
|
19
|
+
import warnings
|
|
20
|
+
|
|
21
|
+
warnings.warn(
|
|
22
|
+
"%s works best with redis > 2.10.0. You're running"
|
|
23
|
+
" redis %s. You should upgrade." % (__name__, redis.__version__)
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
pool: dict[tuple[str, int], redis.ConnectionPool] = {}
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def now_as_iso():
|
|
31
|
+
s = datetime.utcfromtimestamp(time.time()).isoformat()
|
|
32
|
+
return f"{s}Z"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def reply_queue_name(prefix=None, suffix=None):
|
|
36
|
+
if prefix:
|
|
37
|
+
components = [prefix]
|
|
38
|
+
else:
|
|
39
|
+
components = []
|
|
40
|
+
|
|
41
|
+
components.append(get_random_id())
|
|
42
|
+
|
|
43
|
+
if suffix:
|
|
44
|
+
components.append(str(suffix))
|
|
45
|
+
|
|
46
|
+
return "-".join(components)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def retry_call(func, *args, **kw):
|
|
50
|
+
max_attempts = 10
|
|
51
|
+
maximum = 2
|
|
52
|
+
exponent = -7
|
|
53
|
+
|
|
54
|
+
attempts = 0
|
|
55
|
+
while True:
|
|
56
|
+
try:
|
|
57
|
+
ret_val = func(*args, **kw)
|
|
58
|
+
|
|
59
|
+
if exponent != -7:
|
|
60
|
+
logger.info("Reconnected to Redis!")
|
|
61
|
+
|
|
62
|
+
return ret_val
|
|
63
|
+
|
|
64
|
+
except (redis.ConnectionError, ConnectionResetError):
|
|
65
|
+
attempts += 1
|
|
66
|
+
|
|
67
|
+
if attempts > max_attempts:
|
|
68
|
+
logger.exception("Redis connection failed.")
|
|
69
|
+
raise
|
|
70
|
+
else:
|
|
71
|
+
logger.exception("No connection to Redis, reconnecting...")
|
|
72
|
+
time.sleep(2**exponent)
|
|
73
|
+
exponent = exponent + 1 if exponent < maximum else exponent
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _redis_ssl_kwargs(host: str) -> dict:
|
|
77
|
+
return dict(ssl_ca_certs=os.environ.get(f"{host.upper()}_ROOT_CA_PATH", "/etc/clue/ssl/clue_root-ca.crt"))
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def get_client(host, port, private):
|
|
81
|
+
# In case a structure is passed a client as host
|
|
82
|
+
if isinstance(host, (redis.Redis, redis.StrictRedis)):
|
|
83
|
+
return host
|
|
84
|
+
|
|
85
|
+
if not host or not port:
|
|
86
|
+
from clue.config import config
|
|
87
|
+
|
|
88
|
+
host = host or config.core.redis.host
|
|
89
|
+
port = int(port or config.core.redis.port)
|
|
90
|
+
|
|
91
|
+
ssl_kwargs = {}
|
|
92
|
+
|
|
93
|
+
# Automatically detect if encryption was enabled
|
|
94
|
+
tmp_ssl_kwargs = _redis_ssl_kwargs(host)
|
|
95
|
+
if os.path.exists(tmp_ssl_kwargs["ssl_ca_certs"]):
|
|
96
|
+
ssl_kwargs = tmp_ssl_kwargs
|
|
97
|
+
ssl_kwargs["ssl"] = True
|
|
98
|
+
|
|
99
|
+
if private:
|
|
100
|
+
return redis.StrictRedis(host=host, port=port, socket_keepalive=True, **ssl_kwargs)
|
|
101
|
+
else:
|
|
102
|
+
return redis.StrictRedis(
|
|
103
|
+
connection_pool=get_pool(host, port, ssl=ssl_kwargs.get("ssl", False)), socket_keepalive=True
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def get_pool(host, port, ssl=False):
|
|
108
|
+
key = (host, port)
|
|
109
|
+
connection_class = redis.connection.Connection
|
|
110
|
+
connection_kwargs = {}
|
|
111
|
+
if ssl:
|
|
112
|
+
connection_class = redis.connection.SSLConnection # type: ignore[assignment]
|
|
113
|
+
connection_kwargs = _redis_ssl_kwargs(host)
|
|
114
|
+
|
|
115
|
+
connection_pool = pool.get(key, None)
|
|
116
|
+
if not connection_pool:
|
|
117
|
+
connection_pool = redis.BlockingConnectionPool(
|
|
118
|
+
host=host, port=port, max_connections=200, connection_class=connection_class, **connection_kwargs
|
|
119
|
+
)
|
|
120
|
+
pool[key] = connection_pool
|
|
121
|
+
|
|
122
|
+
return connection_pool
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def decode(data):
|
|
126
|
+
try:
|
|
127
|
+
return json.loads(data)
|
|
128
|
+
except ValueError:
|
|
129
|
+
logger.warning("Invalid data on queue: %s", str(data))
|
|
130
|
+
return None
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
3
|
+
from redis.exceptions import ConnectionError
|
|
4
|
+
|
|
5
|
+
from clue.common.uid import get_id_from_data
|
|
6
|
+
from clue.remote.datatypes import get_client, retry_call
|
|
7
|
+
|
|
8
|
+
DEFAULT_TTL = 60 * 60 # 1 Hour
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class RedisCache(object):
|
|
12
|
+
def __init__(self, prefix="brl_cache", host=None, port=None, ttl=DEFAULT_TTL):
|
|
13
|
+
self.c = get_client(host, port, False)
|
|
14
|
+
self.prefix = prefix
|
|
15
|
+
self.ttl = ttl
|
|
16
|
+
|
|
17
|
+
def __enter__(self):
|
|
18
|
+
return self
|
|
19
|
+
|
|
20
|
+
def _get_key(self, name):
|
|
21
|
+
return f"{self.prefix}-{name}"
|
|
22
|
+
|
|
23
|
+
def clear(self):
|
|
24
|
+
# Clear all items belonging to this cahce
|
|
25
|
+
for queue in retry_call(self.c.keys, "%s-*" % self.prefix):
|
|
26
|
+
retry_call(self.c.delete, queue)
|
|
27
|
+
|
|
28
|
+
def create_key(self, *args):
|
|
29
|
+
key_str = "-".join([str(x) for x in args])
|
|
30
|
+
return get_id_from_data(key_str)
|
|
31
|
+
|
|
32
|
+
def get(self, key, ttl=None):
|
|
33
|
+
# Get the key name
|
|
34
|
+
cache_name = self._get_key(key)
|
|
35
|
+
|
|
36
|
+
# Get the value from the cache
|
|
37
|
+
item = retry_call(self.c.get, cache_name)
|
|
38
|
+
if not item:
|
|
39
|
+
return item
|
|
40
|
+
|
|
41
|
+
# Reset the cache while we're still using it
|
|
42
|
+
retry_call(self.c.expire, cache_name, ttl or self.ttl)
|
|
43
|
+
return json.loads(item)
|
|
44
|
+
|
|
45
|
+
def ready(self):
|
|
46
|
+
try:
|
|
47
|
+
self.c.ping()
|
|
48
|
+
except ConnectionError:
|
|
49
|
+
return False
|
|
50
|
+
|
|
51
|
+
return True
|
|
52
|
+
|
|
53
|
+
def set(self, key, value, ttl=None):
|
|
54
|
+
# Get the key name
|
|
55
|
+
cache_name = self._get_key(key)
|
|
56
|
+
|
|
57
|
+
# Set the value and the expiry for the name
|
|
58
|
+
retry_call(self.c.set, cache_name, json.dumps(value))
|
|
59
|
+
retry_call(self.c.expire, cache_name, ttl or self.ttl)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
Cache = RedisCache
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import threading
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar
|
|
7
|
+
|
|
8
|
+
from clue.remote.datatypes import get_client, retry_call
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from redis import Redis
|
|
12
|
+
from redis.client import PubSub
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
MessageType = TypeVar("MessageType")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class EventSender(Generic[MessageType]):
|
|
21
|
+
def __init__(
|
|
22
|
+
self, prefix: str, host=None, port=None, private=None, serializer: Callable[[MessageType], str] = json.dumps
|
|
23
|
+
):
|
|
24
|
+
self.client: Redis[Any] = get_client(host, port, private)
|
|
25
|
+
self.prefix = prefix.lower()
|
|
26
|
+
if not self.prefix.endswith("."):
|
|
27
|
+
self.prefix += "."
|
|
28
|
+
self.serializer = serializer
|
|
29
|
+
|
|
30
|
+
def send(self, name: str, data: MessageType):
|
|
31
|
+
path = self.prefix + name.lower().lstrip(".")
|
|
32
|
+
retry_call(self.client.publish, path, self.serializer(data))
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class PubSubWorkerThread(threading.Thread):
|
|
36
|
+
"""Worker thread that continually reads messages from pubsub.
|
|
37
|
+
|
|
38
|
+
We reimplement the worker thread here rather than use the one in the redis
|
|
39
|
+
package because we want to use the subscribe messages for disconnect/reconnect
|
|
40
|
+
detection.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(self, watcher, exception_handler=None, skip_first_refresh=True):
|
|
44
|
+
super().__init__(daemon=True)
|
|
45
|
+
self.watcher = watcher
|
|
46
|
+
self.exception_handler = exception_handler
|
|
47
|
+
self._running = threading.Event()
|
|
48
|
+
self.skip_first_refresh = skip_first_refresh
|
|
49
|
+
|
|
50
|
+
def run(self):
|
|
51
|
+
if self._running.is_set():
|
|
52
|
+
return
|
|
53
|
+
self._running.set()
|
|
54
|
+
|
|
55
|
+
pubsub: PubSub = self.watcher.pubsub
|
|
56
|
+
ping = self.watcher.client.ping
|
|
57
|
+
sleep_time = 1
|
|
58
|
+
initialized = set()
|
|
59
|
+
|
|
60
|
+
while self._running.is_set():
|
|
61
|
+
try:
|
|
62
|
+
message = pubsub.get_message(ignore_subscribe_messages=False, timeout=sleep_time)
|
|
63
|
+
if message is not None and message["type"] == "psubscribe":
|
|
64
|
+
channel = message.get("channel")
|
|
65
|
+
if channel is None:
|
|
66
|
+
continue
|
|
67
|
+
|
|
68
|
+
if self.skip_first_refresh and channel not in initialized:
|
|
69
|
+
initialized.add(channel)
|
|
70
|
+
continue
|
|
71
|
+
|
|
72
|
+
handler = pubsub.patterns.get(channel, None)
|
|
73
|
+
if handler:
|
|
74
|
+
handler(None)
|
|
75
|
+
|
|
76
|
+
except BaseException:
|
|
77
|
+
# Present the error
|
|
78
|
+
logger.exception("Exception in pubsub watcher:")
|
|
79
|
+
|
|
80
|
+
# Wait until we can reach the server
|
|
81
|
+
retry_call(ping)
|
|
82
|
+
|
|
83
|
+
pubsub.close()
|
|
84
|
+
|
|
85
|
+
def stop(self):
|
|
86
|
+
# trip the flag so the run loop exits. the run loop will
|
|
87
|
+
# close the pubsub connection, which disconnects the socket
|
|
88
|
+
# and returns the connection to the pool.
|
|
89
|
+
self._running.clear()
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class EventWatcher(Generic[MessageType]):
|
|
93
|
+
def __init__(self, host=None, port=None, deserializer: Callable[[str], MessageType] = json.loads):
|
|
94
|
+
self.client: Redis[Any] = get_client(host, port, False)
|
|
95
|
+
self.pubsub = retry_call(self.client.pubsub)
|
|
96
|
+
self.pubsub.ignore_subscribe_messages = False
|
|
97
|
+
self.thread: Optional[PubSubWorkerThread] = None
|
|
98
|
+
self.deserializer = deserializer
|
|
99
|
+
self.skip_first_refresh = True
|
|
100
|
+
|
|
101
|
+
def register(self, path: str, callback: Callable[[Optional[MessageType]], None]):
|
|
102
|
+
def _callback(message: Optional[dict[str, Any]]):
|
|
103
|
+
if message is None:
|
|
104
|
+
callback(None)
|
|
105
|
+
elif message["type"] == "pmessage":
|
|
106
|
+
data = self.deserializer(message.get("data", ""))
|
|
107
|
+
callback(data)
|
|
108
|
+
|
|
109
|
+
self.pubsub.psubscribe(**{path.lower(): _callback})
|
|
110
|
+
|
|
111
|
+
def start(self):
|
|
112
|
+
self.thread = PubSubWorkerThread(self, skip_first_refresh=self.skip_first_refresh)
|
|
113
|
+
self.thread.start()
|
|
114
|
+
return self.thread
|
|
115
|
+
|
|
116
|
+
def stop(self):
|
|
117
|
+
if self.thread is not None:
|
|
118
|
+
self.thread.stop()
|
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import time
|
|
5
|
+
from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union
|
|
6
|
+
|
|
7
|
+
from clue.common.exceptions import ClueTypeError
|
|
8
|
+
from clue.remote.datatypes import get_client, retry_call
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from redis import Redis
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
_conditional_remove_script = """
|
|
15
|
+
local hash_name = KEYS[1]
|
|
16
|
+
local key_in_hash = ARGV[1]
|
|
17
|
+
local expected_value = ARGV[2]
|
|
18
|
+
local result = redis.call('hget', hash_name, key_in_hash)
|
|
19
|
+
if result == expected_value then
|
|
20
|
+
redis.call('hdel', hash_name, key_in_hash)
|
|
21
|
+
return 1
|
|
22
|
+
end
|
|
23
|
+
return 0
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
h_pop_script = """
|
|
28
|
+
local result = redis.call('hget', ARGV[1], ARGV[2])
|
|
29
|
+
if result then redis.call('hdel', ARGV[1], ARGV[2]) end
|
|
30
|
+
return result
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
_limited_add = """
|
|
35
|
+
local set_name = KEYS[1]
|
|
36
|
+
local key = ARGV[1]
|
|
37
|
+
local value = ARGV[2]
|
|
38
|
+
local limit = tonumber(ARGV[3])
|
|
39
|
+
|
|
40
|
+
if redis.call('hlen', set_name) < limit then
|
|
41
|
+
return redis.call('hsetnx', set_name, key, value)
|
|
42
|
+
end
|
|
43
|
+
return nil
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
T = TypeVar("T")
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class HashIterator(Generic[T]):
|
|
50
|
+
def __init__(self, hash_object: Hash[T]):
|
|
51
|
+
self.hash_object = hash_object
|
|
52
|
+
self.cursor = 0
|
|
53
|
+
self.buffer: list[T] = []
|
|
54
|
+
self._load_next()
|
|
55
|
+
|
|
56
|
+
def __next__(self) -> T:
|
|
57
|
+
while True:
|
|
58
|
+
if self.buffer:
|
|
59
|
+
return self.buffer.pop(0)
|
|
60
|
+
if self.cursor == 0:
|
|
61
|
+
raise StopIteration()
|
|
62
|
+
self._load_next()
|
|
63
|
+
|
|
64
|
+
def _load_next(self):
|
|
65
|
+
self.cursor, data = retry_call(self.hash_object.client.hscan, self.hash_object.name, self.cursor)
|
|
66
|
+
for key, value in data.items():
|
|
67
|
+
self.buffer.append((key.decode("utf-8"), json.loads(value)))
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class Hash(Generic[T]):
|
|
71
|
+
def __init__(self, name: str, host: Union[str, Redis] = None, port: int = None):
|
|
72
|
+
self.client = get_client(host, port, False)
|
|
73
|
+
self.name = name
|
|
74
|
+
self._pop = self.client.register_script(h_pop_script)
|
|
75
|
+
self._limited_add = self.client.register_script(_limited_add)
|
|
76
|
+
self._conditional_remove = self.client.register_script(_conditional_remove_script)
|
|
77
|
+
|
|
78
|
+
def __iter__(self):
|
|
79
|
+
return HashIterator(self)
|
|
80
|
+
|
|
81
|
+
def __enter__(self):
|
|
82
|
+
return self
|
|
83
|
+
|
|
84
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
85
|
+
self.delete()
|
|
86
|
+
|
|
87
|
+
def add(self, key: str, value: T) -> int:
|
|
88
|
+
"""Add the (key, value) pair to the hash for new keys.
|
|
89
|
+
|
|
90
|
+
If a key already exists this operation doesn't add it.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
True if key has been added to the table, False otherwise.
|
|
94
|
+
"""
|
|
95
|
+
if isinstance(key, bytes):
|
|
96
|
+
raise ClueTypeError("Cannot use bytes for hashmap keys")
|
|
97
|
+
|
|
98
|
+
return retry_call(self.client.hsetnx, self.name, key, json.dumps(value))
|
|
99
|
+
|
|
100
|
+
def increment(self, key, increment: int = 1):
|
|
101
|
+
return int(retry_call(self.client.hincrby, self.name, key, increment))
|
|
102
|
+
|
|
103
|
+
def limited_add(self, key, value, size_limit):
|
|
104
|
+
"""Add a single value to the set, but only if that wouldn't make the set grow past a given size.
|
|
105
|
+
|
|
106
|
+
If the hash has hit the size limit returns None
|
|
107
|
+
Otherwise, returns the result of hsetnx (same as `add`)
|
|
108
|
+
"""
|
|
109
|
+
return retry_call(self._limited_add, keys=[self.name], args=[key, json.dumps(value), size_limit])
|
|
110
|
+
|
|
111
|
+
def exists(self, key: str) -> bool:
|
|
112
|
+
return retry_call(self.client.hexists, self.name, key)
|
|
113
|
+
|
|
114
|
+
def get(self, key: str) -> Optional[T]:
|
|
115
|
+
item = retry_call(self.client.hget, self.name, key)
|
|
116
|
+
if not item:
|
|
117
|
+
return item
|
|
118
|
+
return json.loads(item)
|
|
119
|
+
|
|
120
|
+
def keys(self) -> list[str]:
|
|
121
|
+
return [k.decode("utf-8") for k in retry_call(self.client.hkeys, self.name)]
|
|
122
|
+
|
|
123
|
+
def length(self):
|
|
124
|
+
return retry_call(self.client.hlen, self.name)
|
|
125
|
+
|
|
126
|
+
def items(self) -> dict[str, T]:
|
|
127
|
+
items = retry_call(self.client.hgetall, self.name)
|
|
128
|
+
if not isinstance(items, dict):
|
|
129
|
+
return {}
|
|
130
|
+
return {k.decode("utf-8"): json.loads(v) for k, v in items.items()}
|
|
131
|
+
|
|
132
|
+
def conditional_remove(self, key: str, value) -> bool:
|
|
133
|
+
return bool(retry_call(self._conditional_remove, keys=[self.name], args=[key, json.dumps(value)]))
|
|
134
|
+
|
|
135
|
+
def pop(self, key: str):
|
|
136
|
+
item = retry_call(self._pop, args=[self.name, key])
|
|
137
|
+
if not item:
|
|
138
|
+
return item
|
|
139
|
+
return json.loads(item)
|
|
140
|
+
|
|
141
|
+
def set(self, key: str, value: T):
|
|
142
|
+
if isinstance(key, bytes):
|
|
143
|
+
raise ClueTypeError("Cannot use bytes for hashmap keys")
|
|
144
|
+
|
|
145
|
+
return retry_call(self.client.hset, self.name, key, json.dumps(value))
|
|
146
|
+
|
|
147
|
+
def multi_set(self, data: dict[str, T]):
|
|
148
|
+
if any(isinstance(key, bytes) for key in data.keys()):
|
|
149
|
+
raise ValueError("Cannot use bytes for hashmap keys")
|
|
150
|
+
encoded = {key: json.dumps(value) for key, value in data.items()}
|
|
151
|
+
return retry_call(self.client.hset, self.name, mapping=encoded)
|
|
152
|
+
|
|
153
|
+
def delete(self):
|
|
154
|
+
retry_call(self.client.delete, self.name)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class ExpiringHash(Hash):
|
|
158
|
+
def __init__(self, name, ttl=86400, host=None, port=None):
|
|
159
|
+
super(ExpiringHash, self).__init__(name, host, port)
|
|
160
|
+
self.ttl = ttl
|
|
161
|
+
self.last_expire_time = 0
|
|
162
|
+
|
|
163
|
+
def _conditional_expire(self):
|
|
164
|
+
if self.ttl:
|
|
165
|
+
ctime = time.time()
|
|
166
|
+
if ctime > self.last_expire_time + (self.ttl / 2):
|
|
167
|
+
retry_call(self.client.expire, self.name, self.ttl)
|
|
168
|
+
self.last_expire_time = ctime
|
|
169
|
+
|
|
170
|
+
def add(self, key, value):
|
|
171
|
+
rval = super(ExpiringHash, self).add(key, value)
|
|
172
|
+
self._conditional_expire()
|
|
173
|
+
return rval
|
|
174
|
+
|
|
175
|
+
def set(self, key, value):
|
|
176
|
+
rval = super(ExpiringHash, self).set(key, value)
|
|
177
|
+
self._conditional_expire()
|
|
178
|
+
return rval
|
|
179
|
+
|
|
180
|
+
def multi_set(self, data):
|
|
181
|
+
rval = super(ExpiringHash, self).multi_set(data)
|
|
182
|
+
self._conditional_expire()
|
|
183
|
+
return rval
|
|
184
|
+
|
|
185
|
+
def increment(self, key, increment=1):
|
|
186
|
+
rval = super(ExpiringHash, self).increment(key, increment)
|
|
187
|
+
self._conditional_expire()
|
|
188
|
+
return rval
|
|
189
|
+
|
|
190
|
+
def limited_add(self, key, value, size_limit):
|
|
191
|
+
rval = super(ExpiringHash, self).limited_add(key, value, size_limit)
|
|
192
|
+
self._conditional_expire()
|
|
193
|
+
return rval
|
|
File without changes
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
3
|
+
import redis
|
|
4
|
+
|
|
5
|
+
from clue.common.logging import get_logger
|
|
6
|
+
from clue.remote.datatypes import decode, get_client, retry_call
|
|
7
|
+
|
|
8
|
+
logger = get_logger(__file__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class CommsQueue(object):
|
|
12
|
+
def __init__(self, names, host=None, port=None, private=False):
|
|
13
|
+
self.c = get_client(host, port, private)
|
|
14
|
+
self.p = retry_call(self.c.pubsub)
|
|
15
|
+
if not isinstance(names, list):
|
|
16
|
+
names = [names]
|
|
17
|
+
self.names = names
|
|
18
|
+
self._connected = False
|
|
19
|
+
|
|
20
|
+
def __enter__(self):
|
|
21
|
+
return self
|
|
22
|
+
|
|
23
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
24
|
+
retry_call(self.p.unsubscribe)
|
|
25
|
+
|
|
26
|
+
def _connect(self):
|
|
27
|
+
if not self._connected:
|
|
28
|
+
retry_call(self.p.subscribe, self.names)
|
|
29
|
+
self._connected = True
|
|
30
|
+
|
|
31
|
+
def close(self):
|
|
32
|
+
retry_call(self.p.close)
|
|
33
|
+
|
|
34
|
+
def listen(self, blocking=True):
|
|
35
|
+
retried = False
|
|
36
|
+
while True:
|
|
37
|
+
self._connect()
|
|
38
|
+
try:
|
|
39
|
+
if blocking:
|
|
40
|
+
i = self.p.listen()
|
|
41
|
+
v = next(i)
|
|
42
|
+
else:
|
|
43
|
+
v = self.p.get_message()
|
|
44
|
+
if v is None:
|
|
45
|
+
yield None
|
|
46
|
+
continue
|
|
47
|
+
|
|
48
|
+
if isinstance(v, dict) and v.get("type", None) == "message":
|
|
49
|
+
data = decode(v.get("data", "null"))
|
|
50
|
+
yield data
|
|
51
|
+
except redis.ConnectionError:
|
|
52
|
+
logger.warning("No connection to Redis, reconnecting...")
|
|
53
|
+
self._connected = False
|
|
54
|
+
retried = True
|
|
55
|
+
finally:
|
|
56
|
+
if self._connected and retried:
|
|
57
|
+
logger.info("Reconnected to Redis!")
|
|
58
|
+
retried = False
|
|
59
|
+
|
|
60
|
+
def publish(self, message):
|
|
61
|
+
for name in self.names:
|
|
62
|
+
retry_call(self.c.publish, name, json.dumps(message))
|