clue-api 1.0.0.dev7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. clue/.gitignore +21 -0
  2. clue/__init__.py +0 -0
  3. clue/api/__init__.py +211 -0
  4. clue/api/base.py +99 -0
  5. clue/api/v1/__init__.py +82 -0
  6. clue/api/v1/actions.py +92 -0
  7. clue/api/v1/auth.py +243 -0
  8. clue/api/v1/configs.py +83 -0
  9. clue/api/v1/fetchers.py +94 -0
  10. clue/api/v1/lookup.py +221 -0
  11. clue/api/v1/registration.py +109 -0
  12. clue/api/v1/static.py +94 -0
  13. clue/app.py +166 -0
  14. clue/cache/__init__.py +129 -0
  15. clue/common/__init__.py +0 -0
  16. clue/common/classification.py +1006 -0
  17. clue/common/classification.yml +130 -0
  18. clue/common/dict_utils.py +130 -0
  19. clue/common/exceptions.py +199 -0
  20. clue/common/forge.py +152 -0
  21. clue/common/json_utils.py +10 -0
  22. clue/common/list_utils.py +11 -0
  23. clue/common/logging/__init__.py +291 -0
  24. clue/common/logging/audit.py +157 -0
  25. clue/common/logging/format.py +42 -0
  26. clue/common/regex.py +31 -0
  27. clue/common/str_utils.py +213 -0
  28. clue/common/swagger.py +139 -0
  29. clue/common/uid.py +47 -0
  30. clue/config.py +60 -0
  31. clue/constants/__init__.py +0 -0
  32. clue/constants/supported_types.py +38 -0
  33. clue/cronjobs/__init__.py +30 -0
  34. clue/cronjobs/plugins.py +32 -0
  35. clue/error.py +129 -0
  36. clue/gunicorn_config.py +29 -0
  37. clue/healthz.py +74 -0
  38. clue/helper/discover.py +53 -0
  39. clue/helper/headers.py +30 -0
  40. clue/helper/oauth.py +128 -0
  41. clue/models/__init__.py +0 -0
  42. clue/models/actions.py +243 -0
  43. clue/models/config.py +456 -0
  44. clue/models/fetchers.py +136 -0
  45. clue/models/graph.py +162 -0
  46. clue/models/model_list.py +52 -0
  47. clue/models/network.py +430 -0
  48. clue/models/results/__init__.py +34 -0
  49. clue/models/results/base.py +10 -0
  50. clue/models/results/graph.py +26 -0
  51. clue/models/results/image.py +22 -0
  52. clue/models/results/status.py +55 -0
  53. clue/models/results/validation.py +57 -0
  54. clue/models/selector.py +67 -0
  55. clue/models/utils.py +52 -0
  56. clue/models/validators.py +19 -0
  57. clue/patched.py +8 -0
  58. clue/plugin/__init__.py +1008 -0
  59. clue/plugin/helpers/__init__.py +0 -0
  60. clue/plugin/helpers/central_server.py +27 -0
  61. clue/plugin/helpers/email_render.py +228 -0
  62. clue/plugin/helpers/token.py +34 -0
  63. clue/plugin/helpers/trino.py +103 -0
  64. clue/plugin/interactive.py +270 -0
  65. clue/plugin/models.py +19 -0
  66. clue/plugin/utils.py +78 -0
  67. clue/remote/__init__.py +0 -0
  68. clue/remote/datatypes/__init__.py +130 -0
  69. clue/remote/datatypes/cache.py +62 -0
  70. clue/remote/datatypes/events.py +118 -0
  71. clue/remote/datatypes/hash.py +193 -0
  72. clue/remote/datatypes/queues/__init__.py +0 -0
  73. clue/remote/datatypes/queues/comms.py +62 -0
  74. clue/remote/datatypes/set.py +96 -0
  75. clue/remote/datatypes/user_quota_tracker.py +54 -0
  76. clue/security/__init__.py +211 -0
  77. clue/security/obo.py +95 -0
  78. clue/security/utils.py +34 -0
  79. clue/services/action_service.py +186 -0
  80. clue/services/auth_service.py +348 -0
  81. clue/services/config_service.py +38 -0
  82. clue/services/fetcher_service.py +203 -0
  83. clue/services/jwt_service.py +233 -0
  84. clue/services/lookup_service.py +786 -0
  85. clue/services/type_service.py +165 -0
  86. clue/services/user_service.py +152 -0
  87. clue_api-1.0.0.dev7.dist-info/METADATA +111 -0
  88. clue_api-1.0.0.dev7.dist-info/RECORD +91 -0
  89. clue_api-1.0.0.dev7.dist-info/WHEEL +4 -0
  90. clue_api-1.0.0.dev7.dist-info/entry_points.txt +8 -0
  91. clue_api-1.0.0.dev7.dist-info/licenses/LICENSE +11 -0
