nvidia-nat 1.4.0a20251102__py3-none-any.whl → 1.4.0a20251120__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/builder/builder.py +52 -0
- nat/builder/component_utils.py +7 -1
- nat/builder/context.py +17 -0
- nat/builder/framework_enum.py +1 -0
- nat/builder/function.py +74 -3
- nat/builder/workflow.py +4 -2
- nat/builder/workflow_builder.py +129 -0
- nat/cli/commands/workflow/workflow_commands.py +3 -2
- nat/cli/register_workflow.py +50 -0
- nat/cli/type_registry.py +68 -0
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +16 -0
- nat/data_models/function.py +14 -1
- nat/data_models/middleware.py +35 -0
- nat/data_models/runtime_enum.py +26 -0
- nat/eval/dataset_handler/dataset_filter.py +34 -2
- nat/eval/evaluate.py +11 -3
- nat/eval/utils/weave_eval.py +17 -3
- nat/front_ends/fastapi/fastapi_front_end_config.py +29 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +13 -7
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +144 -14
- nat/front_ends/mcp/mcp_front_end_plugin.py +4 -0
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +26 -0
- nat/llm/aws_bedrock_llm.py +11 -9
- nat/llm/azure_openai_llm.py +12 -4
- nat/llm/litellm_llm.py +11 -4
- nat/llm/nim_llm.py +11 -9
- nat/llm/openai_llm.py +12 -9
- nat/middleware/__init__.py +35 -0
- nat/middleware/cache_middleware.py +256 -0
- nat/middleware/function_middleware.py +186 -0
- nat/middleware/middleware.py +184 -0
- nat/middleware/register.py +35 -0
- nat/profiler/decorators/framework_wrapper.py +16 -0
- nat/retriever/milvus/register.py +11 -3
- nat/retriever/milvus/retriever.py +102 -40
- nat/runtime/runner.py +12 -1
- nat/runtime/session.py +10 -3
- nat/tool/code_execution/code_sandbox.py +4 -7
- nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +19 -32
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +5 -0
- nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +2 -0
- nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +10 -4
- nat/tool/server_tools.py +15 -2
- nat/utils/__init__.py +8 -4
- nat/utils/io/yaml_tools.py +73 -3
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/METADATA +11 -3
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/RECORD +54 -50
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/entry_points.txt +1 -0
- nat/data_models/temperature_mixin.py +0 -44
- nat/data_models/top_p_mixin.py +0 -44
- nat/tool/code_execution/test_code_execution_sandbox.py +0 -414
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/top_level.txt +0 -0
nat/retriever/milvus/register.py
CHANGED
|
@@ -48,6 +48,7 @@ class MilvusRetrieverConfig(RetrieverBaseConfig, name="milvus_retriever"):
|
|
|
48
48
|
description: str | None = Field(default=None,
|
|
49
49
|
description="If present it will be used as the tool description",
|
|
50
50
|
alias="collection_description")
|
|
51
|
+
use_async_client: bool = Field(default=False, description="Use AsyncMilvusClient for async I/O operations. ")
|
|
51
52
|
|
|
52
53
|
|
|
53
54
|
@register_retriever_provider(config_type=MilvusRetrieverConfig)
|
|
@@ -58,13 +59,20 @@ async def milvus_retriever(retriever_config: MilvusRetrieverConfig, builder: Bui
|
|
|
58
59
|
|
|
59
60
|
@register_retriever_client(config_type=MilvusRetrieverConfig, wrapper_type=None)
|
|
60
61
|
async def milvus_retriever_client(config: MilvusRetrieverConfig, builder: Builder):
|
|
61
|
-
from pymilvus import MilvusClient
|
|
62
|
-
|
|
63
62
|
from nat.retriever.milvus.retriever import MilvusRetriever
|
|
64
63
|
|
|
65
64
|
embedder = await builder.get_embedder(embedder_name=config.embedding_model, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
66
65
|
|
|
67
|
-
|
|
66
|
+
# Create Milvus client based on use_async_client flag
|
|
67
|
+
if config.use_async_client:
|
|
68
|
+
from pymilvus import AsyncMilvusClient
|
|
69
|
+
|
|
70
|
+
milvus_client = AsyncMilvusClient(uri=str(config.uri), **config.connection_args)
|
|
71
|
+
else:
|
|
72
|
+
from pymilvus import MilvusClient
|
|
73
|
+
|
|
74
|
+
milvus_client = MilvusClient(uri=str(config.uri), **config.connection_args)
|
|
75
|
+
|
|
68
76
|
retriever = MilvusRetriever(
|
|
69
77
|
client=milvus_client,
|
|
70
78
|
embedder=embedder,
|
|
@@ -13,13 +13,18 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import inspect
|
|
16
17
|
import logging
|
|
17
18
|
from functools import partial
|
|
19
|
+
from typing import TYPE_CHECKING
|
|
18
20
|
|
|
19
21
|
from langchain_core.embeddings import Embeddings
|
|
20
|
-
from pymilvus import MilvusClient
|
|
21
22
|
from pymilvus.client.abstract import Hit
|
|
22
23
|
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from pymilvus import AsyncMilvusClient
|
|
26
|
+
from pymilvus import MilvusClient
|
|
27
|
+
|
|
23
28
|
from nat.retriever.interface import Retriever
|
|
24
29
|
from nat.retriever.models import Document
|
|
25
30
|
from nat.retriever.models import RetrieverError
|
|
@@ -39,20 +44,27 @@ class MilvusRetriever(Retriever):
|
|
|
39
44
|
|
|
40
45
|
def __init__(
|
|
41
46
|
self,
|
|
42
|
-
client: MilvusClient,
|
|
47
|
+
client: "MilvusClient | AsyncMilvusClient",
|
|
43
48
|
embedder: Embeddings,
|
|
44
49
|
content_field: str = "text",
|
|
45
50
|
use_iterator: bool = False,
|
|
46
51
|
) -> None:
|
|
47
52
|
"""
|
|
48
|
-
Initialize the Milvus Retriever using a preconfigured MilvusClient
|
|
53
|
+
Initialize the Milvus Retriever using a preconfigured MilvusClient or AsyncMilvusClient
|
|
49
54
|
|
|
50
55
|
Args:
|
|
51
|
-
client (MilvusClient): Preinstantiate pymilvus.MilvusClient object.
|
|
52
56
|
"""
|
|
53
|
-
self._client = client
|
|
57
|
+
self._client: MilvusClient | AsyncMilvusClient = client
|
|
54
58
|
self._embedder = embedder
|
|
55
59
|
|
|
60
|
+
# Detect if client is async by inspecting method capabilities
|
|
61
|
+
search_method = getattr(client, "search", None)
|
|
62
|
+
list_collections_method = getattr(client, "list_collections", None)
|
|
63
|
+
self._is_async = any(
|
|
64
|
+
inspect.iscoroutinefunction(method) for method in (search_method, list_collections_method)
|
|
65
|
+
if method is not None)
|
|
66
|
+
logger.info("Initialized Milvus Retriever with %s client", "async" if self._is_async else "sync")
|
|
67
|
+
|
|
56
68
|
if use_iterator and "search_iterator" not in dir(self._client):
|
|
57
69
|
raise ValueError("This version of the pymilvus.MilvusClient does not support the search iterator.")
|
|
58
70
|
|
|
@@ -60,7 +72,7 @@ class MilvusRetriever(Retriever):
|
|
|
60
72
|
self._default_params = None
|
|
61
73
|
self._bound_params = []
|
|
62
74
|
self.content_field = content_field
|
|
63
|
-
logger.info("
|
|
75
|
+
logger.info("Milvus Retriever using %s for search.", self._search_func.__name__)
|
|
64
76
|
|
|
65
77
|
def bind(self, **kwargs) -> None:
|
|
66
78
|
"""
|
|
@@ -81,8 +93,13 @@ class MilvusRetriever(Retriever):
|
|
|
81
93
|
"""
|
|
82
94
|
return [param for param in ["query", "collection_name", "top_k", "filters"] if param not in self._bound_params]
|
|
83
95
|
|
|
84
|
-
def _validate_collection(self, collection_name: str) -> bool:
|
|
85
|
-
|
|
96
|
+
async def _validate_collection(self, collection_name: str) -> bool:
|
|
97
|
+
"""Validate that a collection exists."""
|
|
98
|
+
if self._is_async:
|
|
99
|
+
collections = await self._client.list_collections()
|
|
100
|
+
else:
|
|
101
|
+
collections = self._client.list_collections()
|
|
102
|
+
return collection_name in collections
|
|
86
103
|
|
|
87
104
|
async def search(self, query: str, **kwargs):
|
|
88
105
|
return await self._search_func(query=query, **kwargs)
|
|
@@ -108,39 +125,64 @@ class MilvusRetriever(Retriever):
|
|
|
108
125
|
collection_name,
|
|
109
126
|
top_k)
|
|
110
127
|
|
|
111
|
-
if not self._validate_collection(collection_name):
|
|
128
|
+
if not await self._validate_collection(collection_name):
|
|
112
129
|
raise CollectionNotFoundError(f"Collection: {collection_name} does not exist")
|
|
113
130
|
|
|
114
131
|
# If no output fields are specified, return all of them
|
|
115
132
|
if not output_fields:
|
|
116
|
-
|
|
133
|
+
if self._is_async:
|
|
134
|
+
collection_schema = await self._client.describe_collection(collection_name)
|
|
135
|
+
else:
|
|
136
|
+
collection_schema = self._client.describe_collection(collection_name)
|
|
117
137
|
output_fields = [
|
|
118
138
|
field["name"] for field in collection_schema.get("fields") if field["name"] != vector_field_name
|
|
119
139
|
]
|
|
120
140
|
|
|
121
|
-
search_vector = self._embedder.
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
141
|
+
search_vector = await self._embedder.aembed_query(query)
|
|
142
|
+
|
|
143
|
+
# Create search iterator
|
|
144
|
+
if self._is_async:
|
|
145
|
+
search_iterator = await self._client.search_iterator(
|
|
146
|
+
collection_name=collection_name,
|
|
147
|
+
data=[search_vector],
|
|
148
|
+
batch_size=kwargs.get("batch_size", 1000),
|
|
149
|
+
filter=filters,
|
|
150
|
+
limit=top_k,
|
|
151
|
+
output_fields=output_fields,
|
|
152
|
+
search_params=search_params if search_params else {"metric_type": "L2"},
|
|
153
|
+
timeout=timeout,
|
|
154
|
+
anns_field=vector_field_name,
|
|
155
|
+
round_decimal=kwargs.get("round_decimal", -1),
|
|
156
|
+
partition_names=kwargs.get("partition_names", None),
|
|
157
|
+
)
|
|
158
|
+
else:
|
|
159
|
+
search_iterator = self._client.search_iterator(
|
|
160
|
+
collection_name=collection_name,
|
|
161
|
+
data=[search_vector],
|
|
162
|
+
batch_size=kwargs.get("batch_size", 1000),
|
|
163
|
+
filter=filters,
|
|
164
|
+
limit=top_k,
|
|
165
|
+
output_fields=output_fields,
|
|
166
|
+
search_params=search_params if search_params else {"metric_type": "L2"},
|
|
167
|
+
timeout=timeout,
|
|
168
|
+
anns_field=vector_field_name,
|
|
169
|
+
round_decimal=kwargs.get("round_decimal", -1),
|
|
170
|
+
partition_names=kwargs.get("partition_names", None),
|
|
171
|
+
)
|
|
136
172
|
|
|
137
173
|
results = []
|
|
138
174
|
try:
|
|
139
175
|
while True:
|
|
140
|
-
|
|
176
|
+
if self._is_async:
|
|
177
|
+
_res = await search_iterator.next()
|
|
178
|
+
else:
|
|
179
|
+
_res = search_iterator.next()
|
|
141
180
|
res = _res.get_res()
|
|
142
181
|
if len(_res) == 0:
|
|
143
|
-
|
|
182
|
+
if self._is_async:
|
|
183
|
+
await search_iterator.close()
|
|
184
|
+
else:
|
|
185
|
+
search_iterator.close()
|
|
144
186
|
break
|
|
145
187
|
|
|
146
188
|
if distance_cutoff and res[0][-1].distance > distance_cutoff:
|
|
@@ -176,10 +218,16 @@ class MilvusRetriever(Retriever):
|
|
|
176
218
|
collection_name,
|
|
177
219
|
top_k)
|
|
178
220
|
|
|
179
|
-
if not self._validate_collection(collection_name):
|
|
221
|
+
if not await self._validate_collection(collection_name):
|
|
180
222
|
raise CollectionNotFoundError(f"Collection: {collection_name} does not exist")
|
|
181
223
|
|
|
182
|
-
|
|
224
|
+
# Get collection schema
|
|
225
|
+
if self._is_async:
|
|
226
|
+
collection_schema = await self._client.describe_collection(collection_name)
|
|
227
|
+
else:
|
|
228
|
+
collection_schema = self._client.describe_collection(collection_name)
|
|
229
|
+
|
|
230
|
+
available_fields = [v.get("name") for v in collection_schema.get("fields", [])]
|
|
183
231
|
|
|
184
232
|
if self.content_field not in available_fields:
|
|
185
233
|
raise ValueError(f"The specified content field: {self.content_field} is not part of the schema.")
|
|
@@ -194,17 +242,31 @@ class MilvusRetriever(Retriever):
|
|
|
194
242
|
if self.content_field not in output_fields:
|
|
195
243
|
output_fields.append(self.content_field)
|
|
196
244
|
|
|
197
|
-
search_vector = self._embedder.
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
245
|
+
search_vector = await self._embedder.aembed_query(query)
|
|
246
|
+
|
|
247
|
+
# Perform search
|
|
248
|
+
if self._is_async:
|
|
249
|
+
res = await self._client.search(
|
|
250
|
+
collection_name=collection_name,
|
|
251
|
+
data=[search_vector],
|
|
252
|
+
filter=filters,
|
|
253
|
+
output_fields=output_fields,
|
|
254
|
+
search_params=search_params if search_params else {"metric_type": "L2"},
|
|
255
|
+
timeout=timeout,
|
|
256
|
+
anns_field=vector_field_name,
|
|
257
|
+
limit=top_k,
|
|
258
|
+
)
|
|
259
|
+
else:
|
|
260
|
+
res = self._client.search(
|
|
261
|
+
collection_name=collection_name,
|
|
262
|
+
data=[search_vector],
|
|
263
|
+
filter=filters,
|
|
264
|
+
output_fields=output_fields,
|
|
265
|
+
search_params=search_params if search_params else {"metric_type": "L2"},
|
|
266
|
+
timeout=timeout,
|
|
267
|
+
anns_field=vector_field_name,
|
|
268
|
+
limit=top_k,
|
|
269
|
+
)
|
|
208
270
|
|
|
209
271
|
return _wrap_milvus_results(res[0], content_field=self.content_field)
|
|
210
272
|
|
nat/runtime/runner.py
CHANGED
|
@@ -26,6 +26,7 @@ from nat.data_models.intermediate_step import IntermediateStepType
|
|
|
26
26
|
from nat.data_models.intermediate_step import StreamEventData
|
|
27
27
|
from nat.data_models.intermediate_step import TraceMetadata
|
|
28
28
|
from nat.data_models.invocation_node import InvocationNode
|
|
29
|
+
from nat.data_models.runtime_enum import RuntimeTypeEnum
|
|
29
30
|
from nat.observability.exporter_manager import ExporterManager
|
|
30
31
|
from nat.utils.reactive.subject import Subject
|
|
31
32
|
|
|
@@ -53,7 +54,8 @@ class Runner:
|
|
|
53
54
|
input_message: typing.Any,
|
|
54
55
|
entry_fn: Function,
|
|
55
56
|
context_state: ContextState,
|
|
56
|
-
exporter_manager: ExporterManager
|
|
57
|
+
exporter_manager: ExporterManager,
|
|
58
|
+
runtime_type: RuntimeTypeEnum = RuntimeTypeEnum.RUN_OR_SERVE):
|
|
57
59
|
"""
|
|
58
60
|
The Runner class is used to run a workflow. It handles converting input and output data types and running the
|
|
59
61
|
workflow with the specified concurrency.
|
|
@@ -68,6 +70,8 @@ class Runner:
|
|
|
68
70
|
The context state to use
|
|
69
71
|
exporter_manager : ExporterManager
|
|
70
72
|
The exporter manager to use
|
|
73
|
+
runtime_type : RuntimeTypeEnum
|
|
74
|
+
The runtime type (RUN_OR_SERVE, EVALUATE, OTHER)
|
|
71
75
|
"""
|
|
72
76
|
|
|
73
77
|
if (entry_fn is None):
|
|
@@ -86,6 +90,9 @@ class Runner:
|
|
|
86
90
|
|
|
87
91
|
self._exporter_manager = exporter_manager
|
|
88
92
|
|
|
93
|
+
self._runtime_type = runtime_type
|
|
94
|
+
self._runtime_type_token = None
|
|
95
|
+
|
|
89
96
|
@property
|
|
90
97
|
def context(self) -> Context:
|
|
91
98
|
return self._context
|
|
@@ -105,6 +112,8 @@ class Runner:
|
|
|
105
112
|
function_id="root",
|
|
106
113
|
))
|
|
107
114
|
|
|
115
|
+
self._runtime_type_token = self._context_state.runtime_type.set(self._runtime_type)
|
|
116
|
+
|
|
108
117
|
if (self._state == RunnerState.UNINITIALIZED):
|
|
109
118
|
self._state = RunnerState.INITIALIZED
|
|
110
119
|
else:
|
|
@@ -119,6 +128,8 @@ class Runner:
|
|
|
119
128
|
|
|
120
129
|
self._context_state.input_message.reset(self._input_message_token)
|
|
121
130
|
|
|
131
|
+
self._context_state.runtime_type.reset(self._runtime_type_token)
|
|
132
|
+
|
|
122
133
|
if (self._state not in (RunnerState.COMPLETED, RunnerState.FAILED)):
|
|
123
134
|
raise ValueError("Cannot exit the context without completing the workflow")
|
|
124
135
|
|
nat/runtime/session.py
CHANGED
|
@@ -35,6 +35,7 @@ from nat.data_models.authentication import AuthProviderBaseConfig
|
|
|
35
35
|
from nat.data_models.config import Config
|
|
36
36
|
from nat.data_models.interactive import HumanResponse
|
|
37
37
|
from nat.data_models.interactive import InteractionPrompt
|
|
38
|
+
from nat.data_models.runtime_enum import RuntimeTypeEnum
|
|
38
39
|
|
|
39
40
|
_T = typing.TypeVar("_T")
|
|
40
41
|
|
|
@@ -45,7 +46,10 @@ class UserManagerBase:
|
|
|
45
46
|
|
|
46
47
|
class SessionManager:
|
|
47
48
|
|
|
48
|
-
def __init__(self,
|
|
49
|
+
def __init__(self,
|
|
50
|
+
workflow: Workflow,
|
|
51
|
+
max_concurrency: int = 8,
|
|
52
|
+
runtime_type: RuntimeTypeEnum = RuntimeTypeEnum.RUN_OR_SERVE):
|
|
49
53
|
"""
|
|
50
54
|
The SessionManager class is used to run and manage a user workflow session. It runs and manages the context,
|
|
51
55
|
and configuration of a workflow with the specified concurrency.
|
|
@@ -56,6 +60,8 @@ class SessionManager:
|
|
|
56
60
|
The workflow to run
|
|
57
61
|
max_concurrency : int, optional
|
|
58
62
|
The maximum number of simultaneous workflow invocations, by default 8
|
|
63
|
+
runtime_type : RuntimeTypeEnum, optional
|
|
64
|
+
The type of runtime the session manager is operating in, by default RuntimeTypeEnum.RUN_OR_SERVE
|
|
59
65
|
"""
|
|
60
66
|
|
|
61
67
|
if (workflow is None):
|
|
@@ -66,6 +72,7 @@ class SessionManager:
|
|
|
66
72
|
self._max_concurrency = max_concurrency
|
|
67
73
|
self._context_state = ContextState.get()
|
|
68
74
|
self._context = Context(self._context_state)
|
|
75
|
+
self._runtime_type = runtime_type
|
|
69
76
|
|
|
70
77
|
# We save the context because Uvicorn spawns a new process
|
|
71
78
|
# for each request, and we need to restore the context vars
|
|
@@ -128,7 +135,7 @@ class SessionManager:
|
|
|
128
135
|
self._context_state.user_auth_callback.reset(token_user_authentication)
|
|
129
136
|
|
|
130
137
|
@asynccontextmanager
|
|
131
|
-
async def run(self, message):
|
|
138
|
+
async def run(self, message, runtime_type: RuntimeTypeEnum = RuntimeTypeEnum.RUN_OR_SERVE):
|
|
132
139
|
"""
|
|
133
140
|
Start a workflow run
|
|
134
141
|
"""
|
|
@@ -137,7 +144,7 @@ class SessionManager:
|
|
|
137
144
|
for k, v in self._saved_context.items():
|
|
138
145
|
k.set(v)
|
|
139
146
|
|
|
140
|
-
async with self._workflow.run(message) as runner:
|
|
147
|
+
async with self._workflow.run(message, runtime_type=runtime_type) as runner:
|
|
141
148
|
yield runner
|
|
142
149
|
|
|
143
150
|
def set_metadata_from_http_request(self, request: Request) -> None:
|
|
@@ -92,7 +92,9 @@ class Sandbox(abc.ABC):
|
|
|
92
92
|
raise ValueError(f"Language {language} not supported")
|
|
93
93
|
|
|
94
94
|
generated_code = generated_code.strip().strip("`")
|
|
95
|
-
|
|
95
|
+
# Use json.dumps to properly escape the generated_code instead of repr()
|
|
96
|
+
escaped_code = json.dumps(generated_code)
|
|
97
|
+
code_to_execute = textwrap.dedent(f"""
|
|
96
98
|
import traceback
|
|
97
99
|
import json
|
|
98
100
|
import os
|
|
@@ -101,11 +103,6 @@ class Sandbox(abc.ABC):
|
|
|
101
103
|
import io
|
|
102
104
|
warnings.filterwarnings('ignore')
|
|
103
105
|
os.environ['OPENBLAS_NUM_THREADS'] = '16'
|
|
104
|
-
""").strip()
|
|
105
|
-
|
|
106
|
-
# Use json.dumps to properly escape the generated_code instead of repr()
|
|
107
|
-
escaped_code = json.dumps(generated_code)
|
|
108
|
-
code_to_execute += textwrap.dedent(f"""
|
|
109
106
|
|
|
110
107
|
generated_code = {escaped_code}
|
|
111
108
|
|
|
@@ -155,7 +152,7 @@ class LocalSandbox(Sandbox):
|
|
|
155
152
|
output_json = output.json()
|
|
156
153
|
assert isinstance(output_json, dict)
|
|
157
154
|
return output_json
|
|
158
|
-
except
|
|
155
|
+
except (requests.exceptions.JSONDecodeError, AssertionError) as e:
|
|
159
156
|
logger.exception("Error parsing output: %s. %s", output.text, e)
|
|
160
157
|
return {'process_status': 'error', 'stdout': '', 'stderr': f'Unknown error: {e} \"{output.text}\"'}
|
|
161
158
|
|
|
@@ -12,43 +12,26 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
#
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
# Create Lean project directory and initialize a new Lean project with Mathlib4
|
|
30
|
-
RUN mkdir -p /lean4 && cd /lean4 && \
|
|
31
|
-
/root/.elan/bin/lake new my_project && \
|
|
32
|
-
cd my_project && \
|
|
33
|
-
echo 'leanprover/lean4:v4.12.0' > lean-toolchain && \
|
|
34
|
-
echo 'require mathlib from git "https://github.com/leanprover-community/mathlib4" @ "v4.12.0"' >> lakefile.lean
|
|
35
|
-
|
|
36
|
-
# Download and cache Mathlib4 to avoid recompiling, then build the project
|
|
37
|
-
RUN cd /lean4/my_project && \
|
|
38
|
-
/root/.elan/bin/lake exe cache get && \
|
|
39
|
-
/root/.elan/bin/lake build
|
|
40
|
-
|
|
41
|
-
# Set environment variables to include Lean project path
|
|
42
|
-
ENV LEAN_PATH="/lean4/my_project"
|
|
43
|
-
ENV PATH="/lean4/my_project:$PATH"
|
|
15
|
+
# UWSGI_CHEAPER sets the number of initial uWSGI worker processes
|
|
16
|
+
# UWSGI_PROCESSES sets the maximum number of uWSGI worker processes
|
|
17
|
+
ARG UWSGI_CHEAPER=5
|
|
18
|
+
ARG UWSGI_PROCESSES=10
|
|
19
|
+
|
|
20
|
+
# Use the base image with Python 3.13
|
|
21
|
+
FROM python:3.13-slim-bookworm
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
RUN apt update && \
|
|
25
|
+
apt upgrade && \
|
|
26
|
+
apt install -y --no-install-recommends libexpat1 && \
|
|
27
|
+
apt clean && \
|
|
28
|
+
rm -rf /var/lib/apt/lists/*
|
|
44
29
|
|
|
45
30
|
# Set up application code and install Python dependencies
|
|
46
31
|
COPY sandbox.requirements.txt /app/requirements.txt
|
|
47
32
|
RUN pip install --no-cache-dir -r /app/requirements.txt
|
|
48
33
|
COPY local_sandbox_server.py /app/main.py
|
|
49
|
-
|
|
50
|
-
# Set the working directory to /app
|
|
51
|
-
WORKDIR /app
|
|
34
|
+
RUN mkdir /workspace
|
|
52
35
|
|
|
53
36
|
# Set Flask app environment variables and ports
|
|
54
37
|
ARG UWSGI_CHEAPER
|
|
@@ -58,3 +41,7 @@ ARG UWSGI_PROCESSES
|
|
|
58
41
|
ENV UWSGI_PROCESSES=$UWSGI_PROCESSES
|
|
59
42
|
|
|
60
43
|
ENV LISTEN_PORT=6000
|
|
44
|
+
EXPOSE 6000
|
|
45
|
+
|
|
46
|
+
WORKDIR /app
|
|
47
|
+
CMD uwsgi --http 0.0.0.0:${LISTEN_PORT} --master -p ${UWSGI_PROCESSES} --force-cwd /workspace -w main:app
|
|
@@ -19,7 +19,11 @@
|
|
|
19
19
|
|
|
20
20
|
DOCKER_COMMAND=${DOCKER_COMMAND:-"docker"}
|
|
21
21
|
SANDBOX_NAME=${1:-'local-sandbox'}
|
|
22
|
-
|
|
22
|
+
|
|
23
|
+
# UWSGI_CHEAPER sets the number of initial uWSGI worker processes
|
|
24
|
+
# UWSGI_PROCESSES sets the maximum number of uWSGI worker processes
|
|
25
|
+
UWSGI_CHEAPER=${UWSGI_CHEAPER:-5}
|
|
26
|
+
UWSGI_PROCESSES=${UWSGI_PROCESSES:-10}
|
|
23
27
|
|
|
24
28
|
# Get the output_data directory path for mounting
|
|
25
29
|
# Priority: command line argument > environment variable > default path (current directory)
|
|
@@ -37,14 +41,16 @@ fi
|
|
|
37
41
|
# Check if the Docker image already exists
|
|
38
42
|
if ! ${DOCKER_COMMAND} images ${SANDBOX_NAME} | grep -q "${SANDBOX_NAME}"; then
|
|
39
43
|
echo "Docker image not found locally. Building ${SANDBOX_NAME}..."
|
|
40
|
-
${DOCKER_COMMAND} build --tag=${SANDBOX_NAME}
|
|
44
|
+
${DOCKER_COMMAND} build --tag=${SANDBOX_NAME} \
|
|
45
|
+
--build-arg="UWSGI_PROCESSES=${UWSGI_PROCESSES}" \
|
|
46
|
+
--build-arg="UWSGI_CHEAPER=${UWSGI_CHEAPER}" \
|
|
47
|
+
-f Dockerfile.sandbox .
|
|
41
48
|
else
|
|
42
49
|
echo "Using existing Docker image: ${SANDBOX_NAME}"
|
|
43
50
|
fi
|
|
44
51
|
|
|
45
52
|
# Mount the output_data directory directly so files created in container appear in the local directory
|
|
46
|
-
${DOCKER_COMMAND} run --rm --name=local-sandbox \
|
|
53
|
+
${DOCKER_COMMAND} run --rm -ti --name=local-sandbox \
|
|
47
54
|
--network=host \
|
|
48
55
|
-v "${OUTPUT_DATA_PATH}:/workspace" \
|
|
49
|
-
-w /workspace \
|
|
50
56
|
${SANDBOX_NAME}
|
nat/tool/server_tools.py
CHANGED
|
@@ -32,14 +32,23 @@ class RequestAttributesTool(FunctionBaseConfig, name="current_request_attributes
|
|
|
32
32
|
@register_function(config_type=RequestAttributesTool)
|
|
33
33
|
async def current_request_attributes(config: RequestAttributesTool, builder: Builder):
|
|
34
34
|
|
|
35
|
+
from pydantic import RootModel
|
|
36
|
+
from pydantic.types import JsonValue
|
|
35
37
|
from starlette.datastructures import Headers
|
|
36
38
|
from starlette.datastructures import QueryParams
|
|
37
39
|
|
|
38
|
-
|
|
40
|
+
class RequestBody(RootModel[JsonValue]):
|
|
41
|
+
"""
|
|
42
|
+
Data model that accepts a request body of any valid JSON type.
|
|
43
|
+
"""
|
|
44
|
+
root: JsonValue
|
|
45
|
+
|
|
46
|
+
async def _get_request_attributes(request_body: RequestBody) -> str:
|
|
39
47
|
|
|
40
48
|
from nat.builder.context import Context
|
|
41
49
|
nat_context = Context.get()
|
|
42
50
|
|
|
51
|
+
# Access request attributes from context
|
|
43
52
|
method: str | None = nat_context.metadata.method
|
|
44
53
|
url_path: str | None = nat_context.metadata.url_path
|
|
45
54
|
url_scheme: str | None = nat_context.metadata.url_scheme
|
|
@@ -51,6 +60,9 @@ async def current_request_attributes(config: RequestAttributesTool, builder: Bui
|
|
|
51
60
|
cookies: dict[str, str] | None = nat_context.metadata.cookies
|
|
52
61
|
conversation_id: str | None = nat_context.conversation_id
|
|
53
62
|
|
|
63
|
+
# Access the request body data - can be any valid JSON type
|
|
64
|
+
request_body_data: JsonValue = request_body.root
|
|
65
|
+
|
|
54
66
|
return (f"Method: {method}, "
|
|
55
67
|
f"URL Path: {url_path}, "
|
|
56
68
|
f"URL Scheme: {url_scheme}, "
|
|
@@ -60,7 +72,8 @@ async def current_request_attributes(config: RequestAttributesTool, builder: Bui
|
|
|
60
72
|
f"Client Host: {client_host}, "
|
|
61
73
|
f"Client Port: {client_port}, "
|
|
62
74
|
f"Cookies: {cookies}, "
|
|
63
|
-
f"Conversation Id: {conversation_id}"
|
|
75
|
+
f"Conversation Id: {conversation_id}, "
|
|
76
|
+
f"Request Body: {request_body_data}")
|
|
64
77
|
|
|
65
78
|
yield FunctionInfo.from_fn(_get_request_attributes,
|
|
66
79
|
description="Returns the acquired user defined request attributes.")
|
nat/utils/__init__.py
CHANGED
|
@@ -29,7 +29,8 @@ async def run_workflow(*,
|
|
|
29
29
|
config: "Config | None" = None,
|
|
30
30
|
config_file: "StrPath | None" = None,
|
|
31
31
|
prompt: str,
|
|
32
|
-
to_type: type[_T] = str
|
|
32
|
+
to_type: type[_T] = str,
|
|
33
|
+
session_kwargs: dict[str, typing.Any] | None = None) -> _T:
|
|
33
34
|
"""
|
|
34
35
|
Wrapper to run a workflow given either a config or a config file path and a prompt, returning the result in the
|
|
35
36
|
type specified by the `to_type`.
|
|
@@ -66,7 +67,10 @@ async def run_workflow(*,
|
|
|
66
67
|
|
|
67
68
|
config = load_config(config_file)
|
|
68
69
|
|
|
70
|
+
session_kwargs = session_kwargs or {}
|
|
71
|
+
|
|
69
72
|
async with WorkflowBuilder.from_config(config=config) as workflow_builder:
|
|
70
|
-
|
|
71
|
-
async with
|
|
72
|
-
|
|
73
|
+
session_manager = SessionManager(await workflow_builder.build())
|
|
74
|
+
async with session_manager.session(**session_kwargs) as session:
|
|
75
|
+
async with session.run(prompt) as runner:
|
|
76
|
+
return await runner.result(to_type=to_type)
|