digitalkin 0.3.0rc1__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- digitalkin/__version__.py +1 -1
- digitalkin/core/common/__init__.py +9 -0
- digitalkin/core/common/factories.py +156 -0
- digitalkin/core/job_manager/base_job_manager.py +128 -28
- digitalkin/core/job_manager/single_job_manager.py +80 -25
- digitalkin/core/job_manager/taskiq_broker.py +114 -19
- digitalkin/core/job_manager/taskiq_job_manager.py +291 -39
- digitalkin/core/task_manager/base_task_manager.py +539 -0
- digitalkin/core/task_manager/local_task_manager.py +108 -0
- digitalkin/core/task_manager/remote_task_manager.py +87 -0
- digitalkin/core/task_manager/surrealdb_repository.py +43 -4
- digitalkin/core/task_manager/task_executor.py +249 -0
- digitalkin/core/task_manager/task_session.py +107 -19
- digitalkin/grpc_servers/module_server.py +2 -2
- digitalkin/grpc_servers/module_servicer.py +21 -12
- digitalkin/grpc_servers/registry_server.py +1 -1
- digitalkin/grpc_servers/registry_servicer.py +4 -4
- digitalkin/grpc_servers/utils/grpc_error_handler.py +53 -0
- digitalkin/models/core/task_monitor.py +17 -0
- digitalkin/models/grpc_servers/models.py +4 -4
- digitalkin/models/module/module_context.py +5 -0
- digitalkin/models/module/module_types.py +304 -16
- digitalkin/modules/_base_module.py +66 -28
- digitalkin/services/cost/grpc_cost.py +8 -41
- digitalkin/services/filesystem/grpc_filesystem.py +9 -38
- digitalkin/services/services_config.py +11 -0
- digitalkin/services/services_models.py +3 -1
- digitalkin/services/setup/default_setup.py +5 -6
- digitalkin/services/setup/grpc_setup.py +51 -14
- digitalkin/services/storage/grpc_storage.py +2 -2
- digitalkin/services/user_profile/__init__.py +12 -0
- digitalkin/services/user_profile/default_user_profile.py +55 -0
- digitalkin/services/user_profile/grpc_user_profile.py +69 -0
- digitalkin/services/user_profile/user_profile_strategy.py +40 -0
- digitalkin/utils/__init__.py +28 -0
- digitalkin/utils/dynamic_schema.py +483 -0
- {digitalkin-0.3.0rc1.dist-info → digitalkin-0.3.1.dist-info}/METADATA +9 -29
- {digitalkin-0.3.0rc1.dist-info → digitalkin-0.3.1.dist-info}/RECORD +42 -30
- modules/dynamic_setup_module.py +362 -0
- digitalkin/core/task_manager/task_manager.py +0 -439
- {digitalkin-0.3.0rc1.dist-info → digitalkin-0.3.1.dist-info}/WHEEL +0 -0
- {digitalkin-0.3.0rc1.dist-info → digitalkin-0.3.1.dist-info}/licenses/LICENSE +0 -0
- {digitalkin-0.3.0rc1.dist-info → digitalkin-0.3.1.dist-info}/top_level.txt +0 -0
|
@@ -5,7 +5,7 @@ from collections.abc import AsyncGenerator
|
|
|
5
5
|
from typing import Any
|
|
6
6
|
|
|
7
7
|
import grpc
|
|
8
|
-
from digitalkin_proto.
|
|
8
|
+
from digitalkin_proto.agentic_mesh_protocol.module.v1 import (
|
|
9
9
|
information_pb2,
|
|
10
10
|
lifecycle_pb2,
|
|
11
11
|
module_service_pb2_grpc,
|
|
@@ -112,7 +112,7 @@ class ModuleServicer(module_service_pb2_grpc.ModuleServiceServicer, ArgParser):
|
|
|
112
112
|
# TODO: Secret should be used here as well
|
|
113
113
|
setup_version = request.setup_version
|
|
114
114
|
config_setup_data = self.module_class.create_config_setup_model(json_format.MessageToDict(request.content))
|
|
115
|
-
setup_version_data = self.module_class.create_setup_model(
|
|
115
|
+
setup_version_data = await self.module_class.create_setup_model(
|
|
116
116
|
json_format.MessageToDict(request.setup_version.content),
|
|
117
117
|
config_fields=True,
|
|
118
118
|
)
|
|
@@ -172,7 +172,8 @@ class ModuleServicer(module_service_pb2_grpc.ModuleServiceServicer, ArgParser):
|
|
|
172
172
|
)
|
|
173
173
|
# Process the module input
|
|
174
174
|
# TODO: Check failure of input data format
|
|
175
|
-
input_data = self.module_class.create_input_model(
|
|
175
|
+
input_data = self.module_class.create_input_model(json_format.MessageToDict(request.input))
|
|
176
|
+
|
|
176
177
|
setup_data_class = self.setup.get_setup(
|
|
177
178
|
setup_dict={
|
|
178
179
|
"setup_id": request.setup_id,
|
|
@@ -184,7 +185,7 @@ class ModuleServicer(module_service_pb2_grpc.ModuleServiceServicer, ArgParser):
|
|
|
184
185
|
msg = "No setup data returned."
|
|
185
186
|
raise ServicerError(msg)
|
|
186
187
|
|
|
187
|
-
setup_data = self.module_class.create_setup_model(setup_data_class.current_setup_version.content)
|
|
188
|
+
setup_data = await self.module_class.create_setup_model(setup_data_class.current_setup_version.content)
|
|
188
189
|
|
|
189
190
|
# create a task to run the module in background
|
|
190
191
|
job_id = await self.job_manager.create_module_instance_job(
|
|
@@ -219,13 +220,17 @@ class ModuleServicer(module_service_pb2_grpc.ModuleServiceServicer, ArgParser):
|
|
|
219
220
|
break
|
|
220
221
|
|
|
221
222
|
if message.get("code", None) is not None and message.get("code") == "__END_OF_STREAM__":
|
|
222
|
-
|
|
223
|
+
logger.info(
|
|
224
|
+
"End of stream via __END_OF_STREAM__",
|
|
225
|
+
extra={"job_id": job_id, "mission_id": request.mission_id},
|
|
226
|
+
)
|
|
223
227
|
break
|
|
224
228
|
|
|
229
|
+
logger.info("Yielding message from job %s: %s", job_id, message)
|
|
225
230
|
proto = json_format.ParseDict(message, struct_pb2.Struct(), ignore_unknown_fields=True)
|
|
226
231
|
yield lifecycle_pb2.StartModuleResponse(success=True, output=proto, job_id=job_id)
|
|
227
232
|
finally:
|
|
228
|
-
await self.job_manager.
|
|
233
|
+
await self.job_manager.wait_for_completion(job_id)
|
|
229
234
|
await self.job_manager.clean_session(job_id, mission_id=request.mission_id)
|
|
230
235
|
|
|
231
236
|
logger.info("Job %s finished", job_id)
|
|
@@ -345,7 +350,9 @@ class ModuleServicer(module_service_pb2_grpc.ModuleServiceServicer, ArgParser):
|
|
|
345
350
|
# Get input schema if available
|
|
346
351
|
try:
|
|
347
352
|
# Convert schema to proto format
|
|
348
|
-
input_schema_proto = self.module_class.get_input_format(
|
|
353
|
+
input_schema_proto = await self.module_class.get_input_format(
|
|
354
|
+
llm_format=request.llm_format,
|
|
355
|
+
)
|
|
349
356
|
input_format_struct = json_format.Parse(
|
|
350
357
|
text=input_schema_proto,
|
|
351
358
|
message=struct_pb2.Struct(), # pylint: disable=no-member
|
|
@@ -381,7 +388,9 @@ class ModuleServicer(module_service_pb2_grpc.ModuleServiceServicer, ArgParser):
|
|
|
381
388
|
# Get output schema if available
|
|
382
389
|
try:
|
|
383
390
|
# Convert schema to proto format
|
|
384
|
-
output_schema_proto = self.module_class.get_output_format(
|
|
391
|
+
output_schema_proto = await self.module_class.get_output_format(
|
|
392
|
+
llm_format=request.llm_format,
|
|
393
|
+
)
|
|
385
394
|
output_format_struct = json_format.Parse(
|
|
386
395
|
text=output_schema_proto,
|
|
387
396
|
message=struct_pb2.Struct(), # pylint: disable=no-member
|
|
@@ -417,7 +426,7 @@ class ModuleServicer(module_service_pb2_grpc.ModuleServiceServicer, ArgParser):
|
|
|
417
426
|
# Get setup schema if available
|
|
418
427
|
try:
|
|
419
428
|
# Convert schema to proto format
|
|
420
|
-
setup_schema_proto = self.module_class.get_setup_format(llm_format=request.llm_format)
|
|
429
|
+
setup_schema_proto = await self.module_class.get_setup_format(llm_format=request.llm_format)
|
|
421
430
|
setup_format_struct = json_format.Parse(
|
|
422
431
|
text=setup_schema_proto,
|
|
423
432
|
message=struct_pb2.Struct(), # pylint: disable=no-member
|
|
@@ -434,7 +443,7 @@ class ModuleServicer(module_service_pb2_grpc.ModuleServiceServicer, ArgParser):
|
|
|
434
443
|
setup_schema=setup_format_struct,
|
|
435
444
|
)
|
|
436
445
|
|
|
437
|
-
def GetModuleSecret( # noqa: N802
|
|
446
|
+
async def GetModuleSecret( # noqa: N802
|
|
438
447
|
self,
|
|
439
448
|
request: information_pb2.GetModuleSecretRequest,
|
|
440
449
|
context: grpc.ServicerContext,
|
|
@@ -453,7 +462,7 @@ class ModuleServicer(module_service_pb2_grpc.ModuleServiceServicer, ArgParser):
|
|
|
453
462
|
# Get secret schema if available
|
|
454
463
|
try:
|
|
455
464
|
# Convert schema to proto format
|
|
456
|
-
secret_schema_proto = self.module_class.get_secret_format(llm_format=request.llm_format)
|
|
465
|
+
secret_schema_proto = await self.module_class.get_secret_format(llm_format=request.llm_format)
|
|
457
466
|
secret_format_struct = json_format.Parse(
|
|
458
467
|
text=secret_schema_proto,
|
|
459
468
|
message=struct_pb2.Struct(), # pylint: disable=no-member
|
|
@@ -489,7 +498,7 @@ class ModuleServicer(module_service_pb2_grpc.ModuleServiceServicer, ArgParser):
|
|
|
489
498
|
# Get setup schema if available
|
|
490
499
|
try:
|
|
491
500
|
# Convert schema to proto format
|
|
492
|
-
config_setup_schema_proto = self.module_class.get_config_setup_format(llm_format=request.llm_format)
|
|
501
|
+
config_setup_schema_proto = await self.module_class.get_config_setup_format(llm_format=request.llm_format)
|
|
493
502
|
config_setup_format_struct = json_format.Parse(
|
|
494
503
|
text=config_setup_schema_proto,
|
|
495
504
|
message=struct_pb2.Struct(), # pylint: disable=no-member
|
|
@@ -9,7 +9,7 @@ from collections.abc import Iterator
|
|
|
9
9
|
from enum import Enum
|
|
10
10
|
|
|
11
11
|
import grpc
|
|
12
|
-
from digitalkin_proto.
|
|
12
|
+
from digitalkin_proto.agentic_mesh_protocol.module_registry.v1 import (
|
|
13
13
|
discover_pb2,
|
|
14
14
|
metadata_pb2,
|
|
15
15
|
module_registry_service_pb2_grpc,
|
|
@@ -344,7 +344,7 @@ class RegistryServicer(module_registry_service_pb2_grpc.ModuleRegistryServiceSer
|
|
|
344
344
|
return status_pb2.ModuleStatusResponse()
|
|
345
345
|
|
|
346
346
|
module = self.registered_modules[request.module_id]
|
|
347
|
-
return status_pb2.ModuleStatusResponse(module_id=module.module_id, status=module.status.
|
|
347
|
+
return status_pb2.ModuleStatusResponse(module_id=module.module_id, status=module.status.value)
|
|
348
348
|
|
|
349
349
|
def ListModuleStatus( # noqa: N802
|
|
350
350
|
self,
|
|
@@ -379,7 +379,7 @@ class RegistryServicer(module_registry_service_pb2_grpc.ModuleRegistryServiceSer
|
|
|
379
379
|
list_size = len(self.registered_modules)
|
|
380
380
|
|
|
381
381
|
modules_statuses = [
|
|
382
|
-
status_pb2.ModuleStatusResponse(module_id=module.module_id, status=module.status.
|
|
382
|
+
status_pb2.ModuleStatusResponse(module_id=module.module_id, status=module.status.value)
|
|
383
383
|
for module in list(self.registered_modules.values())[request.offset : request.offset + list_size]
|
|
384
384
|
]
|
|
385
385
|
|
|
@@ -409,7 +409,7 @@ class RegistryServicer(module_registry_service_pb2_grpc.ModuleRegistryServiceSer
|
|
|
409
409
|
for module in self.registered_modules.values():
|
|
410
410
|
yield status_pb2.ModuleStatusResponse(
|
|
411
411
|
module_id=module.module_id,
|
|
412
|
-
status=module.status.
|
|
412
|
+
status=module.status.value,
|
|
413
413
|
)
|
|
414
414
|
|
|
415
415
|
def UpdateModuleStatus( # noqa: N802
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
"""Shared error handling utilities for gRPC services."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Generator
|
|
4
|
+
from contextlib import contextmanager
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from digitalkin.grpc_servers.utils.exceptions import ServerError
|
|
8
|
+
from digitalkin.logger import logger
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class GrpcErrorHandlerMixin:
|
|
12
|
+
"""Mixin class providing common gRPC error handling functionality."""
|
|
13
|
+
|
|
14
|
+
@contextmanager
|
|
15
|
+
def handle_grpc_errors( # noqa: PLR6301
|
|
16
|
+
self,
|
|
17
|
+
operation: str,
|
|
18
|
+
service_error_class: type[Exception] | None = None,
|
|
19
|
+
) -> Generator[Any, Any, Any]:
|
|
20
|
+
"""Handle gRPC errors for the given operation.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
operation: Name of the operation being performed.
|
|
24
|
+
service_error_class: Optional specific service exception class to raise.
|
|
25
|
+
If not provided, uses the generic ServerError.
|
|
26
|
+
|
|
27
|
+
Yields:
|
|
28
|
+
Context for the operation.
|
|
29
|
+
|
|
30
|
+
Raises:
|
|
31
|
+
ServerError: For gRPC-related errors.
|
|
32
|
+
service_error_class: For service-specific errors if provided.
|
|
33
|
+
"""
|
|
34
|
+
if service_error_class is None:
|
|
35
|
+
service_error_class = ServerError
|
|
36
|
+
|
|
37
|
+
try:
|
|
38
|
+
yield
|
|
39
|
+
except service_error_class as e:
|
|
40
|
+
# Re-raise service-specific errors as-is
|
|
41
|
+
msg = f"{service_error_class.__name__} in {operation}: {e}"
|
|
42
|
+
logger.exception(msg)
|
|
43
|
+
raise service_error_class(msg) from e
|
|
44
|
+
except ServerError as e:
|
|
45
|
+
# Handle gRPC server errors
|
|
46
|
+
msg = f"gRPC {operation} failed: {e}"
|
|
47
|
+
logger.exception(msg)
|
|
48
|
+
raise ServerError(msg) from e
|
|
49
|
+
except Exception as e:
|
|
50
|
+
# Handle unexpected errors
|
|
51
|
+
msg = f"Unexpected error in {operation}: {e}"
|
|
52
|
+
logger.exception(msg)
|
|
53
|
+
raise service_error_class(msg) from e
|
|
@@ -17,6 +17,23 @@ class TaskStatus(Enum):
|
|
|
17
17
|
FAILED = "failed"
|
|
18
18
|
|
|
19
19
|
|
|
20
|
+
class CancellationReason(Enum):
|
|
21
|
+
"""Reason for task cancellation - helps distinguish cleanup vs real cancellation."""
|
|
22
|
+
|
|
23
|
+
# Cleanup cancellations (not errors)
|
|
24
|
+
SUCCESS_CLEANUP = "success_cleanup" # Main task completed, cleaning up helper tasks
|
|
25
|
+
FAILURE_CLEANUP = "failure_cleanup" # Main task failed, cleaning up helper tasks
|
|
26
|
+
|
|
27
|
+
# Real cancellations
|
|
28
|
+
SIGNAL = "signal" # External signal requested cancellation
|
|
29
|
+
HEARTBEAT_FAILURE = "heartbeat_failure" # Heartbeat stopped working
|
|
30
|
+
TIMEOUT = "timeout" # Task timed out
|
|
31
|
+
SHUTDOWN = "shutdown" # Manager is shutting down
|
|
32
|
+
|
|
33
|
+
# Unknown/unset
|
|
34
|
+
UNKNOWN = "unknown" # Reason not determined
|
|
35
|
+
|
|
36
|
+
|
|
20
37
|
class SignalType(Enum):
|
|
21
38
|
"""Signal type enumeration."""
|
|
22
39
|
|
|
@@ -175,8 +175,8 @@ class ClientConfig(ChannelConfig):
|
|
|
175
175
|
credentials: ClientCredentials | None = Field(None, description="Client credentials for secure mode")
|
|
176
176
|
channel_options: list[tuple[str, Any]] = Field(
|
|
177
177
|
default_factory=lambda: [
|
|
178
|
-
("grpc.max_receive_message_length",
|
|
179
|
-
("grpc.max_send_message_length",
|
|
178
|
+
("grpc.max_receive_message_length", 100 * 1024 * 1024), # 100MB
|
|
179
|
+
("grpc.max_send_message_length", 100 * 1024 * 1024), # 100MB
|
|
180
180
|
],
|
|
181
181
|
description="Additional channel options",
|
|
182
182
|
)
|
|
@@ -223,8 +223,8 @@ class ServerConfig(ChannelConfig):
|
|
|
223
223
|
credentials: ServerCredentials | None = Field(None, description="Server credentials for secure mode")
|
|
224
224
|
server_options: list[tuple[str, Any]] = Field(
|
|
225
225
|
default_factory=lambda: [
|
|
226
|
-
("grpc.max_receive_message_length",
|
|
227
|
-
("grpc.max_send_message_length",
|
|
226
|
+
("grpc.max_receive_message_length", 100 * 1024 * 1024), # 100MB
|
|
227
|
+
("grpc.max_send_message_length", 100 * 1024 * 1024), # 100MB
|
|
228
228
|
],
|
|
229
229
|
description="Additional server options",
|
|
230
230
|
)
|
|
@@ -10,6 +10,7 @@ from digitalkin.services.identity.identity_strategy import IdentityStrategy
|
|
|
10
10
|
from digitalkin.services.registry.registry_strategy import RegistryStrategy
|
|
11
11
|
from digitalkin.services.snapshot.snapshot_strategy import SnapshotStrategy
|
|
12
12
|
from digitalkin.services.storage.storage_strategy import StorageStrategy
|
|
13
|
+
from digitalkin.services.user_profile.user_profile_strategy import UserProfileStrategy
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
class Session(SimpleNamespace):
|
|
@@ -89,6 +90,7 @@ class ModuleContext:
|
|
|
89
90
|
registry: RegistryStrategy
|
|
90
91
|
snapshot: SnapshotStrategy
|
|
91
92
|
storage: StorageStrategy
|
|
93
|
+
user_profile: UserProfileStrategy
|
|
92
94
|
|
|
93
95
|
session: Session
|
|
94
96
|
callbacks: SimpleNamespace
|
|
@@ -105,6 +107,7 @@ class ModuleContext:
|
|
|
105
107
|
registry: RegistryStrategy,
|
|
106
108
|
snapshot: SnapshotStrategy,
|
|
107
109
|
storage: StorageStrategy,
|
|
110
|
+
user_profile: UserProfileStrategy,
|
|
108
111
|
session: dict[str, Any],
|
|
109
112
|
metadata: dict[str, Any] = {},
|
|
110
113
|
helpers: dict[str, Any] = {},
|
|
@@ -120,6 +123,7 @@ class ModuleContext:
|
|
|
120
123
|
registry: RegistryStrategy.
|
|
121
124
|
snapshot: SnapshotStrategy.
|
|
122
125
|
storage: StorageStrategy.
|
|
126
|
+
user_profile: UserProfileStrategy.
|
|
123
127
|
metadata: dict defining differents Module metadata.
|
|
124
128
|
helpers: dict different user defined helpers.
|
|
125
129
|
session: dict referring the session IDs or informations.
|
|
@@ -133,6 +137,7 @@ class ModuleContext:
|
|
|
133
137
|
self.registry = registry
|
|
134
138
|
self.snapshot = snapshot
|
|
135
139
|
self.storage = storage
|
|
140
|
+
self.user_profile = user_profile
|
|
136
141
|
|
|
137
142
|
self.metadata = SimpleNamespace(**metadata)
|
|
138
143
|
self.session = Session(**session)
|
|
@@ -1,11 +1,25 @@
|
|
|
1
1
|
"""Types for module models."""
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import copy
|
|
6
|
+
import types
|
|
7
|
+
import typing
|
|
3
8
|
from datetime import datetime, timezone
|
|
4
|
-
from typing import Any, ClassVar, Generic, TypeVar, cast
|
|
9
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin
|
|
5
10
|
|
|
6
11
|
from pydantic import BaseModel, ConfigDict, Field, create_model
|
|
7
12
|
|
|
8
13
|
from digitalkin.logger import logger
|
|
14
|
+
from digitalkin.utils.dynamic_schema import (
|
|
15
|
+
DynamicField,
|
|
16
|
+
get_fetchers,
|
|
17
|
+
has_dynamic,
|
|
18
|
+
resolve_safe,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from pydantic.fields import FieldInfo
|
|
9
23
|
|
|
10
24
|
|
|
11
25
|
class DataTrigger(BaseModel):
|
|
@@ -24,7 +38,11 @@ class DataTrigger(BaseModel):
|
|
|
24
38
|
"""
|
|
25
39
|
|
|
26
40
|
protocol: ClassVar[str]
|
|
27
|
-
created_at: str =
|
|
41
|
+
created_at: str = Field(
|
|
42
|
+
default_factory=lambda: datetime.now(tz=timezone.utc).isoformat(),
|
|
43
|
+
title="Created At",
|
|
44
|
+
description="Timestamp when the payload was created.",
|
|
45
|
+
)
|
|
28
46
|
|
|
29
47
|
|
|
30
48
|
DataTriggerT = TypeVar("DataTriggerT", bound=DataTrigger)
|
|
@@ -57,27 +75,50 @@ SetupModelT = TypeVar("SetupModelT", bound="SetupModel")
|
|
|
57
75
|
class SetupModel(BaseModel):
|
|
58
76
|
"""Base definition of setup model showing mandatory root fields.
|
|
59
77
|
|
|
60
|
-
Optionally, the setup model can define a config option in json_schema_extra
|
|
78
|
+
Optionally, the setup model can define a config option in json_schema_extra
|
|
79
|
+
to be used to initialize the Kin. Supports dynamic schema providers for
|
|
80
|
+
runtime value generation.
|
|
61
81
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
82
|
+
Attributes:
|
|
83
|
+
model_fields: Inherited from Pydantic BaseModel, contains field definitions.
|
|
84
|
+
|
|
85
|
+
See Also:
|
|
86
|
+
- Documentation: docs/api/dynamic_schema.md
|
|
87
|
+
- Tests: tests/modules/test_setup_model.py
|
|
66
88
|
"""
|
|
67
89
|
|
|
68
90
|
@classmethod
|
|
69
|
-
def get_clean_model(
|
|
70
|
-
|
|
91
|
+
async def get_clean_model(
|
|
92
|
+
cls,
|
|
93
|
+
*,
|
|
94
|
+
config_fields: bool,
|
|
95
|
+
hidden_fields: bool,
|
|
96
|
+
force: bool = False,
|
|
97
|
+
) -> type[SetupModelT]:
|
|
98
|
+
"""Dynamically builds and returns a new BaseModel subclass with filtered fields.
|
|
71
99
|
|
|
72
|
-
|
|
100
|
+
This method filters fields based on their `json_schema_extra` metadata:
|
|
101
|
+
- Fields with `{"config": True}` are included only when `config_fields=True`
|
|
102
|
+
- Fields with `{"hidden": True}` are included only when `hidden_fields=True`
|
|
73
103
|
|
|
74
|
-
|
|
75
|
-
|
|
104
|
+
When `force=True`, fields with dynamic schema providers will have their
|
|
105
|
+
providers called to fetch fresh values for schema metadata like enums.
|
|
106
|
+
This includes recursively processing nested BaseModel fields.
|
|
76
107
|
|
|
77
|
-
|
|
78
|
-
|
|
108
|
+
Args:
|
|
109
|
+
config_fields: If True, include fields marked with `{"config": True}`.
|
|
110
|
+
These are typically initial configuration fields.
|
|
111
|
+
hidden_fields: If True, include fields marked with `{"hidden": True}`.
|
|
112
|
+
These are typically runtime-only fields not shown in initial config.
|
|
113
|
+
force: If True, refresh dynamic schema fields by calling their providers.
|
|
114
|
+
Use this when you need up-to-date values from external sources like
|
|
115
|
+
databases or APIs. Default is False for performance.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
A new BaseModel subclass with filtered fields.
|
|
79
119
|
"""
|
|
80
120
|
clean_fields: dict[str, Any] = {}
|
|
121
|
+
|
|
81
122
|
for name, field_info in cls.model_fields.items():
|
|
82
123
|
extra = getattr(field_info, "json_schema_extra", {}) or {}
|
|
83
124
|
is_config = bool(extra.get("config", False))
|
|
@@ -93,7 +134,27 @@ class SetupModel(BaseModel):
|
|
|
93
134
|
logger.debug("Skipping '%s' (hidden-only)", name)
|
|
94
135
|
continue
|
|
95
136
|
|
|
96
|
-
|
|
137
|
+
# Refresh dynamic schema fields when force=True
|
|
138
|
+
current_field_info = field_info
|
|
139
|
+
current_annotation = field_info.annotation
|
|
140
|
+
|
|
141
|
+
if force:
|
|
142
|
+
# Check if this field has DynamicField metadata
|
|
143
|
+
if has_dynamic(field_info):
|
|
144
|
+
current_field_info = await cls._refresh_field_schema(name, field_info)
|
|
145
|
+
|
|
146
|
+
# Check if the annotation is a nested BaseModel that might have dynamic fields
|
|
147
|
+
nested_model = cls._get_base_model_type(current_annotation)
|
|
148
|
+
if nested_model is not None:
|
|
149
|
+
refreshed_nested = await cls._refresh_nested_model(nested_model)
|
|
150
|
+
if refreshed_nested is not nested_model:
|
|
151
|
+
# Update annotation to use refreshed nested model
|
|
152
|
+
current_annotation = refreshed_nested
|
|
153
|
+
# Create new field_info with updated annotation (deep copy for safety)
|
|
154
|
+
current_field_info = copy.deepcopy(current_field_info)
|
|
155
|
+
setattr(current_field_info, "annotation", current_annotation)
|
|
156
|
+
|
|
157
|
+
clean_fields[name] = (current_annotation, current_field_info)
|
|
97
158
|
|
|
98
159
|
# Dynamically create a model e.g. "SetupModel"
|
|
99
160
|
m = create_model(
|
|
@@ -102,4 +163,231 @@ class SetupModel(BaseModel):
|
|
|
102
163
|
__config__=ConfigDict(arbitrary_types_allowed=True),
|
|
103
164
|
**clean_fields,
|
|
104
165
|
)
|
|
105
|
-
return cast("type[SetupModelT]", m)
|
|
166
|
+
return cast("type[SetupModelT]", m)
|
|
167
|
+
|
|
168
|
+
@classmethod
|
|
169
|
+
def _get_base_model_type(cls, annotation: type | None) -> type[BaseModel] | None:
|
|
170
|
+
"""Extract BaseModel type from an annotation.
|
|
171
|
+
|
|
172
|
+
Handles direct types, Optional, Union, list, dict, set, tuple, and other generics.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
annotation: The type annotation to inspect.
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
The BaseModel subclass if found, None otherwise.
|
|
179
|
+
"""
|
|
180
|
+
if annotation is None:
|
|
181
|
+
return None
|
|
182
|
+
|
|
183
|
+
# Direct BaseModel subclass check
|
|
184
|
+
if isinstance(annotation, type) and issubclass(annotation, BaseModel):
|
|
185
|
+
return annotation
|
|
186
|
+
|
|
187
|
+
origin = get_origin(annotation)
|
|
188
|
+
if origin is None:
|
|
189
|
+
return None
|
|
190
|
+
|
|
191
|
+
args = get_args(annotation)
|
|
192
|
+
return cls._extract_base_model_from_args(origin, args)
|
|
193
|
+
|
|
194
|
+
@classmethod
|
|
195
|
+
def _extract_base_model_from_args(
|
|
196
|
+
cls,
|
|
197
|
+
origin: type,
|
|
198
|
+
args: tuple[type, ...],
|
|
199
|
+
) -> type[BaseModel] | None:
|
|
200
|
+
"""Extract BaseModel from generic type arguments.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
origin: The generic origin type (list, dict, Union, etc.).
|
|
204
|
+
args: The type arguments.
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
The BaseModel subclass if found, None otherwise.
|
|
208
|
+
"""
|
|
209
|
+
# Union/Optional: check each arg (supports both typing.Union and types.UnionType)
|
|
210
|
+
# Python 3.10+ uses types.UnionType for X | Y syntax
|
|
211
|
+
if origin is typing.Union or origin is types.UnionType:
|
|
212
|
+
return cls._find_base_model_in_args(args)
|
|
213
|
+
|
|
214
|
+
# list, set, frozenset: check first arg
|
|
215
|
+
if origin in {list, set, frozenset} and args:
|
|
216
|
+
return cls._check_base_model(args[0])
|
|
217
|
+
|
|
218
|
+
# dict: check value type (second arg)
|
|
219
|
+
dict_value_index = 1
|
|
220
|
+
if origin is dict and len(args) > dict_value_index:
|
|
221
|
+
return cls._check_base_model(args[dict_value_index])
|
|
222
|
+
|
|
223
|
+
# tuple: check first non-ellipsis arg
|
|
224
|
+
if origin is tuple:
|
|
225
|
+
return cls._find_base_model_in_args(args, skip_ellipsis=True)
|
|
226
|
+
|
|
227
|
+
return None
|
|
228
|
+
|
|
229
|
+
@classmethod
|
|
230
|
+
def _check_base_model(cls, arg: type) -> type[BaseModel] | None:
|
|
231
|
+
"""Check if arg is a BaseModel subclass.
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
The BaseModel subclass if arg is one, None otherwise.
|
|
235
|
+
"""
|
|
236
|
+
if isinstance(arg, type) and issubclass(arg, BaseModel):
|
|
237
|
+
return arg
|
|
238
|
+
return None
|
|
239
|
+
|
|
240
|
+
@classmethod
|
|
241
|
+
def _find_base_model_in_args(
|
|
242
|
+
cls,
|
|
243
|
+
args: tuple[type, ...],
|
|
244
|
+
*,
|
|
245
|
+
skip_ellipsis: bool = False,
|
|
246
|
+
) -> type[BaseModel] | None:
|
|
247
|
+
"""Find first BaseModel in args.
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
The first BaseModel subclass found, None otherwise.
|
|
251
|
+
"""
|
|
252
|
+
for arg in args:
|
|
253
|
+
if arg is type(None):
|
|
254
|
+
continue
|
|
255
|
+
if skip_ellipsis and arg is ...:
|
|
256
|
+
continue
|
|
257
|
+
result = cls._check_base_model(arg)
|
|
258
|
+
if result is not None:
|
|
259
|
+
return result
|
|
260
|
+
return None
|
|
261
|
+
|
|
262
|
+
@classmethod
|
|
263
|
+
async def _refresh_nested_model(cls, model_cls: type[BaseModel]) -> type[BaseModel]:
|
|
264
|
+
"""Refresh dynamic fields in a nested BaseModel.
|
|
265
|
+
|
|
266
|
+
Creates a new model class with all DynamicField metadata resolved.
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
model_cls: The nested model class to refresh.
|
|
270
|
+
|
|
271
|
+
Returns:
|
|
272
|
+
A new model class with refreshed fields, or the original if no changes.
|
|
273
|
+
"""
|
|
274
|
+
has_changes = False
|
|
275
|
+
clean_fields: dict[str, Any] = {}
|
|
276
|
+
|
|
277
|
+
for name, field_info in model_cls.model_fields.items():
|
|
278
|
+
current_field_info = field_info
|
|
279
|
+
current_annotation = field_info.annotation
|
|
280
|
+
|
|
281
|
+
# Check if field has DynamicField metadata
|
|
282
|
+
if has_dynamic(field_info):
|
|
283
|
+
current_field_info = await cls._refresh_field_schema(name, field_info)
|
|
284
|
+
has_changes = True
|
|
285
|
+
|
|
286
|
+
# Recursively check nested models
|
|
287
|
+
nested_model = cls._get_base_model_type(current_annotation)
|
|
288
|
+
if nested_model is not None:
|
|
289
|
+
refreshed_nested = await cls._refresh_nested_model(nested_model)
|
|
290
|
+
if refreshed_nested is not nested_model:
|
|
291
|
+
current_annotation = refreshed_nested
|
|
292
|
+
current_field_info = copy.deepcopy(current_field_info)
|
|
293
|
+
setattr(current_field_info, "annotation", current_annotation)
|
|
294
|
+
has_changes = True
|
|
295
|
+
|
|
296
|
+
clean_fields[name] = (current_annotation, current_field_info)
|
|
297
|
+
|
|
298
|
+
if not has_changes:
|
|
299
|
+
return model_cls
|
|
300
|
+
|
|
301
|
+
# Create new model with refreshed fields
|
|
302
|
+
logger.debug("Creating refreshed nested model for '%s'", model_cls.__name__)
|
|
303
|
+
return create_model(
|
|
304
|
+
model_cls.__name__,
|
|
305
|
+
__base__=BaseModel,
|
|
306
|
+
__config__=ConfigDict(arbitrary_types_allowed=True),
|
|
307
|
+
**clean_fields,
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
@classmethod
|
|
311
|
+
async def _refresh_field_schema(cls, field_name: str, field_info: FieldInfo) -> FieldInfo:
|
|
312
|
+
"""Refresh a field's json_schema_extra with fresh values from dynamic providers.
|
|
313
|
+
|
|
314
|
+
This method calls all dynamic providers registered for a field (via Annotated
|
|
315
|
+
metadata) and creates a new FieldInfo with the resolved values. The original
|
|
316
|
+
field_info is not modified.
|
|
317
|
+
|
|
318
|
+
Uses `resolve_safe()` for structured error handling, allowing partial success
|
|
319
|
+
when some fetchers fail. Successfully resolved values are still applied.
|
|
320
|
+
|
|
321
|
+
Args:
|
|
322
|
+
field_name: The name of the field being refreshed (used for logging).
|
|
323
|
+
field_info: The original FieldInfo object containing the dynamic providers.
|
|
324
|
+
|
|
325
|
+
Returns:
|
|
326
|
+
A new FieldInfo object with the same attributes as the original, but with
|
|
327
|
+
`json_schema_extra` containing resolved values and Dynamic metadata removed.
|
|
328
|
+
|
|
329
|
+
Note:
|
|
330
|
+
If all fetchers fail, the original field_info is returned unchanged.
|
|
331
|
+
If some fetchers fail, successfully resolved values are still applied.
|
|
332
|
+
"""
|
|
333
|
+
fetchers = get_fetchers(field_info)
|
|
334
|
+
|
|
335
|
+
if not fetchers:
|
|
336
|
+
return field_info
|
|
337
|
+
|
|
338
|
+
fetcher_keys = list(fetchers.keys())
|
|
339
|
+
logger.debug(
|
|
340
|
+
"Refreshing dynamic schema for field '%s' with fetchers: %s",
|
|
341
|
+
field_name,
|
|
342
|
+
fetcher_keys,
|
|
343
|
+
extra={"field_name": field_name, "fetcher_keys": fetcher_keys},
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
# Resolve all fetchers with structured error handling
|
|
347
|
+
result = await resolve_safe(fetchers)
|
|
348
|
+
|
|
349
|
+
# Log any errors that occurred with full details
|
|
350
|
+
if result.errors:
|
|
351
|
+
for key, error in result.errors.items():
|
|
352
|
+
logger.warning(
|
|
353
|
+
"Failed to resolve '%s' for field '%s': %s: %s",
|
|
354
|
+
key,
|
|
355
|
+
field_name,
|
|
356
|
+
type(error).__name__,
|
|
357
|
+
str(error) or "(no message)",
|
|
358
|
+
extra={
|
|
359
|
+
"field_name": field_name,
|
|
360
|
+
"fetcher_key": key,
|
|
361
|
+
"error_type": type(error).__name__,
|
|
362
|
+
"error_message": str(error),
|
|
363
|
+
"error_repr": repr(error),
|
|
364
|
+
},
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
# If no values were resolved, return original field_info
|
|
368
|
+
if not result.values:
|
|
369
|
+
logger.warning(
|
|
370
|
+
"All fetchers failed for field '%s', keeping original",
|
|
371
|
+
field_name,
|
|
372
|
+
)
|
|
373
|
+
return field_info
|
|
374
|
+
|
|
375
|
+
# Build new json_schema_extra with resolved values merged
|
|
376
|
+
extra = getattr(field_info, "json_schema_extra", {}) or {}
|
|
377
|
+
new_extra = {**extra, **result.values}
|
|
378
|
+
|
|
379
|
+
# Create a deep copy of the FieldInfo to avoid shared mutable state
|
|
380
|
+
new_field_info = copy.deepcopy(field_info)
|
|
381
|
+
setattr(new_field_info, "json_schema_extra", new_extra)
|
|
382
|
+
|
|
383
|
+
# Remove Dynamic from metadata (it's been resolved)
|
|
384
|
+
new_metadata = [m for m in new_field_info.metadata if not isinstance(m, DynamicField)]
|
|
385
|
+
setattr(new_field_info, "metadata", new_metadata)
|
|
386
|
+
|
|
387
|
+
logger.debug(
|
|
388
|
+
"Refreshed '%s' with dynamic values: %s",
|
|
389
|
+
field_name,
|
|
390
|
+
list(result.values.keys()),
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
return new_field_info
|