notte-sdk 0.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
notte_sdk/types.py ADDED
@@ -0,0 +1,851 @@
1
+ import datetime as dt
2
+ import json
3
+ from base64 import b64decode, b64encode
4
+ from collections.abc import Sequence
5
+ from enum import StrEnum
6
+ from pathlib import Path
7
+ from typing import Annotated, Any, Generic, Literal, Required, TypeVar
8
+
9
+ from notte_core.actions.base import Action, BrowserAction
10
+ from notte_core.browser.observation import Observation, TrajectoryProgress
11
+ from notte_core.browser.snapshot import SnapshotMetadata, TabsData
12
+ from notte_core.controller.actions import BaseAction
13
+ from notte_core.controller.space import BaseActionSpace
14
+ from notte_core.credentials.base import BaseVault, CredentialField, CredentialsDict
15
+ from notte_core.data.space import DataSpace
16
+ from notte_core.llms.engine import LlmModel
17
+ from notte_core.utils.pydantic_schema import create_model_from_schema
18
+ from notte_core.utils.url import get_root_domain
19
+ from pydantic import BaseModel, Field, field_validator, model_validator
20
+ from typing_extensions import TypedDict, override
21
+
22
+ # ############################################################
23
+ # Session Management
24
+ # ############################################################
25
+
26
+
27
+ DEFAULT_OPERATION_SESSION_TIMEOUT_IN_MINUTES = 3
28
+ DEFAULT_GLOBAL_SESSION_TIMEOUT_IN_MINUTES = 30
29
+ DEFAULT_MAX_NB_ACTIONS = 100
30
+ DEFAULT_MAX_NB_STEPS = 20
31
+ DEFAULT_LIMIT_LIST_ITEMS = 10
32
+
33
+
34
+ class ExecutionResponse(BaseModel):
35
+ success: Annotated[bool, Field(description="Whether the operation was successful")]
36
+ message: Annotated[str, Field(description="A message describing the operation")]
37
+
38
+
39
+ class PlaywrightProxySettings(TypedDict, total=False):
40
+ server: str
41
+ bypass: str | None
42
+ username: str | None
43
+ password: str | None
44
+
45
+
46
+ class BrowserType(StrEnum):
47
+ CHROMIUM = "chromium"
48
+ CHROME = "chrome"
49
+ FIREFOX = "firefox"
50
+
51
+
52
+ class ProxyGeolocation(BaseModel):
53
+ """
54
+ Geolocation settings for the proxy.
55
+ E.g. "New York, NY, US"
56
+ """
57
+
58
+ city: str
59
+ state: str
60
+ country: str
61
+
62
+
63
+ class ProxyType(StrEnum):
64
+ NOTTE = "notte"
65
+ EXTERNAL = "external"
66
+
67
+
68
+ class ProxySettings(BaseModel):
69
+ type: ProxyType
70
+ server: str | None
71
+ bypass: str | None
72
+ username: str | None
73
+ password: str | None
74
+ # TODO: enable geolocation later on
75
+ # geolocation: ProxyGeolocation | None
76
+
77
+ @field_validator("server")
78
+ @classmethod
79
+ def validate_server(cls, v: str | None, info: Any) -> str | None:
80
+ if info.data.get("type") == ProxyType.EXTERNAL and v is None:
81
+ raise ValueError("Server is required for external proxy type")
82
+ return v
83
+
84
+ def to_playwright(self) -> PlaywrightProxySettings:
85
+ if self.server is None:
86
+ raise ValueError("Proxy server is required")
87
+ return PlaywrightProxySettings(
88
+ server=self.server,
89
+ bypass=self.bypass,
90
+ username=self.username,
91
+ password=self.password,
92
+ )
93
+
94
+
95
+ class Cookie(BaseModel):
96
+ name: str
97
+ domain: str
98
+ path: str
99
+ httpOnly: bool
100
+ expirationDate: float | None = None
101
+ hostOnly: bool | None = None
102
+ sameSite: str | None = None
103
+ secure: bool | None = None
104
+ session: bool | None = None
105
+ storeId: str | None = None
106
+ value: str
107
+ expires: float | None = Field(default=None)
108
+
109
+ @model_validator(mode="before")
110
+ @classmethod
111
+ def validate_expiration(cls, data: dict[str, Any]) -> dict[str, Any]:
112
+ # Handle either expirationDate or expires being provided
113
+ if data.get("expirationDate") is None and data.get("expires") is not None:
114
+ data["expirationDate"] = float(data["expires"])
115
+ elif data.get("expires") is None and data.get("expirationDate") is not None:
116
+ data["expires"] = float(data["expirationDate"])
117
+ return data
118
+
119
+ @override
120
+ def model_post_init(self, __context: Any) -> None:
121
+ # Set expires if expirationDate is provided but expires is not
122
+ if self.expirationDate is not None and self.expires is None:
123
+ self.expires = float(self.expirationDate)
124
+ # Set expirationDate if expires is provided but expirationDate is not
125
+ elif self.expires is not None and self.expirationDate is None:
126
+ self.expirationDate = float(self.expires)
127
+
128
+ if self.sameSite is not None:
129
+ self.sameSite = self.sameSite.lower()
130
+ self.sameSite = self.sameSite[0].upper() + self.sameSite[1:]
131
+
132
+ @staticmethod
133
+ def from_json(path: str | Path) -> list["Cookie"]:
134
+ path = Path(path)
135
+ if not path.exists():
136
+ raise FileNotFoundError(f"Cookies file not found at {path}")
137
+ with open(path, "r") as f:
138
+ cookies_json = json.load(f)
139
+ cookies = [Cookie.model_validate(cookie) for cookie in cookies_json]
140
+ return cookies
141
+
142
+
143
+ class UploadCookiesRequest(BaseModel):
144
+ cookies: list[Cookie]
145
+
146
+ @staticmethod
147
+ def from_json(path: str | Path) -> "UploadCookiesRequest":
148
+ cookies = Cookie.from_json(path)
149
+ return UploadCookiesRequest(cookies=cookies)
150
+
151
+
152
+ class UploadCookiesResponse(BaseModel):
153
+ success: bool
154
+ message: str
155
+
156
+
157
+ class ReplayResponse(BaseModel):
158
+ replay: Annotated[bytes | None, Field(description="The session replay in `.webp` format", repr=False)] = None
159
+
160
+ model_config = { # type: ignore[reportUnknownMemberType]
161
+ "json_encoders": {
162
+ bytes: lambda v: b64encode(v).decode("utf-8") if v else None,
163
+ }
164
+ }
165
+
166
+ @field_validator("replay", mode="before")
167
+ @classmethod
168
+ def decode_replay(cls, value: str | None) -> bytes | None:
169
+ if value is None:
170
+ return None
171
+ if isinstance(value, bytes):
172
+ return value
173
+ if not isinstance(value, str): # pyright: ignore[reportUnnecessaryIsInstance]
174
+ raise ValueError("replay must be a bytes or a base64 encoded string") # pyright: ignore[reportUnreachable]
175
+ return b64decode(value.encode("utf-8"))
176
+
177
+ @override
178
+ def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
179
+ data = super().model_dump(*args, **kwargs)
180
+ if self.replay is not None:
181
+ data["replay"] = b64encode(self.replay).decode("utf-8")
182
+ return data
183
+
184
+
185
+ class SessionStartRequestDict(TypedDict, total=False):
186
+ timeout_minutes: int
187
+ max_steps: int
188
+ proxies: list[ProxySettings] | bool
189
+ browser_type: BrowserType
190
+ chrome_args: list[str] | None
191
+
192
+
193
+ class SessionRequestDict(TypedDict, total=False):
194
+ session_id: str | None
195
+
196
+
197
+ class SessionStartRequest(BaseModel):
198
+ timeout_minutes: Annotated[
199
+ int,
200
+ Field(
201
+ description="Session timeout in minutes. Cannot exceed the global timeout.",
202
+ gt=0,
203
+ le=DEFAULT_GLOBAL_SESSION_TIMEOUT_IN_MINUTES,
204
+ ),
205
+ ] = DEFAULT_OPERATION_SESSION_TIMEOUT_IN_MINUTES
206
+
207
+ max_steps: Annotated[
208
+ int,
209
+ Field(
210
+ gt=0,
211
+ description="Maximum number of steps in the trajectory. An error will be raised if this limit is reached.",
212
+ ),
213
+ ] = DEFAULT_MAX_NB_STEPS
214
+
215
+ proxies: Annotated[
216
+ list[ProxySettings] | bool,
217
+ Field(
218
+ description="List of custom proxies to use for the session. If True, the default proxies will be used.",
219
+ ),
220
+ ] = False
221
+ browser_type: BrowserType = BrowserType.CHROMIUM
222
+ chrome_args: Annotated[list[str] | None, Field(description="Override the chrome instance arguments")] = None
223
+
224
+ def __post_init__(self):
225
+ """
226
+ Validate that the session timeout does not exceed the allowed global limit.
227
+
228
+ Raises:
229
+ ValueError: If the session's timeout_minutes exceeds DEFAULT_GLOBAL_SESSION_TIMEOUT_IN_MINUTES.
230
+ """
231
+ if self.timeout_minutes > DEFAULT_GLOBAL_SESSION_TIMEOUT_IN_MINUTES:
232
+ raise ValueError(
233
+ (
234
+ "Session timeout cannot be greater than global timeout: "
235
+ f"{self.timeout_minutes} > {DEFAULT_GLOBAL_SESSION_TIMEOUT_IN_MINUTES}"
236
+ )
237
+ )
238
+
239
+
240
+ class SessionRequest(BaseModel):
241
+ session_id: Annotated[
242
+ str | None,
243
+ Field(description="The ID of the session. A new session is created when not provided."),
244
+ ] = None
245
+
246
+
247
+ class SessionStatusRequest(BaseModel):
248
+ session_id: Annotated[
249
+ str | None,
250
+ Field(description="The ID of the session. A new session is created when not provided."),
251
+ ] = None
252
+
253
+ replay: Annotated[
254
+ bool,
255
+ Field(description="Whether to include the video replay in the response (`.webp` format)."),
256
+ ] = False
257
+
258
+
259
+ class ListRequestDict(TypedDict, total=False):
260
+ only_active: bool
261
+ limit: int
262
+
263
+
264
+ class SessionListRequest(BaseModel):
265
+ only_active: bool = True
266
+ limit: int = DEFAULT_LIMIT_LIST_ITEMS
267
+
268
+
269
+ class SessionResponse(BaseModel):
270
+ session_id: Annotated[
271
+ str,
272
+ Field(
273
+ description=(
274
+ "The ID of the session (created or existing). "
275
+ "Use this ID to interact with the session for the next operation."
276
+ )
277
+ ),
278
+ ]
279
+ timeout_minutes: Annotated[
280
+ int,
281
+ Field(description="Session timeout in minutes. Will timeout if now() > last access time + timeout_minutes"),
282
+ ]
283
+ created_at: Annotated[dt.datetime, Field(description="Session creation time")]
284
+ closed_at: Annotated[dt.datetime | None, Field(description="Session closing time")] = None
285
+ last_accessed_at: Annotated[dt.datetime, Field(description="Last access time")]
286
+ duration: Annotated[dt.timedelta, Field(description="Session duration")] = Field(
287
+ default_factory=lambda: dt.timedelta(0)
288
+ )
289
+ status: Annotated[
290
+ Literal["active", "closed", "error", "timed_out"],
291
+ Field(description="Session status"),
292
+ ]
293
+ # TODO: discuss if this is the best way to handle errors
294
+ error: Annotated[str | None, Field(description="Error message if the operation failed to complete")] = None
295
+ proxies: Annotated[
296
+ bool,
297
+ Field(
298
+ description="Whether proxies were used for the session. True if any proxy was applied during session creation."
299
+ ),
300
+ ] = False
301
+ browser_type: BrowserType = BrowserType.CHROMIUM
302
+
303
+ @field_validator("closed_at", mode="before")
304
+ @classmethod
305
+ def validate_closed_at(cls, value: dt.datetime | None, info: Any) -> dt.datetime | None:
306
+ data = info.data
307
+ if data.get("status") == "closed" and value is None:
308
+ raise ValueError("closed_at must be provided if status is closed")
309
+ return value
310
+
311
+ @field_validator("duration", mode="before")
312
+ @classmethod
313
+ def compute_duration(cls, value: dt.timedelta | None, info: Any) -> dt.timedelta:
314
+ data = info.data
315
+ if value is not None:
316
+ return value
317
+ if data.get("status") == "closed" and data.get("closed_at") is not None:
318
+ return data["closed_at"] - data["created_at"]
319
+ return dt.datetime.now() - data["created_at"]
320
+
321
+
322
+ class SessionStatusResponse(SessionResponse, ReplayResponse):
323
+ pass
324
+
325
+
326
+ class SessionResponseDict(TypedDict, total=False):
327
+ session_id: str
328
+ timeout_minutes: int
329
+ created_at: dt.datetime
330
+ last_accessed_at: dt.datetime
331
+ duration: dt.timedelta
332
+ closed_at: dt.datetime | None
333
+ status: Literal["active", "closed", "error", "timed_out"]
334
+ error: str | None
335
+ proxies: bool
336
+ browser_type: BrowserType
337
+
338
+
339
+ # ############################################################
340
+ # Session debug endpoints
341
+ # ############################################################
342
+
343
+
344
+ class TabSessionDebugRequest(BaseModel):
345
+ tab_idx: int
346
+
347
+
348
+ class TabSessionDebugResponse(BaseModel):
349
+ metadata: TabsData
350
+ debug_url: str
351
+ ws_url: str
352
+
353
+
354
+ class WebSocketUrls(BaseModel):
355
+ cdp: Annotated[str, Field(description="WebSocket URL to connect using CDP protocol")]
356
+ recording: Annotated[str, Field(description="WebSocket URL for live session recording (screenshot stream)")]
357
+ logs: Annotated[str, Field(description="WebSocket URL for live logs (obsveration / actions events)")]
358
+
359
+
360
+ class SessionDebugResponse(BaseModel):
361
+ debug_url: str
362
+ ws: WebSocketUrls
363
+ tabs: list[TabSessionDebugResponse]
364
+
365
+
366
+ class SessionDebugRecordingEvent(BaseModel):
367
+ """Model for events that can be sent over the recording WebSocket"""
368
+
369
+ type: Literal["action", "observation", "error"]
370
+ data: BaseAction | Observation | str
371
+ timestamp: dt.datetime = Field(default_factory=dt.datetime.now)
372
+
373
+ @staticmethod
374
+ def session_closed() -> "SessionDebugRecordingEvent":
375
+ return SessionDebugRecordingEvent(
376
+ type="error",
377
+ data="Session closed by user. No more actions will be recorded.",
378
+ )
379
+
380
+
381
+ # ############################################################
382
+ # Persona
383
+ # ############################################################
384
+
385
+
386
+ class EmailsReadRequestDict(TypedDict, total=False):
387
+ limit: int
388
+ timedelta: dt.timedelta | None
389
+ unread_only: bool
390
+
391
+
392
+ class EmailsReadRequest(BaseModel):
393
+ limit: Annotated[int, Field(description="Max number of emails to return")] = DEFAULT_LIMIT_LIST_ITEMS
394
+ timedelta: Annotated[
395
+ dt.timedelta | None, Field(description="Return only emails that are not older than <timedelta>")
396
+ ] = None
397
+ unread_only: Annotated[bool, Field(description="Return only previously unread emails")] = False
398
+
399
+
400
+ class EmailResponse(BaseModel):
401
+ subject: Annotated[str, Field(description="Subject of the email")]
402
+ email_id: Annotated[str, Field(description="Email UUID")]
403
+ created_at: Annotated[dt.datetime, Field(description="Creation date")]
404
+ sender_email: Annotated[str | None, Field(description="Email address of the sender")]
405
+ sender_name: Annotated[str | None, Field(description="Name (if available) of the sender")]
406
+ text_content: Annotated[
407
+ str | None, Field(description="Raw textual body, can be uncorrelated with html content")
408
+ ] = None
409
+ html_content: Annotated[str | None, Field(description="HTML body, can be uncorrelated with raw content")] = None
410
+
411
+
412
+ class SMSReadRequestDict(TypedDict, total=False):
413
+ limit: int
414
+ timedelta: dt.timedelta | None
415
+ unread_only: bool
416
+
417
+
418
+ class SMSReadRequest(BaseModel):
419
+ limit: Annotated[int, Field(description="Max number of messages to return")] = DEFAULT_LIMIT_LIST_ITEMS
420
+ timedelta: Annotated[
421
+ dt.timedelta | None, Field(description="Return only messages that are not older than <timedelta>")
422
+ ] = None
423
+ unread_only: Annotated[bool, Field(description="Return only previously unread messages")] = False
424
+
425
+
426
+ class SMSResponse(BaseModel):
427
+ body: Annotated[str, Field(description="SMS message body")]
428
+ sms_id: Annotated[str, Field(description="SMS UUID")]
429
+ created_at: Annotated[dt.datetime, Field(description="Creation date")]
430
+ sender: Annotated[str | None, Field(description="SMS sender phone number")]
431
+
432
+
433
+ class PersonaCreateRequestDict(TypedDict, total=False):
434
+ pass
435
+
436
+
437
+ class PersonaCreateRequest(BaseModel):
438
+ pass
439
+
440
+
441
+ class PersonaCreateResponse(BaseModel):
442
+ persona_id: Annotated[str, Field(description="ID of the created persona")]
443
+
444
+
445
+ class VirtualNumberRequestDict(TypedDict, total=False):
446
+ pass
447
+
448
+
449
+ class VirtualNumberRequest(BaseModel):
450
+ pass
451
+
452
+
453
+ class VirtualNumberResponse(BaseModel):
454
+ status: Annotated[str, Field(description="Status of the created virtual number")]
455
+
456
+
457
+ class AddCredentialsRequestDict(CredentialsDict, total=False):
458
+ url: str | None
459
+
460
+
461
+ def validate_url(value: str | None) -> str | None:
462
+ if value is None:
463
+ return None
464
+ domain_url = get_root_domain(value)
465
+ if len(domain_url) == 0:
466
+ raise ValueError(f"Invalid URL: {value}. Please provide a valid URL with a domain name.")
467
+ return domain_url
468
+
469
+
470
+ class AddCredentialsRequest(BaseModel):
471
+ url: str | None
472
+ credentials: Annotated[list[CredentialField], Field(description="Credentials to add")]
473
+
474
+ @field_validator("url", mode="before")
475
+ @classmethod
476
+ def validate_url(cls, value: str | None) -> str | None:
477
+ return validate_url(value)
478
+
479
+ @staticmethod
480
+ def load(body: dict[str, Any]) -> "AddCredentialsRequest":
481
+ url = body.get("url")
482
+ creds = [CredentialField.from_dict(field) for field in body["credentials"]]
483
+ return AddCredentialsRequest(url=url, credentials=creds)
484
+
485
+ @classmethod
486
+ def from_request_dict(cls, dic: AddCredentialsRequestDict):
487
+ if "url" not in dic:
488
+ raise ValueError("Invalid credentials request dict")
489
+
490
+ no_url = dic.copy()
491
+ del no_url["url"]
492
+ creds = BaseVault.credentials_dict_to_field(no_url)
493
+
494
+ return AddCredentialsRequest(url=dic["url"], credentials=creds)
495
+
496
+
497
+ class AddCredentialsResponse(BaseModel):
498
+ status: Annotated[str, Field(description="Status of the created credentials")]
499
+
500
+
501
+ class GetCredentialsRequestDict(TypedDict, total=False):
502
+ url: str | None
503
+
504
+
505
+ class GetCredentialsRequest(BaseModel):
506
+ url: str | None
507
+
508
+ @field_validator("url", mode="before")
509
+ @classmethod
510
+ def validate_url(cls, value: str | None) -> str | None:
511
+ return validate_url(value)
512
+
513
+
514
+ class GetCredentialsResponse(BaseModel):
515
+ credentials: Annotated[list[CredentialField], Field(description="Retrieved credentials")]
516
+
517
+
518
+ class DeleteCredentialsRequestDict(TypedDict, total=False):
519
+ url: str | None
520
+
521
+
522
+ class DeleteCredentialsRequest(BaseModel):
523
+ url: str | None
524
+
525
+ @field_validator("url", mode="before")
526
+ @classmethod
527
+ def validate_url(cls, value: str | None) -> str | None:
528
+ return validate_url(value)
529
+
530
+
531
+ class DeleteCredentialsResponse(BaseModel):
532
+ status: Annotated[str, Field(description="Status of the deletion")]
533
+
534
+
535
+ # ############################################################
536
+ # Environment endpoints
537
+ # ############################################################
538
+
539
+
540
+ class PaginationParamsDict(TypedDict, total=False):
541
+ min_nb_actions: int | None
542
+ max_nb_actions: int
543
+
544
+
545
+ class PaginationParams(BaseModel):
546
+ min_nb_actions: Annotated[
547
+ int | None,
548
+ Field(
549
+ description=(
550
+ "The minimum number of actions to list before stopping. "
551
+ "If not provided, the listing will continue until the maximum number of actions is reached."
552
+ ),
553
+ ),
554
+ ] = None
555
+ max_nb_actions: Annotated[
556
+ int,
557
+ Field(
558
+ description=(
559
+ "The maximum number of actions to list after which the listing will stop. "
560
+ "Used when min_nb_actions is not provided."
561
+ ),
562
+ ),
563
+ ] = DEFAULT_MAX_NB_ACTIONS
564
+
565
+
566
+ class ObserveRequest(PaginationParams):
567
+ url: Annotated[
568
+ str | None,
569
+ Field(description="The URL to observe. If not provided, uses the current page URL."),
570
+ ] = None
571
+
572
+
573
+ class ObserveRequestDict(PaginationParamsDict, total=False):
574
+ url: str | None
575
+
576
+
577
+ class ScrapeParamsDict(TypedDict, total=False):
578
+ scrape_links: bool
579
+ only_main_content: bool
580
+ response_format: type[BaseModel] | None
581
+ instructions: str | None
582
+ use_llm: bool | None
583
+
584
+
585
+ class ScrapeRequestDict(ScrapeParamsDict, total=False):
586
+ url: str | None
587
+
588
+
589
+ class ScrapeParams(BaseModel):
590
+ scrape_links: Annotated[
591
+ bool,
592
+ Field(description="Whether to scrape links from the page. Links are scraped by default."),
593
+ ] = True
594
+
595
+ only_main_content: Annotated[
596
+ bool,
597
+ Field(
598
+ description=(
599
+ "Whether to only scrape the main content of the page. If True, navbars, footers, etc. are excluded."
600
+ ),
601
+ ),
602
+ ] = True
603
+
604
+ response_format: Annotated[
605
+ type[BaseModel] | None,
606
+ Field(description="The response format to use for the scrape."),
607
+ ] = None
608
+ instructions: Annotated[str | None, Field(description="The instructions to use for the scrape.")] = None
609
+
610
+ use_llm: Annotated[
611
+ bool | None,
612
+ Field(
613
+ description=(
614
+ "Whether to use an LLM for the extraction process. This will result in a longer response time but a"
615
+ " better accuracy. If not provided, the default value is the same as the NotteSession config."
616
+ )
617
+ ),
618
+ ] = None
619
+
620
+ def requires_schema(self) -> bool:
621
+ return self.response_format is not None or self.instructions is not None
622
+
623
+ @field_validator("response_format", mode="before")
624
+ @classmethod
625
+ def convert_response_format(cls, value: dict[str, Any] | type[BaseModel] | None) -> type[BaseModel] | None:
626
+ """
627
+ Creates a Pydantic model from a given JSON Schema.
628
+
629
+ Args:
630
+ schema_name: The name of the model to be created.
631
+ schema_json: The JSON Schema definition.
632
+
633
+ Returns:
634
+ The dynamically created Pydantic model class.
635
+ """
636
+ if value is None:
637
+ return None
638
+ if isinstance(value, type) and issubclass(value, BaseModel): # type: ignore[arg-type]
639
+ return value
640
+ if not isinstance(value, dict): # type: ignore[arg-type]
641
+ raise ValueError(f"response_format must be a BaseModel or a dict but got: {type(value)} : {value}") # type: ignore[unreachable]
642
+ if len(value.keys()) == 0:
643
+ return None
644
+
645
+ return create_model_from_schema(value)
646
+
647
+ @override
648
+ def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
649
+ dump = super().model_dump(*args, **kwargs)
650
+ if isinstance(self.response_format, type) and issubclass(self.response_format, BaseModel): # pyright: ignore[reportUnnecessaryIsInstance]
651
+ dump["response_format"] = self.response_format.model_json_schema()
652
+ return dump
653
+
654
+ @override
655
+ def model_dump_json(self, *args: Any, **kwargs: Any) -> str:
656
+ dump = self.model_dump(*args, **kwargs)
657
+ if isinstance(self.response_format, type) and issubclass(self.response_format, BaseModel): # pyright: ignore[reportUnnecessaryIsInstance]
658
+ dump["response_format"] = self.response_format.model_json_schema()
659
+ return json.dumps(dump)
660
+
661
+
662
+ class ScrapeRequest(ScrapeParams):
663
+ url: Annotated[
664
+ str | None,
665
+ Field(description="The URL to scrape. If not provided, uses the current page URL."),
666
+ ] = None
667
+
668
+
669
+ class StepRequest(PaginationParams):
670
+ action_id: Annotated[str, Field(description="The ID of the action to execute")]
671
+
672
+ value: Annotated[str | None, Field(description="The value to input for form actions")] = None
673
+
674
+ enter: Annotated[
675
+ bool | None,
676
+ Field(description="Whether to press enter after inputting the value"),
677
+ ] = None
678
+
679
+
680
+ class StepRequestDict(PaginationParamsDict, total=False):
681
+ action_id: str
682
+ value: str | None
683
+ enter: bool | None
684
+
685
+
686
+ class ActionSpaceResponse(BaseModel):
687
+ markdown: Annotated[str | None, Field(description="Markdown representation of the action space")] = None
688
+ actions: Annotated[
689
+ Sequence[Action],
690
+ Field(description="List of available actions in the current state"),
691
+ ]
692
+ browser_actions: Annotated[
693
+ Sequence[BrowserAction],
694
+ Field(description="List of special actions, i.e browser actions"),
695
+ ]
696
+ # TODO: ActionSpaceResponse should be a subclass of ActionSpace
697
+ description: str
698
+ category: str | None = None
699
+
700
+ @staticmethod
701
+ def from_space(space: BaseActionSpace) -> "ActionSpaceResponse":
702
+ return ActionSpaceResponse(
703
+ markdown=space.markdown(),
704
+ description=space.description,
705
+ category=space.category,
706
+ actions=space.actions(), # type: ignore[arg-type]
707
+ browser_actions=space.browser_actions(), # type: ignore[arg-type]
708
+ )
709
+
710
+
711
+ class ScrapeResponse(BaseModel):
712
+ session: Annotated[SessionResponse, Field(description="Browser session information")]
713
+ data: Annotated[DataSpace, Field(description="Data extracted from the current page")]
714
+
715
+
716
+ class ObserveResponse(BaseModel):
717
+ session: Annotated[SessionResponse, Field(description="Browser session information")]
718
+ space: Annotated[
719
+ ActionSpaceResponse,
720
+ Field(description="Available actions in the current state"),
721
+ ]
722
+ metadata: SnapshotMetadata
723
+ screenshot: bytes | None = Field(repr=False)
724
+ data: DataSpace | None
725
+ progress: TrajectoryProgress | None
726
+
727
+ model_config = { # type: ignore[attr-defined]
728
+ "json_encoders": {
729
+ bytes: lambda v: b64encode(v).decode("utf-8") if v else None,
730
+ }
731
+ }
732
+
733
+ @override
734
+ def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
735
+ data = super().model_dump(*args, **kwargs)
736
+ if self.screenshot is not None:
737
+ data["screenshot"] = b64encode(self.screenshot).decode("utf-8")
738
+ return data
739
+
740
+ @staticmethod
741
+ def from_obs(
742
+ obs: Observation,
743
+ session: SessionResponse,
744
+ ) -> "ObserveResponse":
745
+ return ObserveResponse(
746
+ session=session,
747
+ metadata=obs.metadata,
748
+ screenshot=obs.screenshot,
749
+ data=obs.data,
750
+ space=ActionSpaceResponse.from_space(obs.space),
751
+ progress=obs.progress,
752
+ )
753
+
754
+
755
+ # ############################################################
756
+ # Agent endpoints
757
+ # ############################################################
758
+
759
+
760
+ class AgentStatus(StrEnum):
761
+ active = "active"
762
+ closed = "closed"
763
+
764
+
765
+ class AgentSessionRequest(BaseModel):
766
+ agent_id: Annotated[str, Field(description="The ID of the agent to run")]
767
+
768
+
769
+ class AgentCreateRequestDict(SessionRequestDict, total=False):
770
+ reasoning_model: LlmModel
771
+ use_vision: bool
772
+ max_steps: int
773
+ persona_id: str | None
774
+ vault_id: str | None
775
+
776
+
777
+ class AgentRunRequestDict(TypedDict, total=False):
778
+ task: Required[str]
779
+ url: str | None
780
+
781
+
782
+ class AgentStartRequestDict(AgentCreateRequestDict, AgentRunRequestDict, total=False):
783
+ pass
784
+
785
+
786
+ class AgentCreateRequest(SessionRequest):
787
+ reasoning_model: Annotated[LlmModel, Field(description="The reasoning model to use")] = LlmModel.default()
788
+ use_vision: Annotated[
789
+ bool, Field(description="Whether to use vision for the agent. Not all reasoning models support vision.")
790
+ ] = True
791
+ max_steps: Annotated[int, Field(description="The maximum number of steps the agent should take")] = (
792
+ DEFAULT_MAX_NB_STEPS
793
+ )
794
+ persona_id: Annotated[str | None, Field(description="The persona to use for the agent")] = None
795
+ vault_id: Annotated[str | None, Field(description="The vault to use for the agent")] = None
796
+
797
+
798
+ class AgentRunRequest(BaseModel):
799
+ task: Annotated[str, Field(description="The task that the agent should perform")]
800
+ url: Annotated[str | None, Field(description="The URL that the agent should start on (optional)")] = None
801
+
802
+
803
+ class AgentStartRequest(AgentCreateRequest, AgentRunRequest):
804
+ pass
805
+
806
+
807
+ class AgentStatusRequestDict(TypedDict, total=False):
808
+ agent_id: Required[Annotated[str, Field(description="The ID of the agent for which to get the status")]]
809
+ replay: bool
810
+
811
+
812
+ class AgentStatusRequest(AgentSessionRequest):
813
+ replay: Annotated[bool, Field(description="Whether to include the replay in the response")] = False
814
+
815
+
816
+ class AgentListRequest(SessionListRequest):
817
+ pass
818
+
819
+
820
+ class AgentStopRequest(AgentSessionRequest, ReplayResponse):
821
+ success: Annotated[bool, Field(description="Whether the agent task was successful")] = False
822
+ answer: Annotated[str, Field(description="The answer to the agent task")] = "Agent manually stopped by user"
823
+
824
+
825
+ class AgentResponse(BaseModel):
826
+ agent_id: Annotated[str, Field(description="The ID of the agent")]
827
+ created_at: Annotated[dt.datetime, Field(description="The creation time of the agent")]
828
+ session_id: Annotated[str, Field(description="The ID of the session")]
829
+ status: Annotated[AgentStatus, Field(description="The status of the agent (active or closed)")]
830
+ closed_at: Annotated[dt.datetime | None, Field(description="The closing time of the agent")] = None
831
+
832
+
833
+ TStepOutput = TypeVar("TStepOutput", bound=BaseModel)
834
+
835
+
836
+ class AgentStatusResponse(AgentResponse, ReplayResponse, Generic[TStepOutput]):
837
+ task: Annotated[str, Field(description="The task that the agent is currently running")]
838
+ url: Annotated[str | None, Field(description="The URL that the agent started on")] = None
839
+
840
+ success: Annotated[
841
+ bool | None,
842
+ Field(description="Whether the agent task was successful. None if the agent is still running"),
843
+ ] = None
844
+ answer: Annotated[
845
+ str | None,
846
+ Field(description="The answer to the agent task. None if the agent is still running"),
847
+ ] = None
848
+ steps: Annotated[
849
+ list[TStepOutput],
850
+ Field(description="The steps that the agent has currently taken"),
851
+ ] = Field(default_factory=lambda: [])