nvidia-nat 1.3a20250819__py3-none-any.whl → 1.3.0a20250823__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. aiq/__init__.py +66 -0
  2. nat/agent/base.py +16 -0
  3. nat/agent/react_agent/agent.py +38 -13
  4. nat/agent/react_agent/prompt.py +4 -1
  5. nat/agent/react_agent/register.py +1 -1
  6. nat/agent/register.py +0 -1
  7. nat/agent/rewoo_agent/agent.py +6 -3
  8. nat/agent/rewoo_agent/prompt.py +3 -0
  9. nat/agent/rewoo_agent/register.py +4 -3
  10. nat/agent/tool_calling_agent/agent.py +92 -22
  11. nat/agent/tool_calling_agent/register.py +9 -13
  12. nat/authentication/api_key/api_key_auth_provider.py +1 -1
  13. nat/authentication/register.py +0 -1
  14. nat/builder/builder.py +1 -1
  15. nat/builder/context.py +9 -1
  16. nat/builder/function_base.py +3 -3
  17. nat/builder/function_info.py +5 -7
  18. nat/builder/user_interaction_manager.py +2 -2
  19. nat/builder/workflow.py +3 -0
  20. nat/builder/workflow_builder.py +0 -1
  21. nat/cli/commands/evaluate.py +1 -1
  22. nat/cli/commands/info/list_components.py +7 -8
  23. nat/cli/commands/info/list_mcp.py +3 -4
  24. nat/cli/commands/registry/search.py +14 -16
  25. nat/cli/commands/start.py +0 -1
  26. nat/cli/commands/workflow/templates/pyproject.toml.j2 +3 -0
  27. nat/cli/commands/workflow/templates/register.py.j2 +0 -1
  28. nat/cli/commands/workflow/workflow_commands.py +0 -1
  29. nat/cli/type_registry.py +7 -9
  30. nat/data_models/config.py +1 -1
  31. nat/data_models/evaluate.py +1 -1
  32. nat/data_models/function_dependencies.py +6 -6
  33. nat/data_models/intermediate_step.py +3 -3
  34. nat/data_models/model_gated_field_mixin.py +125 -0
  35. nat/data_models/swe_bench_model.py +1 -1
  36. nat/data_models/temperature_mixin.py +36 -0
  37. nat/data_models/top_p_mixin.py +36 -0
  38. nat/embedder/azure_openai_embedder.py +46 -0
  39. nat/embedder/openai_embedder.py +1 -2
  40. nat/embedder/register.py +1 -1
  41. nat/eval/config.py +2 -0
  42. nat/eval/dataset_handler/dataset_handler.py +5 -6
  43. nat/eval/evaluate.py +64 -20
  44. nat/eval/rag_evaluator/register.py +2 -2
  45. nat/eval/register.py +0 -1
  46. nat/eval/tunable_rag_evaluator/evaluate.py +0 -3
  47. nat/eval/utils/eval_trace_ctx.py +89 -0
  48. nat/eval/utils/weave_eval.py +14 -7
  49. nat/experimental/test_time_compute/models/strategy_base.py +3 -2
  50. nat/experimental/test_time_compute/register.py +0 -1
  51. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +0 -2
  52. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +48 -49
  53. nat/front_ends/fastapi/message_handler.py +13 -14
  54. nat/front_ends/fastapi/message_validator.py +4 -4
  55. nat/front_ends/fastapi/step_adaptor.py +1 -1
  56. nat/front_ends/register.py +0 -1
  57. nat/llm/aws_bedrock_llm.py +3 -3
  58. nat/llm/azure_openai_llm.py +49 -0
  59. nat/llm/nim_llm.py +4 -4
  60. nat/llm/openai_llm.py +4 -4
  61. nat/llm/register.py +1 -1
  62. nat/llm/utils/env_config_value.py +2 -3
  63. nat/meta/pypi.md +9 -9
  64. nat/object_store/models.py +2 -0
  65. nat/object_store/register.py +0 -1
  66. nat/observability/exporter/base_exporter.py +1 -1
  67. nat/observability/exporter/file_exporter.py +1 -1
  68. nat/observability/register.py +3 -3
  69. nat/profiler/callbacks/langchain_callback_handler.py +9 -2
  70. nat/profiler/callbacks/semantic_kernel_callback_handler.py +1 -1
  71. nat/profiler/data_frame_row.py +1 -1
  72. nat/profiler/decorators/framework_wrapper.py +1 -4
  73. nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
  74. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
  75. nat/profiler/inference_optimization/data_models.py +3 -3
  76. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
  77. nat/profiler/inference_optimization/token_uniqueness.py +1 -1
  78. nat/profiler/profile_runner.py +13 -8
  79. nat/registry_handlers/package_utils.py +0 -1
  80. nat/registry_handlers/pypi/pypi_handler.py +20 -23
  81. nat/registry_handlers/register.py +3 -4
  82. nat/registry_handlers/rest/rest_handler.py +8 -9
  83. nat/retriever/register.py +0 -1
  84. nat/runtime/session.py +23 -8
  85. nat/settings/global_settings.py +13 -2
  86. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
  87. nat/tool/datetime_tools.py +49 -9
  88. nat/tool/document_search.py +1 -1
  89. nat/tool/mcp/mcp_tool.py +1 -1
  90. nat/tool/register.py +0 -1
  91. nat/utils/data_models/schema_validator.py +2 -2
  92. nat/utils/exception_handlers/automatic_retries.py +0 -2
  93. nat/utils/exception_handlers/schemas.py +1 -1
  94. nat/utils/reactive/base/observable_base.py +2 -2
  95. nat/utils/reactive/base/observer_base.py +1 -1
  96. nat/utils/reactive/observable.py +2 -2
  97. nat/utils/reactive/observer.py +2 -2
  98. nat/utils/reactive/subscription.py +1 -1
  99. nat/utils/settings/global_settings.py +4 -6
  100. nat/utils/type_utils.py +4 -4
  101. {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/METADATA +17 -15
  102. {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/RECORD +107 -100
  103. nvidia_nat-1.3.0a20250823.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
  104. {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/top_level.txt +1 -0
  105. nvidia_nat-1.3a20250819.dist-info/licenses/LICENSE-3rd-party.txt +0 -3686
  106. {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/WHEEL +0 -0
  107. {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/entry_points.txt +0 -0
  108. {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/licenses/LICENSE.md +0 -0
nat/eval/evaluate.py CHANGED
@@ -42,7 +42,7 @@ from nat.runtime.session import SessionManager
42
42
  logger = logging.getLogger(__name__)
43
43
 
44
44
 
45
- class EvaluationRun: # pylint: disable=too-many-public-methods
45
+ class EvaluationRun:
46
46
  """
47
47
  Instantiated for each evaluation run and used to store data for that single run.
48
48
 
@@ -63,7 +63,16 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
63
63
 
64
64
  # Helpers
65
65
  self.intermediate_step_adapter: IntermediateStepAdapter = IntermediateStepAdapter()
66
- self.weave_eval: WeaveEvaluationIntegration = WeaveEvaluationIntegration()
66
+
67
+ # Create evaluation trace context
68
+ try:
69
+ from nat.eval.utils.eval_trace_ctx import WeaveEvalTraceContext
70
+ self.eval_trace_context = WeaveEvalTraceContext()
71
+ except Exception:
72
+ from nat.eval.utils.eval_trace_ctx import EvalTraceContext
73
+ self.eval_trace_context = EvalTraceContext()
74
+
75
+ self.weave_eval: WeaveEvaluationIntegration = WeaveEvaluationIntegration(self.eval_trace_context)
67
76
  # Metadata
68
77
  self.eval_input: EvalInput | None = None
69
78
  self.workflow_interrupted: bool = False
@@ -310,7 +319,7 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
310
319
  except Exception as e:
311
320
  logger.exception("Failed to delete old job directory: %s: %s", dir_to_delete, e, exc_info=True)
312
321
 
313
- def write_output(self, dataset_handler: DatasetHandler, profiler_results: ProfilerResults): # pylint: disable=unused-argument # noqa: E501
322
+ def write_output(self, dataset_handler: DatasetHandler, profiler_results: ProfilerResults):
314
323
  workflow_output_file = self.eval_config.general.output_dir / "workflow_output.json"
315
324
  workflow_output_file.parent.mkdir(parents=True, exist_ok=True)
316
325
 
@@ -401,6 +410,33 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
401
410
 
402
411
  return workflow_type
403
412
 
413
+ async def wait_for_all_export_tasks_local(self, session_manager: SessionManager, timeout: float) -> None:
414
+ """Wait for all trace export tasks to complete for local workflows.
415
+
416
+ This only works for local workflows where we have direct access to the
417
+ SessionManager and its underlying workflow with exporter manager.
418
+ """
419
+ try:
420
+ workflow = session_manager.workflow
421
+ all_exporters = await workflow.get_all_exporters()
422
+ if not all_exporters:
423
+ logger.debug("No exporters to wait for")
424
+ return
425
+
426
+ logger.info("Waiting for export tasks from %d local exporters (timeout: %ds)", len(all_exporters), timeout)
427
+
428
+ for name, exporter in all_exporters.items():
429
+ try:
430
+ await exporter.wait_for_tasks(timeout=timeout)
431
+ logger.info("Export tasks completed for exporter: %s", name)
432
+ except Exception as e:
433
+ logger.warning("Error waiting for export tasks from %s: %s", name, e)
434
+
435
+ logger.info("All local export task waiting completed")
436
+
437
+ except Exception as e:
438
+ logger.warning("Failed to wait for local export tasks: %s", e)
439
+
404
440
  async def run_and_evaluate(self,
405
441
  session_manager: SessionManager | None = None,
406
442
  job_id: str | None = None) -> EvaluationRunOutput:
@@ -442,11 +478,13 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
442
478
  dataset_config = self.eval_config.general.dataset # Currently only one dataset is supported
443
479
  if not dataset_config:
444
480
  logger.info("No dataset found, nothing to evaluate")
445
- return EvaluationRunOutput(
446
- workflow_output_file=self.workflow_output_file,
447
- evaluator_output_files=self.evaluator_output_files,
448
- workflow_interrupted=self.workflow_interrupted,
449
- )
481
+ return EvaluationRunOutput(workflow_output_file=self.workflow_output_file,
482
+ evaluator_output_files=self.evaluator_output_files,
483
+ workflow_interrupted=self.workflow_interrupted,
484
+ eval_input=EvalInput(eval_input_items=[]),
485
+ evaluation_results=[],
486
+ usage_stats=UsageStats(),
487
+ profiler_results=ProfilerResults())
450
488
 
451
489
  dataset_handler = DatasetHandler(dataset_config=dataset_config,
452
490
  reps=self.config.reps,
@@ -456,11 +494,13 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
456
494
  self.eval_input = dataset_handler.get_eval_input_from_dataset(self.config.dataset)
457
495
  if not self.eval_input.eval_input_items:
458
496
  logger.info("Dataset is empty. Nothing to evaluate.")
459
- return EvaluationRunOutput(
460
- workflow_output_file=self.workflow_output_file,
461
- evaluator_output_files=self.evaluator_output_files,
462
- workflow_interrupted=self.workflow_interrupted,
463
- )
497
+ return EvaluationRunOutput(workflow_output_file=self.workflow_output_file,
498
+ evaluator_output_files=self.evaluator_output_files,
499
+ workflow_interrupted=self.workflow_interrupted,
500
+ eval_input=self.eval_input,
501
+ evaluation_results=self.evaluation_results,
502
+ usage_stats=self.usage_stats,
503
+ profiler_results=ProfilerResults())
464
504
 
465
505
  # Run workflow and evaluate
466
506
  async with WorkflowEvalBuilder.from_config(config=config) as eval_workflow:
@@ -468,18 +508,22 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
468
508
  self.weave_eval.initialize_logger(workflow_alias, self.eval_input, config)
469
509
 
470
510
  # Run workflow
471
- if self.config.endpoint:
472
- await self.run_workflow_remote()
473
- else:
474
- if not self.config.skip_workflow:
511
+ with self.eval_trace_context.evaluation_context():
512
+ if self.config.endpoint:
513
+ await self.run_workflow_remote()
514
+ elif not self.config.skip_workflow:
475
515
  if session_manager is None:
476
516
  session_manager = SessionManager(eval_workflow.build(),
477
517
  max_concurrency=self.eval_config.general.max_concurrency)
478
518
  await self.run_workflow_local(session_manager)
479
519
 
480
- # Evaluate
481
- evaluators = {name: eval_workflow.get_evaluator(name) for name in self.eval_config.evaluators}
482
- await self.run_evaluators(evaluators)
520
+ # Evaluate
521
+ evaluators = {name: eval_workflow.get_evaluator(name) for name in self.eval_config.evaluators}
522
+ await self.run_evaluators(evaluators)
523
+
524
+ # Wait for all trace export tasks to complete (local workflows only)
525
+ if session_manager and not self.config.endpoint:
526
+ await self.wait_for_all_export_tasks_local(session_manager, timeout=self.config.export_timeout)
483
527
 
484
528
  # Profile the workflow
485
529
  profiler_results = await self.profile_workflow()
@@ -73,7 +73,7 @@ class RagasEvaluatorConfig(EvaluatorBaseConfig, name="ragas"):
73
73
  if isinstance(self.metric, str):
74
74
  return self.metric
75
75
  if isinstance(self.metric, dict) and self.metric:
76
- return next(iter(self.metric.keys())) # pylint: disable=no-member
76
+ return next(iter(self.metric.keys()))
77
77
  return ""
78
78
 
79
79
  @property
@@ -82,7 +82,7 @@ class RagasEvaluatorConfig(EvaluatorBaseConfig, name="ragas"):
82
82
  if isinstance(self.metric, str):
83
83
  return RagasMetricConfig() # Default config when only a metric name is given
84
84
  if isinstance(self.metric, dict) and self.metric:
85
- return next(iter(self.metric.values())) # pylint: disable=no-member
85
+ return next(iter(self.metric.values()))
86
86
  return RagasMetricConfig() # Default config when an invalid type is provided
87
87
 
88
88
 
nat/eval/register.py CHANGED
@@ -14,7 +14,6 @@
14
14
  # limitations under the License.
15
15
 
16
16
  # flake8: noqa
17
- # pylint: disable=unused-import
18
17
 
19
18
  # Import evaluators which need to be automatically registered here
20
19
  from .rag_evaluator.register import register_ragas_evaluator
@@ -13,7 +13,6 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import asyncio
17
16
  import logging
18
17
  from typing import Callable
19
18
 
@@ -23,7 +22,6 @@ from langchain.schema import HumanMessage
23
22
  from langchain.schema import SystemMessage
24
23
  from langchain_core.language_models import BaseChatModel
25
24
  from langchain_core.runnables import RunnableLambda
26
- from tqdm import tqdm
27
25
 
28
26
  from nat.eval.evaluator.base_evaluator import BaseEvaluator
29
27
  from nat.eval.evaluator.evaluator_model import EvalInputItem
@@ -31,7 +29,6 @@ from nat.eval.evaluator.evaluator_model import EvalOutputItem
31
29
 
32
30
  logger = logging.getLogger(__name__)
33
31
 
34
- # pylint: disable=line-too-long
35
32
  # flake8: noqa: E501
36
33
 
37
34
 
@@ -0,0 +1,89 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ from collections.abc import Callable
18
+ from contextlib import contextmanager
19
+ from typing import Any
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # Type alias for evaluation call objects that have an optional 'id' attribute
24
+ EvalCallType = Any # Could be Weave Call object or other tracing framework objects
25
+
26
+
27
+ class EvalTraceContext:
28
+ """
29
+ Evaluation trace context manager for coordinating traces.
30
+
31
+ This class provides a framework-agnostic way to:
32
+ 1. Track evaluation calls/contexts
33
+ 2. Ensure proper parent-child relationships in traces
34
+ """
35
+
36
+ def __init__(self):
37
+ self.eval_call: EvalCallType | None = None # Store the evaluation call/context for propagation
38
+
39
+ def set_eval_call(self, eval_call: EvalCallType | None) -> None:
40
+ """Set the evaluation call/context for propagation to traces."""
41
+ self.eval_call = eval_call
42
+ if eval_call:
43
+ logger.debug("Set evaluation call context: %s", getattr(eval_call, 'id', str(eval_call)))
44
+
45
+ def get_eval_call(self) -> EvalCallType | None:
46
+ """Get the current evaluation call/context."""
47
+ return self.eval_call
48
+
49
+ @contextmanager
50
+ def evaluation_context(self):
51
+ """
52
+ Context manager that can be overridden by framework-specific implementations.
53
+ Default implementation is a no-op.
54
+ """
55
+ yield
56
+
57
+
58
+ class WeaveEvalTraceContext(EvalTraceContext):
59
+ """
60
+ Weave-specific implementation of evaluation trace context.
61
+ """
62
+
63
+ def __init__(self):
64
+ super().__init__()
65
+ self.available = False
66
+ self.set_call_stack: Callable[[list[EvalCallType]], Any] | None = None
67
+
68
+ try:
69
+ from weave.trace.context.call_context import set_call_stack
70
+ self.set_call_stack = set_call_stack
71
+ self.available = True
72
+ except ImportError:
73
+ self.available = False
74
+ logger.debug("Weave not available for trace context")
75
+
76
+ @contextmanager
77
+ def evaluation_context(self):
78
+ """Set the evaluation call as active context for Weave traces."""
79
+ if self.available and self.eval_call and self.set_call_stack:
80
+ try:
81
+ with self.set_call_stack([self.eval_call]):
82
+ logger.debug("Set Weave evaluation call context: %s",
83
+ getattr(self.eval_call, 'id', str(self.eval_call)))
84
+ yield
85
+ except Exception as e:
86
+ logger.warning("Failed to set Weave evaluation call context: %s", e)
87
+ yield
88
+ else:
89
+ yield
@@ -15,6 +15,7 @@
15
15
 
16
16
  import asyncio
17
17
  import logging
18
+ from typing import TYPE_CHECKING
18
19
  from typing import Any
19
20
 
20
21
  from nat.eval.evaluator.evaluator_model import EvalInput
@@ -24,26 +25,30 @@ from nat.eval.usage_stats import UsageStats
24
25
  from nat.eval.usage_stats import UsageStatsItem
25
26
  from nat.profiler.data_models import ProfilerResults
26
27
 
28
+ if TYPE_CHECKING:
29
+ from nat.eval.utils.eval_trace_ctx import EvalTraceContext
30
+
27
31
  logger = logging.getLogger(__name__)
28
32
 
29
33
 
30
- class WeaveEvaluationIntegration: # pylint: disable=too-many-public-methods
34
+ class WeaveEvaluationIntegration:
31
35
  """
32
36
  Class to handle all Weave integration functionality.
33
37
  """
34
38
 
35
- def __init__(self):
39
+ def __init__(self, eval_trace_context: "EvalTraceContext"):
36
40
  self.available = False
37
41
  self.client = None
38
42
  self.eval_logger = None
39
43
  self.pred_loggers = {}
44
+ self.eval_trace_context = eval_trace_context
40
45
 
41
46
  try:
42
47
  from weave.flow.eval_imperative import EvaluationLogger
43
48
  from weave.flow.eval_imperative import ScoreLogger
44
49
  from weave.trace.context import weave_client_context
45
- self.EvaluationLogger = EvaluationLogger
46
- self.ScoreLogger = ScoreLogger
50
+ self.evaluation_logger_cls = EvaluationLogger
51
+ self.score_logger_cls = ScoreLogger
47
52
  self.weave_client_context = weave_client_context
48
53
  self.available = True
49
54
  except ImportError:
@@ -89,9 +94,12 @@ class WeaveEvaluationIntegration: # pylint: disable=too-many-public-methods
89
94
  weave_dataset = self._get_weave_dataset(eval_input)
90
95
  config_dict = config.model_dump(mode="json")
91
96
  config_dict["name"] = workflow_alias
92
- self.eval_logger = self.EvaluationLogger(model=config_dict, dataset=weave_dataset)
97
+ self.eval_logger = self.evaluation_logger_cls(model=config_dict, dataset=weave_dataset)
93
98
  self.pred_loggers = {}
94
99
 
100
+ # Capture the current evaluation call for context propagation
101
+ self.eval_trace_context.set_eval_call(self.eval_logger._evaluate_call)
102
+
95
103
  return True
96
104
  except Exception as e:
97
105
  self.eval_logger = None
@@ -137,7 +145,7 @@ class WeaveEvaluationIntegration: # pylint: disable=too-many-public-methods
137
145
  await asyncio.gather(*coros)
138
146
 
139
147
  async def afinish_loggers(self):
140
- """Finish all prediction loggers."""
148
+ """Finish all prediction loggers and wait for exports."""
141
149
  if not self.eval_logger:
142
150
  return
143
151
 
@@ -157,7 +165,6 @@ class WeaveEvaluationIntegration: # pylint: disable=too-many-public-methods
157
165
  if profiler_results.workflow_runtime_metrics:
158
166
  profile_metrics["wf_runtime_p95"] = profiler_results.workflow_runtime_metrics.p95
159
167
 
160
- # TODO:get the LLM tokens from the usage stats and log them
161
168
  profile_metrics["total_runtime"] = usage_stats.total_runtime
162
169
 
163
170
  return profile_metrics
@@ -17,9 +17,10 @@ from abc import ABC
17
17
  from abc import abstractmethod
18
18
 
19
19
  from nat.builder.builder import Builder
20
- from nat.experimental.test_time_compute.models.ttc_item import TTCItem
21
- from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum, PipelineTypeEnum
22
20
  from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
21
+ from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
22
+ from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
23
+ from nat.experimental.test_time_compute.models.ttc_item import TTCItem
23
24
 
24
25
 
25
26
  class StrategyBase(ABC):
@@ -13,7 +13,6 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- # pylint: disable=unused-import
17
16
  # flake8: noqa
18
17
 
19
18
  from .editing import iterative_plan_refinement_editor
@@ -135,8 +135,6 @@ class LLMBasedOutputMergingSelector(StrategyBase):
135
135
  except Exception as e:
136
136
  logger.error(f"Error parsing merged output: {e}")
137
137
  raise ValueError("Failed to parse merged output.")
138
- else:
139
- merged_output = merged_output
140
138
 
141
139
  logger.info("Merged output: %s", str(merged_output))
142
140
 
@@ -307,7 +307,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
307
307
  async def start_evaluation(request: EvaluateRequest, background_tasks: BackgroundTasks, http_request: Request):
308
308
  """Handle evaluation requests."""
309
309
 
310
- async with session_manager.session(request=http_request):
310
+ async with session_manager.session(http_connection=http_request):
311
311
 
312
312
  # if job_id is present and already exists return the job info
313
313
  if request.job_id:
@@ -336,7 +336,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
336
336
  """Get the status of an evaluation job."""
337
337
  logger.info("Getting status for job %s", job_id)
338
338
 
339
- async with session_manager.session(request=http_request):
339
+ async with session_manager.session(http_connection=http_request):
340
340
 
341
341
  job = job_store.get_job(job_id)
342
342
  if not job:
@@ -349,7 +349,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
349
349
  """Get the status of the last created evaluation job."""
350
350
  logger.info("Getting last job status")
351
351
 
352
- async with session_manager.session(request=http_request):
352
+ async with session_manager.session(http_connection=http_request):
353
353
 
354
354
  job = job_store.get_last_job()
355
355
  if not job:
@@ -361,7 +361,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
361
361
  async def get_jobs(http_request: Request, status: str | None = None) -> list[EvaluateStatusResponse]:
362
362
  """Get all jobs, optionally filtered by status."""
363
363
 
364
- async with session_manager.session(request=http_request):
364
+ async with session_manager.session(http_connection=http_request):
365
365
 
366
366
  if status is None:
367
367
  logger.info("Getting all jobs")
@@ -522,9 +522,9 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
522
522
 
523
523
  workflow = session_manager.workflow
524
524
 
525
- GenerateBodyType = workflow.input_schema # pylint: disable=invalid-name
526
- GenerateStreamResponseType = workflow.streaming_output_schema # pylint: disable=invalid-name
527
- GenerateSingleResponseType = workflow.single_output_schema # pylint: disable=invalid-name
525
+ GenerateBodyType = workflow.input_schema
526
+ GenerateStreamResponseType = workflow.streaming_output_schema
527
+ GenerateSingleResponseType = workflow.single_output_schema
528
528
 
529
529
  # Append job_id and expiry_seconds to the input schema, this effectively makes these reserved keywords
530
530
  # Consider prefixing these with "nat_" to avoid conflicts
@@ -572,7 +572,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
572
572
 
573
573
  response.headers["Content-Type"] = "application/json"
574
574
 
575
- async with session_manager.session(request=request,
575
+ async with session_manager.session(http_connection=request,
576
576
  user_authentication_callback=self._http_flow_handler.authenticate):
577
577
 
578
578
  return await generate_single_response(None, session_manager, result_type=result_type)
@@ -583,7 +583,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
583
583
 
584
584
  async def get_stream(request: Request):
585
585
 
586
- async with session_manager.session(request=request,
586
+ async with session_manager.session(http_connection=request,
587
587
  user_authentication_callback=self._http_flow_handler.authenticate):
588
588
 
589
589
  return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
@@ -618,7 +618,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
618
618
 
619
619
  response.headers["Content-Type"] = "application/json"
620
620
 
621
- async with session_manager.session(request=request,
621
+ async with session_manager.session(http_connection=request,
622
622
  user_authentication_callback=self._http_flow_handler.authenticate):
623
623
 
624
624
  return await generate_single_response(payload, session_manager, result_type=result_type)
@@ -632,7 +632,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
632
632
 
633
633
  async def post_stream(request: Request, payload: request_type):
634
634
 
635
- async with session_manager.session(request=request,
635
+ async with session_manager.session(http_connection=request,
636
636
  user_authentication_callback=self._http_flow_handler.authenticate):
637
637
 
638
638
  return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
@@ -677,7 +677,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
677
677
  # Check if streaming is requested
678
678
  stream_requested = getattr(payload, 'stream', False)
679
679
 
680
- async with session_manager.session(request=request):
680
+ async with session_manager.session(http_connection=request):
681
681
  if stream_requested:
682
682
  # Return streaming response
683
683
  return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
@@ -688,42 +688,41 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
688
688
  step_adaptor=self.get_step_adaptor(),
689
689
  result_type=ChatResponseChunk,
690
690
  output_type=ChatResponseChunk))
691
- else:
692
- # Return single response - check if workflow supports non-streaming
693
- try:
691
+
692
+ # Return single response - check if workflow supports non-streaming
693
+ try:
694
+ response.headers["Content-Type"] = "application/json"
695
+ return await generate_single_response(payload, session_manager, result_type=ChatResponse)
696
+ except ValueError as e:
697
+ if "Cannot get a single output value for streaming workflows" in str(e):
698
+ # Workflow only supports streaming, but client requested non-streaming
699
+ # Fall back to streaming and collect the result
700
+ chunks = []
701
+ async for chunk_str in generate_streaming_response_as_str(
702
+ payload,
703
+ session_manager=session_manager,
704
+ streaming=True,
705
+ step_adaptor=self.get_step_adaptor(),
706
+ result_type=ChatResponseChunk,
707
+ output_type=ChatResponseChunk):
708
+ if chunk_str.startswith("data: ") and not chunk_str.startswith("data: [DONE]"):
709
+ chunk_data = chunk_str[6:].strip() # Remove "data: " prefix
710
+ if chunk_data:
711
+ try:
712
+ chunk_json = ChatResponseChunk.model_validate_json(chunk_data)
713
+ if (chunk_json.choices and len(chunk_json.choices) > 0
714
+ and chunk_json.choices[0].delta
715
+ and chunk_json.choices[0].delta.content is not None):
716
+ chunks.append(chunk_json.choices[0].delta.content)
717
+ except Exception:
718
+ continue
719
+
720
+ # Create a single response from collected chunks
721
+ content = "".join(chunks)
722
+ single_response = ChatResponse.from_string(content)
694
723
  response.headers["Content-Type"] = "application/json"
695
- return await generate_single_response(payload, session_manager, result_type=ChatResponse)
696
- except ValueError as e:
697
- if "Cannot get a single output value for streaming workflows" in str(e):
698
- # Workflow only supports streaming, but client requested non-streaming
699
- # Fall back to streaming and collect the result
700
- chunks = []
701
- async for chunk_str in generate_streaming_response_as_str(
702
- payload,
703
- session_manager=session_manager,
704
- streaming=True,
705
- step_adaptor=self.get_step_adaptor(),
706
- result_type=ChatResponseChunk,
707
- output_type=ChatResponseChunk):
708
- if chunk_str.startswith("data: ") and not chunk_str.startswith("data: [DONE]"):
709
- chunk_data = chunk_str[6:].strip() # Remove "data: " prefix
710
- if chunk_data:
711
- try:
712
- chunk_json = ChatResponseChunk.model_validate_json(chunk_data)
713
- if (chunk_json.choices and len(chunk_json.choices) > 0
714
- and chunk_json.choices[0].delta
715
- and chunk_json.choices[0].delta.content is not None):
716
- chunks.append(chunk_json.choices[0].delta.content)
717
- except Exception:
718
- continue
719
-
720
- # Create a single response from collected chunks
721
- content = "".join(chunks)
722
- single_response = ChatResponse.from_string(content)
723
- response.headers["Content-Type"] = "application/json"
724
- return single_response
725
- else:
726
- raise
724
+ return single_response
725
+ raise
727
726
 
728
727
  return post_openai_api_compatible
729
728
 
@@ -758,7 +757,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
758
757
  http_request: Request) -> AsyncGenerateResponse | AsyncGenerationStatusResponse:
759
758
  """Handle async generation requests."""
760
759
 
761
- async with session_manager.session(request=http_request):
760
+ async with session_manager.session(http_connection=http_request):
762
761
 
763
762
  # if job_id is present and already exists return the job info
764
763
  if request.job_id:
@@ -804,7 +803,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
804
803
  """Get the status of an async job."""
805
804
  logger.info("Getting status for job %s", job_id)
806
805
 
807
- async with session_manager.session(request=http_request):
806
+ async with session_manager.session(http_connection=http_request):
808
807
 
809
808
  job = job_store.get_job(job_id)
810
809
  if not job:
@@ -86,7 +86,7 @@ class WebSocketMessageHandler:
86
86
 
87
87
  async def __aexit__(self, exc_type, exc_value, traceback) -> None:
88
88
 
89
- # TODO: Handle the exit # pylint: disable=fixme
89
+ # TODO: Handle the exit
90
90
  pass
91
91
 
92
92
  async def run(self) -> None:
@@ -105,12 +105,10 @@ class WebSocketMessageHandler:
105
105
  if (isinstance(validated_message, WebSocketUserMessage)):
106
106
  await self.process_workflow_request(validated_message)
107
107
 
108
- elif isinstance(
109
- validated_message,
110
- ( # noqa: E131
111
- WebSocketSystemResponseTokenMessage,
112
- WebSocketSystemIntermediateStepMessage,
113
- WebSocketSystemInteractionMessage)):
108
+ elif isinstance(validated_message,
109
+ (WebSocketSystemResponseTokenMessage,
110
+ WebSocketSystemIntermediateStepMessage,
111
+ WebSocketSystemInteractionMessage)):
114
112
  # These messages are already handled by self.create_websocket_message(data_model=value, …)
115
113
  # No further processing is needed here.
116
114
  pass
@@ -119,11 +117,9 @@ class WebSocketMessageHandler:
119
117
  user_content = await self.process_user_message_content(validated_message)
120
118
  self._user_interaction_response.set_result(user_content)
121
119
  except (asyncio.CancelledError, WebSocketDisconnect):
122
- # TODO: Handle the disconnect # pylint: disable=fixme
120
+ # TODO: Handle the disconnect
123
121
  break
124
122
 
125
- return None
126
-
127
123
  async def process_user_message_content(
128
124
  self, user_content: WebSocketUserMessage | WebSocketUserInteractionResponseMessage) -> BaseModel | None:
129
125
  """
@@ -162,12 +158,13 @@ class WebSocketMessageHandler:
162
158
 
163
159
  if isinstance(content, TextContent) and (self._running_workflow_task is None):
164
160
 
165
- def _done_callback(task: asyncio.Task): # pylint: disable=unused-argument
161
+ def _done_callback(task: asyncio.Task):
166
162
  self._running_workflow_task = None
167
163
 
168
164
  self._running_workflow_task = asyncio.create_task(
169
- self._run_workflow(content.text,
170
- self._conversation_id,
165
+ self._run_workflow(payload=content.text,
166
+ user_message_id=self._message_parent_id,
167
+ conversation_id=self._conversation_id,
171
168
  result_type=self._schema_output_mapping[self._workflow_schema_type],
172
169
  output_type=self._schema_output_mapping[
173
170
  self._workflow_schema_type])).add_done_callback(_done_callback)
@@ -290,14 +287,16 @@ class WebSocketMessageHandler:
290
287
 
291
288
  async def _run_workflow(self,
292
289
  payload: typing.Any,
290
+ user_message_id: str | None = None,
293
291
  conversation_id: str | None = None,
294
292
  result_type: type | None = None,
295
293
  output_type: type | None = None) -> None:
296
294
 
297
295
  try:
298
296
  async with self._session_manager.session(
297
+ user_message_id=user_message_id,
299
298
  conversation_id=conversation_id,
300
- request=self._socket,
299
+ http_connection=self._socket,
301
300
  user_input_callback=self.human_interaction_callback,
302
301
  user_authentication_callback=(self._flow_handler.authenticate
303
302
  if self._flow_handler else None)) as session: