letta-nightly 0.11.3.dev20250819104229__py3-none-any.whl → 0.11.4.dev20250820213507__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- letta/__init__.py +1 -1
- letta/agents/helpers.py +4 -0
- letta/agents/letta_agent.py +142 -5
- letta/constants.py +10 -7
- letta/data_sources/connectors.py +70 -53
- letta/embeddings.py +3 -240
- letta/errors.py +28 -0
- letta/functions/function_sets/base.py +4 -4
- letta/functions/functions.py +287 -32
- letta/functions/mcp_client/types.py +11 -0
- letta/functions/schema_validator.py +187 -0
- letta/functions/typescript_parser.py +196 -0
- letta/helpers/datetime_helpers.py +8 -4
- letta/helpers/tool_execution_helper.py +25 -2
- letta/llm_api/anthropic_client.py +23 -18
- letta/llm_api/azure_client.py +73 -0
- letta/llm_api/bedrock_client.py +8 -4
- letta/llm_api/google_vertex_client.py +14 -5
- letta/llm_api/llm_api_tools.py +2 -217
- letta/llm_api/llm_client.py +15 -1
- letta/llm_api/llm_client_base.py +32 -1
- letta/llm_api/openai.py +1 -0
- letta/llm_api/openai_client.py +18 -28
- letta/llm_api/together_client.py +55 -0
- letta/orm/provider.py +1 -0
- letta/orm/step_metrics.py +40 -1
- letta/otel/db_pool_monitoring.py +1 -1
- letta/schemas/agent.py +3 -4
- letta/schemas/agent_file.py +2 -0
- letta/schemas/block.py +11 -5
- letta/schemas/embedding_config.py +4 -5
- letta/schemas/enums.py +1 -1
- letta/schemas/job.py +2 -3
- letta/schemas/llm_config.py +79 -7
- letta/schemas/mcp.py +0 -24
- letta/schemas/message.py +0 -108
- letta/schemas/openai/chat_completion_request.py +1 -0
- letta/schemas/providers/__init__.py +0 -2
- letta/schemas/providers/anthropic.py +106 -8
- letta/schemas/providers/azure.py +102 -8
- letta/schemas/providers/base.py +10 -3
- letta/schemas/providers/bedrock.py +28 -16
- letta/schemas/providers/letta.py +3 -3
- letta/schemas/providers/ollama.py +2 -12
- letta/schemas/providers/openai.py +4 -4
- letta/schemas/providers/together.py +14 -2
- letta/schemas/sandbox_config.py +2 -1
- letta/schemas/tool.py +46 -22
- letta/server/rest_api/routers/v1/agents.py +179 -38
- letta/server/rest_api/routers/v1/folders.py +13 -8
- letta/server/rest_api/routers/v1/providers.py +10 -3
- letta/server/rest_api/routers/v1/sources.py +14 -8
- letta/server/rest_api/routers/v1/steps.py +17 -1
- letta/server/rest_api/routers/v1/tools.py +96 -5
- letta/server/rest_api/streaming_response.py +91 -45
- letta/server/server.py +27 -38
- letta/services/agent_manager.py +92 -20
- letta/services/agent_serialization_manager.py +11 -7
- letta/services/context_window_calculator/context_window_calculator.py +40 -2
- letta/services/helpers/agent_manager_helper.py +73 -12
- letta/services/mcp_manager.py +109 -15
- letta/services/passage_manager.py +28 -109
- letta/services/provider_manager.py +24 -0
- letta/services/step_manager.py +68 -0
- letta/services/summarizer/summarizer.py +1 -4
- letta/services/tool_executor/core_tool_executor.py +1 -1
- letta/services/tool_executor/sandbox_tool_executor.py +26 -9
- letta/services/tool_manager.py +82 -5
- letta/services/tool_sandbox/base.py +3 -11
- letta/services/tool_sandbox/modal_constants.py +17 -0
- letta/services/tool_sandbox/modal_deployment_manager.py +242 -0
- letta/services/tool_sandbox/modal_sandbox.py +218 -3
- letta/services/tool_sandbox/modal_sandbox_v2.py +429 -0
- letta/services/tool_sandbox/modal_version_manager.py +273 -0
- letta/services/tool_sandbox/safe_pickle.py +193 -0
- letta/settings.py +5 -3
- letta/templates/sandbox_code_file.py.j2 +2 -4
- letta/templates/sandbox_code_file_async.py.j2 +2 -4
- letta/utils.py +1 -1
- {letta_nightly-0.11.3.dev20250819104229.dist-info → letta_nightly-0.11.4.dev20250820213507.dist-info}/METADATA +2 -2
- {letta_nightly-0.11.3.dev20250819104229.dist-info → letta_nightly-0.11.4.dev20250820213507.dist-info}/RECORD +84 -81
- letta/llm_api/anthropic.py +0 -1206
- letta/llm_api/aws_bedrock.py +0 -104
- letta/llm_api/azure_openai.py +0 -118
- letta/llm_api/azure_openai_constants.py +0 -11
- letta/llm_api/cohere.py +0 -391
- letta/schemas/providers/cohere.py +0 -18
- {letta_nightly-0.11.3.dev20250819104229.dist-info → letta_nightly-0.11.4.dev20250820213507.dist-info}/LICENSE +0 -0
- {letta_nightly-0.11.3.dev20250819104229.dist-info → letta_nightly-0.11.4.dev20250820213507.dist-info}/WHEEL +0 -0
- {letta_nightly-0.11.3.dev20250819104229.dist-info → letta_nightly-0.11.4.dev20250820213507.dist-info}/entry_points.txt +0 -0
letta/embeddings.py
CHANGED
@@ -1,13 +1,9 @@
|
|
1
|
-
import
|
2
|
-
from typing import Any, List, Optional
|
1
|
+
from typing import List
|
3
2
|
|
4
|
-
import numpy as np
|
5
3
|
import tiktoken
|
6
|
-
from openai import OpenAI
|
7
4
|
|
8
|
-
from letta.constants import EMBEDDING_TO_TOKENIZER_DEFAULT, EMBEDDING_TO_TOKENIZER_MAP
|
9
|
-
from letta.
|
10
|
-
from letta.utils import is_valid_url, printd
|
5
|
+
from letta.constants import EMBEDDING_TO_TOKENIZER_DEFAULT, EMBEDDING_TO_TOKENIZER_MAP
|
6
|
+
from letta.utils import printd
|
11
7
|
|
12
8
|
|
13
9
|
def parse_and_chunk_text(text: str, chunk_size: int) -> List[str]:
|
@@ -55,236 +51,3 @@ def check_and_split_text(text: str, embedding_model: str) -> List[str]:
|
|
55
51
|
text = truncate_text(formatted_text, max_length, encoding)
|
56
52
|
|
57
53
|
return [text]
|
58
|
-
|
59
|
-
|
60
|
-
class EmbeddingEndpoint:
|
61
|
-
"""Implementation for OpenAI compatible endpoint"""
|
62
|
-
|
63
|
-
# """ Based off llama index https://github.com/run-llama/llama_index/blob/a98bdb8ecee513dc2e880f56674e7fd157d1dc3a/llama_index/embeddings/text_embeddings_inference.py """
|
64
|
-
|
65
|
-
# _user: str = PrivateAttr()
|
66
|
-
# _timeout: float = PrivateAttr()
|
67
|
-
# _base_url: str = PrivateAttr()
|
68
|
-
|
69
|
-
def __init__(
|
70
|
-
self,
|
71
|
-
model: str,
|
72
|
-
base_url: str,
|
73
|
-
user: str,
|
74
|
-
timeout: float = 60.0,
|
75
|
-
**kwargs: Any,
|
76
|
-
):
|
77
|
-
if not is_valid_url(base_url):
|
78
|
-
raise ValueError(
|
79
|
-
f"Embeddings endpoint was provided an invalid URL (set to: '{base_url}'). Make sure embedding_endpoint is set correctly in your Letta config."
|
80
|
-
)
|
81
|
-
# TODO: find a neater solution - re-mapping for letta endpoint
|
82
|
-
if model == "letta-free":
|
83
|
-
model = "BAAI/bge-large-en-v1.5"
|
84
|
-
self.model_name = model
|
85
|
-
self._user = user
|
86
|
-
self._base_url = base_url
|
87
|
-
self._timeout = timeout
|
88
|
-
|
89
|
-
def _call_api(self, text: str) -> List[float]:
|
90
|
-
if not is_valid_url(self._base_url):
|
91
|
-
raise ValueError(
|
92
|
-
f"Embeddings endpoint does not have a valid URL (set to: '{self._base_url}'). Make sure embedding_endpoint is set correctly in your Letta config."
|
93
|
-
)
|
94
|
-
import httpx
|
95
|
-
|
96
|
-
headers = {"Content-Type": "application/json"}
|
97
|
-
json_data = {"input": text, "model": self.model_name, "user": self._user}
|
98
|
-
|
99
|
-
with httpx.Client() as client:
|
100
|
-
response = client.post(
|
101
|
-
f"{self._base_url}/embeddings",
|
102
|
-
headers=headers,
|
103
|
-
json=json_data,
|
104
|
-
timeout=self._timeout,
|
105
|
-
)
|
106
|
-
|
107
|
-
response_json = response.json()
|
108
|
-
|
109
|
-
if isinstance(response_json, list):
|
110
|
-
# embedding directly in response
|
111
|
-
embedding = response_json
|
112
|
-
elif isinstance(response_json, dict):
|
113
|
-
# TEI embedding packaged inside openai-style response
|
114
|
-
try:
|
115
|
-
embedding = response_json["data"][0]["embedding"]
|
116
|
-
except (KeyError, IndexError):
|
117
|
-
raise TypeError(f"Got back an unexpected payload from text embedding function, response=\n{response_json}")
|
118
|
-
else:
|
119
|
-
# unknown response, can't parse
|
120
|
-
raise TypeError(f"Got back an unexpected payload from text embedding function, response=\n{response_json}")
|
121
|
-
|
122
|
-
return embedding
|
123
|
-
|
124
|
-
def get_text_embedding(self, text: str) -> List[float]:
|
125
|
-
return self._call_api(text)
|
126
|
-
|
127
|
-
|
128
|
-
class AzureOpenAIEmbedding:
|
129
|
-
def __init__(self, api_endpoint: str, api_key: str, api_version: str, model: str):
|
130
|
-
from openai import AzureOpenAI
|
131
|
-
|
132
|
-
self.client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_endpoint)
|
133
|
-
self.model = model
|
134
|
-
|
135
|
-
def get_text_embedding(self, text: str):
|
136
|
-
embeddings = self.client.embeddings.create(input=[text], model=self.model).data[0].embedding
|
137
|
-
return embeddings
|
138
|
-
|
139
|
-
|
140
|
-
class OllamaEmbeddings:
|
141
|
-
|
142
|
-
# Uses OpenAI API standard
|
143
|
-
# Format:
|
144
|
-
# curl http://localhost:11434/v1/embeddings -d '{
|
145
|
-
# "model": "mxbai-embed-large",
|
146
|
-
# "input": "Llamas are members of the camelid family"
|
147
|
-
# }'
|
148
|
-
|
149
|
-
def __init__(self, model: str, base_url: str, ollama_additional_kwargs: dict):
|
150
|
-
self.model = model
|
151
|
-
self.base_url = base_url
|
152
|
-
self.ollama_additional_kwargs = ollama_additional_kwargs
|
153
|
-
|
154
|
-
def get_text_embedding(self, text: str):
|
155
|
-
import httpx
|
156
|
-
|
157
|
-
headers = {"Content-Type": "application/json"}
|
158
|
-
json_data = {"model": self.model, "input": text}
|
159
|
-
json_data.update(self.ollama_additional_kwargs)
|
160
|
-
|
161
|
-
with httpx.Client() as client:
|
162
|
-
response = client.post(
|
163
|
-
f"{self.base_url}/embeddings",
|
164
|
-
headers=headers,
|
165
|
-
json=json_data,
|
166
|
-
)
|
167
|
-
|
168
|
-
response_json = response.json()
|
169
|
-
return response_json["data"][0]["embedding"]
|
170
|
-
|
171
|
-
|
172
|
-
class GoogleEmbeddings:
|
173
|
-
def __init__(self, api_key: str, model: str, base_url: str):
|
174
|
-
self.api_key = api_key
|
175
|
-
self.model = model
|
176
|
-
self.base_url = base_url # Expected to be "https://generativelanguage.googleapis.com"
|
177
|
-
|
178
|
-
def get_text_embedding(self, text: str):
|
179
|
-
import httpx
|
180
|
-
|
181
|
-
headers = {"Content-Type": "application/json"}
|
182
|
-
# Build the URL based on the provided base_url, model, and API key.
|
183
|
-
url = f"{self.base_url}/v1beta/models/{self.model}:embedContent?key={self.api_key}"
|
184
|
-
payload = {"model": self.model, "content": {"parts": [{"text": text}]}}
|
185
|
-
with httpx.Client() as client:
|
186
|
-
response = client.post(url, headers=headers, json=payload)
|
187
|
-
# Raise an error for non-success HTTP status codes.
|
188
|
-
response.raise_for_status()
|
189
|
-
response_json = response.json()
|
190
|
-
return response_json["embedding"]["values"]
|
191
|
-
|
192
|
-
|
193
|
-
class GoogleVertexEmbeddings:
|
194
|
-
def __init__(self, model: str, project_id: str, region: str):
|
195
|
-
from google import genai
|
196
|
-
|
197
|
-
self.client = genai.Client(vertexai=True, project=project_id, location=region, http_options={"api_version": "v1"})
|
198
|
-
self.model = model
|
199
|
-
|
200
|
-
def get_text_embedding(self, text: str):
|
201
|
-
response = self.client.generate_embeddings(content=text, model=self.model)
|
202
|
-
return response.embeddings[0].embedding
|
203
|
-
|
204
|
-
|
205
|
-
class OpenAIEmbeddings:
|
206
|
-
def __init__(self, api_key: str, model: str, base_url: str):
|
207
|
-
if base_url:
|
208
|
-
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
209
|
-
else:
|
210
|
-
self.client = OpenAI(api_key=api_key)
|
211
|
-
self.model = model
|
212
|
-
|
213
|
-
def get_text_embedding(self, text: str):
|
214
|
-
response = self.client.embeddings.create(input=text, model=self.model)
|
215
|
-
|
216
|
-
return response.data[0].embedding
|
217
|
-
|
218
|
-
|
219
|
-
def query_embedding(embedding_model, query_text: str):
|
220
|
-
"""Generate padded embedding for querying database"""
|
221
|
-
query_vec = embedding_model.get_text_embedding(query_text)
|
222
|
-
query_vec = np.array(query_vec)
|
223
|
-
query_vec = np.pad(query_vec, (0, MAX_EMBEDDING_DIM - query_vec.shape[0]), mode="constant").tolist()
|
224
|
-
return query_vec
|
225
|
-
|
226
|
-
|
227
|
-
def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None):
|
228
|
-
"""Return LlamaIndex embedding model to use for embeddings"""
|
229
|
-
|
230
|
-
endpoint_type = config.embedding_endpoint_type
|
231
|
-
|
232
|
-
# TODO: refactor to pass in settings from server
|
233
|
-
from letta.settings import model_settings
|
234
|
-
|
235
|
-
if endpoint_type == "openai":
|
236
|
-
return OpenAIEmbeddings(
|
237
|
-
api_key=model_settings.openai_api_key,
|
238
|
-
model=config.embedding_model,
|
239
|
-
base_url=config.embedding_endpoint or model_settings.openai_api_base,
|
240
|
-
)
|
241
|
-
|
242
|
-
elif endpoint_type == "azure":
|
243
|
-
assert all(
|
244
|
-
[
|
245
|
-
model_settings.azure_api_key is not None,
|
246
|
-
model_settings.azure_base_url is not None,
|
247
|
-
model_settings.azure_api_version is not None,
|
248
|
-
]
|
249
|
-
)
|
250
|
-
return AzureOpenAIEmbedding(
|
251
|
-
api_endpoint=model_settings.azure_base_url,
|
252
|
-
api_key=model_settings.azure_api_key,
|
253
|
-
api_version=model_settings.azure_api_version,
|
254
|
-
model=config.embedding_model,
|
255
|
-
)
|
256
|
-
|
257
|
-
elif endpoint_type == "hugging-face":
|
258
|
-
return EmbeddingEndpoint(
|
259
|
-
model=config.embedding_model,
|
260
|
-
base_url=config.embedding_endpoint,
|
261
|
-
user=user_id,
|
262
|
-
)
|
263
|
-
elif endpoint_type == "ollama":
|
264
|
-
|
265
|
-
model = OllamaEmbeddings(
|
266
|
-
model=config.embedding_model,
|
267
|
-
base_url=config.embedding_endpoint,
|
268
|
-
ollama_additional_kwargs={},
|
269
|
-
)
|
270
|
-
return model
|
271
|
-
|
272
|
-
elif endpoint_type == "google_ai":
|
273
|
-
assert all([model_settings.gemini_api_key is not None, model_settings.gemini_base_url is not None])
|
274
|
-
model = GoogleEmbeddings(
|
275
|
-
model=config.embedding_model,
|
276
|
-
api_key=model_settings.gemini_api_key,
|
277
|
-
base_url=model_settings.gemini_base_url,
|
278
|
-
)
|
279
|
-
return model
|
280
|
-
|
281
|
-
elif endpoint_type == "google_vertex":
|
282
|
-
model = GoogleVertexEmbeddings(
|
283
|
-
model=config.embedding_model,
|
284
|
-
api_key=model_settings.gemini_api_key,
|
285
|
-
base_url=model_settings.gemini_base_url,
|
286
|
-
)
|
287
|
-
return model
|
288
|
-
|
289
|
-
else:
|
290
|
-
raise ValueError(f"Unknown endpoint type {endpoint_type}")
|
letta/errors.py
CHANGED
@@ -236,5 +236,33 @@ class AgentFileExportError(Exception):
|
|
236
236
|
"""Exception raised during agent file export operations"""
|
237
237
|
|
238
238
|
|
239
|
+
class AgentNotFoundForExportError(AgentFileExportError):
|
240
|
+
"""Exception raised when requested agents are not found during export"""
|
241
|
+
|
242
|
+
def __init__(self, missing_ids: List[str]):
|
243
|
+
self.missing_ids = missing_ids
|
244
|
+
super().__init__(f"The following agent IDs were not found: {missing_ids}")
|
245
|
+
|
246
|
+
|
247
|
+
class AgentExportIdMappingError(AgentFileExportError):
|
248
|
+
"""Exception raised when ID mapping fails during export conversion"""
|
249
|
+
|
250
|
+
def __init__(self, db_id: str, entity_type: str):
|
251
|
+
self.db_id = db_id
|
252
|
+
self.entity_type = entity_type
|
253
|
+
super().__init__(
|
254
|
+
f"Unexpected new {entity_type} ID '{db_id}' encountered during conversion. "
|
255
|
+
f"All IDs should have been mapped during agent processing."
|
256
|
+
)
|
257
|
+
|
258
|
+
|
259
|
+
class AgentExportProcessingError(AgentFileExportError):
|
260
|
+
"""Exception raised when general export processing fails"""
|
261
|
+
|
262
|
+
def __init__(self, message: str, original_error: Optional[Exception] = None):
|
263
|
+
self.original_error = original_error
|
264
|
+
super().__init__(f"Export failed: {message}")
|
265
|
+
|
266
|
+
|
239
267
|
class AgentFileImportError(Exception):
|
240
268
|
"""Exception raised during agent file import operations"""
|
@@ -63,7 +63,7 @@ def conversation_search(self: "Agent", query: str, page: Optional[int] = 0) -> O
|
|
63
63
|
return results_str
|
64
64
|
|
65
65
|
|
66
|
-
def archival_memory_insert(self: "Agent", content: str) -> Optional[str]:
|
66
|
+
async def archival_memory_insert(self: "Agent", content: str) -> Optional[str]:
|
67
67
|
"""
|
68
68
|
Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later.
|
69
69
|
|
@@ -73,7 +73,7 @@ def archival_memory_insert(self: "Agent", content: str) -> Optional[str]:
|
|
73
73
|
Returns:
|
74
74
|
Optional[str]: None is always returned as this function does not produce a response.
|
75
75
|
"""
|
76
|
-
self.passage_manager.insert_passage(
|
76
|
+
await self.passage_manager.insert_passage(
|
77
77
|
agent_state=self.agent_state,
|
78
78
|
text=content,
|
79
79
|
actor=self.user,
|
@@ -82,7 +82,7 @@ def archival_memory_insert(self: "Agent", content: str) -> Optional[str]:
|
|
82
82
|
return None
|
83
83
|
|
84
84
|
|
85
|
-
def archival_memory_search(self: "Agent", query: str, page: Optional[int] = 0, start: Optional[int] = 0) -> Optional[str]:
|
85
|
+
async def archival_memory_search(self: "Agent", query: str, page: Optional[int] = 0, start: Optional[int] = 0) -> Optional[str]:
|
86
86
|
"""
|
87
87
|
Search archival memory using semantic (embedding-based) search.
|
88
88
|
|
@@ -107,7 +107,7 @@ def archival_memory_search(self: "Agent", query: str, page: Optional[int] = 0, s
|
|
107
107
|
|
108
108
|
try:
|
109
109
|
# Get results using passage manager
|
110
|
-
all_results = self.agent_manager.
|
110
|
+
all_results = await self.agent_manager.list_passages_async(
|
111
111
|
actor=self.user,
|
112
112
|
agent_id=self.agent_state.id,
|
113
113
|
query_text=query,
|
letta/functions/functions.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
import ast
|
1
2
|
import importlib
|
2
3
|
import inspect
|
3
4
|
from collections.abc import Callable
|
@@ -8,45 +9,299 @@ from typing import Any, Dict, List, Literal, Optional
|
|
8
9
|
from letta.errors import LettaToolCreateError
|
9
10
|
from letta.functions.schema_generator import generate_schema
|
10
11
|
|
12
|
+
# NOTE: THIS FILE WILL BE DEPRECATED
|
13
|
+
|
14
|
+
|
15
|
+
class MockFunction:
|
16
|
+
"""A mock function object that mimics the attributes expected by generate_schema."""
|
17
|
+
|
18
|
+
def __init__(self, name: str, docstring: str, signature: inspect.Signature):
|
19
|
+
self.__name__ = name
|
20
|
+
self.__doc__ = docstring
|
21
|
+
self.__signature__ = signature
|
22
|
+
|
23
|
+
def __call__(self, *args, **kwargs):
|
24
|
+
raise NotImplementedError("This is a mock function and cannot be called")
|
25
|
+
|
26
|
+
|
27
|
+
def _parse_type_annotation(annotation_node: ast.AST, imports_map: Dict[str, Any]) -> Any:
|
28
|
+
"""Parse an AST type annotation node back into a Python type object."""
|
29
|
+
if annotation_node is None:
|
30
|
+
return inspect.Parameter.empty
|
31
|
+
|
32
|
+
if isinstance(annotation_node, ast.Name):
|
33
|
+
type_name = annotation_node.id
|
34
|
+
return imports_map.get(type_name, type_name)
|
35
|
+
|
36
|
+
elif isinstance(annotation_node, ast.Subscript):
|
37
|
+
# Generic type like 'List[str]', 'Optional[int]'
|
38
|
+
value_name = annotation_node.value.id if isinstance(annotation_node.value, ast.Name) else str(annotation_node.value)
|
39
|
+
origin_type = imports_map.get(value_name, value_name)
|
40
|
+
|
41
|
+
# Parse the slice (the part inside the brackets)
|
42
|
+
if isinstance(annotation_node.slice, ast.Name):
|
43
|
+
slice_type = _parse_type_annotation(annotation_node.slice, imports_map)
|
44
|
+
if hasattr(origin_type, "__getitem__"):
|
45
|
+
try:
|
46
|
+
return origin_type[slice_type]
|
47
|
+
except (TypeError, AttributeError):
|
48
|
+
pass
|
49
|
+
return f"{origin_type}[{slice_type}]"
|
50
|
+
else:
|
51
|
+
slice_type = _parse_type_annotation(annotation_node.slice, imports_map)
|
52
|
+
if hasattr(origin_type, "__getitem__"):
|
53
|
+
try:
|
54
|
+
return origin_type[slice_type]
|
55
|
+
except (TypeError, AttributeError):
|
56
|
+
pass
|
57
|
+
return f"{origin_type}[{slice_type}]"
|
58
|
+
|
59
|
+
else:
|
60
|
+
# Fallback - return string representation
|
61
|
+
return ast.unparse(annotation_node)
|
62
|
+
|
63
|
+
|
64
|
+
def _build_imports_map(tree: ast.AST) -> Dict[str, Any]:
|
65
|
+
"""Build a mapping of imported names to their Python objects."""
|
66
|
+
imports_map = {
|
67
|
+
"Optional": Optional,
|
68
|
+
"List": List,
|
69
|
+
"Dict": Dict,
|
70
|
+
"Literal": Literal,
|
71
|
+
# Built-in types
|
72
|
+
"str": str,
|
73
|
+
"int": int,
|
74
|
+
"bool": bool,
|
75
|
+
"float": float,
|
76
|
+
"list": list,
|
77
|
+
"dict": dict,
|
78
|
+
}
|
79
|
+
|
80
|
+
# Try to resolve Pydantic imports if they exist in the source
|
81
|
+
for node in ast.walk(tree):
|
82
|
+
if isinstance(node, ast.ImportFrom):
|
83
|
+
if node.module == "pydantic":
|
84
|
+
for alias in node.names:
|
85
|
+
if alias.name == "BaseModel":
|
86
|
+
try:
|
87
|
+
from pydantic import BaseModel
|
88
|
+
|
89
|
+
imports_map["BaseModel"] = BaseModel
|
90
|
+
except ImportError:
|
91
|
+
pass
|
92
|
+
elif alias.name == "Field":
|
93
|
+
try:
|
94
|
+
from pydantic import Field
|
95
|
+
|
96
|
+
imports_map["Field"] = Field
|
97
|
+
except ImportError:
|
98
|
+
pass
|
99
|
+
elif isinstance(node, ast.Import):
|
100
|
+
for alias in node.names:
|
101
|
+
if alias.name == "typing":
|
102
|
+
imports_map.update(
|
103
|
+
{
|
104
|
+
"typing.Optional": Optional,
|
105
|
+
"typing.List": List,
|
106
|
+
"typing.Dict": Dict,
|
107
|
+
"typing.Literal": Literal,
|
108
|
+
}
|
109
|
+
)
|
110
|
+
|
111
|
+
return imports_map
|
112
|
+
|
113
|
+
|
114
|
+
def _extract_pydantic_classes(tree: ast.AST, imports_map: Dict[str, Any]) -> Dict[str, Any]:
|
115
|
+
"""Extract Pydantic model classes from the AST and create them dynamically."""
|
116
|
+
pydantic_classes = {}
|
117
|
+
|
118
|
+
# Check if BaseModel is available
|
119
|
+
if "BaseModel" not in imports_map:
|
120
|
+
return pydantic_classes
|
121
|
+
|
122
|
+
BaseModel = imports_map["BaseModel"]
|
123
|
+
Field = imports_map.get("Field")
|
124
|
+
|
125
|
+
# First pass: collect all class definitions
|
126
|
+
class_definitions = []
|
127
|
+
for node in ast.walk(tree):
|
128
|
+
if isinstance(node, ast.ClassDef):
|
129
|
+
# Check if this class inherits from BaseModel
|
130
|
+
inherits_basemodel = False
|
131
|
+
for base in node.bases:
|
132
|
+
if isinstance(base, ast.Name) and base.id == "BaseModel":
|
133
|
+
inherits_basemodel = True
|
134
|
+
break
|
135
|
+
|
136
|
+
if inherits_basemodel:
|
137
|
+
class_definitions.append(node)
|
138
|
+
|
139
|
+
# Create classes in order, handling dependencies
|
140
|
+
created_classes = {}
|
141
|
+
remaining_classes = class_definitions.copy()
|
142
|
+
|
143
|
+
while remaining_classes:
|
144
|
+
progress_made = False
|
145
|
+
|
146
|
+
for node in remaining_classes.copy():
|
147
|
+
class_name = node.name
|
148
|
+
|
149
|
+
# Try to create this class
|
150
|
+
try:
|
151
|
+
fields = {}
|
152
|
+
annotations = {}
|
153
|
+
|
154
|
+
# Parse class body for field definitions
|
155
|
+
for stmt in node.body:
|
156
|
+
if isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name):
|
157
|
+
field_name = stmt.target.id
|
158
|
+
|
159
|
+
# Update imports_map with already created classes for type resolution
|
160
|
+
current_imports = {**imports_map, **created_classes}
|
161
|
+
field_annotation = _parse_type_annotation(stmt.annotation, current_imports)
|
162
|
+
annotations[field_name] = field_annotation
|
163
|
+
|
164
|
+
# Handle Field() definitions
|
165
|
+
if stmt.value and isinstance(stmt.value, ast.Call):
|
166
|
+
if isinstance(stmt.value.func, ast.Name) and stmt.value.func.id == "Field" and Field:
|
167
|
+
# Parse Field arguments
|
168
|
+
field_kwargs = {}
|
169
|
+
for keyword in stmt.value.keywords:
|
170
|
+
if keyword.arg == "description":
|
171
|
+
if isinstance(keyword.value, ast.Constant):
|
172
|
+
field_kwargs["description"] = keyword.value.value
|
173
|
+
|
174
|
+
# Handle positional args for required fields
|
175
|
+
if stmt.value.args:
|
176
|
+
try:
|
177
|
+
default_val = ast.literal_eval(stmt.value.args[0])
|
178
|
+
if default_val == ...: # Ellipsis means required
|
179
|
+
pass # Field is required, no default
|
180
|
+
else:
|
181
|
+
field_kwargs["default"] = default_val
|
182
|
+
except:
|
183
|
+
pass
|
184
|
+
|
185
|
+
fields[field_name] = Field(**field_kwargs)
|
186
|
+
else:
|
187
|
+
# Not a Field call, try to evaluate the default value
|
188
|
+
try:
|
189
|
+
default_val = ast.literal_eval(stmt.value)
|
190
|
+
fields[field_name] = default_val
|
191
|
+
except:
|
192
|
+
pass
|
193
|
+
|
194
|
+
# Create the dynamic Pydantic model
|
195
|
+
model_dict = {"__annotations__": annotations, **fields}
|
196
|
+
|
197
|
+
DynamicModel = type(class_name, (BaseModel,), model_dict)
|
198
|
+
created_classes[class_name] = DynamicModel
|
199
|
+
remaining_classes.remove(node)
|
200
|
+
progress_made = True
|
201
|
+
|
202
|
+
except Exception:
|
203
|
+
# This class might depend on others, try later
|
204
|
+
continue
|
205
|
+
|
206
|
+
if not progress_made:
|
207
|
+
# If we can't make progress, create remaining classes without proper field types
|
208
|
+
for node in remaining_classes:
|
209
|
+
class_name = node.name
|
210
|
+
# Create a minimal mock class
|
211
|
+
MockModel = type(class_name, (BaseModel,), {})
|
212
|
+
created_classes[class_name] = MockModel
|
213
|
+
break
|
214
|
+
|
215
|
+
return created_classes
|
216
|
+
|
217
|
+
|
218
|
+
def _parse_function_from_source(source_code: str, desired_name: Optional[str] = None) -> MockFunction:
|
219
|
+
"""Parse a function from source code without executing it."""
|
220
|
+
try:
|
221
|
+
tree = ast.parse(source_code)
|
222
|
+
except SyntaxError as e:
|
223
|
+
raise LettaToolCreateError(f"Failed to parse source code: {e}")
|
224
|
+
|
225
|
+
# Build imports mapping and find pydantic classes
|
226
|
+
imports_map = _build_imports_map(tree)
|
227
|
+
pydantic_classes = _extract_pydantic_classes(tree, imports_map)
|
228
|
+
imports_map.update(pydantic_classes)
|
229
|
+
|
230
|
+
# Find function definitions
|
231
|
+
functions = []
|
232
|
+
for node in ast.walk(tree):
|
233
|
+
if isinstance(node, ast.FunctionDef):
|
234
|
+
functions.append(node)
|
235
|
+
|
236
|
+
if not functions:
|
237
|
+
raise LettaToolCreateError("No functions found in source code")
|
238
|
+
|
239
|
+
# Use the last function (matching original behavior)
|
240
|
+
func_node = functions[-1]
|
241
|
+
|
242
|
+
# Extract function name
|
243
|
+
func_name = func_node.name
|
244
|
+
|
245
|
+
# Extract docstring
|
246
|
+
docstring = None
|
247
|
+
if (
|
248
|
+
func_node.body
|
249
|
+
and isinstance(func_node.body[0], ast.Expr)
|
250
|
+
and isinstance(func_node.body[0].value, ast.Constant)
|
251
|
+
and isinstance(func_node.body[0].value.value, str)
|
252
|
+
):
|
253
|
+
docstring = func_node.body[0].value.value
|
254
|
+
|
255
|
+
if not docstring:
|
256
|
+
raise LettaToolCreateError(f"Function {func_name} missing docstring")
|
257
|
+
|
258
|
+
# Build function signature
|
259
|
+
parameters = []
|
260
|
+
for arg in func_node.args.args:
|
261
|
+
param_name = arg.arg
|
262
|
+
param_annotation = _parse_type_annotation(arg.annotation, imports_map)
|
263
|
+
|
264
|
+
# Handle default values
|
265
|
+
defaults_offset = len(func_node.args.args) - len(func_node.args.defaults)
|
266
|
+
param_index = func_node.args.args.index(arg)
|
267
|
+
|
268
|
+
if param_index >= defaults_offset:
|
269
|
+
default_index = param_index - defaults_offset
|
270
|
+
try:
|
271
|
+
default_value = ast.literal_eval(func_node.args.defaults[default_index])
|
272
|
+
except (ValueError, TypeError):
|
273
|
+
# Can't evaluate the default, use Parameter.empty
|
274
|
+
default_value = inspect.Parameter.empty
|
275
|
+
param = inspect.Parameter(
|
276
|
+
param_name, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=param_annotation, default=default_value
|
277
|
+
)
|
278
|
+
else:
|
279
|
+
param = inspect.Parameter(param_name, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=param_annotation)
|
280
|
+
parameters.append(param)
|
281
|
+
|
282
|
+
signature = inspect.Signature(parameters)
|
283
|
+
|
284
|
+
return MockFunction(func_name, docstring, signature)
|
285
|
+
|
11
286
|
|
12
287
|
def derive_openai_json_schema(source_code: str, name: Optional[str] = None) -> dict:
|
13
288
|
"""Derives the OpenAI JSON schema for a given function source code.
|
14
289
|
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
290
|
+
Parses the source code statically to extract function signature and docstring,
|
291
|
+
then generates the schema without executing any code.
|
292
|
+
|
293
|
+
Limitations:
|
294
|
+
- Complex nested Pydantic models with forward references may not be fully supported
|
295
|
+
- Only basic Pydantic Field definitions are parsed (description, ellipsis for required)
|
296
|
+
- Simple types (str, int, bool, float, list, dict) and basic Pydantic models work well
|
19
297
|
"""
|
20
298
|
try:
|
21
|
-
#
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
"Dict": Dict,
|
26
|
-
"Literal": Literal,
|
27
|
-
# To support Pydantic models
|
28
|
-
# "BaseModel": BaseModel,
|
29
|
-
# "Field": Field,
|
30
|
-
}
|
31
|
-
env.update(globals())
|
32
|
-
# print("About to execute source code...")
|
33
|
-
exec(source_code, env)
|
34
|
-
# print("Source code executed successfully")
|
35
|
-
|
36
|
-
functions = [f for f in env if callable(env[f]) and not f.startswith("__")]
|
37
|
-
if not functions:
|
38
|
-
raise LettaToolCreateError("No callable functions found in source code")
|
39
|
-
|
40
|
-
# print(f"Found functions: {functions}")
|
41
|
-
func = env[functions[-1]]
|
42
|
-
|
43
|
-
if not hasattr(func, "__doc__") or not func.__doc__:
|
44
|
-
raise LettaToolCreateError(f"Function {func.__name__} missing docstring")
|
45
|
-
|
46
|
-
# print("About to generate schema...")
|
299
|
+
# Parse the function from source code without executing it
|
300
|
+
mock_func = _parse_function_from_source(source_code, name)
|
301
|
+
|
302
|
+
# Generate schema using the mock function
|
47
303
|
try:
|
48
|
-
schema = generate_schema(
|
49
|
-
# print("Schema generated successfully")
|
304
|
+
schema = generate_schema(mock_func, name=name)
|
50
305
|
return schema
|
51
306
|
except TypeError as e:
|
52
307
|
raise LettaToolCreateError(f"Type error in schema generation: {str(e)}")
|
@@ -18,9 +18,20 @@ TEMPLATED_VARIABLE_REGEX = (
|
|
18
18
|
logger = get_logger(__name__)
|
19
19
|
|
20
20
|
|
21
|
+
class MCPToolHealth(BaseModel):
|
22
|
+
"""Health status for an MCP tool's schema."""
|
23
|
+
|
24
|
+
# TODO: @jnjpng use the enum provided in schema_validator.py
|
25
|
+
status: str = Field(..., description="Schema health status: STRICT_COMPLIANT, NON_STRICT_ONLY, or INVALID")
|
26
|
+
reasons: List[str] = Field(default_factory=list, description="List of reasons for the health status")
|
27
|
+
|
28
|
+
|
21
29
|
class MCPTool(Tool):
|
22
30
|
"""A simple wrapper around MCP's tool definition (to avoid conflict with our own)"""
|
23
31
|
|
32
|
+
# Optional health information added at runtime
|
33
|
+
health: Optional[MCPToolHealth] = Field(None, description="Schema health status for OpenAI strict mode")
|
34
|
+
|
24
35
|
|
25
36
|
class MCPServerType(str, Enum):
|
26
37
|
SSE = "sse"
|