nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (250) hide show
  1. aiq/__init__.py +2 -2
  2. nat/agent/base.py +24 -15
  3. nat/agent/dual_node.py +9 -4
  4. nat/agent/prompt_optimizer/prompt.py +68 -0
  5. nat/agent/prompt_optimizer/register.py +149 -0
  6. nat/agent/react_agent/agent.py +79 -47
  7. nat/agent/react_agent/register.py +50 -22
  8. nat/agent/reasoning_agent/reasoning_agent.py +11 -9
  9. nat/agent/register.py +1 -1
  10. nat/agent/rewoo_agent/agent.py +326 -148
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +54 -27
  13. nat/agent/tool_calling_agent/agent.py +84 -28
  14. nat/agent/tool_calling_agent/register.py +51 -28
  15. nat/authentication/api_key/api_key_auth_provider.py +2 -2
  16. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  17. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  18. nat/authentication/interfaces.py +5 -2
  19. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
  20. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  21. nat/authentication/register.py +0 -1
  22. nat/builder/builder.py +56 -24
  23. nat/builder/component_utils.py +9 -5
  24. nat/builder/context.py +68 -17
  25. nat/builder/eval_builder.py +16 -11
  26. nat/builder/framework_enum.py +1 -0
  27. nat/builder/front_end.py +1 -1
  28. nat/builder/function.py +378 -8
  29. nat/builder/function_base.py +3 -3
  30. nat/builder/function_info.py +6 -8
  31. nat/builder/user_interaction_manager.py +2 -2
  32. nat/builder/workflow.py +13 -1
  33. nat/builder/workflow_builder.py +281 -76
  34. nat/cli/cli_utils/config_override.py +2 -2
  35. nat/cli/commands/evaluate.py +1 -1
  36. nat/cli/commands/info/info.py +16 -6
  37. nat/cli/commands/info/list_channels.py +1 -1
  38. nat/cli/commands/info/list_components.py +7 -8
  39. nat/cli/commands/mcp/__init__.py +14 -0
  40. nat/cli/commands/mcp/mcp.py +986 -0
  41. nat/cli/commands/object_store/__init__.py +14 -0
  42. nat/cli/commands/object_store/object_store.py +227 -0
  43. nat/cli/commands/optimize.py +90 -0
  44. nat/cli/commands/registry/publish.py +2 -2
  45. nat/cli/commands/registry/pull.py +2 -2
  46. nat/cli/commands/registry/remove.py +2 -2
  47. nat/cli/commands/registry/search.py +15 -17
  48. nat/cli/commands/start.py +16 -5
  49. nat/cli/commands/uninstall.py +1 -1
  50. nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
  51. nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
  52. nat/cli/commands/workflow/templates/register.py.j2 +2 -3
  53. nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
  54. nat/cli/commands/workflow/workflow_commands.py +62 -22
  55. nat/cli/entrypoint.py +8 -10
  56. nat/cli/main.py +3 -0
  57. nat/cli/register_workflow.py +38 -4
  58. nat/cli/type_registry.py +75 -6
  59. nat/control_flow/__init__.py +0 -0
  60. nat/control_flow/register.py +20 -0
  61. nat/control_flow/router_agent/__init__.py +0 -0
  62. nat/control_flow/router_agent/agent.py +329 -0
  63. nat/control_flow/router_agent/prompt.py +48 -0
  64. nat/control_flow/router_agent/register.py +91 -0
  65. nat/control_flow/sequential_executor.py +166 -0
  66. nat/data_models/agent.py +34 -0
  67. nat/data_models/api_server.py +74 -66
  68. nat/data_models/authentication.py +23 -9
  69. nat/data_models/common.py +1 -1
  70. nat/data_models/component.py +2 -0
  71. nat/data_models/component_ref.py +11 -0
  72. nat/data_models/config.py +41 -17
  73. nat/data_models/dataset_handler.py +1 -1
  74. nat/data_models/discovery_metadata.py +4 -4
  75. nat/data_models/evaluate.py +4 -1
  76. nat/data_models/function.py +34 -0
  77. nat/data_models/function_dependencies.py +14 -6
  78. nat/data_models/gated_field_mixin.py +242 -0
  79. nat/data_models/intermediate_step.py +3 -3
  80. nat/data_models/optimizable.py +119 -0
  81. nat/data_models/optimizer.py +149 -0
  82. nat/data_models/span.py +41 -3
  83. nat/data_models/swe_bench_model.py +1 -1
  84. nat/data_models/temperature_mixin.py +44 -0
  85. nat/data_models/thinking_mixin.py +86 -0
  86. nat/data_models/top_p_mixin.py +44 -0
  87. nat/embedder/nim_embedder.py +1 -1
  88. nat/embedder/openai_embedder.py +1 -1
  89. nat/embedder/register.py +0 -1
  90. nat/eval/config.py +3 -1
  91. nat/eval/dataset_handler/dataset_handler.py +71 -7
  92. nat/eval/evaluate.py +86 -31
  93. nat/eval/evaluator/base_evaluator.py +1 -1
  94. nat/eval/evaluator/evaluator_model.py +13 -0
  95. nat/eval/intermediate_step_adapter.py +1 -1
  96. nat/eval/rag_evaluator/evaluate.py +2 -2
  97. nat/eval/rag_evaluator/register.py +3 -3
  98. nat/eval/register.py +4 -1
  99. nat/eval/remote_workflow.py +3 -3
  100. nat/eval/runtime_evaluator/__init__.py +14 -0
  101. nat/eval/runtime_evaluator/evaluate.py +123 -0
  102. nat/eval/runtime_evaluator/register.py +100 -0
  103. nat/eval/swe_bench_evaluator/evaluate.py +6 -6
  104. nat/eval/trajectory_evaluator/evaluate.py +1 -1
  105. nat/eval/trajectory_evaluator/register.py +1 -1
  106. nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
  107. nat/eval/utils/eval_trace_ctx.py +89 -0
  108. nat/eval/utils/weave_eval.py +18 -9
  109. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  110. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  111. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
  112. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
  113. nat/experimental/test_time_compute/models/strategy_base.py +5 -4
  114. nat/experimental/test_time_compute/register.py +0 -1
  115. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
  116. nat/front_ends/console/authentication_flow_handler.py +82 -30
  117. nat/front_ends/console/console_front_end_plugin.py +8 -5
  118. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  119. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  120. nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
  121. nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
  122. nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
  123. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +452 -282
  124. nat/front_ends/fastapi/job_store.py +518 -99
  125. nat/front_ends/fastapi/main.py +11 -19
  126. nat/front_ends/fastapi/message_handler.py +13 -14
  127. nat/front_ends/fastapi/message_validator.py +19 -19
  128. nat/front_ends/fastapi/response_helpers.py +4 -4
  129. nat/front_ends/fastapi/step_adaptor.py +2 -2
  130. nat/front_ends/fastapi/utils.py +57 -0
  131. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  132. nat/front_ends/mcp/mcp_front_end_config.py +10 -1
  133. nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
  134. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
  135. nat/front_ends/mcp/tool_converter.py +44 -14
  136. nat/front_ends/register.py +0 -1
  137. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  138. nat/llm/aws_bedrock_llm.py +24 -12
  139. nat/llm/azure_openai_llm.py +13 -6
  140. nat/llm/litellm_llm.py +69 -0
  141. nat/llm/nim_llm.py +20 -8
  142. nat/llm/openai_llm.py +14 -6
  143. nat/llm/register.py +4 -1
  144. nat/llm/utils/env_config_value.py +2 -3
  145. nat/llm/utils/thinking.py +215 -0
  146. nat/meta/pypi.md +9 -9
  147. nat/object_store/register.py +0 -1
  148. nat/observability/exporter/base_exporter.py +3 -3
  149. nat/observability/exporter/file_exporter.py +1 -1
  150. nat/observability/exporter/processing_exporter.py +309 -81
  151. nat/observability/exporter/span_exporter.py +35 -15
  152. nat/observability/exporter_manager.py +7 -7
  153. nat/observability/mixin/file_mixin.py +7 -7
  154. nat/observability/mixin/redaction_config_mixin.py +42 -0
  155. nat/observability/mixin/tagging_config_mixin.py +62 -0
  156. nat/observability/mixin/type_introspection_mixin.py +420 -107
  157. nat/observability/processor/batching_processor.py +5 -7
  158. nat/observability/processor/falsy_batch_filter_processor.py +55 -0
  159. nat/observability/processor/processor.py +3 -0
  160. nat/observability/processor/processor_factory.py +70 -0
  161. nat/observability/processor/redaction/__init__.py +24 -0
  162. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  163. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  164. nat/observability/processor/redaction/redaction_processor.py +177 -0
  165. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  166. nat/observability/processor/span_tagging_processor.py +68 -0
  167. nat/observability/register.py +6 -4
  168. nat/profiler/calc/calc_runner.py +3 -4
  169. nat/profiler/callbacks/agno_callback_handler.py +1 -1
  170. nat/profiler/callbacks/langchain_callback_handler.py +6 -6
  171. nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
  172. nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
  173. nat/profiler/data_frame_row.py +1 -1
  174. nat/profiler/decorators/framework_wrapper.py +62 -13
  175. nat/profiler/decorators/function_tracking.py +160 -3
  176. nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
  177. nat/profiler/forecasting/models/linear_model.py +1 -1
  178. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  179. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  180. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
  181. nat/profiler/inference_optimization/data_models.py +3 -3
  182. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +8 -9
  183. nat/profiler/inference_optimization/token_uniqueness.py +1 -1
  184. nat/profiler/parameter_optimization/__init__.py +0 -0
  185. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  186. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  187. nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
  188. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  189. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  190. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  191. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  192. nat/profiler/profile_runner.py +14 -9
  193. nat/profiler/utils.py +4 -2
  194. nat/registry_handlers/local/local_handler.py +2 -2
  195. nat/registry_handlers/package_utils.py +1 -2
  196. nat/registry_handlers/pypi/pypi_handler.py +23 -26
  197. nat/registry_handlers/register.py +3 -4
  198. nat/registry_handlers/rest/rest_handler.py +12 -13
  199. nat/retriever/milvus/retriever.py +2 -2
  200. nat/retriever/nemo_retriever/retriever.py +1 -1
  201. nat/retriever/register.py +0 -1
  202. nat/runtime/loader.py +2 -2
  203. nat/runtime/runner.py +106 -8
  204. nat/runtime/session.py +69 -8
  205. nat/settings/global_settings.py +16 -5
  206. nat/tool/chat_completion.py +5 -2
  207. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
  208. nat/tool/datetime_tools.py +49 -9
  209. nat/tool/document_search.py +2 -2
  210. nat/tool/github_tools.py +450 -0
  211. nat/tool/memory_tools/get_memory_tool.py +1 -1
  212. nat/tool/nvidia_rag.py +1 -1
  213. nat/tool/register.py +2 -9
  214. nat/tool/retriever.py +3 -2
  215. nat/utils/callable_utils.py +70 -0
  216. nat/utils/data_models/schema_validator.py +3 -3
  217. nat/utils/decorators.py +210 -0
  218. nat/utils/exception_handlers/automatic_retries.py +104 -51
  219. nat/utils/exception_handlers/schemas.py +1 -1
  220. nat/utils/io/yaml_tools.py +2 -2
  221. nat/utils/log_levels.py +25 -0
  222. nat/utils/reactive/base/observable_base.py +2 -2
  223. nat/utils/reactive/base/observer_base.py +1 -1
  224. nat/utils/reactive/observable.py +2 -2
  225. nat/utils/reactive/observer.py +4 -4
  226. nat/utils/reactive/subscription.py +1 -1
  227. nat/utils/settings/global_settings.py +6 -8
  228. nat/utils/type_converter.py +4 -3
  229. nat/utils/type_utils.py +9 -5
  230. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/METADATA +42 -18
  231. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/RECORD +238 -196
  232. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/entry_points.txt +1 -0
  233. nat/cli/commands/info/list_mcp.py +0 -304
  234. nat/tool/github_tools/create_github_commit.py +0 -133
  235. nat/tool/github_tools/create_github_issue.py +0 -87
  236. nat/tool/github_tools/create_github_pr.py +0 -106
  237. nat/tool/github_tools/get_github_file.py +0 -106
  238. nat/tool/github_tools/get_github_issue.py +0 -166
  239. nat/tool/github_tools/get_github_pr.py +0 -256
  240. nat/tool/github_tools/update_github_issue.py +0 -100
  241. nat/tool/mcp/exceptions.py +0 -142
  242. nat/tool/mcp/mcp_client.py +0 -255
  243. nat/tool/mcp/mcp_tool.py +0 -96
  244. nat/utils/exception_handlers/mcp.py +0 -211
  245. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  246. /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
  247. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/WHEEL +0 -0
  248. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  249. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  250. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/top_level.txt +0 -0
