nvidia-nat 1.3.0a20250822__py3-none-any.whl → 1.3.0a20250823__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. aiq/__init__.py +2 -2
  2. nat/agent/base.py +0 -1
  3. nat/agent/react_agent/agent.py +21 -3
  4. nat/agent/react_agent/register.py +1 -1
  5. nat/agent/register.py +0 -1
  6. nat/agent/rewoo_agent/agent.py +0 -1
  7. nat/agent/rewoo_agent/register.py +1 -1
  8. nat/agent/tool_calling_agent/agent.py +0 -1
  9. nat/agent/tool_calling_agent/register.py +1 -1
  10. nat/authentication/api_key/api_key_auth_provider.py +1 -1
  11. nat/authentication/register.py +0 -1
  12. nat/builder/builder.py +1 -1
  13. nat/builder/context.py +9 -1
  14. nat/builder/function_base.py +3 -3
  15. nat/builder/function_info.py +5 -7
  16. nat/builder/workflow_builder.py +0 -1
  17. nat/cli/commands/evaluate.py +1 -1
  18. nat/cli/commands/info/list_components.py +7 -8
  19. nat/cli/commands/info/list_mcp.py +3 -4
  20. nat/cli/commands/registry/search.py +14 -16
  21. nat/cli/commands/start.py +0 -1
  22. nat/cli/commands/workflow/templates/register.py.j2 +0 -1
  23. nat/cli/commands/workflow/workflow_commands.py +0 -1
  24. nat/cli/type_registry.py +3 -5
  25. nat/data_models/config.py +1 -1
  26. nat/data_models/evaluate.py +1 -1
  27. nat/data_models/function_dependencies.py +6 -6
  28. nat/data_models/intermediate_step.py +3 -3
  29. nat/data_models/model_gated_field_mixin.py +125 -0
  30. nat/data_models/swe_bench_model.py +1 -1
  31. nat/data_models/temperature_mixin.py +36 -0
  32. nat/data_models/top_p_mixin.py +36 -0
  33. nat/embedder/register.py +0 -1
  34. nat/eval/dataset_handler/dataset_handler.py +5 -6
  35. nat/eval/evaluate.py +7 -8
  36. nat/eval/rag_evaluator/register.py +2 -2
  37. nat/eval/register.py +0 -1
  38. nat/eval/tunable_rag_evaluator/evaluate.py +0 -3
  39. nat/eval/utils/weave_eval.py +3 -3
  40. nat/experimental/test_time_compute/models/strategy_base.py +3 -2
  41. nat/experimental/test_time_compute/register.py +0 -1
  42. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +0 -2
  43. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +48 -49
  44. nat/front_ends/fastapi/message_handler.py +13 -14
  45. nat/front_ends/fastapi/message_validator.py +4 -4
  46. nat/front_ends/fastapi/step_adaptor.py +1 -1
  47. nat/front_ends/register.py +0 -1
  48. nat/llm/aws_bedrock_llm.py +3 -3
  49. nat/llm/azure_openai_llm.py +3 -4
  50. nat/llm/nim_llm.py +4 -4
  51. nat/llm/openai_llm.py +4 -4
  52. nat/llm/register.py +0 -1
  53. nat/llm/utils/env_config_value.py +2 -3
  54. nat/object_store/register.py +0 -1
  55. nat/observability/exporter/file_exporter.py +1 -1
  56. nat/observability/register.py +3 -3
  57. nat/profiler/callbacks/langchain_callback_handler.py +1 -1
  58. nat/profiler/callbacks/semantic_kernel_callback_handler.py +1 -1
  59. nat/profiler/data_frame_row.py +1 -1
  60. nat/profiler/decorators/framework_wrapper.py +1 -4
  61. nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
  62. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
  63. nat/profiler/inference_optimization/data_models.py +3 -3
  64. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
  65. nat/profiler/inference_optimization/token_uniqueness.py +1 -1
  66. nat/profiler/profile_runner.py +13 -8
  67. nat/registry_handlers/package_utils.py +0 -1
  68. nat/registry_handlers/pypi/pypi_handler.py +20 -23
  69. nat/registry_handlers/register.py +3 -4
  70. nat/registry_handlers/rest/rest_handler.py +8 -9
  71. nat/retriever/register.py +0 -1
  72. nat/runtime/session.py +23 -8
  73. nat/settings/global_settings.py +0 -1
  74. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
  75. nat/tool/document_search.py +1 -1
  76. nat/tool/mcp/mcp_tool.py +1 -1
  77. nat/tool/register.py +0 -1
  78. nat/utils/data_models/schema_validator.py +2 -2
  79. nat/utils/exception_handlers/automatic_retries.py +0 -2
  80. nat/utils/exception_handlers/schemas.py +1 -1
  81. nat/utils/reactive/base/observable_base.py +2 -2
  82. nat/utils/reactive/base/observer_base.py +1 -1
  83. nat/utils/reactive/observable.py +2 -2
  84. nat/utils/reactive/observer.py +2 -2
  85. nat/utils/reactive/subscription.py +1 -1
  86. nat/utils/settings/global_settings.py +4 -6
  87. nat/utils/type_utils.py +4 -4
  88. {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/METADATA +1 -1
  89. {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/RECORD +94 -91
  90. {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/WHEEL +0 -0
  91. {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/entry_points.txt +0 -0
  92. {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  93. {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/licenses/LICENSE.md +0 -0
  94. {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/top_level.txt +0 -0
nat/llm/openai_llm.py CHANGED
@@ -22,9 +22,11 @@ from nat.builder.llm import LLMProviderInfo
22
22
  from nat.cli.register_workflow import register_llm_provider
23
23
  from nat.data_models.llm import LLMBaseConfig
24
24
  from nat.data_models.retry_mixin import RetryMixin
25
+ from nat.data_models.temperature_mixin import TemperatureMixin
26
+ from nat.data_models.top_p_mixin import TopPMixin
25
27
 
26
28
 
27
- class OpenAIModelConfig(LLMBaseConfig, RetryMixin, name="openai"):
29
+ class OpenAIModelConfig(LLMBaseConfig, RetryMixin, TemperatureMixin, TopPMixin, name="openai"):
28
30
  """An OpenAI LLM provider to be used with an LLM client."""
29
31
 
30
32
  model_config = ConfigDict(protected_namespaces=(), extra="allow")
@@ -34,13 +36,11 @@ class OpenAIModelConfig(LLMBaseConfig, RetryMixin, name="openai"):
34
36
  model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
35
37
  serialization_alias="model",
36
38
  description="The OpenAI hosted model name.")
37
- temperature: float = Field(default=0.0, description="Sampling temperature in [0, 1].")
38
- top_p: float = Field(default=1.0, description="Top-p for distribution sampling.")
39
39
  seed: int | None = Field(default=None, description="Random seed to set for generation.")
40
40
  max_retries: int = Field(default=10, description="The max number of retries for the request.")
41
41
 
42
42
 
43
43
  @register_llm_provider(config_type=OpenAIModelConfig)
44
- async def openai_llm(config: OpenAIModelConfig, builder: Builder):
44
+ async def openai_llm(config: OpenAIModelConfig, _builder: Builder):
45
45
 
46
46
  yield LLMProviderInfo(config=config, description="An OpenAI model for use with an LLM client.")
nat/llm/register.py CHANGED
@@ -13,7 +13,6 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- # pylint: disable=unused-import
17
16
  # flake8: noqa
18
17
  # isort:skip_file
19
18
 
@@ -72,9 +72,8 @@ class EnvConfigValue(ABC):
72
72
  f"{message} Try passing a value to the constructor, or setting the `{self.__class__._ENV_KEY}` "
73
73
  "environment variable.")
74
74
 
75
- else:
76
- if not self.__class__._ALLOW_NONE and value is None:
77
- raise ValueError("value must not be none")
75
+ elif not self.__class__._ALLOW_NONE and value is None:
76
+ raise ValueError("value must not be none")
78
77
 
79
78
  assert isinstance(value, str) or value is None
80
79
 
@@ -13,7 +13,6 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- # pylint: disable=unused-import
17
16
  # flake8: noqa
18
17
  # isort:skip_file
19
18
 
@@ -24,7 +24,7 @@ from nat.observability.processor.intermediate_step_serializer import Intermediat
24
24
  logger = logging.getLogger(__name__)
25
25
 
26
26
 
27
- class FileExporter(FileExportMixin, RawExporter[IntermediateStep, str]): # pylint: disable=R0901
27
+ class FileExporter(FileExportMixin, RawExporter[IntermediateStep, str]):
28
28
  """A File exporter that exports telemetry traces to a local file."""
29
29
 
30
30
  def __init__(self, context_state: ContextState | None = None, **file_kwargs):
@@ -45,7 +45,7 @@ class FileTelemetryExporterConfig(TelemetryExporterBaseConfig, name="file"):
45
45
 
46
46
 
47
47
  @register_telemetry_exporter(config_type=FileTelemetryExporterConfig)
48
- async def file_telemetry_exporter(config: FileTelemetryExporterConfig, builder: Builder): # pylint: disable=W0613
48
+ async def file_telemetry_exporter(config: FileTelemetryExporterConfig, builder: Builder):
49
49
  """
50
50
  Build and return a FileExporter for file-based telemetry export with optional rolling.
51
51
  """
@@ -68,7 +68,7 @@ class ConsoleLoggingMethodConfig(LoggingBaseConfig, name="console"):
68
68
 
69
69
 
70
70
  @register_logging_method(config_type=ConsoleLoggingMethodConfig)
71
- async def console_logging_method(config: ConsoleLoggingMethodConfig, builder: Builder): # pylint: disable=W0613
71
+ async def console_logging_method(config: ConsoleLoggingMethodConfig, builder: Builder):
72
72
  """
73
73
  Build and return a StreamHandler for console-based logging.
74
74
  """
@@ -86,7 +86,7 @@ class FileLoggingMethod(LoggingBaseConfig, name="file"):
86
86
 
87
87
 
88
88
  @register_logging_method(config_type=FileLoggingMethod)
89
- async def file_logging_method(config: FileLoggingMethod, builder: Builder): # pylint: disable=W0613
89
+ async def file_logging_method(config: FileLoggingMethod, builder: Builder):
90
90
  """
91
91
  Build and return a FileHandler for file-based logging.
92
92
  """
@@ -53,7 +53,7 @@ def _extract_tools_schema(invocation_params: dict) -> list:
53
53
  return tools_schema
54
54
 
55
55
 
56
- class LangchainProfilerHandler(AsyncCallbackHandler, BaseProfilerCallback): # pylint: disable=R0901
56
+ class LangchainProfilerHandler(AsyncCallbackHandler, BaseProfilerCallback):
57
57
  """Callback Handler that tracks NIM info."""
58
58
 
59
59
  total_tokens: int = 0
@@ -86,7 +86,7 @@ class SemanticKernelProfilerHandler(BaseProfilerCallback):
86
86
 
87
87
  # Gather the appropriate modules/functions based on your builder config
88
88
  for llm in self._builder_llms:
89
- if self._builder_llms[llm].provider_type == 'openai': # pylint: disable=consider-using-in
89
+ if self._builder_llms[llm].provider_type == 'openai':
90
90
  functions_to_patch.extend(["openai_non_streaming", "openai_streaming"])
91
91
 
92
92
  # Grab original reference for the tool call
@@ -42,7 +42,7 @@ class DataFrameRow(BaseModel):
42
42
  framework: str | None
43
43
 
44
44
  @field_validator('llm_text_input', 'llm_text_output', 'llm_new_token', mode='before')
45
- def cast_to_str(cls, v): # pylint: disable=no-self-argument
45
+ def cast_to_str(cls, v):
46
46
  if v is None:
47
47
  return v
48
48
  try:
@@ -13,8 +13,6 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- # pylint disable=ungrouped-imports
17
-
18
16
  from __future__ import annotations
19
17
 
20
18
  import functools
@@ -75,8 +73,7 @@ def set_framework_profiler_handler(
75
73
  logger.debug("LlamaIndex callback handler registered")
76
74
 
77
75
  if LLMFrameworkEnum.CREWAI in frameworks and not _library_instrumented["crewai"]:
78
- from nat.plugins.crewai.crewai_callback_handler import \
79
- CrewAIProfilerHandler # pylint: disable=ungrouped-imports,line-too-long # noqa E501
76
+ from nat.plugins.crewai.crewai_callback_handler import CrewAIProfilerHandler
80
77
 
81
78
  handler = CrewAIProfilerHandler()
82
79
  handler.instrument()
@@ -15,7 +15,9 @@
15
15
 
16
16
  # forecasting/models/base_model.py
17
17
 
18
- from abc import ABC, abstractmethod
18
+ from abc import ABC
19
+ from abc import abstractmethod
20
+
19
21
  import numpy as np
20
22
 
21
23
 
@@ -195,7 +195,7 @@ def profile_workflow_bottlenecks(all_steps: list[list[IntermediateStep]]) -> Sim
195
195
  c_max = 0
196
196
  for ts, delta in events_sub:
197
197
  c_curr += delta
198
- if c_curr > c_max: # pylint: disable=consider-using-max-builtin
198
+ if c_curr > c_max: # noqa: PLR1730 - don't use max built-in
199
199
  c_max = c_curr
200
200
  max_concurrency_by_name[op_name] = c_max
201
201
 
@@ -172,7 +172,7 @@ class CallNode(BaseModel):
172
172
  if not self.children:
173
173
  return self.duration
174
174
 
175
- intervals = [(c.start_time, c.end_time) for c in self.children] # pylint: disable=not-an-iterable
175
+ intervals = [(c.start_time, c.end_time) for c in self.children]
176
176
  # Sort by start time
177
177
  intervals.sort(key=lambda x: x[0])
178
178
 
@@ -204,7 +204,7 @@ class CallNode(BaseModel):
204
204
  This ensures no overlap double-counting among children.
205
205
  """
206
206
  total = self.compute_self_time()
207
- for c in self.children: # pylint: disable=not-an-iterable
207
+ for c in self.children:
208
208
  total += c.compute_subtree_time()
209
209
  return total
210
210
 
@@ -216,7 +216,7 @@ class CallNode(BaseModel):
216
216
  info = (f"{indent}- {self.operation_type} '{self.operation_name}' "
217
217
  f"(uuid={self.uuid}, start={self.start_time:.2f}, "
218
218
  f"end={self.end_time:.2f}, dur={self.duration:.2f})")
219
- child_strs = [child._repr(level + 1) for child in self.children] # pylint: disable=not-an-iterable
219
+ child_strs = [child._repr(level + 1) for child in self.children]
220
220
  return "\n".join([info] + child_strs)
221
221
 
222
222
 
@@ -228,7 +228,7 @@ def run_prefixspan(sequences_map: dict[int, list[PrefixCallNode]],
228
228
  else:
229
229
  abs_min_support = min_support
230
230
 
231
- freq_patterns = ps.frequent(abs_min_support) # pylint: disable=not-callable
231
+ freq_patterns = ps.frequent(abs_min_support)
232
232
  # freq_patterns => [(count, [item1, item2, ...])]
233
233
 
234
234
  results = []
@@ -321,13 +321,12 @@ def compute_coverage_and_duration(sequences_map: dict[int, list[PrefixCallNode]]
321
321
  # --------------------------------------------------------------------------------
322
322
 
323
323
 
324
- def prefixspan_subworkflow_with_text( # pylint: disable=too-many-positional-arguments
325
- all_steps: list[list[IntermediateStep]],
326
- min_support: int | float = 2,
327
- top_k: int = 10,
328
- min_coverage: float = 0.0,
329
- max_text_len: int = 700,
330
- prefix_list: list[str] = None) -> PrefixSpanSubworkflowResult:
324
+ def prefixspan_subworkflow_with_text(all_steps: list[list[IntermediateStep]],
325
+ min_support: int | float = 2,
326
+ top_k: int = 10,
327
+ min_coverage: float = 0.0,
328
+ max_text_len: int = 700,
329
+ prefix_list: list[str] = None) -> PrefixSpanSubworkflowResult:
331
330
  """
332
331
  1) Build sequences of calls for each example (with llm_text_input).
333
332
  2) Convert to token lists, run PrefixSpan with min_support.
@@ -66,7 +66,7 @@ def compute_inter_query_token_uniqueness_by_llm(all_steps: list[list[Intermediat
66
66
  # 2) Group by (llm_name, example_number), then sort each group
67
67
  grouped = cdf.groupby(['llm_name', 'example_number'], as_index=False, group_keys=True)
68
68
 
69
- for (llm, ex_num), group_df in grouped: # pylint: disable=unused-variable
69
+ for (llm, ex_num), group_df in grouped:
70
70
  # Sort by event_timestamp
71
71
  group_df = group_df.sort_values('event_timestamp', ascending=True)
72
72
 
@@ -88,14 +88,19 @@ class ProfilerRunner:
88
88
  writes out combined requests JSON, then computes and saves additional metrics,
89
89
  and optionally fits a forecasting model.
90
90
  """
91
- from nat.profiler.inference_optimization.bottleneck_analysis.nested_stack_analysis import \
92
- multi_example_call_profiling
93
- from nat.profiler.inference_optimization.bottleneck_analysis.simple_stack_analysis import \
94
- profile_workflow_bottlenecks
95
- from nat.profiler.inference_optimization.experimental.concurrency_spike_analysis import \
96
- concurrency_spike_analysis
97
- from nat.profiler.inference_optimization.experimental.prefix_span_analysis import \
98
- prefixspan_subworkflow_with_text
91
+ # Yapf and ruff disagree on how to format long imports, disable yapf go with ruff
92
+ from nat.profiler.inference_optimization.bottleneck_analysis.nested_stack_analysis import (
93
+ multi_example_call_profiling,
94
+ ) # yapf: disable
95
+ from nat.profiler.inference_optimization.bottleneck_analysis.simple_stack_analysis import (
96
+ profile_workflow_bottlenecks,
97
+ ) # yapf: disable
98
+ from nat.profiler.inference_optimization.experimental.concurrency_spike_analysis import (
99
+ concurrency_spike_analysis,
100
+ ) # yapf: disable
101
+ from nat.profiler.inference_optimization.experimental.prefix_span_analysis import (
102
+ prefixspan_subworkflow_with_text,
103
+ ) # yapf: disable
99
104
  from nat.profiler.inference_optimization.llm_metrics import LLMMetrics
100
105
  from nat.profiler.inference_optimization.prompt_caching import get_common_prefixes
101
106
  from nat.profiler.inference_optimization.token_uniqueness import compute_inter_query_token_uniqueness_by_llm
@@ -29,7 +29,6 @@ from nat.registry_handlers.schemas.publish import Artifact
29
29
  from nat.runtime.loader import PluginTypes
30
30
  from nat.runtime.loader import discover_entrypoints
31
31
 
32
- # pylint: disable=redefined-outer-name
33
32
  logger = logging.getLogger(__name__)
34
33
 
35
34
 
@@ -44,13 +44,12 @@ class PypiRegistryHandler(AbstractRegistryHandler):
44
44
  https://github.com/pypiserver/pypiserver
45
45
  """
46
46
 
47
- def __init__( # pylint: disable=R0917
48
- self,
49
- endpoint: str,
50
- token: str | None = None,
51
- publish_route: str = "",
52
- pull_route: str = "",
53
- search_route: str = ""):
47
+ def __init__(self,
48
+ endpoint: str,
49
+ token: str | None = None,
50
+ publish_route: str = "",
51
+ pull_route: str = "",
52
+ search_route: str = ""):
54
53
  super().__init__()
55
54
  self._endpoint = endpoint.rstrip("/")
56
55
  self._token = token
@@ -126,17 +125,16 @@ class PypiRegistryHandler(AbstractRegistryHandler):
126
125
 
127
126
  versioned_packages_str = " ".join(versioned_packages)
128
127
 
129
- result = subprocess.run(
130
- [
131
- "uv",
132
- "pip",
133
- "install",
134
- "--prerelease=allow",
135
- "--index-url",
136
- f"{self._endpoint}/{self._pull_route}/",
137
- versioned_packages_str
138
- ], # pylint: disable=W0631
139
- check=True)
128
+ result = subprocess.run([
129
+ "uv",
130
+ "pip",
131
+ "install",
132
+ "--prerelease=allow",
133
+ "--index-url",
134
+ f"{self._endpoint}/{self._pull_route}/",
135
+ versioned_packages_str
136
+ ],
137
+ check=True)
140
138
 
141
139
  result.check_returncode()
142
140
 
@@ -171,11 +169,10 @@ class PypiRegistryHandler(AbstractRegistryHandler):
171
169
  """
172
170
 
173
171
  try:
174
- completed_process = subprocess.run(
175
- ["pip", "search", "--index", f"{self._endpoint}", query.query], # pylint: disable=W0631
176
- text=True,
177
- capture_output=True,
178
- check=True)
172
+ completed_process = subprocess.run(["pip", "search", "--index", f"{self._endpoint}", query.query],
173
+ text=True,
174
+ capture_output=True,
175
+ check=True)
179
176
  search_response_list = []
180
177
  search_results = completed_process.stdout
181
178
  package_results = search_results.split("\n")
@@ -13,9 +13,8 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- # pylint: disable=unused-import
17
16
  # flake8: noqa
18
17
 
19
- from .local import register_local # pylint: disable=E0611
20
- from .pypi import register_pypi # pylint: disable=E0611
21
- from .rest import register_rest # pylint: disable=E0611
18
+ from .local import register_local
19
+ from .pypi import register_pypi
20
+ from .rest import register_rest
@@ -42,15 +42,14 @@ logger = logging.getLogger(__name__)
42
42
  class RestRegistryHandler(AbstractRegistryHandler):
43
43
  """A registry handler for interactions with a remote REST registry."""
44
44
 
45
- def __init__( # pylint: disable=R0917
46
- self,
47
- endpoint: str,
48
- token: str,
49
- timeout: int = 30,
50
- publish_route: str = "",
51
- pull_route: str = "",
52
- search_route: str = "",
53
- remove_route: str = ""):
45
+ def __init__(self,
46
+ endpoint: str,
47
+ token: str,
48
+ timeout: int = 30,
49
+ publish_route: str = "",
50
+ pull_route: str = "",
51
+ search_route: str = "",
52
+ remove_route: str = ""):
54
53
  super().__init__()
55
54
  self._endpoint = endpoint.rstrip("/")
56
55
  self._timeout = timeout
nat/retriever/register.py CHANGED
@@ -13,7 +13,6 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- # pylint: disable=unused-import
17
16
  # flake8: noqa
18
17
  # isort:skip_file
19
18
 
nat/runtime/session.py CHANGED
@@ -21,7 +21,9 @@ from collections.abc import Callable
21
21
  from contextlib import asynccontextmanager
22
22
  from contextlib import nullcontext
23
23
 
24
+ from fastapi import WebSocket
24
25
  from starlette.requests import HTTPConnection
26
+ from starlette.requests import Request
25
27
 
26
28
  from nat.builder.context import Context
27
29
  from nat.builder.context import ContextState
@@ -89,7 +91,8 @@ class SessionManager:
89
91
  @asynccontextmanager
90
92
  async def session(self,
91
93
  user_manager=None,
92
- request: HTTPConnection | None = None,
94
+ http_connection: HTTPConnection | None = None,
95
+ user_message_id: str | None = None,
93
96
  conversation_id: str | None = None,
94
97
  user_input_callback: Callable[[InteractionPrompt], Awaitable[HumanResponse]] = None,
95
98
  user_authentication_callback: Callable[[AuthProviderBaseConfig, AuthFlowType],
@@ -107,10 +110,11 @@ class SessionManager:
107
110
  if user_authentication_callback is not None:
108
111
  token_user_authentication = self._context_state.user_auth_callback.set(user_authentication_callback)
109
112
 
110
- if conversation_id is not None and request is None:
111
- self._context_state.conversation_id.set(conversation_id)
113
+ if isinstance(http_connection, WebSocket):
114
+ self.set_metadata_from_websocket(user_message_id, conversation_id)
112
115
 
113
- self.set_metadata_from_http_request(request)
116
+ if isinstance(http_connection, Request):
117
+ self.set_metadata_from_http_request(http_connection)
114
118
 
115
119
  try:
116
120
  yield self
@@ -135,14 +139,11 @@ class SessionManager:
135
139
  async with self._workflow.run(message) as runner:
136
140
  yield runner
137
141
 
138
- def set_metadata_from_http_request(self, request: HTTPConnection | None) -> None:
142
+ def set_metadata_from_http_request(self, request: Request) -> None:
139
143
  """
140
144
  Extracts and sets user metadata request attributes from a HTTP request.
141
145
  If request is None, no attributes are set.
142
146
  """
143
- if request is None:
144
- return
145
-
146
147
  self._context.metadata._request.method = getattr(request, "method", None)
147
148
  self._context.metadata._request.url_path = request.url.path
148
149
  self._context.metadata._request.url_port = request.url.port
@@ -157,6 +158,20 @@ class SessionManager:
157
158
  if request.headers.get("conversation-id"):
158
159
  self._context_state.conversation_id.set(request.headers["conversation-id"])
159
160
 
161
+ if request.headers.get("user-message-id"):
162
+ self._context_state.user_message_id.set(request.headers["user-message-id"])
163
+
164
+ def set_metadata_from_websocket(self, user_message_id: str | None, conversation_id: str | None) -> None:
165
+ """
166
+ Extracts and sets user metadata for Websocket connections.
167
+ """
168
+
169
+ if conversation_id is not None:
170
+ self._context_state.conversation_id.set(conversation_id)
171
+
172
+ if user_message_id is not None:
173
+ self._context_state.user_message_id.set(user_message_id)
174
+
160
175
 
161
176
  # Compatibility aliases with previous releases
162
177
  AIQSessionManager = SessionManager
@@ -124,7 +124,6 @@ class Settings(HashableBaseModel):
124
124
  if (short_names[key.local_name] == 1):
125
125
  type_list.append((key.local_name, key.config_type))
126
126
 
127
- # pylint: disable=consider-alternative-union-syntax
128
127
  return typing.Union[tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
129
128
 
130
129
  RegistryHandlerAnnotation = dict[
@@ -127,7 +127,7 @@ def execute_code_subprocess(generated_code: str, queue):
127
127
  stderr_capture = StringIO()
128
128
  try:
129
129
  with contextlib.redirect_stdout(stdout_capture), contextlib.redirect_stderr(stderr_capture):
130
- exec(generated_code, {}) # pylint: disable=W0122
130
+ exec(generated_code, {})
131
131
  logger.debug("execute_code_subprocess finished, PID: %s", os.getpid())
132
132
  queue.put(CodeExecutionResult(stdout=stdout_capture.getvalue(), stderr=stderr_capture.getvalue()))
133
133
  except Exception as e:
@@ -53,7 +53,7 @@ async def document_search(config: MilvusDocumentSearchToolConfig, builder: Build
53
53
  from langchain_core.messages import HumanMessage
54
54
  from langchain_core.messages import SystemMessage
55
55
  from langchain_core.pydantic_v1 import BaseModel
56
- from langchain_core.pydantic_v1 import Field # pylint: disable=redefined-outer-name, reimported
56
+ from langchain_core.pydantic_v1 import Field
57
57
 
58
58
  # define collection store
59
59
  # create a list of tuples using enumerate()
nat/tool/mcp/mcp_tool.py CHANGED
@@ -48,7 +48,7 @@ class MCPToolConfig(FunctionBaseConfig, name="mcp_tool_wrapper"):
48
48
 
49
49
 
50
50
  @register_function(config_type=MCPToolConfig)
51
- async def mcp_tool(config: MCPToolConfig, builder: Builder): # pylint: disable=unused-argument
51
+ async def mcp_tool(config: MCPToolConfig, builder: Builder):
52
52
  """
53
53
  Generate a NAT Function that wraps a tool provided by the MCP server.
54
54
  """
nat/tool/register.py CHANGED
@@ -13,7 +13,6 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- # pylint: disable=unused-import
17
16
  # flake8: noqa
18
17
 
19
18
  # Import any tools which need to be automatically registered here
@@ -21,7 +21,7 @@ from ..exception_handlers.schemas import yaml_exception_handler
21
21
 
22
22
 
23
23
  @schema_exception_handler
24
- def validate_schema(metadata, Schema): # pylint: disable=invalid-name
24
+ def validate_schema(metadata, Schema):
25
25
 
26
26
  try:
27
27
  return Schema(**metadata)
@@ -31,7 +31,7 @@ def validate_schema(metadata, Schema): # pylint: disable=invalid-name
31
31
 
32
32
 
33
33
  @yaml_exception_handler
34
- def validate_yaml(ctx, param, value): # pylint: disable=unused-argument
34
+ def validate_yaml(ctx, param, value):
35
35
  """
36
36
  Validate that the file is a valid YAML file
37
37
 
@@ -26,8 +26,6 @@ from collections.abc import Sequence
26
26
  from typing import Any
27
27
  from typing import TypeVar
28
28
 
29
- # pylint: disable=inconsistent-return-statements
30
-
31
29
  T = TypeVar("T")
32
30
  Exc = tuple[type[BaseException], ...] # exception classes
33
31
  CodePattern = int | str | range # for retry_codes argument
@@ -21,7 +21,7 @@ from pydantic import ValidationError
21
21
  logger = logging.getLogger(__name__)
22
22
 
23
23
 
24
- def schema_exception_handler(func, **kwargs): # pylint: disable=unused-argument
24
+ def schema_exception_handler(func, **kwargs):
25
25
  """
26
26
  A decorator that handles `ValidationError` exceptions for schema validation functions.
27
27
 
@@ -25,8 +25,8 @@ from nat.utils.reactive.subscription import Subscription
25
25
 
26
26
  # Covariant type param: An Observable producing type X can also produce
27
27
  # a subtype of X.
28
- _T_out_co = TypeVar("_T_out_co", covariant=True) # pylint: disable=invalid-name
29
- _T = TypeVar("_T") # pylint: disable=invalid-name
28
+ _T_out_co = TypeVar("_T_out_co", covariant=True)
29
+ _T = TypeVar("_T")
30
30
 
31
31
  OnNext = Callable[[_T], None]
32
32
  OnError = Callable[[Exception], None]
@@ -20,7 +20,7 @@ from typing import TypeVar
20
20
 
21
21
  # Contravariant type param: An Observer that can accept type X can also
22
22
  # accept any supertype of X.
23
- _T_in_contra = TypeVar("_T_in_contra", contravariant=True) # pylint: disable=invalid-name
23
+ _T_in_contra = TypeVar("_T_in_contra", contravariant=True)
24
24
 
25
25
 
26
26
  class ObserverBase(Generic[_T_in_contra], ABC):
@@ -24,8 +24,8 @@ from nat.utils.type_utils import override
24
24
 
25
25
  # Covariant type param: An Observable producing type X can also produce
26
26
  # a subtype of X.
27
- _T_out_co = TypeVar("_T_out_co", covariant=True) # pylint: disable=invalid-name
28
- _T = TypeVar("_T") # pylint: disable=invalid-name
27
+ _T_out_co = TypeVar("_T_out_co", covariant=True)
28
+ _T = TypeVar("_T")
29
29
 
30
30
  OnNext = Callable[[_T], None]
31
31
  OnError = Callable[[Exception], None]
@@ -23,8 +23,8 @@ logger = logging.getLogger(__name__)
23
23
 
24
24
  # Contravariant type param: An Observer that can accept type X can also
25
25
  # accept any supertype of X.
26
- _T_in_contra = TypeVar("_T_in_contra", contravariant=True) # pylint: disable=invalid-name
27
- _T = TypeVar("_T") # pylint: disable=invalid-name
26
+ _T_in_contra = TypeVar("_T_in_contra", contravariant=True)
27
+ _T = TypeVar("_T")
28
28
 
29
29
  OnNext = Callable[[_T], None]
30
30
  OnError = Callable[[Exception], None]
@@ -21,7 +21,7 @@ from typing import TypeVar
21
21
  if typing.TYPE_CHECKING:
22
22
  from nat.utils.reactive.base.subject_base import SubjectBase
23
23
 
24
- _T = TypeVar("_T") # pylint: disable=invalid-name
24
+ _T = TypeVar("_T")
25
25
 
26
26
  OnNext = Callable[[_T], None]
27
27
  OnError = Callable[[Exception], None]