letta-nightly 0.5.1.dev20241030104135__py3-none-any.whl → 0.5.1.dev20241031104107__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of letta-nightly might be problematic. Click here for more details.

Files changed (32) hide show
  1. letta/agent.py +41 -2
  2. letta/client/client.py +85 -15
  3. letta/client/streaming.py +3 -0
  4. letta/constants.py +3 -0
  5. letta/functions/functions.py +4 -5
  6. letta/functions/schema_generator.py +4 -3
  7. letta/helpers/__init__.py +1 -0
  8. letta/helpers/tool_rule_solver.py +115 -0
  9. letta/llm_api/helpers.py +3 -1
  10. letta/llm_api/llm_api_tools.py +1 -2
  11. letta/llm_api/openai.py +5 -0
  12. letta/metadata.py +43 -1
  13. letta/orm/tool.py +0 -3
  14. letta/schemas/agent.py +5 -5
  15. letta/schemas/letta_response.py +1 -1
  16. letta/schemas/tool.py +4 -6
  17. letta/schemas/tool_rule.py +25 -0
  18. letta/server/rest_api/app.py +1 -5
  19. letta/server/rest_api/routers/v1/organizations.py +2 -2
  20. letta/server/server.py +2 -38
  21. letta/server/static_files/assets/{index-d6b3669a.js → index-b82c8d7c.js} +1 -1
  22. letta/server/static_files/index.html +1 -1
  23. letta/services/tool_manager.py +21 -4
  24. {letta_nightly-0.5.1.dev20241030104135.dist-info → letta_nightly-0.5.1.dev20241031104107.dist-info}/METADATA +1 -1
  25. {letta_nightly-0.5.1.dev20241030104135.dist-info → letta_nightly-0.5.1.dev20241031104107.dist-info}/RECORD +28 -29
  26. letta/server/rest_api/admin/__init__.py +0 -0
  27. letta/server/rest_api/admin/agents.py +0 -21
  28. letta/server/rest_api/admin/tools.py +0 -82
  29. letta/server/rest_api/admin/users.py +0 -98
  30. {letta_nightly-0.5.1.dev20241030104135.dist-info → letta_nightly-0.5.1.dev20241031104107.dist-info}/LICENSE +0 -0
  31. {letta_nightly-0.5.1.dev20241030104135.dist-info → letta_nightly-0.5.1.dev20241031104107.dist-info}/WHEEL +0 -0
  32. {letta_nightly-0.5.1.dev20241030104135.dist-info → letta_nightly-0.5.1.dev20241031104107.dist-info}/entry_points.txt +0 -0
letta/agent.py CHANGED
@@ -20,6 +20,7 @@ from letta.constants import (
20
20
  REQ_HEARTBEAT_MESSAGE,
21
21
  )
22
22
  from letta.errors import LLMError
23
+ from letta.helpers import ToolRulesSolver
23
24
  from letta.interface import AgentInterface
24
25
  from letta.llm_api.helpers import is_context_overflow_error
25
26
  from letta.llm_api.llm_api_tools import create
