letta-nightly 0.5.4.dev20241128000451__py3-none-any.whl → 0.6.0.dev20241204051808__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of letta-nightly might be problematic. Click here for more details.

letta/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.5.4"
1
+ __version__ = "0.6.0"
2
2
 
3
3
  # import clients
4
4
  from letta.client.client import LocalClient, RESTClient, create_client
letta/agent.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import datetime
2
2
  import inspect
3
+ import time
3
4
  import traceback
4
5
  import warnings
5
6
  from abc import ABC, abstractmethod
@@ -566,60 +567,60 @@ class Agent(BaseAgent):
566
567
  self,
567
568
  message_sequence: List[Message],
568
569
  function_call: str = "auto",
569
- first_message: bool = False, # hint
570
+ first_message: bool = False,
570
571
  stream: bool = False, # TODO move to config?
571
- fail_on_empty_response: bool = False,
572
572
  empty_response_retry_limit: int = 3,
573
+ backoff_factor: float = 0.5, # delay multiplier for exponential backoff
574
+ max_delay: float = 10.0, # max delay between retries
573
575
  ) -> ChatCompletionResponse:
574
- """Get response from LLM API"""
575
- # Get the allowed tools based on the ToolRulesSolver state
576
+ """Get response from LLM API with robust retry mechanism."""
577
+
576
578
  allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names()
579
+ allowed_functions = (
580
+ self.functions if not allowed_tool_names else [func for func in self.functions if func["name"] in allowed_tool_names]
581
+ )
577
582
 
578
- if not allowed_tool_names:
579
- # if it's empty, any available tools are fair game
580
- allowed_functions = self.functions
581
- else:
582
- allowed_functions = [func for func in self.functions if func["name"] in allowed_tool_names]
583
+ for attempt in range(1, empty_response_retry_limit + 1):
584
+ try:
585
+ response = create(
586
+ llm_config=self.agent_state.llm_config,
587
+ messages=message_sequence,
588
+ user_id=self.agent_state.user_id,
589
+ functions=allowed_functions,
590
+ functions_python=self.functions_python,
591
+ function_call=function_call,
592
+ first_message=first_message,
593
+ stream=stream,
594
+ stream_interface=self.interface,
595
+ )
583
596
 
584
- try:
585
- response = create(
586
- # agent_state=self.agent_state,
587
- llm_config=self.agent_state.llm_config,
588
- messages=message_sequence,
589
- user_id=self.agent_state.user_id,
590
- functions=allowed_functions,
591
- functions_python=self.functions_python,
592
- function_call=function_call,
593
- # hint
594
- first_message=first_message,
595
- # streaming
596
- stream=stream,
597
- stream_interface=self.interface,
598
- )
597
+ # These bottom two are retryable
598
+ if len(response.choices) == 0 or response.choices[0] is None:
599
+ raise ValueError(f"API call returned an empty message: {response}")
599
600
 
600
- if len(response.choices) == 0 or response.choices[0] is None:
601
- empty_api_err_message = f"API call didn't return a message: {response}"
602
- if fail_on_empty_response or empty_response_retry_limit == 0:
603
- raise Exception(empty_api_err_message)
604
- else:
605
- # Decrement retry limit and try again
606
- warnings.warn(empty_api_err_message)
607
- return self._get_ai_reply(
608
- message_sequence, function_call, first_message, stream, fail_on_empty_response, empty_response_retry_limit - 1
609
- )
601
+ if response.choices[0].finish_reason not in ["stop", "function_call", "tool_calls"]:
602
+ if response.choices[0].finish_reason == "length":
603
+ # This is not retryable, hence RuntimeError v.s. ValueError
604
+ raise RuntimeError("Finish reason was length (maximum context length)")
605
+ else:
606
+ raise ValueError(f"Bad finish reason from API: {response.choices[0].finish_reason}")
607
+
608
+ return response
610
609
 
