dojo-sdk-core 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,4 @@
1
+ __pycache__/
2
+ .venv/
3
+ .pytest_cache/
4
+ screenshots/
@@ -0,0 +1,21 @@
1
+ Metadata-Version: 2.4
2
+ Name: dojo-sdk-core
3
+ Version: 0.1.0
4
+ Summary: Core functionality of dojo
5
+ Requires-Python: >=3.9
6
+ Requires-Dist: datasets>=4.1.1
7
+ Requires-Dist: pydantic>=2.11.9
8
+ Requires-Dist: python-dotenv>=1.1.1
9
+ Provides-Extra: dev
10
+ Requires-Dist: black; extra == 'dev'
11
+ Requires-Dist: isort; extra == 'dev'
12
+ Requires-Dist: mypy; extra == 'dev'
13
+ Requires-Dist: pytest-cov; extra == 'dev'
14
+ Requires-Dist: pytest>=7.0.0; extra == 'dev'
15
+ Requires-Dist: ruff; extra == 'dev'
16
+ Description-Content-Type: text/markdown
17
+
18
+ ## dojo-core
19
+
20
+ Core functionality of dojo.
21
+ Contains models shared by other parts of the dojo project.
@@ -0,0 +1,4 @@
1
+ ## dojo-core
2
+
3
+ Core functionality of dojo.
4
+ Contains models shared by other parts of the dojo project.
@@ -0,0 +1 @@
1
+ __version__ = "0.0.1"
@@ -0,0 +1,46 @@
1
+ from core.__about__ import __version__
2
+ from core.models import EnvironmentConfig, InstructionsConfig, TaskDefinition
3
+ from core.tasks import PyTaskLoader, RemoteTaskLoader
4
+ from core.types import (
5
+ Action,
6
+ ActionType,
7
+ ClickAction,
8
+ DoneAction,
9
+ DoubleClickAction,
10
+ DragAction,
11
+ FailAction,
12
+ HotkeyAction,
13
+ KeyAction,
14
+ MiddleClickAction,
15
+ MoveToAction,
16
+ PressAction,
17
+ RightClickAction,
18
+ ScrollAction,
19
+ TypeAction,
20
+ WaitAction,
21
+ )
22
+
23
+ __all__ = [
24
+ "__version__",
25
+ "EnvironmentConfig",
26
+ "InstructionsConfig",
27
+ "TaskDefinition",
28
+ "PyTaskLoader",
29
+ "RemoteTaskLoader",
30
+ "KeyAction",
31
+ "ClickAction",
32
+ "RightClickAction",
33
+ "ScrollAction",
34
+ "TypeAction",
35
+ "DoubleClickAction",
36
+ "DragAction",
37
+ "MoveToAction",
38
+ "PressAction",
39
+ "HotkeyAction",
40
+ "MiddleClickAction",
41
+ "DoneAction",
42
+ "WaitAction",
43
+ "FailAction",
44
+ "Action",
45
+ "ActionType",
46
+ ]
@@ -0,0 +1,4 @@
1
+ from core.dojos.dojos_loader import collect_tasks, load_dojos
2
+ from core.dojos.model import Dojo
3
+
4
+ __all__ = ["Dojo", "collect_tasks", "load_dojos"]
@@ -0,0 +1,25 @@
1
+ from core.dojos.model import Dojo
2
+
3
+
4
+ # Add new dojos here
5
+ def load_dojos() -> list[Dojo]:
6
+ return []
7
+
8
+
9
+ def collect_tasks(dojos: list[Dojo]) -> list[str]:
10
+ return [dojo.get_id() + "/" + task.id for dojo in dojos for task in dojo.get_tasks_list()]
11
+
12
+
13
+ # TODO: For now this is just a sanity check to make sure the task name is valid.
14
+ def run_tasks_by_name(task_names: list[str]) -> list[str]:
15
+ dojos = load_dojos()
16
+ for task_name in task_names:
17
+ found = False
18
+ for dojo in dojos:
19
+ for task in dojo.get_tasks_list():
20
+ if dojo.get_id() + "/" + task.id == task_name:
21
+ found = True
22
+ break
23
+ if not found:
24
+ raise ValueError(f"Task {task_name} not found")
25
+ return task_names
@@ -0,0 +1,14 @@
1
+ import string
2
+ from abc import ABC, abstractmethod
3
+
4
+ from core.models import TaskDefinition
5
+
6
+
7
+ class Dojo(ABC):
8
+ @abstractmethod
9
+ def get_id(self) -> string:
10
+ pass
11
+
12
+ @abstractmethod
13
+ def get_tasks_list(self) -> list[TaskDefinition]:
14
+ pass
@@ -0,0 +1,131 @@
1
+ """
2
+ Centralized reward functions for all dojos.
3
+
4
+ This module contains all reward validation functions used across different dojos.
5
+ Each function takes (initial_state, final_state) and returns (score, reason).
6
+ """
7
+
8
+ import logging
9
+ from typing import Any, Dict, Tuple
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def _validate_get_2048(initial_state: Dict[str, Any], final_state: Dict[str, Any]) -> Tuple[float, str]:
15
+ """Validate that a 2048 tile is present on the board."""
16
+ if "board" not in final_state:
17
+ return 0.0, "No board in final state"
18
+
19
+ logger.debug(f"running reward function on state: {final_state}")
20
+ if 2048 in final_state["board"]:
21
+ return 1.0, "A 2048 tile is present."
22
+ return 0.0, "No 2048 tile is present."
23
+
24
+
25
+ def _validate_search_for_dzaka(initial_state: Dict[str, Any], final_state: Dict[str, Any]) -> Tuple[float, str]:
26
+ """Validate that the user successfully searched for Dzaka Athif."""
27
+ logger.debug(f"Running reward function on state: {final_state}")
28
+
29
+ # Check 1: currentView is "search"
30
+ if final_state.get("currentView") != "search":
31
+ return 0.0, f"Not on search page, current view: {final_state.get('currentView')}"
32
+
33
+ # Check 2: searchQuery contains "dzaka" (case insensitive)
34
+ query = final_state.get("searchQuery", "").lower()
35
+ if "dzaka" not in query:
36
+ return 0.0, f"Search query doesn't contain 'dzaka': {final_state.get('searchQuery')}"
37
+
38
+ # Check 3: Dzaka Athif in search results
39
+ search_results = final_state.get("searchResults", {})
40
+ people = search_results.get("allPeople", [])
41
+
42
+ dzaka_found = any(
43
+ user.get("name") == "Dzaka Athif"
44
+ for user in people
45
+ )
46
+
47
+ if dzaka_found:
48
+ return 1.0, "Successfully searched for Dzaka Athif"
49
+
50
+ return 0.0, f"Dzaka Athif not found in search results. Found {len(people)} people."
51
+
52
+
53
+ def _validate_drag_to_different_column(initial_state: Dict[str, Any], final_state: Dict[str, Any]) -> Tuple[float, str]:
54
+ """Validate that an issue was moved to a different column."""
55
+ if "issues" not in final_state:
56
+ return 0.0, "No issues in final state"
57
+
58
+ logger.debug(f"Running reward function on state: {final_state}")
59
+
60
+ # Find the issue we're tracking (VSS-101)
61
+ target_issue = next(
62
+ (issue for issue in final_state["issues"] if issue.get("identifier") == "VSS-101"),
63
+ None
64
+ )
65
+
66
+ if not target_issue:
67
+ return 0.0, "Target issue VSS-101 not found in final state"
68
+
69
+ # Check if it moved to in_progress status
70
+ if target_issue.get("status") == "in_progress" and target_issue.get("assigneeId") == "1":
71
+ return 1.0, "Issue VSS-101 successfully moved to In Progress column for user 1"
72
+
73
+ return 0.0, f"Issue VSS-101 has status={target_issue.get('status')}, assigneeId={target_issue.get('assigneeId')}, expected status=in_progress, assigneeId=1"
74
+
75
+
76
+ def _validate_drag_two_issues_same_user(initial_state: Dict[str, Any], final_state: Dict[str, Any]) -> Tuple[float, str]:
77
+ """Validate that two issues were moved within the same user's board."""
78
+ if "issues" not in final_state:
79
+ return 0.0, "No issues in final state"
80
+
81
+ logger.debug(f"Running reward function on state: {final_state}")
82
+
83
+ # Find both target issues
84
+ issue_101 = next(
85
+ (issue for issue in final_state["issues"] if issue.get("identifier") == "VSS-101"),
86
+ None
87
+ )
88
+ issue_106 = next(
89
+ (issue for issue in final_state["issues"] if issue.get("identifier") == "VSS-106"),
90
+ None
91
+ )
92
+
93
+ if not issue_101:
94
+ return 0.0, "Target issue VSS-101 not found in final state"
95
+ if not issue_106:
96
+ return 0.0, "Target issue VSS-106 not found in final state"
97
+
98
+ # Check both issues
99
+ issue_101_correct = (
100
+ issue_101.get("status") == "in_progress" and
101
+ issue_101.get("assigneeId") == "1"
102
+ )
103
+ issue_106_correct = (
104
+ issue_106.get("status") == "queued" and
105
+ issue_106.get("assigneeId") == "1"
106
+ )
107
+
108
+ if issue_101_correct and issue_106_correct:
109
+ return 1.0, "Both issues successfully moved to target columns for user 1"
110
+
111
+ errors = []
112
+ if not issue_101_correct:
113
+ errors.append(f"VSS-101: status={issue_101.get('status')}, assigneeId={issue_101.get('assigneeId')} (expected in_progress, 1)")
114
+ if not issue_106_correct:
115
+ errors.append(f"VSS-106: status={issue_106.get('status')}, assigneeId={issue_106.get('assigneeId')} (expected queued, 1)")
116
+
117
+ return 0.0, "; ".join(errors)
118
+
119
+
120
+ # Registry of all reward functions for easy lookup
121
+ REWARD_FUNCTIONS = {
122
+ "_validate_get_2048": _validate_get_2048,
123
+ "_validate_search_for_dzaka": _validate_search_for_dzaka,
124
+ "_validate_drag_to_different_column": _validate_drag_to_different_column,
125
+ "_validate_drag_two_issues_same_user": _validate_drag_two_issues_same_user,
126
+ }
127
+
128
+
129
+ def get_reward_function(name: str):
130
+ """Get a reward function by name."""
131
+ return REWARD_FUNCTIONS.get(name)
@@ -0,0 +1,167 @@
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Any, Callable, Optional
4
+
5
+ from pydantic import BaseModel, Field
6
+
7
+
8
+ class SettingsConfig(BaseModel):
9
+ """Settings configuration."""
10
+
11
+ anthropic_api_key: str = Field("", description="Anthropic API key")
12
+ openai_api_key: str = Field("", description="OpenAI API key")
13
+ openai_api_url: str = Field("", description="OpenAI API URL")
14
+ dojo_websocket_endpoint: str = Field("", description="Dojo websocket endpoint")
15
+ dojo_http_endpoint: str = Field("", description="Dojo http endpoint")
16
+
17
+
18
+ class EnvironmentConfig(BaseModel):
19
+ """Environment configuration."""
20
+
21
+ type: str = Field(..., description="Environment type (e.g., 'spa')")
22
+ path: str = Field(..., description="Path to environment file")
23
+
24
+
25
+ class InstructionsConfig(BaseModel):
26
+ """Task instructions configuration."""
27
+
28
+ user_prompt: str = Field(..., description="Prompt to show to the agent")
29
+ success_criteria: str = Field(..., description="What constitutes success")
30
+
31
+
32
+ class TaskDefinition(BaseModel):
33
+ """Complete task definition."""
34
+
35
+ id: str = Field(..., description="Unique task identifier")
36
+ name: str = Field(..., description="Human-readable task name")
37
+ description: str = Field(..., description="Task description")
38
+ environment: EnvironmentConfig = Field(..., description="Environment configuration")
39
+ initial_state: dict[str, Any] = Field(
40
+ ..., description="Initial state for the environment"
41
+ )
42
+ instructions: InstructionsConfig = Field(..., description="Task instructions")
43
+ reward_function: Optional[
44
+ Callable[[dict[str, Any], dict[str, Any]], tuple[float, str]]
45
+ ] = Field(
46
+ None,
47
+ description="Reward function to be used if valid_target_states is not provided",
48
+ )
49
+ valid_target_states: Optional[list[dict[str, Any]]] = Field(
50
+ None,
51
+ description="List of valid target states for binary reward (1 if reached, 0 otherwise)",
52
+ )
53
+ max_steps: int = Field(default=10, description="Maximum number of steps allowed")
54
+ timeout_seconds: int = Field(default=60, description="Task timeout in seconds")
55
+ metadata: dict[str, Any] = Field(
56
+ default_factory=dict, description="Additional metadata"
57
+ )
58
+
59
+ def model_post_init(self, __context: Any) -> None:
60
+ """Validate that either reward_function or valid_target_states is provided."""
61
+ if not self.reward_function and not self.valid_target_states:
62
+ raise ValueError(
63
+ "Either reward_function or valid_target_states must be provided"
64
+ )
65
+ if self.reward_function and self.valid_target_states:
66
+ raise ValueError(
67
+ "Cannot specify both reward_function and valid_target_states"
68
+ )
69
+
70
+ def get_environment_path(self, base_path: Optional[Path] = None) -> Path:
71
+ """Get absolute path to environment file."""
72
+ if base_path is None:
73
+ base_path = Path.cwd()
74
+ return base_path / self.environment.path
75
+
76
+ def load_reward_function(
77
+ self,
78
+ ) -> Callable[[dict[str, Any], dict[str, Any]], tuple[float, str]]:
79
+ """Load and return the reward function."""
80
+ if self.valid_target_states:
81
+ # Return built-in binary reward function
82
+ return self._create_binary_reward_function()
83
+
84
+ if not self.reward_function:
85
+ raise ValueError("No reward function or valid_target_states specified")
86
+
87
+ return self.reward_function
88
+
89
+ def _create_binary_reward_function(
90
+ self,
91
+ ) -> Callable[[dict[str, Any], dict[str, Any]], tuple[float, str]]:
92
+ """Create a binary reward function based on valid_target_states."""
93
+
94
+ def binary_reward_function(
95
+ initial_state: dict[str, Any], final_state: dict[str, Any]
96
+ ) -> tuple[float, str]:
97
+ """Binary reward function that checks if final_state matches any valid target state."""
98
+ if not self.valid_target_states:
99
+ return 0.0, "No valid target states defined"
100
+
101
+ for i, target_state in enumerate(self.valid_target_states):
102
+ if self._states_match(final_state, target_state):
103
+ return 1.0, f"Reached valid target state {i + 1}"
104
+
105
+ return 0.0, "Did not reach any valid target state"
106
+
107
+ return binary_reward_function
108
+
109
+ def _states_match(self, state1: dict[str, Any], state2: dict[str, Any]) -> bool:
110
+ """Check if two states match by comparing all key-value pairs."""
111
+ # Check that state1 contains all key-value pairs from state2
112
+ for key, value in state2.items():
113
+ if key not in state1:
114
+ return False
115
+ if state1[key] != value:
116
+ return False
117
+ return True
118
+
119
+ @classmethod
120
+ def from_hf_row(cls, row: dict[str, Any], reward_function_importer: Optional[Callable[[str], Optional[Callable]]] = None) -> "TaskDefinition":
121
+ """
122
+ Create a TaskDefinition from a HuggingFace dataset row.
123
+
124
+ This handles the stringified JSON fields and optional reward function importing
125
+ that's specific to the HF dataset format.
126
+
127
+ Args:
128
+ row: Dictionary representing a row from the HF dataset
129
+ reward_function_importer: Optional function to import reward functions by name
130
+
131
+ Returns:
132
+ TaskDefinition instance
133
+ """
134
+ # Parse stringified JSON fields
135
+ initial_state = json.loads(row["initial_state"])
136
+ environment = json.loads(row["environment"])
137
+ instructions = json.loads(row["instructions"])
138
+ metadata = json.loads(row["metadata"])
139
+
140
+ # Handle valid_target_states (may be empty string)
141
+ valid_target_states = None
142
+ if row["valid_target_states"] and row["valid_target_states"].strip():
143
+ valid_target_states = json.loads(row["valid_target_states"])
144
+
145
+ # Handle reward_function (may be empty string)
146
+ reward_function = None
147
+ if row["reward_function"] and row["reward_function"].strip():
148
+ if reward_function_importer:
149
+ reward_function = reward_function_importer(row["reward_function"])
150
+ else:
151
+ # Store function name as string if no importer provided
152
+ # This allows the caller to handle importing later
153
+ pass
154
+
155
+ return cls(
156
+ id=row["id"],
157
+ name=row["name"],
158
+ description=row["description"],
159
+ environment=environment,
160
+ initial_state=initial_state,
161
+ instructions=instructions,
162
+ reward_function=reward_function,
163
+ valid_target_states=valid_target_states,
164
+ max_steps=row["max_steps"],
165
+ timeout_seconds=row["timeout_seconds"],
166
+ metadata=metadata
167
+ )
@@ -0,0 +1,47 @@
1
+ import os
2
+
3
+ from dotenv import load_dotenv
4
+
5
+ from core.models import SettingsConfig
6
+
7
+ load_dotenv()
8
+
9
+
10
+ _DEFAULT_HTTP_ENDPOINT = "https://orchestrator.trydojo.ai/api/v1"
11
+
12
+
13
+ def _resolve_http_endpoint() -> str:
14
+ """Resolve the HTTP endpoint for the Dojo backend."""
15
+
16
+ return os.getenv("DOJO_HTTP_ENDPOINT") or _DEFAULT_HTTP_ENDPOINT
17
+
18
+
19
+ def _derive_ws_endpoint(http_endpoint: str) -> str:
20
+ """Derive a websocket endpoint from the provided HTTP endpoint."""
21
+
22
+ if not http_endpoint:
23
+ return "ws://localhost:8765/api/v1/jobs"
24
+
25
+ normalized = http_endpoint.rstrip("/")
26
+ if normalized.startswith("https://"):
27
+ normalized = "wss://" + normalized[len("https://") :]
28
+ elif normalized.startswith("http://"):
29
+ normalized = "ws://" + normalized[len("http://") :]
30
+
31
+ if not normalized.endswith("/jobs"):
32
+ normalized = f"{normalized}/jobs"
33
+
34
+ return normalized
35
+
36
+
37
+ _http_endpoint = _resolve_http_endpoint()
38
+ _ws_endpoint = os.getenv("DOJO_WEBSOCKET_ENDPOINT") or _derive_ws_endpoint(_http_endpoint)
39
+
40
+ settings = SettingsConfig(
41
+ anthropic_api_key=os.getenv("ANTHROPIC_API_KEY", ""),
42
+ openai_api_key=os.getenv("OPENAI_API_KEY", ""),
43
+ openai_api_url=os.getenv("OPENAI_API_URL", ""),
44
+ # TODO: switch to prod endpoint as default
45
+ dojo_websocket_endpoint=_ws_endpoint,
46
+ dojo_http_endpoint=_http_endpoint,
47
+ )
@@ -0,0 +1,123 @@
1
+ """Task definition and loading system."""
2
+
3
+ import importlib
4
+ import json
5
+ import string
6
+ from abc import ABC, abstractmethod
7
+ from typing import Any, Callable, List, Optional
8
+
9
+ from datasets import load_dataset
10
+
11
+ from core.dojos import Dojo
12
+ from core.dojos.dojos_loader import load_dojos
13
+ from core.dojos.rewards import get_reward_function
14
+ from core.models import TaskDefinition
15
+
16
+
17
+ class TaskLoader(ABC):
18
+ """Abstract base class for task loaders."""
19
+
20
+ @abstractmethod
21
+ def load_task(self, task_path: str) -> TaskDefinition:
22
+ """Load a task definition from a JSON file."""
23
+ pass
24
+
25
+ @abstractmethod
26
+ def load_dojo_tasks(self, dojo_name: str) -> List[TaskDefinition]:
27
+ """Load all tasks from a benchmark directory."""
28
+ pass
29
+
30
+ @abstractmethod
31
+ def list_tasks(self, category: Optional[str] = None) -> dict[str, TaskDefinition]:
32
+ """List all available tasks."""
33
+ pass
34
+
35
+
36
+ class PyTaskLoader:
37
+ """Loads tasks from python-based task registry"""
38
+
39
+ @staticmethod
40
+ def _register_all(bms: list[Dojo]) -> dict[string, Dojo]:
41
+ result = {}
42
+ for bm in bms:
43
+ result[bm.get_id()] = bm
44
+
45
+ return result
46
+
47
+ def __init__(self):
48
+ self.benchmarks_by_name = PyTaskLoader._register_all(load_dojos())
49
+
50
+ def load_task(self, benchmark_task_path: str) -> TaskDefinition:
51
+ split_path = benchmark_task_path.split("/")
52
+ benchmark_name = split_path[0]
53
+ task_id = split_path[1]
54
+ tasks = self.load_benchmark_tasks(benchmark_name)
55
+ for task in tasks:
56
+ if task.id == task_id:
57
+ return task
58
+ raise ValueError(f"Task {task_id} not found in benchmark {benchmark_name}")
59
+
60
+ def load_benchmark_tasks(self, benchmark_name: str) -> List[TaskDefinition]:
61
+ return self.benchmarks_by_name[benchmark_name].get_tasks_list()
62
+
63
+ def list_tasks(self, category: Optional[str] = None) -> dict[str, TaskDefinition]:
64
+ raise NotImplementedError("List tasks not implemented")
65
+
66
+
67
+ class RemoteTaskLoader:
68
+ """Loads tasks from a Hugging Face dataset."""
69
+
70
+ def __init__(self, dataset_name: str):
71
+ self.dataset_name = dataset_name
72
+ self.dataset = None
73
+ self.tasks_cache = None
74
+
75
+ def _load_dataset(self):
76
+ """Lazy load the HF dataset."""
77
+ if self.dataset is None:
78
+ self.dataset = load_dataset(self.dataset_name)["train"]
79
+ print(f"✓ Loaded {len(self.dataset)} tasks from HF dataset {self.dataset_name}")
80
+ return self.dataset
81
+
82
+ def _import_reward_function(self, function_name: str) -> Optional[Callable[[dict[str, Any], dict[str, Any]], tuple[float, str]]]:
83
+ """Import a reward function by name from the centralized rewards module."""
84
+ if not function_name or function_name == "":
85
+ return None
86
+
87
+ return get_reward_function(function_name)
88
+
89
+ def _get_all_tasks(self) -> List[TaskDefinition]:
90
+ """Load all tasks from HF dataset and convert to TaskDefinition objects."""
91
+ if self.tasks_cache is None:
92
+ dataset = self._load_dataset()
93
+ tasks = []
94
+
95
+ for row in dataset:
96
+ # Use the clean constructor that handles HF-specific parsing
97
+ task = TaskDefinition.from_hf_row(row, reward_function_importer=self._import_reward_function)
98
+ tasks.append(task)
99
+
100
+ self.tasks_cache = tasks
101
+ print(f"✓ Converted {len(tasks)} HF tasks to TaskDefinition objects")
102
+
103
+ return self.tasks_cache
104
+
105
+ def load_task(self, task_path: str) -> TaskDefinition:
106
+ """Load a specific task by dojo/task_id path."""
107
+ # Parse task_path like "action-tester/must-click"
108
+ if "/" not in task_path:
109
+ raise ValueError(f"Invalid task path format: {task_path}. Expected 'dojo/task_id'")
110
+
111
+ dojo_name, task_id = task_path.split("/", 1)
112
+
113
+ all_tasks = self._get_all_tasks()
114
+
115
+ # Find task by task_id (dojo info is embedded in the task path, not metadata)
116
+ for task in all_tasks:
117
+ if task.id == task_id:
118
+ return task
119
+
120
+ raise ValueError(f"Task {task_id} not found in dojo {dojo_name}")
121
+
122
+ def list_tasks(self, category: Optional[str] = None) -> dict[str, TaskDefinition]:
123
+ raise NotImplementedError("List tasks not implemented")
@@ -0,0 +1,131 @@
1
+ from enum import Enum
2
+ from typing import List, Literal, Union
3
+
4
+ from pydantic import BaseModel
5
+
6
+ """Type definitions for Dojo."""
7
+
8
+
9
+ class ActionType(str, Enum):
10
+ KEY = "key"
11
+ CLICK = "click"
12
+ RIGHT_CLICK = "right_click"
13
+ SCROLL = "scroll"
14
+ TYPE = "type"
15
+ DOUBLE_CLICK = "double_click"
16
+ DRAG = "drag"
17
+ MOVE_TO = "move_to"
18
+ PRESS = "press"
19
+ HOTKEY = "hotkey"
20
+ MIDDLE_CLICK = "middle_click"
21
+ DONE = "done"
22
+ WAIT = "wait"
23
+ FAIL = "fail"
24
+
25
+
26
+ class KeyAction(BaseModel):
27
+ type: Literal[ActionType.KEY]
28
+ key: str
29
+
30
+
31
+ class ClickAction(BaseModel):
32
+ type: Literal[ActionType.CLICK]
33
+ x: int
34
+ y: int
35
+
36
+
37
+ class RightClickAction(BaseModel):
38
+ type: Literal[ActionType.RIGHT_CLICK]
39
+ x: int
40
+ y: int
41
+
42
+
43
+ class ScrollAction(BaseModel):
44
+ type: Literal[ActionType.SCROLL]
45
+ direction: str = "up"
46
+ amount: int = 100
47
+
48
+
49
+ class TypeAction(BaseModel):
50
+ type: Literal[ActionType.TYPE]
51
+ text: str
52
+
53
+
54
+ class DoubleClickAction(BaseModel):
55
+ type: Literal[ActionType.DOUBLE_CLICK]
56
+ x: int
57
+ y: int
58
+
59
+
60
+ class DragAction(BaseModel):
61
+ type: Literal[ActionType.DRAG]
62
+ from_x: int
63
+ from_y: int
64
+ to_x: int
65
+ to_y: int
66
+ duration: float = 1.0
67
+
68
+
69
+ class MoveToAction(BaseModel):
70
+ type: Literal[ActionType.MOVE_TO]
71
+ x: int
72
+ y: int
73
+ duration: float = 0.0
74
+
75
+
76
+ class PressAction(BaseModel):
77
+ type: Literal[ActionType.PRESS]
78
+ key: str
79
+
80
+
81
+ class HotkeyAction(BaseModel):
82
+ type: Literal[ActionType.HOTKEY]
83
+ keys: List[str]
84
+
85
+
86
+ class MiddleClickAction(BaseModel):
87
+ type: Literal[ActionType.MIDDLE_CLICK]
88
+ x: int
89
+ y: int
90
+
91
+
92
+ class DoneAction(BaseModel):
93
+ type: Literal[ActionType.DONE]
94
+
95
+
96
+ class WaitAction(BaseModel):
97
+ type: Literal[ActionType.WAIT]
98
+ seconds: int = 1
99
+
100
+
101
+ class FailAction(BaseModel):
102
+ type: Literal[ActionType.FAIL]
103
+ message: str
104
+
105
+
106
+ Action = Union[
107
+ KeyAction,
108
+ ClickAction,
109
+ RightClickAction,
110
+ ScrollAction,
111
+ TypeAction,
112
+ DoubleClickAction,
113
+ DragAction,
114
+ MoveToAction,
115
+ PressAction,
116
+ HotkeyAction,
117
+ MiddleClickAction,
118
+ DoneAction,
119
+ WaitAction,
120
+ FailAction,
121
+ ]
122
+
123
+
124
+ class Score(BaseModel):
125
+ task_name: str
126
+ score: float
127
+ status: str
128
+ success: bool
129
+ steps_taken: int
130
+ reward: float
131
+ completion_reason: str
@@ -0,0 +1,134 @@
1
+ from datetime import datetime
2
+ from enum import Enum
3
+ from typing import List, Literal, Optional, Union
4
+
5
+ from pydantic import BaseModel
6
+
7
+ from core.types import Action
8
+
9
+
10
+ class ExecutionTrace(BaseModel):
11
+ task_name: str
12
+ status: str
13
+ success: bool
14
+ steps_taken: int
15
+ reward: float
16
+ completion_reason: str
17
+ history: List[dict]
18
+ final_state: dict
19
+ final_screenshot: str
20
+
21
+
22
+ # WebSocket API Message Types
23
+ class TaskStatus(str, Enum):
24
+ RUNNING = "RUNNING"
25
+ COMPLETED = "COMPLETED"
26
+ FAILED = "FAILED"
27
+ PENDING = "PENDING"
28
+ TIMEOUT = "TIMEOUT"
29
+
30
+
31
+ class JobStatus(str, Enum):
32
+ RUNNING = "RUNNING"
33
+ COMPLETED = "COMPLETED"
34
+ FAILED = "FAILED"
35
+ PENDING = "PENDING"
36
+
37
+
38
+ class MsgType(str, Enum):
39
+ START_JOB = "StartJobRequest"
40
+ START_JOB_RESP = "StartJobResponse"
41
+ SUBMIT_ACTION = "SubmitActionRequest"
42
+ GET_NEXT_ACTION = "GetNextActionRequest"
43
+ TASK_COMPLETE = "TaskComplete"
44
+ JOB_STATUS_RESP = "JobStatusResponse"
45
+ JOB_COMPLETE = "JobComplete"
46
+ AUTH = "Auth"
47
+ AUTH_RESP = "AuthResp"
48
+ CANCEL_JOB = "CancelJobRequest"
49
+
50
+
51
+ class HistoryStep(BaseModel):
52
+ step: int
53
+ after_screenshot: str
54
+ agent_response: str
55
+ raw_response: str
56
+ action: Action
57
+ score: float
58
+
59
+
60
+ class NextStep(BaseModel):
61
+ number: int
62
+ after_screenshot: str
63
+
64
+
65
+ class TaskResult(BaseModel):
66
+ status: TaskStatus
67
+ start_time: datetime
68
+ end_time: datetime
69
+ exec_id: str
70
+ task_id: str
71
+ task_name: str
72
+ job_id: str
73
+ final_score: float
74
+ reason: str
75
+ history: List[HistoryStep]
76
+ recording: Optional[str] = None
77
+
78
+
79
+ class JobComplete(BaseModel):
80
+ type: Literal[MsgType.JOB_COMPLETE]
81
+ job_id: str
82
+ tasks: List[TaskResult]
83
+
84
+
85
+ # Stateless API Types
86
+
87
+
88
+ class PendingTask(BaseModel):
89
+ start_time: datetime
90
+ task_name: str
91
+ id: str
92
+ exec_id: str
93
+ prompt: str
94
+ status: TaskStatus
95
+ history: List[HistoryStep]
96
+ pending_step: Optional[NextStep] = None
97
+
98
+
99
+ class FinishedTask(BaseModel):
100
+ status: TaskStatus
101
+ start_time: datetime
102
+ end_time: datetime
103
+ task_name: str
104
+ id: str
105
+ exec_id: str
106
+ prompt: str
107
+ history: List[HistoryStep]
108
+ score: float
109
+ recording: Optional[str] = None
110
+
111
+
112
+ class JobStatusResponse(BaseModel):
113
+ status: JobStatus
114
+ start_time: datetime
115
+ end_time: datetime | None
116
+ pending_tasks: List[PendingTask]
117
+ finished_tasks: List[FinishedTask]
118
+
119
+
120
+ class JobMetadata(BaseModel):
121
+ id: str
122
+ status: JobStatus
123
+ start_time: datetime
124
+ end_time: datetime | None
125
+ num_tasks: int
126
+ num_completed_tasks: int
127
+
128
+
129
+ class JobsResponse(BaseModel):
130
+ jobs: List[JobMetadata]
131
+
132
+
133
+ # Union type for all WebSocket messages
134
+ WebSocketMessage = Union[JobComplete,]
@@ -0,0 +1,38 @@
1
+ [project]
2
+ name = "dojo-sdk-core"
3
+ version = "0.1.0"
4
+ description = "Core functionality of dojo"
5
+ readme = "README.md"
6
+ requires-python = ">=3.9"
7
+ dependencies = ["pydantic>=2.11.9", "python-dotenv>=1.1.1", "datasets>=4.1.1"]
8
+
9
+ [project.optional-dependencies]
10
+ dev = ["pytest>=7.0.0", "pytest-cov", "black", "isort", "mypy", "ruff"]
11
+
12
+ [build-system]
13
+ requires = ["hatchling"]
14
+ build-backend = "hatchling.build"
15
+
16
+ [tool.hatch.build.targets.sdist]
17
+ include = ["/src", "/README.md"]
18
+
19
+ [tool.isort]
20
+ profile = "black"
21
+ src_paths = ["src", "tests"]
22
+
23
+ [tool.ruff]
24
+ line-length = 128
25
+ target-version = "py311"
26
+
27
+ [tool.ruff.lint]
28
+ select = ["E", "F", "W", "B", "I"]
29
+
30
+ [tool.ruff.format]
31
+ quote-style = "double"
32
+ indent-style = "space"
33
+
34
+ [tool.hatch.build]
35
+ sources = ["src"]
36
+
37
+ [tool.hatch.build.targets.wheel]
38
+ packages = ["src/core"]