data-designer-engine 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (114) hide show
  1. data_designer/engine/__init__.py +2 -0
  2. data_designer/engine/_version.py +34 -0
  3. data_designer/engine/analysis/column_profilers/base.py +49 -0
  4. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +153 -0
  5. data_designer/engine/analysis/column_profilers/registry.py +22 -0
  6. data_designer/engine/analysis/column_statistics.py +145 -0
  7. data_designer/engine/analysis/dataset_profiler.py +149 -0
  8. data_designer/engine/analysis/errors.py +9 -0
  9. data_designer/engine/analysis/utils/column_statistics_calculations.py +234 -0
  10. data_designer/engine/analysis/utils/judge_score_processing.py +132 -0
  11. data_designer/engine/column_generators/__init__.py +2 -0
  12. data_designer/engine/column_generators/generators/__init__.py +2 -0
  13. data_designer/engine/column_generators/generators/base.py +122 -0
  14. data_designer/engine/column_generators/generators/embedding.py +35 -0
  15. data_designer/engine/column_generators/generators/expression.py +55 -0
  16. data_designer/engine/column_generators/generators/llm_completion.py +116 -0
  17. data_designer/engine/column_generators/generators/samplers.py +69 -0
  18. data_designer/engine/column_generators/generators/seed_dataset.py +144 -0
  19. data_designer/engine/column_generators/generators/validation.py +140 -0
  20. data_designer/engine/column_generators/registry.py +60 -0
  21. data_designer/engine/column_generators/utils/errors.py +15 -0
  22. data_designer/engine/column_generators/utils/generator_classification.py +43 -0
  23. data_designer/engine/column_generators/utils/judge_score_factory.py +58 -0
  24. data_designer/engine/column_generators/utils/prompt_renderer.py +100 -0
  25. data_designer/engine/compiler.py +97 -0
  26. data_designer/engine/configurable_task.py +71 -0
  27. data_designer/engine/dataset_builders/artifact_storage.py +283 -0
  28. data_designer/engine/dataset_builders/column_wise_builder.py +354 -0
  29. data_designer/engine/dataset_builders/errors.py +15 -0
  30. data_designer/engine/dataset_builders/multi_column_configs.py +46 -0
  31. data_designer/engine/dataset_builders/utils/__init__.py +2 -0
  32. data_designer/engine/dataset_builders/utils/concurrency.py +212 -0
  33. data_designer/engine/dataset_builders/utils/config_compiler.py +62 -0
  34. data_designer/engine/dataset_builders/utils/dag.py +62 -0
  35. data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +200 -0
  36. data_designer/engine/dataset_builders/utils/errors.py +15 -0
  37. data_designer/engine/dataset_builders/utils/progress_tracker.py +122 -0
  38. data_designer/engine/errors.py +51 -0
  39. data_designer/engine/model_provider.py +77 -0
  40. data_designer/engine/models/__init__.py +2 -0
  41. data_designer/engine/models/errors.py +300 -0
  42. data_designer/engine/models/facade.py +284 -0
  43. data_designer/engine/models/factory.py +42 -0
  44. data_designer/engine/models/litellm_overrides.py +179 -0
  45. data_designer/engine/models/parsers/__init__.py +2 -0
  46. data_designer/engine/models/parsers/errors.py +34 -0
  47. data_designer/engine/models/parsers/parser.py +235 -0
  48. data_designer/engine/models/parsers/postprocessors.py +93 -0
  49. data_designer/engine/models/parsers/tag_parsers.py +62 -0
  50. data_designer/engine/models/parsers/types.py +84 -0
  51. data_designer/engine/models/recipes/base.py +81 -0
  52. data_designer/engine/models/recipes/response_recipes.py +293 -0
  53. data_designer/engine/models/registry.py +151 -0
  54. data_designer/engine/models/telemetry.py +362 -0
  55. data_designer/engine/models/usage.py +73 -0
  56. data_designer/engine/models/utils.py +101 -0
  57. data_designer/engine/processing/ginja/__init__.py +2 -0
  58. data_designer/engine/processing/ginja/ast.py +65 -0
  59. data_designer/engine/processing/ginja/environment.py +463 -0
  60. data_designer/engine/processing/ginja/exceptions.py +56 -0
  61. data_designer/engine/processing/ginja/record.py +32 -0
  62. data_designer/engine/processing/gsonschema/__init__.py +2 -0
  63. data_designer/engine/processing/gsonschema/exceptions.py +15 -0
  64. data_designer/engine/processing/gsonschema/schema_transformers.py +83 -0
  65. data_designer/engine/processing/gsonschema/types.py +10 -0
  66. data_designer/engine/processing/gsonschema/validators.py +202 -0
  67. data_designer/engine/processing/processors/base.py +13 -0
  68. data_designer/engine/processing/processors/drop_columns.py +42 -0
  69. data_designer/engine/processing/processors/registry.py +25 -0
  70. data_designer/engine/processing/processors/schema_transform.py +71 -0
  71. data_designer/engine/processing/utils.py +169 -0
  72. data_designer/engine/registry/base.py +99 -0
  73. data_designer/engine/registry/data_designer_registry.py +39 -0
  74. data_designer/engine/registry/errors.py +12 -0
  75. data_designer/engine/resources/managed_dataset_generator.py +39 -0
  76. data_designer/engine/resources/managed_dataset_repository.py +197 -0
  77. data_designer/engine/resources/managed_storage.py +65 -0
  78. data_designer/engine/resources/resource_provider.py +77 -0
  79. data_designer/engine/resources/seed_reader.py +154 -0
  80. data_designer/engine/sampling_gen/column.py +91 -0
  81. data_designer/engine/sampling_gen/constraints.py +100 -0
  82. data_designer/engine/sampling_gen/data_sources/base.py +217 -0
  83. data_designer/engine/sampling_gen/data_sources/errors.py +12 -0
  84. data_designer/engine/sampling_gen/data_sources/sources.py +347 -0
  85. data_designer/engine/sampling_gen/entities/__init__.py +2 -0
  86. data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
  87. data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +90 -0
  88. data_designer/engine/sampling_gen/entities/email_address_utils.py +171 -0
  89. data_designer/engine/sampling_gen/entities/errors.py +10 -0
  90. data_designer/engine/sampling_gen/entities/national_id_utils.py +102 -0
  91. data_designer/engine/sampling_gen/entities/person.py +144 -0
  92. data_designer/engine/sampling_gen/entities/phone_number.py +128 -0
  93. data_designer/engine/sampling_gen/errors.py +26 -0
  94. data_designer/engine/sampling_gen/generator.py +122 -0
  95. data_designer/engine/sampling_gen/jinja_utils.py +64 -0
  96. data_designer/engine/sampling_gen/people_gen.py +199 -0
  97. data_designer/engine/sampling_gen/person_constants.py +56 -0
  98. data_designer/engine/sampling_gen/schema.py +147 -0
  99. data_designer/engine/sampling_gen/schema_builder.py +61 -0
  100. data_designer/engine/sampling_gen/utils.py +46 -0
  101. data_designer/engine/secret_resolver.py +82 -0
  102. data_designer/engine/testing/__init__.py +12 -0
  103. data_designer/engine/testing/stubs.py +133 -0
  104. data_designer/engine/testing/utils.py +20 -0
  105. data_designer/engine/validation.py +367 -0
  106. data_designer/engine/validators/__init__.py +19 -0
  107. data_designer/engine/validators/base.py +38 -0
  108. data_designer/engine/validators/local_callable.py +39 -0
  109. data_designer/engine/validators/python.py +254 -0
  110. data_designer/engine/validators/remote.py +89 -0
  111. data_designer/engine/validators/sql.py +65 -0
  112. data_designer_engine-0.4.0.dist-info/METADATA +50 -0
  113. data_designer_engine-0.4.0.dist-info/RECORD +114 -0
  114. data_designer_engine-0.4.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,362 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """
5
+ Telemetry handler for NeMo products.
6
+
7
+ Environment variables:
8
+ - NEMO_TELEMETRY_ENABLED: Whether telemetry is enabled.
9
+ - NEMO_DEPLOYMENT_TYPE: The deployment type the event came from.
10
+ - NEMO_TELEMETRY_ENDPOINT: The endpoint to send the telemetry events to.
11
+ - NEMO_SESSION_PREFIX: Optional prefix to add to session IDs.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import asyncio
17
+ import os
18
+ import platform
19
+ from dataclasses import dataclass
20
+ from datetime import datetime, timezone
21
+ from enum import Enum
22
+ from typing import Any, ClassVar
23
+
24
+ from pydantic import BaseModel, Field
25
+
26
+ from data_designer.lazy_heavy_imports import httpx
27
+
28
+ TELEMETRY_ENABLED = os.getenv("NEMO_TELEMETRY_ENABLED", "true").lower() in ("1", "true", "yes")
29
+ CLIENT_ID = "184482118588404"
30
+ NEMO_TELEMETRY_VERSION = "nemo-telemetry/1.0"
31
+ MAX_RETRIES = 3
32
+ NEMO_TELEMETRY_ENDPOINT = os.getenv(
33
+ "NEMO_TELEMETRY_ENDPOINT", "https://events.telemetry.data.nvidia.com/v1.1/events/json"
34
+ ).lower()
35
+ CPU_ARCHITECTURE = platform.uname().machine
36
+ SESSION_PREFIX = os.getenv("NEMO_SESSION_PREFIX")
37
+
38
+
39
+ class NemoSourceEnum(str, Enum):
40
+ INFERENCE = "inference"
41
+ AUDITOR = "auditor"
42
+ DATADESIGNER = "datadesigner"
43
+ EVALUATOR = "evaluator"
44
+ GUARDRAILS = "guardrails"
45
+ UNDEFINED = "undefined"
46
+
47
+
48
+ class DeploymentTypeEnum(str, Enum):
49
+ LIBRARY = "library"
50
+ API = "api"
51
+ UNDEFINED = "undefined"
52
+
53
+
54
+ _deployment_type_raw = os.getenv("NEMO_DEPLOYMENT_TYPE", "library").lower()
55
+ try:
56
+ DEPLOYMENT_TYPE = DeploymentTypeEnum(_deployment_type_raw)
57
+ except ValueError:
58
+ valid_values = [e.value for e in DeploymentTypeEnum]
59
+ raise ValueError(
60
+ f"Invalid NEMO_DEPLOYMENT_TYPE: {_deployment_type_raw!r}. Must be one of: {valid_values}"
61
+ ) from None
62
+
63
+
64
+ class TaskStatusEnum(str, Enum):
65
+ SUCCESS = "success"
66
+ FAILURE = "failure"
67
+ UNDEFINED = "undefined"
68
+
69
+
70
+ class TelemetryEvent(BaseModel):
71
+ _event_name: ClassVar[str] # Subclasses must define this
72
+ _schema_version: ClassVar[str] = "1.3"
73
+
74
+ def __init_subclass__(cls, **kwargs: Any) -> None:
75
+ super().__init_subclass__(**kwargs)
76
+ if "_event_name" not in cls.__dict__:
77
+ raise TypeError(f"{cls.__name__} must define '_event_name' class variable")
78
+
79
+
80
+ class InferenceEvent(TelemetryEvent):
81
+ _event_name: ClassVar[str] = "inference_event"
82
+
83
+ nemo_source: NemoSourceEnum = Field(
84
+ ...,
85
+ alias="nemoSource",
86
+ description="The NeMo product that created the event (i.e. data-designer).",
87
+ )
88
+ task: str = Field(
89
+ ...,
90
+ description="The type of task that was performed that generated the inference event (i.e. preview-job, batch-job).",
91
+ )
92
+ task_status: TaskStatusEnum = Field(
93
+ ...,
94
+ alias="taskStatus",
95
+ description="The status of the task.",
96
+ )
97
+ deployment_type: DeploymentTypeEnum = Field(
98
+ default=DEPLOYMENT_TYPE,
99
+ alias="deploymentType",
100
+ description="The deployment type the event came from.",
101
+ )
102
+ model: str = Field(
103
+ ...,
104
+ description="The name of the model that was used.",
105
+ )
106
+ model_group: str = Field(
107
+ default="undefined",
108
+ alias="modelGroup",
109
+ description="An optional identifier to group models together.",
110
+ )
111
+ input_bytes: int = Field(
112
+ default=-1,
113
+ alias="inputBytes",
114
+ description="Number of bytes provided as input to the model. -1 if not available.",
115
+ ge=-9223372036854775808,
116
+ le=9223372036854775807,
117
+ )
118
+ input_tokens: int = Field(
119
+ default=-1,
120
+ alias="inputTokens",
121
+ description="Number of tokens provided as input to the model. -1 if not available.",
122
+ ge=-9223372036854775808,
123
+ le=9223372036854775807,
124
+ )
125
+ output_bytes: int = Field(
126
+ default=-1,
127
+ alias="outputBytes",
128
+ description="Number of bytes returned by the model. -1 if not available.",
129
+ ge=-9223372036854775808,
130
+ le=9223372036854775807,
131
+ )
132
+ output_tokens: int = Field(
133
+ default=-1,
134
+ alias="outputTokens",
135
+ description="Number of tokens returned by the model. -1 if not available.",
136
+ ge=-9223372036854775808,
137
+ le=9223372036854775807,
138
+ )
139
+
140
+ model_config = {"populate_by_name": True}
141
+
142
+
143
+ @dataclass
144
+ class QueuedEvent:
145
+ event: TelemetryEvent
146
+ timestamp: datetime
147
+ retry_count: int = 0
148
+
149
+
150
+ def _get_iso_timestamp(dt: datetime | None = None) -> str:
151
+ if dt is None:
152
+ dt = datetime.now(timezone.utc)
153
+ return dt.strftime("%Y-%m-%dT%H:%M:%S.") + f"{dt.microsecond // 1000:03d}Z"
154
+
155
+
156
+ def build_payload(
157
+ events: list[QueuedEvent], *, source_client_version: str, session_id: str = "undefined"
158
+ ) -> dict[str, Any]:
159
+ return {
160
+ "browserType": "undefined", # do not change
161
+ "clientId": CLIENT_ID,
162
+ "clientType": "Native", # do not change
163
+ "clientVariant": "Release", # do not change
164
+ "clientVer": source_client_version,
165
+ "cpuArchitecture": CPU_ARCHITECTURE,
166
+ "deviceGdprBehOptIn": "None", # do not change
167
+ "deviceGdprFuncOptIn": "None", # do not change
168
+ "deviceGdprTechOptIn": "None", # do not change
169
+ "deviceId": "undefined", # do not change
170
+ "deviceMake": "undefined", # do not change
171
+ "deviceModel": "undefined", # do not change
172
+ "deviceOS": "undefined", # do not change
173
+ "deviceOSVersion": "undefined", # do not change
174
+ "deviceType": "undefined", # do not change
175
+ "eventProtocol": "1.6", # do not change
176
+ "eventSchemaVer": events[0].event._schema_version,
177
+ "eventSysVer": NEMO_TELEMETRY_VERSION,
178
+ "externalUserId": "undefined", # do not change
179
+ "gdprBehOptIn": "None", # do not change
180
+ "gdprFuncOptIn": "None", # do not change
181
+ "gdprTechOptIn": "None", # do not change
182
+ "idpId": "undefined", # do not change
183
+ "integrationId": "undefined", # do not change
184
+ "productName": "undefined", # do not change
185
+ "productVersion": "undefined", # do not change
186
+ "sentTs": _get_iso_timestamp(),
187
+ "sessionId": session_id,
188
+ "userId": "undefined", # do not change
189
+ "events": [
190
+ {
191
+ "ts": _get_iso_timestamp(queued.timestamp),
192
+ "parameters": queued.event.model_dump(by_alias=True),
193
+ "name": queued.event._event_name,
194
+ }
195
+ for queued in events
196
+ ],
197
+ }
198
+
199
+
200
+ class TelemetryHandler:
201
+ """
202
+ Handles telemetry event batching, flushing, and retry logic for NeMo products.
203
+
204
+ Args:
205
+ flush_interval_seconds (float): The interval in seconds to flush the events.
206
+ max_queue_size (int): The maximum number of events to queue before flushing.
207
+ max_retries (int): The maximum number of times to retry sending an event.
208
+ source_client_version (str): The version of the source client. This should be the version of
209
+ the actual NeMo product that is sending the events, typically the same as the version of
210
+ a PyPi package that a user would install.
211
+ session_id (str): An optional session ID to associate with the events.
212
+ This should be a unique identifier for the session, such as a UUID.
213
+ It is used to group events together.
214
+ """
215
+
216
+ def __init__(
217
+ self,
218
+ flush_interval_seconds: float = 120.0,
219
+ max_queue_size: int = 50,
220
+ max_retries: int = MAX_RETRIES,
221
+ source_client_version: str = "undefined",
222
+ session_id: str = "undefined",
223
+ ):
224
+ self._flush_interval = flush_interval_seconds
225
+ self._max_queue_size = max_queue_size
226
+ self._max_retries = max_retries
227
+ self._events: list[QueuedEvent] = []
228
+ self._dlq: list[QueuedEvent] = [] # Dead letter queue for retry
229
+ self._flush_signal = asyncio.Event()
230
+ self._timer_task: asyncio.Task | None = None
231
+ self._running = False
232
+ self._source_client_version = source_client_version
233
+ # Apply session prefix if environment variable is set
234
+ if SESSION_PREFIX:
235
+ self._session_id = f"{SESSION_PREFIX}{session_id}"
236
+ else:
237
+ self._session_id = session_id
238
+
239
+ async def astart(self) -> None:
240
+ if self._running:
241
+ return
242
+ self._running = True
243
+ self._timer_task = asyncio.create_task(self._timer_loop())
244
+
245
+ async def astop(self) -> None:
246
+ self._running = False
247
+ self._flush_signal.set()
248
+ if self._timer_task:
249
+ self._timer_task.cancel()
250
+ try:
251
+ await self._timer_task
252
+ except asyncio.CancelledError:
253
+ pass
254
+ self._timer_task = None
255
+ await self._flush_events()
256
+
257
+ async def aflush(self) -> None:
258
+ self._flush_signal.set()
259
+
260
+ def start(self) -> None:
261
+ self._run_sync(self.astart())
262
+
263
+ def stop(self) -> None:
264
+ self._run_sync(self.astop())
265
+
266
+ def flush(self) -> None:
267
+ self._flush_signal.set()
268
+
269
+ def enqueue(self, event: TelemetryEvent) -> None:
270
+ if not TELEMETRY_ENABLED:
271
+ return
272
+ if not isinstance(event, TelemetryEvent):
273
+ # Silently fail as we prioritize not disrupting upstream call sites and telemetry is best effort
274
+ return
275
+ queued = QueuedEvent(event=event, timestamp=datetime.now(timezone.utc))
276
+ self._events.append(queued)
277
+ if len(self._events) >= self._max_queue_size:
278
+ self._flush_signal.set()
279
+
280
+ def _run_sync(self, coro: Any) -> Any:
281
+ try:
282
+ loop = asyncio.get_running_loop()
283
+ except RuntimeError:
284
+ loop = None
285
+
286
+ if loop and loop.is_running():
287
+ import concurrent.futures
288
+
289
+ with concurrent.futures.ThreadPoolExecutor() as pool:
290
+ future = pool.submit(asyncio.run, coro)
291
+ return future.result()
292
+ else:
293
+ return asyncio.run(coro)
294
+
295
+ def __enter__(self) -> TelemetryHandler:
296
+ self.start()
297
+ return self
298
+
299
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
300
+ self.stop()
301
+
302
+ async def __aenter__(self) -> TelemetryHandler:
303
+ await self.astart()
304
+ return self
305
+
306
+ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
307
+ await self.astop()
308
+
309
+ async def _timer_loop(self) -> None:
310
+ while self._running:
311
+ try:
312
+ await asyncio.wait_for(
313
+ self._flush_signal.wait(),
314
+ timeout=self._flush_interval,
315
+ )
316
+ except asyncio.TimeoutError:
317
+ pass
318
+ self._flush_signal.clear()
319
+ await self._flush_events()
320
+
321
+ async def _flush_events(self) -> None:
322
+ dlq_events, self._dlq = self._dlq, []
323
+ new_events, self._events = self._events, []
324
+ events_to_send = dlq_events + new_events
325
+ if events_to_send:
326
+ await self._send_events(events_to_send)
327
+
328
+ async def _send_events(self, events: list[QueuedEvent]) -> None:
329
+ async with httpx.AsyncClient() as client:
330
+ await self._send_events_with_client(client, events)
331
+
332
+ async def _send_events_with_client(self, client: httpx.AsyncClient, events: list[QueuedEvent]) -> None:
333
+ if not events:
334
+ return
335
+
336
+ payload = build_payload(events, source_client_version=self._source_client_version, session_id=self._session_id)
337
+ try:
338
+ response = await client.post(NEMO_TELEMETRY_ENDPOINT, json=payload)
339
+ # 2xx, 400, 422 are all considered complete (no retry)
340
+ # 400/422 indicate bad payload which retrying won't fix
341
+ if response.status_code in (400, 422) or response.is_success:
342
+ return
343
+ # 413 (payload too large) - split and retry
344
+ if response.status_code == 413:
345
+ if len(events) == 1:
346
+ # Can't split further, drop the event
347
+ return
348
+ mid = len(events) // 2
349
+ await self._send_events_with_client(client, events[:mid])
350
+ await self._send_events_with_client(client, events[mid:])
351
+ return
352
+ if response.status_code == 408 or response.status_code >= 500:
353
+ self._add_to_dlq(events)
354
+ except httpx.HTTPError:
355
+ self._add_to_dlq(events)
356
+
357
+ def _add_to_dlq(self, events: list[QueuedEvent]) -> None:
358
+ for queued in events:
359
+ queued.retry_count += 1
360
+ if queued.retry_count > self._max_retries:
361
+ continue
362
+ self._dlq.append(queued)
@@ -0,0 +1,73 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ import logging
7
+
8
+ from pydantic import BaseModel, computed_field
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class TokenUsageStats(BaseModel):
14
+ input_tokens: int = 0
15
+ output_tokens: int = 0
16
+
17
+ @computed_field
18
+ def total_tokens(self) -> int:
19
+ return self.input_tokens + self.output_tokens
20
+
21
+ @property
22
+ def has_usage(self) -> bool:
23
+ return self.total_tokens > 0
24
+
25
+ def extend(self, *, input_tokens: int, output_tokens: int) -> None:
26
+ self.input_tokens += input_tokens
27
+ self.output_tokens += output_tokens
28
+
29
+
30
+ class RequestUsageStats(BaseModel):
31
+ successful_requests: int = 0
32
+ failed_requests: int = 0
33
+
34
+ @computed_field
35
+ def total_requests(self) -> int:
36
+ return self.successful_requests + self.failed_requests
37
+
38
+ @property
39
+ def has_usage(self) -> bool:
40
+ return self.total_requests > 0
41
+
42
+ def extend(self, *, successful_requests: int, failed_requests: int) -> None:
43
+ self.successful_requests += successful_requests
44
+ self.failed_requests += failed_requests
45
+
46
+
47
+ class ModelUsageStats(BaseModel):
48
+ token_usage: TokenUsageStats = TokenUsageStats()
49
+ request_usage: RequestUsageStats = RequestUsageStats()
50
+
51
+ @property
52
+ def has_usage(self) -> bool:
53
+ return self.token_usage.has_usage and self.request_usage.has_usage
54
+
55
+ def extend(
56
+ self, *, token_usage: TokenUsageStats | None = None, request_usage: RequestUsageStats | None = None
57
+ ) -> None:
58
+ if token_usage is not None:
59
+ self.token_usage.extend(input_tokens=token_usage.input_tokens, output_tokens=token_usage.output_tokens)
60
+ if request_usage is not None:
61
+ self.request_usage.extend(
62
+ successful_requests=request_usage.successful_requests, failed_requests=request_usage.failed_requests
63
+ )
64
+
65
+ def get_usage_stats(self, *, total_time_elapsed: float) -> dict:
66
+ return self.model_dump() | {
67
+ "tokens_per_second": int(self.token_usage.total_tokens / total_time_elapsed)
68
+ if total_time_elapsed > 0
69
+ else 0,
70
+ "requests_per_minute": int(self.request_usage.total_requests / total_time_elapsed * 60)
71
+ if total_time_elapsed > 0
72
+ else 0,
73
+ }
@@ -0,0 +1,101 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from dataclasses import dataclass, field
7
+ from typing import Any, Literal
8
+
9
+
10
+ @dataclass
11
+ class ChatMessage:
12
+ """A chat message in an LLM conversation.
13
+
14
+ This dataclass represents messages exchanged in a conversation with an LLM,
15
+ supporting various message types including user prompts, assistant responses,
16
+ system instructions, and tool interactions.
17
+
18
+ Attributes:
19
+ role: The role of the message sender. One of 'user', 'assistant', 'system', or 'tool'.
20
+ content: The message content. Can be a string or a list of content blocks
21
+ for multimodal messages (e.g., text + images).
22
+ reasoning_content: Optional reasoning/thinking content from the assistant,
23
+ typically from extended thinking or chain-of-thought models.
24
+ tool_calls: Optional list of tool calls requested by the assistant.
25
+ Each tool call contains 'id', 'type', and 'function' keys.
26
+ tool_call_id: Optional ID linking a tool response to its corresponding
27
+ tool call. Required for messages with role='tool'.
28
+ """
29
+
30
+ role: Literal["user", "assistant", "system", "tool"]
31
+ content: str | list[dict[str, Any]] = ""
32
+ reasoning_content: str | None = None
33
+ tool_calls: list[dict[str, Any]] = field(default_factory=list)
34
+ tool_call_id: str | None = None
35
+
36
+ def to_dict(self) -> dict[str, Any]:
37
+ """Convert the message to a dictionary format for API calls.
38
+
39
+ Returns:
40
+ A dictionary containing the message fields. Only includes non-empty
41
+ optional fields to keep the output clean.
42
+ """
43
+ result: dict[str, Any] = {"role": self.role, "content": self.content}
44
+ if self.reasoning_content:
45
+ result["reasoning_content"] = self.reasoning_content
46
+ if self.tool_calls:
47
+ result["tool_calls"] = self.tool_calls
48
+ if self.tool_call_id:
49
+ result["tool_call_id"] = self.tool_call_id
50
+ return result
51
+
52
+ @classmethod
53
+ def as_user(cls, content: str | list[dict[str, Any]]) -> ChatMessage:
54
+ """Create a user message."""
55
+ return cls(role="user", content=content)
56
+
57
+ @classmethod
58
+ def as_assistant(
59
+ cls,
60
+ content: str = "",
61
+ reasoning_content: str | None = None,
62
+ tool_calls: list[dict[str, Any]] | None = None,
63
+ ) -> ChatMessage:
64
+ """Create an assistant message."""
65
+ return cls(
66
+ role="assistant",
67
+ content=content,
68
+ reasoning_content=reasoning_content,
69
+ tool_calls=tool_calls or [],
70
+ )
71
+
72
+ @classmethod
73
+ def as_system(cls, content: str) -> ChatMessage:
74
+ """Create a system message."""
75
+ return cls(role="system", content=content)
76
+
77
+ @classmethod
78
+ def as_tool(cls, content: str, tool_call_id: str) -> ChatMessage:
79
+ """Create a tool response message."""
80
+ return cls(role="tool", content=content, tool_call_id=tool_call_id)
81
+
82
+
83
+ def prompt_to_messages(
84
+ *,
85
+ user_prompt: str,
86
+ system_prompt: str | None = None,
87
+ multi_modal_context: list[dict[str, Any]] | None = None,
88
+ ) -> list[ChatMessage]:
89
+ """Convert a user and system prompt into ChatMessage list.
90
+
91
+ Args:
92
+ user_prompt (str): A user prompt.
93
+ system_prompt (str, optional): An optional system prompt.
94
+ """
95
+ user_content: str | list[dict[str, Any]] = user_prompt
96
+ if multi_modal_context:
97
+ user_content = [*multi_modal_context, {"type": "text", "text": user_prompt}]
98
+
99
+ if system_prompt:
100
+ return [ChatMessage.as_system(system_prompt), ChatMessage.as_user(user_content)]
101
+ return [ChatMessage.as_user(user_content)]
@@ -0,0 +1,2 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
@@ -0,0 +1,65 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from collections import deque
7
+
8
+ from jinja2 import nodes as j_nodes
9
+
10
+
11
+ def ast_max_depth(node: j_nodes.Node) -> int:
12
+ """Calculate the depth of a Jinja AST from a given node.
13
+
14
+ Args:
15
+ node (jinja2.nodes.Node): The starting Jinja2 AST node
16
+
17
+ Returns:
18
+ int: The maximum depth of the tree
19
+ """
20
+ # Each entry is (node, depth)
21
+ queue = deque([(node, 1)])
22
+ max_depth = 0
23
+
24
+ while queue:
25
+ current_node, current_depth = queue.popleft()
26
+
27
+ # Update maximum depth seen so far
28
+ max_depth = max(max_depth, current_depth)
29
+
30
+ # Add all children with incremented depth
31
+ for child in current_node.iter_child_nodes():
32
+ queue.append((child, current_depth + 1))
33
+
34
+ return max_depth
35
+
36
+
37
+ def ast_descendant_count(ast: j_nodes.Node, only_type: type[j_nodes.Node] | None = None) -> int:
38
+ """Count the number of nodes which descend from the given node.
39
+
40
+ Args:
41
+ ast (jinja2.nodes.Node): The starting Jinja2 AST node
42
+ only_type (Type[jinja2.nodes.Node]): If specified, then only
43
+ nodes of this type will be counted.
44
+
45
+ Returns:
46
+ int: The number of nodes descended from the given node.
47
+ """
48
+ if only_type is None:
49
+ only_type = j_nodes.Node
50
+
51
+ return len(list(ast.find_all(only_type)))
52
+
53
+
54
+ def ast_count_name_references(ast: j_nodes.Node, name: str) -> int:
55
+ """Count the number of nodes descended from the current that refer to name.
56
+
57
+ Args:
58
+ ast (jinja2.nodes.Node): The starting Jinja2 AST node
59
+
60
+ Returns:
61
+ int: The number of nodes descended from the provided node whose
62
+ name field matches the given name.
63
+ """
64
+ referenced_names = [node.name for node in ast.find_all(j_nodes.Name) if node.name == name]
65
+ return len(referenced_names)