nvidia-nat 1.3.0a20250822__py3-none-any.whl → 1.3.0a20250824__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +2 -2
- nat/agent/base.py +0 -1
- nat/agent/react_agent/agent.py +21 -3
- nat/agent/react_agent/register.py +1 -1
- nat/agent/register.py +0 -1
- nat/agent/rewoo_agent/agent.py +0 -1
- nat/agent/rewoo_agent/register.py +1 -1
- nat/agent/tool_calling_agent/agent.py +0 -1
- nat/agent/tool_calling_agent/register.py +1 -1
- nat/authentication/api_key/api_key_auth_provider.py +1 -1
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +1 -1
- nat/builder/context.py +9 -1
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +5 -7
- nat/builder/workflow_builder.py +0 -1
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/info/list_mcp.py +3 -4
- nat/cli/commands/registry/search.py +14 -16
- nat/cli/commands/start.py +0 -1
- nat/cli/commands/workflow/templates/register.py.j2 +0 -1
- nat/cli/commands/workflow/workflow_commands.py +0 -1
- nat/cli/type_registry.py +3 -5
- nat/data_models/config.py +1 -1
- nat/data_models/evaluate.py +1 -1
- nat/data_models/function_dependencies.py +6 -6
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/model_gated_field_mixin.py +125 -0
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +36 -0
- nat/data_models/top_p_mixin.py +36 -0
- nat/embedder/register.py +0 -1
- nat/eval/dataset_handler/dataset_handler.py +5 -6
- nat/eval/evaluate.py +7 -8
- nat/eval/rag_evaluator/register.py +2 -2
- nat/eval/register.py +0 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +0 -3
- nat/eval/utils/weave_eval.py +3 -3
- nat/experimental/test_time_compute/models/strategy_base.py +3 -2
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +0 -2
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +48 -49
- nat/front_ends/fastapi/message_handler.py +13 -14
- nat/front_ends/fastapi/message_validator.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +1 -1
- nat/front_ends/register.py +0 -1
- nat/llm/aws_bedrock_llm.py +3 -3
- nat/llm/azure_openai_llm.py +3 -4
- nat/llm/nim_llm.py +4 -4
- nat/llm/openai_llm.py +4 -4
- nat/llm/register.py +0 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/object_store/register.py +0 -1
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/register.py +3 -3
- nat/profiler/callbacks/langchain_callback_handler.py +1 -1
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +1 -1
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +1 -4
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/profile_runner.py +13 -8
- nat/registry_handlers/package_utils.py +0 -1
- nat/registry_handlers/pypi/pypi_handler.py +20 -23
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +8 -9
- nat/retriever/register.py +0 -1
- nat/runtime/session.py +23 -8
- nat/settings/global_settings.py +0 -1
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
- nat/tool/document_search.py +1 -1
- nat/tool/mcp/mcp_tool.py +1 -1
- nat/tool/register.py +0 -1
- nat/utils/data_models/schema_validator.py +2 -2
- nat/utils/exception_handlers/automatic_retries.py +0 -2
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +2 -2
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +4 -6
- nat/utils/type_utils.py +4 -4
- {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250824.dist-info}/METADATA +1 -1
- {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250824.dist-info}/RECORD +94 -91
- {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250824.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250824.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250824.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250824.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250824.dist-info}/top_level.txt +0 -0
nat/llm/openai_llm.py
CHANGED
|
@@ -22,9 +22,11 @@ from nat.builder.llm import LLMProviderInfo
|
|
|
22
22
|
from nat.cli.register_workflow import register_llm_provider
|
|
23
23
|
from nat.data_models.llm import LLMBaseConfig
|
|
24
24
|
from nat.data_models.retry_mixin import RetryMixin
|
|
25
|
+
from nat.data_models.temperature_mixin import TemperatureMixin
|
|
26
|
+
from nat.data_models.top_p_mixin import TopPMixin
|
|
25
27
|
|
|
26
28
|
|
|
27
|
-
class OpenAIModelConfig(LLMBaseConfig, RetryMixin, name="openai"):
|
|
29
|
+
class OpenAIModelConfig(LLMBaseConfig, RetryMixin, TemperatureMixin, TopPMixin, name="openai"):
|
|
28
30
|
"""An OpenAI LLM provider to be used with an LLM client."""
|
|
29
31
|
|
|
30
32
|
model_config = ConfigDict(protected_namespaces=(), extra="allow")
|
|
@@ -34,13 +36,11 @@ class OpenAIModelConfig(LLMBaseConfig, RetryMixin, name="openai"):
|
|
|
34
36
|
model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
|
|
35
37
|
serialization_alias="model",
|
|
36
38
|
description="The OpenAI hosted model name.")
|
|
37
|
-
temperature: float = Field(default=0.0, description="Sampling temperature in [0, 1].")
|
|
38
|
-
top_p: float = Field(default=1.0, description="Top-p for distribution sampling.")
|
|
39
39
|
seed: int | None = Field(default=None, description="Random seed to set for generation.")
|
|
40
40
|
max_retries: int = Field(default=10, description="The max number of retries for the request.")
|
|
41
41
|
|
|
42
42
|
|
|
43
43
|
@register_llm_provider(config_type=OpenAIModelConfig)
|
|
44
|
-
async def openai_llm(config: OpenAIModelConfig,
|
|
44
|
+
async def openai_llm(config: OpenAIModelConfig, _builder: Builder):
|
|
45
45
|
|
|
46
46
|
yield LLMProviderInfo(config=config, description="An OpenAI model for use with an LLM client.")
|
nat/llm/register.py
CHANGED
|
@@ -72,9 +72,8 @@ class EnvConfigValue(ABC):
|
|
|
72
72
|
f"{message} Try passing a value to the constructor, or setting the `{self.__class__._ENV_KEY}` "
|
|
73
73
|
"environment variable.")
|
|
74
74
|
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
raise ValueError("value must not be none")
|
|
75
|
+
elif not self.__class__._ALLOW_NONE and value is None:
|
|
76
|
+
raise ValueError("value must not be none")
|
|
78
77
|
|
|
79
78
|
assert isinstance(value, str) or value is None
|
|
80
79
|
|
nat/object_store/register.py
CHANGED
|
@@ -24,7 +24,7 @@ from nat.observability.processor.intermediate_step_serializer import Intermediat
|
|
|
24
24
|
logger = logging.getLogger(__name__)
|
|
25
25
|
|
|
26
26
|
|
|
27
|
-
class FileExporter(FileExportMixin, RawExporter[IntermediateStep, str]):
|
|
27
|
+
class FileExporter(FileExportMixin, RawExporter[IntermediateStep, str]):
|
|
28
28
|
"""A File exporter that exports telemetry traces to a local file."""
|
|
29
29
|
|
|
30
30
|
def __init__(self, context_state: ContextState | None = None, **file_kwargs):
|
nat/observability/register.py
CHANGED
|
@@ -45,7 +45,7 @@ class FileTelemetryExporterConfig(TelemetryExporterBaseConfig, name="file"):
|
|
|
45
45
|
|
|
46
46
|
|
|
47
47
|
@register_telemetry_exporter(config_type=FileTelemetryExporterConfig)
|
|
48
|
-
async def file_telemetry_exporter(config: FileTelemetryExporterConfig, builder: Builder):
|
|
48
|
+
async def file_telemetry_exporter(config: FileTelemetryExporterConfig, builder: Builder):
|
|
49
49
|
"""
|
|
50
50
|
Build and return a FileExporter for file-based telemetry export with optional rolling.
|
|
51
51
|
"""
|
|
@@ -68,7 +68,7 @@ class ConsoleLoggingMethodConfig(LoggingBaseConfig, name="console"):
|
|
|
68
68
|
|
|
69
69
|
|
|
70
70
|
@register_logging_method(config_type=ConsoleLoggingMethodConfig)
|
|
71
|
-
async def console_logging_method(config: ConsoleLoggingMethodConfig, builder: Builder):
|
|
71
|
+
async def console_logging_method(config: ConsoleLoggingMethodConfig, builder: Builder):
|
|
72
72
|
"""
|
|
73
73
|
Build and return a StreamHandler for console-based logging.
|
|
74
74
|
"""
|
|
@@ -86,7 +86,7 @@ class FileLoggingMethod(LoggingBaseConfig, name="file"):
|
|
|
86
86
|
|
|
87
87
|
|
|
88
88
|
@register_logging_method(config_type=FileLoggingMethod)
|
|
89
|
-
async def file_logging_method(config: FileLoggingMethod, builder: Builder):
|
|
89
|
+
async def file_logging_method(config: FileLoggingMethod, builder: Builder):
|
|
90
90
|
"""
|
|
91
91
|
Build and return a FileHandler for file-based logging.
|
|
92
92
|
"""
|
|
@@ -53,7 +53,7 @@ def _extract_tools_schema(invocation_params: dict) -> list:
|
|
|
53
53
|
return tools_schema
|
|
54
54
|
|
|
55
55
|
|
|
56
|
-
class LangchainProfilerHandler(AsyncCallbackHandler, BaseProfilerCallback):
|
|
56
|
+
class LangchainProfilerHandler(AsyncCallbackHandler, BaseProfilerCallback):
|
|
57
57
|
"""Callback Handler that tracks NIM info."""
|
|
58
58
|
|
|
59
59
|
total_tokens: int = 0
|
|
@@ -86,7 +86,7 @@ class SemanticKernelProfilerHandler(BaseProfilerCallback):
|
|
|
86
86
|
|
|
87
87
|
# Gather the appropriate modules/functions based on your builder config
|
|
88
88
|
for llm in self._builder_llms:
|
|
89
|
-
if self._builder_llms[llm].provider_type == 'openai':
|
|
89
|
+
if self._builder_llms[llm].provider_type == 'openai':
|
|
90
90
|
functions_to_patch.extend(["openai_non_streaming", "openai_streaming"])
|
|
91
91
|
|
|
92
92
|
# Grab original reference for the tool call
|
nat/profiler/data_frame_row.py
CHANGED
|
@@ -42,7 +42,7 @@ class DataFrameRow(BaseModel):
|
|
|
42
42
|
framework: str | None
|
|
43
43
|
|
|
44
44
|
@field_validator('llm_text_input', 'llm_text_output', 'llm_new_token', mode='before')
|
|
45
|
-
def cast_to_str(cls, v):
|
|
45
|
+
def cast_to_str(cls, v):
|
|
46
46
|
if v is None:
|
|
47
47
|
return v
|
|
48
48
|
try:
|
|
@@ -13,8 +13,6 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
# pylint disable=ungrouped-imports
|
|
17
|
-
|
|
18
16
|
from __future__ import annotations
|
|
19
17
|
|
|
20
18
|
import functools
|
|
@@ -75,8 +73,7 @@ def set_framework_profiler_handler(
|
|
|
75
73
|
logger.debug("LlamaIndex callback handler registered")
|
|
76
74
|
|
|
77
75
|
if LLMFrameworkEnum.CREWAI in frameworks and not _library_instrumented["crewai"]:
|
|
78
|
-
from nat.plugins.crewai.crewai_callback_handler import
|
|
79
|
-
CrewAIProfilerHandler # pylint: disable=ungrouped-imports,line-too-long # noqa E501
|
|
76
|
+
from nat.plugins.crewai.crewai_callback_handler import CrewAIProfilerHandler
|
|
80
77
|
|
|
81
78
|
handler = CrewAIProfilerHandler()
|
|
82
79
|
handler.instrument()
|
|
@@ -195,7 +195,7 @@ def profile_workflow_bottlenecks(all_steps: list[list[IntermediateStep]]) -> Sim
|
|
|
195
195
|
c_max = 0
|
|
196
196
|
for ts, delta in events_sub:
|
|
197
197
|
c_curr += delta
|
|
198
|
-
if c_curr > c_max: #
|
|
198
|
+
if c_curr > c_max: # noqa: PLR1730 - don't use max built-in
|
|
199
199
|
c_max = c_curr
|
|
200
200
|
max_concurrency_by_name[op_name] = c_max
|
|
201
201
|
|
|
@@ -172,7 +172,7 @@ class CallNode(BaseModel):
|
|
|
172
172
|
if not self.children:
|
|
173
173
|
return self.duration
|
|
174
174
|
|
|
175
|
-
intervals = [(c.start_time, c.end_time) for c in self.children]
|
|
175
|
+
intervals = [(c.start_time, c.end_time) for c in self.children]
|
|
176
176
|
# Sort by start time
|
|
177
177
|
intervals.sort(key=lambda x: x[0])
|
|
178
178
|
|
|
@@ -204,7 +204,7 @@ class CallNode(BaseModel):
|
|
|
204
204
|
This ensures no overlap double-counting among children.
|
|
205
205
|
"""
|
|
206
206
|
total = self.compute_self_time()
|
|
207
|
-
for c in self.children:
|
|
207
|
+
for c in self.children:
|
|
208
208
|
total += c.compute_subtree_time()
|
|
209
209
|
return total
|
|
210
210
|
|
|
@@ -216,7 +216,7 @@ class CallNode(BaseModel):
|
|
|
216
216
|
info = (f"{indent}- {self.operation_type} '{self.operation_name}' "
|
|
217
217
|
f"(uuid={self.uuid}, start={self.start_time:.2f}, "
|
|
218
218
|
f"end={self.end_time:.2f}, dur={self.duration:.2f})")
|
|
219
|
-
child_strs = [child._repr(level + 1) for child in self.children]
|
|
219
|
+
child_strs = [child._repr(level + 1) for child in self.children]
|
|
220
220
|
return "\n".join([info] + child_strs)
|
|
221
221
|
|
|
222
222
|
|
|
@@ -228,7 +228,7 @@ def run_prefixspan(sequences_map: dict[int, list[PrefixCallNode]],
|
|
|
228
228
|
else:
|
|
229
229
|
abs_min_support = min_support
|
|
230
230
|
|
|
231
|
-
freq_patterns = ps.frequent(abs_min_support)
|
|
231
|
+
freq_patterns = ps.frequent(abs_min_support)
|
|
232
232
|
# freq_patterns => [(count, [item1, item2, ...])]
|
|
233
233
|
|
|
234
234
|
results = []
|
|
@@ -321,13 +321,12 @@ def compute_coverage_and_duration(sequences_map: dict[int, list[PrefixCallNode]]
|
|
|
321
321
|
# --------------------------------------------------------------------------------
|
|
322
322
|
|
|
323
323
|
|
|
324
|
-
def prefixspan_subworkflow_with_text(
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
prefix_list: list[str] = None) -> PrefixSpanSubworkflowResult:
|
|
324
|
+
def prefixspan_subworkflow_with_text(all_steps: list[list[IntermediateStep]],
|
|
325
|
+
min_support: int | float = 2,
|
|
326
|
+
top_k: int = 10,
|
|
327
|
+
min_coverage: float = 0.0,
|
|
328
|
+
max_text_len: int = 700,
|
|
329
|
+
prefix_list: list[str] = None) -> PrefixSpanSubworkflowResult:
|
|
331
330
|
"""
|
|
332
331
|
1) Build sequences of calls for each example (with llm_text_input).
|
|
333
332
|
2) Convert to token lists, run PrefixSpan with min_support.
|
|
@@ -66,7 +66,7 @@ def compute_inter_query_token_uniqueness_by_llm(all_steps: list[list[Intermediat
|
|
|
66
66
|
# 2) Group by (llm_name, example_number), then sort each group
|
|
67
67
|
grouped = cdf.groupby(['llm_name', 'example_number'], as_index=False, group_keys=True)
|
|
68
68
|
|
|
69
|
-
for (llm, ex_num), group_df in grouped:
|
|
69
|
+
for (llm, ex_num), group_df in grouped:
|
|
70
70
|
# Sort by event_timestamp
|
|
71
71
|
group_df = group_df.sort_values('event_timestamp', ascending=True)
|
|
72
72
|
|
nat/profiler/profile_runner.py
CHANGED
|
@@ -88,14 +88,19 @@ class ProfilerRunner:
|
|
|
88
88
|
writes out combined requests JSON, then computes and saves additional metrics,
|
|
89
89
|
and optionally fits a forecasting model.
|
|
90
90
|
"""
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
from nat.profiler.inference_optimization.
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
91
|
+
# Yapf and ruff disagree on how to format long imports, disable yapf go with ruff
|
|
92
|
+
from nat.profiler.inference_optimization.bottleneck_analysis.nested_stack_analysis import (
|
|
93
|
+
multi_example_call_profiling,
|
|
94
|
+
) # yapf: disable
|
|
95
|
+
from nat.profiler.inference_optimization.bottleneck_analysis.simple_stack_analysis import (
|
|
96
|
+
profile_workflow_bottlenecks,
|
|
97
|
+
) # yapf: disable
|
|
98
|
+
from nat.profiler.inference_optimization.experimental.concurrency_spike_analysis import (
|
|
99
|
+
concurrency_spike_analysis,
|
|
100
|
+
) # yapf: disable
|
|
101
|
+
from nat.profiler.inference_optimization.experimental.prefix_span_analysis import (
|
|
102
|
+
prefixspan_subworkflow_with_text,
|
|
103
|
+
) # yapf: disable
|
|
99
104
|
from nat.profiler.inference_optimization.llm_metrics import LLMMetrics
|
|
100
105
|
from nat.profiler.inference_optimization.prompt_caching import get_common_prefixes
|
|
101
106
|
from nat.profiler.inference_optimization.token_uniqueness import compute_inter_query_token_uniqueness_by_llm
|
|
@@ -44,13 +44,12 @@ class PypiRegistryHandler(AbstractRegistryHandler):
|
|
|
44
44
|
https://github.com/pypiserver/pypiserver
|
|
45
45
|
"""
|
|
46
46
|
|
|
47
|
-
def __init__(
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
search_route: str = ""):
|
|
47
|
+
def __init__(self,
|
|
48
|
+
endpoint: str,
|
|
49
|
+
token: str | None = None,
|
|
50
|
+
publish_route: str = "",
|
|
51
|
+
pull_route: str = "",
|
|
52
|
+
search_route: str = ""):
|
|
54
53
|
super().__init__()
|
|
55
54
|
self._endpoint = endpoint.rstrip("/")
|
|
56
55
|
self._token = token
|
|
@@ -126,17 +125,16 @@ class PypiRegistryHandler(AbstractRegistryHandler):
|
|
|
126
125
|
|
|
127
126
|
versioned_packages_str = " ".join(versioned_packages)
|
|
128
127
|
|
|
129
|
-
result = subprocess.run(
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
check=True)
|
|
128
|
+
result = subprocess.run([
|
|
129
|
+
"uv",
|
|
130
|
+
"pip",
|
|
131
|
+
"install",
|
|
132
|
+
"--prerelease=allow",
|
|
133
|
+
"--index-url",
|
|
134
|
+
f"{self._endpoint}/{self._pull_route}/",
|
|
135
|
+
versioned_packages_str
|
|
136
|
+
],
|
|
137
|
+
check=True)
|
|
140
138
|
|
|
141
139
|
result.check_returncode()
|
|
142
140
|
|
|
@@ -171,11 +169,10 @@ class PypiRegistryHandler(AbstractRegistryHandler):
|
|
|
171
169
|
"""
|
|
172
170
|
|
|
173
171
|
try:
|
|
174
|
-
completed_process = subprocess.run(
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
check=True)
|
|
172
|
+
completed_process = subprocess.run(["pip", "search", "--index", f"{self._endpoint}", query.query],
|
|
173
|
+
text=True,
|
|
174
|
+
capture_output=True,
|
|
175
|
+
check=True)
|
|
179
176
|
search_response_list = []
|
|
180
177
|
search_results = completed_process.stdout
|
|
181
178
|
package_results = search_results.split("\n")
|
|
@@ -13,9 +13,8 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
# pylint: disable=unused-import
|
|
17
16
|
# flake8: noqa
|
|
18
17
|
|
|
19
|
-
from .local import register_local
|
|
20
|
-
from .pypi import register_pypi
|
|
21
|
-
from .rest import register_rest
|
|
18
|
+
from .local import register_local
|
|
19
|
+
from .pypi import register_pypi
|
|
20
|
+
from .rest import register_rest
|
|
@@ -42,15 +42,14 @@ logger = logging.getLogger(__name__)
|
|
|
42
42
|
class RestRegistryHandler(AbstractRegistryHandler):
|
|
43
43
|
"""A registry handler for interactions with a remote REST registry."""
|
|
44
44
|
|
|
45
|
-
def __init__(
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
remove_route: str = ""):
|
|
45
|
+
def __init__(self,
|
|
46
|
+
endpoint: str,
|
|
47
|
+
token: str,
|
|
48
|
+
timeout: int = 30,
|
|
49
|
+
publish_route: str = "",
|
|
50
|
+
pull_route: str = "",
|
|
51
|
+
search_route: str = "",
|
|
52
|
+
remove_route: str = ""):
|
|
54
53
|
super().__init__()
|
|
55
54
|
self._endpoint = endpoint.rstrip("/")
|
|
56
55
|
self._timeout = timeout
|
nat/retriever/register.py
CHANGED
nat/runtime/session.py
CHANGED
|
@@ -21,7 +21,9 @@ from collections.abc import Callable
|
|
|
21
21
|
from contextlib import asynccontextmanager
|
|
22
22
|
from contextlib import nullcontext
|
|
23
23
|
|
|
24
|
+
from fastapi import WebSocket
|
|
24
25
|
from starlette.requests import HTTPConnection
|
|
26
|
+
from starlette.requests import Request
|
|
25
27
|
|
|
26
28
|
from nat.builder.context import Context
|
|
27
29
|
from nat.builder.context import ContextState
|
|
@@ -89,7 +91,8 @@ class SessionManager:
|
|
|
89
91
|
@asynccontextmanager
|
|
90
92
|
async def session(self,
|
|
91
93
|
user_manager=None,
|
|
92
|
-
|
|
94
|
+
http_connection: HTTPConnection | None = None,
|
|
95
|
+
user_message_id: str | None = None,
|
|
93
96
|
conversation_id: str | None = None,
|
|
94
97
|
user_input_callback: Callable[[InteractionPrompt], Awaitable[HumanResponse]] = None,
|
|
95
98
|
user_authentication_callback: Callable[[AuthProviderBaseConfig, AuthFlowType],
|
|
@@ -107,10 +110,11 @@ class SessionManager:
|
|
|
107
110
|
if user_authentication_callback is not None:
|
|
108
111
|
token_user_authentication = self._context_state.user_auth_callback.set(user_authentication_callback)
|
|
109
112
|
|
|
110
|
-
if
|
|
111
|
-
self.
|
|
113
|
+
if isinstance(http_connection, WebSocket):
|
|
114
|
+
self.set_metadata_from_websocket(user_message_id, conversation_id)
|
|
112
115
|
|
|
113
|
-
|
|
116
|
+
if isinstance(http_connection, Request):
|
|
117
|
+
self.set_metadata_from_http_request(http_connection)
|
|
114
118
|
|
|
115
119
|
try:
|
|
116
120
|
yield self
|
|
@@ -135,14 +139,11 @@ class SessionManager:
|
|
|
135
139
|
async with self._workflow.run(message) as runner:
|
|
136
140
|
yield runner
|
|
137
141
|
|
|
138
|
-
def set_metadata_from_http_request(self, request:
|
|
142
|
+
def set_metadata_from_http_request(self, request: Request) -> None:
|
|
139
143
|
"""
|
|
140
144
|
Extracts and sets user metadata request attributes from a HTTP request.
|
|
141
145
|
If request is None, no attributes are set.
|
|
142
146
|
"""
|
|
143
|
-
if request is None:
|
|
144
|
-
return
|
|
145
|
-
|
|
146
147
|
self._context.metadata._request.method = getattr(request, "method", None)
|
|
147
148
|
self._context.metadata._request.url_path = request.url.path
|
|
148
149
|
self._context.metadata._request.url_port = request.url.port
|
|
@@ -157,6 +158,20 @@ class SessionManager:
|
|
|
157
158
|
if request.headers.get("conversation-id"):
|
|
158
159
|
self._context_state.conversation_id.set(request.headers["conversation-id"])
|
|
159
160
|
|
|
161
|
+
if request.headers.get("user-message-id"):
|
|
162
|
+
self._context_state.user_message_id.set(request.headers["user-message-id"])
|
|
163
|
+
|
|
164
|
+
def set_metadata_from_websocket(self, user_message_id: str | None, conversation_id: str | None) -> None:
|
|
165
|
+
"""
|
|
166
|
+
Extracts and sets user metadata for Websocket connections.
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
if conversation_id is not None:
|
|
170
|
+
self._context_state.conversation_id.set(conversation_id)
|
|
171
|
+
|
|
172
|
+
if user_message_id is not None:
|
|
173
|
+
self._context_state.user_message_id.set(user_message_id)
|
|
174
|
+
|
|
160
175
|
|
|
161
176
|
# Compatibility aliases with previous releases
|
|
162
177
|
AIQSessionManager = SessionManager
|
nat/settings/global_settings.py
CHANGED
|
@@ -124,7 +124,6 @@ class Settings(HashableBaseModel):
|
|
|
124
124
|
if (short_names[key.local_name] == 1):
|
|
125
125
|
type_list.append((key.local_name, key.config_type))
|
|
126
126
|
|
|
127
|
-
# pylint: disable=consider-alternative-union-syntax
|
|
128
127
|
return typing.Union[tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
|
|
129
128
|
|
|
130
129
|
RegistryHandlerAnnotation = dict[
|
|
@@ -127,7 +127,7 @@ def execute_code_subprocess(generated_code: str, queue):
|
|
|
127
127
|
stderr_capture = StringIO()
|
|
128
128
|
try:
|
|
129
129
|
with contextlib.redirect_stdout(stdout_capture), contextlib.redirect_stderr(stderr_capture):
|
|
130
|
-
exec(generated_code, {})
|
|
130
|
+
exec(generated_code, {})
|
|
131
131
|
logger.debug("execute_code_subprocess finished, PID: %s", os.getpid())
|
|
132
132
|
queue.put(CodeExecutionResult(stdout=stdout_capture.getvalue(), stderr=stderr_capture.getvalue()))
|
|
133
133
|
except Exception as e:
|
nat/tool/document_search.py
CHANGED
|
@@ -53,7 +53,7 @@ async def document_search(config: MilvusDocumentSearchToolConfig, builder: Build
|
|
|
53
53
|
from langchain_core.messages import HumanMessage
|
|
54
54
|
from langchain_core.messages import SystemMessage
|
|
55
55
|
from langchain_core.pydantic_v1 import BaseModel
|
|
56
|
-
from langchain_core.pydantic_v1 import Field
|
|
56
|
+
from langchain_core.pydantic_v1 import Field
|
|
57
57
|
|
|
58
58
|
# define collection store
|
|
59
59
|
# create a list of tuples using enumerate()
|
nat/tool/mcp/mcp_tool.py
CHANGED
|
@@ -48,7 +48,7 @@ class MCPToolConfig(FunctionBaseConfig, name="mcp_tool_wrapper"):
|
|
|
48
48
|
|
|
49
49
|
|
|
50
50
|
@register_function(config_type=MCPToolConfig)
|
|
51
|
-
async def mcp_tool(config: MCPToolConfig, builder: Builder):
|
|
51
|
+
async def mcp_tool(config: MCPToolConfig, builder: Builder):
|
|
52
52
|
"""
|
|
53
53
|
Generate a NAT Function that wraps a tool provided by the MCP server.
|
|
54
54
|
"""
|
nat/tool/register.py
CHANGED
|
@@ -21,7 +21,7 @@ from ..exception_handlers.schemas import yaml_exception_handler
|
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
@schema_exception_handler
|
|
24
|
-
def validate_schema(metadata, Schema):
|
|
24
|
+
def validate_schema(metadata, Schema):
|
|
25
25
|
|
|
26
26
|
try:
|
|
27
27
|
return Schema(**metadata)
|
|
@@ -31,7 +31,7 @@ def validate_schema(metadata, Schema): # pylint: disable=invalid-name
|
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
@yaml_exception_handler
|
|
34
|
-
def validate_yaml(ctx, param, value):
|
|
34
|
+
def validate_yaml(ctx, param, value):
|
|
35
35
|
"""
|
|
36
36
|
Validate that the file is a valid YAML file
|
|
37
37
|
|
|
@@ -26,8 +26,6 @@ from collections.abc import Sequence
|
|
|
26
26
|
from typing import Any
|
|
27
27
|
from typing import TypeVar
|
|
28
28
|
|
|
29
|
-
# pylint: disable=inconsistent-return-statements
|
|
30
|
-
|
|
31
29
|
T = TypeVar("T")
|
|
32
30
|
Exc = tuple[type[BaseException], ...] # exception classes
|
|
33
31
|
CodePattern = int | str | range # for retry_codes argument
|
|
@@ -21,7 +21,7 @@ from pydantic import ValidationError
|
|
|
21
21
|
logger = logging.getLogger(__name__)
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
def schema_exception_handler(func, **kwargs):
|
|
24
|
+
def schema_exception_handler(func, **kwargs):
|
|
25
25
|
"""
|
|
26
26
|
A decorator that handles `ValidationError` exceptions for schema validation functions.
|
|
27
27
|
|
|
@@ -25,8 +25,8 @@ from nat.utils.reactive.subscription import Subscription
|
|
|
25
25
|
|
|
26
26
|
# Covariant type param: An Observable producing type X can also produce
|
|
27
27
|
# a subtype of X.
|
|
28
|
-
_T_out_co = TypeVar("_T_out_co", covariant=True)
|
|
29
|
-
_T = TypeVar("_T")
|
|
28
|
+
_T_out_co = TypeVar("_T_out_co", covariant=True)
|
|
29
|
+
_T = TypeVar("_T")
|
|
30
30
|
|
|
31
31
|
OnNext = Callable[[_T], None]
|
|
32
32
|
OnError = Callable[[Exception], None]
|
|
@@ -20,7 +20,7 @@ from typing import TypeVar
|
|
|
20
20
|
|
|
21
21
|
# Contravariant type param: An Observer that can accept type X can also
|
|
22
22
|
# accept any supertype of X.
|
|
23
|
-
_T_in_contra = TypeVar("_T_in_contra", contravariant=True)
|
|
23
|
+
_T_in_contra = TypeVar("_T_in_contra", contravariant=True)
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
class ObserverBase(Generic[_T_in_contra], ABC):
|
nat/utils/reactive/observable.py
CHANGED
|
@@ -24,8 +24,8 @@ from nat.utils.type_utils import override
|
|
|
24
24
|
|
|
25
25
|
# Covariant type param: An Observable producing type X can also produce
|
|
26
26
|
# a subtype of X.
|
|
27
|
-
_T_out_co = TypeVar("_T_out_co", covariant=True)
|
|
28
|
-
_T = TypeVar("_T")
|
|
27
|
+
_T_out_co = TypeVar("_T_out_co", covariant=True)
|
|
28
|
+
_T = TypeVar("_T")
|
|
29
29
|
|
|
30
30
|
OnNext = Callable[[_T], None]
|
|
31
31
|
OnError = Callable[[Exception], None]
|
nat/utils/reactive/observer.py
CHANGED
|
@@ -23,8 +23,8 @@ logger = logging.getLogger(__name__)
|
|
|
23
23
|
|
|
24
24
|
# Contravariant type param: An Observer that can accept type X can also
|
|
25
25
|
# accept any supertype of X.
|
|
26
|
-
_T_in_contra = TypeVar("_T_in_contra", contravariant=True)
|
|
27
|
-
_T = TypeVar("_T")
|
|
26
|
+
_T_in_contra = TypeVar("_T_in_contra", contravariant=True)
|
|
27
|
+
_T = TypeVar("_T")
|
|
28
28
|
|
|
29
29
|
OnNext = Callable[[_T], None]
|
|
30
30
|
OnError = Callable[[Exception], None]
|
|
@@ -21,7 +21,7 @@ from typing import TypeVar
|
|
|
21
21
|
if typing.TYPE_CHECKING:
|
|
22
22
|
from nat.utils.reactive.base.subject_base import SubjectBase
|
|
23
23
|
|
|
24
|
-
_T = TypeVar("_T")
|
|
24
|
+
_T = TypeVar("_T")
|
|
25
25
|
|
|
26
26
|
OnNext = Callable[[_T], None]
|
|
27
27
|
OnError = Callable[[Exception], None]
|