letta-nightly 0.5.1.dev20241030104135__py3-none-any.whl → 0.5.1.dev20241101104122__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of letta-nightly might be problematic. Click here for more details.

Files changed (35) hide show
  1. letta/agent.py +41 -2
  2. letta/client/client.py +98 -23
  3. letta/client/streaming.py +3 -0
  4. letta/constants.py +3 -0
  5. letta/functions/functions.py +4 -5
  6. letta/functions/schema_generator.py +4 -3
  7. letta/helpers/__init__.py +1 -0
  8. letta/helpers/tool_rule_solver.py +115 -0
  9. letta/llm_api/helpers.py +3 -1
  10. letta/llm_api/llm_api_tools.py +1 -2
  11. letta/llm_api/openai.py +5 -0
  12. letta/metadata.py +43 -1
  13. letta/orm/__init__.py +4 -0
  14. letta/orm/tool.py +0 -3
  15. letta/schemas/agent.py +5 -5
  16. letta/schemas/letta_response.py +3 -3
  17. letta/schemas/tool.py +4 -6
  18. letta/schemas/tool_rule.py +25 -0
  19. letta/server/rest_api/app.py +5 -3
  20. letta/server/rest_api/routers/v1/agents.py +16 -3
  21. letta/server/rest_api/routers/v1/organizations.py +2 -2
  22. letta/server/server.py +7 -43
  23. letta/server/startup.sh +3 -0
  24. letta/server/static_files/assets/{index-d6b3669a.js → index-9fa459a2.js} +66 -69
  25. letta/server/static_files/index.html +1 -1
  26. letta/services/tool_manager.py +21 -4
  27. {letta_nightly-0.5.1.dev20241030104135.dist-info → letta_nightly-0.5.1.dev20241101104122.dist-info}/METADATA +1 -1
  28. {letta_nightly-0.5.1.dev20241030104135.dist-info → letta_nightly-0.5.1.dev20241101104122.dist-info}/RECORD +31 -32
  29. letta/server/rest_api/admin/__init__.py +0 -0
  30. letta/server/rest_api/admin/agents.py +0 -21
  31. letta/server/rest_api/admin/tools.py +0 -82
  32. letta/server/rest_api/admin/users.py +0 -98
  33. {letta_nightly-0.5.1.dev20241030104135.dist-info → letta_nightly-0.5.1.dev20241101104122.dist-info}/LICENSE +0 -0
  34. {letta_nightly-0.5.1.dev20241030104135.dist-info → letta_nightly-0.5.1.dev20241101104122.dist-info}/WHEEL +0 -0
  35. {letta_nightly-0.5.1.dev20241030104135.dist-info → letta_nightly-0.5.1.dev20241101104122.dist-info}/entry_points.txt +0 -0
letta/agent.py CHANGED
@@ -20,6 +20,7 @@ from letta.constants import (
20
20
  REQ_HEARTBEAT_MESSAGE,
21
21
  )
22
22
  from letta.errors import LLMError
23
+ from letta.helpers import ToolRulesSolver
23
24
  from letta.interface import AgentInterface
24
25
  from letta.llm_api.helpers import is_context_overflow_error
25
26
  from letta.llm_api.llm_api_tools import create