@@ -43,6 +44,7 @@ from letta.schemas.openai.chat_completion_response import (
43
44
  from letta.schemas.openai.chat_completion_response import UsageStatistics
44
45
  from letta.schemas.passage import Passage
45
46
  from letta.schemas.tool import Tool
47
+ from letta.schemas.tool_rule import TerminalToolRule
46
48
  from letta.schemas.usage import LettaUsageStatistics
47
49
  from letta.system import (
48
50
  get_heartbeat,
@@ -242,6 +244,14 @@ class Agent(BaseAgent):
242
244
  # link tools
243
245
  self.link_tools(tools)
244
246
 
247
+ # initialize a tool rules solver
248
+ if agent_state.tool_rules:
249
+ # if there are tool rules, print out a warning
250
+ warnings.warn("Tool rules only work reliably for the latest OpenAI models that support structured outputs.")
251
+ # add default rule for having send_message be a terminal tool
252
+ agent_state.tool_rules.append(TerminalToolRule(tool_name="send_message"))
253
+ self.tool_rules_solver = ToolRulesSolver(tool_rules=agent_state.tool_rules)
254
+
245
255
  # gpt-4, gpt-3.5-turbo, ...
246
256
  self.model = self.agent_state.llm_config.model
247
257
 
@@ -465,15 +475,26 @@ class Agent(BaseAgent):
465
475
  function_call: str = "auto",
466
476
  first_message: bool = False, # hint
467
477
  stream: bool = False, # TODO move to config?
478
+ fail_on_empty_response: bool = False,
479
+ empty_response_retry_limit: int = 3,
468
480
  ) -> ChatCompletionResponse:
469
481
  """Get response from LLM API"""
482
+ # Get the allowed tools based on the ToolRulesSolver state
483
+ allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names()
484
+
485
+ if not allowed_tool_names:
486
+ # if it's empty, any available tools are fair game
487
+ allowed_functions = self.functions
488
+ else:
489
+ allowed_functions = [func for func in self.functions if func["name"] in allowed_tool_names]
490
+
470
491
  try:
471
492
  response = create(
472
493
  # agent_state=self.agent_state,
473
494
  llm_config=self.agent_state.llm_config,
474
495
  messages=message_sequence,
475
496
  user_id=self.agent_state.user_id,
476
- functions=self.functions,
497
+ functions=allowed_functions,
477
498
  functions_python=self.functions_python,
478
499
  function_call=function_call,
479
500
  # hint
@@ -484,7 +505,15 @@ class Agent(BaseAgent):
484
505
  )
485
506
 
486
507
  if len(response.choices) == 0 or response.choices[0] is None:
487
- raise Exception(f"API call didn't return a message: {response}")
508
+ empty_api_err_message = f"API call didn't return a message: {response}"
509
+ if fail_on_empty_response or empty_response_retry_limit == 0:
510
+ raise Exception(empty_api_err_message)
511
+ else:
512
+ # Decrement retry limit and try again
513
+ warnings.warn(empty_api_err_message)
514
+ return self._get_ai_reply(
515
+ message_sequence, function_call, first_message, stream, fail_on_empty_response, empty_response_retry_limit - 1
516
+ )
488
517
 
489
518
  # special case for 'length'
490
519
  if response.choices[0].finish_reason == "length":
@@ -515,6 +544,7 @@ class Agent(BaseAgent):
515
544
  assert response_message_id.startswith("message-"), response_message_id
516
545
 
517
546
  messages = [] # append these to the history when done
547
+ function_name = None
518
548
 
519
549
  # Step 2: check if LLM wanted to call a function
520
550
  if response_message.function_call or (response_message.tool_calls is not None and len(response_message.tool_calls) > 0):
@@ -724,6 +754,15 @@ class Agent(BaseAgent):
724
754
  # TODO: @charles please check this
725
755
  self.rebuild_memory()
726
756
 
757
+ # Update ToolRulesSolver state with last called function
758
+ self.tool_rules_solver.update_tool_usage(function_name)
759
+
760
+ # Update heartbeat request according to provided tool rules
761
+ if self.tool_rules_solver.has_children_tools(function_name):
762
+ heartbeat_request = True
763
+ elif self.tool_rules_solver.is_terminal_tool(function_name):
764
+ heartbeat_request = False
765
+
727
766
  return messages, heartbeat_request, function_failed
728
767
 
729
768
  def step(
letta/client/client.py CHANGED
@@ -5,7 +5,7 @@ from typing import Callable, Dict, Generator, List, Optional, Union
5
5
  import requests
6
6
 
7
7
  import letta.utils
8
- from letta.constants import BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA
8
+ from letta.constants import ADMIN_PREFIX, BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA
9
9
  from letta.data_sources.connectors import DataConnector
10
10
  from letta.functions.functions import parse_source_code
11
11
  from letta.memory import get_memory_functions
@@ -39,9 +39,11 @@ from letta.schemas.memory import (
39
39
  )
40
40
  from letta.schemas.message import Message, MessageCreate, UpdateMessage
41
41
  from letta.schemas.openai.chat_completions import ToolCall
42
+ from letta.schemas.organization import Organization
42
43
  from letta.schemas.passage import Passage
43
44
  from letta.schemas.source import Source, SourceCreate, SourceUpdate
44
45
  from letta.schemas.tool import Tool, ToolCreate, ToolUpdate
46
+ from letta.schemas.tool_rule import BaseToolRule
45
47
  from letta.server.rest_api.interface import QueuingInterface
46
48
  from letta.server.server import SyncServer
47
49
  from letta.utils import get_human_text, get_persona_text
@@ -139,6 +141,8 @@ class AbstractClient(object):
139
141
  agent_id: Optional[str] = None,
140
142
  name: Optional[str] = None,
141
143
  stream: Optional[bool] = False,
144
+ stream_steps: bool = False,
145
+ stream_tokens: bool = False,
142
146
  include_full_message: Optional[bool] = False,
143
147
  ) -> LettaResponse:
144
148
  raise NotImplementedError
@@ -195,7 +199,6 @@ class AbstractClient(object):
195
199
  self,
196
200
  func,
197
201
  name: Optional[str] = None,
198
- update: Optional[bool] = True,
199
202
  tags: Optional[List[str]] = None,
200
203
  ) -> Tool:
