grasp_agents 0.2.11__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. grasp_agents/__init__.py +15 -14
  2. grasp_agents/cloud_llm.py +118 -131
  3. grasp_agents/comm_processor.py +201 -0
  4. grasp_agents/generics_utils.py +15 -7
  5. grasp_agents/llm.py +60 -31
  6. grasp_agents/llm_agent.py +229 -273
  7. grasp_agents/llm_agent_memory.py +58 -0
  8. grasp_agents/llm_policy_executor.py +482 -0
  9. grasp_agents/memory.py +20 -134
  10. grasp_agents/message_history.py +140 -0
  11. grasp_agents/openai/__init__.py +54 -36
  12. grasp_agents/openai/completion_chunk_converters.py +78 -0
  13. grasp_agents/openai/completion_converters.py +53 -30
  14. grasp_agents/openai/content_converters.py +13 -14
  15. grasp_agents/openai/converters.py +44 -68
  16. grasp_agents/openai/message_converters.py +58 -72
  17. grasp_agents/openai/openai_llm.py +101 -42
  18. grasp_agents/openai/tool_converters.py +24 -19
  19. grasp_agents/packet.py +24 -0
  20. grasp_agents/packet_pool.py +91 -0
  21. grasp_agents/printer.py +29 -15
  22. grasp_agents/processor.py +194 -0
  23. grasp_agents/prompt_builder.py +175 -192
  24. grasp_agents/run_context.py +20 -37
  25. grasp_agents/typing/completion.py +58 -12
  26. grasp_agents/typing/completion_chunk.py +173 -0
  27. grasp_agents/typing/converters.py +8 -12
  28. grasp_agents/typing/events.py +86 -0
  29. grasp_agents/typing/io.py +4 -13
  30. grasp_agents/typing/message.py +12 -50
  31. grasp_agents/typing/tool.py +52 -26
  32. grasp_agents/usage_tracker.py +6 -6
  33. grasp_agents/utils.py +3 -3
  34. grasp_agents/workflow/looped_workflow.py +132 -0
  35. grasp_agents/workflow/parallel_processor.py +95 -0
  36. grasp_agents/workflow/sequential_workflow.py +66 -0
  37. grasp_agents/workflow/workflow_processor.py +78 -0
  38. {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.1.dist-info}/METADATA +41 -50
  39. grasp_agents-0.3.1.dist-info/RECORD +51 -0
  40. grasp_agents/agent_message.py +0 -27
  41. grasp_agents/agent_message_pool.py +0 -92
  42. grasp_agents/base_agent.py +0 -51
  43. grasp_agents/comm_agent.py +0 -217
  44. grasp_agents/llm_agent_state.py +0 -79
  45. grasp_agents/tool_orchestrator.py +0 -203
  46. grasp_agents/workflow/looped_agent.py +0 -134
  47. grasp_agents/workflow/sequential_agent.py +0 -72
  48. grasp_agents/workflow/workflow_agent.py +0 -88
  49. grasp_agents-0.2.11.dist-info/RECORD +0 -46
  50. {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.1.dist-info}/WHEEL +0 -0
  51. {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.1.dist-info}/licenses/LICENSE.md +0 -0
@@ -1,31 +1,31 @@
1
- from collections.abc import AsyncIterator, Iterable
1
+ from collections.abc import Iterable
2
2
  from typing import Any
3
3
 
4
4
  from pydantic import BaseModel
5
5
 
6
- from ..typing.completion import Completion, CompletionChunk
6
+ from ..typing.completion import Completion, Usage
7
+ from ..typing.completion_chunk import CompletionChunk
7
8
  from ..typing.content import Content
8
9
  from ..typing.converters import Converters
9
10
  from ..typing.message import AssistantMessage, SystemMessage, ToolMessage, UserMessage
10
11
  from ..typing.tool import BaseTool, ToolChoice
