oagi-core 0.11.0__py3-none-any.whl → 0.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
oagi/client/async_.py CHANGED
@@ -9,17 +9,19 @@
9
9
  from functools import wraps
10
10
 
11
11
  import httpx
12
+ from httpx import AsyncHTTPTransport
13
+ from openai import AsyncOpenAI
12
14
 
13
15
  from ..constants import (
14
- API_HEALTH_ENDPOINT,
15
16
  API_V1_FILE_UPLOAD_ENDPOINT,
16
17
  API_V1_GENERATE_ENDPOINT,
17
- API_V2_MESSAGE_ENDPOINT,
18
+ DEFAULT_MAX_RETRIES,
18
19
  HTTP_CLIENT_TIMEOUT,
19
20
  )
20
21
  from ..logging import get_logger
21
22
  from ..types import Image
22
- from ..types.models import GenerateResponse, LLMResponse, UploadFileResponse
23
+ from ..types.models import GenerateResponse, UploadFileResponse, Usage
24
+ from ..types.models.step import Step
23
25
  from .base import BaseClient
24
26
 
25
27
  logger = get_logger("async_client")
@@ -35,8 +37,7 @@ def async_log_trace_on_failure(func):
35
37
  except Exception as e:
36
38
  # Try to get response from the exception if it has one
37
39
  if (response := getattr(e, "response", None)) is not None:
38
- logger.error(f"Request Id: {response.headers.get('x-request-id', '')}")
39
- logger.error(f"Trace Id: {response.headers.get('x-trace-id', '')}")
40
+ BaseClient._log_trace_id(response)
40
41
  raise
41
42
 
42
43
  return wrapper
@@ -45,115 +46,72 @@ def async_log_trace_on_failure(func):
45
46
  class AsyncClient(BaseClient[httpx.AsyncClient]):
46
47
  """Asynchronous HTTP client for the OAGI API."""
47
48
 
48
- def __init__(self, base_url: str | None = None, api_key: str | None = None):
49
- super().__init__(base_url, api_key)
50
- self.client = httpx.AsyncClient(base_url=self.base_url)
51
- self.upload_client = httpx.AsyncClient(timeout=HTTP_CLIENT_TIMEOUT)
49
+ def __init__(
50
+ self,
51
+ base_url: str | None = None,
52
+ api_key: str | None = None,
53
+ max_retries: int = DEFAULT_MAX_RETRIES,
54
+ ):
55
+ super().__init__(base_url, api_key, max_retries)
56
+
57
+ # OpenAI client for chat completions (with retries)
58
+ self.openai_client = AsyncOpenAI(
59
+ api_key=self.api_key,
60
+ base_url=f"{self.base_url}/v1",
61
+ max_retries=self.max_retries,
62
+ )
63
+
64
+ # httpx clients for S3 uploads and other endpoints (with retries)
65
+ transport = AsyncHTTPTransport(retries=self.max_retries)
66
+ self.http_client = httpx.AsyncClient(
67
+ transport=transport, base_url=self.base_url
68
+ )
69
+ self.upload_client = httpx.AsyncClient(
70
+ transport=transport, timeout=HTTP_CLIENT_TIMEOUT
71
+ )
72
+
52
73
  logger.info(f"AsyncClient initialized with base_url: {self.base_url}")
53
74
 
54
75
  async def __aenter__(self):
55
76
  return self
56
77
 
57
78
  async def __aexit__(self, exc_type, exc_val, exc_tb):
58
- await self.client.aclose()
59
- await self.upload_client.aclose()
79
+ await self.close()
60
80
 
61
81
  async def close(self):
62
- """Close the underlying httpx async clients."""
63
- await self.client.aclose()
82
+ """Close the underlying async clients."""
83
+ await self.openai_client.close()
84
+ await self.http_client.aclose()
64
85
  await self.upload_client.aclose()
65
86
 
