clue-api 1.0.0.dev7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. clue/.gitignore +21 -0
  2. clue/__init__.py +0 -0
  3. clue/api/__init__.py +211 -0
  4. clue/api/base.py +99 -0
  5. clue/api/v1/__init__.py +82 -0
  6. clue/api/v1/actions.py +92 -0
  7. clue/api/v1/auth.py +243 -0
  8. clue/api/v1/configs.py +83 -0
  9. clue/api/v1/fetchers.py +94 -0
  10. clue/api/v1/lookup.py +221 -0
  11. clue/api/v1/registration.py +109 -0
  12. clue/api/v1/static.py +94 -0
  13. clue/app.py +166 -0
  14. clue/cache/__init__.py +129 -0
  15. clue/common/__init__.py +0 -0
  16. clue/common/classification.py +1006 -0
  17. clue/common/classification.yml +130 -0
  18. clue/common/dict_utils.py +130 -0
  19. clue/common/exceptions.py +199 -0
  20. clue/common/forge.py +152 -0
  21. clue/common/json_utils.py +10 -0
  22. clue/common/list_utils.py +11 -0
  23. clue/common/logging/__init__.py +291 -0
  24. clue/common/logging/audit.py +157 -0
  25. clue/common/logging/format.py +42 -0
  26. clue/common/regex.py +31 -0
  27. clue/common/str_utils.py +213 -0
  28. clue/common/swagger.py +139 -0
  29. clue/common/uid.py +47 -0
  30. clue/config.py +60 -0
  31. clue/constants/__init__.py +0 -0
  32. clue/constants/supported_types.py +38 -0
  33. clue/cronjobs/__init__.py +30 -0
  34. clue/cronjobs/plugins.py +32 -0
  35. clue/error.py +129 -0
  36. clue/gunicorn_config.py +29 -0
  37. clue/healthz.py +74 -0
  38. clue/helper/discover.py +53 -0
  39. clue/helper/headers.py +30 -0
  40. clue/helper/oauth.py +128 -0
  41. clue/models/__init__.py +0 -0
  42. clue/models/actions.py +243 -0
  43. clue/models/config.py +456 -0
  44. clue/models/fetchers.py +136 -0
  45. clue/models/graph.py +162 -0
  46. clue/models/model_list.py +52 -0
  47. clue/models/network.py +430 -0
  48. clue/models/results/__init__.py +34 -0
  49. clue/models/results/base.py +10 -0
  50. clue/models/results/graph.py +26 -0
  51. clue/models/results/image.py +22 -0
  52. clue/models/results/status.py +55 -0
  53. clue/models/results/validation.py +57 -0
  54. clue/models/selector.py +67 -0
  55. clue/models/utils.py +52 -0
  56. clue/models/validators.py +19 -0
  57. clue/patched.py +8 -0
  58. clue/plugin/__init__.py +1008 -0
  59. clue/plugin/helpers/__init__.py +0 -0
  60. clue/plugin/helpers/central_server.py +27 -0
  61. clue/plugin/helpers/email_render.py +228 -0
  62. clue/plugin/helpers/token.py +34 -0
  63. clue/plugin/helpers/trino.py +103 -0
  64. clue/plugin/interactive.py +270 -0
  65. clue/plugin/models.py +19 -0
  66. clue/plugin/utils.py +78 -0
  67. clue/remote/__init__.py +0 -0
  68. clue/remote/datatypes/__init__.py +130 -0
  69. clue/remote/datatypes/cache.py +62 -0
  70. clue/remote/datatypes/events.py +118 -0
  71. clue/remote/datatypes/hash.py +193 -0
  72. clue/remote/datatypes/queues/__init__.py +0 -0
  73. clue/remote/datatypes/queues/comms.py +62 -0
  74. clue/remote/datatypes/set.py +96 -0
  75. clue/remote/datatypes/user_quota_tracker.py +54 -0
  76. clue/security/__init__.py +211 -0
  77. clue/security/obo.py +95 -0
  78. clue/security/utils.py +34 -0
  79. clue/services/action_service.py +186 -0
  80. clue/services/auth_service.py +348 -0
  81. clue/services/config_service.py +38 -0
  82. clue/services/fetcher_service.py +203 -0
  83. clue/services/jwt_service.py +233 -0
  84. clue/services/lookup_service.py +786 -0
  85. clue/services/type_service.py +165 -0
  86. clue/services/user_service.py +152 -0
  87. clue_api-1.0.0.dev7.dist-info/METADATA +111 -0
  88. clue_api-1.0.0.dev7.dist-info/RECORD +91 -0
  89. clue_api-1.0.0.dev7.dist-info/WHEEL +4 -0
  90. clue_api-1.0.0.dev7.dist-info/entry_points.txt +8 -0
  91. clue_api-1.0.0.dev7.dist-info/licenses/LICENSE +11 -0