11
12
  from . import (
12
- ChatCompletion,
13
- ChatCompletionAssistantMessageParam,
14
- ChatCompletionAsyncStream, # type: ignore[import]
15
- ChatCompletionChunk,
16
- ChatCompletionContentPartParam,
17
- ChatCompletionMessage,
18
- ChatCompletionSystemMessageParam,
19
- ChatCompletionToolChoiceOptionParam,
20
- ChatCompletionToolMessageParam,
21
- ChatCompletionToolParam,
22
- ChatCompletionUsage,
23
- ChatCompletionUserMessageParam,
13
+ OpenAIAssistantMessageParam,
14
+ OpenAICompletion,
15
+ OpenAICompletionChunk,
16
+ OpenAICompletionUsage,
17
+ OpenAIContentPartParam,
18
+ OpenAIMessage,
19
+ OpenAISystemMessageParam,
20
+ OpenAIToolChoiceOptionParam,
21
+ OpenAIToolMessageParam,
22
+ OpenAIToolParam,
23
+ OpenAIUserMessageParam,
24
24
  )
25
+ from .completion_chunk_converters import from_api_completion_chunk
25
26
  from .completion_converters import (
26
27
  from_api_completion,
27
- from_api_completion_chunk,
28
- from_api_completion_chunk_iterator,
28
+ from_api_completion_usage,
29
29
  to_api_completion,
30
30
  )
31
31
  from .content_converters import from_api_content, to_api_content
@@ -46,119 +46,95 @@ class OpenAIConverters(Converters):
46
46
  @staticmethod
47
47
  def to_system_message(
48
48
  system_message: SystemMessage, **kwargs: Any
49
- ) -> ChatCompletionSystemMessageParam:
49
+ ) -> OpenAISystemMessageParam:
50
50
  return to_api_system_message(system_message, **kwargs)
51
51
 
52
52
  @staticmethod
53
53
  def from_system_message(
54
- raw_message: ChatCompletionSystemMessageParam,
55
- model_id: str | None = None,
56
- **kwargs: Any,
54
+ raw_message: OpenAISystemMessageParam, name: str | None = None, **kwargs: Any
57
55
  ) -> SystemMessage:
58
- return from_api_system_message(raw_message, model_id=model_id, **kwargs)
56
+ return from_api_system_message(raw_message, name=name, **kwargs)
59
57
 
60
58
  @staticmethod
61
59
  def to_user_message(
62
60
  user_message: UserMessage, **kwargs: Any
63
- ) -> ChatCompletionUserMessageParam:
61
+ ) -> OpenAIUserMessageParam:
64
62
  return to_api_user_message(user_message, **kwargs)
65
63
 
66
64
  @staticmethod
67
65
  def from_user_message(
68
- raw_message: ChatCompletionUserMessageParam,
69
- model_id: str | None = None,
70
- **kwargs: Any,
66
+ raw_message: OpenAIUserMessageParam, name: str | None = None, **kwargs: Any
71
67
  ) -> UserMessage:
72
- return from_api_user_message(raw_message, model_id=model_id, **kwargs)
68
+ return from_api_user_message(raw_message, name=name, **kwargs)
73
69
 
74
70
  @staticmethod
75
71
  def to_assistant_message(
76
72
  assistant_message: AssistantMessage, **kwargs: Any
77
- ) -> ChatCompletionAssistantMessageParam:
73
+ ) -> OpenAIAssistantMessageParam:
78
74
  return to_api_assistant_message(assistant_message, **kwargs)
79
75
 
76
+ @staticmethod
77
+ def from_completion_usage(raw_usage: OpenAICompletionUsage, **kwargs: Any) -> Usage:
78
+ return from_api_completion_usage(raw_usage, **kwargs)
79
+
80
80
  @staticmethod
81
81
  def from_assistant_message(
82
- raw_message: ChatCompletionMessage,
83
- raw_usage: ChatCompletionUsage,
84
- model_id: str | None = None,
85
- **kwargs: Any,
82
+ raw_message: OpenAIMessage, name: str | None = None, **kwargs: Any
86
83
  ) -> AssistantMessage:
87
- return from_api_assistant_message(
88
- raw_message, raw_usage, model_id=model_id, **kwargs
89
- )
84
+ return from_api_assistant_message(raw_message, name=name, **kwargs)
90
85
 
