digitalkin 0.3.2.dev2__py3-none-any.whl → 0.3.2.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,573 @@
1
+ """Setup model types with dynamic schema resolution."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import copy
6
+ import types
7
+ import typing
8
+ from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin
9
+
10
+ from pydantic import BaseModel, ConfigDict, PrivateAttr, create_model
11
+
12
+ from digitalkin.logger import logger
13
+ from digitalkin.models.module.tool_cache import ToolCache
14
+ from digitalkin.models.module.tool_reference import ToolReference
15
+ from digitalkin.utils.dynamic_schema import (
16
+ DynamicField,
17
+ get_fetchers,
18
+ has_dynamic,
19
+ resolve_safe,
20
+ )
21
+
22
+ if TYPE_CHECKING:
23
+ from pydantic.fields import FieldInfo
24
+
25
+ from digitalkin.services.registry import RegistryStrategy
26
+
27
+ SetupModelT = TypeVar("SetupModelT", bound="SetupModel")
28
+
29
+
30
+ class SetupModel(BaseModel, Generic[SetupModelT]):
31
+ """Base definition of setup model showing mandatory root fields.
32
+
33
+ Optionally, the setup model can define a config option in json_schema_extra
34
+ to be used to initialize the Kin. Supports dynamic schema providers for
35
+ runtime value generation.
36
+
37
+ The tool_cache is populated during run_config_setup and contains resolved
38
+ ModuleInfo indexed by slug. It is validated during initialize.
39
+
40
+ Attributes:
41
+ model_fields: Inherited from Pydantic BaseModel, contains field definitions.
42
+
43
+ See Also:
44
+ - Documentation: docs/api/dynamic_schema.md
45
+ - Tests: tests/modules/test_setup_model.py
46
+ """
47
+
48
+ _clean_model_cache: ClassVar[dict[tuple[type, bool, bool], type]] = {}
49
+ _tool_cache: ToolCache = PrivateAttr(default_factory=ToolCache)
50
+
51
+ @classmethod
52
+ async def get_clean_model(
53
+ cls,
54
+ *,
55
+ config_fields: bool,
56
+ hidden_fields: bool,
57
+ force: bool = False,
58
+ ) -> type[SetupModelT]:
59
+ """Dynamically builds and returns a new BaseModel subclass with filtered fields.
60
+
61
+ This method filters fields based on their `json_schema_extra` metadata:
62
+ - Fields with `{"config": True}` are included only when `config_fields=True`
63
+ - Fields with `{"hidden": True}` are included only when `hidden_fields=True`
64
+
65
+ When `force=True`, fields with dynamic schema providers will have their
66
+ providers called to fetch fresh values for schema metadata like enums.
67
+ This includes recursively processing nested BaseModel fields.
68
+
69
+ Args:
70
+ config_fields: If True, include fields marked with `{"config": True}`.
71
+ These are typically initial configuration fields.
72
+ hidden_fields: If True, include fields marked with `{"hidden": True}`.
73
+ These are typically runtime-only fields not shown in initial config.
74
+ force: If True, refresh dynamic schema fields by calling their providers.
75
+ Use this when you need up-to-date values from external sources like
76
+ databases or APIs. Default is False for performance.
77
+
78
+ Returns:
79
+ A new BaseModel subclass with filtered fields.
80
+ """
81
+ # Check cache for non-forced requests
82
+ cache_key = (cls, config_fields, hidden_fields)
83
+ if not force and cache_key in cls._clean_model_cache:
84
+ return cast("type[SetupModelT]", cls._clean_model_cache[cache_key])
85
+
86
+ clean_fields: dict[str, Any] = {}
87
+
88
+ for name, field_info in cls.model_fields.items():
89
+ extra = field_info.json_schema_extra or {}
90
+ is_config = bool(extra.get("config", False)) if isinstance(extra, dict) else False
91
+ is_hidden = bool(extra.get("hidden", False)) if isinstance(extra, dict) else False
92
+
93
+ # Skip config unless explicitly included
94
+ if is_config and not config_fields:
95
+ logger.debug("Skipping '%s' (config-only)", name)
96
+ continue
97
+
98
+ # Skip hidden unless explicitly included
99
+ if is_hidden and not hidden_fields:
100
+ logger.debug("Skipping '%s' (hidden-only)", name)
101
+ continue
102
+
103
+ # Refresh dynamic schema fields when force=True
104
+ current_field_info = field_info
105
+ current_annotation = field_info.annotation
106
+
107
+ if force:
108
+ # Check if this field has DynamicField metadata
109
+ if has_dynamic(field_info):
110
+ current_field_info = await cls._refresh_field_schema(name, field_info)
111
+
112
+ # Check if the annotation is a nested BaseModel that might have dynamic fields
113
+ nested_model = cls._get_base_model_type(current_annotation)
114
+ if nested_model is not None:
115
+ refreshed_nested = await cls._refresh_nested_model(nested_model)
116
+ if refreshed_nested is not nested_model:
117
+ # Update annotation to use refreshed nested model
118
+ current_annotation = refreshed_nested
119
+ # Create new field_info with updated annotation (deep copy for safety)
120
+ current_field_info = copy.deepcopy(current_field_info)
121
+ current_field_info.annotation = current_annotation
122
+
123
+ clean_fields[name] = (current_annotation, current_field_info)
124
+
125
+ # Dynamically create a model e.g. "SetupModel"
126
+ m = create_model(
127
+ f"{cls.__name__}",
128
+ __base__=BaseModel,
129
+ __config__=ConfigDict(arbitrary_types_allowed=True),
130
+ **clean_fields,
131
+ )
132
+
133
+ # Cache for non-forced requests
134
+ if not force:
135
+ cls._clean_model_cache[cache_key] = m
136
+
137
+ return cast("type[SetupModelT]", m)
138
+
139
+ @classmethod
140
+ def _get_base_model_type(cls, annotation: type | None) -> type[BaseModel] | None:
141
+ """Extract BaseModel type from an annotation.
142
+
143
+ Handles direct types, Optional, Union, list, dict, set, tuple, and other generics.
144
+
145
+ Args:
146
+ annotation: The type annotation to inspect.
147
+
148
+ Returns:
149
+ The BaseModel subclass if found, None otherwise.
150
+ """
151
+ if annotation is None:
152
+ return None
153
+
154
+ # Direct BaseModel subclass check
155
+ if isinstance(annotation, type) and issubclass(annotation, BaseModel):
156
+ return annotation
157
+
158
+ origin = get_origin(annotation)
159
+ if origin is None:
160
+ return None
161
+
162
+ args = get_args(annotation)
163
+ return cls._extract_base_model_from_args(origin, args)
164
+
165
+ @classmethod
166
+ def _extract_base_model_from_args(
167
+ cls,
168
+ origin: type,
169
+ args: tuple[type, ...],
170
+ ) -> type[BaseModel] | None:
171
+ """Extract BaseModel from generic type arguments.
172
+
173
+ Args:
174
+ origin: The generic origin type (list, dict, Union, etc.).
175
+ args: The type arguments.
176
+
177
+ Returns:
178
+ The BaseModel subclass if found, None otherwise.
179
+ """
180
+ # Union/Optional: check each arg (supports both typing.Union and types.UnionType)
181
+ # Python 3.10+ uses types.UnionType for X | Y syntax
182
+ if origin is typing.Union or origin is types.UnionType:
183
+ return cls._find_base_model_in_args(args)
184
+
185
+ # list, set, frozenset: check first arg
186
+ if origin in {list, set, frozenset} and args:
187
+ return cls._check_base_model(args[0])
188
+
189
+ # dict: check value type (second arg)
190
+ dict_value_index = 1
191
+ if origin is dict and len(args) > dict_value_index:
192
+ return cls._check_base_model(args[dict_value_index])
193
+
194
+ # tuple: check first non-ellipsis arg
195
+ if origin is tuple:
196
+ return cls._find_base_model_in_args(args, skip_ellipsis=True)
197
+
198
+ return None
199
+
200
+ @classmethod
201
+ def _check_base_model(cls, arg: type) -> type[BaseModel] | None:
202
+ """Check if arg is a BaseModel subclass.
203
+
204
+ Returns:
205
+ The BaseModel subclass if arg is one, None otherwise.
206
+ """
207
+ if isinstance(arg, type) and issubclass(arg, BaseModel):
208
+ return arg
209
+ return None
210
+
211
+ @classmethod
212
+ def _find_base_model_in_args(
213
+ cls,
214
+ args: tuple[type, ...],
215
+ *,
216
+ skip_ellipsis: bool = False,
217
+ ) -> type[BaseModel] | None:
218
+ """Find first BaseModel in args.
219
+
220
+ Returns:
221
+ The first BaseModel subclass found, None otherwise.
222
+ """
223
+ for arg in args:
224
+ if arg is type(None):
225
+ continue
226
+ if skip_ellipsis and arg is ...:
227
+ continue
228
+ result = cls._check_base_model(arg)
229
+ if result is not None:
230
+ return result
231
+ return None
232
+
233
+ @classmethod
234
+ async def _refresh_nested_model(cls, model_cls: type[BaseModel]) -> type[BaseModel]:
235
+ """Refresh dynamic fields in a nested BaseModel.
236
+
237
+ Creates a new model class with all DynamicField metadata resolved.
238
+
239
+ Args:
240
+ model_cls: The nested model class to refresh.
241
+
242
+ Returns:
243
+ A new model class with refreshed fields, or the original if no changes.
244
+ """
245
+ has_changes = False
246
+ clean_fields: dict[str, Any] = {}
247
+
248
+ for name, field_info in model_cls.model_fields.items():
249
+ current_field_info = field_info
250
+ current_annotation = field_info.annotation
251
+
252
+ # Check if field has DynamicField metadata
253
+ if has_dynamic(field_info):
254
+ current_field_info = await cls._refresh_field_schema(name, field_info)
255
+ has_changes = True
256
+
257
+ # Recursively check nested models
258
+ nested_model = cls._get_base_model_type(current_annotation)
259
+ if nested_model is not None:
260
+ refreshed_nested = await cls._refresh_nested_model(nested_model)
261
+ if refreshed_nested is not nested_model:
262
+ current_annotation = refreshed_nested
263
+ current_field_info = copy.deepcopy(current_field_info)
264
+ current_field_info.annotation = current_annotation
265
+ has_changes = True
266
+
267
+ clean_fields[name] = (current_annotation, current_field_info)
268
+
269
+ if not has_changes:
270
+ return model_cls
271
+
272
+ # Create new model with refreshed fields
273
+ logger.debug("Creating refreshed nested model for '%s'", model_cls.__name__)
274
+ return create_model(
275
+ model_cls.__name__,
276
+ __base__=BaseModel,
277
+ __config__=ConfigDict(arbitrary_types_allowed=True),
278
+ **clean_fields,
279
+ )
280
+
281
+ @classmethod
282
+ async def _refresh_field_schema(cls, field_name: str, field_info: FieldInfo) -> FieldInfo:
283
+ """Refresh a field's json_schema_extra with fresh values from dynamic providers.
284
+
285
+ This method calls all dynamic providers registered for a field (via Annotated
286
+ metadata) and creates a new FieldInfo with the resolved values. The original
287
+ field_info is not modified.
288
+
289
+ Uses `resolve_safe()` for structured error handling, allowing partial success
290
+ when some fetchers fail. Successfully resolved values are still applied.
291
+
292
+ Args:
293
+ field_name: The name of the field being refreshed (used for logging).
294
+ field_info: The original FieldInfo object containing the dynamic providers.
295
+
296
+ Returns:
297
+ A new FieldInfo object with the same attributes as the original, but with
298
+ `json_schema_extra` containing resolved values and Dynamic metadata removed.
299
+
300
+ Note:
301
+ If all fetchers fail, the original field_info is returned unchanged.
302
+ If some fetchers fail, successfully resolved values are still applied.
303
+ """
304
+ fetchers = get_fetchers(field_info)
305
+
306
+ if not fetchers:
307
+ return field_info
308
+
309
+ fetcher_keys = list(fetchers.keys())
310
+ logger.debug(
311
+ "Refreshing dynamic schema for field '%s' with fetchers: %s",
312
+ field_name,
313
+ fetcher_keys,
314
+ extra={"field_name": field_name, "fetcher_keys": fetcher_keys},
315
+ )
316
+
317
+ # Resolve all fetchers with structured error handling
318
+ result = await resolve_safe(fetchers)
319
+
320
+ # Log any errors that occurred with full details
321
+ if result.errors:
322
+ for key, error in result.errors.items():
323
+ logger.warning(
324
+ "Failed to resolve '%s' for field '%s': %s: %s",
325
+ key,
326
+ field_name,
327
+ type(error).__name__,
328
+ str(error) or "(no message)",
329
+ extra={
330
+ "field_name": field_name,
331
+ "fetcher_key": key,
332
+ "error_type": type(error).__name__,
333
+ "error_message": str(error),
334
+ "error_repr": repr(error),
335
+ },
336
+ )
337
+
338
+ # If no values were resolved, return original field_info
339
+ if not result.values:
340
+ logger.warning(
341
+ "All fetchers failed for field '%s', keeping original",
342
+ field_name,
343
+ )
344
+ return field_info
345
+
346
+ # Build new json_schema_extra with resolved values merged
347
+ extra = field_info.json_schema_extra or {}
348
+ new_extra = {**extra, **result.values} if isinstance(extra, dict) else result.values
349
+
350
+ # Create a deep copy of the FieldInfo to avoid shared mutable state
351
+ new_field_info = copy.deepcopy(field_info)
352
+ new_field_info.json_schema_extra = new_extra
353
+
354
+ # Remove Dynamic from metadata (it's been resolved)
355
+ new_metadata = [m for m in new_field_info.metadata if not isinstance(m, DynamicField)]
356
+ new_field_info.metadata = new_metadata
357
+
358
+ logger.debug(
359
+ "Refreshed '%s' with dynamic values: %s",
360
+ field_name,
361
+ list(result.values.keys()),
362
+ )
363
+
364
+ return new_field_info
365
+
366
+ def resolve_tool_references(self, registry: RegistryStrategy) -> None:
367
+ """Resolve all ToolReference fields in this setup instance.
368
+
369
+ Recursively walks through all fields, including nested BaseModel instances,
370
+ and resolves any ToolReference fields using the provided registry.
371
+
372
+ Args:
373
+ registry: Registry service to use for resolution.
374
+ """
375
+ self._resolve_tool_references_recursive(self, registry)
376
+
377
+ @classmethod
378
+ def _resolve_tool_references_recursive(
379
+ cls,
380
+ model_instance: BaseModel,
381
+ registry: RegistryStrategy,
382
+ ) -> None:
383
+ """Recursively resolve ToolReference fields in a model instance.
384
+
385
+ Args:
386
+ model_instance: The model instance to process.
387
+ registry: Registry service to use for resolution.
388
+ """
389
+ for field_name, field_value in model_instance.__dict__.items():
390
+ if field_value is None:
391
+ continue
392
+
393
+ cls._resolve_field_value(field_name, field_value, registry)
394
+
395
+ @classmethod
396
+ def _resolve_field_value(
397
+ cls,
398
+ field_name: str,
399
+ field_value: BaseModel | ToolReference | list | dict,
400
+ registry: RegistryStrategy,
401
+ ) -> None:
402
+ """Resolve a single field value, handling different types.
403
+
404
+ Args:
405
+ field_name: Name of the field for logging.
406
+ field_value: The value to process.
407
+ registry: Registry service to use for resolution.
408
+ """
409
+ if isinstance(field_value, ToolReference):
410
+ cls._resolve_single_tool_reference(field_name, field_value, registry)
411
+ elif isinstance(field_value, BaseModel):
412
+ cls._resolve_tool_references_recursive(field_value, registry)
413
+ elif isinstance(field_value, list):
414
+ cls._resolve_list_items(field_value, registry)
415
+ elif isinstance(field_value, dict):
416
+ cls._resolve_dict_values(field_value, registry)
417
+
418
+ @classmethod
419
+ def _resolve_single_tool_reference(
420
+ cls,
421
+ field_name: str,
422
+ tool_ref: ToolReference,
423
+ registry: RegistryStrategy,
424
+ ) -> None:
425
+ """Resolve a single ToolReference instance.
426
+
427
+ Args:
428
+ field_name: Name of the field for logging.
429
+ tool_ref: The ToolReference instance.
430
+ registry: Registry service to use for resolution.
431
+ """
432
+ try:
433
+ tool_ref.resolve(registry)
434
+ logger.debug(
435
+ "Resolved ToolReference field '%s'",
436
+ field_name,
437
+ extra={"field_name": field_name, "mode": tool_ref.config.mode.value},
438
+ )
439
+ except Exception:
440
+ logger.exception(
441
+ "Failed to resolve ToolReference field '%s'",
442
+ field_name,
443
+ extra={"field_name": field_name, "config": tool_ref.config.model_dump()},
444
+ )
445
+
446
+ @classmethod
447
+ def _resolve_list_items(
448
+ cls,
449
+ items: list,
450
+ registry: RegistryStrategy,
451
+ ) -> None:
452
+ """Resolve ToolReference instances in a list.
453
+
454
+ Args:
455
+ items: List of items to process.
456
+ registry: Registry service to use for resolution.
457
+ """
458
+ for item in items:
459
+ if isinstance(item, ToolReference):
460
+ cls._resolve_single_tool_reference("list_item", item, registry)
461
+ elif isinstance(item, BaseModel):
462
+ cls._resolve_tool_references_recursive(item, registry)
463
+
464
+ @classmethod
465
+ def _resolve_dict_values(
466
+ cls,
467
+ mapping: dict,
468
+ registry: RegistryStrategy,
469
+ ) -> None:
470
+ """Resolve ToolReference instances in a dict's values.
471
+
472
+ Args:
473
+ mapping: Dict to process.
474
+ registry: Registry service to use for resolution.
475
+ """
476
+ for item in mapping.values():
477
+ if isinstance(item, ToolReference):
478
+ cls._resolve_single_tool_reference("dict_value", item, registry)
479
+ elif isinstance(item, BaseModel):
480
+ cls._resolve_tool_references_recursive(item, registry)
481
+
482
+ @property
483
+ def tool_cache(self) -> ToolCache:
484
+ """Get the tool cache for this setup instance.
485
+
486
+ Returns:
487
+ The ToolCache containing resolved tools.
488
+ """
489
+ return self._tool_cache
490
+
491
+ def build_tool_cache(self) -> ToolCache:
492
+ """Build the tool cache from resolved ToolReferences.
493
+
494
+ This should be called during run_config_setup after resolve_tool_references.
495
+ It walks all ToolReference fields and adds resolved ones to the cache.
496
+
497
+ Returns:
498
+ The populated ToolCache.
499
+ """
500
+ self._build_tool_cache_recursive(self)
501
+ logger.debug(
502
+ "Tool cache built",
503
+ extra={"slugs": self._tool_cache.list_slugs()},
504
+ )
505
+ return self._tool_cache
506
+
507
+ def _build_tool_cache_recursive(self, model_instance: BaseModel) -> None:
508
+ """Recursively build tool cache from model fields.
509
+
510
+ Args:
511
+ model_instance: The model instance to process.
512
+ """
513
+ for field_name, field_value in model_instance.__dict__.items():
514
+ if field_value is None:
515
+ continue
516
+
517
+ if isinstance(field_value, ToolReference):
518
+ self._add_tool_reference_to_cache(field_name, field_value)
519
+ elif isinstance(field_value, BaseModel):
520
+ self._build_tool_cache_recursive(field_value)
521
+ elif isinstance(field_value, list):
522
+ self._build_tool_cache_from_list(field_value)
523
+ elif isinstance(field_value, dict):
524
+ self._build_tool_cache_from_dict(field_value)
525
+
526
+ def _add_tool_reference_to_cache(self, field_name: str, tool_ref: ToolReference) -> None:
527
+ """Add a resolved ToolReference to the cache.
528
+
529
+ Args:
530
+ field_name: Name of the field (used as fallback slug).
531
+ tool_ref: The ToolReference instance.
532
+ """
533
+ if tool_ref.module_info:
534
+ # Use slug from config, or field name as fallback
535
+ slug = tool_ref.slug or field_name
536
+ self._tool_cache.add(slug, tool_ref.module_info)
537
+
538
+ def _build_tool_cache_from_list(self, items: list) -> None:
539
+ """Build tool cache from list items.
540
+
541
+ Args:
542
+ items: List of items to process.
543
+ """
544
+ for idx, item in enumerate(items):
545
+ if isinstance(item, ToolReference):
546
+ self._add_tool_reference_to_cache(f"list_{idx}", item)
547
+ elif isinstance(item, BaseModel):
548
+ self._build_tool_cache_recursive(item)
549
+
550
+ def _build_tool_cache_from_dict(self, mapping: dict) -> None:
551
+ """Build tool cache from dict values.
552
+
553
+ Args:
554
+ mapping: Dict to process.
555
+ """
556
+ for key, item in mapping.items():
557
+ if isinstance(item, ToolReference):
558
+ self._add_tool_reference_to_cache(str(key), item)
559
+ elif isinstance(item, BaseModel):
560
+ self._build_tool_cache_recursive(item)
561
+
562
+ def validate_tool_cache(self, registry: RegistryStrategy) -> list[str]:
563
+ """Validate all cached tools are still available.
564
+
565
+ Should be called during initialize to ensure tools are still valid.
566
+
567
+ Args:
568
+ registry: Registry to validate against.
569
+
570
+ Returns:
571
+ List of slugs that are no longer valid.
572
+ """
573
+ return self._tool_cache.validate(registry)