clue/models/config.py ADDED
@@ -0,0 +1,456 @@
1
+ # ruff: noqa: D101
2
+ import logging
3
+ import os
4
+ from email.utils import parseaddr
5
+ from enum import Enum
6
+ from pathlib import Path
7
+ from uuid import uuid4
8
+
9
+ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
10
+ from pydantic_core import Url
11
+ from pydantic_settings import (
12
+ BaseSettings,
13
+ PydanticBaseSettingsSource,
14
+ SettingsConfigDict,
15
+ YamlConfigSettingsSource,
16
+ )
17
+ from typing_extensions import Self
18
+
19
+ from clue.common import forge
20
+ from clue.common.exceptions import ClueValueError
21
+ from clue.common.logging.format import BRL_DATE_FORMAT, BRL_LOG_FORMAT
22
+ from clue.common.str_utils import default_string_value
23
+
24
+ AUTO_PROPERTY_TYPE = ["access", "classification", "type", "role", "remove_role", "group"]
25
+ DEFAULT_EMAIL_FIELDS = ["email", "emails", "extension_selectedEmailAddress", "otherMails", "preferred_username", "upn"]
26
+ DEFAULT_USER_FIELDS = ["uname", "preferred_username", "upn"]
27
+ DEFAULT_USER_NAME_FIELDS = ["name", "displayName"]
28
+ APP_NAME = default_string_value(env_name="APP_NAME", default="clue").replace("-dev", "") # type: ignore[union-attr]
29
+ CLASSIFICATION = forge.get_classification()
30
+
31
+
32
+ class PasswordRequirement(BaseModel):
33
+ lower: bool = Field(description="Password must contain lowercase letters", default=False)
34
+ number: bool = Field(description="Password must contain numbers", default=False)
35
+ special: bool = Field(description="Password must contain special characters", default=False)
36
+ upper: bool = Field(description="Password must contain uppercase letters", default=False)
37
+ min_length: int = Field(description="Minimum password length", default=12)
38
+
39
+
40
+ class OAuthProvider(BaseModel):
41
+ auto_create: bool = Field(default=True, description="Auto-create users if they are missing")
42
+ auto_sync: bool = Field(default=False, description="Should we automatically sync with OAuth provider?")
43
+ uid_randomize: bool = Field(
44
+ default=False,
45
+ description="Should we generate a random username for the authenticated user?",
46
+ )
47
+ uid_randomize_digits: int = Field(
48
+ default=0,
49
+ description="How many digits should we add at the end of the username?",
50
+ )
51
+ uid_randomize_delimiter: str = Field(
52
+ default="-",
53
+ description="What is the delimiter used by the random name generator?",
54
+ )
55
+ uid_regex: str | None = Field(
56
+ description="Regex used to parse an email address and capture parts to create a user ID out of it", default=None
57
+ )
58
+ uid_format: str | None = Field(
59
+ description="Format of the user ID based on the captured parts from the regex", default=None
60
+ )
61
+ client_id: str | None = Field(description="ID of your application to authenticate to the OAuth provider")
62
+ client_secret: str | None = Field(
63
+ description="Password to your application to authenticate to the OAuth provider", default=None
64
+ )
65
+ required_groups: list[str] = Field(
66
+ default=[], description="The groups the JWT must contain in order to allow access"
67
+ )
68
+ role_map: dict[str, str] = Field(default={}, description="A mapping of OAuth groups to clue roles")
69
+ classification_map: dict[str, str] = Field(
70
+ default={}, description="A mapping of OAuth groups to classification levels"
71
+ )
72
+ access_token_url: str | None = Field(description="URL to get access token")
73
+ authorize_url: str | None = Field(description="URL used to authorize access to a resource")
74
+ api_base_url: str | None = Field(description="Base URL for downloading the user's and groups info")
75
+ audience: str | None = Field(
76
+ description="The audience to validate against. Only must be set if audience is different than the client id."
77
+ )
78
+ scope: str = Field(description="The scope to validate against")
79
+ iss: str | None = Field(description="Optional issuer field for JWT validation", default=None)
80
+ jwks_uri: str = Field(description="URL used to verify if a returned JWKS token is valid")
81
+
82
+
83
+ class OAuth(BaseModel):
84
+ enabled: bool = Field(description="Enable use of OAuth?", default=False)
85
+ gravatar_enabled: bool = Field(description="Enable gravatar?", default=False)
86
+ providers: dict[str, OAuthProvider] = Field(default={}, description="OAuth provider configuration")
87
+ other_audiences: list[str] | None = Field(
88
+ default=None, description="What other audiences in JWT tokens should Clue accept?"
89
+ )
90
+
91
+ @model_validator(mode="before")
92
+ @classmethod
93
+ def prepare_model(
94
+ cls, # noqa: ANN102
95
+ oauth_data: dict[str, dict[str, dict | OAuthProvider]], # noqa: ANN102
96
+ ) -> dict[str, dict[str, dict | OAuthProvider]]:
97
+ """Validates the oauth data, and adds the client secret if it's not already there.
98
+
99
+ Args:
100
+ oauth_data (dict[str, dict[str, dict | OAuthProvider]]): The data to validate.
101
+
102
+ Returns:
103
+ dict[str, dict[str, dict | OAuthProvider]]: The validated data with the client secrets.
104
+ """
105
+ if "providers" in oauth_data and isinstance(oauth_data["providers"], dict):
106
+ for name, provider in oauth_data["providers"].items():
107
+ if isinstance(provider, OAuthProvider):
108
+ provider.client_secret = default_string_value(
109
+ provider.client_secret,
110
+ env_name=f"{name.upper()}_CLIENT_SECRET",
111
+ )
112
+ elif isinstance(provider, dict):
113
+ provider["client_secret"] = default_string_value(
114
+ provider.get("client_secret", None),
115
+ env_name=f"{name.upper()}_CLIENT_SECRET",
116
+ )
117
+
118
+ return oauth_data
119
+
120
+
121
+ class ServiceAccountCreds(BaseModel):
122
+ username: str = Field(description="Username of the service account")
123
+ password: str = Field(description="Password of the service account")
124
+ provider: str = Field(description="What OAuth provider does this service account connect to?")
125
+
126
+ @model_validator(mode="before")
127
+ @classmethod
128
+ def prepare_model(cls, data: dict[str, str]) -> dict[str, str]: # noqa: ANN102
129
+ """Adds the service account password to the data if missing.
130
+
131
+ Args:
132
+ data (dict[str, str]): The data to validate.
133
+
134
+ Returns:
135
+ dict[str, str]: The data including the password.
136
+ """
137
+ if "password" not in data and "provider" in data:
138
+ if env_pass := os.getenv(f'SA_{data["provider"].upper()}_PASSWORD'):
139
+ data["password"] = env_pass
140
+
141
+ return data
142
+
143
+
144
+ class ServiceAccount(BaseModel):
145
+ enabled: bool = Field(description="Enable use of a service account?", default=False)
146
+ accounts: list[ServiceAccountCreds] = Field(
147
+ description="A list of service accounts on a per-provider basis", default=[]
148
+ )
149
+
150
+ @model_validator(mode="after")
151
+ def validate_model(self: Self) -> Self:
152
+ """Validates the model.
153
+
154
+ Raises:
155
+ ClueValueError: Raised whenever there is more than one service account per provider.
156
+
157
+ Returns:
158
+ Self: The validated model.
159
+ """
160
+ providers = {account.provider for account in self.accounts}
161
+
162
+ if len(providers) != len(self.accounts):
163
+ raise ClueValueError("You may only have one service account per provider.")
164
+
165
+ return self
166
+
167
+
168
+ class Auth(BaseModel):
169
+ allow_apikeys: bool = Field(description="Allow API keys?", default=False)
170
+ apikeys: dict[str, str] = Field(default={}, description="API Keys available in the system")
171
+ propagate_clue_key: bool = Field(
172
+ default=True, description="Should clue include the root clue token in requests when OBO is used?"
173
+ )
174
+ oauth: OAuth = OAuth()
175
+ service_account: ServiceAccount = ServiceAccount()
176
+
177
+ @model_validator(mode="after")
178
+ def validate_model(self: Self) -> Self:
179
+ """Validates the model.
180
+
181
+ Raises:
182
+ ClueValueError: Raised whenever there is an invalid value in the model.
183
+
184
+ Returns:
185
+ Self: The validated model.
186
+ """
187
+ if not self.service_account.enabled:
188
+ return self
189
+
190
+ if not self.oauth.enabled:
191
+ raise ClueValueError("In order to use service accounts to connect to plugins, you must have oauth enabled.")
192
+
193
+ for account in self.service_account.accounts:
194
+ if account.provider not in self.oauth.providers:
195
+ raise ClueValueError(
196
+ f"{account.username} is used to connect to non-existent provider {account.provider}."
197
+ )
198
+
199
+ return self
200
+
201
+
202
+ class RedisServer(BaseModel):
203
+ host: str = Field(description="Hostname of Redis instance", default="127.0.0.1")
204
+ port: int = Field(description="Port of Redis instance", default=6379)
205
+
206
+
207
+ class APMServer(BaseModel):
208
+ server_url: str | None = Field(description="URL to API server", default=None)
209
+ token: str | None = Field(description="Authentication token for server", default=None)
210
+
211
+
212
+ class Metrics(BaseModel):
213
+ apm_server: APMServer = APMServer()
214
+ export_interval: int = Field(description="How often should we be exporting metrics?", default=5)
215
+ redis: RedisServer = RedisServer()
216
+
217
+
218
+ class Core(BaseModel):
219
+ metrics: Metrics = Metrics()
220
+ redis: RedisServer = RedisServer()
221
+
222
+
223
+ class LogLevel(str, Enum):
224
+ DEBUG = "DEBUG"
225
+ INFO = "INFO"
226
+ WARNING = "WARNING"
227
+ ERROR = "ERROR"
228
+ CRITICAL = "CRITICAL"
229
+ DISABLED = "DISABLED"
230
+
231
+
232
+ class Logging(BaseModel):
233
+ log_level: LogLevel = Field( # type: ignore
234
+ description="What level of logging should we have?", default=LogLevel.DEBUG
235
+ )
236
+ log_to_console: bool = Field(description="Should we log to console?", default=True)
237
+ log_to_file: bool = Field(description="Should we log to files on the server?", default=False)
238
+ log_directory: str = Field(
239
+ description="If `log_to_file: true`, what is the directory to store logs?", default="/var/log/clue/"
240
+ )
241
+ log_to_syslog: bool = Field(description="Should logs be sent to a syslog server?", default=False)
242
+ syslog_host: str = Field(
243
+ description="If `log_to_syslog: true`, provide hostname/IP of the syslog server?", default="localhost"
244
+ )
245
+ syslog_port: int = Field(description="If `log_to_syslog: true`, provide port of the syslog server?", default=514)
246
+ export_interval: int = Field(description="How often, in seconds, should counters log their values?", default=5)
247
+ log_as_json: bool = Field(description="Log in JSON format?", default=False)
248
+ heartbeat_file: str | None = Field(
249
+ description=(
250
+ "Add a health check to core components.<br>"
251
+ "If `true`, core components will touch this path regularly to tell the container environment it is healthy"
252
+ ),
253
+ default=None,
254
+ )
255
+
256
+
257
+ class ExternalSource(BaseModel):
258
+ name: str = Field(description="Name of the source.")
259
+ classification: str | None = Field(
260
+ description="Minimum classification applied to information from the source and required to know the "
261
+ "existence of the source.",
262
+ default=CLASSIFICATION.UNRESTRICTED,
263
+ )
264
+ max_classification: str | None = Field(
265
+ description="Maximum classification of data that may be handled by the source", default=None
266
+ )
267
+ url: str = Field(description="URL of the upstream source's lookup service.")
268
+ obo_target: str | None = Field(
269
+ description="The name of a target clue should OBO to before forwarding the token", default=None
270
+ )
271
+ maintainer: str | None = Field(
272
+ description="Email contact in the RFC-5322 format 'Full Name <email_address>'.", default=None
273
+ )
274
+ datahub_link: Url | None = Field(description="Link to datahub entry on this enrichment", default=None)
275
+ documentation_link: Url | None = Field(description="Link to documentation on this enrichment", default=None)
276
+ production: bool = Field(
277
+ description="Is this source ready for production? This will disable model validation for increased speeds",
278
+ default=False,
279
+ )
280
+ include_default: bool = Field(
281
+ description="Should this source be included by default, or only when specifically requested?", default=True
282
+ )
283
+ built_in: bool = Field(default=True, description="Is this a source included in the clue configuration files?")
284
+ default_timeout: float = Field(
285
+ default=30.0, description="How long should clue wait by default for action execution?"
286
+ )
287
+
288
+ model_config = ConfigDict(validate_assignment=True)
289
+
290
+ @field_validator("maintainer")
291
+ @classmethod
292
+ def validate_maintainer(cls, maintainer: str | None) -> str | None: # noqa: ANN102
293
+ """Validates the maintainer field.
294
+
295
+ Args:
296
+ maintainer (str | None): The maintainer field to validate. If None, it will be passed through.
297
+
298
+ Raises:
299
+ AssertionError: Raised whenever the field is in an invalid format.
300
+
301
+ Returns:
302
+ str | None: The validated maintainer field.
303
+ """
304
+ if maintainer:
305
+ parsed_addr = parseaddr(maintainer)
306
+ if not (all(parsed_addr) and "@" in parsed_addr[1]):
307
+ raise AssertionError("Maintainer string must be in RFC-5322 format.")
308
+
309
+ return maintainer
310
+
311
+ @field_validator("classification", "max_classification")
312
+ @classmethod
313
+ def validate_classification(cls, cls_str: str) -> str: # noqa: ANN102
314
+ """Validates the classification and max_classification fields.
315
+
316
+ Args:
317
+ cls_str (str): The classification value to validate.
318
+
319
+ Raises:
320
+ AssertionError: Raised whenever the provided classification is not valid.
321
+
322
+ Returns:
323
+ str: The validated classification value.
324
+ """
325
+ cls_str = cls_str.upper()
326
+
327
+ if not CLASSIFICATION.is_valid(cls_str):
328
+ raise AssertionError(f"{cls_str} is not a valid classification")
329
+
330
+ return cls_str
331
+
332
+
333
+ EXAMPLE_EXTERNAL_SOURCE_VT = {
334
+ # This is an example on how this would work with VirusTotal
335
+ "name": "VirusTotal",
336
+ "url": "vt-lookup.namespace.svc.cluster.local",
337
+ "classification": "TLP:CLEAR",
338
+ "max_classification": "TLP:CLEAR",
339
+ }
340
+
341
+ EXAMPLE_EXTERNAL_SOURCE_MB = {
342
+ # This is an example on how this would work with Malware Bazaar
343
+ "name": "Malware Bazaar",
344
+ "url": "mb-lookup.namespace.scv.cluster.local",
345
+ "classification": "TLP:CLEAR",
346
+ "max_classification": "TLP:CLEAR",
347
+ }
348
+
349
+
350
+ class OBOService(BaseModel):
351
+ enabled: bool = Field(default=False, description="Is this service available?")
352
+ scope: str = Field(description="The scope to OBO to.")
353
+ quota: int | None = Field(
354
+ default=None, description="Optional quota for the number of concurrent requests per user to this service"
355
+ )
356
+
357
+
358
+ class UI(BaseModel):
359
+ cors_origins: list[str] = Field(default=[], description="List of valid deployments")
360
+
361
+
362
+ class API(BaseModel):
363
+ audit: bool = Field(description="Should API calls be audited and saved to a separate log file?", default=True)
364
+ debug: bool = Field(description="Enable debugging?", default=False)
365
+ discover_url: str | None = Field(description="Discover URL", default=None)
366
+ external_sources: list[ExternalSource] = Field(description="List of external sources to query", default=[])
367
+ obo_targets: dict[str, OBOService] = Field(description="List of targets clue can OBO to", default={})
368
+ secret_key: str = Field(description="Flask secret key to store cookies, etc.", default_factory=lambda: uuid4().hex)
369
+ session_duration: int = Field(
370
+ description="Duration of the user session before the user has to login again", default=3600
371
+ )
372
+ validate_session_ip: bool = Field(
373
+ description="Validate if the session IP matches the IP the session was created from", default=True
374
+ )
375
+ validate_session_useragent: bool = Field(
376
+ description="Validate if the session useragent matches the useragent the session was created with", default=True
377
+ )
378
+ validate_session_xsrf_token: bool = Field(
379
+ description="Validate if the XSRF token matches the randomly generated token for the session", default=True
380
+ )
381
+ vault_url: str = Field(default="https://vault.vault.svc.cluster.local:8200")
382
+
383
+
384
+ root_path = Path("/etc") / APP_NAME
385
+
386
+ config_locations = [
387
+ root_path / "conf" / "config.yml",
388
+ Path(os.environ.get("CLUE_CONF_FOLDER", root_path)) / "config.yml",
389
+ ]
390
+
391
+ if os.getenv("AZURE_TEST_CONFIG", None) is not None:
392
+ import re
393
+
394
+ logger = logging.getLogger("clue.models.config")
395
+ logger.setLevel(logging.INFO)
396
+ console = logging.StreamHandler()
397
+ console.setLevel(logging.INFO)
398
+ console.setFormatter(logging.Formatter(BRL_LOG_FORMAT, BRL_DATE_FORMAT))
399
+ logger.addHandler(console)
400
+
401
+ logger.info("Azure build environment detected, adding additional config path")
402
+
403
+ work_dir_parent = Path("/__w")
404
+ work_dir: Path | None = None
405
+ for sub_path in work_dir_parent.iterdir():
406
+ if not sub_path.is_dir():
407
+ continue
408
+
409
+ logger.info("Testing sub path %s", sub_path)
410
+
411
+ if re.match(r"\d+", str(sub_path.name)):
412
+ work_dir = work_dir_parent / sub_path
413
+
414
+ if work_dir is not None:
415
+ logger.info("Subpath %s exists, checking for test path", work_dir)
416
+ test_config_path = work_dir / "s" / "test" / "config" / "config.yml"
417
+
418
+ if test_config_path.exists():
419
+ config_locations.append(test_config_path)
420
+ logger.info("Path %s added as config path", test_config_path)
421
+ break
422
+
423
+ logger.error("Config path not found at path %s", test_config_path)
424
+ logger.info("Available files:\n%s", "\n".join(sorted(str(path) for path in (work_dir / "s").glob("**/*"))))
425
+ work_dir = None
426
+
427
+
428
+ class Config(BaseSettings):
429
+ api: API = API()
430
+ ui: UI = UI()
431
+ auth: Auth = Auth()
432
+ core: Core = Core()
433
+ logging: Logging = Logging()
434
+
435
+ model_config = SettingsConfigDict(
436
+ yaml_file=config_locations,
437
+ yaml_file_encoding="utf-8",
438
+ strict=True,
439
+ env_nested_delimiter="__",
440
+ )
441
+
442
+ @classmethod
443
+ def settings_customise_sources(
444
+ cls, # noqa: ANN102
445
+ *args, # noqa: ANN002
446
+ **kwargs, # noqa: ANN002, ANN102
447
+ ) -> tuple[PydanticBaseSettingsSource, ...]:
448
+ "Adds a YamlConfigSettingsSource object at the end of the settings_customize_sources response."
449
+ return (*super().settings_customise_sources(*args, **kwargs), YamlConfigSettingsSource(cls))
450
+
451
+
452
+ if __name__ == "__main__":
453
+ # When executed, the config model will print the default values of the configuration
454
+ import yaml
455
+
456
+ print(yaml.safe_dump(Config().model_dump(mode="json"))) # noqa: T201
@@ -0,0 +1,136 @@
1
+ # ruff: noqa: D101
2
+ import re
3
+ from typing import Dict, Generic, Literal, Optional, Self, Union
4
+
5
+ from pydantic import BaseModel, Field, JsonValue, ValidationInfo, field_validator, model_validator
6
+ from pydantic_core import Url
7
+
8
+ from clue.common.exceptions import ClueValueError
9
+ from clue.common.logging import get_logger
10
+ from clue.constants.supported_types import SUPPORTED_TYPES
11
+ from clue.models.results import DATA, FORMAT_MAPPINGS_REVERSE
12
+ from clue.models.results.validation import validate_result
13
+ from clue.models.validators import validate_classification
14
+
15
+ logger = get_logger(__file__)
16
+
17
+
18
+ class FetcherDefinition(BaseModel):
19
+ id: str = Field(description="An ID for the given fetcher. Structured as <plugin_id>.<fetcher_id>.")
20
+ classification: str = Field(
21
+ description="Classification of the fetcher. Denotes the maximum classification of data sent to the fetcher.",
22
+ )
23
+ description: str = Field(description="A basic description of the fetcher's usage.")
24
+ format: str = Field(description="The output format of the fetcher's result.")
25
+ supported_types: set[str] = Field(description="A list of types this fetcher supports.")
26
+ extra_data: Optional[Dict[str, JsonValue]] = Field(
27
+ default=None, description="Extra data you want to define for a fetcher"
28
+ )
29
+
30
+ @field_validator("id")
31
+ @classmethod
32
+ def validate_id(cls, fetcher_id: str) -> str: # noqa: ANN102
33
+ """Validates the fetcher ID field.
34
+
35
+ Args:
36
+ fetcher_id (str): The ID to validate.
37
+
38
+ Raises:
39
+ ClueValueError: Raised whenever the ID is not in a valid format.
40
+
41
+ Returns:
42
+ str: The validated ID.
43
+ """
44
+ if re.match(r"[^a-z_]", fetcher_id):
45
+ raise ClueValueError("Invalid fetcher id - can only contain lowercase letters and underscores.")
46
+
47
+ return fetcher_id
48
+
49
+ @field_validator("classification")
50
+ @classmethod
51
+ def check_classification(cls, classification: str) -> str: # noqa: ANN102
52
+ """Validates the provided classification.
53
+
54
+ Args:
55
+ classification (str): The classification to validate.
56
+
57
+ Raises:
58
+ AssertionError: Raised whenever the provided classification is not valid.
59
+
60
+ Returns:
61
+ str: The validated classification.
62
+ """
63
+ return validate_classification(classification)
64
+
65
+ @field_validator("format")
66
+ @classmethod
67
+ def check_format(cls, format: str) -> str: # noqa: ANN102
68
+ """Validates the provided format.
69
+
70
+ Args:
71
+ format (str): The format to validate.
72
+
73
+ Raises:
74
+ ClueValueError: Raised whenever the provided format is not valid.
75
+
76
+ Returns:
77
+ str: The validated classification.
78
+ """
79
+ if format not in FORMAT_MAPPINGS_REVERSE:
80
+ raise ClueValueError("Invalid format. To use custom results, register your result using register_result.")
81
+
82
+ return format
83
+
84
+ @field_validator("supported_types")
85
+ @classmethod
86
+ def validate_supported_types(cls, supported_types: set[str]) -> set[str]: # noqa: ANN102
87
+ """Validate that the list of supported types matches the list of supported types"""
88
+ invalid_types = supported_types - set(SUPPORTED_TYPES.keys())
89
+
90
+ if invalid_types:
91
+ raise AssertionError(f"{', '.join(invalid_types)} are not supported types.")
92
+
93
+ return supported_types
94
+
95
+
96
+ class FetcherResult(BaseModel, Generic[DATA]):
97
+ outcome: Union[Literal["success"], Literal["failure"]] = Field(description="Did the fetcher succeed or fail?")
98
+ data: DATA | None = Field(description="The output of the fetcher.", default=None)
99
+ error: str | None = Field(description="If the fetcher failed, contains the relevant error message.", default=None)
100
+ format: str = Field(
101
+ description="What is the format of the output? Used to indicate what component to use when rendering "
102
+ "the output.",
103
+ )
104
+ link: Optional[Url] = Field(description="Link to more information on the fetcher", default=None)
105
+
106
+ @model_validator(mode="after")
107
+ def validate_model(self: Self, info: ValidationInfo) -> Self: # noqa: C901
108
+ """Validates the entire model.
109
+
110
+ Raises:
111
+ AssertionError: Raised whenever a field is invalid on the model.
112
+
113
+ Returns:
114
+ Self: The validated model.
115
+ """
116
+ if self.outcome == "success" and self.data is None:
117
+ raise ClueValueError("Successful fetcher results must return data.")
118
+
119
+ if self.outcome == "failure":
120
+ if self.data is not None:
121
+ raise ClueValueError("Failed fetcher results cannot return data.")
122
+ elif self.format != "error" or not self.error:
123
+ raise ClueValueError("Returning an error fetcher result must specify an error.")
124
+ else:
125
+ return self
126
+ elif self.error:
127
+ raise ClueValueError("Errors can only be specified if the outcome is failure.")
128
+
129
+ self.data = validate_result(self.format, self.data, info)
130
+
131
+ return self
132
+
133
+ @staticmethod
134
+ def error_result(err: str) -> "FetcherResult":
135
+ "Helper function to generate a failed fetcher result"
136
+ return FetcherResult(outcome="failure", format="error", error=err)