@@ -43,6 +44,7 @@ from letta.schemas.openai.chat_completion_response import (
43
44
  from letta.schemas.openai.chat_completion_response import UsageStatistics
44
45
  from letta.schemas.passage import Passage
45
46
  from letta.schemas.tool import Tool
47
+ from letta.schemas.tool_rule import TerminalToolRule
46
48
  from letta.schemas.usage import LettaUsageStatistics
47
49
  from letta.system import (
48
50
  get_heartbeat,
@@ -242,6 +244,14 @@ class Agent(BaseAgent):
242
244
  # link tools
243
245
  self.link_tools(tools)
244
246
 
247
+ # initialize a tool rules solver
248
+ if agent_state.tool_rules:
249
+ # if there are tool rules, print out a warning
250
+ warnings.warn("Tool rules only work reliably for the latest OpenAI models that support structured outputs.")
251
+ # add default rule for having send_message be a terminal tool
252
+ agent_state.tool_rules.append(TerminalToolRule(tool_name="send_message"))
253
+ self.tool_rules_solver = ToolRulesSolver(tool_rules=agent_state.tool_rules)
254
+
245
255
  # gpt-4, gpt-3.5-turbo, ...
246
256
  self.model = self.agent_state.llm_config.model
247
257
 
@@ -465,15 +475,26 @@ class Agent(BaseAgent):
465
475
  function_call: str = "auto",
466
476
  first_message: bool = False, # hint
467
477
  stream: bool = False, # TODO move to config?
478
+ fail_on_empty_response: bool = False,
479
+ empty_response_retry_limit: int = 3,
468
480
  ) -> ChatCompletionResponse:
469
481
  """Get response from LLM API"""
482
+ # Get the allowed tools based on the ToolRulesSolver state
483
+ allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names()
484
+
485
+ if not allowed_tool_names:
486
+ # if it's empty, any available tools are fair game
487
+ allowed_functions = self.functions
488
+ else:
489
+ allowed_functions = [func for func in self.functions if func["name"] in allowed_tool_names]
490
+
470
491
  try:
471
492
  response = create(
472
493
  # agent_state=self.agent_state,
473
494
  llm_config=self.agent_state.llm_config,
474
495
  messages=message_sequence,
475
496
  user_id=self.agent_state.user_id,
476
- functions=self.functions,
497
+ functions=allowed_functions,
477
498
  functions_python=self.functions_python,
478
499
  function_call=function_call,
479
500
  # hint
@@ -484,7 +505,15 @@ class Agent(BaseAgent):
484
505
  )
485
506
 
486
507
  if len(response.choices) == 0 or response.choices[0] is None:
487
- raise Exception(f"API call didn't return a message: {response}")
508
+ empty_api_err_message = f"API call didn't return a message: {response}"
509
+ if fail_on_empty_response or empty_response_retry_limit == 0:
510
+ raise Exception(empty_api_err_message)
511
+ else:
512
+ # Decrement retry limit and try again
513
+ warnings.warn(empty_api_err_message)
514
+ return self._get_ai_reply(
515
+ message_sequence, function_call, first_message, stream, fail_on_empty_response, empty_response_retry_limit - 1
516
+ )
488
517
 
489
518
  # special case for 'length'
490
519
  if response.choices[0].finish_reason == "length":
@@ -515,6 +544,7 @@ class Agent(BaseAgent):
515
544
  assert response_message_id.startswith("message-"), response_message_id
516
545
 
517
546
  messages = [] # append these to the history when done
547
+ function_name = None
518
548
 
519
549
  # Step 2: check if LLM wanted to call a function
520
550
  if response_message.function_call or (response_message.tool_calls is not None and len(response_message.tool_calls) > 0):
@@ -724,6 +754,15 @@ class Agent(BaseAgent):
724
754
  # TODO: @charles please check this
725
755
  self.rebuild_memory()
726
756
 
757
+ # Update ToolRulesSolver state with last called function
758
+ self.tool_rules_solver.update_tool_usage(function_name)
759
+
760
+ # Update heartbeat request according to provided tool rules
761
+ if self.tool_rules_solver.has_children_tools(function_name):
762
+ heartbeat_request = True
763
+ elif self.tool_rules_solver.is_terminal_tool(function_name):
764
+ heartbeat_request = False
765
+
727
766
  return messages, heartbeat_request, function_failed
728
767
 
729
768
  def step(
letta/client/client.py CHANGED
@@ -5,7 +5,7 @@ from typing import Callable, Dict, Generator, List, Optional, Union
5
5
  import requests
6
6
 
7
7
  import letta.utils
8
- from letta.constants import BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA
8
+ from letta.constants import ADMIN_PREFIX, BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA
9
9
  from letta.data_sources.connectors import DataConnector
10
10
  from letta.functions.functions import parse_source_code
11
11
  from letta.memory import get_memory_functions
@@ -39,9 +39,11 @@ from letta.schemas.memory import (
39
39
  )
40
40
  from letta.schemas.message import Message, MessageCreate, UpdateMessage
41
41
  from letta.schemas.openai.chat_completions import ToolCall
42
+ from letta.schemas.organization import Organization
42
43
  from letta.schemas.passage import Passage
43
44
  from letta.schemas.source import Source, SourceCreate, SourceUpdate
44
45
  from letta.schemas.tool import Tool, ToolCreate, ToolUpdate
46
+ from letta.schemas.tool_rule import BaseToolRule
45
47
  from letta.server.rest_api.interface import QueuingInterface
46
48
  from letta.server.server import SyncServer
47
49
  from letta.utils import get_human_text, get_persona_text
@@ -139,6 +141,8 @@ class AbstractClient(object):
139
141
  agent_id: Optional[str] = None,
140
142
  name: Optional[str] = None,
141
143
  stream: Optional[bool] = False,
144
+ stream_steps: bool = False,
145
+ stream_tokens: bool = False,
142
146
  include_full_message: Optional[bool] = False,
143
147
  ) -> LettaResponse:
144
148
  raise NotImplementedError
@@ -195,7 +199,6 @@ class AbstractClient(object):
195
199
  self,
196
200
  func,
197
201
  name: Optional[str] = None,
198
- update: Optional[bool] = True,
199
202
  tags: Optional[List[str]] = None,
200
203
  ) -> Tool:
201
204
  raise NotImplementedError
@@ -204,6 +207,7 @@ class AbstractClient(object):
204
207
  self,
205
208
  id: str,
206
209
  name: Optional[str] = None,
210
+ description: Optional[str] = None,
207
211
  func: Optional[Callable] = None,
208
212
  tags: Optional[List[str]] = None,
209
213
  ) -> Tool:
@@ -282,6 +286,15 @@ class AbstractClient(object):
282
286
  def list_embedding_configs(self) -> List[EmbeddingConfig]:
283
287
  raise NotImplementedError
284
288
 
289
+ def create_org(self, name: Optional[str] = None) -> Organization:
290
+ raise NotImplementedError
291
+
292
+ def list_orgs(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Organization]:
293
+ raise NotImplementedError
294
+
295
+ def delete_org(self, org_id: str) -> Organization:
296
+ raise NotImplementedError
297
+
285
298
 
286
299
  class RESTClient(AbstractClient):
287
300
  """
@@ -394,7 +407,7 @@ class RESTClient(AbstractClient):
394
407
  # add memory tools
395
408
  memory_functions = get_memory_functions(memory)
396
409
  for func_name, func in memory_functions.items():
397
- tool = self.create_tool(func, name=func_name, tags=["memory", "letta-base"], update=True)
410
+ tool = self.create_tool(func, name=func_name, tags=["memory", "letta-base"])
398
411
  tool_names.append(tool.name)
399
412
 
400
413
  # check if default configs are provided
@@ -778,7 +791,7 @@ class RESTClient(AbstractClient):
778
791
  name: Optional[str] = None,
779
792
  stream_steps: bool = False,
780
793
  stream_tokens: bool = False,
781
- include_full_message: Optional[bool] = False,
794
+ include_full_message: bool = False,
782
795
  ) -> Union[LettaResponse, Generator[LettaStreamingResponse, None, None]]:
783
796
  """
784
797
  Send a message to an agent
@@ -799,7 +812,12 @@ class RESTClient(AbstractClient):
799
812
  # TODO: figure out how to handle stream_steps and stream_tokens
800
813
 
801
814
  # When streaming steps is True, stream_tokens must be False
802
- request = LettaRequest(messages=messages, stream_steps=stream_steps, stream_tokens=stream_tokens, return_message_object=True)
815
+ request = LettaRequest(
816
+ messages=messages,
817
+ stream_steps=stream_steps,
818
+ stream_tokens=stream_tokens,
819
+ return_message_object=include_full_message,
820
+ )
803
821
  if stream_tokens or stream_steps:
804
822
  from letta.client.streaming import _sse_post
805
823
 
@@ -814,12 +832,12 @@ class RESTClient(AbstractClient):
814
832
  response = LettaResponse(**response.json())
815
833
 
816
834
  # simplify messages
817
- if not include_full_message:
818
- messages = []
819
- for m in response.messages:
820
- assert isinstance(m, Message)
821
- messages += m.to_letta_message()
822
- response.messages = messages
835
+ # if not include_full_message:
836
+ # messages = []
837
+ # for m in response.messages:
838
+ # assert isinstance(m, Message)
839
+ # messages += m.to_letta_message()
840
+ # response.messages = messages
823
841
 
824
842
  return response
825
843
 
@@ -1257,7 +1275,6 @@ class RESTClient(AbstractClient):
1257
1275
  self,
1258
1276
  func: Callable,
1259
1277
  name: Optional[str] = None,
1260
- update: Optional[bool] = True, # TODO: actually use this
1261
1278
  tags: Optional[List[str]] = None,
1262
1279
  ) -> Tool:
1263
1280
  """
@@ -1267,7 +1284,6 @@ class RESTClient(AbstractClient):
1267
1284
  func (callable): The function to create a tool for.
1268
1285
  name: (str): Name of the tool (must be unique per-user.)
1269
1286
  tags (Optional[List[str]], optional): Tags for the tool. Defaults to None.
1270
- update (bool, optional): Update the tool if it already exists. Defaults to True.
1271
1287
 
1272
1288
  Returns:
1273
1289
  tool (Tool): The created tool.
@@ -1292,6 +1308,7 @@ class RESTClient(AbstractClient):
1292
1308
  self,
1293
1309
  id: str,
1294
1310
  name: Optional[str] = None,
1311
+ description: Optional[str] = None,
1295
1312
  func: Optional[Callable] = None,
1296
1313
  tags: Optional[List[str]] = None,
1297
1314
  ) -> Tool:
@@ -1314,7 +1331,7 @@ class RESTClient(AbstractClient):
1314
1331
 
1315
1332
  source_type = "python"
1316
1333
 
1317
- request = ToolUpdate(source_type=source_type, source_code=source_code, tags=tags, name=name)
1334
+ request = ToolUpdate(description=description, source_type=source_type, source_code=source_code, tags=tags, name=name)
1318
1335
  response = requests.patch(f"{self.base_url}/{self.api_prefix}/tools/{id}", json=request.model_dump(), headers=self.headers)
1319
1336
  if response.status_code != 200:
1320
1337
  raise ValueError(f"Failed to update tool: {response.text}")
@@ -1464,6 +1481,54 @@ class RESTClient(AbstractClient):
1464
1481
  raise ValueError(f"Failed to list embedding configs: {response.text}")
1465
1482
  return [EmbeddingConfig(**config) for config in response.json()]
1466
1483
 
1484
+ def list_orgs(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Organization]:
1485
+ """
1486
+ Retrieves a list of all organizations in the database, with optional pagination.
1487
+
1488
+ @param cursor: the pagination cursor, if any
1489
+ @param limit: the maximum number of organizations to retrieve
1490
+ @return: a list of Organization objects
1491
+ """
1492
+ params = {"cursor": cursor, "limit": limit}
1493
+ response = requests.get(f"{self.base_url}/{ADMIN_PREFIX}/orgs", headers=self.headers, params=params)
1494
+ if response.status_code != 200:
1495
+ raise ValueError(f"Failed to retrieve organizations: {response.text}")
1496
+ return [Organization(**org_data) for org_data in response.json()]
1497
+
1498
+ def create_org(self, name: Optional[str] = None) -> Organization:
1499
+ """
1500
+ Creates an organization with the given name. If not provided, we generate a random one.
1501
+
1502
+ @param name: the name of the organization
1503
+ @return: the created Organization
1504
+ """
1505
+ payload = {"name": name}
1506
+ response = requests.post(f"{self.base_url}/{ADMIN_PREFIX}/orgs", headers=self.headers, json=payload)
1507
+ if response.status_code != 200:
1508
+ raise ValueError(f"Failed to create org: {response.text}")
1509
+ return Organization(**response.json())
1510
+
1511
+ def delete_org(self, org_id: str) -> Organization:
1512
+ """
1513
+ Deletes an organization by its ID.
1514
+
1515
+ @param org_id: the ID of the organization to delete
1516
+ @return: the deleted Organization object
1517
+ """
1518
+ # Define query parameters with org_id
1519
+ params = {"org_id": org_id}
1520
+
1521
+ # Make the DELETE request with query parameters
1522
+ response = requests.delete(f"{self.base_url}/{ADMIN_PREFIX}/orgs", headers=self.headers, params=params)
1523
+
1524
+ if response.status_code == 404:
1525
+ raise ValueError(f"Organization with ID '{org_id}' does not exist")
1526
+ elif response.status_code != 200:
1527
+ raise ValueError(f"Failed to delete organization: {response.text}")
1528
+
1529
+ # Parse and return the deleted organization
1530
+ return Organization(**response.json())
1531
+
1467
1532
 
1468
1533
  class LocalClient(AbstractClient):
1469
1534
  """
@@ -1568,6 +1633,7 @@ class LocalClient(AbstractClient):
1568
1633
  system: Optional[str] = None,
1569
1634
  # tools
1570
1635
  tools: Optional[List[str]] = None,
1636
+ tool_rules: Optional[List[BaseToolRule]] = None,
1571
1637
  include_base_tools: Optional[bool] = True,
1572
1638
  # metadata
1573
1639
  metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA},
@@ -1582,6 +1648,7 @@ class LocalClient(AbstractClient):
1582
1648
  memory (Memory): Memory configuration
1583
1649
  system (str): System configuration
1584
1650
  tools (List[str]): List of tools
