hypern 0.2.0__cp311-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. hypern/__init__.py +4 -0
  2. hypern/application.py +412 -0
  3. hypern/auth/__init__.py +0 -0
  4. hypern/auth/authorization.py +2 -0
  5. hypern/background.py +4 -0
  6. hypern/caching/__init__.py +0 -0
  7. hypern/caching/base/__init__.py +8 -0
  8. hypern/caching/base/backend.py +3 -0
  9. hypern/caching/base/key_maker.py +8 -0
  10. hypern/caching/cache_manager.py +56 -0
  11. hypern/caching/cache_tag.py +10 -0
  12. hypern/caching/custom_key_maker.py +11 -0
  13. hypern/caching/redis_backend.py +3 -0
  14. hypern/cli/__init__.py +0 -0
  15. hypern/cli/commands.py +0 -0
  16. hypern/config.py +149 -0
  17. hypern/datastructures.py +40 -0
  18. hypern/db/__init__.py +0 -0
  19. hypern/db/nosql/__init__.py +25 -0
  20. hypern/db/nosql/addons/__init__.py +4 -0
  21. hypern/db/nosql/addons/color.py +16 -0
  22. hypern/db/nosql/addons/daterange.py +30 -0
  23. hypern/db/nosql/addons/encrypted.py +53 -0
  24. hypern/db/nosql/addons/password.py +134 -0
  25. hypern/db/nosql/addons/unicode.py +10 -0
  26. hypern/db/sql/__init__.py +179 -0
  27. hypern/db/sql/addons/__init__.py +14 -0
  28. hypern/db/sql/addons/color.py +16 -0
  29. hypern/db/sql/addons/daterange.py +23 -0
  30. hypern/db/sql/addons/datetime.py +22 -0
  31. hypern/db/sql/addons/encrypted.py +58 -0
  32. hypern/db/sql/addons/password.py +171 -0
  33. hypern/db/sql/addons/ts_vector.py +46 -0
  34. hypern/db/sql/addons/unicode.py +15 -0
  35. hypern/db/sql/repository.py +290 -0
  36. hypern/enum.py +13 -0
  37. hypern/exceptions.py +97 -0
  38. hypern/hypern.cp311-win_amd64.pyd +0 -0
  39. hypern/hypern.pyi +266 -0
  40. hypern/i18n/__init__.py +0 -0
  41. hypern/logging/__init__.py +3 -0
  42. hypern/logging/logger.py +82 -0
  43. hypern/middleware/__init__.py +5 -0
  44. hypern/middleware/base.py +18 -0
  45. hypern/middleware/cors.py +38 -0
  46. hypern/middleware/i18n.py +1 -0
  47. hypern/middleware/limit.py +176 -0
  48. hypern/openapi/__init__.py +5 -0
  49. hypern/openapi/schemas.py +53 -0
  50. hypern/openapi/swagger.py +3 -0
  51. hypern/processpool.py +106 -0
  52. hypern/py.typed +0 -0
  53. hypern/response/__init__.py +3 -0
  54. hypern/response/response.py +134 -0
  55. hypern/routing/__init__.py +4 -0
  56. hypern/routing/dispatcher.py +67 -0
  57. hypern/routing/endpoint.py +30 -0
  58. hypern/routing/parser.py +100 -0
  59. hypern/routing/route.py +284 -0
  60. hypern/scheduler.py +5 -0
  61. hypern/security.py +44 -0
  62. hypern/worker.py +30 -0
  63. hypern-0.2.0.dist-info/METADATA +127 -0
  64. hypern-0.2.0.dist-info/RECORD +66 -0
  65. hypern-0.2.0.dist-info/WHEEL +4 -0
  66. hypern-0.2.0.dist-info/licenses/LICENSE +24 -0