201
204
  raise NotImplementedError
@@ -204,6 +207,7 @@ class AbstractClient(object):
204
207
  self,
205
208
  id: str,
206
209
  name: Optional[str] = None,
210
+ description: Optional[str] = None,
207
211
  func: Optional[Callable] = None,
208
212
  tags: Optional[List[str]] = None,
209
213
  ) -> Tool:
@@ -282,6 +286,15 @@ class AbstractClient(object):
282
286
  def list_embedding_configs(self) -> List[EmbeddingConfig]:
283
287
  raise NotImplementedError
284
288
 
289
+ def create_org(self, name: Optional[str] = None) -> Organization:
290
+ raise NotImplementedError
291
+
292
+ def list_orgs(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Organization]:
293
+ raise NotImplementedError
294
+
295
+ def delete_org(self, org_id: str) -> Organization:
296
+ raise NotImplementedError
297
+
285
298
 
286
299
  class RESTClient(AbstractClient):
287
300
  """
@@ -394,7 +407,7 @@ class RESTClient(AbstractClient):
394
407
  # add memory tools
395
408
  memory_functions = get_memory_functions(memory)
396
409
  for func_name, func in memory_functions.items():
397
- tool = self.create_tool(func, name=func_name, tags=["memory", "letta-base"], update=True)
410
+ tool = self.create_tool(func, name=func_name, tags=["memory", "letta-base"])
398
411
  tool_names.append(tool.name)
399
412
 
400
413
  # check if default configs are provided
@@ -1257,7 +1270,6 @@ class RESTClient(AbstractClient):
1257
1270
  self,
1258
1271
  func: Callable,
1259
1272
  name: Optional[str] = None,
1260
- update: Optional[bool] = True, # TODO: actually use this
1261
1273
  tags: Optional[List[str]] = None,
1262
1274
  ) -> Tool:
1263
1275
  """
@@ -1267,7 +1279,6 @@ class RESTClient(AbstractClient):
1267
1279
  func (callable): The function to create a tool for.
1268
1280
  name: (str): Name of the tool (must be unique per-user.)
1269
1281
  tags (Optional[List[str]], optional): Tags for the tool. Defaults to None.
1270
- update (bool, optional): Update the tool if it already exists. Defaults to True.
1271
1282
 
1272
1283
  Returns:
1273
1284
  tool (Tool): The created tool.
@@ -1292,6 +1303,7 @@ class RESTClient(AbstractClient):
1292
1303
  self,
1293
1304
  id: str,
1294
1305
  name: Optional[str] = None,
1306
+ description: Optional[str] = None,
1295
1307
  func: Optional[Callable] = None,
1296
1308
  tags: Optional[List[str]] = None,
1297
1309
  ) -> Tool:
@@ -1314,7 +1326,7 @@ class RESTClient(AbstractClient):
1314
1326
 
1315
1327
  source_type = "python"
1316
1328
 
1317
- request = ToolUpdate(source_type=source_type, source_code=source_code, tags=tags, name=name)
1329
+ request = ToolUpdate(description=description, source_type=source_type, source_code=source_code, tags=tags, name=name)
1318
1330
  response = requests.patch(f"{self.base_url}/{self.api_prefix}/tools/{id}", json=request.model_dump(), headers=self.headers)
1319
1331
  if response.status_code != 200:
1320
1332
  raise ValueError(f"Failed to update tool: {response.text}")
