appkit-commons 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of appkit-commons might be problematic. Click here for more details.

@@ -0,0 +1,52 @@
1
+ import contextlib
2
+ from collections.abc import AsyncIterator, Iterator
3
+ from contextlib import contextmanager
4
+ from typing import Any
5
+
6
+ from sqlalchemy import create_engine
7
+ from sqlalchemy.ext.asyncio import (
8
+ AsyncSession,
9
+ async_sessionmaker,
10
+ create_async_engine,
11
+ )
12
+ from sqlalchemy.orm import Session, sessionmaker
13
+
14
+
15
+ class AsyncSessionManager:
16
+ def __init__(self, host: str, engine_kwargs: dict[str, Any] | None = None):
17
+ self._engine = create_async_engine(host, **(engine_kwargs or {}))
18
+ self._sessionmaker = async_sessionmaker(bind=self._engine)
19
+
20
+ async def close(self) -> None:
21
+ if self._engine:
22
+ await self._engine.dispose()
23
+
24
+ @contextlib.asynccontextmanager
25
+ async def session(self) -> AsyncIterator[AsyncSession]:
26
+ async with self._sessionmaker() as session:
27
+ try:
28
+ yield session
29
+ await session.commit()
30
+ except Exception:
31
+ await session.rollback()
32
+ raise
33
+
34
+
35
+ class SessionManager:
36
+ def __init__(self, host: str, engine_kwargs: dict[str, Any] | None = None):
37
+ self._engine = create_engine(host, **(engine_kwargs or {}))
38
+ self._sessionmaker = sessionmaker(bind=self._engine)
39
+
40
+ def close(self) -> None:
41
+ if self._engine:
42
+ self._engine.dispose()
43
+
44
+ @contextmanager
45
+ def session(self) -> Iterator[Session]:
46
+ with self._sessionmaker() as session:
47
+ try:
48
+ yield session
49
+ session.commit()
50
+ except Exception:
51
+ session.rollback()
52
+ raise
@@ -0,0 +1,19 @@
1
+ from starlette.types import ASGIApp, Receive, Scope, Send
2
+
3
+
4
+ # Redirect HTTP to HTTPS when behind a proxy (Azure) that sets X-Forwarded-Proto
5
+ class ForceHTTPSMiddleware:
6
+ def __init__(self, app: ASGIApp):
7
+ self.app = app
8
+
9
+ async def __call__(self, scope: Scope, receive: Receive, send: Send):
10
+ if scope["type"] in ("http", "websocket"):
11
+ # Read headers to find X-Forwarded-Proto
12
+ headers = dict(scope["headers"])
13
+ # If Azure says it was HTTPS, force the scope to HTTPS
14
+ if (
15
+ b"x-forwarded-proto" in headers
16
+ and headers[b"x-forwarded-proto"] == b"https"
17
+ ):
18
+ scope["scheme"] = "https"
19
+ await self.app(scope, receive, send)
@@ -0,0 +1,188 @@
1
+ import logging
2
+ from functools import lru_cache
3
+ from typing import TYPE_CHECKING, Any, TypeVar, cast
4
+
5
+ if TYPE_CHECKING:
6
+ from appkit_commons.configuration.configuration import (
7
+ ApplicationConfig,
8
+ Configuration,
9
+ )
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ T = TypeVar("T")
14
+ ConfigT = TypeVar("ConfigT", bound="ApplicationConfig")
15
+
16
+
17
+ class ServiceRegistry:
18
+ """Registry for storing and retrieving initialized instances by their class type."""
19
+
20
+ def __init__(self) -> None:
21
+ self._instances: dict[type[Any], Any] = {}
22
+
23
+ def _register_config_recursively( # noqa: PLR0912
24
+ self, obj: Any, visited: set[int] | None = None
25
+ ) -> None:
26
+ """Recursively register configuration objects and their attributes."""
27
+ if visited is None:
28
+ visited = set()
29
+
30
+ # Avoid infinite recursion by tracking visited objects
31
+ obj_id = id(obj)
32
+ if obj_id in visited:
33
+ return
34
+ visited.add(obj_id)
35
+
36
+ # Use __dict__ to get instance attributes directly
37
+ if hasattr(obj, "__dict__"):
38
+ for attr_name, attr_value in obj.__dict__.items():
39
+ # Skip private attributes and None values
40
+ if attr_name.startswith("_") or attr_value is None:
41
+ continue
42
+
43
+ try:
44
+ # Check if this is a configuration object (not a basic type)
45
+ if hasattr(attr_value, "__class__"):
46
+ attr_class = attr_value.__class__
47
+
48
+ # Skip built-in types, pydantic types, and already registered
49
+ if (
50
+ attr_class.__module__ != "builtins"
51
+ and attr_class.__name__ not in ("SecretStr", "StrEnum")
52
+ and not self.has(attr_class)
53
+ ):
54
+ self.register_as(attr_class, attr_value)
55
+ logger.debug(
56
+ "Registered service configuration: %s from attribute %s", # noqa: E501
57
+ attr_class.__name__,
58
+ attr_name,
59
+ )
60
+
61
+ # Recursively register nested configurations
62
+ self._register_config_recursively(attr_value, visited)
63
+
64
+ except Exception as e:
65
+ logger.warning(
66
+ "Failed to process attribute %s: %s", attr_name, str(e)
67
+ )
68
+
69
+ # Also check class annotations to handle properties/descriptors
70
+ if hasattr(obj.__class__, "__annotations__"):
71
+ for attr_name in obj.__class__.__annotations__:
72
+ if attr_name.startswith("_"):
73
+ continue
74
+
75
+ try:
76
+ attr_value = getattr(obj, attr_name, None)
77
+ if attr_value is not None and hasattr(attr_value, "__class__"):
78
+ attr_class = attr_value.__class__
79
+
80
+ if (
81
+ attr_class.__module__ != "builtins"
82
+ and attr_class.__name__ not in ("SecretStr", "StrEnum")
83
+ and not self.has(attr_class)
84
+ ):
85
+ self.register_as(attr_class, attr_value)
86
+ logger.debug(
87
+ "Registered service configuration: %s from annotated attribute %s", # noqa: E501
88
+ attr_class.__name__,
89
+ attr_name,
90
+ )
91
+
92
+ # Recursively register nested configurations
93
+ self._register_config_recursively(attr_value, visited)
94
+
95
+ except Exception as e:
96
+ logger.warning(
97
+ "Failed to access annotated attribute %s: %s", attr_name, str(e)
98
+ )
99
+
100
+ def configure(
101
+ self, app_config_class: type[ConfigT], env_file: str = ".env"
102
+ ) -> "Configuration[ConfigT]":
103
+ """Configure and register the application configuration."""
104
+ from appkit_commons.configuration.configuration import ( # noqa: PLC0415
105
+ Configuration,
106
+ )
107
+
108
+ logger.debug(
109
+ "Configuring application with config class: %s", app_config_class.__name__
110
+ )
111
+
112
+ # Create the configuration instance
113
+ configuration = Configuration[app_config_class](_env_file=env_file)
114
+
115
+ # Register the configuration instance
116
+ self.register_as(Configuration, configuration)
117
+ self._register_config_recursively(configuration)
118
+
119
+ logger.info("Application configuration initialized and registered")
120
+ logger.info("Total registered instances: %d", len(self._instances))
121
+ for registered_type in self.list_registered():
122
+ logger.debug("Registered: %s", registered_type.__name__)
123
+
124
+ return configuration
125
+
126
+ def register(self, instance: object) -> None:
127
+ """Register an initialized instance using its class type as the key."""
128
+ instance_type = type(instance)
129
+
130
+ if instance_type in self._instances:
131
+ logger.warning(
132
+ "Overwriting existing instance of type: %s", instance_type.__name__
133
+ )
134
+
135
+ self._instances[instance_type] = instance
136
+ logger.debug("Registered instance of type %s", instance_type.__name__)
137
+
138
+ def register_as(self, instance_type: type[T], instance: T) -> None:
139
+ """Register an initialized instance with a specific type as the key."""
140
+ if instance_type in self._instances:
141
+ logger.warning(
142
+ "Overwriting existing instance of type: %s", instance_type.__name__
143
+ )
144
+
145
+ self._instances[instance_type] = instance
146
+ logger.debug("Registered instance as type %s", instance_type.__name__)
147
+
148
+ def get(self, instance_type: type[T]) -> T:
149
+ """Retrieve an instance by its class type, returning None if not found."""
150
+ instance: type[T] | None = self._instances.get(instance_type)
151
+ if instance is None:
152
+ logger.error(
153
+ "Instance of type %s not found in registry", instance_type.__name__
154
+ )
155
+ raise KeyError(
156
+ f"Instance of type {instance_type.__name__} not found in registry"
157
+ )
158
+ return cast(T, instance)
159
+
160
+ def unregister(self, instance_type: type[T]) -> None:
161
+ """Remove an instance from the registry by its class type."""
162
+ if instance_type in self._instances:
163
+ del self._instances[instance_type]
164
+ logger.debug("Unregistered instance of type: %s", instance_type.__name__)
165
+ else:
166
+ logger.warning(
167
+ "Attempted to unregister non-existent type: %s", instance_type.__name__
168
+ )
169
+
170
+ def list_registered(self) -> list[type[Any]]:
171
+ """Get a list of all registered class types."""
172
+ return list(self._instances.keys())
173
+
174
+ def has(self, instance_type: type[T]) -> bool:
175
+ """Check if an instance is registered for the given class type."""
176
+ return instance_type in self._instances
177
+
178
+ def clear(self) -> None:
179
+ """Clear all registered instances."""
180
+ count = len(self._instances)
181
+ self._instances.clear()
182
+ logger.debug("Cleared %d instances from registry", count)
183
+
184
+
185
+ @lru_cache(maxsize=1)
186
+ def service_registry() -> ServiceRegistry:
187
+ logger.debug("Creating the service registry instance")
188
+ return ServiceRegistry()
@@ -0,0 +1,112 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import hmac
5
+ import secrets
6
+
7
+ SALT_CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
8
+ DEFAULT_PBKDF2_ITERATIONS = 1_000_000
9
+
10
+
11
+ def _gen_salt(length: int) -> str:
12
+ """Generate a random string of SALT_CHARS with specified ``length``."""
13
+ if length <= 0:
14
+ raise ValueError("Salt length must be at least 1.")
15
+
16
+ return "".join(secrets.choice(SALT_CHARS) for _ in range(length))
17
+
18
+
19
+ def _hash_internal(method: str, salt: str, password: str) -> tuple[str, str]:
20
+ method, *args = method.split(":")
21
+ salt_bytes = salt.encode()
22
+ password_bytes = password.encode()
23
+
24
+ if method == "scrypt":
25
+ if not args:
26
+ n = 2**15
27
+ r = 8
28
+ p = 1
29
+ else:
30
+ try:
31
+ n, r, p = map(int, args)
32
+ except ValueError:
33
+ raise ValueError("'scrypt' takes 3 arguments.") from None
34
+
35
+ maxmem = 132 * n * r * p # ideally 128, but some extra seems needed
36
+ return (
37
+ hashlib.scrypt(
38
+ password_bytes, salt=salt_bytes, n=n, r=r, p=p, maxmem=maxmem
39
+ ).hex(),
40
+ f"scrypt:{n}:{r}:{p}",
41
+ )
42
+ if method == "pbkdf2":
43
+ len_args = len(args)
44
+
45
+ if len_args == 0:
46
+ hash_name = "sha256"
47
+ iterations = DEFAULT_PBKDF2_ITERATIONS
48
+ elif len_args == 1:
49
+ hash_name = args[0]
50
+ iterations = DEFAULT_PBKDF2_ITERATIONS
51
+ elif len_args == 2: # noqa: PLR2004
52
+ hash_name = args[0]
53
+ iterations = int(args[1])
54
+ else:
55
+ raise ValueError("'pbkdf2' takes 2 arguments.")
56
+
57
+ return (
58
+ hashlib.pbkdf2_hmac(
59
+ hash_name, password_bytes, salt_bytes, iterations
60
+ ).hex(),
61
+ f"pbkdf2:{hash_name}:{iterations}",
62
+ )
63
+ raise ValueError(f"Invalid hash method '{method}'.")
64
+
65
+
66
+ def generate_password_hash(
67
+ password: str, method: str = "scrypt", salt_length: int = 16
68
+ ) -> str:
69
+ """Securely hash a password for storage. A password can be compared to a stored hash
70
+ using :func:`check_password_hash`.
71
+
72
+ The following methods are supported:
73
+
74
+ - ``scrypt``, the default. The parameters are ``n``, ``r``, and ``p``, the default
75
+ is ``scrypt:32768:8:1``. See :func:`hashlib.scrypt`.
76
+ - ``pbkdf2``, less secure. The parameters are ``hash_method`` and ``iterations``,
77
+ the default is ``pbkdf2:sha256:600000``. See :func:`hashlib.pbkdf2_hmac`.
78
+
79
+ Default parameters may be updated to reflect current guidelines, and methods may be
80
+ deprecated and removed if they are no longer considered secure. To migrate old
81
+ hashes, you may generate a new hash when checking an old hash, or you may contact
82
+ users with a link to reset their password.
83
+
84
+ :param password: The plaintext password.
85
+ :param method: The key derivation function and parameters.
86
+ :param salt_length: The number of characters to generate for the salt.
87
+
88
+ .. versionchanged:: 3.1
89
+ The default iterations for pbkdf2 was increased to 1,000,000.
90
+ """
91
+ salt = _gen_salt(salt_length)
92
+ h, actual_method = _hash_internal(method, salt, password)
93
+ return f"{actual_method}${salt}${h}"
94
+
95
+
96
+ def check_password_hash(pwhash: str, password: str) -> bool:
97
+ """Securely check that the given stored password hash, previously generated using
98
+ :func:`generate_password_hash`, matches the given password.
99
+
100
+ Methods may be deprecated and removed if they are no longer considered secure. To
101
+ migrate old hashes, you may generate a new hash when checking an old hash, or you
102
+ may contact users with a link to reset their password.
103
+
104
+ :param pwhash: The hashed password.
105
+ :param password: The plaintext password.
106
+ """
107
+ try:
108
+ method, salt, hashval = pwhash.split("$", 2)
109
+ except ValueError:
110
+ return False
111
+
112
+ return hmac.compare_digest(_hash_internal(method, salt, password)[0], hashval)