oagi-core 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. oagi/__init__.py +108 -0
  2. oagi/agent/__init__.py +31 -0
  3. oagi/agent/default.py +75 -0
  4. oagi/agent/factories.py +50 -0
  5. oagi/agent/protocol.py +55 -0
  6. oagi/agent/registry.py +155 -0
  7. oagi/agent/tasker/__init__.py +35 -0
  8. oagi/agent/tasker/memory.py +184 -0
  9. oagi/agent/tasker/models.py +83 -0
  10. oagi/agent/tasker/planner.py +385 -0
  11. oagi/agent/tasker/taskee_agent.py +395 -0
  12. oagi/agent/tasker/tasker_agent.py +323 -0
  13. oagi/async_pyautogui_action_handler.py +44 -0
  14. oagi/async_screenshot_maker.py +47 -0
  15. oagi/async_single_step.py +85 -0
  16. oagi/cli/__init__.py +11 -0
  17. oagi/cli/agent.py +125 -0
  18. oagi/cli/main.py +77 -0
  19. oagi/cli/server.py +94 -0
  20. oagi/cli/utils.py +82 -0
  21. oagi/client/__init__.py +12 -0
  22. oagi/client/async_.py +293 -0
  23. oagi/client/base.py +465 -0
  24. oagi/client/sync.py +296 -0
  25. oagi/exceptions.py +118 -0
  26. oagi/logging.py +47 -0
  27. oagi/pil_image.py +102 -0
  28. oagi/pyautogui_action_handler.py +268 -0
  29. oagi/screenshot_maker.py +41 -0
  30. oagi/server/__init__.py +13 -0
  31. oagi/server/agent_wrappers.py +98 -0
  32. oagi/server/config.py +46 -0
  33. oagi/server/main.py +157 -0
  34. oagi/server/models.py +98 -0
  35. oagi/server/session_store.py +116 -0
  36. oagi/server/socketio_server.py +405 -0
  37. oagi/single_step.py +87 -0
  38. oagi/task/__init__.py +14 -0
  39. oagi/task/async_.py +97 -0
  40. oagi/task/async_short.py +64 -0
  41. oagi/task/base.py +121 -0
  42. oagi/task/short.py +64 -0
  43. oagi/task/sync.py +97 -0
  44. oagi/types/__init__.py +28 -0
  45. oagi/types/action_handler.py +30 -0
  46. oagi/types/async_action_handler.py +30 -0
  47. oagi/types/async_image_provider.py +37 -0
  48. oagi/types/image.py +17 -0
  49. oagi/types/image_provider.py +34 -0
  50. oagi/types/models/__init__.py +32 -0
  51. oagi/types/models/action.py +33 -0
  52. oagi/types/models/client.py +64 -0
  53. oagi/types/models/image_config.py +47 -0
  54. oagi/types/models/step.py +17 -0
  55. oagi/types/url_image.py +47 -0
  56. oagi_core-0.9.0.dist-info/METADATA +257 -0
  57. oagi_core-0.9.0.dist-info/RECORD +60 -0
  58. oagi_core-0.9.0.dist-info/WHEEL +4 -0
  59. oagi_core-0.9.0.dist-info/entry_points.txt +2 -0
  60. oagi_core-0.9.0.dist-info/licenses/LICENSE +21 -0
