inspect-ai 0.3.69__py3-none-any.whl → 0.3.71__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +27 -9
- inspect_ai/_display/core/display.py +2 -0
- inspect_ai/_display/core/footer.py +13 -3
- inspect_ai/_display/plain/display.py +6 -2
- inspect_ai/_display/rich/display.py +19 -6
- inspect_ai/_display/textual/app.py +9 -3
- inspect_ai/_display/textual/display.py +4 -0
- inspect_ai/_display/textual/widgets/samples.py +4 -10
- inspect_ai/_display/textual/widgets/transcript.py +35 -18
- inspect_ai/_eval/eval.py +14 -2
- inspect_ai/_eval/evalset.py +6 -1
- inspect_ai/_eval/run.py +6 -0
- inspect_ai/_eval/task/run.py +49 -23
- inspect_ai/_eval/task/task.py +26 -3
- inspect_ai/_util/content.py +20 -1
- inspect_ai/_util/interrupt.py +6 -0
- inspect_ai/_util/logger.py +19 -0
- inspect_ai/_util/rich.py +7 -8
- inspect_ai/_util/text.py +13 -0
- inspect_ai/_util/transcript.py +20 -6
- inspect_ai/_util/working.py +50 -0
- inspect_ai/_view/www/App.css +6 -0
- inspect_ai/_view/www/dist/assets/index.css +171 -99
- inspect_ai/_view/www/dist/assets/index.js +5972 -2770
- inspect_ai/_view/www/eslint.config.mjs +24 -1
- inspect_ai/_view/www/log-schema.json +619 -21
- inspect_ai/_view/www/package.json +8 -3
- inspect_ai/_view/www/src/App.tsx +2 -2
- inspect_ai/_view/www/src/appearance/icons.ts +3 -1
- inspect_ai/_view/www/src/components/AnsiDisplay.tsx +4 -3
- inspect_ai/_view/www/src/components/Card.tsx +9 -8
- inspect_ai/_view/www/src/components/DownloadButton.tsx +2 -1
- inspect_ai/_view/www/src/components/EmptyPanel.tsx +2 -2
- inspect_ai/_view/www/src/components/ErrorPanel.tsx +4 -3
- inspect_ai/_view/www/src/components/ExpandablePanel.tsx +13 -5
- inspect_ai/_view/www/src/components/FindBand.tsx +3 -3
- inspect_ai/_view/www/src/components/HumanBaselineView.tsx +3 -3
- inspect_ai/_view/www/src/components/LabeledValue.tsx +5 -4
- inspect_ai/_view/www/src/components/LargeModal.tsx +18 -13
- inspect_ai/_view/www/src/components/{LightboxCarousel.css → LightboxCarousel.module.css} +22 -18
- inspect_ai/_view/www/src/components/LightboxCarousel.tsx +36 -27
- inspect_ai/_view/www/src/components/MessageBand.tsx +2 -1
- inspect_ai/_view/www/src/components/NavPills.tsx +9 -8
- inspect_ai/_view/www/src/components/ProgressBar.tsx +2 -1
- inspect_ai/_view/www/src/components/TabSet.tsx +21 -15
- inspect_ai/_view/www/src/index.tsx +2 -2
- inspect_ai/_view/www/src/metadata/MetaDataGrid.tsx +11 -9
- inspect_ai/_view/www/src/metadata/MetaDataView.tsx +3 -2
- inspect_ai/_view/www/src/metadata/MetadataGrid.module.css +1 -0
- inspect_ai/_view/www/src/metadata/RenderedContent.tsx +16 -1
- inspect_ai/_view/www/src/plan/DatasetDetailView.tsx +3 -2
- inspect_ai/_view/www/src/plan/DetailStep.tsx +2 -1
- inspect_ai/_view/www/src/plan/PlanCard.tsx +2 -5
- inspect_ai/_view/www/src/plan/PlanDetailView.tsx +6 -9
- inspect_ai/_view/www/src/plan/ScorerDetailView.tsx +2 -1
- inspect_ai/_view/www/src/plan/SolverDetailView.tsx +3 -3
- inspect_ai/_view/www/src/samples/InlineSampleDisplay.tsx +2 -2
- inspect_ai/_view/www/src/samples/SampleDialog.tsx +3 -3
- inspect_ai/_view/www/src/samples/SampleDisplay.module.css +9 -1
- inspect_ai/_view/www/src/samples/SampleDisplay.tsx +30 -3
- inspect_ai/_view/www/src/samples/SampleSummaryView.module.css +4 -0
- inspect_ai/_view/www/src/samples/SampleSummaryView.tsx +25 -4
- inspect_ai/_view/www/src/samples/SamplesTools.tsx +2 -1
- inspect_ai/_view/www/src/samples/chat/ChatMessage.tsx +3 -19
- inspect_ai/_view/www/src/samples/chat/ChatMessageRenderer.tsx +2 -1
- inspect_ai/_view/www/src/samples/chat/ChatMessageRow.tsx +2 -1
- inspect_ai/_view/www/src/samples/chat/ChatView.tsx +2 -1
- inspect_ai/_view/www/src/samples/chat/ChatViewVirtualList.tsx +22 -7
- inspect_ai/_view/www/src/samples/chat/MessageContent.tsx +35 -6
- inspect_ai/_view/www/src/samples/chat/MessageContents.tsx +2 -2
- inspect_ai/_view/www/src/samples/chat/messages.ts +15 -2
- inspect_ai/_view/www/src/samples/chat/tools/ToolCallView.tsx +13 -4
- inspect_ai/_view/www/src/samples/chat/tools/ToolInput.module.css +2 -2
- inspect_ai/_view/www/src/samples/chat/tools/ToolInput.tsx +18 -19
- inspect_ai/_view/www/src/samples/chat/tools/ToolOutput.module.css +1 -1
- inspect_ai/_view/www/src/samples/chat/tools/ToolOutput.tsx +4 -3
- inspect_ai/_view/www/src/samples/chat/tools/ToolTitle.tsx +2 -2
- inspect_ai/_view/www/src/samples/error/FlatSampleErrorView.tsx +2 -3
- inspect_ai/_view/www/src/samples/error/SampleErrorView.tsx +3 -2
- inspect_ai/_view/www/src/samples/list/SampleFooter.tsx +2 -1
- inspect_ai/_view/www/src/samples/list/SampleHeader.tsx +2 -1
- inspect_ai/_view/www/src/samples/list/SampleList.tsx +57 -45
- inspect_ai/_view/www/src/samples/list/SampleRow.tsx +2 -1
- inspect_ai/_view/www/src/samples/list/SampleSeparator.tsx +2 -1
- inspect_ai/_view/www/src/samples/sample-tools/EpochFilter.tsx +2 -2
- inspect_ai/_view/www/src/samples/sample-tools/SelectScorer.tsx +4 -3
- inspect_ai/_view/www/src/samples/sample-tools/SortFilter.tsx +2 -5
- inspect_ai/_view/www/src/samples/sample-tools/sample-filter/SampleFilter.tsx +2 -2
- inspect_ai/_view/www/src/samples/scores/SampleScoreView.tsx +2 -1
- inspect_ai/_view/www/src/samples/scores/SampleScores.tsx +2 -2
- inspect_ai/_view/www/src/samples/transcript/ApprovalEventView.tsx +2 -1
- inspect_ai/_view/www/src/samples/transcript/ErrorEventView.tsx +2 -1
- inspect_ai/_view/www/src/samples/transcript/InfoEventView.tsx +2 -1
- inspect_ai/_view/www/src/samples/transcript/InputEventView.tsx +2 -1
- inspect_ai/_view/www/src/samples/transcript/LoggerEventView.module.css +4 -0
- inspect_ai/_view/www/src/samples/transcript/LoggerEventView.tsx +12 -2
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.module.css +1 -1
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.tsx +25 -28
- inspect_ai/_view/www/src/samples/transcript/SampleInitEventView.tsx +2 -1
- inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.tsx +9 -4
- inspect_ai/_view/www/src/samples/transcript/SampleTranscript.tsx +2 -2
- inspect_ai/_view/www/src/samples/transcript/SandboxEventView.module.css +32 -0
- inspect_ai/_view/www/src/samples/transcript/SandboxEventView.tsx +153 -0
- inspect_ai/_view/www/src/samples/transcript/ScoreEventView.tsx +2 -2
- inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +12 -5
- inspect_ai/_view/www/src/samples/transcript/SubtaskEventView.tsx +18 -14
- inspect_ai/_view/www/src/samples/transcript/ToolEventView.tsx +5 -5
- inspect_ai/_view/www/src/samples/transcript/TranscriptView.tsx +53 -16
- inspect_ai/_view/www/src/samples/transcript/event/EventNav.tsx +2 -1
- inspect_ai/_view/www/src/samples/transcript/event/EventNavs.tsx +2 -1
- inspect_ai/_view/www/src/samples/transcript/event/EventPanel.tsx +6 -3
- inspect_ai/_view/www/src/samples/transcript/event/EventRow.tsx +3 -2
- inspect_ai/_view/www/src/samples/transcript/event/EventSection.tsx +2 -2
- inspect_ai/_view/www/src/samples/transcript/event/EventTimingPanel.module.css +28 -0
- inspect_ai/_view/www/src/samples/transcript/event/EventTimingPanel.tsx +115 -0
- inspect_ai/_view/www/src/samples/transcript/event/utils.ts +29 -0
- inspect_ai/_view/www/src/samples/transcript/state/StateDiffView.tsx +2 -1
- inspect_ai/_view/www/src/samples/transcript/state/StateEventRenderers.tsx +3 -3
- inspect_ai/_view/www/src/samples/transcript/state/StateEventView.tsx +11 -8
- inspect_ai/_view/www/src/samples/transcript/types.ts +3 -1
- inspect_ai/_view/www/src/types/log.d.ts +312 -137
- inspect_ai/_view/www/src/usage/ModelTokenTable.tsx +6 -10
- inspect_ai/_view/www/src/usage/ModelUsagePanel.module.css +4 -0
- inspect_ai/_view/www/src/usage/ModelUsagePanel.tsx +32 -9
- inspect_ai/_view/www/src/usage/TokenTable.tsx +4 -6
- inspect_ai/_view/www/src/usage/UsageCard.tsx +2 -1
- inspect_ai/_view/www/src/utils/format.ts +8 -5
- inspect_ai/_view/www/src/utils/json.ts +24 -0
- inspect_ai/_view/www/src/workspace/WorkSpace.tsx +6 -5
- inspect_ai/_view/www/src/workspace/WorkSpaceView.tsx +18 -8
- inspect_ai/_view/www/src/workspace/error/TaskErrorPanel.tsx +2 -1
- inspect_ai/_view/www/src/workspace/navbar/Navbar.tsx +2 -1
- inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.tsx +3 -3
- inspect_ai/_view/www/src/workspace/navbar/ResultsPanel.tsx +4 -3
- inspect_ai/_view/www/src/workspace/navbar/SecondaryBar.tsx +5 -4
- inspect_ai/_view/www/src/workspace/navbar/StatusPanel.tsx +5 -8
- inspect_ai/_view/www/src/workspace/sidebar/EvalStatus.tsx +5 -4
- inspect_ai/_view/www/src/workspace/sidebar/LogDirectoryTitleView.tsx +2 -1
- inspect_ai/_view/www/src/workspace/sidebar/Sidebar.tsx +2 -1
- inspect_ai/_view/www/src/workspace/sidebar/SidebarLogEntry.tsx +2 -2
- inspect_ai/_view/www/src/workspace/sidebar/SidebarScoreView.tsx +2 -1
- inspect_ai/_view/www/src/workspace/sidebar/SidebarScoresView.tsx +2 -2
- inspect_ai/_view/www/src/workspace/tabs/InfoTab.tsx +2 -2
- inspect_ai/_view/www/src/workspace/tabs/JsonTab.tsx +2 -5
- inspect_ai/_view/www/src/workspace/tabs/SamplesTab.tsx +12 -11
- inspect_ai/_view/www/yarn.lock +241 -5
- inspect_ai/log/__init__.py +2 -0
- inspect_ai/log/_condense.py +4 -0
- inspect_ai/log/_log.py +72 -12
- inspect_ai/log/_recorders/eval.py +6 -1
- inspect_ai/log/_samples.py +5 -1
- inspect_ai/log/_transcript.py +89 -2
- inspect_ai/model/__init__.py +2 -0
- inspect_ai/model/_call_tools.py +8 -1
- inspect_ai/model/_chat_message.py +22 -7
- inspect_ai/model/_conversation.py +11 -9
- inspect_ai/model/_generate_config.py +25 -4
- inspect_ai/model/_model.py +164 -72
- inspect_ai/model/_model_call.py +10 -3
- inspect_ai/model/_model_output.py +3 -0
- inspect_ai/model/_openai.py +106 -40
- inspect_ai/model/_providers/anthropic.py +145 -26
- inspect_ai/model/_providers/bedrock.py +7 -0
- inspect_ai/model/_providers/cloudflare.py +20 -7
- inspect_ai/model/_providers/google.py +29 -8
- inspect_ai/model/_providers/groq.py +66 -27
- inspect_ai/model/_providers/hf.py +6 -0
- inspect_ai/model/_providers/mistral.py +78 -51
- inspect_ai/model/_providers/openai.py +66 -4
- inspect_ai/model/_providers/openai_o1.py +10 -0
- inspect_ai/model/_providers/providers.py +2 -2
- inspect_ai/model/_providers/util/tracker.py +92 -0
- inspect_ai/model/_providers/vllm.py +13 -5
- inspect_ai/model/_reasoning.py +15 -2
- inspect_ai/scorer/_model.py +23 -19
- inspect_ai/solver/_basic_agent.py +1 -3
- inspect_ai/solver/_bridge/patch.py +0 -2
- inspect_ai/solver/_human_agent/agent.py +14 -10
- inspect_ai/solver/_human_agent/commands/__init__.py +7 -3
- inspect_ai/solver/_human_agent/commands/submit.py +76 -30
- inspect_ai/solver/_limit.py +4 -4
- inspect_ai/solver/_plan.py +0 -3
- inspect_ai/solver/_task_state.py +7 -0
- inspect_ai/tool/__init__.py +2 -0
- inspect_ai/tool/_tool.py +3 -1
- inspect_ai/tool/_tools/_computer/_resources/tool/_run.py +1 -1
- inspect_ai/tool/_tools/_web_browser/_resources/.pylintrc +8 -0
- inspect_ai/tool/_tools/_web_browser/_resources/.vscode/launch.json +24 -0
- inspect_ai/tool/_tools/_web_browser/_resources/.vscode/settings.json +25 -0
- inspect_ai/tool/_tools/_web_browser/_resources/Dockerfile +5 -6
- inspect_ai/tool/_tools/_web_browser/_resources/README.md +10 -11
- inspect_ai/tool/_tools/_web_browser/_resources/accessibility_tree.py +71 -0
- inspect_ai/tool/_tools/_web_browser/_resources/accessibility_tree_node.py +323 -0
- inspect_ai/tool/_tools/_web_browser/_resources/cdp/__init__.py +5 -0
- inspect_ai/tool/_tools/_web_browser/_resources/cdp/a11y.py +279 -0
- inspect_ai/tool/_tools/_web_browser/_resources/cdp/dom.py +9 -0
- inspect_ai/tool/_tools/_web_browser/_resources/cdp/dom_snapshot.py +293 -0
- inspect_ai/tool/_tools/_web_browser/_resources/cdp/page.py +94 -0
- inspect_ai/tool/_tools/_web_browser/_resources/constants.py +2 -0
- inspect_ai/tool/_tools/_web_browser/_resources/images/usage_diagram.svg +2 -0
- inspect_ai/tool/_tools/_web_browser/_resources/playwright_browser.py +50 -0
- inspect_ai/tool/_tools/_web_browser/_resources/playwright_crawler.py +31 -359
- inspect_ai/tool/_tools/_web_browser/_resources/playwright_page_crawler.py +280 -0
- inspect_ai/tool/_tools/_web_browser/_resources/pyproject.toml +65 -0
- inspect_ai/tool/_tools/_web_browser/_resources/rectangle.py +64 -0
- inspect_ai/tool/_tools/_web_browser/_resources/rpc_client_helpers.py +146 -0
- inspect_ai/tool/_tools/_web_browser/_resources/scale_factor.py +64 -0
- inspect_ai/tool/_tools/_web_browser/_resources/test_accessibility_tree_node.py +180 -0
- inspect_ai/tool/_tools/_web_browser/_resources/test_playwright_crawler.py +15 -9
- inspect_ai/tool/_tools/_web_browser/_resources/test_rectangle.py +15 -0
- inspect_ai/tool/_tools/_web_browser/_resources/test_web_client.py +44 -0
- inspect_ai/tool/_tools/_web_browser/_resources/web_browser_rpc_types.py +39 -0
- inspect_ai/tool/_tools/_web_browser/_resources/web_client.py +198 -48
- inspect_ai/tool/_tools/_web_browser/_resources/web_client_new_session.py +26 -25
- inspect_ai/tool/_tools/_web_browser/_resources/web_server.py +178 -39
- inspect_ai/tool/_tools/_web_browser/_web_browser.py +38 -19
- inspect_ai/tool/_tools/_web_search.py +3 -3
- inspect_ai/util/__init__.py +2 -1
- inspect_ai/util/_concurrency.py +14 -8
- inspect_ai/util/_display.py +12 -0
- inspect_ai/util/_sandbox/context.py +15 -0
- inspect_ai/util/_sandbox/docker/docker.py +7 -5
- inspect_ai/util/_sandbox/environment.py +32 -1
- inspect_ai/util/_sandbox/events.py +183 -0
- inspect_ai/util/_sandbox/local.py +3 -3
- inspect_ai/util/_sandbox/self_check.py +131 -43
- inspect_ai/util/_subtask.py +11 -0
- {inspect_ai-0.3.69.dist-info → inspect_ai-0.3.71.dist-info}/METADATA +3 -3
- {inspect_ai-0.3.69.dist-info → inspect_ai-0.3.71.dist-info}/RECORD +233 -211
- {inspect_ai-0.3.69.dist-info → inspect_ai-0.3.71.dist-info}/WHEEL +1 -1
- inspect_ai/_view/www/src/components/VirtualList.module.css +0 -19
- inspect_ai/_view/www/src/components/VirtualList.tsx +0 -292
- inspect_ai/tool/_tools/_web_browser/_resources/accessibility_node.py +0 -312
- inspect_ai/tool/_tools/_web_browser/_resources/dm_env_servicer.py +0 -275
- inspect_ai/tool/_tools/_web_browser/_resources/images/usage_diagram.png +0 -0
- inspect_ai/tool/_tools/_web_browser/_resources/test_accessibility_node.py +0 -176
- inspect_ai/tool/_tools/_web_browser/_resources/test_dm_env_servicer.py +0 -135
- inspect_ai/tool/_tools/_web_browser/_resources/test_web_environment.py +0 -71
- inspect_ai/tool/_tools/_web_browser/_resources/web_environment.py +0 -184
- {inspect_ai-0.3.69.dist-info → inspect_ai-0.3.71.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.69.dist-info → inspect_ai-0.3.71.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.69.dist-info → inspect_ai-0.3.71.dist-info}/top_level.txt +0 -0
@@ -1,275 +0,0 @@
|
|
1
|
-
"""Environment service that allows clients to run shell commands in steps."""
|
2
|
-
|
3
|
-
import threading
|
4
|
-
from typing import Any, Iterable, Type
|
5
|
-
|
6
|
-
import dm_env
|
7
|
-
import grpc
|
8
|
-
import playwright_crawler
|
9
|
-
from dm_env import specs
|
10
|
-
from dm_env_rpc.v1 import (
|
11
|
-
dm_env_rpc_pb2,
|
12
|
-
dm_env_rpc_pb2_grpc,
|
13
|
-
dm_env_utils,
|
14
|
-
spec_manager,
|
15
|
-
)
|
16
|
-
from google.rpc import code_pb2, status_pb2
|
17
|
-
|
18
|
-
_DEFAULT_WORLD_NAME = "WebBrowser"
|
19
|
-
|
20
|
-
|
21
|
-
class EnvironmentSpec:
|
22
|
-
"""Specifications for a dm_environment.
|
23
|
-
|
24
|
-
This class holds action and observation specs, as well as the required
|
25
|
-
managers to pack actions and observations.
|
26
|
-
"""
|
27
|
-
|
28
|
-
def __init__(self, env: dm_env.Environment):
|
29
|
-
convert = dm_env_utils.dm_env_spec_to_tensor_spec
|
30
|
-
|
31
|
-
# We support either a single spec, of flat dictionary of specs.
|
32
|
-
# In the dictionary case we need to map names to unique IDs.
|
33
|
-
env_obs_spec: dict[str, Any] = env.observation_spec()
|
34
|
-
if isinstance(env_obs_spec, specs.Array):
|
35
|
-
self.observation_spec = {1: convert(env_obs_spec)}
|
36
|
-
else:
|
37
|
-
self.observation_spec = {}
|
38
|
-
for i, obs_spec in enumerate(env_obs_spec.values()):
|
39
|
-
self.observation_spec[i + 1] = convert(obs_spec)
|
40
|
-
|
41
|
-
assert isinstance(env.action_spec(), specs.Array), (
|
42
|
-
"Only a single action type is supported."
|
43
|
-
)
|
44
|
-
self.action_spec = {1: convert(env.action_spec())}
|
45
|
-
|
46
|
-
self.observation_manager = spec_manager.SpecManager(self.observation_spec)
|
47
|
-
self.action_manager = spec_manager.SpecManager(self.action_spec)
|
48
|
-
|
49
|
-
|
50
|
-
class EnvironmentService(dm_env_rpc_pb2_grpc.EnvironmentServicer):
|
51
|
-
"""Runs the environment as a gRPC EnvironmentServicer."""
|
52
|
-
|
53
|
-
def __init__(self, env_type: Type[dm_env.Environment]) -> None:
|
54
|
-
"""Initializes the environment.
|
55
|
-
|
56
|
-
Args:
|
57
|
-
env_type: A dm_env class to serve.
|
58
|
-
"""
|
59
|
-
self._env_type = env_type
|
60
|
-
self._envs: dict[str, dm_env.Environment] = {}
|
61
|
-
self._specs: dict[str, EnvironmentSpec] = {}
|
62
|
-
self._joined_worlds: set[str] = set()
|
63
|
-
self._browser: playwright_crawler.PlaywrightBrowser = None
|
64
|
-
self._lock = threading.Lock()
|
65
|
-
self._num_worlds = 0
|
66
|
-
|
67
|
-
def Process(
|
68
|
-
self,
|
69
|
-
request_iterator: Iterable[dm_env_rpc_pb2.EnvironmentRequest],
|
70
|
-
context: grpc.ServicerContext,
|
71
|
-
):
|
72
|
-
"""Processes incoming EnvironmentRequests.
|
73
|
-
|
74
|
-
For each EnvironmentRequest the internal message is extracted and handled.
|
75
|
-
The response for that message is then placed in a EnvironmentResponse which
|
76
|
-
is returned to the client.
|
77
|
-
|
78
|
-
An error status will be returned if an unknown message type is received or
|
79
|
-
if the message is invalid for the current world state.
|
80
|
-
|
81
|
-
|
82
|
-
Args:
|
83
|
-
request_iterator: Message iterator provided by gRPC.
|
84
|
-
context: Context provided by gRPC.
|
85
|
-
|
86
|
-
Yields:
|
87
|
-
EnvironmentResponse: Response for each incoming EnvironmentRequest.
|
88
|
-
"""
|
89
|
-
cur_world = None
|
90
|
-
for request in request_iterator:
|
91
|
-
environment_response = dm_env_rpc_pb2.EnvironmentResponse()
|
92
|
-
try:
|
93
|
-
message_type = request.WhichOneof("payload")
|
94
|
-
internal_request = getattr(request, message_type)
|
95
|
-
match type(internal_request):
|
96
|
-
case dm_env_rpc_pb2.CreateWorldRequest:
|
97
|
-
response = self._handle_create_world_request(internal_request)
|
98
|
-
case dm_env_rpc_pb2.JoinWorldRequest:
|
99
|
-
response, cur_world = self._handle_join_world_request(
|
100
|
-
internal_request
|
101
|
-
)
|
102
|
-
case dm_env_rpc_pb2.LeaveWorldRequest:
|
103
|
-
response = self._handle_leave_world_request(
|
104
|
-
internal_request, cur_world
|
105
|
-
)
|
106
|
-
case dm_env_rpc_pb2.DestroyWorldRequest:
|
107
|
-
response = self._handle_destroy_world_request(internal_request)
|
108
|
-
case dm_env_rpc_pb2.StepRequest:
|
109
|
-
response = self._handle_step_request(
|
110
|
-
internal_request, cur_world
|
111
|
-
)
|
112
|
-
case _:
|
113
|
-
raise ValueError(
|
114
|
-
f"Unsupported request type: {type(internal_request)}"
|
115
|
-
)
|
116
|
-
getattr(environment_response, message_type).CopyFrom(response)
|
117
|
-
except Exception as e: # pylint: disable=broad-except
|
118
|
-
environment_response.error.CopyFrom(
|
119
|
-
status_pb2.Status(code=code_pb2.INTERNAL, message=str(e))
|
120
|
-
)
|
121
|
-
yield environment_response
|
122
|
-
|
123
|
-
def _validate_settings(self, settings, valid_settings):
|
124
|
-
"""Validate the provided settings with list of valid setting keys."""
|
125
|
-
unrecognized_settings = [
|
126
|
-
setting for setting in settings if setting not in valid_settings
|
127
|
-
]
|
128
|
-
|
129
|
-
if unrecognized_settings:
|
130
|
-
raise ValueError(
|
131
|
-
"Unrecognized settings provided! Invalid settings:"
|
132
|
-
f" {unrecognized_settings}"
|
133
|
-
)
|
134
|
-
|
135
|
-
def _add_spec_to_response(
|
136
|
-
self, world_name: str, response: dm_env_rpc_pb2.EnvironmentResponse
|
137
|
-
):
|
138
|
-
"""Modifies given respose to include action/observation specifications."""
|
139
|
-
if not self._specs.get(world_name):
|
140
|
-
raise ValueError(f"Not found a spec for {world_name} world")
|
141
|
-
|
142
|
-
spec = self._specs[world_name]
|
143
|
-
for uid, action in spec.action_spec.items():
|
144
|
-
response.specs.actions[uid].CopyFrom(action)
|
145
|
-
for uid, observation in spec.observation_spec.items():
|
146
|
-
response.specs.observations[uid].CopyFrom(observation)
|
147
|
-
|
148
|
-
def _handle_create_world_request(
|
149
|
-
self, request: dm_env_rpc_pb2.CreateWorldRequest
|
150
|
-
) -> dm_env_rpc_pb2.CreateWorldResponse:
|
151
|
-
"""Handles create_world requests."""
|
152
|
-
self._validate_settings(request.settings, [])
|
153
|
-
del request
|
154
|
-
world_name = _DEFAULT_WORLD_NAME
|
155
|
-
with self._lock:
|
156
|
-
if self._browser is None:
|
157
|
-
self._browser = playwright_crawler.PlaywrightBrowser()
|
158
|
-
else:
|
159
|
-
world_name += f"_{self._num_worlds}"
|
160
|
-
self._num_worlds += 1
|
161
|
-
|
162
|
-
new_context = self._browser.get_new_context()
|
163
|
-
env = self._env_type(new_context)
|
164
|
-
spec = EnvironmentSpec(env)
|
165
|
-
self._envs[world_name] = env
|
166
|
-
self._specs[world_name] = spec
|
167
|
-
|
168
|
-
return dm_env_rpc_pb2.CreateWorldResponse(world_name=world_name)
|
169
|
-
|
170
|
-
def _handle_join_world_request(
|
171
|
-
self, request: dm_env_rpc_pb2.JoinWorldRequest
|
172
|
-
) -> tuple[dm_env_rpc_pb2.JoinWorldResponse, str]:
|
173
|
-
"""Handles join_world requests."""
|
174
|
-
self._validate_settings(request.settings, [])
|
175
|
-
response = dm_env_rpc_pb2.JoinWorldResponse()
|
176
|
-
world_name = request.world_name
|
177
|
-
with self._lock:
|
178
|
-
if not self._envs.get(world_name):
|
179
|
-
raise ValueError(f"Joining with the wrong world_name {world_name}")
|
180
|
-
if world_name in self._joined_worlds:
|
181
|
-
raise ValueError(f"Only one client can joint the world {world_name}")
|
182
|
-
self._joined_worlds.add(world_name)
|
183
|
-
self._add_spec_to_response(world_name, response)
|
184
|
-
|
185
|
-
del request
|
186
|
-
return (response, world_name)
|
187
|
-
|
188
|
-
def _handle_leave_world_request(
|
189
|
-
self, request: dm_env_rpc_pb2.LeaveWorldRequest, world_name: str
|
190
|
-
) -> dm_env_rpc_pb2.LeaveWorldResponse:
|
191
|
-
"""Handles leave_world requests."""
|
192
|
-
del request
|
193
|
-
if world_name in self._joined_worlds:
|
194
|
-
self._joined_worlds.remove(world_name)
|
195
|
-
response = dm_env_rpc_pb2.LeaveWorldResponse()
|
196
|
-
return response
|
197
|
-
|
198
|
-
def _handle_destroy_world_request(
|
199
|
-
self, request: dm_env_rpc_pb2.DestroyWorldRequest
|
200
|
-
) -> dm_env_rpc_pb2.DestroyWorldResponse:
|
201
|
-
"""Handles destroy_world requests."""
|
202
|
-
world_name = request.world_name
|
203
|
-
del request
|
204
|
-
with self._lock:
|
205
|
-
if not self._envs.get(world_name):
|
206
|
-
raise ValueError("Can not destroy uncreated environment.")
|
207
|
-
if world_name in self._joined_worlds:
|
208
|
-
raise ValueError("Can not destroy environment with a joined agent.")
|
209
|
-
env = self._envs.pop(world_name)
|
210
|
-
env.close()
|
211
|
-
env = None
|
212
|
-
self._specs.pop(world_name, None)
|
213
|
-
|
214
|
-
if not self._envs:
|
215
|
-
self._browser.close()
|
216
|
-
self._browser = None
|
217
|
-
response = dm_env_rpc_pb2.DestroyWorldResponse()
|
218
|
-
return response
|
219
|
-
|
220
|
-
def _handle_step_request(
|
221
|
-
self, request: dm_env_rpc_pb2.StepRequest, cur_world: str
|
222
|
-
) -> dm_env_rpc_pb2.StepResponse:
|
223
|
-
"""Handles step requests.
|
224
|
-
|
225
|
-
Args:
|
226
|
-
request: The request, which should contain a 'command' entry.
|
227
|
-
cur_world: The name of the world in which we're making a step.
|
228
|
-
|
229
|
-
Returns:
|
230
|
-
Response including requested observations.
|
231
|
-
|
232
|
-
Raises:
|
233
|
-
KeyError: If the requested observation is not in the list of available
|
234
|
-
observations.
|
235
|
-
"""
|
236
|
-
with self._lock:
|
237
|
-
assert cur_world in self._envs, (
|
238
|
-
"Current world does not have an assosiated environment"
|
239
|
-
)
|
240
|
-
assert cur_world in self._joined_worlds, (
|
241
|
-
"Please join world before calling step."
|
242
|
-
)
|
243
|
-
env = self._envs[cur_world]
|
244
|
-
spec = self._specs[cur_world]
|
245
|
-
|
246
|
-
action = spec.action_manager.unpack(request.actions)
|
247
|
-
|
248
|
-
if "command" in action:
|
249
|
-
command = action["command"]
|
250
|
-
else:
|
251
|
-
# For some reason dm_env calls step without actions after a reset.
|
252
|
-
command = ""
|
253
|
-
|
254
|
-
timestep: dm_env.TimeStep = env.step(command)
|
255
|
-
|
256
|
-
packed_observations = spec.observation_manager.pack(timestep.observation)
|
257
|
-
|
258
|
-
match timestep.step_type:
|
259
|
-
case dm_env.StepType.MID:
|
260
|
-
step_state = dm_env_rpc_pb2.RUNNING
|
261
|
-
case dm_env.StepType.LAST:
|
262
|
-
step_state = dm_env_rpc_pb2.TERMINATED
|
263
|
-
case _:
|
264
|
-
raise ValueError(f"Unsupported step type {timestep.step_type}.")
|
265
|
-
|
266
|
-
response = dm_env_rpc_pb2.StepResponse(state=step_state)
|
267
|
-
for requested_observation in request.requested_observations:
|
268
|
-
if requested_observation not in packed_observations:
|
269
|
-
name = spec.observation_manager.uid_to_name(requested_observation)
|
270
|
-
raise KeyError(f"Requested observation not found: {name}")
|
271
|
-
response.observations[requested_observation].CopyFrom(
|
272
|
-
packed_observations[requested_observation]
|
273
|
-
)
|
274
|
-
|
275
|
-
return response
|
Binary file
|
@@ -1,176 +0,0 @@
|
|
1
|
-
from absl.testing import absltest
|
2
|
-
from accessibility_node import AccessibilityNode, NodeBounds
|
3
|
-
|
4
|
-
|
5
|
-
class TestNodeBounds(absltest.TestCase):
|
6
|
-
def test_union(self):
|
7
|
-
bounds1 = NodeBounds(10, 20, 30, 40)
|
8
|
-
bounds2 = NodeBounds(20, 30, 40, 50)
|
9
|
-
union_bounds = bounds1.union(bounds2)
|
10
|
-
self.assertEqual(union_bounds.left, 10)
|
11
|
-
self.assertEqual(union_bounds.top, 20)
|
12
|
-
self.assertEqual(union_bounds.right, 60)
|
13
|
-
self.assertEqual(union_bounds.bottom, 80)
|
14
|
-
|
15
|
-
def test_union_with_empty_bound(self):
|
16
|
-
bounds = NodeBounds(10, 20, 30, 40)
|
17
|
-
empty_bounds = NodeBounds(0, 0, 0, 0)
|
18
|
-
union_bounds = bounds.union(empty_bounds)
|
19
|
-
self.assertEqual(union_bounds.left, 10)
|
20
|
-
self.assertEqual(union_bounds.top, 20)
|
21
|
-
self.assertEqual(union_bounds.right, 40)
|
22
|
-
self.assertEqual(union_bounds.bottom, 60)
|
23
|
-
|
24
|
-
def test_is_inside(self):
|
25
|
-
bounds = NodeBounds(10, 20, 30, 40)
|
26
|
-
self.assertTrue(bounds.is_inside(20, 30))
|
27
|
-
self.assertFalse(bounds.is_inside(5, 10))
|
28
|
-
self.assertFalse(bounds.is_inside(50, 70))
|
29
|
-
|
30
|
-
def test_overlaps(self):
|
31
|
-
bounds1 = NodeBounds(10, 20, 30, 40)
|
32
|
-
bounds2 = NodeBounds(20, 30, 40, 50)
|
33
|
-
self.assertTrue(bounds1.overlaps(bounds2))
|
34
|
-
bounds3 = NodeBounds(50, 60, 20, 30)
|
35
|
-
self.assertFalse(bounds1.overlaps(bounds3))
|
36
|
-
|
37
|
-
|
38
|
-
class TestAccessibilityNode(absltest.TestCase):
|
39
|
-
def test_getitem(self):
|
40
|
-
node_data = {"name": {"value": "Test"}, "role": {"value": "button"}}
|
41
|
-
node = AccessibilityNode(node_data)
|
42
|
-
self.assertEqual(node["name"], {"value": "Test"})
|
43
|
-
self.assertEqual(node["role"], {"value": "button"})
|
44
|
-
self.assertIsNone(node["invalid_key"])
|
45
|
-
|
46
|
-
def test_setitem(self):
|
47
|
-
node_data = {"name": {"value": "Test"}}
|
48
|
-
node = AccessibilityNode(node_data)
|
49
|
-
node["role"] = {"value": "button"}
|
50
|
-
self.assertEqual(
|
51
|
-
node._node, {"name": {"value": "Test"}, "role": {"value": "button"}}
|
52
|
-
)
|
53
|
-
|
54
|
-
def test_str_role_name(self):
|
55
|
-
node_data = {
|
56
|
-
"nodeId": "1",
|
57
|
-
"role": {"value": "button"},
|
58
|
-
"name": {"value": "Test Button"},
|
59
|
-
}
|
60
|
-
node = AccessibilityNode(node_data, bounds=NodeBounds(10, 10, 20, 20))
|
61
|
-
self.assertIn('[1] button "Test Button"', str(node))
|
62
|
-
|
63
|
-
def test_str_image_src(self):
|
64
|
-
node_data = {
|
65
|
-
"nodeId": "1",
|
66
|
-
"role": {"value": "image"},
|
67
|
-
"name": {"value": "Test Image"},
|
68
|
-
"src": "image.png",
|
69
|
-
}
|
70
|
-
node = AccessibilityNode(node_data, bounds=NodeBounds(10, 10, 20, 20))
|
71
|
-
self.assertIn('[1] image "Test Image" image.png', str(node))
|
72
|
-
|
73
|
-
def test_str_property(self):
|
74
|
-
node_data = {
|
75
|
-
"nodeId": "1",
|
76
|
-
"role": {"value": "link"},
|
77
|
-
"name": {"value": "Test Link"},
|
78
|
-
"properties": [{"name": "url", "value": {"value": "www.example.com"}}],
|
79
|
-
}
|
80
|
-
node = AccessibilityNode(node_data, bounds=NodeBounds(10, 10, 20, 20))
|
81
|
-
self.assertIn('[1] link "Test Link" [url: www.example.com]', str(node))
|
82
|
-
|
83
|
-
def test_str_empty(self):
|
84
|
-
node_data = {"nodeId": "1"}
|
85
|
-
node = AccessibilityNode(node_data)
|
86
|
-
self.assertIn('[*] ""', str(node))
|
87
|
-
|
88
|
-
def test_link_children(self):
|
89
|
-
node_data = {"nodeId": "1", "childIds": [2, 3]}
|
90
|
-
node = AccessibilityNode(node_data)
|
91
|
-
child1 = AccessibilityNode({"nodeId": "2"})
|
92
|
-
child2 = AccessibilityNode({"nodeId": "3"})
|
93
|
-
node_lookup = {2: child1, 3: child2}
|
94
|
-
node.link_children(node_lookup)
|
95
|
-
|
96
|
-
self.assertEqual(node.children, [child1, child2])
|
97
|
-
self.assertEqual(child1._parent, node)
|
98
|
-
self.assertEqual(child2._parent, node)
|
99
|
-
|
100
|
-
def test_get_union_bounds(self):
|
101
|
-
node_data = {"nodeId": "1", "childIds": [2, 3]}
|
102
|
-
node = AccessibilityNode(node_data, bounds=NodeBounds(10, 10, 20, 20))
|
103
|
-
child1 = AccessibilityNode({"nodeId": "2"}, bounds=NodeBounds(20, 20, 30, 30))
|
104
|
-
child2 = AccessibilityNode({"nodeId": "3"}, bounds=NodeBounds(5, 5, 10, 10))
|
105
|
-
node.link_children({2: child1, 3: child2}) # Manually link children
|
106
|
-
|
107
|
-
union_bounds = node.get_union_bounds()
|
108
|
-
self.assertEqual(union_bounds.left, 5)
|
109
|
-
self.assertEqual(union_bounds.top, 5)
|
110
|
-
self.assertEqual(union_bounds.right, 50)
|
111
|
-
self.assertEqual(union_bounds.bottom, 50)
|
112
|
-
|
113
|
-
def test_get_property(self):
|
114
|
-
node_data = {"name": {"value": "Test"}, "role": {"value": "button"}}
|
115
|
-
node = AccessibilityNode(node_data)
|
116
|
-
self.assertEqual(node.get_property("name"), {"value": "Test"})
|
117
|
-
self.assertEqual(node.get_property("role"), {"value": "button"})
|
118
|
-
self.assertIsNone(node.get_property("invalid_key"))
|
119
|
-
|
120
|
-
def test_to_string(self):
|
121
|
-
node_data = {
|
122
|
-
"nodeId": "1",
|
123
|
-
"ignored": False,
|
124
|
-
"childIds": [2],
|
125
|
-
"role": {"value": "button"},
|
126
|
-
"name": {"value": "Test Button"},
|
127
|
-
}
|
128
|
-
node = AccessibilityNode(node_data, bounds=NodeBounds(20, 20, 30, 30))
|
129
|
-
self.assertIn('[1] button "Test Button"', node.to_string())
|
130
|
-
|
131
|
-
def test_to_string_with_children(self):
|
132
|
-
node_data = {
|
133
|
-
"nodeId": "1",
|
134
|
-
"ignored": False,
|
135
|
-
"childIds": [2, 3],
|
136
|
-
"role": {"value": "button"},
|
137
|
-
"name": {"value": "Test Button"},
|
138
|
-
}
|
139
|
-
node = AccessibilityNode(node_data, bounds=NodeBounds(20, 20, 30, 30))
|
140
|
-
child_data = {
|
141
|
-
"nodeId": "2",
|
142
|
-
"ignored": False,
|
143
|
-
"role": {"value": "text"},
|
144
|
-
"name": {"value": "Child Text"},
|
145
|
-
}
|
146
|
-
child = AccessibilityNode(child_data, bounds=NodeBounds(20, 20, 30, 30))
|
147
|
-
ignored_child_data = {
|
148
|
-
"nodeId": "3",
|
149
|
-
"ignored": True,
|
150
|
-
"role": {"value": "text"},
|
151
|
-
"name": {"value": "Child Text"},
|
152
|
-
}
|
153
|
-
ignored_child = AccessibilityNode(
|
154
|
-
ignored_child_data, bounds=NodeBounds(20, 20, 30, 30)
|
155
|
-
)
|
156
|
-
node.link_children({2: child, 3: ignored_child})
|
157
|
-
self.assertIn('[2] text "Child Text"', node.to_string())
|
158
|
-
self.assertNotIn('[3] text "Child Text"', node.to_string())
|
159
|
-
|
160
|
-
# Test with include_children=False
|
161
|
-
self.assertNotIn(
|
162
|
-
'[2] text "Child Text"', node.to_string(include_children=False)
|
163
|
-
)
|
164
|
-
|
165
|
-
def test_property_string(self):
|
166
|
-
node_data = {"properties": [{"name": "checked", "value": {"value": True}}]}
|
167
|
-
node = AccessibilityNode(node_data)
|
168
|
-
self.assertEqual(node.property_string(), " [checked: True]")
|
169
|
-
|
170
|
-
def test_property_string_with_ignored_property(self):
|
171
|
-
node_data = {
|
172
|
-
"properties": [{"name": "focusable", "value": {"value": "some_value"}}]
|
173
|
-
}
|
174
|
-
node = AccessibilityNode(node_data)
|
175
|
-
# 'focusable' is in _IGNORED_ACTREE_PROPERTIES
|
176
|
-
self.assertEqual(node.property_string(), "")
|
@@ -1,135 +0,0 @@
|
|
1
|
-
from concurrent import futures
|
2
|
-
|
3
|
-
import dm_env_servicer
|
4
|
-
import grpc
|
5
|
-
import mock_environment
|
6
|
-
from dm_env_rpc.v1 import (
|
7
|
-
compliance,
|
8
|
-
dm_env_rpc_pb2,
|
9
|
-
dm_env_rpc_pb2_grpc,
|
10
|
-
tensor_utils,
|
11
|
-
)
|
12
|
-
from dm_env_rpc.v1 import connection as dm_env_rpc_connection
|
13
|
-
|
14
|
-
|
15
|
-
class ServerConnection:
|
16
|
-
def __init__(self):
|
17
|
-
self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
|
18
|
-
servicer = dm_env_servicer.EnvironmentService(mock_environment.MockEnvironment)
|
19
|
-
dm_env_rpc_pb2_grpc.add_EnvironmentServicer_to_server(servicer, self._server)
|
20
|
-
port = self._server.add_secure_port("[::]:0", grpc.local_server_credentials())
|
21
|
-
self._server.start()
|
22
|
-
|
23
|
-
self._channel = grpc.secure_channel(
|
24
|
-
f"[::]:{port}", grpc.local_channel_credentials()
|
25
|
-
)
|
26
|
-
grpc.channel_ready_future(self._channel).result()
|
27
|
-
|
28
|
-
self.connection = dm_env_rpc_connection.Connection(self._channel)
|
29
|
-
|
30
|
-
def close(self):
|
31
|
-
self.connection.close()
|
32
|
-
self._channel.close()
|
33
|
-
self._server.stop(grace=None)
|
34
|
-
|
35
|
-
|
36
|
-
class JoinedServerConnection(ServerConnection):
|
37
|
-
def __init__(self):
|
38
|
-
super().__init__()
|
39
|
-
response = self.connection.send(dm_env_rpc_pb2.CreateWorldRequest())
|
40
|
-
self.world_name = response.world_name
|
41
|
-
|
42
|
-
response = self.connection.send(
|
43
|
-
dm_env_rpc_pb2.JoinWorldRequest(world_name=self.world_name)
|
44
|
-
)
|
45
|
-
self.specs = response.specs
|
46
|
-
|
47
|
-
def close(self):
|
48
|
-
try:
|
49
|
-
self.connection.send(dm_env_rpc_pb2.LeaveWorldRequest())
|
50
|
-
self.connection.send(
|
51
|
-
dm_env_rpc_pb2.DestroyWorldRequest(world_name=self.world_name)
|
52
|
-
)
|
53
|
-
finally:
|
54
|
-
super().close()
|
55
|
-
|
56
|
-
|
57
|
-
class DmEnvRpcStepTest(compliance.Step):
|
58
|
-
@property
|
59
|
-
def connection(self):
|
60
|
-
return self._server_connection.connection
|
61
|
-
|
62
|
-
@property
|
63
|
-
def specs(self):
|
64
|
-
return self._server_connection.specs
|
65
|
-
|
66
|
-
def setUp(self):
|
67
|
-
super().setUp()
|
68
|
-
self._server_connection = JoinedServerConnection()
|
69
|
-
|
70
|
-
def tearDown(self):
|
71
|
-
self._server_connection.close()
|
72
|
-
super().tearDown()
|
73
|
-
|
74
|
-
# Overriding this test since this behaviour does not make sence
|
75
|
-
# for our use case.
|
76
|
-
def test_first_step_actions_are_ignored(self):
|
77
|
-
pass
|
78
|
-
|
79
|
-
|
80
|
-
class DmEnvRpcCreateAndDestoryWorldTest(compliance.CreateDestroyWorld):
|
81
|
-
@property
|
82
|
-
def connection(self):
|
83
|
-
return self._server_connection.connection
|
84
|
-
|
85
|
-
@property
|
86
|
-
def required_world_settings(self):
|
87
|
-
"""A string to Tensor mapping of the minimum set of required settings."""
|
88
|
-
return {}
|
89
|
-
|
90
|
-
@property
|
91
|
-
def invalid_world_settings(self):
|
92
|
-
"""World creation settings which are invalid in some way."""
|
93
|
-
return {"invalid_setting": tensor_utils.pack_tensor(123)}
|
94
|
-
|
95
|
-
@property
|
96
|
-
def has_multiple_world_support(self):
|
97
|
-
"""Does the server support creating more than one world?"""
|
98
|
-
return True
|
99
|
-
|
100
|
-
def setUp(self):
|
101
|
-
self._server_connection = ServerConnection()
|
102
|
-
super().setUp()
|
103
|
-
|
104
|
-
def tearDown(self):
|
105
|
-
super().tearDown()
|
106
|
-
self._server_connection.close()
|
107
|
-
|
108
|
-
|
109
|
-
class DmEnvRpcJoinAndLeaveWorldTest(compliance.JoinLeaveWorld):
|
110
|
-
@property
|
111
|
-
def connection(self):
|
112
|
-
return self._server_connection.connection
|
113
|
-
|
114
|
-
@property
|
115
|
-
def world_name(self):
|
116
|
-
return self._world_name
|
117
|
-
|
118
|
-
@property
|
119
|
-
def invalid_join_settings(self):
|
120
|
-
return {"invalid_setting": tensor_utils.pack_tensor(123)}
|
121
|
-
|
122
|
-
def setUp(self):
|
123
|
-
self._server_connection = ServerConnection()
|
124
|
-
response = self.connection.send(dm_env_rpc_pb2.CreateWorldRequest())
|
125
|
-
self._world_name = response.world_name
|
126
|
-
super().setUp()
|
127
|
-
|
128
|
-
def tearDown(self):
|
129
|
-
super().tearDown()
|
130
|
-
try:
|
131
|
-
self.connection.send(
|
132
|
-
dm_env_rpc_pb2.DestroyWorldRequest(world_name=self.world_name)
|
133
|
-
)
|
134
|
-
finally:
|
135
|
-
self._server_connection.close()
|
@@ -1,71 +0,0 @@
|
|
1
|
-
from unittest.mock import MagicMock, call, patch
|
2
|
-
|
3
|
-
import web_environment
|
4
|
-
from absl.testing import parameterized
|
5
|
-
from playwright_crawler import CrawlerOutputFormat
|
6
|
-
|
7
|
-
|
8
|
-
class TestWebEnvironment(parameterized.TestCase):
|
9
|
-
@patch("playwright_crawler.PlaywrightCrawler")
|
10
|
-
def setUp(self, MockCrawler):
|
11
|
-
self._mock_crawler = MagicMock()
|
12
|
-
MockCrawler.return_value = self._mock_crawler
|
13
|
-
self._mock_browser_context = MagicMock()
|
14
|
-
self._web_env = web_environment.WebEnvironment(self._mock_browser_context)
|
15
|
-
|
16
|
-
def test_step_go_to_command(self):
|
17
|
-
self._web_env.step("web_go https://en.wikipedia.org/wiki/Sun ignored_param")
|
18
|
-
self._mock_crawler.go_to_page.assert_called_once_with(
|
19
|
-
"https://en.wikipedia.org/wiki/Sun"
|
20
|
-
)
|
21
|
-
|
22
|
-
def test_step_click_command(self):
|
23
|
-
self._web_env.step("web_click 1111 ignored_param")
|
24
|
-
# click() might be also called later but we only check the first call
|
25
|
-
self.assertEqual(self._mock_crawler.mock_calls[0], call.click("1111"))
|
26
|
-
|
27
|
-
def test_step_scroll_command(self):
|
28
|
-
self._web_env.step("web_scroll up ignored_param")
|
29
|
-
self._mock_crawler.scroll.assert_called_once_with("up")
|
30
|
-
|
31
|
-
def test_step_forward_command(self):
|
32
|
-
self._web_env.step("web_forward ignored_param")
|
33
|
-
self._mock_crawler.forward.assert_called_once_with()
|
34
|
-
|
35
|
-
def test_step_back_command(self):
|
36
|
-
self._web_env.step("web_back ignored_param")
|
37
|
-
self._mock_crawler.back.assert_called_once_with()
|
38
|
-
|
39
|
-
def test_step_refresh_command(self):
|
40
|
-
self._web_env.step("web_refresh ignored_param")
|
41
|
-
self._mock_crawler.refresh.assert_called_once_with()
|
42
|
-
|
43
|
-
def test_step_type_command(self):
|
44
|
-
self._web_env.step("web_type some_element_id text to type into element")
|
45
|
-
self._mock_crawler.type.assert_called_once_with(
|
46
|
-
"some_element_id", "text to type into element"
|
47
|
-
)
|
48
|
-
|
49
|
-
def test_step_type_submit_command(self):
|
50
|
-
self._web_env.step("web_type_submit some_element_id text to type into element")
|
51
|
-
self._mock_crawler.clear.assert_called_once()
|
52
|
-
self._mock_crawler.type.assert_called_once_with(
|
53
|
-
"some_element_id", "text to type into element\n"
|
54
|
-
)
|
55
|
-
|
56
|
-
@parameterized.parameters(
|
57
|
-
("web_go"),
|
58
|
-
("web_click"),
|
59
|
-
("web_scroll"),
|
60
|
-
("web_type"),
|
61
|
-
("web_type_submit"),
|
62
|
-
("some_random_command"),
|
63
|
-
)
|
64
|
-
def test_step_invalid_command(self, command):
|
65
|
-
self._web_env.step(command)
|
66
|
-
self.assertEqual(self._web_env._last_error, f'\n\nInvalid command: "{command}"')
|
67
|
-
|
68
|
-
def test_get_observations_returns_only_required_observations(self):
|
69
|
-
obs = self._web_env.get_observations(["web_at"])
|
70
|
-
self.assertTrue(set(obs.keys()) == set(["web_at"]))
|
71
|
-
self._mock_crawler.render.assert_called_once_with(CrawlerOutputFormat.AT)
|