@@ -1464,6 +1476,54 @@ class RESTClient(AbstractClient):
1464
1476
  raise ValueError(f"Failed to list embedding configs: {response.text}")
1465
1477
  return [EmbeddingConfig(**config) for config in response.json()]
1466
1478
 
1479
+ def list_orgs(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Organization]:
1480
+ """
1481
+ Retrieves a list of all organizations in the database, with optional pagination.
1482
+
1483
+ @param cursor: the pagination cursor, if any
1484
+ @param limit: the maximum number of organizations to retrieve
1485
+ @return: a list of Organization objects
1486
+ """
1487
+ params = {"cursor": cursor, "limit": limit}
1488
+ response = requests.get(f"{self.base_url}/{ADMIN_PREFIX}/orgs", headers=self.headers, params=params)
1489
+ if response.status_code != 200:
1490
+ raise ValueError(f"Failed to retrieve organizations: {response.text}")
1491
+ return [Organization(**org_data) for org_data in response.json()]
1492
+
1493
+ def create_org(self, name: Optional[str] = None) -> Organization:
1494
+ """
1495
+ Creates an organization with the given name. If not provided, we generate a random one.
1496
+
1497
+ @param name: the name of the organization
1498
+ @return: the created Organization
1499
+ """
1500
+ payload = {"name": name}
1501
+ response = requests.post(f"{self.base_url}/{ADMIN_PREFIX}/orgs", headers=self.headers, json=payload)
1502
+ if response.status_code != 200:
1503
+ raise ValueError(f"Failed to create org: {response.text}")
1504
+ return Organization(**response.json())
1505
+
1506
+ def delete_org(self, org_id: str) -> Organization:
1507
+ """
1508
+ Deletes an organization by its ID.
1509
+
1510
+ @param org_id: the ID of the organization to delete
1511
+ @return: the deleted Organization object
1512
+ """
1513
+ # Define query parameters with org_id
1514
+ params = {"org_id": org_id}
1515
+
1516
+ # Make the DELETE request with query parameters
1517
+ response = requests.delete(f"{self.base_url}/{ADMIN_PREFIX}/orgs", headers=self.headers, params=params)
1518
+
1519
+ if response.status_code == 404:
1520
+ raise ValueError(f"Organization with ID '{org_id}' does not exist")
1521
+ elif response.status_code != 200:
1522
+ raise ValueError(f"Failed to delete organization: {response.text}")
1523
+
1524
+ # Parse and return the deleted organization
1525
+ return Organization(**response.json())
1526
+
1467
1527
 
1468
1528
  class LocalClient(AbstractClient):
