android-env 1.2.1__py3-none-any.whl → 1.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- android_env/__init__.py +1 -1
- android_env/components/__init__.py +1 -1
- android_env/components/a11y/__init__.py +15 -0
- android_env/components/a11y/a11y_events.py +118 -0
- android_env/components/a11y/a11y_events_test.py +173 -0
- android_env/components/a11y/a11y_forests.py +128 -0
- android_env/components/a11y/a11y_forests_test.py +237 -0
- android_env/components/a11y/a11y_servicer.py +199 -0
- android_env/components/a11y/a11y_servicer_test.py +224 -0
- android_env/components/action_fns.py +132 -0
- android_env/components/action_fns_test.py +227 -0
- android_env/components/action_type.py +26 -3
- android_env/components/adb_call_parser.py +239 -196
- android_env/components/adb_call_parser_test.py +179 -209
- android_env/components/adb_controller.py +90 -52
- android_env/components/adb_controller_test.py +187 -16
- android_env/components/adb_log_stream.py +17 -5
- android_env/components/adb_log_stream_test.py +17 -3
- android_env/components/app_screen_checker.py +17 -15
- android_env/components/app_screen_checker_test.py +7 -8
- android_env/components/config_classes.py +203 -0
- android_env/components/coordinator.py +102 -338
- android_env/components/coordinator_test.py +59 -199
- android_env/components/device_settings.py +174 -0
- android_env/components/device_settings_test.py +228 -0
- android_env/components/dumpsys_thread.py +3 -4
- android_env/components/dumpsys_thread_test.py +1 -1
- android_env/components/errors.py +52 -10
- android_env/components/errors_test.py +110 -0
- android_env/components/log_stream.py +7 -5
- android_env/components/log_stream_test.py +1 -1
- android_env/components/logcat_thread.py +9 -8
- android_env/components/logcat_thread_test.py +3 -4
- android_env/components/{utils.py → pixel_fns.py} +20 -20
- android_env/components/{utils_test.py → pixel_fns_test.py} +20 -15
- android_env/components/setup_step_interpreter.py +47 -39
- android_env/components/setup_step_interpreter_test.py +4 -4
- android_env/components/simulators/__init__.py +1 -1
- android_env/components/simulators/base_simulator.py +116 -44
- android_env/components/simulators/base_simulator_test.py +131 -9
- android_env/components/simulators/emulator/__init__.py +1 -1
- android_env/components/simulators/emulator/emulator_launcher.py +67 -77
- android_env/components/simulators/emulator/emulator_launcher_test.py +153 -49
- android_env/components/simulators/emulator/emulator_simulator.py +276 -95
- android_env/components/simulators/emulator/emulator_simulator_test.py +314 -89
- android_env/components/simulators/fake/__init__.py +1 -1
- android_env/components/simulators/fake/fake_simulator.py +17 -25
- android_env/components/simulators/fake/fake_simulator_test.py +29 -12
- android_env/components/specs.py +18 -28
- android_env/components/specs_test.py +1 -44
- android_env/components/task_manager.py +48 -48
- android_env/components/task_manager_test.py +71 -60
- android_env/env_interface.py +37 -23
- android_env/environment.py +83 -51
- android_env/environment_test.py +68 -29
- android_env/loader.py +57 -43
- android_env/loader_test.py +115 -35
- android_env/proto/__init__.py +1 -1
- android_env/proto/a11y/__init__.py +15 -0
- android_env/proto/a11y/a11y.proto +75 -0
- android_env/proto/a11y/a11y_pb2.py +54 -0
- android_env/proto/a11y/a11y_pb2.pyi +49 -0
- android_env/proto/a11y/a11y_pb2_grpc.py +202 -0
- android_env/proto/a11y/android_accessibility_action.proto +32 -0
- android_env/proto/a11y/android_accessibility_action_pb2.py +37 -0
- android_env/proto/a11y/android_accessibility_action_pb2.pyi +13 -0
- android_env/proto/a11y/android_accessibility_action_pb2_grpc.py +24 -0
- android_env/proto/a11y/android_accessibility_forest.proto +29 -0
- android_env/proto/a11y/android_accessibility_forest_pb2.py +38 -0
- android_env/proto/a11y/android_accessibility_forest_pb2.pyi +13 -0
- android_env/proto/a11y/android_accessibility_forest_pb2_grpc.py +24 -0
- android_env/proto/a11y/android_accessibility_node_info.proto +122 -0
- android_env/proto/a11y/android_accessibility_node_info_clickable_span.proto +49 -0
- android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2.py +39 -0
- android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2.pyi +28 -0
- android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2_grpc.py +24 -0
- android_env/proto/a11y/android_accessibility_node_info_pb2.py +42 -0
- android_env/proto/a11y/android_accessibility_node_info_pb2.pyi +75 -0
- android_env/proto/a11y/android_accessibility_node_info_pb2_grpc.py +24 -0
- android_env/proto/a11y/android_accessibility_tree.proto +29 -0
- android_env/proto/a11y/android_accessibility_tree_pb2.py +38 -0
- android_env/proto/a11y/android_accessibility_tree_pb2.pyi +13 -0
- android_env/proto/a11y/android_accessibility_tree_pb2_grpc.py +24 -0
- android_env/proto/a11y/android_accessibility_window_info.proto +84 -0
- android_env/proto/a11y/android_accessibility_window_info_pb2.py +41 -0
- android_env/proto/a11y/android_accessibility_window_info_pb2.pyi +48 -0
- android_env/proto/a11y/android_accessibility_window_info_pb2_grpc.py +24 -0
- android_env/proto/a11y/rect.proto +30 -0
- android_env/proto/a11y/rect_pb2.py +37 -0
- android_env/proto/a11y/rect_pb2.pyi +17 -0
- android_env/proto/a11y/rect_pb2_grpc.py +24 -0
- android_env/proto/adb.proto +17 -6
- android_env/proto/adb_pb2.py +120 -107
- android_env/proto/adb_pb2.pyi +396 -0
- android_env/proto/adb_pb2_grpc.py +20 -0
- android_env/proto/emulator_controller.proto +68 -63
- android_env/proto/emulator_controller_pb2.py +142 -131
- android_env/proto/emulator_controller_pb2.pyi +672 -0
- android_env/proto/emulator_controller_pb2_grpc.py +505 -142
- android_env/proto/snapshot.proto +169 -0
- android_env/proto/snapshot_pb2.py +47 -0
- android_env/proto/snapshot_pb2.pyi +117 -0
- android_env/proto/snapshot_pb2_grpc.py +24 -0
- android_env/proto/snapshot_service.proto +289 -0
- android_env/proto/snapshot_service_pb2.py +54 -0
- android_env/proto/snapshot_service_pb2.pyi +86 -0
- android_env/proto/snapshot_service_pb2_grpc.py +487 -0
- android_env/proto/state.proto +63 -0
- android_env/proto/state_pb2.py +63 -0
- android_env/proto/state_pb2.pyi +85 -0
- android_env/proto/state_pb2_grpc.py +24 -0
- android_env/proto/task.proto +5 -1
- android_env/proto/task_pb2.py +42 -31
- android_env/proto/task_pb2.pyi +160 -0
- android_env/proto/task_pb2_grpc.py +20 -0
- android_env/wrappers/__init__.py +1 -1
- android_env/wrappers/a11y_grpc_wrapper.py +500 -0
- android_env/wrappers/a11y_grpc_wrapper_test.py +849 -0
- android_env/wrappers/base_wrapper.py +34 -13
- android_env/wrappers/base_wrapper_test.py +22 -16
- android_env/wrappers/discrete_action_wrapper.py +18 -17
- android_env/wrappers/discrete_action_wrapper_test.py +4 -4
- android_env/wrappers/flat_interface_wrapper.py +5 -5
- android_env/wrappers/flat_interface_wrapper_test.py +7 -11
- android_env/wrappers/float_pixels_wrapper.py +9 -10
- android_env/wrappers/float_pixels_wrapper_test.py +3 -3
- android_env/wrappers/gym_wrapper.py +19 -13
- android_env/wrappers/gym_wrapper_test.py +3 -5
- android_env/wrappers/image_rescale_wrapper.py +18 -21
- android_env/wrappers/image_rescale_wrapper_test.py +25 -37
- android_env/wrappers/last_action_wrapper.py +16 -13
- android_env/wrappers/last_action_wrapper_test.py +44 -51
- android_env/wrappers/rate_limit_wrapper.py +6 -3
- android_env/wrappers/rate_limit_wrapper_test.py +22 -1
- android_env/wrappers/tap_action_wrapper.py +16 -17
- android_env/wrappers/tap_action_wrapper_test.py +51 -16
- {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info}/METADATA +14 -18
- android_env-1.2.3.dist-info/RECORD +141 -0
- {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info}/WHEEL +1 -1
- android_env/proto/raw_observation.proto +0 -39
- android_env/proto/raw_observation_pb2.py +0 -27
- android_env/proto/raw_observation_pb2_grpc.py +0 -4
- android_env-1.2.1.dist-info/RECORD +0 -81
- {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info/licenses}/LICENSE +0 -0
- {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2024 DeepMind Technologies Limited.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -15,12 +15,12 @@
|
|
15
15
|
|
16
16
|
"""Base class for AndroidEnv wrappers."""
|
17
17
|
|
18
|
-
from typing import Any
|
18
|
+
from typing import Any
|
19
19
|
|
20
20
|
from absl import logging
|
21
21
|
from android_env import env_interface
|
22
22
|
from android_env.proto import adb_pb2
|
23
|
-
from android_env.proto import
|
23
|
+
from android_env.proto import state_pb2
|
24
24
|
import dm_env
|
25
25
|
from dm_env import specs
|
26
26
|
import numpy as np
|
@@ -42,7 +42,7 @@ class BaseWrapper(env_interface.AndroidEnvInterface):
|
|
42
42
|
action = self._process_action(action)
|
43
43
|
return self._process_timestep(self._env.step(action))
|
44
44
|
|
45
|
-
def task_extras(self, latest_only: bool = True) ->
|
45
|
+
def task_extras(self, latest_only: bool = True) -> dict[str, np.ndarray]:
|
46
46
|
return self._env.task_extras(latest_only=latest_only)
|
47
47
|
|
48
48
|
def _reset_state(self):
|
@@ -54,31 +54,52 @@ class BaseWrapper(env_interface.AndroidEnvInterface):
|
|
54
54
|
def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep:
|
55
55
|
return timestep
|
56
56
|
|
57
|
-
def observation_spec(self) ->
|
57
|
+
def observation_spec(self) -> dict[str, specs.Array]:
|
58
58
|
return self._env.observation_spec()
|
59
59
|
|
60
|
-
def action_spec(self) ->
|
60
|
+
def action_spec(self) -> dict[str, specs.Array]:
|
61
61
|
return self._env.action_spec()
|
62
62
|
|
63
|
-
def
|
64
|
-
return self._env.
|
63
|
+
def reward_spec(self) -> specs.Array:
|
64
|
+
return self._env.reward_spec()
|
65
65
|
|
66
|
-
def
|
66
|
+
def discount_spec(self) -> specs.Array:
|
67
|
+
return self._env.discount_spec()
|
68
|
+
|
69
|
+
def _wrapper_stats(self) -> dict[str, Any]:
|
67
70
|
"""Add wrapper specific logging here."""
|
68
71
|
return {}
|
69
72
|
|
70
|
-
def stats(self) ->
|
73
|
+
def stats(self) -> dict[str, Any]:
|
71
74
|
info = self._env.stats()
|
72
75
|
info.update(self._wrapper_stats())
|
73
76
|
return info
|
74
77
|
|
78
|
+
def load_state(
|
79
|
+
self, request: state_pb2.LoadStateRequest
|
80
|
+
) -> state_pb2.LoadStateResponse:
|
81
|
+
"""Loads a state."""
|
82
|
+
return self._env.load_state(request)
|
83
|
+
|
84
|
+
def save_state(
|
85
|
+
self, request: state_pb2.SaveStateRequest
|
86
|
+
) -> state_pb2.SaveStateResponse:
|
87
|
+
"""Saves a state.
|
88
|
+
|
89
|
+
Args:
|
90
|
+
request: A `SaveStateRequest` containing any parameters necessary to
|
91
|
+
specify how/what state to save.
|
92
|
+
|
93
|
+
Returns:
|
94
|
+
A `SaveStateResponse` containing the status, error message (if
|
95
|
+
applicable), and any other relevant information.
|
96
|
+
"""
|
97
|
+
return self._env.save_state(request)
|
98
|
+
|
75
99
|
def execute_adb_call(self,
|
76
100
|
adb_call: adb_pb2.AdbRequest) -> adb_pb2.AdbResponse:
|
77
101
|
return self._env.execute_adb_call(adb_call)
|
78
102
|
|
79
|
-
def update_task(self, task: task_pb2.Task) -> bool:
|
80
|
-
return self._env.update_task(task)
|
81
|
-
|
82
103
|
@property
|
83
104
|
def raw_action(self):
|
84
105
|
return self._env.raw_action
|
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2024 DeepMind Technologies Limited.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -19,8 +19,8 @@ from unittest import mock
|
|
19
19
|
|
20
20
|
from absl import logging
|
21
21
|
from absl.testing import absltest
|
22
|
-
from android_env import
|
23
|
-
from android_env.proto import
|
22
|
+
from android_env import env_interface
|
23
|
+
from android_env.proto import state_pb2
|
24
24
|
from android_env.wrappers import base_wrapper
|
25
25
|
|
26
26
|
|
@@ -28,7 +28,7 @@ class BaseWrapperTest(absltest.TestCase):
|
|
28
28
|
|
29
29
|
@mock.patch.object(logging, 'info')
|
30
30
|
def test_base_function_forwarding(self, mock_info):
|
31
|
-
base_env = mock.create_autospec(
|
31
|
+
base_env = mock.create_autospec(env_interface.AndroidEnvInterface)
|
32
32
|
wrapped_env = base_wrapper.BaseWrapper(base_env)
|
33
33
|
mock_info.assert_called_with('Wrapping with %s', 'BaseWrapper')
|
34
34
|
|
@@ -58,11 +58,6 @@ class BaseWrapperTest(absltest.TestCase):
|
|
58
58
|
self.assertEqual(fake_action_spec, wrapped_env.action_spec())
|
59
59
|
base_env.action_spec.assert_called_once()
|
60
60
|
|
61
|
-
fake_task_extras_spec = 'fake_task_extras_spec'
|
62
|
-
base_env.task_extras_spec.return_value = fake_task_extras_spec
|
63
|
-
self.assertEqual(fake_task_extras_spec, wrapped_env.task_extras_spec())
|
64
|
-
base_env.task_extras_spec.assert_called_once()
|
65
|
-
|
66
61
|
fake_raw_action = 'fake_raw_action'
|
67
62
|
type(base_env).raw_action = mock.PropertyMock(return_value=fake_raw_action)
|
68
63
|
self.assertEqual(fake_raw_action, wrapped_env.raw_action)
|
@@ -72,10 +67,21 @@ class BaseWrapperTest(absltest.TestCase):
|
|
72
67
|
return_value=fake_raw_observation)
|
73
68
|
self.assertEqual(fake_raw_observation, wrapped_env.raw_observation)
|
74
69
|
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
70
|
+
load_request = state_pb2.LoadStateRequest(args={})
|
71
|
+
expected_response = state_pb2.LoadStateResponse(
|
72
|
+
status=state_pb2.LoadStateResponse.Status.OK
|
73
|
+
)
|
74
|
+
base_env.load_state.return_value = expected_response
|
75
|
+
self.assertEqual(wrapped_env.load_state(load_request), expected_response)
|
76
|
+
base_env.load_state.assert_called_once_with(load_request)
|
77
|
+
|
78
|
+
save_request = state_pb2.SaveStateRequest(args={})
|
79
|
+
expected_response = state_pb2.SaveStateResponse(
|
80
|
+
status=state_pb2.SaveStateResponse.Status.OK
|
81
|
+
)
|
82
|
+
base_env.save_state.return_value = expected_response
|
83
|
+
self.assertEqual(wrapped_env.save_state(save_request), expected_response)
|
84
|
+
base_env.save_state.assert_called_once_with(save_request)
|
79
85
|
|
80
86
|
wrapped_env.close()
|
81
87
|
base_env.close.assert_called_once()
|
@@ -87,7 +93,7 @@ class BaseWrapperTest(absltest.TestCase):
|
|
87
93
|
base_env.some_random_function.return_value = fake_return_value
|
88
94
|
|
89
95
|
def test_multiple_wrappers(self):
|
90
|
-
base_env = mock.create_autospec(
|
96
|
+
base_env = mock.create_autospec(env_interface.AndroidEnvInterface)
|
91
97
|
wrapped_env_1 = base_wrapper.BaseWrapper(base_env)
|
92
98
|
wrapped_env_2 = base_wrapper.BaseWrapper(wrapped_env_1)
|
93
99
|
|
@@ -101,7 +107,7 @@ class BaseWrapperTest(absltest.TestCase):
|
|
101
107
|
self.assertEqual(base_env, wrapped_env_2.raw_env)
|
102
108
|
|
103
109
|
def test_stats(self):
|
104
|
-
base_env = mock.create_autospec(
|
110
|
+
base_env = mock.create_autospec(env_interface.AndroidEnvInterface)
|
105
111
|
wrapped_env = base_wrapper.BaseWrapper(base_env)
|
106
112
|
base_stats = {'base': 'stats'}
|
107
113
|
base_env.stats.return_value = base_stats
|
@@ -109,7 +115,7 @@ class BaseWrapperTest(absltest.TestCase):
|
|
109
115
|
|
110
116
|
@mock.patch.object(logging, 'info')
|
111
117
|
def test_wrapped_stats(self, mock_info):
|
112
|
-
base_env = mock.create_autospec(
|
118
|
+
base_env = mock.create_autospec(env_interface.AndroidEnvInterface)
|
113
119
|
|
114
120
|
class LoggingWrapper1(base_wrapper.BaseWrapper):
|
115
121
|
|
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2024 DeepMind Technologies Limited.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -15,7 +15,7 @@
|
|
15
15
|
|
16
16
|
"""Wraps the AndroidEnv environment to provide discrete actions."""
|
17
17
|
|
18
|
-
from
|
18
|
+
from collections.abc import Sequence
|
19
19
|
|
20
20
|
from android_env.components import action_type
|
21
21
|
from android_env.wrappers import base_wrapper
|
@@ -24,23 +24,24 @@ from dm_env import specs
|
|
24
24
|
import numpy as np
|
25
25
|
|
26
26
|
|
27
|
-
|
27
|
+
_NOISE_CLIP_VALUE = 0.4999
|
28
28
|
|
29
29
|
|
30
30
|
class DiscreteActionWrapper(base_wrapper.BaseWrapper):
|
31
31
|
"""AndroidEnv with discrete actions."""
|
32
32
|
|
33
|
-
def __init__(
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
33
|
+
def __init__(
|
34
|
+
self,
|
35
|
+
env: dm_env.Environment,
|
36
|
+
action_grid: Sequence[int] = (10, 10),
|
37
|
+
redundant_actions: bool = True,
|
38
|
+
noise: float = 0.1,
|
39
|
+
):
|
39
40
|
super().__init__(env)
|
40
41
|
self._parent_action_spec = self._env.action_spec()
|
41
42
|
self._assert_base_env()
|
42
43
|
self._action_grid = action_grid # [height, width]
|
43
|
-
self._grid_size = np.
|
44
|
+
self._grid_size = np.prod(self._action_grid)
|
44
45
|
self._num_action_types = self._parent_action_spec['action_type'].num_values
|
45
46
|
self._redundant_actions = redundant_actions
|
46
47
|
self._noise = noise
|
@@ -57,16 +58,16 @@ class DiscreteActionWrapper(base_wrapper.BaseWrapper):
|
|
57
58
|
"""Number of discrete actions."""
|
58
59
|
|
59
60
|
if self._redundant_actions:
|
60
|
-
return
|
61
|
+
return self._grid_size * self._num_action_types
|
61
62
|
else:
|
62
|
-
return
|
63
|
+
return self._grid_size + self._num_action_types - 1
|
63
64
|
|
64
|
-
def step(self, action:
|
65
|
+
def step(self, action: dict[str, int]) -> dm_env.TimeStep:
|
65
66
|
"""Take a step in the base environment."""
|
66
67
|
|
67
68
|
return self._env.step(self._process_action(action))
|
68
69
|
|
69
|
-
def _process_action(self, action:
|
70
|
+
def _process_action(self, action: dict[str, int]) -> dict[str, np.ndarray]:
|
70
71
|
"""Transforms action so that it agrees with AndroidEnv's action spec."""
|
71
72
|
|
72
73
|
return {
|
@@ -133,8 +134,8 @@ class DiscreteActionWrapper(base_wrapper.BaseWrapper):
|
|
133
134
|
noise_y = np.random.normal(loc=0.0, scale=self._noise)
|
134
135
|
|
135
136
|
# Noise is clipped so that the action will strictly stay in the cell.
|
136
|
-
noise_x = max(min(noise_x,
|
137
|
-
noise_y = max(min(noise_y,
|
137
|
+
noise_x = max(min(noise_x, _NOISE_CLIP_VALUE), -_NOISE_CLIP_VALUE)
|
138
|
+
noise_y = max(min(noise_y, _NOISE_CLIP_VALUE), -_NOISE_CLIP_VALUE)
|
138
139
|
|
139
140
|
x_pos = (x_pos_grid + 0.5 + noise_x) / self._action_grid[1] # WIDTH
|
140
141
|
y_pos = (y_pos_grid + 0.5 + noise_y) / self._action_grid[0] # HEIGHT
|
@@ -149,7 +150,7 @@ class DiscreteActionWrapper(base_wrapper.BaseWrapper):
|
|
149
150
|
|
150
151
|
return [x_pos, y_pos]
|
151
152
|
|
152
|
-
def action_spec(self) ->
|
153
|
+
def action_spec(self) -> dict[str, specs.Array]:
|
153
154
|
"""Action spec of the wrapped environment."""
|
154
155
|
|
155
156
|
return {
|
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2024 DeepMind Technologies Limited.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -18,7 +18,7 @@
|
|
18
18
|
from unittest import mock
|
19
19
|
|
20
20
|
from absl.testing import absltest
|
21
|
-
from android_env import
|
21
|
+
from android_env import env_interface
|
22
22
|
from android_env.components import action_type as action_type_lib
|
23
23
|
from android_env.wrappers import discrete_action_wrapper
|
24
24
|
from dm_env import specs
|
@@ -64,7 +64,7 @@ class DiscreteActionWrapperTest(absltest.TestCase):
|
|
64
64
|
'touch_position': _make_array_spec(
|
65
65
|
shape=(2,), dtype=np.float32, name='touch_position'),
|
66
66
|
}
|
67
|
-
self.base_env = mock.create_autospec(
|
67
|
+
self.base_env = mock.create_autospec(env_interface.AndroidEnvInterface)
|
68
68
|
self.base_env.action_spec.return_value = self._base_action_spec
|
69
69
|
|
70
70
|
def test_num_actions(self):
|
@@ -295,7 +295,7 @@ class DiscreteActionWrapperTest(absltest.TestCase):
|
|
295
295
|
'touch_position': _make_array_spec(
|
296
296
|
shape=(2,), dtype=np.float64, name='touch_position'),
|
297
297
|
}
|
298
|
-
base_env = mock.create_autospec(
|
298
|
+
base_env = mock.create_autospec(env_interface.AndroidEnvInterface)
|
299
299
|
base_env.action_spec.return_value = base_action_spec
|
300
300
|
|
301
301
|
wrapped_env = discrete_action_wrapper.DiscreteActionWrapper(
|
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2024 DeepMind Technologies Limited.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -15,7 +15,7 @@
|
|
15
15
|
|
16
16
|
"""Wraps the AndroidEnv environment to make its interface flat."""
|
17
17
|
|
18
|
-
from typing import
|
18
|
+
from typing import Any
|
19
19
|
|
20
20
|
from android_env.wrappers import base_wrapper
|
21
21
|
import dm_env
|
@@ -74,7 +74,7 @@ class FlatInterfaceWrapper(base_wrapper.BaseWrapper):
|
|
74
74
|
assert isinstance(base_action_spec, dict)
|
75
75
|
assert isinstance(base_action_spec[self._action_name], specs.BoundedArray)
|
76
76
|
|
77
|
-
def _process_action(self, action:
|
77
|
+
def _process_action(self, action: int | np.ndarray | dict[str, Any]):
|
78
78
|
if self._flat_actions:
|
79
79
|
return {self._action_name: action}
|
80
80
|
else:
|
@@ -103,7 +103,7 @@ class FlatInterfaceWrapper(base_wrapper.BaseWrapper):
|
|
103
103
|
timestep = self._env.step(self._process_action(action))
|
104
104
|
return self._process_timestep(timestep)
|
105
105
|
|
106
|
-
def observation_spec(self) ->
|
106
|
+
def observation_spec(self) -> specs.Array | dict[str, specs.Array]: # pytype: disable=signature-mismatch # overriding-return-type-checks
|
107
107
|
if self._flat_observations:
|
108
108
|
pixels_spec = self._env.observation_spec()['pixels']
|
109
109
|
if not self._keep_action_layer:
|
@@ -112,7 +112,7 @@ class FlatInterfaceWrapper(base_wrapper.BaseWrapper):
|
|
112
112
|
else:
|
113
113
|
return self._env.observation_spec()
|
114
114
|
|
115
|
-
def action_spec(self) ->
|
115
|
+
def action_spec(self) -> specs.BoundedArray | dict[str, specs.Array]: # pytype: disable=signature-mismatch # overriding-return-type-checks
|
116
116
|
if self._flat_actions:
|
117
117
|
return self._env.action_spec()[self._action_name]
|
118
118
|
else:
|
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2024 DeepMind Technologies Limited.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
"""Tests for android_env.wrappers.flat_interface_wrapper."""
|
17
17
|
|
18
|
+
from typing import cast
|
18
19
|
from unittest import mock
|
19
20
|
|
20
21
|
from absl.testing import absltest
|
@@ -33,10 +34,6 @@ def _make_array_spec(shape, dtype=np.float32, name=None, maximum=3, minimum=0):
|
|
33
34
|
minimum=np.ones(shape) * minimum)
|
34
35
|
|
35
36
|
|
36
|
-
def _make_discrete_array_spec(name, num_values):
|
37
|
-
return specs.DiscreteArray(name=name, num_values=num_values)
|
38
|
-
|
39
|
-
|
40
37
|
def _make_timestep(observation):
|
41
38
|
return dm_env.TimeStep(
|
42
39
|
step_type='fake_step_type',
|
@@ -51,9 +48,8 @@ class FlatInterfaceWrapperTest(absltest.TestCase):
|
|
51
48
|
def setUp(self):
|
52
49
|
super().setUp()
|
53
50
|
self.action_shape = (1,)
|
54
|
-
self.base_action_spec = {
|
55
|
-
'action_id':
|
56
|
-
name='action_id', num_values=4)
|
51
|
+
self.base_action_spec: dict[str, specs.DiscreteArray] = {
|
52
|
+
'action_id': specs.DiscreteArray(name='action_id', num_values=4)
|
57
53
|
}
|
58
54
|
self.int_obs_shape = (3, 4, 2)
|
59
55
|
self.float_obs_shape = (2,)
|
@@ -148,10 +144,10 @@ class FlatInterfaceWrapperTest(absltest.TestCase):
|
|
148
144
|
|
149
145
|
def test_action_spec(self):
|
150
146
|
wrapped_env = flat_interface_wrapper.FlatInterfaceWrapper(self.base_env)
|
151
|
-
action_spec = wrapped_env.action_spec()
|
152
|
-
|
147
|
+
action_spec = cast(specs.BoundedArray, wrapped_env.action_spec())
|
148
|
+
parent_action_spec = self.base_action_spec['action_id']
|
153
149
|
|
154
|
-
self.assertEqual(
|
150
|
+
self.assertEqual(parent_action_spec.name, action_spec.name)
|
155
151
|
self.assertEqual((), action_spec.shape)
|
156
152
|
self.assertEqual(np.int32, action_spec.dtype)
|
157
153
|
self.assertEqual(0, action_spec.minimum)
|
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2024 DeepMind Technologies Limited.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -15,9 +15,7 @@
|
|
15
15
|
|
16
16
|
"""Converts pixel observation to from int to float32 between 0.0 and 1.0."""
|
17
17
|
|
18
|
-
from
|
19
|
-
|
20
|
-
from android_env.components import utils
|
18
|
+
from android_env.components import pixel_fns
|
21
19
|
from android_env.wrappers import base_wrapper
|
22
20
|
import dm_env
|
23
21
|
from dm_env import specs
|
@@ -34,11 +32,12 @@ class FloatPixelsWrapper(base_wrapper.BaseWrapper):
|
|
34
32
|
np.integer)
|
35
33
|
|
36
34
|
def _process_observation(
|
37
|
-
self, observation:
|
38
|
-
) ->
|
35
|
+
self, observation: dict[str, np.ndarray]
|
36
|
+
) -> dict[str, np.ndarray]:
|
39
37
|
if self._should_convert_int_to_float:
|
40
|
-
float_pixels =
|
41
|
-
|
38
|
+
float_pixels = pixel_fns.convert_int_to_float(
|
39
|
+
observation['pixels'], self._input_spec
|
40
|
+
)
|
42
41
|
observation['pixels'] = float_pixels
|
43
42
|
return observation
|
44
43
|
|
@@ -53,10 +52,10 @@ class FloatPixelsWrapper(base_wrapper.BaseWrapper):
|
|
53
52
|
def reset(self) -> dm_env.TimeStep:
|
54
53
|
return self._process_timestep(self._env.reset())
|
55
54
|
|
56
|
-
def step(self, action:
|
55
|
+
def step(self, action: dict[str, np.ndarray]) -> dm_env.TimeStep:
|
57
56
|
return self._process_timestep(self._env.step(action))
|
58
57
|
|
59
|
-
def observation_spec(self) ->
|
58
|
+
def observation_spec(self) -> dict[str, specs.Array]:
|
60
59
|
if self._should_convert_int_to_float:
|
61
60
|
observation_spec = self._env.observation_spec()
|
62
61
|
observation_spec['pixels'] = specs.BoundedArray(
|
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2024 DeepMind Technologies Limited.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -96,7 +96,7 @@ class FloatPixelsWrapperTest(absltest.TestCase):
|
|
96
96
|
|
97
97
|
def test_float_pixels_wrapper_step(self):
|
98
98
|
wrapped_env = float_pixels_wrapper.FloatPixelsWrapper(self.base_env)
|
99
|
-
ts = wrapped_env.step('fake_action')
|
99
|
+
ts = wrapped_env.step({'fake_action': np.array([1, 2, 3])})
|
100
100
|
|
101
101
|
self.assertEqual(self.base_timestep.step_type, ts.step_type)
|
102
102
|
self.assertEqual(self.base_timestep.reward, ts.reward)
|
@@ -141,7 +141,7 @@ class FloatPixelsWrapperTest(absltest.TestCase):
|
|
141
141
|
# The wrapper should not touch the timestep in this case.
|
142
142
|
fake_timestep = ('step_type', 'reward', 'discount', 'obs')
|
143
143
|
base_env.step.return_value = fake_timestep
|
144
|
-
ts = wrapped_env.step('fake_action')
|
144
|
+
ts = wrapped_env.step({'fake_action': np.array([1, 2, 3])})
|
145
145
|
self.assertEqual(fake_timestep, ts)
|
146
146
|
|
147
147
|
|
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2024 DeepMind Technologies Limited.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -15,7 +15,7 @@
|
|
15
15
|
|
16
16
|
"""Wraps the AndroidEnv to expose an OpenAI Gym interface."""
|
17
17
|
|
18
|
-
from typing import Any
|
18
|
+
from typing import Any
|
19
19
|
|
20
20
|
from android_env.wrappers import base_wrapper
|
21
21
|
import dm_env
|
@@ -25,16 +25,16 @@ from gym import spaces
|
|
25
25
|
import numpy as np
|
26
26
|
|
27
27
|
|
28
|
-
class GymInterfaceWrapper(
|
28
|
+
class GymInterfaceWrapper(gym.Env):
|
29
29
|
"""AndroidEnv with OpenAI Gym interface."""
|
30
30
|
|
31
31
|
def __init__(self, env: dm_env.Environment):
|
32
|
-
|
33
|
-
base_wrapper.BaseWrapper.__init__(self, env)
|
32
|
+
self._env = env
|
34
33
|
self.spec = None
|
35
34
|
self.action_space = self._spec_to_space(self._env.action_spec())
|
36
35
|
self.observation_space = self._spec_to_space(self._env.observation_spec())
|
37
36
|
self.metadata = {'render.modes': ['rgb_array']}
|
37
|
+
self._latest_observation = None
|
38
38
|
|
39
39
|
def _spec_to_space(self, spec: specs.Array) -> spaces.Space:
|
40
40
|
"""Converts dm_env specs to OpenAI Gym spaces."""
|
@@ -44,7 +44,8 @@ class GymInterfaceWrapper(base_wrapper.BaseWrapper, gym.Env):
|
|
44
44
|
|
45
45
|
if isinstance(spec, dict):
|
46
46
|
return spaces.Dict(
|
47
|
-
{name: self._spec_to_space(s) for name, s in spec.items()}
|
47
|
+
{name: self._spec_to_space(s) for name, s in spec.items()}
|
48
|
+
)
|
48
49
|
|
49
50
|
if isinstance(spec, specs.DiscreteArray):
|
50
51
|
return spaces.Box(
|
@@ -61,11 +62,13 @@ class GymInterfaceWrapper(base_wrapper.BaseWrapper, gym.Env):
|
|
61
62
|
high=spec.maximum)
|
62
63
|
|
63
64
|
if isinstance(spec, specs.Array):
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
65
|
+
if spec.dtype == np.uint8:
|
66
|
+
low = 0
|
67
|
+
high = 255
|
68
|
+
else:
|
69
|
+
low = -np.inf
|
70
|
+
high = np.inf
|
71
|
+
return spaces.Box(shape=spec.shape, dtype=spec.dtype, low=low, high=high)
|
69
72
|
|
70
73
|
raise ValueError('Unknown type for specs: {}'.format(spec))
|
71
74
|
|
@@ -74,18 +77,21 @@ class GymInterfaceWrapper(base_wrapper.BaseWrapper, gym.Env):
|
|
74
77
|
if mode == 'rgb_array':
|
75
78
|
if self._latest_observation is None:
|
76
79
|
return
|
80
|
+
|
77
81
|
return self._latest_observation['pixels']
|
78
82
|
else:
|
79
83
|
raise ValueError('Only supported render mode is rgb_array.')
|
80
84
|
|
81
85
|
def reset(self) -> np.ndarray:
|
86
|
+
self._latest_observation = None
|
82
87
|
timestep = self._env.reset()
|
83
88
|
return timestep.observation
|
84
89
|
|
85
|
-
def step(self, action:
|
90
|
+
def step(self, action: dict[str, int]) -> tuple[Any, ...]:
|
86
91
|
"""Take a step in the base environment."""
|
87
|
-
timestep = self._env.step(
|
92
|
+
timestep = self._env.step(action)
|
88
93
|
observation = timestep.observation
|
94
|
+
self._latest_observation = observation
|
89
95
|
reward = timestep.reward
|
90
96
|
done = timestep.step_type == dm_env.StepType.LAST
|
91
97
|
info = {'discount': timestep.discount}
|
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2024 DeepMind Technologies Limited.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -18,7 +18,7 @@
|
|
18
18
|
from unittest import mock
|
19
19
|
|
20
20
|
from absl.testing import absltest
|
21
|
-
from android_env import
|
21
|
+
from android_env import env_interface
|
22
22
|
from android_env.wrappers import gym_wrapper
|
23
23
|
import dm_env
|
24
24
|
from dm_env import specs
|
@@ -30,7 +30,7 @@ class GymInterfaceWrapperTest(absltest.TestCase):
|
|
30
30
|
|
31
31
|
def setUp(self):
|
32
32
|
super().setUp()
|
33
|
-
self._base_env = mock.create_autospec(
|
33
|
+
self._base_env = mock.create_autospec(env_interface.AndroidEnvInterface)
|
34
34
|
self._base_env.action_spec.return_value = {
|
35
35
|
'action_type':
|
36
36
|
specs.DiscreteArray(
|
@@ -68,7 +68,6 @@ class GymInterfaceWrapperTest(absltest.TestCase):
|
|
68
68
|
def test_render(self):
|
69
69
|
self._base_env.step.return_value = self._fake_ts
|
70
70
|
_ = self._wrapped_env.step(action=np.zeros(shape=(1,)))
|
71
|
-
self._base_env._latest_observation = {'pixels': np.ones(shape=(2, 3))}
|
72
71
|
image = self._wrapped_env.render(mode='rgb_array')
|
73
72
|
self.assertTrue(np.array_equal(image, np.ones(shape=(2, 3))))
|
74
73
|
|
@@ -90,7 +89,6 @@ class GymInterfaceWrapperTest(absltest.TestCase):
|
|
90
89
|
self._base_env.step.return_value = self._fake_ts
|
91
90
|
obs, _, _, _ = self._wrapped_env.step(action=np.zeros(shape=(1,)))
|
92
91
|
self._base_env.step.assert_called_once()
|
93
|
-
print(obs)
|
94
92
|
self.assertTrue(np.array_equal(obs['pixels'], np.ones(shape=(2, 3))))
|
95
93
|
|
96
94
|
def test_spec_to_space(self):
|