nvidia-nat 1.3.0rc1__py3-none-any.whl → 1.3.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. nat/agent/prompt_optimizer/register.py +2 -2
  2. nat/agent/react_agent/register.py +20 -21
  3. nat/agent/rewoo_agent/register.py +18 -20
  4. nat/agent/tool_calling_agent/register.py +7 -3
  5. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +31 -18
  6. nat/builder/component_utils.py +1 -1
  7. nat/builder/context.py +22 -6
  8. nat/builder/function.py +3 -2
  9. nat/builder/workflow_builder.py +46 -3
  10. nat/cli/commands/mcp/mcp.py +6 -6
  11. nat/cli/commands/workflow/templates/config.yml.j2 +14 -12
  12. nat/cli/commands/workflow/templates/register.py.j2 +2 -2
  13. nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
  14. nat/cli/commands/workflow/workflow_commands.py +54 -10
  15. nat/cli/entrypoint.py +9 -1
  16. nat/cli/main.py +3 -0
  17. nat/data_models/api_server.py +143 -66
  18. nat/data_models/config.py +1 -1
  19. nat/data_models/span.py +41 -3
  20. nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
  21. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +2 -2
  22. nat/front_ends/console/console_front_end_plugin.py +11 -2
  23. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
  24. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +5 -35
  25. nat/front_ends/fastapi/message_validator.py +3 -1
  26. nat/observability/exporter/span_exporter.py +34 -14
  27. nat/observability/register.py +16 -0
  28. nat/profiler/decorators/framework_wrapper.py +1 -1
  29. nat/profiler/forecasting/models/linear_model.py +1 -1
  30. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  31. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  32. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
  33. nat/runtime/runner.py +103 -6
  34. nat/runtime/session.py +27 -1
  35. nat/tool/memory_tools/add_memory_tool.py +3 -3
  36. nat/tool/memory_tools/delete_memory_tool.py +3 -4
  37. nat/tool/memory_tools/get_memory_tool.py +4 -4
  38. nat/utils/decorators.py +210 -0
  39. nat/utils/type_converter.py +8 -0
  40. nvidia_nat-1.3.0rc3.dist-info/METADATA +195 -0
  41. {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/RECORD +46 -45
  42. nvidia_nat-1.3.0rc1.dist-info/METADATA +0 -391
  43. {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/WHEEL +0 -0
  44. {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/entry_points.txt +0 -0
  45. {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  46. {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/licenses/LICENSE.md +0 -0
  47. {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc3.dist-info}/top_level.txt +0 -0
@@ -27,6 +27,50 @@ from jinja2 import FileSystemLoader
27
27
  logger = logging.getLogger(__name__)
28
28
 
29
29
 
30
+ def _get_nat_version() -> str | None:
31
+ """
32
+ Get the current NAT version.
33
+
34
+ Returns:
35
+ str: The NAT version intended for use in a dependency string.
36
+ None: If the NAT version is not found.
37
+ """
38
+ from nat.cli.entrypoint import get_version
39
+
40
+ current_version = get_version()
41
+ if current_version == "unknown":
42
+ return None
43
+
44
+ version_parts = current_version.split(".")
45
+ if len(version_parts) < 3:
46
+ # If the version somehow doesn't have three parts, return the full version
47
+ return current_version
48
+
49
+ patch = version_parts[2]
50
+ try:
51
+ # If the patch is a number, keep only the major and minor parts
52
+ # Useful for stable releases and adheres to semantic versioning
53
+ _ = int(patch)
54
+ digits_to_keep = 2
55
+ except ValueError:
56
+ # If the patch is not a number, keep all three digits
57
+ # Useful for pre-release versions (and nightly builds)
58
+ digits_to_keep = 3
59
+
60
+ return ".".join(version_parts[:digits_to_keep])
61
+
62
+
63
+ def _is_nat_version_prerelease() -> bool:
64
+ """
65
+ Check if the NAT version is a prerelease.
66
+ """
67
+ version = _get_nat_version()
68
+ if version is None:
69
+ return False
70
+
71
+ return len(version.split(".")) >= 3
72
+
73
+
30
74
  def _get_nat_dependency(versioned: bool = True) -> str:
31
75
  """
32
76
  Get the NAT dependency string with version.
@@ -44,16 +88,12 @@ def _get_nat_dependency(versioned: bool = True) -> str:
44
88
  logger.debug("Using unversioned NAT dependency: %s", dependency)
45
89
  return dependency
46
90
 
47
- # Get the current NAT version
48
- from nat.cli.entrypoint import get_version
49
- current_version = get_version()
50
- if current_version == "unknown":
51
- logger.warning("Could not detect NAT version, using unversioned dependency")
91
+ version = _get_nat_version()
92
+ if version is None:
93
+ logger.debug("Could not detect NAT version, using unversioned dependency: %s", dependency)
52
94
  return dependency
53
95
 
54
- # Extract major.minor (e.g., "1.2.3" -> "1.2")
55
- major_minor = ".".join(current_version.split(".")[:2])
56
- dependency += f"~={major_minor}"
96
+ dependency += f"~={version}"
57
97
  logger.debug("Using NAT dependency: %s", dependency)
58
98
  return dependency
59
99
 
@@ -219,12 +259,16 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
219
259
  install_cmd = ['uv', 'pip', 'install', '-e', str(new_workflow_dir)]
220
260
  else:
221
261
  install_cmd = ['pip', 'install', '-e', str(new_workflow_dir)]
262
+ if _is_nat_version_prerelease():
263
+ install_cmd.insert(2, "--pre")
264
+
265
+ python_safe_workflow_name = workflow_name.replace("-", "_")
222
266
 
223
267
  # List of templates and their destinations
224
268
  files_to_render = {
225
269
  'pyproject.toml.j2': new_workflow_dir / 'pyproject.toml',
226
270
  'register.py.j2': base_dir / 'register.py',
227
- 'workflow.py.j2': base_dir / f'{workflow_name}_function.py',
271
+ 'workflow.py.j2': base_dir / f'{python_safe_workflow_name}.py',
228
272
  '__init__.py.j2': base_dir / '__init__.py',
229
273
  'config.yml.j2': configs_dir / 'config.yml',
230
274
  }
@@ -233,7 +277,7 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
233
277
  context = {
234
278
  'editable': editable,
235
279
  'workflow_name': workflow_name,
236
- 'python_safe_workflow_name': workflow_name.replace("-", "_"),
280
+ 'python_safe_workflow_name': python_safe_workflow_name,
237
281
  'package_name': package_name,
238
282
  'rel_path_to_repo_root': rel_path_to_repo_root,
239
283
  'workflow_class_name': f"{_generate_valid_classname(workflow_name)}FunctionConfig",
nat/cli/entrypoint.py CHANGED
@@ -29,6 +29,7 @@ import time
29
29
 
30
30
  import click
31
31
  import nest_asyncio
32
+ from dotenv import load_dotenv
32
33
 
33
34
  from nat.utils.log_levels import LOG_LEVELS
34
35
 
@@ -45,6 +46,9 @@ from .commands.uninstall import uninstall_command
45
46
  from .commands.validate import validate_command
46
47
  from .commands.workflow.workflow import workflow_command
47
48
 
49
+ # Load environment variables from .env file, if it exists
50
+ load_dotenv()
51
+
48
52
  # Apply at the beginning of the file to avoid issues with asyncio
49
53
  nest_asyncio.apply()
50
54
 
@@ -52,7 +56,11 @@ nest_asyncio.apply()
52
56
  def setup_logging(log_level: str):
53
57
  """Configure logging with the specified level"""
54
58
  numeric_level = LOG_LEVELS.get(log_level.upper(), logging.INFO)
55
- logging.basicConfig(level=numeric_level, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
59
+ logging.basicConfig(
60
+ level=numeric_level,
61
+ format="%(asctime)s - %(levelname)-8s - %(name)s:%(lineno)d - %(message)s",
62
+ datefmt="%Y-%m-%d %H:%M:%S",
63
+ )
56
64
  return numeric_level
57
65
 
58
66
 
nat/cli/main.py CHANGED
@@ -30,6 +30,9 @@ def run_cli():
30
30
  import os
31
31
  import sys
32
32
 
33
+ # Suppress warnings from transformers
34
+ os.environ["TRANSFORMERS_VERBOSITY"] = "error"
35
+
33
36
  parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
34
37
 
35
38
  if (parent_dir not in sys.path):
@@ -28,6 +28,7 @@ from pydantic import HttpUrl
28
28
  from pydantic import conlist
29
29
  from pydantic import field_serializer
30
30
  from pydantic import field_validator
31
+ from pydantic import model_validator
31
32
  from pydantic_core.core_schema import ValidationInfo
32
33
 
33
34
  from nat.data_models.interactive import HumanPrompt
@@ -36,6 +37,15 @@ from nat.utils.type_converter import GlobalTypeConverter
36
37
  FINISH_REASONS = frozenset({'stop', 'length', 'tool_calls', 'content_filter', 'function_call'})
37
38
 
38
39
 
40
+ class UserMessageContentRoleType(str, Enum):
41
+ """
42
+ Enum representing chat message roles in API requests and responses.
43
+ """
44
+ USER = "user"
45
+ ASSISTANT = "assistant"
46
+ SYSTEM = "system"
47
+
48
+
39
49
  class Request(BaseModel):
40
50
  """
41
51
  Request is a data model that represents HTTP request attributes.
@@ -108,18 +118,10 @@ UserContent = typing.Annotated[TextContent | ImageContent | AudioContent, Discri
108
118
 
109
119
  class Message(BaseModel):
110
120
  content: str | list[UserContent]
111
- role: str
112
-
113
-
114
- class ChatRequest(BaseModel):
115
- """
116
- ChatRequest is a data model that represents a request to the NAT chat API.
117
- Fully compatible with OpenAI Chat Completions API specification.
118
- """
121
+ role: UserMessageContentRoleType
119
122
 
120
- # Required fields
121
- messages: typing.Annotated[list[Message], conlist(Message, min_length=1)]
122
123
 
124
+ class ChatRequestOptionals(BaseModel):
123
125
  # Optional fields (OpenAI Chat Completions API compatible)
124
126
  model: str | None = Field(default=None, description="name of the model to use")
125
127
  frequency_penalty: float | None = Field(default=0.0,
@@ -144,6 +146,16 @@ class ChatRequest(BaseModel):
144
146
  parallel_tool_calls: bool | None = Field(default=True, description="Whether to enable parallel function calling")
145
147
  user: str | None = Field(default=None, description="Unique identifier representing end-user")
146
148
 
149
+
150
+ class ChatRequest(ChatRequestOptionals):
151
+ """
152
+ ChatRequest is a data model that represents a request to the NAT chat API.
153
+ Fully compatible with OpenAI Chat Completions API specification.
154
+ """
155
+
156
+ # Required fields
157
+ messages: typing.Annotated[list[Message], conlist(Message, min_length=1)]
158
+
147
159
  model_config = ConfigDict(extra="allow",
148
160
  json_schema_extra={
149
161
  "example": {
@@ -164,7 +176,7 @@ class ChatRequest(BaseModel):
164
176
  max_tokens: int | None = None,
165
177
  top_p: float | None = None) -> "ChatRequest":
166
178
 
167
- return ChatRequest(messages=[Message(content=data, role="user")],
179
+ return ChatRequest(messages=[Message(content=data, role=UserMessageContentRoleType.USER)],
168
180
  model=model,
169
181
  temperature=temperature,
170
182
  max_tokens=max_tokens,
@@ -178,38 +190,85 @@ class ChatRequest(BaseModel):
178
190
  max_tokens: int | None = None,
179
191
  top_p: float | None = None) -> "ChatRequest":
180
192
 
181
- return ChatRequest(messages=[Message(content=content, role="user")],
193
+ return ChatRequest(messages=[Message(content=content, role=UserMessageContentRoleType.USER)],
182
194
  model=model,
183
195
  temperature=temperature,
184
196
  max_tokens=max_tokens,
185
197
  top_p=top_p)
186
198
 
187
199
 
200
+ class ChatRequestOrMessage(ChatRequestOptionals):
201
+ """
202
+ ChatRequestOrMessage is a data model that represents either a conversation or a string input.
203
+ This is useful for functions that can handle either type of input.
204
+
205
+ `messages` is compatible with the OpenAI Chat Completions API specification.
206
+
207
+ `input_string` is a string input that can be used for functions that do not require a conversation.
208
+ """
209
+
210
+ messages: typing.Annotated[list[Message] | None, conlist(Message, min_length=1)] = Field(
211
+ default=None, description="The conversation messages to process.")
212
+
213
+ input_string: str | None = Field(default=None, alias="input_message", description="The input message to process.")
214
+
215
+ @property
216
+ def is_string(self) -> bool:
217
+ return self.input_string is not None
218
+
219
+ @property
220
+ def is_conversation(self) -> bool:
221
+ return self.messages is not None
222
+
223
+ @model_validator(mode="after")
224
+ def validate_messages_or_input_string(self):
225
+ if self.messages is not None and self.input_string is not None:
226
+ raise ValueError("Either messages or input_message/input_string must be provided, not both")
227
+ if self.messages is None and self.input_string is None:
228
+ raise ValueError("Either messages or input_message/input_string must be provided")
229
+ if self.input_string is not None:
230
+ extra_fields = self.model_dump(exclude={"input_string"}, exclude_none=True, exclude_unset=True)
231
+ if len(extra_fields) > 0:
232
+ raise ValueError("no extra fields are permitted when input_message/input_string is provided")
233
+ return self
234
+
235
+
188
236
  class ChoiceMessage(BaseModel):
189
237
  content: str | None = None
190
- role: str | None = None
238
+ role: UserMessageContentRoleType | None = None
191
239
 
192
240
 
193
241
  class ChoiceDelta(BaseModel):
194
242
  """Delta object for streaming responses (OpenAI-compatible)"""
195
243
  content: str | None = None
196
- role: str | None = None
244
+ role: UserMessageContentRoleType | None = None
197
245
 
198
246
 
199
- class Choice(BaseModel):
247
+ class ChoiceBase(BaseModel):
248
+ """Base choice model with common fields for both streaming and non-streaming responses"""
200
249
  model_config = ConfigDict(extra="allow")
201
-
202
- message: ChoiceMessage | None = None
203
- delta: ChoiceDelta | None = None
204
250
  finish_reason: typing.Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call'] | None = None
205
251
  index: int
206
- # logprobs: ChoiceLogprobs | None = None
252
+
253
+
254
+ class ChatResponseChoice(ChoiceBase):
255
+ """Choice model for non-streaming responses - contains message field"""
256
+ message: ChoiceMessage
257
+
258
+
259
+ class ChatResponseChunkChoice(ChoiceBase):
260
+ """Choice model for streaming responses - contains delta field"""
261
+ delta: ChoiceDelta
262
+
263
+
264
+ # Backward compatibility alias
265
+ Choice = ChatResponseChoice
207
266
 
208
267
 
209
268
  class Usage(BaseModel):
210
- prompt_tokens: int
211
- completion_tokens: int
212
- total_tokens: int
269
+ prompt_tokens: int | None = None
270
+ completion_tokens: int | None = None
271
+ total_tokens: int | None = None
213
272
 
214
273
 
215
274
  class ResponseSerializable(abc.ABC):
@@ -245,10 +304,10 @@ class ChatResponse(ResponseBaseModelOutput):
245
304
  model_config = ConfigDict(extra="allow")
246
305
  id: str
247
306
  object: str = "chat.completion"
248
- model: str = ""
307
+ model: str = "unknown-model"
249
308
  created: datetime.datetime
250
- choices: list[Choice]
251
- usage: Usage | None = None
309
+ choices: list[ChatResponseChoice]
310
+ usage: Usage
252
311
  system_fingerprint: str | None = None
253
312
  service_tier: typing.Literal["scale", "default"] | None = None
254
313
 
@@ -264,14 +323,14 @@ class ChatResponse(ResponseBaseModelOutput):
264
323
  object_: str | None = None,
265
324
  model: str | None = None,
266
325
  created: datetime.datetime | None = None,
267
- usage: Usage | None = None) -> "ChatResponse":
326
+ usage: Usage) -> "ChatResponse":
268
327
 
269
328
  if id_ is None:
270
329
  id_ = str(uuid.uuid4())
271
330
  if object_ is None:
272
331
  object_ = "chat.completion"
273
332
  if model is None:
274
- model = ""
333
+ model = "unknown-model"
275
334
  if created is None:
276
335
  created = datetime.datetime.now(datetime.UTC)
277
336
 
@@ -279,7 +338,12 @@ class ChatResponse(ResponseBaseModelOutput):
279
338
  object=object_,
280
339
  model=model,
281
340
  created=created,
282
- choices=[Choice(index=0, message=ChoiceMessage(content=data), finish_reason="stop")],
341
+ choices=[
342
+ ChatResponseChoice(index=0,
343
+ message=ChoiceMessage(content=data,
344
+ role=UserMessageContentRoleType.ASSISTANT),
345
+ finish_reason="stop")
346
+ ],
283
347
  usage=usage)
284
348
 
285
349
 
@@ -293,9 +357,9 @@ class ChatResponseChunk(ResponseBaseModelOutput):
293
357
  model_config = ConfigDict(extra="allow")
294
358
 
295
359
  id: str
296
- choices: list[Choice]
360
+ choices: list[ChatResponseChunkChoice]
297
361
  created: datetime.datetime
298
- model: str = ""
362
+ model: str = "unknown-model"
299
363
  object: str = "chat.completion.chunk"
300
364
  system_fingerprint: str | None = None
301
365
  service_tier: typing.Literal["scale", "default"] | None = None
@@ -319,12 +383,18 @@ class ChatResponseChunk(ResponseBaseModelOutput):
319
383
  if created is None:
320
384
  created = datetime.datetime.now(datetime.UTC)
321
385
  if model is None:
322
- model = ""
386
+ model = "unknown-model"
323
387
  if object_ is None:
324
388
  object_ = "chat.completion.chunk"
325
389
 
326
390
  return ChatResponseChunk(id=id_,
327
- choices=[Choice(index=0, message=ChoiceMessage(content=data), finish_reason="stop")],
391
+ choices=[
392
+ ChatResponseChunkChoice(index=0,
393
+ delta=ChoiceDelta(
394
+ content=data,
395
+ role=UserMessageContentRoleType.ASSISTANT),
396
+ finish_reason="stop")
397
+ ],
328
398
  created=created,
329
399
  model=model,
330
400
  object=object_)
@@ -335,7 +405,7 @@ class ChatResponseChunk(ResponseBaseModelOutput):
335
405
  id_: str | None = None,
336
406
  created: datetime.datetime | None = None,
337
407
  model: str | None = None,
338
- role: str | None = None,
408
+ role: UserMessageContentRoleType | None = None,
339
409
  finish_reason: str | None = None,
340
410
  usage: Usage | None = None,
341
411
  system_fingerprint: str | None = None) -> "ChatResponseChunk":
@@ -345,7 +415,7 @@ class ChatResponseChunk(ResponseBaseModelOutput):
345
415
  if created is None:
346
416
  created = datetime.datetime.now(datetime.UTC)
347
417
  if model is None:
348
- model = ""
418
+ model = "unknown-model"
349
419
 
350
420
  delta = ChoiceDelta(content=content, role=role) if content is not None or role is not None else ChoiceDelta()
351
421
 
@@ -353,7 +423,14 @@ class ChatResponseChunk(ResponseBaseModelOutput):
353
423
 
354
424
  return ChatResponseChunk(
355
425
  id=id_,
356
- choices=[Choice(index=0, message=None, delta=delta, finish_reason=final_finish_reason)],
426
+ choices=[
427
+ ChatResponseChunkChoice(
428
+ index=0,
429
+ delta=delta,
430
+ finish_reason=typing.cast(
431
+ typing.Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call'] | None,
432
+ final_finish_reason))
433
+ ],
357
434
  created=created,
358
435
  model=model,
359
436
  object="chat.completion.chunk",
@@ -398,11 +475,6 @@ class GenerateResponse(BaseModel):
398
475
  value: str | None = "default"
399
476
 
400
477
 
401
- class UserMessageContentRoleType(str, Enum):
402
- USER = "user"
403
- ASSISTANT = "assistant"
404
-
405
-
406
478
  class WebSocketMessageType(str, Enum):
407
479
  """
408
480
  WebSocketMessageType is an Enum that represents WebSocket Message types.
@@ -622,12 +694,42 @@ GlobalTypeConverter.register_converter(_nat_chat_request_to_string)
622
694
 
623
695
 
624
696
  def _string_to_nat_chat_request(data: str) -> ChatRequest:
625
- return ChatRequest.from_string(data, model="")
697
+ return ChatRequest.from_string(data, model="unknown-model")
626
698
 
627
699
 
628
700
  GlobalTypeConverter.register_converter(_string_to_nat_chat_request)
629
701
 
630
702
 
703
+ def _chat_request_or_message_to_chat_request(data: ChatRequestOrMessage) -> ChatRequest:
704
+ if data.input_string is not None:
705
+ return _string_to_nat_chat_request(data.input_string)
706
+ return ChatRequest(**data.model_dump(exclude={"input_string"}))
707
+
708
+
709
+ GlobalTypeConverter.register_converter(_chat_request_or_message_to_chat_request)
710
+
711
+
712
+ def _chat_request_to_chat_request_or_message(data: ChatRequest) -> ChatRequestOrMessage:
713
+ return ChatRequestOrMessage(**data.model_dump(by_alias=True))
714
+
715
+
716
+ GlobalTypeConverter.register_converter(_chat_request_to_chat_request_or_message)
717
+
718
+
719
+ def _chat_request_or_message_to_string(data: ChatRequestOrMessage) -> str:
720
+ return data.input_string or ""
721
+
722
+
723
+ GlobalTypeConverter.register_converter(_chat_request_or_message_to_string)
724
+
725
+
726
+ def _string_to_chat_request_or_message(data: str) -> ChatRequestOrMessage:
727
+ return ChatRequestOrMessage(input_message=data)
728
+
729
+
730
+ GlobalTypeConverter.register_converter(_string_to_chat_request_or_message)
731
+
732
+
631
733
  # ======== ChatResponse Converters ========
632
734
  def _nat_chat_response_to_string(data: ChatResponse) -> str:
633
735
  if data.choices and data.choices[0].message:
@@ -654,22 +756,12 @@ def _string_to_nat_chat_response(data: str) -> ChatResponse:
654
756
  GlobalTypeConverter.register_converter(_string_to_nat_chat_response)
655
757
 
656
758
 
657
- def _chat_response_to_chat_response_chunk(data: ChatResponse) -> ChatResponseChunk:
658
- # Preserve original message structure for backward compatibility
659
- return ChatResponseChunk(id=data.id, choices=data.choices, created=data.created, model=data.model)
660
-
661
-
662
- GlobalTypeConverter.register_converter(_chat_response_to_chat_response_chunk)
663
-
664
-
665
759
  # ======== ChatResponseChunk Converters ========
666
760
  def _chat_response_chunk_to_string(data: ChatResponseChunk) -> str:
667
761
  if data.choices and len(data.choices) > 0:
668
762
  choice = data.choices[0]
669
763
  if choice.delta and choice.delta.content:
670
764
  return choice.delta.content
671
- if choice.message and choice.message.content:
672
- return choice.message.content
673
765
  return ""
674
766
 
675
767
 
@@ -685,21 +777,6 @@ def _string_to_nat_chat_response_chunk(data: str) -> ChatResponseChunk:
685
777
 
686
778
  GlobalTypeConverter.register_converter(_string_to_nat_chat_response_chunk)
687
779
 
688
-
689
- # ======== AINodeMessageChunk Converters ========
690
- def _ai_message_chunk_to_nat_chat_response_chunk(data) -> ChatResponseChunk:
691
- '''Converts LangChain/LangGraph AINodeMessageChunk to ChatResponseChunk'''
692
- content = ""
693
- if hasattr(data, 'content') and data.content is not None:
694
- content = str(data.content)
695
- elif hasattr(data, 'text') and data.text is not None:
696
- content = str(data.text)
697
- elif hasattr(data, 'message') and data.message is not None:
698
- content = str(data.message)
699
-
700
- return ChatResponseChunk.create_streaming_chunk(content=content, role="assistant", finish_reason=None)
701
-
702
-
703
780
  # Compatibility aliases with previous releases
704
781
  AIQChatRequest = ChatRequest
705
782
  AIQChoiceMessage = ChoiceMessage
nat/data_models/config.py CHANGED
@@ -187,7 +187,7 @@ class TelemetryConfig(BaseModel):
187
187
 
188
188
  class GeneralConfig(BaseModel):
189
189
 
190
- model_config = ConfigDict(protected_namespaces=())
190
+ model_config = ConfigDict(protected_namespaces=(), extra="forbid")
191
191
 
192
192
  use_uvloop: bool | None = Field(
193
193
  default=None,
nat/data_models/span.py CHANGED
@@ -128,10 +128,48 @@ class SpanStatus(BaseModel):
128
128
  message: str | None = Field(default=None, description="The status message of the span.")
129
129
 
130
130
 
131
+ def _generate_nonzero_trace_id() -> int:
132
+ """Generate a non-zero 128-bit trace ID."""
133
+ return uuid.uuid4().int
134
+
135
+
136
+ def _generate_nonzero_span_id() -> int:
137
+ """Generate a non-zero 64-bit span ID."""
138
+ return uuid.uuid4().int >> 64
139
+
140
+
131
141
  class SpanContext(BaseModel):
132
- trace_id: int = Field(default_factory=lambda: uuid.uuid4().int, description="The 128-bit trace ID of the span.")
133
- span_id: int = Field(default_factory=lambda: uuid.uuid4().int & ((1 << 64) - 1),
134
- description="The 64-bit span ID of the span.")
142
+ trace_id: int = Field(default_factory=_generate_nonzero_trace_id,
143
+ description="The OTel-syle 128-bit trace ID of the span.")
144
+ span_id: int = Field(default_factory=_generate_nonzero_span_id,
145
+ description="The OTel-syle 64-bit span ID of the span.")
146
+
147
+ @field_validator("trace_id", mode="before")
148
+ @classmethod
149
+ def _validate_trace_id(cls, v: int | str | None) -> int:
150
+ """Regenerate if trace_id is None; raise an exception if trace_id is invalid;"""
151
+ if isinstance(v, str):
152
+ v = uuid.UUID(v).int
153
+ if isinstance(v, type(None)):
154
+ v = _generate_nonzero_trace_id()
155
+ if v <= 0 or v >> 128:
156
+ raise ValueError(f"Invalid trace_id: must be a non-zero 128-bit integer, got {v}")
157
+ return v
158
+
159
+ @field_validator("span_id", mode="before")
160
+ @classmethod
161
+ def _validate_span_id(cls, v: int | str | None) -> int:
162
+ """Regenerate if span_id is None; raise an exception if span_id is invalid;"""
163
+ if isinstance(v, str):
164
+ try:
165
+ v = int(v, 16)
166
+ except ValueError:
167
+ raise ValueError(f"span_id unable to be parsed: {v}")
168
+ if isinstance(v, type(None)):
169
+ v = _generate_nonzero_span_id()
170
+ if v <= 0 or v >> 64:
171
+ raise ValueError(f"Invalid span_id: must be a non-zero 64-bit integer, got {v}")
172
+ return v
135
173
 
136
174
 
137
175
  class Span(BaseModel):
@@ -46,7 +46,7 @@ async def execute_score_select_function(config: ExecuteScoreSelectFunctionConfig
46
46
 
47
47
  from pydantic import BaseModel
48
48
 
49
- executable_fn: Function = builder.get_function(name=config.augmented_fn)
49
+ executable_fn: Function = await builder.get_function(name=config.augmented_fn)
50
50
 
51
51
  if config.scorer:
52
52
  scorer = await builder.get_ttc_strategy(strategy_name=config.scorer,
@@ -98,8 +98,8 @@ async def register_ttc_tool_wrapper_function(
98
98
 
99
99
  augmented_function_desc = config.tool_description
100
100
 
101
- fn_input_schema: BaseModel = augmented_function.input_schema
102
- fn_output_schema: BaseModel = augmented_function.single_output_schema
101
+ fn_input_schema: type[BaseModel] = augmented_function.input_schema
102
+ fn_output_schema: type[BaseModel] | type[None] = augmented_function.single_output_schema
103
103
 
104
104
  runnable_llm = input_llm.with_structured_output(schema=fn_input_schema)
105
105
 
@@ -95,5 +95,14 @@ class ConsoleFrontEndPlugin(SimpleFrontEndPluginBase[ConsoleFrontEndConfig]):
95
95
  else:
96
96
  assert False, "Should not reach here. Should have been caught by pre_run"
97
97
 
98
- # Print result
99
- logger.info(f"\n{'-' * 50}\n{Fore.GREEN}Workflow Result:\n%s{Fore.RESET}\n{'-' * 50}", runner_outputs)
98
+ line = f"{'-' * 50}"
99
+ prefix = f"{line}\n{Fore.GREEN}Workflow Result:\n"
100
+ suffix = f"{Fore.RESET}\n{line}"
101
+
102
+ logger.info(f"{prefix}%s{suffix}", runner_outputs)
103
+
104
+ # (handler is a stream handler) => (level > INFO)
105
+ effective_level_too_high = all(
106
+ type(h) is not logging.StreamHandler or h.level > logging.INFO for h in logging.getLogger().handlers)
107
+ if effective_level_too_high:
108
+ print(f"{prefix}{runner_outputs}{suffix}")
@@ -24,4 +24,4 @@ class HTTPAuthenticationFlowHandler(FlowHandlerBase):
24
24
  async def authenticate(self, config: AuthProviderBaseConfig, method: AuthFlowType) -> AuthenticatedContext:
25
25
 
26
26
  raise NotImplementedError(f"Authentication method '{method}' is not supported by the HTTP frontend."
27
- f" Do you have Websockets enabled?")
27
+ f" Do you have WebSockets enabled?")