nvidia-nat 1.3.0rc1__py3-none-any.whl → 1.3.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/prompt_optimizer/register.py +2 -2
- nat/agent/react_agent/register.py +20 -21
- nat/agent/rewoo_agent/register.py +18 -20
- nat/agent/tool_calling_agent/register.py +7 -3
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +31 -18
- nat/builder/component_utils.py +1 -1
- nat/builder/context.py +22 -6
- nat/builder/function.py +3 -2
- nat/builder/workflow_builder.py +46 -3
- nat/cli/commands/mcp/mcp.py +6 -6
- nat/cli/commands/workflow/templates/config.yml.j2 +14 -12
- nat/cli/commands/workflow/templates/register.py.j2 +2 -2
- nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
- nat/cli/commands/workflow/workflow_commands.py +54 -10
- nat/cli/entrypoint.py +9 -1
- nat/cli/main.py +3 -0
- nat/data_models/api_server.py +143 -66
- nat/data_models/config.py +1 -1
- nat/data_models/span.py +41 -3
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +2 -2
- nat/front_ends/console/console_front_end_plugin.py +11 -2
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +5 -35
- nat/front_ends/fastapi/message_validator.py +3 -1
- nat/observability/exporter/span_exporter.py +34 -14
- nat/observability/register.py +16 -0
- nat/profiler/decorators/framework_wrapper.py +1 -1
- nat/profiler/forecasting/models/linear_model.py +1 -1
- nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
- nat/runtime/runner.py +103 -6
- nat/runtime/session.py +27 -1
- nat/tool/memory_tools/add_memory_tool.py +3 -3
- nat/tool/memory_tools/delete_memory_tool.py +3 -4
- nat/tool/memory_tools/get_memory_tool.py +4 -4
- nat/utils/decorators.py +210 -0
- nat/utils/type_converter.py +8 -0
- nvidia_nat-1.3.0rc3.dist-info/METADATA +195 -0
- {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/RECORD +46 -45
- nvidia_nat-1.3.0rc1.dist-info/METADATA +0 -391
- {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/top_level.txt +0 -0
|
@@ -27,6 +27,50 @@ from jinja2 import FileSystemLoader
|
|
|
27
27
|
logger = logging.getLogger(__name__)
|
|
28
28
|
|
|
29
29
|
|
|
30
|
+
def _get_nat_version() -> str | None:
|
|
31
|
+
"""
|
|
32
|
+
Get the current NAT version.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
str: The NAT version intended for use in a dependency string.
|
|
36
|
+
None: If the NAT version is not found.
|
|
37
|
+
"""
|
|
38
|
+
from nat.cli.entrypoint import get_version
|
|
39
|
+
|
|
40
|
+
current_version = get_version()
|
|
41
|
+
if current_version == "unknown":
|
|
42
|
+
return None
|
|
43
|
+
|
|
44
|
+
version_parts = current_version.split(".")
|
|
45
|
+
if len(version_parts) < 3:
|
|
46
|
+
# If the version somehow doesn't have three parts, return the full version
|
|
47
|
+
return current_version
|
|
48
|
+
|
|
49
|
+
patch = version_parts[2]
|
|
50
|
+
try:
|
|
51
|
+
# If the patch is a number, keep only the major and minor parts
|
|
52
|
+
# Useful for stable releases and adheres to semantic versioning
|
|
53
|
+
_ = int(patch)
|
|
54
|
+
digits_to_keep = 2
|
|
55
|
+
except ValueError:
|
|
56
|
+
# If the patch is not a number, keep all three digits
|
|
57
|
+
# Useful for pre-release versions (and nightly builds)
|
|
58
|
+
digits_to_keep = 3
|
|
59
|
+
|
|
60
|
+
return ".".join(version_parts[:digits_to_keep])
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _is_nat_version_prerelease() -> bool:
|
|
64
|
+
"""
|
|
65
|
+
Check if the NAT version is a prerelease.
|
|
66
|
+
"""
|
|
67
|
+
version = _get_nat_version()
|
|
68
|
+
if version is None:
|
|
69
|
+
return False
|
|
70
|
+
|
|
71
|
+
return len(version.split(".")) >= 3
|
|
72
|
+
|
|
73
|
+
|
|
30
74
|
def _get_nat_dependency(versioned: bool = True) -> str:
|
|
31
75
|
"""
|
|
32
76
|
Get the NAT dependency string with version.
|
|
@@ -44,16 +88,12 @@ def _get_nat_dependency(versioned: bool = True) -> str:
|
|
|
44
88
|
logger.debug("Using unversioned NAT dependency: %s", dependency)
|
|
45
89
|
return dependency
|
|
46
90
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
if current_version == "unknown":
|
|
51
|
-
logger.warning("Could not detect NAT version, using unversioned dependency")
|
|
91
|
+
version = _get_nat_version()
|
|
92
|
+
if version is None:
|
|
93
|
+
logger.debug("Could not detect NAT version, using unversioned dependency: %s", dependency)
|
|
52
94
|
return dependency
|
|
53
95
|
|
|
54
|
-
|
|
55
|
-
major_minor = ".".join(current_version.split(".")[:2])
|
|
56
|
-
dependency += f"~={major_minor}"
|
|
96
|
+
dependency += f"~={version}"
|
|
57
97
|
logger.debug("Using NAT dependency: %s", dependency)
|
|
58
98
|
return dependency
|
|
59
99
|
|
|
@@ -219,12 +259,16 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
|
|
|
219
259
|
install_cmd = ['uv', 'pip', 'install', '-e', str(new_workflow_dir)]
|
|
220
260
|
else:
|
|
221
261
|
install_cmd = ['pip', 'install', '-e', str(new_workflow_dir)]
|
|
262
|
+
if _is_nat_version_prerelease():
|
|
263
|
+
install_cmd.insert(2, "--pre")
|
|
264
|
+
|
|
265
|
+
python_safe_workflow_name = workflow_name.replace("-", "_")
|
|
222
266
|
|
|
223
267
|
# List of templates and their destinations
|
|
224
268
|
files_to_render = {
|
|
225
269
|
'pyproject.toml.j2': new_workflow_dir / 'pyproject.toml',
|
|
226
270
|
'register.py.j2': base_dir / 'register.py',
|
|
227
|
-
'workflow.py.j2': base_dir / f'{
|
|
271
|
+
'workflow.py.j2': base_dir / f'{python_safe_workflow_name}.py',
|
|
228
272
|
'__init__.py.j2': base_dir / '__init__.py',
|
|
229
273
|
'config.yml.j2': configs_dir / 'config.yml',
|
|
230
274
|
}
|
|
@@ -233,7 +277,7 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
|
|
|
233
277
|
context = {
|
|
234
278
|
'editable': editable,
|
|
235
279
|
'workflow_name': workflow_name,
|
|
236
|
-
'python_safe_workflow_name':
|
|
280
|
+
'python_safe_workflow_name': python_safe_workflow_name,
|
|
237
281
|
'package_name': package_name,
|
|
238
282
|
'rel_path_to_repo_root': rel_path_to_repo_root,
|
|
239
283
|
'workflow_class_name': f"{_generate_valid_classname(workflow_name)}FunctionConfig",
|
nat/cli/entrypoint.py
CHANGED
|
@@ -29,6 +29,7 @@ import time
|
|
|
29
29
|
|
|
30
30
|
import click
|
|
31
31
|
import nest_asyncio
|
|
32
|
+
from dotenv import load_dotenv
|
|
32
33
|
|
|
33
34
|
from nat.utils.log_levels import LOG_LEVELS
|
|
34
35
|
|
|
@@ -45,6 +46,9 @@ from .commands.uninstall import uninstall_command
|
|
|
45
46
|
from .commands.validate import validate_command
|
|
46
47
|
from .commands.workflow.workflow import workflow_command
|
|
47
48
|
|
|
49
|
+
# Load environment variables from .env file, if it exists
|
|
50
|
+
load_dotenv()
|
|
51
|
+
|
|
48
52
|
# Apply at the beginning of the file to avoid issues with asyncio
|
|
49
53
|
nest_asyncio.apply()
|
|
50
54
|
|
|
@@ -52,7 +56,11 @@ nest_asyncio.apply()
|
|
|
52
56
|
def setup_logging(log_level: str):
|
|
53
57
|
"""Configure logging with the specified level"""
|
|
54
58
|
numeric_level = LOG_LEVELS.get(log_level.upper(), logging.INFO)
|
|
55
|
-
logging.basicConfig(
|
|
59
|
+
logging.basicConfig(
|
|
60
|
+
level=numeric_level,
|
|
61
|
+
format="%(asctime)s - %(levelname)-8s - %(name)s:%(lineno)d - %(message)s",
|
|
62
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
63
|
+
)
|
|
56
64
|
return numeric_level
|
|
57
65
|
|
|
58
66
|
|
nat/cli/main.py
CHANGED
nat/data_models/api_server.py
CHANGED
|
@@ -28,6 +28,7 @@ from pydantic import HttpUrl
|
|
|
28
28
|
from pydantic import conlist
|
|
29
29
|
from pydantic import field_serializer
|
|
30
30
|
from pydantic import field_validator
|
|
31
|
+
from pydantic import model_validator
|
|
31
32
|
from pydantic_core.core_schema import ValidationInfo
|
|
32
33
|
|
|
33
34
|
from nat.data_models.interactive import HumanPrompt
|
|
@@ -36,6 +37,15 @@ from nat.utils.type_converter import GlobalTypeConverter
|
|
|
36
37
|
FINISH_REASONS = frozenset({'stop', 'length', 'tool_calls', 'content_filter', 'function_call'})
|
|
37
38
|
|
|
38
39
|
|
|
40
|
+
class UserMessageContentRoleType(str, Enum):
|
|
41
|
+
"""
|
|
42
|
+
Enum representing chat message roles in API requests and responses.
|
|
43
|
+
"""
|
|
44
|
+
USER = "user"
|
|
45
|
+
ASSISTANT = "assistant"
|
|
46
|
+
SYSTEM = "system"
|
|
47
|
+
|
|
48
|
+
|
|
39
49
|
class Request(BaseModel):
|
|
40
50
|
"""
|
|
41
51
|
Request is a data model that represents HTTP request attributes.
|
|
@@ -108,18 +118,10 @@ UserContent = typing.Annotated[TextContent | ImageContent | AudioContent, Discri
|
|
|
108
118
|
|
|
109
119
|
class Message(BaseModel):
|
|
110
120
|
content: str | list[UserContent]
|
|
111
|
-
role:
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
class ChatRequest(BaseModel):
|
|
115
|
-
"""
|
|
116
|
-
ChatRequest is a data model that represents a request to the NAT chat API.
|
|
117
|
-
Fully compatible with OpenAI Chat Completions API specification.
|
|
118
|
-
"""
|
|
121
|
+
role: UserMessageContentRoleType
|
|
119
122
|
|
|
120
|
-
# Required fields
|
|
121
|
-
messages: typing.Annotated[list[Message], conlist(Message, min_length=1)]
|
|
122
123
|
|
|
124
|
+
class ChatRequestOptionals(BaseModel):
|
|
123
125
|
# Optional fields (OpenAI Chat Completions API compatible)
|
|
124
126
|
model: str | None = Field(default=None, description="name of the model to use")
|
|
125
127
|
frequency_penalty: float | None = Field(default=0.0,
|
|
@@ -144,6 +146,16 @@ class ChatRequest(BaseModel):
|
|
|
144
146
|
parallel_tool_calls: bool | None = Field(default=True, description="Whether to enable parallel function calling")
|
|
145
147
|
user: str | None = Field(default=None, description="Unique identifier representing end-user")
|
|
146
148
|
|
|
149
|
+
|
|
150
|
+
class ChatRequest(ChatRequestOptionals):
|
|
151
|
+
"""
|
|
152
|
+
ChatRequest is a data model that represents a request to the NAT chat API.
|
|
153
|
+
Fully compatible with OpenAI Chat Completions API specification.
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
# Required fields
|
|
157
|
+
messages: typing.Annotated[list[Message], conlist(Message, min_length=1)]
|
|
158
|
+
|
|
147
159
|
model_config = ConfigDict(extra="allow",
|
|
148
160
|
json_schema_extra={
|
|
149
161
|
"example": {
|
|
@@ -164,7 +176,7 @@ class ChatRequest(BaseModel):
|
|
|
164
176
|
max_tokens: int | None = None,
|
|
165
177
|
top_p: float | None = None) -> "ChatRequest":
|
|
166
178
|
|
|
167
|
-
return ChatRequest(messages=[Message(content=data, role=
|
|
179
|
+
return ChatRequest(messages=[Message(content=data, role=UserMessageContentRoleType.USER)],
|
|
168
180
|
model=model,
|
|
169
181
|
temperature=temperature,
|
|
170
182
|
max_tokens=max_tokens,
|
|
@@ -178,38 +190,85 @@ class ChatRequest(BaseModel):
|
|
|
178
190
|
max_tokens: int | None = None,
|
|
179
191
|
top_p: float | None = None) -> "ChatRequest":
|
|
180
192
|
|
|
181
|
-
return ChatRequest(messages=[Message(content=content, role=
|
|
193
|
+
return ChatRequest(messages=[Message(content=content, role=UserMessageContentRoleType.USER)],
|
|
182
194
|
model=model,
|
|
183
195
|
temperature=temperature,
|
|
184
196
|
max_tokens=max_tokens,
|
|
185
197
|
top_p=top_p)
|
|
186
198
|
|
|
187
199
|
|
|
200
|
+
class ChatRequestOrMessage(ChatRequestOptionals):
|
|
201
|
+
"""
|
|
202
|
+
ChatRequestOrMessage is a data model that represents either a conversation or a string input.
|
|
203
|
+
This is useful for functions that can handle either type of input.
|
|
204
|
+
|
|
205
|
+
`messages` is compatible with the OpenAI Chat Completions API specification.
|
|
206
|
+
|
|
207
|
+
`input_string` is a string input that can be used for functions that do not require a conversation.
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
messages: typing.Annotated[list[Message] | None, conlist(Message, min_length=1)] = Field(
|
|
211
|
+
default=None, description="The conversation messages to process.")
|
|
212
|
+
|
|
213
|
+
input_string: str | None = Field(default=None, alias="input_message", description="The input message to process.")
|
|
214
|
+
|
|
215
|
+
@property
|
|
216
|
+
def is_string(self) -> bool:
|
|
217
|
+
return self.input_string is not None
|
|
218
|
+
|
|
219
|
+
@property
|
|
220
|
+
def is_conversation(self) -> bool:
|
|
221
|
+
return self.messages is not None
|
|
222
|
+
|
|
223
|
+
@model_validator(mode="after")
|
|
224
|
+
def validate_messages_or_input_string(self):
|
|
225
|
+
if self.messages is not None and self.input_string is not None:
|
|
226
|
+
raise ValueError("Either messages or input_message/input_string must be provided, not both")
|
|
227
|
+
if self.messages is None and self.input_string is None:
|
|
228
|
+
raise ValueError("Either messages or input_message/input_string must be provided")
|
|
229
|
+
if self.input_string is not None:
|
|
230
|
+
extra_fields = self.model_dump(exclude={"input_string"}, exclude_none=True, exclude_unset=True)
|
|
231
|
+
if len(extra_fields) > 0:
|
|
232
|
+
raise ValueError("no extra fields are permitted when input_message/input_string is provided")
|
|
233
|
+
return self
|
|
234
|
+
|
|
235
|
+
|
|
188
236
|
class ChoiceMessage(BaseModel):
|
|
189
237
|
content: str | None = None
|
|
190
|
-
role:
|
|
238
|
+
role: UserMessageContentRoleType | None = None
|
|
191
239
|
|
|
192
240
|
|
|
193
241
|
class ChoiceDelta(BaseModel):
|
|
194
242
|
"""Delta object for streaming responses (OpenAI-compatible)"""
|
|
195
243
|
content: str | None = None
|
|
196
|
-
role:
|
|
244
|
+
role: UserMessageContentRoleType | None = None
|
|
197
245
|
|
|
198
246
|
|
|
199
|
-
class
|
|
247
|
+
class ChoiceBase(BaseModel):
|
|
248
|
+
"""Base choice model with common fields for both streaming and non-streaming responses"""
|
|
200
249
|
model_config = ConfigDict(extra="allow")
|
|
201
|
-
|
|
202
|
-
message: ChoiceMessage | None = None
|
|
203
|
-
delta: ChoiceDelta | None = None
|
|
204
250
|
finish_reason: typing.Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call'] | None = None
|
|
205
251
|
index: int
|
|
206
|
-
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
class ChatResponseChoice(ChoiceBase):
|
|
255
|
+
"""Choice model for non-streaming responses - contains message field"""
|
|
256
|
+
message: ChoiceMessage
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
class ChatResponseChunkChoice(ChoiceBase):
|
|
260
|
+
"""Choice model for streaming responses - contains delta field"""
|
|
261
|
+
delta: ChoiceDelta
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
# Backward compatibility alias
|
|
265
|
+
Choice = ChatResponseChoice
|
|
207
266
|
|
|
208
267
|
|
|
209
268
|
class Usage(BaseModel):
|
|
210
|
-
prompt_tokens: int
|
|
211
|
-
completion_tokens: int
|
|
212
|
-
total_tokens: int
|
|
269
|
+
prompt_tokens: int | None = None
|
|
270
|
+
completion_tokens: int | None = None
|
|
271
|
+
total_tokens: int | None = None
|
|
213
272
|
|
|
214
273
|
|
|
215
274
|
class ResponseSerializable(abc.ABC):
|
|
@@ -245,10 +304,10 @@ class ChatResponse(ResponseBaseModelOutput):
|
|
|
245
304
|
model_config = ConfigDict(extra="allow")
|
|
246
305
|
id: str
|
|
247
306
|
object: str = "chat.completion"
|
|
248
|
-
model: str = ""
|
|
307
|
+
model: str = "unknown-model"
|
|
249
308
|
created: datetime.datetime
|
|
250
|
-
choices: list[
|
|
251
|
-
usage: Usage
|
|
309
|
+
choices: list[ChatResponseChoice]
|
|
310
|
+
usage: Usage
|
|
252
311
|
system_fingerprint: str | None = None
|
|
253
312
|
service_tier: typing.Literal["scale", "default"] | None = None
|
|
254
313
|
|
|
@@ -264,14 +323,14 @@ class ChatResponse(ResponseBaseModelOutput):
|
|
|
264
323
|
object_: str | None = None,
|
|
265
324
|
model: str | None = None,
|
|
266
325
|
created: datetime.datetime | None = None,
|
|
267
|
-
usage: Usage
|
|
326
|
+
usage: Usage) -> "ChatResponse":
|
|
268
327
|
|
|
269
328
|
if id_ is None:
|
|
270
329
|
id_ = str(uuid.uuid4())
|
|
271
330
|
if object_ is None:
|
|
272
331
|
object_ = "chat.completion"
|
|
273
332
|
if model is None:
|
|
274
|
-
model = ""
|
|
333
|
+
model = "unknown-model"
|
|
275
334
|
if created is None:
|
|
276
335
|
created = datetime.datetime.now(datetime.UTC)
|
|
277
336
|
|
|
@@ -279,7 +338,12 @@ class ChatResponse(ResponseBaseModelOutput):
|
|
|
279
338
|
object=object_,
|
|
280
339
|
model=model,
|
|
281
340
|
created=created,
|
|
282
|
-
choices=[
|
|
341
|
+
choices=[
|
|
342
|
+
ChatResponseChoice(index=0,
|
|
343
|
+
message=ChoiceMessage(content=data,
|
|
344
|
+
role=UserMessageContentRoleType.ASSISTANT),
|
|
345
|
+
finish_reason="stop")
|
|
346
|
+
],
|
|
283
347
|
usage=usage)
|
|
284
348
|
|
|
285
349
|
|
|
@@ -293,9 +357,9 @@ class ChatResponseChunk(ResponseBaseModelOutput):
|
|
|
293
357
|
model_config = ConfigDict(extra="allow")
|
|
294
358
|
|
|
295
359
|
id: str
|
|
296
|
-
choices: list[
|
|
360
|
+
choices: list[ChatResponseChunkChoice]
|
|
297
361
|
created: datetime.datetime
|
|
298
|
-
model: str = ""
|
|
362
|
+
model: str = "unknown-model"
|
|
299
363
|
object: str = "chat.completion.chunk"
|
|
300
364
|
system_fingerprint: str | None = None
|
|
301
365
|
service_tier: typing.Literal["scale", "default"] | None = None
|
|
@@ -319,12 +383,18 @@ class ChatResponseChunk(ResponseBaseModelOutput):
|
|
|
319
383
|
if created is None:
|
|
320
384
|
created = datetime.datetime.now(datetime.UTC)
|
|
321
385
|
if model is None:
|
|
322
|
-
model = ""
|
|
386
|
+
model = "unknown-model"
|
|
323
387
|
if object_ is None:
|
|
324
388
|
object_ = "chat.completion.chunk"
|
|
325
389
|
|
|
326
390
|
return ChatResponseChunk(id=id_,
|
|
327
|
-
choices=[
|
|
391
|
+
choices=[
|
|
392
|
+
ChatResponseChunkChoice(index=0,
|
|
393
|
+
delta=ChoiceDelta(
|
|
394
|
+
content=data,
|
|
395
|
+
role=UserMessageContentRoleType.ASSISTANT),
|
|
396
|
+
finish_reason="stop")
|
|
397
|
+
],
|
|
328
398
|
created=created,
|
|
329
399
|
model=model,
|
|
330
400
|
object=object_)
|
|
@@ -335,7 +405,7 @@ class ChatResponseChunk(ResponseBaseModelOutput):
|
|
|
335
405
|
id_: str | None = None,
|
|
336
406
|
created: datetime.datetime | None = None,
|
|
337
407
|
model: str | None = None,
|
|
338
|
-
role:
|
|
408
|
+
role: UserMessageContentRoleType | None = None,
|
|
339
409
|
finish_reason: str | None = None,
|
|
340
410
|
usage: Usage | None = None,
|
|
341
411
|
system_fingerprint: str | None = None) -> "ChatResponseChunk":
|
|
@@ -345,7 +415,7 @@ class ChatResponseChunk(ResponseBaseModelOutput):
|
|
|
345
415
|
if created is None:
|
|
346
416
|
created = datetime.datetime.now(datetime.UTC)
|
|
347
417
|
if model is None:
|
|
348
|
-
model = ""
|
|
418
|
+
model = "unknown-model"
|
|
349
419
|
|
|
350
420
|
delta = ChoiceDelta(content=content, role=role) if content is not None or role is not None else ChoiceDelta()
|
|
351
421
|
|
|
@@ -353,7 +423,14 @@ class ChatResponseChunk(ResponseBaseModelOutput):
|
|
|
353
423
|
|
|
354
424
|
return ChatResponseChunk(
|
|
355
425
|
id=id_,
|
|
356
|
-
choices=[
|
|
426
|
+
choices=[
|
|
427
|
+
ChatResponseChunkChoice(
|
|
428
|
+
index=0,
|
|
429
|
+
delta=delta,
|
|
430
|
+
finish_reason=typing.cast(
|
|
431
|
+
typing.Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call'] | None,
|
|
432
|
+
final_finish_reason))
|
|
433
|
+
],
|
|
357
434
|
created=created,
|
|
358
435
|
model=model,
|
|
359
436
|
object="chat.completion.chunk",
|
|
@@ -398,11 +475,6 @@ class GenerateResponse(BaseModel):
|
|
|
398
475
|
value: str | None = "default"
|
|
399
476
|
|
|
400
477
|
|
|
401
|
-
class UserMessageContentRoleType(str, Enum):
|
|
402
|
-
USER = "user"
|
|
403
|
-
ASSISTANT = "assistant"
|
|
404
|
-
|
|
405
|
-
|
|
406
478
|
class WebSocketMessageType(str, Enum):
|
|
407
479
|
"""
|
|
408
480
|
WebSocketMessageType is an Enum that represents WebSocket Message types.
|
|
@@ -622,12 +694,42 @@ GlobalTypeConverter.register_converter(_nat_chat_request_to_string)
|
|
|
622
694
|
|
|
623
695
|
|
|
624
696
|
def _string_to_nat_chat_request(data: str) -> ChatRequest:
|
|
625
|
-
return ChatRequest.from_string(data, model="")
|
|
697
|
+
return ChatRequest.from_string(data, model="unknown-model")
|
|
626
698
|
|
|
627
699
|
|
|
628
700
|
GlobalTypeConverter.register_converter(_string_to_nat_chat_request)
|
|
629
701
|
|
|
630
702
|
|
|
703
|
+
def _chat_request_or_message_to_chat_request(data: ChatRequestOrMessage) -> ChatRequest:
|
|
704
|
+
if data.input_string is not None:
|
|
705
|
+
return _string_to_nat_chat_request(data.input_string)
|
|
706
|
+
return ChatRequest(**data.model_dump(exclude={"input_string"}))
|
|
707
|
+
|
|
708
|
+
|
|
709
|
+
GlobalTypeConverter.register_converter(_chat_request_or_message_to_chat_request)
|
|
710
|
+
|
|
711
|
+
|
|
712
|
+
def _chat_request_to_chat_request_or_message(data: ChatRequest) -> ChatRequestOrMessage:
|
|
713
|
+
return ChatRequestOrMessage(**data.model_dump(by_alias=True))
|
|
714
|
+
|
|
715
|
+
|
|
716
|
+
GlobalTypeConverter.register_converter(_chat_request_to_chat_request_or_message)
|
|
717
|
+
|
|
718
|
+
|
|
719
|
+
def _chat_request_or_message_to_string(data: ChatRequestOrMessage) -> str:
|
|
720
|
+
return data.input_string or ""
|
|
721
|
+
|
|
722
|
+
|
|
723
|
+
GlobalTypeConverter.register_converter(_chat_request_or_message_to_string)
|
|
724
|
+
|
|
725
|
+
|
|
726
|
+
def _string_to_chat_request_or_message(data: str) -> ChatRequestOrMessage:
|
|
727
|
+
return ChatRequestOrMessage(input_message=data)
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
GlobalTypeConverter.register_converter(_string_to_chat_request_or_message)
|
|
731
|
+
|
|
732
|
+
|
|
631
733
|
# ======== ChatResponse Converters ========
|
|
632
734
|
def _nat_chat_response_to_string(data: ChatResponse) -> str:
|
|
633
735
|
if data.choices and data.choices[0].message:
|
|
@@ -654,22 +756,12 @@ def _string_to_nat_chat_response(data: str) -> ChatResponse:
|
|
|
654
756
|
GlobalTypeConverter.register_converter(_string_to_nat_chat_response)
|
|
655
757
|
|
|
656
758
|
|
|
657
|
-
def _chat_response_to_chat_response_chunk(data: ChatResponse) -> ChatResponseChunk:
|
|
658
|
-
# Preserve original message structure for backward compatibility
|
|
659
|
-
return ChatResponseChunk(id=data.id, choices=data.choices, created=data.created, model=data.model)
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
GlobalTypeConverter.register_converter(_chat_response_to_chat_response_chunk)
|
|
663
|
-
|
|
664
|
-
|
|
665
759
|
# ======== ChatResponseChunk Converters ========
|
|
666
760
|
def _chat_response_chunk_to_string(data: ChatResponseChunk) -> str:
|
|
667
761
|
if data.choices and len(data.choices) > 0:
|
|
668
762
|
choice = data.choices[0]
|
|
669
763
|
if choice.delta and choice.delta.content:
|
|
670
764
|
return choice.delta.content
|
|
671
|
-
if choice.message and choice.message.content:
|
|
672
|
-
return choice.message.content
|
|
673
765
|
return ""
|
|
674
766
|
|
|
675
767
|
|
|
@@ -685,21 +777,6 @@ def _string_to_nat_chat_response_chunk(data: str) -> ChatResponseChunk:
|
|
|
685
777
|
|
|
686
778
|
GlobalTypeConverter.register_converter(_string_to_nat_chat_response_chunk)
|
|
687
779
|
|
|
688
|
-
|
|
689
|
-
# ======== AINodeMessageChunk Converters ========
|
|
690
|
-
def _ai_message_chunk_to_nat_chat_response_chunk(data) -> ChatResponseChunk:
|
|
691
|
-
'''Converts LangChain/LangGraph AINodeMessageChunk to ChatResponseChunk'''
|
|
692
|
-
content = ""
|
|
693
|
-
if hasattr(data, 'content') and data.content is not None:
|
|
694
|
-
content = str(data.content)
|
|
695
|
-
elif hasattr(data, 'text') and data.text is not None:
|
|
696
|
-
content = str(data.text)
|
|
697
|
-
elif hasattr(data, 'message') and data.message is not None:
|
|
698
|
-
content = str(data.message)
|
|
699
|
-
|
|
700
|
-
return ChatResponseChunk.create_streaming_chunk(content=content, role="assistant", finish_reason=None)
|
|
701
|
-
|
|
702
|
-
|
|
703
780
|
# Compatibility aliases with previous releases
|
|
704
781
|
AIQChatRequest = ChatRequest
|
|
705
782
|
AIQChoiceMessage = ChoiceMessage
|
nat/data_models/config.py
CHANGED
|
@@ -187,7 +187,7 @@ class TelemetryConfig(BaseModel):
|
|
|
187
187
|
|
|
188
188
|
class GeneralConfig(BaseModel):
|
|
189
189
|
|
|
190
|
-
model_config = ConfigDict(protected_namespaces=())
|
|
190
|
+
model_config = ConfigDict(protected_namespaces=(), extra="forbid")
|
|
191
191
|
|
|
192
192
|
use_uvloop: bool | None = Field(
|
|
193
193
|
default=None,
|
nat/data_models/span.py
CHANGED
|
@@ -128,10 +128,48 @@ class SpanStatus(BaseModel):
|
|
|
128
128
|
message: str | None = Field(default=None, description="The status message of the span.")
|
|
129
129
|
|
|
130
130
|
|
|
131
|
+
def _generate_nonzero_trace_id() -> int:
|
|
132
|
+
"""Generate a non-zero 128-bit trace ID."""
|
|
133
|
+
return uuid.uuid4().int
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _generate_nonzero_span_id() -> int:
|
|
137
|
+
"""Generate a non-zero 64-bit span ID."""
|
|
138
|
+
return uuid.uuid4().int >> 64
|
|
139
|
+
|
|
140
|
+
|
|
131
141
|
class SpanContext(BaseModel):
|
|
132
|
-
trace_id: int = Field(default_factory=
|
|
133
|
-
|
|
134
|
-
|
|
142
|
+
trace_id: int = Field(default_factory=_generate_nonzero_trace_id,
|
|
143
|
+
description="The OTel-syle 128-bit trace ID of the span.")
|
|
144
|
+
span_id: int = Field(default_factory=_generate_nonzero_span_id,
|
|
145
|
+
description="The OTel-syle 64-bit span ID of the span.")
|
|
146
|
+
|
|
147
|
+
@field_validator("trace_id", mode="before")
|
|
148
|
+
@classmethod
|
|
149
|
+
def _validate_trace_id(cls, v: int | str | None) -> int:
|
|
150
|
+
"""Regenerate if trace_id is None; raise an exception if trace_id is invalid;"""
|
|
151
|
+
if isinstance(v, str):
|
|
152
|
+
v = uuid.UUID(v).int
|
|
153
|
+
if isinstance(v, type(None)):
|
|
154
|
+
v = _generate_nonzero_trace_id()
|
|
155
|
+
if v <= 0 or v >> 128:
|
|
156
|
+
raise ValueError(f"Invalid trace_id: must be a non-zero 128-bit integer, got {v}")
|
|
157
|
+
return v
|
|
158
|
+
|
|
159
|
+
@field_validator("span_id", mode="before")
|
|
160
|
+
@classmethod
|
|
161
|
+
def _validate_span_id(cls, v: int | str | None) -> int:
|
|
162
|
+
"""Regenerate if span_id is None; raise an exception if span_id is invalid;"""
|
|
163
|
+
if isinstance(v, str):
|
|
164
|
+
try:
|
|
165
|
+
v = int(v, 16)
|
|
166
|
+
except ValueError:
|
|
167
|
+
raise ValueError(f"span_id unable to be parsed: {v}")
|
|
168
|
+
if isinstance(v, type(None)):
|
|
169
|
+
v = _generate_nonzero_span_id()
|
|
170
|
+
if v <= 0 or v >> 64:
|
|
171
|
+
raise ValueError(f"Invalid span_id: must be a non-zero 64-bit integer, got {v}")
|
|
172
|
+
return v
|
|
135
173
|
|
|
136
174
|
|
|
137
175
|
class Span(BaseModel):
|
|
@@ -46,7 +46,7 @@ async def execute_score_select_function(config: ExecuteScoreSelectFunctionConfig
|
|
|
46
46
|
|
|
47
47
|
from pydantic import BaseModel
|
|
48
48
|
|
|
49
|
-
executable_fn: Function = builder.get_function(name=config.augmented_fn)
|
|
49
|
+
executable_fn: Function = await builder.get_function(name=config.augmented_fn)
|
|
50
50
|
|
|
51
51
|
if config.scorer:
|
|
52
52
|
scorer = await builder.get_ttc_strategy(strategy_name=config.scorer,
|
|
@@ -98,8 +98,8 @@ async def register_ttc_tool_wrapper_function(
|
|
|
98
98
|
|
|
99
99
|
augmented_function_desc = config.tool_description
|
|
100
100
|
|
|
101
|
-
fn_input_schema: BaseModel = augmented_function.input_schema
|
|
102
|
-
fn_output_schema: BaseModel = augmented_function.single_output_schema
|
|
101
|
+
fn_input_schema: type[BaseModel] = augmented_function.input_schema
|
|
102
|
+
fn_output_schema: type[BaseModel] | type[None] = augmented_function.single_output_schema
|
|
103
103
|
|
|
104
104
|
runnable_llm = input_llm.with_structured_output(schema=fn_input_schema)
|
|
105
105
|
|
|
@@ -95,5 +95,14 @@ class ConsoleFrontEndPlugin(SimpleFrontEndPluginBase[ConsoleFrontEndConfig]):
|
|
|
95
95
|
else:
|
|
96
96
|
assert False, "Should not reach here. Should have been caught by pre_run"
|
|
97
97
|
|
|
98
|
-
|
|
99
|
-
|
|
98
|
+
line = f"{'-' * 50}"
|
|
99
|
+
prefix = f"{line}\n{Fore.GREEN}Workflow Result:\n"
|
|
100
|
+
suffix = f"{Fore.RESET}\n{line}"
|
|
101
|
+
|
|
102
|
+
logger.info(f"{prefix}%s{suffix}", runner_outputs)
|
|
103
|
+
|
|
104
|
+
# (handler is a stream handler) => (level > INFO)
|
|
105
|
+
effective_level_too_high = all(
|
|
106
|
+
type(h) is not logging.StreamHandler or h.level > logging.INFO for h in logging.getLogger().handlers)
|
|
107
|
+
if effective_level_too_high:
|
|
108
|
+
print(f"{prefix}{runner_outputs}{suffix}")
|
|
@@ -24,4 +24,4 @@ class HTTPAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
24
24
|
async def authenticate(self, config: AuthProviderBaseConfig, method: AuthFlowType) -> AuthenticatedContext:
|
|
25
25
|
|
|
26
26
|
raise NotImplementedError(f"Authentication method '{method}' is not supported by the HTTP frontend."
|
|
27
|
-
f" Do you have
|
|
27
|
+
f" Do you have WebSockets enabled?")
|