66
- @async_log_trace_on_failure
67
- async def create_message(
87
+ async def chat_completion(
68
88
  self,
69
89
  model: str,
70
- screenshot: bytes | None = None,
71
- screenshot_url: str | None = None,
72
- task_description: str | None = None,
73
- task_id: str | None = None,
74
- instruction: str | None = None,
75
- messages_history: list | None = None,
90
+ messages: list,
76
91
  temperature: float | None = None,
77
- api_version: str | None = None,
78
- ) -> "LLMResponse":
92
+ task_id: str | None = None,
93
+ ) -> tuple[Step, str, Usage | None]:
79
94
  """
80
- Call the /v2/message endpoint to analyze task and screenshot
95
+ Call OpenAI-compatible /v1/chat/completions endpoint.
81
96
 
82
97
  Args:
83
- model: The model to use for task analysis
84
- screenshot: Screenshot image bytes (mutually exclusive with screenshot_url)
85
- screenshot_url: Direct URL to screenshot (mutually exclusive with screenshot)
86
- task_description: Description of the task (required for new sessions)
87
- task_id: Task ID for continuing existing task
88
- instruction: Additional instruction when continuing a session
89
- messages_history: OpenAI-compatible chat message history
90
- temperature: Sampling temperature (0.0-2.0) for LLM inference
91
- api_version: API version header
98
+ model: Model to use for inference
99
+ messages: Full message history (OpenAI-compatible format)
100
+ temperature: Sampling temperature (0.0-2.0)
101
+ task_id: Optional task ID for multi-turn conversations
92
102
 
93
103
  Returns:
94
- LLMResponse: The response from the API
95
-
96
- Raises:
97
- ValueError: If both or neither screenshot and screenshot_url are provided
98
- httpx.HTTPStatusError: For HTTP error responses
104
+ Tuple of (Step, raw_output, Usage)
105
+ - Step: Parsed actions and reasoning
106
+ - raw_output: Raw model output string (for message history)
107
+ - Usage: Token usage statistics (or None if not available)
99
108
  """
