nx-framework 0.0.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nx_framework-0.0.3/LICENSE +21 -0
- nx_framework-0.0.3/PKG-INFO +24 -0
- nx_framework-0.0.3/README.md +2 -0
- nx_framework-0.0.3/nx/__init__.py +15 -0
- nx_framework-0.0.3/nx/__main__.py +110 -0
- nx_framework-0.0.3/nx/config/__init__.py +55 -0
- nx_framework-0.0.3/nx/config/config_model.py +82 -0
- nx_framework-0.0.3/nx/config/fields.py +79 -0
- nx_framework-0.0.3/nx/db.py +185 -0
- nx_framework-0.0.3/nx/exceptions.py +19 -0
- nx_framework-0.0.3/nx/ffmpeg/__init__.py +11 -0
- nx_framework-0.0.3/nx/ffmpeg/ffmpeg.py +184 -0
- nx_framework-0.0.3/nx/ffmpeg/ffprobe.py +32 -0
- nx_framework-0.0.3/nx/initialize.py +17 -0
- nx_framework-0.0.3/nx/logging.py +73 -0
- nx_framework-0.0.3/nx/py.typed +0 -0
- nx_framework-0.0.3/nx/redis.py +178 -0
- nx_framework-0.0.3/nx/server/__init__.py +0 -0
- nx_framework-0.0.3/nx/server/app.py +31 -0
- nx_framework-0.0.3/nx/server/bubblewrap.py +103 -0
- nx_framework-0.0.3/nx/server/gatekeeper.py +43 -0
- nx_framework-0.0.3/nx/server/lifespan.py +16 -0
- nx_framework-0.0.3/nx/utils/__init__.py +20 -0
- nx_framework-0.0.3/nx/utils/coalesce.py +86 -0
- nx_framework-0.0.3/nx/utils/utils.py +45 -0
- nx_framework-0.0.3/nx/version.py +1 -0
- nx_framework-0.0.3/nx_framework.egg-info/PKG-INFO +24 -0
- nx_framework-0.0.3/nx_framework.egg-info/SOURCES.txt +31 -0
- nx_framework-0.0.3/nx_framework.egg-info/dependency_links.txt +1 -0
- nx_framework-0.0.3/nx_framework.egg-info/requires.txt +11 -0
- nx_framework-0.0.3/nx_framework.egg-info/top_level.txt +1 -0
- nx_framework-0.0.3/pyproject.toml +81 -0
- nx_framework-0.0.3/setup.cfg +4 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Martin Wacker
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: nx-framework
|
|
3
|
+
Version: 0.0.3
|
|
4
|
+
Summary: Low-level FastAPI, asyncpg and Redis framework used by Nebula Broadcast and other projects.
|
|
5
|
+
Author-email: Martin Wacker <martas@imm.cz>
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Requires-Python: >=3.12
|
|
8
|
+
Description-Content-Type: text/markdown
|
|
9
|
+
License-File: LICENSE
|
|
10
|
+
Requires-Dist: aiofiles>=24.1.0
|
|
11
|
+
Requires-Dist: aioshutil>=1.5
|
|
12
|
+
Requires-Dist: asyncpg>=0.30.0
|
|
13
|
+
Requires-Dist: fastapi>=0.115.8
|
|
14
|
+
Requires-Dist: loguru>=0.7.3
|
|
15
|
+
Requires-Dist: pydantic-settings>=2.7.1
|
|
16
|
+
Requires-Dist: pydantic[email]>=2.10.6
|
|
17
|
+
Requires-Dist: python-dotenv>=1.0.1
|
|
18
|
+
Requires-Dist: redis>=5.2.1
|
|
19
|
+
Requires-Dist: shortuuid>=1.0.13
|
|
20
|
+
Requires-Dist: unidecode>=1.3.8
|
|
21
|
+
Dynamic: license-file
|
|
22
|
+
|
|
23
|
+
# nx
|
|
24
|
+
Very opinionated set of tools
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
__all__ = [
|
|
2
|
+
"coalesce",
|
|
3
|
+
"config",
|
|
4
|
+
"db",
|
|
5
|
+
"initialize",
|
|
6
|
+
"log",
|
|
7
|
+
"redis",
|
|
8
|
+
]
|
|
9
|
+
|
|
10
|
+
from nx.config import config
|
|
11
|
+
from nx.db import db
|
|
12
|
+
from nx.initialize import initialize
|
|
13
|
+
from nx.logging import logger as log
|
|
14
|
+
from nx.redis import redis
|
|
15
|
+
from nx.utils import coalesce
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
"""nx cli
|
|
2
|
+
======
|
|
3
|
+
|
|
4
|
+
This entrypoint is used to run nx in a development environment.
|
|
5
|
+
It is used to test various features of nx and to run the vanilla server
|
|
6
|
+
without any modifications.
|
|
7
|
+
|
|
8
|
+
It is not intended to be used in production.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import asyncio
|
|
12
|
+
import os
|
|
13
|
+
import signal
|
|
14
|
+
import subprocess
|
|
15
|
+
import sys
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
import nx
|
|
19
|
+
from nx.version import __version__
|
|
20
|
+
|
|
21
|
+
nx.initialize(standalone=True)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
GUNICORN_PID_FILE = "/tmp/gunicorn.pid" # noqa: S108
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def version() -> None:
|
|
28
|
+
"""Show the version."""
|
|
29
|
+
print(__version__, end="") # noqa: T201
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def run(*args: Any) -> None:
|
|
33
|
+
"""Run a command."""
|
|
34
|
+
nx.log.info(f"Running command, {args}")
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def serve() -> None:
|
|
38
|
+
"""Run the server."""
|
|
39
|
+
cmd = [
|
|
40
|
+
"gunicorn",
|
|
41
|
+
"--bind",
|
|
42
|
+
f":{nx.config.server_port}",
|
|
43
|
+
"--reload",
|
|
44
|
+
"--worker-class",
|
|
45
|
+
"uvicorn_worker.UvicornWorker",
|
|
46
|
+
"--max-requests",
|
|
47
|
+
"1000",
|
|
48
|
+
"--log-level",
|
|
49
|
+
"warning",
|
|
50
|
+
"--pid",
|
|
51
|
+
GUNICORN_PID_FILE,
|
|
52
|
+
"nx.server.app:app",
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
process = subprocess.Popen(cmd) # noqa: S603
|
|
56
|
+
gunicorn_pid = process.pid
|
|
57
|
+
|
|
58
|
+
def handle_sigterm(signum, frame) -> None: # type: ignore[no-untyped-def]
|
|
59
|
+
_ = signum, frame
|
|
60
|
+
nx.log.warning("Received SIGTERM")
|
|
61
|
+
os.kill(gunicorn_pid, signal.SIGTERM)
|
|
62
|
+
|
|
63
|
+
def handle_sigint(signum, frame) -> None: # type: ignore[no-untyped-def]
|
|
64
|
+
_ = signum, frame
|
|
65
|
+
nx.log.warning("Received SIGINT")
|
|
66
|
+
os.kill(gunicorn_pid, signal.SIGINT)
|
|
67
|
+
|
|
68
|
+
signal.signal(signal.SIGTERM, handle_sigterm)
|
|
69
|
+
signal.signal(signal.SIGINT, handle_sigint)
|
|
70
|
+
|
|
71
|
+
process.wait()
|
|
72
|
+
nx.log.info("Gunicorn process terminated.")
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def reload() -> None:
|
|
76
|
+
"""Reload the server by sending SIGHUP to Gunicorn."""
|
|
77
|
+
if os.path.exists(GUNICORN_PID_FILE):
|
|
78
|
+
with open(GUNICORN_PID_FILE) as f:
|
|
79
|
+
try:
|
|
80
|
+
gunicorn_pid = int(f.read().strip())
|
|
81
|
+
os.kill(gunicorn_pid, signal.SIGHUP)
|
|
82
|
+
nx.log.info(f"Sent SIGHUP to Gunicorn (PID: {gunicorn_pid}).")
|
|
83
|
+
except ValueError:
|
|
84
|
+
nx.log.error("Invalid PID in Gunicorn PID file.")
|
|
85
|
+
except ProcessLookupError:
|
|
86
|
+
nx.log.error("Gunicorn process not found.")
|
|
87
|
+
else:
|
|
88
|
+
nx.log.error("Gunicorn PID file not found.")
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
async def debug() -> None:
|
|
92
|
+
print(nx.config.model_dump_json(indent=2, exclude_unset=True)) # noqa: T201
|
|
93
|
+
res = await nx.db.fetch("SELECT * FROM config")
|
|
94
|
+
print(res) # noqa: T201
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
if __name__ == "__main__":
|
|
98
|
+
if "version" in sys.argv:
|
|
99
|
+
version()
|
|
100
|
+
elif "run" in sys.argv:
|
|
101
|
+
run(sys.argv[2:])
|
|
102
|
+
elif "serve" in sys.argv:
|
|
103
|
+
serve()
|
|
104
|
+
elif "reload" in sys.argv:
|
|
105
|
+
reload()
|
|
106
|
+
elif "debug" in sys.argv:
|
|
107
|
+
asyncio.run(debug())
|
|
108
|
+
else:
|
|
109
|
+
nx.log.error("Invalid command. Use 'version', 'run', 'serve', or 'reload'.")
|
|
110
|
+
sys.exit(1)
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
__all__ = [
|
|
2
|
+
"ConfigModel",
|
|
3
|
+
"ConfigProxy",
|
|
4
|
+
"config",
|
|
5
|
+
]
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
from typing import Any, Generic, TypeVar, cast
|
|
9
|
+
|
|
10
|
+
from dotenv import load_dotenv
|
|
11
|
+
from pydantic import BaseModel
|
|
12
|
+
|
|
13
|
+
from .config_model import ConfigModel
|
|
14
|
+
|
|
15
|
+
T = TypeVar("T", bound=BaseModel)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ConfigProxy(Generic[T]):
|
|
19
|
+
_instance: "ConfigProxy[T] | None" = None
|
|
20
|
+
_config_model: type[T]
|
|
21
|
+
_fields: set[str]
|
|
22
|
+
_config: BaseModel | None = None
|
|
23
|
+
|
|
24
|
+
def __new__(cls, *args: Any, **kwargs: Any) -> "ConfigProxy[T]":
|
|
25
|
+
_ = args, kwargs
|
|
26
|
+
if cls._instance is None:
|
|
27
|
+
cls._instance = super().__new__(cls)
|
|
28
|
+
return cls._instance
|
|
29
|
+
|
|
30
|
+
def initialize(self, config_model: type[T], env_prefix: str) -> None:
|
|
31
|
+
self._config_model = config_model
|
|
32
|
+
self._fields = set(config_model.model_fields)
|
|
33
|
+
self._env_prefix = env_prefix
|
|
34
|
+
|
|
35
|
+
full_env_prefix = f"{env_prefix}_".lower()
|
|
36
|
+
load_dotenv()
|
|
37
|
+
env_data = {}
|
|
38
|
+
for key, value in dict(os.environ).items():
|
|
39
|
+
if not key.lower().startswith(full_env_prefix):
|
|
40
|
+
continue
|
|
41
|
+
|
|
42
|
+
fkey = key.lower().removeprefix(full_env_prefix)
|
|
43
|
+
if fkey in self._fields:
|
|
44
|
+
env_data[fkey] = value
|
|
45
|
+
|
|
46
|
+
self._config = self._config_model(**env_data)
|
|
47
|
+
|
|
48
|
+
def __getattr__(self, key: str) -> Any:
|
|
49
|
+
if not self._config:
|
|
50
|
+
raise AttributeError("Config not initialized. Call initialize() first.")
|
|
51
|
+
return getattr(self._config, key)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
_config_proxy = ConfigProxy() # type: ignore[var-annotated]
|
|
55
|
+
config = cast("ConfigModel", _config_proxy)
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
from typing import Any, Self, cast
|
|
2
|
+
from urllib.parse import urlparse
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, PostgresDsn, RedisDsn, field_validator, model_validator
|
|
5
|
+
|
|
6
|
+
from .fields import (
|
|
7
|
+
LogLevel,
|
|
8
|
+
LogMode,
|
|
9
|
+
PostgresHost,
|
|
10
|
+
PostgresName,
|
|
11
|
+
PostgresPassword,
|
|
12
|
+
PostgresPort,
|
|
13
|
+
PostgresUser,
|
|
14
|
+
ServerPort,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ConfigModel(BaseModel):
|
|
19
|
+
log_level: LogLevel = "DEBUG"
|
|
20
|
+
log_mode: LogMode = "text"
|
|
21
|
+
log_context: bool = True
|
|
22
|
+
server_host: str = "0.0.0.0"
|
|
23
|
+
server_port: ServerPort = 8765
|
|
24
|
+
postgres_url: PostgresDsn = PostgresDsn("postgresql://nx:nx@postgres:5432/nx")
|
|
25
|
+
redis_url: RedisDsn = RedisDsn("redis://redis")
|
|
26
|
+
|
|
27
|
+
# database connection overrides
|
|
28
|
+
# The folowing fields are used to override the default connection settings
|
|
29
|
+
# provided by POSTGRES_URL
|
|
30
|
+
|
|
31
|
+
postgres_host: PostgresHost = None
|
|
32
|
+
postgres_port: PostgresPort = None
|
|
33
|
+
postgres_name: PostgresName = None
|
|
34
|
+
postgres_user: PostgresUser = None
|
|
35
|
+
postgres_password: PostgresPassword = None
|
|
36
|
+
|
|
37
|
+
@field_validator("log_level", mode="before")
|
|
38
|
+
@classmethod
|
|
39
|
+
def validate_log_level(cls, v: Any) -> LogLevel:
|
|
40
|
+
assert isinstance(v, str), "Log level must be a string"
|
|
41
|
+
return cast("LogLevel", v.upper())
|
|
42
|
+
|
|
43
|
+
@model_validator(mode="after")
|
|
44
|
+
def construct_final_postgres_url(self) -> Self:
|
|
45
|
+
"""Synchronize the postgres_url with the individual fields."""
|
|
46
|
+
parsed = urlparse(str(self.postgres_url))
|
|
47
|
+
# Extract the relevant components
|
|
48
|
+
user = parsed.username if self.postgres_user is None else self.postgres_user
|
|
49
|
+
password = (
|
|
50
|
+
parsed.password
|
|
51
|
+
if self.postgres_password is None
|
|
52
|
+
else self.postgres_password
|
|
53
|
+
)
|
|
54
|
+
host = parsed.hostname if self.postgres_host is None else self.postgres_host
|
|
55
|
+
port = parsed.port or 5432 if self.postgres_port is None else self.postgres_port
|
|
56
|
+
database = parsed.path[1:] if self.postgres_name is None else self.postgres_name
|
|
57
|
+
|
|
58
|
+
# rebuild the URL with the overrides
|
|
59
|
+
|
|
60
|
+
self.postgres_url = PostgresDsn.build(
|
|
61
|
+
scheme="postgresql",
|
|
62
|
+
username=user,
|
|
63
|
+
password=password,
|
|
64
|
+
host=host,
|
|
65
|
+
port=port,
|
|
66
|
+
path=database,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
# and populate the fields again, so all values are in sync
|
|
70
|
+
|
|
71
|
+
parsed = urlparse(str(self.postgres_url))
|
|
72
|
+
|
|
73
|
+
self.postgres_host = parsed.hostname
|
|
74
|
+
self.postgres_port = parsed.port or 5432
|
|
75
|
+
self.postgres_name = parsed.path[1:]
|
|
76
|
+
self.postgres_user = parsed.username
|
|
77
|
+
self.postgres_password = parsed.password
|
|
78
|
+
|
|
79
|
+
return self
|
|
80
|
+
|
|
81
|
+
def initialize(self, **kwargs: Any) -> None:
|
|
82
|
+
_ = kwargs
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
from typing import Annotated, Literal
|
|
2
|
+
|
|
3
|
+
from pydantic import Field
|
|
4
|
+
|
|
5
|
+
LogMode = Annotated[
|
|
6
|
+
Literal["text", "json"],
|
|
7
|
+
Field(
|
|
8
|
+
title="Log mode",
|
|
9
|
+
description="The log mode for the server",
|
|
10
|
+
examples=["text"],
|
|
11
|
+
),
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
LogLevel = Annotated[
|
|
15
|
+
Literal["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "TRACE"],
|
|
16
|
+
Field(
|
|
17
|
+
title="Log Level",
|
|
18
|
+
description="The log level for the server",
|
|
19
|
+
examples=["INFO"],
|
|
20
|
+
),
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
ServerPort = Annotated[
|
|
24
|
+
int,
|
|
25
|
+
Field(
|
|
26
|
+
title="Port",
|
|
27
|
+
description="The port the server will listen on",
|
|
28
|
+
examples=[8765],
|
|
29
|
+
ge=0,
|
|
30
|
+
le=65535,
|
|
31
|
+
),
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
PostgresHost = Annotated[
|
|
35
|
+
str | None,
|
|
36
|
+
Field(
|
|
37
|
+
title="Postgres host",
|
|
38
|
+
description="Override the default Postgres host",
|
|
39
|
+
examples=["localhost", "postgres"],
|
|
40
|
+
),
|
|
41
|
+
]
|
|
42
|
+
|
|
43
|
+
PostgresPort = Annotated[
|
|
44
|
+
int | None,
|
|
45
|
+
Field(
|
|
46
|
+
title="Postgres port",
|
|
47
|
+
description="Override the default Postgres port provided by POSTGRES_URL",
|
|
48
|
+
examples=[5432],
|
|
49
|
+
ge=0,
|
|
50
|
+
le=65535,
|
|
51
|
+
),
|
|
52
|
+
]
|
|
53
|
+
|
|
54
|
+
PostgresName = Annotated[
|
|
55
|
+
str | None,
|
|
56
|
+
Field(
|
|
57
|
+
title="Postgres database name",
|
|
58
|
+
description="Override the default Postgres database name from POSTGRES_URL",
|
|
59
|
+
examples=["nx"],
|
|
60
|
+
),
|
|
61
|
+
]
|
|
62
|
+
|
|
63
|
+
PostgresUser = Annotated[
|
|
64
|
+
str | None,
|
|
65
|
+
Field(
|
|
66
|
+
title="Postgres user",
|
|
67
|
+
description="Override the default Postgres user from POSTGRES_URL",
|
|
68
|
+
examples=["nx"],
|
|
69
|
+
),
|
|
70
|
+
]
|
|
71
|
+
|
|
72
|
+
PostgresPassword = Annotated[
|
|
73
|
+
str | None,
|
|
74
|
+
Field(
|
|
75
|
+
title="Postgres password",
|
|
76
|
+
description="Override the default Postgres password from POSTGRES_URL",
|
|
77
|
+
examples=["nx"],
|
|
78
|
+
),
|
|
79
|
+
]
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
__all__ = ["db"]
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import sys
|
|
5
|
+
import uuid
|
|
6
|
+
from collections.abc import AsyncGenerator, AsyncIterator
|
|
7
|
+
from contextlib import asynccontextmanager
|
|
8
|
+
from contextvars import ContextVar
|
|
9
|
+
from typing import TYPE_CHECKING, Any
|
|
10
|
+
|
|
11
|
+
import asyncpg
|
|
12
|
+
|
|
13
|
+
from nx.config import config
|
|
14
|
+
from nx.logging import logger
|
|
15
|
+
from nx.utils import json_dumps, json_loads, normalize_uuid
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from asyncpg.pool import PoolConnectionProxy
|
|
19
|
+
from asyncpg.prepared_stmt import PreparedStatement
|
|
20
|
+
|
|
21
|
+
_current_connection: ContextVar["PoolConnectionProxy | None"] = ContextVar( # type: ignore[type-arg]
|
|
22
|
+
"_current_connection", default=None
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
_connection_lock = asyncio.Lock()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class DB:
|
|
29
|
+
_instance: "DB | None" = None
|
|
30
|
+
_pool: asyncpg.pool.Pool | None = None # type: ignore[type-arg]
|
|
31
|
+
|
|
32
|
+
def __new__(cls, *args: Any, **kwargs: Any) -> "DB":
|
|
33
|
+
_ = args, kwargs
|
|
34
|
+
if cls._instance is None:
|
|
35
|
+
cls._instance = super().__new__(cls)
|
|
36
|
+
return cls._instance
|
|
37
|
+
|
|
38
|
+
async def _init_connection(self, conn) -> None: # type: ignore[no-untyped-def]
|
|
39
|
+
await conn.set_type_codec(
|
|
40
|
+
"jsonb",
|
|
41
|
+
encoder=json_dumps,
|
|
42
|
+
decoder=json_loads,
|
|
43
|
+
schema="pg_catalog",
|
|
44
|
+
)
|
|
45
|
+
await conn.set_type_codec(
|
|
46
|
+
"uuid",
|
|
47
|
+
encoder=str,
|
|
48
|
+
decoder=lambda x: uuid.UUID(x).hex,
|
|
49
|
+
schema="pg_catalog",
|
|
50
|
+
)
|
|
51
|
+
await conn.set_type_codec(
|
|
52
|
+
"uuid",
|
|
53
|
+
encoder=lambda x: normalize_uuid(x, True),
|
|
54
|
+
decoder=lambda x: normalize_uuid(x, True),
|
|
55
|
+
schema="pg_catalog",
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
async def _connect(self) -> None:
|
|
59
|
+
"""Create a Postgres connection pool."""
|
|
60
|
+
async with _connection_lock:
|
|
61
|
+
if self._pool is not None:
|
|
62
|
+
return # Double check
|
|
63
|
+
|
|
64
|
+
try:
|
|
65
|
+
self._pool = await asyncpg.create_pool(
|
|
66
|
+
str(config.postgres_url),
|
|
67
|
+
init=self._init_connection,
|
|
68
|
+
)
|
|
69
|
+
except Exception as e:
|
|
70
|
+
logger.error(f"Failed to connect to the database: {e}")
|
|
71
|
+
logger.error(
|
|
72
|
+
f"Unrecoverable error while "
|
|
73
|
+
f"connecting to '{config.postgres_url}'. Exiting."
|
|
74
|
+
)
|
|
75
|
+
sys.exit(1)
|
|
76
|
+
|
|
77
|
+
@asynccontextmanager
|
|
78
|
+
async def acquire(
|
|
79
|
+
self,
|
|
80
|
+
*,
|
|
81
|
+
timeout: int | None = None, # noqa: ASYNC109
|
|
82
|
+
force_new: bool = False,
|
|
83
|
+
) -> AsyncIterator["PoolConnectionProxy"]: # type: ignore[type-arg]
|
|
84
|
+
"""Resolve the current connection from the contextvar or acquire a new one.
|
|
85
|
+
If the connection is not available, create a new one.
|
|
86
|
+
"""
|
|
87
|
+
conn = _current_connection.get()
|
|
88
|
+
if not force_new and conn is not None:
|
|
89
|
+
yield conn
|
|
90
|
+
return
|
|
91
|
+
|
|
92
|
+
if self._pool is None:
|
|
93
|
+
await self._connect()
|
|
94
|
+
assert self._pool is not None, "Database pool is not initialized"
|
|
95
|
+
|
|
96
|
+
try:
|
|
97
|
+
connection_proxy = await self._pool.acquire(timeout=timeout)
|
|
98
|
+
except TimeoutError as e:
|
|
99
|
+
raise ConnectionError("Database pool timeout") from e
|
|
100
|
+
|
|
101
|
+
token = _current_connection.set(connection_proxy)
|
|
102
|
+
|
|
103
|
+
try:
|
|
104
|
+
yield connection_proxy
|
|
105
|
+
finally:
|
|
106
|
+
_current_connection.reset(token)
|
|
107
|
+
await self._pool.release(connection_proxy)
|
|
108
|
+
|
|
109
|
+
@asynccontextmanager
|
|
110
|
+
async def transaction(
|
|
111
|
+
self,
|
|
112
|
+
timeout: int | None = None, # noqa: ASYNC109
|
|
113
|
+
force_new: bool = False,
|
|
114
|
+
) -> AsyncIterator["PoolConnectionProxy"]: # type: ignore[type-arg]
|
|
115
|
+
"""Acquire a connection from the pool and manage transaction state."""
|
|
116
|
+
async with self.acquire(timeout=timeout, force_new=force_new) as connection:
|
|
117
|
+
if connection.is_in_transaction():
|
|
118
|
+
yield connection
|
|
119
|
+
else:
|
|
120
|
+
async with connection.transaction():
|
|
121
|
+
yield connection
|
|
122
|
+
|
|
123
|
+
@property
|
|
124
|
+
def is_in_transaction(self) -> bool:
|
|
125
|
+
"""Check if the current connection is in a transaction."""
|
|
126
|
+
conn = _current_connection.get()
|
|
127
|
+
if conn is None:
|
|
128
|
+
return False
|
|
129
|
+
return conn.is_in_transaction()
|
|
130
|
+
|
|
131
|
+
async def execute(self, query: str, *args: Any) -> str:
|
|
132
|
+
"""Exeuute a query and return the status."""
|
|
133
|
+
async with self.acquire() as conn:
|
|
134
|
+
return await conn.execute(query, *args)
|
|
135
|
+
|
|
136
|
+
async def executemany(self, query: str, *args: Any) -> None:
|
|
137
|
+
"""Execute a query multiple times and return the result."""
|
|
138
|
+
async with self.acquire() as conn:
|
|
139
|
+
await conn.executemany(query, *args)
|
|
140
|
+
|
|
141
|
+
async def prepare(self, query: str, *args: Any) -> "PreparedStatement": # type: ignore[type-arg]
|
|
142
|
+
"""Fetch a query and return the result."""
|
|
143
|
+
async with self.acquire() as conn:
|
|
144
|
+
if not conn.is_in_transaction():
|
|
145
|
+
raise RuntimeError("Transaction not started")
|
|
146
|
+
return await conn.prepare(query, *args)
|
|
147
|
+
|
|
148
|
+
async def fetch(self, query: str, *args: Any) -> list[asyncpg.Record]:
|
|
149
|
+
"""Fetch a query and return the result."""
|
|
150
|
+
async with self.acquire() as conn:
|
|
151
|
+
return await conn.fetch(query, *args)
|
|
152
|
+
|
|
153
|
+
async def fetchrow(self, query: str, *args: Any) -> asyncpg.Record | None:
|
|
154
|
+
"""Fetch a query and return the first result."""
|
|
155
|
+
async with self.acquire() as conn:
|
|
156
|
+
return await conn.fetchrow(query, *args)
|
|
157
|
+
|
|
158
|
+
async def iterate(
|
|
159
|
+
self,
|
|
160
|
+
query: str,
|
|
161
|
+
*args: Any,
|
|
162
|
+
timeout: int | None = None, # noqa: ASYNC109
|
|
163
|
+
) -> AsyncGenerator[asyncpg.Record]:
|
|
164
|
+
"""Run a query and yield rows in batches using cursor.fetch()."""
|
|
165
|
+
if self._pool is None:
|
|
166
|
+
await self._connect()
|
|
167
|
+
assert self._pool is not None, "Database pool is not initialized"
|
|
168
|
+
|
|
169
|
+
conn = await self._pool.acquire(timeout=timeout)
|
|
170
|
+
|
|
171
|
+
try:
|
|
172
|
+
if not conn.is_in_transaction():
|
|
173
|
+
async with conn.transaction():
|
|
174
|
+
statement = await conn.prepare(query)
|
|
175
|
+
async for record in statement.cursor(*args):
|
|
176
|
+
yield record
|
|
177
|
+
else:
|
|
178
|
+
statement = await conn.prepare(query)
|
|
179
|
+
async for record in statement.cursor(*args):
|
|
180
|
+
yield record
|
|
181
|
+
finally:
|
|
182
|
+
await self._pool.release(conn)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
db = DB()
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
class BaseNXError(Exception):
|
|
2
|
+
"""Base class for all exceptions raised by the NX library."""
|
|
3
|
+
|
|
4
|
+
status = 500
|
|
5
|
+
|
|
6
|
+
def __init__(self, detail: str | None = None) -> None:
|
|
7
|
+
super().__init__(detail)
|
|
8
|
+
if detail is not None:
|
|
9
|
+
self.detail = detail
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class NotFoundError(BaseNXError):
|
|
13
|
+
status = 404
|
|
14
|
+
detail = "Not Found"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class UnauthorizedError(BaseNXError):
|
|
18
|
+
status = 401
|
|
19
|
+
detail = "Unauthorized"
|