1469
1529
  """
@@ -1568,6 +1628,7 @@ class LocalClient(AbstractClient):
1568
1628
  system: Optional[str] = None,
1569
1629
  # tools
1570
1630
  tools: Optional[List[str]] = None,
1631
+ tool_rules: Optional[List[BaseToolRule]] = None,
1571
1632
  include_base_tools: Optional[bool] = True,
1572
1633
  # metadata
1573
1634
  metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA},
@@ -1582,6 +1643,7 @@ class LocalClient(AbstractClient):
1582
1643
  memory (Memory): Memory configuration
1583
1644
  system (str): System configuration
1584
1645
  tools (List[str]): List of tools
1646
+ tool_rules (Optional[List[BaseToolRule]]): List of tool rules
1585
1647
  include_base_tools (bool): Include base tools
1586
1648
  metadata (Dict): Metadata
1587
1649
  description (str): Description
@@ -1603,7 +1665,7 @@ class LocalClient(AbstractClient):
1603
1665
  # add memory tools
1604
1666
  memory_functions = get_memory_functions(memory)
1605
1667
  for func_name, func in memory_functions.items():
1606
- tool = self.create_tool(func, name=func_name, tags=["memory", "letta-base"], update=True)
1668
+ tool = self.create_tool(func, name=func_name, tags=["memory", "letta-base"])
1607
1669
  tool_names.append(tool.name)
1608
1670
 
1609
1671
  self.interface.clear()
@@ -1620,6 +1682,7 @@ class LocalClient(AbstractClient):
1620
1682
  metadata_=metadata,
1621
1683
  memory=memory,
1622
1684
  tools=tool_names,
1685
+ tool_rules=tool_rules,
1623
1686
  system=system,
1624
1687
  agent_type=agent_type,
1625
1688
  llm_config=llm_config if llm_config else self._default_llm_config,
@@ -2175,7 +2238,6 @@ class LocalClient(AbstractClient):
2175
2238
  def load_langchain_tool(self, langchain_tool: "LangChainBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> Tool:
2176
2239
  tool_create = ToolCreate.from_langchain(
2177
2240
  langchain_tool=langchain_tool,
2178
- organization_id=self.org_id,
2179
2241
  additional_imports_module_attr_map=additional_imports_module_attr_map,
2180
2242
  )
2181
2243
  return self.server.tool_manager.create_or_update_tool(tool_create, actor=self.user)
@@ -2184,12 +2246,11 @@ class LocalClient(AbstractClient):
2184
2246
  tool_create = ToolCreate.from_crewai(
2185
2247
  crewai_tool=crewai_tool,
2186
2248
  additional_imports_module_attr_map=additional_imports_module_attr_map,
2187
- organization_id=self.org_id,
2188
2249
  )
2189
2250
  return self.server.tool_manager.create_or_update_tool(tool_create, actor=self.user)
2190
2251
 
2191
2252
  def load_composio_tool(self, action: "ActionType") -> Tool:
2192
- tool_create = ToolCreate.from_composio(action=action, organization_id=self.org_id)
2253
+ tool_create = ToolCreate.from_composio(action=action)
2193
2254
  return self.server.tool_manager.create_or_update_tool(tool_create, actor=self.user)
2194
2255
 
2195
2256
  # TODO: Use the above function `add_tool` here as there is duplicate logic
@@ -2197,9 +2258,8 @@ class LocalClient(AbstractClient):
2197
2258
  self,
2198
2259
  func,
2199
2260
  name: Optional[str] = None,
2200
- update: Optional[bool] = True, # TODO: actually use this
2201
2261
  tags: Optional[List[str]] = None,
2202
- terminal: Optional[bool] = False,
2262
+ description: Optional[str] = None,
2203
2263
  ) -> Tool:
2204
2264
  """
2205
2265
  Create a tool. This stores the source code of function on the server, so that the server can execute the function and generate an OpenAI JSON schemas for it when using with an agent.
@@ -2208,8 +2268,7 @@ class LocalClient(AbstractClient):
2208
2268
  func (callable): The function to create a tool for.
2209
2269
  name: (str): Name of the tool (must be unique per-user.)
2210
2270
  tags (Optional[List[str]], optional): Tags for the tool. Defaults to None.
2211
- update (bool, optional): Update the tool if it already exists. Defaults to True.
2212
- terminal (bool, optional): Whether the tool is a terminal tool (no more agent steps). Defaults to False.
2271
+ description (str, optional): The description.
2213
2272
 
2214
2273
  Returns:
2215
2274
  tool (Tool): The created tool.
@@ -2229,7 +2288,7 @@ class LocalClient(AbstractClient):
2229
2288
  source_code=source_code,
2230
2289
  name=name,
2231
2290
  tags=tags,
2232
- terminal=terminal,
2291
+ description=description,
2233
2292
  ),
2234
2293
  actor=self.user,
2235
2294
  )
@@ -2238,6 +2297,7 @@ class LocalClient(AbstractClient):
2238
2297
  self,
2239
2298
  id: str,
2240
2299
  name: Optional[str] = None,
2300
+ description: Optional[str] = None,
2241
2301
  func: Optional[callable] = None,
2242
2302
  tags: Optional[List[str]] = None,
2243
2303
  ) -> Tool:
@@ -2258,6 +2318,7 @@ class LocalClient(AbstractClient):
2258
2318
  "source_code": parse_source_code(func) if func else None,
2259
2319
  "tags": tags,
2260
2320
  "name": name,
2321
+ "description": description,
2261
2322
  }
2262
2323
 
2263
2324
  # Filter out any None values from the dictionary
@@ -2648,3 +2709,12 @@ class LocalClient(AbstractClient):
2648
2709
  configs (List[EmbeddingConfig]): List of embedding configurations
