data-designer 0.1.5__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. data_designer/_version.py +2 -2
  2. data_designer/cli/README.md +15 -1
  3. data_designer/cli/commands/download.py +56 -0
  4. data_designer/cli/commands/list.py +4 -18
  5. data_designer/cli/controllers/__init__.py +2 -1
  6. data_designer/cli/controllers/download_controller.py +217 -0
  7. data_designer/cli/controllers/model_controller.py +4 -3
  8. data_designer/cli/forms/field.py +65 -19
  9. data_designer/cli/forms/model_builder.py +251 -44
  10. data_designer/cli/main.py +11 -1
  11. data_designer/cli/repositories/persona_repository.py +88 -0
  12. data_designer/cli/services/__init__.py +2 -1
  13. data_designer/cli/services/download_service.py +97 -0
  14. data_designer/cli/ui.py +131 -0
  15. data_designer/cli/utils.py +34 -0
  16. data_designer/config/analysis/__init__.py +2 -0
  17. data_designer/config/analysis/column_profilers.py +75 -7
  18. data_designer/config/analysis/column_statistics.py +192 -48
  19. data_designer/config/analysis/dataset_profiler.py +23 -5
  20. data_designer/config/analysis/utils/reporting.py +3 -3
  21. data_designer/config/base.py +3 -3
  22. data_designer/config/column_configs.py +27 -6
  23. data_designer/config/column_types.py +24 -17
  24. data_designer/config/config_builder.py +36 -27
  25. data_designer/config/data_designer_config.py +7 -7
  26. data_designer/config/datastore.py +6 -6
  27. data_designer/config/default_model_settings.py +27 -34
  28. data_designer/config/exports.py +8 -0
  29. data_designer/config/models.py +155 -29
  30. data_designer/config/preview_results.py +6 -8
  31. data_designer/config/processors.py +63 -2
  32. data_designer/config/sampler_constraints.py +1 -2
  33. data_designer/config/sampler_params.py +50 -31
  34. data_designer/config/seed.py +1 -2
  35. data_designer/config/utils/code_lang.py +4 -5
  36. data_designer/config/utils/constants.py +31 -8
  37. data_designer/config/utils/io_helpers.py +5 -5
  38. data_designer/config/utils/misc.py +1 -4
  39. data_designer/config/utils/numerical_helpers.py +2 -2
  40. data_designer/config/utils/type_helpers.py +3 -3
  41. data_designer/config/utils/validation.py +7 -8
  42. data_designer/config/utils/visualization.py +32 -17
  43. data_designer/config/validator_params.py +4 -8
  44. data_designer/engine/analysis/column_profilers/base.py +0 -7
  45. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +2 -3
  46. data_designer/engine/analysis/column_statistics.py +16 -16
  47. data_designer/engine/analysis/dataset_profiler.py +25 -4
  48. data_designer/engine/analysis/utils/column_statistics_calculations.py +71 -49
  49. data_designer/engine/analysis/utils/judge_score_processing.py +5 -5
  50. data_designer/engine/column_generators/generators/base.py +34 -0
  51. data_designer/engine/column_generators/generators/embedding.py +45 -0
  52. data_designer/engine/column_generators/generators/{llm_generators.py → llm_completion.py} +17 -49
  53. data_designer/engine/column_generators/registry.py +4 -2
  54. data_designer/engine/column_generators/utils/judge_score_factory.py +5 -6
  55. data_designer/engine/configurable_task.py +2 -2
  56. data_designer/engine/dataset_builders/artifact_storage.py +1 -2
  57. data_designer/engine/dataset_builders/column_wise_builder.py +58 -15
  58. data_designer/engine/dataset_builders/utils/concurrency.py +6 -6
  59. data_designer/engine/models/facade.py +66 -9
  60. data_designer/engine/models/litellm_overrides.py +5 -6
  61. data_designer/engine/models/parsers/errors.py +2 -4
  62. data_designer/engine/models/parsers/parser.py +2 -3
  63. data_designer/engine/models/parsers/postprocessors.py +3 -4
  64. data_designer/engine/models/parsers/types.py +4 -4
  65. data_designer/engine/models/registry.py +47 -12
  66. data_designer/engine/models/telemetry.py +355 -0
  67. data_designer/engine/models/usage.py +7 -9
  68. data_designer/engine/processing/ginja/ast.py +1 -2
  69. data_designer/engine/processing/utils.py +40 -2
  70. data_designer/engine/registry/base.py +12 -12
  71. data_designer/engine/sampling_gen/constraints.py +1 -2
  72. data_designer/engine/sampling_gen/data_sources/base.py +14 -14
  73. data_designer/engine/sampling_gen/entities/phone_number.py +1 -2
  74. data_designer/engine/sampling_gen/people_gen.py +3 -7
  75. data_designer/engine/validators/base.py +2 -2
  76. data_designer/logging.py +2 -2
  77. data_designer/plugin_manager.py +3 -3
  78. data_designer/plugins/plugin.py +3 -3
  79. data_designer/plugins/registry.py +2 -2
  80. {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/METADATA +32 -1
  81. {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/RECORD +84 -77
  82. {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/WHEEL +0 -0
  83. {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/entry_points.txt +0 -0
  84. {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -3,6 +3,7 @@
3
3
 
4
4
  from data_designer.config.base import ConfigBase
5
5
  from data_designer.config.column_configs import (
6
+ EmbeddingColumnConfig,
6
7
  ExpressionColumnConfig,
7
8
  LLMCodeColumnConfig,
8
9
  LLMJudgeColumnConfig,
@@ -12,8 +13,9 @@ from data_designer.config.column_configs import (
12
13
  )
13
14
  from data_designer.config.column_types import DataDesignerColumnType
14
15
  from data_designer.engine.column_generators.generators.base import ColumnGenerator
16
+ from data_designer.engine.column_generators.generators.embedding import EmbeddingCellGenerator
15
17
  from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator
16
- from data_designer.engine.column_generators.generators.llm_generators import (
18
+ from data_designer.engine.column_generators.generators.llm_completion import (
17
19
  LLMCodeCellGenerator,
18
20
  LLMJudgeCellGenerator,
19
21
  LLMStructuredCellGenerator,
@@ -40,11 +42,11 @@ def create_default_column_generator_registry(with_plugins: bool = True) -> Colum
40
42
  registry.register(DataDesignerColumnType.LLM_CODE, LLMCodeCellGenerator, LLMCodeColumnConfig)
41
43
  registry.register(DataDesignerColumnType.LLM_JUDGE, LLMJudgeCellGenerator, LLMJudgeColumnConfig)
42
44
  registry.register(DataDesignerColumnType.EXPRESSION, ExpressionColumnGenerator, ExpressionColumnConfig)
45
+ registry.register(DataDesignerColumnType.EMBEDDING, EmbeddingCellGenerator, EmbeddingColumnConfig)
43
46
  registry.register(DataDesignerColumnType.SAMPLER, SamplerColumnGenerator, SamplerMultiColumnConfig)
44
47
  registry.register(DataDesignerColumnType.SEED_DATASET, SeedDatasetColumnGenerator, SeedDatasetMultiColumnConfig)
45
48
  registry.register(DataDesignerColumnType.VALIDATION, ValidationColumnGenerator, ValidationColumnConfig)
46
49
  registry.register(DataDesignerColumnType.LLM_STRUCTURED, LLMStructuredCellGenerator, LLMStructuredColumnConfig)
47
-
48
50
  if with_plugins:
49
51
  for plugin in PluginRegistry().get_plugins(PluginType.COLUMN_GENERATOR):
50
52
  registry.register(
@@ -2,7 +2,6 @@
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  from enum import Enum
5
- from typing import Type
6
5
 
7
6
  from pydantic import BaseModel, ConfigDict, Field, create_model
8
7
 
@@ -19,7 +18,7 @@ class BaseJudgeResponse(BaseModel):
19
18
  reasoning: str = Field(..., description="Reasoning for the assigned score.")
20
19
 
21
20
 
22
- def _stringify_scoring(options: dict, enum_type: Type[Enum]) -> str:
21
+ def _stringify_scoring(options: dict, enum_type: type[Enum]) -> str:
23
22
  """Convert score descriptions into a single text block."""
24
23
  list_block = "\n".join(
25
24
  [SCORING_FORMAT.format(score=score, description=description) for score, description in options.items()]
@@ -27,7 +26,7 @@ def _stringify_scoring(options: dict, enum_type: Type[Enum]) -> str:
27
26
  return SCORE_FIELD_DESCRIPTION_FORMAT.format(enum_name=enum_type.__name__, scoring=list_block)
28
27
 
29
28
 
30
- def create_judge_response_model(score: Score) -> Type[BaseJudgeResponse]:
29
+ def create_judge_response_model(score: Score) -> type[BaseJudgeResponse]:
31
30
  """Create a JudgeResponse data type."""
32
31
  enum_members = {}
33
32
  for option in score.options.keys():
@@ -46,12 +45,12 @@ def create_judge_response_model(score: Score) -> Type[BaseJudgeResponse]:
46
45
 
47
46
 
48
47
  def create_judge_structured_output_model(
49
- judge_responses: list[Type[BaseJudgeResponse]],
50
- ) -> Type[BaseModel]:
48
+ judge_responses: list[type[BaseJudgeResponse]],
49
+ ) -> type[BaseModel]:
51
50
  """Create a JudgeStructuredOutput class dynamically."""
52
51
  return create_model(
53
52
  "JudgeStructuredOutput",
54
53
  __doc__=f"Response schema for scores with the following names: {[response.__name__ for response in judge_responses]}.",
55
54
  __base__=BaseModel,
56
- **{response.__name__.lower(): (response, ...) for response in judge_responses},
55
+ **{response.__name__: (response, ...) for response in judge_responses},
57
56
  )
@@ -3,7 +3,7 @@
3
3
 
4
4
  from abc import ABC, abstractmethod
5
5
  from pathlib import Path
6
- from typing import Generic, Type, TypeVar, get_origin
6
+ from typing import Generic, TypeVar, get_origin
7
7
 
8
8
  import pandas as pd
9
9
 
@@ -30,7 +30,7 @@ class ConfigurableTask(ABC, Generic[TaskConfigT]):
30
30
  self._initialize()
31
31
 
32
32
  @classmethod
33
- def get_config_type(cls) -> Type[TaskConfigT]:
33
+ def get_config_type(cls) -> type[TaskConfigT]:
34
34
  for base in cls.__orig_bases__:
35
35
  if hasattr(base, "__args__") and len(base.__args__) == 1:
36
36
  arg = base.__args__[0]
@@ -7,7 +7,6 @@ import shutil
7
7
  from datetime import datetime
8
8
  from functools import cached_property
9
9
  from pathlib import Path
10
- from typing import Union
11
10
 
12
11
  import pandas as pd
13
12
  from pydantic import BaseModel, field_validator, model_validator
@@ -77,7 +76,7 @@ class ArtifactStorage(BaseModel):
77
76
  return self.base_dataset_path / self.processors_outputs_folder_name
78
77
 
79
78
  @field_validator("artifact_path")
80
- def validate_artifact_path(cls, v: Union[Path, str]) -> Path:
79
+ def validate_artifact_path(cls, v: Path | str) -> Path:
81
80
  v = Path(v)
82
81
  if not v.is_dir():
83
82
  raise ArtifactStorageError("Artifact path must exist and be a directory")
@@ -1,24 +1,30 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
+ from __future__ import annotations
3
4
 
4
5
  import functools
6
+ import importlib.metadata
5
7
  import json
6
8
  import logging
7
9
  import time
10
+ import uuid
8
11
  from pathlib import Path
9
- from typing import Callable
12
+ from typing import TYPE_CHECKING, Callable
10
13
 
11
14
  import pandas as pd
12
15
 
13
- from data_designer.config.column_types import ColumnConfigT, column_type_is_llm_generated
16
+ from data_designer.config.column_types import ColumnConfigT, column_type_is_model_generated
14
17
  from data_designer.config.dataset_builders import BuildStage
15
18
  from data_designer.config.processors import (
16
19
  DropColumnsProcessorConfig,
17
20
  ProcessorConfig,
18
21
  ProcessorType,
19
22
  )
20
- from data_designer.engine.column_generators.generators.base import ColumnGenerator, GenerationStrategy
21
- from data_designer.engine.column_generators.generators.llm_generators import WithLLMGeneration
23
+ from data_designer.engine.column_generators.generators.base import (
24
+ ColumnGenerator,
25
+ GenerationStrategy,
26
+ WithModelGeneration,
27
+ )
22
28
  from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
23
29
  from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError
24
30
  from data_designer.engine.dataset_builders.multi_column_configs import (
@@ -32,14 +38,21 @@ from data_designer.engine.dataset_builders.utils.concurrency import (
32
38
  from data_designer.engine.dataset_builders.utils.dataset_batch_manager import (
33
39
  DatasetBatchManager,
34
40
  )
41
+ from data_designer.engine.models.telemetry import InferenceEvent, NemoSourceEnum, TaskStatusEnum, TelemetryHandler
35
42
  from data_designer.engine.processing.processors.base import Processor
36
43
  from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor
37
44
  from data_designer.engine.registry.data_designer_registry import DataDesignerRegistry
38
45
  from data_designer.engine.resources.resource_provider import ResourceProvider
39
46
 
47
+ if TYPE_CHECKING:
48
+ from data_designer.engine.models.usage import ModelUsageStats
49
+
40
50
  logger = logging.getLogger(__name__)
41
51
 
42
52
 
53
+ _CLIENT_VERSION: str = importlib.metadata.version("data_designer")
54
+
55
+
43
56
  class ColumnWiseDatasetBuilder:
44
57
  def __init__(
45
58
  self,
@@ -72,7 +85,7 @@ class ColumnWiseDatasetBuilder:
72
85
 
73
86
  @functools.cached_property
74
87
  def llm_generated_column_configs(self) -> list[ColumnConfigT]:
75
- return [config for config in self.single_column_configs if column_type_is_llm_generated(config.column_type)]
88
+ return [config for config in self.single_column_configs if column_type_is_model_generated(config.column_type)]
76
89
 
77
90
  def build(
78
91
  self,
@@ -86,11 +99,12 @@ class ColumnWiseDatasetBuilder:
86
99
 
87
100
  generators = self._initialize_generators()
88
101
  start_time = time.perf_counter()
102
+ group_id = uuid.uuid4().hex
89
103
 
90
104
  self.batch_manager.start(num_records=num_records, buffer_size=buffer_size)
91
105
  for batch_idx in range(self.batch_manager.num_batches):
92
106
  logger.info(f"⏳ Processing batch {batch_idx + 1} of {self.batch_manager.num_batches}")
93
- self._run_batch(generators)
107
+ self._run_batch(generators, batch_mode="batch", group_id=group_id)
94
108
  df_batch = self._run_processors(
95
109
  stage=BuildStage.POST_BATCH,
96
110
  dataframe=self.batch_manager.get_current_batch(as_dataframe=True),
@@ -111,10 +125,10 @@ class ColumnWiseDatasetBuilder:
111
125
  self._run_model_health_check_if_needed()
112
126
 
113
127
  generators = self._initialize_generators()
114
-
128
+ group_id = uuid.uuid4().hex
115
129
  start_time = time.perf_counter()
116
130
  self.batch_manager.start(num_records=num_records, buffer_size=num_records)
117
- self._run_batch(generators, save_partial_results=False)
131
+ self._run_batch(generators, batch_mode="preview", save_partial_results=False, group_id=group_id)
118
132
  dataset = self.batch_manager.get_current_batch(as_dataframe=True)
119
133
  self.batch_manager.reset()
120
134
 
@@ -140,7 +154,10 @@ class ColumnWiseDatasetBuilder:
140
154
  for config in self._column_configs
141
155
  ]
142
156
 
143
- def _run_batch(self, generators: list[ColumnGenerator], *, save_partial_results: bool = True) -> None:
157
+ def _run_batch(
158
+ self, generators: list[ColumnGenerator], *, batch_mode: str, save_partial_results: bool = True, group_id: str
159
+ ) -> None:
160
+ pre_batch_snapshot = self._resource_provider.model_registry.get_model_usage_snapshot()
144
161
  for generator in generators:
145
162
  generator.log_pre_generation()
146
163
  try:
@@ -163,16 +180,20 @@ class ColumnWiseDatasetBuilder:
163
180
  )
164
181
  raise DatasetGenerationError(f"🛑 Failed to process {column_error_str}:\n{e}")
165
182
 
183
+ try:
184
+ usage_deltas = self._resource_provider.model_registry.get_usage_deltas(pre_batch_snapshot)
185
+ self._emit_batch_inference_events(batch_mode, usage_deltas, group_id)
186
+ except Exception:
187
+ pass
188
+
166
189
  def _run_from_scratch_column_generator(self, generator: ColumnGenerator) -> None:
167
190
  df = generator.generate_from_scratch(self.batch_manager.num_records_batch)
168
191
  self.batch_manager.add_records(df.to_dict(orient="records"))
169
192
 
170
193
  def _run_cell_by_cell_generator(self, generator: ColumnGenerator) -> None:
171
194
  max_workers = MAX_CONCURRENCY_PER_NON_LLM_GENERATOR
172
- if isinstance(generator, WithLLMGeneration):
195
+ if isinstance(generator, WithModelGeneration):
173
196
  max_workers = generator.inference_parameters.max_parallel_requests
174
- elif hasattr(generator.config, "max_parallel_requests"):
175
- max_workers = generator.config.max_parallel_requests
176
197
  self._fan_out_with_threads(generator, max_workers=max_workers)
177
198
 
178
199
  def _run_full_column_generator(self, generator: ColumnGenerator) -> None:
@@ -180,12 +201,12 @@ class ColumnWiseDatasetBuilder:
180
201
  self.batch_manager.update_records(df.to_dict(orient="records"))
181
202
 
182
203
  def _run_model_health_check_if_needed(self) -> bool:
183
- if any(column_type_is_llm_generated(config.column_type) for config in self.single_column_configs):
204
+ if any(column_type_is_model_generated(config.column_type) for config in self.single_column_configs):
184
205
  self._resource_provider.model_registry.run_health_check(
185
- set(config.model_alias for config in self.llm_generated_column_configs)
206
+ list(set(config.model_alias for config in self.llm_generated_column_configs))
186
207
  )
187
208
 
188
- def _fan_out_with_threads(self, generator: WithLLMGeneration, max_workers: int) -> None:
209
+ def _fan_out_with_threads(self, generator: WithModelGeneration, max_workers: int) -> None:
189
210
  if generator.generation_strategy != GenerationStrategy.CELL_BY_CELL:
190
211
  raise DatasetGenerationError(
191
212
  f"Generator {generator.metadata().name} is not a {GenerationStrategy.CELL_BY_CELL} "
@@ -288,3 +309,25 @@ class ColumnWiseDatasetBuilder:
288
309
  json_file_name="model_configs.json",
289
310
  configs=self._resource_provider.model_registry.model_configs.values(),
290
311
  )
312
+
313
+ def _emit_batch_inference_events(
314
+ self, batch_mode: str, usage_deltas: dict[str, ModelUsageStats], group_id: str
315
+ ) -> None:
316
+ if not usage_deltas:
317
+ return
318
+
319
+ events = [
320
+ InferenceEvent(
321
+ nemo_source=NemoSourceEnum.DATADESIGNER,
322
+ task=batch_mode,
323
+ task_status=TaskStatusEnum.SUCCESS,
324
+ model=model_name,
325
+ input_tokens=delta.token_usage.input_tokens,
326
+ output_tokens=delta.token_usage.output_tokens,
327
+ )
328
+ for model_name, delta in usage_deltas.items()
329
+ ]
330
+
331
+ with TelemetryHandler(source_client_version=_CLIENT_VERSION, session_id=group_id) as telemetry_handler:
332
+ for event in events:
333
+ telemetry_handler.enqueue(event)
@@ -8,7 +8,7 @@ import json
8
8
  import logging
9
9
  from concurrent.futures import Future, ThreadPoolExecutor
10
10
  from threading import Lock, Semaphore
11
- from typing import Any, Optional, Protocol
11
+ from typing import Any, Protocol
12
12
 
13
13
  from pydantic import BaseModel, Field
14
14
 
@@ -46,13 +46,13 @@ class ExecutorResults(BaseModel):
46
46
  class CallbackWithContext(Protocol):
47
47
  """Executor callback functions must accept a context kw argument."""
48
48
 
49
- def __call__(self, result: Any, *, context: Optional[dict] = None) -> Any: ...
49
+ def __call__(self, result: Any, *, context: dict | None = None) -> Any: ...
50
50
 
51
51
 
52
52
  class ErrorCallbackWithContext(Protocol):
53
53
  """Error callbacks take the Exception instance and context."""
54
54
 
55
- def __call__(self, exc: Exception, *, context: Optional[dict] = None) -> Any: ...
55
+ def __call__(self, exc: Exception, *, context: dict | None = None) -> Any: ...
56
56
 
57
57
 
58
58
  class ConcurrentThreadExecutor:
@@ -92,8 +92,8 @@ class ConcurrentThreadExecutor:
92
92
  *,
93
93
  max_workers: int,
94
94
  column_name: str,
95
- result_callback: Optional[CallbackWithContext] = None,
96
- error_callback: Optional[ErrorCallbackWithContext] = None,
95
+ result_callback: CallbackWithContext | None = None,
96
+ error_callback: ErrorCallbackWithContext | None = None,
97
97
  shutdown_error_rate: float = 0.50,
98
98
  shutdown_error_window: int = 10,
99
99
  ):
@@ -136,7 +136,7 @@ class ConcurrentThreadExecutor:
136
136
  )
137
137
  )
138
138
 
139
- def submit(self, fn, *args, context: Optional[dict] = None, **kwargs) -> None:
139
+ def submit(self, fn, *args, context: dict | None = None, **kwargs) -> None:
140
140
  if self._executor is None:
141
141
  raise RuntimeError("Executor is not initialized, this class should be used as a context manager.")
142
142
 
@@ -9,9 +9,9 @@ from copy import deepcopy
9
9
  from typing import Any
10
10
 
11
11
  from litellm.types.router import DeploymentTypedDict, LiteLLM_Params
12
- from litellm.types.utils import ModelResponse
12
+ from litellm.types.utils import EmbeddingResponse, ModelResponse
13
13
 
14
- from data_designer.config.models import ModelConfig, ModelProvider
14
+ from data_designer.config.models import GenerationType, ModelConfig, ModelProvider
15
15
  from data_designer.engine.model_provider import ModelProviderRegistry
16
16
  from data_designer.engine.models.errors import (
17
17
  GenerationValidationFailureError,
@@ -49,6 +49,10 @@ class ModelFacade:
49
49
  def model_provider(self) -> ModelProvider:
50
50
  return self._model_provider_registry.get_provider(self._model_config.provider)
51
51
 
52
+ @property
53
+ def model_generation_type(self) -> GenerationType:
54
+ return self._model_config.generation_type
55
+
52
56
  @property
53
57
  def model_provider_name(self) -> str:
54
58
  return self.model_provider.name
@@ -64,13 +68,12 @@ class ModelFacade:
64
68
  def completion(self, messages: list[dict[str, str]], skip_usage_tracking: bool = False, **kwargs) -> ModelResponse:
65
69
  logger.debug(
66
70
  f"Prompting model {self.model_name!r}...",
67
- extra={"model": self.model_name, "messages": messages, "sensitive": True},
71
+ extra={"model": self.model_name, "messages": messages},
68
72
  )
69
73
  response = None
70
- if self.model_provider.extra_body:
71
- kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body}
74
+ kwargs = self.consolidate_kwargs(**kwargs)
72
75
  try:
73
- response = self._router.completion(self.model_name, messages, **kwargs)
76
+ response = self._router.completion(model=self.model_name, messages=messages, **kwargs)
74
77
  logger.debug(
75
78
  f"Received completion from model {self.model_name!r}",
76
79
  extra={
@@ -84,9 +87,50 @@ class ModelFacade:
84
87
  except Exception as e:
85
88
  raise e
86
89
  finally:
87
- if not skip_usage_tracking:
90
+ if not skip_usage_tracking and response is not None:
88
91
  self._track_usage(response)
89
92
 
93
+ def consolidate_kwargs(self, **kwargs) -> dict[str, Any]:
94
+ # Remove purpose from kwargs to avoid passing it to the model
95
+ kwargs.pop("purpose", None)
96
+ kwargs = {**self._model_config.inference_parameters.generate_kwargs, **kwargs}
97
+ if self.model_provider.extra_body:
98
+ kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body}
99
+ return kwargs
100
+
101
+ @catch_llm_exceptions
102
+ def generate_text_embeddings(
103
+ self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs
104
+ ) -> list[list[float]]:
105
+ logger.debug(
106
+ f"Generating embeddings with model {self.model_name!r}...",
107
+ extra={
108
+ "model": self.model_name,
109
+ "input_count": len(input_texts),
110
+ },
111
+ )
112
+ kwargs = self.consolidate_kwargs(**kwargs)
113
+ response = None
114
+ try:
115
+ response = self._router.embedding(model=self.model_name, input=input_texts, **kwargs)
116
+ logger.debug(
117
+ f"Received embeddings from model {self.model_name!r}",
118
+ extra={
119
+ "model": self.model_name,
120
+ "embedding_count": len(response.data) if response.data else 0,
121
+ "usage": self._usage_stats.model_dump(),
122
+ },
123
+ )
124
+ if response.data and len(response.data) == len(input_texts):
125
+ return [data["embedding"] for data in response.data]
126
+ else:
127
+ raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.data)}")
128
+ except Exception as e:
129
+ raise e
130
+ finally:
131
+ if not skip_usage_tracking and response is not None:
132
+ self._track_usage_from_embedding(response)
133
+
90
134
  @catch_llm_exceptions
91
135
  def generate(
92
136
  self,
@@ -218,8 +262,21 @@ class ModelFacade:
218
262
  ):
219
263
  self._usage_stats.extend(
220
264
  token_usage=TokenUsageStats(
221
- prompt_tokens=response.usage.prompt_tokens,
222
- completion_tokens=response.usage.completion_tokens,
265
+ input_tokens=response.usage.prompt_tokens,
266
+ output_tokens=response.usage.completion_tokens,
267
+ ),
268
+ request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
269
+ )
270
+
271
+ def _track_usage_from_embedding(self, response: EmbeddingResponse | None) -> None:
272
+ if response is None:
273
+ self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1))
274
+ return
275
+ if response.usage is not None and response.usage.prompt_tokens is not None:
276
+ self._usage_stats.extend(
277
+ token_usage=TokenUsageStats(
278
+ input_tokens=response.usage.prompt_tokens,
279
+ output_tokens=0,
223
280
  ),
224
281
  request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
225
282
  )
@@ -5,7 +5,6 @@ from __future__ import annotations
5
5
 
6
6
  import random
7
7
  import threading
8
- from typing import Optional, Union
9
8
 
10
9
  import httpx
11
10
  import litellm
@@ -90,7 +89,7 @@ class CustomRouter(Router):
90
89
  self._initial_retry_after_s = initial_retry_after_s
91
90
  self._jitter_pct = jitter_pct
92
91
 
93
- def _extract_retry_delay_from_headers(self, e: Exception) -> Optional[Union[int, float]]:
92
+ def _extract_retry_delay_from_headers(self, e: Exception) -> int | float | None:
94
93
  """
95
94
  Most of this code logic was extracted directly from the parent
96
95
  `Router`'s `_time_to_sleep_before_retry` function. Our override
@@ -99,7 +98,7 @@ class CustomRouter(Router):
99
98
  return this info, we'll simply use that retry value returned here.
100
99
  """
101
100
 
102
- response_headers: Optional[httpx.Headers] = None
101
+ response_headers: httpx.Headers | None = None
103
102
  if hasattr(e, "response") and hasattr(e.response, "headers"): # type: ignore
104
103
  response_headers = e.response.headers # type: ignore
105
104
  if hasattr(e, "litellm_response_headers"):
@@ -119,9 +118,9 @@ class CustomRouter(Router):
119
118
  e: Exception,
120
119
  remaining_retries: int,
121
120
  num_retries: int,
122
- healthy_deployments: Optional[list] = None,
123
- all_deployments: Optional[list] = None,
124
- ) -> Union[int, float]:
121
+ healthy_deployments: list | None = None,
122
+ all_deployments: list | None = None,
123
+ ) -> int | float:
125
124
  """
126
125
  Implements exponential backoff for retries.
127
126
 
@@ -1,8 +1,6 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
- from typing import Optional
5
-
6
4
 
7
5
  class ParserException(Exception):
8
6
  """Identifies errors resulting from generic parser errors.
@@ -12,7 +10,7 @@ class ParserException(Exception):
12
10
  attempted to parse.
13
11
  """
14
12
 
15
- source: Optional[str]
13
+ source: str | None
16
14
 
17
15
  @staticmethod
18
16
  def _log_format(source: str) -> str:
@@ -24,7 +22,7 @@ class ParserException(Exception):
24
22
  # return f"<source>{source}</source>"
25
23
  return ""
26
24
 
27
- def __init__(self, msg: Optional[str] = None, source: Optional[str] = None):
25
+ def __init__(self, msg: str | None = None, source: str | None = None):
28
26
  msg = "" if msg is None else msg.strip()
29
27
 
30
28
  if source is not None:
@@ -2,7 +2,6 @@
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  from functools import reduce
5
- from typing import Optional
6
5
 
7
6
  import marko
8
7
  from lxml import etree
@@ -105,8 +104,8 @@ class LLMResponseParser:
105
104
 
106
105
  def __init__(
107
106
  self,
108
- tag_parsers: Optional[dict[str, TagParser]] = None,
109
- postprocessors: Optional[list[PostProcessor]] = None,
107
+ tag_parsers: dict[str, TagParser] | None = None,
108
+ postprocessors: list[PostProcessor] | None = None,
110
109
  ):
111
110
  """
112
111
  Initializes the LLMResponseParser with optional tag parsers and post-processors.
@@ -1,7 +1,6 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
- from typing import Optional, Type
5
4
 
6
5
  import json_repair
7
6
  from pydantic import BaseModel, ValidationError
@@ -60,12 +59,12 @@ def deserialize_json_code(
60
59
 
61
60
 
62
61
  class RealizePydanticTypes:
63
- types: list[Type[BaseModel]]
62
+ types: list[type[BaseModel]]
64
63
 
65
- def __init__(self, types: list[Type[BaseModel]]):
64
+ def __init__(self, types: list[type[BaseModel]]):
66
65
  self.types = types
67
66
 
68
- def _fit_types(self, obj: dict) -> Optional[BaseModel]:
67
+ def _fit_types(self, obj: dict) -> BaseModel | None:
69
68
  final_obj = None
70
69
 
71
70
  for t in self.types:
@@ -1,7 +1,7 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
- from typing import Any, Optional, Protocol, Type, runtime_checkable
4
+ from typing import Any, Protocol, runtime_checkable
5
5
 
6
6
  from lxml.etree import _Element
7
7
  from pydantic import BaseModel, Field
@@ -30,7 +30,7 @@ class LLMStructuredResponse(BaseModel):
30
30
  out.parsed = out.parsed[-n:]
31
31
  return out
32
32
 
33
- def filter(self, block_types: list[Type[BaseModel]]) -> Self:
33
+ def filter(self, block_types: list[type[BaseModel]]) -> Self:
34
34
  out = self.model_copy()
35
35
  out.parsed = [b for b in out.parsed if isinstance(b, tuple(block_types))]
36
36
  return out
@@ -44,7 +44,7 @@ class TagParser(Protocol):
44
44
  element, do some computation, and return some kind of structured
45
45
  output, represented as a subclass of Pydantic `BaseModel`.
46
46
  This protocol implementation can cover both classes as well
47
- as curried fuctions as parsers (e.g. `partial`).
47
+ as curried functions as parsers (e.g. `partial`).
48
48
  """
49
49
 
50
50
  def __call__(self, element: _Element) -> BaseModel: ...
@@ -69,7 +69,7 @@ class TextBlock(BaseModel):
69
69
 
70
70
  class CodeBlock(BaseModel):
71
71
  code: str
72
- code_lang: Optional[str] = None
72
+ code_lang: str | None = None
73
73
 
74
74
 
75
75
  class StructuredDataBlock(BaseModel):