611
- # special case for 'length'
612
- if response.choices[0].finish_reason == "length":
613
- raise Exception("Finish reason was length (maximum context length)")
610
+ except ValueError as ve:
611
+ if attempt >= empty_response_retry_limit:
612
+ warnings.warn(f"Retry limit reached. Final error: {ve}")
613
+ break
614
+ else:
615
+ delay = min(backoff_factor * (2 ** (attempt - 1)), max_delay)
616
+ warnings.warn(f"Attempt {attempt} failed: {ve}. Retrying in {delay} seconds...")
617
+ time.sleep(delay)
614
618
 
615
- # catches for soft errors
616
- if response.choices[0].finish_reason not in ["stop", "function_call", "tool_calls"]:
617
- raise Exception(f"API call finish with bad finish reason: {response}")
619
+ except Exception as e:
620
+ # For non-retryable errors, exit immediately
621
+ raise e
618
622
 
619
- # unpack with response.choices[0].message.content
620
- return response
621
- except Exception as e:
622
- raise e
623
+ raise Exception("Retries exhausted and no valid response received.")
623
624
 
624
625
  def _handle_ai_response(
625
626
  self,
letta/cli/cli.py CHANGED
@@ -10,7 +10,12 @@ import letta.utils as utils
10
10
  from letta import create_client
11
11
  from letta.agent import Agent, save_agent
12
12
  from letta.config import LettaConfig
13
- from letta.constants import CLI_WARNING_PREFIX, LETTA_DIR, MIN_CONTEXT_WINDOW
13
+ from letta.constants import (
14
+ CLI_WARNING_PREFIX,
15
+ CORE_MEMORY_BLOCK_CHAR_LIMIT,
16
+ LETTA_DIR,
17
+ MIN_CONTEXT_WINDOW,
18
+ )
14
19
  from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL
15
20
  from letta.log import get_logger
16
21
  from letta.metadata import MetadataStore
@@ -91,7 +96,7 @@ def run(
91
96
  ] = None,
92
97
  core_memory_limit: Annotated[
93
98
  Optional[int], typer.Option(help="The character limit to each core-memory section (human/persona).")
94
- ] = 2000,
99
+ ] = CORE_MEMORY_BLOCK_CHAR_LIMIT,
95
100
  # other
96
101
  first: Annotated[bool, typer.Option(help="Use --first to send the first message in the sequence")] = False,
97
102
  strip_ui: Annotated[bool, typer.Option(help="Remove all the bells and whistles in CLI output (helpful for testing)")] = False,
@@ -220,7 +225,8 @@ def run(
220
225
 
221
226
  # create agent
222
227
  tools = [server.tool_manager.get_tool_by_name(tool_name=tool_name, actor=client.user) for tool_name in agent_state.tool_names]
223
- letta_agent = Agent(agent_state=agent_state, interface=interface(), tools=tools, user=client.user)
228
+ agent_state.tools = tools
229
+ letta_agent = Agent(agent_state=agent_state, interface=interface(), user=client.user)
224
230
 
225
231
  else: # create new agent
226
232
  # create new agent config: override defaults with args if provided
letta/client/client.py CHANGED
@@ -434,6 +434,7 @@ class RESTClient(AbstractClient):
434
434
  debug: bool = False,
435
435
  default_llm_config: Optional[LLMConfig] = None,
436
436
  default_embedding_config: Optional[EmbeddingConfig] = None,
437
+ headers: Optional[Dict] = None,
437
438
  ):
438
439
  """
439
440
  Initializes a new instance of Client class.
@@ -442,12 +443,16 @@ class RESTClient(AbstractClient):
442
443
  auto_save (bool): Whether to automatically save changes.
443
444
  user_id (str): The user ID.
444
445
  debug (bool): Whether to print debug information.
445
- default
446
+ default_llm_config (Optional[LLMConfig]): The default LLM configuration.
447
+ default_embedding_config (Optional[EmbeddingConfig]): The default embedding configuration.
448
+ headers (Optional[Dict]): The additional headers for the REST API.
446
449
  """
447
450
  super().__init__(debug=debug)
448
451
  self.base_url = base_url
449
452
  self.api_prefix = api_prefix
450
453
  self.headers = {"accept": "application/json", "authorization": f"Bearer {token}"}
