data-designer-engine 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (114) hide show
  1. data_designer/engine/__init__.py +2 -0
  2. data_designer/engine/_version.py +34 -0
  3. data_designer/engine/analysis/column_profilers/base.py +49 -0
  4. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +153 -0
  5. data_designer/engine/analysis/column_profilers/registry.py +22 -0
  6. data_designer/engine/analysis/column_statistics.py +145 -0
  7. data_designer/engine/analysis/dataset_profiler.py +149 -0
  8. data_designer/engine/analysis/errors.py +9 -0
  9. data_designer/engine/analysis/utils/column_statistics_calculations.py +234 -0
  10. data_designer/engine/analysis/utils/judge_score_processing.py +132 -0
  11. data_designer/engine/column_generators/__init__.py +2 -0
  12. data_designer/engine/column_generators/generators/__init__.py +2 -0
  13. data_designer/engine/column_generators/generators/base.py +122 -0
  14. data_designer/engine/column_generators/generators/embedding.py +35 -0
  15. data_designer/engine/column_generators/generators/expression.py +55 -0
  16. data_designer/engine/column_generators/generators/llm_completion.py +116 -0
  17. data_designer/engine/column_generators/generators/samplers.py +69 -0
  18. data_designer/engine/column_generators/generators/seed_dataset.py +144 -0
  19. data_designer/engine/column_generators/generators/validation.py +140 -0
  20. data_designer/engine/column_generators/registry.py +60 -0
  21. data_designer/engine/column_generators/utils/errors.py +15 -0
  22. data_designer/engine/column_generators/utils/generator_classification.py +43 -0
  23. data_designer/engine/column_generators/utils/judge_score_factory.py +58 -0
  24. data_designer/engine/column_generators/utils/prompt_renderer.py +100 -0
  25. data_designer/engine/compiler.py +97 -0
  26. data_designer/engine/configurable_task.py +71 -0
  27. data_designer/engine/dataset_builders/artifact_storage.py +283 -0
  28. data_designer/engine/dataset_builders/column_wise_builder.py +354 -0
  29. data_designer/engine/dataset_builders/errors.py +15 -0
  30. data_designer/engine/dataset_builders/multi_column_configs.py +46 -0
  31. data_designer/engine/dataset_builders/utils/__init__.py +2 -0
  32. data_designer/engine/dataset_builders/utils/concurrency.py +212 -0
  33. data_designer/engine/dataset_builders/utils/config_compiler.py +62 -0
  34. data_designer/engine/dataset_builders/utils/dag.py +62 -0
  35. data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +200 -0
  36. data_designer/engine/dataset_builders/utils/errors.py +15 -0
  37. data_designer/engine/dataset_builders/utils/progress_tracker.py +122 -0
  38. data_designer/engine/errors.py +51 -0
  39. data_designer/engine/model_provider.py +77 -0
  40. data_designer/engine/models/__init__.py +2 -0
  41. data_designer/engine/models/errors.py +300 -0
  42. data_designer/engine/models/facade.py +284 -0
  43. data_designer/engine/models/factory.py +42 -0
  44. data_designer/engine/models/litellm_overrides.py +179 -0
  45. data_designer/engine/models/parsers/__init__.py +2 -0
  46. data_designer/engine/models/parsers/errors.py +34 -0
  47. data_designer/engine/models/parsers/parser.py +235 -0
  48. data_designer/engine/models/parsers/postprocessors.py +93 -0
  49. data_designer/engine/models/parsers/tag_parsers.py +62 -0
  50. data_designer/engine/models/parsers/types.py +84 -0
  51. data_designer/engine/models/recipes/base.py +81 -0
  52. data_designer/engine/models/recipes/response_recipes.py +293 -0
  53. data_designer/engine/models/registry.py +151 -0
  54. data_designer/engine/models/telemetry.py +362 -0
  55. data_designer/engine/models/usage.py +73 -0
  56. data_designer/engine/models/utils.py +101 -0
  57. data_designer/engine/processing/ginja/__init__.py +2 -0
  58. data_designer/engine/processing/ginja/ast.py +65 -0
  59. data_designer/engine/processing/ginja/environment.py +463 -0
  60. data_designer/engine/processing/ginja/exceptions.py +56 -0
  61. data_designer/engine/processing/ginja/record.py +32 -0
  62. data_designer/engine/processing/gsonschema/__init__.py +2 -0
  63. data_designer/engine/processing/gsonschema/exceptions.py +15 -0
  64. data_designer/engine/processing/gsonschema/schema_transformers.py +83 -0
  65. data_designer/engine/processing/gsonschema/types.py +10 -0
  66. data_designer/engine/processing/gsonschema/validators.py +202 -0
  67. data_designer/engine/processing/processors/base.py +13 -0
  68. data_designer/engine/processing/processors/drop_columns.py +42 -0
  69. data_designer/engine/processing/processors/registry.py +25 -0
  70. data_designer/engine/processing/processors/schema_transform.py +71 -0
  71. data_designer/engine/processing/utils.py +169 -0
  72. data_designer/engine/registry/base.py +99 -0
  73. data_designer/engine/registry/data_designer_registry.py +39 -0
  74. data_designer/engine/registry/errors.py +12 -0
  75. data_designer/engine/resources/managed_dataset_generator.py +39 -0
  76. data_designer/engine/resources/managed_dataset_repository.py +197 -0
  77. data_designer/engine/resources/managed_storage.py +65 -0
  78. data_designer/engine/resources/resource_provider.py +77 -0
  79. data_designer/engine/resources/seed_reader.py +154 -0
  80. data_designer/engine/sampling_gen/column.py +91 -0
  81. data_designer/engine/sampling_gen/constraints.py +100 -0
  82. data_designer/engine/sampling_gen/data_sources/base.py +217 -0
  83. data_designer/engine/sampling_gen/data_sources/errors.py +12 -0
  84. data_designer/engine/sampling_gen/data_sources/sources.py +347 -0
  85. data_designer/engine/sampling_gen/entities/__init__.py +2 -0
  86. data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
  87. data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +90 -0
  88. data_designer/engine/sampling_gen/entities/email_address_utils.py +171 -0
  89. data_designer/engine/sampling_gen/entities/errors.py +10 -0
  90. data_designer/engine/sampling_gen/entities/national_id_utils.py +102 -0
  91. data_designer/engine/sampling_gen/entities/person.py +144 -0
  92. data_designer/engine/sampling_gen/entities/phone_number.py +128 -0
  93. data_designer/engine/sampling_gen/errors.py +26 -0
  94. data_designer/engine/sampling_gen/generator.py +122 -0
  95. data_designer/engine/sampling_gen/jinja_utils.py +64 -0
  96. data_designer/engine/sampling_gen/people_gen.py +199 -0
  97. data_designer/engine/sampling_gen/person_constants.py +56 -0
  98. data_designer/engine/sampling_gen/schema.py +147 -0
  99. data_designer/engine/sampling_gen/schema_builder.py +61 -0
  100. data_designer/engine/sampling_gen/utils.py +46 -0
  101. data_designer/engine/secret_resolver.py +82 -0
  102. data_designer/engine/testing/__init__.py +12 -0
  103. data_designer/engine/testing/stubs.py +133 -0
  104. data_designer/engine/testing/utils.py +20 -0
  105. data_designer/engine/validation.py +367 -0
  106. data_designer/engine/validators/__init__.py +19 -0
  107. data_designer/engine/validators/base.py +38 -0
  108. data_designer/engine/validators/local_callable.py +39 -0
  109. data_designer/engine/validators/python.py +254 -0
  110. data_designer/engine/validators/remote.py +89 -0
  111. data_designer/engine/validators/sql.py +65 -0
  112. data_designer_engine-0.4.0.dist-info/METADATA +50 -0
  113. data_designer_engine-0.4.0.dist-info/RECORD +114 -0
  114. data_designer_engine-0.4.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,81 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ import abc