@@ -45,7 +45,7 @@ class FileTelemetryExporterConfig(TelemetryExporterBaseConfig, name="file"):
45
45
 
46
46
 
47
47
  @register_telemetry_exporter(config_type=FileTelemetryExporterConfig)
48
- async def file_telemetry_exporter(config: FileTelemetryExporterConfig, builder: Builder): # pylint: disable=W0613
48
+ async def file_telemetry_exporter(config: FileTelemetryExporterConfig, builder: Builder):
49
49
  """
50
50
  Build and return a FileExporter for file-based telemetry export with optional rolling.
51
51
  """
@@ -68,12 +68,14 @@ class ConsoleLoggingMethodConfig(LoggingBaseConfig, name="console"):
68
68
 
69
69
 
70
70
  @register_logging_method(config_type=ConsoleLoggingMethodConfig)
71
- async def console_logging_method(config: ConsoleLoggingMethodConfig, builder: Builder): # pylint: disable=W0613
71
+ async def console_logging_method(config: ConsoleLoggingMethodConfig, builder: Builder):
72
72
  """
73
73
  Build and return a StreamHandler for console-based logging.
74
74
  """
75
+ import sys
76
+
75
77
  level = getattr(logging, config.level.upper(), logging.INFO)
76
- handler = logging.StreamHandler()
78
+ handler = logging.StreamHandler(stream=sys.stdout)
77
79
  handler.setLevel(level)