454
+ if headers:
455
+ self.headers.update(headers)
451
456
  self._default_llm_config = default_llm_config
452
457
  self._default_embedding_config = default_embedding_config
453
458
 
@@ -1,33 +1,61 @@
1
- import importlib
2
1
  import inspect
3
- import os
4
2
  from textwrap import dedent # remove indentation
5
3
  from types import ModuleType
6
4
  from typing import Dict, List, Optional
7
5
 
8
- from letta.constants import CLI_WARNING_PREFIX
9
6
  from letta.errors import LettaToolCreateError
10
7
  from letta.functions.schema_generator import generate_schema
11
8
 
12
9
 
13
10
  def derive_openai_json_schema(source_code: str, name: Optional[str] = None) -> dict:
14
- # auto-generate openai schema
11
+ """Derives the OpenAI JSON schema for a given function source code.
12
+
13
+ First, attempts to execute the source code in a custom environment with only the necessary imports.
14
+ Then, it generates the schema from the function's docstring and signature.
15
+ """
15
16
  try:
16
17
  # Define a custom environment with necessary imports
17
- env = {"Optional": Optional, "List": List, "Dict": Dict} # Add any other required imports here
18
-
18
+ env = {
19
+ "Optional": Optional,
20
+ "List": List,
21
+ "Dict": Dict,
22
+ # To support Pydantic models
23
+ # "BaseModel": BaseModel,
24
+ # "Field": Field,
25
+ }
19
26
  env.update(globals())
27
+
28
+ # print("About to execute source code...")
20
29
  exec(source_code, env)
30
+ # print("Source code executed successfully")
21
31
 
22
- # get available functions
23
- functions = [f for f in env if callable(env[f])]
32
+ functions = [f for f in env if callable(env[f]) and not f.startswith("__")]
33
+ if not functions:
34
+ raise LettaToolCreateError("No callable functions found in source code")
24
35
 
25
- # TODO: not sure if this always works
36
+ # print(f"Found functions: {functions}")
26
37
  func = env[functions[-1]]
27
- json_schema = generate_schema(func, name=name)
28
- return json_schema
38
+
39
+ if not hasattr(func, "__doc__") or not func.__doc__:
40
+ raise LettaToolCreateError(f"Function {func.__name__} missing docstring")
41
+
42
+ # print("About to generate schema...")
43
+ try:
44
+ schema = generate_schema(func, name=name)
45
+ # print("Schema generated successfully")
46
+ return schema
47
+ except TypeError as e:
48
+ raise LettaToolCreateError(f"Type error in schema generation: {str(e)}")
49
+ except ValueError as e:
50
+ raise LettaToolCreateError(f"Value error in schema generation: {str(e)}")
51
+ except Exception as e:
52
+ raise LettaToolCreateError(f"Unexpected error in schema generation: {str(e)}")
53
+
29
54
  except Exception as e:
30
- raise LettaToolCreateError(f"Failed to derive JSON schema from source code: {e}")
55
+ import traceback
56
+
57
+ traceback.print_exc()
58
+ raise LettaToolCreateError(f"Schema generation failed: {str(e)}") from e
31
59
 
32
60
 
33
61
  def parse_source_code(func) -> str:
@@ -59,46 +87,3 @@ def load_function_set(module: ModuleType) -> dict:
59
87
  if len(function_dict) == 0:
60
88
  raise ValueError(f"No functions found in module {module}")
61
89
  return function_dict
