sqlmodel-graphql 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,56 @@
1
+ """SQLModel GraphQL - GraphQL SDL generation and query optimization for SQLModel.
2
+
3
+ This package provides:
4
+ - Automatic GraphQL SDL generation from SQLModel classes
5
+ - @query/@mutation decorators for defining GraphQL operations
6
+ - QueryMeta extraction from GraphQL queries for query optimization
7
+ - SQLAlchemy query optimization via to_options()
8
+
9
+ Example:
10
+ ```python
11
+ from sqlmodel import SQLModel, Field, Relationship, select
12
+ from sqlmodel_graphql import query, mutation, SDLGenerator, QueryParser
13
+
14
+ class User(SQLModel, table=True):
15
+ id: int = Field(primary_key=True)
16
+ name: str
17
+ posts: list["Post"] = Relationship(back_populates="author")
18
+
19
+ @query(name='users')
20
+ async def get_all(cls, query_meta: QueryMeta = None) -> list['User']:
21
+ stmt = select(cls)
22
+ if query_meta:
23
+ stmt = stmt.options(*query_meta.to_options(cls))
24
+ return await fetch_users(stmt)
25
+
26
+ # Generate SDL
27
+ generator = SDLGenerator([User, Post])
28
+ print(generator.generate())
29
+ ```
30
+ """
31
+
32
+ from __future__ import annotations
33
+
34
+ __version__ = "0.1.0"
35
+
36
+ from sqlmodel_graphql.decorator import mutation, query
37
+ from sqlmodel_graphql.handler import GraphQLHandler
38
+ from sqlmodel_graphql.query_parser import QueryParser
39
+ from sqlmodel_graphql.sdl_generator import SDLGenerator
40
+ from sqlmodel_graphql.types import FieldSelection, QueryMeta, RelationshipSelection
41
+
42
+ __all__ = [
43
+ # Version
44
+ "__version__",
45
+ # Decorators
46
+ "query",
47
+ "mutation",
48
+ # Core classes
49
+ "SDLGenerator",
50
+ "QueryParser",
51
+ "GraphQLHandler",
52
+ # Types
53
+ "QueryMeta",
54
+ "FieldSelection",
55
+ "RelationshipSelection",
56
+ ]
@@ -0,0 +1,145 @@
1
+ """Decorators for marking SQLModel methods as GraphQL queries and mutations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Callable
6
+ from typing import overload
7
+
8
+
9
+ @overload
10
+ def query(func: Callable) -> classmethod: ...
11
+
12
+
13
+ @overload
14
+ def query(
15
+ *, name: str | None = None, description: str | None = None
16
+ ) -> Callable[[Callable], classmethod]: ...
17
+
18
+
19
+ def query(
20
+ name_or_func: Callable | None = None,
21
+ *,
22
+ name: str | None = None,
23
+ description: str | None = None,
24
+ ) -> classmethod | Callable[[Callable], classmethod]:
25
+ """Mark a method as a GraphQL query.
26
+
27
+ This decorator automatically converts the method to a classmethod.
28
+
29
+ Args:
30
+ name_or_func: Function object (when called without parameters) or None.
31
+ name: GraphQL query name (defaults to camelCase conversion of method name).
32
+ description: Description text in GraphQL Schema.
33
+
34
+ Returns:
35
+ A classmethod decorator.
36
+
37
+ Example:
38
+ ```python
39
+ from sqlmodel import SQLModel
40
+ from sqlmodel_graphql import query
41
+
42
+ class User(SQLModel, table=True):
43
+ id: int
44
+ name: str
45
+
46
+ @query(name='users', description='Get all users')
47
+ async def get_all(cls, limit: int = 10) -> list['User']:
48
+ return await fetch_users(limit)
49
+ ```
50
+
51
+ This generates the following GraphQL Schema:
52
+ ```graphql
53
+ type Query {
54
+ users(limit: Int): [User!]!
55
+ }
56
+ ```
57
+ """
58
+ # Handle @query without parameters
59
+ if callable(name_or_func):
60
+ func = name_or_func
61
+ func._graphql_query = True # type: ignore[attr-defined]
62
+ func._graphql_query_name = name # type: ignore[attr-defined]
63
+ func._graphql_query_description = description # type: ignore[attr-defined]
64
+ return classmethod(func)
65
+
66
+ # Handle @query(name='...', description='...')
67
+ query_name = name or name_or_func
68
+
69
+ def decorator(func: Callable) -> classmethod:
70
+ func._graphql_query = True # type: ignore[attr-defined]
71
+ func._graphql_query_name = query_name # type: ignore[attr-defined]
72
+ func._graphql_query_description = description # type: ignore[attr-defined]
73
+ return classmethod(func)
74
+
75
+ return decorator
76
+
77
+
78
+ @overload
79
+ def mutation(func: Callable) -> classmethod: ...
80
+
81
+
82
+ @overload
83
+ def mutation(
84
+ *, name: str | None = None, description: str | None = None
85
+ ) -> Callable[[Callable], classmethod]: ...
86
+
87
+
88
+ def mutation(
89
+ name_or_func: Callable | None = None,
90
+ *,
91
+ name: str | None = None,
92
+ description: str | None = None,
93
+ ) -> classmethod | Callable[[Callable], classmethod]:
94
+ """Mark a method as a GraphQL mutation.
95
+
96
+ This decorator automatically converts the method to a classmethod.
97
+
98
+ Args:
99
+ name_or_func: Function object (when called without parameters) or None.
100
+ name: GraphQL mutation name (defaults to camelCase conversion of method name).
101
+ description: Description text in GraphQL Schema.
102
+
103
+ Returns:
104
+ A classmethod decorator.
105
+
106
+ Example:
107
+ ```python
108
+ from sqlmodel import SQLModel
109
+ from sqlmodel_graphql import mutation
110
+
111
+ class User(SQLModel, table=True):
112
+ id: int
113
+ name: str
114
+ email: str
115
+
116
+ @mutation(name='createUser', description='Create a new user')
117
+ async def create(cls, name: str, email: str) -> 'User':
118
+ return await create_user(name, email)
119
+ ```
120
+
121
+ This generates the following GraphQL Schema:
122
+ ```graphql
123
+ type Mutation {
124
+ createUser(name: String!, email: String!): User!
125
+ }
126
+ ```
127
+ """
128
+ # Handle @mutation without parameters
129
+ if callable(name_or_func):
130
+ func = name_or_func
131
+ func._graphql_mutation = True # type: ignore[attr-defined]
132
+ func._graphql_mutation_name = name # type: ignore[attr-defined]
133
+ func._graphql_mutation_description = description # type: ignore[attr-defined]
134
+ return classmethod(func)
135
+
136
+ # Handle @mutation(name='...', description='...')
137
+ mutation_name = name or name_or_func
138
+
139
+ def decorator(func: Callable) -> classmethod:
140
+ func._graphql_mutation = True # type: ignore[attr-defined]
141
+ func._graphql_mutation_name = mutation_name # type: ignore[attr-defined]
142
+ func._graphql_mutation_description = description # type: ignore[attr-defined]
143
+ return classmethod(func)
144
+
145
+ return decorator
@@ -0,0 +1,398 @@
1
+ """GraphQL execution handler for SQLModel entities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import inspect
6
+ from collections.abc import Callable
7
+ from typing import TYPE_CHECKING, Any
8
+
9
+ from graphql import parse
10
+
11
+ from sqlmodel_graphql.introspection import IntrospectionGenerator
12
+ from sqlmodel_graphql.query_parser import QueryParser
13
+ from sqlmodel_graphql.sdl_generator import SDLGenerator
14
+
15
+ if TYPE_CHECKING:
16
+ from sqlmodel import SQLModel
17
+
18
+
19
+ def _serialize_value(
20
+ value: Any,
21
+ include: set[str] | dict[str, Any] | None = None
22
+ ) -> Any:
23
+ """Serialize a value for JSON response.
24
+
25
+ Handles SQLModel instances, lists, and basic types.
26
+
27
+ Args:
28
+ value: The value to serialize.
29
+ include: Can be:
30
+ - None: Include all fields
31
+ - set[str]: Include only these field names (no nested selection)
32
+ - dict[str, Any]: Field selection tree where keys are field names
33
+ and values are nested selection trees for relationships
34
+ """
35
+ if value is None:
36
+ return None
37
+
38
+ # Handle SQLModel instances
39
+ if hasattr(value, "model_dump"):
40
+ # Get base fields from model_dump
41
+ result = value.model_dump()
42
+
43
+ # Determine field selection
44
+ if include is None:
45
+ # Include all fields (including relationships)
46
+ for field_name in dir(value):
47
+ if not field_name.startswith('_') and field_name not in result:
48
+ field_value = getattr(value, field_name, None)
49
+ if field_value is not None and (
50
+ hasattr(field_value, "model_dump") or
51
+ isinstance(field_value, list)
52
+ ):
53
+ result[field_name] = _serialize_value(field_value)
54
+ elif isinstance(include, dict):
55
+ # Dict-based selection with nested field info
56
+ # First, filter scalar fields
57
+ result = {k: v for k, v in result.items() if k in include}
58
+
59
+ # Then handle relationship fields
60
+ for field_name, nested_include in include.items():
61
+ if field_name not in result and hasattr(value, field_name):
62
+ field_value = getattr(value, field_name)
63
+ if field_value is not None:
64
+ result[field_name] = _serialize_value(field_value, nested_include)
65
+ else:
66
+ # Set-based selection (backward compatible)
67
+ result = {k: v for k, v in result.items() if k in include}
68
+
69
+ # Handle relationship fields
70
+ for field_name in include:
71
+ if field_name not in result and hasattr(value, field_name):
72
+ field_value = getattr(value, field_name)
73
+ if field_value is not None:
74
+ result[field_name] = _serialize_value(field_value)
75
+
76
+ return result
77
+
78
+ # Handle lists
79
+ if isinstance(value, list):
80
+ return [_serialize_value(item, include) for item in value]
81
+
82
+ # Handle dicts
83
+ if isinstance(value, dict):
84
+ if include:
85
+ if isinstance(include, dict):
86
+ return {
87
+ k: _serialize_value(v, include.get(k))
88
+ for k, v in value.items()
89
+ if k in include
90
+ }
91
+ else:
92
+ return {
93
+ k: _serialize_value(v)
94
+ for k, v in value.items()
95
+ if k in include
96
+ }
97
+ return {k: _serialize_value(v) for k, v in value.items()}
98
+
99
+ # Basic types (int, str, bool, float)
100
+ return value
101
+
102
+
103
+ class GraphQLHandler:
104
+ """Handles GraphQL query execution for SQLModel entities.
105
+
106
+ This class scans entities for @query and @mutation decorators,
107
+ builds a GraphQL schema, and executes queries against it.
108
+
109
+ Example:
110
+ ```python
111
+ from sqlmodel import SQLModel
112
+ from sqlmodel_graphql import GraphQLHandler, query
113
+
114
+ class User(SQLModel, table=True):
115
+ id: int
116
+ name: str
117
+
118
+ @query(name='users')
119
+ async def get_all(cls) -> list['User']:
120
+ return await fetch_users()
121
+
122
+ handler = GraphQLHandler(entities=[User])
123
+ result = await handler.execute('{ users { id name } }')
124
+ ```
125
+ """
126
+
127
+ def __init__(self, entities: list[type[SQLModel]]):
128
+ """Initialize the GraphQL handler.
129
+
130
+ Args:
131
+ entities: List of SQLModel classes with @query/@mutation decorators.
132
+ """
133
+ self.entities = entities
134
+ self._sdl_generator = SDLGenerator(entities)
135
+ self._query_parser = QueryParser()
136
+
137
+ # Build method mappings: field_name -> (entity, method)
138
+ self._query_methods: dict[str, tuple[type[SQLModel], Callable[..., Any]]] = {}
139
+ self._mutation_methods: dict[str, tuple[type[SQLModel], Callable[..., Any]]] = {}
140
+
141
+ self._scan_methods()
142
+
143
+ # Initialize introspection generator
144
+ self._introspection_generator = IntrospectionGenerator(
145
+ entities=entities,
146
+ query_methods=self._query_methods,
147
+ mutation_methods=self._mutation_methods,
148
+ )
149
+
150
+ def _scan_methods(self) -> None:
151
+ """Scan all entities for @query and @mutation methods."""
152
+ for entity in self.entities:
153
+ for name in dir(entity):
154
+ try:
155
+ attr = getattr(entity, name)
156
+ if not callable(attr):
157
+ continue
158
+
159
+ # Check for @query decorator
160
+ if hasattr(attr, "_graphql_query"):
161
+ func = attr.__func__ if hasattr(attr, "__func__") else attr
162
+ gql_name = getattr(func, "_graphql_query_name", None)
163
+ if gql_name is None:
164
+ gql_name = func.__name__
165
+ self._query_methods[gql_name] = (entity, attr)
166
+
167
+ # Check for @mutation decorator
168
+ if hasattr(attr, "_graphql_mutation"):
169
+ func = attr.__func__ if hasattr(attr, "__func__") else attr
170
+ gql_name = getattr(func, "_graphql_mutation_name", None)
171
+ if gql_name is None:
172
+ gql_name = func.__name__
173
+ self._mutation_methods[gql_name] = (entity, attr)
174
+
175
+ except Exception:
176
+ continue
177
+
178
+ def get_sdl(self) -> str:
179
+ """Get the GraphQL Schema Definition Language string.
180
+
181
+ Returns:
182
+ SDL string representing the GraphQL schema.
183
+ """
184
+ return self._sdl_generator.generate()
185
+
186
+ async def execute(
187
+ self,
188
+ query: str,
189
+ variables: dict[str, Any] | None = None,
190
+ operation_name: str | None = None,
191
+ ) -> dict[str, Any]:
192
+ """Execute a GraphQL query.
193
+
194
+ Args:
195
+ query: GraphQL query string.
196
+ variables: Optional variables for the query.
197
+ operation_name: Optional operation name for multi-operation documents.
198
+
199
+ Returns:
200
+ Dictionary with 'data' and/or 'errors' keys.
201
+ """
202
+ try:
203
+ # Check if this is an introspection query
204
+ if self._is_introspection_query(query):
205
+ return await self._execute_introspection(query, variables)
206
+
207
+ # Parse the query to get field selection info
208
+ parsed = self._query_parser.parse(query)
209
+
210
+ # Execute regular query
211
+ return await self._execute_query(query, variables, operation_name, parsed)
212
+
213
+ except Exception as e:
214
+ return {"errors": [{"message": str(e)}]}
215
+
216
+ def _is_introspection_query(self, query: str) -> bool:
217
+ """Check if the query is an introspection query."""
218
+ return "__schema" in query or "__type" in query
219
+
220
+ async def _execute_introspection(
221
+ self, query: str, variables: dict[str, Any] | None
222
+ ) -> dict[str, Any]:
223
+ """Execute an introspection query.
224
+
225
+ Args:
226
+ query: GraphQL introspection query string.
227
+ variables: Optional variables.
228
+
229
+ Returns:
230
+ Introspection result dictionary.
231
+ """
232
+ return self._introspection_generator.execute(query)
233
+
234
+ async def _execute_query(
235
+ self,
236
+ query: str,
237
+ variables: dict[str, Any] | None,
238
+ operation_name: str | None,
239
+ parsed_meta: dict[str, Any],
240
+ ) -> dict[str, Any]:
241
+ """Execute a regular GraphQL query.
242
+
243
+ Args:
244
+ query: GraphQL query string.
245
+ variables: Optional variables.
246
+ operation_name: Optional operation name.
247
+ parsed_meta: Parsed QueryMeta from the query.
248
+
249
+ Returns:
250
+ Query result dictionary.
251
+ """
252
+ from graphql import FieldNode, OperationDefinitionNode
253
+
254
+ document = parse(query)
255
+ data: dict[str, Any] = {}
256
+ errors: list[dict[str, Any]] = []
257
+
258
+ for definition in document.definitions:
259
+ if isinstance(definition, OperationDefinitionNode):
260
+ op_type = definition.operation.value # 'query' or 'mutation'
261
+
262
+ for selection in definition.selection_set.selections:
263
+ if isinstance(selection, FieldNode):
264
+ field_name = selection.name.value
265
+
266
+ try:
267
+ # Get the method for this field
268
+ if op_type == "query":
269
+ method_info = self._query_methods.get(field_name)
270
+ else:
271
+ method_info = self._mutation_methods.get(field_name)
272
+
273
+ if method_info is None:
274
+ op_name = op_type.title()
275
+ msg = f"Cannot query field '{field_name}' on type '{op_name}'"
276
+ errors.append(
277
+ {
278
+ "message": msg,
279
+ "path": [field_name],
280
+ }
281
+ )
282
+ continue
283
+
284
+ entity, method = method_info
285
+
286
+ # Build arguments
287
+ args = self._build_arguments(
288
+ selection, variables, method, entity
289
+ )
290
+
291
+ # Add query_meta if available (only for queries, not mutations)
292
+ if op_type == "query" and field_name in parsed_meta:
293
+ args["query_meta"] = parsed_meta[field_name]
294
+
295
+ # Execute the method
296
+ result = method(**args)
297
+ if inspect.isawaitable(result):
298
+ result = await result
299
+
300
+ # Extract requested fields from selection set
301
+ requested_fields = self._build_field_tree(selection)
302
+
303
+ # Serialize the result, only including requested fields
304
+ data[field_name] = _serialize_value(result, include=requested_fields)
305
+
306
+ except Exception as e:
307
+ errors.append(
308
+ {"message": str(e), "path": [field_name]}
309
+ )
310
+
311
+ response: dict[str, Any] = {}
312
+ if data:
313
+ response["data"] = data
314
+ if errors:
315
+ response["errors"] = errors
316
+
317
+ return response
318
+
319
+ def _build_field_tree(self, selection: Any) -> dict[str, Any] | None:
320
+ """Build a field selection tree from a GraphQL FieldNode.
321
+
322
+ Args:
323
+ selection: GraphQL FieldNode with selection set.
324
+
325
+ Returns:
326
+ Dictionary where keys are field names and values are:
327
+ - {} for scalar fields
328
+ - {nested_field: ...} for relationship fields
329
+ Returns None if no selection_set.
330
+ """
331
+ if not selection.selection_set:
332
+ return None
333
+
334
+ field_tree: dict[str, Any] = {}
335
+ for field in selection.selection_set.selections:
336
+ if hasattr(field, "name"):
337
+ field_name = field.name.value
338
+ if hasattr(field, "selection_set") and field.selection_set:
339
+ # It's a relationship field - recursively build nested tree
340
+ field_tree[field_name] = self._build_field_tree(field)
341
+ else:
342
+ # It's a scalar field
343
+ field_tree[field_name] = None
344
+
345
+ return field_tree
346
+
347
+ def _build_arguments(
348
+ self,
349
+ selection: Any,
350
+ variables: dict[str, Any] | None,
351
+ method: Callable[..., Any],
352
+ entity: type[SQLModel],
353
+ ) -> dict[str, Any]:
354
+ """Build method arguments from GraphQL field arguments.
355
+
356
+ Args:
357
+ selection: GraphQL FieldNode with argument info.
358
+ variables: GraphQL variables dict.
359
+ method: The method to call.
360
+ entity: The SQLModel entity class.
361
+
362
+ Returns:
363
+ Dictionary of argument name to value.
364
+ """
365
+ args: dict[str, Any] = {}
366
+ variables = variables or {}
367
+
368
+ if not selection.arguments:
369
+ return args
370
+
371
+ # Get method signature for type info
372
+ func = method.__func__ if hasattr(method, "__func__") else method
373
+ sig = inspect.signature(func)
374
+
375
+ for arg in selection.arguments:
376
+ arg_name = arg.name.value
377
+
378
+ # Get the value (from literal or variable)
379
+ if hasattr(arg.value, "value"):
380
+ # Literal value
381
+ value = arg.value.value
382
+ elif hasattr(arg.value, "name"):
383
+ # Variable reference
384
+ var_name = arg.value.name.value
385
+ value = variables.get(var_name)
386
+ else:
387
+ value = arg.value
388
+
389
+ # Use argument name directly
390
+ param_name = arg_name
391
+
392
+ # Check if this param exists in the method signature
393
+ if param_name in sig.parameters:
394
+ args[param_name] = value
395
+ elif arg_name in sig.parameters:
396
+ args[arg_name] = value
397
+
398
+ return args