100
- # Validate that exactly one is provided
101
- if (screenshot is None) == (screenshot_url is None):
102
- raise ValueError(
103
- "Exactly one of 'screenshot' or 'screenshot_url' must be provided"
104
- )
105
-
106
- self._log_request_info(model, task_description, task_id)
107
-
108
- # Upload screenshot to S3 if bytes provided, otherwise use URL directly
109
- upload_file_response = None
110
- if screenshot is not None:
111
- upload_file_response = await self.put_s3_presigned_url(
112
- screenshot, api_version
113
- )
114
-
115
- # Prepare message payload
116
- headers, payload = self._prepare_message_payload(
117
- model=model,
118
- upload_file_response=upload_file_response,
119
- task_description=task_description,
120
- task_id=task_id,
121
- instruction=instruction,
122
- messages_history=messages_history,
123
- temperature=temperature,
124
- api_version=api_version,
125
- screenshot_url=screenshot_url,
109
+ logger.info(f"Making async chat completion request with model: {model}")
110
+ kwargs = self._build_chat_completion_kwargs(
111
+ model, messages, temperature, task_id
126
112
  )
127
-
128
- # Make request
129
- try:
130
- response = await self.client.post(
131
- API_V2_MESSAGE_ENDPOINT,
132
- json=payload,
133
- headers=headers,
134
- timeout=self.timeout,
135
- )
136
- return self._process_response(response)
137
- except (httpx.TimeoutException, httpx.NetworkError) as e:
138
- self._handle_upload_http_errors(e)
139
-
140
- async def health_check(self) -> dict:
141
- """
142
- Call the /health endpoint for health check
143
-
144
- Returns:
145
- dict: Health check response
146
- """
147
- logger.debug("Making async health check request")
148
- try:
149
- response = await self.client.get(API_HEALTH_ENDPOINT)
150
- response.raise_for_status()
151
- result = response.json()
152
- logger.debug("Async health check successful")
153
- return result
154
- except httpx.HTTPStatusError as e:
155
- logger.warning(f"Async health check failed: {e}")
156
- raise
113
+ response = await self.openai_client.chat.completions.create(**kwargs)
114
+ return self._parse_chat_completion_response(response)
157
115
 
158
116
  async def get_s3_presigned_url(
159
117
  self,
@@ -172,7 +130,7 @@ class AsyncClient(BaseClient[httpx.AsyncClient]):
172
130
 
173
131
  try:
174
132
  headers = self._build_headers(api_version)
175
- response = await self.client.get(
133
+ response = await self.http_client.get(
176
134
  API_V1_FILE_UPLOAD_ENDPOINT, headers=headers, timeout=self.timeout
177
135
  )
178
136
  return self._process_upload_response(response)
@@ -292,7 +250,7 @@ class AsyncClient(BaseClient[httpx.AsyncClient]):
292
250
 
293
251
  # Make request
294
252
  try:
295
- response = await self.client.post(
253
+ response = await self.http_client.post(
296
254
  API_V1_GENERATE_ENDPOINT,
297
255
  json=payload,
298
256
  headers=headers,
oagi/client/base.py CHANGED
@@ -11,7 +11,12 @@ from typing import Any, Generic, TypeVar
11
11
 
12
12
  import httpx
13
13
 
14
- from ..constants import API_KEY_HELP_URL, DEFAULT_BASE_URL, HTTP_CLIENT_TIMEOUT
14
+ from ..constants import (
15
+ API_KEY_HELP_URL,
16
+ DEFAULT_BASE_URL,
17
+ DEFAULT_MAX_RETRIES,
18
+ HTTP_CLIENT_TIMEOUT,
19
+ )
15
20
  from ..exceptions import (
16
21
  APIError,
17
22
  AuthenticationError,
@@ -27,9 +32,11 @@ from ..logging import get_logger
27
32
  from ..types.models import (
28
33
  ErrorResponse,
29
34
  GenerateResponse,
30
- LLMResponse,
31
35
  UploadFileResponse,
36
+ Usage,
32
37
  )
38
+ from ..types.models.step import Step
39
+ from ..utils.output_parser import parse_raw_output
33
40
 
34
41
  logger = get_logger("client.base")
35
42
 
@@ -40,7 +47,12 @@ HttpClientT = TypeVar("HttpClientT")
40
47
  class BaseClient(Generic[HttpClientT]):
41
48
  """Base class with shared business logic for sync/async clients."""
42
49
 
43
- def __init__(self, base_url: str | None = None, api_key: str | None = None):
50
+ def __init__(
51
+ self,
52
+ base_url: str | None = None,
53
+ api_key: str | None = None,
54
+ max_retries: int = DEFAULT_MAX_RETRIES,
55
+ ):
44
56
  # Get from environment if not provided
45
57
  self.base_url = base_url or os.getenv("OAGI_BASE_URL") or DEFAULT_BASE_URL
46
58
  self.api_key = api_key or os.getenv("OAGI_API_KEY")
@@ -55,6 +67,7 @@ class BaseClient(Generic[HttpClientT]):
55
67
 
56
68
  self.base_url = self.base_url.rstrip("/")
57
69
  self.timeout = HTTP_CLIENT_TIMEOUT
70
+ self.max_retries = max_retries
58
71
  self.client: HttpClientT # Will be set by subclasses
59
72
 
60
73
  logger.info(f"Client initialized with base_url: {self.base_url}")
@@ -67,39 +80,77 @@ class BaseClient(Generic[HttpClientT]):
67
80
  headers["x-api-key"] = self.api_key
68
81
  return headers
69
82
 
70
- def _build_payload(
83
+ @staticmethod
84
+ def _log_trace_id(response) -> None:
85
+ """Log trace IDs from response headers for debugging."""
86
+ logger.error(f"Request Id: {response.headers.get('x-request-id', '')}")
87
+ logger.error(f"Trace Id: {response.headers.get('x-trace-id', '')}")
88
+
89
+ def _build_chat_completion_kwargs(
71
90
  self,
72
91
  model: str,
73
- messages_history: list,
74
- task_description: str | None = None,
75
- task_id: str | None = None,
92
+ messages: list,
76
93
  temperature: float | None = None,
77
- ) -> dict[str, Any]:
78
- """Build OpenAI-compatible request payload.
94
+ task_id: str | None = None,
95
+ ) -> dict:
96
+ """Build kwargs dict for OpenAI chat completion call.
79
97
 
80
98
  Args:
81
- model: Model to use
82
- messages_history: OpenAI-compatible message history
83
- task_description: Task description
84
- task_id: Task ID for continuing session
85
- temperature: Sampling temperature
99
+ model: Model to use for inference
100
+ messages: Full message history (OpenAI-compatible format)
101
+ temperature: Sampling temperature (0.0-2.0)
102
+ task_id: Optional task ID for multi-turn conversations
86
103
 
87
104
  Returns:
88
- OpenAI-compatible request payload
105
+ Dict of kwargs for chat.completions.create()
89
106
  """
90
- payload: dict[str, Any] = {
91
- "model": model,
92
- "messages": messages_history,
93
- }
94
-
95
- if task_description is not None:
96
- payload["task_description"] = task_description
97
- if task_id is not None:
98
- payload["task_id"] = task_id
107
+ kwargs: dict = {"model": model, "messages": messages}
99
108
  if temperature is not None:
100
- payload["temperature"] = temperature
109
+ kwargs["temperature"] = temperature
110
+ if task_id is not None:
111
+ kwargs["extra_body"] = {"task_id": task_id}
112
+ return kwargs
113
+
114
+ def _parse_chat_completion_response(
115
+ self, response
116
+ ) -> tuple[Step, str, Usage | None]:
117
+ """Extract and parse OpenAI chat completion response, and log success.
118
+
119
+ This is sync/async agnostic as it only processes the response object.
120
+
121
+ Args:
122
+ response: OpenAI ChatCompletion response object
101
123
 
102
- return payload
124
+ Returns:
125
+ Tuple of (Step, raw_output, Usage)
126
+ """
127
+ raw_output = response.choices[0].message.content or ""
128
+ step = parse_raw_output(raw_output)
129
+
130
+ # Extract task_id from response (custom field from OAGI API)
131
+ task_id = getattr(response, "task_id", None)
132
+
133
+ usage = None
134
+ if response.usage:
135
+ usage = Usage(
136
+ prompt_tokens=response.usage.prompt_tokens,
137
+ completion_tokens=response.usage.completion_tokens,
138
+ total_tokens=response.usage.total_tokens,
139
+ )
140
+
141
+ # Log success with task_id and usage
142
+ usage_str = (
143
+ f", tokens: {usage.prompt_tokens}+{usage.completion_tokens}"
144
+ if usage
145
+ else ""
146
+ )
147
+ task_str = f"task_id: {task_id}, " if task_id else ""
148
+ logger.info(
149
+ f"Chat completion successful - {task_str}actions: {len(step.actions)}, "
150
+ f"stop: {step.stop}{usage_str}"
151
+ )
152
+
153
+ return step, raw_output, usage
103
154
 
104
155
  def _handle_response_error(
105
156
  self, response: httpx.Response, response_data: dict
@@ -141,84 +192,6 @@ class BaseClient(Generic[HttpClientT]):
141
192
 
142
193
  return status_map.get(status_code, APIError)
143
194
 
144
- def _log_request_info(self, model: str, task_description: Any, task_id: Any):
145
- logger.info(f"Making API request to /v2/message with model: {model}")
146
- logger.debug(
147
- f"Request includes task_description: {task_description is not None}, "
148
- f"task_id: {task_id is not None}"
149
- )
150
-
151
- def _build_user_message(
152
- self, screenshot_url: str, instruction: str | None
153
- ) -> dict[str, Any]:
154
- """Build OpenAI-compatible user message with screenshot and optional instruction.
155
-
156
- Args:
157
- screenshot_url: URL of uploaded screenshot
158
- instruction: Optional text instruction
159
-
160
- Returns:
161
- User message dict
162
- """
163
- content = [{"type": "image_url", "image_url": {"url": screenshot_url}}]
164
- if instruction:
165
- content.append({"type": "text", "text": instruction})
166
- return {"role": "user", "content": content}
167
-
168
- def _prepare_message_payload(
169
- self,
170
- model: str,
171
- upload_file_response: UploadFileResponse | None,
172
- task_description: str | None,
173
- task_id: str | None,
174
- instruction: str | None,
175
- messages_history: list | None,
176
- temperature: float | None,
177
- api_version: str | None,
178
- screenshot_url: str | None = None,
179
- ) -> tuple[dict[str, str], dict[str, Any]]:
180
- """Prepare headers and payload for /v2/message request.
181
-
182
- Args:
183
- model: Model to use
184
- upload_file_response: Response from S3 upload (if screenshot was uploaded)
185
- task_description: Task description
186
- task_id: Task ID
187
- instruction: Optional instruction
188
- messages_history: Message history
189
- temperature: Sampling temperature
190
- api_version: API version
191
- screenshot_url: Direct screenshot URL (alternative to upload_file_response)
192
-
193
- Returns:
194
- Tuple of (headers, payload)
195
- """
196
- # Use provided screenshot_url or get from upload_file_response
197
- if screenshot_url is None:
198
- if upload_file_response is None:
199
- raise ValueError(
200
- "Either screenshot_url or upload_file_response must be provided"
201
- )
202
- screenshot_url = upload_file_response.download_url
203
-
204
- # Build user message and append to history
205
- if messages_history is None:
206
- messages_history = []
207
- user_message = self._build_user_message(screenshot_url, instruction)
208
- messages_history.append(user_message)
209
-
210
- # Build payload and headers
211
- headers = self._build_headers(api_version)
212
- payload = self._build_payload(
213
- model=model,
214
- messages_history=messages_history,
215
- task_description=task_description,
216
- task_id=task_id,
217
- temperature=temperature,
218
- )
219
-
220
- return headers, payload
221
-
222
195
  def _parse_response_json(self, response: httpx.Response) -> dict[str, Any]:
223
196
  try:
224
197
  return response.json()
@@ -230,35 +203,6 @@ class BaseClient(Generic[HttpClientT]):
230
203
  response=response,
231
204
  )
232
205
 
233
- def _process_response(self, response: httpx.Response) -> "LLMResponse":
234
- response_data = self._parse_response_json(response)
235
-
236
- # Check if it's an error response (non-200 status)
237
- if response.status_code != 200:
238
- self._handle_response_error(response, response_data)
239
-
240
- # Parse successful response
241
- result = LLMResponse(**response_data)
242
-
243
- # Check if the response contains an error (even with 200 status)
244
- if result.error:
245
- logger.error(
246
- f"API Error in response: [{result.error.code}]: {result.error.message}"
247
- )
248
- raise APIError(
249
- result.error.message,
250
- code=result.error.code,
251
- status_code=200,
252
- response=response,
253
- )
254
-
255
- logger.info(
256
- f"API request successful - task_id: {result.task_id}, "
257
- f"complete: {result.is_complete}"
258
- )
259
- logger.debug(f"Response included {len(result.actions)} actions")
260
- return result
261
-
262
206
  def _process_upload_response(self, response: httpx.Response) -> UploadFileResponse:
263
207
  """Process response from /v1/file/upload endpoint.
264
208
 
@@ -449,7 +393,11 @@ class BaseClient(Generic[HttpClientT]):
449
393
  # Parse successful response
450
394
  result = GenerateResponse(**response_data)
451
395
 
396
+ # Capture request_id from response header
397
+ result.request_id = response.headers.get("X-Request-ID")
398
+
452
399
  logger.info(
453
400
  f"Generate request successful - tokens: {result.prompt_tokens}+{result.completion_tokens}, "
401
+ f"request_id: {result.request_id}"
454
402
  )
455
403
  return result