nvidia-nat 1.3.0a20250822__py3-none-any.whl → 1.3.0a20250824__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +2 -2
- nat/agent/base.py +0 -1
- nat/agent/react_agent/agent.py +21 -3
- nat/agent/react_agent/register.py +1 -1
- nat/agent/register.py +0 -1
- nat/agent/rewoo_agent/agent.py +0 -1
- nat/agent/rewoo_agent/register.py +1 -1
- nat/agent/tool_calling_agent/agent.py +0 -1
- nat/agent/tool_calling_agent/register.py +1 -1
- nat/authentication/api_key/api_key_auth_provider.py +1 -1
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +1 -1
- nat/builder/context.py +9 -1
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +5 -7
- nat/builder/workflow_builder.py +0 -1
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/info/list_mcp.py +3 -4
- nat/cli/commands/registry/search.py +14 -16
- nat/cli/commands/start.py +0 -1
- nat/cli/commands/workflow/templates/register.py.j2 +0 -1
- nat/cli/commands/workflow/workflow_commands.py +0 -1
- nat/cli/type_registry.py +3 -5
- nat/data_models/config.py +1 -1
- nat/data_models/evaluate.py +1 -1
- nat/data_models/function_dependencies.py +6 -6
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/model_gated_field_mixin.py +125 -0
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +36 -0
- nat/data_models/top_p_mixin.py +36 -0
- nat/embedder/register.py +0 -1
- nat/eval/dataset_handler/dataset_handler.py +5 -6
- nat/eval/evaluate.py +7 -8
- nat/eval/rag_evaluator/register.py +2 -2
- nat/eval/register.py +0 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +0 -3
- nat/eval/utils/weave_eval.py +3 -3
- nat/experimental/test_time_compute/models/strategy_base.py +3 -2
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +0 -2
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +48 -49
- nat/front_ends/fastapi/message_handler.py +13 -14
- nat/front_ends/fastapi/message_validator.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +1 -1
- nat/front_ends/register.py +0 -1
- nat/llm/aws_bedrock_llm.py +3 -3
- nat/llm/azure_openai_llm.py +3 -4
- nat/llm/nim_llm.py +4 -4
- nat/llm/openai_llm.py +4 -4
- nat/llm/register.py +0 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/object_store/register.py +0 -1
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/register.py +3 -3
- nat/profiler/callbacks/langchain_callback_handler.py +1 -1
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +1 -1
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +1 -4
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/profile_runner.py +13 -8
- nat/registry_handlers/package_utils.py +0 -1
- nat/registry_handlers/pypi/pypi_handler.py +20 -23
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +8 -9
- nat/retriever/register.py +0 -1
- nat/runtime/session.py +23 -8
- nat/settings/global_settings.py +0 -1
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
- nat/tool/document_search.py +1 -1
- nat/tool/mcp/mcp_tool.py +1 -1
- nat/tool/register.py +0 -1
- nat/utils/data_models/schema_validator.py +2 -2
- nat/utils/exception_handlers/automatic_retries.py +0 -2
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +2 -2
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +4 -6
- nat/utils/type_utils.py +4 -4
- {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250824.dist-info}/METADATA +1 -1
- {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250824.dist-info}/RECORD +94 -91
- {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250824.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250824.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250824.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250824.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250824.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import re
|
|
17
|
+
|
|
18
|
+
from pydantic import BaseModel
|
|
19
|
+
from pydantic import Field
|
|
20
|
+
|
|
21
|
+
from nat.data_models.model_gated_field_mixin import ModelGatedFieldMixin
|
|
22
|
+
|
|
23
|
+
_UNSUPPORTED_TEMPERATURE_MODELS = (re.compile(r"gpt-?5", re.IGNORECASE), )
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class TemperatureMixin(
|
|
27
|
+
BaseModel,
|
|
28
|
+
ModelGatedFieldMixin[float],
|
|
29
|
+
field_name="temperature",
|
|
30
|
+
default_if_supported=0.0,
|
|
31
|
+
unsupported_models=_UNSUPPORTED_TEMPERATURE_MODELS,
|
|
32
|
+
):
|
|
33
|
+
"""
|
|
34
|
+
Mixin class for temperature configuration.
|
|
35
|
+
"""
|
|
36
|
+
temperature: float | None = Field(default=None, ge=0.0, le=1.0, description="Sampling temperature in [0, 1].")
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import re
|
|
17
|
+
|
|
18
|
+
from pydantic import BaseModel
|
|
19
|
+
from pydantic import Field
|
|
20
|
+
|
|
21
|
+
from nat.data_models.model_gated_field_mixin import ModelGatedFieldMixin
|
|
22
|
+
|
|
23
|
+
_UNSUPPORTED_TOP_P_MODELS = (re.compile(r"gpt-?5", re.IGNORECASE), )
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class TopPMixin(
|
|
27
|
+
BaseModel,
|
|
28
|
+
ModelGatedFieldMixin[float],
|
|
29
|
+
field_name="top_p",
|
|
30
|
+
default_if_supported=1.0,
|
|
31
|
+
unsupported_models=_UNSUPPORTED_TOP_P_MODELS,
|
|
32
|
+
):
|
|
33
|
+
"""
|
|
34
|
+
Mixin class for top-p configuration.
|
|
35
|
+
"""
|
|
36
|
+
top_p: float | None = Field(default=None, ge=0.0, le=1.0, description="Top-p for distribution sampling.")
|
nat/embedder/register.py
CHANGED
|
@@ -146,13 +146,12 @@ class DatasetHandler:
|
|
|
146
146
|
# When num_passes is specified, always use concurrency * num_passes
|
|
147
147
|
# This respects the user's intent for exact number of passes
|
|
148
148
|
target_size = self.concurrency * self.num_passes
|
|
149
|
+
# When num_passes = 0, use the largest multiple of concurrency <= original_size
|
|
150
|
+
# If original_size < concurrency, we need at least concurrency rows
|
|
151
|
+
elif original_size >= self.concurrency:
|
|
152
|
+
target_size = (original_size // self.concurrency) * self.concurrency
|
|
149
153
|
else:
|
|
150
|
-
|
|
151
|
-
# If original_size < concurrency, we need at least concurrency rows
|
|
152
|
-
if original_size >= self.concurrency:
|
|
153
|
-
target_size = (original_size // self.concurrency) * self.concurrency
|
|
154
|
-
else:
|
|
155
|
-
target_size = self.concurrency
|
|
154
|
+
target_size = self.concurrency
|
|
156
155
|
|
|
157
156
|
if target_size == 0:
|
|
158
157
|
raise ValueError("Input dataset too small for even one batch at given concurrency.")
|
nat/eval/evaluate.py
CHANGED
|
@@ -42,7 +42,7 @@ from nat.runtime.session import SessionManager
|
|
|
42
42
|
logger = logging.getLogger(__name__)
|
|
43
43
|
|
|
44
44
|
|
|
45
|
-
class EvaluationRun:
|
|
45
|
+
class EvaluationRun:
|
|
46
46
|
"""
|
|
47
47
|
Instantiated for each evaluation run and used to store data for that single run.
|
|
48
48
|
|
|
@@ -319,7 +319,7 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
|
|
|
319
319
|
except Exception as e:
|
|
320
320
|
logger.exception("Failed to delete old job directory: %s: %s", dir_to_delete, e, exc_info=True)
|
|
321
321
|
|
|
322
|
-
def write_output(self, dataset_handler: DatasetHandler, profiler_results: ProfilerResults):
|
|
322
|
+
def write_output(self, dataset_handler: DatasetHandler, profiler_results: ProfilerResults):
|
|
323
323
|
workflow_output_file = self.eval_config.general.output_dir / "workflow_output.json"
|
|
324
324
|
workflow_output_file.parent.mkdir(parents=True, exist_ok=True)
|
|
325
325
|
|
|
@@ -511,12 +511,11 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
|
|
|
511
511
|
with self.eval_trace_context.evaluation_context():
|
|
512
512
|
if self.config.endpoint:
|
|
513
513
|
await self.run_workflow_remote()
|
|
514
|
-
|
|
515
|
-
if
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
await self.run_workflow_local(session_manager)
|
|
514
|
+
elif not self.config.skip_workflow:
|
|
515
|
+
if session_manager is None:
|
|
516
|
+
session_manager = SessionManager(eval_workflow.build(),
|
|
517
|
+
max_concurrency=self.eval_config.general.max_concurrency)
|
|
518
|
+
await self.run_workflow_local(session_manager)
|
|
520
519
|
|
|
521
520
|
# Evaluate
|
|
522
521
|
evaluators = {name: eval_workflow.get_evaluator(name) for name in self.eval_config.evaluators}
|
|
@@ -73,7 +73,7 @@ class RagasEvaluatorConfig(EvaluatorBaseConfig, name="ragas"):
|
|
|
73
73
|
if isinstance(self.metric, str):
|
|
74
74
|
return self.metric
|
|
75
75
|
if isinstance(self.metric, dict) and self.metric:
|
|
76
|
-
return next(iter(self.metric.keys()))
|
|
76
|
+
return next(iter(self.metric.keys()))
|
|
77
77
|
return ""
|
|
78
78
|
|
|
79
79
|
@property
|
|
@@ -82,7 +82,7 @@ class RagasEvaluatorConfig(EvaluatorBaseConfig, name="ragas"):
|
|
|
82
82
|
if isinstance(self.metric, str):
|
|
83
83
|
return RagasMetricConfig() # Default config when only a metric name is given
|
|
84
84
|
if isinstance(self.metric, dict) and self.metric:
|
|
85
|
-
return next(iter(self.metric.values()))
|
|
85
|
+
return next(iter(self.metric.values()))
|
|
86
86
|
return RagasMetricConfig() # Default config when an invalid type is provided
|
|
87
87
|
|
|
88
88
|
|
nat/eval/register.py
CHANGED
|
@@ -13,7 +13,6 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import asyncio
|
|
17
16
|
import logging
|
|
18
17
|
from typing import Callable
|
|
19
18
|
|
|
@@ -23,7 +22,6 @@ from langchain.schema import HumanMessage
|
|
|
23
22
|
from langchain.schema import SystemMessage
|
|
24
23
|
from langchain_core.language_models import BaseChatModel
|
|
25
24
|
from langchain_core.runnables import RunnableLambda
|
|
26
|
-
from tqdm import tqdm
|
|
27
25
|
|
|
28
26
|
from nat.eval.evaluator.base_evaluator import BaseEvaluator
|
|
29
27
|
from nat.eval.evaluator.evaluator_model import EvalInputItem
|
|
@@ -31,7 +29,6 @@ from nat.eval.evaluator.evaluator_model import EvalOutputItem
|
|
|
31
29
|
|
|
32
30
|
logger = logging.getLogger(__name__)
|
|
33
31
|
|
|
34
|
-
# pylint: disable=line-too-long
|
|
35
32
|
# flake8: noqa: E501
|
|
36
33
|
|
|
37
34
|
|
nat/eval/utils/weave_eval.py
CHANGED
|
@@ -31,7 +31,7 @@ if TYPE_CHECKING:
|
|
|
31
31
|
logger = logging.getLogger(__name__)
|
|
32
32
|
|
|
33
33
|
|
|
34
|
-
class WeaveEvaluationIntegration:
|
|
34
|
+
class WeaveEvaluationIntegration:
|
|
35
35
|
"""
|
|
36
36
|
Class to handle all Weave integration functionality.
|
|
37
37
|
"""
|
|
@@ -47,8 +47,8 @@ class WeaveEvaluationIntegration: # pylint: disable=too-many-public-methods
|
|
|
47
47
|
from weave.flow.eval_imperative import EvaluationLogger
|
|
48
48
|
from weave.flow.eval_imperative import ScoreLogger
|
|
49
49
|
from weave.trace.context import weave_client_context
|
|
50
|
-
self.evaluation_logger_cls = EvaluationLogger
|
|
51
|
-
self.score_logger_cls = ScoreLogger
|
|
50
|
+
self.evaluation_logger_cls = EvaluationLogger
|
|
51
|
+
self.score_logger_cls = ScoreLogger
|
|
52
52
|
self.weave_client_context = weave_client_context
|
|
53
53
|
self.available = True
|
|
54
54
|
except ImportError:
|
|
@@ -17,9 +17,10 @@ from abc import ABC
|
|
|
17
17
|
from abc import abstractmethod
|
|
18
18
|
|
|
19
19
|
from nat.builder.builder import Builder
|
|
20
|
-
from nat.experimental.test_time_compute.models.ttc_item import TTCItem
|
|
21
|
-
from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum, PipelineTypeEnum
|
|
22
20
|
from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
|
|
21
|
+
from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
|
|
22
|
+
from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
|
|
23
|
+
from nat.experimental.test_time_compute.models.ttc_item import TTCItem
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
class StrategyBase(ABC):
|
|
@@ -135,8 +135,6 @@ class LLMBasedOutputMergingSelector(StrategyBase):
|
|
|
135
135
|
except Exception as e:
|
|
136
136
|
logger.error(f"Error parsing merged output: {e}")
|
|
137
137
|
raise ValueError("Failed to parse merged output.")
|
|
138
|
-
else:
|
|
139
|
-
merged_output = merged_output
|
|
140
138
|
|
|
141
139
|
logger.info("Merged output: %s", str(merged_output))
|
|
142
140
|
|
|
@@ -307,7 +307,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
307
307
|
async def start_evaluation(request: EvaluateRequest, background_tasks: BackgroundTasks, http_request: Request):
|
|
308
308
|
"""Handle evaluation requests."""
|
|
309
309
|
|
|
310
|
-
async with session_manager.session(
|
|
310
|
+
async with session_manager.session(http_connection=http_request):
|
|
311
311
|
|
|
312
312
|
# if job_id is present and already exists return the job info
|
|
313
313
|
if request.job_id:
|
|
@@ -336,7 +336,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
336
336
|
"""Get the status of an evaluation job."""
|
|
337
337
|
logger.info("Getting status for job %s", job_id)
|
|
338
338
|
|
|
339
|
-
async with session_manager.session(
|
|
339
|
+
async with session_manager.session(http_connection=http_request):
|
|
340
340
|
|
|
341
341
|
job = job_store.get_job(job_id)
|
|
342
342
|
if not job:
|
|
@@ -349,7 +349,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
349
349
|
"""Get the status of the last created evaluation job."""
|
|
350
350
|
logger.info("Getting last job status")
|
|
351
351
|
|
|
352
|
-
async with session_manager.session(
|
|
352
|
+
async with session_manager.session(http_connection=http_request):
|
|
353
353
|
|
|
354
354
|
job = job_store.get_last_job()
|
|
355
355
|
if not job:
|
|
@@ -361,7 +361,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
361
361
|
async def get_jobs(http_request: Request, status: str | None = None) -> list[EvaluateStatusResponse]:
|
|
362
362
|
"""Get all jobs, optionally filtered by status."""
|
|
363
363
|
|
|
364
|
-
async with session_manager.session(
|
|
364
|
+
async with session_manager.session(http_connection=http_request):
|
|
365
365
|
|
|
366
366
|
if status is None:
|
|
367
367
|
logger.info("Getting all jobs")
|
|
@@ -522,9 +522,9 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
522
522
|
|
|
523
523
|
workflow = session_manager.workflow
|
|
524
524
|
|
|
525
|
-
GenerateBodyType = workflow.input_schema
|
|
526
|
-
GenerateStreamResponseType = workflow.streaming_output_schema
|
|
527
|
-
GenerateSingleResponseType = workflow.single_output_schema
|
|
525
|
+
GenerateBodyType = workflow.input_schema
|
|
526
|
+
GenerateStreamResponseType = workflow.streaming_output_schema
|
|
527
|
+
GenerateSingleResponseType = workflow.single_output_schema
|
|
528
528
|
|
|
529
529
|
# Append job_id and expiry_seconds to the input schema, this effectively makes these reserved keywords
|
|
530
530
|
# Consider prefixing these with "nat_" to avoid conflicts
|
|
@@ -572,7 +572,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
572
572
|
|
|
573
573
|
response.headers["Content-Type"] = "application/json"
|
|
574
574
|
|
|
575
|
-
async with session_manager.session(
|
|
575
|
+
async with session_manager.session(http_connection=request,
|
|
576
576
|
user_authentication_callback=self._http_flow_handler.authenticate):
|
|
577
577
|
|
|
578
578
|
return await generate_single_response(None, session_manager, result_type=result_type)
|
|
@@ -583,7 +583,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
583
583
|
|
|
584
584
|
async def get_stream(request: Request):
|
|
585
585
|
|
|
586
|
-
async with session_manager.session(
|
|
586
|
+
async with session_manager.session(http_connection=request,
|
|
587
587
|
user_authentication_callback=self._http_flow_handler.authenticate):
|
|
588
588
|
|
|
589
589
|
return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
|
|
@@ -618,7 +618,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
618
618
|
|
|
619
619
|
response.headers["Content-Type"] = "application/json"
|
|
620
620
|
|
|
621
|
-
async with session_manager.session(
|
|
621
|
+
async with session_manager.session(http_connection=request,
|
|
622
622
|
user_authentication_callback=self._http_flow_handler.authenticate):
|
|
623
623
|
|
|
624
624
|
return await generate_single_response(payload, session_manager, result_type=result_type)
|
|
@@ -632,7 +632,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
632
632
|
|
|
633
633
|
async def post_stream(request: Request, payload: request_type):
|
|
634
634
|
|
|
635
|
-
async with session_manager.session(
|
|
635
|
+
async with session_manager.session(http_connection=request,
|
|
636
636
|
user_authentication_callback=self._http_flow_handler.authenticate):
|
|
637
637
|
|
|
638
638
|
return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
|
|
@@ -677,7 +677,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
677
677
|
# Check if streaming is requested
|
|
678
678
|
stream_requested = getattr(payload, 'stream', False)
|
|
679
679
|
|
|
680
|
-
async with session_manager.session(
|
|
680
|
+
async with session_manager.session(http_connection=request):
|
|
681
681
|
if stream_requested:
|
|
682
682
|
# Return streaming response
|
|
683
683
|
return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
|
|
@@ -688,42 +688,41 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
688
688
|
step_adaptor=self.get_step_adaptor(),
|
|
689
689
|
result_type=ChatResponseChunk,
|
|
690
690
|
output_type=ChatResponseChunk))
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
691
|
+
|
|
692
|
+
# Return single response - check if workflow supports non-streaming
|
|
693
|
+
try:
|
|
694
|
+
response.headers["Content-Type"] = "application/json"
|
|
695
|
+
return await generate_single_response(payload, session_manager, result_type=ChatResponse)
|
|
696
|
+
except ValueError as e:
|
|
697
|
+
if "Cannot get a single output value for streaming workflows" in str(e):
|
|
698
|
+
# Workflow only supports streaming, but client requested non-streaming
|
|
699
|
+
# Fall back to streaming and collect the result
|
|
700
|
+
chunks = []
|
|
701
|
+
async for chunk_str in generate_streaming_response_as_str(
|
|
702
|
+
payload,
|
|
703
|
+
session_manager=session_manager,
|
|
704
|
+
streaming=True,
|
|
705
|
+
step_adaptor=self.get_step_adaptor(),
|
|
706
|
+
result_type=ChatResponseChunk,
|
|
707
|
+
output_type=ChatResponseChunk):
|
|
708
|
+
if chunk_str.startswith("data: ") and not chunk_str.startswith("data: [DONE]"):
|
|
709
|
+
chunk_data = chunk_str[6:].strip() # Remove "data: " prefix
|
|
710
|
+
if chunk_data:
|
|
711
|
+
try:
|
|
712
|
+
chunk_json = ChatResponseChunk.model_validate_json(chunk_data)
|
|
713
|
+
if (chunk_json.choices and len(chunk_json.choices) > 0
|
|
714
|
+
and chunk_json.choices[0].delta
|
|
715
|
+
and chunk_json.choices[0].delta.content is not None):
|
|
716
|
+
chunks.append(chunk_json.choices[0].delta.content)
|
|
717
|
+
except Exception:
|
|
718
|
+
continue
|
|
719
|
+
|
|
720
|
+
# Create a single response from collected chunks
|
|
721
|
+
content = "".join(chunks)
|
|
722
|
+
single_response = ChatResponse.from_string(content)
|
|
694
723
|
response.headers["Content-Type"] = "application/json"
|
|
695
|
-
return
|
|
696
|
-
|
|
697
|
-
if "Cannot get a single output value for streaming workflows" in str(e):
|
|
698
|
-
# Workflow only supports streaming, but client requested non-streaming
|
|
699
|
-
# Fall back to streaming and collect the result
|
|
700
|
-
chunks = []
|
|
701
|
-
async for chunk_str in generate_streaming_response_as_str(
|
|
702
|
-
payload,
|
|
703
|
-
session_manager=session_manager,
|
|
704
|
-
streaming=True,
|
|
705
|
-
step_adaptor=self.get_step_adaptor(),
|
|
706
|
-
result_type=ChatResponseChunk,
|
|
707
|
-
output_type=ChatResponseChunk):
|
|
708
|
-
if chunk_str.startswith("data: ") and not chunk_str.startswith("data: [DONE]"):
|
|
709
|
-
chunk_data = chunk_str[6:].strip() # Remove "data: " prefix
|
|
710
|
-
if chunk_data:
|
|
711
|
-
try:
|
|
712
|
-
chunk_json = ChatResponseChunk.model_validate_json(chunk_data)
|
|
713
|
-
if (chunk_json.choices and len(chunk_json.choices) > 0
|
|
714
|
-
and chunk_json.choices[0].delta
|
|
715
|
-
and chunk_json.choices[0].delta.content is not None):
|
|
716
|
-
chunks.append(chunk_json.choices[0].delta.content)
|
|
717
|
-
except Exception:
|
|
718
|
-
continue
|
|
719
|
-
|
|
720
|
-
# Create a single response from collected chunks
|
|
721
|
-
content = "".join(chunks)
|
|
722
|
-
single_response = ChatResponse.from_string(content)
|
|
723
|
-
response.headers["Content-Type"] = "application/json"
|
|
724
|
-
return single_response
|
|
725
|
-
else:
|
|
726
|
-
raise
|
|
724
|
+
return single_response
|
|
725
|
+
raise
|
|
727
726
|
|
|
728
727
|
return post_openai_api_compatible
|
|
729
728
|
|
|
@@ -758,7 +757,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
758
757
|
http_request: Request) -> AsyncGenerateResponse | AsyncGenerationStatusResponse:
|
|
759
758
|
"""Handle async generation requests."""
|
|
760
759
|
|
|
761
|
-
async with session_manager.session(
|
|
760
|
+
async with session_manager.session(http_connection=http_request):
|
|
762
761
|
|
|
763
762
|
# if job_id is present and already exists return the job info
|
|
764
763
|
if request.job_id:
|
|
@@ -804,7 +803,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
804
803
|
"""Get the status of an async job."""
|
|
805
804
|
logger.info("Getting status for job %s", job_id)
|
|
806
805
|
|
|
807
|
-
async with session_manager.session(
|
|
806
|
+
async with session_manager.session(http_connection=http_request):
|
|
808
807
|
|
|
809
808
|
job = job_store.get_job(job_id)
|
|
810
809
|
if not job:
|
|
@@ -86,7 +86,7 @@ class WebSocketMessageHandler:
|
|
|
86
86
|
|
|
87
87
|
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
|
|
88
88
|
|
|
89
|
-
# TODO: Handle the exit
|
|
89
|
+
# TODO: Handle the exit
|
|
90
90
|
pass
|
|
91
91
|
|
|
92
92
|
async def run(self) -> None:
|
|
@@ -105,12 +105,10 @@ class WebSocketMessageHandler:
|
|
|
105
105
|
if (isinstance(validated_message, WebSocketUserMessage)):
|
|
106
106
|
await self.process_workflow_request(validated_message)
|
|
107
107
|
|
|
108
|
-
elif isinstance(
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
WebSocketSystemIntermediateStepMessage,
|
|
113
|
-
WebSocketSystemInteractionMessage)):
|
|
108
|
+
elif isinstance(validated_message,
|
|
109
|
+
(WebSocketSystemResponseTokenMessage,
|
|
110
|
+
WebSocketSystemIntermediateStepMessage,
|
|
111
|
+
WebSocketSystemInteractionMessage)):
|
|
114
112
|
# These messages are already handled by self.create_websocket_message(data_model=value, …)
|
|
115
113
|
# No further processing is needed here.
|
|
116
114
|
pass
|
|
@@ -119,11 +117,9 @@ class WebSocketMessageHandler:
|
|
|
119
117
|
user_content = await self.process_user_message_content(validated_message)
|
|
120
118
|
self._user_interaction_response.set_result(user_content)
|
|
121
119
|
except (asyncio.CancelledError, WebSocketDisconnect):
|
|
122
|
-
# TODO: Handle the disconnect
|
|
120
|
+
# TODO: Handle the disconnect
|
|
123
121
|
break
|
|
124
122
|
|
|
125
|
-
return None
|
|
126
|
-
|
|
127
123
|
async def process_user_message_content(
|
|
128
124
|
self, user_content: WebSocketUserMessage | WebSocketUserInteractionResponseMessage) -> BaseModel | None:
|
|
129
125
|
"""
|
|
@@ -162,12 +158,13 @@ class WebSocketMessageHandler:
|
|
|
162
158
|
|
|
163
159
|
if isinstance(content, TextContent) and (self._running_workflow_task is None):
|
|
164
160
|
|
|
165
|
-
def _done_callback(task: asyncio.Task):
|
|
161
|
+
def _done_callback(task: asyncio.Task):
|
|
166
162
|
self._running_workflow_task = None
|
|
167
163
|
|
|
168
164
|
self._running_workflow_task = asyncio.create_task(
|
|
169
|
-
self._run_workflow(content.text,
|
|
170
|
-
self.
|
|
165
|
+
self._run_workflow(payload=content.text,
|
|
166
|
+
user_message_id=self._message_parent_id,
|
|
167
|
+
conversation_id=self._conversation_id,
|
|
171
168
|
result_type=self._schema_output_mapping[self._workflow_schema_type],
|
|
172
169
|
output_type=self._schema_output_mapping[
|
|
173
170
|
self._workflow_schema_type])).add_done_callback(_done_callback)
|
|
@@ -290,14 +287,16 @@ class WebSocketMessageHandler:
|
|
|
290
287
|
|
|
291
288
|
async def _run_workflow(self,
|
|
292
289
|
payload: typing.Any,
|
|
290
|
+
user_message_id: str | None = None,
|
|
293
291
|
conversation_id: str | None = None,
|
|
294
292
|
result_type: type | None = None,
|
|
295
293
|
output_type: type | None = None) -> None:
|
|
296
294
|
|
|
297
295
|
try:
|
|
298
296
|
async with self._session_manager.session(
|
|
297
|
+
user_message_id=user_message_id,
|
|
299
298
|
conversation_id=conversation_id,
|
|
300
|
-
|
|
299
|
+
http_connection=self._socket,
|
|
301
300
|
user_input_callback=self.human_interaction_callback,
|
|
302
301
|
user_authentication_callback=(self._flow_handler.authenticate
|
|
303
302
|
if self._flow_handler else None)) as session:
|
|
@@ -232,7 +232,7 @@ class MessageValidator:
|
|
|
232
232
|
"""
|
|
233
233
|
return data_model.parent_id or "root"
|
|
234
234
|
|
|
235
|
-
async def create_system_response_token_message(
|
|
235
|
+
async def create_system_response_token_message(
|
|
236
236
|
self,
|
|
237
237
|
message_type: Literal[WebSocketMessageType.RESPONSE_MESSAGE,
|
|
238
238
|
WebSocketMessageType.ERROR_MESSAGE] = WebSocketMessageType.RESPONSE_MESSAGE,
|
|
@@ -272,7 +272,7 @@ class MessageValidator:
|
|
|
272
272
|
logger.error("Error creating system response token message: %s", str(e), exc_info=True)
|
|
273
273
|
return None
|
|
274
274
|
|
|
275
|
-
async def create_system_intermediate_step_message(
|
|
275
|
+
async def create_system_intermediate_step_message(
|
|
276
276
|
self,
|
|
277
277
|
message_type: Literal[WebSocketMessageType.INTERMEDIATE_STEP_MESSAGE] = (
|
|
278
278
|
WebSocketMessageType.INTERMEDIATE_STEP_MESSAGE),
|
|
@@ -311,7 +311,7 @@ class MessageValidator:
|
|
|
311
311
|
logger.error("Error creating system intermediate step message: %s", str(e), exc_info=True)
|
|
312
312
|
return None
|
|
313
313
|
|
|
314
|
-
async def create_system_interaction_message(
|
|
314
|
+
async def create_system_interaction_message(
|
|
315
315
|
self,
|
|
316
316
|
*,
|
|
317
317
|
message_type: Literal[WebSocketMessageType.SYSTEM_INTERACTION_MESSAGE] = (
|
|
@@ -323,7 +323,7 @@ class MessageValidator:
|
|
|
323
323
|
content: HumanPrompt,
|
|
324
324
|
status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
|
|
325
325
|
timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
|
|
326
|
-
) -> WebSocketSystemInteractionMessage | None:
|
|
326
|
+
) -> WebSocketSystemInteractionMessage | None:
|
|
327
327
|
"""
|
|
328
328
|
Creates a system interaction message with default values.
|
|
329
329
|
|
|
@@ -289,7 +289,7 @@ class StepAdaptor:
|
|
|
289
289
|
|
|
290
290
|
return event
|
|
291
291
|
|
|
292
|
-
def process(self, step: IntermediateStep) -> ResponseSerializable | None:
|
|
292
|
+
def process(self, step: IntermediateStep) -> ResponseSerializable | None:
|
|
293
293
|
|
|
294
294
|
# Track the chunk
|
|
295
295
|
self._history.append(step)
|
nat/front_ends/register.py
CHANGED
nat/llm/aws_bedrock_llm.py
CHANGED
|
@@ -22,9 +22,10 @@ from nat.builder.llm import LLMProviderInfo
|
|
|
22
22
|
from nat.cli.register_workflow import register_llm_provider
|
|
23
23
|
from nat.data_models.llm import LLMBaseConfig
|
|
24
24
|
from nat.data_models.retry_mixin import RetryMixin
|
|
25
|
+
from nat.data_models.temperature_mixin import TemperatureMixin
|
|
25
26
|
|
|
26
27
|
|
|
27
|
-
class AWSBedrockModelConfig(LLMBaseConfig, RetryMixin, name="aws_bedrock"):
|
|
28
|
+
class AWSBedrockModelConfig(LLMBaseConfig, RetryMixin, TemperatureMixin, name="aws_bedrock"):
|
|
28
29
|
"""An AWS Bedrock llm provider to be used with an LLM client."""
|
|
29
30
|
|
|
30
31
|
model_config = ConfigDict(protected_namespaces=())
|
|
@@ -33,7 +34,6 @@ class AWSBedrockModelConfig(LLMBaseConfig, RetryMixin, name="aws_bedrock"):
|
|
|
33
34
|
model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
|
|
34
35
|
serialization_alias="model",
|
|
35
36
|
description="The model name for the hosted AWS Bedrock.")
|
|
36
|
-
temperature: float = Field(default=0.0, ge=0.0, le=1.0, description="Sampling temperature in [0, 1].")
|
|
37
37
|
max_tokens: int | None = Field(default=1024,
|
|
38
38
|
gt=0,
|
|
39
39
|
description="Maximum number of tokens to generate."
|
|
@@ -52,6 +52,6 @@ class AWSBedrockModelConfig(LLMBaseConfig, RetryMixin, name="aws_bedrock"):
|
|
|
52
52
|
|
|
53
53
|
|
|
54
54
|
@register_llm_provider(config_type=AWSBedrockModelConfig)
|
|
55
|
-
async def aws_bedrock_model(llm_config: AWSBedrockModelConfig,
|
|
55
|
+
async def aws_bedrock_model(llm_config: AWSBedrockModelConfig, _builder: Builder):
|
|
56
56
|
|
|
57
57
|
yield LLMProviderInfo(config=llm_config, description="A AWS Bedrock model for use with an LLM client.")
|
nat/llm/azure_openai_llm.py
CHANGED
|
@@ -22,9 +22,11 @@ from nat.builder.llm import LLMProviderInfo
|
|
|
22
22
|
from nat.cli.register_workflow import register_llm_provider
|
|
23
23
|
from nat.data_models.llm import LLMBaseConfig
|
|
24
24
|
from nat.data_models.retry_mixin import RetryMixin
|
|
25
|
+
from nat.data_models.temperature_mixin import TemperatureMixin
|
|
26
|
+
from nat.data_models.top_p_mixin import TopPMixin
|
|
25
27
|
|
|
26
28
|
|
|
27
|
-
class AzureOpenAIModelConfig(LLMBaseConfig, RetryMixin, name="azure_openai"):
|
|
29
|
+
class AzureOpenAIModelConfig(LLMBaseConfig, RetryMixin, TemperatureMixin, TopPMixin, name="azure_openai"):
|
|
28
30
|
"""An Azure OpenAI LLM provider to be used with an LLM client."""
|
|
29
31
|
|
|
30
32
|
model_config = ConfigDict(protected_namespaces=(), extra="allow")
|
|
@@ -38,10 +40,7 @@ class AzureOpenAIModelConfig(LLMBaseConfig, RetryMixin, name="azure_openai"):
|
|
|
38
40
|
azure_deployment: str = Field(validation_alias=AliasChoices("azure_deployment", "model_name", "model"),
|
|
39
41
|
serialization_alias="azure_deployment",
|
|
40
42
|
description="The Azure OpenAI hosted model/deployment name.")
|
|
41
|
-
temperature: float = Field(default=0.0, description="Sampling temperature in [0, 1].")
|
|
42
|
-
top_p: float = Field(default=1.0, description="Top-p for distribution sampling.")
|
|
43
43
|
seed: int | None = Field(default=None, description="Random seed to set for generation.")
|
|
44
|
-
max_retries: int = Field(default=10, description="The max number of retries for the request.")
|
|
45
44
|
|
|
46
45
|
|
|
47
46
|
@register_llm_provider(config_type=AzureOpenAIModelConfig)
|
nat/llm/nim_llm.py
CHANGED
|
@@ -23,9 +23,11 @@ from nat.builder.llm import LLMProviderInfo
|
|
|
23
23
|
from nat.cli.register_workflow import register_llm_provider
|
|
24
24
|
from nat.data_models.llm import LLMBaseConfig
|
|
25
25
|
from nat.data_models.retry_mixin import RetryMixin
|
|
26
|
+
from nat.data_models.temperature_mixin import TemperatureMixin
|
|
27
|
+
from nat.data_models.top_p_mixin import TopPMixin
|
|
26
28
|
|
|
27
29
|
|
|
28
|
-
class NIMModelConfig(LLMBaseConfig, RetryMixin, name="nim"):
|
|
30
|
+
class NIMModelConfig(LLMBaseConfig, RetryMixin, TemperatureMixin, TopPMixin, name="nim"):
|
|
29
31
|
"""An NVIDIA Inference Microservice (NIM) llm provider to be used with an LLM client."""
|
|
30
32
|
|
|
31
33
|
model_config = ConfigDict(protected_namespaces=())
|
|
@@ -35,12 +37,10 @@ class NIMModelConfig(LLMBaseConfig, RetryMixin, name="nim"):
|
|
|
35
37
|
model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
|
|
36
38
|
serialization_alias="model",
|
|
37
39
|
description="The model name for the hosted NIM.")
|
|
38
|
-
temperature: float = Field(default=0.0, description="Sampling temperature in [0, 1].")
|
|
39
|
-
top_p: float = Field(default=1.0, description="Top-p for distribution sampling.")
|
|
40
40
|
max_tokens: PositiveInt = Field(default=300, description="Maximum number of tokens to generate.")
|
|
41
41
|
|
|
42
42
|
|
|
43
43
|
@register_llm_provider(config_type=NIMModelConfig)
|
|
44
|
-
async def nim_model(llm_config: NIMModelConfig,
|
|
44
|
+
async def nim_model(llm_config: NIMModelConfig, _builder: Builder):
|
|
45
45
|
|
|
46
46
|
yield LLMProviderInfo(config=llm_config, description="A NIM model for use with an LLM client.")
|