Fast-Controller 0.2.0b0__tar.gz → 0.3.0b0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: Fast-Controller
3
- Version: 0.2.0b0
3
+ Version: 0.3.0b0
4
4
  Summary: The fastest way to a turn your models into a full ReST API
5
5
  Keywords: controller,base,rest,api,backend
6
6
  Author-Email: Cody M Sommer <bassmastacod@gmail.com>
@@ -0,0 +1,324 @@
1
+ from contextlib import contextmanager
2
+ from enum import Enum, auto
3
+ from typing import Optional, Callable
4
+ import inspect
5
+
6
+ from daomodel import DAOModel
7
+ from daomodel.dao import NotFound
8
+ from daomodel.db import DAOFactory
9
+ from daomodel.transaction import Conflict
10
+ from fastapi import FastAPI, APIRouter, Request, Response, Depends, Path, Body, Query, Header
11
+ from fastapi.responses import JSONResponse, RedirectResponse
12
+ from sqlalchemy import Engine
13
+ from sqlalchemy.orm import sessionmaker
14
+ from sqlmodel import SQLModel
15
+
16
+ from fast_controller.resource import Resource, get_field_type
17
+ from fast_controller.util import docstring_format, InvalidInput, expose_path_params, extract_values
18
+
19
+
20
+ class Action(Enum):
21
+ SEARCH = auto()
22
+ CREATE = auto()
23
+ UPSERT = auto()
24
+ VIEW = auto()
25
+ UPDATE = auto()
26
+ MODIFY = auto()
27
+ DELETE = auto()
28
+ RENAME = auto()
29
+
30
+ def register_endpoint(self, controller, router: APIRouter, resource: type[Resource]):
31
+ {
32
+ Action.SEARCH: _register_search_endpoint,
33
+ Action.CREATE: _register_create_endpoint,
34
+ Action.UPSERT: _register_upsert_endpoint,
35
+ Action.VIEW: _register_view_endpoint,
36
+ Action.UPDATE: _register_update_endpoint,
37
+ Action.MODIFY: _register_modify_endpoint,
38
+ Action.DELETE: _register_delete_endpoint,
39
+ Action.RENAME: _register_rename_endpoint,
40
+ }[self](controller, router, resource)
41
+
42
+
43
+ def _construct_path(pk):
44
+ path = '/'.join([''] + ['{' + p + '}' for p in pk])
45
+ return path
46
+
47
+
48
+ def _register_search_endpoint(controller, router: APIRouter, resource: type[Resource]):
49
+ @router.get('/', include_in_schema=False)
50
+ async def redirect():
51
+ return RedirectResponse(url='', status_code=307)
52
+
53
+ @router.get(
54
+ '',
55
+ response_model=list[resource.get_output_schema()],
56
+ dependencies=controller.dependencies_for(resource, Action.SEARCH))
57
+ @docstring_format(resource=resource.doc_name())
58
+ def search(response: Response,
59
+ filters: resource.get_search_schema() = Query(),
60
+ x_page: Optional[int] = Header(default=None, gt=0),
61
+ x_per_page: Optional[int] = Header(default=None, gt=0),
62
+ daos: DAOFactory = controller.daos) -> list[DAOModel]:
63
+ """Searches for {resource} by criteria"""
64
+ results = daos[resource].find(x_page, x_per_page, **filters.model_dump(exclude_unset=True))
65
+ response.headers["x-total-count"] = str(results.total)
66
+ response.headers["x-page"] = str(results.page)
67
+ response.headers["x-per-page"] = str(results.per_page)
68
+ return results
69
+
70
+
71
+ def _register_create_endpoint(controller, router: APIRouter, resource: type[Resource]):
72
+ @router.post('/', include_in_schema=False)
73
+ async def redirect():
74
+ return RedirectResponse(url='', status_code=307)
75
+
76
+ @router.post(
77
+ '',
78
+ response_model=resource.get_detailed_output_schema(),
79
+ status_code=201,
80
+ dependencies=controller.dependencies_for(resource, Action.CREATE))
81
+ @docstring_format(resource=resource.doc_name())
82
+ def create(model: resource.get_input_schema(),
83
+ daos: DAOFactory = controller.daos) -> DAOModel:
84
+ """Creates a new {resource}"""
85
+ return daos[resource].create_with(**model.model_dump(exclude_unset=True))
86
+
87
+
88
+ def _register_upsert_endpoint(controller, router: APIRouter, resource: type[Resource]):
89
+ @router.put('/', include_in_schema=False)
90
+ async def redirect():
91
+ return RedirectResponse(url='', status_code=307)
92
+
93
+ @router.put(
94
+ '',
95
+ response_model=resource.get_detailed_output_schema(),
96
+ dependencies=controller.dependencies_for(resource, Action.UPSERT))
97
+ @docstring_format(resource=resource.doc_name())
98
+ def upsert(model: resource.get_input_schema(),
99
+ daos: DAOFactory = controller.daos) -> SQLModel:
100
+ """Creates/modifies a {resource}"""
101
+ daos[resource].upsert(model)
102
+ return model
103
+
104
+
105
+ def _register_view_endpoint(controller, router: APIRouter, resource: type[Resource]):
106
+ pk = [p.name for p in resource.get_pk()]
107
+ path = _construct_path(pk)
108
+
109
+ @router.get(f'{path}/', include_in_schema=False)
110
+ async def redirect():
111
+ return RedirectResponse(url=path, status_code=307)
112
+
113
+ @router.get(
114
+ path,
115
+ response_model=resource.get_detailed_output_schema(),
116
+ dependencies=controller.dependencies_for(resource, Action.VIEW))
117
+ @docstring_format(resource=resource.doc_name())
118
+ def view(daos: DAOFactory = controller.daos, **kwargs) -> DAOModel:
119
+ """Retrieves a detailed view of a {resource}"""
120
+ return daos[resource].get(*extract_values(kwargs, pk))
121
+
122
+ expose_path_params(view, pk)
123
+
124
+
125
+ def _register_update_endpoint(controller, router: APIRouter, resource: type[Resource]):
126
+ pk = [p.name for p in resource.get_pk()]
127
+ path = _construct_path(pk)
128
+
129
+ @router.put(f'{path}/', include_in_schema=False)
130
+ async def redirect():
131
+ return RedirectResponse(url=path, status_code=307)
132
+
133
+ @router.put(
134
+ path,
135
+ response_model=resource.get_detailed_output_schema(),
136
+ dependencies=controller.dependencies_for(resource, Action.UPDATE))
137
+ @docstring_format(resource=resource.doc_name())
138
+ def update(model: resource.get_update_schema(), # TODO - Remove PK from input schema
139
+ pk0=Path(alias=pk[0]),
140
+ daos: DAOFactory = controller.daos) -> DAOModel:
141
+ """Creates/modifies a {resource}"""
142
+ result = daos[resource].get(pk0)
143
+ result.set_values(**model.model_dump(exclude_unset=False))
144
+ daos[resource].commit(result)
145
+ return result
146
+
147
+ expose_path_params(update, pk)
148
+
149
+
150
+ def _register_modify_endpoint(controller, router: APIRouter, resource: type[Resource]):
151
+ pk = [p.name for p in resource.get_pk()]
152
+ path = _construct_path(pk)
153
+
154
+ @router.patch(f'{path}/', include_in_schema=False)
155
+ async def redirect():
156
+ return RedirectResponse(url=path, status_code=307)
157
+
158
+ @router.patch(
159
+ path,
160
+ response_model=resource.get_detailed_output_schema(),
161
+ dependencies=controller.dependencies_for(resource, Action.MODIFY))
162
+ @docstring_format(resource=resource.doc_name())
163
+ def modify(model: resource.get_update_schema(), # TODO - Remove PK from input schema
164
+ daos: DAOFactory = controller.daos, **kwargs) -> DAOModel:
165
+ """Modifies specific fields of a {resource} while leaving others unchanged"""
166
+ dao = daos[resource]
167
+ result = dao.get(*extract_values(kwargs, pk))
168
+ result.set_values(**model.model_dump(exclude_unset=True))
169
+ dao.commit(result)
170
+ return result
171
+
172
+ expose_path_params(modify, pk)
173
+
174
+
175
+ def _register_delete_endpoint(controller, router: APIRouter, resource: type[Resource]):
176
+ pk = [p.name for p in resource.get_pk()]
177
+ path = _construct_path(pk)
178
+
179
+ @router.delete(f'{path}/', include_in_schema=False)
180
+ async def redirect():
181
+ return RedirectResponse(url=path, status_code=307)
182
+
183
+ @router.delete(
184
+ path,
185
+ status_code=204,
186
+ dependencies=controller.dependencies_for(resource, Action.DELETE))
187
+ @docstring_format(resource=resource.doc_name())
188
+ def delete(daos: DAOFactory = controller.daos, **kwargs) -> None:
189
+ """Deletes a {resource}"""
190
+ daos[resource].remove(*extract_values(kwargs, pk))
191
+
192
+ expose_path_params(delete, pk)
193
+
194
+
195
+ def _register_rename_endpoint(controller, router: APIRouter, resource: type[Resource]):
196
+ pk = [p.name for p in resource.get_pk()]
197
+ path = f'{_construct_path(pk)}/rename'
198
+
199
+ @router.post(f'{path}/', include_in_schema=False)
200
+ async def redirect():
201
+ return RedirectResponse(url=path, status_code=307)
202
+
203
+ @router.post(
204
+ path,
205
+ response_model=resource.get_detailed_output_schema(),
206
+ dependencies=controller.dependencies_for(resource, Action.RENAME))
207
+ @docstring_format(resource=resource.doc_name())
208
+ def rename(daos: DAOFactory = controller.daos, **kwargs) -> DAOModel:
209
+ """Renames a {resource}"""
210
+ dao = daos[resource]
211
+ current = dao.get(*extract_values(kwargs, pk))
212
+
213
+ if len(pk) == 1:
214
+ new_value = kwargs['new_pk']
215
+ dao.rename(current, dao.get(new_value))
216
+ else:
217
+ new_values = kwargs.get('new_pk', {})
218
+ new_pk_values = [new_values.get(field, kwargs[field]) for field in pk]
219
+ dao.rename(current, dao.get(*new_pk_values))
220
+
221
+ return current
222
+
223
+ expose_path_params(rename, pk)
224
+
225
+ sig = inspect.signature(rename)
226
+ new_params = list(sig.parameters.values())
227
+ new_params.append(inspect.Parameter(
228
+ 'new_pk',
229
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
230
+ default=Body(),
231
+ annotation=resource.get_pk_schema() if len(pk) > 1 else get_field_type(next(iter(resource.get_pk())))
232
+ ))
233
+ rename.__signature__ = sig.replace(parameters=new_params)
234
+
235
+
236
+ class Controller:
237
+ def __init__(self,
238
+ prefix: Optional[str] = '',
239
+ app: Optional[FastAPI] = None,
240
+ engine: Optional[Engine] = None) -> None:
241
+ self.prefix = prefix
242
+ self.app = None
243
+ self.engine = None
244
+ self.models = None
245
+ if app is not None and engine is not None:
246
+ self.init_app(app, engine)
247
+ self.daos = Depends(self.dao_generator)
248
+
249
+ def init_app(self, app: FastAPI, engine: Engine) -> None:
250
+ self.app = app
251
+ self.engine = engine
252
+
253
+ @app.exception_handler(InvalidInput)
254
+ async def not_found_handler(request: Request, exc: InvalidInput):
255
+ return JSONResponse(status_code=400, content={"detail": exc.detail})
256
+
257
+ @app.exception_handler(NotFound)
258
+ async def not_found_handler(request: Request, exc: NotFound):
259
+ return JSONResponse(status_code=404, content={"detail": exc.detail})
260
+
261
+ @app.exception_handler(Conflict)
262
+ async def not_found_handler(request: Request, exc: Conflict):
263
+ return JSONResponse(status_code=409, content={"detail": exc.detail})
264
+
265
+ def dao_generator(self) -> DAOFactory:
266
+ """Yields a DAOFactory."""
267
+ with DAOFactory(sessionmaker(bind=self.engine)) as daos:
268
+ yield daos
269
+
270
+ @contextmanager
271
+ def dao_context(self):
272
+ yield from self.dao_generator()
273
+
274
+ def dependencies_for(self, resource: type[Resource], action: Action) -> list[Depends]:
275
+ return []
276
+
277
+ def register_resource(self,
278
+ resource: type[Resource],
279
+ skip: Optional[set[Action]] = frozenset(),
280
+ additional_endpoints: Optional[Callable] = None) -> None:
281
+ api_router = APIRouter(
282
+ prefix=self.prefix + resource.get_resource_path(),
283
+ tags=[resource.doc_name()])
284
+ self._register_resource_endpoints(api_router, resource, skip)
285
+ if additional_endpoints:
286
+ additional_endpoints(api_router, self)
287
+ self.app.include_router(api_router)
288
+
289
+ def _register_resource_endpoints(self,
290
+ router: APIRouter,
291
+ resource: type[Resource],
292
+ skip: Optional[set[Action]] = frozenset()) -> None:
293
+ for action in Action:
294
+ if action not in skip:
295
+ action.register_endpoint(self, router, resource)
296
+
297
+ # TODO: finish implementing merge endpoint
298
+ def _register_merge_endpoint(self,
299
+ router: APIRouter,
300
+ resource: type[Resource],
301
+ path: str,
302
+ pk: list[str]):
303
+ @router.post(
304
+ f'{path}/merge',
305
+ response_model=resource.get_detailed_output_schema(),
306
+ dependencies=self.dependencies_for(resource, Action.RENAME))
307
+ @docstring_format(resource=resource.doc_name())
308
+ def merge(pk0=Path(alias=pk[0]),
309
+ target_id=Body(alias=pk[0]),
310
+ daos: DAOFactory = self.daos) -> DAOModel:
311
+ source = daos[resource].get(pk0)
312
+ # for model in all_models(self.engine):
313
+ # for column in model.get_references_of(resource):
314
+ #daos[type[model]].find(column.name=)
315
+ # if fk.column.table.name == target_table_name and fk.column.name in target_column_values:
316
+ # print(f"Foreign key in table {table.name} references the column '{fk.column.name}' in {target_table.name}")
317
+ # # Retrieve rows in this table that reference the target row
318
+ # conn = engine.connect()
319
+ # condition = (table.c[fk.parent.name] == target_column_values[fk.column.name])
320
+ # result = conn.execute(table.select().where(condition))
321
+ # referencing_rows.extend(result.fetchall())
322
+ # conn.close()
323
+ #
324
+ # return referencing_rows
@@ -7,8 +7,7 @@ from sqlmodel import SQLModel
7
7
 
