nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +2 -2
- nat/agent/base.py +24 -15
- nat/agent/dual_node.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +79 -47
- nat/agent/react_agent/register.py +50 -22
- nat/agent/reasoning_agent/reasoning_agent.py +11 -9
- nat/agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +326 -148
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +54 -27
- nat/agent/tool_calling_agent/agent.py +84 -28
- nat/agent/tool_calling_agent/register.py +51 -28
- nat/authentication/api_key/api_key_auth_provider.py +2 -2
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +56 -24
- nat/builder/component_utils.py +9 -5
- nat/builder/context.py +68 -17
- nat/builder/eval_builder.py +16 -11
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +378 -8
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +6 -8
- nat/builder/user_interaction_manager.py +2 -2
- nat/builder/workflow.py +13 -1
- nat/builder/workflow_builder.py +281 -76
- nat/cli/cli_utils/config_override.py +2 -2
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/publish.py +2 -2
- nat/cli/commands/registry/pull.py +2 -2
- nat/cli/commands/registry/remove.py +2 -2
- nat/cli/commands/registry/search.py +15 -17
- nat/cli/commands/start.py +16 -5
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
- nat/cli/commands/workflow/templates/register.py.j2 +2 -3
- nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
- nat/cli/commands/workflow/workflow_commands.py +62 -22
- nat/cli/entrypoint.py +8 -10
- nat/cli/main.py +3 -0
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +75 -6
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +74 -66
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +41 -17
- nat/data_models/dataset_handler.py +1 -1
- nat/data_models/discovery_metadata.py +4 -4
- nat/data_models/evaluate.py +4 -1
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +14 -6
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/span.py +41 -3
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/embedder/register.py +0 -1
- nat/eval/config.py +3 -1
- nat/eval/dataset_handler/dataset_handler.py +71 -7
- nat/eval/evaluate.py +86 -31
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/evaluator/evaluator_model.py +13 -0
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +2 -2
- nat/eval/rag_evaluator/register.py +3 -3
- nat/eval/register.py +4 -1
- nat/eval/remote_workflow.py +3 -3
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +6 -6
- nat/eval/trajectory_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/weave_eval.py +18 -9
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
- nat/experimental/test_time_compute/models/strategy_base.py +5 -4
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +8 -5
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
- nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +452 -282
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +13 -14
- nat/front_ends/fastapi/message_validator.py +19 -19
- nat/front_ends/fastapi/response_helpers.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +2 -2
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +10 -1
- nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
- nat/front_ends/mcp/tool_converter.py +44 -14
- nat/front_ends/register.py +0 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +24 -12
- nat/llm/azure_openai_llm.py +13 -6
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +20 -8
- nat/llm/openai_llm.py +14 -6
- nat/llm/register.py +4 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/llm/utils/thinking.py +215 -0
- nat/meta/pypi.md +9 -9
- nat/object_store/register.py +0 -1
- nat/observability/exporter/base_exporter.py +3 -3
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +309 -81
- nat/observability/exporter/span_exporter.py +35 -15
- nat/observability/exporter_manager.py +7 -7
- nat/observability/mixin/file_mixin.py +7 -7
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +5 -7
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +6 -4
- nat/profiler/calc/calc_runner.py +3 -4
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +6 -6
- nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +62 -13
- nat/profiler/decorators/function_tracking.py +160 -3
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/forecasting/models/linear_model.py +1 -1
- nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +8 -9
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +14 -9
- nat/profiler/utils.py +4 -2
- nat/registry_handlers/local/local_handler.py +2 -2
- nat/registry_handlers/package_utils.py +1 -2
- nat/registry_handlers/pypi/pypi_handler.py +23 -26
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +12 -13
- nat/retriever/milvus/retriever.py +2 -2
- nat/retriever/nemo_retriever/retriever.py +1 -1
- nat/retriever/register.py +0 -1
- nat/runtime/loader.py +2 -2
- nat/runtime/runner.py +106 -8
- nat/runtime/session.py +69 -8
- nat/settings/global_settings.py +16 -5
- nat/tool/chat_completion.py +5 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
- nat/tool/datetime_tools.py +49 -9
- nat/tool/document_search.py +2 -2
- nat/tool/github_tools.py +450 -0
- nat/tool/memory_tools/get_memory_tool.py +1 -1
- nat/tool/nvidia_rag.py +1 -1
- nat/tool/register.py +2 -9
- nat/tool/retriever.py +3 -2
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +3 -3
- nat/utils/decorators.py +210 -0
- nat/utils/exception_handlers/automatic_retries.py +104 -51
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/yaml_tools.py +2 -2
- nat/utils/log_levels.py +25 -0
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +4 -4
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +6 -8
- nat/utils/type_converter.py +4 -3
- nat/utils/type_utils.py +9 -5
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/METADATA +42 -18
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/RECORD +238 -196
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -304
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nat/tool/mcp/exceptions.py +0 -142
- nat/tool/mcp/mcp_client.py +0 -255
- nat/tool/mcp/mcp_tool.py +0 -96
- nat/utils/exception_handlers/mcp.py +0 -211
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/top_level.txt +0 -0
nat/data_models/api_server.py
CHANGED
|
@@ -36,6 +36,15 @@ from nat.utils.type_converter import GlobalTypeConverter
|
|
|
36
36
|
FINISH_REASONS = frozenset({'stop', 'length', 'tool_calls', 'content_filter', 'function_call'})
|
|
37
37
|
|
|
38
38
|
|
|
39
|
+
class UserMessageContentRoleType(str, Enum):
|
|
40
|
+
"""
|
|
41
|
+
Enum representing chat message roles in API requests and responses.
|
|
42
|
+
"""
|
|
43
|
+
USER = "user"
|
|
44
|
+
ASSISTANT = "assistant"
|
|
45
|
+
SYSTEM = "system"
|
|
46
|
+
|
|
47
|
+
|
|
39
48
|
class Request(BaseModel):
|
|
40
49
|
"""
|
|
41
50
|
Request is a data model that represents HTTP request attributes.
|
|
@@ -108,7 +117,7 @@ UserContent = typing.Annotated[TextContent | ImageContent | AudioContent, Discri
|
|
|
108
117
|
|
|
109
118
|
class Message(BaseModel):
|
|
110
119
|
content: str | list[UserContent]
|
|
111
|
-
role:
|
|
120
|
+
role: UserMessageContentRoleType
|
|
112
121
|
|
|
113
122
|
|
|
114
123
|
class ChatRequest(BaseModel):
|
|
@@ -164,7 +173,7 @@ class ChatRequest(BaseModel):
|
|
|
164
173
|
max_tokens: int | None = None,
|
|
165
174
|
top_p: float | None = None) -> "ChatRequest":
|
|
166
175
|
|
|
167
|
-
return ChatRequest(messages=[Message(content=data, role=
|
|
176
|
+
return ChatRequest(messages=[Message(content=data, role=UserMessageContentRoleType.USER)],
|
|
168
177
|
model=model,
|
|
169
178
|
temperature=temperature,
|
|
170
179
|
max_tokens=max_tokens,
|
|
@@ -178,7 +187,7 @@ class ChatRequest(BaseModel):
|
|
|
178
187
|
max_tokens: int | None = None,
|
|
179
188
|
top_p: float | None = None) -> "ChatRequest":
|
|
180
189
|
|
|
181
|
-
return ChatRequest(messages=[Message(content=content, role=
|
|
190
|
+
return ChatRequest(messages=[Message(content=content, role=UserMessageContentRoleType.USER)],
|
|
182
191
|
model=model,
|
|
183
192
|
temperature=temperature,
|
|
184
193
|
max_tokens=max_tokens,
|
|
@@ -187,29 +196,40 @@ class ChatRequest(BaseModel):
|
|
|
187
196
|
|
|
188
197
|
class ChoiceMessage(BaseModel):
|
|
189
198
|
content: str | None = None
|
|
190
|
-
role:
|
|
199
|
+
role: UserMessageContentRoleType | None = None
|
|
191
200
|
|
|
192
201
|
|
|
193
202
|
class ChoiceDelta(BaseModel):
|
|
194
203
|
"""Delta object for streaming responses (OpenAI-compatible)"""
|
|
195
204
|
content: str | None = None
|
|
196
|
-
role:
|
|
205
|
+
role: UserMessageContentRoleType | None = None
|
|
197
206
|
|
|
198
207
|
|
|
199
|
-
class
|
|
208
|
+
class ChoiceBase(BaseModel):
|
|
209
|
+
"""Base choice model with common fields for both streaming and non-streaming responses"""
|
|
200
210
|
model_config = ConfigDict(extra="allow")
|
|
201
|
-
|
|
202
|
-
message: ChoiceMessage | None = None
|
|
203
|
-
delta: ChoiceDelta | None = None
|
|
204
211
|
finish_reason: typing.Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call'] | None = None
|
|
205
212
|
index: int
|
|
206
|
-
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class ChatResponseChoice(ChoiceBase):
|
|
216
|
+
"""Choice model for non-streaming responses - contains message field"""
|
|
217
|
+
message: ChoiceMessage
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
class ChatResponseChunkChoice(ChoiceBase):
|
|
221
|
+
"""Choice model for streaming responses - contains delta field"""
|
|
222
|
+
delta: ChoiceDelta
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
# Backward compatibility alias
|
|
226
|
+
Choice = ChatResponseChoice
|
|
207
227
|
|
|
208
228
|
|
|
209
229
|
class Usage(BaseModel):
|
|
210
|
-
prompt_tokens: int
|
|
211
|
-
completion_tokens: int
|
|
212
|
-
total_tokens: int
|
|
230
|
+
prompt_tokens: int | None = None
|
|
231
|
+
completion_tokens: int | None = None
|
|
232
|
+
total_tokens: int | None = None
|
|
213
233
|
|
|
214
234
|
|
|
215
235
|
class ResponseSerializable(abc.ABC):
|
|
@@ -245,10 +265,10 @@ class ChatResponse(ResponseBaseModelOutput):
|
|
|
245
265
|
model_config = ConfigDict(extra="allow")
|
|
246
266
|
id: str
|
|
247
267
|
object: str = "chat.completion"
|
|
248
|
-
model: str = ""
|
|
268
|
+
model: str = "unknown-model"
|
|
249
269
|
created: datetime.datetime
|
|
250
|
-
choices: list[
|
|
251
|
-
usage: Usage
|
|
270
|
+
choices: list[ChatResponseChoice]
|
|
271
|
+
usage: Usage
|
|
252
272
|
system_fingerprint: str | None = None
|
|
253
273
|
service_tier: typing.Literal["scale", "default"] | None = None
|
|
254
274
|
|
|
@@ -264,22 +284,27 @@ class ChatResponse(ResponseBaseModelOutput):
|
|
|
264
284
|
object_: str | None = None,
|
|
265
285
|
model: str | None = None,
|
|
266
286
|
created: datetime.datetime | None = None,
|
|
267
|
-
usage: Usage
|
|
287
|
+
usage: Usage) -> "ChatResponse":
|
|
268
288
|
|
|
269
289
|
if id_ is None:
|
|
270
290
|
id_ = str(uuid.uuid4())
|
|
271
291
|
if object_ is None:
|
|
272
292
|
object_ = "chat.completion"
|
|
273
293
|
if model is None:
|
|
274
|
-
model = ""
|
|
294
|
+
model = "unknown-model"
|
|
275
295
|
if created is None:
|
|
276
|
-
created = datetime.datetime.now(datetime.
|
|
296
|
+
created = datetime.datetime.now(datetime.UTC)
|
|
277
297
|
|
|
278
298
|
return ChatResponse(id=id_,
|
|
279
299
|
object=object_,
|
|
280
300
|
model=model,
|
|
281
301
|
created=created,
|
|
282
|
-
choices=[
|
|
302
|
+
choices=[
|
|
303
|
+
ChatResponseChoice(index=0,
|
|
304
|
+
message=ChoiceMessage(content=data,
|
|
305
|
+
role=UserMessageContentRoleType.ASSISTANT),
|
|
306
|
+
finish_reason="stop")
|
|
307
|
+
],
|
|
283
308
|
usage=usage)
|
|
284
309
|
|
|
285
310
|
|
|
@@ -293,9 +318,9 @@ class ChatResponseChunk(ResponseBaseModelOutput):
|
|
|
293
318
|
model_config = ConfigDict(extra="allow")
|
|
294
319
|
|
|
295
320
|
id: str
|
|
296
|
-
choices: list[
|
|
321
|
+
choices: list[ChatResponseChunkChoice]
|
|
297
322
|
created: datetime.datetime
|
|
298
|
-
model: str = ""
|
|
323
|
+
model: str = "unknown-model"
|
|
299
324
|
object: str = "chat.completion.chunk"
|
|
300
325
|
system_fingerprint: str | None = None
|
|
301
326
|
service_tier: typing.Literal["scale", "default"] | None = None
|
|
@@ -317,14 +342,20 @@ class ChatResponseChunk(ResponseBaseModelOutput):
|
|
|
317
342
|
if id_ is None:
|
|
318
343
|
id_ = str(uuid.uuid4())
|
|
319
344
|
if created is None:
|
|
320
|
-
created = datetime.datetime.now(datetime.
|
|
345
|
+
created = datetime.datetime.now(datetime.UTC)
|
|
321
346
|
if model is None:
|
|
322
|
-
model = ""
|
|
347
|
+
model = "unknown-model"
|
|
323
348
|
if object_ is None:
|
|
324
349
|
object_ = "chat.completion.chunk"
|
|
325
350
|
|
|
326
351
|
return ChatResponseChunk(id=id_,
|
|
327
|
-
choices=[
|
|
352
|
+
choices=[
|
|
353
|
+
ChatResponseChunkChoice(index=0,
|
|
354
|
+
delta=ChoiceDelta(
|
|
355
|
+
content=data,
|
|
356
|
+
role=UserMessageContentRoleType.ASSISTANT),
|
|
357
|
+
finish_reason="stop")
|
|
358
|
+
],
|
|
328
359
|
created=created,
|
|
329
360
|
model=model,
|
|
330
361
|
object=object_)
|
|
@@ -335,7 +366,7 @@ class ChatResponseChunk(ResponseBaseModelOutput):
|
|
|
335
366
|
id_: str | None = None,
|
|
336
367
|
created: datetime.datetime | None = None,
|
|
337
368
|
model: str | None = None,
|
|
338
|
-
role:
|
|
369
|
+
role: UserMessageContentRoleType | None = None,
|
|
339
370
|
finish_reason: str | None = None,
|
|
340
371
|
usage: Usage | None = None,
|
|
341
372
|
system_fingerprint: str | None = None) -> "ChatResponseChunk":
|
|
@@ -343,9 +374,9 @@ class ChatResponseChunk(ResponseBaseModelOutput):
|
|
|
343
374
|
if id_ is None:
|
|
344
375
|
id_ = str(uuid.uuid4())
|
|
345
376
|
if created is None:
|
|
346
|
-
created = datetime.datetime.now(datetime.
|
|
377
|
+
created = datetime.datetime.now(datetime.UTC)
|
|
347
378
|
if model is None:
|
|
348
|
-
model = ""
|
|
379
|
+
model = "unknown-model"
|
|
349
380
|
|
|
350
381
|
delta = ChoiceDelta(content=content, role=role) if content is not None or role is not None else ChoiceDelta()
|
|
351
382
|
|
|
@@ -353,7 +384,14 @@ class ChatResponseChunk(ResponseBaseModelOutput):
|
|
|
353
384
|
|
|
354
385
|
return ChatResponseChunk(
|
|
355
386
|
id=id_,
|
|
356
|
-
choices=[
|
|
387
|
+
choices=[
|
|
388
|
+
ChatResponseChunkChoice(
|
|
389
|
+
index=0,
|
|
390
|
+
delta=delta,
|
|
391
|
+
finish_reason=typing.cast(
|
|
392
|
+
typing.Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call'] | None,
|
|
393
|
+
final_finish_reason))
|
|
394
|
+
],
|
|
357
395
|
created=created,
|
|
358
396
|
model=model,
|
|
359
397
|
object="chat.completion.chunk",
|
|
@@ -398,11 +436,6 @@ class GenerateResponse(BaseModel):
|
|
|
398
436
|
value: str | None = "default"
|
|
399
437
|
|
|
400
438
|
|
|
401
|
-
class UserMessageContentRoleType(str, Enum):
|
|
402
|
-
USER = "user"
|
|
403
|
-
ASSISTANT = "assistant"
|
|
404
|
-
|
|
405
|
-
|
|
406
439
|
class WebSocketMessageType(str, Enum):
|
|
407
440
|
"""
|
|
408
441
|
WebSocketMessageType is an Enum that represents WebSocket Message types.
|
|
@@ -485,7 +518,7 @@ class WebSocketUserMessage(BaseModel):
|
|
|
485
518
|
security: Security = Security()
|
|
486
519
|
error: Error = Error()
|
|
487
520
|
schema_version: str = "1.0.0"
|
|
488
|
-
timestamp: str = str(datetime.datetime.now(datetime.
|
|
521
|
+
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
489
522
|
|
|
490
523
|
|
|
491
524
|
class WebSocketUserInteractionResponseMessage(BaseModel):
|
|
@@ -501,7 +534,7 @@ class WebSocketUserInteractionResponseMessage(BaseModel):
|
|
|
501
534
|
security: Security = Security()
|
|
502
535
|
error: Error = Error()
|
|
503
536
|
schema_version: str = "1.0.0"
|
|
504
|
-
timestamp: str = str(datetime.datetime.now(datetime.
|
|
537
|
+
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
505
538
|
|
|
506
539
|
|
|
507
540
|
class SystemIntermediateStepContent(BaseModel):
|
|
@@ -527,7 +560,7 @@ class WebSocketSystemIntermediateStepMessage(BaseModel):
|
|
|
527
560
|
conversation_id: str | None = None
|
|
528
561
|
content: SystemIntermediateStepContent
|
|
529
562
|
status: WebSocketMessageStatus
|
|
530
|
-
timestamp: str = str(datetime.datetime.now(datetime.
|
|
563
|
+
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
531
564
|
|
|
532
565
|
|
|
533
566
|
class SystemResponseContent(BaseModel):
|
|
@@ -551,7 +584,7 @@ class WebSocketSystemResponseTokenMessage(BaseModel):
|
|
|
551
584
|
conversation_id: str | None = None
|
|
552
585
|
content: SystemResponseContent | Error | GenerateResponse
|
|
553
586
|
status: WebSocketMessageStatus
|
|
554
|
-
timestamp: str = str(datetime.datetime.now(datetime.
|
|
587
|
+
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
555
588
|
|
|
556
589
|
@field_validator("content")
|
|
557
590
|
@classmethod
|
|
@@ -560,7 +593,7 @@ class WebSocketSystemResponseTokenMessage(BaseModel):
|
|
|
560
593
|
raise ValueError(f"Field: content must be 'Error' when type is {WebSocketMessageType.ERROR_MESSAGE}")
|
|
561
594
|
|
|
562
595
|
if info.data.get("type") == WebSocketMessageType.RESPONSE_MESSAGE and not isinstance(
|
|
563
|
-
value,
|
|
596
|
+
value, SystemResponseContent | GenerateResponse):
|
|
564
597
|
raise ValueError(
|
|
565
598
|
f"Field: content must be 'SystemResponseContent' when type is {WebSocketMessageType.RESPONSE_MESSAGE}")
|
|
566
599
|
return value
|
|
@@ -582,7 +615,7 @@ class WebSocketSystemInteractionMessage(BaseModel):
|
|
|
582
615
|
conversation_id: str | None = None
|
|
583
616
|
content: HumanPrompt
|
|
584
617
|
status: WebSocketMessageStatus
|
|
585
|
-
timestamp: str = str(datetime.datetime.now(datetime.
|
|
618
|
+
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
586
619
|
|
|
587
620
|
|
|
588
621
|
# ======== GenerateResponse Converters ========
|
|
@@ -622,7 +655,7 @@ GlobalTypeConverter.register_converter(_nat_chat_request_to_string)
|
|
|
622
655
|
|
|
623
656
|
|
|
624
657
|
def _string_to_nat_chat_request(data: str) -> ChatRequest:
|
|
625
|
-
return ChatRequest.from_string(data, model="")
|
|
658
|
+
return ChatRequest.from_string(data, model="unknown-model")
|
|
626
659
|
|
|
627
660
|
|
|
628
661
|
GlobalTypeConverter.register_converter(_string_to_nat_chat_request)
|
|
@@ -654,22 +687,12 @@ def _string_to_nat_chat_response(data: str) -> ChatResponse:
|
|
|
654
687
|
GlobalTypeConverter.register_converter(_string_to_nat_chat_response)
|
|
655
688
|
|
|
656
689
|
|
|
657
|
-
def _chat_response_to_chat_response_chunk(data: ChatResponse) -> ChatResponseChunk:
|
|
658
|
-
# Preserve original message structure for backward compatibility
|
|
659
|
-
return ChatResponseChunk(id=data.id, choices=data.choices, created=data.created, model=data.model)
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
GlobalTypeConverter.register_converter(_chat_response_to_chat_response_chunk)
|
|
663
|
-
|
|
664
|
-
|
|
665
690
|
# ======== ChatResponseChunk Converters ========
|
|
666
691
|
def _chat_response_chunk_to_string(data: ChatResponseChunk) -> str:
|
|
667
692
|
if data.choices and len(data.choices) > 0:
|
|
668
693
|
choice = data.choices[0]
|
|
669
694
|
if choice.delta and choice.delta.content:
|
|
670
695
|
return choice.delta.content
|
|
671
|
-
if choice.message and choice.message.content:
|
|
672
|
-
return choice.message.content
|
|
673
696
|
return ""
|
|
674
697
|
|
|
675
698
|
|
|
@@ -685,21 +708,6 @@ def _string_to_nat_chat_response_chunk(data: str) -> ChatResponseChunk:
|
|
|
685
708
|
|
|
686
709
|
GlobalTypeConverter.register_converter(_string_to_nat_chat_response_chunk)
|
|
687
710
|
|
|
688
|
-
|
|
689
|
-
# ======== AINodeMessageChunk Converters ========
|
|
690
|
-
def _ai_message_chunk_to_nat_chat_response_chunk(data) -> ChatResponseChunk:
|
|
691
|
-
'''Converts LangChain AINodeMessageChunk to ChatResponseChunk'''
|
|
692
|
-
content = ""
|
|
693
|
-
if hasattr(data, 'content') and data.content is not None:
|
|
694
|
-
content = str(data.content)
|
|
695
|
-
elif hasattr(data, 'text') and data.text is not None:
|
|
696
|
-
content = str(data.text)
|
|
697
|
-
elif hasattr(data, 'message') and data.message is not None:
|
|
698
|
-
content = str(data.message)
|
|
699
|
-
|
|
700
|
-
return ChatResponseChunk.create_streaming_chunk(content=content, role="assistant", finish_reason=None)
|
|
701
|
-
|
|
702
|
-
|
|
703
711
|
# Compatibility aliases with previous releases
|
|
704
712
|
AIQChatRequest = ChatRequest
|
|
705
713
|
AIQChoiceMessage = ChoiceMessage
|
|
@@ -14,8 +14,8 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import typing
|
|
17
|
+
from datetime import UTC
|
|
17
18
|
from datetime import datetime
|
|
18
|
-
from datetime import timezone
|
|
19
19
|
from enum import Enum
|
|
20
20
|
|
|
21
21
|
import httpx
|
|
@@ -166,17 +166,31 @@ class BearerTokenCred(_CredBase):
|
|
|
166
166
|
|
|
167
167
|
|
|
168
168
|
Credential = typing.Annotated[
|
|
169
|
-
|
|
170
|
-
HeaderCred,
|
|
171
|
-
QueryCred,
|
|
172
|
-
CookieCred,
|
|
173
|
-
BasicAuthCred,
|
|
174
|
-
BearerTokenCred,
|
|
175
|
-
],
|
|
169
|
+
HeaderCred | QueryCred | CookieCred | BasicAuthCred | BearerTokenCred,
|
|
176
170
|
Field(discriminator="kind"),
|
|
177
171
|
]
|
|
178
172
|
|
|
179
173
|
|
|
174
|
+
class TokenValidationResult(BaseModel):
|
|
175
|
+
"""
|
|
176
|
+
Standard result for Bearer Token Validation.
|
|
177
|
+
"""
|
|
178
|
+
model_config = ConfigDict(extra="forbid")
|
|
179
|
+
|
|
180
|
+
client_id: str | None = Field(description="OAuth2 client identifier")
|
|
181
|
+
scopes: list[str] | None = Field(default=None, description="List of granted scopes (introspection only)")
|
|
182
|
+
expires_at: int | None = Field(default=None, description="Token expiration time (Unix timestamp)")
|
|
183
|
+
audience: list[str] | None = Field(default=None, description="Token audiences (aud claim)")
|
|
184
|
+
subject: str | None = Field(default=None, description="Token subject (sub claim)")
|
|
185
|
+
issuer: str | None = Field(default=None, description="Token issuer (iss claim)")
|
|
186
|
+
token_type: str = Field(description="Token type")
|
|
187
|
+
active: bool | None = Field(default=True, description="Token active status")
|
|
188
|
+
nbf: int | None = Field(default=None, description="Not before time (Unix timestamp)")
|
|
189
|
+
iat: int | None = Field(default=None, description="Issued at time (Unix timestamp)")
|
|
190
|
+
jti: str | None = Field(default=None, description="JWT ID")
|
|
191
|
+
username: str | None = Field(default=None, description="Username (introspection only)")
|
|
192
|
+
|
|
193
|
+
|
|
180
194
|
class AuthResult(BaseModel):
|
|
181
195
|
"""
|
|
182
196
|
Represents the result of an authentication process.
|
|
@@ -193,7 +207,7 @@ class AuthResult(BaseModel):
|
|
|
193
207
|
"""
|
|
194
208
|
Checks if the authentication token has expired.
|
|
195
209
|
"""
|
|
196
|
-
return bool(self.token_expires_at and datetime.now(
|
|
210
|
+
return bool(self.token_expires_at and datetime.now(UTC) >= self.token_expires_at)
|
|
197
211
|
|
|
198
212
|
def as_requests_kwargs(self) -> dict[str, typing.Any]:
|
|
199
213
|
"""
|
nat/data_models/common.py
CHANGED
|
@@ -160,7 +160,7 @@ class TypedBaseModel(BaseModel):
|
|
|
160
160
|
|
|
161
161
|
@staticmethod
|
|
162
162
|
def discriminator(v: typing.Any) -> str | None:
|
|
163
|
-
# If
|
|
163
|
+
# If it's serialized, then we use the alias
|
|
164
164
|
if isinstance(v, dict):
|
|
165
165
|
return v.get("_type", v.get("type"))
|
|
166
166
|
|
nat/data_models/component.py
CHANGED
|
@@ -27,6 +27,7 @@ class ComponentEnum(StrEnum):
|
|
|
27
27
|
EVALUATOR = "evaluator"
|
|
28
28
|
FRONT_END = "front_end"
|
|
29
29
|
FUNCTION = "function"
|
|
30
|
+
FUNCTION_GROUP = "function_group"
|
|
30
31
|
TTC_STRATEGY = "ttc_strategy"
|
|
31
32
|
LLM_CLIENT = "llm_client"
|
|
32
33
|
LLM_PROVIDER = "llm_provider"
|
|
@@ -47,6 +48,7 @@ class ComponentGroup(StrEnum):
|
|
|
47
48
|
AUTHENTICATION = "authentication"
|
|
48
49
|
EMBEDDERS = "embedders"
|
|
49
50
|
FUNCTIONS = "functions"
|
|
51
|
+
FUNCTION_GROUPS = "function_groups"
|
|
50
52
|
TTC_STRATEGIES = "ttc_strategies"
|
|
51
53
|
LLMS = "llms"
|
|
52
54
|
MEMORY = "memory"
|
nat/data_models/component_ref.py
CHANGED
|
@@ -102,6 +102,17 @@ class FunctionRef(ComponentRef):
|
|
|
102
102
|
return ComponentGroup.FUNCTIONS
|
|
103
103
|
|
|
104
104
|
|
|
105
|
+
class FunctionGroupRef(ComponentRef):
|
|
106
|
+
"""
|
|
107
|
+
A reference to a function group in a NAT configuration object.
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
@override
|
|
112
|
+
def component_group(self):
|
|
113
|
+
return ComponentGroup.FUNCTION_GROUPS
|
|
114
|
+
|
|
115
|
+
|
|
105
116
|
class LLMRef(ComponentRef):
|
|
106
117
|
"""
|
|
107
118
|
A reference to an LLM in a NAT configuration object.
|
nat/data_models/config.py
CHANGED
|
@@ -20,6 +20,7 @@ import typing
|
|
|
20
20
|
from pydantic import BaseModel
|
|
21
21
|
from pydantic import ConfigDict
|
|
22
22
|
from pydantic import Discriminator
|
|
23
|
+
from pydantic import Field
|
|
23
24
|
from pydantic import ValidationError
|
|
24
25
|
from pydantic import ValidationInfo
|
|
25
26
|
from pydantic import ValidatorFunctionWrapHandler
|
|
@@ -29,7 +30,9 @@ from nat.data_models.evaluate import EvalConfig
|
|
|
29
30
|
from nat.data_models.front_end import FrontEndBaseConfig
|
|
30
31
|
from nat.data_models.function import EmptyFunctionConfig
|
|
31
32
|
from nat.data_models.function import FunctionBaseConfig
|
|
33
|
+
from nat.data_models.function import FunctionGroupBaseConfig
|
|
32
34
|
from nat.data_models.logging import LoggingBaseConfig
|
|
35
|
+
from nat.data_models.optimizer import OptimizerConfig
|
|
33
36
|
from nat.data_models.telemetry_exporter import TelemetryExporterBaseConfig
|
|
34
37
|
from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
|
|
35
38
|
from nat.front_ends.fastapi.fastapi_front_end_config import FastApiFrontEndConfig
|
|
@@ -47,7 +50,7 @@ logger = logging.getLogger(__name__)
|
|
|
47
50
|
|
|
48
51
|
|
|
49
52
|
def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
|
|
50
|
-
from nat.cli.type_registry import GlobalTypeRegistry
|
|
53
|
+
from nat.cli.type_registry import GlobalTypeRegistry
|
|
51
54
|
|
|
52
55
|
new_errors = []
|
|
53
56
|
logged_once = False
|
|
@@ -57,9 +60,10 @@ def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWr
|
|
|
57
60
|
error_type = e['type']
|
|
58
61
|
if error_type == 'union_tag_invalid' and "ctx" in e and not logged_once:
|
|
59
62
|
requested_type = e["ctx"]["tag"]
|
|
60
|
-
|
|
61
63
|
if (info.field_name in ('workflow', 'functions')):
|
|
62
64
|
registered_keys = GlobalTypeRegistry.get().get_registered_functions()
|
|
65
|
+
elif (info.field_name == "function_groups"):
|
|
66
|
+
registered_keys = GlobalTypeRegistry.get().get_registered_function_groups()
|
|
63
67
|
elif (info.field_name == "authentication"):
|
|
64
68
|
registered_keys = GlobalTypeRegistry.get().get_registered_auth_providers()
|
|
65
69
|
elif (info.field_name == "llms"):
|
|
@@ -135,8 +139,8 @@ def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWr
|
|
|
135
139
|
|
|
136
140
|
class TelemetryConfig(BaseModel):
|
|
137
141
|
|
|
138
|
-
logging: dict[str, LoggingBaseConfig] =
|
|
139
|
-
tracing: dict[str, TelemetryExporterBaseConfig] =
|
|
142
|
+
logging: dict[str, LoggingBaseConfig] = Field(default_factory=dict)
|
|
143
|
+
tracing: dict[str, TelemetryExporterBaseConfig] = Field(default_factory=dict)
|
|
140
144
|
|
|
141
145
|
@field_validator("logging", "tracing", mode="wrap")
|
|
142
146
|
@classmethod
|
|
@@ -185,10 +189,14 @@ class GeneralConfig(BaseModel):
|
|
|
185
189
|
|
|
186
190
|
model_config = ConfigDict(protected_namespaces=())
|
|
187
191
|
|
|
188
|
-
use_uvloop: bool =
|
|
192
|
+
use_uvloop: bool | None = Field(
|
|
193
|
+
default=None,
|
|
194
|
+
deprecated=
|
|
195
|
+
"`use_uvloop` field is deprecated and will be removed in a future release. The use of `uv_loop` is now" +
|
|
196
|
+
"automatically determined based on platform")
|
|
189
197
|
"""
|
|
190
|
-
|
|
191
|
-
|
|
198
|
+
This field is deprecated and ignored. It previously controlled whether to use uvloop as the event loop. uvloop
|
|
199
|
+
usage is now determined automatically based on the platform.
|
|
192
200
|
"""
|
|
193
201
|
|
|
194
202
|
telemetry: TelemetryConfig = TelemetryConfig()
|
|
@@ -240,31 +248,37 @@ class Config(HashableBaseModel):
|
|
|
240
248
|
general: GeneralConfig = GeneralConfig()
|
|
241
249
|
|
|
242
250
|
# Functions Configuration
|
|
243
|
-
functions: dict[str, FunctionBaseConfig] =
|
|
251
|
+
functions: dict[str, FunctionBaseConfig] = Field(default_factory=dict)
|
|
252
|
+
|
|
253
|
+
# Function Groups Configuration
|
|
254
|
+
function_groups: dict[str, FunctionGroupBaseConfig] = Field(default_factory=dict)
|
|
244
255
|
|
|
245
256
|
# LLMs Configuration
|
|
246
|
-
llms: dict[str, LLMBaseConfig] =
|
|
257
|
+
llms: dict[str, LLMBaseConfig] = Field(default_factory=dict)
|
|
247
258
|
|
|
248
259
|
# Embedders Configuration
|
|
249
|
-
embedders: dict[str, EmbedderBaseConfig] =
|
|
260
|
+
embedders: dict[str, EmbedderBaseConfig] = Field(default_factory=dict)
|
|
250
261
|
|
|
251
262
|
# Memory Configuration
|
|
252
|
-
memory: dict[str, MemoryBaseConfig] =
|
|
263
|
+
memory: dict[str, MemoryBaseConfig] = Field(default_factory=dict)
|
|
253
264
|
|
|
254
265
|
# Object Stores Configuration
|
|
255
|
-
object_stores: dict[str, ObjectStoreBaseConfig] =
|
|
266
|
+
object_stores: dict[str, ObjectStoreBaseConfig] = Field(default_factory=dict)
|
|
267
|
+
|
|
268
|
+
# Optimizer Configuration
|
|
269
|
+
optimizer: OptimizerConfig = OptimizerConfig()
|
|
256
270
|
|
|
257
271
|
# Retriever Configuration
|
|
258
|
-
retrievers: dict[str, RetrieverBaseConfig] =
|
|
272
|
+
retrievers: dict[str, RetrieverBaseConfig] = Field(default_factory=dict)
|
|
259
273
|
|
|
260
274
|
# TTC Strategies
|
|
261
|
-
ttc_strategies: dict[str, TTCStrategyBaseConfig] =
|
|
275
|
+
ttc_strategies: dict[str, TTCStrategyBaseConfig] = Field(default_factory=dict)
|
|
262
276
|
|
|
263
277
|
# Workflow Configuration
|
|
264
278
|
workflow: FunctionBaseConfig = EmptyFunctionConfig()
|
|
265
279
|
|
|
266
280
|
# Authentication Configuration
|
|
267
|
-
authentication: dict[str, AuthProviderBaseConfig] =
|
|
281
|
+
authentication: dict[str, AuthProviderBaseConfig] = Field(default_factory=dict)
|
|
268
282
|
|
|
269
283
|
# Evaluation Options
|
|
270
284
|
eval: EvalConfig = EvalConfig()
|
|
@@ -278,6 +292,7 @@ class Config(HashableBaseModel):
|
|
|
278
292
|
stream.write(f"Workflow Type: {self.workflow.type}\n")
|
|
279
293
|
|
|
280
294
|
stream.write(f"Number of Functions: {len(self.functions)}\n")
|
|
295
|
+
stream.write(f"Number of Function Groups: {len(self.function_groups)}\n")
|
|
281
296
|
stream.write(f"Number of LLMs: {len(self.llms)}\n")
|
|
282
297
|
stream.write(f"Number of Embedders: {len(self.embedders)}\n")
|
|
283
298
|
stream.write(f"Number of Memory: {len(self.memory)}\n")
|
|
@@ -287,6 +302,7 @@ class Config(HashableBaseModel):
|
|
|
287
302
|
stream.write(f"Number of Authentication Providers: {len(self.authentication)}\n")
|
|
288
303
|
|
|
289
304
|
@field_validator("functions",
|
|
305
|
+
"function_groups",
|
|
290
306
|
"llms",
|
|
291
307
|
"embedders",
|
|
292
308
|
"memory",
|
|
@@ -328,6 +344,10 @@ class Config(HashableBaseModel):
|
|
|
328
344
|
typing.Annotated[type_registry.compute_annotation(FunctionBaseConfig),
|
|
329
345
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
330
346
|
|
|
347
|
+
FunctionGroupsAnnotation = dict[str,
|
|
348
|
+
typing.Annotated[type_registry.compute_annotation(FunctionGroupBaseConfig),
|
|
349
|
+
Discriminator(TypedBaseModel.discriminator)]]
|
|
350
|
+
|
|
331
351
|
MemoryAnnotation = dict[str,
|
|
332
352
|
typing.Annotated[type_registry.compute_annotation(MemoryBaseConfig),
|
|
333
353
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
@@ -335,7 +355,6 @@ class Config(HashableBaseModel):
|
|
|
335
355
|
ObjectStoreAnnotation = dict[str,
|
|
336
356
|
typing.Annotated[type_registry.compute_annotation(ObjectStoreBaseConfig),
|
|
337
357
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
338
|
-
|
|
339
358
|
RetrieverAnnotation = dict[str,
|
|
340
359
|
typing.Annotated[type_registry.compute_annotation(RetrieverBaseConfig),
|
|
341
360
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
@@ -344,7 +363,7 @@ class Config(HashableBaseModel):
|
|
|
344
363
|
typing.Annotated[type_registry.compute_annotation(TTCStrategyBaseConfig),
|
|
345
364
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
346
365
|
|
|
347
|
-
WorkflowAnnotation = typing.Annotated[type_registry.compute_annotation(FunctionBaseConfig),
|
|
366
|
+
WorkflowAnnotation = typing.Annotated[(type_registry.compute_annotation(FunctionBaseConfig)),
|
|
348
367
|
Discriminator(TypedBaseModel.discriminator)]
|
|
349
368
|
|
|
350
369
|
should_rebuild = False
|
|
@@ -369,6 +388,11 @@ class Config(HashableBaseModel):
|
|
|
369
388
|
functions_field.annotation = FunctionsAnnotation
|
|
370
389
|
should_rebuild = True
|
|
371
390
|
|
|
391
|
+
function_groups_field = cls.model_fields.get("function_groups")
|
|
392
|
+
if function_groups_field is not None and function_groups_field.annotation != FunctionGroupsAnnotation:
|
|
393
|
+
function_groups_field.annotation = FunctionGroupsAnnotation
|
|
394
|
+
should_rebuild = True
|
|
395
|
+
|
|
372
396
|
memory_field = cls.model_fields.get("memory")
|
|
373
397
|
if memory_field is not None and memory_field.annotation != MemoryAnnotation:
|
|
374
398
|
memory_field.annotation = MemoryAnnotation
|
|
@@ -80,7 +80,7 @@ class EvalDatasetJsonConfig(EvalDatasetBaseConfig, name="json"):
|
|
|
80
80
|
|
|
81
81
|
|
|
82
82
|
def read_jsonl(file_path: FilePath):
|
|
83
|
-
with open(file_path,
|
|
83
|
+
with open(file_path, encoding='utf-8') as f:
|
|
84
84
|
data = [json.loads(line) for line in f]
|
|
85
85
|
return pd.DataFrame(data)
|
|
86
86
|
|
|
@@ -177,7 +177,7 @@ class DiscoveryMetadata(BaseModel):
|
|
|
177
177
|
logger.warning("Package metadata not found for %s", distro_name)
|
|
178
178
|
version = ""
|
|
179
179
|
except Exception as e:
|
|
180
|
-
logger.exception("Encountered issue extracting module metadata for %s: %s", config_type, e
|
|
180
|
+
logger.exception("Encountered issue extracting module metadata for %s: %s", config_type, e)
|
|
181
181
|
return DiscoveryMetadata(status=DiscoveryStatusEnum.FAILURE)
|
|
182
182
|
|
|
183
183
|
description = generate_config_type_docs(config_type=config_type)
|
|
@@ -217,7 +217,7 @@ class DiscoveryMetadata(BaseModel):
|
|
|
217
217
|
logger.warning("Package metadata not found for %s", distro_name)
|
|
218
218
|
version = ""
|
|
219
219
|
except Exception as e:
|
|
220
|
-
logger.exception("Encountered issue extracting module metadata for %s: %s", fn, e
|
|
220
|
+
logger.exception("Encountered issue extracting module metadata for %s: %s", fn, e)
|
|
221
221
|
return DiscoveryMetadata(status=DiscoveryStatusEnum.FAILURE)
|
|
222
222
|
|
|
223
223
|
if isinstance(wrapper_type, LLMFrameworkEnum):
|
|
@@ -252,7 +252,7 @@ class DiscoveryMetadata(BaseModel):
|
|
|
252
252
|
description = ""
|
|
253
253
|
package_version = package_version or ""
|
|
254
254
|
except Exception as e:
|
|
255
|
-
logger.exception("Encountered issue extracting module metadata for %s: %s", package_name, e
|
|
255
|
+
logger.exception("Encountered issue extracting module metadata for %s: %s", package_name, e)
|
|
256
256
|
return DiscoveryMetadata(status=DiscoveryStatusEnum.FAILURE)
|
|
257
257
|
|
|
258
258
|
return DiscoveryMetadata(package=package_name,
|
|
@@ -290,7 +290,7 @@ class DiscoveryMetadata(BaseModel):
|
|
|
290
290
|
logger.warning("Package metadata not found for %s", distro_name)
|
|
291
291
|
version = ""
|
|
292
292
|
except Exception as e:
|
|
293
|
-
logger.exception("Encountered issue extracting module metadata for %s: %s", config_type, e
|
|
293
|
+
logger.exception("Encountered issue extracting module metadata for %s: %s", config_type, e)
|
|
294
294
|
return DiscoveryMetadata(status=DiscoveryStatusEnum.FAILURE)
|
|
295
295
|
|
|
296
296
|
wrapper_type = wrapper_type.value if isinstance(wrapper_type, LLMFrameworkEnum) else wrapper_type
|
nat/data_models/evaluate.py
CHANGED
|
@@ -57,6 +57,9 @@ class EvalOutputConfig(BaseModel):
|
|
|
57
57
|
dir: Path = Path("./.tmp/nat/examples/default/")
|
|
58
58
|
# S3 prefix for the workflow and evaluation results
|
|
59
59
|
remote_dir: str | None = None
|
|
60
|
+
# Custom function to pre-evaluation process the eval input
|
|
61
|
+
# Format: "module.path.function_name"
|
|
62
|
+
custom_pre_eval_process_function: str | None = None
|
|
60
63
|
# Custom scripts to run after the workflow and evaluation results are saved
|
|
61
64
|
custom_scripts: dict[str, EvalCustomScriptConfig] = {}
|
|
62
65
|
# S3 config for uploading the contents of the output directory
|
|
@@ -108,7 +111,7 @@ class EvalConfig(BaseModel):
|
|
|
108
111
|
@classmethod
|
|
109
112
|
def rebuild_annotations(cls):
|
|
110
113
|
|
|
111
|
-
from nat.cli.type_registry import GlobalTypeRegistry
|
|
114
|
+
from nat.cli.type_registry import GlobalTypeRegistry
|
|
112
115
|
|
|
113
116
|
type_registry = GlobalTypeRegistry.get()
|
|
114
117
|
|