data-designer 0.1.5__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/_version.py +2 -2
- data_designer/cli/README.md +15 -1
- data_designer/cli/commands/download.py +56 -0
- data_designer/cli/commands/list.py +4 -18
- data_designer/cli/controllers/__init__.py +2 -1
- data_designer/cli/controllers/download_controller.py +217 -0
- data_designer/cli/controllers/model_controller.py +4 -3
- data_designer/cli/forms/field.py +65 -19
- data_designer/cli/forms/model_builder.py +251 -44
- data_designer/cli/main.py +11 -1
- data_designer/cli/repositories/persona_repository.py +88 -0
- data_designer/cli/services/__init__.py +2 -1
- data_designer/cli/services/download_service.py +97 -0
- data_designer/cli/ui.py +131 -0
- data_designer/cli/utils.py +34 -0
- data_designer/config/analysis/__init__.py +2 -0
- data_designer/config/analysis/column_profilers.py +75 -7
- data_designer/config/analysis/column_statistics.py +192 -48
- data_designer/config/analysis/dataset_profiler.py +23 -5
- data_designer/config/analysis/utils/reporting.py +3 -3
- data_designer/config/base.py +3 -3
- data_designer/config/column_configs.py +27 -6
- data_designer/config/column_types.py +24 -17
- data_designer/config/config_builder.py +36 -27
- data_designer/config/data_designer_config.py +7 -7
- data_designer/config/datastore.py +6 -6
- data_designer/config/default_model_settings.py +27 -34
- data_designer/config/exports.py +8 -0
- data_designer/config/models.py +155 -29
- data_designer/config/preview_results.py +6 -8
- data_designer/config/processors.py +63 -2
- data_designer/config/sampler_constraints.py +1 -2
- data_designer/config/sampler_params.py +50 -31
- data_designer/config/seed.py +1 -2
- data_designer/config/utils/code_lang.py +4 -5
- data_designer/config/utils/constants.py +31 -8
- data_designer/config/utils/io_helpers.py +5 -5
- data_designer/config/utils/misc.py +1 -4
- data_designer/config/utils/numerical_helpers.py +2 -2
- data_designer/config/utils/type_helpers.py +3 -3
- data_designer/config/utils/validation.py +7 -8
- data_designer/config/utils/visualization.py +32 -17
- data_designer/config/validator_params.py +4 -8
- data_designer/engine/analysis/column_profilers/base.py +0 -7
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +2 -3
- data_designer/engine/analysis/column_statistics.py +16 -16
- data_designer/engine/analysis/dataset_profiler.py +25 -4
- data_designer/engine/analysis/utils/column_statistics_calculations.py +71 -49
- data_designer/engine/analysis/utils/judge_score_processing.py +5 -5
- data_designer/engine/column_generators/generators/base.py +34 -0
- data_designer/engine/column_generators/generators/embedding.py +45 -0
- data_designer/engine/column_generators/generators/{llm_generators.py → llm_completion.py} +17 -49
- data_designer/engine/column_generators/registry.py +4 -2
- data_designer/engine/column_generators/utils/judge_score_factory.py +5 -6
- data_designer/engine/configurable_task.py +2 -2
- data_designer/engine/dataset_builders/artifact_storage.py +1 -2
- data_designer/engine/dataset_builders/column_wise_builder.py +58 -15
- data_designer/engine/dataset_builders/utils/concurrency.py +6 -6
- data_designer/engine/models/facade.py +66 -9
- data_designer/engine/models/litellm_overrides.py +5 -6
- data_designer/engine/models/parsers/errors.py +2 -4
- data_designer/engine/models/parsers/parser.py +2 -3
- data_designer/engine/models/parsers/postprocessors.py +3 -4
- data_designer/engine/models/parsers/types.py +4 -4
- data_designer/engine/models/registry.py +47 -12
- data_designer/engine/models/telemetry.py +355 -0
- data_designer/engine/models/usage.py +7 -9
- data_designer/engine/processing/ginja/ast.py +1 -2
- data_designer/engine/processing/utils.py +40 -2
- data_designer/engine/registry/base.py +12 -12
- data_designer/engine/sampling_gen/constraints.py +1 -2
- data_designer/engine/sampling_gen/data_sources/base.py +14 -14
- data_designer/engine/sampling_gen/entities/phone_number.py +1 -2
- data_designer/engine/sampling_gen/people_gen.py +3 -7
- data_designer/engine/validators/base.py +2 -2
- data_designer/logging.py +2 -2
- data_designer/plugin_manager.py +3 -3
- data_designer/plugins/plugin.py +3 -3
- data_designer/plugins/registry.py +2 -2
- {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/METADATA +32 -1
- {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/RECORD +84 -77
- {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/WHEEL +0 -0
- {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/entry_points.txt +0 -0
- {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
|
|
4
4
|
from data_designer.config.base import ConfigBase
|
|
5
5
|
from data_designer.config.column_configs import (
|
|
6
|
+
EmbeddingColumnConfig,
|
|
6
7
|
ExpressionColumnConfig,
|
|
7
8
|
LLMCodeColumnConfig,
|
|
8
9
|
LLMJudgeColumnConfig,
|
|
@@ -12,8 +13,9 @@ from data_designer.config.column_configs import (
|
|
|
12
13
|
)
|
|
13
14
|
from data_designer.config.column_types import DataDesignerColumnType
|
|
14
15
|
from data_designer.engine.column_generators.generators.base import ColumnGenerator
|
|
16
|
+
from data_designer.engine.column_generators.generators.embedding import EmbeddingCellGenerator
|
|
15
17
|
from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator
|
|
16
|
-
from data_designer.engine.column_generators.generators.
|
|
18
|
+
from data_designer.engine.column_generators.generators.llm_completion import (
|
|
17
19
|
LLMCodeCellGenerator,
|
|
18
20
|
LLMJudgeCellGenerator,
|
|
19
21
|
LLMStructuredCellGenerator,
|
|
@@ -40,11 +42,11 @@ def create_default_column_generator_registry(with_plugins: bool = True) -> Colum
|
|
|
40
42
|
registry.register(DataDesignerColumnType.LLM_CODE, LLMCodeCellGenerator, LLMCodeColumnConfig)
|
|
41
43
|
registry.register(DataDesignerColumnType.LLM_JUDGE, LLMJudgeCellGenerator, LLMJudgeColumnConfig)
|
|
42
44
|
registry.register(DataDesignerColumnType.EXPRESSION, ExpressionColumnGenerator, ExpressionColumnConfig)
|
|
45
|
+
registry.register(DataDesignerColumnType.EMBEDDING, EmbeddingCellGenerator, EmbeddingColumnConfig)
|
|
43
46
|
registry.register(DataDesignerColumnType.SAMPLER, SamplerColumnGenerator, SamplerMultiColumnConfig)
|
|
44
47
|
registry.register(DataDesignerColumnType.SEED_DATASET, SeedDatasetColumnGenerator, SeedDatasetMultiColumnConfig)
|
|
45
48
|
registry.register(DataDesignerColumnType.VALIDATION, ValidationColumnGenerator, ValidationColumnConfig)
|
|
46
49
|
registry.register(DataDesignerColumnType.LLM_STRUCTURED, LLMStructuredCellGenerator, LLMStructuredColumnConfig)
|
|
47
|
-
|
|
48
50
|
if with_plugins:
|
|
49
51
|
for plugin in PluginRegistry().get_plugins(PluginType.COLUMN_GENERATOR):
|
|
50
52
|
registry.register(
|
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
4
|
from enum import Enum
|
|
5
|
-
from typing import Type
|
|
6
5
|
|
|
7
6
|
from pydantic import BaseModel, ConfigDict, Field, create_model
|
|
8
7
|
|
|
@@ -19,7 +18,7 @@ class BaseJudgeResponse(BaseModel):
|
|
|
19
18
|
reasoning: str = Field(..., description="Reasoning for the assigned score.")
|
|
20
19
|
|
|
21
20
|
|
|
22
|
-
def _stringify_scoring(options: dict, enum_type:
|
|
21
|
+
def _stringify_scoring(options: dict, enum_type: type[Enum]) -> str:
|
|
23
22
|
"""Convert score descriptions into a single text block."""
|
|
24
23
|
list_block = "\n".join(
|
|
25
24
|
[SCORING_FORMAT.format(score=score, description=description) for score, description in options.items()]
|
|
@@ -27,7 +26,7 @@ def _stringify_scoring(options: dict, enum_type: Type[Enum]) -> str:
|
|
|
27
26
|
return SCORE_FIELD_DESCRIPTION_FORMAT.format(enum_name=enum_type.__name__, scoring=list_block)
|
|
28
27
|
|
|
29
28
|
|
|
30
|
-
def create_judge_response_model(score: Score) ->
|
|
29
|
+
def create_judge_response_model(score: Score) -> type[BaseJudgeResponse]:
|
|
31
30
|
"""Create a JudgeResponse data type."""
|
|
32
31
|
enum_members = {}
|
|
33
32
|
for option in score.options.keys():
|
|
@@ -46,12 +45,12 @@ def create_judge_response_model(score: Score) -> Type[BaseJudgeResponse]:
|
|
|
46
45
|
|
|
47
46
|
|
|
48
47
|
def create_judge_structured_output_model(
|
|
49
|
-
judge_responses: list[
|
|
50
|
-
) ->
|
|
48
|
+
judge_responses: list[type[BaseJudgeResponse]],
|
|
49
|
+
) -> type[BaseModel]:
|
|
51
50
|
"""Create a JudgeStructuredOutput class dynamically."""
|
|
52
51
|
return create_model(
|
|
53
52
|
"JudgeStructuredOutput",
|
|
54
53
|
__doc__=f"Response schema for scores with the following names: {[response.__name__ for response in judge_responses]}.",
|
|
55
54
|
__base__=BaseModel,
|
|
56
|
-
**{response.__name__
|
|
55
|
+
**{response.__name__: (response, ...) for response in judge_responses},
|
|
57
56
|
)
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
|
|
4
4
|
from abc import ABC, abstractmethod
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import Generic,
|
|
6
|
+
from typing import Generic, TypeVar, get_origin
|
|
7
7
|
|
|
8
8
|
import pandas as pd
|
|
9
9
|
|
|
@@ -30,7 +30,7 @@ class ConfigurableTask(ABC, Generic[TaskConfigT]):
|
|
|
30
30
|
self._initialize()
|
|
31
31
|
|
|
32
32
|
@classmethod
|
|
33
|
-
def get_config_type(cls) ->
|
|
33
|
+
def get_config_type(cls) -> type[TaskConfigT]:
|
|
34
34
|
for base in cls.__orig_bases__:
|
|
35
35
|
if hasattr(base, "__args__") and len(base.__args__) == 1:
|
|
36
36
|
arg = base.__args__[0]
|
|
@@ -7,7 +7,6 @@ import shutil
|
|
|
7
7
|
from datetime import datetime
|
|
8
8
|
from functools import cached_property
|
|
9
9
|
from pathlib import Path
|
|
10
|
-
from typing import Union
|
|
11
10
|
|
|
12
11
|
import pandas as pd
|
|
13
12
|
from pydantic import BaseModel, field_validator, model_validator
|
|
@@ -77,7 +76,7 @@ class ArtifactStorage(BaseModel):
|
|
|
77
76
|
return self.base_dataset_path / self.processors_outputs_folder_name
|
|
78
77
|
|
|
79
78
|
@field_validator("artifact_path")
|
|
80
|
-
def validate_artifact_path(cls, v:
|
|
79
|
+
def validate_artifact_path(cls, v: Path | str) -> Path:
|
|
81
80
|
v = Path(v)
|
|
82
81
|
if not v.is_dir():
|
|
83
82
|
raise ArtifactStorageError("Artifact path must exist and be a directory")
|
|
@@ -1,24 +1,30 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
from __future__ import annotations
|
|
3
4
|
|
|
4
5
|
import functools
|
|
6
|
+
import importlib.metadata
|
|
5
7
|
import json
|
|
6
8
|
import logging
|
|
7
9
|
import time
|
|
10
|
+
import uuid
|
|
8
11
|
from pathlib import Path
|
|
9
|
-
from typing import Callable
|
|
12
|
+
from typing import TYPE_CHECKING, Callable
|
|
10
13
|
|
|
11
14
|
import pandas as pd
|
|
12
15
|
|
|
13
|
-
from data_designer.config.column_types import ColumnConfigT,
|
|
16
|
+
from data_designer.config.column_types import ColumnConfigT, column_type_is_model_generated
|
|
14
17
|
from data_designer.config.dataset_builders import BuildStage
|
|
15
18
|
from data_designer.config.processors import (
|
|
16
19
|
DropColumnsProcessorConfig,
|
|
17
20
|
ProcessorConfig,
|
|
18
21
|
ProcessorType,
|
|
19
22
|
)
|
|
20
|
-
from data_designer.engine.column_generators.generators.base import
|
|
21
|
-
|
|
23
|
+
from data_designer.engine.column_generators.generators.base import (
|
|
24
|
+
ColumnGenerator,
|
|
25
|
+
GenerationStrategy,
|
|
26
|
+
WithModelGeneration,
|
|
27
|
+
)
|
|
22
28
|
from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
|
|
23
29
|
from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError
|
|
24
30
|
from data_designer.engine.dataset_builders.multi_column_configs import (
|
|
@@ -32,14 +38,21 @@ from data_designer.engine.dataset_builders.utils.concurrency import (
|
|
|
32
38
|
from data_designer.engine.dataset_builders.utils.dataset_batch_manager import (
|
|
33
39
|
DatasetBatchManager,
|
|
34
40
|
)
|
|
41
|
+
from data_designer.engine.models.telemetry import InferenceEvent, NemoSourceEnum, TaskStatusEnum, TelemetryHandler
|
|
35
42
|
from data_designer.engine.processing.processors.base import Processor
|
|
36
43
|
from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor
|
|
37
44
|
from data_designer.engine.registry.data_designer_registry import DataDesignerRegistry
|
|
38
45
|
from data_designer.engine.resources.resource_provider import ResourceProvider
|
|
39
46
|
|
|
47
|
+
if TYPE_CHECKING:
|
|
48
|
+
from data_designer.engine.models.usage import ModelUsageStats
|
|
49
|
+
|
|
40
50
|
logger = logging.getLogger(__name__)
|
|
41
51
|
|
|
42
52
|
|
|
53
|
+
_CLIENT_VERSION: str = importlib.metadata.version("data_designer")
|
|
54
|
+
|
|
55
|
+
|
|
43
56
|
class ColumnWiseDatasetBuilder:
|
|
44
57
|
def __init__(
|
|
45
58
|
self,
|
|
@@ -72,7 +85,7 @@ class ColumnWiseDatasetBuilder:
|
|
|
72
85
|
|
|
73
86
|
@functools.cached_property
|
|
74
87
|
def llm_generated_column_configs(self) -> list[ColumnConfigT]:
|
|
75
|
-
return [config for config in self.single_column_configs if
|
|
88
|
+
return [config for config in self.single_column_configs if column_type_is_model_generated(config.column_type)]
|
|
76
89
|
|
|
77
90
|
def build(
|
|
78
91
|
self,
|
|
@@ -86,11 +99,12 @@ class ColumnWiseDatasetBuilder:
|
|
|
86
99
|
|
|
87
100
|
generators = self._initialize_generators()
|
|
88
101
|
start_time = time.perf_counter()
|
|
102
|
+
group_id = uuid.uuid4().hex
|
|
89
103
|
|
|
90
104
|
self.batch_manager.start(num_records=num_records, buffer_size=buffer_size)
|
|
91
105
|
for batch_idx in range(self.batch_manager.num_batches):
|
|
92
106
|
logger.info(f"⏳ Processing batch {batch_idx + 1} of {self.batch_manager.num_batches}")
|
|
93
|
-
self._run_batch(generators)
|
|
107
|
+
self._run_batch(generators, batch_mode="batch", group_id=group_id)
|
|
94
108
|
df_batch = self._run_processors(
|
|
95
109
|
stage=BuildStage.POST_BATCH,
|
|
96
110
|
dataframe=self.batch_manager.get_current_batch(as_dataframe=True),
|
|
@@ -111,10 +125,10 @@ class ColumnWiseDatasetBuilder:
|
|
|
111
125
|
self._run_model_health_check_if_needed()
|
|
112
126
|
|
|
113
127
|
generators = self._initialize_generators()
|
|
114
|
-
|
|
128
|
+
group_id = uuid.uuid4().hex
|
|
115
129
|
start_time = time.perf_counter()
|
|
116
130
|
self.batch_manager.start(num_records=num_records, buffer_size=num_records)
|
|
117
|
-
self._run_batch(generators, save_partial_results=False)
|
|
131
|
+
self._run_batch(generators, batch_mode="preview", save_partial_results=False, group_id=group_id)
|
|
118
132
|
dataset = self.batch_manager.get_current_batch(as_dataframe=True)
|
|
119
133
|
self.batch_manager.reset()
|
|
120
134
|
|
|
@@ -140,7 +154,10 @@ class ColumnWiseDatasetBuilder:
|
|
|
140
154
|
for config in self._column_configs
|
|
141
155
|
]
|
|
142
156
|
|
|
143
|
-
def _run_batch(
|
|
157
|
+
def _run_batch(
|
|
158
|
+
self, generators: list[ColumnGenerator], *, batch_mode: str, save_partial_results: bool = True, group_id: str
|
|
159
|
+
) -> None:
|
|
160
|
+
pre_batch_snapshot = self._resource_provider.model_registry.get_model_usage_snapshot()
|
|
144
161
|
for generator in generators:
|
|
145
162
|
generator.log_pre_generation()
|
|
146
163
|
try:
|
|
@@ -163,16 +180,20 @@ class ColumnWiseDatasetBuilder:
|
|
|
163
180
|
)
|
|
164
181
|
raise DatasetGenerationError(f"🛑 Failed to process {column_error_str}:\n{e}")
|
|
165
182
|
|
|
183
|
+
try:
|
|
184
|
+
usage_deltas = self._resource_provider.model_registry.get_usage_deltas(pre_batch_snapshot)
|
|
185
|
+
self._emit_batch_inference_events(batch_mode, usage_deltas, group_id)
|
|
186
|
+
except Exception:
|
|
187
|
+
pass
|
|
188
|
+
|
|
166
189
|
def _run_from_scratch_column_generator(self, generator: ColumnGenerator) -> None:
|
|
167
190
|
df = generator.generate_from_scratch(self.batch_manager.num_records_batch)
|
|
168
191
|
self.batch_manager.add_records(df.to_dict(orient="records"))
|
|
169
192
|
|
|
170
193
|
def _run_cell_by_cell_generator(self, generator: ColumnGenerator) -> None:
|
|
171
194
|
max_workers = MAX_CONCURRENCY_PER_NON_LLM_GENERATOR
|
|
172
|
-
if isinstance(generator,
|
|
195
|
+
if isinstance(generator, WithModelGeneration):
|
|
173
196
|
max_workers = generator.inference_parameters.max_parallel_requests
|
|
174
|
-
elif hasattr(generator.config, "max_parallel_requests"):
|
|
175
|
-
max_workers = generator.config.max_parallel_requests
|
|
176
197
|
self._fan_out_with_threads(generator, max_workers=max_workers)
|
|
177
198
|
|
|
178
199
|
def _run_full_column_generator(self, generator: ColumnGenerator) -> None:
|
|
@@ -180,12 +201,12 @@ class ColumnWiseDatasetBuilder:
|
|
|
180
201
|
self.batch_manager.update_records(df.to_dict(orient="records"))
|
|
181
202
|
|
|
182
203
|
def _run_model_health_check_if_needed(self) -> bool:
|
|
183
|
-
if any(
|
|
204
|
+
if any(column_type_is_model_generated(config.column_type) for config in self.single_column_configs):
|
|
184
205
|
self._resource_provider.model_registry.run_health_check(
|
|
185
|
-
set(config.model_alias for config in self.llm_generated_column_configs)
|
|
206
|
+
list(set(config.model_alias for config in self.llm_generated_column_configs))
|
|
186
207
|
)
|
|
187
208
|
|
|
188
|
-
def _fan_out_with_threads(self, generator:
|
|
209
|
+
def _fan_out_with_threads(self, generator: WithModelGeneration, max_workers: int) -> None:
|
|
189
210
|
if generator.generation_strategy != GenerationStrategy.CELL_BY_CELL:
|
|
190
211
|
raise DatasetGenerationError(
|
|
191
212
|
f"Generator {generator.metadata().name} is not a {GenerationStrategy.CELL_BY_CELL} "
|
|
@@ -288,3 +309,25 @@ class ColumnWiseDatasetBuilder:
|
|
|
288
309
|
json_file_name="model_configs.json",
|
|
289
310
|
configs=self._resource_provider.model_registry.model_configs.values(),
|
|
290
311
|
)
|
|
312
|
+
|
|
313
|
+
def _emit_batch_inference_events(
|
|
314
|
+
self, batch_mode: str, usage_deltas: dict[str, ModelUsageStats], group_id: str
|
|
315
|
+
) -> None:
|
|
316
|
+
if not usage_deltas:
|
|
317
|
+
return
|
|
318
|
+
|
|
319
|
+
events = [
|
|
320
|
+
InferenceEvent(
|
|
321
|
+
nemo_source=NemoSourceEnum.DATADESIGNER,
|
|
322
|
+
task=batch_mode,
|
|
323
|
+
task_status=TaskStatusEnum.SUCCESS,
|
|
324
|
+
model=model_name,
|
|
325
|
+
input_tokens=delta.token_usage.input_tokens,
|
|
326
|
+
output_tokens=delta.token_usage.output_tokens,
|
|
327
|
+
)
|
|
328
|
+
for model_name, delta in usage_deltas.items()
|
|
329
|
+
]
|
|
330
|
+
|
|
331
|
+
with TelemetryHandler(source_client_version=_CLIENT_VERSION, session_id=group_id) as telemetry_handler:
|
|
332
|
+
for event in events:
|
|
333
|
+
telemetry_handler.enqueue(event)
|
|
@@ -8,7 +8,7 @@ import json
|
|
|
8
8
|
import logging
|
|
9
9
|
from concurrent.futures import Future, ThreadPoolExecutor
|
|
10
10
|
from threading import Lock, Semaphore
|
|
11
|
-
from typing import Any,
|
|
11
|
+
from typing import Any, Protocol
|
|
12
12
|
|
|
13
13
|
from pydantic import BaseModel, Field
|
|
14
14
|
|
|
@@ -46,13 +46,13 @@ class ExecutorResults(BaseModel):
|
|
|
46
46
|
class CallbackWithContext(Protocol):
|
|
47
47
|
"""Executor callback functions must accept a context kw argument."""
|
|
48
48
|
|
|
49
|
-
def __call__(self, result: Any, *, context:
|
|
49
|
+
def __call__(self, result: Any, *, context: dict | None = None) -> Any: ...
|
|
50
50
|
|
|
51
51
|
|
|
52
52
|
class ErrorCallbackWithContext(Protocol):
|
|
53
53
|
"""Error callbacks take the Exception instance and context."""
|
|
54
54
|
|
|
55
|
-
def __call__(self, exc: Exception, *, context:
|
|
55
|
+
def __call__(self, exc: Exception, *, context: dict | None = None) -> Any: ...
|
|
56
56
|
|
|
57
57
|
|
|
58
58
|
class ConcurrentThreadExecutor:
|
|
@@ -92,8 +92,8 @@ class ConcurrentThreadExecutor:
|
|
|
92
92
|
*,
|
|
93
93
|
max_workers: int,
|
|
94
94
|
column_name: str,
|
|
95
|
-
result_callback:
|
|
96
|
-
error_callback:
|
|
95
|
+
result_callback: CallbackWithContext | None = None,
|
|
96
|
+
error_callback: ErrorCallbackWithContext | None = None,
|
|
97
97
|
shutdown_error_rate: float = 0.50,
|
|
98
98
|
shutdown_error_window: int = 10,
|
|
99
99
|
):
|
|
@@ -136,7 +136,7 @@ class ConcurrentThreadExecutor:
|
|
|
136
136
|
)
|
|
137
137
|
)
|
|
138
138
|
|
|
139
|
-
def submit(self, fn, *args, context:
|
|
139
|
+
def submit(self, fn, *args, context: dict | None = None, **kwargs) -> None:
|
|
140
140
|
if self._executor is None:
|
|
141
141
|
raise RuntimeError("Executor is not initialized, this class should be used as a context manager.")
|
|
142
142
|
|
|
@@ -9,9 +9,9 @@ from copy import deepcopy
|
|
|
9
9
|
from typing import Any
|
|
10
10
|
|
|
11
11
|
from litellm.types.router import DeploymentTypedDict, LiteLLM_Params
|
|
12
|
-
from litellm.types.utils import ModelResponse
|
|
12
|
+
from litellm.types.utils import EmbeddingResponse, ModelResponse
|
|
13
13
|
|
|
14
|
-
from data_designer.config.models import ModelConfig, ModelProvider
|
|
14
|
+
from data_designer.config.models import GenerationType, ModelConfig, ModelProvider
|
|
15
15
|
from data_designer.engine.model_provider import ModelProviderRegistry
|
|
16
16
|
from data_designer.engine.models.errors import (
|
|
17
17
|
GenerationValidationFailureError,
|
|
@@ -49,6 +49,10 @@ class ModelFacade:
|
|
|
49
49
|
def model_provider(self) -> ModelProvider:
|
|
50
50
|
return self._model_provider_registry.get_provider(self._model_config.provider)
|
|
51
51
|
|
|
52
|
+
@property
|
|
53
|
+
def model_generation_type(self) -> GenerationType:
|
|
54
|
+
return self._model_config.generation_type
|
|
55
|
+
|
|
52
56
|
@property
|
|
53
57
|
def model_provider_name(self) -> str:
|
|
54
58
|
return self.model_provider.name
|
|
@@ -64,13 +68,12 @@ class ModelFacade:
|
|
|
64
68
|
def completion(self, messages: list[dict[str, str]], skip_usage_tracking: bool = False, **kwargs) -> ModelResponse:
|
|
65
69
|
logger.debug(
|
|
66
70
|
f"Prompting model {self.model_name!r}...",
|
|
67
|
-
extra={"model": self.model_name, "messages": messages
|
|
71
|
+
extra={"model": self.model_name, "messages": messages},
|
|
68
72
|
)
|
|
69
73
|
response = None
|
|
70
|
-
|
|
71
|
-
kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body}
|
|
74
|
+
kwargs = self.consolidate_kwargs(**kwargs)
|
|
72
75
|
try:
|
|
73
|
-
response = self._router.completion(self.model_name, messages, **kwargs)
|
|
76
|
+
response = self._router.completion(model=self.model_name, messages=messages, **kwargs)
|
|
74
77
|
logger.debug(
|
|
75
78
|
f"Received completion from model {self.model_name!r}",
|
|
76
79
|
extra={
|
|
@@ -84,9 +87,50 @@ class ModelFacade:
|
|
|
84
87
|
except Exception as e:
|
|
85
88
|
raise e
|
|
86
89
|
finally:
|
|
87
|
-
if not skip_usage_tracking:
|
|
90
|
+
if not skip_usage_tracking and response is not None:
|
|
88
91
|
self._track_usage(response)
|
|
89
92
|
|
|
93
|
+
def consolidate_kwargs(self, **kwargs) -> dict[str, Any]:
|
|
94
|
+
# Remove purpose from kwargs to avoid passing it to the model
|
|
95
|
+
kwargs.pop("purpose", None)
|
|
96
|
+
kwargs = {**self._model_config.inference_parameters.generate_kwargs, **kwargs}
|
|
97
|
+
if self.model_provider.extra_body:
|
|
98
|
+
kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body}
|
|
99
|
+
return kwargs
|
|
100
|
+
|
|
101
|
+
@catch_llm_exceptions
|
|
102
|
+
def generate_text_embeddings(
|
|
103
|
+
self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs
|
|
104
|
+
) -> list[list[float]]:
|
|
105
|
+
logger.debug(
|
|
106
|
+
f"Generating embeddings with model {self.model_name!r}...",
|
|
107
|
+
extra={
|
|
108
|
+
"model": self.model_name,
|
|
109
|
+
"input_count": len(input_texts),
|
|
110
|
+
},
|
|
111
|
+
)
|
|
112
|
+
kwargs = self.consolidate_kwargs(**kwargs)
|
|
113
|
+
response = None
|
|
114
|
+
try:
|
|
115
|
+
response = self._router.embedding(model=self.model_name, input=input_texts, **kwargs)
|
|
116
|
+
logger.debug(
|
|
117
|
+
f"Received embeddings from model {self.model_name!r}",
|
|
118
|
+
extra={
|
|
119
|
+
"model": self.model_name,
|
|
120
|
+
"embedding_count": len(response.data) if response.data else 0,
|
|
121
|
+
"usage": self._usage_stats.model_dump(),
|
|
122
|
+
},
|
|
123
|
+
)
|
|
124
|
+
if response.data and len(response.data) == len(input_texts):
|
|
125
|
+
return [data["embedding"] for data in response.data]
|
|
126
|
+
else:
|
|
127
|
+
raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.data)}")
|
|
128
|
+
except Exception as e:
|
|
129
|
+
raise e
|
|
130
|
+
finally:
|
|
131
|
+
if not skip_usage_tracking and response is not None:
|
|
132
|
+
self._track_usage_from_embedding(response)
|
|
133
|
+
|
|
90
134
|
@catch_llm_exceptions
|
|
91
135
|
def generate(
|
|
92
136
|
self,
|
|
@@ -218,8 +262,21 @@ class ModelFacade:
|
|
|
218
262
|
):
|
|
219
263
|
self._usage_stats.extend(
|
|
220
264
|
token_usage=TokenUsageStats(
|
|
221
|
-
|
|
222
|
-
|
|
265
|
+
input_tokens=response.usage.prompt_tokens,
|
|
266
|
+
output_tokens=response.usage.completion_tokens,
|
|
267
|
+
),
|
|
268
|
+
request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
def _track_usage_from_embedding(self, response: EmbeddingResponse | None) -> None:
|
|
272
|
+
if response is None:
|
|
273
|
+
self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1))
|
|
274
|
+
return
|
|
275
|
+
if response.usage is not None and response.usage.prompt_tokens is not None:
|
|
276
|
+
self._usage_stats.extend(
|
|
277
|
+
token_usage=TokenUsageStats(
|
|
278
|
+
input_tokens=response.usage.prompt_tokens,
|
|
279
|
+
output_tokens=0,
|
|
223
280
|
),
|
|
224
281
|
request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
|
|
225
282
|
)
|
|
@@ -5,7 +5,6 @@ from __future__ import annotations
|
|
|
5
5
|
|
|
6
6
|
import random
|
|
7
7
|
import threading
|
|
8
|
-
from typing import Optional, Union
|
|
9
8
|
|
|
10
9
|
import httpx
|
|
11
10
|
import litellm
|
|
@@ -90,7 +89,7 @@ class CustomRouter(Router):
|
|
|
90
89
|
self._initial_retry_after_s = initial_retry_after_s
|
|
91
90
|
self._jitter_pct = jitter_pct
|
|
92
91
|
|
|
93
|
-
def _extract_retry_delay_from_headers(self, e: Exception) ->
|
|
92
|
+
def _extract_retry_delay_from_headers(self, e: Exception) -> int | float | None:
|
|
94
93
|
"""
|
|
95
94
|
Most of this code logic was extracted directly from the parent
|
|
96
95
|
`Router`'s `_time_to_sleep_before_retry` function. Our override
|
|
@@ -99,7 +98,7 @@ class CustomRouter(Router):
|
|
|
99
98
|
return this info, we'll simply use that retry value returned here.
|
|
100
99
|
"""
|
|
101
100
|
|
|
102
|
-
response_headers:
|
|
101
|
+
response_headers: httpx.Headers | None = None
|
|
103
102
|
if hasattr(e, "response") and hasattr(e.response, "headers"): # type: ignore
|
|
104
103
|
response_headers = e.response.headers # type: ignore
|
|
105
104
|
if hasattr(e, "litellm_response_headers"):
|
|
@@ -119,9 +118,9 @@ class CustomRouter(Router):
|
|
|
119
118
|
e: Exception,
|
|
120
119
|
remaining_retries: int,
|
|
121
120
|
num_retries: int,
|
|
122
|
-
healthy_deployments:
|
|
123
|
-
all_deployments:
|
|
124
|
-
) ->
|
|
121
|
+
healthy_deployments: list | None = None,
|
|
122
|
+
all_deployments: list | None = None,
|
|
123
|
+
) -> int | float:
|
|
125
124
|
"""
|
|
126
125
|
Implements exponential backoff for retries.
|
|
127
126
|
|
|
@@ -1,8 +1,6 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
-
from typing import Optional
|
|
5
|
-
|
|
6
4
|
|
|
7
5
|
class ParserException(Exception):
|
|
8
6
|
"""Identifies errors resulting from generic parser errors.
|
|
@@ -12,7 +10,7 @@ class ParserException(Exception):
|
|
|
12
10
|
attempted to parse.
|
|
13
11
|
"""
|
|
14
12
|
|
|
15
|
-
source:
|
|
13
|
+
source: str | None
|
|
16
14
|
|
|
17
15
|
@staticmethod
|
|
18
16
|
def _log_format(source: str) -> str:
|
|
@@ -24,7 +22,7 @@ class ParserException(Exception):
|
|
|
24
22
|
# return f"<source>{source}</source>"
|
|
25
23
|
return ""
|
|
26
24
|
|
|
27
|
-
def __init__(self, msg:
|
|
25
|
+
def __init__(self, msg: str | None = None, source: str | None = None):
|
|
28
26
|
msg = "" if msg is None else msg.strip()
|
|
29
27
|
|
|
30
28
|
if source is not None:
|
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
4
|
from functools import reduce
|
|
5
|
-
from typing import Optional
|
|
6
5
|
|
|
7
6
|
import marko
|
|
8
7
|
from lxml import etree
|
|
@@ -105,8 +104,8 @@ class LLMResponseParser:
|
|
|
105
104
|
|
|
106
105
|
def __init__(
|
|
107
106
|
self,
|
|
108
|
-
tag_parsers:
|
|
109
|
-
postprocessors:
|
|
107
|
+
tag_parsers: dict[str, TagParser] | None = None,
|
|
108
|
+
postprocessors: list[PostProcessor] | None = None,
|
|
110
109
|
):
|
|
111
110
|
"""
|
|
112
111
|
Initializes the LLMResponseParser with optional tag parsers and post-processors.
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
-
from typing import Optional, Type
|
|
5
4
|
|
|
6
5
|
import json_repair
|
|
7
6
|
from pydantic import BaseModel, ValidationError
|
|
@@ -60,12 +59,12 @@ def deserialize_json_code(
|
|
|
60
59
|
|
|
61
60
|
|
|
62
61
|
class RealizePydanticTypes:
|
|
63
|
-
types: list[
|
|
62
|
+
types: list[type[BaseModel]]
|
|
64
63
|
|
|
65
|
-
def __init__(self, types: list[
|
|
64
|
+
def __init__(self, types: list[type[BaseModel]]):
|
|
66
65
|
self.types = types
|
|
67
66
|
|
|
68
|
-
def _fit_types(self, obj: dict) ->
|
|
67
|
+
def _fit_types(self, obj: dict) -> BaseModel | None:
|
|
69
68
|
final_obj = None
|
|
70
69
|
|
|
71
70
|
for t in self.types:
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
-
from typing import Any,
|
|
4
|
+
from typing import Any, Protocol, runtime_checkable
|
|
5
5
|
|
|
6
6
|
from lxml.etree import _Element
|
|
7
7
|
from pydantic import BaseModel, Field
|
|
@@ -30,7 +30,7 @@ class LLMStructuredResponse(BaseModel):
|
|
|
30
30
|
out.parsed = out.parsed[-n:]
|
|
31
31
|
return out
|
|
32
32
|
|
|
33
|
-
def filter(self, block_types: list[
|
|
33
|
+
def filter(self, block_types: list[type[BaseModel]]) -> Self:
|
|
34
34
|
out = self.model_copy()
|
|
35
35
|
out.parsed = [b for b in out.parsed if isinstance(b, tuple(block_types))]
|
|
36
36
|
return out
|
|
@@ -44,7 +44,7 @@ class TagParser(Protocol):
|
|
|
44
44
|
element, do some computation, and return some kind of structured
|
|
45
45
|
output, represented as a subclass of Pydantic `BaseModel`.
|
|
46
46
|
This protocol implementation can cover both classes as well
|
|
47
|
-
as curried
|
|
47
|
+
as curried functions as parsers (e.g. `partial`).
|
|
48
48
|
"""
|
|
49
49
|
|
|
50
50
|
def __call__(self, element: _Element) -> BaseModel: ...
|
|
@@ -69,7 +69,7 @@ class TextBlock(BaseModel):
|
|
|
69
69
|
|
|
70
70
|
class CodeBlock(BaseModel):
|
|
71
71
|
code: str
|
|
72
|
-
code_lang:
|
|
72
|
+
code_lang: str | None = None
|
|
73
73
|
|
|
74
74
|
|
|
75
75
|
class StructuredDataBlock(BaseModel):
|