1651
+ tool_rules (Optional[List[BaseToolRule]]): List of tool rules
1585
1652
  include_base_tools (bool): Include base tools
1586
1653
  metadata (Dict): Metadata
1587
1654
  description (str): Description
@@ -1603,7 +1670,7 @@ class LocalClient(AbstractClient):
1603
1670
  # add memory tools
1604
1671
  memory_functions = get_memory_functions(memory)
1605
1672
  for func_name, func in memory_functions.items():
1606
- tool = self.create_tool(func, name=func_name, tags=["memory", "letta-base"], update=True)
1673
+ tool = self.create_tool(func, name=func_name, tags=["memory", "letta-base"])
1607
1674
  tool_names.append(tool.name)
1608
1675
 
1609
1676
  self.interface.clear()
@@ -1620,6 +1687,7 @@ class LocalClient(AbstractClient):
1620
1687
  metadata_=metadata,
1621
1688
  memory=memory,
1622
1689
  tools=tool_names,
1690
+ tool_rules=tool_rules,
1623
1691
  system=system,
1624
1692
  agent_type=agent_type,
1625
1693
  llm_config=llm_config if llm_config else self._default_llm_config,
@@ -2175,7 +2243,6 @@ class LocalClient(AbstractClient):
2175
2243
  def load_langchain_tool(self, langchain_tool: "LangChainBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> Tool:
2176
2244
  tool_create = ToolCreate.from_langchain(
2177
2245
  langchain_tool=langchain_tool,
2178
- organization_id=self.org_id,
2179
2246
  additional_imports_module_attr_map=additional_imports_module_attr_map,
2180
2247
  )
2181
2248
  return self.server.tool_manager.create_or_update_tool(tool_create, actor=self.user)
@@ -2184,12 +2251,11 @@ class LocalClient(AbstractClient):
2184
2251
  tool_create = ToolCreate.from_crewai(
2185
2252
  crewai_tool=crewai_tool,
2186
2253
  additional_imports_module_attr_map=additional_imports_module_attr_map,
2187
- organization_id=self.org_id,
2188
2254
  )
2189
2255
  return self.server.tool_manager.create_or_update_tool(tool_create, actor=self.user)
2190
2256
 
2191
2257
  def load_composio_tool(self, action: "ActionType") -> Tool:
2192
- tool_create = ToolCreate.from_composio(action=action, organization_id=self.org_id)
2258
+ tool_create = ToolCreate.from_composio(action=action)
2193
2259
  return self.server.tool_manager.create_or_update_tool(tool_create, actor=self.user)
2194
2260
 
2195
2261
  # TODO: Use the above function `add_tool` here as there is duplicate logic
@@ -2197,9 +2263,8 @@ class LocalClient(AbstractClient):
2197
2263
  self,
2198
2264
  func,
2199
2265
  name: Optional[str] = None,
2200
- update: Optional[bool] = True, # TODO: actually use this
2201
2266
  tags: Optional[List[str]] = None,
2202
- terminal: Optional[bool] = False,
2267
+ description: Optional[str] = None,
2203
2268
  ) -> Tool:
2204
2269
  """
2205
2270
  Create a tool. This stores the source code of function on the server, so that the server can execute the function and generate an OpenAI JSON schemas for it when using with an agent.
@@ -2208,8 +2273,7 @@ class LocalClient(AbstractClient):
2208
2273
  func (callable): The function to create a tool for.
2209
2274
  name: (str): Name of the tool (must be unique per-user.)
2210
2275
  tags (Optional[List[str]], optional): Tags for the tool. Defaults to None.
2211
- update (bool, optional): Update the tool if it already exists. Defaults to True.
2212
- terminal (bool, optional): Whether the tool is a terminal tool (no more agent steps). Defaults to False.
2276
+ description (str, optional): The description.
2213
2277
 
2214
2278
  Returns:
2215
2279
  tool (Tool): The created tool.
@@ -2229,7 +2293,7 @@ class LocalClient(AbstractClient):
2229
2293
  source_code=source_code,
2230
2294
  name=name,
2231
2295
  tags=tags,
2232
- terminal=terminal,
2296
+ description=description,
2233
2297
  ),
2234
2298
  actor=self.user,
2235
2299
  )
@@ -2238,6 +2302,7 @@ class LocalClient(AbstractClient):
2238
2302
  self,
2239
2303
  id: str,
2240
2304
  name: Optional[str] = None,
2305
+ description: Optional[str] = None,
2241
2306
  func: Optional[callable] = None,
2242
2307
  tags: Optional[List[str]] = None,
2243
2308
  ) -> Tool:
@@ -2258,6 +2323,7 @@ class LocalClient(AbstractClient):
2258
2323
  "source_code": parse_source_code(func) if func else None,
2259
2324
  "tags": tags,
2260
2325
  "name": name,
2326
+ "description": description,
2261
2327
  }
2262
2328
 
2263
2329
  # Filter out any None values from the dictionary
@@ -2648,3 +2714,12 @@ class LocalClient(AbstractClient):
2648
2714
  configs (List[EmbeddingConfig]): List of embedding configurations
2649
2715
  """
