clue-api 1.0.0.dev7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clue/.gitignore +21 -0
- clue/__init__.py +0 -0
- clue/api/__init__.py +211 -0
- clue/api/base.py +99 -0
- clue/api/v1/__init__.py +82 -0
- clue/api/v1/actions.py +92 -0
- clue/api/v1/auth.py +243 -0
- clue/api/v1/configs.py +83 -0
- clue/api/v1/fetchers.py +94 -0
- clue/api/v1/lookup.py +221 -0
- clue/api/v1/registration.py +109 -0
- clue/api/v1/static.py +94 -0
- clue/app.py +166 -0
- clue/cache/__init__.py +129 -0
- clue/common/__init__.py +0 -0
- clue/common/classification.py +1006 -0
- clue/common/classification.yml +130 -0
- clue/common/dict_utils.py +130 -0
- clue/common/exceptions.py +199 -0
- clue/common/forge.py +152 -0
- clue/common/json_utils.py +10 -0
- clue/common/list_utils.py +11 -0
- clue/common/logging/__init__.py +291 -0
- clue/common/logging/audit.py +157 -0
- clue/common/logging/format.py +42 -0
- clue/common/regex.py +31 -0
- clue/common/str_utils.py +213 -0
- clue/common/swagger.py +139 -0
- clue/common/uid.py +47 -0
- clue/config.py +60 -0
- clue/constants/__init__.py +0 -0
- clue/constants/supported_types.py +38 -0
- clue/cronjobs/__init__.py +30 -0
- clue/cronjobs/plugins.py +32 -0
- clue/error.py +129 -0
- clue/gunicorn_config.py +29 -0
- clue/healthz.py +74 -0
- clue/helper/discover.py +53 -0
- clue/helper/headers.py +30 -0
- clue/helper/oauth.py +128 -0
- clue/models/__init__.py +0 -0
- clue/models/actions.py +243 -0
- clue/models/config.py +456 -0
- clue/models/fetchers.py +136 -0
- clue/models/graph.py +162 -0
- clue/models/model_list.py +52 -0
- clue/models/network.py +430 -0
- clue/models/results/__init__.py +34 -0
- clue/models/results/base.py +10 -0
- clue/models/results/graph.py +26 -0
- clue/models/results/image.py +22 -0
- clue/models/results/status.py +55 -0
- clue/models/results/validation.py +57 -0
- clue/models/selector.py +67 -0
- clue/models/utils.py +52 -0
- clue/models/validators.py +19 -0
- clue/patched.py +8 -0
- clue/plugin/__init__.py +1008 -0
- clue/plugin/helpers/__init__.py +0 -0
- clue/plugin/helpers/central_server.py +27 -0
- clue/plugin/helpers/email_render.py +228 -0
- clue/plugin/helpers/token.py +34 -0
- clue/plugin/helpers/trino.py +103 -0
- clue/plugin/interactive.py +270 -0
- clue/plugin/models.py +19 -0
- clue/plugin/utils.py +78 -0
- clue/remote/__init__.py +0 -0
- clue/remote/datatypes/__init__.py +130 -0
- clue/remote/datatypes/cache.py +62 -0
- clue/remote/datatypes/events.py +118 -0
- clue/remote/datatypes/hash.py +193 -0
- clue/remote/datatypes/queues/__init__.py +0 -0
- clue/remote/datatypes/queues/comms.py +62 -0
- clue/remote/datatypes/set.py +96 -0
- clue/remote/datatypes/user_quota_tracker.py +54 -0
- clue/security/__init__.py +211 -0
- clue/security/obo.py +95 -0
- clue/security/utils.py +34 -0
- clue/services/action_service.py +186 -0
- clue/services/auth_service.py +348 -0
- clue/services/config_service.py +38 -0
- clue/services/fetcher_service.py +203 -0
- clue/services/jwt_service.py +233 -0
- clue/services/lookup_service.py +786 -0
- clue/services/type_service.py +165 -0
- clue/services/user_service.py +152 -0
- clue_api-1.0.0.dev7.dist-info/METADATA +111 -0
- clue_api-1.0.0.dev7.dist-info/RECORD +91 -0
- clue_api-1.0.0.dev7.dist-info/WHEEL +4 -0
- clue_api-1.0.0.dev7.dist-info/entry_points.txt +8 -0
- clue_api-1.0.0.dev7.dist-info/licenses/LICENSE +11 -0
clue/gunicorn_config.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
import multiprocessing
|
|
2
|
+
from os import environ as env
|
|
3
|
+
|
|
4
|
+
# Port to bind to
|
|
5
|
+
bind = f":{int(env.get('PORT', 5000))}"
|
|
6
|
+
|
|
7
|
+
# Number of processes to launch
|
|
8
|
+
workers = int(env.get("WORKERS", multiprocessing.cpu_count()))
|
|
9
|
+
|
|
10
|
+
# Number of concurrent handled connections
|
|
11
|
+
threads = int(env.get("THREADS", 4))
|
|
12
|
+
worker_connections = int(env.get("WORKER_CONNECTIONS", "1000"))
|
|
13
|
+
|
|
14
|
+
# Recycle the process after X request randomized by the jitter
|
|
15
|
+
max_requests = int(env.get("MAX_REQUESTS", "1000"))
|
|
16
|
+
max_requests_jitter = int(env.get("MAX_REQUESTS_JITTER", "100"))
|
|
17
|
+
|
|
18
|
+
# Connection timeouts
|
|
19
|
+
graceful_timeout = int(env.get("GRACEFUL_TIMEOUT", "30"))
|
|
20
|
+
timeout = int(env.get("TIMEOUT", "30"))
|
|
21
|
+
|
|
22
|
+
# TLS/SSL Configuration
|
|
23
|
+
certfile = env.get("CERTFILE")
|
|
24
|
+
keyfile = env.get("KEYFILE")
|
|
25
|
+
|
|
26
|
+
# Request Max Size Configuration
|
|
27
|
+
limit_request_line = int(env.get("LIMIT_REQUEST_LINE", "4094"))
|
|
28
|
+
limit_request_fields = int(env.get("LIMIT_REQUEST_FIELDS", "100"))
|
|
29
|
+
limit_request_field_size = int(env.get("LIMIT_REQUEST_FIELD_SIZE", "8190"))
|
clue/healthz.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
from flasgger import swag_from
|
|
2
|
+
from flask import Blueprint, abort, make_response
|
|
3
|
+
|
|
4
|
+
from clue.config import get_redis
|
|
5
|
+
|
|
6
|
+
API_PREFIX = "/healthz"
|
|
7
|
+
healthz = Blueprint("healthz", __name__, url_prefix=API_PREFIX)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@swag_from(
|
|
11
|
+
{
|
|
12
|
+
"parameters": [],
|
|
13
|
+
"definitions": {},
|
|
14
|
+
"responses": {"200": {"description": "Liveness Probe"}},
|
|
15
|
+
"tags": ["Health"],
|
|
16
|
+
"operationId": "clue.healthz.liveness",
|
|
17
|
+
}
|
|
18
|
+
)
|
|
19
|
+
@healthz.route("/live")
|
|
20
|
+
def liveness(**_):
|
|
21
|
+
"""Check if the API is live
|
|
22
|
+
|
|
23
|
+
Variables:
|
|
24
|
+
None
|
|
25
|
+
|
|
26
|
+
Arguments:
|
|
27
|
+
None
|
|
28
|
+
|
|
29
|
+
Data Block:
|
|
30
|
+
None
|
|
31
|
+
|
|
32
|
+
Result example:
|
|
33
|
+
OK or FAIL
|
|
34
|
+
"""
|
|
35
|
+
return make_response("OK")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@swag_from(
|
|
39
|
+
{
|
|
40
|
+
"parameters": [],
|
|
41
|
+
"definitions": {},
|
|
42
|
+
"responses": {"200": {"description": "Readyness Probe"}},
|
|
43
|
+
"tags": ["Health"],
|
|
44
|
+
"operationId": "clue.healthz.readyness",
|
|
45
|
+
}
|
|
46
|
+
)
|
|
47
|
+
@healthz.route("/ready")
|
|
48
|
+
def readyness(**_):
|
|
49
|
+
"""Check if the API is Ready
|
|
50
|
+
|
|
51
|
+
Variables:
|
|
52
|
+
None
|
|
53
|
+
|
|
54
|
+
Arguments:
|
|
55
|
+
None
|
|
56
|
+
|
|
57
|
+
Data Block:
|
|
58
|
+
None
|
|
59
|
+
|
|
60
|
+
Result example:
|
|
61
|
+
OK or FAIL
|
|
62
|
+
"""
|
|
63
|
+
redis = get_redis()
|
|
64
|
+
|
|
65
|
+
if redis.ping():
|
|
66
|
+
return make_response("OK")
|
|
67
|
+
else:
|
|
68
|
+
abort(503)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@healthz.errorhandler(503)
|
|
72
|
+
def error(_):
|
|
73
|
+
"Handle errors exposed in healthz routes"
|
|
74
|
+
return "FAIL", 503
|
clue/helper/discover.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
import geventhttpclient
|
|
5
|
+
|
|
6
|
+
from clue.common.logging import get_logger
|
|
7
|
+
from clue.config import config
|
|
8
|
+
|
|
9
|
+
logger = get_logger(__file__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def get_apps_list() -> list[dict[str, str]]:
|
|
13
|
+
"""Get a list of apps from the discovery service
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
list[dict[str, str]]: A list of other apps
|
|
17
|
+
"""
|
|
18
|
+
apps = []
|
|
19
|
+
|
|
20
|
+
if "pytest" in sys.modules or bool(os.getenv("SKIP_DISCOVERY", "")):
|
|
21
|
+
logger.info("Skipping discovery, running in a test environment")
|
|
22
|
+
|
|
23
|
+
if config.api.discover_url:
|
|
24
|
+
try:
|
|
25
|
+
resp = geventhttpclient.get(
|
|
26
|
+
config.api.discover_url,
|
|
27
|
+
headers={"accept": "application/json"},
|
|
28
|
+
)
|
|
29
|
+
if resp.ok:
|
|
30
|
+
data = resp.json()
|
|
31
|
+
for app in data["applications"]["application"]:
|
|
32
|
+
try:
|
|
33
|
+
url = app["instance"][0]["hostName"]
|
|
34
|
+
|
|
35
|
+
if "clue" not in url:
|
|
36
|
+
apps.append(
|
|
37
|
+
{
|
|
38
|
+
"alt": app["instance"][0]["metadata"]["alternateText"],
|
|
39
|
+
"name": app["name"],
|
|
40
|
+
"img_d": app["instance"][0]["metadata"]["imageDark"],
|
|
41
|
+
"img_l": app["instance"][0]["metadata"]["imageLight"],
|
|
42
|
+
"route": url,
|
|
43
|
+
"classification": app["instance"][0]["metadata"]["classification"],
|
|
44
|
+
}
|
|
45
|
+
)
|
|
46
|
+
except Exception:
|
|
47
|
+
logger.exception(f"Failed to parse get app: {str(app)}")
|
|
48
|
+
else:
|
|
49
|
+
logger.warning(f"Invalid response from server for apps discovery: {config.api.discover_url}")
|
|
50
|
+
except Exception:
|
|
51
|
+
logger.exception(f"Failed to get apps from discover URL: {config.api.discover_url}")
|
|
52
|
+
|
|
53
|
+
return sorted(apps, key=lambda k: k["name"])
|
clue/helper/headers.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from clue.common.logging import get_logger
|
|
2
|
+
from clue.config import DEBUG, cache, config
|
|
3
|
+
|
|
4
|
+
logger = get_logger(__file__)
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@cache.memoize(timeout=1 if DEBUG else 5 * 60) # Cached for 5 minutes
|
|
8
|
+
def generate_headers(access_token: str | None, clue_access_token: str | None) -> dict[str, str]:
|
|
9
|
+
"""Generates the request headers.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
access_token (str): The access token to include in the Authorization header.
|
|
13
|
+
|
|
14
|
+
Returns:
|
|
15
|
+
dict[str, str]: A dict of the request headers
|
|
16
|
+
"""
|
|
17
|
+
_headers = {
|
|
18
|
+
"accept": "application/json",
|
|
19
|
+
"content-type": "application/json",
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
if access_token:
|
|
23
|
+
logger.debug("Appending authorization header")
|
|
24
|
+
_headers["Authorization"] = f"Bearer {access_token}"
|
|
25
|
+
|
|
26
|
+
if config.auth.propagate_clue_key and clue_access_token:
|
|
27
|
+
logger.debug("Appending custom authorization header")
|
|
28
|
+
_headers["X-Clue-Authorization"] = clue_access_token
|
|
29
|
+
|
|
30
|
+
return _headers
|
clue/helper/oauth.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import re
|
|
3
|
+
from typing import Any, Optional
|
|
4
|
+
|
|
5
|
+
import elasticapm
|
|
6
|
+
|
|
7
|
+
from clue.common.logging import get_logger
|
|
8
|
+
from clue.config import CLASSIFICATION as CL_ENGINE
|
|
9
|
+
from clue.config import USER_TYPES, config
|
|
10
|
+
from clue.models.config import (
|
|
11
|
+
DEFAULT_EMAIL_FIELDS,
|
|
12
|
+
DEFAULT_USER_FIELDS,
|
|
13
|
+
DEFAULT_USER_NAME_FIELDS,
|
|
14
|
+
OAuthProvider,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
VALID_CHARS = [str(x) for x in range(10)] + [chr(x + 65) for x in range(26)] + [chr(x + 97) for x in range(26)] + ["-"]
|
|
18
|
+
|
|
19
|
+
logger = get_logger(__file__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def reorder_name(name: Optional[str]) -> Optional[str]:
|
|
23
|
+
"""Reorders a name, so that the last name goes in front of the first name.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
name (Optional[str]): The name to reorder
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Optional[str]: The reordered name
|
|
30
|
+
"""
|
|
31
|
+
if name is None:
|
|
32
|
+
return name
|
|
33
|
+
|
|
34
|
+
return " ".join(name.split(", ", 1)[::-1])
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@elasticapm.capture_span(span_type="authentication")
|
|
38
|
+
def parse_profile(profile: dict[str, Any], provider_config: OAuthProvider) -> dict[str, Any]: # noqa: C901
|
|
39
|
+
"""Find email address and normalize it for further processing"""
|
|
40
|
+
email_adr: str | None = None
|
|
41
|
+
for field in DEFAULT_EMAIL_FIELDS:
|
|
42
|
+
if field in profile:
|
|
43
|
+
email_adr = profile[field]
|
|
44
|
+
if isinstance(email_adr, list):
|
|
45
|
+
email_adr = email_adr[0]
|
|
46
|
+
break
|
|
47
|
+
|
|
48
|
+
if isinstance(email_adr, list):
|
|
49
|
+
email_adr = email_adr[0]
|
|
50
|
+
|
|
51
|
+
if email_adr:
|
|
52
|
+
email_adr = email_adr.lower()
|
|
53
|
+
if "@" not in email_adr:
|
|
54
|
+
email_adr = None
|
|
55
|
+
|
|
56
|
+
# Find the name of the user
|
|
57
|
+
name = None
|
|
58
|
+
for field in DEFAULT_USER_NAME_FIELDS:
|
|
59
|
+
if field in profile:
|
|
60
|
+
name = reorder_name(profile[field])
|
|
61
|
+
break
|
|
62
|
+
|
|
63
|
+
# Try to find a username or use email address
|
|
64
|
+
uname = None
|
|
65
|
+
for field in DEFAULT_USER_FIELDS:
|
|
66
|
+
if field in profile:
|
|
67
|
+
uname: str = profile[field]
|
|
68
|
+
break
|
|
69
|
+
uname = uname or email_adr
|
|
70
|
+
|
|
71
|
+
# Did we use the email address?
|
|
72
|
+
if uname is not None and email_adr is not None and uname.lower() == email_adr.lower():
|
|
73
|
+
# 1. Use provided regex matcher
|
|
74
|
+
if provider_config.uid_regex:
|
|
75
|
+
match = re.match(provider_config.uid_regex, uname)
|
|
76
|
+
if match:
|
|
77
|
+
if provider_config.uid_format:
|
|
78
|
+
uname = provider_config.uid_format.format(*[x or "" for x in match.groups()]).lower()
|
|
79
|
+
else:
|
|
80
|
+
uname = "".join([x for x in match.groups() if x]).lower()
|
|
81
|
+
|
|
82
|
+
# 2. Parse name and domain from email if regex failed or missing
|
|
83
|
+
if uname is not None and uname == email_adr:
|
|
84
|
+
e_name, e_dom = uname.split("@", 1)
|
|
85
|
+
uname = f"{e_name}-{e_dom.split('.')[0]}"
|
|
86
|
+
|
|
87
|
+
# 3. Use name as username if there are no username found yet
|
|
88
|
+
if uname is None and name is not None:
|
|
89
|
+
uname = name.replace(" ", "-")
|
|
90
|
+
|
|
91
|
+
# Cleanup username
|
|
92
|
+
if uname:
|
|
93
|
+
uname = "".join([c for c in uname if c in VALID_CHARS])
|
|
94
|
+
|
|
95
|
+
# Get avatar from gravatar
|
|
96
|
+
if config.auth.oauth.gravatar_enabled and email_adr:
|
|
97
|
+
email_hash = hashlib.md5(email_adr.encode("utf-8")).hexdigest() # noqa: S324
|
|
98
|
+
alternate = f"https://www.gravatar.com/avatar/{email_hash}?s=256&d=404&r=pg"
|
|
99
|
+
else:
|
|
100
|
+
alternate = None
|
|
101
|
+
|
|
102
|
+
# Compute access, roles and classification using auto_properties
|
|
103
|
+
access = True
|
|
104
|
+
roles = ["user"]
|
|
105
|
+
# TODO: correctly figure out the classification
|
|
106
|
+
classification = CL_ENGINE.UNRESTRICTED
|
|
107
|
+
|
|
108
|
+
# Infer roles from groups
|
|
109
|
+
if profile.get("groups") and provider_config.role_map:
|
|
110
|
+
for user_type in USER_TYPES:
|
|
111
|
+
if (
|
|
112
|
+
user_type in provider_config.role_map
|
|
113
|
+
and provider_config.role_map[user_type] in profile.get("groups", [])
|
|
114
|
+
and user_type not in roles
|
|
115
|
+
):
|
|
116
|
+
roles.append(user_type)
|
|
117
|
+
|
|
118
|
+
return dict(
|
|
119
|
+
access=access,
|
|
120
|
+
type=roles,
|
|
121
|
+
classification=classification,
|
|
122
|
+
uname=uname,
|
|
123
|
+
name=name,
|
|
124
|
+
email=email_adr,
|
|
125
|
+
password="__NO_PASSWORD__", # noqa: S106
|
|
126
|
+
avatar=profile.get("picture", alternate),
|
|
127
|
+
groups=profile.get("groups", []),
|
|
128
|
+
)
|
clue/models/__init__.py
ADDED
|
File without changes
|
clue/models/actions.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
# ruff: noqa: D101
|
|
2
|
+
import re
|
|
3
|
+
from inspect import isclass
|
|
4
|
+
from typing import (
|
|
5
|
+
Any,
|
|
6
|
+
Generic,
|
|
7
|
+
Literal,
|
|
8
|
+
TypeVar,
|
|
9
|
+
Union,
|
|
10
|
+
cast,
|
|
11
|
+
get_args,
|
|
12
|
+
get_origin,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from pydantic import BaseModel, Field, ValidationInfo, field_validator, model_validator
|
|
16
|
+
from pydantic_core import Url, ValidationError
|
|
17
|
+
from typing_extensions import Self
|
|
18
|
+
|
|
19
|
+
from clue.common.exceptions import ClueValueError
|
|
20
|
+
from clue.common.logging import get_logger
|
|
21
|
+
from clue.constants.supported_types import SUPPORTED_TYPES
|
|
22
|
+
from clue.models.results import DATA
|
|
23
|
+
from clue.models.results.validation import validate_result
|
|
24
|
+
from clue.models.selector import Selector
|
|
25
|
+
from clue.models.validators import validate_classification
|
|
26
|
+
|
|
27
|
+
logger = get_logger(__file__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ExecuteRequest(BaseModel):
|
|
31
|
+
selector: Selector | None = Field(description="The selector to execute the action on.", default=None)
|
|
32
|
+
selectors: list[Selector] = Field(description="The selectors to execute the action on.", default=[])
|
|
33
|
+
|
|
34
|
+
@model_validator(mode="after")
|
|
35
|
+
def validate_model(self: Self, info: ValidationInfo) -> Self: # noqa: C901
|
|
36
|
+
"""Validates the entire model.
|
|
37
|
+
|
|
38
|
+
Raises:
|
|
39
|
+
AssertionError: Raised whenever a field is invalid on the model.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
Self: The validated model.
|
|
43
|
+
"""
|
|
44
|
+
action_to_validate: Action | None = None
|
|
45
|
+
if info.context:
|
|
46
|
+
action_to_validate: Action | None = info.context.get("action", None)
|
|
47
|
+
|
|
48
|
+
if self.selector is None and (self.selectors is None or len(self.selectors) < 1):
|
|
49
|
+
if not action_to_validate or not action_to_validate.accept_empty:
|
|
50
|
+
raise ClueValueError(
|
|
51
|
+
"Either selector (single entry) or selectors (multiple entries) must not be empty."
|
|
52
|
+
)
|
|
53
|
+
elif self.selectors is None or len(self.selectors) < 1:
|
|
54
|
+
self.selectors = [cast(Selector, self.selector)]
|
|
55
|
+
elif self.selector is None and len(self.selectors) == 1:
|
|
56
|
+
self.selector = self.selectors[0]
|
|
57
|
+
|
|
58
|
+
return self
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
ER = TypeVar("ER", bound=ExecuteRequest)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class ActionBase(BaseModel):
|
|
65
|
+
id: str = Field(description="Unique identifier for the action.")
|
|
66
|
+
name: str = Field(description="Name of the action.")
|
|
67
|
+
classification: str = Field(
|
|
68
|
+
description="Classification of the action. Denotes the maximum classification of data sent to the action.",
|
|
69
|
+
)
|
|
70
|
+
summary: str | None = Field(description="A plaintext summary of the action.", default=None)
|
|
71
|
+
supported_types: set[str] = Field(description="A list of types this action supports.")
|
|
72
|
+
action_icon: str | None = Field(
|
|
73
|
+
description=(
|
|
74
|
+
"Formatted string to present an icon for this analytic on the UI using iconify/react format: "
|
|
75
|
+
"https://iconify.design/docs/icon-components/react/. External icons not yet supported."
|
|
76
|
+
),
|
|
77
|
+
default=None,
|
|
78
|
+
)
|
|
79
|
+
accept_empty: bool = Field(description="Does this action support execution with no selectors?", default=False)
|
|
80
|
+
accept_multiple: bool = Field(description="Does this action support multiple values?", default=False)
|
|
81
|
+
format: str | None = Field(
|
|
82
|
+
description="What is the format of the output, if known?",
|
|
83
|
+
default=None,
|
|
84
|
+
)
|
|
85
|
+
extra_schema: Any | None = Field(
|
|
86
|
+
description="Extra key values for the form schema. These will overwrite default behaviour", default={}
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
@field_validator("id")
|
|
90
|
+
@classmethod
|
|
91
|
+
def validate_id(cls, action_id: str) -> str: # noqa: ANN102
|
|
92
|
+
"""Validates the action ID field.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
action_id (str): The ID to validate.
|
|
96
|
+
|
|
97
|
+
Raises:
|
|
98
|
+
ClueValueError: Raised whenever the ID is not in a valid format.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
str: The validated ID.
|
|
102
|
+
"""
|
|
103
|
+
if re.match(r"[^a-z_]", action_id):
|
|
104
|
+
raise ClueValueError("Invalid action id - can only contain lowercase letters and underscores.")
|
|
105
|
+
|
|
106
|
+
return action_id
|
|
107
|
+
|
|
108
|
+
@field_validator("classification")
|
|
109
|
+
@classmethod
|
|
110
|
+
def check_classification(cls, classification: str) -> str: # noqa: ANN102
|
|
111
|
+
"""Validates the provided classification.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
classification (str): The classification to validate.
|
|
115
|
+
|
|
116
|
+
Raises:
|
|
117
|
+
AssertionError: Raised whenever the provided classification is not valid.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
str: The validated classification.
|
|
121
|
+
"""
|
|
122
|
+
return validate_classification(classification)
|
|
123
|
+
|
|
124
|
+
@field_validator("supported_types")
|
|
125
|
+
@classmethod
|
|
126
|
+
def validate_supported_types(cls, supported_types: set[str]) -> set[str]: # noqa: ANN102
|
|
127
|
+
"""Validate that the list of supported types matches the list of supported types"""
|
|
128
|
+
invalid_types = supported_types - set(SUPPORTED_TYPES.keys())
|
|
129
|
+
|
|
130
|
+
if invalid_types:
|
|
131
|
+
raise AssertionError(f"{', '.join(invalid_types)} are not supported types.")
|
|
132
|
+
|
|
133
|
+
return supported_types
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class Action(ActionBase, Generic[ER]):
|
|
137
|
+
params: ER | dict[str, Any] | None = Field(description="Specification of additional parameters.", default=None)
|
|
138
|
+
|
|
139
|
+
@model_validator(mode="before")
|
|
140
|
+
@classmethod
|
|
141
|
+
def check_structure(cls, data: Any) -> Any: # noqa: ANN102
|
|
142
|
+
"""Checks the structure of the model.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
data (Any): The model data to validate.
|
|
146
|
+
|
|
147
|
+
Raises:
|
|
148
|
+
ClueValueError: Raised whenever the additional_annotations field doesn't inherit ExecuteRequest
|
|
149
|
+
AssertionError: Raised whenever a field is not valid.
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
Any: The validated data.
|
|
153
|
+
"""
|
|
154
|
+
additional_annotations: type[Any] = cast(type[Any], cls.model_fields["params"].annotation).__args__[0]
|
|
155
|
+
|
|
156
|
+
if not isinstance(data.get("params", None), dict) and isinstance(additional_annotations, TypeVar):
|
|
157
|
+
raise ClueValueError(
|
|
158
|
+
"you must provide a non-generic class as a type annotation. To accept no additional parameters, use "
|
|
159
|
+
"Action[ExecuteRequest]."
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
if isinstance(data.get("params", None), dict):
|
|
163
|
+
if "$defs" not in data["params"]:
|
|
164
|
+
raise ClueValueError("If params is a dict, it must be a valid json schema.")
|
|
165
|
+
|
|
166
|
+
return data
|
|
167
|
+
elif not issubclass(additional_annotations, ExecuteRequest):
|
|
168
|
+
raise ClueValueError(
|
|
169
|
+
"params does not inherit from ExecuteRequest. When extending the params, it is necessary to inherit "
|
|
170
|
+
"from ExecuteRequest."
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
missing_annotations = [key for key, info in additional_annotations.model_fields.items() if not info.annotation]
|
|
174
|
+
|
|
175
|
+
if missing_annotations:
|
|
176
|
+
raise AssertionError(
|
|
177
|
+
f"{','.join(missing_annotations)} do not have type annotations. All fields must be annotated"
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
nested_fields: list[str] = []
|
|
181
|
+
for key, info in additional_annotations.model_fields.items():
|
|
182
|
+
field_type = cast(type[Any], info.annotation)
|
|
183
|
+
|
|
184
|
+
if get_origin(field_type) is Union:
|
|
185
|
+
field_type = get_args(field_type)[0]
|
|
186
|
+
|
|
187
|
+
if key not in ["selector", "selectors"] and isclass(field_type) and BaseModel in field_type.__mro__:
|
|
188
|
+
nested_fields.append(key)
|
|
189
|
+
|
|
190
|
+
if nested_fields:
|
|
191
|
+
raise AssertionError(
|
|
192
|
+
f"{','.join(nested_fields)} are not primitive types. params cannot require nested fields, "
|
|
193
|
+
"except raw_data."
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
return data
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class ActionResult(BaseModel, Generic[DATA]):
|
|
200
|
+
outcome: Union[Literal["success"], Literal["failure"]] = Field(description="Did the action succeed or fail?")
|
|
201
|
+
summary: str | None = Field(description="Message explaining the outcome of the action.", default=None)
|
|
202
|
+
output: DATA | Url | None = Field(description="The output of the action.", default=None)
|
|
203
|
+
format: str | None = Field(
|
|
204
|
+
description="What is the format of the output? Used to indicate what component to use when rendering "
|
|
205
|
+
"the output.",
|
|
206
|
+
default=None,
|
|
207
|
+
)
|
|
208
|
+
link: Url | None = Field(description="Link to more information on the outcome of the action", default=None)
|
|
209
|
+
|
|
210
|
+
@model_validator(mode="after")
|
|
211
|
+
def validate_model(self: Self, info: ValidationInfo) -> Self: # noqa: C901
|
|
212
|
+
"""Validates the entire model.
|
|
213
|
+
|
|
214
|
+
Raises:
|
|
215
|
+
AssertionError: Raised whenever a field is invalid on the model.
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
Self: The validated model.
|
|
219
|
+
"""
|
|
220
|
+
if not self.format and self.outcome != "failure":
|
|
221
|
+
raise ClueValueError("You must set a format if outcome is not failure.")
|
|
222
|
+
|
|
223
|
+
if self.format == "pivot" and (not self.output or not isinstance(self.output, Url)):
|
|
224
|
+
if isinstance(self.output, str):
|
|
225
|
+
try:
|
|
226
|
+
self.output = Url(self.output)
|
|
227
|
+
return self
|
|
228
|
+
except ValidationError:
|
|
229
|
+
pass
|
|
230
|
+
|
|
231
|
+
raise ClueValueError("When returning a pivot, output must be a Url.")
|
|
232
|
+
|
|
233
|
+
if self.format != "pivot" and isinstance(self.output, Url):
|
|
234
|
+
raise ClueValueError("You can only return a Url if format is set to pivot.")
|
|
235
|
+
|
|
236
|
+
if self.format and not isinstance(self.output, Url):
|
|
237
|
+
self.output = validate_result(self.format, self.output, info)
|
|
238
|
+
|
|
239
|
+
return self
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
class ActionSpec(ActionBase):
|
|
243
|
+
params: dict[str, Any]
|