2649
2710
  """
2650
2711
  return self.server.list_embedding_models()
2712
+
2713
+ def create_org(self, name: Optional[str] = None) -> Organization:
2714
+ return self.server.organization_manager.create_organization(name=name)
2715
+
2716
+ def list_orgs(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Organization]:
2717
+ return self.server.organization_manager.list_organizations(cursor=cursor, limit=limit)
2718
+
2719
+ def delete_org(self, org_id: str) -> Organization:
2720
+ return self.server.organization_manager.delete_organization_by_id(org_id=org_id)
letta/client/streaming.py CHANGED
@@ -13,6 +13,7 @@ from letta.schemas.letta_message import (
13
13
  InternalMonologue,
14
14
  )
15
15
  from letta.schemas.letta_response import LettaStreamingResponse
16
+ from letta.schemas.usage import LettaUsageStatistics
16
17
 
17
18
 
18
19
  def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingResponse, None, None]:
@@ -58,6 +59,8 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingRe
58
59
  yield FunctionCallMessage(**chunk_data)
59
60
  elif "function_return" in chunk_data:
60
61
  yield FunctionReturn(**chunk_data)
62
+ elif "usage" in chunk_data:
63
+ yield LettaUsageStatistics(**chunk_data["usage"])
61
64
  else:
62
65
  raise ValueError(f"Unknown message type in chunk_data: {chunk_data}")
63
66
 
letta/constants.py CHANGED
@@ -3,6 +3,9 @@ from logging import CRITICAL, DEBUG, ERROR, INFO, NOTSET, WARN, WARNING
3
3
 
4
4
  LETTA_DIR = os.path.join(os.path.expanduser("~"), ".letta")
5
5
 
6
+ ADMIN_PREFIX = "/v1/admin"
7
+ API_PREFIX = "/v1"
8
+ OPENAI_API_PREFIX = "/openai"
6
9
 
7
10
  # String in the error message for when the context window is too large
8
11
  # Example full message:
@@ -7,10 +7,9 @@ from typing import Optional
7
7
 
8
8
  from letta.constants import CLI_WARNING_PREFIX
9
9
  from letta.functions.schema_generator import generate_schema
10
- from letta.schemas.tool import ToolCreate
11
10
 
12
11
 
13
- def derive_openai_json_schema(tool_create: ToolCreate) -> dict:
12
+ def derive_openai_json_schema(source_code: str, name: Optional[str]) -> dict:
14
13
  # auto-generate openai schema
15
14
  try:
16
15
  # Define a custom environment with necessary imports
@@ -19,14 +18,14 @@ def derive_openai_json_schema(tool_create: ToolCreate) -> dict:
19
18
  }
20
19
 
21
20
  env.update(globals())
22
- exec(tool_create.source_code, env)
21
+ exec(source_code, env)
23
22
 
24
23
  # get available functions
25
24
  functions = [f for f in env if callable(env[f])]
26
25
 
27
26
  # TODO: not sure if this always works
28
27
  func = env[functions[-1]]
29
- json_schema = generate_schema(func, terminal=tool_create.terminal, name=tool_create.name)
28
+ json_schema = generate_schema(func, name=name)
30
29
  return json_schema
31
30
  except Exception as e:
32
31
  raise RuntimeError(f"Failed to execute source code: {e}")
@@ -51,7 +50,7 @@ def load_function_set(module: ModuleType) -> dict:
51
50
  if attr_name in function_dict:
52
51
  raise ValueError(f"Found a duplicate of function name '{attr_name}'")
53
52
 
54
- generated_schema = generate_schema(attr, terminal=False)
53
+ generated_schema = generate_schema(attr)
55
54
  function_dict[attr_name] = {
56
55
  "module": inspect.getsource(module),
57
56
  "python_function": attr,
@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional, Type, Union, get_args, get_origin
3
3
 
4
4
  from docstring_parser import parse
5
5
  from pydantic import BaseModel
6
+ from pydantic.v1 import BaseModel as V1BaseModel
6
7
 
7
8
 
8
9
  def is_optional(annotation):
@@ -74,7 +75,7 @@ def pydantic_model_to_open_ai(model):
74
75
  }
75
76
 
76
77
 
77
- def generate_schema(function, terminal: Optional[bool], name: Optional[str] = None, description: Optional[str] = None) -> dict:
78
+ def generate_schema(function, name: Optional[str] = None, description: Optional[str] = None) -> dict:
78
79
  # Get the signature of the function
79
80
  sig = inspect.signature(function)
80
81
 
@@ -128,7 +129,7 @@ def generate_schema(function, terminal: Optional[bool], name: Optional[str] = No
128
129
 
129
130
  # append the heartbeat
130
131
  # TODO: don't hard-code
131
- if function.__name__ not in ["send_message", "pause_heartbeats"] and not terminal:
132
+ if function.__name__ not in ["send_message", "pause_heartbeats"]:
132
133
  schema["parameters"]["properties"]["request_heartbeat"] = {
133
134
  "type": "boolean",
134
135
  "description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
@@ -139,7 +140,7 @@ def generate_schema(function, terminal: Optional[bool], name: Optional[str] = No
139
140
 
140
141
 
141
142
  def generate_schema_from_args_schema(
142
- args_schema: Type[BaseModel], name: Optional[str] = None, description: Optional[str] = None, append_heartbeat: bool = True
143
+ args_schema: Type[V1BaseModel], name: Optional[str] = None, description: Optional[str] = None, append_heartbeat: bool = True
143
144
  ) -> Dict[str, Any]:
144
145
  properties = {}
145
146
  required = []
@@ -0,0 +1 @@
1
+ from letta.helpers.tool_rule_solver import ToolRulesSolver
@@ -0,0 +1,115 @@
1
+ import warnings
2
+ from typing import Dict, List, Optional, Set
3
+
4
+ from pydantic import BaseModel, Field
5
+
6
+ from letta.schemas.tool_rule import (
7
+ BaseToolRule,
8
+ InitToolRule,
9
+ TerminalToolRule,
10
+ ToolRule,
11
+ )
12
+
13
+
14
+ class ToolRuleValidationError(Exception):
15
+ """Custom exception for tool rule validation errors in ToolRulesSolver."""
16
+
17
+ def __init__(self, message: str):
18
+ super().__init__(f"ToolRuleValidationError: {message}")
19
+
20
+
21
+ class ToolRulesSolver(BaseModel):
22
+ init_tool_rules: List[InitToolRule] = Field(
23
+ default_factory=list, description="Initial tool rules to be used at the start of tool execution."
24
+ )
25
+ tool_rules: List[ToolRule] = Field(
26
+ default_factory=list, description="Standard tool rules for controlling execution sequence and allowed transitions."
27
+ )
28
+ terminal_tool_rules: List[TerminalToolRule] = Field(
29
+ default_factory=list, description="Terminal tool rules that end the agent loop if called."
30
+ )
31
+ last_tool_name: Optional[str] = Field(None, description="The most recent tool used, updated with each tool call.")
32
+
33
+ def __init__(self, tool_rules: List[BaseToolRule], **kwargs):
34
+ super().__init__(**kwargs)
35
+ # Separate the provided tool rules into init, standard, and terminal categories
36
+ for rule in tool_rules:
37
+ if isinstance(rule, InitToolRule):
38
+ self.init_tool_rules.append(rule)
39
+ elif isinstance(rule, ToolRule):
40
+ self.tool_rules.append(rule)
41
+ elif isinstance(rule, TerminalToolRule):
42
+ self.terminal_tool_rules.append(rule)
43
+
44
+ # Validate the tool rules to ensure they form a DAG
45
+ if not self.validate_tool_rules():
46
+ raise ToolRuleValidationError("Tool rules contain cycles, which are not allowed in a valid configuration.")
47
+
48
+ def update_tool_usage(self, tool_name: str):
49
+ """Update the internal state to track the last tool called."""
50
+ self.last_tool_name = tool_name
51
+
52
+ def get_allowed_tool_names(self, error_on_empty: bool = False) -> List[str]:
53
+ """Get a list of tool names allowed based on the last tool called."""
54
+ if self.last_tool_name is None:
55
+ # Use initial tool rules if no tool has been called yet
56
+ return [rule.tool_name for rule in self.init_tool_rules]
57
+ else:
58
+ # Find a matching ToolRule for the last tool used
59
+ current_rule = next((rule for rule in self.tool_rules if rule.tool_name == self.last_tool_name), None)
60
+
61
+ # Return children which must exist on ToolRule
62
+ if current_rule:
63
+ return current_rule.children
64
+
65
+ # Default to empty if no rule matches
66
+ message = "User provided tool rules and execution state resolved to no more possible tool calls."
67
+ if error_on_empty:
68
+ raise RuntimeError(message)
69
+ else:
70
+ warnings.warn(message)
71
+ return []
72
+
73
+ def is_terminal_tool(self, tool_name: str) -> bool:
74
+ """Check if the tool is defined as a terminal tool in the terminal tool rules."""
75
+ return any(rule.tool_name == tool_name for rule in self.terminal_tool_rules)
76
+
77
+ def has_children_tools(self, tool_name):
78
+ """Check if the tool has children tools"""
79
+ return any(rule.tool_name == tool_name for rule in self.tool_rules)
80
+
81
+ def validate_tool_rules(self) -> bool:
82
+ """
83
+ Validate that the tool rules define a directed acyclic graph (DAG).
84
+ Returns True if valid (no cycles), otherwise False.
85
+ """
86
+ # Build adjacency list for the tool graph
87
+ adjacency_list: Dict[str, List[str]] = {rule.tool_name: rule.children for rule in self.tool_rules}
88
+
89
+ # Track visited nodes
90
+ visited: Set[str] = set()
91
+ path_stack: Set[str] = set()
92
+
93
+ # Define DFS helper function
94
+ def dfs(tool_name: str) -> bool:
95
+ if tool_name in path_stack:
96
+ return False # Cycle detected
97
+ if tool_name in visited:
98
+ return True # Already validated
99
+
100
+ # Mark the node as visited in the current path
101
+ path_stack.add(tool_name)
102
+ for child in adjacency_list.get(tool_name, []):
103
+ if not dfs(child):
104
+ return False # Cycle detected in DFS
105
+ path_stack.remove(tool_name) # Remove from current path
106
+ visited.add(tool_name)
107
+ return True
108
+
109
+ # Run DFS from each tool in `tool_rules`
110
+ for rule in self.tool_rules:
111
+ if rule.tool_name not in visited:
112
+ if not dfs(rule.tool_name):
113
+ return False # Cycle found, invalid tool rules
114
+
115
+ return True # No cycles, valid DAG
letta/llm_api/helpers.py CHANGED
@@ -16,9 +16,11 @@ def convert_to_structured_output(openai_function: dict) -> dict:
16
16
 
17
17
  See: https://platform.openai.com/docs/guides/structured-outputs/supported-schemas
18
18
  """
