VaultAPI 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vaultapi/models.py ADDED
@@ -0,0 +1,282 @@
1
+ import json
2
+ import logging
3
+ import os
4
+ import pathlib
5
+ import re
6
+ import socket
7
+ import sqlite3
8
+ from typing import Any, Dict, List, Set
9
+
10
+ import yaml
11
+ from cryptography.fernet import Fernet
12
+ from pydantic import (
13
+ BaseModel,
14
+ Field,
15
+ FilePath,
16
+ HttpUrl,
17
+ NewPath,
18
+ PositiveInt,
19
+ field_validator,
20
+ )
21
+ from pydantic_settings import BaseSettings
22
+
23
+ LOGGER = logging.getLogger("uvicorn.default")
24
+ DEFAULT_ALLOWED = ("0.0.0.0", "127.0.0.1", "localhost")
25
+
26
+
27
+ def complexity_checker(secret: str) -> None:
28
+ """Verifies the strength of a secret.
29
+
30
+ See Also:
31
+ A secret is considered strong if it at least has:
32
+
33
+ - 32 characters
34
+ - 1 digit
35
+ - 1 symbol
36
+ - 1 uppercase letter
37
+ - 1 lowercase letter
38
+
39
+ Raises:
40
+ AssertionError: When at least 1 of the above conditions fail to match.
41
+ """
42
+ # calculates the length
43
+ assert (
44
+ len(secret) >= 32
45
+ ), f"secret length must be at least 32, received {len(secret)}"
46
+
47
+ # searches for digits
48
+ assert re.search(r"\d", secret), "secret must include an integer"
49
+
50
+ # searches for uppercase
51
+ assert re.search(
52
+ r"[A-Z]", secret
53
+ ), "secret must include at least one uppercase letter"
54
+
55
+ # searches for lowercase
56
+ assert re.search(
57
+ r"[a-z]", secret
58
+ ), "secret must include at least one lowercase letter"
59
+
60
+ # searches for symbols
61
+ assert re.search(
62
+ r"[ !@#$%^&*()_='+,-./[\\\]`{|}~" + r'"]', secret
63
+ ), "secret must contain at least one special character"
64
+
65
+
66
+ class Database:
67
+ """Creates a connection and instantiates the cursor.
68
+
69
+ >>> Database
70
+
71
+ Args:
72
+ filepath: Name of the database file.
73
+ timeout: Timeout for the connection to database.
74
+ """
75
+
76
+ def __init__(self, filepath: FilePath | str, timeout: int = 10):
77
+ """Instantiates the class ``Database`` to create a connection and a cursor."""
78
+ if not filepath.endswith(".db"):
79
+ filepath = filepath + ".db"
80
+ self.connection = sqlite3.connect(
81
+ database=filepath, check_same_thread=False, timeout=timeout
82
+ )
83
+
84
+
85
+ class RateLimit(BaseModel):
86
+ """Object to store the rate limit settings.
87
+
88
+ >>> RateLimit
89
+
90
+ """
91
+
92
+ max_requests: PositiveInt
93
+ seconds: PositiveInt
94
+
95
+
96
+ class Session(BaseModel):
97
+ """Object to store session information.
98
+
99
+ >>> Session
100
+
101
+ """
102
+
103
+ fernet: Fernet | None = None
104
+ info: Dict[str, str] = Field(default_factory=dict)
105
+ rps: Dict[str, int] = Field(default_factory=dict)
106
+ allowed_origins: Set[str] = Field(default_factory=set)
107
+
108
+ class Config:
109
+ """Config to allow arbitrary types."""
110
+
111
+ arbitrary_types_allowed = True
112
+
113
+
114
+ class EnvConfig(BaseSettings):
115
+ """Object to load environment variables.
116
+
117
+ >>> EnvConfig
118
+
119
+ """
120
+
121
+ apikey: str
122
+ secret: str
123
+ transit_key_length: PositiveInt = 32
124
+ transit_time_bucket: PositiveInt = 60
125
+ database: FilePath | NewPath | str = Field("secrets.db", pattern=".*.db$")
126
+ host: str = socket.gethostbyname("localhost") or "0.0.0.0"
127
+ port: PositiveInt = 9010
128
+ workers: PositiveInt = 1
129
+ log_config: FilePath | Dict[str, Any] | None = None
130
+ allowed_origins: HttpUrl | List[HttpUrl] = Field(default_factory=list)
131
+ allowed_ip_range: List[str] = Field(default_factory=list)
132
+ # This is a base rate limit configuration
133
+ rate_limit: RateLimit | List[RateLimit] = Field(
134
+ default=[
135
+ # Burst limit: Prevents excessive load on the server
136
+ {
137
+ "max_requests": 5,
138
+ "seconds": 2,
139
+ },
140
+ # Sustained limit: Prevents too many trial and errors
141
+ {
142
+ "max_requests": 10,
143
+ "seconds": 30,
144
+ },
145
+ ]
146
+ )
147
+
148
+ @field_validator("allowed_origins", mode="after", check_fields=True)
149
+ def validate_allowed_origins(
150
+ cls, value: HttpUrl | List[HttpUrl] # noqa: PyMethodParameters
151
+ ) -> List[HttpUrl]:
152
+ """Validate allowed origins to enable CORS policy."""
153
+ if isinstance(value, list):
154
+ return value
155
+ return [value]
156
+
157
+ @field_validator("allowed_ip_range", mode="after", check_fields=True)
158
+ def validate_allowed_ip_range(
159
+ cls, value: List[str] # noqa: PyMethodParameters
160
+ ) -> List[str]:
161
+ """Validate allowed IP range to whitelist."""
162
+ for ip_range in value:
163
+ try:
164
+ assert (
165
+ len(ip_range.split(".")) > 1
166
+ ), f"Expected a valid IP address, received {ip_range}"
167
+ assert (
168
+ len(ip_range.split(".")[-1].split("-")) == 2
169
+ ), f"Expected a valid IP range, received {ip_range}"
170
+ except AssertionError as error:
171
+ exc = f"{error}\n\tInput should be a list of IP range (eg: ['192.168.1.10-19', '10.120.1.5-35'])"
172
+ raise ValueError(exc)
173
+ return value
174
+
175
+ @field_validator("apikey", mode="after")
176
+ def validate_apikey(cls, value: str) -> str | None: # noqa: PyMethodParameters
177
+ """Validate API key for complexity."""
178
+ try:
179
+ complexity_checker(value)
180
+ except AssertionError as error:
181
+ raise ValueError(error.__str__())
182
+ return value
183
+
184
+ @field_validator("secret", mode="after")
185
+ def validate_api_secret(cls, value: str) -> str: # noqa: PyMethodParameters
186
+ """Validate API secret to Fernet compatible."""
187
+ try:
188
+ Fernet(value)
189
+ except ValueError as error:
190
+ exc = f"{error}\n\tConsider using 'vaultapi keygen' command to generate a valid secret."
191
+ raise ValueError(exc)
192
+ return value
193
+
194
+ @classmethod
195
+ def from_env_file(cls, env_file: pathlib.Path) -> "EnvConfig":
196
+ """Create Settings instance from environment file.
197
+
198
+ Args:
199
+ env_file: Name of the env file.
200
+
201
+ Returns:
202
+ EnvConfig:
203
+ Loads the ``EnvConfig`` model.
204
+ """
205
+ return cls(_env_file=env_file)
206
+
207
+ class Config:
208
+ """Extra configuration for EnvConfig object."""
209
+
210
+ extra = "ignore"
211
+ hide_input_in_errors = True
212
+ arbitrary_types_allowed = True
213
+
214
+
215
+ def envfile_loader(filename: str | os.PathLike) -> EnvConfig:
216
+ """Loads environment variables based on filetypes.
217
+
218
+ Args:
219
+ filename: Filename from where env vars have to be loaded.
220
+
221
+ Returns:
222
+ EnvConfig:
223
+ Returns a reference to the ``EnvConfig`` object.
224
+ """
225
+ env_file = pathlib.Path(filename)
226
+ if env_file.suffix.lower() == ".json":
227
+ with open(env_file) as stream:
228
+ env_data = json.load(stream)
229
+ return EnvConfig(**{k.lower(): v for k, v in env_data.items()})
230
+ elif env_file.suffix.lower() in (".yaml", ".yml"):
231
+ with open(env_file) as stream:
232
+ env_data = yaml.load(stream, yaml.FullLoader)
233
+ return EnvConfig(**{k.lower(): v for k, v in env_data.items()})
234
+ elif not env_file.suffix or env_file.suffix.lower() in (
235
+ ".text",
236
+ ".txt",
237
+ "",
238
+ ):
239
+ return EnvConfig.from_env_file(env_file)
240
+ else:
241
+ raise ValueError(
242
+ "\n\tUnsupported format for 'env_file', can be one of (.json, .yaml, .yml, .txt, .text, or null)"
243
+ )
244
+
245
+
246
+ def load_env() -> EnvConfig:
247
+ """Loads te env vars based on the env_file provided.
248
+
249
+ See Also:
250
+ This function allows env vars to be loaded partially from .env files and partially through kwargs.
251
+
252
+ Returns:
253
+ EnvConfig:
254
+ Returns a reference to the ``EnvConfig`` object.
255
+ """
256
+ env_file = os.getenv("env_file") or os.getenv("ENV_FILE") or ".env"
257
+ if os.path.isfile(env_file):
258
+ return envfile_loader(env_file)
259
+ return EnvConfig()
260
+
261
+
262
+ def __init__() -> None:
263
+ """Instantiates the env, session and database connections."""
264
+ session.fernet = Fernet(env.secret)
265
+ if env.host in DEFAULT_ALLOWED:
266
+ session.allowed_origins.update(DEFAULT_ALLOWED)
267
+ else:
268
+ session.allowed_origins.add(env.host)
269
+ for allowed in env.allowed_origins:
270
+ session.allowed_origins.add(allowed.host)
271
+ for cidr_range in env.allowed_ip_range:
272
+ ip_notion = ".".join(cidr_range.split(".")[0:-1])
273
+ start_ip, end_ip = cidr_range.split(".")[-1].split("-")
274
+ start_ip, end_ip = int(start_ip), int(end_ip) + 1
275
+ for i in range(start_ip, end_ip):
276
+ session.allowed_origins.add(f"{ip_notion}.{i}")
277
+
278
+
279
+ env: EnvConfig = load_env()
280
+ database: Database = Database(env.database)
281
+ session = Session()
282
+ __init__()
vaultapi/payload.py ADDED
@@ -0,0 +1,37 @@
1
+ from typing import Dict
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class DeleteSecret(BaseModel):
7
+ """Payload for delete-secret API call.
8
+
9
+ >>> DeleteSecret
10
+
11
+ """
12
+
13
+ key: str
14
+ table_name: str = "default"
15
+
16
+
17
+ class PutSecret(BaseModel):
18
+ """Payload for put-secret API call.
19
+
20
+ >>> PutSecret
21
+
22
+ """
23
+
24
+ key: str
25
+ value: str
26
+ table_name: str = "default"
27
+
28
+
29
+ class PutSecrets(BaseModel):
30
+ """Payload for put-secrets API call.
31
+
32
+ >>> PutSecret
33
+
34
+ """
35
+
36
+ secrets: Dict[str, str]
37
+ table_name: str = "default"
vaultapi/rate_limit.py ADDED
@@ -0,0 +1,68 @@
1
+ import collections
2
+ import math
3
+ import time
4
+ from http import HTTPStatus
5
+ from threading import Lock
6
+
7
+ from fastapi import HTTPException, Request
8
+
9
+ from . import models
10
+
11
+
12
+ def _get_identifier(request: Request) -> str:
13
+ """Generate a unique identifier for the request."""
14
+ if forwarded := request.headers.get("x-forwarded-for"):
15
+ return f"{forwarded.split(',')[0]}:{request.url.path}"
16
+ return f"{request.client.host}:{request.url.path}"
17
+
18
+
19
+ class RateLimiter:
20
+ """Rate limiter for incoming requests.
21
+
22
+ >>> RateLimiter
23
+
24
+ """
25
+
26
+ def __init__(self, rps: models.RateLimit):
27
+ # noinspection PyUnresolvedReferences
28
+ """Instantiates the object with the necessary args.
29
+
30
+ Args:
31
+ rps: RateLimit object with ``max_requests`` and ``seconds``.
32
+
33
+ Attributes:
34
+ max_requests: Maximum requests to allow in a given time frame.
35
+ seconds: Number of seconds after which the cache is set to expire.
36
+ """
37
+ self.max_requests = rps.max_requests
38
+ self.seconds = rps.seconds
39
+ self.locks = collections.defaultdict(Lock) # For thread-safe access
40
+ self.requests = collections.defaultdict(list)
41
+
42
+ def init(self, request: Request) -> None:
43
+ """Checks if the number of calls exceeds the rate limit for the given identifier.
44
+
45
+ Args:
46
+ request: The incoming request object.
47
+
48
+ Raises:
49
+ 429: Too many requests.
50
+ """
51
+ identifier = _get_identifier(request)
52
+ current_time = time.time()
53
+
54
+ with self.locks[identifier]:
55
+ # Clean up expired timestamps
56
+ self.requests[identifier] = [
57
+ timestamp
58
+ for timestamp in self.requests[identifier]
59
+ if current_time - timestamp < self.seconds
60
+ ]
61
+
62
+ if len(self.requests[identifier]) >= self.max_requests:
63
+ raise HTTPException(
64
+ status_code=HTTPStatus.TOO_MANY_REQUESTS.value,
65
+ detail=HTTPStatus.TOO_MANY_REQUESTS.phrase,
66
+ headers={"Retry-After": str(math.ceil(self.seconds))},
67
+ )
68
+ self.requests[identifier].append(current_time)