dojo-sdk-core 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dojo_sdk_core-0.1.0/.gitignore +4 -0
- dojo_sdk_core-0.1.0/PKG-INFO +21 -0
- dojo_sdk_core-0.1.0/README.md +4 -0
- dojo_sdk_core-0.1.0/core/__about__.py +1 -0
- dojo_sdk_core-0.1.0/core/__init__.py +46 -0
- dojo_sdk_core-0.1.0/core/dojos/__init__.py +4 -0
- dojo_sdk_core-0.1.0/core/dojos/dojos_loader.py +25 -0
- dojo_sdk_core-0.1.0/core/dojos/model.py +14 -0
- dojo_sdk_core-0.1.0/core/dojos/rewards.py +131 -0
- dojo_sdk_core-0.1.0/core/models.py +167 -0
- dojo_sdk_core-0.1.0/core/settings.py +47 -0
- dojo_sdk_core-0.1.0/core/tasks.py +123 -0
- dojo_sdk_core-0.1.0/core/types.py +131 -0
- dojo_sdk_core-0.1.0/core/ws_types.py +134 -0
- dojo_sdk_core-0.1.0/pyproject.toml +38 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: dojo-sdk-core
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Core functionality of dojo
|
|
5
|
+
Requires-Python: >=3.9
|
|
6
|
+
Requires-Dist: datasets>=4.1.1
|
|
7
|
+
Requires-Dist: pydantic>=2.11.9
|
|
8
|
+
Requires-Dist: python-dotenv>=1.1.1
|
|
9
|
+
Provides-Extra: dev
|
|
10
|
+
Requires-Dist: black; extra == 'dev'
|
|
11
|
+
Requires-Dist: isort; extra == 'dev'
|
|
12
|
+
Requires-Dist: mypy; extra == 'dev'
|
|
13
|
+
Requires-Dist: pytest-cov; extra == 'dev'
|
|
14
|
+
Requires-Dist: pytest>=7.0.0; extra == 'dev'
|
|
15
|
+
Requires-Dist: ruff; extra == 'dev'
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
|
|
18
|
+
## dojo-core
|
|
19
|
+
|
|
20
|
+
Core functionality of dojo.
|
|
21
|
+
Contains models shared by other parts of the dojo project.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.0.1"
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from core.__about__ import __version__
|
|
2
|
+
from core.models import EnvironmentConfig, InstructionsConfig, TaskDefinition
|
|
3
|
+
from core.tasks import PyTaskLoader, RemoteTaskLoader
|
|
4
|
+
from core.types import (
|
|
5
|
+
Action,
|
|
6
|
+
ActionType,
|
|
7
|
+
ClickAction,
|
|
8
|
+
DoneAction,
|
|
9
|
+
DoubleClickAction,
|
|
10
|
+
DragAction,
|
|
11
|
+
FailAction,
|
|
12
|
+
HotkeyAction,
|
|
13
|
+
KeyAction,
|
|
14
|
+
MiddleClickAction,
|
|
15
|
+
MoveToAction,
|
|
16
|
+
PressAction,
|
|
17
|
+
RightClickAction,
|
|
18
|
+
ScrollAction,
|
|
19
|
+
TypeAction,
|
|
20
|
+
WaitAction,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"__version__",
|
|
25
|
+
"EnvironmentConfig",
|
|
26
|
+
"InstructionsConfig",
|
|
27
|
+
"TaskDefinition",
|
|
28
|
+
"PyTaskLoader",
|
|
29
|
+
"RemoteTaskLoader",
|
|
30
|
+
"KeyAction",
|
|
31
|
+
"ClickAction",
|
|
32
|
+
"RightClickAction",
|
|
33
|
+
"ScrollAction",
|
|
34
|
+
"TypeAction",
|
|
35
|
+
"DoubleClickAction",
|
|
36
|
+
"DragAction",
|
|
37
|
+
"MoveToAction",
|
|
38
|
+
"PressAction",
|
|
39
|
+
"HotkeyAction",
|
|
40
|
+
"MiddleClickAction",
|
|
41
|
+
"DoneAction",
|
|
42
|
+
"WaitAction",
|
|
43
|
+
"FailAction",
|
|
44
|
+
"Action",
|
|
45
|
+
"ActionType",
|
|
46
|
+
]
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from core.dojos.model import Dojo
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
# Add new dojos here
|
|
5
|
+
def load_dojos() -> list[Dojo]:
|
|
6
|
+
return []
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def collect_tasks(dojos: list[Dojo]) -> list[str]:
|
|
10
|
+
return [dojo.get_id() + "/" + task.id for dojo in dojos for task in dojo.get_tasks_list()]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# TODO: For now this is just a sanity check to make sure the task name is valid.
|
|
14
|
+
def run_tasks_by_name(task_names: list[str]) -> list[str]:
|
|
15
|
+
dojos = load_dojos()
|
|
16
|
+
for task_name in task_names:
|
|
17
|
+
found = False
|
|
18
|
+
for dojo in dojos:
|
|
19
|
+
for task in dojo.get_tasks_list():
|
|
20
|
+
if dojo.get_id() + "/" + task.id == task_name:
|
|
21
|
+
found = True
|
|
22
|
+
break
|
|
23
|
+
if not found:
|
|
24
|
+
raise ValueError(f"Task {task_name} not found")
|
|
25
|
+
return task_names
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
import string
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
|
|
4
|
+
from core.models import TaskDefinition
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Dojo(ABC):
|
|
8
|
+
@abstractmethod
|
|
9
|
+
def get_id(self) -> string:
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
@abstractmethod
|
|
13
|
+
def get_tasks_list(self) -> list[TaskDefinition]:
|
|
14
|
+
pass
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Centralized reward functions for all dojos.
|
|
3
|
+
|
|
4
|
+
This module contains all reward validation functions used across different dojos.
|
|
5
|
+
Each function takes (initial_state, final_state) and returns (score, reason).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
from typing import Any, Dict, Tuple
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _validate_get_2048(initial_state: Dict[str, Any], final_state: Dict[str, Any]) -> Tuple[float, str]:
|
|
15
|
+
"""Validate that a 2048 tile is present on the board."""
|
|
16
|
+
if "board" not in final_state:
|
|
17
|
+
return 0.0, "No board in final state"
|
|
18
|
+
|
|
19
|
+
logger.debug(f"running reward function on state: {final_state}")
|
|
20
|
+
if 2048 in final_state["board"]:
|
|
21
|
+
return 1.0, "A 2048 tile is present."
|
|
22
|
+
return 0.0, "No 2048 tile is present."
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _validate_search_for_dzaka(initial_state: Dict[str, Any], final_state: Dict[str, Any]) -> Tuple[float, str]:
|
|
26
|
+
"""Validate that the user successfully searched for Dzaka Athif."""
|
|
27
|
+
logger.debug(f"Running reward function on state: {final_state}")
|
|
28
|
+
|
|
29
|
+
# Check 1: currentView is "search"
|
|
30
|
+
if final_state.get("currentView") != "search":
|
|
31
|
+
return 0.0, f"Not on search page, current view: {final_state.get('currentView')}"
|
|
32
|
+
|
|
33
|
+
# Check 2: searchQuery contains "dzaka" (case insensitive)
|
|
34
|
+
query = final_state.get("searchQuery", "").lower()
|
|
35
|
+
if "dzaka" not in query:
|
|
36
|
+
return 0.0, f"Search query doesn't contain 'dzaka': {final_state.get('searchQuery')}"
|
|
37
|
+
|
|
38
|
+
# Check 3: Dzaka Athif in search results
|
|
39
|
+
search_results = final_state.get("searchResults", {})
|
|
40
|
+
people = search_results.get("allPeople", [])
|
|
41
|
+
|
|
42
|
+
dzaka_found = any(
|
|
43
|
+
user.get("name") == "Dzaka Athif"
|
|
44
|
+
for user in people
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
if dzaka_found:
|
|
48
|
+
return 1.0, "Successfully searched for Dzaka Athif"
|
|
49
|
+
|
|
50
|
+
return 0.0, f"Dzaka Athif not found in search results. Found {len(people)} people."
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _validate_drag_to_different_column(initial_state: Dict[str, Any], final_state: Dict[str, Any]) -> Tuple[float, str]:
|
|
54
|
+
"""Validate that an issue was moved to a different column."""
|
|
55
|
+
if "issues" not in final_state:
|
|
56
|
+
return 0.0, "No issues in final state"
|
|
57
|
+
|
|
58
|
+
logger.debug(f"Running reward function on state: {final_state}")
|
|
59
|
+
|
|
60
|
+
# Find the issue we're tracking (VSS-101)
|
|
61
|
+
target_issue = next(
|
|
62
|
+
(issue for issue in final_state["issues"] if issue.get("identifier") == "VSS-101"),
|
|
63
|
+
None
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
if not target_issue:
|
|
67
|
+
return 0.0, "Target issue VSS-101 not found in final state"
|
|
68
|
+
|
|
69
|
+
# Check if it moved to in_progress status
|
|
70
|
+
if target_issue.get("status") == "in_progress" and target_issue.get("assigneeId") == "1":
|
|
71
|
+
return 1.0, "Issue VSS-101 successfully moved to In Progress column for user 1"
|
|
72
|
+
|
|
73
|
+
return 0.0, f"Issue VSS-101 has status={target_issue.get('status')}, assigneeId={target_issue.get('assigneeId')}, expected status=in_progress, assigneeId=1"
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _validate_drag_two_issues_same_user(initial_state: Dict[str, Any], final_state: Dict[str, Any]) -> Tuple[float, str]:
|
|
77
|
+
"""Validate that two issues were moved within the same user's board."""
|
|
78
|
+
if "issues" not in final_state:
|
|
79
|
+
return 0.0, "No issues in final state"
|
|
80
|
+
|
|
81
|
+
logger.debug(f"Running reward function on state: {final_state}")
|
|
82
|
+
|
|
83
|
+
# Find both target issues
|
|
84
|
+
issue_101 = next(
|
|
85
|
+
(issue for issue in final_state["issues"] if issue.get("identifier") == "VSS-101"),
|
|
86
|
+
None
|
|
87
|
+
)
|
|
88
|
+
issue_106 = next(
|
|
89
|
+
(issue for issue in final_state["issues"] if issue.get("identifier") == "VSS-106"),
|
|
90
|
+
None
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
if not issue_101:
|
|
94
|
+
return 0.0, "Target issue VSS-101 not found in final state"
|
|
95
|
+
if not issue_106:
|
|
96
|
+
return 0.0, "Target issue VSS-106 not found in final state"
|
|
97
|
+
|
|
98
|
+
# Check both issues
|
|
99
|
+
issue_101_correct = (
|
|
100
|
+
issue_101.get("status") == "in_progress" and
|
|
101
|
+
issue_101.get("assigneeId") == "1"
|
|
102
|
+
)
|
|
103
|
+
issue_106_correct = (
|
|
104
|
+
issue_106.get("status") == "queued" and
|
|
105
|
+
issue_106.get("assigneeId") == "1"
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
if issue_101_correct and issue_106_correct:
|
|
109
|
+
return 1.0, "Both issues successfully moved to target columns for user 1"
|
|
110
|
+
|
|
111
|
+
errors = []
|
|
112
|
+
if not issue_101_correct:
|
|
113
|
+
errors.append(f"VSS-101: status={issue_101.get('status')}, assigneeId={issue_101.get('assigneeId')} (expected in_progress, 1)")
|
|
114
|
+
if not issue_106_correct:
|
|
115
|
+
errors.append(f"VSS-106: status={issue_106.get('status')}, assigneeId={issue_106.get('assigneeId')} (expected queued, 1)")
|
|
116
|
+
|
|
117
|
+
return 0.0, "; ".join(errors)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
# Registry of all reward functions for easy lookup
|
|
121
|
+
REWARD_FUNCTIONS = {
|
|
122
|
+
"_validate_get_2048": _validate_get_2048,
|
|
123
|
+
"_validate_search_for_dzaka": _validate_search_for_dzaka,
|
|
124
|
+
"_validate_drag_to_different_column": _validate_drag_to_different_column,
|
|
125
|
+
"_validate_drag_two_issues_same_user": _validate_drag_two_issues_same_user,
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def get_reward_function(name: str):
|
|
130
|
+
"""Get a reward function by name."""
|
|
131
|
+
return REWARD_FUNCTIONS.get(name)
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any, Callable, Optional
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SettingsConfig(BaseModel):
|
|
9
|
+
"""Settings configuration."""
|
|
10
|
+
|
|
11
|
+
anthropic_api_key: str = Field("", description="Anthropic API key")
|
|
12
|
+
openai_api_key: str = Field("", description="OpenAI API key")
|
|
13
|
+
openai_api_url: str = Field("", description="OpenAI API URL")
|
|
14
|
+
dojo_websocket_endpoint: str = Field("", description="Dojo websocket endpoint")
|
|
15
|
+
dojo_http_endpoint: str = Field("", description="Dojo http endpoint")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class EnvironmentConfig(BaseModel):
|
|
19
|
+
"""Environment configuration."""
|
|
20
|
+
|
|
21
|
+
type: str = Field(..., description="Environment type (e.g., 'spa')")
|
|
22
|
+
path: str = Field(..., description="Path to environment file")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class InstructionsConfig(BaseModel):
|
|
26
|
+
"""Task instructions configuration."""
|
|
27
|
+
|
|
28
|
+
user_prompt: str = Field(..., description="Prompt to show to the agent")
|
|
29
|
+
success_criteria: str = Field(..., description="What constitutes success")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class TaskDefinition(BaseModel):
|
|
33
|
+
"""Complete task definition."""
|
|
34
|
+
|
|
35
|
+
id: str = Field(..., description="Unique task identifier")
|
|
36
|
+
name: str = Field(..., description="Human-readable task name")
|
|
37
|
+
description: str = Field(..., description="Task description")
|
|
38
|
+
environment: EnvironmentConfig = Field(..., description="Environment configuration")
|
|
39
|
+
initial_state: dict[str, Any] = Field(
|
|
40
|
+
..., description="Initial state for the environment"
|
|
41
|
+
)
|
|
42
|
+
instructions: InstructionsConfig = Field(..., description="Task instructions")
|
|
43
|
+
reward_function: Optional[
|
|
44
|
+
Callable[[dict[str, Any], dict[str, Any]], tuple[float, str]]
|
|
45
|
+
] = Field(
|
|
46
|
+
None,
|
|
47
|
+
description="Reward function to be used if valid_target_states is not provided",
|
|
48
|
+
)
|
|
49
|
+
valid_target_states: Optional[list[dict[str, Any]]] = Field(
|
|
50
|
+
None,
|
|
51
|
+
description="List of valid target states for binary reward (1 if reached, 0 otherwise)",
|
|
52
|
+
)
|
|
53
|
+
max_steps: int = Field(default=10, description="Maximum number of steps allowed")
|
|
54
|
+
timeout_seconds: int = Field(default=60, description="Task timeout in seconds")
|
|
55
|
+
metadata: dict[str, Any] = Field(
|
|
56
|
+
default_factory=dict, description="Additional metadata"
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
def model_post_init(self, __context: Any) -> None:
|
|
60
|
+
"""Validate that either reward_function or valid_target_states is provided."""
|
|
61
|
+
if not self.reward_function and not self.valid_target_states:
|
|
62
|
+
raise ValueError(
|
|
63
|
+
"Either reward_function or valid_target_states must be provided"
|
|
64
|
+
)
|
|
65
|
+
if self.reward_function and self.valid_target_states:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
"Cannot specify both reward_function and valid_target_states"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
def get_environment_path(self, base_path: Optional[Path] = None) -> Path:
|
|
71
|
+
"""Get absolute path to environment file."""
|
|
72
|
+
if base_path is None:
|
|
73
|
+
base_path = Path.cwd()
|
|
74
|
+
return base_path / self.environment.path
|
|
75
|
+
|
|
76
|
+
def load_reward_function(
|
|
77
|
+
self,
|
|
78
|
+
) -> Callable[[dict[str, Any], dict[str, Any]], tuple[float, str]]:
|
|
79
|
+
"""Load and return the reward function."""
|
|
80
|
+
if self.valid_target_states:
|
|
81
|
+
# Return built-in binary reward function
|
|
82
|
+
return self._create_binary_reward_function()
|
|
83
|
+
|
|
84
|
+
if not self.reward_function:
|
|
85
|
+
raise ValueError("No reward function or valid_target_states specified")
|
|
86
|
+
|
|
87
|
+
return self.reward_function
|
|
88
|
+
|
|
89
|
+
def _create_binary_reward_function(
|
|
90
|
+
self,
|
|
91
|
+
) -> Callable[[dict[str, Any], dict[str, Any]], tuple[float, str]]:
|
|
92
|
+
"""Create a binary reward function based on valid_target_states."""
|
|
93
|
+
|
|
94
|
+
def binary_reward_function(
|
|
95
|
+
initial_state: dict[str, Any], final_state: dict[str, Any]
|
|
96
|
+
) -> tuple[float, str]:
|
|
97
|
+
"""Binary reward function that checks if final_state matches any valid target state."""
|
|
98
|
+
if not self.valid_target_states:
|
|
99
|
+
return 0.0, "No valid target states defined"
|
|
100
|
+
|
|
101
|
+
for i, target_state in enumerate(self.valid_target_states):
|
|
102
|
+
if self._states_match(final_state, target_state):
|
|
103
|
+
return 1.0, f"Reached valid target state {i + 1}"
|
|
104
|
+
|
|
105
|
+
return 0.0, "Did not reach any valid target state"
|
|
106
|
+
|
|
107
|
+
return binary_reward_function
|
|
108
|
+
|
|
109
|
+
def _states_match(self, state1: dict[str, Any], state2: dict[str, Any]) -> bool:
|
|
110
|
+
"""Check if two states match by comparing all key-value pairs."""
|
|
111
|
+
# Check that state1 contains all key-value pairs from state2
|
|
112
|
+
for key, value in state2.items():
|
|
113
|
+
if key not in state1:
|
|
114
|
+
return False
|
|
115
|
+
if state1[key] != value:
|
|
116
|
+
return False
|
|
117
|
+
return True
|
|
118
|
+
|
|
119
|
+
@classmethod
|
|
120
|
+
def from_hf_row(cls, row: dict[str, Any], reward_function_importer: Optional[Callable[[str], Optional[Callable]]] = None) -> "TaskDefinition":
|
|
121
|
+
"""
|
|
122
|
+
Create a TaskDefinition from a HuggingFace dataset row.
|
|
123
|
+
|
|
124
|
+
This handles the stringified JSON fields and optional reward function importing
|
|
125
|
+
that's specific to the HF dataset format.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
row: Dictionary representing a row from the HF dataset
|
|
129
|
+
reward_function_importer: Optional function to import reward functions by name
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
TaskDefinition instance
|
|
133
|
+
"""
|
|
134
|
+
# Parse stringified JSON fields
|
|
135
|
+
initial_state = json.loads(row["initial_state"])
|
|
136
|
+
environment = json.loads(row["environment"])
|
|
137
|
+
instructions = json.loads(row["instructions"])
|
|
138
|
+
metadata = json.loads(row["metadata"])
|
|
139
|
+
|
|
140
|
+
# Handle valid_target_states (may be empty string)
|
|
141
|
+
valid_target_states = None
|
|
142
|
+
if row["valid_target_states"] and row["valid_target_states"].strip():
|
|
143
|
+
valid_target_states = json.loads(row["valid_target_states"])
|
|
144
|
+
|
|
145
|
+
# Handle reward_function (may be empty string)
|
|
146
|
+
reward_function = None
|
|
147
|
+
if row["reward_function"] and row["reward_function"].strip():
|
|
148
|
+
if reward_function_importer:
|
|
149
|
+
reward_function = reward_function_importer(row["reward_function"])
|
|
150
|
+
else:
|
|
151
|
+
# Store function name as string if no importer provided
|
|
152
|
+
# This allows the caller to handle importing later
|
|
153
|
+
pass
|
|
154
|
+
|
|
155
|
+
return cls(
|
|
156
|
+
id=row["id"],
|
|
157
|
+
name=row["name"],
|
|
158
|
+
description=row["description"],
|
|
159
|
+
environment=environment,
|
|
160
|
+
initial_state=initial_state,
|
|
161
|
+
instructions=instructions,
|
|
162
|
+
reward_function=reward_function,
|
|
163
|
+
valid_target_states=valid_target_states,
|
|
164
|
+
max_steps=row["max_steps"],
|
|
165
|
+
timeout_seconds=row["timeout_seconds"],
|
|
166
|
+
metadata=metadata
|
|
167
|
+
)
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from dotenv import load_dotenv
|
|
4
|
+
|
|
5
|
+
from core.models import SettingsConfig
|
|
6
|
+
|
|
7
|
+
load_dotenv()
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
_DEFAULT_HTTP_ENDPOINT = "https://orchestrator.trydojo.ai/api/v1"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _resolve_http_endpoint() -> str:
|
|
14
|
+
"""Resolve the HTTP endpoint for the Dojo backend."""
|
|
15
|
+
|
|
16
|
+
return os.getenv("DOJO_HTTP_ENDPOINT") or _DEFAULT_HTTP_ENDPOINT
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _derive_ws_endpoint(http_endpoint: str) -> str:
|
|
20
|
+
"""Derive a websocket endpoint from the provided HTTP endpoint."""
|
|
21
|
+
|
|
22
|
+
if not http_endpoint:
|
|
23
|
+
return "ws://localhost:8765/api/v1/jobs"
|
|
24
|
+
|
|
25
|
+
normalized = http_endpoint.rstrip("/")
|
|
26
|
+
if normalized.startswith("https://"):
|
|
27
|
+
normalized = "wss://" + normalized[len("https://") :]
|
|
28
|
+
elif normalized.startswith("http://"):
|
|
29
|
+
normalized = "ws://" + normalized[len("http://") :]
|
|
30
|
+
|
|
31
|
+
if not normalized.endswith("/jobs"):
|
|
32
|
+
normalized = f"{normalized}/jobs"
|
|
33
|
+
|
|
34
|
+
return normalized
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
_http_endpoint = _resolve_http_endpoint()
|
|
38
|
+
_ws_endpoint = os.getenv("DOJO_WEBSOCKET_ENDPOINT") or _derive_ws_endpoint(_http_endpoint)
|
|
39
|
+
|
|
40
|
+
settings = SettingsConfig(
|
|
41
|
+
anthropic_api_key=os.getenv("ANTHROPIC_API_KEY", ""),
|
|
42
|
+
openai_api_key=os.getenv("OPENAI_API_KEY", ""),
|
|
43
|
+
openai_api_url=os.getenv("OPENAI_API_URL", ""),
|
|
44
|
+
# TODO: switch to prod endpoint as default
|
|
45
|
+
dojo_websocket_endpoint=_ws_endpoint,
|
|
46
|
+
dojo_http_endpoint=_http_endpoint,
|
|
47
|
+
)
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
"""Task definition and loading system."""
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
import json
|
|
5
|
+
import string
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from typing import Any, Callable, List, Optional
|
|
8
|
+
|
|
9
|
+
from datasets import load_dataset
|
|
10
|
+
|
|
11
|
+
from core.dojos import Dojo
|
|
12
|
+
from core.dojos.dojos_loader import load_dojos
|
|
13
|
+
from core.dojos.rewards import get_reward_function
|
|
14
|
+
from core.models import TaskDefinition
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TaskLoader(ABC):
|
|
18
|
+
"""Abstract base class for task loaders."""
|
|
19
|
+
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def load_task(self, task_path: str) -> TaskDefinition:
|
|
22
|
+
"""Load a task definition from a JSON file."""
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def load_dojo_tasks(self, dojo_name: str) -> List[TaskDefinition]:
|
|
27
|
+
"""Load all tasks from a benchmark directory."""
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def list_tasks(self, category: Optional[str] = None) -> dict[str, TaskDefinition]:
|
|
32
|
+
"""List all available tasks."""
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class PyTaskLoader:
|
|
37
|
+
"""Loads tasks from python-based task registry"""
|
|
38
|
+
|
|
39
|
+
@staticmethod
|
|
40
|
+
def _register_all(bms: list[Dojo]) -> dict[string, Dojo]:
|
|
41
|
+
result = {}
|
|
42
|
+
for bm in bms:
|
|
43
|
+
result[bm.get_id()] = bm
|
|
44
|
+
|
|
45
|
+
return result
|
|
46
|
+
|
|
47
|
+
def __init__(self):
|
|
48
|
+
self.benchmarks_by_name = PyTaskLoader._register_all(load_dojos())
|
|
49
|
+
|
|
50
|
+
def load_task(self, benchmark_task_path: str) -> TaskDefinition:
|
|
51
|
+
split_path = benchmark_task_path.split("/")
|
|
52
|
+
benchmark_name = split_path[0]
|
|
53
|
+
task_id = split_path[1]
|
|
54
|
+
tasks = self.load_benchmark_tasks(benchmark_name)
|
|
55
|
+
for task in tasks:
|
|
56
|
+
if task.id == task_id:
|
|
57
|
+
return task
|
|
58
|
+
raise ValueError(f"Task {task_id} not found in benchmark {benchmark_name}")
|
|
59
|
+
|
|
60
|
+
def load_benchmark_tasks(self, benchmark_name: str) -> List[TaskDefinition]:
|
|
61
|
+
return self.benchmarks_by_name[benchmark_name].get_tasks_list()
|
|
62
|
+
|
|
63
|
+
def list_tasks(self, category: Optional[str] = None) -> dict[str, TaskDefinition]:
|
|
64
|
+
raise NotImplementedError("List tasks not implemented")
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class RemoteTaskLoader:
|
|
68
|
+
"""Loads tasks from a Hugging Face dataset."""
|
|
69
|
+
|
|
70
|
+
def __init__(self, dataset_name: str):
|
|
71
|
+
self.dataset_name = dataset_name
|
|
72
|
+
self.dataset = None
|
|
73
|
+
self.tasks_cache = None
|
|
74
|
+
|
|
75
|
+
def _load_dataset(self):
|
|
76
|
+
"""Lazy load the HF dataset."""
|
|
77
|
+
if self.dataset is None:
|
|
78
|
+
self.dataset = load_dataset(self.dataset_name)["train"]
|
|
79
|
+
print(f"✓ Loaded {len(self.dataset)} tasks from HF dataset {self.dataset_name}")
|
|
80
|
+
return self.dataset
|
|
81
|
+
|
|
82
|
+
def _import_reward_function(self, function_name: str) -> Optional[Callable[[dict[str, Any], dict[str, Any]], tuple[float, str]]]:
|
|
83
|
+
"""Import a reward function by name from the centralized rewards module."""
|
|
84
|
+
if not function_name or function_name == "":
|
|
85
|
+
return None
|
|
86
|
+
|
|
87
|
+
return get_reward_function(function_name)
|
|
88
|
+
|
|
89
|
+
def _get_all_tasks(self) -> List[TaskDefinition]:
|
|
90
|
+
"""Load all tasks from HF dataset and convert to TaskDefinition objects."""
|
|
91
|
+
if self.tasks_cache is None:
|
|
92
|
+
dataset = self._load_dataset()
|
|
93
|
+
tasks = []
|
|
94
|
+
|
|
95
|
+
for row in dataset:
|
|
96
|
+
# Use the clean constructor that handles HF-specific parsing
|
|
97
|
+
task = TaskDefinition.from_hf_row(row, reward_function_importer=self._import_reward_function)
|
|
98
|
+
tasks.append(task)
|
|
99
|
+
|
|
100
|
+
self.tasks_cache = tasks
|
|
101
|
+
print(f"✓ Converted {len(tasks)} HF tasks to TaskDefinition objects")
|
|
102
|
+
|
|
103
|
+
return self.tasks_cache
|
|
104
|
+
|
|
105
|
+
def load_task(self, task_path: str) -> TaskDefinition:
|
|
106
|
+
"""Load a specific task by dojo/task_id path."""
|
|
107
|
+
# Parse task_path like "action-tester/must-click"
|
|
108
|
+
if "/" not in task_path:
|
|
109
|
+
raise ValueError(f"Invalid task path format: {task_path}. Expected 'dojo/task_id'")
|
|
110
|
+
|
|
111
|
+
dojo_name, task_id = task_path.split("/", 1)
|
|
112
|
+
|
|
113
|
+
all_tasks = self._get_all_tasks()
|
|
114
|
+
|
|
115
|
+
# Find task by task_id (dojo info is embedded in the task path, not metadata)
|
|
116
|
+
for task in all_tasks:
|
|
117
|
+
if task.id == task_id:
|
|
118
|
+
return task
|
|
119
|
+
|
|
120
|
+
raise ValueError(f"Task {task_id} not found in dojo {dojo_name}")
|
|
121
|
+
|
|
122
|
+
def list_tasks(self, category: Optional[str] = None) -> dict[str, TaskDefinition]:
|
|
123
|
+
raise NotImplementedError("List tasks not implemented")
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import List, Literal, Union
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
|
|
6
|
+
"""Type definitions for Dojo."""
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ActionType(str, Enum):
|
|
10
|
+
KEY = "key"
|
|
11
|
+
CLICK = "click"
|
|
12
|
+
RIGHT_CLICK = "right_click"
|
|
13
|
+
SCROLL = "scroll"
|
|
14
|
+
TYPE = "type"
|
|
15
|
+
DOUBLE_CLICK = "double_click"
|
|
16
|
+
DRAG = "drag"
|
|
17
|
+
MOVE_TO = "move_to"
|
|
18
|
+
PRESS = "press"
|
|
19
|
+
HOTKEY = "hotkey"
|
|
20
|
+
MIDDLE_CLICK = "middle_click"
|
|
21
|
+
DONE = "done"
|
|
22
|
+
WAIT = "wait"
|
|
23
|
+
FAIL = "fail"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class KeyAction(BaseModel):
|
|
27
|
+
type: Literal[ActionType.KEY]
|
|
28
|
+
key: str
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ClickAction(BaseModel):
|
|
32
|
+
type: Literal[ActionType.CLICK]
|
|
33
|
+
x: int
|
|
34
|
+
y: int
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class RightClickAction(BaseModel):
|
|
38
|
+
type: Literal[ActionType.RIGHT_CLICK]
|
|
39
|
+
x: int
|
|
40
|
+
y: int
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class ScrollAction(BaseModel):
|
|
44
|
+
type: Literal[ActionType.SCROLL]
|
|
45
|
+
direction: str = "up"
|
|
46
|
+
amount: int = 100
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class TypeAction(BaseModel):
|
|
50
|
+
type: Literal[ActionType.TYPE]
|
|
51
|
+
text: str
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class DoubleClickAction(BaseModel):
|
|
55
|
+
type: Literal[ActionType.DOUBLE_CLICK]
|
|
56
|
+
x: int
|
|
57
|
+
y: int
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class DragAction(BaseModel):
|
|
61
|
+
type: Literal[ActionType.DRAG]
|
|
62
|
+
from_x: int
|
|
63
|
+
from_y: int
|
|
64
|
+
to_x: int
|
|
65
|
+
to_y: int
|
|
66
|
+
duration: float = 1.0
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class MoveToAction(BaseModel):
|
|
70
|
+
type: Literal[ActionType.MOVE_TO]
|
|
71
|
+
x: int
|
|
72
|
+
y: int
|
|
73
|
+
duration: float = 0.0
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class PressAction(BaseModel):
|
|
77
|
+
type: Literal[ActionType.PRESS]
|
|
78
|
+
key: str
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class HotkeyAction(BaseModel):
|
|
82
|
+
type: Literal[ActionType.HOTKEY]
|
|
83
|
+
keys: List[str]
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class MiddleClickAction(BaseModel):
|
|
87
|
+
type: Literal[ActionType.MIDDLE_CLICK]
|
|
88
|
+
x: int
|
|
89
|
+
y: int
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class DoneAction(BaseModel):
|
|
93
|
+
type: Literal[ActionType.DONE]
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class WaitAction(BaseModel):
|
|
97
|
+
type: Literal[ActionType.WAIT]
|
|
98
|
+
seconds: int = 1
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class FailAction(BaseModel):
|
|
102
|
+
type: Literal[ActionType.FAIL]
|
|
103
|
+
message: str
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
Action = Union[
|
|
107
|
+
KeyAction,
|
|
108
|
+
ClickAction,
|
|
109
|
+
RightClickAction,
|
|
110
|
+
ScrollAction,
|
|
111
|
+
TypeAction,
|
|
112
|
+
DoubleClickAction,
|
|
113
|
+
DragAction,
|
|
114
|
+
MoveToAction,
|
|
115
|
+
PressAction,
|
|
116
|
+
HotkeyAction,
|
|
117
|
+
MiddleClickAction,
|
|
118
|
+
DoneAction,
|
|
119
|
+
WaitAction,
|
|
120
|
+
FailAction,
|
|
121
|
+
]
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class Score(BaseModel):
|
|
125
|
+
task_name: str
|
|
126
|
+
score: float
|
|
127
|
+
status: str
|
|
128
|
+
success: bool
|
|
129
|
+
steps_taken: int
|
|
130
|
+
reward: float
|
|
131
|
+
completion_reason: str
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import List, Literal, Optional, Union
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
from core.types import Action
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ExecutionTrace(BaseModel):
|
|
11
|
+
task_name: str
|
|
12
|
+
status: str
|
|
13
|
+
success: bool
|
|
14
|
+
steps_taken: int
|
|
15
|
+
reward: float
|
|
16
|
+
completion_reason: str
|
|
17
|
+
history: List[dict]
|
|
18
|
+
final_state: dict
|
|
19
|
+
final_screenshot: str
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# WebSocket API Message Types
|
|
23
|
+
class TaskStatus(str, Enum):
|
|
24
|
+
RUNNING = "RUNNING"
|
|
25
|
+
COMPLETED = "COMPLETED"
|
|
26
|
+
FAILED = "FAILED"
|
|
27
|
+
PENDING = "PENDING"
|
|
28
|
+
TIMEOUT = "TIMEOUT"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class JobStatus(str, Enum):
|
|
32
|
+
RUNNING = "RUNNING"
|
|
33
|
+
COMPLETED = "COMPLETED"
|
|
34
|
+
FAILED = "FAILED"
|
|
35
|
+
PENDING = "PENDING"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class MsgType(str, Enum):
|
|
39
|
+
START_JOB = "StartJobRequest"
|
|
40
|
+
START_JOB_RESP = "StartJobResponse"
|
|
41
|
+
SUBMIT_ACTION = "SubmitActionRequest"
|
|
42
|
+
GET_NEXT_ACTION = "GetNextActionRequest"
|
|
43
|
+
TASK_COMPLETE = "TaskComplete"
|
|
44
|
+
JOB_STATUS_RESP = "JobStatusResponse"
|
|
45
|
+
JOB_COMPLETE = "JobComplete"
|
|
46
|
+
AUTH = "Auth"
|
|
47
|
+
AUTH_RESP = "AuthResp"
|
|
48
|
+
CANCEL_JOB = "CancelJobRequest"
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class HistoryStep(BaseModel):
|
|
52
|
+
step: int
|
|
53
|
+
after_screenshot: str
|
|
54
|
+
agent_response: str
|
|
55
|
+
raw_response: str
|
|
56
|
+
action: Action
|
|
57
|
+
score: float
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class NextStep(BaseModel):
|
|
61
|
+
number: int
|
|
62
|
+
after_screenshot: str
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class TaskResult(BaseModel):
|
|
66
|
+
status: TaskStatus
|
|
67
|
+
start_time: datetime
|
|
68
|
+
end_time: datetime
|
|
69
|
+
exec_id: str
|
|
70
|
+
task_id: str
|
|
71
|
+
task_name: str
|
|
72
|
+
job_id: str
|
|
73
|
+
final_score: float
|
|
74
|
+
reason: str
|
|
75
|
+
history: List[HistoryStep]
|
|
76
|
+
recording: Optional[str] = None
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class JobComplete(BaseModel):
|
|
80
|
+
type: Literal[MsgType.JOB_COMPLETE]
|
|
81
|
+
job_id: str
|
|
82
|
+
tasks: List[TaskResult]
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
# Stateless API Types
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class PendingTask(BaseModel):
|
|
89
|
+
start_time: datetime
|
|
90
|
+
task_name: str
|
|
91
|
+
id: str
|
|
92
|
+
exec_id: str
|
|
93
|
+
prompt: str
|
|
94
|
+
status: TaskStatus
|
|
95
|
+
history: List[HistoryStep]
|
|
96
|
+
pending_step: Optional[NextStep] = None
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class FinishedTask(BaseModel):
|
|
100
|
+
status: TaskStatus
|
|
101
|
+
start_time: datetime
|
|
102
|
+
end_time: datetime
|
|
103
|
+
task_name: str
|
|
104
|
+
id: str
|
|
105
|
+
exec_id: str
|
|
106
|
+
prompt: str
|
|
107
|
+
history: List[HistoryStep]
|
|
108
|
+
score: float
|
|
109
|
+
recording: Optional[str] = None
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class JobStatusResponse(BaseModel):
|
|
113
|
+
status: JobStatus
|
|
114
|
+
start_time: datetime
|
|
115
|
+
end_time: datetime | None
|
|
116
|
+
pending_tasks: List[PendingTask]
|
|
117
|
+
finished_tasks: List[FinishedTask]
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class JobMetadata(BaseModel):
|
|
121
|
+
id: str
|
|
122
|
+
status: JobStatus
|
|
123
|
+
start_time: datetime
|
|
124
|
+
end_time: datetime | None
|
|
125
|
+
num_tasks: int
|
|
126
|
+
num_completed_tasks: int
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class JobsResponse(BaseModel):
|
|
130
|
+
jobs: List[JobMetadata]
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
# Union type for all WebSocket messages
|
|
134
|
+
WebSocketMessage = Union[JobComplete,]
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "dojo-sdk-core"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Core functionality of dojo"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.9"
|
|
7
|
+
dependencies = ["pydantic>=2.11.9", "python-dotenv>=1.1.1", "datasets>=4.1.1"]
|
|
8
|
+
|
|
9
|
+
[project.optional-dependencies]
|
|
10
|
+
dev = ["pytest>=7.0.0", "pytest-cov", "black", "isort", "mypy", "ruff"]
|
|
11
|
+
|
|
12
|
+
[build-system]
|
|
13
|
+
requires = ["hatchling"]
|
|
14
|
+
build-backend = "hatchling.build"
|
|
15
|
+
|
|
16
|
+
[tool.hatch.build.targets.sdist]
|
|
17
|
+
include = ["/src", "/README.md"]
|
|
18
|
+
|
|
19
|
+
[tool.isort]
|
|
20
|
+
profile = "black"
|
|
21
|
+
src_paths = ["src", "tests"]
|
|
22
|
+
|
|
23
|
+
[tool.ruff]
|
|
24
|
+
line-length = 128
|
|
25
|
+
target-version = "py311"
|
|
26
|
+
|
|
27
|
+
[tool.ruff.lint]
|
|
28
|
+
select = ["E", "F", "W", "B", "I"]
|
|
29
|
+
|
|
30
|
+
[tool.ruff.format]
|
|
31
|
+
quote-style = "double"
|
|
32
|
+
indent-style = "space"
|
|
33
|
+
|
|
34
|
+
[tool.hatch.build]
|
|
35
|
+
sources = ["src"]
|
|
36
|
+
|
|
37
|
+
[tool.hatch.build.targets.wheel]
|
|
38
|
+
packages = ["src/core"]
|