clue-api 1.0.0.dev7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clue/.gitignore +21 -0
- clue/__init__.py +0 -0
- clue/api/__init__.py +211 -0
- clue/api/base.py +99 -0
- clue/api/v1/__init__.py +82 -0
- clue/api/v1/actions.py +92 -0
- clue/api/v1/auth.py +243 -0
- clue/api/v1/configs.py +83 -0
- clue/api/v1/fetchers.py +94 -0
- clue/api/v1/lookup.py +221 -0
- clue/api/v1/registration.py +109 -0
- clue/api/v1/static.py +94 -0
- clue/app.py +166 -0
- clue/cache/__init__.py +129 -0
- clue/common/__init__.py +0 -0
- clue/common/classification.py +1006 -0
- clue/common/classification.yml +130 -0
- clue/common/dict_utils.py +130 -0
- clue/common/exceptions.py +199 -0
- clue/common/forge.py +152 -0
- clue/common/json_utils.py +10 -0
- clue/common/list_utils.py +11 -0
- clue/common/logging/__init__.py +291 -0
- clue/common/logging/audit.py +157 -0
- clue/common/logging/format.py +42 -0
- clue/common/regex.py +31 -0
- clue/common/str_utils.py +213 -0
- clue/common/swagger.py +139 -0
- clue/common/uid.py +47 -0
- clue/config.py +60 -0
- clue/constants/__init__.py +0 -0
- clue/constants/supported_types.py +38 -0
- clue/cronjobs/__init__.py +30 -0
- clue/cronjobs/plugins.py +32 -0
- clue/error.py +129 -0
- clue/gunicorn_config.py +29 -0
- clue/healthz.py +74 -0
- clue/helper/discover.py +53 -0
- clue/helper/headers.py +30 -0
- clue/helper/oauth.py +128 -0
- clue/models/__init__.py +0 -0
- clue/models/actions.py +243 -0
- clue/models/config.py +456 -0
- clue/models/fetchers.py +136 -0
- clue/models/graph.py +162 -0
- clue/models/model_list.py +52 -0
- clue/models/network.py +430 -0
- clue/models/results/__init__.py +34 -0
- clue/models/results/base.py +10 -0
- clue/models/results/graph.py +26 -0
- clue/models/results/image.py +22 -0
- clue/models/results/status.py +55 -0
- clue/models/results/validation.py +57 -0
- clue/models/selector.py +67 -0
- clue/models/utils.py +52 -0
- clue/models/validators.py +19 -0
- clue/patched.py +8 -0
- clue/plugin/__init__.py +1008 -0
- clue/plugin/helpers/__init__.py +0 -0
- clue/plugin/helpers/central_server.py +27 -0
- clue/plugin/helpers/email_render.py +228 -0
- clue/plugin/helpers/token.py +34 -0
- clue/plugin/helpers/trino.py +103 -0
- clue/plugin/interactive.py +270 -0
- clue/plugin/models.py +19 -0
- clue/plugin/utils.py +78 -0
- clue/remote/__init__.py +0 -0
- clue/remote/datatypes/__init__.py +130 -0
- clue/remote/datatypes/cache.py +62 -0
- clue/remote/datatypes/events.py +118 -0
- clue/remote/datatypes/hash.py +193 -0
- clue/remote/datatypes/queues/__init__.py +0 -0
- clue/remote/datatypes/queues/comms.py +62 -0
- clue/remote/datatypes/set.py +96 -0
- clue/remote/datatypes/user_quota_tracker.py +54 -0
- clue/security/__init__.py +211 -0
- clue/security/obo.py +95 -0
- clue/security/utils.py +34 -0
- clue/services/action_service.py +186 -0
- clue/services/auth_service.py +348 -0
- clue/services/config_service.py +38 -0
- clue/services/fetcher_service.py +203 -0
- clue/services/jwt_service.py +233 -0
- clue/services/lookup_service.py +786 -0
- clue/services/type_service.py +165 -0
- clue/services/user_service.py +152 -0
- clue_api-1.0.0.dev7.dist-info/METADATA +111 -0
- clue_api-1.0.0.dev7.dist-info/RECORD +91 -0
- clue_api-1.0.0.dev7.dist-info/WHEEL +4 -0
- clue_api-1.0.0.dev7.dist-info/entry_points.txt +8 -0
- clue_api-1.0.0.dev7.dist-info/licenses/LICENSE +11 -0
clue/models/config.py
ADDED
|
@@ -0,0 +1,456 @@
|
|
|
1
|
+
# ruff: noqa: D101
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
from email.utils import parseaddr
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from uuid import uuid4
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
|
10
|
+
from pydantic_core import Url
|
|
11
|
+
from pydantic_settings import (
|
|
12
|
+
BaseSettings,
|
|
13
|
+
PydanticBaseSettingsSource,
|
|
14
|
+
SettingsConfigDict,
|
|
15
|
+
YamlConfigSettingsSource,
|
|
16
|
+
)
|
|
17
|
+
from typing_extensions import Self
|
|
18
|
+
|
|
19
|
+
from clue.common import forge
|
|
20
|
+
from clue.common.exceptions import ClueValueError
|
|
21
|
+
from clue.common.logging.format import BRL_DATE_FORMAT, BRL_LOG_FORMAT
|
|
22
|
+
from clue.common.str_utils import default_string_value
|
|
23
|
+
|
|
24
|
+
AUTO_PROPERTY_TYPE = ["access", "classification", "type", "role", "remove_role", "group"]
|
|
25
|
+
DEFAULT_EMAIL_FIELDS = ["email", "emails", "extension_selectedEmailAddress", "otherMails", "preferred_username", "upn"]
|
|
26
|
+
DEFAULT_USER_FIELDS = ["uname", "preferred_username", "upn"]
|
|
27
|
+
DEFAULT_USER_NAME_FIELDS = ["name", "displayName"]
|
|
28
|
+
APP_NAME = default_string_value(env_name="APP_NAME", default="clue").replace("-dev", "") # type: ignore[union-attr]
|
|
29
|
+
CLASSIFICATION = forge.get_classification()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class PasswordRequirement(BaseModel):
|
|
33
|
+
lower: bool = Field(description="Password must contain lowercase letters", default=False)
|
|
34
|
+
number: bool = Field(description="Password must contain numbers", default=False)
|
|
35
|
+
special: bool = Field(description="Password must contain special characters", default=False)
|
|
36
|
+
upper: bool = Field(description="Password must contain uppercase letters", default=False)
|
|
37
|
+
min_length: int = Field(description="Minimum password length", default=12)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class OAuthProvider(BaseModel):
|
|
41
|
+
auto_create: bool = Field(default=True, description="Auto-create users if they are missing")
|
|
42
|
+
auto_sync: bool = Field(default=False, description="Should we automatically sync with OAuth provider?")
|
|
43
|
+
uid_randomize: bool = Field(
|
|
44
|
+
default=False,
|
|
45
|
+
description="Should we generate a random username for the authenticated user?",
|
|
46
|
+
)
|
|
47
|
+
uid_randomize_digits: int = Field(
|
|
48
|
+
default=0,
|
|
49
|
+
description="How many digits should we add at the end of the username?",
|
|
50
|
+
)
|
|
51
|
+
uid_randomize_delimiter: str = Field(
|
|
52
|
+
default="-",
|
|
53
|
+
description="What is the delimiter used by the random name generator?",
|
|
54
|
+
)
|
|
55
|
+
uid_regex: str | None = Field(
|
|
56
|
+
description="Regex used to parse an email address and capture parts to create a user ID out of it", default=None
|
|
57
|
+
)
|
|
58
|
+
uid_format: str | None = Field(
|
|
59
|
+
description="Format of the user ID based on the captured parts from the regex", default=None
|
|
60
|
+
)
|
|
61
|
+
client_id: str | None = Field(description="ID of your application to authenticate to the OAuth provider")
|
|
62
|
+
client_secret: str | None = Field(
|
|
63
|
+
description="Password to your application to authenticate to the OAuth provider", default=None
|
|
64
|
+
)
|
|
65
|
+
required_groups: list[str] = Field(
|
|
66
|
+
default=[], description="The groups the JWT must contain in order to allow access"
|
|
67
|
+
)
|
|
68
|
+
role_map: dict[str, str] = Field(default={}, description="A mapping of OAuth groups to clue roles")
|
|
69
|
+
classification_map: dict[str, str] = Field(
|
|
70
|
+
default={}, description="A mapping of OAuth groups to classification levels"
|
|
71
|
+
)
|
|
72
|
+
access_token_url: str | None = Field(description="URL to get access token")
|
|
73
|
+
authorize_url: str | None = Field(description="URL used to authorize access to a resource")
|
|
74
|
+
api_base_url: str | None = Field(description="Base URL for downloading the user's and groups info")
|
|
75
|
+
audience: str | None = Field(
|
|
76
|
+
description="The audience to validate against. Only must be set if audience is different than the client id."
|
|
77
|
+
)
|
|
78
|
+
scope: str = Field(description="The scope to validate against")
|
|
79
|
+
iss: str | None = Field(description="Optional issuer field for JWT validation", default=None)
|
|
80
|
+
jwks_uri: str = Field(description="URL used to verify if a returned JWKS token is valid")
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class OAuth(BaseModel):
|
|
84
|
+
enabled: bool = Field(description="Enable use of OAuth?", default=False)
|
|
85
|
+
gravatar_enabled: bool = Field(description="Enable gravatar?", default=False)
|
|
86
|
+
providers: dict[str, OAuthProvider] = Field(default={}, description="OAuth provider configuration")
|
|
87
|
+
other_audiences: list[str] | None = Field(
|
|
88
|
+
default=None, description="What other audiences in JWT tokens should Clue accept?"
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
@model_validator(mode="before")
|
|
92
|
+
@classmethod
|
|
93
|
+
def prepare_model(
|
|
94
|
+
cls, # noqa: ANN102
|
|
95
|
+
oauth_data: dict[str, dict[str, dict | OAuthProvider]], # noqa: ANN102
|
|
96
|
+
) -> dict[str, dict[str, dict | OAuthProvider]]:
|
|
97
|
+
"""Validates the oauth data, and adds the client secret if it's not already there.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
oauth_data (dict[str, dict[str, dict | OAuthProvider]]): The data to validate.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
dict[str, dict[str, dict | OAuthProvider]]: The validated data with the client secrets.
|
|
104
|
+
"""
|
|
105
|
+
if "providers" in oauth_data and isinstance(oauth_data["providers"], dict):
|
|
106
|
+
for name, provider in oauth_data["providers"].items():
|
|
107
|
+
if isinstance(provider, OAuthProvider):
|
|
108
|
+
provider.client_secret = default_string_value(
|
|
109
|
+
provider.client_secret,
|
|
110
|
+
env_name=f"{name.upper()}_CLIENT_SECRET",
|
|
111
|
+
)
|
|
112
|
+
elif isinstance(provider, dict):
|
|
113
|
+
provider["client_secret"] = default_string_value(
|
|
114
|
+
provider.get("client_secret", None),
|
|
115
|
+
env_name=f"{name.upper()}_CLIENT_SECRET",
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
return oauth_data
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class ServiceAccountCreds(BaseModel):
|
|
122
|
+
username: str = Field(description="Username of the service account")
|
|
123
|
+
password: str = Field(description="Password of the service account")
|
|
124
|
+
provider: str = Field(description="What OAuth provider does this service account connect to?")
|
|
125
|
+
|
|
126
|
+
@model_validator(mode="before")
|
|
127
|
+
@classmethod
|
|
128
|
+
def prepare_model(cls, data: dict[str, str]) -> dict[str, str]: # noqa: ANN102
|
|
129
|
+
"""Adds the service account password to the data if missing.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
data (dict[str, str]): The data to validate.
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
dict[str, str]: The data including the password.
|
|
136
|
+
"""
|
|
137
|
+
if "password" not in data and "provider" in data:
|
|
138
|
+
if env_pass := os.getenv(f'SA_{data["provider"].upper()}_PASSWORD'):
|
|
139
|
+
data["password"] = env_pass
|
|
140
|
+
|
|
141
|
+
return data
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class ServiceAccount(BaseModel):
|
|
145
|
+
enabled: bool = Field(description="Enable use of a service account?", default=False)
|
|
146
|
+
accounts: list[ServiceAccountCreds] = Field(
|
|
147
|
+
description="A list of service accounts on a per-provider basis", default=[]
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
@model_validator(mode="after")
|
|
151
|
+
def validate_model(self: Self) -> Self:
|
|
152
|
+
"""Validates the model.
|
|
153
|
+
|
|
154
|
+
Raises:
|
|
155
|
+
ClueValueError: Raised whenever there is more than one service account per provider.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
Self: The validated model.
|
|
159
|
+
"""
|
|
160
|
+
providers = {account.provider for account in self.accounts}
|
|
161
|
+
|
|
162
|
+
if len(providers) != len(self.accounts):
|
|
163
|
+
raise ClueValueError("You may only have one service account per provider.")
|
|
164
|
+
|
|
165
|
+
return self
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
class Auth(BaseModel):
|
|
169
|
+
allow_apikeys: bool = Field(description="Allow API keys?", default=False)
|
|
170
|
+
apikeys: dict[str, str] = Field(default={}, description="API Keys available in the system")
|
|
171
|
+
propagate_clue_key: bool = Field(
|
|
172
|
+
default=True, description="Should clue include the root clue token in requests when OBO is used?"
|
|
173
|
+
)
|
|
174
|
+
oauth: OAuth = OAuth()
|
|
175
|
+
service_account: ServiceAccount = ServiceAccount()
|
|
176
|
+
|
|
177
|
+
@model_validator(mode="after")
|
|
178
|
+
def validate_model(self: Self) -> Self:
|
|
179
|
+
"""Validates the model.
|
|
180
|
+
|
|
181
|
+
Raises:
|
|
182
|
+
ClueValueError: Raised whenever there is an invalid value in the model.
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
Self: The validated model.
|
|
186
|
+
"""
|
|
187
|
+
if not self.service_account.enabled:
|
|
188
|
+
return self
|
|
189
|
+
|
|
190
|
+
if not self.oauth.enabled:
|
|
191
|
+
raise ClueValueError("In order to use service accounts to connect to plugins, you must have oauth enabled.")
|
|
192
|
+
|
|
193
|
+
for account in self.service_account.accounts:
|
|
194
|
+
if account.provider not in self.oauth.providers:
|
|
195
|
+
raise ClueValueError(
|
|
196
|
+
f"{account.username} is used to connect to non-existent provider {account.provider}."
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
return self
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
class RedisServer(BaseModel):
|
|
203
|
+
host: str = Field(description="Hostname of Redis instance", default="127.0.0.1")
|
|
204
|
+
port: int = Field(description="Port of Redis instance", default=6379)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
class APMServer(BaseModel):
|
|
208
|
+
server_url: str | None = Field(description="URL to API server", default=None)
|
|
209
|
+
token: str | None = Field(description="Authentication token for server", default=None)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
class Metrics(BaseModel):
|
|
213
|
+
apm_server: APMServer = APMServer()
|
|
214
|
+
export_interval: int = Field(description="How often should we be exporting metrics?", default=5)
|
|
215
|
+
redis: RedisServer = RedisServer()
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class Core(BaseModel):
|
|
219
|
+
metrics: Metrics = Metrics()
|
|
220
|
+
redis: RedisServer = RedisServer()
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
class LogLevel(str, Enum):
|
|
224
|
+
DEBUG = "DEBUG"
|
|
225
|
+
INFO = "INFO"
|
|
226
|
+
WARNING = "WARNING"
|
|
227
|
+
ERROR = "ERROR"
|
|
228
|
+
CRITICAL = "CRITICAL"
|
|
229
|
+
DISABLED = "DISABLED"
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
class Logging(BaseModel):
|
|
233
|
+
log_level: LogLevel = Field( # type: ignore
|
|
234
|
+
description="What level of logging should we have?", default=LogLevel.DEBUG
|
|
235
|
+
)
|
|
236
|
+
log_to_console: bool = Field(description="Should we log to console?", default=True)
|
|
237
|
+
log_to_file: bool = Field(description="Should we log to files on the server?", default=False)
|
|
238
|
+
log_directory: str = Field(
|
|
239
|
+
description="If `log_to_file: true`, what is the directory to store logs?", default="/var/log/clue/"
|
|
240
|
+
)
|
|
241
|
+
log_to_syslog: bool = Field(description="Should logs be sent to a syslog server?", default=False)
|
|
242
|
+
syslog_host: str = Field(
|
|
243
|
+
description="If `log_to_syslog: true`, provide hostname/IP of the syslog server?", default="localhost"
|
|
244
|
+
)
|
|
245
|
+
syslog_port: int = Field(description="If `log_to_syslog: true`, provide port of the syslog server?", default=514)
|
|
246
|
+
export_interval: int = Field(description="How often, in seconds, should counters log their values?", default=5)
|
|
247
|
+
log_as_json: bool = Field(description="Log in JSON format?", default=False)
|
|
248
|
+
heartbeat_file: str | None = Field(
|
|
249
|
+
description=(
|
|
250
|
+
"Add a health check to core components.<br>"
|
|
251
|
+
"If `true`, core components will touch this path regularly to tell the container environment it is healthy"
|
|
252
|
+
),
|
|
253
|
+
default=None,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
class ExternalSource(BaseModel):
|
|
258
|
+
name: str = Field(description="Name of the source.")
|
|
259
|
+
classification: str | None = Field(
|
|
260
|
+
description="Minimum classification applied to information from the source and required to know the "
|
|
261
|
+
"existence of the source.",
|
|
262
|
+
default=CLASSIFICATION.UNRESTRICTED,
|
|
263
|
+
)
|
|
264
|
+
max_classification: str | None = Field(
|
|
265
|
+
description="Maximum classification of data that may be handled by the source", default=None
|
|
266
|
+
)
|
|
267
|
+
url: str = Field(description="URL of the upstream source's lookup service.")
|
|
268
|
+
obo_target: str | None = Field(
|
|
269
|
+
description="The name of a target clue should OBO to before forwarding the token", default=None
|
|
270
|
+
)
|
|
271
|
+
maintainer: str | None = Field(
|
|
272
|
+
description="Email contact in the RFC-5322 format 'Full Name <email_address>'.", default=None
|
|
273
|
+
)
|
|
274
|
+
datahub_link: Url | None = Field(description="Link to datahub entry on this enrichment", default=None)
|
|
275
|
+
documentation_link: Url | None = Field(description="Link to documentation on this enrichment", default=None)
|
|
276
|
+
production: bool = Field(
|
|
277
|
+
description="Is this source ready for production? This will disable model validation for increased speeds",
|
|
278
|
+
default=False,
|
|
279
|
+
)
|
|
280
|
+
include_default: bool = Field(
|
|
281
|
+
description="Should this source be included by default, or only when specifically requested?", default=True
|
|
282
|
+
)
|
|
283
|
+
built_in: bool = Field(default=True, description="Is this a source included in the clue configuration files?")
|
|
284
|
+
default_timeout: float = Field(
|
|
285
|
+
default=30.0, description="How long should clue wait by default for action execution?"
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
model_config = ConfigDict(validate_assignment=True)
|
|
289
|
+
|
|
290
|
+
@field_validator("maintainer")
|
|
291
|
+
@classmethod
|
|
292
|
+
def validate_maintainer(cls, maintainer: str | None) -> str | None: # noqa: ANN102
|
|
293
|
+
"""Validates the maintainer field.
|
|
294
|
+
|
|
295
|
+
Args:
|
|
296
|
+
maintainer (str | None): The maintainer field to validate. If None, it will be passed through.
|
|
297
|
+
|
|
298
|
+
Raises:
|
|
299
|
+
AssertionError: Raised whenever the field is in an invalid format.
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
str | None: The validated maintainer field.
|
|
303
|
+
"""
|
|
304
|
+
if maintainer:
|
|
305
|
+
parsed_addr = parseaddr(maintainer)
|
|
306
|
+
if not (all(parsed_addr) and "@" in parsed_addr[1]):
|
|
307
|
+
raise AssertionError("Maintainer string must be in RFC-5322 format.")
|
|
308
|
+
|
|
309
|
+
return maintainer
|
|
310
|
+
|
|
311
|
+
@field_validator("classification", "max_classification")
|
|
312
|
+
@classmethod
|
|
313
|
+
def validate_classification(cls, cls_str: str) -> str: # noqa: ANN102
|
|
314
|
+
"""Validates the classification and max_classification fields.
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
cls_str (str): The classification value to validate.
|
|
318
|
+
|
|
319
|
+
Raises:
|
|
320
|
+
AssertionError: Raised whenever the provided classification is not valid.
|
|
321
|
+
|
|
322
|
+
Returns:
|
|
323
|
+
str: The validated classification value.
|
|
324
|
+
"""
|
|
325
|
+
cls_str = cls_str.upper()
|
|
326
|
+
|
|
327
|
+
if not CLASSIFICATION.is_valid(cls_str):
|
|
328
|
+
raise AssertionError(f"{cls_str} is not a valid classification")
|
|
329
|
+
|
|
330
|
+
return cls_str
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
EXAMPLE_EXTERNAL_SOURCE_VT = {
|
|
334
|
+
# This is an example on how this would work with VirusTotal
|
|
335
|
+
"name": "VirusTotal",
|
|
336
|
+
"url": "vt-lookup.namespace.svc.cluster.local",
|
|
337
|
+
"classification": "TLP:CLEAR",
|
|
338
|
+
"max_classification": "TLP:CLEAR",
|
|
339
|
+
}
|
|
340
|
+
|
|
341
|
+
EXAMPLE_EXTERNAL_SOURCE_MB = {
|
|
342
|
+
# This is an example on how this would work with Malware Bazaar
|
|
343
|
+
"name": "Malware Bazaar",
|
|
344
|
+
"url": "mb-lookup.namespace.scv.cluster.local",
|
|
345
|
+
"classification": "TLP:CLEAR",
|
|
346
|
+
"max_classification": "TLP:CLEAR",
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
class OBOService(BaseModel):
|
|
351
|
+
enabled: bool = Field(default=False, description="Is this service available?")
|
|
352
|
+
scope: str = Field(description="The scope to OBO to.")
|
|
353
|
+
quota: int | None = Field(
|
|
354
|
+
default=None, description="Optional quota for the number of concurrent requests per user to this service"
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
class UI(BaseModel):
|
|
359
|
+
cors_origins: list[str] = Field(default=[], description="List of valid deployments")
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
class API(BaseModel):
|
|
363
|
+
audit: bool = Field(description="Should API calls be audited and saved to a separate log file?", default=True)
|
|
364
|
+
debug: bool = Field(description="Enable debugging?", default=False)
|
|
365
|
+
discover_url: str | None = Field(description="Discover URL", default=None)
|
|
366
|
+
external_sources: list[ExternalSource] = Field(description="List of external sources to query", default=[])
|
|
367
|
+
obo_targets: dict[str, OBOService] = Field(description="List of targets clue can OBO to", default={})
|
|
368
|
+
secret_key: str = Field(description="Flask secret key to store cookies, etc.", default_factory=lambda: uuid4().hex)
|
|
369
|
+
session_duration: int = Field(
|
|
370
|
+
description="Duration of the user session before the user has to login again", default=3600
|
|
371
|
+
)
|
|
372
|
+
validate_session_ip: bool = Field(
|
|
373
|
+
description="Validate if the session IP matches the IP the session was created from", default=True
|
|
374
|
+
)
|
|
375
|
+
validate_session_useragent: bool = Field(
|
|
376
|
+
description="Validate if the session useragent matches the useragent the session was created with", default=True
|
|
377
|
+
)
|
|
378
|
+
validate_session_xsrf_token: bool = Field(
|
|
379
|
+
description="Validate if the XSRF token matches the randomly generated token for the session", default=True
|
|
380
|
+
)
|
|
381
|
+
vault_url: str = Field(default="https://vault.vault.svc.cluster.local:8200")
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
root_path = Path("/etc") / APP_NAME
|
|
385
|
+
|
|
386
|
+
config_locations = [
|
|
387
|
+
root_path / "conf" / "config.yml",
|
|
388
|
+
Path(os.environ.get("CLUE_CONF_FOLDER", root_path)) / "config.yml",
|
|
389
|
+
]
|
|
390
|
+
|
|
391
|
+
if os.getenv("AZURE_TEST_CONFIG", None) is not None:
|
|
392
|
+
import re
|
|
393
|
+
|
|
394
|
+
logger = logging.getLogger("clue.models.config")
|
|
395
|
+
logger.setLevel(logging.INFO)
|
|
396
|
+
console = logging.StreamHandler()
|
|
397
|
+
console.setLevel(logging.INFO)
|
|
398
|
+
console.setFormatter(logging.Formatter(BRL_LOG_FORMAT, BRL_DATE_FORMAT))
|
|
399
|
+
logger.addHandler(console)
|
|
400
|
+
|
|
401
|
+
logger.info("Azure build environment detected, adding additional config path")
|
|
402
|
+
|
|
403
|
+
work_dir_parent = Path("/__w")
|
|
404
|
+
work_dir: Path | None = None
|
|
405
|
+
for sub_path in work_dir_parent.iterdir():
|
|
406
|
+
if not sub_path.is_dir():
|
|
407
|
+
continue
|
|
408
|
+
|
|
409
|
+
logger.info("Testing sub path %s", sub_path)
|
|
410
|
+
|
|
411
|
+
if re.match(r"\d+", str(sub_path.name)):
|
|
412
|
+
work_dir = work_dir_parent / sub_path
|
|
413
|
+
|
|
414
|
+
if work_dir is not None:
|
|
415
|
+
logger.info("Subpath %s exists, checking for test path", work_dir)
|
|
416
|
+
test_config_path = work_dir / "s" / "test" / "config" / "config.yml"
|
|
417
|
+
|
|
418
|
+
if test_config_path.exists():
|
|
419
|
+
config_locations.append(test_config_path)
|
|
420
|
+
logger.info("Path %s added as config path", test_config_path)
|
|
421
|
+
break
|
|
422
|
+
|
|
423
|
+
logger.error("Config path not found at path %s", test_config_path)
|
|
424
|
+
logger.info("Available files:\n%s", "\n".join(sorted(str(path) for path in (work_dir / "s").glob("**/*"))))
|
|
425
|
+
work_dir = None
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
class Config(BaseSettings):
|
|
429
|
+
api: API = API()
|
|
430
|
+
ui: UI = UI()
|
|
431
|
+
auth: Auth = Auth()
|
|
432
|
+
core: Core = Core()
|
|
433
|
+
logging: Logging = Logging()
|
|
434
|
+
|
|
435
|
+
model_config = SettingsConfigDict(
|
|
436
|
+
yaml_file=config_locations,
|
|
437
|
+
yaml_file_encoding="utf-8",
|
|
438
|
+
strict=True,
|
|
439
|
+
env_nested_delimiter="__",
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
@classmethod
|
|
443
|
+
def settings_customise_sources(
|
|
444
|
+
cls, # noqa: ANN102
|
|
445
|
+
*args, # noqa: ANN002
|
|
446
|
+
**kwargs, # noqa: ANN002, ANN102
|
|
447
|
+
) -> tuple[PydanticBaseSettingsSource, ...]:
|
|
448
|
+
"Adds a YamlConfigSettingsSource object at the end of the settings_customize_sources response."
|
|
449
|
+
return (*super().settings_customise_sources(*args, **kwargs), YamlConfigSettingsSource(cls))
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
if __name__ == "__main__":
|
|
453
|
+
# When executed, the config model will print the default values of the configuration
|
|
454
|
+
import yaml
|
|
455
|
+
|
|
456
|
+
print(yaml.safe_dump(Config().model_dump(mode="json"))) # noqa: T201
|
clue/models/fetchers.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
# ruff: noqa: D101
|
|
2
|
+
import re
|
|
3
|
+
from typing import Dict, Generic, Literal, Optional, Self, Union
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, Field, JsonValue, ValidationInfo, field_validator, model_validator
|
|
6
|
+
from pydantic_core import Url
|
|
7
|
+
|
|
8
|
+
from clue.common.exceptions import ClueValueError
|
|
9
|
+
from clue.common.logging import get_logger
|
|
10
|
+
from clue.constants.supported_types import SUPPORTED_TYPES
|
|
11
|
+
from clue.models.results import DATA, FORMAT_MAPPINGS_REVERSE
|
|
12
|
+
from clue.models.results.validation import validate_result
|
|
13
|
+
from clue.models.validators import validate_classification
|
|
14
|
+
|
|
15
|
+
logger = get_logger(__file__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class FetcherDefinition(BaseModel):
|
|
19
|
+
id: str = Field(description="An ID for the given fetcher. Structured as <plugin_id>.<fetcher_id>.")
|
|
20
|
+
classification: str = Field(
|
|
21
|
+
description="Classification of the fetcher. Denotes the maximum classification of data sent to the fetcher.",
|
|
22
|
+
)
|
|
23
|
+
description: str = Field(description="A basic description of the fetcher's usage.")
|
|
24
|
+
format: str = Field(description="The output format of the fetcher's result.")
|
|
25
|
+
supported_types: set[str] = Field(description="A list of types this fetcher supports.")
|
|
26
|
+
extra_data: Optional[Dict[str, JsonValue]] = Field(
|
|
27
|
+
default=None, description="Extra data you want to define for a fetcher"
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
@field_validator("id")
|
|
31
|
+
@classmethod
|
|
32
|
+
def validate_id(cls, fetcher_id: str) -> str: # noqa: ANN102
|
|
33
|
+
"""Validates the fetcher ID field.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
fetcher_id (str): The ID to validate.
|
|
37
|
+
|
|
38
|
+
Raises:
|
|
39
|
+
ClueValueError: Raised whenever the ID is not in a valid format.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
str: The validated ID.
|
|
43
|
+
"""
|
|
44
|
+
if re.match(r"[^a-z_]", fetcher_id):
|
|
45
|
+
raise ClueValueError("Invalid fetcher id - can only contain lowercase letters and underscores.")
|
|
46
|
+
|
|
47
|
+
return fetcher_id
|
|
48
|
+
|
|
49
|
+
@field_validator("classification")
|
|
50
|
+
@classmethod
|
|
51
|
+
def check_classification(cls, classification: str) -> str: # noqa: ANN102
|
|
52
|
+
"""Validates the provided classification.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
classification (str): The classification to validate.
|
|
56
|
+
|
|
57
|
+
Raises:
|
|
58
|
+
AssertionError: Raised whenever the provided classification is not valid.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
str: The validated classification.
|
|
62
|
+
"""
|
|
63
|
+
return validate_classification(classification)
|
|
64
|
+
|
|
65
|
+
@field_validator("format")
|
|
66
|
+
@classmethod
|
|
67
|
+
def check_format(cls, format: str) -> str: # noqa: ANN102
|
|
68
|
+
"""Validates the provided format.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
format (str): The format to validate.
|
|
72
|
+
|
|
73
|
+
Raises:
|
|
74
|
+
ClueValueError: Raised whenever the provided format is not valid.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
str: The validated classification.
|
|
78
|
+
"""
|
|
79
|
+
if format not in FORMAT_MAPPINGS_REVERSE:
|
|
80
|
+
raise ClueValueError("Invalid format. To use custom results, register your result using register_result.")
|
|
81
|
+
|
|
82
|
+
return format
|
|
83
|
+
|
|
84
|
+
@field_validator("supported_types")
|
|
85
|
+
@classmethod
|
|
86
|
+
def validate_supported_types(cls, supported_types: set[str]) -> set[str]: # noqa: ANN102
|
|
87
|
+
"""Validate that the list of supported types matches the list of supported types"""
|
|
88
|
+
invalid_types = supported_types - set(SUPPORTED_TYPES.keys())
|
|
89
|
+
|
|
90
|
+
if invalid_types:
|
|
91
|
+
raise AssertionError(f"{', '.join(invalid_types)} are not supported types.")
|
|
92
|
+
|
|
93
|
+
return supported_types
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class FetcherResult(BaseModel, Generic[DATA]):
|
|
97
|
+
outcome: Union[Literal["success"], Literal["failure"]] = Field(description="Did the fetcher succeed or fail?")
|
|
98
|
+
data: DATA | None = Field(description="The output of the fetcher.", default=None)
|
|
99
|
+
error: str | None = Field(description="If the fetcher failed, contains the relevant error message.", default=None)
|
|
100
|
+
format: str = Field(
|
|
101
|
+
description="What is the format of the output? Used to indicate what component to use when rendering "
|
|
102
|
+
"the output.",
|
|
103
|
+
)
|
|
104
|
+
link: Optional[Url] = Field(description="Link to more information on the fetcher", default=None)
|
|
105
|
+
|
|
106
|
+
@model_validator(mode="after")
|
|
107
|
+
def validate_model(self: Self, info: ValidationInfo) -> Self: # noqa: C901
|
|
108
|
+
"""Validates the entire model.
|
|
109
|
+
|
|
110
|
+
Raises:
|
|
111
|
+
AssertionError: Raised whenever a field is invalid on the model.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
Self: The validated model.
|
|
115
|
+
"""
|
|
116
|
+
if self.outcome == "success" and self.data is None:
|
|
117
|
+
raise ClueValueError("Successful fetcher results must return data.")
|
|
118
|
+
|
|
119
|
+
if self.outcome == "failure":
|
|
120
|
+
if self.data is not None:
|
|
121
|
+
raise ClueValueError("Failed fetcher results cannot return data.")
|
|
122
|
+
elif self.format != "error" or not self.error:
|
|
123
|
+
raise ClueValueError("Returning an error fetcher result must specify an error.")
|
|
124
|
+
else:
|
|
125
|
+
return self
|
|
126
|
+
elif self.error:
|
|
127
|
+
raise ClueValueError("Errors can only be specified if the outcome is failure.")
|
|
128
|
+
|
|
129
|
+
self.data = validate_result(self.format, self.data, info)
|
|
130
|
+
|
|
131
|
+
return self
|
|
132
|
+
|
|
133
|
+
@staticmethod
|
|
134
|
+
def error_result(err: str) -> "FetcherResult":
|
|
135
|
+
"Helper function to generate a failed fetcher result"
|
|
136
|
+
return FetcherResult(outcome="failure", format="error", error=err)
|