8
8
 
9
9
  def either(preferred: Any, default: type[SQLModel]) -> type[SQLModel]:
10
- """
11
- Returns the preferred type if present, otherwise the default type.
10
+ """Returns the preferred type if present, otherwise the default type.
12
11
 
13
12
  :param preferred: The type to return if not None
14
13
  :param default: The type to return if the preferred is not a model
@@ -17,6 +16,15 @@ def either(preferred: Any, default: type[SQLModel]) -> type[SQLModel]:
17
16
  return preferred if isclass(preferred) and issubclass(preferred, SQLModel) else default
18
17
 
19
18
 
19
+ def get_field_type(field) -> type:
20
+ """Returns the equivalent type for the given field.
21
+
22
+ :param field: The Column of an SQLModel
23
+ :return: the Python type used to represent the DB Column value
24
+ """
25
+ return getattr(field.type, 'impl', field.type).python_type
26
+
27
+
20
28
  class Resource(DAOModel):
21
29
  __abstract__ = True
22
30
  _default_schema: type[SQLModel]
@@ -27,8 +35,8 @@ class Resource(DAOModel):
27
35
 
28
36
  @classmethod
29
37
  def get_resource_path(cls) -> str:
30
- """
31
- Returns the URI path to this resource as defined by the 'path' class variable.
38
+ """Returns the URI path to this resource as defined by the 'path' class variable.
39
+
32
40
  A default value of `/api/{resource_name} is returned unless overridden.
33
41
 
34
42
  :return: The URI path to be used for this Resource
@@ -48,11 +56,6 @@ class Resource(DAOModel):
48
56
  if hasattr(field, 'class_') and field.class_ is not cls and hasattr(field, 'table') and field.table.name:
49
57
  field_name = f'{field.table.name}_{field_name}'
50
58
  return field_name
51
- def get_field_type(field) -> type:
52
- field_type = field.type
53
- if hasattr(field_type, 'impl'):
54
- field_type = field_type.impl
55
- return field_type.python_type
56
59
  fields = [field[-1] if isinstance(field, tuple) else field for field in cls.get_searchable_properties()]
57
60
  field_types = {
58
61
  get_field_name(field): (get_field_type(field), None) for field in fields
@@ -62,6 +65,14 @@ class Resource(DAOModel):
62
65
  **field_types
63
66
  )
64
67
 
68
+ @classmethod
69
+ def get_pk_schema(cls) -> type[SQLModel]:
70
+ """Returns an SQLModel representing the primary key fields"""
71
+ return create_model(
72
+ f'{cls.doc_name()}PKSchema',
73
+ **{field.name: (get_field_type(field), ...) for field in cls.get_pk()}
74
+ )
75
+
65
76
  @classmethod
66
77
  def get_base(cls) -> type[SQLModel]:
67
78
  return cls
@@ -0,0 +1,77 @@
1
+ import inspect
2
+ from typing import Callable, get_type_hints, Optional
3
+ from warnings import deprecated
4
+
5
+ from sqlmodel import SQLModel
6
+
7
+
8
+ class InvalidInput(Exception):
9
+ """Indicates that the user provided bad input."""
10
+ def __init__(self, detail: str):
11
+ self.detail = detail
12
+
13
+
14
+ def docstring_format(**kwargs):
15
+ """
16
+ A decorator that formats the docstring of a function with specified values.
17
+
18
+ :param kwargs: The values to inject into the docstring
19
+ """
20
+ def decorator(func: Callable):
21
+ func.__doc__ = func.__doc__.format(**kwargs)
22
+ return func
23
+ return decorator
24
+
25
+
26
+ @deprecated("No usages and no test coverage")
27
+ def all_optional(superclass: type[SQLModel]):
28
+ """Creates a new SQLModel for the specified class but having no required fields.
29
+
30
+ :param superclass: The SQLModel of which to make all fields Optional
31
+ :return: The newly wrapped Model
32
+ """
33
+ class OptionalModel(superclass):
34
+ pass
35
+ for field, field_type in get_type_hints(OptionalModel).items():
36
+ if not isinstance(field_type, type(Optional)):
37
+ OptionalModel.__annotations__[field] = Optional[field_type]
38
+ return OptionalModel
39
+
40
+
41
+ def expose_path_params(func: Callable, field_names: list[str]) -> Callable:
42
+ """Converts implicit path parameters from **kwargs to explicit parameters.
43
+
44
+ Takes a function using **kwargs and modifies its signature to expose specific
45
+ field names as explicit path parameters (field1, field2, etc.) with Path defaults,
46
+ making them visible to FastAPI's routing system. All existing parameters
47
+ (except **kwargs) are preserved in their original order.e
48
+
49
+ :param func: The function to modify
50
+ :param field_names: List of field names to expose as path parameters
51
+ :return: The modified function with an updated signature
52
+ """
53
+ sig = inspect.signature(func)
54
+ new_params = []
55
+
56
+ for field_name in field_names:
57
+ new_params.append(inspect.Parameter(
58
+ field_name,
59
+ inspect.Parameter.POSITIONAL_OR_KEYWORD
60
+ ))
61
+
62
+ for param_name, param in sig.parameters.items():
63
+ if param.kind != inspect.Parameter.VAR_KEYWORD:
64
+ new_params.append(param)
65
+
66
+ func.__signature__ = sig.replace(parameters=new_params)
67
+ return func
68
+
69
+
70
+ def extract_values(kwargs: dict, field_names: list[str]) -> list:
71
+ """Extracts values from kwargs in the specified order.
72
+
73
+ :param kwargs: Dictionary containing the function arguments
74
+ :param field_names: List of field names in the desired order
75
+ :return: List of values in the same order as field_names
76
+ """
77
+ return [kwargs[field] for field in field_names]
@@ -40,7 +40,7 @@ classifiers = [
40
40
  "Topic :: Software Development :: Libraries",
41
41
  "Typing :: Typed",
42
42
  ]
43
- version = "0.2.0b0"
43
+ version = "0.3.0b0"
44
44
 
45
45
  [project.license]
46
46
  text = "MIT"
@@ -0,0 +1,184 @@
1
+ import inspect
2
+ from unittest.mock import Mock
3
+
4
+ from fast_controller import docstring_format
5
+ from fast_controller.util import expose_path_params, extract_values
6
+
7
+
8
+ @docstring_format(key="value")
9
+ def test_docstring_format():
10
+ """{key}"""
11
+ assert inspect.getdoc(test_docstring_format) == "value"
12
+
13
+
14
+ @docstring_format(key="value")
15
+ def test_docstring_format__empty():
16
+ """"""
17
+ assert inspect.getdoc(test_docstring_format__empty) == ""
18
+
19
+
20
+ @docstring_format(key="value")
21
+ def test_docstring_format__multiple_values():
22
+ """{key}1, {key}2"""
23
+ assert inspect.getdoc(test_docstring_format__multiple_values) == "value1, value2"
24
+
25
+
26
+ @docstring_format(key1="value1", key2="value2")
27
+ def test_docstring_format__multiple_keys():
28
+ """{key1}, {key2}"""
29
+ assert inspect.getdoc(test_docstring_format__multiple_keys) == "value1, value2"
30
+
31
+
32
+ # TODO: Convert to labeled_tests
33
+ def test_expose_path_params():
34
+ def test_func(**kwargs):
35
+ return kwargs
36
+
37
+ modified_func = expose_path_params(test_func, ["id"])
38
+ sig = inspect.signature(modified_func)
39
+
40
+ assert list(sig.parameters.keys()) == ["id"]
41
+
42
+
43
+ def test_expose_path_params__multiple_params():
44
+ def test_func(**kwargs):
45
+ pass
46
+
47
+ modified_func = expose_path_params(test_func, ["field1", "field2", "field3"])
48
+ sig = inspect.signature(modified_func)
49
+
50
+ assert list(sig.parameters.keys()) == ["field1", "field2", "field3"]
51
+
52
+
53
+ def test_expose_path_params__no_params():
54
+ def test_func(**kwargs):
55
+ pass
56
+
57
+ modified_func = expose_path_params(test_func, [])
58
+ sig = inspect.signature(modified_func)
59
+
60
+ assert list(sig.parameters.keys()) == []
61
+
62
+
63
+ def test_expose_path_params__existing_params():
64
+ def test_func(model: str = "default_model", daos: Mock = Mock(), **kwargs):
65
+ pass
66
+
67
+ modified_func = expose_path_params(test_func, ["user_id", "role_id"])
68
+ sig = inspect.signature(modified_func)
69
+
70
+ assert list(sig.parameters.keys()) == ["user_id", "role_id", "model", "daos"]
71
+
72
+ assert sig.parameters["model"].annotation == str
73
+ assert sig.parameters["model"].default == "default_model"
74
+ assert sig.parameters["daos"].annotation == Mock
75
+ assert isinstance(sig.parameters["daos"].default, Mock)
76
+
77
+
78
+ def test_expose_path_params__original_preserved():
79
+ def test_func(**kwargs):
80
+ return kwargs
81
+
82
+ modified_func = expose_path_params(test_func, ["field1"])
83
+
84
+ assert modified_func is test_func
85
+
86
+ result = test_func(field1="test_value")
87
+ assert result["field1"] == "test_value"
88
+
89
+
90
+ def test_extract_values():
91
+ kwargs = {
92
+ "field1": "value1",
93
+ "field2": "value2",
94
+ "field3": "value3"
95
+ }
96
+ field_names = ["field1", "field2", "field3"]
97
+ expected = ["value1", "value2", "value3"]
98
+
99
+ assert extract_values(kwargs, field_names) == expected
100
+
101
+
102
+ def test_extract_values___order():
103
+ kwargs = {
104
+ "field1": "value1",
105
+ "field2": "value2",
106
+ "field3": "value3"
107
+ }
108
+ field_names = ["field3", "field1", "field2"]
109
+ expected = ["value3", "value1", "value2"]
110
+
111
+ assert extract_values(kwargs, field_names) == expected
112
+
113
+
114
+ def test_extract_values__single_field():
115
+ kwargs = {"id": "test_id"}
116
+ field_names = ["id"]
117
+
118
+ result = extract_values(kwargs, field_names)
119
+ assert result == ["test_id"]
120
+
121
+
122
+ def test_extract_values__matching_values():
123
+ kwargs = {
124
+ "field1": "value",
125
+ "field2": "value"
126
+ }
127
+ field_names = ["field1", "field2"]
128
+ expected = ["value", "value"]
129
+
130
+ assert extract_values(kwargs, field_names) == expected
131
+
132
+
133
+ def test_extract_values__duplicated_field():
134
+ kwargs = {
135
+ "field1": "value1",
136
+ "field2": "value2"
137
+ }
138
+ field_names = ["field1", "field1", "field2", "field1"]
139
+ expected = ["value1", "value1", "value2", "value1"]
140
+
141
+ assert extract_values(kwargs, field_names) == expected
142
+
143
+
144
+ def test_extract_values__partial_extract():
145
+ kwargs = {
146
+ "field1": "value1",
147
+ "field2": "value2",
148
+ "extra1": "ignored",
149
+ "extra2": "ignored",
150
+ }
151
+ field_names = ["field1", "field2"]
152
+ expected = ["value1", "value2"]
153
+
154
+ assert extract_values(kwargs, field_names) == expected
155
+
156
+
157
+ def test_extract_values__no_fields():
158
+ kwargs = {"field": "value"}
159
+ field_names = []
160
+ expected = []
161
+
162
+ assert extract_values(kwargs, field_names) == expected
163
+
164
+
165
+ def test_extract_values__empty():
166
+ kwargs = {}
167
+ field_names = []
168
+ expected = []
169
+
170
+ assert extract_values(kwargs, field_names) == expected
171
+
172
+
173
+ def test_extract_values__different_types():
174
+ kwargs = {
175
+ "string_field": "string_value",
176
+ "int_field": 123,
177
+ "bool_field": True,
178
+ "list_field": [1, 2, 3],
179
+ "none_field": None
180
+ }
181
+ field_names = ["string_field", "int_field", "bool_field", "list_field", "none_field"]
182
+ expected = ["string_value", 123, True, [1, 2, 3], None]
183
+
184
+ assert extract_values(kwargs, field_names) == expected
@@ -1,272 +0,0 @@
1
- from contextlib import contextmanager
2
- from enum import Enum, auto
3
- from typing import Optional, Callable
4
-
5
- from daomodel import DAOModel
6
- from daomodel.dao import NotFound
7
- from daomodel.db import DAOFactory
8
- from daomodel.transaction import Conflict
9
- from fastapi import FastAPI, APIRouter, Request, Response, Depends, Path, Body, Query, Header
10
- from fastapi.responses import JSONResponse
11
- from sqlalchemy import Engine
12
- from sqlalchemy.orm import sessionmaker
13
- from sqlmodel import SQLModel
14
-
15
- from fast_controller.resource import Resource
16
- from fast_controller.util import docstring_format, InvalidInput
17
-
18
-
19
- class Action(Enum):
20
- VIEW = auto()
21
- SEARCH = auto()
22
- CREATE = auto()
23
- UPSERT = auto()
24
- MODIFY = auto()
25
- RENAME = auto()
26
- DELETE = auto()
27
-
28
-
29
- class Controller:
30
- def __init__(self,
31
- prefix: Optional[str] = '',
32
- app: Optional[FastAPI] = None,
33
- engine: Optional[Engine] = None) -> None:
34
- self.prefix = prefix
35
- self.app = None
36
- self.engine = None
37
- self.models = None
38
- if app is not None and engine is not None:
39
- self.init_app(app, engine)
40
- self.daos = Depends(self.dao_generator)
41
-
42
- def init_app(self, app: FastAPI, engine: Engine) -> None:
43
- self.app = app
44
- self.engine = engine
45
-
46
- @app.exception_handler(InvalidInput)
47
- async def not_found_handler(request: Request, exc: InvalidInput):
48
- return JSONResponse(status_code=400, content={"detail": exc.detail})
49
-
50
- @app.exception_handler(NotFound)
51
- async def not_found_handler(request: Request, exc: NotFound):
52
- return JSONResponse(status_code=404, content={"detail": exc.detail})
53
-
54
- @app.exception_handler(Conflict)
55
- async def not_found_handler(request: Request, exc: Conflict):
56
- return JSONResponse(status_code=409, content={"detail": exc.detail})
57
-
58
- def dao_generator(self) -> DAOFactory:
59
- """Yields a DAOFactory."""
60
- with DAOFactory(sessionmaker(bind=self.engine)) as daos:
61
- yield daos
62
-
63
- @contextmanager
64
- def dao_context(self):
65
- yield from self.dao_generator()
66
-
67
- def dependencies_for(self, resource: type[Resource], action: Action) -> list[Depends]:
68
- return []
69
-
70
- def register_resource(self,
71
- resource: type[Resource],
72
- skip: Optional[set[Action]] = None,
73
- additional_endpoints: Optional[Callable] = None) -> None:
74
- api_router = APIRouter(
75
- prefix=self.prefix + resource.get_resource_path(),
76
- tags=[resource.doc_name()])
77
- self._register_resource_endpoints(api_router, resource, skip)
78
- if additional_endpoints:
79
- additional_endpoints(api_router, self)
80
- self.app.include_router(api_router)
81
-
82
- def _register_resource_endpoints(self,
83
- router: APIRouter,
84
- resource: type[Resource],
85
- skip: Optional[set[Action]] = None) -> None:
86
- if skip is None:
87
- skip = set()
88
- if Action.SEARCH not in skip:
89
- self._register_search_endpoint(router, resource)
90
- if Action.CREATE not in skip:
91
- self._register_create_endpoint(router, resource)
92
- if Action.UPSERT not in skip:
93
- self._register_update_endpoint(router, resource)
94
-
95
- pk = [p.name for p in resource.get_pk()]
96
- path = "/".join([""] + ["{" + p + "}" for p in pk])
97
-
98
- # Caveat: Only up to 2 columns are supported within a primary key.
99
- # This allows us to avoid resorting to exec() while **kwargs is unsupported for Path variables
100
- if len(pk) == 1:
101
- if Action.VIEW not in skip:
102
- self._register_view_endpoint(router, resource, path, pk)
103
-
104
- # Caveat: Rename action is only supported for resources with a single column primary key
105
- if Action.RENAME not in skip:
106
- self._register_rename_endpoint(router, resource, path, pk)
107
-
108
- # Caveat: Modify action is only supported for resources with a single column primary key
109
- # Use Upsert instead for multi-column PK resources
110
- if Action.MODIFY not in skip:
111
- self._register_modify_endpoint(router, resource, path, pk)
112
-
113
- # Caveat: Delete action is only supported for resources with a single column primary key
114
- if Action.DELETE not in skip:
115
- self._register_delete_endpoint(router, resource, path, pk)
116
- elif len(pk) == 2:
117
- if Action.VIEW not in skip:
118
- self._register_view_endpoint_dual_pk(router, resource, path, pk)
119
-
120
- def _register_search_endpoint(self, router: APIRouter, resource: type[Resource]):
121
- @router.get(
122
- "/",
123
- response_model=list[resource.get_output_schema()],
124
- dependencies=self.dependencies_for(resource, Action.SEARCH))
125
- @docstring_format(resource=resource.doc_name())
126
- def search(response: Response,
127
- filters: resource.get_search_schema() = Query(),
128
- x_page: Optional[int] = Header(default=None, gt=0),
129
- x_per_page: Optional[int] = Header(default=None, gt=0),
130
- daos: DAOFactory = self.daos) -> list[DAOModel]:
131
- """Searches for {resource} by criteria"""
132
- results = daos[resource].find(x_page, x_per_page, **filters.model_dump(exclude_unset=True))
133
- response.headers["x-total-count"] = str(results.total)
134
- response.headers["x-page"] = str(results.page)
135
- response.headers["x-per-page"] = str(results.per_page)
136
- return results
137
-
138
- def _register_create_endpoint(self, router: APIRouter, resource: type[Resource]):
139
- @router.post(
140
- "/",
141
- response_model=resource.get_detailed_output_schema(),
142
- status_code=201,
143
- dependencies=self.dependencies_for(resource, Action.CREATE))
144
- @docstring_format(resource=resource.doc_name())
145
- def create(model: resource.get_input_schema(),
146
- daos: DAOFactory = self.daos) -> DAOModel:
147
- """Creates a new {resource}"""
148
- return daos[resource].create_with(**model.model_dump(exclude_unset=True))
149
-
150
- def _register_update_endpoint(self, router: APIRouter, resource: type[Resource]):
151
- @router.put(
152
- "/",
153
- response_model=resource.get_detailed_output_schema(),
154
- dependencies=self.dependencies_for(resource, Action.UPSERT))
155
- @docstring_format(resource=resource.doc_name())
156
- def upsert(model: resource.get_input_schema(),
157
- daos: DAOFactory = self.daos) -> SQLModel:
158
- """Creates/modifies a {resource}"""
159
- daos[resource].upsert(model)
160
- return model
161
-
162
- def _register_view_endpoint(self,
163
- router: APIRouter,
164
- resource: type[Resource],
165
- path: str,
166
- pk: list[str]):
167
- @router.get(
168
- path,
169
- response_model=resource.get_detailed_output_schema(),
170
- dependencies=self.dependencies_for(resource, Action.VIEW))
171
- @docstring_format(resource=resource.doc_name())
172
- def view(pk0=Path(alias=pk[0]),
173
- daos: DAOFactory = self.daos) -> DAOModel:
174
- """Retrieves a detailed view of a {resource}"""
175
- return daos[resource].get(pk0)
176
-
177
- def _register_view_endpoint_dual_pk(self,
178
- router: APIRouter,
179
- resource: type[Resource],
180
- path: str,
181
- pk: list[str]):
182
- @router.get(
183
- path,
184
- response_model=resource.get_detailed_output_schema(),
185
- dependencies=self.dependencies_for(resource, Action.VIEW))
186
- @docstring_format(resource=resource.doc_name())
187
- def view(pk0=Path(alias=pk[0]),
188
- pk1=Path(alias=pk[1]),
189
- daos: DAOFactory = self.daos) -> DAOModel:
190
- """Retrieves a detailed view of a {resource}"""
191
- return daos[resource].get(pk0, pk1)
192
-
193
- def _register_rename_endpoint(self,
194
- router: APIRouter,
195
- resource: type[Resource],
196
- path: str,
197
- pk: list[str]):
198
- @router.post(
199
- f'{path}/rename',
200
- response_model=resource.get_detailed_output_schema(),
201
- dependencies=self.dependencies_for(resource, Action.RENAME))
202
- @docstring_format(resource=resource.doc_name())
203
- def rename(pk0=Path(alias=pk[0]),
204
- new_id=Body(alias=pk[0]),
205
- daos: DAOFactory = self.daos) -> DAOModel:
206
- """Renames a {resource}"""
207
- dao = daos[resource]
208
- current = dao.get(pk0)
209
- dao.rename(current, dao.get(new_id))
210
- return current
211
-
212
- def _register_merge_endpoint(self,
213
- router: APIRouter,
214
- resource: type[Resource],
215
- path: str,
216
- pk: list[str]):
217
- @router.post(
218
- f'{path}/merge',
219
- response_model=resource.get_detailed_output_schema(),
220
- dependencies=self.dependencies_for(resource, Action.RENAME))
221
- @docstring_format(resource=resource.doc_name())
222
- def merge(pk0=Path(alias=pk[0]),
223
- target_id=Body(alias=pk[0]),
224
- daos: DAOFactory = self.daos) -> DAOModel:
225
- source = daos[resource].get(pk0)
226
- # for model in all_models(self.engine):
227
- # for column in model.get_references_of(resource):
228
- #daos[type[model]].find(column.name=)
229
- # if fk.column.table.name == target_table_name and fk.column.name in target_column_values:
230
- # print(f"Foreign key in table {table.name} references the column '{fk.column.name}' in {target_table.name}")
231
- # # Retrieve rows in this table that reference the target row
232
- # conn = engine.connect()
233
- # condition = (table.c[fk.parent.name] == target_column_values[fk.column.name])
234
- # result = conn.execute(table.select().where(condition))
235
- # referencing_rows.extend(result.fetchall())
236
- # conn.close()
237
- #
238
- # return referencing_rows
239
-
240
- def _register_modify_endpoint(self,
241
- router: APIRouter,
242
- resource: type[Resource],
243
- path: str,
244
- pk: list[str]):
245
- @router.put(
246
- path,
247
- response_model=resource.get_detailed_output_schema(),
248
- dependencies=self.dependencies_for(resource, Action.MODIFY))
249
- @docstring_format(resource=resource.doc_name())
250
- def update(model: resource.get_update_schema(), # TODO - Remove PK from input schema
251
- pk0=Path(alias=pk[0]),
252
- daos: DAOFactory = self.daos) -> DAOModel:
253
- """Creates/modifies a {resource}"""
254
- result = daos[resource].get(pk0)
255
- result.set_values(**model.model_dump(exclude_unset=True))
256
- daos[resource].commit(result)
257
- return result
258
-
259
- def _register_delete_endpoint(self,
260
- router: APIRouter,
261
- resource: type[Resource],
262
- path: str,
263
- pk: list[str]):
264
- @router.delete(
265
- path,
266
- status_code=204,
267
- dependencies=self.dependencies_for(resource, Action.DELETE))
268
- @docstring_format(resource=resource.doc_name())
269
- def delete(pk0=Path(alias=pk[0]),
270
- daos: DAOFactory = self.daos) -> None:
271
- """Deletes a {resource}"""
272
- daos[resource].remove(pk0)
@@ -1,36 +0,0 @@
1
- from typing import Callable, get_type_hints, Optional
2
-
3
- from sqlmodel import SQLModel
4
-
5
-
6
- class InvalidInput(Exception):
7
- """Indicates that the user provided bad input."""
8
- def __init__(self, detail: str):
9
- self.detail = detail
10
-
11
-
12
- def docstring_format(**kwargs):
13
- """
14
- A decorator that formats the docstring of a function with specified values.
15
-
16
- :param kwargs: The values to inject into the docstring
17
- """
18
- def decorator(func: Callable):
19
- func.__doc__ = func.__doc__.format(**kwargs)
20
- return func
21
- return decorator
22
-
23
-
24
- # TODO: Determine bast way to add test coverage
25
- def all_optional(superclass: type[SQLModel]):
26
- """Creates a new SQLModel for the specified class but having no required fields.
27
-
28
- :param superclass: The SQLModel of which to make all fields Optional
29
- :return: The newly wrapped Model
30
- """
31
- class OptionalModel(superclass):
32
- pass
33
- for field, field_type in get_type_hints(OptionalModel).items():
34
- if not isinstance(field_type, type(Optional)):
35
- OptionalModel.__annotations__[field] = Optional[field_type]
36
- return OptionalModel
@@ -1,27 +0,0 @@
1
- import inspect
2
-
3
- from fast_controller import docstring_format
4
-
5
-
6
- @docstring_format(key="value")
7
- def test_docstring_format():
8
- """{key}"""
9
- assert inspect.getdoc(test_docstring_format) == "value"
10
-
11
-
12
- @docstring_format(key="value")
13
- def test_docstring_format__empty():
14
- """"""
15
- assert inspect.getdoc(test_docstring_format__empty) == ""
16
-
17
-
18
- @docstring_format(key="value")
19
- def test_docstring_format__multiple_values():
20
- """{key}1, {key}2"""
21
- assert inspect.getdoc(test_docstring_format__multiple_values) == "value1, value2"
22
-
23
-
24
- @docstring_format(key1="value1", key2="value2")
25
- def test_docstring_format__multiple_keys():
26
- """{key1}, {key2}"""
27
- assert inspect.getdoc(test_docstring_format__multiple_keys) == "value1, value2"