62
-
63
-
64
- def validate_function(module_name, module_full_path):
65
- try:
66
- file = os.path.basename(module_full_path)
67
- spec = importlib.util.spec_from_file_location(module_name, module_full_path)
68
- module = importlib.util.module_from_spec(spec)
69
- spec.loader.exec_module(module)
70
- except ModuleNotFoundError as e:
71
- # Handle missing module imports
72
- missing_package = str(e).split("'")[1] # Extract the name of the missing package
73
- print(f"{CLI_WARNING_PREFIX}skipped loading python file '{module_full_path}'!")
74
- return (
75
- False,
76
- f"'{file}' imports '{missing_package}', but '{missing_package}' is not installed locally - install python package '{missing_package}' to link functions from '{file}' to Letta.",
77
- )
78
- except SyntaxError as e:
79
- # Handle syntax errors in the module
80
- return False, f"{CLI_WARNING_PREFIX}skipped loading python file '{file}' due to a syntax error: {e}"
81
- except Exception as e:
82
- # Handle other general exceptions
83
- return False, f"{CLI_WARNING_PREFIX}skipped loading python file '{file}': {e}"
84
-
85
- return True, None
86
-
87
-
88
- def load_function_file(filepath: str) -> dict:
89
- file = os.path.basename(filepath)
90
- module_name = file[:-3] # Remove '.py' from filename
91
- try:
92
- spec = importlib.util.spec_from_file_location(module_name, filepath)
93
- module = importlib.util.module_from_spec(spec)
94
- spec.loader.exec_module(module)
95
- except ModuleNotFoundError as e:
96
- # Handle missing module imports
97
- missing_package = str(e).split("'")[1] # Extract the name of the missing package
98
- print(f"{CLI_WARNING_PREFIX}skipped loading python file '{filepath}'!")
99
- print(
100
- f"'{file}' imports '{missing_package}', but '{missing_package}' is not installed locally - install python package '{missing_package}' to link functions from '{file}' to Letta."
101
- )
102
- # load all functions in the module
103
- function_dict = load_function_set(module)
104
- return function_dict
@@ -22,7 +22,7 @@ def optional_length(annotation):
22
22
  raise ValueError("The annotation is not an Optional type")
23
23
 
24
24
 
