half-orm-gen 1.0.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,34 @@
1
+ """
2
+ Frontend store generator for halfORM/Litestar projects.
3
+ """
4
+
5
+ from pathlib import Path
6
+ from half_orm_gen.gen_store.base import StoreGenerator
7
+
8
+
9
+ class GenStore:
10
+ """
11
+ Generate frontend stores from CRUD_ACCESS introspection.
12
+
13
+ Parameters
14
+ ----------
15
+ repo:
16
+ A ``half_orm_dev.repo.Repo`` instance.
17
+ generator:
18
+ A :class:`StoreGenerator` subclass instance (e.g. SvelteGenerator).
19
+ output_dir:
20
+ Directory where the generated files will be written.
21
+ api_version:
22
+ Integer API version (used to build route prefixes).
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ repo,
28
+ *,
29
+ generator: StoreGenerator,
30
+ output_dir: Path,
31
+ api_version: int | None = None,
32
+ ):
33
+ classes = list(repo.model.classes())
34
+ generator.generate(classes, api_version, output_dir)
@@ -0,0 +1,88 @@
1
+ """
2
+ Abstract base class for frontend store generators.
3
+ """
4
+
5
+ from abc import ABC, abstractmethod
6
+ from pathlib import Path
7
+
8
+
9
+ class StoreGenerator(ABC):
10
+
11
+ PY_TO_TS = {
12
+ 'str': 'string',
13
+ 'int': 'number',
14
+ 'float': 'number',
15
+ 'bool': 'boolean',
16
+ 'uuid.UUID': 'string',
17
+ 'datetime.datetime': 'string',
18
+ 'datetime.date': 'string',
19
+ 'datetime.time': 'string',
20
+ 'datetime.timedelta':'string',
21
+ 'decimal.Decimal': 'number',
22
+ }
23
+
24
+ def ts_type(self, py_type_str: str) -> str:
25
+ return self.PY_TO_TS.get(py_type_str, 'unknown')
26
+
27
+ def resource_name(self, schema: str, table: str) -> str:
28
+ """blogAuthor (camelCase)"""
29
+ parts = schema.split('_') + table.split('_')
30
+ return parts[0].lower() + ''.join(p.capitalize() for p in parts[1:])
31
+
32
+ def interface_name(self, schema: str, table: str) -> str:
33
+ """BlogAuthor (PascalCase)"""
34
+ parts = schema.split('_') + table.split('_')
35
+ return ''.join(p.capitalize() for p in parts)
36
+
37
+ def _fk_deps(self, inst, out_names: list, crud_resources: set) -> list:
38
+ """Return (local_field, remote_schema, remote_table, remote_pk) for each
39
+ simple non-reverse FK whose local field is in out_names and whose remote
40
+ table is in crud_resources."""
41
+ deps = []
42
+ for fk in getattr(inst, '_ho_fkeys', {}).values():
43
+ if fk.is_reverse:
44
+ continue
45
+ local_fields = fk.names
46
+ remote_pks = fk.fk_names
47
+ if len(local_fields) != 1 or len(remote_pks) != 1:
48
+ continue
49
+ local_field = local_fields[0]
50
+ if local_field not in out_names:
51
+ continue
52
+ fqtn = fk.remote['fqtn']
53
+ remote_schema = fqtn[0].replace('.', '_')
54
+ remote_table = fqtn[1]
55
+ if (remote_schema, remote_table) not in crud_resources:
56
+ continue
57
+ deps.append((local_field, remote_schema, remote_table, remote_pks[0]))
58
+ return deps
59
+
60
+ def _reverse_fk_deps(self, inst, pk_field: str | None, crud_resources: set) -> list:
61
+ """Return (remote_schema, remote_table, fk_field) for each simple reverse FK
62
+ whose remote table is in crud_resources. Deduplicated by remote table."""
63
+ if not pk_field:
64
+ return []
65
+ deps = []
66
+ seen: set[tuple[str, str]] = set()
67
+ for fk in getattr(inst, '_ho_fkeys', {}).values():
68
+ if not fk.is_reverse:
69
+ continue
70
+ our_pk_fields = fk.names
71
+ remote_fk_fields = fk.fk_names
72
+ if len(our_pk_fields) != 1 or len(remote_fk_fields) != 1:
73
+ continue
74
+ if our_pk_fields[0] != pk_field:
75
+ continue
76
+ fqtn = fk.remote['fqtn']
77
+ remote_schema = fqtn[0].replace('.', '_')
78
+ remote_table = fqtn[1]
79
+ if (remote_schema, remote_table) not in crud_resources:
80
+ continue
81
+ if (remote_schema, remote_table) in seen:
82
+ continue
83
+ seen.add((remote_schema, remote_table))
84
+ deps.append((remote_schema, remote_table, remote_fk_fields[0]))
85
+ return deps
86
+
87
+ @abstractmethod
88
+ def generate(self, classes, api_version, output_dir: Path) -> None: ...
@@ -0,0 +1,282 @@
1
+ """
2
+ Svelte 5 / TypeScript store generator (.svelte.ts, $state runes).
3
+ """
4
+
5
+ import importlib
6
+ import shutil
7
+ from pathlib import Path
8
+
9
+ from half_orm_gen.crud_routes import (
10
+ _gen_out_fields,
11
+ _gen_in_fields,
12
+ _pk_info,
13
+ _simple_pk,
14
+ _instance,
15
+ _py_type_str,
16
+ )
17
+ from half_orm_gen.gen_store.base import StoreGenerator
18
+
19
+
20
+ class SvelteGenerator(StoreGenerator):
21
+
22
+ def generate(self, classes, api_version, output_dir: Path) -> None:
23
+ if output_dir.exists():
24
+ shutil.rmtree(output_dir)
25
+ output_dir.mkdir(parents=True)
26
+ self._write_base(output_dir)
27
+ version_prefix = f'/v{api_version}' if api_version is not None else ''
28
+
29
+ # Pass 1: collect resources that have CRUD_ACCESS
30
+ resources = []
31
+ crud_resources: set[tuple[str, str]] = set()
32
+
33
+ for relation, _relation_type in classes:
34
+ module_str = relation.__module__
35
+ try:
36
+ mod = importlib.import_module(module_str)
37
+ except ImportError:
38
+ continue
39
+ crud_access = getattr(mod, 'CRUD_ACCESS', None) or {'GET': {}, 'POST': {}, 'PUT': {}, 'DELETE': {}}
40
+ schema_name = relation._t_fqrn[1]
41
+ table_name = relation._t_fqrn[2]
42
+ crud_resources.add((schema_name, table_name))
43
+ resources.append((relation, mod, crud_access, schema_name, table_name))
44
+
45
+ # Pass 2: generate one .svelte.ts per resource
46
+ stems = []
47
+
48
+ for relation, mod, crud_access, schema_name, table_name in resources:
49
+ api_excluded = getattr(mod, 'API_EXCLUDED_FIELDS', [])
50
+ inst = _instance(relation)
51
+ all_fields = getattr(inst, '_ho_fields', {})
52
+ all_names = list(all_fields.keys())
53
+ pk_cols = _pk_info(relation)
54
+ pk_info = pk_cols # truthy iff non-empty
55
+ if len(pk_cols) == 1:
56
+ pk_field = pk_cols[0][0]
57
+ pk_ts_type = self.ts_type(pk_cols[0][2])
58
+ pk_extractor = f'i => String(i.{pk_field})'
59
+ elif len(pk_cols) > 1:
60
+ pk_field = pk_cols[0][0]
61
+ pk_ts_type = 'string'
62
+ pk_extractor = 'i => [' + ', '.join(f'i.{f}' for f, _, _ in pk_cols) + '].map(String).join("::")'
63
+ else:
64
+ pk_field = pk_ts_type = pk_extractor = None
65
+
66
+ iname = self.interface_name(schema_name, table_name)
67
+ rname = self.resource_name(schema_name, table_name)
68
+ base_path = f'{version_prefix}/{schema_name}/{table_name}'
69
+ stem = f'{schema_name}_{table_name}'
70
+
71
+ out_names = _gen_out_fields(crud_access, 'GET', api_excluded, all_names)
72
+ if not out_names:
73
+ out_names = [f for f in all_names if f not in api_excluded]
74
+
75
+ has_post = 'POST' in crud_access and pk_info
76
+ has_put = 'PUT' in crud_access and pk_info
77
+ has_del = 'DELETE' in crud_access and pk_info
78
+
79
+ post_in_names = _gen_in_fields(
80
+ crud_access, 'POST', pk_field, api_excluded, all_names
81
+ ) if has_post else []
82
+ put_in_names = _gen_in_fields(
83
+ crud_access, 'PUT', pk_field, api_excluded, all_names
84
+ ) if has_put else []
85
+
86
+ fk_deps = self._fk_deps(inst, out_names, crud_resources)
87
+
88
+ lines = []
89
+
90
+ # Imports
91
+ lines.append("import { BaseState } from './base.svelte.ts';")
92
+ lines.append("import { auth } from '$lib/auth.svelte.ts';")
93
+ lines.append("import { registerClear } from '$lib/stateRegistry';")
94
+ lines.append('')
95
+
96
+ # FK imports (deduplicated: skip self-referential FKs and multi-FK to same table)
97
+ seen_stems: set[str] = {stem}
98
+ for local_field, remote_schema, remote_table, remote_pk in fk_deps:
99
+ remote_stem = f'{remote_schema}_{remote_table}'
100
+ if remote_stem in seen_stems:
101
+ continue
102
+ seen_stems.add(remote_stem)
103
+ remote_rname = self.resource_name(remote_schema, remote_table)
104
+ lines.append(
105
+ f"import {{ {remote_rname}State }} from './{remote_stem}.svelte.ts';"
106
+ )
107
+ if fk_deps:
108
+ lines.append('')
109
+
110
+ # Interfaces
111
+ lines.append(self._interface(f'{iname}Out', out_names, all_fields))
112
+ if has_post:
113
+ lines.append(self._interface(f'{iname}PostIn', post_in_names, all_fields))
114
+ if has_put:
115
+ lines.append(self._interface(f'{iname}PutIn', put_in_names, all_fields))
116
+
117
+ # State class
118
+ if pk_info:
119
+ lines.append(f'class {iname}State extends BaseState<{iname}Out> {{')
120
+ lines.append(f' constructor() {{ super({pk_extractor}); }}')
121
+ else:
122
+ lines.append(f'class {iname}State {{')
123
+ lines.append(f' items = $state<{iname}Out[]>([]);')
124
+ lines.append(f' setItems(data: {iname}Out[]) {{ this.items = data; }}')
125
+ lines.append(f' mergeItems(data: {iname}Out[]) {{ this.items = data; }}')
126
+
127
+ if fk_deps:
128
+ lines.append('')
129
+ for local_field, remote_schema, remote_table, remote_pk in fk_deps:
130
+ remote_rname = self.resource_name(remote_schema, remote_table)
131
+ map_name = f'_{local_field}Map'
132
+ lines.append(
133
+ f' {map_name} = $derived('
134
+ f'Object.fromEntries({remote_rname}State.items.map('
135
+ f'r => [r.{remote_pk}, r])));'
136
+ )
137
+ lines.append('')
138
+ enriched = ', '.join(
139
+ f'_{lf}: this._{lf}Map[item.{lf}] ?? null'
140
+ for lf, _, _, _ in fk_deps
141
+ )
142
+ lines.append(
143
+ f' itemsWithRelations = $derived('
144
+ f'this.items.map(item => ({{...item, {enriched}}})));'
145
+ )
146
+
147
+ lines.append('}')
148
+ lines.append('')
149
+ lines.append(f'export const {rname}State = new {iname}State();')
150
+ if pk_field:
151
+ lines.append(f'registerClear(() => {rname}State.clear());')
152
+ lines.append('')
153
+
154
+ # API
155
+ lines.append(f"const _BASE = '{base_path}';")
156
+ lines.append("const _hdrs = (extra?: Record<string, string>) => ({")
157
+ lines.append(" ...(auth.token ? { Authorization: `Bearer ${auth.token}` } : {}),")
158
+ lines.append(" ...extra,")
159
+ lines.append("});")
160
+ lines.append("const _fetch = (url: string, opts?: RequestInit) => {")
161
+ lines.append(" const method = opts?.method ?? 'GET';")
162
+ lines.append(" if (method === 'GET') auth.fetchedRoutes.add(url);")
163
+ lines.append(" return fetch(url, opts);")
164
+ lines.append("};")
165
+ lines.append('')
166
+ api_entries = []
167
+ if 'GET' in crud_access:
168
+ api_entries.append(
169
+ f" listUrl: (params: Partial<{iname}Out> = {{}}) =>\n"
170
+ f" _BASE + '?' + new URLSearchParams(params as any),"
171
+ )
172
+ api_entries.append(
173
+ f" list: (params: Partial<{iname}Out> = {{}}) =>\n"
174
+ f" _fetch(_BASE + '?' + new URLSearchParams(params as any),\n"
175
+ f" {{ headers: _hdrs() }}),"
176
+ )
177
+ if pk_info:
178
+ api_entries.append(
179
+ f" getUrl: (id: {pk_ts_type}) => `${{_BASE}}/${{id}}`,"
180
+ )
181
+ api_entries.append(
182
+ f" get: (id: {pk_ts_type}) => {{\n"
183
+ f" const _c = {rname}State.byId.get(String(id));\n"
184
+ f" if (_c) return Promise.resolve(new Response(JSON.stringify(_c),\n"
185
+ f" {{ status: 200, headers: {{ 'Content-Type': 'application/json' }} }}));\n"
186
+ f" return _fetch(`${{_BASE}}/${{id}}`, {{ headers: _hdrs() }});\n"
187
+ f" }},"
188
+ )
189
+ if has_post:
190
+ api_entries.append(
191
+ f" create: (data: {iname}PostIn) =>\n"
192
+ f" _fetch(_BASE, {{ method: 'POST',\n"
193
+ f" headers: _hdrs({{'Content-Type': 'application/json'}}),\n"
194
+ f" body: JSON.stringify(data) }}),"
195
+ )
196
+ if has_put:
197
+ api_entries.append(
198
+ f" update: (id: {pk_ts_type}, data: {iname}PutIn) =>\n"
199
+ f" _fetch(`${{_BASE}}/${{id}}`, {{ method: 'PUT',\n"
200
+ f" headers: _hdrs({{'Content-Type': 'application/json'}}),\n"
201
+ f" body: JSON.stringify(data) }}),"
202
+ )
203
+ if has_del:
204
+ api_entries.append(
205
+ f" remove: (id: {pk_ts_type}) =>\n"
206
+ f" _fetch(`${{_BASE}}/${{id}}`,\n"
207
+ f" {{ method: 'DELETE', headers: _hdrs() }}),"
208
+ )
209
+ lines.append(f'export const {rname}Api = {{')
210
+ lines.extend(api_entries)
211
+ lines.append('};')
212
+ lines.append('')
213
+
214
+ out_file = output_dir / f'{stem}.svelte.ts'
215
+ out_file.write_text('\n'.join(lines), encoding='utf-8')
216
+ print(f' {out_file}')
217
+ stems.append(stem)
218
+
219
+ if stems:
220
+ self._write_index(output_dir, stems, version_prefix)
221
+
222
+ def _write_base(self, output_dir: Path) -> None:
223
+ content = """\
224
+ export class BaseState<V> {
225
+ byId = $state(new Map<string, V>());
226
+ items = $derived([...this.byId.values()]);
227
+
228
+ constructor(private readonly pk: (item: V) => string) {}
229
+
230
+ clear() {
231
+ this.byId = new Map();
232
+ }
233
+ setItems(data: V[]) {
234
+ this.byId = new Map(data.map(i => [this.pk(i), i]));
235
+ }
236
+ mergeItems(data: V[]) {
237
+ const m = new Map(this.byId);
238
+ data.forEach(i => m.set(this.pk(i), i));
239
+ this.byId = m;
240
+ }
241
+ setItem(item: V) {
242
+ const m = new Map(this.byId);
243
+ m.set(this.pk(item), item);
244
+ this.byId = m;
245
+ }
246
+ removeItem(id: string) {
247
+ const m = new Map(this.byId);
248
+ m.delete(id);
249
+ this.byId = m;
250
+ }
251
+ }
252
+ """
253
+ base_file = output_dir / 'base.svelte.ts'
254
+ base_file.write_text(content, encoding='utf-8')
255
+ print(f' {base_file}')
256
+
257
+ def _interface(self, name: str, field_names: list, all_fields: dict) -> str:
258
+ if not field_names:
259
+ return f'export interface {name} {{}}\n'
260
+ props = '\n'.join(
261
+ f' {f}: {self.ts_type(_py_type_str(all_fields[f].py_type))};'
262
+ for f in field_names if f in all_fields
263
+ )
264
+ return f'export interface {name} {{\n{props}\n}}\n'
265
+
266
+ def _write_index(self, output_dir: Path, stems: list, version_prefix: str) -> None:
267
+ lines = [f"export * from './{s}.svelte.ts';" for s in stems]
268
+ lines += [
269
+ '',
270
+ 'export async function hoAccess(token?: string): Promise<Record<string, any>> {',
271
+ ' const headers: Record<string, string> = token',
272
+ ' ? { Authorization: `Bearer ${token}` }',
273
+ ' : {};',
274
+ f" const res = await fetch('{version_prefix}/ho_access', {{ headers }});",
275
+ " if (!res.ok) throw new Error(`ho_access: ${res.status}`);",
276
+ ' return res.json();',
277
+ '}',
278
+ '',
279
+ ]
280
+ index = output_dir / 'index.svelte.ts'
281
+ index.write_text('\n'.join(lines), encoding='utf-8')
282
+ print(f' {index}')
@@ -0,0 +1,120 @@
1
+ """
2
+ Litestar API generator for halfORM projects.
3
+
4
+ Orchestrates generation of api/app.py by combining:
5
+ - @api_* decorated route handlers (api_routes.py)
6
+ - Auto-CRUD handlers from CRUD_ACCESS (crud_routes.py)
7
+ - Scaffolding of missing api/ files (scaffold.py)
8
+ """
9
+
10
+ import os
11
+ from pathlib import Path
12
+ from typing import Iterable, Tuple, Type
13
+
14
+ from half_orm.relation import Relation
15
+
16
+ from half_orm_gen import templates as T
17
+ from half_orm_gen.scaffold import scaffold_api_dir
18
+ from half_orm_gen.api_routes import generate_api_routes
19
+ from half_orm_gen.crud_routes import generate_crud_routes
20
+
21
+
22
+ class GenApi:
23
+ """
24
+ Generate ``api/app.py`` from a halfORM project.
25
+
26
+ Parameters
27
+ ----------
28
+ repo:
29
+ A ``half_orm_dev.repo.Repo`` instance. When *None*, supply
30
+ *relation_classes*, *module_name*, and *base_dir* directly.
31
+ relation_classes:
32
+ Iterable of ``(RelationClass, relation_type)`` pairs (used when
33
+ *repo* is *None*).
34
+ module_name:
35
+ Top-level Python package name of the halfORM model (e.g. ``"mydb"``).
36
+ base_dir:
37
+ Root directory of the project (``api/`` is created inside it).
38
+ api_version:
39
+ Integer API version (written as ``/vN/`` prefix in routes).
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ repo=None,
45
+ *,
46
+ relation_classes: Iterable[Tuple[Type[Relation], str]] | None = None,
47
+ module_name: str | None = None,
48
+ base_dir: str | None = None,
49
+ api_version: int | None = None,
50
+ framework: str = 'litestar',
51
+ ):
52
+ if repo is not None:
53
+ self._module_name = repo.name
54
+ self._base_dir = Path(repo.base_dir)
55
+ self._classes = list(repo.model.classes())
56
+ else:
57
+ if relation_classes is None or module_name is None or base_dir is None:
58
+ raise ValueError(
59
+ "Provide either a repo or (relation_classes, module_name, base_dir)."
60
+ )
61
+ self._module_name = module_name
62
+ self._base_dir = Path(base_dir)
63
+ self._classes = list(relation_classes)
64
+
65
+ self._api_version = api_version
66
+ self._framework = framework
67
+ self._api_dir = self._base_dir / 'api'
68
+ self._generate()
69
+
70
+ def _generate(self) -> None:
71
+ os.environ.setdefault('API_GEN_MODE', '1')
72
+
73
+ if self._framework == 'fastapi':
74
+ from half_orm_gen import templates_fastapi as templates
75
+ api_blocks, api_handlers, covered = [], [], set()
76
+ else:
77
+ templates = T
78
+ api_blocks, api_handlers, covered = generate_api_routes(
79
+ self._classes, self._api_version
80
+ )
81
+
82
+ # --- auto-CRUD routes ---
83
+ crud_blocks, crud_handlers = generate_crud_routes(
84
+ self._classes, self._api_version, covered, templates=templates
85
+ )
86
+
87
+ # --- assemble app.py ---
88
+ openapi_config = (
89
+ templates.OPENAPI_CONFIG.format(
90
+ title=self._module_name,
91
+ version=f'v{self._api_version}',
92
+ )
93
+ if self._api_version is not None
94
+ else ''
95
+ )
96
+
97
+ output = (
98
+ templates.HEADER.format(module=self._module_name)
99
+ + (templates.CRUD_HELPERS if crud_blocks else '')
100
+ + ''.join(api_blocks)
101
+ + ''.join(crud_blocks)
102
+ )
103
+
104
+ if self._framework == 'fastapi':
105
+ output += templates.FOOTER.format(openapi_config=openapi_config)
106
+ else:
107
+ route_handlers_str = ', '.join(api_handlers + crud_handlers)
108
+ output += templates.FOOTER.format(
109
+ route_handlers=route_handlers_str,
110
+ openapi_config=openapi_config,
111
+ )
112
+ # --- scaffold missing api/ files ---
113
+ print(f'\nScaffolding {self._api_dir} ...')
114
+ scaffold_api_dir(self._api_dir)
115
+
116
+ # --- write app.py ---
117
+ self._api_dir.mkdir(parents=True, exist_ok=True)
118
+ app_py = self._api_dir / 'app.py'
119
+ app_py.write_text(output, encoding='utf-8')
120
+ print(f'\nGenerated {app_py}')
@@ -0,0 +1,37 @@
1
+ """
2
+ Scaffolding helpers for half-orm-litestar.
3
+
4
+ Creates missing api/ files on first generate. Never overwrites existing files.
5
+ """
6
+
7
+ import shutil
8
+ from pathlib import Path
9
+
10
+ _SCAFFOLDING_DIR = Path(__file__).parent / 'scaffolding'
11
+
12
+
13
+ def scaffold_api_dir(api_dir: Path) -> None:
14
+ """Create missing api/ scaffolding files. Never overwrites existing files."""
15
+ files = {
16
+ api_dir / 'guards.py':
17
+ _SCAFFOLDING_DIR / 'guards.py',
18
+ api_dir / '__init__.py':
19
+ _SCAFFOLDING_DIR / 'api_init.py',
20
+ api_dir / 'custom' / 'routes.py':
21
+ _SCAFFOLDING_DIR / 'custom_routes.py',
22
+ api_dir / 'custom' / '__init__.py':
23
+ _SCAFFOLDING_DIR / 'custom_init.py',
24
+ api_dir / 'custom' / 'middlewares' / '__init__.py':
25
+ _SCAFFOLDING_DIR / 'custom_middlewares_init.py',
26
+ api_dir / 'custom' / 'middlewares' / 'authorization.py':
27
+ _SCAFFOLDING_DIR / 'custom_authorization.py',
28
+ api_dir / 'roles' / 'core.py':
29
+ _SCAFFOLDING_DIR / 'roles_core.py',
30
+ }
31
+ for dest, src in files.items():
32
+ if not dest.exists():
33
+ dest.parent.mkdir(parents=True, exist_ok=True)
34
+ shutil.copy(src, dest)
35
+ print(f' created {dest}')
36
+ else:
37
+ print(f' exists {dest}')
@@ -0,0 +1 @@
1
+ # api package
@@ -0,0 +1,36 @@
1
+ """
2
+ Authorization middleware.
3
+
4
+ Implement the ``Authorization`` class below to handle authentication for your
5
+ API (JWT, session cookies, API keys, etc.).
6
+
7
+ When present, it is automatically placed first in the middleware stack by
8
+ ``api/main.py``.
9
+
10
+ Example (JWT bearer token)::
11
+
12
+ import jwt
13
+ from litestar.middleware import AbstractAuthenticationMiddleware, AuthenticationResult
14
+ from litestar.connection import ASGIConnection
15
+
16
+ SECRET = "change-me"
17
+
18
+ class Authorization(AbstractAuthenticationMiddleware):
19
+ async def authenticate_request(self, connection: ASGIConnection) -> AuthenticationResult:
20
+ token = connection.headers.get("Authorization", "").removeprefix("Bearer ").strip()
21
+ try:
22
+ payload = jwt.decode(token, SECRET, algorithms=["HS256"])
23
+ return AuthenticationResult(user=payload["sub"], auth=token)
24
+ except Exception:
25
+ return AuthenticationResult(user=None, auth=None)
26
+
27
+ See https://docs.litestar.dev/latest/usage/security/abstract-authentication-middleware.html
28
+ """
29
+
30
+ # TODO: implement Authorization
31
+ # from litestar.middleware import AbstractAuthenticationMiddleware, AuthenticationResult
32
+ # from litestar.connection import ASGIConnection
33
+ #
34
+ # class Authorization(AbstractAuthenticationMiddleware):
35
+ # async def authenticate_request(self, connection: ASGIConnection) -> AuthenticationResult:
36
+ # raise NotImplementedError
@@ -0,0 +1 @@
1
+ # api/custom package
@@ -0,0 +1,18 @@
1
+ """
2
+ Custom Litestar middlewares for this project.
3
+
4
+ Add your middleware classes here and list them in ``middlewares``.
5
+ They will be prepended to the middleware stack (after the optional
6
+ ``Authorization`` middleware).
7
+
8
+ Example::
9
+
10
+ from litestar.middleware import AbstractMiddleware
11
+
12
+ class MyMiddleware(AbstractMiddleware):
13
+ ...
14
+
15
+ middlewares = [MyMiddleware]
16
+ """
17
+
18
+ middlewares = []
@@ -0,0 +1,18 @@
1
+ """
2
+ Custom Litestar route handlers for this project.
3
+
4
+ Add any hand-written route handlers here and list them in ``routes``.
5
+ They will be registered alongside the auto-generated routes.
6
+
7
+ Example::
8
+
9
+ from litestar import get
10
+
11
+ @get('/health')
12
+ async def health_check() -> dict:
13
+ return {'status': 'ok'}
14
+
15
+ routes = [health_check]
16
+ """
17
+
18
+ routes = []
@@ -0,0 +1,40 @@
1
+ """
2
+ API guards for this project.
3
+
4
+ A guard is an async callable with the signature::
5
+
6
+ async def my_guard(connection: ASGIConnection, handler: BaseRouteHandler) -> None:
7
+ ...
8
+
9
+ It should raise ``NotAuthorizedException`` or ``HTTPException`` to deny access,
10
+ or return ``None`` to allow it.
11
+
12
+ Reference the guards by name in your ``@tools.api_*`` decorators::
13
+
14
+ @tools.api_get('/items/{id: uuid}', guards=['connected'])
15
+ async def get_item(self, request):
16
+ ...
17
+
18
+ See https://docs.litestar.dev/latest/usage/security/guards.html
19
+ """
20
+
21
+ from litestar.connection import ASGIConnection
22
+ from litestar.handlers.base import BaseRouteHandler
23
+ from litestar.exceptions import NotAuthorizedException, HTTPException
24
+
25
+
26
+ async def public(connection: ASGIConnection, handler: BaseRouteHandler = None) -> None:
27
+ """Allow all requests."""
28
+ return
29
+
30
+
31
+ async def connected(connection: ASGIConnection, handler: BaseRouteHandler = None) -> None:
32
+ """Allow only authenticated users."""
33
+ if connection.user:
34
+ return
35
+ raise NotAuthorizedException()
36
+
37
+
38
+ # ---------------------------------------------------------------------------
39
+ # Add your project-specific guards below.
40
+ # ---------------------------------------------------------------------------