2650
2716
  return self.server.list_embedding_models()
2717
+
2718
+ def create_org(self, name: Optional[str] = None) -> Organization:
2719
+ return self.server.organization_manager.create_organization(name=name)
2720
+
2721
+ def list_orgs(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Organization]:
2722
+ return self.server.organization_manager.list_organizations(cursor=cursor, limit=limit)
2723
+
2724
+ def delete_org(self, org_id: str) -> Organization:
2725
+ return self.server.organization_manager.delete_organization_by_id(org_id=org_id)
letta/client/streaming.py CHANGED
@@ -13,6 +13,7 @@ from letta.schemas.letta_message import (
13
13
  InternalMonologue,
14
14
  )
15
15
  from letta.schemas.letta_response import LettaStreamingResponse
16
+ from letta.schemas.usage import LettaUsageStatistics
16
17
 
17
18
 
18
19
  def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingResponse, None, None]:
@@ -58,6 +59,8 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingRe
58
59
  yield FunctionCallMessage(**chunk_data)
59
60
  elif "function_return" in chunk_data:
60
61
  yield FunctionReturn(**chunk_data)
62
+ elif "usage" in chunk_data:
63
+ yield LettaUsageStatistics(**chunk_data["usage"])
61
64
  else:
62
65
  raise ValueError(f"Unknown message type in chunk_data: {chunk_data}")
63
66
 
letta/constants.py CHANGED
@@ -3,6 +3,9 @@ from logging import CRITICAL, DEBUG, ERROR, INFO, NOTSET, WARN, WARNING
3
3
 
4
4
  LETTA_DIR = os.path.join(os.path.expanduser("~"), ".letta")
5
5
 
6
+ ADMIN_PREFIX = "/v1/admin"
7
+ API_PREFIX = "/v1"
8
+ OPENAI_API_PREFIX = "/openai"
6
9
 
7
10
  # String in the error message for when the context window is too large
8
11
  # Example full message:
@@ -7,10 +7,9 @@ from typing import Optional
7
7
 
8
8
  from letta.constants import CLI_WARNING_PREFIX
9
9
  from letta.functions.schema_generator import generate_schema
10
- from letta.schemas.tool import ToolCreate
11
10
 
12
11
 
13
- def derive_openai_json_schema(tool_create: ToolCreate) -> dict:
12
+ def derive_openai_json_schema(source_code: str, name: Optional[str]) -> dict:
14
13
  # auto-generate openai schema
15
14
  try:
16
15
  # Define a custom environment with necessary imports
@@ -19,14 +18,14 @@ def derive_openai_json_schema(tool_create: ToolCreate) -> dict:
19
18
  }
20
19
 
21
20
  env.update(globals())
22
- exec(tool_create.source_code, env)
21
+ exec(source_code, env)
23
22
 
24
23
  # get available functions
25
24
  functions = [f for f in env if callable(env[f])]
26
25
 
27
26
  # TODO: not sure if this always works
28
27
  func = env[functions[-1]]
29
- json_schema = generate_schema(func, terminal=tool_create.terminal, name=tool_create.name)
28
+ json_schema = generate_schema(func, name=name)
30
29
  return json_schema
31
30
  except Exception as e:
32
31
  raise RuntimeError(f"Failed to execute source code: {e}")
@@ -51,7 +50,7 @@ def load_function_set(module: ModuleType) -> dict:
51
50
  if attr_name in function_dict:
52
51
  raise ValueError(f"Found a duplicate of function name '{attr_name}'")
53
52
 
54
- generated_schema = generate_schema(attr, terminal=False)
53
+ generated_schema = generate_schema(attr)
55
54
  function_dict[attr_name] = {
56
55
  "module": inspect.getsource(module),
57
56
  "python_function": attr,
@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional, Type, Union, get_args, get_origin
3
3
 
4
4
  from docstring_parser import parse
5
5
  from pydantic import BaseModel
6
+ from pydantic.v1 import BaseModel as V1BaseModel
6
7
 
7
8
 
8
9
  def is_optional(annotation):
@@ -74,7 +75,7 @@ def pydantic_model_to_open_ai(model):
74
75
  }