91
86
  @staticmethod
92
87
  def to_tool_message(
93
88
  tool_message: ToolMessage, **kwargs: Any
94
- ) -> ChatCompletionToolMessageParam:
89
+ ) -> OpenAIToolMessageParam:
95
90
  return to_api_tool_message(tool_message, **kwargs)
96
91
 
97
92
  @staticmethod
98
93
  def from_tool_message(
99
- raw_message: ChatCompletionToolMessageParam,
100
- model_id: str | None = None,
101
- **kwargs: Any,
94
+ raw_message: OpenAIToolMessageParam, name: str | None = None, **kwargs: Any
102
95
  ) -> ToolMessage:
103
- return from_api_tool_message(raw_message, model_id=model_id, **kwargs)
96
+ return from_api_tool_message(raw_message, name=name, **kwargs)
104
97
 
105
98
  @staticmethod
106
99
  def to_tool(
107
- tool: BaseTool[BaseModel, Any, Any], **kwargs: Any
108
- ) -> ChatCompletionToolParam:
109
- return to_api_tool(tool, **kwargs)
100
+ tool: BaseTool[BaseModel, Any, Any], strict: bool | None = None, **kwargs: Any
101
+ ) -> OpenAIToolParam:
102
+ return to_api_tool(tool, strict=strict, **kwargs)
110
103
 
111
104
  @staticmethod
112
105
  def to_tool_choice(
113
106
  tool_choice: ToolChoice, **kwargs: Any
114
- ) -> ChatCompletionToolChoiceOptionParam:
107
+ ) -> OpenAIToolChoiceOptionParam:
115
108
  return to_api_tool_choice(tool_choice, **kwargs)
116
109
 
117
110
  @staticmethod
118
- def to_content(
119
- content: Content, **kwargs: Any
120
- ) -> Iterable[ChatCompletionContentPartParam]:
111
+ def to_content(content: Content, **kwargs: Any) -> Iterable[OpenAIContentPartParam]:
121
112
  return to_api_content(content, **kwargs)
122
113
 
123
114
  @staticmethod
124
115
  def from_content(
125
- raw_content: str | Iterable[ChatCompletionContentPartParam],
126
- **kwargs: Any,
116
+ raw_content: str | Iterable[OpenAIContentPartParam], **kwargs: Any
127
117
  ) -> Content:
128
118
  return from_api_content(raw_content, **kwargs)
129
119
 
130
120
  @staticmethod
131
- def to_completion(completion: Completion, **kwargs: Any) -> ChatCompletion:
121
+ def to_completion(completion: Completion, **kwargs: Any) -> OpenAICompletion:
132
122
  return to_api_completion(completion, **kwargs)
133
123
 
134
124
  @staticmethod
135
125
  def from_completion(
136
- raw_completion: ChatCompletion,
137
- model_id: str | None = None,
138
- **kwargs: Any,
126
+ raw_completion: OpenAICompletion, name: str | None = None, **kwargs: Any
139
127
  ) -> Completion:
140
- return from_api_completion(raw_completion, model_id=model_id, **kwargs)
128
+ return from_api_completion(raw_completion, name=name, **kwargs)
141
129
 
142
130
  @staticmethod
143
131
  def to_completion_chunk(
144
132
  chunk: CompletionChunk, **kwargs: Any
145
- ) -> ChatCompletionChunk:
133
+ ) -> OpenAICompletionChunk:
146
134
  raise NotImplementedError
147
135
 
148
136
  @staticmethod
149
137
  def from_completion_chunk(
150
- raw_chunk: ChatCompletionChunk,
151
- model_id: str | None = None,
152
- **kwargs: Any,
138
+ raw_chunk: OpenAICompletionChunk, name: str | None = None, **kwargs: Any
153
139
  ) -> CompletionChunk:
154
- return from_api_completion_chunk(raw_chunk, model_id=model_id, **kwargs)
155
-
156
- @staticmethod
157
- def from_completion_chunk_iterator( # type: ignore[override]
158
- raw_chunk_iterator: ChatCompletionAsyncStream[ChatCompletionChunk],
159
- model_id: str | None = None,
160
- **kwargs: Any,
161
- ) -> AsyncIterator[CompletionChunk]:
162
- return from_api_completion_chunk_iterator(
163
- raw_chunk_iterator, model_id=model_id, **kwargs
164
- )
140
+ return from_api_completion_chunk(raw_chunk, name=name, **kwargs)
@@ -1,76 +1,62 @@
1
1
  from typing import TypeAlias
2
2
 
3
+ from ..typing.content import Content
3
4
  from ..typing.message import (
4
5
  AssistantMessage,
5
6
  SystemMessage,
6
7
  ToolMessage,
7
- Usage,
8
8
  UserMessage,
9
9
  )
10
10
  from ..typing.tool import ToolCall
11
11
  from . import (
12
- ChatCompletionAssistantMessageParam,
13
- ChatCompletionDeveloperMessageParam,
14
- ChatCompletionFunctionMessageParam,
15
- ChatCompletionMessage,
16
- ChatCompletionMessageToolCallParam,
17
- ChatCompletionSystemMessageParam,
18
- ChatCompletionToolCallFunction,
19
- ChatCompletionToolMessageParam,
20
- ChatCompletionUsage,
21
- ChatCompletionUserMessageParam,
12
+ OpenAIAssistantMessageParam,
13
+ OpenAIDeveloperMessageParam,
14
+ OpenAIFunctionMessageParam,
15
+ OpenAIMessage,
16
+ OpenAISystemMessageParam,
17
+ OpenAIToolCallFunction,
18
+ OpenAIToolCallParam,
19
+ OpenAIToolMessageParam,
20
+ OpenAIUserMessageParam,
22
21
  )
23
22
  from .content_converters import from_api_content, to_api_content
24
23
 
