clue-api 1.0.0.dev7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. clue/.gitignore +21 -0
  2. clue/__init__.py +0 -0
  3. clue/api/__init__.py +211 -0
  4. clue/api/base.py +99 -0
  5. clue/api/v1/__init__.py +82 -0
  6. clue/api/v1/actions.py +92 -0
  7. clue/api/v1/auth.py +243 -0
  8. clue/api/v1/configs.py +83 -0
  9. clue/api/v1/fetchers.py +94 -0
  10. clue/api/v1/lookup.py +221 -0
  11. clue/api/v1/registration.py +109 -0
  12. clue/api/v1/static.py +94 -0
  13. clue/app.py +166 -0
  14. clue/cache/__init__.py +129 -0
  15. clue/common/__init__.py +0 -0
  16. clue/common/classification.py +1006 -0
  17. clue/common/classification.yml +130 -0
  18. clue/common/dict_utils.py +130 -0
  19. clue/common/exceptions.py +199 -0
  20. clue/common/forge.py +152 -0
  21. clue/common/json_utils.py +10 -0
  22. clue/common/list_utils.py +11 -0
  23. clue/common/logging/__init__.py +291 -0
  24. clue/common/logging/audit.py +157 -0
  25. clue/common/logging/format.py +42 -0
  26. clue/common/regex.py +31 -0
  27. clue/common/str_utils.py +213 -0
  28. clue/common/swagger.py +139 -0
  29. clue/common/uid.py +47 -0
  30. clue/config.py +60 -0
  31. clue/constants/__init__.py +0 -0
  32. clue/constants/supported_types.py +38 -0
  33. clue/cronjobs/__init__.py +30 -0
  34. clue/cronjobs/plugins.py +32 -0
  35. clue/error.py +129 -0
  36. clue/gunicorn_config.py +29 -0
  37. clue/healthz.py +74 -0
  38. clue/helper/discover.py +53 -0
  39. clue/helper/headers.py +30 -0
  40. clue/helper/oauth.py +128 -0
  41. clue/models/__init__.py +0 -0
  42. clue/models/actions.py +243 -0
  43. clue/models/config.py +456 -0
  44. clue/models/fetchers.py +136 -0
  45. clue/models/graph.py +162 -0
  46. clue/models/model_list.py +52 -0
  47. clue/models/network.py +430 -0
  48. clue/models/results/__init__.py +34 -0
  49. clue/models/results/base.py +10 -0
  50. clue/models/results/graph.py +26 -0
  51. clue/models/results/image.py +22 -0
  52. clue/models/results/status.py +55 -0
  53. clue/models/results/validation.py +57 -0
  54. clue/models/selector.py +67 -0
  55. clue/models/utils.py +52 -0
  56. clue/models/validators.py +19 -0
  57. clue/patched.py +8 -0
  58. clue/plugin/__init__.py +1008 -0
  59. clue/plugin/helpers/__init__.py +0 -0
  60. clue/plugin/helpers/central_server.py +27 -0
  61. clue/plugin/helpers/email_render.py +228 -0
  62. clue/plugin/helpers/token.py +34 -0
  63. clue/plugin/helpers/trino.py +103 -0
  64. clue/plugin/interactive.py +270 -0
  65. clue/plugin/models.py +19 -0
  66. clue/plugin/utils.py +78 -0
  67. clue/remote/__init__.py +0 -0
  68. clue/remote/datatypes/__init__.py +130 -0
  69. clue/remote/datatypes/cache.py +62 -0
  70. clue/remote/datatypes/events.py +118 -0
  71. clue/remote/datatypes/hash.py +193 -0
  72. clue/remote/datatypes/queues/__init__.py +0 -0
  73. clue/remote/datatypes/queues/comms.py +62 -0
  74. clue/remote/datatypes/set.py +96 -0
  75. clue/remote/datatypes/user_quota_tracker.py +54 -0
  76. clue/security/__init__.py +211 -0
  77. clue/security/obo.py +95 -0
  78. clue/security/utils.py +34 -0
  79. clue/services/action_service.py +186 -0
  80. clue/services/auth_service.py +348 -0
  81. clue/services/config_service.py +38 -0
  82. clue/services/fetcher_service.py +203 -0
  83. clue/services/jwt_service.py +233 -0
  84. clue/services/lookup_service.py +786 -0
  85. clue/services/type_service.py +165 -0
  86. clue/services/user_service.py +152 -0
  87. clue_api-1.0.0.dev7.dist-info/METADATA +111 -0
  88. clue_api-1.0.0.dev7.dist-info/RECORD +91 -0
  89. clue_api-1.0.0.dev7.dist-info/WHEEL +4 -0
  90. clue_api-1.0.0.dev7.dist-info/entry_points.txt +8 -0
  91. clue_api-1.0.0.dev7.dist-info/licenses/LICENSE +11 -0