75
76
 
76
77
 
77
- def generate_schema(function, terminal: Optional[bool], name: Optional[str] = None, description: Optional[str] = None) -> dict:
78
+ def generate_schema(function, name: Optional[str] = None, description: Optional[str] = None) -> dict:
78
79
  # Get the signature of the function
79
80
  sig = inspect.signature(function)
80
81
 
@@ -128,7 +129,7 @@ def generate_schema(function, terminal: Optional[bool], name: Optional[str] = No
128
129
 
129
130
  # append the heartbeat
130
131
  # TODO: don't hard-code
131
- if function.__name__ not in ["send_message", "pause_heartbeats"] and not terminal:
132
+ if function.__name__ not in ["send_message", "pause_heartbeats"]:
132
133
  schema["parameters"]["properties"]["request_heartbeat"] = {
133
134
  "type": "boolean",
134
135
  "description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
@@ -139,7 +140,7 @@ def generate_schema(function, terminal: Optional[bool], name: Optional[str] = No
139
140
 
140
141
 
141
142
  def generate_schema_from_args_schema(
142
- args_schema: Type[BaseModel], name: Optional[str] = None, description: Optional[str] = None, append_heartbeat: bool = True
143
+ args_schema: Type[V1BaseModel], name: Optional[str] = None, description: Optional[str] = None, append_heartbeat: bool = True
143
144
  ) -> Dict[str, Any]:
144
145
  properties = {}
145
146
  required = []
@@ -0,0 +1 @@
1
+ from letta.helpers.tool_rule_solver import ToolRulesSolver
@@ -0,0 +1,115 @@
1
+ import warnings
2
+ from typing import Dict, List, Optional, Set
3
+
4
+ from pydantic import BaseModel, Field
5
+
6
+ from letta.schemas.tool_rule import (
7
+ BaseToolRule,
8
+ InitToolRule,
9
+ TerminalToolRule,
10
+ ToolRule,
11
+ )
12
+
13
+
14
+ class ToolRuleValidationError(Exception):
15
+ """Custom exception for tool rule validation errors in ToolRulesSolver."""
16
+
17
+ def __init__(self, message: str):
18
+ super().__init__(f"ToolRuleValidationError: {message}")
19
+
20
+
21
+ class ToolRulesSolver(BaseModel):
22
+ init_tool_rules: List[InitToolRule] = Field(
23
+ default_factory=list, description="Initial tool rules to be used at the start of tool execution."
24
+ )
25
+ tool_rules: List[ToolRule] = Field(
26
+ default_factory=list, description="Standard tool rules for controlling execution sequence and allowed transitions."
27
+ )
28
+ terminal_tool_rules: List[TerminalToolRule] = Field(
29
+ default_factory=list, description="Terminal tool rules that end the agent loop if called."
30
+ )
31
+ last_tool_name: Optional[str] = Field(None, description="The most recent tool used, updated with each tool call.")
32
+
33
+ def __init__(self, tool_rules: List[BaseToolRule], **kwargs):
34
+ super().__init__(**kwargs)
35
+ # Separate the provided tool rules into init, standard, and terminal categories
36
+ for rule in tool_rules:
37
+ if isinstance(rule, InitToolRule):
38
+ self.init_tool_rules.append(rule)
39
+ elif isinstance(rule, ToolRule):
40
+ self.tool_rules.append(rule)
41
+ elif isinstance(rule, TerminalToolRule):
42
+ self.terminal_tool_rules.append(rule)
43
+
44
+ # Validate the tool rules to ensure they form a DAG
45
+ if not self.validate_tool_rules():
46
+ raise ToolRuleValidationError("Tool rules contain cycles, which are not allowed in a valid configuration.")
47
+
48
+ def update_tool_usage(self, tool_name: str):
49
+ """Update the internal state to track the last tool called."""
50
+ self.last_tool_name = tool_name
51
+
52
+ def get_allowed_tool_names(self, error_on_empty: bool = False) -> List[str]:
53
+ """Get a list of tool names allowed based on the last tool called."""
54
+ if self.last_tool_name is None:
55
+ # Use initial tool rules if no tool has been called yet
56
+ return [rule.tool_name for rule in self.init_tool_rules]
57
+ else:
58
+ # Find a matching ToolRule for the last tool used
59
+ current_rule = next((rule for rule in self.tool_rules if rule.tool_name == self.last_tool_name), None)
60
+
61
+ # Return children which must exist on ToolRule
62
+ if current_rule:
63
+ return current_rule.children
64
+
65
+ # Default to empty if no rule matches
66
+ message = "User provided tool rules and execution state resolved to no more possible tool calls."
67
+ if error_on_empty:
68
+ raise RuntimeError(message)
69
+ else:
70
+ warnings.warn(message)
71
+ return []
72
+
73
+ def is_terminal_tool(self, tool_name: str) -> bool:
74
+ """Check if the tool is defined as a terminal tool in the terminal tool rules."""
75
+ return any(rule.tool_name == tool_name for rule in self.terminal_tool_rules)
76
+
77
+ def has_children_tools(self, tool_name):
78
+ """Check if the tool has children tools"""
79
+ return any(rule.tool_name == tool_name for rule in self.tool_rules)
80
+
81
+ def validate_tool_rules(self) -> bool:
82
+ """
83
+ Validate that the tool rules define a directed acyclic graph (DAG).
84
+ Returns True if valid (no cycles), otherwise False.
85
+ """
86
+ # Build adjacency list for the tool graph
87
+ adjacency_list: Dict[str, List[str]] = {rule.tool_name: rule.children for rule in self.tool_rules}
88
+
89
+ # Track visited nodes
90
+ visited: Set[str] = set()
91
+ path_stack: Set[str] = set()
92
+
93
+ # Define DFS helper function
94
+ def dfs(tool_name: str) -> bool:
95
+ if tool_name in path_stack:
96
+ return False # Cycle detected
97
+ if tool_name in visited:
98
+ return True # Already validated
99
+
100
+ # Mark the node as visited in the current path
101
+ path_stack.add(tool_name)
102
+ for child in adjacency_list.get(tool_name, []):
103
+ if not dfs(child):
104
+ return False # Cycle detected in DFS
105
+ path_stack.remove(tool_name) # Remove from current path
106
+ visited.add(tool_name)
107
+ return True
108
+
109
+ # Run DFS from each tool in `tool_rules`
110
+ for rule in self.tool_rules:
111
+ if rule.tool_name not in visited:
112
+ if not dfs(rule.tool_name):
113
+ return False # Cycle found, invalid tool rules
114
+
115
+ return True # No cycles, valid DAG
letta/llm_api/helpers.py CHANGED
@@ -16,9 +16,11 @@ def convert_to_structured_output(openai_function: dict) -> dict:
16
16
 
17
17
  See: https://platform.openai.com/docs/guides/structured-outputs/supported-schemas
18
18
  """
19
+ description = openai_function["description"] if "description" in openai_function else ""
20
+
19
21
  structured_output = {
20
22
  "name": openai_function["name"],
21
- "description": openai_function["description"],
23
+ "description": description,
22
24
  "strict": True,
23
25
  "parameters": {"type": "object", "properties": {}, "additionalProperties": False, "required": []},
24
26
  }
@@ -106,7 +106,7 @@ def create(
106
106
  messages: List[Message],
107
107
  user_id: Optional[str] = None, # option UUID to associate request with
108
108
  functions: Optional[list] = None,
109
- functions_python: Optional[list] = None,
109
+ functions_python: Optional[dict] = None,
110
110
  function_call: str = "auto",
111
111
  # hint
112
112
  first_message: bool = False,
@@ -140,7 +140,6 @@ def create(
140
140
  raise ValueError(f"OpenAI key is missing from letta config file")
141
141
 
142
142
  data = build_openai_chat_completions_request(llm_config, messages, user_id, functions, function_call, use_tool_naming, max_tokens)
143
-
144
143
  if stream: # Client requested token streaming
145
144
  data.stream = True
146
145
  assert isinstance(stream_interface, AgentChunkStreamingInterface) or isinstance(
letta/llm_api/openai.py CHANGED
@@ -530,7 +530,12 @@ def openai_chat_completions_request(
530
530
  data.pop("tools")
531
531
  data.pop("tool_choice", None) # extra safe, should exist always (default="auto")
532
532
 
533
+ if "tools" in data:
534
+ for tool in data["tools"]:
535
+ tool["function"] = convert_to_structured_output(tool["function"])
536
+
533
537
  response_json = make_post_request(url, headers, data)
538
+
534
539
  return ChatCompletionResponse(**response_json)
535
540
 
536
541