19
+ description = openai_function["description"] if "description" in openai_function else ""
20
+
19
21
  structured_output = {
20
22
  "name": openai_function["name"],
21
- "description": openai_function["description"],
23
+ "description": description,
22
24
  "strict": True,
23
25
  "parameters": {"type": "object", "properties": {}, "additionalProperties": False, "required": []},
24
26
  }
@@ -106,7 +106,7 @@ def create(
106
106
  messages: List[Message],
107
107
  user_id: Optional[str] = None, # option UUID to associate request with
108
108
  functions: Optional[list] = None,
109
- functions_python: Optional[list] = None,
109
+ functions_python: Optional[dict] = None,
110
110
  function_call: str = "auto",
111
111
  # hint
112
112
  first_message: bool = False,
@@ -140,7 +140,6 @@ def create(
140
140
  raise ValueError(f"OpenAI key is missing from letta config file")
141
141
 
142
142
  data = build_openai_chat_completions_request(llm_config, messages, user_id, functions, function_call, use_tool_naming, max_tokens)
143
-
144
143
  if stream: # Client requested token streaming
145
144
  data.stream = True
146
145
  assert isinstance(stream_interface, AgentChunkStreamingInterface) or isinstance(
letta/llm_api/openai.py CHANGED
@@ -530,7 +530,12 @@ def openai_chat_completions_request(
530
530
  data.pop("tools")
531
531
  data.pop("tool_choice", None) # extra safe, should exist always (default="auto")
532
532
 
533
+ if "tools" in data:
534
+ for tool in data["tools"]:
535
+ tool["function"] = convert_to_structured_output(tool["function"])
536
+
533
537
  response_json = make_post_request(url, headers, data)
538
+
534
539
  return ChatCompletionResponse(**response_json)
535
540
 
536
541