@@ -0,0 +1,786 @@
1
+ import functools
2
+ import itertools
3
+ import json
4
+ import math
5
+ import os
6
+ import time
7
+ from datetime import datetime, timedelta, timezone
8
+ from hashlib import sha256
9
+ from typing import Any, Optional
10
+ from urllib.parse import urlparse
11
+
12
+ import elasticapm
13
+ from elasticapm.traces import Transaction, execution_context
14
+ from flask import Request, request
15
+ from gevent import Greenlet
16
+ from gevent.pool import Pool
17
+ from geventhttpclient import HTTPClient
18
+ from geventhttpclient.response import HTTPResponse
19
+ from pydantic import BaseModel, ValidationError
20
+ from requests import Response
21
+
22
+ from clue.common.exceptions import AuthenticationException, InvalidDataException
23
+ from clue.common.logging import get_logger, log_error
24
+ from clue.common.logging.audit import audit
25
+ from clue.config import CLASSIFICATION as CLASSIFICATION
26
+ from clue.config import DEBUG, config
27
+ from clue.helper.headers import generate_headers
28
+ from clue.models.config import ExternalSource
29
+ from clue.models.network import QueryEntry, QueryResult
30
+ from clue.models.selector import Selector
31
+ from clue.services import auth_service, type_service, user_service
32
+
33
+ logger = get_logger(__file__)
34
+ CLIENTS: dict[str, HTTPClient] = {}
35
+
36
+
37
+ def get_client(base_url: str, timeout: float) -> HTTPClient:
38
+ """Gets or creates an HTTPClient for the provided base_url.
39
+
40
+ Args:
41
+ base_url (str): The base url of the desired client.
42
+ timeout (float): The connection and network timeout to use (is multiplied by 3).
43
+
44
+ Returns:
45
+ HTTPClient: The HTTPClient instance matching the provided base_url.
46
+ """
47
+ client_hash = sha256(base_url.encode())
48
+ client_hash.update(str(timeout).encode())
49
+ client_key = client_hash.hexdigest()
50
+
51
+ if client_key not in CLIENTS:
52
+ # Pool of 16 connections by default
53
+ CLIENTS[client_key] = HTTPClient.from_url(
54
+ base_url,
55
+ concurrency=math.floor(int(os.environ.get("EXECUTOR_THREADS", 32)) / 2),
56
+ connection_timeout=timeout * 3,
57
+ network_timeout=timeout * 3,
58
+ )
59
+
60
+ return CLIENTS[client_key]
61
+
62
+
63
+ def build_result(
64
+ type_name: str, value: str, source: ExternalSource, error: Optional[str] = None, latency: Optional[float] = None
65
+ ):
66
+ """Builds the QueryResult object using the provided values.
67
+
68
+ Args:
69
+ type_name (str): The name of the type of result.
70
+ value (str): The value of the result.
71
+ source (ExternalSource): The ExternalSource that provided the result.
72
+ error (Optional[str], optional): The error that occured during the request. Defaults to None.
73
+ latency (Optional[float], optional): The amount of time between the request and the response (in milliseconds).
74
+ Defaults to None.
75
+
76
+ Returns:
77
+ QueryResult: The QueryResult object built.
78
+ """
79
+ if error and error != "invalid_type" and error.lower() != "request timed out":
80
+ logger.warning(error)
81
+
82
+ if DEBUG:
83
+ logger.debug("Building query result for source %s", source.name)
84
+
85
+ return QueryResult(
86
+ type=type_name,
87
+ value=value,
88
+ source=source.name,
89
+ maintainer=source.maintainer,
90
+ datahub_link=source.datahub_link,
91
+ documentation_link=source.documentation_link,
92
+ error=error,
93
+ latency=latency or 0,
94
+ )
95
+
96
+
97
+ class ParsedParams(BaseModel):
98
+ "Validation of parameters parsed from request"
99
+
100
+ query_sources: list[str]
101
+ max_timeout: float
102
+ limit: int
103
+ type_classification: str
104
+ no_annotation: bool
105
+ include_raw: bool
106
+ exclude_unset: bool
107
+ no_cache: bool
108
+
109
+
110
+ def parse_timeout(timeout: float = 5.0) -> float:
111
+ """Gets the max_timeout value from the request object, otherwise uses the provided timeout.
112
+
113
+ Args:
114
+ timeout (float, optional): The timeout to use if no max_timeout is provided in the request. Defaults to 5.0.
115
+
116
+ Returns:
117
+ float: The parsed max_timeout value.
118
+ """
119
+ try:
120
+ max_timeout = request.args.get("max_timeout", timeout, type=float)
121
+ except (ValueError, TypeError):
122
+ max_timeout = timeout
123
+
124
+ return max_timeout
125
+
126
+
127
+ def parse_query_params(request: Request, limit: int = 10, timeout: float = 5.0):
128
+ """Parse the standard query params."""
129
+ query_sources_str = request.args.get("sources")
130
+
131
+ limit = request.args.get("limit", limit, type=int)
132
+
133
+ type_classification = request.args.get("classification", CLASSIFICATION.UNRESTRICTED)
134
+ no_annotation = request.args.get("no_annotation", "false").lower() in ("true", "1", "")
135
+ no_cache = request.args.get("no_cache", "false").lower() in ("true", "1", "")
136
+ raw = request.args.get("include_raw", "false").lower() in ("true", "1", "")
137
+ exclude_unset = request.args.get("exclude_unset", "false").lower() in ("true", "1", "")
138
+
139
+ if query_sources_str:
140
+ if "|" in query_sources_str:
141
+ query_sources = query_sources_str.split("|")
142
+ else:
143
+ query_sources = query_sources_str.split(",")
144
+ else:
145
+ query_sources = []
146
+
147
+ return ParsedParams(
148
+ query_sources=query_sources,
149
+ max_timeout=parse_timeout(timeout),
150
+ limit=limit,
151
+ type_classification=type_classification,
152
+ no_annotation=no_annotation,
153
+ include_raw=raw,
154
+ exclude_unset=exclude_unset,
155
+ no_cache=no_cache,
156
+ )
157
+
158
+
159
+ def generate_params(
160
+ limit: int, timeout: float, no_annotation: bool = False, include_raw: bool = False, no_cache: bool = False
161
+ ):
162
+ """Generates HTTP request parameters for a call to a source.
163
+
164
+ Args:
165
+ limit (int): The maximum number of results to return.
166
+ timeout (float): The maximum amount of time to wait for a response.
167
+ no_annotation (bool): Whether to include annotations. Defaults to False.
168
+ include_raw (bool): Whether to include the raw results. Defaults to False.
169
+ no_cache (bool): Allows to bypass the cache. Defaults to False.
170
+
171
+ Returns:
172
+ str: A string of HTTP params formatted so that it can be appended to a url
173
+ (in the format "?param1=value1&param2=value2")
174
+ """
175
+ params = {
176
+ "limit": limit,
177
+ "max_timeout": max(timeout * 0.95, 0.5),
178
+ "deadline": (datetime.now(timezone.utc) + timedelta(seconds=max(timeout * 0.95, 0.5))).timestamp(),
179
+ }
180
+
181
+ if no_annotation:
182
+ params["no_annotation"] = True
183
+
184
+ if include_raw:
185
+ params["include_raw"] = True
186
+
187
+ if no_cache:
188
+ params["no_cache"] = True
189
+
190
+ return "?" + "&".join(f"{key}={val}" for key, val in params.items())
191
+
192
+
193
+ def process_exception(source_name: str, rsp: Response, exception: Exception):
194
+ """Parses an exception in a response.
195
+
196
+ Args:
197
+ source_name (str): The name of the source from which the exception came.
198
+ rsp (Response): The response object.
199
+ exception (Exception): The exception to parse.
200
+
201
+ Returns:
202
+ str: The formatted string of the parsed exception.
203
+ """
204
+ if isinstance(exception, ConnectionError):
205
+ return f"Could not connect to the specified plugin: {source_name}."
206
+
207
+ if rsp and isinstance(exception, json.JSONDecodeError):
208
+ logger.warning("%s: %s", source_name, rsp.status_code)
209
+ if rsp.status_code == 404:
210
+ return None
211
+
212
+ if rsp.status_code == 422:
213
+ return f"{source_name} was unable to process this selector."
214
+ elif rsp.status_code > 299:
215
+ return f"{source_name} experienced an unknown error"
216
+
217
+ err_msg = f"{source_name} did not return a response in the expected format"
218
+ err_id = log_error(logger, err_msg, exception)
219
+
220
+ return f"{err_msg}. Error ID: {err_id}"
221
+
222
+
223
+ def get_sources(user: dict[str, str]):
224
+ """Gets all the sources the user is allowed to submit requests to.
225
+
226
+ This must first be checked against what systems the user is allowed to see. Additional type level checking is then
227
+ done later to provide feedback to user.
228
+
229
+ Args:
230
+ user (dict[str, Any]): The user for which we want all sources.
231
+
232
+ Returns:
233
+ list[ExternalSource]: The sources that the user has access to.
234
+ """
235
+ return [
236
+ x for x in config.api.external_sources if CLASSIFICATION.is_accessible(user["classification"], x.classification)
237
+ ]
238
+
239
+
240
+ def parse_response(source: ExternalSource, user: dict[str, Any], api_response: Any) -> list[QueryEntry]:
241
+ """Parses the response from a source.
242
+
243
+ Args:
244
+ source (ExternalSource): The source that returned the response.
245
+ user (dict[str, Any]): The user that initiated the request.
246
+ api_response (Any): The response provided by the source.
247
+
248
+ Returns:
249
+ list[QueryEntry]: The list of results contained in the response.
250
+ """
251
+ with elasticapm.capture_span(source.name, "parsing"):
252
+ if isinstance(api_response, dict):
253
+ api_response = [api_response]
254
+
255
+ logger.debug(
256
+ "Validating response from source %s, returning %s annotations in %s items",
257
+ source.name,
258
+ len(list(itertools.chain.from_iterable(entry.get("annotations", []) for entry in api_response))),
259
+ len(api_response),
260
+ )
261
+
262
+ if source.production:
263
+ logger.debug(f"Skipping validation for production source {source.name}")
264
+ items: list[QueryEntry] = [QueryEntry.model_construct(data) for data in api_response]
265
+ else:
266
+ items = [QueryEntry.model_validate(data, context={"user": user}) for data in api_response]
267
+
268
+ return items
269
+
270
+
271
+ def parse_bulk_response(
272
+ source: ExternalSource,
273
+ user: dict[str, Any],
274
+ api_response: dict[str, dict[str, Any]],
275
+ latency: Optional[float] = None,
276
+ ) -> dict[str, dict[str, QueryResult]]:
277
+ """Parses the response from a bulk request to a source.
278
+
279
+ Args:
280
+ source (ExternalSource): The source that returned the response.
281
+ user (dict[str, Any]): The user that initiated the request.
282
+ api_response (dict[str, dict[str, Any]]): The response provided by the source.
283
+ latency (Optional[float]): The time between the request and the response, in milliseconds.
284
+
285
+ Returns:
286
+ dict[str, dict[str, QueryResult]]: A dict containing each type and their corresponding result sets.
287
+ """
288
+ bulk_result: dict[str, dict[str, QueryResult]] = {}
289
+
290
+ if source.production:
291
+ logger.debug(f"Skipping validation for production source {source.name}")
292
+
293
+ with elasticapm.capture_span(f"{source.name}-bulk", "parsing"):
294
+ for type in api_response:
295
+ bulk_result.setdefault(type, {})
296
+ for value in api_response[type]:
297
+ data: dict[str, Any] = dict(
298
+ type=type,
299
+ value=value,
300
+ source=source.name,
301
+ maintainer=source.maintainer,
302
+ datahub_link=source.datahub_link,
303
+ documentation_link=source.documentation_link,
304
+ )
305
+
306
+ # This allows plugins to overwrite the default values if they want
307
+ data = {**data, **api_response[type][value], "latency": latency or 0.0}
308
+
309
+ logger.debug(
310
+ "Validating bulk response from source %s (%s), returning %s annotations in %s items, using user %s",
311
+ source.name,
312
+ "production" if source.production else "not production",
313
+ len(
314
+ list(
315
+ itertools.chain.from_iterable(
316
+ entry.get("annotations", []) for entry in data.get("items", [])
317
+ )
318
+ )
319
+ ),
320
+ len(data.get("items", [])),
321
+ user.get("uname", user.get("email", None)),
322
+ )
323
+
324
+ if source.production:
325
+ bulk_result[type][value] = QueryResult.model_construct(**data)
326
+ else:
327
+ bulk_result[type][value] = QueryResult.model_validate(
328
+ data,
329
+ context={"user": user},
330
+ )
331
+
332
+ return bulk_result
333
+
334
+
335
+ def handle_validation_error(source: ExternalSource, err: ValidationError) -> str:
336
+ """Handles errors that occured while trying to parse a response from a source.
337
+
338
+ Args:
339
+ source (ExternalSource): The source from which the invalid response came.
340
+ err (ValidationError): The error in question.
341
+
342
+ Returns:
343
+ str: A formatted error message.
344
+ """
345
+ pydantic_errs: list[str] = []
346
+
347
+ for validation_err in err.errors():
348
+ loc = ".".join(
349
+ section if isinstance(section, str) else f"[{str(section)}]" for section in validation_err["loc"]
350
+ )
351
+ pydantic_errs.append(f'"{loc}": {validation_err["msg"]}')
352
+
353
+ err_msg = f"{source.name} returned an improperly formatted response: {', '.join(pydantic_errs)}"
354
+ err_id = log_error(logger, err_msg, err)
355
+ return f"{err_msg}. Error ID: {err_id}"
356
+
357
+
358
+ def query_external(
359
+ user: dict[str, Any],
360
+ source: ExternalSource,
361
+ type_name: str,
362
+ value: str,
363
+ limit: int,
364
+ timeout: float,
365
+ access_token: str,
366
+ clue_access_token: str | None,
367
+ no_annotation: bool = False,
368
+ no_cache: bool = False,
369
+ include_raw: bool = True,
370
+ apm_transaction: Optional[Transaction] = None,
371
+ ) -> Optional[QueryResult]:
372
+ """Query the external source for details."""
373
+ if apm_transaction:
374
+ execution_context.set_transaction(apm_transaction)
375
+
376
+ finish_result = functools.partial(build_result, type_name, value, source)
377
+
378
+ with elasticapm.capture_span(query_external.__name__, span_type="greenlet"):
379
+ if type_name not in type_service.all_supported_types(user, access_token=access_token).get(source.name, {}):
380
+ return finish_result(error="invalid_type")
381
+
382
+ if config.api.audit:
383
+ audit(
384
+ [
385
+ f"source={source.name}",
386
+ f"type={type_name}",
387
+ f"value={value}",
388
+ f"no_annotation={no_annotation}",
389
+ f"include_raw={include_raw}",
390
+ f"no_cache={no_cache}",
391
+ ],
392
+ {},
393
+ user,
394
+ query_external,
395
+ )
396
+
397
+ if quota_error := user_service.check_quota(source, user):
398
+ return finish_result(error=quota_error)
399
+
400
+ # perform the lookup, ensuring access controls are applied
401
+ url = f"{source.url}/lookup/{type_name}/{value}/"
402
+ response: Any = None
403
+ try:
404
+ rsp = None
405
+ start = time.perf_counter()
406
+ with elasticapm.capture_span(url, "http"):
407
+ parsed_url = urlparse(url)
408
+ rsp = get_client(f"{parsed_url.scheme}://{parsed_url.netloc}", timeout).get(
409
+ parsed_url.path + generate_params(limit, timeout, no_annotation, include_raw, no_cache),
410
+ headers=generate_headers(access_token, clue_access_token),
411
+ )
412
+
413
+ response = json.load(rsp)
414
+ except Exception as exception:
415
+ return finish_result(
416
+ error=process_exception(source.name, rsp, exception),
417
+ latency=(time.perf_counter() - start) * 1000,
418
+ )
419
+ finally:
420
+ user_service.release_quota(source, user)
421
+
422
+ if response and "api_error_message" in response and response["api_error_message"]:
423
+ logger.warning(f"Error response from {url}: {response['api_error_message']}")
424
+ return finish_result(
425
+ error=response["api_error_message"],
426
+ latency=(time.perf_counter() - start) * 1000,
427
+ )
428
+
429
+ try:
430
+ result = finish_result(latency=(time.perf_counter() - start) * 1000)
431
+
432
+ api_response = response["api_response"]
433
+ if api_response:
434
+ result.items = parse_response(source, user, api_response)
435
+
436
+ logger.debug("Returning valid result from source %s", source)
437
+
438
+ return result
439
+ except ValidationError as err:
440
+ logger.exception("Validation error on response from %s", source)
441
+ return finish_result(
442
+ error=handle_validation_error(source, err),
443
+ latency=(time.perf_counter() - start) * 1000,
444
+ )
445
+
446
+
447
+ def enrich(type_name: str, value: str, user: dict[str, Any]): # noqa: C901
448
+ """Queries all available sources with the provided value.
449
+
450
+ Args:
451
+ type_name (str): The type of the value to query.
452
+ value (str): The value to query.
453
+ user (dict[str, Any]): The user requesting the query.
454
+
455
+ Raises:
456
+ AuthenticationException: Raised whenever there is a problem with the authentication.
457
+
458
+ Returns:
459
+ dict[str, QueryResult]: A dict of each source and their query result.
460
+ """
461
+ query_params = parse_query_params(request=request)
462
+ query_sources = query_params.query_sources
463
+ available_sources = get_sources(user)
464
+
465
+ access_token = request.headers.get("Authorization", type=str)
466
+ if not access_token:
467
+ raise AuthenticationException("Access token is required to enrich.")
468
+ access_token = access_token.split(" ")[1]
469
+
470
+ logger.debug(
471
+ f"Beginning enrichment for single selector on sources "
472
+ f"[{','.join(query_sources or [source.name for source in available_sources])}]"
473
+ )
474
+
475
+ results: dict[str, QueryResult] = {}
476
+
477
+ pool_size = min(len(query_sources or available_sources), int(os.environ.get("EXECUTOR_THREADS", 32)))
478
+ thread_pool = Pool(pool_size)
479
+
480
+ greenlets: list[tuple[str, str, ExternalSource, Greenlet[Any, Optional[QueryResult]]]] = []
481
+ # create searches for external sources
482
+ for source in available_sources:
483
+ if query_sources and source.name not in query_sources:
484
+ continue
485
+ elif not query_sources and not source.include_default:
486
+ continue
487
+
488
+ finish_result = functools.partial(build_result, type_name, value, source)
489
+
490
+ obo_access_token, error = auth_service.check_obo(source, access_token, user["uname"])
491
+
492
+ # TODO: sa-clue support
493
+ if not obo_access_token and source.obo_target:
494
+ results[source.name] = finish_result(error="You must have a valid JWT to access this plugin.")
495
+ continue
496
+
497
+ if error:
498
+ results[source.name] = finish_result(error=error)
499
+ continue
500
+
501
+ # check query against the max supported classification of the external system
502
+ # if this is not supported, we should let the user know.
503
+ if not CLASSIFICATION.is_accessible(
504
+ source.max_classification or CLASSIFICATION.UNRESTRICTED, query_params.type_classification
505
+ ):
506
+ results[source.name] = finish_result(
507
+ error=f"Type classification exceeds max classification of source: {source.name}."
508
+ )
509
+ continue
510
+
511
+ greenlets.append(
512
+ (
513
+ type_name,
514
+ value,
515
+ source,
516
+ thread_pool.spawn(
517
+ query_external,
518
+ user=user,
519
+ source=source,
520
+ type_name=type_name,
521
+ value=value,
522
+ limit=query_params.limit,
523
+ timeout=query_params.max_timeout,
524
+ access_token=obo_access_token or access_token,
525
+ clue_access_token=access_token if obo_access_token else None,
526
+ no_annotation=query_params.no_annotation,
527
+ include_raw=query_params.include_raw,
528
+ no_cache=query_params.no_cache,
529
+ apm_transaction=execution_context.get_transaction(),
530
+ ),
531
+ )
532
+ )
533
+
534
+ thread_pool.join(timeout=query_params.max_timeout * 2)
535
+
536
+ for type_name, value, source, greenlet in greenlets:
537
+ result = greenlet.value
538
+ if result:
539
+ if result.error == "invalid_type":
540
+ continue
541
+
542
+ results[source.name] = result
543
+ else:
544
+ results[source.name] = build_result(
545
+ type_name, value, source, "Request Timed Out" if not greenlet.exception else str(greenlet.exception)
546
+ )
547
+
548
+ thread_pool.kill(block=False)
549
+
550
+ return results
551
+
552
+
553
+ def bulk_query_external( # noqa: C901
554
+ data: list[Selector],
555
+ user: dict[str, Any],
556
+ source: ExternalSource,
557
+ limit: int,
558
+ timeout: float,
559
+ access_token: str,
560
+ clue_access_token: str | None,
561
+ no_annotation: bool = False,
562
+ no_cache: bool = False,
563
+ include_raw: bool = True,
564
+ apm_transaction: Optional[Transaction] = None,
565
+ ) -> dict[str, dict[str, QueryResult]]:
566
+ """Query the external source for details."""
567
+ if apm_transaction:
568
+ execution_context.set_transaction(apm_transaction)
569
+
570
+ with elasticapm.capture_span(bulk_query_external.__name__, span_type="greenlet"):
571
+ supported_types = type_service.all_supported_types(user, access_token=access_token).get(source.name, {})
572
+ bulk_result: dict[str, dict[str, QueryResult]] = {}
573
+
574
+ filtered_data: list[Selector] = []
575
+ for entry in data:
576
+ bulk_result.setdefault(entry.type, {})
577
+
578
+ if entry.type not in supported_types:
579
+ bulk_result[entry.type][entry.value] = build_result(entry.type, entry.value, source, "invalid_type")
580
+ continue
581
+
582
+ filtered_data.append(entry)
583
+
584
+ data = filtered_data
585
+
586
+ if config.api.audit:
587
+ values = ",".join(f"{entry.type}:{entry.value}" for entry in data)
588
+
589
+ audit(
590
+ [
591
+ f"source={source.name}",
592
+ f"values={values}",
593
+ f"no_annotation={no_annotation}",
594
+ f"include_raw={include_raw}",
595
+ f"no_cache={no_cache}",
596
+ ],
597
+ {},
598
+ user,
599
+ bulk_query_external,
600
+ )
601
+
602
+ if quota_error := user_service.check_quota(source, user):
603
+ return build_result(entry.type, entry.value, source, error=quota_error)
604
+
605
+ # perform the lookup, ensuring access controls are applied
606
+ error = None
607
+ latency = None
608
+
609
+ url = f"{source.url}/lookup/"
610
+ response: Any = None
611
+ try:
612
+ start = time.perf_counter()
613
+ rsp: Optional[HTTPResponse] = None
614
+ with elasticapm.capture_span(url, "http"):
615
+ parsed_url = urlparse(url)
616
+
617
+ rsp = get_client(f"{parsed_url.scheme}://{parsed_url.netloc}", timeout * 1.1).post(
618
+ parsed_url.path + generate_params(limit, timeout, no_annotation, include_raw, no_cache),
619
+ body=json.dumps([entry.model_dump(exclude_none=True, exclude_unset=True) for entry in data]),
620
+ headers=generate_headers(access_token, clue_access_token),
621
+ )
622
+
623
+ logger.debug(f"{rsp.status_code}: {url}")
624
+ response = json.load(rsp)
625
+ except Exception as exception:
626
+ error = process_exception(source.name, rsp, exception)
627
+ finally:
628
+ end = time.perf_counter()
629
+ latency = (end - start) * 1000
630
+
631
+ if response and "api_error_message" in response and response["api_error_message"]:
632
+ error = response["api_error_message"]
633
+
634
+ if error:
635
+ for entry in data:
636
+ bulk_result[entry.type][entry.value] = build_result(
637
+ entry.type, entry.value, source, error=error, latency=latency
638
+ )
639
+
640
+ return bulk_result
641
+
642
+ try:
643
+ api_response = response["api_response"]
644
+ # handle case of 200 OK for not found.
645
+ if not api_response:
646
+ for entry in data:
647
+ bulk_result[entry.type][entry.value] = build_result(
648
+ entry.type, entry.value, source, latency=latency
649
+ )
650
+ else:
651
+ bulk_result = parse_bulk_response(source, user, api_response, latency)
652
+ except ValidationError as err:
653
+ error_message = handle_validation_error(source, err)
654
+
655
+ for entry in data:
656
+ bulk_result[entry.type][entry.value] = build_result(
657
+ entry.type, entry.value, source, error=error_message, latency=latency
658
+ )
659
+
660
+ return bulk_result
661
+
662
+
663
+ def bulk_enrich(data: list[Selector], user: dict[str, Any]): # noqa: C901
664
+ """create searches for external sources"""
665
+ query_params = parse_query_params(request=request)
666
+ query_sources = query_params.query_sources
667
+ available_sources = get_sources(user)
668
+
669
+ logger.debug(
670
+ f"Beginning enrichment for {len(data)} selectors on sources "
671
+ f"[{','.join(query_sources or [source.name for source in available_sources])}]"
672
+ )
673
+
674
+ access_token = request.headers.get("Authorization", type=str)
675
+ if not access_token:
676
+ raise AuthenticationException("Access token is required to enrich.")
677
+ access_token = access_token.split(" ")[1]
678
+
679
+ if len(data) < 1:
680
+ raise InvalidDataException("You must provide at least one value to lookup.")
681
+
682
+ bulk_result: dict[str, dict[str, dict[str, QueryResult]]] = {}
683
+ for entry in data:
684
+ bulk_result.setdefault(entry.type, {})
685
+ bulk_result[entry.type].setdefault(entry.value, {})
686
+
687
+ pool_size = min(len(data) * len(query_sources or available_sources), int(os.environ.get("EXECUTOR_THREADS", 32)))
688
+ thread_pool = Pool(pool_size)
689
+
690
+ greenlets: list[tuple[list[Selector], ExternalSource, Greenlet[Any, dict[str, dict[str, QueryResult]]]]] = []
691
+ for source in available_sources:
692
+ if query_sources and source.name not in query_sources:
693
+ continue
694
+ elif not query_sources and not source.include_default:
695
+ continue
696
+
697
+ obo_access_token, error = auth_service.check_obo(source, access_token, user["uname"])
698
+
699
+ if error:
700
+ logger.error("%s: %s", source.name, error)
701
+
702
+ # TODO: sa-clue support
703
+ if not obo_access_token and source.obo_target:
704
+ for entry in data:
705
+ bulk_result[entry.type][entry.value][source.name] = build_result(
706
+ entry.type, entry.value, source, "You must have a valid JWT to access this plugin."
707
+ )
708
+ continue
709
+
710
+ data_for_source: list[Selector] = []
711
+
712
+ # check query against the max supported classification of the external system
713
+ # if this is not supported, we should let the user know.
714
+ for entry in data:
715
+ if entry.sources is not None and source.name not in entry.sources:
716
+ continue
717
+
718
+ if not CLASSIFICATION.is_accessible(
719
+ source.max_classification or CLASSIFICATION.UNRESTRICTED,
720
+ entry.classification,
721
+ ):
722
+ if source.name in (entry.sources or []):
723
+ bulk_result[entry.type][entry.value][source.name] = build_result(
724
+ entry.type,
725
+ entry.value,
726
+ source,
727
+ (
728
+ f"Selector classification ({entry.classification}) exceeds max classification "
729
+ f"of source: {source.name} ({source.max_classification})."
730
+ ),
731
+ )
732
+
733
+ continue
734
+
735
+ data_for_source.append(entry)
736
+
737
+ greenlets.append(
738
+ (
739
+ data_for_source,
740
+ source,
741
+ thread_pool.spawn(
742
+ bulk_query_external,
743
+ data=data_for_source,
744
+ user=user,
745
+ source=source,
746
+ limit=query_params.limit,
747
+ timeout=query_params.max_timeout,
748
+ access_token=obo_access_token or access_token,
749
+ clue_access_token=access_token if obo_access_token else None,
750
+ no_annotation=query_params.no_annotation,
751
+ no_cache=query_params.no_cache,
752
+ include_raw=query_params.include_raw,
753
+ apm_transaction=execution_context.get_transaction(),
754
+ ),
755
+ )
756
+ )
757
+
758
+ start = time.perf_counter()
759
+
760
+ thread_pool.join(timeout=query_params.max_timeout * 2.2)
761
+
762
+ for greenlet_data, source, greenlet in greenlets:
763
+ result = greenlet.value
764
+
765
+ if not result:
766
+ for entry in greenlet_data:
767
+ bulk_result[entry.type][entry.value][source.name] = build_result(
768
+ entry.type,
769
+ entry.value,
770
+ source,
771
+ "Request Timed Out" if not greenlet.exception else str(greenlet.exception),
772
+ (time.perf_counter() - start) * 1000,
773
+ )
774
+
775
+ continue
776
+
777
+ for type, values in result.items():
778
+ for value, query_result in values.items():
779
+ if result[type][value].error == "invalid_type":
780
+ continue
781
+
782
+ bulk_result[type][value][source.name] = query_result
783
+
784
+ thread_pool.kill(block=False)
785
+
786
+ return bulk_result