oagi/server/models.py ADDED
@@ -0,0 +1,98 @@
1
+ # -----------------------------------------------------------------------------
2
+ # Copyright (c) OpenAGI Foundation
3
+ # All rights reserved.
4
+ #
5
+ # This file is part of the official API project.
6
+ # Licensed under the MIT License.
7
+ # -----------------------------------------------------------------------------
8
+
9
+ from typing import Literal
10
+
11
+ from pydantic import BaseModel, Field
12
+
13
+
14
+ # Client-to-server events
15
+ class InitEventData(BaseModel):
16
+ instruction: str = Field(...)
17
+ mode: str | None = Field(default="actor")
18
+ model: str | None = Field(default="lux-v1")
19
+ temperature: float | None = Field(default=0.1, ge=0.0, le=2.0)
20
+
21
+
22
+ # Server-to-client events
23
+ class BaseActionEventData(BaseModel):
24
+ index: int = Field(..., ge=0)
25
+ total: int = Field(..., ge=1)
26
+
27
+
28
+ class ClickEventData(BaseActionEventData):
29
+ x: int = Field(..., ge=0, le=1000)
30
+ y: int = Field(..., ge=0, le=1000)
31
+
32
+
33
+ class DragEventData(BaseActionEventData):
34
+ x1: int = Field(..., ge=0, le=1000)
35
+ y1: int = Field(..., ge=0, le=1000)
36
+ x2: int = Field(..., ge=0, le=1000)
37
+ y2: int = Field(..., ge=0, le=1000)
38
+
39
+
40
+ class HotkeyEventData(BaseActionEventData):
41
+ combo: str = Field(...)
42
+ count: int = Field(default=1, ge=1)
43
+
44
+
45
+ class TypeEventData(BaseActionEventData):
46
+ text: str = Field(...)
47
+
48
+
49
+ class ScrollEventData(BaseActionEventData):
50
+ x: int = Field(..., ge=0, le=1000)
51
+ y: int = Field(..., ge=0, le=1000)
52
+ direction: Literal["up", "down"] = Field(...)
53
+ count: int = Field(default=1, ge=1)
54
+
55
+
56
+ class WaitEventData(BaseActionEventData):
57
+ duration_ms: int = Field(default=1000, ge=0)
58
+
59
+
60
+ class FinishEventData(BaseActionEventData):
61
+ pass
62
+
63
+
64
+ # Screenshot request/response
65
+ class ScreenshotRequestData(BaseModel):
66
+ presigned_url: str = Field(...)
67
+ uuid: str = Field(...)
68
+ expires_at: str = Field(...)
69
+
70
+
71
+ class ScreenshotResponseData(BaseModel):
72
+ success: bool = Field(...)
73
+ error: str | None = Field(None)
74
+
75
+
76
+ # Action acknowledgement
77
+ class ActionAckData(BaseModel):
78
+ action_index: int = Field(...)
79
+ success: bool = Field(...)
80
+ error: str | None = Field(None)
81
+ execution_time_ms: int | None = Field(None)
82
+
83
+
84
+ # Session status
85
+ class SessionStatusData(BaseModel):
86
+ session_id: str = Field(...)
87
+ status: Literal["initialized", "running", "completed", "failed"] = Field(...)
88
+ instruction: str = Field(...)
89
+ created_at: str = Field(...)
90
+ actions_executed: int = Field(default=0)
91
+ last_activity: str = Field(...)
92
+
93
+
94
+ # Error event
95
+ class ErrorEventData(BaseModel):
96
+ message: str = Field(...)
97
+ code: str | None = Field(None)
98
+ details: dict | None = Field(None)
@@ -0,0 +1,116 @@
1
+ # -----------------------------------------------------------------------------
2
+ # Copyright (c) OpenAGI Foundation
3
+ # All rights reserved.
4
+ #
5
+ # This file is part of the official API project.
6
+ # Licensed under the MIT License.
7
+ # -----------------------------------------------------------------------------
8
+
9
+ import secrets
10
+ from datetime import datetime
11
+ from typing import Any
12
+ from uuid import uuid4
13
+
14
+
15
+ class Session:
16
+ def __init__(
17
+ self,
18
+ session_id: str,
19
+ instruction: str,
20
+ mode: str = "actor",
21
+ model: str = "lux-v1",
22
+ temperature: float = 0.0,
23
+ ):
24
+ self.session_id: str = session_id
25
+ self.instruction: str = instruction
26
+ self.mode: str = mode
27
+ self.model: str = model
28
+ self.temperature: float = temperature
29
+
30
+ # OAGI task state
31
+ self.task_id: str = uuid4().hex
32
+ self.message_history: list[dict[str, Any]] = []
33
+ self.current_screenshot_url: str | None = None
34
+
35
+ # Socket state
36
+ self.socket_id: str | None = None
37
+ self.namespace: str | None = None
38
+ self.last_activity: float = datetime.now().timestamp()
39
+
40
+ # Status tracking
41
+ self.status: str = "initialized"
42
+ self.created_at: str = datetime.now().isoformat()
43
+ self.actions_executed: int = 0
44
+
45
+ # OAGI client reference
46
+ self.oagi_client: Any | None = None
47
+
48
+
49
+ class SessionStore:
50
+ def __init__(self):
51
+ self.sessions: dict[str, Session] = {}
52
+
53
+ def create_session(
54
+ self,
55
+ instruction: str,
56
+ mode: str = "actor",
57
+ model: str = "lux-v1",
58
+ temperature: float = 0.0,
59
+ session_id: str | None = None,
60
+ ) -> str:
61
+ if session_id is None:
62
+ session_id = f"ses_{secrets.token_urlsafe(16)}"
63
+
64
+ session = Session(session_id, instruction, mode, model, temperature)
65
+ self.sessions[session_id] = session
66
+ return session_id
67
+
68
+ def get_session(self, session_id: str) -> Session | None:
69
+ return self.sessions.get(session_id)
70
+
71
+ def get_session_by_socket_id(self, socket_id: str) -> Session | None:
72
+ for session in self.sessions.values():
73
+ if session.socket_id == socket_id:
74
+ return session
75
+ return None
76
+
77
+ def delete_session(self, session_id: str) -> bool:
78
+ if session_id in self.sessions:
79
+ self.sessions.pop(session_id)
80
+ return True
81
+ return False
82
+
83
+ def update_activity(self, session_id: str) -> None:
84
+ session = self.sessions.get(session_id)
85
+ if session:
86
+ session.last_activity = datetime.now().timestamp()
87
+
88
+ def list_sessions(self) -> list[dict[str, Any]]:
89
+ return [
90
+ {
91
+ "session_id": session.session_id,
92
+ "status": session.status,
93
+ "instruction": session.instruction,
94
+ "created_at": session.created_at,
95
+ "actions_executed": session.actions_executed,
96
+ "connected": session.socket_id is not None,
97
+ }
98
+ for session in self.sessions.values()
99
+ ]
100
+
101
+ def cleanup_inactive_sessions(self, timeout_seconds: float) -> int:
102
+ current_time = datetime.now().timestamp()
103
+ sessions_to_delete = []
104
+
105
+ for session_id, session in self.sessions.items():
106
+ if current_time - session.last_activity > timeout_seconds:
107
+ sessions_to_delete.append(session_id)
108
+
109
+ for session_id in sessions_to_delete:
110
+ self.delete_session(session_id)
111
+
112
+ return len(sessions_to_delete)
113
+
114
+
115
+ # Global instance
116
+ session_store = SessionStore()
@@ -0,0 +1,405 @@
1
+ # -----------------------------------------------------------------------------
2
+ # Copyright (c) OpenAGI Foundation
3
+ # All rights reserved.
4
+ #
5
+ # This file is part of the official API project.
6
+ # Licensed under the MIT License.
7
+ # -----------------------------------------------------------------------------
8
+
9
+ import asyncio
10
+ import logging
11
+ from datetime import datetime
12
+ from typing import Any
13
+
14
+ from pydantic import ValidationError
15
+
16
+ from ..agent import AsyncDefaultAgent, create_agent
17
+ from ..client import AsyncClient
18
+ from ..exceptions import check_optional_dependency
19
+ from ..types.models.action import Action, ActionType
20
+ from .agent_wrappers import SocketIOActionHandler, SocketIOImageProvider
21
+ from .config import ServerConfig
22
+ from .models import (
23
+ BaseActionEventData,
24
+ ClickEventData,
25
+ DragEventData,
26
+ ErrorEventData,
27
+ FinishEventData,
28
+ HotkeyEventData,
29
+ InitEventData,
30
+ ScrollEventData,
31
+ TypeEventData,
32
+ WaitEventData,
33
+ )
34
+ from .session_store import Session, session_store
35
+
36
+ check_optional_dependency("socketio", "Server features", "server")
37
+ import socketio # noqa: E402
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+ sio = socketio.AsyncServer(
42
+ async_mode="asgi",
43
+ cors_allowed_origins="*",
44
+ logger=False,
45
+ engineio_logger=False,
46
+ )
47
+
48
+
49
+ class SessionNamespace(socketio.AsyncNamespace):
50
+ def __init__(self, namespace: str, config: ServerConfig):
51
+ super().__init__(namespace)
52
+ self.config = config
53
+ self.background_tasks: dict[str, asyncio.Task] = {}
54
+
55
+ async def on_connect(self, sid: str, environ: dict, auth: dict | None) -> bool:
56
+ session_id = self.namespace.split("/")[-1]
57
+ logger.info(f"Client {sid} connected to session {session_id}")
58
+
59
+ session = session_store.get_session(session_id)
60
+ if session:
61
+ session.socket_id = sid
62
+ session.namespace = self.namespace
63
+ session_store.update_activity(session_id)
64
+
65
+ # Create OAGI client if not exists
66
+ if not session.oagi_client:
67
+ session.oagi_client = AsyncClient(
68
+ base_url=self.config.oagi_base_url,
69
+ api_key=self.config.oagi_api_key,
70
+ )
71
+ else:
72
+ logger.warning(f"Connection to non-existent session {session_id}")
73
+ # Create session on connect if it doesn't exist
74
+ session = Session(
75
+ session_id=session_id,
76
+ instruction="",
77
+ mode="actor", # Default mode
78
+ model=self.config.default_model,
79
+ temperature=self.config.default_temperature,
80
+ )
81
+ session.socket_id = sid
82
+ session.namespace = self.namespace
83
+ session.oagi_client = AsyncClient(
84
+ base_url=self.config.oagi_base_url,
85
+ api_key=self.config.oagi_api_key,
86
+ )
87
+ session_store.sessions[session_id] = session
88
+
89
+ return True
90
+
91
+ async def on_disconnect(self, sid: str) -> None:
92
+ session_id = self.namespace.split("/")[-1]
93
+ logger.info(f"Client {sid} disconnected from session {session_id}")
94
+
95
+ # Cancel any background tasks
96
+ if sid in self.background_tasks:
97
+ self.background_tasks[sid].cancel()
98
+ del self.background_tasks[sid]
99
+
100
+ # Start cleanup task
101
+ asyncio.create_task(self._cleanup_after_timeout(session_id))
102
+
103
+ async def _cleanup_after_timeout(self, session_id: str) -> None:
104
+ await asyncio.sleep(self.config.session_timeout_seconds)
105
+
106
+ session = session_store.get_session(session_id)
107
+ if session:
108
+ current_time = datetime.now().timestamp()
109
+ if (
110
+ current_time - session.last_activity
111
+ >= self.config.session_timeout_seconds
112
+ ):
113
+ logger.info(f"Session {session_id} timed out, cleaning up")
114
+
115
+ # Close OAGI client
116
+ if session.oagi_client:
117
+ await session.oagi_client.close()
118
+
119
+ session_store.delete_session(session_id)
120
+
121
+ async def on_init(self, sid: str, data: dict) -> None:
122
+ try:
123
+ session_id = self.namespace.split("/")[-1]
124
+ logger.info(f"Initializing session {session_id}")
125
+
126
+ # Validate input
127
+ event_data = InitEventData(**data)
128
+
129
+ # Get or create session
130
+ session = session_store.get_session(session_id)
131
+ if not session:
132
+ logger.error(f"Session {session_id} not found")
133
+ await self.emit(
134
+ "error",
135
+ ErrorEventData(
136
+ message=f"Session {session_id} not found"
137
+ ).model_dump(),
138
+ room=sid,
139
+ )
140
+ return
141
+
142
+ # Update session with init data
143
+ session.instruction = event_data.instruction
144
+ if event_data.mode:
145
+ session.mode = event_data.mode
146
+ if event_data.model:
147
+ session.model = event_data.model
148
+ if event_data.temperature is not None:
149
+ session.temperature = event_data.temperature
150
+ session.status = "running"
151
+ session_store.update_activity(session_id)
152
+
153
+ logger.info(
154
+ f"Session {session_id} initialized with: {event_data.instruction} "
155
+ f"(mode={event_data.mode}, model={event_data.model})"
156
+ )
157
+
158
+ # Create agent and wrappers
159
+ agent = create_agent(
160
+ mode=session.mode,
161
+ api_key=self.config.oagi_api_key,
162
+ base_url=self.config.oagi_base_url,
163
+ max_steps=self.config.max_steps,
164
+ model=event_data.model,
165
+ temperature=event_data.temperature,
166
+ )
167
+
168
+ action_handler = SocketIOActionHandler(self, session)
169
+ image_provider = SocketIOImageProvider(self, session, session.oagi_client)
170
+
171
+ # Start execution in background using agent
172
+ task = asyncio.create_task(
173
+ self._run_agent_task(
174
+ agent,
175
+ session,
176
+ action_handler,
177
+ image_provider,
178
+ event_data.instruction,
179
+ )
180
+ )
181
+ self.background_tasks[sid] = task
182
+
183
+ except ValidationError as e:
184
+ logger.error(f"Invalid init data: {e}")
185
+ await self.emit(
186
+ "error",
187
+ ErrorEventData(
188
+ message="Invalid init data",
189
+ details={"validation_errors": e.errors()},
190
+ ).model_dump(),
191
+ room=sid,
192
+ )
193
+ except Exception as e:
194
+ logger.error(f"Error in init: {e}", exc_info=True)
195
+ await self.emit(
196
+ "error",
197
+ ErrorEventData(message=str(e)).model_dump(),
198
+ room=sid,
199
+ )
200
+
201
+ async def _run_agent_task(
202
+ self,
203
+ agent: AsyncDefaultAgent,
204
+ session: Session,
205
+ action_handler: SocketIOActionHandler,
206
+ image_provider: SocketIOImageProvider,
207
+ instruction: str,
208
+ ) -> None:
209
+ try:
210
+ # Execute task using agent
211
+ success = await agent.execute(
212
+ instruction=instruction,
213
+ action_handler=action_handler,
214
+ image_provider=image_provider,
215
+ )
216
+
217
+ # Update session status
218
+ if success:
219
+ session.status = "completed"
220
+ logger.info(
221
+ f"Task completed successfully for session {session.session_id}"
222
+ )
223
+
224
+ # Emit finish event
225
+ await self.call(
226
+ "finish",
227
+ FinishEventData(action_index=0, total_actions=1).model_dump(),
228
+ to=session.socket_id,
229
+ timeout=self.config.socketio_timeout,
230
+ )
231
+ else:
232
+ session.status = "failed"
233
+ logger.warning(f"Task failed for session {session.session_id}")
234
+
235
+ session_store.update_activity(session.session_id)
236
+
237
+ except asyncio.CancelledError:
238
+ logger.info(f"Agent task cancelled for session {session.session_id}")
239
+ session.status = "cancelled"
240
+ except Exception as e:
241
+ logger.error(f"Error in agent task: {e}", exc_info=True)
242
+ session.status = "failed"
243
+ if session.socket_id:
244
+ await self.emit(
245
+ "error",
246
+ ErrorEventData(message=f"Execution failed: {str(e)}").model_dump(),
247
+ room=session.socket_id,
248
+ )
249
+
250
+ async def _emit_actions(self, session: Session, actions: list[Action]) -> None:
251
+ total = len(actions)
252
+
253
+ for i, action in enumerate(actions):
254
+ try:
255
+ ack = await self._emit_single_action(session, action, i, total)
256
+ session.actions_executed += 1
257
+
258
+ if ack and not ack.get("success"):
259
+ logger.warning(f"Action {i} failed: {ack.get('error')}")
260
+
261
+ except Exception as e:
262
+ logger.error(f"Error emitting action {i}: {e}", exc_info=True)
263
+
264
+ async def _emit_single_action(
265
+ self, session: Session, action: Action, index: int, total: int
266
+ ) -> dict | None:
267
+ arg = action.argument.strip("()")
268
+ common = BaseActionEventData(index=index, total=total).model_dump()
269
+
270
+ logger.info(f"Emitting action {index + 1}/{total}: {action.type.value} {arg}")
271
+ match action.type:
272
+ case (
273
+ ActionType.CLICK
274
+ | ActionType.LEFT_DOUBLE
275
+ | ActionType.LEFT_TRIPLE
276
+ | ActionType.RIGHT_SINGLE
277
+ ):
278
+ coords = arg.split(",")
279
+ if len(coords) >= 2:
280
+ x, y = int(coords[0]), int(coords[1])
281
+ else:
282
+ logger.warning(f"Invalid action coordinates: {arg}")
283
+ return None
284
+
285
+ return await self.call(
286
+ action.type.value,
287
+ ClickEventData(**common, x=x, y=y).model_dump(),
288
+ to=session.socket_id,
289
+ timeout=self.config.socketio_timeout,
290
+ )
291
+
292
+ case ActionType.DRAG:
293
+ coords = arg.split(",")
294
+ if len(coords) >= 4:
295
+ x1, y1, x2, y2 = (int(coords[i]) for i in range(4))
296
+ else:
297
+ logger.warning(f"Invalid drag coordinates: {arg}")
298
+ return None
299
+
300
+ return await self.call(
301
+ "drag",
302
+ DragEventData(**common, x1=x1, y1=y1, x2=x2, y2=y2).model_dump(),
303
+ to=session.socket_id,
304
+ timeout=self.config.socketio_timeout,
305
+ )
306
+
307
+ case ActionType.HOTKEY:
308
+ combo = arg.strip()
309
+ count = action.count or 1
310
+
311
+ return await self.call(
312
+ "hotkey",
313
+ HotkeyEventData(**common, combo=combo, count=count).model_dump(),
314
+ to=session.socket_id,
315
+ timeout=self.config.socketio_timeout,
316
+ )
317
+
318
+ case ActionType.TYPE:
319
+ text = arg.strip()
320
+
321
+ return await self.call(
322
+ "type",
323
+ TypeEventData(**common, text=text).model_dump(),
324
+ to=session.socket_id,
325
+ timeout=self.config.socketio_timeout,
326
+ )
327
+
328
+ case ActionType.SCROLL:
329
+ parts = arg.split(",")
330
+ if len(parts) >= 3:
331
+ x, y = int(parts[0]), int(parts[1])
332
+ direction = parts[2].strip().lower()
333
+ else:
334
+ logger.warning(f"Invalid scroll coordinates: {arg}")
335
+ return None
336
+
337
+ count = action.count or 1
338
+
339
+ return await self.call(
340
+ "scroll",
341
+ ScrollEventData(
342
+ **common,
343
+ x=x,
344
+ y=y,
345
+ direction=direction,
346
+ count=count, # type: ignore
347
+ ).model_dump(),
348
+ to=session.socket_id,
349
+ timeout=self.config.socketio_timeout,
350
+ )
351
+
352
+ case ActionType.WAIT:
353
+ try:
354
+ duration_ms = int(arg) if arg else 1000
355
+ except (ValueError, TypeError):
356
+ duration_ms = 1000
357
+
358
+ return await self.call(
359
+ "wait",
360
+ WaitEventData(**common, duration_ms=duration_ms).model_dump(),
361
+ to=session.socket_id,
362
+ timeout=self.config.socketio_timeout,
363
+ )
364
+
365
+ case ActionType.FINISH:
366
+ return await self.call(
367
+ "finish",
368
+ FinishEventData(**common).model_dump(),
369
+ to=session.socket_id,
370
+ timeout=self.config.socketio_timeout,
371
+ )
372
+
373
+ case _:
374
+ logger.warning(f"Unknown action type: {action.type}")
375
+ return None
376
+
377
+
378
+ # Dynamic namespace registration
379
+ _registered_namespaces: dict[str, SessionNamespace] = {}
380
+
381
+
382
+ def get_or_create_namespace(namespace: str, config: ServerConfig) -> SessionNamespace:
383
+ if namespace not in _registered_namespaces:
384
+ ns = SessionNamespace(namespace, config)
385
+ sio.register_namespace(ns)
386
+ _registered_namespaces[namespace] = ns
387
+ logger.info(f"Registered namespace: {namespace}")
388
+ return _registered_namespaces[namespace]
389
+
390
+
391
+ # Patch connect handler for dynamic registration
392
+ original_connect = sio._handle_connect
393
+
394
+
395
+ async def _patched_handle_connect(eio_sid: str, namespace: str, data: Any) -> Any:
396
+ if namespace and namespace.startswith("/session/"):
397
+ config = ServerConfig()
398
+ get_or_create_namespace(namespace, config)
399
+ return await original_connect(eio_sid, namespace, data)
400
+
401
+
402
+ sio._handle_connect = _patched_handle_connect
403
+
404
+ # Create ASGI app
405
+ socket_app = socketio.ASGIApp(sio, socketio_path="socket.io")
oagi/single_step.py ADDED
@@ -0,0 +1,87 @@
1
+ # -----------------------------------------------------------------------------
2
+ # Copyright (c) OpenAGI Foundation
3
+ # All rights reserved.
4
+ #
5
+ # This file is part of the official API project.
6
+ # Licensed under the MIT License.
7
+ # -----------------------------------------------------------------------------
8
+
9
+ from pathlib import Path
10
+
11
+ from .task import Task
12
+ from .types import Image, Step
13
+
14
+
15
+ def single_step(
16
+ task_description: str,
17
+ screenshot: str | bytes | Path | Image,
18
+ instruction: str | None = None,
19
+ api_key: str | None = None,
20
+ base_url: str | None = None,
21
+ temperature: float | None = None,
22
+ ) -> Step:
23
+ """
24
+ Perform a single-step inference without maintaining task state.
25
+
26
+ This is useful for one-off analyses where you don't need to maintain
27
+ a conversation or task context across multiple steps.
28
+
29
+ Args:
30
+ task_description: Description of the task to perform
31
+ screenshot: Screenshot as Image, bytes, or file path
32
+ instruction: Optional additional instruction for the task
33
+ api_key: OAGI API key (uses environment variable if not provided)
34
+ base_url: OAGI base URL (uses environment variable if not provided)
35
+ temperature: Sampling temperature (0.0-2.0) for LLM inference
36
+
37
+ Returns:
38
+ Step: Object containing reasoning, actions, and completion status
39
+
40
+ Example:
41
+ >>> # Using with bytes
42
+ >>> with open("screenshot.png", "rb") as f:
43
+ ... image_bytes = f.read()
44
+ >>> step = single_step(
45
+ ... task_description="Click the submit button",
46
+ ... screenshot=image_bytes
47
+ ... )
48
+
49
+ >>> # Using with file path
50
+ >>> step = single_step(
51
+ ... task_description="Fill in the form",
52
+ ... screenshot=Path("screenshot.png"),
53
+ ... instruction="Use test@example.com for email"
54
+ ... )
55
+
56
+ >>> # Using with Image object
57
+ >>> from oagi.types import Image
58
+ >>> image = Image(...)
59
+ >>> step = single_step(
60
+ ... task_description="Navigate to settings",
61
+ ... screenshot=image
62
+ ... )
63
+ """
64
+ # Lazy import PILImage only when needed
65
+ from .pil_image import PILImage # noqa: PLC0415
66
+
67
+ # Convert file paths to bytes using PILImage
68
+ if isinstance(screenshot, (str, Path)):
69
+ path = Path(screenshot) if isinstance(screenshot, str) else screenshot
70
+ if path.exists():
71
+ pil_image = PILImage.from_file(str(path))
72
+ screenshot_bytes = pil_image.read()
73
+ else:
74
+ raise FileNotFoundError(f"Screenshot file not found: {path}")
75
+ elif isinstance(screenshot, bytes):
76
+ screenshot_bytes = screenshot
77
+ elif isinstance(screenshot, Image):
78
+ screenshot_bytes = screenshot.read()
79
+ else:
80
+ raise ValueError(
81
+ f"screenshot must be Image, bytes, str, or Path, got {type(screenshot)}"
82
+ )
83
+
84
+ # Use Task to perform single step
85
+ with Task(api_key=api_key, base_url=base_url, temperature=temperature) as task:
86
+ task.init_task(task_description)
87
+ return task.step(screenshot_bytes, instruction=instruction)