VaultAPI 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vaultapi/routes.py ADDED
@@ -0,0 +1,413 @@
1
+ import logging
2
+ import sqlite3
3
+ from http import HTTPStatus
4
+ from typing import Dict, List
5
+
6
+ from fastapi import Depends, Request
7
+ from fastapi.responses import RedirectResponse
8
+ from fastapi.routing import APIRoute
9
+ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
10
+
11
+ from . import auth, database, exceptions, models, payload, rate_limit, transit
12
+
13
+ LOGGER = logging.getLogger("uvicorn.default")
14
+ security = HTTPBearer()
15
+
16
+
17
+ async def retrieve_secret(key: str, table_name: str) -> str | None:
18
+ """Retrieve an existing secret from a table in the database.
19
+
20
+ Args:
21
+ key: Name of the secret to retrieve.
22
+ table_name: Name of the table where the secret is stored.
23
+
24
+ Returns:
25
+ str:
26
+ Returns the secret value.
27
+ """
28
+ try:
29
+ return database.get_secret(key=key, table_name=table_name)
30
+ except sqlite3.OperationalError as error:
31
+ LOGGER.error(error)
32
+ raise exceptions.APIResponse(
33
+ status_code=HTTPStatus.BAD_REQUEST.real, detail=error.args[0]
34
+ )
35
+
36
+
37
+ async def retrieve_secrets(table_name: str, keys: List[str] = None) -> Dict[str, str]:
38
+ """Retrieve multiple secrets from a table or retrieve the table as a whole.
39
+
40
+ Args:
41
+ table_name: Name of the table where the secret is stored.
42
+ keys: List of keys for which the values have to be retrieved.
43
+
44
+ Returns:
45
+ Dict[str, str]:
46
+ Returns the key-value pairs for secret key and it's value.
47
+ """
48
+ if keys:
49
+ values = {}
50
+ for key in keys:
51
+ if value := await retrieve_secret(key, table_name):
52
+ values[key] = value
53
+ return values
54
+ else:
55
+ try:
56
+ return dict(database.get_table(table_name))
57
+ except sqlite3.OperationalError as error:
58
+ LOGGER.error(error)
59
+ raise exceptions.APIResponse(
60
+ status_code=HTTPStatus.BAD_REQUEST.real, detail=error.args[0]
61
+ )
62
+
63
+
64
+ async def get_secret(
65
+ request: Request,
66
+ key: str,
67
+ table_name: str = "default",
68
+ apikey: HTTPAuthorizationCredentials = Depends(security),
69
+ ):
70
+ """**API function to retrieve a secret.**
71
+
72
+ **Args:**
73
+
74
+ request: Reference to the FastAPI request object.
75
+ key: Name of the secret to be retrieved.
76
+ table_name: Name of the table where the secret is stored.
77
+ apikey: API Key to authenticate the request.
78
+
79
+ **Raises:**
80
+
81
+ APIResponse:
82
+ Raises the HTTPStatus object with a status code and detail as response.
83
+ """
84
+ await auth.validate(request, apikey)
85
+ if value := await retrieve_secret(key, table_name):
86
+ LOGGER.info("Secret value for '%s' was retrieved", key)
87
+ decrypted = models.session.fernet.decrypt(value).decode(encoding="UTF-8")
88
+ raise exceptions.APIResponse(
89
+ status_code=HTTPStatus.OK.real, detail=transit.encrypt({key: decrypted})
90
+ )
91
+ LOGGER.info("Secret value for '%s' NOT found in the datastore", key)
92
+ raise exceptions.APIResponse(
93
+ status_code=HTTPStatus.NOT_FOUND.real, detail=HTTPStatus.NOT_FOUND.phrase
94
+ )
95
+
96
+
97
+ async def get_secrets(
98
+ request: Request,
99
+ keys: str,
100
+ table_name: str = "default",
101
+ apikey: HTTPAuthorizationCredentials = Depends(security),
102
+ ):
103
+ """**API function to retrieve multiple secrets at a time.**
104
+
105
+ **Args:**
106
+
107
+ request: Reference to the FastAPI request object.
108
+ key: Comma separated list of secret names to be retrieved.
109
+ table_name: Name of the table where the secrets are stored.
110
+ apikey: API Key to authenticate the request.
111
+
112
+ **Raises:**
113
+
114
+ APIResponse:
115
+ Raises the HTTPStatus object with a status code and detail as response.
116
+ """
117
+ await auth.validate(request, apikey)
118
+ # keys = [key.strip() for key in keys.split(",") if key.strip()]
119
+ keys = list(filter(None, map(str.strip, keys.split(","))))
120
+ keys_ct = len(keys)
121
+ try:
122
+ assert keys_ct >= 1, f"Expected at least one key, received {keys_ct}"
123
+ except AssertionError as error:
124
+ LOGGER.error(error)
125
+ raise exceptions.APIResponse(
126
+ status_code=HTTPStatus.BAD_REQUEST.real, detail=error.args[0]
127
+ )
128
+ if values := await retrieve_secrets(table_name, keys):
129
+ values_ct = len(values)
130
+ try:
131
+ assert (
132
+ values_ct == keys_ct
133
+ ), f"Number of keys [{keys_ct}] requested didn't match the number of values [{values_ct}] retrieved."
134
+ LOGGER.info("Secret value for %d (%s) were retrieved", keys_ct, keys)
135
+ code = HTTPStatus.OK.real
136
+ except AssertionError as error:
137
+ LOGGER.warning(error)
138
+ code = HTTPStatus.PARTIAL_CONTENT.real
139
+ decrypted = {
140
+ key: models.session.fernet.decrypt(value).decode(encoding="UTF-8")
141
+ for key, value in values.items()
142
+ }
143
+ raise exceptions.APIResponse(
144
+ status_code=code, detail=transit.encrypt(decrypted)
145
+ )
146
+ if keys_ct == 1:
147
+ LOGGER.info("Secret value for '%s' NOT found in the datastore", keys[0])
148
+ else:
149
+ LOGGER.info(
150
+ "Secret values for %d keys %s were NOT found in the datastore",
151
+ keys_ct,
152
+ keys,
153
+ )
154
+ raise exceptions.APIResponse(
155
+ status_code=HTTPStatus.NOT_FOUND.real, detail=HTTPStatus.NOT_FOUND.phrase
156
+ )
157
+
158
+
159
+ async def list_tables(
160
+ request: Request,
161
+ apikey: HTTPAuthorizationCredentials = Depends(security),
162
+ ):
163
+ """**API function to retrieve ALL available tables.**
164
+
165
+ **Args:**
166
+
167
+ request: Reference to the FastAPI request object.
168
+ apikey: API Key to authenticate the request.
169
+
170
+ **Raises:**
171
+
172
+ APIResponse:
173
+ Raises the HTTPStatus object with a status code and detail as response.
174
+ """
175
+ await auth.validate(request, apikey)
176
+ raise exceptions.APIResponse(
177
+ status_code=HTTPStatus.OK.real, detail=database.list_tables()
178
+ )
179
+
180
+
181
+ async def get_table(
182
+ request: Request,
183
+ table_name: str = "default",
184
+ apikey: HTTPAuthorizationCredentials = Depends(security),
185
+ ):
186
+ """**API function to retrieve ALL the key-value pairs stored in a particular table.**
187
+
188
+ **Args:**
189
+
190
+ request: Reference to the FastAPI request object.
191
+ table_name: Name of the table where the secrets are stored.
192
+ apikey: API Key to authenticate the request.
193
+
194
+ **Raises:**
195
+
196
+ APIResponse:
197
+ Raises the HTTPStatus object with a status code and detail as response.
198
+ """
199
+ await auth.validate(request, apikey)
200
+ table_content = await retrieve_secrets(table_name)
201
+ decrypted = {
202
+ key: models.session.fernet.decrypt(value).decode(encoding="UTF-8")
203
+ for key, value in table_content.items()
204
+ }
205
+ raise exceptions.APIResponse(
206
+ status_code=HTTPStatus.OK.real, detail=transit.encrypt(decrypted)
207
+ )
208
+
209
+
210
+ async def put_secret(
211
+ request: Request,
212
+ data: payload.PutSecret,
213
+ apikey: HTTPAuthorizationCredentials = Depends(security),
214
+ ):
215
+ """**API function to add secrets to database.**
216
+
217
+ **Args:**
218
+
219
+ request: Reference to the FastAPI request object.
220
+ data: Payload with ``key``, ``value``, and ``table_name`` as body.
221
+ apikey: API Key to authenticate the request.
222
+
223
+ **Raises:**
224
+
225
+ APIResponse:
226
+ Raises the HTTPStatus object with a status code and detail as response.
227
+ """
228
+ await auth.validate(request, apikey)
229
+ if await retrieve_secret(data.key, data.table_name):
230
+ LOGGER.info("Secret value for '%s' will be overridden", data.key)
231
+ else:
232
+ LOGGER.info(
233
+ "Storing a secret value for '%s' to the table '%s' in the datastore",
234
+ data.key,
235
+ data.table_name,
236
+ )
237
+ encrypted = models.session.fernet.encrypt(data.value.encode(encoding="UTF-8"))
238
+ database.put_secret(key=data.key, value=encrypted, table_name=data.table_name)
239
+ raise exceptions.APIResponse(
240
+ status_code=HTTPStatus.OK.real, detail=HTTPStatus.OK.phrase
241
+ )
242
+
243
+
244
+ async def put_secrets(
245
+ request: Request,
246
+ data: payload.PutSecrets,
247
+ apikey: HTTPAuthorizationCredentials = Depends(security),
248
+ ):
249
+ """**API function to add multiple secrets to a table in the database.**
250
+
251
+ **Args:**
252
+
253
+ request: Reference to the FastAPI request object.
254
+ data: Payload with ``key``, ``value``, and ``table_name`` as body.
255
+ apikey: API Key to authenticate the request.
256
+
257
+ **Raises:**
258
+
259
+ APIResponse:
260
+ Raises the HTTPStatus object with a status code and detail as response.
261
+ """
262
+ await auth.validate(request, apikey)
263
+ for key, value in data.secrets.items():
264
+ encrypted = models.session.fernet.encrypt(value.encode(encoding="UTF-8"))
265
+ database.put_secret(key=key, value=encrypted, table_name=data.table_name)
266
+ raise exceptions.APIResponse(
267
+ status_code=HTTPStatus.OK.real, detail=HTTPStatus.OK.phrase
268
+ )
269
+
270
+
271
+ async def delete_secret(
272
+ request: Request,
273
+ data: payload.DeleteSecret,
274
+ apikey: HTTPAuthorizationCredentials = Depends(security),
275
+ ):
276
+ """**API function to delete secrets from database.**
277
+
278
+ **Args:**
279
+
280
+ request: Reference to the FastAPI request object.
281
+ data: Payload with ``key`` and ``table_name`` as body.
282
+ apikey: API Key to authenticate the request.
283
+
284
+ **Raises:**
285
+
286
+ APIResponse:
287
+ Raises the HTTPStatus object with a status code and detail as response.
288
+ """
289
+ await auth.validate(request, apikey)
290
+ if await retrieve_secret(data.key, data.table_name):
291
+ LOGGER.info("Secret value for '%s' will be removed", data.key)
292
+ else:
293
+ LOGGER.warning("Secret value for '%s' NOT found", data.key)
294
+ raise exceptions.APIResponse(
295
+ status_code=HTTPStatus.NOT_FOUND.real, detail=HTTPStatus.NOT_FOUND.phrase
296
+ )
297
+ database.remove_secret(key=data.key, table_name=data.table_name)
298
+ raise exceptions.APIResponse(
299
+ status_code=HTTPStatus.OK.real, detail=HTTPStatus.OK.phrase
300
+ )
301
+
302
+
303
+ async def create_table(
304
+ request: Request,
305
+ table_name: str,
306
+ apikey: HTTPAuthorizationCredentials = Depends(security),
307
+ ):
308
+ """**API function to create a new table in the database.**
309
+
310
+ **Args:**
311
+
312
+ request: Reference to the FastAPI request object.
313
+ table_name: Name of the table to be created.
314
+ apikey: API Key to authenticate the request.
315
+
316
+ **Raises:**
317
+
318
+ APIResponse:
319
+ Raises the HTTPStatus object with a status code and detail as response.
320
+ """
321
+ await auth.validate(request, apikey)
322
+ try:
323
+ database.create_table(table_name, ["key", "value"])
324
+ except sqlite3.OperationalError as error:
325
+ LOGGER.error(error)
326
+ raise exceptions.APIResponse(
327
+ status_code=HTTPStatus.EXPECTATION_FAILED.real, detail=error.args[0]
328
+ )
329
+ raise exceptions.APIResponse(
330
+ status_code=HTTPStatus.OK.real, detail=HTTPStatus.OK.phrase
331
+ )
332
+
333
+
334
+ async def health() -> Dict[str, str]:
335
+ """Healthcheck endpoint.
336
+
337
+ Returns:
338
+ Dict[str, str]:
339
+ Returns the health response.
340
+ """
341
+ return {"STATUS": "OK"}
342
+
343
+
344
+ async def docs() -> RedirectResponse:
345
+ """Redirect to docs page.
346
+
347
+ Returns:
348
+ RedirectResponse:
349
+ Redirects the user to ``/docs`` page.
350
+ """
351
+ return RedirectResponse("/docs")
352
+
353
+
354
+ def get_all_routes() -> List[APIRoute]:
355
+ """Get all the routes to be added for the API server.
356
+
357
+ Returns:
358
+ List[APIRoute]:
359
+ Returns the routes as a list of APIRoute objects.
360
+ """
361
+ dependencies = [
362
+ Depends(dependency=rate_limit.RateLimiter(each_rate_limit).init)
363
+ for each_rate_limit in models.env.rate_limit
364
+ ]
365
+ routes = [
366
+ APIRoute(path="/", endpoint=docs, methods=["GET"], include_in_schema=False),
367
+ APIRoute(
368
+ path="/health", endpoint=health, methods=["GET"], include_in_schema=False
369
+ ),
370
+ APIRoute(
371
+ path="/get-secret",
372
+ endpoint=get_secret,
373
+ methods=["GET"],
374
+ dependencies=dependencies,
375
+ ),
376
+ APIRoute(
377
+ path="/get-secrets",
378
+ endpoint=get_secrets,
379
+ methods=["GET"],
380
+ dependencies=dependencies,
381
+ ),
382
+ APIRoute(
383
+ path="/get-table",
384
+ endpoint=get_table,
385
+ methods=["GET"],
386
+ dependencies=dependencies,
387
+ ),
388
+ APIRoute(
389
+ path="/list-tables",
390
+ endpoint=list_tables,
391
+ methods=["GET"],
392
+ dependencies=dependencies,
393
+ ),
394
+ APIRoute(
395
+ path="/put-secret",
396
+ endpoint=put_secret,
397
+ methods=["PUT"],
398
+ dependencies=dependencies,
399
+ ),
400
+ APIRoute(
401
+ path="/delete-secret",
402
+ endpoint=delete_secret,
403
+ methods=["DELETE"],
404
+ dependencies=dependencies,
405
+ ),
406
+ APIRoute(
407
+ path="/create-table",
408
+ endpoint=create_table,
409
+ methods=["POST"],
410
+ dependencies=dependencies,
411
+ ),
412
+ ]
413
+ return routes
vaultapi/server.py ADDED
@@ -0,0 +1,20 @@
1
+ import pathlib
2
+
3
+ import uvicorn
4
+
5
+ from . import database, models
6
+
7
+
8
+ def start() -> None:
9
+ """Starter function for the API, which uses uvicorn server as trigger."""
10
+ database.create_table("default", ["key", "value"])
11
+ module_name = pathlib.Path(__file__)
12
+ kwargs = dict(
13
+ host=models.env.host,
14
+ port=models.env.port,
15
+ workers=models.env.workers,
16
+ app=f"{module_name.parent.stem}.api:VaultAPI",
17
+ )
18
+ if models.env.log_config:
19
+ kwargs["log_config"] = models.env.log_config
20
+ uvicorn.run(**kwargs)
vaultapi/transit.py ADDED
@@ -0,0 +1,85 @@
1
+ """Module that performs transit encryption/decryption.
2
+
3
+ This allows the server to securely transmit the retrieved secret to be decrypted at the client side using the API key.
4
+ """
5
+
6
+ import base64
7
+ import hashlib
8
+ import json
9
+ import secrets
10
+ import time
11
+ from typing import Any, ByteString, Dict
12
+
13
+ from cryptography.hazmat.primitives.ciphers.aead import AESGCM
14
+
15
+ from . import models
16
+
17
+
18
+ def string_to_aes_key(input_string: str, key_length: int) -> ByteString:
19
+ """Hashes the string.
20
+
21
+ Args:
22
+ input_string: String for which an AES hash has to be generated.
23
+ key_length: AES key size used during encryption.
24
+
25
+ See Also:
26
+ AES supports three key lengths:
27
+ - 128 bits (16 bytes)
28
+ - 192 bits (24 bytes)
29
+ - 256 bits (32 bytes)
30
+
31
+ Returns:
32
+ str:
33
+ Return the first 16 bytes for the AES key
34
+ """
35
+ hash_object = hashlib.sha256(input_string.encode())
36
+ return hash_object.digest()[:key_length]
37
+
38
+
39
+ def encrypt(payload: Dict[str, Any], url_safe: bool = True) -> ByteString | str:
40
+ """Encrypt a message using GCM mode with 12 fresh bytes.
41
+
42
+ Args:
43
+ payload: Payload to be encrypted.
44
+ url_safe: Boolean flag to perform base64 encoding to perform JSON serialization.
45
+
46
+ Returns:
47
+ ByteString | str:
48
+ Returns the ciphertext as a string or bytes based on the ``url_safe`` flag.
49
+ """
50
+ nonce = secrets.token_bytes(12)
51
+ encoded = json.dumps(payload).encode()
52
+ epoch = int(time.time()) // models.env.transit_time_bucket
53
+ aes_key = string_to_aes_key(
54
+ f"{epoch}.{models.env.apikey}", models.env.transit_key_length
55
+ )
56
+ ciphertext = nonce + AESGCM(aes_key).encrypt(nonce, encoded, b"")
57
+ if url_safe:
58
+ return base64.b64encode(ciphertext).decode("utf-8")
59
+ return ciphertext
60
+
61
+
62
+ def decrypt(ciphertext: ByteString | str) -> Dict[str, Any]:
63
+ """Decrypt the ciphertext.
64
+
65
+ Raises:
66
+ Raises ``InvalidTag`` if using wrong key or corrupted ciphertext.
67
+
68
+ Returns:
69
+ Dict[str, Any]:
70
+ Returns the JSON serialized decrypted payload.
71
+ """
72
+ if isinstance(ciphertext, str):
73
+ ciphertext = base64.b64decode(ciphertext)
74
+ epoch = int(time.time()) // models.env.transit_time_bucket
75
+ aes_key = string_to_aes_key(
76
+ f"{epoch}.{models.env.apikey}", models.env.transit_key_length
77
+ )
78
+ decrypted = AESGCM(aes_key).decrypt(ciphertext[:12], ciphertext[12:], b"")
79
+ return json.loads(decrypted)
80
+
81
+
82
+ if __name__ == "__main__":
83
+ encrypted = encrypt({"key": "value"})
84
+ b64_encoded = base64.b64encode(encrypted).decode("utf-8")
85
+ print(decrypt(b64_encoded))
vaultapi/util.py ADDED
@@ -0,0 +1,86 @@
1
+ import base64
2
+ import hashlib
3
+ import importlib
4
+ import json
5
+ import logging
6
+ import sqlite3
7
+ import time
8
+ from typing import Any, ByteString, Dict
9
+
10
+ from cryptography.hazmat.primitives.ciphers.aead import AESGCM
11
+ from dotenv import dotenv_values
12
+
13
+ from . import database, models
14
+
15
+ importlib.reload(logging)
16
+ LOGGER = logging.getLogger(__name__)
17
+ LOGGER.setLevel(logging.DEBUG)
18
+ HANDLER = logging.StreamHandler()
19
+ DEFAULT_FORMATTER = logging.Formatter(
20
+ datefmt="%b-%d-%Y %I:%M:%S %p",
21
+ fmt="%(asctime)s - %(levelname)s - [%(module)s:%(lineno)d] - %(funcName)s - %(message)s",
22
+ )
23
+ HANDLER.setFormatter(DEFAULT_FORMATTER)
24
+ LOGGER.addHandler(HANDLER)
25
+
26
+
27
+ def dotenv_to_table(
28
+ table_name: str, dotenv_file: str, drop_existing: bool = False
29
+ ) -> None:
30
+ """Store all the env vars from a .env file into the database.
31
+
32
+ Args:
33
+ table_name: Name of the table to store secrets.
34
+ dotenv_file: Dot env filename.
35
+ drop_existing: Boolean flag to drop existing table.
36
+ """
37
+ if drop_existing and database.table_exists(table_name):
38
+ LOGGER.info("Dropping table '%s' from '%s'", table_name, models.env.database)
39
+ database.drop_table(table_name)
40
+ database.create_table(table_name, ["key", "value"])
41
+ else:
42
+ try:
43
+ if existing := database.get_table(table_name):
44
+ LOGGER.warning(
45
+ "Table '%s' exists already in %s. %d secrets will be overwritten",
46
+ table_name,
47
+ models.env.database,
48
+ len(existing),
49
+ )
50
+ except sqlite3.OperationalError as error:
51
+ if str(error) == f"no such table: {table_name}":
52
+ LOGGER.info(
53
+ "Creating a new table '%s' in '%s'", table_name, models.env.database
54
+ )
55
+ database.create_table(table_name, ["key", "value"])
56
+ else:
57
+ raise
58
+ env_vars = dotenv_values(dotenv_file)
59
+ for key, value in env_vars.items():
60
+ encrypted = models.session.fernet.encrypt(value.encode(encoding="UTF-8"))
61
+ database.put_secret(key, encrypted, table_name)
62
+ LOGGER.info(
63
+ "%d secrets stored in the table %s, in the database %s.",
64
+ len(env_vars),
65
+ table_name,
66
+ models.env.database,
67
+ )
68
+
69
+
70
+ def transit_decrypt(ciphertext: str | ByteString) -> Dict[str, Any]:
71
+ """Decrypts the ciphertext into an appropriate payload.
72
+
73
+ Args:
74
+ ciphertext: Encrypted ciphertext.
75
+
76
+ Returns:
77
+ Dict[str, Any]:
78
+ Returns the decrypted payload.
79
+ """
80
+ epoch = int(time.time()) // models.env.transit_time_bucket
81
+ hash_object = hashlib.sha256(f"{epoch}.{models.env.apikey}".encode())
82
+ aes_key = hash_object.digest()[: models.env.transit_key_length]
83
+ if isinstance(ciphertext, str):
84
+ ciphertext = base64.b64decode(ciphertext)
85
+ decrypted = AESGCM(aes_key).decrypt(ciphertext[:12], ciphertext[12:], b"")
86
+ return json.loads(decrypted)
vaultapi/version.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = "0.1.1"