VaultAPI 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- VaultAPI-0.1.1.dist-info/LICENSE +21 -0
- VaultAPI-0.1.1.dist-info/METADATA +240 -0
- VaultAPI-0.1.1.dist-info/RECORD +19 -0
- VaultAPI-0.1.1.dist-info/WHEEL +5 -0
- VaultAPI-0.1.1.dist-info/entry_points.txt +2 -0
- VaultAPI-0.1.1.dist-info/top_level.txt +1 -0
- vaultapi/__init__.py +77 -0
- vaultapi/api.py +53 -0
- vaultapi/auth.py +51 -0
- vaultapi/database.py +122 -0
- vaultapi/exceptions.py +9 -0
- vaultapi/models.py +282 -0
- vaultapi/payload.py +37 -0
- vaultapi/rate_limit.py +68 -0
- vaultapi/routes.py +413 -0
- vaultapi/server.py +20 -0
- vaultapi/transit.py +85 -0
- vaultapi/util.py +86 -0
- vaultapi/version.py +1 -0
vaultapi/models.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
import pathlib
|
|
5
|
+
import re
|
|
6
|
+
import socket
|
|
7
|
+
import sqlite3
|
|
8
|
+
from typing import Any, Dict, List, Set
|
|
9
|
+
|
|
10
|
+
import yaml
|
|
11
|
+
from cryptography.fernet import Fernet
|
|
12
|
+
from pydantic import (
|
|
13
|
+
BaseModel,
|
|
14
|
+
Field,
|
|
15
|
+
FilePath,
|
|
16
|
+
HttpUrl,
|
|
17
|
+
NewPath,
|
|
18
|
+
PositiveInt,
|
|
19
|
+
field_validator,
|
|
20
|
+
)
|
|
21
|
+
from pydantic_settings import BaseSettings
|
|
22
|
+
|
|
23
|
+
LOGGER = logging.getLogger("uvicorn.default")
|
|
24
|
+
DEFAULT_ALLOWED = ("0.0.0.0", "127.0.0.1", "localhost")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def complexity_checker(secret: str) -> None:
|
|
28
|
+
"""Verifies the strength of a secret.
|
|
29
|
+
|
|
30
|
+
See Also:
|
|
31
|
+
A secret is considered strong if it at least has:
|
|
32
|
+
|
|
33
|
+
- 32 characters
|
|
34
|
+
- 1 digit
|
|
35
|
+
- 1 symbol
|
|
36
|
+
- 1 uppercase letter
|
|
37
|
+
- 1 lowercase letter
|
|
38
|
+
|
|
39
|
+
Raises:
|
|
40
|
+
AssertionError: When at least 1 of the above conditions fail to match.
|
|
41
|
+
"""
|
|
42
|
+
# calculates the length
|
|
43
|
+
assert (
|
|
44
|
+
len(secret) >= 32
|
|
45
|
+
), f"secret length must be at least 32, received {len(secret)}"
|
|
46
|
+
|
|
47
|
+
# searches for digits
|
|
48
|
+
assert re.search(r"\d", secret), "secret must include an integer"
|
|
49
|
+
|
|
50
|
+
# searches for uppercase
|
|
51
|
+
assert re.search(
|
|
52
|
+
r"[A-Z]", secret
|
|
53
|
+
), "secret must include at least one uppercase letter"
|
|
54
|
+
|
|
55
|
+
# searches for lowercase
|
|
56
|
+
assert re.search(
|
|
57
|
+
r"[a-z]", secret
|
|
58
|
+
), "secret must include at least one lowercase letter"
|
|
59
|
+
|
|
60
|
+
# searches for symbols
|
|
61
|
+
assert re.search(
|
|
62
|
+
r"[ !@#$%^&*()_='+,-./[\\\]`{|}~" + r'"]', secret
|
|
63
|
+
), "secret must contain at least one special character"
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class Database:
|
|
67
|
+
"""Creates a connection and instantiates the cursor.
|
|
68
|
+
|
|
69
|
+
>>> Database
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
filepath: Name of the database file.
|
|
73
|
+
timeout: Timeout for the connection to database.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
def __init__(self, filepath: FilePath | str, timeout: int = 10):
|
|
77
|
+
"""Instantiates the class ``Database`` to create a connection and a cursor."""
|
|
78
|
+
if not filepath.endswith(".db"):
|
|
79
|
+
filepath = filepath + ".db"
|
|
80
|
+
self.connection = sqlite3.connect(
|
|
81
|
+
database=filepath, check_same_thread=False, timeout=timeout
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class RateLimit(BaseModel):
|
|
86
|
+
"""Object to store the rate limit settings.
|
|
87
|
+
|
|
88
|
+
>>> RateLimit
|
|
89
|
+
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
max_requests: PositiveInt
|
|
93
|
+
seconds: PositiveInt
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class Session(BaseModel):
|
|
97
|
+
"""Object to store session information.
|
|
98
|
+
|
|
99
|
+
>>> Session
|
|
100
|
+
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
fernet: Fernet | None = None
|
|
104
|
+
info: Dict[str, str] = Field(default_factory=dict)
|
|
105
|
+
rps: Dict[str, int] = Field(default_factory=dict)
|
|
106
|
+
allowed_origins: Set[str] = Field(default_factory=set)
|
|
107
|
+
|
|
108
|
+
class Config:
|
|
109
|
+
"""Config to allow arbitrary types."""
|
|
110
|
+
|
|
111
|
+
arbitrary_types_allowed = True
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class EnvConfig(BaseSettings):
|
|
115
|
+
"""Object to load environment variables.
|
|
116
|
+
|
|
117
|
+
>>> EnvConfig
|
|
118
|
+
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
apikey: str
|
|
122
|
+
secret: str
|
|
123
|
+
transit_key_length: PositiveInt = 32
|
|
124
|
+
transit_time_bucket: PositiveInt = 60
|
|
125
|
+
database: FilePath | NewPath | str = Field("secrets.db", pattern=".*.db$")
|
|
126
|
+
host: str = socket.gethostbyname("localhost") or "0.0.0.0"
|
|
127
|
+
port: PositiveInt = 9010
|
|
128
|
+
workers: PositiveInt = 1
|
|
129
|
+
log_config: FilePath | Dict[str, Any] | None = None
|
|
130
|
+
allowed_origins: HttpUrl | List[HttpUrl] = Field(default_factory=list)
|
|
131
|
+
allowed_ip_range: List[str] = Field(default_factory=list)
|
|
132
|
+
# This is a base rate limit configuration
|
|
133
|
+
rate_limit: RateLimit | List[RateLimit] = Field(
|
|
134
|
+
default=[
|
|
135
|
+
# Burst limit: Prevents excessive load on the server
|
|
136
|
+
{
|
|
137
|
+
"max_requests": 5,
|
|
138
|
+
"seconds": 2,
|
|
139
|
+
},
|
|
140
|
+
# Sustained limit: Prevents too many trial and errors
|
|
141
|
+
{
|
|
142
|
+
"max_requests": 10,
|
|
143
|
+
"seconds": 30,
|
|
144
|
+
},
|
|
145
|
+
]
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
@field_validator("allowed_origins", mode="after", check_fields=True)
|
|
149
|
+
def validate_allowed_origins(
|
|
150
|
+
cls, value: HttpUrl | List[HttpUrl] # noqa: PyMethodParameters
|
|
151
|
+
) -> List[HttpUrl]:
|
|
152
|
+
"""Validate allowed origins to enable CORS policy."""
|
|
153
|
+
if isinstance(value, list):
|
|
154
|
+
return value
|
|
155
|
+
return [value]
|
|
156
|
+
|
|
157
|
+
@field_validator("allowed_ip_range", mode="after", check_fields=True)
|
|
158
|
+
def validate_allowed_ip_range(
|
|
159
|
+
cls, value: List[str] # noqa: PyMethodParameters
|
|
160
|
+
) -> List[str]:
|
|
161
|
+
"""Validate allowed IP range to whitelist."""
|
|
162
|
+
for ip_range in value:
|
|
163
|
+
try:
|
|
164
|
+
assert (
|
|
165
|
+
len(ip_range.split(".")) > 1
|
|
166
|
+
), f"Expected a valid IP address, received {ip_range}"
|
|
167
|
+
assert (
|
|
168
|
+
len(ip_range.split(".")[-1].split("-")) == 2
|
|
169
|
+
), f"Expected a valid IP range, received {ip_range}"
|
|
170
|
+
except AssertionError as error:
|
|
171
|
+
exc = f"{error}\n\tInput should be a list of IP range (eg: ['192.168.1.10-19', '10.120.1.5-35'])"
|
|
172
|
+
raise ValueError(exc)
|
|
173
|
+
return value
|
|
174
|
+
|
|
175
|
+
@field_validator("apikey", mode="after")
|
|
176
|
+
def validate_apikey(cls, value: str) -> str | None: # noqa: PyMethodParameters
|
|
177
|
+
"""Validate API key for complexity."""
|
|
178
|
+
try:
|
|
179
|
+
complexity_checker(value)
|
|
180
|
+
except AssertionError as error:
|
|
181
|
+
raise ValueError(error.__str__())
|
|
182
|
+
return value
|
|
183
|
+
|
|
184
|
+
@field_validator("secret", mode="after")
|
|
185
|
+
def validate_api_secret(cls, value: str) -> str: # noqa: PyMethodParameters
|
|
186
|
+
"""Validate API secret to Fernet compatible."""
|
|
187
|
+
try:
|
|
188
|
+
Fernet(value)
|
|
189
|
+
except ValueError as error:
|
|
190
|
+
exc = f"{error}\n\tConsider using 'vaultapi keygen' command to generate a valid secret."
|
|
191
|
+
raise ValueError(exc)
|
|
192
|
+
return value
|
|
193
|
+
|
|
194
|
+
@classmethod
|
|
195
|
+
def from_env_file(cls, env_file: pathlib.Path) -> "EnvConfig":
|
|
196
|
+
"""Create Settings instance from environment file.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
env_file: Name of the env file.
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
EnvConfig:
|
|
203
|
+
Loads the ``EnvConfig`` model.
|
|
204
|
+
"""
|
|
205
|
+
return cls(_env_file=env_file)
|
|
206
|
+
|
|
207
|
+
class Config:
|
|
208
|
+
"""Extra configuration for EnvConfig object."""
|
|
209
|
+
|
|
210
|
+
extra = "ignore"
|
|
211
|
+
hide_input_in_errors = True
|
|
212
|
+
arbitrary_types_allowed = True
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def envfile_loader(filename: str | os.PathLike) -> EnvConfig:
|
|
216
|
+
"""Loads environment variables based on filetypes.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
filename: Filename from where env vars have to be loaded.
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
EnvConfig:
|
|
223
|
+
Returns a reference to the ``EnvConfig`` object.
|
|
224
|
+
"""
|
|
225
|
+
env_file = pathlib.Path(filename)
|
|
226
|
+
if env_file.suffix.lower() == ".json":
|
|
227
|
+
with open(env_file) as stream:
|
|
228
|
+
env_data = json.load(stream)
|
|
229
|
+
return EnvConfig(**{k.lower(): v for k, v in env_data.items()})
|
|
230
|
+
elif env_file.suffix.lower() in (".yaml", ".yml"):
|
|
231
|
+
with open(env_file) as stream:
|
|
232
|
+
env_data = yaml.load(stream, yaml.FullLoader)
|
|
233
|
+
return EnvConfig(**{k.lower(): v for k, v in env_data.items()})
|
|
234
|
+
elif not env_file.suffix or env_file.suffix.lower() in (
|
|
235
|
+
".text",
|
|
236
|
+
".txt",
|
|
237
|
+
"",
|
|
238
|
+
):
|
|
239
|
+
return EnvConfig.from_env_file(env_file)
|
|
240
|
+
else:
|
|
241
|
+
raise ValueError(
|
|
242
|
+
"\n\tUnsupported format for 'env_file', can be one of (.json, .yaml, .yml, .txt, .text, or null)"
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def load_env() -> EnvConfig:
|
|
247
|
+
"""Loads te env vars based on the env_file provided.
|
|
248
|
+
|
|
249
|
+
See Also:
|
|
250
|
+
This function allows env vars to be loaded partially from .env files and partially through kwargs.
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
EnvConfig:
|
|
254
|
+
Returns a reference to the ``EnvConfig`` object.
|
|
255
|
+
"""
|
|
256
|
+
env_file = os.getenv("env_file") or os.getenv("ENV_FILE") or ".env"
|
|
257
|
+
if os.path.isfile(env_file):
|
|
258
|
+
return envfile_loader(env_file)
|
|
259
|
+
return EnvConfig()
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def __init__() -> None:
|
|
263
|
+
"""Instantiates the env, session and database connections."""
|
|
264
|
+
session.fernet = Fernet(env.secret)
|
|
265
|
+
if env.host in DEFAULT_ALLOWED:
|
|
266
|
+
session.allowed_origins.update(DEFAULT_ALLOWED)
|
|
267
|
+
else:
|
|
268
|
+
session.allowed_origins.add(env.host)
|
|
269
|
+
for allowed in env.allowed_origins:
|
|
270
|
+
session.allowed_origins.add(allowed.host)
|
|
271
|
+
for cidr_range in env.allowed_ip_range:
|
|
272
|
+
ip_notion = ".".join(cidr_range.split(".")[0:-1])
|
|
273
|
+
start_ip, end_ip = cidr_range.split(".")[-1].split("-")
|
|
274
|
+
start_ip, end_ip = int(start_ip), int(end_ip) + 1
|
|
275
|
+
for i in range(start_ip, end_ip):
|
|
276
|
+
session.allowed_origins.add(f"{ip_notion}.{i}")
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
env: EnvConfig = load_env()
|
|
280
|
+
database: Database = Database(env.database)
|
|
281
|
+
session = Session()
|
|
282
|
+
__init__()
|
vaultapi/payload.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from typing import Dict
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class DeleteSecret(BaseModel):
|
|
7
|
+
"""Payload for delete-secret API call.
|
|
8
|
+
|
|
9
|
+
>>> DeleteSecret
|
|
10
|
+
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
key: str
|
|
14
|
+
table_name: str = "default"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class PutSecret(BaseModel):
|
|
18
|
+
"""Payload for put-secret API call.
|
|
19
|
+
|
|
20
|
+
>>> PutSecret
|
|
21
|
+
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
key: str
|
|
25
|
+
value: str
|
|
26
|
+
table_name: str = "default"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class PutSecrets(BaseModel):
|
|
30
|
+
"""Payload for put-secrets API call.
|
|
31
|
+
|
|
32
|
+
>>> PutSecret
|
|
33
|
+
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
secrets: Dict[str, str]
|
|
37
|
+
table_name: str = "default"
|
vaultapi/rate_limit.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
import collections
|
|
2
|
+
import math
|
|
3
|
+
import time
|
|
4
|
+
from http import HTTPStatus
|
|
5
|
+
from threading import Lock
|
|
6
|
+
|
|
7
|
+
from fastapi import HTTPException, Request
|
|
8
|
+
|
|
9
|
+
from . import models
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _get_identifier(request: Request) -> str:
|
|
13
|
+
"""Generate a unique identifier for the request."""
|
|
14
|
+
if forwarded := request.headers.get("x-forwarded-for"):
|
|
15
|
+
return f"{forwarded.split(',')[0]}:{request.url.path}"
|
|
16
|
+
return f"{request.client.host}:{request.url.path}"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class RateLimiter:
|
|
20
|
+
"""Rate limiter for incoming requests.
|
|
21
|
+
|
|
22
|
+
>>> RateLimiter
|
|
23
|
+
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, rps: models.RateLimit):
|
|
27
|
+
# noinspection PyUnresolvedReferences
|
|
28
|
+
"""Instantiates the object with the necessary args.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
rps: RateLimit object with ``max_requests`` and ``seconds``.
|
|
32
|
+
|
|
33
|
+
Attributes:
|
|
34
|
+
max_requests: Maximum requests to allow in a given time frame.
|
|
35
|
+
seconds: Number of seconds after which the cache is set to expire.
|
|
36
|
+
"""
|
|
37
|
+
self.max_requests = rps.max_requests
|
|
38
|
+
self.seconds = rps.seconds
|
|
39
|
+
self.locks = collections.defaultdict(Lock) # For thread-safe access
|
|
40
|
+
self.requests = collections.defaultdict(list)
|
|
41
|
+
|
|
42
|
+
def init(self, request: Request) -> None:
|
|
43
|
+
"""Checks if the number of calls exceeds the rate limit for the given identifier.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
request: The incoming request object.
|
|
47
|
+
|
|
48
|
+
Raises:
|
|
49
|
+
429: Too many requests.
|
|
50
|
+
"""
|
|
51
|
+
identifier = _get_identifier(request)
|
|
52
|
+
current_time = time.time()
|
|
53
|
+
|
|
54
|
+
with self.locks[identifier]:
|
|
55
|
+
# Clean up expired timestamps
|
|
56
|
+
self.requests[identifier] = [
|
|
57
|
+
timestamp
|
|
58
|
+
for timestamp in self.requests[identifier]
|
|
59
|
+
if current_time - timestamp < self.seconds
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
if len(self.requests[identifier]) >= self.max_requests:
|
|
63
|
+
raise HTTPException(
|
|
64
|
+
status_code=HTTPStatus.TOO_MANY_REQUESTS.value,
|
|
65
|
+
detail=HTTPStatus.TOO_MANY_REQUESTS.phrase,
|
|
66
|
+
headers={"Retry-After": str(math.ceil(self.seconds))},
|
|
67
|
+
)
|
|
68
|
+
self.requests[identifier].append(current_time)
|