surrealdb-orm 0.1.4__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- surreal_orm/__init__.py +72 -3
- surreal_orm/aggregations.py +164 -0
- surreal_orm/auth/__init__.py +15 -0
- surreal_orm/auth/access.py +167 -0
- surreal_orm/auth/mixins.py +302 -0
- surreal_orm/cli/__init__.py +15 -0
- surreal_orm/cli/commands.py +369 -0
- surreal_orm/connection_manager.py +58 -18
- surreal_orm/fields/__init__.py +36 -0
- surreal_orm/fields/encrypted.py +166 -0
- surreal_orm/fields/relation.py +465 -0
- surreal_orm/migrations/__init__.py +51 -0
- surreal_orm/migrations/executor.py +380 -0
- surreal_orm/migrations/generator.py +272 -0
- surreal_orm/migrations/introspector.py +305 -0
- surreal_orm/migrations/migration.py +188 -0
- surreal_orm/migrations/operations.py +531 -0
- surreal_orm/migrations/state.py +406 -0
- surreal_orm/model_base.py +530 -44
- surreal_orm/query_set.py +609 -33
- surreal_orm/relations.py +645 -0
- surreal_orm/surreal_function.py +95 -0
- surreal_orm/surreal_ql.py +113 -0
- surreal_orm/types.py +86 -0
- surreal_sdk/README.md +79 -0
- surreal_sdk/__init__.py +151 -0
- surreal_sdk/connection/__init__.py +17 -0
- surreal_sdk/connection/base.py +516 -0
- surreal_sdk/connection/http.py +421 -0
- surreal_sdk/connection/pool.py +244 -0
- surreal_sdk/connection/websocket.py +519 -0
- surreal_sdk/exceptions.py +71 -0
- surreal_sdk/functions.py +607 -0
- surreal_sdk/protocol/__init__.py +13 -0
- surreal_sdk/protocol/rpc.py +218 -0
- surreal_sdk/py.typed +0 -0
- surreal_sdk/pyproject.toml +49 -0
- surreal_sdk/streaming/__init__.py +31 -0
- surreal_sdk/streaming/change_feed.py +278 -0
- surreal_sdk/streaming/live_query.py +265 -0
- surreal_sdk/streaming/live_select.py +369 -0
- surreal_sdk/transaction.py +386 -0
- surreal_sdk/types.py +346 -0
- surrealdb_orm-0.5.0.dist-info/METADATA +465 -0
- surrealdb_orm-0.5.0.dist-info/RECORD +52 -0
- {surrealdb_orm-0.1.4.dist-info → surrealdb_orm-0.5.0.dist-info}/WHEEL +1 -1
- surrealdb_orm-0.5.0.dist-info/entry_points.txt +2 -0
- {surrealdb_orm-0.1.4.dist-info → surrealdb_orm-0.5.0.dist-info}/licenses/LICENSE +1 -1
- surrealdb_orm-0.1.4.dist-info/METADATA +0 -184
- surrealdb_orm-0.1.4.dist-info/RECORD +0 -12
|
@@ -0,0 +1,516 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base Connection Interface for SurrealDB SDK.
|
|
3
|
+
|
|
4
|
+
Defines the abstract interface that all connection types must implement.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Self
|
|
9
|
+
|
|
10
|
+
from ..protocol.rpc import RPCRequest, RPCResponse
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from ..functions import FunctionNamespace
|
|
14
|
+
from ..transaction import BaseTransaction
|
|
15
|
+
from ..types import (
|
|
16
|
+
QueryResponse,
|
|
17
|
+
RecordResponse,
|
|
18
|
+
RecordsResponse,
|
|
19
|
+
AuthResponse,
|
|
20
|
+
InfoResponse,
|
|
21
|
+
DeleteResponse,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class BaseSurrealConnection(ABC):
|
|
26
|
+
"""
|
|
27
|
+
Abstract base class for SurrealDB connections.
|
|
28
|
+
|
|
29
|
+
All connection implementations (HTTP, WebSocket) must inherit from this class.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
url: str,
|
|
35
|
+
namespace: str,
|
|
36
|
+
database: str,
|
|
37
|
+
timeout: float = 30.0,
|
|
38
|
+
):
|
|
39
|
+
"""
|
|
40
|
+
Initialize connection parameters.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
url: SurrealDB server URL
|
|
44
|
+
namespace: Target namespace
|
|
45
|
+
database: Target database
|
|
46
|
+
timeout: Request timeout in seconds
|
|
47
|
+
"""
|
|
48
|
+
self.url = url.rstrip("/")
|
|
49
|
+
self.namespace = namespace
|
|
50
|
+
self.database = database
|
|
51
|
+
self.timeout = timeout
|
|
52
|
+
self._connected = False
|
|
53
|
+
self._authenticated = False
|
|
54
|
+
self._token: str | None = None
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def is_connected(self) -> bool:
|
|
58
|
+
"""Check if connection is established."""
|
|
59
|
+
return self._connected
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def is_authenticated(self) -> bool:
|
|
63
|
+
"""Check if authenticated."""
|
|
64
|
+
return self._authenticated
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def token(self) -> str | None:
|
|
68
|
+
"""Get the current authentication token."""
|
|
69
|
+
return self._token
|
|
70
|
+
|
|
71
|
+
# Abstract methods that must be implemented
|
|
72
|
+
|
|
73
|
+
@abstractmethod
|
|
74
|
+
async def connect(self) -> None:
|
|
75
|
+
"""Establish connection to SurrealDB."""
|
|
76
|
+
...
|
|
77
|
+
|
|
78
|
+
@abstractmethod
|
|
79
|
+
async def close(self) -> None:
|
|
80
|
+
"""Close the connection."""
|
|
81
|
+
...
|
|
82
|
+
|
|
83
|
+
@abstractmethod
|
|
84
|
+
async def _send_rpc(self, request: RPCRequest) -> RPCResponse:
|
|
85
|
+
"""
|
|
86
|
+
Send an RPC request and receive response.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
request: The RPC request to send
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
The RPC response
|
|
93
|
+
"""
|
|
94
|
+
...
|
|
95
|
+
|
|
96
|
+
# Context manager support
|
|
97
|
+
|
|
98
|
+
async def __aenter__(self) -> Self:
|
|
99
|
+
"""Async context manager entry."""
|
|
100
|
+
await self.connect()
|
|
101
|
+
return self
|
|
102
|
+
|
|
103
|
+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
104
|
+
"""Async context manager exit."""
|
|
105
|
+
await self.close()
|
|
106
|
+
|
|
107
|
+
# High-level API methods
|
|
108
|
+
|
|
109
|
+
async def rpc(self, method: str, params: list[Any] | dict[str, Any] | None = None) -> Any:
|
|
110
|
+
"""
|
|
111
|
+
Execute an RPC call.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
method: RPC method name
|
|
115
|
+
params: Method parameters
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
The result from SurrealDB
|
|
119
|
+
|
|
120
|
+
Raises:
|
|
121
|
+
QueryError: If the RPC call fails
|
|
122
|
+
"""
|
|
123
|
+
from ..exceptions import QueryError
|
|
124
|
+
|
|
125
|
+
request = RPCRequest(method=method, params=params or [])
|
|
126
|
+
response = await self._send_rpc(request)
|
|
127
|
+
|
|
128
|
+
if response.is_error:
|
|
129
|
+
raise QueryError(
|
|
130
|
+
message=response.error.message if response.error else "Unknown error",
|
|
131
|
+
code=response.error.code if response.error else None,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
return response.result
|
|
135
|
+
|
|
136
|
+
async def signin(
|
|
137
|
+
self,
|
|
138
|
+
user: str | None = None,
|
|
139
|
+
password: str | None = None,
|
|
140
|
+
namespace: str | None = None,
|
|
141
|
+
database: str | None = None,
|
|
142
|
+
access: str | None = None,
|
|
143
|
+
**credentials: Any,
|
|
144
|
+
) -> AuthResponse:
|
|
145
|
+
"""
|
|
146
|
+
Authenticate with SurrealDB.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
user: Username (for root/namespace/database auth)
|
|
150
|
+
password: Password (for root/namespace/database auth)
|
|
151
|
+
namespace: Optional namespace scope
|
|
152
|
+
database: Optional database scope
|
|
153
|
+
access: Optional access method (for record access auth)
|
|
154
|
+
**credentials: Additional credentials for record access (email, password, etc.)
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
AuthResponse with token and success status
|
|
158
|
+
"""
|
|
159
|
+
from ..exceptions import AuthenticationError
|
|
160
|
+
|
|
161
|
+
params: dict[str, Any] = {}
|
|
162
|
+
if user:
|
|
163
|
+
params["user"] = user
|
|
164
|
+
if password:
|
|
165
|
+
params["pass"] = password
|
|
166
|
+
if namespace:
|
|
167
|
+
params["ns"] = namespace
|
|
168
|
+
if database:
|
|
169
|
+
params["db"] = database
|
|
170
|
+
if access:
|
|
171
|
+
params["ac"] = access
|
|
172
|
+
# Add any additional credentials for record access
|
|
173
|
+
params.update(credentials)
|
|
174
|
+
|
|
175
|
+
try:
|
|
176
|
+
result = await self.rpc("signin", params)
|
|
177
|
+
response = AuthResponse.from_rpc_result(result)
|
|
178
|
+
if response.token:
|
|
179
|
+
self._token = response.token
|
|
180
|
+
self._authenticated = response.success
|
|
181
|
+
return response
|
|
182
|
+
except Exception as e:
|
|
183
|
+
raise AuthenticationError(f"Authentication failed: {e}")
|
|
184
|
+
|
|
185
|
+
async def signup(
|
|
186
|
+
self,
|
|
187
|
+
namespace: str,
|
|
188
|
+
database: str,
|
|
189
|
+
access: str,
|
|
190
|
+
**credentials: Any,
|
|
191
|
+
) -> AuthResponse:
|
|
192
|
+
"""
|
|
193
|
+
Sign up a new user.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
namespace: Namespace
|
|
197
|
+
database: Database
|
|
198
|
+
access: Access method
|
|
199
|
+
**credentials: Additional credentials (email, password, etc.)
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
AuthResponse with token and success status
|
|
203
|
+
"""
|
|
204
|
+
params = {
|
|
205
|
+
"ns": namespace,
|
|
206
|
+
"db": database,
|
|
207
|
+
"ac": access,
|
|
208
|
+
**credentials,
|
|
209
|
+
}
|
|
210
|
+
result = await self.rpc("signup", params)
|
|
211
|
+
response = AuthResponse.from_rpc_result(result)
|
|
212
|
+
if response.token:
|
|
213
|
+
self._token = response.token
|
|
214
|
+
self._authenticated = response.success
|
|
215
|
+
return response
|
|
216
|
+
|
|
217
|
+
async def use(self, namespace: str, database: str) -> None:
|
|
218
|
+
"""
|
|
219
|
+
Set the namespace and database to use.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
namespace: Target namespace
|
|
223
|
+
database: Target database
|
|
224
|
+
"""
|
|
225
|
+
await self.rpc("use", [namespace, database])
|
|
226
|
+
self.namespace = namespace
|
|
227
|
+
self.database = database
|
|
228
|
+
|
|
229
|
+
async def info(self) -> InfoResponse:
|
|
230
|
+
"""Get information about the current user."""
|
|
231
|
+
result = await self.rpc("info")
|
|
232
|
+
return InfoResponse.from_rpc_result(result)
|
|
233
|
+
|
|
234
|
+
async def version(self) -> str:
|
|
235
|
+
"""Get SurrealDB server version."""
|
|
236
|
+
result = await self.rpc("version")
|
|
237
|
+
return str(result) if result else ""
|
|
238
|
+
|
|
239
|
+
async def ping(self) -> bool:
|
|
240
|
+
"""Check if connection is alive."""
|
|
241
|
+
try:
|
|
242
|
+
await self.rpc("ping")
|
|
243
|
+
return True
|
|
244
|
+
except Exception:
|
|
245
|
+
return False
|
|
246
|
+
|
|
247
|
+
# Query methods
|
|
248
|
+
|
|
249
|
+
async def query(self, sql: str, vars: dict[str, Any] | None = None) -> QueryResponse:
|
|
250
|
+
"""
|
|
251
|
+
Execute a SurrealQL query.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
sql: SurrealQL query string
|
|
255
|
+
vars: Query variables
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
QueryResponse containing results for each statement
|
|
259
|
+
"""
|
|
260
|
+
result = await self.rpc("query", [sql, vars or {}])
|
|
261
|
+
return QueryResponse.from_rpc_result(result)
|
|
262
|
+
|
|
263
|
+
async def select(self, thing: str) -> RecordsResponse:
|
|
264
|
+
"""
|
|
265
|
+
Select records from a table or specific record.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
thing: Table name or record ID (e.g., "users" or "users:123")
|
|
269
|
+
|
|
270
|
+
Returns:
|
|
271
|
+
RecordsResponse containing selected records
|
|
272
|
+
"""
|
|
273
|
+
result = await self.rpc("select", [thing])
|
|
274
|
+
return RecordsResponse.from_rpc_result(result)
|
|
275
|
+
|
|
276
|
+
async def create(self, thing: str, data: dict[str, Any] | None = None) -> RecordResponse:
|
|
277
|
+
"""
|
|
278
|
+
Create a new record.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
thing: Table name or record ID
|
|
282
|
+
data: Record data
|
|
283
|
+
|
|
284
|
+
Returns:
|
|
285
|
+
RecordResponse containing the created record
|
|
286
|
+
"""
|
|
287
|
+
result = await self.rpc("create", [thing, data or {}])
|
|
288
|
+
return RecordResponse.from_rpc_result(result)
|
|
289
|
+
|
|
290
|
+
async def insert(self, table: str, data: list[dict[str, Any]] | dict[str, Any]) -> RecordsResponse:
|
|
291
|
+
"""
|
|
292
|
+
Insert one or more records.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
table: Table name
|
|
296
|
+
data: Record(s) to insert
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
RecordsResponse containing inserted records
|
|
300
|
+
"""
|
|
301
|
+
result = await self.rpc("insert", [table, data])
|
|
302
|
+
return RecordsResponse.from_rpc_result(result)
|
|
303
|
+
|
|
304
|
+
async def update(self, thing: str, data: dict[str, Any]) -> RecordsResponse:
|
|
305
|
+
"""
|
|
306
|
+
Update record(s), replacing all fields.
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
thing: Table name or record ID
|
|
310
|
+
data: New record data
|
|
311
|
+
|
|
312
|
+
Returns:
|
|
313
|
+
RecordsResponse containing updated record(s)
|
|
314
|
+
"""
|
|
315
|
+
result = await self.rpc("update", [thing, data])
|
|
316
|
+
return RecordsResponse.from_rpc_result(result)
|
|
317
|
+
|
|
318
|
+
async def merge(self, thing: str, data: dict[str, Any]) -> RecordsResponse:
|
|
319
|
+
"""
|
|
320
|
+
Merge data into record(s), updating only specified fields.
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
thing: Table name or record ID
|
|
324
|
+
data: Fields to merge
|
|
325
|
+
|
|
326
|
+
Returns:
|
|
327
|
+
RecordsResponse containing updated record(s)
|
|
328
|
+
"""
|
|
329
|
+
result = await self.rpc("merge", [thing, data])
|
|
330
|
+
return RecordsResponse.from_rpc_result(result)
|
|
331
|
+
|
|
332
|
+
async def patch(self, thing: str, patches: list[dict[str, Any]]) -> RecordsResponse:
|
|
333
|
+
"""
|
|
334
|
+
Apply JSON Patch operations to record(s).
|
|
335
|
+
|
|
336
|
+
Args:
|
|
337
|
+
thing: Table name or record ID
|
|
338
|
+
patches: List of JSON Patch operations
|
|
339
|
+
|
|
340
|
+
Returns:
|
|
341
|
+
RecordsResponse containing updated record(s)
|
|
342
|
+
"""
|
|
343
|
+
result = await self.rpc("patch", [thing, patches])
|
|
344
|
+
return RecordsResponse.from_rpc_result(result)
|
|
345
|
+
|
|
346
|
+
async def delete(self, thing: str) -> DeleteResponse:
|
|
347
|
+
"""
|
|
348
|
+
Delete record(s).
|
|
349
|
+
|
|
350
|
+
Args:
|
|
351
|
+
thing: Table name or record ID
|
|
352
|
+
|
|
353
|
+
Returns:
|
|
354
|
+
DeleteResponse containing deleted record(s)
|
|
355
|
+
"""
|
|
356
|
+
result = await self.rpc("delete", [thing])
|
|
357
|
+
return DeleteResponse.from_rpc_result(result)
|
|
358
|
+
|
|
359
|
+
async def relate(
|
|
360
|
+
self,
|
|
361
|
+
from_thing: str,
|
|
362
|
+
relation: str,
|
|
363
|
+
to_thing: str,
|
|
364
|
+
data: dict[str, Any] | None = None,
|
|
365
|
+
) -> RecordResponse:
|
|
366
|
+
"""
|
|
367
|
+
Create a graph relationship between records.
|
|
368
|
+
|
|
369
|
+
Args:
|
|
370
|
+
from_thing: Source record ID
|
|
371
|
+
relation: Relation table name
|
|
372
|
+
to_thing: Target record ID
|
|
373
|
+
data: Optional relation data
|
|
374
|
+
|
|
375
|
+
Returns:
|
|
376
|
+
RecordResponse containing the created relation record
|
|
377
|
+
"""
|
|
378
|
+
params: list[Any] = [from_thing, relation, to_thing]
|
|
379
|
+
if data:
|
|
380
|
+
params.append(data)
|
|
381
|
+
result = await self.rpc("relate", params)
|
|
382
|
+
return RecordResponse.from_rpc_result(result)
|
|
383
|
+
|
|
384
|
+
# Transaction support
|
|
385
|
+
|
|
386
|
+
@abstractmethod
|
|
387
|
+
def transaction(self) -> "BaseTransaction":
|
|
388
|
+
"""
|
|
389
|
+
Create a new transaction context.
|
|
390
|
+
|
|
391
|
+
Usage:
|
|
392
|
+
async with conn.transaction() as tx:
|
|
393
|
+
await tx.update("players:abc", {"is_ready": True})
|
|
394
|
+
await tx.update("game_tables:xyz", {"ready_count": 1})
|
|
395
|
+
# Auto-commit on success, auto-rollback on exception
|
|
396
|
+
|
|
397
|
+
Returns:
|
|
398
|
+
Transaction context manager
|
|
399
|
+
"""
|
|
400
|
+
...
|
|
401
|
+
|
|
402
|
+
# Function call API
|
|
403
|
+
|
|
404
|
+
@property
|
|
405
|
+
def fn(self) -> "FunctionNamespace":
|
|
406
|
+
"""
|
|
407
|
+
Access SurrealDB function call API.
|
|
408
|
+
|
|
409
|
+
Usage:
|
|
410
|
+
# Built-in functions
|
|
411
|
+
result = await conn.fn.math.sqrt(16)
|
|
412
|
+
result = await conn.fn.time.now()
|
|
413
|
+
|
|
414
|
+
# Custom user-defined functions
|
|
415
|
+
result = await conn.fn.cast_vote(user_id, table_id, "yes")
|
|
416
|
+
|
|
417
|
+
Returns:
|
|
418
|
+
Function namespace for building calls
|
|
419
|
+
"""
|
|
420
|
+
from ..functions import FunctionNamespace
|
|
421
|
+
|
|
422
|
+
return FunctionNamespace(self)
|
|
423
|
+
|
|
424
|
+
async def call(
|
|
425
|
+
self,
|
|
426
|
+
function: str,
|
|
427
|
+
params: dict[str, Any] | None = None,
|
|
428
|
+
return_type: type | None = None,
|
|
429
|
+
) -> Any:
|
|
430
|
+
"""
|
|
431
|
+
Call a SurrealDB function with typed return value.
|
|
432
|
+
|
|
433
|
+
This method provides a clean interface for calling custom functions
|
|
434
|
+
with automatic type conversion using Pydantic models or dataclasses.
|
|
435
|
+
|
|
436
|
+
Args:
|
|
437
|
+
function: Function name (e.g., "fn::cast_vote" or just "cast_vote")
|
|
438
|
+
params: Named parameters to pass to the function
|
|
439
|
+
return_type: Optional Pydantic model or dataclass to convert result to
|
|
440
|
+
|
|
441
|
+
Returns:
|
|
442
|
+
The function result, optionally converted to return_type
|
|
443
|
+
|
|
444
|
+
Usage:
|
|
445
|
+
# Without type
|
|
446
|
+
result = await conn.call("fn::cast_vote", {
|
|
447
|
+
"user_id": "users:alice",
|
|
448
|
+
"table_id": "game_tables:xyz",
|
|
449
|
+
"vote": "yes"
|
|
450
|
+
})
|
|
451
|
+
|
|
452
|
+
# With typed return
|
|
453
|
+
@dataclass
|
|
454
|
+
class VoteResult:
|
|
455
|
+
success: bool
|
|
456
|
+
new_count: int
|
|
457
|
+
total_votes: int
|
|
458
|
+
|
|
459
|
+
result: VoteResult = await conn.call(
|
|
460
|
+
"fn::cast_vote",
|
|
461
|
+
params={"user_id": "users:alice", "table_id": "game_tables:xyz", "vote": "yes"},
|
|
462
|
+
return_type=VoteResult
|
|
463
|
+
)
|
|
464
|
+
"""
|
|
465
|
+
# Normalize function name
|
|
466
|
+
if not function.startswith("fn::") and "::" not in function:
|
|
467
|
+
function = f"fn::{function}"
|
|
468
|
+
|
|
469
|
+
# Build parameterized query
|
|
470
|
+
if params:
|
|
471
|
+
param_placeholders = ", ".join(f"${key}" for key in params.keys())
|
|
472
|
+
sql = f"RETURN {function}({param_placeholders});"
|
|
473
|
+
else:
|
|
474
|
+
sql = f"RETURN {function}();"
|
|
475
|
+
|
|
476
|
+
result = await self.query(sql, params or {})
|
|
477
|
+
|
|
478
|
+
# Extract result value
|
|
479
|
+
value = None
|
|
480
|
+
if result.first_result and result.first_result.result is not None:
|
|
481
|
+
value = result.first_result.result
|
|
482
|
+
|
|
483
|
+
# Convert to return_type if specified
|
|
484
|
+
if return_type is not None and value is not None:
|
|
485
|
+
return self._convert_to_type(value, return_type)
|
|
486
|
+
|
|
487
|
+
return value
|
|
488
|
+
|
|
489
|
+
def _convert_to_type(self, value: Any, target_type: type) -> Any:
|
|
490
|
+
"""Convert a value to the target type."""
|
|
491
|
+
import dataclasses
|
|
492
|
+
|
|
493
|
+
# Check if it's a Pydantic model
|
|
494
|
+
try:
|
|
495
|
+
from pydantic import BaseModel
|
|
496
|
+
|
|
497
|
+
if isinstance(target_type, type) and issubclass(target_type, BaseModel):
|
|
498
|
+
if isinstance(value, dict):
|
|
499
|
+
return target_type(**value)
|
|
500
|
+
return target_type.model_validate(value)
|
|
501
|
+
except ImportError:
|
|
502
|
+
pass
|
|
503
|
+
|
|
504
|
+
# Check if it's a dataclass
|
|
505
|
+
if dataclasses.is_dataclass(target_type) and isinstance(target_type, type):
|
|
506
|
+
if isinstance(value, dict):
|
|
507
|
+
return target_type(**value)
|
|
508
|
+
|
|
509
|
+
# For simple types, try direct conversion
|
|
510
|
+
if isinstance(target_type, type):
|
|
511
|
+
try:
|
|
512
|
+
return target_type(value)
|
|
513
|
+
except (TypeError, ValueError):
|
|
514
|
+
pass
|
|
515
|
+
|
|
516
|
+
return value
|