@@ -0,0 +1,96 @@
1
+ import json
2
+ import time
3
+
4
+ from clue.remote.datatypes import get_client, retry_call
5
+
6
+ _drop_card_script = """
7
+ local set_name = ARGV[1]
8
+ local key = ARGV[2]
9
+
10
+ redis.call('srem', set_name, key)
11
+ return redis.call('scard', set_name)
12
+ """
13
+
14
+ _limited_add = """
15
+ local set_name = KEYS[1]
16
+ local key = ARGV[1]
17
+ local limit = tonumber(ARGV[2])
18
+
19
+ if redis.call('scard', set_name) < limit then
20
+ redis.call('sadd', set_name, key)
21
+ return true
22
+ end
23
+ return false
24
+ """
25
+
26
+
27
+ class Set(object):
28
+ def __init__(self, name, host=None, port=None):
29
+ self.c = get_client(host, port, False)
30
+ self.name = name
31
+ self._drop_card = self.c.register_script(_drop_card_script)
32
+ self._limited_add = self.c.register_script(_limited_add)
33
+
34
+ def __enter__(self):
35
+ return self
36
+
37
+ def __exit__(self, exc_type, exc_val, exc_tb):
38
+ self.delete()
39
+
40
+ def add(self, *values):
41
+ return retry_call(self.c.sadd, self.name, *[json.dumps(v) for v in values])
42
+
43
+ def limited_add(self, value, size_limit):
44
+ """Add a single value to the set, but only if that wouldn't make the set grow past a given size."""
45
+ return retry_call(self._limited_add, keys=[self.name], args=[json.dumps(value), size_limit])
46
+
47
+ def exist(self, value):
48
+ return retry_call(self.c.sismember, self.name, json.dumps(value))
49
+
50
+ def length(self):
51
+ return retry_call(self.c.scard, self.name)
52
+
53
+ def members(self):
54
+ return [json.loads(s) for s in retry_call(self.c.smembers, self.name)]
55
+
56
+ def remove(self, *values):
57
+ return retry_call(self.c.srem, self.name, *[json.dumps(v) for v in values])
58
+
59
+ def drop(self, value):
60
+ return retry_call(self._drop_card, args=[value])
61
+
62
+ def random(self, num=None):
63
+ ret_val = retry_call(self.c.srandmember, self.name, num)
64
+ if isinstance(ret_val, list):
65
+ return [json.loads(s) for s in ret_val]
66
+ else:
67
+ return json.loads(ret_val)
68
+
69
+ def pop(self):
70
+ data = retry_call(self.c.spop, self.name)
71
+ return json.loads(data) if data else None
72
+
73
+ def pop_all(self):
74
+ return [json.loads(s) for s in retry_call(self.c.spop, self.name, self.length())]
75
+
76
+ def delete(self):
77
+ retry_call(self.c.delete, self.name)
78
+
79
+
80
+ class ExpiringSet(Set):
81
+ def __init__(self, name, ttl=86400, host=None, port=None):
82
+ super(ExpiringSet, self).__init__(name, host, port)
83
+ self.ttl = ttl
84
+ self.last_expire_time = 0
85
+
86
+ def _conditional_expire(self):
87
+ if self.ttl:
88
+ ctime = time.time()
89
+ if ctime > self.last_expire_time + (self.ttl / 2):
90
+ retry_call(self.c.expire, self.name, self.ttl)
91
+ self.last_expire_time = ctime
92
+
93
+ def add(self, *values):
94
+ rval = super(ExpiringSet, self).add(*values)
95
+ self._conditional_expire()
96
+ return rval
@@ -0,0 +1,54 @@
1
+ import redis
2
+
3
+ from clue.remote.datatypes import get_client, retry_call
4
+
5
+ begin_script = """
6
+ local t = redis.call('time')
7
+ local key = tonumber(t[1] .. string.format("%06d", t[2]))
8
+
9
+ local name = ARGV[1]
10
+ local max = tonumber(ARGV[2])
11
+ local timeout = tonumber(ARGV[3] .. "000000")
12
+
13
+ redis.call('zremrangebyscore', name, 0, key - timeout)
14
+ if redis.call('zcard', name) < max then
15
+ redis.call('zadd', name, key, key)
16
+ return true
17
+ else
18
+ return false
19
+ end
20
+ """
21
+
22
+
23
+ class UserQuotaTracker(object):
24
+ def __init__(self, prefix, timeout=120, redis=None, host=None, port=None, private=False):
25
+ self.c = redis or get_client(host, port, private)
26
+ self.bs = self.c.register_script(begin_script)
27
+ self.prefix = prefix
28
+ self.timeout = timeout
29
+
30
+ def _queue_name(self, user):
31
+ return f"{self.prefix}-{user}"
32
+
33
+ def begin(self, user: str, max_quota: int) -> bool:
34
+ try:
35
+ return retry_call(self.bs, args=[self._queue_name(user), max_quota, self.timeout]) == 1
36
+ except redis.exceptions.ResponseError as er:
37
+ # TODO: This is a failsafe for upgrade purposes could be removed in a future version
38
+ if "WRONGTYPE" in str(er):
39
+ retry_call(self.c.delete, self._queue_name(user))
40
+ return retry_call(self.bs, args=[self._queue_name(user), max_quota, self.timeout]) == 1
41
+ else:
42
+ raise
43
+
44
+ def end(self, user: str):
45
+ """When only one item is requested, blocking it is possible."""
46
+ try:
47
+ retry_call(self.c.zpopmin, self._queue_name(user))
48
+ except redis.exceptions.ResponseError as er:
49
+ # TODO: This is a failsafe for upgrade purposes could be removed in a future version
50
+ if "WRONGTYPE" in str(er):
51
+ retry_call(self.c.delete, self._queue_name(user))
52
+ retry_call(self.c.zpopmin, self._queue_name(user))
53
+ else:
54
+ raise
@@ -0,0 +1,211 @@
1
+ import functools
2
+ from typing import Callable, Optional
3
+
4
+ import elasticapm
5
+ import requests
6
+ from flask import request
7
+ from jwt import ExpiredSignatureError
8
+ from prometheus_client import Counter
9
+
10
+ import clue.services.auth_service as auth_service
11
+ from clue.api import bad_request, forbidden, internal_error, not_found, unauthorized
12
+ from clue.common.exceptions import (
13
+ AccessDeniedException,
14
+ AuthenticationException,
15
+ ClueAttributeError,
16
+ ClueNotImplementedError,
17
+ ClueRuntimeError,
18
+ InvalidDataException,
19
+ NotFoundException,
20
+ )
21
+ from clue.common.forge import APP_NAME
22
+ from clue.common.logging import get_logger
23
+ from clue.common.logging.audit import audit
24
+ from clue.config import AUDIT, config
25
+
26
+ logger = get_logger(__file__)
27
+
28
+ SUCCESSFUL_ATTEMPTS = Counter(
29
+ f"{APP_NAME.replace('-', '_')}_auth_success_total",
30
+ "Successful Authentication Attempts",
31
+ )
32
+
33
+ FAILED_ATTEMPTS = Counter(
34
+ f"{APP_NAME.replace('-', '_')}_auth_fail_total",
35
+ "Failed Authentication Attempts",
36
+ ["status"],
37
+ )
38
+
39
+ XSRF_ENABLED = True
40
+
41
+
42
+ ####################################
43
+ # API Helper func and decorators
44
+ # noinspection PyPep8Naming
45
+ class api_login(object): # noqa: N801
46
+ """Adds authentication to an endpoint"""
47
+
48
+ def __init__(
49
+ self, # noqa: ANN101
50
+ # TODO: Fix type parsing and checks
51
+ # required_type: Optional[list[str]] = None,
52
+ username_key: str = "username",
53
+ audit: bool = True,
54
+ required_priv: Optional[list[str]] = None,
55
+ required_method: Optional[list[str]] = None,
56
+ check_xsrf_token: bool = XSRF_ENABLED,
57
+ ):
58
+ if required_priv is None:
59
+ required_priv = ["R", "W"]
60
+
61
+ # TODO: Fix type parsing and checks
62
+ # if required_type is None:
63
+ # required_type = ["admin", "user"]
64
+
65
+ required_method_set: set[str]
66
+ if required_method is None:
67
+ required_method_set = {"userpass", "apikey", "internal", "oauth"}
68
+ else:
69
+ required_method_set = set(required_method)
70
+
71
+ if len(required_method_set - {"userpass", "apikey", "internal", "oauth"}) > 0:
72
+ raise ClueAttributeError("required_method must be a subset of {userpass, apikey, internal, oauth}")
73
+
74
+ # TODO: Fix type parsing and checks
75
+ # self.required_type = required_type
76
+ self.audit = audit and AUDIT
77
+ self.required_priv = required_priv
78
+ self.required_method = required_method_set
79
+ self.username_key = username_key
80
+ self.check_xsrf_token = check_xsrf_token
81
+
82
+ def __call__(self, func: Callable) -> Callable: # noqa: ANN101, C901
83
+ """Wraps any function calls with authentication logic that uses either userpass, apikey, internal or oauth.
84
+
85
+ Args:
86
+ func (Callable): The function to wrap with auth.
87
+
88
+ Raises:
89
+ AuthenticationException: Raised whenever there's an actual problem with the provided authentication.
90
+ InvalidDataException: Raised whenever data is incorrectly formatted.
91
+ ClueRuntimeError: Raised whenever there is a connection error with the oauth provider
92
+ AccessDeniedException: Raised whenever the authentication is valid but the authenticated identity doesn't
93
+ have the required access
94
+
95
+ Returns:
96
+ _type_: _description_
97
+ """
98
+
99
+ @functools.wraps(func)
100
+ def base(*args, **kwargs): # noqa: C901
101
+ try:
102
+ # All authorization (except impersonation) must go through the Authorization header, in one of
103
+ # four formats:
104
+ # 1. Basic user/pass authentication
105
+ # Authorization: Basic username:password (but in base64)
106
+ # 2. Basic user/apikey authentication
107
+ # Authorization: Basic username:keyname:keydata (but in base64)
108
+ # 3. Bearer internal token authentication (obtained from the login endpoint)
109
+ # Authorization: Bearer username:token
110
+ # 4. Bearer OAuth authentication (obtained from external authentication provider i.e. azure, keycloak)
111
+ # Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjMifQ (example)
112
+ authorization = request.headers.get("Authorization", None)
113
+ if not authorization:
114
+ raise AuthenticationException("No Authorization header present")
115
+ elif " " not in authorization or len(authorization.split(" ")) > 2:
116
+ raise InvalidDataException("Incorrectly formatted Authorization header")
117
+
118
+ logger.debug("Authenticating user for path %s", request.path)
119
+
120
+ [auth_type, data] = authorization.split(" ")
121
+
122
+ user = None
123
+ if auth_type == "Basic" and len(self.required_method & {"userpass", "apikey"}) > 0:
124
+ # Authenticate case (1) and (2) above
125
+ user, priv = auth_service.basic_auth(
126
+ data,
127
+ skip_apikey="apikey" not in self.required_method,
128
+ skip_password="userpass" not in self.required_method,
129
+ )
130
+ elif auth_type == "Bearer" and len(self.required_method & {"internal", "oauth"}) > 0:
131
+ # Authenticate case (3) and (4) above
132
+ try:
133
+ user, priv = auth_service.bearer_auth(
134
+ data,
135
+ skip_jwt="oauth" not in self.required_method,
136
+ skip_internal="internal" not in self.required_method,
137
+ )
138
+ except ExpiredSignatureError as e:
139
+ raise AuthenticationException("Token Expired") from e
140
+ except (requests.exceptions.ConnectionError, ConnectionError) as e:
141
+ logger.exception("Failed to connect to OAuth Provider:")
142
+ raise ClueRuntimeError("Failed to connect to OAuth Provider") from e
143
+ else:
144
+ raise InvalidDataException("Not a valid authentication type for this endpoint.")
145
+
146
+ if not user:
147
+ raise AuthenticationException("No authenticated user found")
148
+
149
+ # Ensure that the provided api key allows access to this API
150
+ if not priv or not set(self.required_priv) & set(priv):
151
+ raise AccessDeniedException("You do not have access to this API.")
152
+
153
+ # Make sure the user has the correct type for this endpoint
154
+ # TODO: Fix type parsing and checks
155
+ # if not set(self.required_type) & set(user["type"]):
156
+ # logger.warning(
157
+ # f"{user['uname']} is missing one of the types: {', '.join(self.required_type)}. "
158
+ # "Cannot access {request.path}"
159
+ # )
160
+ # raise AccessDeniedException(
161
+ # f"{request.path} requires one of the following user types: {', '.join(self.required_type)}"
162
+ # )
163
+
164
+ ip = request.headers.get("X-Forwarded-For", request.remote_addr)
165
+ logger.info(f"Logged in as {user['uname']} from {ip}")
166
+
167
+ # If auditing is enabled, write this successful access to the audit logs
168
+ if self.audit:
169
+ audit(
170
+ args,
171
+ kwargs,
172
+ user,
173
+ func,
174
+ )
175
+ except (InvalidDataException, ClueNotImplementedError) as e:
176
+ FAILED_ATTEMPTS.labels("400").inc()
177
+ return bad_request(err=e.message)
178
+ except AuthenticationException as e:
179
+ FAILED_ATTEMPTS.labels("401").inc()
180
+ return unauthorized(err=e.message)
181
+ except AccessDeniedException as e:
182
+ FAILED_ATTEMPTS.labels("403").inc()
183
+ return forbidden(err=e.message)
184
+ except NotFoundException:
185
+ FAILED_ATTEMPTS.labels("404").inc()
186
+ return not_found()
187
+ except ClueRuntimeError as e:
188
+ FAILED_ATTEMPTS.labels("500").inc()
189
+ return internal_error(err=e.message)
190
+
191
+ if config.core.metrics.apm_server.server_url is not None:
192
+ elasticapm.set_user_context(
193
+ username=user.get("name", None),
194
+ email=user.get("email", None),
195
+ user_id=user.get("uname", None),
196
+ )
197
+
198
+ # Save user data in kwargs for future reference in the wrapped method
199
+ kwargs["user"] = user
200
+
201
+ SUCCESSFUL_ATTEMPTS.inc()
202
+ return func(*args, **kwargs)
203
+
204
+ base.protected = True
205
+ # TODO: Fix type parsing and checks
206
+ # base.required_type = self.required_type
207
+ base.audit = self.audit
208
+ base.required_priv = self.required_priv
209
+ base.required_method = self.required_method
210
+ base.check_xsrf_token = self.check_xsrf_token
211
+ return base
clue/security/obo.py ADDED
@@ -0,0 +1,95 @@
1
+ from datetime import datetime
2
+ from typing import Optional
3
+
4
+ # from hogwarts.auth.vault.exceptions import VaultRequestException
5
+ # from hogwarts.auth.vault.vault_client import VaultClient
6
+ from clue.common.exceptions import InvalidDataException
7
+ from clue.common.logging import get_logger
8
+ from clue.config import config, get_redis
9
+ from clue.remote.datatypes.set import ExpiringSet
10
+ from clue.security.utils import decode_jwt_payload
11
+
12
+ logger = get_logger(__file__)
13
+
14
+
15
+ def _get_obo_token_store(service: str, user: str) -> ExpiringSet:
16
+ """Get an expiring redis set in which to add a token
17
+
18
+ Args:
19
+ user (str): The user the token corresponds to
20
+
21
+ Returns:
22
+ ExpiringSet: The set in which we'll store the token
23
+ """
24
+ return ExpiringSet(f"{service}_token_{user}", host=get_redis(), ttl=60 * 5)
25
+
26
+
27
+ def _get_token_raw(service: str, user: str) -> Optional[str]:
28
+ token_store = _get_obo_token_store(service, user)
29
+
30
+ if token_store.length() > 0:
31
+ return token_store.random(1)[0]
32
+
33
+ return None
34
+
35
+
36
+ def get_obo_token(service: str, access_token: str, user: str, force_refresh: bool = False):
37
+ """Gets an On-Behalf-Of token from either the Redis cache or from the Vault API.
38
+
39
+ Args:
40
+ service (str): The target application we want a token for.
41
+ access_token (str): The access token we want to use for the exchange.
42
+ user (str): The name of the user.
43
+ force_refresh (bool, optional): Allows to skip the Redis cache and get a new token. Defaults to False.
44
+
45
+ Raises:
46
+ InvalidDataException: Raised whenever an invalid OBO target is provided.
47
+
48
+ Returns:
49
+ Optional[str]: The access token for the targeted application.
50
+ """
51
+ if service not in config.api.obo_targets:
52
+ raise InvalidDataException("Not a valid OBO target")
53
+
54
+ # For testing purposes, we special-case test-obo
55
+ if service == "test-obo":
56
+ return access_token
57
+
58
+ try:
59
+ obo_access_token: str | None = None
60
+
61
+ if not force_refresh:
62
+ obo_access_token = _get_token_raw(service, user)
63
+
64
+ if obo_access_token is not None:
65
+ expiry = datetime.fromtimestamp(decode_jwt_payload(obo_access_token)["exp"])
66
+
67
+ if expiry < datetime.now():
68
+ logger.warning("Cached token has expired")
69
+ obo_access_token = None
70
+
71
+ if obo_access_token is None:
72
+ logger.info(f"Fetching OBO token for user {user} to service {service}")
73
+
74
+ logger.debug("Contacting vault for new OBO token")
75
+ # vault_client = VaultClient(url=config.api.vault_url)
76
+ # obo_access_token, _ = vault_client.on_behalf_of(
77
+ # config.api.obo_targets[service].scope,
78
+ # access_token,
79
+ # token_client_name=APP_NAME.replace("-dev", ""),
80
+ # )
81
+ obo_access_token = None
82
+
83
+ if obo_access_token:
84
+ service_token_store = _get_obo_token_store(service, user)
85
+ service_token_store.pop_all()
86
+ service_token_store.add(obo_access_token)
87
+ else:
88
+ logger.error("Vault OBO failed, no token received.")
89
+ else:
90
+ logger.debug("Using cached OBO token")
91
+
92
+ return obo_access_token
93
+ except Exception:
94
+ # except VaultRequestException:
95
+ logger.exception("VaultRequestException on OBO:")
clue/security/utils.py ADDED
@@ -0,0 +1,34 @@
1
+ import base64
2
+ import json
3
+ import os
4
+ from typing import Any
5
+
6
+ UPPERCASE = r"[A-Z]"
7
+ LOWERCASE = r"[a-z]"
8
+ NUMBER = r"[0-9]"
9
+ SPECIAL = r'[ !#$@%&\'()*+,-./[\\\]^_`{|}~"]'
10
+ PASS_BASIC = (
11
+ [chr(x + 65) for x in range(26)]
12
+ + [chr(x + 97) for x in range(26)]
13
+ + [str(x) for x in range(10)]
14
+ + ["!", "@", "$", "^", "?", "&", "*", "(", ")"]
15
+ )
16
+
17
+
18
+ def generate_random_secret(length: int = 25) -> str:
19
+ """Generate a random secret
20
+
21
+ Args:
22
+ length (int, optional): The length of the secret. Defaults to 25.
23
+
24
+ Returns:
25
+ str: The random secret
26
+ """
27
+ return base64.b32encode(os.urandom(length)).decode("UTF-8")
28
+
29
+
30
+ def decode_jwt_payload(jwt: str) -> dict[str, Any]:
31
+ "Decode a JWT payload. DOES NOT VALIDATE THE JWT, DO NOT USE THIS TO VALIDATE THE TOKEN."
32
+ payload = jwt.split(".")[1]
33
+
34
+ return json.loads(base64.urlsafe_b64decode(payload + "=" * (-len(payload) % 4)).decode())
@@ -0,0 +1,186 @@
1
+ from typing import Any, Optional
2
+ from urllib.parse import urljoin
3
+
4
+ import elasticapm
5
+ import requests
6
+ from flask import request
7
+ from pydantic import TypeAdapter, ValidationError
8
+ from requests import JSONDecodeError, exceptions
9
+
10
+ from clue.common.exceptions import ClueException, NotFoundException
11
+ from clue.common.logging import get_logger
12
+ from clue.config import CLASSIFICATION, config
13
+ from clue.helper.headers import generate_headers
14
+ from clue.models.actions import ActionResult, ActionSpec
15
+ from clue.models.config import ExternalSource
16
+ from clue.services import auth_service
17
+
18
+ logger = get_logger(__file__)
19
+
20
+
21
+ def get_supported_actions(
22
+ source: ExternalSource, user: dict[str, Any], access_token: Optional[str] = None
23
+ ) -> dict[str, ActionSpec]:
24
+ """Gets all supported actions for a source
25
+
26
+ Args:
27
+ source_url (str): The URL of the source
28
+ access_token (Optional[str], optional): The access token to use, if necessary. Defaults to None.
29
+
30
+ Returns:
31
+ dict[str, ActionSpec]: A dict of each action and their schema
32
+ """
33
+ logger.info("Fetching actions for source %s", source.name)
34
+
35
+ url = urljoin(source.url, "actions/")
36
+
37
+ obo_access_token = None
38
+ if access_token:
39
+ obo_access_token, error = auth_service.check_obo(source, access_token, user["uname"])
40
+
41
+ if error:
42
+ logger.error("%s: %s", source.name, error)
43
+ return {}
44
+
45
+ headers = generate_headers(obo_access_token or access_token, access_token if obo_access_token else None)
46
+
47
+ with elasticapm.capture_span(f"GET {url}", span_type="http"):
48
+ try:
49
+ rsp = requests.get(url, headers=headers, timeout=10.0)
50
+ result = rsp.json()
51
+
52
+ if not rsp.ok:
53
+ err = result["api_error_message"]
54
+ logger.error(f"Error from upstream server: {rsp.status_code=}, {err=}")
55
+
56
+ return TypeAdapter(dict[str, ActionSpec]).validate_python(result["api_response"])
57
+ except exceptions.ConnectionError:
58
+ # any errors are logged and no result is saved to local cache to enable retry on next query
59
+ logger.exception("Unable to connect: %s", url)
60
+ return {}
61
+ except (requests.exceptions.JSONDecodeError, KeyError, JSONDecodeError):
62
+ logger.exception("External API did not return expected format. Full data:\n\n%s\n\nStack Trace:", rsp.text)
63
+ return {}
64
+ except ValidationError:
65
+ logger.exception("ValidationError in response from %s:\n%s", source.url)
66
+ return {}
67
+ except Exception:
68
+ logger.exception("Unknown exception occurred on action fetching:")
69
+ return {}
70
+
71
+
72
+ def all_supported_actions(user: dict[str, Any], access_token: Optional[str] = None) -> dict[str, ActionSpec]:
73
+ """Gets all supported actions for all sources
74
+
75
+ Args:
76
+ access_token (Optional[str], optional): The access token to use, if necessary. Defaults to None.
77
+
78
+ Returns:
79
+ dict[str, ActionSpec]: A dict of all actions and their matching schema
80
+ """
81
+ all_actions: dict[str, ActionSpec] = {}
82
+
83
+ for source in config.api.external_sources:
84
+ supported_actions = get_supported_actions(source, user, access_token=access_token)
85
+ total_actions = 0
86
+ for key, action in supported_actions.items():
87
+ total_actions += 1
88
+ all_actions[f"{source.name}.{key}"] = action
89
+ logger.debug("Plugin %s exposes %s action(s)", source.name, total_actions)
90
+
91
+ return all_actions
92
+
93
+
94
+ def get_plugins_supported_actions(user: dict[str, Any]) -> dict[str, ActionSpec]:
95
+ """Return the supported actions of each external service, filtered to what the user has access to."""
96
+ available_actions: dict[str, ActionSpec] = {}
97
+
98
+ access_token = request.headers.get("Authorization", type=str)
99
+ if access_token:
100
+ access_token = access_token.split(" ")[1]
101
+
102
+ all_actions = all_supported_actions(
103
+ user,
104
+ access_token=access_token,
105
+ )
106
+
107
+ logger.info("Fetching actions for classification %s", user["classification"])
108
+
109
+ for action_id, action in all_actions.items():
110
+ # Validate if the user is allow to even see the source
111
+ if user and not CLASSIFICATION.is_accessible(user["classification"], action.classification):
112
+ logger.info(
113
+ "Not including actions from source %s at classification %s", action.name, user["classification"]
114
+ )
115
+ continue
116
+
117
+ # user can view source, now filter types user cannot see
118
+ available_actions[action_id] = action
119
+
120
+ logger.info("%s actions are available for user %s", len(available_actions), user["uname"])
121
+
122
+ return available_actions
123
+
124
+
125
+ def execute_action(plugin_id: str, action_id: str, user: dict[str, Any]) -> ActionResult:
126
+ """Executes a specified action.
127
+
128
+ Args:
129
+ plugin_id (str): The ID of the plugin.
130
+ action_id (str): The ID of the action to run.
131
+ user (dict[str, Any]): The user dict of the user running the action.
132
+
133
+ Raises:
134
+ NotFoundException: Raised whenever the plugin or the action doesn't exist.
135
+ ClueException: Raised whenever an error is returned by the plugin endpoint.
136
+
137
+ Returns:
138
+ ActionResult: The result of the action.
139
+ """
140
+ plugin = next((source for source in config.api.external_sources if source.name == plugin_id), None)
141
+
142
+ if not plugin:
143
+ raise NotFoundException(f"Plugin {plugin_id} does not exist.")
144
+
145
+ access_token = request.headers.get("Authorization", type=str)
146
+ if access_token:
147
+ access_token = access_token.split(" ")[1]
148
+
149
+ obo_access_token = None
150
+ if access_token:
151
+ obo_access_token, error = auth_service.check_obo(plugin, access_token, user["uname"])
152
+
153
+ if error:
154
+ logger.error("%s: %s", plugin.name, error)
155
+ return ActionResult(outcome="failure", summary="Invalid token provided for this enrichment.")
156
+
157
+ headers = generate_headers(obo_access_token or access_token, access_token if obo_access_token else None)
158
+
159
+ if request.content_type == "application/json":
160
+ parameters = request.json
161
+ else:
162
+ # TODO: Pass parameters via urlencode?
163
+ parameters = {}
164
+
165
+ try:
166
+ req_url = urljoin(plugin.url, f"actions/{action_id}")
167
+ logger.debug("Executing action %s for user %s", req_url, user["uname"])
168
+
169
+ response = requests.post(
170
+ req_url,
171
+ json=parameters,
172
+ headers=headers,
173
+ timeout=request.args.get("max_timeout", plugin.default_timeout, type=float),
174
+ )
175
+
176
+ result = response.json()
177
+
178
+ if not response.ok:
179
+ raise ClueException(result["api_error_message"])
180
+
181
+ return ActionResult.model_validate(result["api_response"])
182
+ except (JSONDecodeError, exceptions.ConnectionError) as err:
183
+ logger.exception(f"Something went wrong when retrieving the result from plugin '{plugin_id}'")
184
+ raise ClueException(
185
+ f"Something went wrong when retrieving the result from plugin '{plugin_id}': {err.__class__.__name__}."
186
+ )