hypern/__init__.py ADDED
@@ -0,0 +1,4 @@
1
+ from .application import Hypern
2
+ from .hypern import Request, Response
3
+
4
+ __all__ = ["Hypern", "Request", "Response"]
hypern/application.py ADDED
@@ -0,0 +1,412 @@
1
+ # -*- coding: utf-8 -*-
2
+ from __future__ import annotations
3
+
4
+ import asyncio
5
+ import socket
6
+ from typing import Any, Callable, List, TypeVar
7
+
8
+ import orjson
9
+ from typing_extensions import Annotated, Doc
10
+
11
+ from hypern.datastructures import Contact, HTTPMethod, Info, License
12
+ from hypern.exceptions import InvalidPortNumber
13
+ from hypern.hypern import FunctionInfo, Router
14
+ from hypern.hypern import Route as InternalRoute
15
+ from hypern.logging import logger
16
+ from hypern.openapi import SchemaGenerator, SwaggerUI
17
+ from hypern.processpool import run_processes
18
+ from hypern.response import HTMLResponse, JSONResponse
19
+ from hypern.routing import Route
20
+ from hypern.scheduler import Scheduler
21
+ from hypern.middleware import Middleware
22
+
23
+ AppType = TypeVar("AppType", bound="Hypern")
24
+
25
+
26
+ class Hypern:
27
+ def __init__(
28
+ self: AppType,
29
+ routes: Annotated[
30
+ List[Route] | None,
31
+ Doc(
32
+ """
33
+ A list of routes to serve incoming HTTP and WebSocket requests.
34
+ You can define routes using the `Route` class from `Hypern.routing`.
35
+ **Example**
36
+ ---
37
+ ```python
38
+ class DefaultRoute(HTTPEndpoint):
39
+ async def get(self, global_dependencies):
40
+ return PlainTextResponse("/hello")
41
+ Route("/test", DefaultRoute)
42
+
43
+ # Or you can define routes using the decorator
44
+ route = Route("/test)
45
+ @route.get("/route")
46
+ def def_get():
47
+ return PlainTextResponse("Hello")
48
+ ```
49
+ """
50
+ ),
51
+ ] = None,
52
+ title: Annotated[
53
+ str,
54
+ Doc(
55
+ """
56
+ The title of the API.
57
+
58
+ It will be added to the generated OpenAPI (e.g. visible at `/docs`).
59
+
60
+ Read more in the
61
+ """
62
+ ),
63
+ ] = "Hypern",
64
+ summary: Annotated[
65
+ str | None,
66
+ Doc(
67
+ """"
68
+ A short summary of the API.
69
+
70
+ It will be added to the generated OpenAPI (e.g. visible at `/docs`).
71
+ """
72
+ ),
73
+ ] = None,
74
+ description: Annotated[
75
+ str,
76
+ Doc(
77
+ """
78
+ A description of the API. Supports Markdown (using
79
+ [CommonMark syntax](https://commonmark.org/)).
80
+
81
+ It will be added to the generated OpenAPI (e.g. visible at `/docs`).
82
+ """
83
+ ),
84
+ ] = "",
85
+ version: Annotated[
86
+ str,
87
+ Doc(
88
+ """
89
+ The version of the API.
90
+
91
+ **Note** This is the version of your application, not the version of
92
+ the OpenAPI specification nor the version of Application being used.
93
+
94
+ It will be added to the generated OpenAPI (e.g. visible at `/docs`).
95
+ """
96
+ ),
97
+ ] = "0.0.1",
98
+ contact: Annotated[
99
+ Contact | None,
100
+ Doc(
101
+ """
102
+ A dictionary with the contact information for the exposed API.
103
+
104
+ It can contain several fields.
105
+
106
+ * `name`: (`str`) The name of the contact person/organization.
107
+ * `url`: (`str`) A URL pointing to the contact information. MUST be in
108
+ the format of a URL.
109
+ * `email`: (`str`) The email address of the contact person/organization.
110
+ MUST be in the format of an email address.
111
+ """
112
+ ),
113
+ ] = None,
114
+ openapi_url: Annotated[
115
+ str | None,
116
+ Doc(
117
+ """
118
+ The URL where the OpenAPI schema will be served from.
119
+
120
+ If you set it to `None`, no OpenAPI schema will be served publicly, and
121
+ the default automatic endpoints `/docs` and `/redoc` will also be
122
+ disabled.
123
+ """
124
+ ),
125
+ ] = "/openapi.json",
126
+ docs_url: Annotated[
127
+ str | None,
128
+ Doc(
129
+ """
130
+ The path to the automatic interactive API documentation.
131
+ It is handled in the browser by Swagger UI.
132
+
133
+ The default URL is `/docs`. You can disable it by setting it to `None`.
134
+
135
+ If `openapi_url` is set to `None`, this will be automatically disabled.
136
+ """
137
+ ),
138
+ ] = "/docs",
139
+ license_info: Annotated[
140
+ License | None,
141
+ Doc(
142
+ """
143
+ A dictionary with the license information for the exposed API.
144
+
145
+ It can contain several fields.
146
+
147
+ * `name`: (`str`) **REQUIRED** (if a `license_info` is set). The
148
+ license name used for the API.
149
+ * `identifier`: (`str`) An [SPDX](https://spdx.dev/) license expression
150
+ for the API. The `identifier` field is mutually exclusive of the `url`
151
+ field. Available since OpenAPI 3.1.0
152
+ * `url`: (`str`) A URL to the license used for the API. This MUST be
153
+ the format of a URL.
154
+
155
+ It will be added to the generated OpenAPI (e.g. visible at `/docs`).
156
+
157
+ **Example**
158
+
159
+ ```python
160
+ app = Hypern(
161
+ license_info={
162
+ "name": "Apache 2.0",
163
+ "url": "https://www.apache.org/licenses/LICENSE-2.0.html",
164
+ }
165
+ )
166
+ ```
167
+ """
168
+ ),
169
+ ] = None,
170
+ scheduler: Annotated[
171
+ Scheduler | None,
172
+ Doc(
173
+ """
174
+ A scheduler to run background tasks.
175
+ """
176
+ ),
177
+ ] = None,
178
+ default_injectables: Annotated[
179
+ dict[str, Any] | None,
180
+ Doc(
181
+ """
182
+ A dictionary of default injectables to be passed to all routes.
183
+ """
184
+ ),
185
+ ] = None,
186
+ *args: Any,
187
+ **kwargs: Any,
188
+ ) -> None:
189
+ super().__init__(*args, **kwargs)
190
+ self.router = Router(path="/")
191
+ self.scheduler = scheduler
192
+ self.injectables = default_injectables or {}
193
+ self.middleware_before_request = []
194
+ self.middleware_after_request = []
195
+ self.response_headers = {}
196
+
197
+ for route in routes:
198
+ self.router.extend_route(route(app=self).routes)
199
+
200
+ if openapi_url and docs_url:
201
+ self.__add_openapi(
202
+ info=Info(
203
+ title=title,
204
+ summary=summary,
205
+ description=description,
206
+ version=version,
207
+ contact=contact,
208
+ license=license_info,
209
+ ),
210
+ openapi_url=openapi_url,
211
+ docs_url=docs_url,
212
+ )
213
+
214
+ def __add_openapi(
215
+ self,
216
+ info: Info,
217
+ openapi_url: str,
218
+ docs_url: str,
219
+ ):
220
+ """
221
+ Adds OpenAPI schema and documentation routes to the application.
222
+
223
+ Args:
224
+ info (Info): An instance of the Info class containing metadata about the API.
225
+ openapi_url (str): The URL path where the OpenAPI schema will be served.
226
+ docs_url (str): The URL path where the Swagger UI documentation will be served.
227
+
228
+ The method defines two internal functions:
229
+ - schema: Generates and returns the OpenAPI schema as a JSON response.
230
+ - template_render: Renders and returns the Swagger UI documentation as an HTML response.
231
+
232
+ The method then adds routes to the application for serving the OpenAPI schema and the Swagger UI documentation.
233
+ """
234
+
235
+ def schema(*args, **kwargs):
236
+ schemas = SchemaGenerator(
237
+ {
238
+ "openapi": "3.0.0",
239
+ "info": info.model_dump(),
240
+ "components": {"securitySchemes": {}},
241
+ }
242
+ )
243
+ return JSONResponse(content=orjson.dumps(schemas.get_schema(self)))
244
+
245
+ def template_render(*args, **kwargs):
246
+ swagger = SwaggerUI(
247
+ title="Swagger",
248
+ openapi_url=openapi_url,
249
+ )
250
+ template = swagger.get_html_content()
251
+ return HTMLResponse(template)
252
+
253
+ self.add_route(HTTPMethod.GET, openapi_url, schema)
254
+ self.add_route(HTTPMethod.GET, docs_url, template_render)
255
+
256
+ def add_response_header(self, key: str, value: str):
257
+ """
258
+ Adds a response header to the response headers dictionary.
259
+
260
+ Args:
261
+ key (str): The header field name.
262
+ value (str): The header field value.
263
+ """
264
+ self.response_headers[key] = value
265
+
266
+ def before_request(self):
267
+ """
268
+ A decorator to register a function to be executed before each request.
269
+
270
+ This decorator can be used to add middleware functions that will be
271
+ executed before the main request handler. The function can be either
272
+ synchronous or asynchronous.
273
+
274
+ Returns:
275
+ function: The decorator function that registers the middleware.
276
+ """
277
+
278
+ def decorator(func):
279
+ is_async = asyncio.iscoroutinefunction(func)
280
+ func_info = FunctionInfo(handler=func, is_async=is_async)
281
+ self.middleware_before_request.append(func_info)
282
+ return func
283
+
284
+ return decorator
285
+
286
+ def after_request(self):
287
+ """
288
+ Decorator to register a function to be called after each request.
289
+
290
+ This decorator can be used to register both synchronous and asynchronous functions.
291
+ The registered function will be wrapped in a FunctionInfo object and appended to the
292
+ middleware_after_request list.
293
+
294
+ Returns:
295
+ function: The decorator function that registers the given function.
296
+ """
297
+
298
+ def decorator(func):
299
+ is_async = asyncio.iscoroutinefunction(func)
300
+ func_info = FunctionInfo(handler=func, is_async=is_async)
301
+ self.middleware_after_request.append(func_info)
302
+ return func
303
+
304
+ return decorator
305
+
306
+ def inject(self, key: str, value: Any):
307
+ """
308
+ Injects a key-value pair into the injectables dictionary.
309
+
310
+ Args:
311
+ key (str): The key to be added to the injectables dictionary.
312
+ value (Any): The value to be associated with the key.
313
+
314
+ Returns:
315
+ self: Returns the instance of the class to allow method chaining.
316
+ """
317
+ self.injectables[key] = value
318
+ return self
319
+
320
+ def add_middleware(self, middleware: Middleware):
321
+ """
322
+ Adds middleware to the application.
323
+
324
+ This method attaches the middleware to the application instance and registers
325
+ its `before_request` and `after_request` hooks if they are defined.
326
+
327
+ Args:
328
+ middleware (Middleware): The middleware instance to be added.
329
+
330
+ Returns:
331
+ self: The application instance with the middleware added.
332
+ """
333
+ setattr(middleware, "app", self)
334
+ before_request = getattr(middleware, "before_request", None)
335
+ after_request = getattr(middleware, "after_request", None)
336
+
337
+ if before_request:
338
+ self.before_request()(before_request)
339
+ if after_request:
340
+ self.after_request()(after_request)
341
+ return self
342
+
343
+ def is_port_in_use(self, port: int) -> bool:
344
+ try:
345
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
346
+ return s.connect_ex(("localhost", port)) == 0
347
+ except Exception:
348
+ raise InvalidPortNumber(f"Invalid port number: {port}")
349
+
350
+ def start(
351
+ self,
352
+ host: Annotated[str, Doc("The host to run the server on. Defaults to `127.0.0.1`")] = "127.0.0.1",
353
+ port: Annotated[int, Doc("The port to run the server on. Defaults to `8080`")] = 8080,
354
+ workers: Annotated[int, Doc("The number of workers to run. Defaults to `1`")] = 1,
355
+ processes: Annotated[int, Doc("The number of processes to run. Defaults to `1`")] = 1,
356
+ max_blocking_threads: Annotated[int, Doc("The maximum number of blocking threads. Defaults to `100`")] = 1,
357
+ check_port: Annotated[bool, Doc("Check if the port is already in use. Defaults to `True`")] = False,
358
+ ):
359
+ """
360
+ Starts the server with the specified configuration.
361
+
362
+ Args:
363
+ host (str): The host to run the server on. Defaults to `127.0.0.1`.
364
+ port (int): The port to run the server on. Defaults to `8080`.
365
+ workers (int): The number of workers to run. Defaults to `1`.
366
+ processes (int): The number of processes to run. Defaults to `1`.
367
+ max_blocking_threads (int): The maximum number of blocking threads. Defaults to `100`.
368
+ check_port (bool): Check if the port is already in use. Defaults to `True`.
369
+
370
+ Raises:
371
+ ValueError: If an invalid port number is entered when prompted.
372
+
373
+ """
374
+ if check_port:
375
+ while self.is_port_in_use(port):
376
+ logger.error("Port %s is already in use. Please use a different port.", port)
377
+ try:
378
+ port = int(input("Enter a different port: "))
379
+ except Exception:
380
+ logger.error("Invalid port number. Please enter a valid port number.")
381
+ continue
382
+
383
+ if self.scheduler:
384
+ self.scheduler.start()
385
+
386
+ run_processes(
387
+ host=host,
388
+ port=port,
389
+ workers=workers,
390
+ processes=processes,
391
+ max_blocking_threads=max_blocking_threads,
392
+ router=self.router,
393
+ injectables=self.injectables,
394
+ before_request=self.middleware_before_request,
395
+ after_request=self.middleware_after_request,
396
+ response_headers=self.response_headers,
397
+ )
398
+
399
+ def add_route(self, method: HTTPMethod, endpoint: str, handler: Callable[..., Any]):
400
+ """
401
+ Adds a route to the router.
402
+
403
+ Args:
404
+ method (HTTPMethod): The HTTP method for the route (e.g., GET, POST).
405
+ endpoint (str): The endpoint path for the route.
406
+ handler (Callable[..., Any]): The function that handles requests to the route.
407
+
408
+ """
409
+ is_async = asyncio.iscoroutinefunction(handler)
410
+ func_info = FunctionInfo(handler=handler, is_async=is_async)
411
+ route = InternalRoute(path=endpoint, function=func_info, method=method.name)
412
+ self.router.add_route(route=route)
File without changes
@@ -0,0 +1,2 @@
1
+ class Authorization:
2
+ pass
hypern/background.py ADDED
@@ -0,0 +1,4 @@
1
+ from .hypern import BackgroundTask
2
+ from .hypern import BackgroundTasks
3
+
4
+ __all__ = ["BackgroundTask", "BackgroundTasks"]
File without changes
@@ -0,0 +1,8 @@
1
+ # -*- coding: utf-8 -*-
2
+ from .backend import BaseBackend
3
+ from .key_maker import BaseKeyMaker
4
+
5
+ __all__ = [
6
+ "BaseKeyMaker",
7
+ "BaseBackend",
8
+ ]
@@ -0,0 +1,3 @@
1
+ from hypern.hypern import BaseBackend
2
+
3
+ __all__ = ["BaseBackend"]
@@ -0,0 +1,8 @@
1
+ # -*- coding: utf-8 -*-
2
+ from abc import ABC, abstractmethod
3
+ from typing import Callable
4
+
5
+
6
+ class BaseKeyMaker(ABC):
7
+ @abstractmethod
8
+ async def make(self, function: Callable, prefix: str, identify_key: str) -> str: ...
@@ -0,0 +1,56 @@
1
+ # -*- coding: utf-8 -*-
2
+ from functools import wraps
3
+ from typing import Callable, Dict, Type
4
+
5
+ from .base import BaseBackend, BaseKeyMaker
6
+ from .cache_tag import CacheTag
7
+ import orjson
8
+
9
+
10
+ class CacheManager:
11
+ def __init__(self):
12
+ self.backend = None
13
+ self.key_maker = None
14
+
15
+ def init(self, backend: BaseBackend, key_maker: BaseKeyMaker) -> None:
16
+ self.backend = backend
17
+ self.key_maker = key_maker
18
+
19
+ def cached(self, tag: CacheTag, ttl: int = 60, identify: Dict = {}) -> Type[Callable]:
20
+ def _cached(function):
21
+ @wraps(function)
22
+ async def __cached(*args, **kwargs):
23
+ if not self.backend or not self.key_maker:
24
+ raise ValueError("Backend or KeyMaker not initialized")
25
+
26
+ _identify_key = []
27
+ for key, values in identify.items():
28
+ _obj = kwargs.get(key, None)
29
+ if not _obj:
30
+ raise ValueError(f"Caching: Identify key {key} not found in kwargs")
31
+ for attr in values:
32
+ _identify_key.append(f"{attr}={getattr(_obj, attr)}")
33
+ _identify_key = ":".join(_identify_key)
34
+
35
+ key = await self.key_maker.make(function=function, prefix=tag.value, identify_key=_identify_key)
36
+
37
+ cached_response = self.backend.get(key=key)
38
+ if cached_response:
39
+ return orjson.loads(cached_response)
40
+
41
+ response = await function(*args, **kwargs)
42
+ self.backend.set(response=orjson.dumps(response).decode("utf-8"), key=key, ttl=ttl)
43
+ return response
44
+
45
+ return __cached
46
+
47
+ return _cached # type: ignore
48
+
49
+ async def remove_by_tag(self, tag: CacheTag) -> None:
50
+ await self.backend.delete_startswith(value=tag.value)
51
+
52
+ async def remove_by_prefix(self, prefix: str) -> None:
53
+ await self.backend.delete_startswith(value=prefix)
54
+
55
+
56
+ Cache = CacheManager()
@@ -0,0 +1,10 @@
1
+ # -*- coding: utf-8 -*-
2
+ from enum import Enum
3
+
4
+
5
+ class CacheTag(Enum):
6
+ GET_HEALTH_CHECK = "get_health_check"
7
+ GET_USER_INFO = "get_user_info"
8
+ GET_CATEGORIES = "get_categories"
9
+ GET_HISTORY = "get_chat_history"
10
+ GET_QUESTION = "get_question"
@@ -0,0 +1,11 @@
1
+ # -*- coding: utf-8 -*-
2
+ from typing import Callable
3
+ import inspect
4
+
5
+ from hypern.caching.base import BaseKeyMaker
6
+
7
+
8
+ class CustomKeyMaker(BaseKeyMaker):
9
+ async def make(self, function: Callable, prefix: str, identify_key: str = "") -> str:
10
+ path = f"{prefix}:{inspect.getmodule(function).__name__}.{function.__name__}:{identify_key}" # type: ignore
11
+ return str(path)
@@ -0,0 +1,3 @@
1
+ from hypern.hypern import RedisBackend
2
+
3
+ __all__ = ["RedisBackend"]
hypern/cli/__init__.py ADDED
File without changes
hypern/cli/commands.py ADDED
File without changes
hypern/config.py ADDED
@@ -0,0 +1,149 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import typing
5
+ import warnings
6
+ from pathlib import Path
7
+
8
+ """
9
+
10
+ refer: https://github.com/encode/starlette/blob/master/starlette/config.py
11
+ # Config will be read from environment variables and/or ".env" files.
12
+ config = Config(".env")
13
+
14
+ DEBUG = config('DEBUG', cast=bool, default=False)
15
+ DATABASE_URL = config('DATABASE_URL')
16
+ ALLOWED_HOSTS = config('ALLOWED_HOSTS', cast=CommaSeparatedStrings)
17
+ """
18
+
19
+
20
+ class undefined:
21
+ pass
22
+
23
+
24
+ class EnvironError(Exception):
25
+ pass
26
+
27
+
28
+ class Environ(typing.MutableMapping[str, str]):
29
+ def __init__(self, environ: typing.MutableMapping[str, str] = os.environ):
30
+ self._environ = environ
31
+ self._has_been_read: set[str] = set()
32
+
33
+ def __getitem__(self, key: str) -> str:
34
+ self._has_been_read.add(key)
35
+ return self._environ.__getitem__(key)
36
+
37
+ def __setitem__(self, key: str, value: str) -> None:
38
+ if key in self._has_been_read:
39
+ raise EnvironError(f"Attempting to set environ['{key}'], but the value has already been read.")
40
+ self._environ.__setitem__(key, value)
41
+
42
+ def __delitem__(self, key: str) -> None:
43
+ if key in self._has_been_read:
44
+ raise EnvironError(f"Attempting to delete environ['{key}'], but the value has already been read.")
45
+ self._environ.__delitem__(key)
46
+
47
+ def __iter__(self) -> typing.Iterator[str]:
48
+ return iter(self._environ)
49
+
50
+ def __len__(self) -> int:
51
+ return len(self._environ)
52
+
53
+
54
+ environ = Environ()
55
+
56
+ T = typing.TypeVar("T")
57
+
58
+
59
+ class Config:
60
+ def __init__(
61
+ self,
62
+ env_file: str | Path | None = None,
63
+ environ: typing.Mapping[str, str] = environ,
64
+ env_prefix: str = "",
65
+ ) -> None:
66
+ self.environ = environ
67
+ self.env_prefix = env_prefix
68
+ self.file_values: dict[str, str] = {}
69
+ if env_file is not None:
70
+ if not os.path.isfile(env_file):
71
+ warnings.warn(f"Config file '{env_file}' not found.")
72
+ else:
73
+ self.file_values = self._read_file(env_file)
74
+
75
+ @typing.overload
76
+ def __call__(self, key: str, *, default: None) -> str | None: ...
77
+
78
+ @typing.overload
79
+ def __call__(self, key: str, cast: type[T], default: T = ...) -> T: ...
80
+
81
+ @typing.overload
82
+ def __call__(self, key: str, cast: type[str] = ..., default: str = ...) -> str: ...
83
+
84
+ @typing.overload
85
+ def __call__(
86
+ self,
87
+ key: str,
88
+ cast: typing.Callable[[typing.Any], T] = ...,
89
+ default: typing.Any = ...,
90
+ ) -> T: ...
91
+
92
+ @typing.overload
93
+ def __call__(self, key: str, cast: type[str] = ..., default: T = ...) -> T | str: ...
94
+
95
+ def __call__(
96
+ self,
97
+ key: str,
98
+ cast: typing.Callable[[typing.Any], typing.Any] | None = None,
99
+ default: typing.Any = undefined,
100
+ ) -> typing.Any:
101
+ return self.get(key, cast, default)
102
+
103
+ def get(
104
+ self,
105
+ key: str,
106
+ cast: typing.Callable[[typing.Any], typing.Any] | None = None,
107
+ default: typing.Any = undefined,
108
+ ) -> typing.Any:
109
+ key = self.env_prefix + key
110
+ if key in self.environ:
111
+ value = self.environ[key]
112
+ return self._perform_cast(key, value, cast)
113
+ if key in self.file_values:
114
+ value = self.file_values[key]
115
+ return self._perform_cast(key, value, cast)
116
+ if default is not undefined:
117
+ return self._perform_cast(key, default, cast)
118
+ raise KeyError(f"Config '{key}' is missing, and has no default.")
119
+
120
+ def _read_file(self, file_name: str | Path) -> dict[str, str]:
121
+ file_values: dict[str, str] = {}
122
+ with open(file_name) as input_file:
123
+ for line in input_file.readlines():
124
+ line = line.strip()
125
+ if "=" in line and not line.startswith("#"):
126
+ key, value = line.split("=", 1)
127
+ key = key.strip()
128
+ value = value.strip().strip("\"'")
129
+ file_values[key] = value
130
+ return file_values
131
+
132
+ def _perform_cast(
133
+ self,
134
+ key: str,
135
+ value: typing.Any,
136
+ cast: typing.Callable[[typing.Any], typing.Any] | None = None,
137
+ ) -> typing.Any:
138
+ if cast is None or value is None:
139
+ return value
140
+ elif cast is bool and isinstance(value, str):
141
+ mapping = {"true": True, "1": True, "false": False, "0": False}
142
+ value = value.lower()
143
+ if value not in mapping:
144
+ raise ValueError(f"Config '{key}' has value '{value}'. Not a valid bool.")
145
+ return mapping[value]
146
+ try:
147
+ return cast(value)
148
+ except (TypeError, ValueError):
149
+ raise ValueError(f"Config '{key}' has value '{value}'. Not a valid {cast.__name__}.")