clerk-sdk 0.1.9__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clerk/base.py +94 -0
- clerk/client.py +3 -104
- clerk/decorator/models.py +1 -0
- clerk/decorator/task_decorator.py +1 -0
- clerk/gui_automation/__init__.py +0 -0
- clerk/gui_automation/action_model/__init__.py +0 -0
- clerk/gui_automation/action_model/model.py +126 -0
- clerk/gui_automation/action_model/utils.py +26 -0
- clerk/gui_automation/client.py +144 -0
- clerk/gui_automation/client_actor/__init__.py +4 -0
- clerk/gui_automation/client_actor/client_actor.py +178 -0
- clerk/gui_automation/client_actor/exception.py +22 -0
- clerk/gui_automation/client_actor/model.py +192 -0
- clerk/gui_automation/decorators/__init__.py +1 -0
- clerk/gui_automation/decorators/gui_automation.py +109 -0
- clerk/gui_automation/exceptions/__init__.py +0 -0
- clerk/gui_automation/exceptions/modality/__init__.py +0 -0
- clerk/gui_automation/exceptions/modality/exc.py +46 -0
- clerk/gui_automation/exceptions/websocket.py +6 -0
- clerk/gui_automation/ui_actions/__init__.py +1 -0
- clerk/gui_automation/ui_actions/actions.py +781 -0
- clerk/gui_automation/ui_actions/base.py +200 -0
- clerk/gui_automation/ui_actions/support.py +68 -0
- clerk/gui_automation/ui_state_inspector/__init__.py +0 -0
- clerk/gui_automation/ui_state_inspector/gui_vision.py +184 -0
- clerk/gui_automation/ui_state_inspector/models.py +184 -0
- clerk/gui_automation/ui_state_machine/__init__.py +11 -0
- clerk/gui_automation/ui_state_machine/ai_recovery.py +110 -0
- clerk/gui_automation/ui_state_machine/decorators.py +71 -0
- clerk/gui_automation/ui_state_machine/exceptions.py +42 -0
- clerk/gui_automation/ui_state_machine/models.py +40 -0
- clerk/gui_automation/ui_state_machine/state_machine.py +838 -0
- clerk/models/remote_device.py +7 -0
- clerk/utils/__init__.py +0 -0
- clerk/utils/logger.py +118 -0
- clerk/utils/save_artifact.py +35 -0
- {clerk_sdk-0.1.9.dist-info → clerk_sdk-0.2.0.dist-info}/METADATA +11 -1
- clerk_sdk-0.2.0.dist-info/RECORD +48 -0
- clerk_sdk-0.1.9.dist-info/RECORD +0 -15
- {clerk_sdk-0.1.9.dist-info → clerk_sdk-0.2.0.dist-info}/WHEEL +0 -0
- {clerk_sdk-0.1.9.dist-info → clerk_sdk-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {clerk_sdk-0.1.9.dist-info → clerk_sdk-0.2.0.dist-info}/top_level.txt +0 -0
clerk/base.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import requests
|
|
4
|
+
import backoff
|
|
5
|
+
from typing import Dict, Optional, Self
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, model_validator, Field
|
|
9
|
+
|
|
10
|
+
from .models.response_model import StandardResponse
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def giveup_handler(e):
|
|
14
|
+
return (
|
|
15
|
+
isinstance(e, requests.exceptions.HTTPError)
|
|
16
|
+
and e.response is not None
|
|
17
|
+
and e.response.status_code < 500
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class BaseClerk(BaseModel):
|
|
22
|
+
api_key: Optional[str] = Field(default=None, min_length=1)
|
|
23
|
+
headers: Dict[str, str] = Field(default_factory=dict)
|
|
24
|
+
base_url: str = Field(
|
|
25
|
+
default_factory=lambda: os.getenv("CLERK_BASE_URL", "https://api.clerk-app.com")
|
|
26
|
+
)
|
|
27
|
+
root_endpoint: Optional[str] = None
|
|
28
|
+
|
|
29
|
+
@model_validator(mode="after")
|
|
30
|
+
def validate_api_key(self) -> Self:
|
|
31
|
+
if not self.api_key:
|
|
32
|
+
self.api_key = os.getenv("CLERK_API_KEY")
|
|
33
|
+
|
|
34
|
+
if not self.api_key:
|
|
35
|
+
raise ValueError("API key has not been provided.")
|
|
36
|
+
|
|
37
|
+
self.headers = {"Authorization": f"Bearer {self.api_key}"}
|
|
38
|
+
return self
|
|
39
|
+
|
|
40
|
+
@backoff.on_exception(
|
|
41
|
+
backoff.expo,
|
|
42
|
+
(requests.exceptions.RequestException,),
|
|
43
|
+
max_tries=3,
|
|
44
|
+
jitter=None,
|
|
45
|
+
# on_backoff=backoff_handler,
|
|
46
|
+
giveup=giveup_handler,
|
|
47
|
+
)
|
|
48
|
+
def get_request(
|
|
49
|
+
self,
|
|
50
|
+
endpoint: str,
|
|
51
|
+
headers: Dict[str, str] = {},
|
|
52
|
+
json: Dict = {},
|
|
53
|
+
params: Dict = {},
|
|
54
|
+
) -> StandardResponse:
|
|
55
|
+
|
|
56
|
+
merged_headers = {**self.headers, **headers}
|
|
57
|
+
url = f"{self.base_url}{endpoint}"
|
|
58
|
+
if self.root_endpoint:
|
|
59
|
+
url = f"{self.base_url}{self.root_endpoint}{endpoint}"
|
|
60
|
+
|
|
61
|
+
# logger.info(f"GET {url} | params={params}")
|
|
62
|
+
|
|
63
|
+
response = requests.get(url, headers=merged_headers, json=json, params=params)
|
|
64
|
+
response.raise_for_status()
|
|
65
|
+
|
|
66
|
+
return StandardResponse(**response.json())
|
|
67
|
+
|
|
68
|
+
@backoff.on_exception(
|
|
69
|
+
backoff.expo,
|
|
70
|
+
(requests.exceptions.RequestException,),
|
|
71
|
+
max_tries=3,
|
|
72
|
+
jitter=None,
|
|
73
|
+
# on_backoff=backoff_handler,
|
|
74
|
+
giveup=giveup_handler,
|
|
75
|
+
)
|
|
76
|
+
def post_request(
|
|
77
|
+
self,
|
|
78
|
+
endpoint: str,
|
|
79
|
+
headers: Dict[str, str] = {},
|
|
80
|
+
json: Dict = {},
|
|
81
|
+
params: Dict = {},
|
|
82
|
+
) -> StandardResponse:
|
|
83
|
+
|
|
84
|
+
merged_headers = {**self.headers, **headers}
|
|
85
|
+
url = f"{self.base_url}{endpoint}"
|
|
86
|
+
if self.root_endpoint:
|
|
87
|
+
url = f"{self.base_url}{self.root_endpoint}{endpoint}"
|
|
88
|
+
|
|
89
|
+
# logger.info(f"POST {url} | body={json} | params={params}")
|
|
90
|
+
|
|
91
|
+
response = requests.post(url, headers=merged_headers, json=json, params=params)
|
|
92
|
+
response.raise_for_status()
|
|
93
|
+
|
|
94
|
+
return StandardResponse(**response.json())
|
clerk/client.py
CHANGED
|
@@ -1,112 +1,11 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
3
|
-
# import logging
|
|
4
|
-
import requests
|
|
5
|
-
import backoff
|
|
6
|
-
from typing import Dict, List, Optional, Self
|
|
1
|
+
from typing import List
|
|
7
2
|
from xml.dom.minidom import Document
|
|
8
3
|
|
|
9
|
-
|
|
10
|
-
from pydantic import BaseModel, model_validator, Field
|
|
11
|
-
|
|
4
|
+
from clerk.base import BaseClerk
|
|
12
5
|
from .models.file import ParsedFile
|
|
13
|
-
from .models.response_model import StandardResponse
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
# logger = logging.getLogger(__name__)
|
|
17
|
-
# logger.setLevel(logging.INFO)
|
|
18
|
-
|
|
19
|
-
# if not logger.handlers:
|
|
20
|
-
# handler = logging.StreamHandler()
|
|
21
|
-
# formatter = logging.Formatter("[%(levelname)s] %(asctime)s - %(message)s")
|
|
22
|
-
# handler.setFormatter(formatter)
|
|
23
|
-
# logger.addHandler(handler)
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
# def backoff_handler(details):
|
|
27
|
-
# logger.warning(
|
|
28
|
-
# f"Retrying {details['target'].__name__} after {details['tries']} tries..."
|
|
29
|
-
# )
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
def giveup_handler(e):
|
|
33
|
-
return (
|
|
34
|
-
isinstance(e, requests.exceptions.HTTPError)
|
|
35
|
-
and e.response is not None
|
|
36
|
-
and e.response.status_code < 500
|
|
37
|
-
)
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
class Clerk(BaseModel):
|
|
41
|
-
api_key: Optional[str] = Field(default=None, min_length=1)
|
|
42
|
-
headers: Dict[str, str] = Field(default_factory=dict)
|
|
43
|
-
base_url: str = Field(
|
|
44
|
-
default_factory=lambda: os.getenv("CLERK_BASE_URL", "https://api.clerk-app.com")
|
|
45
|
-
)
|
|
46
|
-
|
|
47
|
-
@model_validator(mode="after")
|
|
48
|
-
def validate_api_key(self) -> Self:
|
|
49
|
-
if not self.api_key:
|
|
50
|
-
self.api_key = os.getenv("CLERK_API_KEY")
|
|
51
|
-
|
|
52
|
-
if not self.api_key:
|
|
53
|
-
raise ValueError("API key has not been provided.")
|
|
54
|
-
|
|
55
|
-
self.headers = {"Authorization": f"Bearer {self.api_key}"}
|
|
56
|
-
return self
|
|
57
|
-
|
|
58
|
-
@backoff.on_exception(
|
|
59
|
-
backoff.expo,
|
|
60
|
-
(requests.exceptions.RequestException,),
|
|
61
|
-
max_tries=3,
|
|
62
|
-
jitter=None,
|
|
63
|
-
# on_backoff=backoff_handler,
|
|
64
|
-
giveup=giveup_handler,
|
|
65
|
-
)
|
|
66
|
-
def get_request(
|
|
67
|
-
self,
|
|
68
|
-
endpoint: str,
|
|
69
|
-
headers: Dict[str, str] = {},
|
|
70
|
-
json: Dict = {},
|
|
71
|
-
params: Dict = {},
|
|
72
|
-
) -> StandardResponse:
|
|
73
|
-
|
|
74
|
-
merged_headers = {**self.headers, **headers}
|
|
75
|
-
url = f"{self.base_url}{endpoint}"
|
|
76
|
-
|
|
77
|
-
# logger.info(f"GET {url} | params={params}")
|
|
78
|
-
|
|
79
|
-
response = requests.get(url, headers=merged_headers, json=json, params=params)
|
|
80
|
-
response.raise_for_status()
|
|
81
|
-
|
|
82
|
-
return StandardResponse(**response.json())
|
|
83
|
-
|
|
84
|
-
@backoff.on_exception(
|
|
85
|
-
backoff.expo,
|
|
86
|
-
(requests.exceptions.RequestException,),
|
|
87
|
-
max_tries=3,
|
|
88
|
-
jitter=None,
|
|
89
|
-
# on_backoff=backoff_handler,
|
|
90
|
-
giveup=giveup_handler,
|
|
91
|
-
)
|
|
92
|
-
def post_request(
|
|
93
|
-
self,
|
|
94
|
-
endpoint: str,
|
|
95
|
-
headers: Dict[str, str] = {},
|
|
96
|
-
json: Dict = {},
|
|
97
|
-
params: Dict = {},
|
|
98
|
-
) -> StandardResponse:
|
|
99
|
-
|
|
100
|
-
merged_headers = {**self.headers, **headers}
|
|
101
|
-
url = f"{self.base_url}{endpoint}"
|
|
102
|
-
|
|
103
|
-
# logger.info(f"POST {url} | body={json} | params={params}")
|
|
104
|
-
|
|
105
|
-
response = requests.post(url, headers=merged_headers, json=json, params=params)
|
|
106
|
-
response.raise_for_status()
|
|
107
6
|
|
|
108
|
-
return StandardResponse(**response.json())
|
|
109
7
|
|
|
8
|
+
class Clerk(BaseClerk):
|
|
110
9
|
def get_document(self, document_id: str) -> Document:
|
|
111
10
|
endpoint = f"/document/{document_id}"
|
|
112
11
|
res = self.get_request(endpoint=endpoint)
|
clerk/decorator/models.py
CHANGED
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import os
|
|
3
|
+
from typing import List, Literal, Optional, Union
|
|
4
|
+
from pydantic import BaseModel, Field, validator
|
|
5
|
+
|
|
6
|
+
CoordsType = Union[List[float], List[int]]
|
|
7
|
+
|
|
8
|
+
PredictionsFormat = Union[
|
|
9
|
+
Literal["xyxy"], Literal["xyxyn"], Literal["xywh"], Literal["xywhn"]
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
RelationsType = Union[
|
|
13
|
+
Literal["above"], Literal["below"], Literal["left"], Literal["right"], Literal[""]
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ImageB64(BaseModel):
|
|
18
|
+
"""
|
|
19
|
+
A class representing an image encoded in base64 format.
|
|
20
|
+
|
|
21
|
+
Attributes:
|
|
22
|
+
id (Optional[str]): The ID of the image. Defaults to None.
|
|
23
|
+
value (str): The base64 encoded value of the image.
|
|
24
|
+
|
|
25
|
+
Methods:
|
|
26
|
+
from_path(value: Union[str, "ImageB64"]) -> "ImageB64":
|
|
27
|
+
Creates an ImageB64 instance from a file path or an existing ImageB64 instance.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
value (Union[str, "ImageB64"]): The file path or an existing ImageB64 instance.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
ImageB64: The created ImageB64 instance.
|
|
34
|
+
|
|
35
|
+
_to_b64(path: str) -> str:
|
|
36
|
+
Encodes the image file at the given path to base64 format.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
path (str): The path to the image file.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
str: The base64 encoded image.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
id: Optional[str] = None
|
|
46
|
+
value: str = ""
|
|
47
|
+
|
|
48
|
+
@classmethod
|
|
49
|
+
def from_path(cls, value: Union[str, "ImageB64"]) -> "ImageB64":
|
|
50
|
+
if isinstance(value, ImageB64):
|
|
51
|
+
return value
|
|
52
|
+
return ImageB64(
|
|
53
|
+
id=os.path.basename(value),
|
|
54
|
+
value=to_b64(value),
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def to_b64(path: str) -> str:
|
|
59
|
+
with open(path, "rb") as f:
|
|
60
|
+
img_b64: str = base64.b64encode(f.read()).decode("utf-8")
|
|
61
|
+
return img_b64
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class Anchor(BaseModel):
|
|
65
|
+
"""
|
|
66
|
+
A class representing an anchor for a screenshot.
|
|
67
|
+
|
|
68
|
+
Attributes:
|
|
69
|
+
value (Union[str, ImageB64]): The value of the anchor, which can be a string or an ImageB64 instance.
|
|
70
|
+
relation (RelationsType): The relation of the anchor to the target, which can be one of the following: "above", "below", "left", "right", or an empty string.
|
|
71
|
+
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
value: Union[str, ImageB64] = ""
|
|
75
|
+
relation: RelationsType = ""
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class Screenshot(BaseModel):
|
|
79
|
+
"""
|
|
80
|
+
A class representing a screenshot.
|
|
81
|
+
|
|
82
|
+
Attributes:
|
|
83
|
+
screen_b64 (ImageB64): The base64 encoded value of the screenshot.
|
|
84
|
+
target (Union[str, ImageB64]): The target of the screenshot, which can be a string or an ImageB64 instance.
|
|
85
|
+
anchors (List[Anchor]): The list of anchors for the screenshot.
|
|
86
|
+
is_awaited (bool): A flag to signal whether the target should appear immediately or is awaited.
|
|
87
|
+
target_name (Optional[str]): A readable representation of a target which is set automatically when validating the target and is used in the AM for logging.
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
screen_b64: ImageB64
|
|
93
|
+
target: Union[str, ImageB64]
|
|
94
|
+
anchors: List[Anchor] = []
|
|
95
|
+
is_awaited: bool = False
|
|
96
|
+
target_name: Optional[str] = None
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class Coords(BaseModel):
|
|
100
|
+
"""
|
|
101
|
+
A class representing coordinates.
|
|
102
|
+
|
|
103
|
+
Attributes:
|
|
104
|
+
value (CoordsType): The value of the coordinates, which can be a list of floats or a list of integers.
|
|
105
|
+
score (int): The score associated with the coordinates, defaults to 0.
|
|
106
|
+
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
value: CoordsType
|
|
110
|
+
score: int = 0
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class RouterOutput(BaseModel):
|
|
114
|
+
"""
|
|
115
|
+
A class representing the output of a router.
|
|
116
|
+
|
|
117
|
+
Attributes:
|
|
118
|
+
Resources (List[Coords]): A list of coordinates representing the resources.
|
|
119
|
+
StatusMessage (Union[Literal["Success"], Literal["Failure"], None]): The status message of the router output.
|
|
120
|
+
ErrorMessage (str): The error message associated with the router output.
|
|
121
|
+
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
Resources: List[Coords] = []
|
|
125
|
+
StatusMessage: Union[Literal["Success"], Literal["Failure"], None] = None
|
|
126
|
+
ErrorMessage: str = ""
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from .model import Coords, Screenshot
|
|
2
|
+
from ..decorators.gui_automation import clerk_client
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def get_coordinates(payload: Screenshot) -> Coords:
|
|
6
|
+
"""
|
|
7
|
+
Get coordinates from the action model API endpoint.
|
|
8
|
+
|
|
9
|
+
The method requires the following environmental variables to work:
|
|
10
|
+
- AM_URL: action model URL
|
|
11
|
+
|
|
12
|
+
Parameters:
|
|
13
|
+
payload (Screenshot): The payload containing the necessary data for the request.
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
Coords: The coordinates obtained from the API response.
|
|
17
|
+
|
|
18
|
+
Raises:
|
|
19
|
+
RuntimeError: If the API response status code is not 200.
|
|
20
|
+
|
|
21
|
+
Example:
|
|
22
|
+
payload = Screenshot(screen_b64="base64_encoded_image", target="target_image")
|
|
23
|
+
coordinates = get_coordinates(payload)
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
return clerk_client.get_coordinates(payload.model_dump())
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
from typing import Dict, List, Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
from clerk.base import BaseClerk
|
|
5
|
+
from clerk.gui_automation.action_model.model import Coords
|
|
6
|
+
from clerk.gui_automation.ui_state_inspector.models import (
|
|
7
|
+
ActionString,
|
|
8
|
+
BaseState,
|
|
9
|
+
States,
|
|
10
|
+
TargetWithAnchor,
|
|
11
|
+
)
|
|
12
|
+
from clerk.models.remote_device import RemoteDevice
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class RPAClerk(BaseClerk):
|
|
16
|
+
|
|
17
|
+
root_endpoint: str = "/gui_automation"
|
|
18
|
+
|
|
19
|
+
def allocate_remote_device(self, group_name: str, run_id: str):
|
|
20
|
+
endpoint = "/remote_device/allocate"
|
|
21
|
+
res = self.post_request(
|
|
22
|
+
endpoint=endpoint, json={"group_name": group_name, "run_id": run_id}
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
return RemoteDevice(**res.data[0])
|
|
26
|
+
|
|
27
|
+
def deallocate_remote_device(
|
|
28
|
+
self,
|
|
29
|
+
remote_device: RemoteDevice,
|
|
30
|
+
run_id: str,
|
|
31
|
+
):
|
|
32
|
+
endpoint = "/remote_device/deallocate"
|
|
33
|
+
self.post_request(
|
|
34
|
+
endpoint=endpoint,
|
|
35
|
+
json={"id": remote_device.id, "name": remote_device.name, "run_id": run_id},
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
def get_coordinates(self, payload: Dict) -> Coords:
|
|
39
|
+
endpoint = "/action_model/get_coordinates"
|
|
40
|
+
res = self.post_request(endpoint=endpoint, json=payload)
|
|
41
|
+
if res.data[0] is None:
|
|
42
|
+
raise RuntimeError("No coordinates found in the response.")
|
|
43
|
+
return Coords(**res.data[0])
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class GUIVisionClerk(BaseClerk):
|
|
47
|
+
root_endpoint: str = "/gui_automation/vision"
|
|
48
|
+
|
|
49
|
+
def find_target(self, screen_b64: str, use_ocr: bool, target_prompt: str):
|
|
50
|
+
endpoint = "/find_target"
|
|
51
|
+
res = self.post_request(
|
|
52
|
+
endpoint=endpoint,
|
|
53
|
+
json={
|
|
54
|
+
"screen_b64": screen_b64,
|
|
55
|
+
"use_ocr": use_ocr,
|
|
56
|
+
"target_prompt": target_prompt,
|
|
57
|
+
},
|
|
58
|
+
)
|
|
59
|
+
return TargetWithAnchor(**res.data[0])
|
|
60
|
+
|
|
61
|
+
def verify_state(
|
|
62
|
+
self, screen_b64: str, use_ocr: bool, possible_states: States
|
|
63
|
+
) -> BaseState:
|
|
64
|
+
endpoint = "/verify_state"
|
|
65
|
+
res = self.post_request(
|
|
66
|
+
endpoint=endpoint,
|
|
67
|
+
json={
|
|
68
|
+
"screen_b64": screen_b64,
|
|
69
|
+
"use_ocr": use_ocr,
|
|
70
|
+
"possible_states": possible_states,
|
|
71
|
+
},
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
return BaseState(**res.data[0])
|
|
75
|
+
|
|
76
|
+
def answer(
|
|
77
|
+
self, screen_b64: str, use_ocr: bool, question: str, output_model: BaseModel
|
|
78
|
+
) -> Dict:
|
|
79
|
+
endpoint = "/answer"
|
|
80
|
+
res = self.post_request(
|
|
81
|
+
endpoint=endpoint,
|
|
82
|
+
json={
|
|
83
|
+
"screen_b64": screen_b64,
|
|
84
|
+
"use_ocr": use_ocr,
|
|
85
|
+
"question": question,
|
|
86
|
+
"output_model": output_model.model_json_schema(),
|
|
87
|
+
},
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
return output_model(**res.data[0])
|
|
91
|
+
|
|
92
|
+
def classify_state(
|
|
93
|
+
self, screen_b64: str, use_ocr: bool, possible_states: List[Dict[str, str]]
|
|
94
|
+
) -> BaseState:
|
|
95
|
+
endpoint = "/classify_state"
|
|
96
|
+
res = self.post_request(
|
|
97
|
+
endpoint=endpoint,
|
|
98
|
+
json={
|
|
99
|
+
"screen_b64": screen_b64,
|
|
100
|
+
"use_ocr": use_ocr,
|
|
101
|
+
"possible_states": possible_states,
|
|
102
|
+
},
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
return BaseState(**res.data[0])
|
|
106
|
+
|
|
107
|
+
def write_action_string(
|
|
108
|
+
self, screen_b64: str, use_ocr: bool, action_prompt: str
|
|
109
|
+
) -> ActionString:
|
|
110
|
+
endpoint = "/write_action-string"
|
|
111
|
+
res = self.post_request(
|
|
112
|
+
endpoint=endpoint,
|
|
113
|
+
json={
|
|
114
|
+
"screen_b64": screen_b64,
|
|
115
|
+
"use_ocr": use_ocr,
|
|
116
|
+
"action_prompt": action_prompt,
|
|
117
|
+
},
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
return ActionString(**res.data[0])
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class CourseCorrectorClerk(BaseClerk):
|
|
124
|
+
root_endpoint: str = "/gui_automation/course_corrector"
|
|
125
|
+
|
|
126
|
+
def get_corrective_actions(
|
|
127
|
+
self,
|
|
128
|
+
screen_b64: str,
|
|
129
|
+
use_ocr: str,
|
|
130
|
+
goal: str,
|
|
131
|
+
custom_instructions: Optional[str] = None,
|
|
132
|
+
) -> ActionString:
|
|
133
|
+
endpoint = "/get_corrective_actions"
|
|
134
|
+
res = self.post_request(
|
|
135
|
+
endpoint=endpoint,
|
|
136
|
+
json={
|
|
137
|
+
"screen_b64": screen_b64,
|
|
138
|
+
"use_ocr": use_ocr,
|
|
139
|
+
"goal": goal,
|
|
140
|
+
"custom_instructions": custom_instructions,
|
|
141
|
+
},
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
return ActionString(**res.data[0])
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any, Dict, Union
|
|
5
|
+
|
|
6
|
+
import pydantic
|
|
7
|
+
import requests
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
from .model import (
|
|
11
|
+
ExecutePayload,
|
|
12
|
+
DeleteFilesExecutePayload,
|
|
13
|
+
ApplicationExecutePayload,
|
|
14
|
+
SaveFilesExecutePayload,
|
|
15
|
+
WindowExecutePayload,
|
|
16
|
+
GetFileExecutePayload,
|
|
17
|
+
)
|
|
18
|
+
import backoff
|
|
19
|
+
|
|
20
|
+
from .model import PerformActionResponse, ActionStates
|
|
21
|
+
from .exception import PerformActionException, GetScreenError
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
async def _perform_action_ws(payload: Dict) -> PerformActionResponse:
|
|
25
|
+
"""Perform an action over a WebSocket connection.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
payload (Dict): The payload request to be sent.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
PerformActionResponse: The response of performing the action.
|
|
32
|
+
|
|
33
|
+
Raises:
|
|
34
|
+
RuntimeError: If the ACK message is not received within the specified timeout.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
from ..decorators.gui_automation import global_ws
|
|
38
|
+
|
|
39
|
+
# 1. Send the payload request
|
|
40
|
+
if global_ws:
|
|
41
|
+
await global_ws.send(json.dumps(payload))
|
|
42
|
+
|
|
43
|
+
# 2. wait for ack message
|
|
44
|
+
try:
|
|
45
|
+
ack = await asyncio.wait_for(global_ws.recv(), 90)
|
|
46
|
+
if ack == "OK":
|
|
47
|
+
action_info = await asyncio.wait_for(global_ws.recv(), 90)
|
|
48
|
+
return PerformActionResponse(**json.loads(action_info))
|
|
49
|
+
else:
|
|
50
|
+
raise RuntimeError("Received ACK != OK")
|
|
51
|
+
except asyncio.TimeoutError:
|
|
52
|
+
raise RuntimeError("The ack message did not arrive.")
|
|
53
|
+
else:
|
|
54
|
+
raise RuntimeError("The Websocket has not been initiated.")
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
async def _get_screen_async() -> str:
|
|
58
|
+
"""
|
|
59
|
+
Asynchronously retrieves a screen using a WebSocket connection.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
str: The base64 encoded screen image.
|
|
63
|
+
|
|
64
|
+
Note:
|
|
65
|
+
This function sends a request to perform a screenshot action over a WebSocket connection
|
|
66
|
+
and returns the base64 encoded image of the screen captured.
|
|
67
|
+
"""
|
|
68
|
+
payload = {
|
|
69
|
+
"proc_inst_id": os.getenv("PROC_ID"),
|
|
70
|
+
"client_name": os.getenv("REMOTE_DEVICE_NAME"),
|
|
71
|
+
"headless": (
|
|
72
|
+
True if os.getenv("HEADLESS", "True").lower() == "true" else False
|
|
73
|
+
),
|
|
74
|
+
"action": {"action_type": "screenshot"},
|
|
75
|
+
}
|
|
76
|
+
try:
|
|
77
|
+
action_info = await _perform_action_ws(payload)
|
|
78
|
+
except Exception as e:
|
|
79
|
+
if str(e) in (
|
|
80
|
+
"The ack message did not arrive.",
|
|
81
|
+
"Received ACK != OK",
|
|
82
|
+
):
|
|
83
|
+
raise GetScreenError("The ack message did not arrive.")
|
|
84
|
+
raise # else raise the error
|
|
85
|
+
|
|
86
|
+
if action_info.screen_b64 is not None:
|
|
87
|
+
return action_info.screen_b64
|
|
88
|
+
raise GetScreenError()
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@backoff.on_exception(
|
|
92
|
+
backoff.expo,
|
|
93
|
+
(requests.RequestException, pydantic.ValidationError, GetScreenError),
|
|
94
|
+
max_time=120,
|
|
95
|
+
)
|
|
96
|
+
def get_screen() -> str:
|
|
97
|
+
"""
|
|
98
|
+
Request the VDI screen and return the base64 representation of the screenshot.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
str: The base64 representation of the screenshot.
|
|
102
|
+
|
|
103
|
+
Raises:
|
|
104
|
+
RuntimeError: If the request to the VDI screen fails.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
loop = asyncio.get_event_loop()
|
|
108
|
+
# asyncio.set_event_loop(loop)
|
|
109
|
+
task = loop.create_task(_get_screen_async())
|
|
110
|
+
res = loop.run_until_complete(task)
|
|
111
|
+
return res
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
async def _perform_action_async(
|
|
115
|
+
payload: Union[
|
|
116
|
+
ExecutePayload,
|
|
117
|
+
ApplicationExecutePayload,
|
|
118
|
+
WindowExecutePayload,
|
|
119
|
+
SaveFilesExecutePayload,
|
|
120
|
+
DeleteFilesExecutePayload,
|
|
121
|
+
GetFileExecutePayload,
|
|
122
|
+
],
|
|
123
|
+
) -> Any:
|
|
124
|
+
"""
|
|
125
|
+
Perform an asynchronous action based on the provided payload.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
payload (Union[ExecutePayload, ApplicationExecutePayload, WindowExecutePayload, SaveFilesExecutePayload, DeleteFilesExecutePayload, GetFileExecutePayload]): The payload containing information about the action to be performed.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
Any: The return value of the action.
|
|
132
|
+
|
|
133
|
+
Raises:
|
|
134
|
+
PerformActionException: If the action fails with an error message.
|
|
135
|
+
"""
|
|
136
|
+
req_payload: Dict = {
|
|
137
|
+
"proc_inst_id": os.getenv("PROC_ID"),
|
|
138
|
+
"client_name": os.getenv("REMOTE_DEVICE_NAME"),
|
|
139
|
+
"headless": (
|
|
140
|
+
True if os.getenv("HEADLESS", "True").lower() == "true" else False
|
|
141
|
+
),
|
|
142
|
+
"action": payload.model_dump(),
|
|
143
|
+
}
|
|
144
|
+
action_info = await _perform_action_ws(req_payload)
|
|
145
|
+
|
|
146
|
+
if action_info.state == ActionStates.failed:
|
|
147
|
+
raise PerformActionException(action_info.message)
|
|
148
|
+
return action_info.return_value
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def perform_action(
|
|
152
|
+
payload: Union[
|
|
153
|
+
ExecutePayload,
|
|
154
|
+
ApplicationExecutePayload,
|
|
155
|
+
WindowExecutePayload,
|
|
156
|
+
SaveFilesExecutePayload,
|
|
157
|
+
DeleteFilesExecutePayload,
|
|
158
|
+
GetFileExecutePayload,
|
|
159
|
+
],
|
|
160
|
+
) -> Any:
|
|
161
|
+
"""
|
|
162
|
+
Perform an action on the VDI client.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
payload (Union[ExecutePayload, ApplicationExecutePayload, WindowExecutePayload]): The payload containing the details of the action to be performed.
|
|
166
|
+
|
|
167
|
+
Raises:
|
|
168
|
+
PerformActionException: If the action fails.
|
|
169
|
+
RuntimeError: If the request to perform the action fails.
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
Any
|
|
173
|
+
"""
|
|
174
|
+
|
|
175
|
+
loop = asyncio.get_event_loop()
|
|
176
|
+
task = loop.create_task(_perform_action_async(payload))
|
|
177
|
+
res = loop.run_until_complete(task)
|
|
178
|
+
return res
|