78
80
  yield handler
79
81
 
@@ -86,7 +88,7 @@ class FileLoggingMethod(LoggingBaseConfig, name="file"):
86
88
 
87
89
 
88
90
  @register_logging_method(config_type=FileLoggingMethod)
89
- async def file_logging_method(config: FileLoggingMethod, builder: Builder): # pylint: disable=W0613
91
+ async def file_logging_method(config: FileLoggingMethod, builder: Builder):
90
92
  """
91
93
  Build and return a FileHandler for file-based logging.
92
94
  """
@@ -442,7 +442,7 @@ class CalcRunner:
442
442
  runtime_fit=self.linear_analyzer.wf_runtime_fit # May be None
443
443
  )
444
444
  except Exception as e:
445
- logger.exception("Failed to plot concurrency vs. time metrics: %s", e, exc_info=True)
445
+ logger.exception("Failed to plot concurrency vs. time metrics: %s", e)
446
446
  logger.warning("Skipping plot of concurrency vs. time metrics")
447
447
 
448
448
  def write_output(self, output_dir: Path, calc_runner_output: CalcRunnerOutput):
@@ -506,11 +506,10 @@ class CalcRunner:
506
506
  continue
507
507
  try:
508
508
  calc_output = CalcRunnerOutput.model_validate_json(calc_runner_output_path.read_text())
509
- except ValidationError as e:
509
+ except ValidationError:
510
510
  logger.exception("Failed to validate calc runner output file %s. Skipping job %s.",
511
511
  calc_runner_output_path,
512
- e,
513
- exc_info=True)
512
+ job_dir.name)
514
513
  continue
515
514
 
516
515
  # Extract sizing metrics from calc_data
@@ -144,7 +144,7 @@ class AgnoProfilerHandler(BaseProfilerCallback):
144
144
  return result
