letta-nightly 0.5.5.dev20241122170833__py3-none-any.whl → 0.6.0.dev20241204051808__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of letta-nightly might be problematic. Click here for more details.

Files changed (70) hide show
  1. letta/__init__.py +2 -2
  2. letta/agent.py +155 -166
  3. letta/agent_store/chroma.py +2 -0
  4. letta/agent_store/db.py +1 -1
  5. letta/cli/cli.py +12 -8
  6. letta/cli/cli_config.py +1 -1
  7. letta/client/client.py +765 -137
  8. letta/config.py +2 -2
  9. letta/constants.py +10 -14
  10. letta/errors.py +12 -0
  11. letta/functions/function_sets/base.py +38 -1
  12. letta/functions/functions.py +40 -57
  13. letta/functions/helpers.py +0 -4
  14. letta/functions/schema_generator.py +279 -18
  15. letta/helpers/tool_rule_solver.py +6 -5
  16. letta/llm_api/helpers.py +99 -5
  17. letta/llm_api/openai.py +8 -2
  18. letta/local_llm/utils.py +13 -6
  19. letta/log.py +7 -9
  20. letta/main.py +1 -1
  21. letta/metadata.py +53 -38
  22. letta/o1_agent.py +1 -4
  23. letta/orm/__init__.py +2 -0
  24. letta/orm/block.py +7 -3
  25. letta/orm/blocks_agents.py +32 -0
  26. letta/orm/errors.py +8 -0
  27. letta/orm/mixins.py +8 -0
  28. letta/orm/organization.py +8 -1
  29. letta/orm/sandbox_config.py +56 -0
  30. letta/orm/sqlalchemy_base.py +68 -10
  31. letta/persistence_manager.py +1 -0
  32. letta/schemas/agent.py +57 -52
  33. letta/schemas/block.py +85 -26
  34. letta/schemas/blocks_agents.py +32 -0
  35. letta/schemas/enums.py +14 -0
  36. letta/schemas/letta_base.py +10 -1
  37. letta/schemas/letta_request.py +11 -23
  38. letta/schemas/letta_response.py +1 -2
  39. letta/schemas/memory.py +41 -76
  40. letta/schemas/message.py +3 -3
  41. letta/schemas/sandbox_config.py +114 -0
  42. letta/schemas/tool.py +37 -1
  43. letta/schemas/tool_rule.py +13 -5
  44. letta/server/rest_api/app.py +5 -4
  45. letta/server/rest_api/interface.py +12 -19
  46. letta/server/rest_api/routers/openai/assistants/threads.py +2 -3
  47. letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +0 -2
  48. letta/server/rest_api/routers/v1/__init__.py +4 -9
  49. letta/server/rest_api/routers/v1/agents.py +145 -61
  50. letta/server/rest_api/routers/v1/blocks.py +50 -5
  51. letta/server/rest_api/routers/v1/sandbox_configs.py +127 -0
  52. letta/server/rest_api/routers/v1/sources.py +8 -1
  53. letta/server/rest_api/routers/v1/tools.py +139 -13
  54. letta/server/rest_api/utils.py +6 -0
  55. letta/server/server.py +397 -340
  56. letta/server/static_files/assets/index-9fa459a2.js +1 -1
  57. letta/services/block_manager.py +23 -2
  58. letta/services/blocks_agents_manager.py +106 -0
  59. letta/services/per_agent_lock_manager.py +18 -0
  60. letta/services/sandbox_config_manager.py +256 -0
  61. letta/services/tool_execution_sandbox.py +352 -0
  62. letta/services/tool_manager.py +16 -22
  63. letta/services/tool_sandbox_env/.gitkeep +0 -0
  64. letta/settings.py +4 -0
  65. letta/utils.py +0 -7
  66. {letta_nightly-0.5.5.dev20241122170833.dist-info → letta_nightly-0.6.0.dev20241204051808.dist-info}/METADATA +8 -6
  67. {letta_nightly-0.5.5.dev20241122170833.dist-info → letta_nightly-0.6.0.dev20241204051808.dist-info}/RECORD +70 -60
  68. {letta_nightly-0.5.5.dev20241122170833.dist-info → letta_nightly-0.6.0.dev20241204051808.dist-info}/LICENSE +0 -0
  69. {letta_nightly-0.5.5.dev20241122170833.dist-info → letta_nightly-0.6.0.dev20241204051808.dist-info}/WHEEL +0 -0
  70. {letta_nightly-0.5.5.dev20241122170833.dist-info → letta_nightly-0.6.0.dev20241204051808.dist-info}/entry_points.txt +0 -0
