data-designer 0.1.5__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/_version.py +2 -2
- data_designer/cli/README.md +15 -1
- data_designer/cli/commands/download.py +56 -0
- data_designer/cli/commands/list.py +4 -18
- data_designer/cli/controllers/__init__.py +2 -1
- data_designer/cli/controllers/download_controller.py +217 -0
- data_designer/cli/controllers/model_controller.py +4 -3
- data_designer/cli/forms/field.py +65 -19
- data_designer/cli/forms/model_builder.py +251 -44
- data_designer/cli/main.py +11 -1
- data_designer/cli/repositories/persona_repository.py +88 -0
- data_designer/cli/services/__init__.py +2 -1
- data_designer/cli/services/download_service.py +97 -0
- data_designer/cli/ui.py +131 -0
- data_designer/cli/utils.py +34 -0
- data_designer/config/analysis/__init__.py +2 -0
- data_designer/config/analysis/column_profilers.py +75 -7
- data_designer/config/analysis/column_statistics.py +192 -48
- data_designer/config/analysis/dataset_profiler.py +23 -5
- data_designer/config/analysis/utils/reporting.py +3 -3
- data_designer/config/base.py +3 -3
- data_designer/config/column_configs.py +27 -6
- data_designer/config/column_types.py +24 -17
- data_designer/config/config_builder.py +36 -27
- data_designer/config/data_designer_config.py +7 -7
- data_designer/config/datastore.py +6 -6
- data_designer/config/default_model_settings.py +27 -34
- data_designer/config/exports.py +8 -0
- data_designer/config/models.py +155 -29
- data_designer/config/preview_results.py +6 -8
- data_designer/config/processors.py +63 -2
- data_designer/config/sampler_constraints.py +1 -2
- data_designer/config/sampler_params.py +50 -31
- data_designer/config/seed.py +1 -2
- data_designer/config/utils/code_lang.py +4 -5
- data_designer/config/utils/constants.py +31 -8
- data_designer/config/utils/io_helpers.py +5 -5
- data_designer/config/utils/misc.py +1 -4
- data_designer/config/utils/numerical_helpers.py +2 -2
- data_designer/config/utils/type_helpers.py +3 -3
- data_designer/config/utils/validation.py +7 -8
- data_designer/config/utils/visualization.py +32 -17
- data_designer/config/validator_params.py +4 -8
- data_designer/engine/analysis/column_profilers/base.py +0 -7
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +2 -3
- data_designer/engine/analysis/column_statistics.py +16 -16
- data_designer/engine/analysis/dataset_profiler.py +25 -4
- data_designer/engine/analysis/utils/column_statistics_calculations.py +71 -49
- data_designer/engine/analysis/utils/judge_score_processing.py +5 -5
- data_designer/engine/column_generators/generators/base.py +34 -0
- data_designer/engine/column_generators/generators/embedding.py +45 -0
- data_designer/engine/column_generators/generators/{llm_generators.py → llm_completion.py} +17 -49
- data_designer/engine/column_generators/registry.py +4 -2
- data_designer/engine/column_generators/utils/judge_score_factory.py +5 -6
- data_designer/engine/configurable_task.py +2 -2
- data_designer/engine/dataset_builders/artifact_storage.py +1 -2
- data_designer/engine/dataset_builders/column_wise_builder.py +58 -15
- data_designer/engine/dataset_builders/utils/concurrency.py +6 -6
- data_designer/engine/models/facade.py +66 -9
- data_designer/engine/models/litellm_overrides.py +5 -6
- data_designer/engine/models/parsers/errors.py +2 -4
- data_designer/engine/models/parsers/parser.py +2 -3
- data_designer/engine/models/parsers/postprocessors.py +3 -4
- data_designer/engine/models/parsers/types.py +4 -4
- data_designer/engine/models/registry.py +47 -12
- data_designer/engine/models/telemetry.py +355 -0
- data_designer/engine/models/usage.py +7 -9
- data_designer/engine/processing/ginja/ast.py +1 -2
- data_designer/engine/processing/utils.py +40 -2
- data_designer/engine/registry/base.py +12 -12
- data_designer/engine/sampling_gen/constraints.py +1 -2
- data_designer/engine/sampling_gen/data_sources/base.py +14 -14
- data_designer/engine/sampling_gen/entities/phone_number.py +1 -2
- data_designer/engine/sampling_gen/people_gen.py +3 -7
- data_designer/engine/validators/base.py +2 -2
- data_designer/logging.py +2 -2
- data_designer/plugin_manager.py +3 -3
- data_designer/plugins/plugin.py +3 -3
- data_designer/plugins/registry.py +2 -2
- {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/METADATA +32 -1
- {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/RECORD +84 -77
- {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/WHEEL +0 -0
- {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/entry_points.txt +0 -0
- {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -5,10 +5,11 @@ from __future__ import annotations
|
|
|
5
5
|
|
|
6
6
|
import logging
|
|
7
7
|
|
|
8
|
-
from data_designer.config.models import ModelConfig
|
|
8
|
+
from data_designer.config.models import GenerationType, ModelConfig
|
|
9
9
|
from data_designer.engine.model_provider import ModelProvider, ModelProviderRegistry
|
|
10
10
|
from data_designer.engine.models.facade import ModelFacade
|
|
11
11
|
from data_designer.engine.models.litellm_overrides import apply_litellm_patches
|
|
12
|
+
from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenUsageStats
|
|
12
13
|
from data_designer.engine.secret_resolver import SecretResolver
|
|
13
14
|
|
|
14
15
|
logger = logging.getLogger(__name__)
|
|
@@ -25,7 +26,7 @@ class ModelRegistry:
|
|
|
25
26
|
self._secret_resolver = secret_resolver
|
|
26
27
|
self._model_provider_registry = model_provider_registry
|
|
27
28
|
self._model_configs = {}
|
|
28
|
-
self._models = {}
|
|
29
|
+
self._models: dict[str, ModelFacade] = {}
|
|
29
30
|
self._set_model_configs(model_configs)
|
|
30
31
|
|
|
31
32
|
@property
|
|
@@ -69,11 +70,36 @@ class ModelRegistry:
|
|
|
69
70
|
if model.usage_stats.has_usage
|
|
70
71
|
}
|
|
71
72
|
|
|
73
|
+
def get_model_usage_snapshot(self) -> dict[str, ModelUsageStats]:
|
|
74
|
+
return {
|
|
75
|
+
model.model_name: model.usage_stats.model_copy(deep=True)
|
|
76
|
+
for model in self._models.values()
|
|
77
|
+
if model.usage_stats.has_usage
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
def get_usage_deltas(self, snapshot: dict[str, ModelUsageStats]) -> dict[str, ModelUsageStats]:
|
|
81
|
+
deltas = {}
|
|
82
|
+
for model_name, current in self.get_model_usage_snapshot().items():
|
|
83
|
+
prev = snapshot.get(model_name)
|
|
84
|
+
delta_input = current.token_usage.input_tokens - (prev.token_usage.input_tokens if prev else 0)
|
|
85
|
+
delta_output = current.token_usage.output_tokens - (prev.token_usage.output_tokens if prev else 0)
|
|
86
|
+
delta_successful = current.request_usage.successful_requests - (
|
|
87
|
+
prev.request_usage.successful_requests if prev else 0
|
|
88
|
+
)
|
|
89
|
+
delta_failed = current.request_usage.failed_requests - (prev.request_usage.failed_requests if prev else 0)
|
|
90
|
+
|
|
91
|
+
if delta_input > 0 or delta_output > 0 or delta_successful > 0 or delta_failed > 0:
|
|
92
|
+
deltas[model_name] = ModelUsageStats(
|
|
93
|
+
token_usage=TokenUsageStats(input_tokens=delta_input, output_tokens=delta_output),
|
|
94
|
+
request_usage=RequestUsageStats(successful_requests=delta_successful, failed_requests=delta_failed),
|
|
95
|
+
)
|
|
96
|
+
return deltas
|
|
97
|
+
|
|
72
98
|
def get_model_provider(self, *, model_alias: str) -> ModelProvider:
|
|
73
99
|
model_config = self.get_model_config(model_alias=model_alias)
|
|
74
100
|
return self._model_provider_registry.get_provider(model_config.provider)
|
|
75
101
|
|
|
76
|
-
def run_health_check(self, model_aliases:
|
|
102
|
+
def run_health_check(self, model_aliases: list[str]) -> None:
|
|
77
103
|
logger.info("🩺 Running health checks for models...")
|
|
78
104
|
for model_alias in model_aliases:
|
|
79
105
|
model = self.get_model(model_alias=model_alias)
|
|
@@ -81,15 +107,24 @@ class ModelRegistry:
|
|
|
81
107
|
f" |-- 👀 Checking {model.model_name!r} in provider named {model.model_provider_name!r} for model alias {model.model_alias!r}..."
|
|
82
108
|
)
|
|
83
109
|
try:
|
|
84
|
-
model.
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
110
|
+
if model.model_generation_type == GenerationType.EMBEDDING:
|
|
111
|
+
model.generate_text_embeddings(
|
|
112
|
+
input_texts=["Hello!"],
|
|
113
|
+
skip_usage_tracking=True,
|
|
114
|
+
purpose="running health checks",
|
|
115
|
+
)
|
|
116
|
+
elif model.model_generation_type == GenerationType.CHAT_COMPLETION:
|
|
117
|
+
model.generate(
|
|
118
|
+
prompt="Hello!",
|
|
119
|
+
parser=lambda x: x,
|
|
120
|
+
system_prompt="You are a helpful assistant.",
|
|
121
|
+
max_correction_steps=0,
|
|
122
|
+
max_conversation_restarts=0,
|
|
123
|
+
skip_usage_tracking=True,
|
|
124
|
+
purpose="running health checks",
|
|
125
|
+
)
|
|
126
|
+
else:
|
|
127
|
+
raise ValueError(f"Unsupported generation type: {model.model_generation_type}")
|
|
93
128
|
logger.info(" |-- ✅ Passed!")
|
|
94
129
|
except Exception as e:
|
|
95
130
|
logger.error(" |-- ❌ Failed!")
|
|
@@ -0,0 +1,355 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
Telemetry handler for NeMo products.
|
|
6
|
+
|
|
7
|
+
Environment variables:
|
|
8
|
+
- NEMO_TELEMETRY_ENABLED: Whether telemetry is enabled.
|
|
9
|
+
- NEMO_DEPLOYMENT_TYPE: The deployment type the event came from.
|
|
10
|
+
- NEMO_TELEMETRY_ENDPOINT: The endpoint to send the telemetry events to.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
import os
|
|
17
|
+
import platform
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
from datetime import datetime, timezone
|
|
20
|
+
from enum import Enum
|
|
21
|
+
from typing import Any, ClassVar
|
|
22
|
+
|
|
23
|
+
import httpx
|
|
24
|
+
from pydantic import BaseModel, Field
|
|
25
|
+
|
|
26
|
+
TELEMETRY_ENABLED = os.getenv("NEMO_TELEMETRY_ENABLED", "true").lower() in ("1", "true", "yes")
|
|
27
|
+
CLIENT_ID = "184482118588404"
|
|
28
|
+
NEMO_TELEMETRY_VERSION = "nemo-telemetry/1.0"
|
|
29
|
+
MAX_RETRIES = 3
|
|
30
|
+
NEMO_TELEMETRY_ENDPOINT = os.getenv(
|
|
31
|
+
"NEMO_TELEMETRY_ENDPOINT", "https://events.telemetry.data.nvidia.com/v1.1/events/json"
|
|
32
|
+
).lower()
|
|
33
|
+
CPU_ARCHITECTURE = platform.uname().machine
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class NemoSourceEnum(str, Enum):
|
|
37
|
+
INFERENCE = "inference"
|
|
38
|
+
AUDITOR = "auditor"
|
|
39
|
+
DATADESIGNER = "datadesigner"
|
|
40
|
+
EVALUATOR = "evaluator"
|
|
41
|
+
GUARDRAILS = "guardrails"
|
|
42
|
+
UNDEFINED = "undefined"
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class DeploymentTypeEnum(str, Enum):
|
|
46
|
+
LIBRARY = "library"
|
|
47
|
+
API = "api"
|
|
48
|
+
UNDEFINED = "undefined"
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
_deployment_type_raw = os.getenv("NEMO_DEPLOYMENT_TYPE", "library").lower()
|
|
52
|
+
try:
|
|
53
|
+
DEPLOYMENT_TYPE = DeploymentTypeEnum(_deployment_type_raw)
|
|
54
|
+
except ValueError:
|
|
55
|
+
valid_values = [e.value for e in DeploymentTypeEnum]
|
|
56
|
+
raise ValueError(
|
|
57
|
+
f"Invalid NEMO_DEPLOYMENT_TYPE: {_deployment_type_raw!r}. Must be one of: {valid_values}"
|
|
58
|
+
) from None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class TaskStatusEnum(str, Enum):
|
|
62
|
+
SUCCESS = "success"
|
|
63
|
+
FAILURE = "failure"
|
|
64
|
+
UNDEFINED = "undefined"
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class TelemetryEvent(BaseModel):
|
|
68
|
+
_event_name: ClassVar[str] # Subclasses must define this
|
|
69
|
+
_schema_version: ClassVar[str] = "1.3"
|
|
70
|
+
|
|
71
|
+
def __init_subclass__(cls, **kwargs: Any) -> None:
|
|
72
|
+
super().__init_subclass__(**kwargs)
|
|
73
|
+
if "_event_name" not in cls.__dict__:
|
|
74
|
+
raise TypeError(f"{cls.__name__} must define '_event_name' class variable")
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class InferenceEvent(TelemetryEvent):
|
|
78
|
+
_event_name: ClassVar[str] = "inference_event"
|
|
79
|
+
|
|
80
|
+
nemo_source: NemoSourceEnum = Field(
|
|
81
|
+
...,
|
|
82
|
+
alias="nemoSource",
|
|
83
|
+
description="The NeMo product that created the event (i.e. data-designer).",
|
|
84
|
+
)
|
|
85
|
+
task: str = Field(
|
|
86
|
+
...,
|
|
87
|
+
description="The type of task that was performed that generated the inference event (i.e. preview-job, batch-job).",
|
|
88
|
+
)
|
|
89
|
+
task_status: TaskStatusEnum = Field(
|
|
90
|
+
...,
|
|
91
|
+
alias="taskStatus",
|
|
92
|
+
description="The status of the task.",
|
|
93
|
+
)
|
|
94
|
+
deployment_type: DeploymentTypeEnum = Field(
|
|
95
|
+
default=DEPLOYMENT_TYPE,
|
|
96
|
+
alias="deploymentType",
|
|
97
|
+
description="The deployment type the event came from.",
|
|
98
|
+
)
|
|
99
|
+
model: str = Field(
|
|
100
|
+
...,
|
|
101
|
+
description="The name of the model that was used.",
|
|
102
|
+
)
|
|
103
|
+
model_group: str = Field(
|
|
104
|
+
default="undefined",
|
|
105
|
+
alias="modelGroup",
|
|
106
|
+
description="An optional identifier to group models together.",
|
|
107
|
+
)
|
|
108
|
+
input_bytes: int = Field(
|
|
109
|
+
default=-1,
|
|
110
|
+
alias="inputBytes",
|
|
111
|
+
description="Number of bytes provided as input to the model. -1 if not available.",
|
|
112
|
+
ge=-9223372036854775808,
|
|
113
|
+
le=9223372036854775807,
|
|
114
|
+
)
|
|
115
|
+
input_tokens: int = Field(
|
|
116
|
+
default=-1,
|
|
117
|
+
alias="inputTokens",
|
|
118
|
+
description="Number of tokens provided as input to the model. -1 if not available.",
|
|
119
|
+
ge=-9223372036854775808,
|
|
120
|
+
le=9223372036854775807,
|
|
121
|
+
)
|
|
122
|
+
output_bytes: int = Field(
|
|
123
|
+
default=-1,
|
|
124
|
+
alias="outputBytes",
|
|
125
|
+
description="Number of bytes returned by the model. -1 if not available.",
|
|
126
|
+
ge=-9223372036854775808,
|
|
127
|
+
le=9223372036854775807,
|
|
128
|
+
)
|
|
129
|
+
output_tokens: int = Field(
|
|
130
|
+
default=-1,
|
|
131
|
+
alias="outputTokens",
|
|
132
|
+
description="Number of tokens returned by the model. -1 if not available.",
|
|
133
|
+
ge=-9223372036854775808,
|
|
134
|
+
le=9223372036854775807,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
model_config = {"populate_by_name": True}
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
@dataclass
|
|
141
|
+
class QueuedEvent:
|
|
142
|
+
event: TelemetryEvent
|
|
143
|
+
timestamp: datetime
|
|
144
|
+
retry_count: int = 0
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _get_iso_timestamp(dt: datetime | None = None) -> str:
|
|
148
|
+
if dt is None:
|
|
149
|
+
dt = datetime.now(timezone.utc)
|
|
150
|
+
return dt.strftime("%Y-%m-%dT%H:%M:%S.") + f"{dt.microsecond // 1000:03d}Z"
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def build_payload(
|
|
154
|
+
events: list[QueuedEvent], *, source_client_version: str, session_id: str = "undefined"
|
|
155
|
+
) -> dict[str, Any]:
|
|
156
|
+
return {
|
|
157
|
+
"browserType": "undefined", # do not change
|
|
158
|
+
"clientId": CLIENT_ID,
|
|
159
|
+
"clientType": "Native", # do not change
|
|
160
|
+
"clientVariant": "Release", # do not change
|
|
161
|
+
"clientVer": source_client_version,
|
|
162
|
+
"cpuArchitecture": CPU_ARCHITECTURE,
|
|
163
|
+
"deviceGdprBehOptIn": "None", # do not change
|
|
164
|
+
"deviceGdprFuncOptIn": "None", # do not change
|
|
165
|
+
"deviceGdprTechOptIn": "None", # do not change
|
|
166
|
+
"deviceId": "undefined", # do not change
|
|
167
|
+
"deviceMake": "undefined", # do not change
|
|
168
|
+
"deviceModel": "undefined", # do not change
|
|
169
|
+
"deviceOS": "undefined", # do not change
|
|
170
|
+
"deviceOSVersion": "undefined", # do not change
|
|
171
|
+
"deviceType": "undefined", # do not change
|
|
172
|
+
"eventProtocol": "1.6", # do not change
|
|
173
|
+
"eventSchemaVer": events[0].event._schema_version,
|
|
174
|
+
"eventSysVer": NEMO_TELEMETRY_VERSION,
|
|
175
|
+
"externalUserId": "undefined", # do not change
|
|
176
|
+
"gdprBehOptIn": "None", # do not change
|
|
177
|
+
"gdprFuncOptIn": "None", # do not change
|
|
178
|
+
"gdprTechOptIn": "None", # do not change
|
|
179
|
+
"idpId": "undefined", # do not change
|
|
180
|
+
"integrationId": "undefined", # do not change
|
|
181
|
+
"productName": "undefined", # do not change
|
|
182
|
+
"productVersion": "undefined", # do not change
|
|
183
|
+
"sentTs": _get_iso_timestamp(),
|
|
184
|
+
"sessionId": session_id,
|
|
185
|
+
"userId": "undefined", # do not change
|
|
186
|
+
"events": [
|
|
187
|
+
{
|
|
188
|
+
"ts": _get_iso_timestamp(queued.timestamp),
|
|
189
|
+
"parameters": queued.event.model_dump(by_alias=True),
|
|
190
|
+
"name": queued.event._event_name,
|
|
191
|
+
}
|
|
192
|
+
for queued in events
|
|
193
|
+
],
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
class TelemetryHandler:
|
|
198
|
+
"""
|
|
199
|
+
Handles telemetry event batching, flushing, and retry logic for NeMo products.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
flush_interval_seconds (float): The interval in seconds to flush the events.
|
|
203
|
+
max_queue_size (int): The maximum number of events to queue before flushing.
|
|
204
|
+
max_retries (int): The maximum number of times to retry sending an event.
|
|
205
|
+
source_client_version (str): The version of the source client. This should be the version of
|
|
206
|
+
the actual NeMo product that is sending the events, typically the same as the version of
|
|
207
|
+
a PyPi package that a user would install.
|
|
208
|
+
session_id (str): An optional session ID to associate with the events.
|
|
209
|
+
This should be a unique identifier for the session, such as a UUID.
|
|
210
|
+
It is used to group events together.
|
|
211
|
+
"""
|
|
212
|
+
|
|
213
|
+
def __init__(
|
|
214
|
+
self,
|
|
215
|
+
flush_interval_seconds: float = 120.0,
|
|
216
|
+
max_queue_size: int = 50,
|
|
217
|
+
max_retries: int = MAX_RETRIES,
|
|
218
|
+
source_client_version: str = "undefined",
|
|
219
|
+
session_id: str = "undefined",
|
|
220
|
+
):
|
|
221
|
+
self._flush_interval = flush_interval_seconds
|
|
222
|
+
self._max_queue_size = max_queue_size
|
|
223
|
+
self._max_retries = max_retries
|
|
224
|
+
self._events: list[QueuedEvent] = []
|
|
225
|
+
self._dlq: list[QueuedEvent] = [] # Dead letter queue for retry
|
|
226
|
+
self._flush_signal = asyncio.Event()
|
|
227
|
+
self._timer_task: asyncio.Task | None = None
|
|
228
|
+
self._running = False
|
|
229
|
+
self._source_client_version = source_client_version
|
|
230
|
+
self._session_id = session_id
|
|
231
|
+
|
|
232
|
+
async def astart(self) -> None:
|
|
233
|
+
if self._running:
|
|
234
|
+
return
|
|
235
|
+
self._running = True
|
|
236
|
+
self._timer_task = asyncio.create_task(self._timer_loop())
|
|
237
|
+
|
|
238
|
+
async def astop(self) -> None:
|
|
239
|
+
self._running = False
|
|
240
|
+
self._flush_signal.set()
|
|
241
|
+
if self._timer_task:
|
|
242
|
+
self._timer_task.cancel()
|
|
243
|
+
try:
|
|
244
|
+
await self._timer_task
|
|
245
|
+
except asyncio.CancelledError:
|
|
246
|
+
pass
|
|
247
|
+
self._timer_task = None
|
|
248
|
+
await self._flush_events()
|
|
249
|
+
|
|
250
|
+
async def aflush(self) -> None:
|
|
251
|
+
self._flush_signal.set()
|
|
252
|
+
|
|
253
|
+
def start(self) -> None:
|
|
254
|
+
self._run_sync(self.astart())
|
|
255
|
+
|
|
256
|
+
def stop(self) -> None:
|
|
257
|
+
self._run_sync(self.astop())
|
|
258
|
+
|
|
259
|
+
def flush(self) -> None:
|
|
260
|
+
self._flush_signal.set()
|
|
261
|
+
|
|
262
|
+
def enqueue(self, event: TelemetryEvent) -> None:
|
|
263
|
+
if not TELEMETRY_ENABLED:
|
|
264
|
+
return
|
|
265
|
+
if not isinstance(event, TelemetryEvent):
|
|
266
|
+
# Silently fail as we prioritize not disrupting upstream call sites and telemetry is best effort
|
|
267
|
+
return
|
|
268
|
+
queued = QueuedEvent(event=event, timestamp=datetime.now(timezone.utc))
|
|
269
|
+
self._events.append(queued)
|
|
270
|
+
if len(self._events) >= self._max_queue_size:
|
|
271
|
+
self._flush_signal.set()
|
|
272
|
+
|
|
273
|
+
def _run_sync(self, coro: Any) -> Any:
|
|
274
|
+
try:
|
|
275
|
+
loop = asyncio.get_running_loop()
|
|
276
|
+
except RuntimeError:
|
|
277
|
+
loop = None
|
|
278
|
+
|
|
279
|
+
if loop and loop.is_running():
|
|
280
|
+
import concurrent.futures
|
|
281
|
+
|
|
282
|
+
with concurrent.futures.ThreadPoolExecutor() as pool:
|
|
283
|
+
future = pool.submit(asyncio.run, coro)
|
|
284
|
+
return future.result()
|
|
285
|
+
else:
|
|
286
|
+
return asyncio.run(coro)
|
|
287
|
+
|
|
288
|
+
def __enter__(self) -> TelemetryHandler:
|
|
289
|
+
self.start()
|
|
290
|
+
return self
|
|
291
|
+
|
|
292
|
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
293
|
+
self.stop()
|
|
294
|
+
|
|
295
|
+
async def __aenter__(self) -> TelemetryHandler:
|
|
296
|
+
await self.astart()
|
|
297
|
+
return self
|
|
298
|
+
|
|
299
|
+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
300
|
+
await self.astop()
|
|
301
|
+
|
|
302
|
+
async def _timer_loop(self) -> None:
|
|
303
|
+
while self._running:
|
|
304
|
+
try:
|
|
305
|
+
await asyncio.wait_for(
|
|
306
|
+
self._flush_signal.wait(),
|
|
307
|
+
timeout=self._flush_interval,
|
|
308
|
+
)
|
|
309
|
+
except asyncio.TimeoutError:
|
|
310
|
+
pass
|
|
311
|
+
self._flush_signal.clear()
|
|
312
|
+
await self._flush_events()
|
|
313
|
+
|
|
314
|
+
async def _flush_events(self) -> None:
|
|
315
|
+
dlq_events, self._dlq = self._dlq, []
|
|
316
|
+
new_events, self._events = self._events, []
|
|
317
|
+
events_to_send = dlq_events + new_events
|
|
318
|
+
if events_to_send:
|
|
319
|
+
await self._send_events(events_to_send)
|
|
320
|
+
|
|
321
|
+
async def _send_events(self, events: list[QueuedEvent]) -> None:
|
|
322
|
+
async with httpx.AsyncClient() as client:
|
|
323
|
+
await self._send_events_with_client(client, events)
|
|
324
|
+
|
|
325
|
+
async def _send_events_with_client(self, client: httpx.AsyncClient, events: list[QueuedEvent]) -> None:
|
|
326
|
+
if not events:
|
|
327
|
+
return
|
|
328
|
+
|
|
329
|
+
payload = build_payload(events, source_client_version=self._source_client_version, session_id=self._session_id)
|
|
330
|
+
try:
|
|
331
|
+
response = await client.post(NEMO_TELEMETRY_ENDPOINT, json=payload)
|
|
332
|
+
# 2xx, 400, 422 are all considered complete (no retry)
|
|
333
|
+
# 400/422 indicate bad payload which retrying won't fix
|
|
334
|
+
if response.status_code in (400, 422) or response.is_success:
|
|
335
|
+
return
|
|
336
|
+
# 413 (payload too large) - split and retry
|
|
337
|
+
if response.status_code == 413:
|
|
338
|
+
if len(events) == 1:
|
|
339
|
+
# Can't split further, drop the event
|
|
340
|
+
return
|
|
341
|
+
mid = len(events) // 2
|
|
342
|
+
await self._send_events_with_client(client, events[:mid])
|
|
343
|
+
await self._send_events_with_client(client, events[mid:])
|
|
344
|
+
return
|
|
345
|
+
if response.status_code == 408 or response.status_code >= 500:
|
|
346
|
+
self._add_to_dlq(events)
|
|
347
|
+
except httpx.HTTPError:
|
|
348
|
+
self._add_to_dlq(events)
|
|
349
|
+
|
|
350
|
+
def _add_to_dlq(self, events: list[QueuedEvent]) -> None:
|
|
351
|
+
for queued in events:
|
|
352
|
+
queued.retry_count += 1
|
|
353
|
+
if queued.retry_count > self._max_retries:
|
|
354
|
+
continue
|
|
355
|
+
self._dlq.append(queued)
|
|
@@ -11,20 +11,20 @@ logger = logging.getLogger(__name__)
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class TokenUsageStats(BaseModel):
|
|
14
|
-
|
|
15
|
-
|
|
14
|
+
input_tokens: int = 0
|
|
15
|
+
output_tokens: int = 0
|
|
16
16
|
|
|
17
17
|
@computed_field
|
|
18
18
|
def total_tokens(self) -> int:
|
|
19
|
-
return self.
|
|
19
|
+
return self.input_tokens + self.output_tokens
|
|
20
20
|
|
|
21
21
|
@property
|
|
22
22
|
def has_usage(self) -> bool:
|
|
23
23
|
return self.total_tokens > 0
|
|
24
24
|
|
|
25
|
-
def extend(self, *,
|
|
26
|
-
self.
|
|
27
|
-
self.
|
|
25
|
+
def extend(self, *, input_tokens: int, output_tokens: int) -> None:
|
|
26
|
+
self.input_tokens += input_tokens
|
|
27
|
+
self.output_tokens += output_tokens
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
class RequestUsageStats(BaseModel):
|
|
@@ -56,9 +56,7 @@ class ModelUsageStats(BaseModel):
|
|
|
56
56
|
self, *, token_usage: TokenUsageStats | None = None, request_usage: RequestUsageStats | None = None
|
|
57
57
|
) -> None:
|
|
58
58
|
if token_usage is not None:
|
|
59
|
-
self.token_usage.extend(
|
|
60
|
-
prompt_tokens=token_usage.prompt_tokens, completion_tokens=token_usage.completion_tokens
|
|
61
|
-
)
|
|
59
|
+
self.token_usage.extend(input_tokens=token_usage.input_tokens, output_tokens=token_usage.output_tokens)
|
|
62
60
|
if request_usage is not None:
|
|
63
61
|
self.request_usage.extend(
|
|
64
62
|
successful_requests=request_usage.successful_requests, failed_requests=request_usage.failed_requests
|
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
4
|
from collections import deque
|
|
5
|
-
from typing import Optional, Type
|
|
6
5
|
|
|
7
6
|
from jinja2 import nodes as j_nodes
|
|
8
7
|
|
|
@@ -33,7 +32,7 @@ def ast_max_depth(node: j_nodes.Node) -> int:
|
|
|
33
32
|
return max_depth
|
|
34
33
|
|
|
35
34
|
|
|
36
|
-
def ast_descendant_count(ast: j_nodes.Node, only_type:
|
|
35
|
+
def ast_descendant_count(ast: j_nodes.Node, only_type: type[j_nodes.Node] | None = None) -> int:
|
|
37
36
|
"""Count the number of nodes which descend from the given node.
|
|
38
37
|
|
|
39
38
|
Args:
|
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
import ast
|
|
4
5
|
import json
|
|
5
6
|
import logging
|
|
6
|
-
|
|
7
|
+
import re
|
|
8
|
+
from typing import Any, TypeVar, overload
|
|
7
9
|
|
|
8
10
|
import pandas as pd
|
|
9
11
|
|
|
@@ -25,7 +27,7 @@ def concat_datasets(datasets: list[pd.DataFrame]) -> pd.DataFrame:
|
|
|
25
27
|
# Overloads to help static type checker better understand
|
|
26
28
|
# the input/output types of the deserialize_json_values function.
|
|
27
29
|
@overload
|
|
28
|
-
def deserialize_json_values(data: str) ->
|
|
30
|
+
def deserialize_json_values(data: str) -> dict[str, Any] | list[Any] | Any: ...
|
|
29
31
|
|
|
30
32
|
|
|
31
33
|
@overload
|
|
@@ -100,6 +102,42 @@ def deserialize_json_values(data):
|
|
|
100
102
|
return data
|
|
101
103
|
|
|
102
104
|
|
|
105
|
+
def parse_list_string(text: str) -> list[str]:
|
|
106
|
+
"""Parse a list from a string, handling JSON arrays, Python lists, and trailing commas."""
|
|
107
|
+
text = text.strip()
|
|
108
|
+
|
|
109
|
+
# Try JSON first
|
|
110
|
+
try:
|
|
111
|
+
list_obj = json.loads(text)
|
|
112
|
+
if isinstance(list_obj, list):
|
|
113
|
+
return _clean_whitespace(list_obj)
|
|
114
|
+
except json.JSONDecodeError:
|
|
115
|
+
pass
|
|
116
|
+
|
|
117
|
+
# Remove trailing commas before closing brackets (common in JSON-like strings)
|
|
118
|
+
text_cleaned = re.sub(r",\s*]", "]", text)
|
|
119
|
+
text_cleaned = re.sub(r",\s*}", "}", text_cleaned)
|
|
120
|
+
|
|
121
|
+
# Try JSON again with cleaned text
|
|
122
|
+
try:
|
|
123
|
+
return _clean_whitespace(json.loads(text_cleaned))
|
|
124
|
+
except json.JSONDecodeError:
|
|
125
|
+
pass
|
|
126
|
+
|
|
127
|
+
# Try Python literal eval (handles single quotes)
|
|
128
|
+
try:
|
|
129
|
+
return _clean_whitespace(ast.literal_eval(text_cleaned))
|
|
130
|
+
except (ValueError, SyntaxError):
|
|
131
|
+
pass
|
|
132
|
+
|
|
133
|
+
# If all else fails, return the original text
|
|
134
|
+
return [text.strip()]
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _clean_whitespace(texts: list[str]) -> list[str]:
|
|
138
|
+
return [text.strip() for text in texts]
|
|
139
|
+
|
|
140
|
+
|
|
103
141
|
def _verify_columns_are_unique(datasets: list[pd.DataFrame]) -> None:
|
|
104
142
|
joined_columns = set()
|
|
105
143
|
for df in datasets:
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
4
|
import threading
|
|
5
|
-
from typing import Any, Generic,
|
|
5
|
+
from typing import Any, Generic, TypeVar
|
|
6
6
|
|
|
7
7
|
from data_designer.config.base import ConfigBase
|
|
8
8
|
from data_designer.config.utils.type_helpers import StrEnum
|
|
@@ -16,14 +16,14 @@ TaskConfigT = TypeVar("TaskConfigT", bound=ConfigBase)
|
|
|
16
16
|
|
|
17
17
|
class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]):
|
|
18
18
|
# registered type name -> type
|
|
19
|
-
_registry: dict[EnumNameT,
|
|
19
|
+
_registry: dict[EnumNameT, type[TaskT]] = {}
|
|
20
20
|
# type -> registered type name
|
|
21
|
-
_reverse_registry: dict[
|
|
21
|
+
_reverse_registry: dict[type[TaskT], EnumNameT] = {}
|
|
22
22
|
|
|
23
23
|
# registered type name -> config type
|
|
24
|
-
_config_registry: dict[EnumNameT,
|
|
24
|
+
_config_registry: dict[EnumNameT, type[TaskConfigT]] = {}
|
|
25
25
|
# config type -> registered type name
|
|
26
|
-
_reverse_config_registry: dict[
|
|
26
|
+
_reverse_config_registry: dict[type[TaskConfigT], EnumNameT] = {}
|
|
27
27
|
|
|
28
28
|
# all registries are singletons
|
|
29
29
|
_instance = None
|
|
@@ -33,8 +33,8 @@ class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]):
|
|
|
33
33
|
def register(
|
|
34
34
|
cls,
|
|
35
35
|
name: EnumNameT,
|
|
36
|
-
task:
|
|
37
|
-
config:
|
|
36
|
+
task: type[TaskT],
|
|
37
|
+
config: type[TaskConfigT],
|
|
38
38
|
raise_on_collision: bool = False,
|
|
39
39
|
) -> None:
|
|
40
40
|
if cls._has_been_registered(name):
|
|
@@ -52,22 +52,22 @@ class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]):
|
|
|
52
52
|
cls._reverse_config_registry[config] = name
|
|
53
53
|
|
|
54
54
|
@classmethod
|
|
55
|
-
def get_task_type(cls, name: EnumNameT) ->
|
|
55
|
+
def get_task_type(cls, name: EnumNameT) -> type[TaskT]:
|
|
56
56
|
cls._raise_if_not_registered(name, cls._registry)
|
|
57
57
|
return cls._registry[name]
|
|
58
58
|
|
|
59
59
|
@classmethod
|
|
60
|
-
def get_config_type(cls, name: EnumNameT) ->
|
|
60
|
+
def get_config_type(cls, name: EnumNameT) -> type[TaskConfigT]:
|
|
61
61
|
cls._raise_if_not_registered(name, cls._config_registry)
|
|
62
62
|
return cls._config_registry[name]
|
|
63
63
|
|
|
64
64
|
@classmethod
|
|
65
|
-
def get_registered_name(cls, task:
|
|
65
|
+
def get_registered_name(cls, task: type[TaskT]) -> EnumNameT:
|
|
66
66
|
cls._raise_if_not_registered(task, cls._reverse_registry)
|
|
67
67
|
return cls._reverse_registry[task]
|
|
68
68
|
|
|
69
69
|
@classmethod
|
|
70
|
-
def get_for_config_type(cls, config:
|
|
70
|
+
def get_for_config_type(cls, config: type[TaskConfigT]) -> type[TaskT]:
|
|
71
71
|
cls._raise_if_not_registered(config, cls._reverse_config_registry)
|
|
72
72
|
name = cls._reverse_config_registry[config]
|
|
73
73
|
return cls.get_task_type(name)
|
|
@@ -77,7 +77,7 @@ class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]):
|
|
|
77
77
|
return name in cls._registry
|
|
78
78
|
|
|
79
79
|
@classmethod
|
|
80
|
-
def _raise_if_not_registered(cls, key: EnumNameT |
|
|
80
|
+
def _raise_if_not_registered(cls, key: EnumNameT | type[TaskT] | type[TaskConfigT], mapping: dict) -> None:
|
|
81
81
|
if not (isinstance(key, StrEnum) or isinstance(key, str)):
|
|
82
82
|
cls._raise_if_not_type(key)
|
|
83
83
|
if key not in mapping:
|
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
4
|
from abc import ABC, abstractmethod
|
|
5
|
-
from typing import Type
|
|
6
5
|
|
|
7
6
|
import numpy as np
|
|
8
7
|
import pandas as pd
|
|
@@ -91,5 +90,5 @@ CONSTRAINT_TYPE_TO_CHECKER = {
|
|
|
91
90
|
}
|
|
92
91
|
|
|
93
92
|
|
|
94
|
-
def get_constraint_checker(constraint_type: ConstraintType) ->
|
|
93
|
+
def get_constraint_checker(constraint_type: ConstraintType) -> type[ConstraintChecker]:
|
|
95
94
|
return CONSTRAINT_TYPE_TO_CHECKER[ConstraintType(constraint_type)]
|