25
- OpenAIMessage: TypeAlias = (
26
- ChatCompletionAssistantMessageParam
27
- | ChatCompletionToolMessageParam
28
- | ChatCompletionUserMessageParam
29
- | ChatCompletionDeveloperMessageParam
30
- | ChatCompletionSystemMessageParam
31
- | ChatCompletionFunctionMessageParam
24
+ OpenAIMessageType: TypeAlias = (
25
+ OpenAIAssistantMessageParam
26
+ | OpenAIToolMessageParam
27
+ | OpenAIUserMessageParam
28
+ | OpenAIDeveloperMessageParam
29
+ | OpenAISystemMessageParam
30
+ | OpenAIFunctionMessageParam
32
31
  )
33
32
 
34
33
 
35
34
  def from_api_user_message(
36
- api_message: ChatCompletionUserMessageParam, model_id: str | None = None
35
+ api_message: OpenAIUserMessageParam, name: str | None = None
37
36
  ) -> UserMessage:
38
37
  content = from_api_content(api_message["content"])
38
+ name = api_message.get("name")
39
39
 
40
- return UserMessage(content=content, model_id=model_id)
40
+ return UserMessage(content=content, name=name)
41
41
 
42
42
 
43
- def to_api_user_message(message: UserMessage) -> ChatCompletionUserMessageParam:
44
- api_content = to_api_content(message.content)
43
+ def to_api_user_message(message: UserMessage) -> OpenAIUserMessageParam:
44
+ api_content = (
45
+ to_api_content(message.content)
46
+ if isinstance(message.content, Content)
47
+ else message.content
48
+ )
49
+ api_name = message.name
50
+ api_message = OpenAIUserMessageParam(role="user", content=api_content)
51
+ if api_name is not None:
52
+ api_message["name"] = api_name
45
53
 
46
- return ChatCompletionUserMessageParam(role="user", content=api_content)
54
+ return api_message
47
55
 
48
56
 
49
57
  def from_api_assistant_message(
50
- api_message: ChatCompletionMessage,
51
- api_usage: ChatCompletionUsage | None = None,
52
- model_id: str | None = None,
58
+ api_message: OpenAIMessage, name: str | None = None
53
59
  ) -> AssistantMessage:
54
- usage = None
55
- if api_usage is not None:
56
- reasoning_tokens = None
57
- cached_tokens = None
58
-
59
- if api_usage.completion_tokens_details is not None:
60
- reasoning_tokens = api_usage.completion_tokens_details.reasoning_tokens
61
- if api_usage.prompt_tokens_details is not None:
62
- cached_tokens = api_usage.prompt_tokens_details.cached_tokens
63
-
64
- input_tokens = api_usage.prompt_tokens - (cached_tokens or 0)
65
- output_tokens = api_usage.completion_tokens - (reasoning_tokens or 0)
66
-
67
- usage = Usage(
68
- input_tokens=input_tokens,
69
- output_tokens=output_tokens,
70
- reasoning_tokens=reasoning_tokens,
71
- cached_tokens=cached_tokens,
72
- )
73
-
74
60
  tool_calls = None
75
61
  if api_message.tool_calls is not None:
76
62
  tool_calls = [
@@ -84,23 +70,22 @@ def from_api_assistant_message(
84
70
 
85
71
  return AssistantMessage(
86
72
  content=api_message.content,
87
- usage=usage,
88
73
  tool_calls=tool_calls,
89
74
  refusal=api_message.refusal,
90
- model_id=model_id,
75
+ name=name,
91
76
  )
92
77
 
93
78
 
94
79
  def to_api_assistant_message(
95
80
  message: AssistantMessage,
96
- ) -> ChatCompletionAssistantMessageParam:
81
+ ) -> OpenAIAssistantMessageParam:
97
82
  api_tool_calls = None
98
83
  if message.tool_calls is not None:
99
84
  api_tool_calls = [
100
- ChatCompletionMessageToolCallParam(
85
+ OpenAIToolCallParam(
101
86
  type="function",
102
87
  id=tool_call.id,
103
- function=ChatCompletionToolCallFunction(
88
+ function=OpenAIToolCallFunction(
104
89
  name=tool_call.tool_name,
105
90
  arguments=tool_call.tool_arguments,
106
91
  ),
@@ -108,48 +93,49 @@ def to_api_assistant_message(
108
93
  for tool_call in message.tool_calls
109
94
  ]
110
95
 
111
- api_message = ChatCompletionAssistantMessageParam(
112
- role="assistant",
113
- content=message.content,
114
- tool_calls=api_tool_calls or [],
115
- refusal=message.refusal,
116
- )
96
+ api_message = OpenAIAssistantMessageParam(role="assistant", content=message.content)
97
+
98
+ if message.name is not None:
99
+ api_message["name"] = message.name
100
+ if api_tool_calls is not None:
101
+ api_message["tool_calls"] = api_tool_calls or []
102
+ if message.refusal is not None:
103
+ api_message["refusal"] = message.refusal
104
+
105
+ # TODO: hack
117
106
  if message.content is None:
118
107
  # Some API providers return None in the generated content without errors,
119
108
  # even though None in the input content is not accepted.
120
109
  api_message["content"] = "<empty>"
121
- if api_tool_calls is None:
122
- api_message.pop("tool_calls")
123
- if message.refusal is None:
124
- api_message.pop("refusal")
125
110
 
126
111
  return api_message
127
112
 
128
113
 
129
114
  def from_api_system_message(
130
- api_message: ChatCompletionSystemMessageParam,
131
- model_id: str | None = None,
115
+ api_message: OpenAISystemMessageParam, name: str | None = None
132
116
  ) -> SystemMessage:
133
- return SystemMessage(content=api_message["content"], model_id=model_id) # type: ignore
117
+ return SystemMessage(content=api_message["content"], name=name) # type: ignore
134
118
 
135
119
 
136
- def to_api_system_message(
137
- message: SystemMessage,
138
- ) -> ChatCompletionSystemMessageParam:
139
- return ChatCompletionSystemMessageParam(role="system", content=message.content)
120
+ def to_api_system_message(message: SystemMessage) -> OpenAISystemMessageParam:
121
+ api_message = OpenAISystemMessageParam(role="system", content=message.content)
122
+ if message.name is not None:
123
+ api_message["name"] = message.name
124
+
125
+ return api_message
140
126
 
141
127
 
142
128
  def from_api_tool_message(
143
- api_message: ChatCompletionToolMessageParam, model_id: str | None = None
129
+ api_message: OpenAIToolMessageParam, name: str | None = None
144
130
  ) -> ToolMessage:
145
131
  return ToolMessage(
146
132
  content=api_message["content"], # type: ignore
147
133
  tool_call_id=api_message["tool_call_id"],
148
- model_id=model_id,
134
+ name=name,
149
135
  )
150
136
 
151
137
 
152
- def to_api_tool_message(message: ToolMessage) -> ChatCompletionToolMessageParam:
153
- return ChatCompletionToolMessageParam(
138
+ def to_api_tool_message(message: ToolMessage) -> OpenAIToolMessageParam:
139
+ return OpenAIToolMessageParam(
154
140
  role="tool", content=message.content, tool_call_id=message.tool_call_id
155
141
  )
@@ -1,45 +1,45 @@
1
1
  import logging
2
- from collections.abc import Iterable, Mapping
2
+ from collections.abc import AsyncIterator, Iterable, Mapping
3
3
  from copy import deepcopy
4
- from typing import Any, Literal
4
+ from typing import Any, Literal, NamedTuple
5
5
 
6
6
  from openai import AsyncOpenAI
7
7
  from openai._types import NOT_GIVEN # type: ignore[import]
8
+ from openai.lib.streaming.chat import (
9
+ AsyncChatCompletionStreamManager as OpenAIAsyncChatCompletionStreamManager,
10
+ )
11
+ from openai.lib.streaming.chat import ChunkEvent as OpenAIChunkEvent
8
12
  from pydantic import BaseModel
9
13
 
10
14
  from ..cloud_llm import CloudLLM, CloudLLMSettings
11
15
  from ..http_client import AsyncHTTPClientParams
12
16
  from ..rate_limiting.rate_limiter_chunked import RateLimiterC
13
- from ..typing.message import AssistantMessage, Conversation
17
+ from ..typing.message import AssistantMessage, Messages
14
18
  from ..typing.tool import BaseTool
15
19
  from . import (
16
- ChatCompletion,
17
- ChatCompletionAsyncStream, # type: ignore[import]
18
- ChatCompletionChunk,
19
- ChatCompletionMessageParam,
20
- ChatCompletionPredictionContentParam,
21
- ChatCompletionStreamOptionsParam,
22
- ChatCompletionToolChoiceOptionParam,
23
- ChatCompletionToolParam,
24
- ParsedChatCompletion,
25
- # ResponseFormatJSONObject,
26
- # ResponseFormatJSONSchema,
27
- # ResponseFormatText,
20
+ OpenAICompletion,
21
+ OpenAICompletionChunk,
22
+ OpenAIMessageParam,
23
+ OpenAIParsedCompletion,
24
+ OpenAIPredictionContentParam,
25
+ OpenAIStreamOptionsParam,
26
+ OpenAIToolChoiceOptionParam,
27
+ OpenAIToolParam,
28
28
  )
29
29
  from .converters import OpenAIConverters
30
30
 
31
31
  logger = logging.getLogger(__name__)
32
32
 
33
33
 
34
+ class ToolCallSettings(NamedTuple):
35
+ strict: bool | None = None
36
+
37
+
34
38
  class OpenAILLMSettings(CloudLLMSettings, total=False):
35
39
  reasoning_effort: Literal["low", "medium", "high"] | None
36
40
 
37
41
  parallel_tool_calls: bool
38
42
 
39
- # response_format: (
40
- # ResponseFormatText | ResponseFormatJSONSchema | ResponseFormatJSONObject
41
- # )
42
-
43
43
  modalities: list[Literal["text", "audio"]] | None
44
44
 
45
45
  frequency_penalty: float | None
@@ -48,16 +48,23 @@ class OpenAILLMSettings(CloudLLMSettings, total=False):
48
48
  stop: str | list[str] | None
49
49
  logprobs: bool | None
50
50
  top_logprobs: int | None
51
- n: int | None
52
51
 
53
- prediction: ChatCompletionPredictionContentParam | None
52
+ prediction: OpenAIPredictionContentParam | None
54
53
 
55
- stream_options: ChatCompletionStreamOptionsParam | None
54
+ stream_options: OpenAIStreamOptionsParam | None
56
55
 
57
56
  metadata: dict[str, str] | None
58
57
  store: bool | None
59
58
  user: str
60
59
 
60
+ strict_tool_args: bool
61
+
62
+ # response_format: (
63
+ # OpenAIResponseFormatText
64
+ # | OpenAIResponseFormatJSONSchema
65
+ # | OpenAIResponseFormatJSONObject
66
+ # )
67
+
61
68
 
62
69
  class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
63
70
  def __init__(
@@ -74,7 +81,7 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
74
81
  ) = None,
75
82
  async_openai_client_params: dict[str, Any] | None = None,
76
83
  # Rate limiting
77
- rate_limiter: (RateLimiterC[Conversation, AssistantMessage] | None) = None,
84
+ rate_limiter: (RateLimiterC[Messages, AssistantMessage] | None) = None,
78
85
  rate_limiter_rpm: float | None = None,
79
86
  rate_limiter_chunk_size: int = 1000,
80
87
  rate_limiter_max_concurrency: int = 300,
@@ -101,46 +108,55 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
101
108
  **kwargs,
102
109
  )
103
110
 
104
- async_openai_client_params_ = deepcopy(async_openai_client_params or {})
111
+ self._tool_call_settings = {
112
+ "strict": self._llm_settings.pop("strict_tool_args", None)
113
+ }
114
+
115
+ _async_openai_client_params = deepcopy(async_openai_client_params or {})
105
116
  if self._async_http_client is not None:
106
- async_openai_client_params_["http_client"] = self._async_http_client
117
+ _async_openai_client_params["http_client"] = self._async_http_client
107
118
 
108
119
  # TODO: context manager for async client
109
120
  self._client: AsyncOpenAI = AsyncOpenAI(
110
121
  base_url=self._base_url,
111
122
  api_key=self._api_key,
112
- **async_openai_client_params_,
123
+ **_async_openai_client_params,
113
124
  )
114
125
 
115
126
  async def _get_completion(
116
127
  self,
117
- api_messages: Iterable[ChatCompletionMessageParam],
118
- api_tools: list[ChatCompletionToolParam] | None = None,
119
- api_tool_choice: ChatCompletionToolChoiceOptionParam | None = None,
128
+ api_messages: Iterable[OpenAIMessageParam],
129
+ api_tools: list[OpenAIToolParam] | None = None,
130
+ api_tool_choice: OpenAIToolChoiceOptionParam | None = None,
131
+ n_choices: int | None = None,
120
132
  **api_llm_settings: Any,
121
- ) -> ChatCompletion:
133
+ ) -> OpenAICompletion:
122
134
  tools = api_tools or NOT_GIVEN
123
135
  tool_choice = api_tool_choice or NOT_GIVEN
136
+ n = n_choices or NOT_GIVEN
124
137
 
125
138
  return await self._client.chat.completions.create(
126
139
  model=self._api_model_name,
127
140
  messages=api_messages,
128
141
  tools=tools,
129
142
  tool_choice=tool_choice,
143
+ n=n,
130
144
  stream=False,
131
145
  **api_llm_settings,
132
146
  )
133
147
 
134
148
  async def _get_parsed_completion(
135
149
  self,
136
- api_messages: Iterable[ChatCompletionMessageParam],
137
- api_tools: list[ChatCompletionToolParam] | None = None,
138
- api_tool_choice: ChatCompletionToolChoiceOptionParam | None = None,
150
+ api_messages: Iterable[OpenAIMessageParam],
151
+ api_tools: list[OpenAIToolParam] | None = None,
152
+ api_tool_choice: OpenAIToolChoiceOptionParam | None = None,
139
153
  api_response_format: type | None = None,
154
+ n_choices: int | None = None,
140
155
  **api_llm_settings: Any,
141
- ) -> ParsedChatCompletion[Any]:
156
+ ) -> OpenAIParsedCompletion[Any]:
142
157
  tools = api_tools or NOT_GIVEN
143
158
  tool_choice = api_tool_choice or NOT_GIVEN
159
+ n = n_choices or NOT_GIVEN
144
160
  response_format = api_response_format or NOT_GIVEN
145
161
 
146
162
  return await self._client.beta.chat.completions.parse(
@@ -148,27 +164,70 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
148
164
  messages=api_messages,
149
165
  tools=tools,
150
166
  tool_choice=tool_choice,
151
- response_format=response_format, # type: ignore[arg-type]
167
+ response_format=response_format,
168
+ n=n,
152
169
  **api_llm_settings,
153
170
  )
154
171
 
155
172
  async def _get_completion_stream(
156
173
  self,
157
- api_messages: Iterable[ChatCompletionMessageParam],
158
- api_tools: list[ChatCompletionToolParam] | None = None,
159
- api_tool_choice: ChatCompletionToolChoiceOptionParam | None = None,
174
+ api_messages: Iterable[OpenAIMessageParam],
175
+ api_tools: list[OpenAIToolParam] | None = None,
176
+ api_tool_choice: OpenAIToolChoiceOptionParam | None = None,
177
+ n_choices: int | None = None,
160
178
  **api_llm_settings: Any,
161
- ) -> ChatCompletionAsyncStream[ChatCompletionChunk]:
162
- assert not api_tools, "Tool use is not supported in streaming mode"
163
-
179
+ ) -> AsyncIterator[OpenAICompletionChunk]:
164
180
  tools = api_tools or NOT_GIVEN
165
181
  tool_choice = api_tool_choice or NOT_GIVEN
182
+ n = n_choices or NOT_GIVEN
166
183
 
167
- return await self._client.chat.completions.create(
184
+ stream_generator = await self._client.chat.completions.create(
168
185
  model=self._api_model_name,
169
186
  messages=api_messages,
170
187
  tools=tools,
171
188
  tool_choice=tool_choice,
172
189
  stream=True,
190
+ n=n,
173
191
  **api_llm_settings,
174
192
  )
193
+
194
+ async def iterate() -> AsyncIterator[OpenAICompletionChunk]:
195
+ async with stream_generator as stream:
196
+ async for completion_chunk in stream:
197
+ yield completion_chunk
198
+
199
+ return iterate()
200
+
201
+ async def _get_parsed_completion_stream(
202
+ self,
203
+ api_messages: Iterable[OpenAIMessageParam],
204
+ api_tools: list[OpenAIToolParam] | None = None,
205
+ api_tool_choice: OpenAIToolChoiceOptionParam | None = None,
206
+ api_response_format: type | None = None,
207
+ n_choices: int | None = None,
208
+ **api_llm_settings: Any,
209
+ ) -> AsyncIterator[OpenAICompletionChunk]:
210
+ tools = api_tools or NOT_GIVEN
211
+ tool_choice = api_tool_choice or NOT_GIVEN
212
+ response_format = api_response_format or NOT_GIVEN
213
+ n = n_choices or NOT_GIVEN
214
+
215
+ stream_manager: OpenAIAsyncChatCompletionStreamManager[
216
+ OpenAICompletionChunk
217
+ ] = self._client.beta.chat.completions.stream(
218
+ model=self._api_model_name,
219
+ messages=api_messages,
220
+ tools=tools,
221
+ tool_choice=tool_choice,
222
+ response_format=response_format,
223
+ n=n,
224
+ **api_llm_settings,
225
+ )
226
+
227
+ async def iterate() -> AsyncIterator[OpenAICompletionChunk]:
228
+ async with stream_manager as stream:
229
+ async for chunk_event in stream:
230
+ if isinstance(chunk_event, OpenAIChunkEvent):
231
+ yield chunk_event.chunk
232
+
233
+ return iterate()