145
145
 
146
146
  except Exception as e:
147
- logger.exception("Tool execution error: %s", e)
147
+ logger.error("Tool execution error: %s", e)
148
148
  raise
149
149
 
150
150
  return wrapped_tool_execute
@@ -53,7 +53,7 @@ def _extract_tools_schema(invocation_params: dict) -> list:
53
53
  return tools_schema
54
54
 
55
55
 
56
- class LangchainProfilerHandler(AsyncCallbackHandler, BaseProfilerCallback): # pylint: disable=R0901
56
+ class LangchainProfilerHandler(AsyncCallbackHandler, BaseProfilerCallback):
57
57
  """Callback Handler that tracks NIM info."""
58
58
 
59
59
  total_tokens: int = 0
@@ -106,7 +106,7 @@ class LangchainProfilerHandler(AsyncCallbackHandler, BaseProfilerCallback): # p
106
106
  try:
107
107
  model_name = kwargs.get("metadata")["ls_model_name"]
108
108
  except Exception as e:
109
- logger.exception("Error getting model name: %s", e, exc_info=True)
109
+ logger.exception("Error getting model name: %s", e)
110
110
 
111
111
  run_id = str(kwargs.get("run_id", str(uuid4())))
112
112
  self._run_id_to_model_name[run_id] = model_name
@@ -144,7 +144,7 @@ class LangchainProfilerHandler(AsyncCallbackHandler, BaseProfilerCallback): # p
144
144
  try:
145
145
  model_name = metadata["ls_model_name"] if metadata else kwargs.get("metadata")["ls_model_name"]
146
146
  except Exception as e:
147
- logger.exception("Error getting model name: %s", e, exc_info=True)
147
+ logger.exception("Error getting model name: %s", e)
148
148
 
149
149
  run_id = str(run_id)
150
150
  self._run_id_to_model_name[run_id] = model_name
@@ -173,13 +173,13 @@ class LangchainProfilerHandler(AsyncCallbackHandler, BaseProfilerCallback): # p
173
173
  try:
174
174
  model_name = self._run_id_to_model_name.get(str(kwargs.get("run_id", "")), "")
175
175
  except Exception as e:
176
- logger.exception("Error getting model name: %s", e, exc_info=True)
176
+ logger.exception("Error getting model name: %s", e)
177
177
 
178
178
  usage_metadata = {}
179
179
  try:
180
180
  usage_metadata = kwargs.get("chunk").message.usage_metadata if kwargs.get("chunk") else {}
181
181
  except Exception as e:
182
- logger.exception("Error getting usage metadata: %s", e, exc_info=True)
182
+ logger.exception("Error getting usage metadata: %s", e)
183
183
 