25
- def type_to_json_schema_type(py_type):
25
+ def type_to_json_schema_type(py_type) -> dict:
26
26
  """
27
27
  Maps a Python type to a JSON schema type.
28
28
  Specifically handles typing.Optional and common Python types.
@@ -36,36 +36,87 @@ def type_to_json_schema_type(py_type):
36
36
  # Extract and map the inner type
37
37
  return type_to_json_schema_type(type_args[0])
38
38
 
39
+ # Handle Union types (except Optional which is handled above)
40
+ if get_origin(py_type) is Union:
41
+ # TODO support mapping Unions to anyOf
42
+ raise NotImplementedError("General Union types are not yet supported")
43
+
44
+ # Handle array types
45
+ origin = get_origin(py_type)
46
+ if py_type == list or origin in (list, List):
47
+ args = get_args(py_type)
48
+
49
+ if args and inspect.isclass(args[0]) and issubclass(args[0], BaseModel):
50
+ # If it's a list of Pydantic models, return an array with the model schema as items
51
+ return {
52
+ "type": "array",
53
+ "items": pydantic_model_to_json_schema(args[0]),
54
+ }
55
+
56
+ # Otherwise, recursively call the basic type checker
57
+ return {
58
+ "type": "array",
59
+ # get the type of the items in the list
60
+ "items": type_to_json_schema_type(args[0]),
61
+ }
62
+
63
+ # Handle object types
64
+ if py_type == dict or origin in (dict, Dict):
65
+ args = get_args(py_type)
66
+ if not args:
67
+ # Generic dict without type arguments
68
+ return {
69
+ "type": "object",
70
+ # "properties": {}
71
+ }
72
+ else:
73
+ raise ValueError(
74
+ f"Dictionary types {py_type} with nested type arguments are not supported (consider using a Pydantic model instead)"
75
+ )
76
+
77
+ # NOTE: the below code works for generic JSON schema parsing, but there's a problem with the key inference
78
+ # when it comes to OpenAI function schema generation so it doesn't make sense to allow for dict[str, Any] type hints
79
+ # key_type, value_type = args
80
+
81
+ # # Ensure dict keys are strings
82
+ # # Otherwise there's no JSON schema equivalent
83
+ # if key_type != str:
84
+ # raise ValueError("Dictionary keys must be strings for OpenAI function schema compatibility")
85
+
86
+ # # Handle value type to determine property schema
87
+ # value_schema = {}
88
+ # if inspect.isclass(value_type) and issubclass(value_type, BaseModel):
89
+ # value_schema = pydantic_model_to_json_schema(value_type)
90
+ # else:
91
+ # value_schema = type_to_json_schema_type(value_type)
92
+
93
+ # # NOTE: the problem lies here - the key is always "key_placeholder"
94
+ # return {"type": "object", "properties": {"key_placeholder": value_schema}}
95
+
96
+ # Handle direct Pydantic models
97
+ if inspect.isclass(py_type) and issubclass(py_type, BaseModel):
98
+ return pydantic_model_to_json_schema(py_type)
99
+
39
100
  # Mapping of Python types to JSON schema types
40
101
  type_map = {
41
102
  # Basic types
103
+ # Optional, Union, and collections are handled above ^
42
104
  int: "integer",
43
105
  str: "string",
44
106
  bool: "boolean",
45
107
  float: "number",
46
- # Collections
47
- List[str]: "array",
48
- List[int]: "array",
49
- list: "array",
50
- tuple: "array",
51
- set: "array",
52
- # Dictionaries
53
- dict: "object",
54
- Dict[str, Any]: "object",
55
- # Special types
56
108
  None: "null",
57
- type(None): "null",
58
- # Optional types
59
- # Optional[str]: "string", # NOTE: caught above ^
60
- Union[str, None]: "string",
61
109
  }
62
110
  if py_type not in type_map:
63
111
  raise ValueError(f"Python type {py_type} has no corresponding JSON schema type - full map: {type_map}")
64
-
65
- return type_map.get(py_type, "string") # Default to "string" if type not in map
112
+ else:
113
+ return {"type": type_map[py_type]}
66
114
 
67
115
 
68
- def pydantic_model_to_open_ai(model):
116
+ def pydantic_model_to_open_ai(model: Type[BaseModel]) -> dict:
117
+ """
118
+ Converts a Pydantic model as a singular arg to a JSON schema object for use in OpenAI function calling.
119
+ """
69
120
  schema = model.model_json_schema()
70
121
  docstring = parse(model.__doc__ or "")
71
122
  parameters = {k: v for k, v in schema.items() if k not in ("title", "description")}
@@ -80,7 +131,7 @@ def pydantic_model_to_open_ai(model):
80
131
  if docstring.short_description:
81
132
  schema["description"] = docstring.short_description
82
133
  else:
83
- raise
134
+ raise ValueError(f"No description found in docstring or description field (model: {model}, docstring: {docstring})")
84
135
 
85
136
  return {
86
137
  "name": schema["title"],
@@ -89,6 +140,159 @@ def pydantic_model_to_open_ai(model):
89
140
  }
90
141
 
91
142
 
143
+ def pydantic_model_to_json_schema(model: Type[BaseModel]) -> dict:
144
+ """
145
+ Converts a Pydantic model (as an arg that already is annotated) to a JSON schema object for use in OpenAI function calling.
146
+
147
+ An example of a Pydantic model as an arg:
148
+
149
+ class Step(BaseModel):
150
+ name: str = Field(
151
+ ...,
152
+ description="Name of the step.",
153
+ )
154
+ key: str = Field(
155
+ ...,
156
+ description="Unique identifier for the step.",
157
+ )
158
+ description: str = Field(
159
+ ...,
160
+ description="An exhaustic description of what this step is trying to achieve and accomplish.",
161
+ )
162
+
163
+ def create_task_plan(steps: list[Step]):
164
+ '''
165
+ Creates a task plan for the current task.
166
+
167
+ Args:
168
+ steps: List of steps to add to the task plan.
169
+ ...
170
+
171
+ Should result in:
172
+ {
173
+ "name": "create_task_plan",
174
+ "description": "Creates a task plan for the current task.",
175
+ "parameters": {
176
+ "type": "object",
177
+ "properties": {
178
+ "steps": { # <= this is the name of the arg
179
+ "type": "object",
180
+ "description": "List of steps to add to the task plan.",
181
+ "properties": {
182
+ "name": {
183
+ "type": "str",
184
+ "description": "Name of the step.",
185
+ },
186
+ "key": {
187
+ "type": "str",
188
+ "description": "Unique identifier for the step.",
189
+ },
190
+ "description": {
191
+ "type": "str",
192
+ "description": "An exhaustic description of what this step is trying to achieve and accomplish.",
193
+ },
194
+ },
195
+ "required": ["name", "key", "description"],
196
+ }
197
+ },
198
+ "required": ["steps"],
199
+ }
200
+ }
201
+
202
+ Specifically, the result of pydantic_model_to_json_schema(steps) (where `steps` is an instance of BaseModel) is:
203
+ {
204
+ "type": "object",
205
+ "properties": {
206
+ "name": {
207
+ "type": "str",
208
+ "description": "Name of the step."
209
+ },
210
+ "key": {
211
+ "type": "str",
212
+ "description": "Unique identifier for the step."
213
+ },
214
+ "description": {
215
+ "type": "str",
216
+ "description": "An exhaustic description of what this step is trying to achieve and accomplish."
217
+ },
218
+ },
219
+ "required": ["name", "key", "description"],
220
+ }
221
+ """
222
+ schema = model.model_json_schema()
223
+
224
+ def clean_property(prop: dict) -> dict:
225
+ """Clean up a property schema to match desired format"""
226
+
227
+ if "description" not in prop:
228
+ raise ValueError(f"Property {prop} lacks a 'description' key")
229
+
230
+ return {
231
+ "type": "string" if prop["type"] == "string" else prop["type"],
232
+ "description": prop["description"],
233
+ }
234
+
235
+ def resolve_ref(ref: str, schema: dict) -> dict:
236
+ """Resolve a $ref reference in the schema"""
237
+ if not ref.startswith("#/$defs/"):
238
+ raise ValueError(f"Unexpected reference format: {ref}")
239
+
240
+ model_name = ref.split("/")[-1]
241
+ if model_name not in schema.get("$defs", {}):
242
+ raise ValueError(f"Reference {model_name} not found in schema definitions")
243
+
244
+ return schema["$defs"][model_name]
245
+
246
+ def clean_schema(schema_part: dict, full_schema: dict) -> dict:
247
+ """Clean up a schema part, handling references and nested structures"""
248
+ # Handle $ref
249
+ if "$ref" in schema_part:
250
+ schema_part = resolve_ref(schema_part["$ref"], full_schema)
251
+
252
+ if "type" not in schema_part:
253
+ raise ValueError(f"Schema part lacks a 'type' key: {schema_part}")
254
+
255
+ # Handle array type
256
+ if schema_part["type"] == "array":
257
+ items_schema = schema_part["items"]
258
+ if "$ref" in items_schema:
259
+ items_schema = resolve_ref(items_schema["$ref"], full_schema)
260
+ return {"type": "array", "items": clean_schema(items_schema, full_schema), "description": schema_part.get("description", "")}
261
+
262
+ # Handle object type
263
+ if schema_part["type"] == "object":
264
+ if "properties" not in schema_part:
265
+ raise ValueError(f"Object schema lacks 'properties' key: {schema_part}")
266
+
267
+ properties = {}
268
+ for name, prop in schema_part["properties"].items():
269
+ if "items" in prop: # Handle arrays
270
+ if "description" not in prop:
271
+ raise ValueError(f"Property {prop} lacks a 'description' key")
272
+ properties[name] = {
273
+ "type": "array",
274
+ "items": clean_schema(prop["items"], full_schema),
275
+ "description": prop["description"],
276
+ }
277
+ else:
278
+ properties[name] = clean_property(prop)
279
+
280
+ pydantic_model_schema_dict = {
281
+ "type": "object",
282
+ "properties": properties,
283
+ "required": schema_part.get("required", []),
284
+ }
285
+ if "description" in schema_part:
286
+ pydantic_model_schema_dict["description"] = schema_part["description"]
287
+
288
+ return pydantic_model_schema_dict
289
+
290
+ # Handle primitive types
291
+ return clean_property(schema_part)
292
+
293
+ return clean_schema(schema_part=schema, full_schema=schema)
294
+
295
+
92
296
  def generate_schema(function, name: Optional[str] = None, description: Optional[str] = None) -> dict:
93
297
  # Get the signature of the function
94
298
  sig = inspect.signature(function)
@@ -126,24 +330,60 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[
126
330
  if not param_doc or not param_doc.description:
127
331
  raise ValueError(f"Parameter '{param.name}' in function '{function.__name__}' lacks a description in the docstring")
128
332
 
129
- if inspect.isclass(param.annotation) and issubclass(param.annotation, BaseModel):
130
- schema["parameters"]["properties"][param.name] = pydantic_model_to_open_ai(param.annotation)
333
+ # If the parameter is a pydantic model, we need to unpack the Pydantic model type into a JSON schema object
334
+ # if inspect.isclass(param.annotation) and issubclass(param.annotation, BaseModel):
335
+ if (
336
+ (inspect.isclass(param.annotation) or inspect.isclass(get_origin(param.annotation) or param.annotation))
337
+ and not get_origin(param.annotation)
338
+ and issubclass(param.annotation, BaseModel)
339
+ ):
340
+ # print("Generating schema for pydantic model:", param.annotation)
341
+ # Extract the properties from the pydantic model
342
+ schema["parameters"]["properties"][param.name] = pydantic_model_to_json_schema(param.annotation)
343
+ schema["parameters"]["properties"][param.name]["description"] = param_doc.description
344
+
345
+ # Otherwise, we convert the Python typing to JSON schema types
346
+ # NOTE: important - if a dict or list, the internal type can be a Pydantic model itself
347
+ # however in that
131
348
  else:
132
- # Add parameter details to the schema
349
+ # print("Generating schema for non-pydantic model:", param.annotation)
350
+ # Grab the description for the parameter from the extended docstring
351
+ # If it doesn't exist, we should raise an error
133
352
  param_doc = next((d for d in docstring.params if d.arg_name == param.name), None)
134
- if param_doc:
135
- schema["parameters"]["properties"][param.name] = {
136
- # "type": "string" if param.annotation == str else str(param.annotation),
137
- "type": type_to_json_schema_type(param.annotation) if param.annotation != inspect.Parameter.empty else "string",
138
- "description": param_doc.description,
139
- }
140
- if param.default == inspect.Parameter.empty:
353
+ if not param_doc:
354
+ raise ValueError(f"Parameter '{param.name}' in function '{function.__name__}' lacks a description in the docstring")
355
+ elif not isinstance(param_doc.description, str):
356
+ raise ValueError(
357
+ f"Parameter '{param.name}' in function '{function.__name__}' has a description in the docstring that is not a string (type: {type(param_doc.description)})"
358
+ )
359
+ else:
360
+ # If it's a string or a basic type, then all you need is: (1) type, (2) description
361
+ # If it's a more complex type, then you also need either:
362
+ # - for array, you need "items", each of which has "type"
363
+ # - for a dict, you need "properties", which has keys which each have "type"
364
+ if param.annotation != inspect.Parameter.empty:
365
+ param_generated_schema = type_to_json_schema_type(param.annotation)
366
+ else:
367
+ # TODO why are we inferring here?
368
+ param_generated_schema = {"type": "string"}
369
+
370
+ # Add in the description
371
+ param_generated_schema["description"] = param_doc.description
372
+
373
+ # Add the schema to the function arg key
374
+ schema["parameters"]["properties"][param.name] = param_generated_schema
375
+
376
+ # If the parameter doesn't have a default value, it is required (so we need to add it to the required list)
377
+ if param.default == inspect.Parameter.empty and not is_optional(param.annotation):
141
378
  schema["parameters"]["required"].append(param.name)
142
379
 
380
+ # TODO what's going on here?
381
+ # If the parameter is a list of strings we need to hard cast to "string" instead of `str`
143
382
  if get_origin(param.annotation) is list:
144
383
  if get_args(param.annotation)[0] is str:
145
384
  schema["parameters"]["properties"][param.name]["items"] = {"type": "string"}
146
385
 
386
+ # TODO is this not duplicating the other append directly above?
147
387
  if param.annotation == inspect.Parameter.empty:
148
388
  schema["parameters"]["required"].append(param.name)
149
389