truefoundry 0.7.0rc6__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truefoundry might be problematic. Click here for more details.
- truefoundry/cli/__main__.py +24 -8
- truefoundry/cli/config.py +8 -8
- truefoundry/cli/display_util.py +66 -4
- truefoundry/cli/util.py +2 -2
- truefoundry/common/constants.py +28 -0
- truefoundry/common/utils.py +3 -3
- truefoundry/deploy/cli/commands/__init__.py +1 -3
- truefoundry/deploy/cli/commands/ask_command.py +152 -0
- truefoundry/deploy/cli/commands/delete_command.py +2 -2
- truefoundry/deploy/cli/commands/kubeconfig_command.py +2 -32
- truefoundry/deploy/cli/commands/login_command.py +7 -1
- truefoundry/deploy/cli/commands/utils.py +30 -0
- truefoundry/deploy/lib/clients/_mcp_streamable_http.py +264 -0
- truefoundry/deploy/lib/clients/ask_client.py +371 -0
- truefoundry/workflow/__init__.py +4 -1
- {truefoundry-0.7.0rc6.dist-info → truefoundry-0.7.1.dist-info}/METADATA +4 -3
- {truefoundry-0.7.0rc6.dist-info → truefoundry-0.7.1.dist-info}/RECORD +19 -19
- truefoundry/deploy/cli/commands/create_command.py +0 -75
- truefoundry/deploy/cli/commands/list_command.py +0 -171
- truefoundry/deploy/cli/commands/redeploy_command.py +0 -41
- {truefoundry-0.7.0rc6.dist-info → truefoundry-0.7.1.dist-info}/WHEEL +0 -0
- {truefoundry-0.7.0rc6.dist-info → truefoundry-0.7.1.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
"""
|
|
2
|
+
StreamableHTTP Client Transport Module
|
|
3
|
+
# From https://github.com/modelcontextprotocol/python-sdk/pull/573
|
|
4
|
+
|
|
5
|
+
This module implements the StreamableHTTP transport for MCP clients,
|
|
6
|
+
providing support for HTTP POST requests with optional SSE streaming responses
|
|
7
|
+
and session management.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
from contextlib import asynccontextmanager
|
|
12
|
+
from datetime import timedelta
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
import anyio
|
|
16
|
+
import httpx
|
|
17
|
+
from httpx_sse import EventSource, aconnect_sse
|
|
18
|
+
from mcp.types import (
|
|
19
|
+
ErrorData,
|
|
20
|
+
JSONRPCError,
|
|
21
|
+
JSONRPCMessage,
|
|
22
|
+
JSONRPCNotification,
|
|
23
|
+
JSONRPCRequest,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
# Header names
|
|
29
|
+
MCP_SESSION_ID_HEADER = "mcp-session-id"
|
|
30
|
+
LAST_EVENT_ID_HEADER = "last-event-id"
|
|
31
|
+
|
|
32
|
+
# Content types
|
|
33
|
+
CONTENT_TYPE_JSON = "application/json"
|
|
34
|
+
CONTENT_TYPE_SSE = "text/event-stream"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@asynccontextmanager
|
|
38
|
+
async def streamablehttp_client(
|
|
39
|
+
url: str,
|
|
40
|
+
headers: dict[str, Any] | None = None,
|
|
41
|
+
timeout: timedelta = timedelta(seconds=30),
|
|
42
|
+
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
|
|
43
|
+
):
|
|
44
|
+
"""
|
|
45
|
+
Client transport for StreamableHTTP.
|
|
46
|
+
|
|
47
|
+
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
|
|
48
|
+
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
|
49
|
+
|
|
50
|
+
Yields:
|
|
51
|
+
Tuple of (read_stream, write_stream, terminate_callback)
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
read_stream_writer, read_stream = anyio.create_memory_object_stream[
|
|
55
|
+
JSONRPCMessage | Exception
|
|
56
|
+
](0)
|
|
57
|
+
write_stream, write_stream_reader = anyio.create_memory_object_stream[
|
|
58
|
+
JSONRPCMessage
|
|
59
|
+
](0)
|
|
60
|
+
|
|
61
|
+
async def get_stream():
|
|
62
|
+
"""
|
|
63
|
+
Optional GET stream for server-initiated messages
|
|
64
|
+
"""
|
|
65
|
+
nonlocal session_id
|
|
66
|
+
try:
|
|
67
|
+
# Only attempt GET if we have a session ID
|
|
68
|
+
if not session_id:
|
|
69
|
+
return
|
|
70
|
+
|
|
71
|
+
get_headers = request_headers.copy()
|
|
72
|
+
get_headers[MCP_SESSION_ID_HEADER] = session_id
|
|
73
|
+
|
|
74
|
+
async with aconnect_sse(
|
|
75
|
+
client,
|
|
76
|
+
"GET",
|
|
77
|
+
url,
|
|
78
|
+
headers=get_headers,
|
|
79
|
+
timeout=httpx.Timeout(timeout.seconds, read=sse_read_timeout.seconds),
|
|
80
|
+
) as event_source:
|
|
81
|
+
event_source.response.raise_for_status()
|
|
82
|
+
logger.debug("GET SSE connection established")
|
|
83
|
+
|
|
84
|
+
async for sse in event_source.aiter_sse():
|
|
85
|
+
if sse.event == "message":
|
|
86
|
+
try:
|
|
87
|
+
message = JSONRPCMessage.model_validate_json(sse.data)
|
|
88
|
+
logger.debug(f"GET message: {message}")
|
|
89
|
+
await read_stream_writer.send(message)
|
|
90
|
+
except Exception as exc:
|
|
91
|
+
logger.error(f"Error parsing GET message: {exc}")
|
|
92
|
+
await read_stream_writer.send(exc)
|
|
93
|
+
else:
|
|
94
|
+
logger.warning(f"Unknown SSE event from GET: {sse.event}")
|
|
95
|
+
except Exception as exc:
|
|
96
|
+
# GET stream is optional, so don't propagate errors
|
|
97
|
+
logger.debug(f"GET stream error (non-fatal): {exc}")
|
|
98
|
+
|
|
99
|
+
async def post_writer(client: httpx.AsyncClient):
|
|
100
|
+
nonlocal session_id
|
|
101
|
+
try:
|
|
102
|
+
async with write_stream_reader:
|
|
103
|
+
async for message in write_stream_reader:
|
|
104
|
+
# Add session ID to headers if we have one
|
|
105
|
+
post_headers = request_headers.copy()
|
|
106
|
+
if session_id:
|
|
107
|
+
post_headers[MCP_SESSION_ID_HEADER] = session_id
|
|
108
|
+
|
|
109
|
+
logger.debug(f"Sending client message: {message}")
|
|
110
|
+
|
|
111
|
+
# Handle initial initialization request
|
|
112
|
+
is_initialization = (
|
|
113
|
+
isinstance(message.root, JSONRPCRequest)
|
|
114
|
+
and message.root.method == "initialize"
|
|
115
|
+
)
|
|
116
|
+
if (
|
|
117
|
+
isinstance(message.root, JSONRPCNotification)
|
|
118
|
+
and message.root.method == "notifications/initialized"
|
|
119
|
+
):
|
|
120
|
+
tg.start_soon(get_stream)
|
|
121
|
+
|
|
122
|
+
async with client.stream(
|
|
123
|
+
"POST",
|
|
124
|
+
url,
|
|
125
|
+
json=message.model_dump(
|
|
126
|
+
by_alias=True, mode="json", exclude_none=True
|
|
127
|
+
),
|
|
128
|
+
headers=post_headers,
|
|
129
|
+
) as response:
|
|
130
|
+
if response.status_code == 202:
|
|
131
|
+
logger.debug("Received 202 Accepted")
|
|
132
|
+
continue
|
|
133
|
+
# Check for 404 (session expired/invalid)
|
|
134
|
+
if response.status_code == 404:
|
|
135
|
+
if isinstance(message.root, JSONRPCRequest):
|
|
136
|
+
jsonrpc_error = JSONRPCError(
|
|
137
|
+
jsonrpc="2.0",
|
|
138
|
+
id=message.root.id,
|
|
139
|
+
error=ErrorData(
|
|
140
|
+
code=32600,
|
|
141
|
+
message="Session terminated",
|
|
142
|
+
),
|
|
143
|
+
)
|
|
144
|
+
await read_stream_writer.send(
|
|
145
|
+
JSONRPCMessage(jsonrpc_error)
|
|
146
|
+
)
|
|
147
|
+
continue
|
|
148
|
+
|
|
149
|
+
if not response.is_success:
|
|
150
|
+
_response_content = await response.aread()
|
|
151
|
+
logger.error(
|
|
152
|
+
f"Response: {response.status_code} {_response_content}"
|
|
153
|
+
)
|
|
154
|
+
response.raise_for_status()
|
|
155
|
+
|
|
156
|
+
# Extract session ID from response headers
|
|
157
|
+
if is_initialization:
|
|
158
|
+
new_session_id = response.headers.get(MCP_SESSION_ID_HEADER)
|
|
159
|
+
if new_session_id:
|
|
160
|
+
session_id = new_session_id
|
|
161
|
+
logger.info(f"Received session ID: {session_id}")
|
|
162
|
+
|
|
163
|
+
# Handle different response types
|
|
164
|
+
content_type = response.headers.get("content-type", "").lower()
|
|
165
|
+
|
|
166
|
+
if content_type.startswith(CONTENT_TYPE_JSON):
|
|
167
|
+
try:
|
|
168
|
+
content = await response.aread()
|
|
169
|
+
json_message = JSONRPCMessage.model_validate_json(
|
|
170
|
+
content
|
|
171
|
+
)
|
|
172
|
+
await read_stream_writer.send(json_message)
|
|
173
|
+
except Exception as exc:
|
|
174
|
+
logger.error(f"Error parsing JSON response: {exc}")
|
|
175
|
+
await read_stream_writer.send(exc)
|
|
176
|
+
|
|
177
|
+
elif content_type.startswith(CONTENT_TYPE_SSE):
|
|
178
|
+
# Parse SSE events from the response
|
|
179
|
+
try:
|
|
180
|
+
event_source = EventSource(response)
|
|
181
|
+
async for sse in event_source.aiter_sse():
|
|
182
|
+
if sse.event == "message":
|
|
183
|
+
try:
|
|
184
|
+
await read_stream_writer.send(
|
|
185
|
+
JSONRPCMessage.model_validate_json(
|
|
186
|
+
sse.data
|
|
187
|
+
)
|
|
188
|
+
)
|
|
189
|
+
except Exception as exc:
|
|
190
|
+
logger.exception("Error parsing message")
|
|
191
|
+
await read_stream_writer.send(exc)
|
|
192
|
+
else:
|
|
193
|
+
logger.warning(f"Unknown event: {sse.event}")
|
|
194
|
+
|
|
195
|
+
except Exception as e:
|
|
196
|
+
logger.exception("Error reading SSE stream:")
|
|
197
|
+
await read_stream_writer.send(e)
|
|
198
|
+
|
|
199
|
+
else:
|
|
200
|
+
# For 202 Accepted with no body
|
|
201
|
+
if response.status_code == 202:
|
|
202
|
+
logger.debug("Received 202 Accepted")
|
|
203
|
+
continue
|
|
204
|
+
|
|
205
|
+
error_msg = f"Unexpected content type: {content_type}"
|
|
206
|
+
logger.error(error_msg)
|
|
207
|
+
await read_stream_writer.send(ValueError(error_msg))
|
|
208
|
+
|
|
209
|
+
except Exception as exc:
|
|
210
|
+
logger.error(f"Error in post_writer: {exc}")
|
|
211
|
+
finally:
|
|
212
|
+
await read_stream_writer.aclose()
|
|
213
|
+
await write_stream.aclose()
|
|
214
|
+
|
|
215
|
+
async def terminate_session():
|
|
216
|
+
"""
|
|
217
|
+
Terminate the session by sending a DELETE request.
|
|
218
|
+
"""
|
|
219
|
+
nonlocal session_id
|
|
220
|
+
if not session_id:
|
|
221
|
+
return # No session to terminate
|
|
222
|
+
|
|
223
|
+
try:
|
|
224
|
+
delete_headers = request_headers.copy()
|
|
225
|
+
delete_headers[MCP_SESSION_ID_HEADER] = session_id
|
|
226
|
+
|
|
227
|
+
response = await client.delete(
|
|
228
|
+
url,
|
|
229
|
+
headers=delete_headers,
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
if response.status_code == 405:
|
|
233
|
+
# Server doesn't allow client-initiated termination
|
|
234
|
+
logger.debug("Server does not allow session termination")
|
|
235
|
+
elif response.status_code != 200:
|
|
236
|
+
logger.warning(f"Session termination failed: {response.status_code}")
|
|
237
|
+
except Exception as exc:
|
|
238
|
+
logger.warning(f"Session termination failed: {exc}")
|
|
239
|
+
|
|
240
|
+
async with anyio.create_task_group() as tg:
|
|
241
|
+
try:
|
|
242
|
+
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")
|
|
243
|
+
# Set up headers with required Accept header
|
|
244
|
+
request_headers = {
|
|
245
|
+
"Accept": f"{CONTENT_TYPE_JSON}, {CONTENT_TYPE_SSE}",
|
|
246
|
+
"Content-Type": CONTENT_TYPE_JSON,
|
|
247
|
+
**(headers or {}),
|
|
248
|
+
}
|
|
249
|
+
# Track session ID if provided by server
|
|
250
|
+
session_id: str | None = None
|
|
251
|
+
|
|
252
|
+
async with httpx.AsyncClient(
|
|
253
|
+
headers=request_headers,
|
|
254
|
+
timeout=httpx.Timeout(timeout.seconds, read=sse_read_timeout.seconds),
|
|
255
|
+
follow_redirects=True,
|
|
256
|
+
) as client:
|
|
257
|
+
tg.start_soon(post_writer, client)
|
|
258
|
+
try:
|
|
259
|
+
yield read_stream, write_stream, terminate_session
|
|
260
|
+
finally:
|
|
261
|
+
tg.cancel_scope.cancel()
|
|
262
|
+
finally:
|
|
263
|
+
await read_stream_writer.aclose()
|
|
264
|
+
await write_stream.aclose()
|
|
@@ -0,0 +1,371 @@
|
|
|
1
|
+
try:
|
|
2
|
+
from mcp import ClientSession
|
|
3
|
+
from mcp.types import TextContent
|
|
4
|
+
|
|
5
|
+
from truefoundry.deploy.lib.clients._mcp_streamable_http import (
|
|
6
|
+
streamablehttp_client,
|
|
7
|
+
)
|
|
8
|
+
except ImportError:
|
|
9
|
+
import sys
|
|
10
|
+
|
|
11
|
+
python_version = sys.version_info
|
|
12
|
+
raise ImportError(
|
|
13
|
+
f"This feature requires Python 3.10 or higher. Your current Python version is '{python_version.major}.{python_version.minor}.{python_version.micro}'. "
|
|
14
|
+
"Please upgrade to a supported version."
|
|
15
|
+
) from None
|
|
16
|
+
|
|
17
|
+
import json
|
|
18
|
+
from contextlib import AsyncExitStack
|
|
19
|
+
from typing import List, Optional, Union
|
|
20
|
+
|
|
21
|
+
import rich_click as click
|
|
22
|
+
from openai import NOT_GIVEN, AsyncOpenAI
|
|
23
|
+
from openai.types.chat import (
|
|
24
|
+
ChatCompletionAssistantMessageParam,
|
|
25
|
+
ChatCompletionMessageParam,
|
|
26
|
+
ChatCompletionSystemMessageParam,
|
|
27
|
+
ChatCompletionToolMessageParam,
|
|
28
|
+
ChatCompletionToolParam,
|
|
29
|
+
ChatCompletionUserMessageParam,
|
|
30
|
+
)
|
|
31
|
+
from pydantic import BaseModel
|
|
32
|
+
from rich.console import Console
|
|
33
|
+
from rich.status import Status
|
|
34
|
+
|
|
35
|
+
from truefoundry.cli.display_util import log_chat_completion_message
|
|
36
|
+
from truefoundry.common.constants import ENV_VARS
|
|
37
|
+
from truefoundry.logger import logger
|
|
38
|
+
|
|
39
|
+
console = Console(soft_wrap=False)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class AskClient:
|
|
43
|
+
"""Handles the chat session lifecycle between the user and the assistant via OpenAI and MCP."""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
cluster: str,
|
|
48
|
+
token: str,
|
|
49
|
+
openai_model: str,
|
|
50
|
+
debug: bool = False,
|
|
51
|
+
openai_client: Optional[AsyncOpenAI] = None,
|
|
52
|
+
):
|
|
53
|
+
self.cluster = cluster
|
|
54
|
+
self.token = token
|
|
55
|
+
self.debug = debug
|
|
56
|
+
|
|
57
|
+
self.async_openai_client = openai_client or AsyncOpenAI()
|
|
58
|
+
# Initialize the OpenAI client with the session
|
|
59
|
+
self.openai_model = openai_model
|
|
60
|
+
self._log_message(
|
|
61
|
+
f"\nInitialize OpenAI client with model: {self.openai_model!r}, base_url: {str(self.async_openai_client.base_url)!r}\n"
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
self.exit_stack = AsyncExitStack()
|
|
65
|
+
self.history: List[ChatCompletionMessageParam] = []
|
|
66
|
+
self._cached_tools: Optional[List[ChatCompletionToolParam]] = None
|
|
67
|
+
|
|
68
|
+
async def connect(self, server_url: str, prompt_name: Optional[str] = None):
|
|
69
|
+
"""Initialize connection to the SSE-based MCP server and prepare the chat session."""
|
|
70
|
+
try:
|
|
71
|
+
logger.debug(f"Starting a new client for {server_url}")
|
|
72
|
+
self._streams_context = streamablehttp_client(
|
|
73
|
+
url=server_url, headers=self._auth_headers()
|
|
74
|
+
)
|
|
75
|
+
(
|
|
76
|
+
read_stream,
|
|
77
|
+
write_stream,
|
|
78
|
+
self._terminate_cb,
|
|
79
|
+
) = await self._streams_context.__aenter__()
|
|
80
|
+
self._session_context = ClientSession(
|
|
81
|
+
read_stream=read_stream, write_stream=write_stream
|
|
82
|
+
)
|
|
83
|
+
self.session = await self._session_context.__aenter__()
|
|
84
|
+
await self.session.initialize()
|
|
85
|
+
self._log_message("Connected and session initialized.")
|
|
86
|
+
await self._list_tools() # Pre-load tool definitions for tool-calling
|
|
87
|
+
self._log_message(
|
|
88
|
+
"\nTFY ASK is ready. Type 'exit' to quit.", log=self.debug
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
await self._load_initial_prompt(prompt_name)
|
|
92
|
+
|
|
93
|
+
except Exception as e:
|
|
94
|
+
self._log_message(f"❌ Connection error: {e}")
|
|
95
|
+
await self.cleanup()
|
|
96
|
+
raise
|
|
97
|
+
finally:
|
|
98
|
+
await self.exit_stack.__aenter__()
|
|
99
|
+
|
|
100
|
+
async def cleanup(self):
|
|
101
|
+
"""Properly close all async contexts opened during session initialization."""
|
|
102
|
+
await self._terminate_cb()
|
|
103
|
+
for context in [
|
|
104
|
+
getattr(self, "_session_context", None),
|
|
105
|
+
getattr(self, "_streams_context", None),
|
|
106
|
+
]:
|
|
107
|
+
if context:
|
|
108
|
+
await context.__aexit__(None, None, None)
|
|
109
|
+
|
|
110
|
+
async def chat_loop(self):
|
|
111
|
+
"""Interactive loop: accepts user queries and returns responses until interrupted or 'exit' is typed."""
|
|
112
|
+
await self.process_query() # Optional greeting message from assistant
|
|
113
|
+
|
|
114
|
+
while True:
|
|
115
|
+
try:
|
|
116
|
+
query = click.prompt(click.style("User", fg="yellow"), type=str)
|
|
117
|
+
if not query:
|
|
118
|
+
self._log_message("Empty query. Type 'exit' to quit.", log=True)
|
|
119
|
+
continue
|
|
120
|
+
|
|
121
|
+
if query.lower() in ("exit", "quit"):
|
|
122
|
+
self._log_message("Exiting chat...")
|
|
123
|
+
break
|
|
124
|
+
|
|
125
|
+
await self.process_query(query)
|
|
126
|
+
|
|
127
|
+
except (KeyboardInterrupt, EOFError, click.Abort):
|
|
128
|
+
self._log_message("\nChat interrupted.")
|
|
129
|
+
break
|
|
130
|
+
|
|
131
|
+
async def process_query(self, query: Optional[str] = None, max_turns: int = 50):
|
|
132
|
+
"""Handles sending user input to the assistant and processing the assistant’s reply."""
|
|
133
|
+
if query:
|
|
134
|
+
self._append_message(
|
|
135
|
+
ChatCompletionUserMessageParam(role="user", content=query),
|
|
136
|
+
log=self.debug,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
tools = await self._list_tools() # Fetch or use cached tool list
|
|
140
|
+
|
|
141
|
+
turn: int = 0
|
|
142
|
+
# Backup history to revert if OpenAI call fails
|
|
143
|
+
_checkpoint_idx = len(self.history)
|
|
144
|
+
|
|
145
|
+
with console.status(status="Thinking...", spinner="dots") as spinner:
|
|
146
|
+
while True:
|
|
147
|
+
try:
|
|
148
|
+
if turn >= max_turns:
|
|
149
|
+
self._log_message("Max turns reached. Exiting.")
|
|
150
|
+
break
|
|
151
|
+
spinner.update("Thinking...", spinner="dots")
|
|
152
|
+
response = await self._call_openai(
|
|
153
|
+
model=self.openai_model, tools=tools
|
|
154
|
+
)
|
|
155
|
+
turn += 1
|
|
156
|
+
message = response.choices[0].message
|
|
157
|
+
|
|
158
|
+
if message.tool_calls:
|
|
159
|
+
await self._handle_tool_calls(message, spinner)
|
|
160
|
+
elif message.content:
|
|
161
|
+
self._append_message(
|
|
162
|
+
ChatCompletionAssistantMessageParam(
|
|
163
|
+
role="assistant", content=message.content
|
|
164
|
+
)
|
|
165
|
+
)
|
|
166
|
+
break
|
|
167
|
+
else:
|
|
168
|
+
self._log_message("No assistant response.")
|
|
169
|
+
break
|
|
170
|
+
except Exception as e:
|
|
171
|
+
self._log_message(f"OpenAI call failed: {e}", log=self.debug)
|
|
172
|
+
console.print(
|
|
173
|
+
"Something went wrong. Please try rephrasing your query."
|
|
174
|
+
)
|
|
175
|
+
self.history = self.history[
|
|
176
|
+
:_checkpoint_idx
|
|
177
|
+
] # Revert to safe state
|
|
178
|
+
turn = 0
|
|
179
|
+
break
|
|
180
|
+
|
|
181
|
+
async def _list_tools(self) -> Optional[List[ChatCompletionToolParam]]:
|
|
182
|
+
"""Fetch and cache the list of available tools from the MCP session."""
|
|
183
|
+
if self._cached_tools:
|
|
184
|
+
return self._cached_tools
|
|
185
|
+
|
|
186
|
+
self._cached_tools = [
|
|
187
|
+
{
|
|
188
|
+
"type": "function",
|
|
189
|
+
"function": {
|
|
190
|
+
"name": t.name,
|
|
191
|
+
"description": t.description,
|
|
192
|
+
"parameters": t.inputSchema,
|
|
193
|
+
},
|
|
194
|
+
}
|
|
195
|
+
for t in (await self.session.list_tools()).tools
|
|
196
|
+
]
|
|
197
|
+
|
|
198
|
+
self._log_message("\nAvailable tools:")
|
|
199
|
+
for tool in self._cached_tools or []:
|
|
200
|
+
self._log_message(
|
|
201
|
+
f" - {tool['function']['name']}: {tool['function']['description']}"
|
|
202
|
+
)
|
|
203
|
+
return self._cached_tools
|
|
204
|
+
|
|
205
|
+
async def _load_initial_prompt(self, prompt_name: Optional[str]) -> None:
|
|
206
|
+
"""Load a system prompt to set assistant behavior at session start."""
|
|
207
|
+
if not (self.session and prompt_name):
|
|
208
|
+
return
|
|
209
|
+
|
|
210
|
+
result = await self.session.get_prompt(name=prompt_name)
|
|
211
|
+
if not result:
|
|
212
|
+
self._log_message("Failed to get initial system prompt.")
|
|
213
|
+
return
|
|
214
|
+
|
|
215
|
+
for message in result.messages:
|
|
216
|
+
data = message.model_dump() if isinstance(message, BaseModel) else message
|
|
217
|
+
content = None
|
|
218
|
+
if isinstance(data, dict):
|
|
219
|
+
content = (
|
|
220
|
+
data.get("content", {}).get("text")
|
|
221
|
+
if isinstance(data.get("content"), dict)
|
|
222
|
+
else data.get("content")
|
|
223
|
+
)
|
|
224
|
+
else:
|
|
225
|
+
content = data
|
|
226
|
+
|
|
227
|
+
# First message is system prompt?
|
|
228
|
+
|
|
229
|
+
if content:
|
|
230
|
+
self._append_message(
|
|
231
|
+
ChatCompletionSystemMessageParam(role="system", content=content),
|
|
232
|
+
log=self.debug,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
async def _call_openai(
|
|
236
|
+
self, model: str, tools: Optional[List[ChatCompletionToolParam]]
|
|
237
|
+
):
|
|
238
|
+
"""Make a chat completion request to OpenAI with optional tool support."""
|
|
239
|
+
return await self.async_openai_client.chat.completions.create(
|
|
240
|
+
model=model,
|
|
241
|
+
messages=self.history,
|
|
242
|
+
tools=tools or NOT_GIVEN,
|
|
243
|
+
temperature=0.0, # Set to 0 for deterministic behavior
|
|
244
|
+
top_p=1,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
async def _handle_tool_calls(self, message, spinner: Status):
|
|
248
|
+
"""Execute tool calls returned by the assistant and return the results."""
|
|
249
|
+
for tool_call in message.tool_calls:
|
|
250
|
+
try:
|
|
251
|
+
spinner.update(
|
|
252
|
+
f"Executing tool: {tool_call.function.name}", spinner="aesthetic"
|
|
253
|
+
)
|
|
254
|
+
args = json.loads(tool_call.function.arguments)
|
|
255
|
+
result = await self.session.call_tool(tool_call.function.name, args)
|
|
256
|
+
content = getattr(result, "content", result)
|
|
257
|
+
result_content = self._format_tool_result(content)
|
|
258
|
+
except Exception as e:
|
|
259
|
+
result_content = f"Tool `{tool_call.function.name}` call failed: {e}"
|
|
260
|
+
|
|
261
|
+
# Log assistant's tool call
|
|
262
|
+
self._append_message(
|
|
263
|
+
ChatCompletionAssistantMessageParam(
|
|
264
|
+
role="assistant", content=None, tool_calls=[tool_call]
|
|
265
|
+
),
|
|
266
|
+
log=self.debug,
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
# Log tool response
|
|
270
|
+
self._append_message(
|
|
271
|
+
ChatCompletionToolMessageParam(
|
|
272
|
+
role="tool", tool_call_id=tool_call.id, content=result_content
|
|
273
|
+
),
|
|
274
|
+
log=self.debug,
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
def _format_tool_result(self, content) -> str:
|
|
278
|
+
"""Format tool result into a readable string or JSON block."""
|
|
279
|
+
if isinstance(content, list):
|
|
280
|
+
content = (
|
|
281
|
+
content[0].text
|
|
282
|
+
if len(content) == 1 and isinstance(content[0], TextContent)
|
|
283
|
+
else content
|
|
284
|
+
)
|
|
285
|
+
if isinstance(content, list):
|
|
286
|
+
return (
|
|
287
|
+
"```\n"
|
|
288
|
+
+ "\n".join(
|
|
289
|
+
(
|
|
290
|
+
item.model_dump_json(indent=2)
|
|
291
|
+
if isinstance(item, BaseModel)
|
|
292
|
+
else str(item)
|
|
293
|
+
)
|
|
294
|
+
for item in content
|
|
295
|
+
)
|
|
296
|
+
+ "\n```"
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
if isinstance(content, (BaseModel, dict)):
|
|
300
|
+
return (
|
|
301
|
+
"```\n"
|
|
302
|
+
+ json.dumps(
|
|
303
|
+
content.model_dump() if isinstance(content, BaseModel) else content,
|
|
304
|
+
indent=2,
|
|
305
|
+
)
|
|
306
|
+
+ "\n```"
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
if isinstance(content, str):
|
|
310
|
+
try:
|
|
311
|
+
return "```\n" + json.dumps(json.loads(content), indent=2) + "\n```"
|
|
312
|
+
except Exception:
|
|
313
|
+
return content
|
|
314
|
+
|
|
315
|
+
return str(content)
|
|
316
|
+
|
|
317
|
+
def _append_message(self, message: ChatCompletionMessageParam, log: bool = True):
|
|
318
|
+
"""Append a message to history and optionally log it."""
|
|
319
|
+
self._log_message(message, log)
|
|
320
|
+
self.history.append(message)
|
|
321
|
+
|
|
322
|
+
def _auth_headers(self):
|
|
323
|
+
"""Generate authorization headers for connecting to the SSE server."""
|
|
324
|
+
return {
|
|
325
|
+
"Authorization": f"Bearer {self.token}",
|
|
326
|
+
"X-TFY-Cluster-Id": self.cluster,
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
def _log_message(
|
|
330
|
+
self,
|
|
331
|
+
message: Union[str, ChatCompletionMessageParam],
|
|
332
|
+
log: bool = False,
|
|
333
|
+
):
|
|
334
|
+
"""Display a message using Rich console, conditionally based on debug settings."""
|
|
335
|
+
if not self.debug and not log:
|
|
336
|
+
return
|
|
337
|
+
if isinstance(message, str):
|
|
338
|
+
console.print(message)
|
|
339
|
+
else:
|
|
340
|
+
log_chat_completion_message(message, console_=console)
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
async def ask_client_main(
|
|
344
|
+
cluster: str,
|
|
345
|
+
server_url: str,
|
|
346
|
+
token: str,
|
|
347
|
+
openai_model: str,
|
|
348
|
+
debug: bool = False,
|
|
349
|
+
openai_client: Optional[AsyncOpenAI] = None,
|
|
350
|
+
):
|
|
351
|
+
"""Main entrypoint for launching the AskClient chat loop."""
|
|
352
|
+
ask_client = AskClient(
|
|
353
|
+
cluster=cluster,
|
|
354
|
+
token=token,
|
|
355
|
+
debug=debug,
|
|
356
|
+
openai_client=openai_client,
|
|
357
|
+
openai_model=openai_model,
|
|
358
|
+
)
|
|
359
|
+
try:
|
|
360
|
+
await ask_client.connect(
|
|
361
|
+
server_url=server_url, prompt_name=ENV_VARS.TFY_ASK_SYSTEM_PROMPT_NAME
|
|
362
|
+
)
|
|
363
|
+
await ask_client.chat_loop()
|
|
364
|
+
except Exception as e:
|
|
365
|
+
console.print(
|
|
366
|
+
f"[red]An unexpected error occurred while running the assistant: {e}[/red], Check with TrueFoundry support for more details."
|
|
367
|
+
)
|
|
368
|
+
except KeyboardInterrupt:
|
|
369
|
+
console.print("[yellow]Chat interrupted.[/yellow]")
|
|
370
|
+
finally:
|
|
371
|
+
await ask_client.cleanup()
|
truefoundry/workflow/__init__.py
CHANGED
|
@@ -2,7 +2,10 @@ try:
|
|
|
2
2
|
import fsspec
|
|
3
3
|
from flytekit import task as _
|
|
4
4
|
except ImportError:
|
|
5
|
-
print(
|
|
5
|
+
print(
|
|
6
|
+
"To use workflows, please run 'pip install truefoundry[workflow]'. "
|
|
7
|
+
"Note: The `workflow` feature is only available for Python 3.9 to 3.12"
|
|
8
|
+
)
|
|
6
9
|
|
|
7
10
|
from flytekit import conditional
|
|
8
11
|
from flytekit.types.directory import FlyteDirectory
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: truefoundry
|
|
3
|
-
Version: 0.7.
|
|
3
|
+
Version: 0.7.1
|
|
4
4
|
Summary: TrueFoundry CLI
|
|
5
5
|
Author-email: TrueFoundry Team <abhishek@truefoundry.com>
|
|
6
6
|
Requires-Python: <3.14,>=3.8.1
|
|
@@ -14,9 +14,10 @@ Requires-Dist: gitpython<4.0.0,>=3.1.43
|
|
|
14
14
|
Requires-Dist: importlib-metadata<9.0.0,>=4.11.3
|
|
15
15
|
Requires-Dist: importlib-resources<7.0.0,>=5.2.0
|
|
16
16
|
Requires-Dist: mako<2.0.0,>=1.1.6
|
|
17
|
+
Requires-Dist: mcp==1.6.0; python_version >= '3.10'
|
|
17
18
|
Requires-Dist: numpy<3.0.0,>=1.23.0
|
|
18
19
|
Requires-Dist: openai<2.0.0,>=1.16.2
|
|
19
|
-
Requires-Dist: packaging<
|
|
20
|
+
Requires-Dist: packaging<26.0,>=20.0
|
|
20
21
|
Requires-Dist: pydantic<3.0.0,>=1.8.2
|
|
21
22
|
Requires-Dist: pygments<3.0.0,>=2.12.0
|
|
22
23
|
Requires-Dist: pyjwt<3.0.0,>=2.0.0
|
|
@@ -30,7 +31,7 @@ Requires-Dist: requirements-parser<0.12.0,>=0.11.0
|
|
|
30
31
|
Requires-Dist: rich-click<2.0.0,>=1.2.1
|
|
31
32
|
Requires-Dist: rich<14.0.0,>=13.7.1
|
|
32
33
|
Requires-Dist: tqdm<5.0.0,>=4.0.0
|
|
33
|
-
Requires-Dist: truefoundry-sdk==0.0.
|
|
34
|
+
Requires-Dist: truefoundry-sdk==0.0.16
|
|
34
35
|
Requires-Dist: typing-extensions>=4.0
|
|
35
36
|
Requires-Dist: urllib3<3,>=1.26.18
|
|
36
37
|
Requires-Dist: yq<4.0.0,>=3.1.0
|