7
+ from collections.abc import Callable
8
+ from typing import Generic, TypeVar
9
+
10
+ T = TypeVar("T")
11
+
12
+
13
+ class ResponseRecipe(abc.ABC, Generic[T]):
14
+ """Base class for defining response recipes.
15
+
16
+ Response recipes contain all necessary information for
17
+ getting an LLM to perform a particular common task,
18
+ like outputting code in a desired format or following
19
+ structured output.
20
+ """
21
+
22
+ @abc.abstractmethod
23
+ def _build_parser_fn(self) -> Callable[[str], T]:
24
+ """Build the recipe's output parser function."""
25
+ ...
26
+
27
+ @property
28
+ @abc.abstractmethod
29
+ def example_template(self) -> str: ...
30
+
31
+ @abc.abstractmethod
32
+ def serialize_output(self, output: T) -> str:
33
+ """Serialize an instance of the parser output."""
34
+ ...
35
+
36
+ @abc.abstractmethod
37
+ def deserialize_output(self, serialized_output: str) -> T:
38
+ """Deserialize a serialized instance of the parser output."""
39
+ ...
40
+
41
+ def __init__(self):
42
+ self._parse_fn = self._build_parser_fn()
43
+
44
+ @property
45
+ def task_instructions(self) -> str | None:
46
+ """Specifies task instructions.
47
+
48
+ These instructions lay out the particular task information the
49
+ LLM requires in order to carry out the function of the recipe.
50
+ """
51
+ return None
52
+
53
+ def parse(self, response: str) -> T:
54
+ """Apply the recipe's parser to a raw model output."""
55
+ return self._parse_fn(response)
56
+
57
+ def generate_response_example(self, example: T) -> str:
58
+ """Create a serialized response example that the parser would admit."""
59
+ return self.example_template.format(example=example)
60
+
61
+ def apply_recipe_to_user_prompt(self, user_prompt: str) -> str:
62
+ """Appends recipe specific task instructions if applicable.
63
+
64
+ Args:
65
+ user_prompt (str): User prompt to be appended with recipe specific task instructions if applicable.
66
+
67
+ Returns:
68
+ str: Final user prompt
69
+ """
70
+ return f"{user_prompt}\n\n{self.task_instructions}" if self.task_instructions else user_prompt
71
+
72
+ def apply_recipe_to_system_prompt(self, system_prompt: str | None) -> str:
73
+ """Appends recipe specific task instructions if applicable.
74
+
75
+ Args:
76
+ system_prompt (str): System prompt to be appended with recipe specific task instructions if applicable.
77
+
78
+ Returns:
79
+ str: Final system prompt
80
+ """
81
+ return system_prompt
@@ -0,0 +1,293 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ import json
7
+ from collections.abc import Callable
8
+
9
+ from pydantic import BaseModel
10
+
11
+ from data_designer.config.utils.code_lang import CodeLang
12
+ from data_designer.engine.models.parsers.errors import ParserException
13
+ from data_designer.engine.models.parsers.parser import LLMResponseParser
14
+ from data_designer.engine.models.parsers.postprocessors import (
15
+ StructuredDataBlock,
16
+ deserialize_json_code,
17
+ merge_text_blocks,
18
+ )
19
+ from data_designer.engine.models.parsers.types import CodeBlock
20
+ from data_designer.engine.models.recipes.base import (
21
+ ResponseRecipe,
22
+ )
23
+ from data_designer.engine.processing.gsonschema.validators import JSONSchemaValidationError, validate
24
+
25
+
26
+ class TextResponseRecipe(ResponseRecipe[str]):
27
+ """Default text-parser.
28
+
29
+ This parser is meant to cover the "pass-through" case of natural language LLM responses.
30
+ """
31
+
32
+ @property
33
+ def example_template(self) -> str:
34
+ return "{example}"
35
+
36
+ def serialize_output(self, output: str) -> str:
37
+ return output
38
+
39
+ def deserialize_output(self, serialized_output: str) -> str:
40
+ return serialized_output
41
+
42
+ def _build_parser_fn(self) -> Callable[[str], str]:
43
+ parser = LLMResponseParser(
44
+ postprocessors=[
45
+ merge_text_blocks,
46
+ ]
47
+ )
48
+
49
+ return lambda x: parser.parse(x).response
50
+
51
+
52
+ class StructuredResponseRecipe(ResponseRecipe[dict]):
53
+ """Recipe for structured responses.
54
+
55
+ This recipe is intended to cover the generic case of
56
+ prompting-based requests for structured data outputs,
57
+ and the structure in question is determined by a
58
+ provided JSON Schema.
59
+
60
+ The LLM's response us validated against the provided
61
+ JSON Schema, however the object returned is python
62
+ dictionary obtained from deserializing the LLM's
63
+ JSON response.
64
+ """
65
+
66
+ json_schema: dict
67
+ pruning: bool
68
+ no_extra_properties: bool
69
+
70
+ def __init__(
71
+ self,
72
+ json_schema: dict,
73
+ pruning: bool = True,
74
+ no_extra_properties: bool = True,
75
+ **kwargs,
76
+ ):
77
+ """Initialize StructuredResponseRecipe.
78
+
79
+ Args:
80
+ json_schema (dict): A target JSON schema that the LLM
81
+ should adhere to when making its response.
82
+ pruning (bool): If `True`, then any extra fields in the returned
83
+ JSON object will be removed. Otherwise, they are retained,
84
+ which could raise validation errors. Default=True
85
+ no_extra_properties (bool) If `True`, then validation will fail
86
+ if extra properties are encountered in the returned JSON response.
87
+ Default=True.
88
+ """
89
+ super().__init__(**kwargs)
90
+ self.json_schema = json_schema
91
+ self.pruning = pruning
92
+ self.no_extra_properties = no_extra_properties
93
+
94
+ @property
95
+ def task_instructions(self) -> str:
96
+ return (
97
+ "* Your response must be in JSON format.\n"
98
+ "* Your JSON response must be returned within a Markdown, ```json code fence.\n"
99
+ "* The JSON format is given as a JSON Schema description within <response_schema> markup tags.\n\n"
100
+ f"<response_schema>\n{self.schema}\n</response_schema>"
101
+ )
102
+
103
+ @property
104
+ def example_template(self) -> str:
105
+ return "```json\n{example}\n```"
106
+
107
+ def generate_response_example(self, example: dict) -> str:
108
+ return self.example_template.format(example=json.dumps(example))
109
+
110
+ @property
111
+ def schema(self) -> str:
112
+ return json.dumps(self.json_schema)
113
+
114
+ def serialize_output(self, output: dict) -> str:
115
+ return json.dumps(output, ensure_ascii=False)
116
+
117
+ def deserialize_output(self, serialized_output: str) -> dict:
118
+ return json.loads(serialized_output)
119
+
120
+ @property
121
+ def _validate_args(self):
122
+ return {
123
+ "schema": self.json_schema,
124
+ "pruning": self.pruning,
125
+ "no_extra_properties": self.no_extra_properties,
126
+ }
127
+
128
+ def _build_parser_fn(self) -> Callable[[str], dict]:
129
+ parser = LLMResponseParser(
130
+ postprocessors=[
131
+ merge_text_blocks,
132
+ deserialize_json_code,
133
+ ]
134
+ )
135
+
136
+ def parse_fn(response: str) -> dict:
137
+ try:
138
+ obj = parser.parse(response).filter([StructuredDataBlock]).parsed.pop().obj
139
+ return validate(obj, **self._validate_args)
140
+ except IndexError:
141
+ raise ParserException(
142
+ "No parsable JSON structure within ```json markdown fence.",
143
+ source=response,
144
+ ) from None
145
+ except JSONSchemaValidationError as exc:
146
+ raise ParserException(
147
+ "Response doesn't match requested <response_schema>\n" + str(exc),
148
+ source=response,
149
+ ) from None
150
+
151
+ return parse_fn
152
+
153
+
154
+ class PydanticResponseRecipe(ResponseRecipe[BaseModel]):
155
+ """Recipe for Pydantic responses.
156
+
157
+ This recipe covers the case that we have a Pydantic
158
+ data type (BaseModel) already specified in the runtime
159
+ making LLM calls, and we want to obtain an object of
160
+ that same data type as the output from the parser.
161
+
162
+ This recipe operates in a very similar fashion to
163
+ `StructuredResponseRecipe` except that it is initialized
164
+ from a Pydantic `BaseModel` and does the extra step of
165
+ validating against that `BaseModel` using
166
+ `BaseModel.model_validate` for its return.
167
+ """
168
+
169
+ data_type: type[BaseModel]
170
+
171
+ def __init__(self, data_type: type[BaseModel], **kwargs):
172
+ """Initialize a PydanticResponseRecipe.
173
+
174
+ Args:
175
+ data_type (type(BaseModel)): The target Pydantic BaseModel
176
+ subclass that the LLM should adhere to in its response,
177
+ and defines the output type of the parser.
178
+ """
179
+ super().__init__(**kwargs)
180
+ self.data_type = data_type
181
+
182
+ @property
183
+ def schema(self) -> str:
184
+ return json.dumps(self.data_type.model_json_schema())
185
+
186
+ @property
187
+ def task_instructions(self) -> str:
188
+ return (
189
+ "* Your response must be in JSON format.\n"
190
+ "* Your JSON response must be returned within a Markdown, ```json code fence.\n"
191
+ "* The JSON format is given as a JSON Schema description within <response_schema> markup tags.\n\n"
192
+ f"<response_schema>\n{self.schema}\n</response_schema>"
193
+ )
194
+
195
+ @property
196
+ def example_template(self) -> str:
197
+ return "```json\n{example}\n```"
198
+
199
+ def generate_response_example(self, example: BaseModel) -> str:
200
+ return self.example_template.format(example=example.model_dump_json())
201
+
202
+ def serialize_output(self, output: BaseModel) -> str:
203
+ return output.model_dump_json()
204
+
205
+ def deserialize_output(self, serialized_output: str) -> BaseModel:
206
+ return self.data_type.model_validate_json(serialized_output)
207
+
208
+ def _build_parser_fn(self) -> Callable[[str], BaseModel]:
209
+ parser = LLMResponseParser(
210
+ postprocessors=[
211
+ merge_text_blocks,
212
+ deserialize_json_code,
213
+ ]
214
+ )
215
+
216
+ def parse_fn(response: str) -> BaseModel:
217
+ try:
218
+ obj = parser.parse(response).filter([StructuredDataBlock]).parsed.pop().obj
219
+ return self.data_type.model_validate(obj)
220
+ except IndexError:
221
+ raise ParserException(
222
+ "No parsable JSON structure within ```json markdown fence.",
223
+ source=response,
224
+ ) from None
225
+ except Exception as exc:
226
+ raise ParserException(
227
+ "Response doesn't match requested <response_schema>\n" + str(exc),
228
+ source=response,
229
+ ) from None
230
+
231
+ return parse_fn
232
+
233
+
234
+ class CodeResponseRecipe(ResponseRecipe[str]):
235
+ """Obtain a code snippet from an LLM."""
236
+
237
+ def __init__(self, syntax: str | CodeLang, **kwargs):
238
+ """Initialize a CodeResponseRecipe.
239
+
240
+ Args:
241
+ syntax (str | CodeLang): The code syntax that the
242
+ LLM should adhere to, e.g. `"python"`, `"sql"`, etc.
243
+ """
244
+ super().__init__(**kwargs)
245
+ self.syntax = CodeLang.parse_lang(syntax)
246
+
247
+ @property
248
+ def task_instructions(self) -> str:
249
+ return (
250
+ f"* Your response must be code written in {self.syntax}.\n"
251
+ "* You will follow accepted and common syntax and best-practices.\n"
252
+ f"* Your response will be given in markdown code fences specifying the correct language.\n"
253
+ "* Only respond with a SINGLE code block."
254
+ )
255
+
256
+ @property
257
+ def example_template(self) -> str:
258
+ return f"```{self.syntax}\n{{example}}\n```\n"
259
+
260
+ def serialize_output(self, output: str) -> str:
261
+ return output
262
+
263
+ def deserialize_output(self, serialized_output: str) -> str:
264
+ return serialized_output
265
+
266
+ def _build_parser_fn(self) -> Callable[[str], str]:
267
+ parser = LLMResponseParser(
268
+ postprocessors=[
269
+ merge_text_blocks,
270
+ ]
271
+ )
272
+
273
+ def parse_fn(response: str) -> str:
274
+ try:
275
+ code_block = parser.parse(response).filter([CodeBlock]).parsed.pop()
276
+ # For the type checker -- should always pass
277
+ assert isinstance(code_block, CodeBlock)
278
+ except IndexError:
279
+ raise ParserException(
280
+ "No parsable code response.",
281
+ source=response,
282
+ ) from None
283
+
284
+ # Only report this as a parser error if there was a mismatch.
285
+ if code_block.code_lang and code_block.code_lang != self.syntax:
286
+ raise ParserException(
287
+ f"Responded with code not matching the requested syntax ({self.syntax}).",
288
+ source=response,
289
+ )
290
+
291
+ return code_block.code.strip()
292
+
293
+ return parse_fn
@@ -0,0 +1,151 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ import logging
7
+ from collections.abc import Callable
8
+ from typing import TYPE_CHECKING
9
+
10
+ from data_designer.config.models import GenerationType, ModelConfig
11
+ from data_designer.engine.model_provider import ModelProvider, ModelProviderRegistry
12
+ from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenUsageStats
13
+ from data_designer.engine.secret_resolver import SecretResolver
14
+
15
+ if TYPE_CHECKING:
16
+ from data_designer.engine.models.facade import ModelFacade
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class ModelRegistry:
22
+ def __init__(
23
+ self,
24
+ *,
25
+ secret_resolver: SecretResolver,
26
+ model_provider_registry: ModelProviderRegistry,
27
+ model_configs: list[ModelConfig] | None = None,
28
+ model_facade_factory: Callable[[ModelConfig, SecretResolver, ModelProviderRegistry], ModelFacade] | None = None,
29
+ ):
30
+ self._secret_resolver = secret_resolver
31
+ self._model_provider_registry = model_provider_registry
32
+ self._model_facade_factory = model_facade_factory
33
+ self._model_configs: dict[str, ModelConfig] = {}
34
+ self._models: dict[str, ModelFacade] = {}
35
+ self._set_model_configs(model_configs)
36
+
37
+ @property
38
+ def model_configs(self) -> dict[str, ModelConfig]:
39
+ return self._model_configs
40
+
41
+ @property
42
+ def models(self) -> dict[str, ModelFacade]:
43
+ return self._models
44
+
45
+ def register_model_configs(self, model_configs: list[ModelConfig]) -> None:
46
+ """Register a new Model configuration at runtime.
47
+
48
+ Args:
49
+ model_config: A new Model configuration to register. If an
50
+ Model configuration already exists in the registry
51
+ with the same name, then it will be overwritten.
52
+ """
53
+ self._set_model_configs(list(self._model_configs.values()) + model_configs)
54
+
55
+ def get_model(self, *, model_alias: str) -> ModelFacade:
56
+ # Check if model config exists first
57
+ if model_alias not in self._model_configs:
58
+ raise ValueError(f"No model config with alias {model_alias!r} found!")
59
+
60
+ # Lazy initialization: only create model facade when first requested
61
+ if model_alias not in self._models:
62
+ self._models[model_alias] = self._get_model(self._model_configs[model_alias])
63
+
64
+ return self._models[model_alias]
65
+
66
+ def get_model_config(self, *, model_alias: str) -> ModelConfig:
67
+ if model_alias not in self._model_configs:
68
+ raise ValueError(f"No model config with alias {model_alias!r} found!")
69
+ return self._model_configs[model_alias]
70
+
71
+ def get_model_usage_stats(self, total_time_elapsed: float) -> dict[str, dict]:
72
+ return {
73
+ model.model_name: model.usage_stats.get_usage_stats(total_time_elapsed=total_time_elapsed)
74
+ for model in self._models.values()
75
+ if model.usage_stats.has_usage
76
+ }
77
+
78
+ def get_model_usage_snapshot(self) -> dict[str, ModelUsageStats]:
79
+ return {
80
+ model.model_name: model.usage_stats.model_copy(deep=True)
81
+ for model in self._models.values()
82
+ if model.usage_stats.has_usage
83
+ }
84
+
85
+ def get_usage_deltas(self, snapshot: dict[str, ModelUsageStats]) -> dict[str, ModelUsageStats]:
86
+ deltas = {}
87
+ for model_name, current in self.get_model_usage_snapshot().items():
88
+ prev = snapshot.get(model_name)
89
+ delta_input = current.token_usage.input_tokens - (prev.token_usage.input_tokens if prev else 0)
90
+ delta_output = current.token_usage.output_tokens - (prev.token_usage.output_tokens if prev else 0)
91
+ delta_successful = current.request_usage.successful_requests - (
92
+ prev.request_usage.successful_requests if prev else 0
93
+ )
94
+ delta_failed = current.request_usage.failed_requests - (prev.request_usage.failed_requests if prev else 0)
95
+
96
+ if delta_input > 0 or delta_output > 0 or delta_successful > 0 or delta_failed > 0:
97
+ deltas[model_name] = ModelUsageStats(
98
+ token_usage=TokenUsageStats(input_tokens=delta_input, output_tokens=delta_output),
99
+ request_usage=RequestUsageStats(successful_requests=delta_successful, failed_requests=delta_failed),
100
+ )
101
+ return deltas
102
+
103
+ def get_model_provider(self, *, model_alias: str) -> ModelProvider:
104
+ model_config = self.get_model_config(model_alias=model_alias)
105
+ return self._model_provider_registry.get_provider(model_config.provider)
106
+
107
+ def run_health_check(self, model_aliases: list[str]) -> None:
108
+ logger.info("🩺 Running health checks for models...")
109
+ for model_alias in model_aliases:
110
+ model_config = self.get_model_config(model_alias=model_alias)
111
+ if model_config.skip_health_check:
112
+ logger.info(f" |-- ⏭️ Skipping health check for model alias {model_alias!r} (skip_health_check=True)")
113
+ continue
114
+
115
+ model = self.get_model(model_alias=model_alias)
116
+ logger.info(
117
+ f" |-- 👀 Checking {model.model_name!r} in provider named {model.model_provider_name!r} for model alias {model.model_alias!r}..."
118
+ )
119
+ try:
120
+ if model.model_generation_type == GenerationType.EMBEDDING:
121
+ model.generate_text_embeddings(
122
+ input_texts=["Hello!"],
123
+ skip_usage_tracking=True,
124
+ purpose="running health checks",
125
+ )
126
+ elif model.model_generation_type == GenerationType.CHAT_COMPLETION:
127
+ model.generate(
128
+ prompt="Hello!",
129
+ parser=lambda x: x,
130
+ system_prompt="You are a helpful assistant.",
131
+ max_correction_steps=0,
132
+ max_conversation_restarts=0,
133
+ skip_usage_tracking=True,
134
+ purpose="running health checks",
135
+ )
136
+ else:
137
+ raise ValueError(f"Unsupported generation type: {model.model_generation_type}")
138
+ logger.info(" |-- ✅ Passed!")
139
+ except Exception as e:
140
+ logger.error(" |-- ❌ Failed!")
141
+ raise e
142
+
143
+ def _set_model_configs(self, model_configs: list[ModelConfig]) -> None:
144
+ model_configs = model_configs or []
145
+ self._model_configs = {mc.alias: mc for mc in model_configs}
146
+ # Models are now lazily initialized in get_model() when first requested
147
+
148
+ def _get_model(self, model_config: ModelConfig) -> ModelFacade:
149
+ if self._model_facade_factory is None:
150
+ raise RuntimeError("ModelRegistry was not initialized with a model_facade_factory")
151
+ return self._model_facade_factory(model_config, self._secret_resolver, self._model_provider_registry)