letta/config.py CHANGED
@@ -16,7 +16,7 @@ from letta.constants import (
16
16
  LETTA_DIR,
17
17
  )
18
18
  from letta.log import get_logger
19
- from letta.schemas.agent import AgentState
19
+ from letta.schemas.agent import PersistedAgentState
20
20
  from letta.schemas.embedding_config import EmbeddingConfig
21
21
  from letta.schemas.llm_config import LLMConfig
22
22
 
@@ -434,7 +434,7 @@ class AgentConfig:
434
434
  json.dump(vars(self), f, indent=4)
435
435
 
436
436
  def to_agent_state(self):
437
- return AgentState(
437
+ return PersistedAgentState(
438
438
  name=self.name,
439
439
  preset=self.preset,
440
440
  persona=self.persona,
letta/constants.py CHANGED
@@ -36,14 +36,10 @@ DEFAULT_PERSONA = "sam_pov"
36
36
  DEFAULT_HUMAN = "basic"
37
37
  DEFAULT_PRESET = "memgpt_chat"
38
38
 
39
- # Tools
40
- BASE_TOOLS = [
41
- "send_message",
42
- "conversation_search",
43
- "conversation_search_date",
44
- "archival_memory_insert",
45
- "archival_memory_search",
46
- ]
39
+ # Base tools that cannot be edited, as they access agent state directly
40
+ BASE_TOOLS = ["send_message", "conversation_search", "conversation_search_date", "archival_memory_insert", "archival_memory_search"]
41
+ # Base memory tools CAN be edited, and are added by default by the server
42
+ BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"]
47
43
 
48
44
  # The name of the tool used to send message to the user
49
45
  # May not be relevant in cases where the agent has multiple ways to message to user (send_imessage, send_discord_mesasge, ...)
@@ -133,9 +129,13 @@ MESSAGE_SUMMARY_REQUEST_ACK = "Understood, I will respond with a summary of the
133
129
  # These serve as in-context examples of how to use functions / what user messages look like
134
130
  MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST = 3
135
131
 
132
+ # Maximum length of an error message
133
+ MAX_ERROR_MESSAGE_CHAR_LIMIT = 500
134
+
136
135
  # Default memory limits
137
- CORE_MEMORY_PERSONA_CHAR_LIMIT = 2000
138
- CORE_MEMORY_HUMAN_CHAR_LIMIT = 2000
136
+ CORE_MEMORY_PERSONA_CHAR_LIMIT: int = 5000
137
+ CORE_MEMORY_HUMAN_CHAR_LIMIT: int = 5000
138
+ CORE_MEMORY_BLOCK_CHAR_LIMIT: int = 5000
139
139
 
140
140
  # Function return limits
141
141
  FUNCTION_RETURN_CHAR_LIMIT = 6000 # ~300 words
@@ -155,9 +155,5 @@ FUNC_FAILED_HEARTBEAT_MESSAGE = f"{NON_USER_MSG_PREFIX}Function call failed, ret
155
155
 
156
156
  RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE = 5
157
157
 
158
- # TODO Is this config or constant?
159
- CORE_MEMORY_PERSONA_CHAR_LIMIT: int = 2000
160
- CORE_MEMORY_HUMAN_CHAR_LIMIT: int = 2000
161
-
162
158
  MAX_FILENAME_LENGTH = 255
163
159
  RESERVED_FILENAMES = {"CON", "PRN", "AUX", "NUL", "COM1", "COM2", "LPT1", "LPT2"}
letta/errors.py CHANGED
@@ -10,6 +10,18 @@ class LettaError(Exception):
10
10
  """Base class for all Letta related errors."""
11
11
 
12
12
 
13
+ class LettaToolCreateError(LettaError):
14
+ """Error raised when a tool cannot be created."""
15
+
16
+ default_error_message = "Error creating tool."
17
+
18
+ def __init__(self, message=None):
19
+ if message is None:
20
+ message = self.default_error_message
21
+ self.message = message
22
+ super().__init__(self.message)
23
+
24
+
13
25
  class LLMError(LettaError):
14
26
  pass
15
27
 
@@ -11,7 +11,7 @@ from letta.constants import MAX_PAUSE_HEARTBEATS
11
11
  # If the function fails, throw an exception
12
12
 
13
13
 
14
- def send_message(self: Agent, message: str) -> Optional[str]:
14
+ def send_message(self: "Agent", message: str) -> Optional[str]:
15
15
  """
16
16
  Sends a message to the human user.
17
17
 
@@ -172,3 +172,40 @@ def archival_memory_search(self: Agent, query: str, page: Optional[int] = 0) ->
172
172
  results_formatted = [f"timestamp: {d['timestamp']}, memory: {d['content']}" for d in results]
173
173
  results_str = f"{results_pref} {json_dumps(results_formatted)}"
174
174
  return results_str
175
+
176
+
177
+ def core_memory_append(agent_state: "AgentState", label: str, content: str) -> Optional[str]: # type: ignore
178
+ """
179
+ Append to the contents of core memory.
180
+
181
+ Args:
182
+ label (str): Section of the memory to be edited (persona or human).
183
+ content (str): Content to write to the memory. All unicode (including emojis) are supported.
184
+
185
+ Returns:
186
+ Optional[str]: None is always returned as this function does not produce a response.
187
+ """
188
+ current_value = str(agent_state.memory.get_block(label).value)
189
+ new_value = current_value + "\n" + str(content)
190
+ agent_state.memory.update_block_value(label=label, value=new_value)
191
+ return None
192
+
193
+
194
+ def core_memory_replace(agent_state: "AgentState", label: str, old_content: str, new_content: str) -> Optional[str]: # type: ignore
195
+ """
196
+ Replace the contents of core memory. To delete memories, use an empty string for new_content.
197
+
198
+ Args:
199
+ label (str): Section of the memory to be edited (persona or human).
200
+ old_content (str): String to replace. Must be an exact match.
201
+ new_content (str): Content to write to the memory. All unicode (including emojis) are supported.
202
+
203
+ Returns:
204
+ Optional[str]: None is always returned as this function does not produce a response.
205
+ """
206
+ current_value = str(agent_state.memory.get_block(label).value)
207
+ if old_content not in current_value:
208
+ raise ValueError(f"Old content '{old_content}' not found in memory block '{label}'")
209
+ new_value = current_value.replace(str(old_content), str(new_content))
210
+ agent_state.memory.update_block_value(label=label, value=new_value)
211
+ return None
@@ -1,35 +1,61 @@
1
- import importlib
2
1
  import inspect
3
- import os
4
2
  from textwrap import dedent # remove indentation
5
3
  from types import ModuleType
6
- from typing import Optional, List
4
+ from typing import Dict, List, Optional
7
5
 
8
- from letta.constants import CLI_WARNING_PREFIX
6
+ from letta.errors import LettaToolCreateError
9
7
  from letta.functions.schema_generator import generate_schema
10
8
 
11
9
 
12
10
  def derive_openai_json_schema(source_code: str, name: Optional[str] = None) -> dict:
13
- # auto-generate openai schema
11
+ """Derives the OpenAI JSON schema for a given function source code.
12
+
13
+ First, attempts to execute the source code in a custom environment with only the necessary imports.
14
+ Then, it generates the schema from the function's docstring and signature.
15
+ """
14
16
  try:
15
17
  # Define a custom environment with necessary imports
16
18
  env = {
17
- "Optional": Optional, # Add any other required imports here
18
- "List": List
19
+ "Optional": Optional,
20
+ "List": List,
21
+ "Dict": Dict,
22
+ # To support Pydantic models
23
+ # "BaseModel": BaseModel,
24
+ # "Field": Field,
19
25
  }
20
-
21
26
  env.update(globals())
27
+
28
+ # print("About to execute source code...")
22
29
  exec(source_code, env)
30
+ # print("Source code executed successfully")
23
31
 
24
- # get available functions
25
- functions = [f for f in env if callable(env[f])]
32
+ functions = [f for f in env if callable(env[f]) and not f.startswith("__")]
33
+ if not functions:
34
+ raise LettaToolCreateError("No callable functions found in source code")
26
35
 
27
- # TODO: not sure if this always works
36
+ # print(f"Found functions: {functions}")
28
37
  func = env[functions[-1]]
29
- json_schema = generate_schema(func, name=name)
30
- return json_schema
38
+
39
+ if not hasattr(func, "__doc__") or not func.__doc__:
40
+ raise LettaToolCreateError(f"Function {func.__name__} missing docstring")
41
+
42
+ # print("About to generate schema...")
43
+ try:
44
+ schema = generate_schema(func, name=name)
45
+ # print("Schema generated successfully")
46
+ return schema
47
+ except TypeError as e:
48
+ raise LettaToolCreateError(f"Type error in schema generation: {str(e)}")
49
+ except ValueError as e:
50
+ raise LettaToolCreateError(f"Value error in schema generation: {str(e)}")
51
+ except Exception as e:
52
+ raise LettaToolCreateError(f"Unexpected error in schema generation: {str(e)}")
53
+
31
54
  except Exception as e:
32
- raise RuntimeError(f"Failed to execute source code: {e}")
55
+ import traceback
56
+
57
+ traceback.print_exc()
58
+ raise LettaToolCreateError(f"Schema generation failed: {str(e)}") from e
33
59
 
34
60
 
35
61
  def parse_source_code(func) -> str:
@@ -61,46 +87,3 @@ def load_function_set(module: ModuleType) -> dict:
61
87
  if len(function_dict) == 0:
62
88
  raise ValueError(f"No functions found in module {module}")
63
89
  return function_dict
64
-
65
-
66
- def validate_function(module_name, module_full_path):
67
- try:
68
- file = os.path.basename(module_full_path)
69
- spec = importlib.util.spec_from_file_location(module_name, module_full_path)
70
- module = importlib.util.module_from_spec(spec)
71
- spec.loader.exec_module(module)
72
- except ModuleNotFoundError as e:
73
- # Handle missing module imports
74
- missing_package = str(e).split("'")[1] # Extract the name of the missing package
75
- print(f"{CLI_WARNING_PREFIX}skipped loading python file '{module_full_path}'!")
76
- return (
77
- False,
78
- f"'{file}' imports '{missing_package}', but '{missing_package}' is not installed locally - install python package '{missing_package}' to link functions from '{file}' to Letta.",
79
- )
80
- except SyntaxError as e:
81
- # Handle syntax errors in the module
82
- return False, f"{CLI_WARNING_PREFIX}skipped loading python file '{file}' due to a syntax error: {e}"
83
- except Exception as e:
84
- # Handle other general exceptions
85
- return False, f"{CLI_WARNING_PREFIX}skipped loading python file '{file}': {e}"
86
-
87
- return True, None
88
-
89
-
90
- def load_function_file(filepath: str) -> dict:
91
- file = os.path.basename(filepath)
92
- module_name = file[:-3] # Remove '.py' from filename
93
- try:
94
- spec = importlib.util.spec_from_file_location(module_name, filepath)
95
- module = importlib.util.module_from_spec(spec)
96
- spec.loader.exec_module(module)
97
- except ModuleNotFoundError as e:
98
- # Handle missing module imports
99
- missing_package = str(e).split("'")[1] # Extract the name of the missing package
100
- print(f"{CLI_WARNING_PREFIX}skipped loading python file '{filepath}'!")
101
- print(
102
- f"'{file}' imports '{missing_package}', but '{missing_package}' is not installed locally - install python package '{missing_package}' to link functions from '{file}' to Letta."
103
- )
104
- # load all functions in the module
105
- function_dict = load_function_set(module)
106
- return function_dict
@@ -13,8 +13,6 @@ def generate_composio_tool_wrapper(action: "ActionType") -> tuple[str, str]:
13
13
 
14
14
  wrapper_function_str = f"""
15
15
  def {func_name}(**kwargs):
16
- if 'self' in kwargs:
17
- del kwargs['self']
18
16
  from composio import Action, App, Tag
19
17
  from composio_langchain import ComposioToolSet
20
18
 
@@ -46,8 +44,6 @@ def generate_langchain_tool_wrapper(
46
44
  # Combine all parts into the wrapper function
47
45
  wrapper_function_str = f"""
48
46
  def {func_name}(**kwargs):
49
- if 'self' in kwargs:
50
- del kwargs['self']
51
47
  import importlib
52
48
  {import_statement}
53
49
  {extra_module_imports}
@@ -1,5 +1,5 @@
1
1
  import inspect
2
- from typing import Any, Dict, Optional, Type, Union, get_args, get_origin
2
+ from typing import Any, Dict, List, Optional, Type, Union, get_args, get_origin
3
3
 
4
4
  from docstring_parser import parse
5
5
  from pydantic import BaseModel
@@ -22,7 +22,7 @@ def optional_length(annotation):
22
22
  raise ValueError("The annotation is not an Optional type")
23
23
 
24
24
 
25
- def type_to_json_schema_type(py_type):
25
+ def type_to_json_schema_type(py_type) -> dict:
26
26
  """
27
27
  Maps a Python type to a JSON schema type.
28
28
  Specifically handles typing.Optional and common Python types.
@@ -36,22 +36,87 @@ def type_to_json_schema_type(py_type):
36
36
  # Extract and map the inner type
37
37
  return type_to_json_schema_type(type_args[0])
38
38
 
39
+ # Handle Union types (except Optional which is handled above)
40
+ if get_origin(py_type) is Union:
41
+ # TODO support mapping Unions to anyOf
42
+ raise NotImplementedError("General Union types are not yet supported")
43
+
44
+ # Handle array types
45
+ origin = get_origin(py_type)
46
+ if py_type == list or origin in (list, List):
47
+ args = get_args(py_type)
48
+
49
+ if args and inspect.isclass(args[0]) and issubclass(args[0], BaseModel):
50
+ # If it's a list of Pydantic models, return an array with the model schema as items
51
+ return {
52
+ "type": "array",
53
+ "items": pydantic_model_to_json_schema(args[0]),
54
+ }
55
+
56
+ # Otherwise, recursively call the basic type checker
57
+ return {
58
+ "type": "array",
59
+ # get the type of the items in the list
60
+ "items": type_to_json_schema_type(args[0]),
61
+ }
62
+
63
+ # Handle object types
64
+ if py_type == dict or origin in (dict, Dict):
65
+ args = get_args(py_type)
66
+ if not args:
67
+ # Generic dict without type arguments
68
+ return {
69
+ "type": "object",
70
+ # "properties": {}
71
+ }
72
+ else:
73
+ raise ValueError(
74
+ f"Dictionary types {py_type} with nested type arguments are not supported (consider using a Pydantic model instead)"
75
+ )
76
+
77
+ # NOTE: the below code works for generic JSON schema parsing, but there's a problem with the key inference
78
+ # when it comes to OpenAI function schema generation so it doesn't make sense to allow for dict[str, Any] type hints
79
+ # key_type, value_type = args
80
+
81
+ # # Ensure dict keys are strings
82
+ # # Otherwise there's no JSON schema equivalent
83
+ # if key_type != str:
84
+ # raise ValueError("Dictionary keys must be strings for OpenAI function schema compatibility")
85
+
86
+ # # Handle value type to determine property schema
87
+ # value_schema = {}
88
+ # if inspect.isclass(value_type) and issubclass(value_type, BaseModel):
89
+ # value_schema = pydantic_model_to_json_schema(value_type)
90
+ # else:
91
+ # value_schema = type_to_json_schema_type(value_type)
92
+
93
+ # # NOTE: the problem lies here - the key is always "key_placeholder"
94
+ # return {"type": "object", "properties": {"key_placeholder": value_schema}}
95
+
96
+ # Handle direct Pydantic models
97
+ if inspect.isclass(py_type) and issubclass(py_type, BaseModel):
98
+ return pydantic_model_to_json_schema(py_type)
99
+
39
100
  # Mapping of Python types to JSON schema types
40
101
  type_map = {
102
+ # Basic types
103
+ # Optional, Union, and collections are handled above ^
41
104
  int: "integer",
42
105
  str: "string",
43
106
  bool: "boolean",
44
107
  float: "number",
45
- list[str]: "array",
46
- # Add more mappings as needed
108
+ None: "null",
47
109
  }
48
110
  if py_type not in type_map:
49
- raise ValueError(f"Python type {py_type} has no corresponding JSON schema type")
50
-
51
- return type_map.get(py_type, "string") # Default to "string" if type not in map
111
+ raise ValueError(f"Python type {py_type} has no corresponding JSON schema type - full map: {type_map}")
112
+ else:
113
+ return {"type": type_map[py_type]}
52
114
 
53
115
 
54
- def pydantic_model_to_open_ai(model):
116
+ def pydantic_model_to_open_ai(model: Type[BaseModel]) -> dict:
117
+ """
118
+ Converts a Pydantic model as a singular arg to a JSON schema object for use in OpenAI function calling.
119
+ """
55
120
  schema = model.model_json_schema()
56
121
  docstring = parse(model.__doc__ or "")
57
122
  parameters = {k: v for k, v in schema.items() if k not in ("title", "description")}
@@ -66,7 +131,7 @@ def pydantic_model_to_open_ai(model):
66
131
  if docstring.short_description:
67
132
  schema["description"] = docstring.short_description
68
133
  else:
69
- raise
134
+ raise ValueError(f"No description found in docstring or description field (model: {model}, docstring: {docstring})")
70
135
 
71
136
  return {
72
137
  "name": schema["title"],
@@ -75,6 +140,159 @@ def pydantic_model_to_open_ai(model):
75
140
  }
76
141
 
77
142
 
143
+ def pydantic_model_to_json_schema(model: Type[BaseModel]) -> dict:
144
+ """
145
+ Converts a Pydantic model (as an arg that already is annotated) to a JSON schema object for use in OpenAI function calling.
146
+
147
+ An example of a Pydantic model as an arg:
148
+
149
+ class Step(BaseModel):
150
+ name: str = Field(
151
+ ...,
152
+ description="Name of the step.",
153
+ )
154
+ key: str = Field(
155
+ ...,
156
+ description="Unique identifier for the step.",
157
+ )
158
+ description: str = Field(
159
+ ...,
160
+ description="An exhaustic description of what this step is trying to achieve and accomplish.",
161
+ )
162
+
163
+ def create_task_plan(steps: list[Step]):
164
+ '''
165
+ Creates a task plan for the current task.
166
+
167
+ Args:
168
+ steps: List of steps to add to the task plan.
169
+ ...
170
+
171
+ Should result in:
172
+ {
173
+ "name": "create_task_plan",
174
+ "description": "Creates a task plan for the current task.",
175
+ "parameters": {
176
+ "type": "object",
177
+ "properties": {
178
+ "steps": { # <= this is the name of the arg
179
+ "type": "object",
180
+ "description": "List of steps to add to the task plan.",
181
+ "properties": {
182
+ "name": {
183
+ "type": "str",
184
+ "description": "Name of the step.",
185
+ },
186
+ "key": {
187
+ "type": "str",
188
+ "description": "Unique identifier for the step.",
189
+ },
190
+ "description": {
191
+ "type": "str",
192
+ "description": "An exhaustic description of what this step is trying to achieve and accomplish.",
193
+ },
194
+ },
195
+ "required": ["name", "key", "description"],
196
+ }
197
+ },
198
+ "required": ["steps"],
199
+ }
200
+ }
201
+
202
+ Specifically, the result of pydantic_model_to_json_schema(steps) (where `steps` is an instance of BaseModel) is:
203
+ {
204
+ "type": "object",
205
+ "properties": {
206
+ "name": {
207
+ "type": "str",
208
+ "description": "Name of the step."
209
+ },
210
+ "key": {
211
+ "type": "str",
212
+ "description": "Unique identifier for the step."
213
+ },
214
+ "description": {
215
+ "type": "str",
216
+ "description": "An exhaustic description of what this step is trying to achieve and accomplish."
217
+ },
218
+ },
219
+ "required": ["name", "key", "description"],
220
+ }
221
+ """
222
+ schema = model.model_json_schema()
223
+
224
+ def clean_property(prop: dict) -> dict:
225
+ """Clean up a property schema to match desired format"""
226
+
227
+ if "description" not in prop:
228
+ raise ValueError(f"Property {prop} lacks a 'description' key")
229
+
230
+ return {
231
+ "type": "string" if prop["type"] == "string" else prop["type"],
232
+ "description": prop["description"],
233
+ }
234
+
235
+ def resolve_ref(ref: str, schema: dict) -> dict:
236
+ """Resolve a $ref reference in the schema"""
237
+ if not ref.startswith("#/$defs/"):
238
+ raise ValueError(f"Unexpected reference format: {ref}")
239
+
240
+ model_name = ref.split("/")[-1]
241
+ if model_name not in schema.get("$defs", {}):
242
+ raise ValueError(f"Reference {model_name} not found in schema definitions")
243
+
244
+ return schema["$defs"][model_name]
245
+
246
+ def clean_schema(schema_part: dict, full_schema: dict) -> dict:
247
+ """Clean up a schema part, handling references and nested structures"""
248
+ # Handle $ref
249
+ if "$ref" in schema_part:
250
+ schema_part = resolve_ref(schema_part["$ref"], full_schema)
251
+
252
+ if "type" not in schema_part:
253
+ raise ValueError(f"Schema part lacks a 'type' key: {schema_part}")
254
+
255
+ # Handle array type
256
+ if schema_part["type"] == "array":
257
+ items_schema = schema_part["items"]
258
+ if "$ref" in items_schema:
259
+ items_schema = resolve_ref(items_schema["$ref"], full_schema)
260
+ return {"type": "array", "items": clean_schema(items_schema, full_schema), "description": schema_part.get("description", "")}
261
+
262
+ # Handle object type
263
+ if schema_part["type"] == "object":
264
+ if "properties" not in schema_part:
265
+ raise ValueError(f"Object schema lacks 'properties' key: {schema_part}")
266
+
267
+ properties = {}
268
+ for name, prop in schema_part["properties"].items():
269
+ if "items" in prop: # Handle arrays
270
+ if "description" not in prop:
271
+ raise ValueError(f"Property {prop} lacks a 'description' key")
272
+ properties[name] = {
273
+ "type": "array",
274
+ "items": clean_schema(prop["items"], full_schema),
275
+ "description": prop["description"],
276
+ }
277
+ else:
278
+ properties[name] = clean_property(prop)
279
+
280
+ pydantic_model_schema_dict = {
281
+ "type": "object",
282
+ "properties": properties,
283
+ "required": schema_part.get("required", []),
284
+ }
285
+ if "description" in schema_part:
286
+ pydantic_model_schema_dict["description"] = schema_part["description"]
287
+
288
+ return pydantic_model_schema_dict
289
+
290
+ # Handle primitive types
291
+ return clean_property(schema_part)
292
+
293
+ return clean_schema(schema_part=schema, full_schema=schema)
294
+
295
+
78
296
  def generate_schema(function, name: Optional[str] = None, description: Optional[str] = None) -> dict:
79
297
  # Get the signature of the function
80
298
  sig = inspect.signature(function)
@@ -93,9 +311,14 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[
93
311
 
94
312
  for param in sig.parameters.values():
95
313
  # Exclude 'self' parameter
314
+ # TODO: eventually remove this (only applies to BASE_TOOLS)
96
315
  if param.name == "self":
97
316
  continue
98
317
 
318
+ # exclude 'agent_state' parameter
319
+ if param.name == "agent_state":
320
+ continue
321
+
99
322
  # Assert that the parameter has a type annotation
100
323
  if param.annotation == inspect.Parameter.empty:
101
324
  raise TypeError(f"Parameter '{param.name}' in function '{function.__name__}' lacks a type annotation")
@@ -107,28 +330,66 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[
107
330
  if not param_doc or not param_doc.description:
108
331
  raise ValueError(f"Parameter '{param.name}' in function '{function.__name__}' lacks a description in the docstring")
109
332
 
110
- if inspect.isclass(param.annotation) and issubclass(param.annotation, BaseModel):
111
- schema["parameters"]["properties"][param.name] = pydantic_model_to_open_ai(param.annotation)
333
+ # If the parameter is a pydantic model, we need to unpack the Pydantic model type into a JSON schema object
334
+ # if inspect.isclass(param.annotation) and issubclass(param.annotation, BaseModel):
335
+ if (
336
+ (inspect.isclass(param.annotation) or inspect.isclass(get_origin(param.annotation) or param.annotation))
337
+ and not get_origin(param.annotation)
338
+ and issubclass(param.annotation, BaseModel)
339
+ ):
340
+ # print("Generating schema for pydantic model:", param.annotation)
341
+ # Extract the properties from the pydantic model
342
+ schema["parameters"]["properties"][param.name] = pydantic_model_to_json_schema(param.annotation)
343
+ schema["parameters"]["properties"][param.name]["description"] = param_doc.description
344
+
345
+ # Otherwise, we convert the Python typing to JSON schema types
346
+ # NOTE: important - if a dict or list, the internal type can be a Pydantic model itself
347
+ # however in that
112
348
  else:
113
- # Add parameter details to the schema
349
+ # print("Generating schema for non-pydantic model:", param.annotation)
350
+ # Grab the description for the parameter from the extended docstring
351
+ # If it doesn't exist, we should raise an error
114
352
  param_doc = next((d for d in docstring.params if d.arg_name == param.name), None)
115
- schema["parameters"]["properties"][param.name] = {
116
- # "type": "string" if param.annotation == str else str(param.annotation),
117
- "type": type_to_json_schema_type(param.annotation) if param.annotation != inspect.Parameter.empty else "string",
118
- "description": param_doc.description,
119
- }
120
- if param.default == inspect.Parameter.empty:
353
+ if not param_doc:
354
+ raise ValueError(f"Parameter '{param.name}' in function '{function.__name__}' lacks a description in the docstring")
355
+ elif not isinstance(param_doc.description, str):
356
+ raise ValueError(
357
+ f"Parameter '{param.name}' in function '{function.__name__}' has a description in the docstring that is not a string (type: {type(param_doc.description)})"
358
+ )
359
+ else:
360
+ # If it's a string or a basic type, then all you need is: (1) type, (2) description
361
+ # If it's a more complex type, then you also need either:
362
+ # - for array, you need "items", each of which has "type"
363
+ # - for a dict, you need "properties", which has keys which each have "type"
364
+ if param.annotation != inspect.Parameter.empty:
365
+ param_generated_schema = type_to_json_schema_type(param.annotation)
366
+ else:
367
+ # TODO why are we inferring here?
368
+ param_generated_schema = {"type": "string"}
369
+
370
+ # Add in the description
371
+ param_generated_schema["description"] = param_doc.description
372
+
373
+ # Add the schema to the function arg key
374
+ schema["parameters"]["properties"][param.name] = param_generated_schema
375
+
376
+ # If the parameter doesn't have a default value, it is required (so we need to add it to the required list)
377
+ if param.default == inspect.Parameter.empty and not is_optional(param.annotation):
121
378
  schema["parameters"]["required"].append(param.name)
122
379
 
380
+ # TODO what's going on here?
381
+ # If the parameter is a list of strings we need to hard cast to "string" instead of `str`
123
382
  if get_origin(param.annotation) is list:
124
383
  if get_args(param.annotation)[0] is str:
125
384
  schema["parameters"]["properties"][param.name]["items"] = {"type": "string"}
126
385
 
386
+ # TODO is this not duplicating the other append directly above?
127
387
  if param.annotation == inspect.Parameter.empty:
128
388
  schema["parameters"]["required"].append(param.name)
129
389
 
130
390
  # append the heartbeat
131
391
  # TODO: don't hard-code
392
+ # TODO: if terminal, don't include this
132
393
  if function.__name__ not in ["send_message", "pause_heartbeats"]:
133
394
  schema["parameters"]["properties"]["request_heartbeat"] = {
134
395
  "type": "boolean",