184
184
  stats = IntermediateStepPayload(
185
185
  event_type=IntermediateStepType.LLM_NEW_TOKEN,
@@ -206,7 +206,7 @@ class LangchainProfilerHandler(AsyncCallbackHandler, BaseProfilerCallback): # p
206
206
  try:
207
207
  model_name = self._run_id_to_model_name.get(str(kwargs.get("run_id", "")), "")
208
208
  except Exception as e_inner:
209
- logger.exception("Error getting model name: %s from outer error %s", e_inner, e, exc_info=True)
209
+ logger.exception("Error getting model name: %s from outer error %s", e_inner, e)
210
210
 
211
211
  try:
212
212
  generation = response.generations[0][0]
@@ -94,7 +94,7 @@ class LlamaIndexProfilerHandler(BaseCallbackHandler, BaseProfilerCallback):
94
94
  try:
95
95
  model_name = payload.get(EventPayload.SERIALIZED)['model']
96
96
  except Exception as e:
97
- logger.exception("Error getting model name: %s", e, exc_info=True)
97
+ logger.exception("Error getting model name: %s", e)
98
98
 
99
99
  llm_text_input = " ".join(prompts_or_messages) if prompts_or_messages else ""
100
100
 
@@ -159,13 +159,13 @@ class LlamaIndexProfilerHandler(BaseCallbackHandler, BaseProfilerCallback):
159
159
  for block in response.message.blocks:
160
160
  llm_text_output += block.text
161
161
  except Exception as e:
162
- logger.exception("Error getting LLM text output: %s", e, exc_info=True)
162
+ logger.exception("Error getting LLM text output: %s", e)
163
163
 
164
164
  model_name = ""
165
165
  try:
166
166
  model_name = response.raw.model
167
167
  except Exception as e:
168
- logger.exception("Error getting model name: %s", e, exc_info=True)
168
+ logger.exception("Error getting model name: %s", e)
169
169
 
170
170
  # Append usage data to NAT usage stats
171
171
  with self._lock:
@@ -86,7 +86,7 @@ class SemanticKernelProfilerHandler(BaseProfilerCallback):
86
86
 
87
87
  # Gather the appropriate modules/functions based on your builder config
88
88
  for llm in self._builder_llms:
89
- if self._builder_llms[llm].provider_type == 'openai': # pylint: disable=consider-using-in
89
+ if self._builder_llms[llm].provider_type == 'openai':
90
90
  functions_to_patch.extend(["openai_non_streaming", "openai_streaming"])
91
91
 
92
92
  # Grab original reference for the tool call
@@ -132,7 +132,7 @@ class SemanticKernelProfilerHandler(BaseProfilerCallback):
132
132
  if "text" in item:
133
133
  model_input += item["text"]
134
134
  except Exception as e:
135
- logger.exception("Error in getting model input: %s", e, exc_info=True)
135
+ logger.exception("Error in getting model input: %s", e)
136
136
 
137
137
  input_stats = IntermediateStepPayload(event_type=IntermediateStepType.LLM_START,
138
138
  framework=LLMFrameworkEnum.SEMANTIC_KERNEL,
@@ -232,7 +232,7 @@ class SemanticKernelProfilerHandler(BaseProfilerCallback):
232
232
  return result
233
233
 
234
234
  except Exception as e:
235
- logger.exception("ToolUsage._use error: %s", e)
235
+ logger.error("ToolUsage._use error: %s", e)
236
236
  raise
237
237
 
238
238
  return patched_tool_call
@@ -42,7 +42,7 @@ class DataFrameRow(BaseModel):
42
42
  framework: str | None
43
43
 
44
44
  @field_validator('llm_text_input', 'llm_text_output', 'llm_new_token', mode='before')
45
- def cast_to_str(cls, v): # pylint: disable=no-self-argument
45
+ def cast_to_str(cls, v):
46
46
  if v is None:
47
47
  return v
48
48
  try:
@@ -13,12 +13,11 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- # pylint disable=ungrouped-imports
17
-
18
16
  from __future__ import annotations
19
17
 
20
18
  import functools
21
19
  import logging
20
+ from collections.abc import AsyncIterator
22
21
  from collections.abc import Callable
23
22
  from contextlib import AbstractAsyncContextManager as AsyncContextManager
24
23
  from contextlib import asynccontextmanager
@@ -34,35 +33,55 @@ _library_instrumented = {
34
33
  "crewai": False,
35
34
  "semantic_kernel": False,
36
35
  "agno": False,
36
+ "adk": False,
37
37
  }
38
38
 
39
39
  callback_handler_var: ContextVar[Any | None] = ContextVar("callback_handler_var", default=None)
40
40
 
41
41
 
42
42
  def set_framework_profiler_handler(
43
- workflow_llms: dict = None,
44
- frameworks: list[LLMFrameworkEnum] = None,
43
+ workflow_llms: dict | None = None,
44
+ frameworks: list[LLMFrameworkEnum] | None = None,
45
45
  ) -> Callable[[Callable[..., AsyncContextManager[Any]]], Callable[..., AsyncContextManager[Any]]]:
46
46
  """
47
47
  Decorator that wraps an async context manager function to set up framework-specific profiling.
48
+
49
+ Args:
50
+ workflow_llms (dict | None): A dictionary of workflow LLM configurations.
51
+ frameworks (list[LLMFrameworkEnum] | None): A list of LLM frameworks used in the workflow functions.
52
+
53
+ Returns:
54
+ Callable[[Callable[..., AsyncContextManager[Any]]], Callable[..., AsyncContextManager[Any]]]:
55
+ A decorator that wraps the original function with profiling setup.
48
56
  """
49
57
 
50
58
  def decorator(func: Callable[..., AsyncContextManager[Any]]) -> Callable[..., AsyncContextManager[Any]]:
59
+ """The actual decorator that wraps the function.
60
+
61
+ Args:
62
+ func (Callable[..., AsyncContextManager[Any]]): The function to wrap.
63
+
64
+ Returns:
65
+ Callable[..., AsyncContextManager[Any]]: The wrapped function.
66
+ """
51
67
 
52
68
  @functools.wraps(func)
53
69
  @asynccontextmanager
54
70
  async def wrapper(workflow_config, builder):
55
71
 
56
- if LLMFrameworkEnum.LANGCHAIN in frameworks and not _library_instrumented["langchain"]:
57
- from langchain_core.tracers.context import register_configure_hook
58
-
72
+ if LLMFrameworkEnum.LANGCHAIN in frameworks:
73
+ # Always set a fresh handler in the current context so callbacks
74
+ # route to the active run. Only register the hook once globally.
59
75
  from nat.profiler.callbacks.langchain_callback_handler import LangchainProfilerHandler
60
76
 
61
77
  handler = LangchainProfilerHandler()
62
78
  callback_handler_var.set(handler)
63
- register_configure_hook(callback_handler_var, inheritable=True)
64
- _library_instrumented["langchain"] = True
65
- logger.debug("Langchain callback handler registered")
79
+
80
+ if not _library_instrumented["langchain"]:
81
+ from langchain_core.tracers.context import register_configure_hook
82
+ register_configure_hook(callback_handler_var, inheritable=True)
83
+ _library_instrumented["langchain"] = True
84
+ logger.debug("LangChain/LangGraph callback hook registered")
66
85
 
67
86
  if LLMFrameworkEnum.LLAMA_INDEX in frameworks:
68
87
  from llama_index.core import Settings
@@ -75,8 +94,7 @@ def set_framework_profiler_handler(
75
94
  logger.debug("LlamaIndex callback handler registered")
76
95
 
77
96
  if LLMFrameworkEnum.CREWAI in frameworks and not _library_instrumented["crewai"]:
78
- from nat.plugins.crewai.crewai_callback_handler import \
79
- CrewAIProfilerHandler # pylint: disable=ungrouped-imports,line-too-long # noqa E501
97
+ from nat.plugins.crewai.crewai_callback_handler import CrewAIProfilerHandler
80
98
 
81
99
  handler = CrewAIProfilerHandler()
82
100
  handler.instrument()
@@ -99,6 +117,20 @@ def set_framework_profiler_handler(
99
117
  _library_instrumented["agno"] = True
100
118
  logger.info("Agno callback handler registered")
101
119
 
120
+ if LLMFrameworkEnum.ADK in frameworks and not _library_instrumented["adk"]:
121
+ try:
122
+ from nat.plugins.adk.adk_callback_handler import ADKProfilerHandler
123
+ except ImportError as e:
124
+ logger.warning(
125
+ "ADK profiler not available. " +
126
+ "Install NAT with ADK extras: pip install \"nvidia-nat[adk]\". Error: %s",
127
+ e)
128
+ else:
129
+ handler = ADKProfilerHandler()
130
+ handler.instrument()
131
+ _library_instrumented["adk"] = True
132
+ logger.debug("ADK callback handler registered")
133
+
102
134
  # IMPORTANT: actually call the wrapped function as an async context manager
103
135
  async with func(workflow_config, builder) as result:
104
136
  yield result
@@ -117,11 +149,28 @@ def chain_wrapped_build_fn(
117
149
  Convert an original build function into an async context manager that
118
150
  wraps it with a single call to set_framework_profiler_handler, passing
119
151
  all frameworks at once.
152
+
153
+ Args:
154
+ original_build_fn (Callable[..., AsyncContextManager]): The original build function to wrap.
155
+ workflow_llms (dict): A dictionary of workflow LLM configurations.
156
+ function_frameworks (list[LLMFrameworkEnum]): A list of LLM frameworks used in the workflow functions.
157
+
158
+ Returns:
159
+ Callable[..., AsyncContextManager]: The wrapped build function.
120
160
  """
121
161
 
122
162
  # Define a base async context manager that simply calls the original build function.
123
163
  @asynccontextmanager
124
- async def base_fn(*args, **kwargs):
164
+ async def base_fn(*args, **kwargs) -> AsyncIterator[Any]:
165
+ """Base async context manager that calls the original build function.
166
+
167
+ Args:
168
+ *args: Positional arguments to pass to the original build function.
169
+ **kwargs: Keyword arguments to pass to the original build function.
170
+
171
+ Yields:
172
+ The result of the original build function.
173
+ """
125
174
  async with original_build_fn(*args, **kwargs) as w:
126
175
  yield w
127
176
 
@@ -16,7 +16,11 @@
16
16
  import functools
17
17
  import inspect
18
18
  import uuid
19
+ from collections.abc import Callable
19
20
  from typing import Any
21
+ from typing import TypeVar
22
+ from typing import cast
23
+ from typing import overload
20
24
 
21
25
  from pydantic import BaseModel
22
26
 
@@ -36,10 +40,10 @@ def _serialize_data(obj: Any) -> Any:
36
40
 
37
41
  if isinstance(obj, dict):
38
42
  return {str(k): _serialize_data(v) for k, v in obj.items()}
39
- if isinstance(obj, (list, tuple, set)):
43
+ if isinstance(obj, list | tuple | set):
40
44
  return [_serialize_data(item) for item in obj]
41
45
 
42
- if isinstance(obj, (str, int, float, bool, type(None))):
46
+ if isinstance(obj, str | int | float | bool | type(None)):
43
47
  return obj
44
48
 
45
49
  # Fallback
@@ -75,7 +79,24 @@ def push_intermediate_step(step_manager: IntermediateStepManager,
75
79
  step_manager.push_intermediate_step(payload)
76
80
 
77
81
 
78
- def track_function(func: Any = None, *, metadata: dict[str, Any] | None = None):
82
+ # Type variable for overloads
83
+ F = TypeVar('F', bound=Callable[..., Any])
84
+
85
+
86
+ # Overloads for different function types
87
+ @overload
88
+ def track_function(func: F, *, metadata: dict[str, Any] | None = None) -> F:
89
+ """Overload for when a function is passed directly."""
90
+ ...
91
+
92
+
93
+ @overload
94
+ def track_function(*, metadata: dict[str, Any] | None = None) -> Callable[[F], F]:
95
+ """Overload for decorator factory usage (when called with parentheses)."""
96
+ ...
97
+
98
+
99
+ def track_function(func: Any = None, *, metadata: dict[str, Any] | None = None) -> Any:
79
100
  """
80
101
  Decorator that can wrap any type of function (sync, async, generator,
81
102
  async generator) and executes "tracking logic" around it.
@@ -252,3 +273,139 @@ def track_function(func: Any = None, *, metadata: dict[str, Any] | None = None):
252
273
  return result
253
274
 
254
275
  return sync_wrapper
276
+
277
+
278
+ # Overloads for track_unregistered_function
279
+ @overload
280
+ def track_unregistered_function(func: F, *, name: str | None = None, metadata: dict[str, Any] | None = None) -> F:
281
+ """Overload for when a function is passed directly."""
282
+ ...
283
+
284
+
285
+ @overload
286
+ def track_unregistered_function(*, name: str | None = None, metadata: dict[str, Any] | None = None) -> Callable[[F], F]:
287
+ """Overload for decorator factory usage (when called with parentheses)."""
288
+ ...
289
+
290
+
291
+ def track_unregistered_function(func: Callable[..., Any] | None = None,
292
+ *,
293
+ name: str | None = None,
294
+ metadata: dict[str, Any] | None = None) -> Callable[..., Any]:
295
+ """
296
+ Decorator that wraps any function with scope management and automatic tracking.
297
+
298
+ - Sets active function context using the function name
299
+ - Leverages Context.push_active_function for built-in tracking
300
+ - Avoids duplicate tracking entries by relying on the library's built-in systems
301
+ - Supports sync/async functions and generators
302
+
303
+ Args:
304
+ func: The function to wrap (auto-detected when used without parentheses)
305
+ name: Custom name to use for tracking instead of func.__name__
306
+ metadata: Additional metadata to include in tracking
307
+ """
308
+
309
+ # If called with parameters: @track_unregistered_function(name="...", metadata={...})
310
+ if func is None:
311
+
312
+ def decorator_wrapper(actual_func: Callable[..., Any]) -> Callable[..., Any]:
313
+ # Cast to ensure type checker understands this returns a callable
314
+ return cast(Callable[..., Any], track_unregistered_function(actual_func, name=name, metadata=metadata))
315
+
316
+ return decorator_wrapper
317
+
318
+ # Direct decoration: @track_unregistered_function or recursive call with actual function
319
+ function_name: str = name if name else func.__name__
320
+
321
+ # --- Validate metadata ---
322
+ if metadata is not None:
323
+ if not isinstance(metadata, dict):
324
+ raise TypeError("metadata must be a dict[str, Any].")
325
+ if any(not isinstance(k, str) for k in metadata.keys()):
326
+ raise TypeError("All metadata keys must be strings.")
327
+
328
+ trace_metadata = TraceMetadata(provided_metadata=metadata)
329
+
330
+ # --- Now detect the function type and wrap accordingly ---
331
+ if inspect.isasyncgenfunction(func):
332
+ # ---------------------
333
+ # ASYNC GENERATOR
334
+ # ---------------------
335
+
336
+ @functools.wraps(func)
337
+ async def async_gen_wrapper(*args, **kwargs):
338
+ context = Context.get()
339
+ input_data = (
340
+ *args,
341
+ kwargs,
342
+ )
343
+ # Only do context management - let push_active_function handle tracking
344
+ with context.push_active_function(function_name, input_data=input_data, metadata=trace_metadata) as manager:
345
+ final_outputs = []
346
+ async for item in func(*args, **kwargs):
347
+ final_outputs.append(item)
348
+ yield item
349
+
350
+ manager.set_output(final_outputs)
351
+
352
+ return async_gen_wrapper
353
+
354
+ if inspect.iscoroutinefunction(func):
355
+ # ---------------------
356
+ # ASYNC FUNCTION
357
+ # ---------------------
358
+ @functools.wraps(func)
359
+ async def async_wrapper(*args, **kwargs):
360
+ context = Context.get()
361
+ input_data = (
362
+ *args,
363
+ kwargs,
364
+ )
365
+
366
+ # Only do context management - let push_active_function handle tracking
367
+ with context.push_active_function(function_name, input_data=input_data, metadata=trace_metadata) as manager:
368
+ result = await func(*args, **kwargs)
369
+ manager.set_output(result)
370
+ return result
371
+
372
+ return async_wrapper
373
+
374
+ if inspect.isgeneratorfunction(func):
375
+ # ---------------------
376
+ # SYNC GENERATOR
377
+ # ---------------------
378
+ @functools.wraps(func)
379
+ def sync_gen_wrapper(*args, **kwargs):
380
+ context = Context.get()
381
+ input_data = (
382
+ *args,
383
+ kwargs,
384
+ )
385
+
386
+ # Only do context management - let push_active_function handle tracking
387
+ with context.push_active_function(function_name, input_data=input_data, metadata=trace_metadata) as manager:
388
+ final_outputs = []
389
+ for item in func(*args, **kwargs):
390
+ final_outputs.append(item)
391
+ yield item
392
+
393
+ manager.set_output(final_outputs)
394
+
395
+ return sync_gen_wrapper
396
+
397
+ @functools.wraps(func)
398
+ def sync_wrapper(*args, **kwargs):
399
+ context = Context.get()
400
+ input_data = (
401
+ *args,
402
+ kwargs,
403
+ )
404
+
405
+ # Only do context management - let push_active_function handle tracking
406
+ with context.push_active_function(function_name, input_data=input_data, metadata=trace_metadata) as manager:
407
+ result = func(*args, **kwargs)
408
+ manager.set_output(result)
409
+ return result
410
+
411
+ return sync_wrapper
@@ -15,7 +15,9 @@
15
15
 
16
16
  # forecasting/models/base_model.py
17
17
 
18
- from abc import ABC, abstractmethod
18
+ from abc import ABC
19
+ from abc import abstractmethod
20
+
19
21
  import numpy as np
20
22
 
21
23
 
@@ -36,7 +36,7 @@ class LinearModel(ForecastingBaseModel):
36
36
  except ImportError:
37
37
  logger.error(
38
38
  "scikit-learn is not installed. Please install scikit-learn to use the LinearModel "
39
- "profiling model or install `nvidia-nat[profiler]` to install all necessary profiling packages.")
39
+ "profiling model or install \"nvidia-nat[profiler]\" to install all necessary profiling packages.")
40
40
 
41
41
  raise
42
42
 
@@ -36,7 +36,7 @@ class RandomForestModel(ForecastingBaseModel):
36
36
  except ImportError:
37
37
  logger.error(
38
38
  "scikit-learn is not installed. Please install scikit-learn to use the RandomForest "
39
- "profiling model or install `nvidia-nat[profiler]` to install all necessary profiling packages.")
39
+ "profiling model or install \"nvidia-nat[profiler]\" to install all necessary profiling packages.")
40
40
 
41
41
  raise
42
42
 
@@ -304,7 +304,7 @@ def save_gantt_chart(all_nodes: list[CallNode], output_path: str) -> None:
304
304
  import matplotlib.pyplot as plt
305
305
  except ImportError:
306
306
  logger.error("matplotlib is not installed. Please install matplotlib to use generate plots for the profiler "
307
- "or install `nvidia-nat[profiler]` to install all necessary profiling packages.")
307
+ "or install \"nvidia-nat[profiler]\" to install all necessary profiling packages.")
308
308
 
309
309
  raise
310
310
 
@@ -195,7 +195,7 @@ def profile_workflow_bottlenecks(all_steps: list[list[IntermediateStep]]) -> Sim
195
195
  c_max = 0
196
196
  for ts, delta in events_sub:
197
197
  c_curr += delta
198
- if c_curr > c_max: # pylint: disable=consider-using-max-builtin
198
+ if c_curr > c_max: # noqa: PLR1730 - don't use max built-in
199
199
  c_max = c_curr
200
200
  max_concurrency_by_name[op_name] = c_max
201
201
 
@@ -172,7 +172,7 @@ class CallNode(BaseModel):
172
172
  if not self.children:
173
173
  return self.duration
174
174
 
175
- intervals = [(c.start_time, c.end_time) for c in self.children] # pylint: disable=not-an-iterable
175
+ intervals = [(c.start_time, c.end_time) for c in self.children]
176
176
  # Sort by start time
177
177
  intervals.sort(key=lambda x: x[0])
178
178
 
@@ -204,7 +204,7 @@ class CallNode(BaseModel):
204
204
  This ensures no overlap double-counting among children.
205
205
  """
206
206
  total = self.compute_self_time()
207
- for c in self.children: # pylint: disable=not-an-iterable
207
+ for c in self.children:
208
208
  total += c.compute_subtree_time()
209
209
  return total
210
210
 
@@ -216,7 +216,7 @@ class CallNode(BaseModel):
216
216
  info = (f"{indent}- {self.operation_type} '{self.operation_name}' "
217
217
  f"(uuid={self.uuid}, start={self.start_time:.2f}, "
218
218
  f"end={self.end_time:.2f}, dur={self.duration:.2f})")
219
- child_strs = [child._repr(level + 1) for child in self.children] # pylint: disable=not-an-iterable
219
+ child_strs = [child._repr(level + 1) for child in self.children]
220
220
  return "\n".join([info] + child_strs)
221
221
 
222
222
 
@@ -212,7 +212,7 @@ def run_prefixspan(sequences_map: dict[int, list[PrefixCallNode]],
212
212
  from prefixspan import PrefixSpan
213
213
  except ImportError:
214
214
  logger.error("prefixspan is not installed. Please install prefixspan to run the prefix analysis in the "
215
- "profiler or install `nvidia-nat[profiler]` to install all necessary profiling packages.")
215
+ "profiler or install \"nvidia-nat[profiler]\" to install all necessary profiling packages.")
216
216
 
217
217
  raise
218
218
 
@@ -228,7 +228,7 @@ def run_prefixspan(sequences_map: dict[int, list[PrefixCallNode]],
228
228
  else:
229
229
  abs_min_support = min_support
230
230
 
231
- freq_patterns = ps.frequent(abs_min_support) # pylint: disable=not-callable
231
+ freq_patterns = ps.frequent(abs_min_support)
232
232
  # freq_patterns => [(count, [item1, item2, ...])]
233
233
 
234
234
  results = []
@@ -321,13 +321,12 @@ def compute_coverage_and_duration(sequences_map: dict[int, list[PrefixCallNode]]
321
321
  # --------------------------------------------------------------------------------
322
322
 
323
323
 
324
- def prefixspan_subworkflow_with_text( # pylint: disable=too-many-positional-arguments
325
- all_steps: list[list[IntermediateStep]],
326
- min_support: int | float = 2,
327
- top_k: int = 10,
328
- min_coverage: float = 0.0,
329
- max_text_len: int = 700,
330
- prefix_list: list[str] = None) -> PrefixSpanSubworkflowResult:
324
+ def prefixspan_subworkflow_with_text(all_steps: list[list[IntermediateStep]],
325
+ min_support: int | float = 2,
326
+ top_k: int = 10,
327
+ min_coverage: float = 0.0,
328
+ max_text_len: int = 700,
329
+ prefix_list: list[str] = None) -> PrefixSpanSubworkflowResult:
331
330
  """
332
331
  1) Build sequences of calls for each example (with llm_text_input).
333
332
  2) Convert to token lists, run PrefixSpan with min_support.
@@ -66,7 +66,7 @@ def compute_inter_query_token_uniqueness_by_llm(all_steps: list[list[Intermediat
66
66
  # 2) Group by (llm_name, example_number), then sort each group
67
67
  grouped = cdf.groupby(['llm_name', 'example_number'], as_index=False, group_keys=True)
68
68
 
69
- for (llm, ex_num), group_df in grouped: # pylint: disable=unused-variable
69
+ for (llm, ex_num), group_df in grouped:
70
70
  # Sort by event_timestamp
71
71
  group_df = group_df.sort_values('event_timestamp', ascending=True)
72
72
 
File without changes