android-env 1.2.1__py3-none-any.whl → 1.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- android_env/__init__.py +1 -1
- android_env/components/__init__.py +1 -1
- android_env/components/a11y/__init__.py +15 -0
- android_env/components/a11y/a11y_events.py +118 -0
- android_env/components/a11y/a11y_events_test.py +173 -0
- android_env/components/a11y/a11y_forests.py +128 -0
- android_env/components/a11y/a11y_forests_test.py +237 -0
- android_env/components/a11y/a11y_servicer.py +199 -0
- android_env/components/a11y/a11y_servicer_test.py +224 -0
- android_env/components/action_fns.py +132 -0
- android_env/components/action_fns_test.py +227 -0
- android_env/components/action_type.py +26 -3
- android_env/components/adb_call_parser.py +239 -196
- android_env/components/adb_call_parser_test.py +179 -209
- android_env/components/adb_controller.py +90 -52
- android_env/components/adb_controller_test.py +187 -16
- android_env/components/adb_log_stream.py +17 -5
- android_env/components/adb_log_stream_test.py +17 -3
- android_env/components/app_screen_checker.py +17 -15
- android_env/components/app_screen_checker_test.py +7 -8
- android_env/components/config_classes.py +203 -0
- android_env/components/coordinator.py +102 -338
- android_env/components/coordinator_test.py +59 -199
- android_env/components/device_settings.py +174 -0
- android_env/components/device_settings_test.py +228 -0
- android_env/components/dumpsys_thread.py +3 -4
- android_env/components/dumpsys_thread_test.py +1 -1
- android_env/components/errors.py +52 -10
- android_env/components/errors_test.py +110 -0
- android_env/components/log_stream.py +7 -5
- android_env/components/log_stream_test.py +1 -1
- android_env/components/logcat_thread.py +9 -8
- android_env/components/logcat_thread_test.py +3 -4
- android_env/components/{utils.py → pixel_fns.py} +20 -20
- android_env/components/{utils_test.py → pixel_fns_test.py} +20 -15
- android_env/components/setup_step_interpreter.py +47 -39
- android_env/components/setup_step_interpreter_test.py +4 -4
- android_env/components/simulators/__init__.py +1 -1
- android_env/components/simulators/base_simulator.py +116 -44
- android_env/components/simulators/base_simulator_test.py +131 -9
- android_env/components/simulators/emulator/__init__.py +1 -1
- android_env/components/simulators/emulator/emulator_launcher.py +67 -77
- android_env/components/simulators/emulator/emulator_launcher_test.py +153 -49
- android_env/components/simulators/emulator/emulator_simulator.py +276 -95
- android_env/components/simulators/emulator/emulator_simulator_test.py +314 -89
- android_env/components/simulators/fake/__init__.py +1 -1
- android_env/components/simulators/fake/fake_simulator.py +17 -25
- android_env/components/simulators/fake/fake_simulator_test.py +29 -12
- android_env/components/specs.py +18 -28
- android_env/components/specs_test.py +1 -44
- android_env/components/task_manager.py +48 -48
- android_env/components/task_manager_test.py +71 -60
- android_env/env_interface.py +37 -23
- android_env/environment.py +83 -51
- android_env/environment_test.py +68 -29
- android_env/loader.py +57 -43
- android_env/loader_test.py +115 -35
- android_env/proto/__init__.py +1 -1
- android_env/proto/a11y/__init__.py +15 -0
- android_env/proto/a11y/a11y.proto +75 -0
- android_env/proto/a11y/a11y_pb2.py +54 -0
- android_env/proto/a11y/a11y_pb2.pyi +49 -0
- android_env/proto/a11y/a11y_pb2_grpc.py +202 -0
- android_env/proto/a11y/android_accessibility_action.proto +32 -0
- android_env/proto/a11y/android_accessibility_action_pb2.py +37 -0
- android_env/proto/a11y/android_accessibility_action_pb2.pyi +13 -0
- android_env/proto/a11y/android_accessibility_action_pb2_grpc.py +24 -0
- android_env/proto/a11y/android_accessibility_forest.proto +29 -0
- android_env/proto/a11y/android_accessibility_forest_pb2.py +38 -0
- android_env/proto/a11y/android_accessibility_forest_pb2.pyi +13 -0
- android_env/proto/a11y/android_accessibility_forest_pb2_grpc.py +24 -0
- android_env/proto/a11y/android_accessibility_node_info.proto +122 -0
- android_env/proto/a11y/android_accessibility_node_info_clickable_span.proto +49 -0
- android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2.py +39 -0
- android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2.pyi +28 -0
- android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2_grpc.py +24 -0
- android_env/proto/a11y/android_accessibility_node_info_pb2.py +42 -0
- android_env/proto/a11y/android_accessibility_node_info_pb2.pyi +75 -0
- android_env/proto/a11y/android_accessibility_node_info_pb2_grpc.py +24 -0
- android_env/proto/a11y/android_accessibility_tree.proto +29 -0
- android_env/proto/a11y/android_accessibility_tree_pb2.py +38 -0
- android_env/proto/a11y/android_accessibility_tree_pb2.pyi +13 -0
- android_env/proto/a11y/android_accessibility_tree_pb2_grpc.py +24 -0
- android_env/proto/a11y/android_accessibility_window_info.proto +84 -0
- android_env/proto/a11y/android_accessibility_window_info_pb2.py +41 -0
- android_env/proto/a11y/android_accessibility_window_info_pb2.pyi +48 -0
- android_env/proto/a11y/android_accessibility_window_info_pb2_grpc.py +24 -0
- android_env/proto/a11y/rect.proto +30 -0
- android_env/proto/a11y/rect_pb2.py +37 -0
- android_env/proto/a11y/rect_pb2.pyi +17 -0
- android_env/proto/a11y/rect_pb2_grpc.py +24 -0
- android_env/proto/adb.proto +17 -6
- android_env/proto/adb_pb2.py +120 -107
- android_env/proto/adb_pb2.pyi +396 -0
- android_env/proto/adb_pb2_grpc.py +20 -0
- android_env/proto/emulator_controller.proto +68 -63
- android_env/proto/emulator_controller_pb2.py +142 -131
- android_env/proto/emulator_controller_pb2.pyi +672 -0
- android_env/proto/emulator_controller_pb2_grpc.py +505 -142
- android_env/proto/snapshot.proto +169 -0
- android_env/proto/snapshot_pb2.py +47 -0
- android_env/proto/snapshot_pb2.pyi +117 -0
- android_env/proto/snapshot_pb2_grpc.py +24 -0
- android_env/proto/snapshot_service.proto +289 -0
- android_env/proto/snapshot_service_pb2.py +54 -0
- android_env/proto/snapshot_service_pb2.pyi +86 -0
- android_env/proto/snapshot_service_pb2_grpc.py +487 -0
- android_env/proto/state.proto +63 -0
- android_env/proto/state_pb2.py +63 -0
- android_env/proto/state_pb2.pyi +85 -0
- android_env/proto/state_pb2_grpc.py +24 -0
- android_env/proto/task.proto +5 -1
- android_env/proto/task_pb2.py +42 -31
- android_env/proto/task_pb2.pyi +160 -0
- android_env/proto/task_pb2_grpc.py +20 -0
- android_env/wrappers/__init__.py +1 -1
- android_env/wrappers/a11y_grpc_wrapper.py +500 -0
- android_env/wrappers/a11y_grpc_wrapper_test.py +849 -0
- android_env/wrappers/base_wrapper.py +34 -13
- android_env/wrappers/base_wrapper_test.py +22 -16
- android_env/wrappers/discrete_action_wrapper.py +18 -17
- android_env/wrappers/discrete_action_wrapper_test.py +4 -4
- android_env/wrappers/flat_interface_wrapper.py +5 -5
- android_env/wrappers/flat_interface_wrapper_test.py +7 -11
- android_env/wrappers/float_pixels_wrapper.py +9 -10
- android_env/wrappers/float_pixels_wrapper_test.py +3 -3
- android_env/wrappers/gym_wrapper.py +19 -13
- android_env/wrappers/gym_wrapper_test.py +3 -5
- android_env/wrappers/image_rescale_wrapper.py +18 -21
- android_env/wrappers/image_rescale_wrapper_test.py +25 -37
- android_env/wrappers/last_action_wrapper.py +16 -13
- android_env/wrappers/last_action_wrapper_test.py +44 -51
- android_env/wrappers/rate_limit_wrapper.py +6 -3
- android_env/wrappers/rate_limit_wrapper_test.py +22 -1
- android_env/wrappers/tap_action_wrapper.py +16 -17
- android_env/wrappers/tap_action_wrapper_test.py +51 -16
- {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info}/METADATA +14 -18
- android_env-1.2.3.dist-info/RECORD +141 -0
- {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info}/WHEEL +1 -1
- android_env/proto/raw_observation.proto +0 -39
- android_env/proto/raw_observation_pb2.py +0 -27
- android_env/proto/raw_observation_pb2_grpc.py +0 -4
- android_env-1.2.1.dist-info/RECORD +0 -81
- {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info/licenses}/LICENSE +0 -0
- {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2024 DeepMind Technologies Limited.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -15,7 +15,7 @@
|
|
15
15
|
|
16
16
|
"""Wraps the AndroidEnv environment to rescale the observations."""
|
17
17
|
|
18
|
-
from
|
18
|
+
from collections.abc import Sequence
|
19
19
|
|
20
20
|
from android_env.wrappers import base_wrapper
|
21
21
|
import dm_env
|
@@ -37,8 +37,9 @@ class ImageRescaleWrapper(base_wrapper.BaseWrapper):
|
|
37
37
|
def __init__(
|
38
38
|
self,
|
39
39
|
env: dm_env.Environment,
|
40
|
-
zoom_factors:
|
41
|
-
grayscale: bool = False
|
40
|
+
zoom_factors: Sequence[float] | None = (0.5, 0.5),
|
41
|
+
grayscale: bool = False,
|
42
|
+
):
|
42
43
|
super().__init__(env)
|
43
44
|
assert 'pixels' in self._env.observation_spec()
|
44
45
|
assert self._env.observation_spec()['pixels'].shape[-1] in [1, 3], (
|
@@ -50,16 +51,8 @@ class ImageRescaleWrapper(base_wrapper.BaseWrapper):
|
|
50
51
|
# want to zoom the number of channels so we just multiply it by 1.0.
|
51
52
|
self._zoom_factors = tuple(zoom_factors) + (1.0,)
|
52
53
|
|
53
|
-
# Save the raw image for making videos, for example.
|
54
|
-
self._raw_pixels = None
|
55
|
-
|
56
|
-
@property
|
57
|
-
def raw_pixels(self):
|
58
|
-
return self._raw_pixels
|
59
|
-
|
60
54
|
def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep:
|
61
55
|
observation = timestep.observation
|
62
|
-
self._raw_pixels = observation['pixels'].copy()
|
63
56
|
processed_observation = observation.copy()
|
64
57
|
processed_observation['pixels'] = self._process_pixels(
|
65
58
|
observation['pixels'])
|
@@ -78,14 +71,16 @@ class ImageRescaleWrapper(base_wrapper.BaseWrapper):
|
|
78
71
|
return self._resize_image_array(image, new_shape)
|
79
72
|
|
80
73
|
def _resize_image_array(
|
81
|
-
self,
|
82
|
-
|
83
|
-
new_shape: Sequence[int]) -> np.ndarray:
|
74
|
+
self, grayscale_or_rbg_array: np.ndarray, new_shape: np.ndarray
|
75
|
+
) -> np.ndarray:
|
84
76
|
"""Resize color or grayscale/action_layer array to new_shape."""
|
85
|
-
assert
|
77
|
+
assert new_shape.ndim == 1
|
86
78
|
assert len(new_shape) == 2
|
87
|
-
resized_array = np.array(
|
88
|
-
grayscale_or_rbg_array.astype('uint8')).resize(
|
79
|
+
resized_array = np.array(
|
80
|
+
Image.fromarray(grayscale_or_rbg_array.astype('uint8')).resize(
|
81
|
+
tuple(new_shape)
|
82
|
+
)
|
83
|
+
)
|
89
84
|
if resized_array.ndim == 2:
|
90
85
|
return np.expand_dims(resized_array, axis=-1)
|
91
86
|
return resized_array
|
@@ -98,15 +93,17 @@ class ImageRescaleWrapper(base_wrapper.BaseWrapper):
|
|
98
93
|
timestep = self._env.step(action)
|
99
94
|
return self._process_timestep(timestep)
|
100
95
|
|
101
|
-
def observation_spec(self) ->
|
96
|
+
def observation_spec(self) -> dict[str, specs.Array]:
|
102
97
|
parent_spec = self._env.observation_spec().copy()
|
103
98
|
out_shape = np.multiply(parent_spec['pixels'].shape,
|
104
99
|
self._zoom_factors).astype(np.int32)
|
105
100
|
if self._grayscale:
|
106
101
|
# In grayscale mode we want the output shape to be [W, H, 1]
|
107
102
|
out_shape[-1] = 1
|
108
|
-
parent_spec['pixels'] = specs.
|
103
|
+
parent_spec['pixels'] = specs.BoundedArray(
|
109
104
|
shape=out_shape,
|
110
105
|
dtype=parent_spec['pixels'].dtype,
|
111
|
-
name=parent_spec['pixels'].name
|
106
|
+
name=parent_spec['pixels'].name,
|
107
|
+
minimum=parent_spec['pixels'].minimum,
|
108
|
+
maximum=parent_spec['pixels'].maximum)
|
112
109
|
return parent_spec
|
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2024 DeepMind Technologies Limited.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -15,42 +15,24 @@
|
|
15
15
|
|
16
16
|
"""Tests for android_env.wrappers.image_rescale_wrapper."""
|
17
17
|
|
18
|
-
from typing import Any
|
18
|
+
from typing import Any
|
19
|
+
from unittest import mock
|
19
20
|
|
20
21
|
from absl.testing import absltest
|
21
|
-
from android_env import
|
22
|
+
from android_env import env_interface
|
22
23
|
from android_env.wrappers import image_rescale_wrapper
|
23
24
|
import dm_env
|
24
25
|
from dm_env import specs
|
25
26
|
import numpy as np
|
26
27
|
|
27
28
|
|
28
|
-
class FakeEnv(environment.AndroidEnv):
|
29
|
-
"""A class that we can use to inject custom observations and specs."""
|
30
|
-
|
31
|
-
def __init__(self, obs_spec):
|
32
|
-
self._obs_spec = obs_spec
|
33
|
-
self._next_obs = None
|
34
|
-
|
35
|
-
def reset(self) -> dm_env.TimeStep:
|
36
|
-
return self._next_timestep
|
37
|
-
|
38
|
-
def step(self, action: Any) -> dm_env.TimeStep:
|
39
|
-
return self._next_timestep
|
40
|
-
|
41
|
-
def observation_spec(self) -> Dict[str, specs.Array]:
|
42
|
-
return self._obs_spec
|
43
|
-
|
44
|
-
def action_spec(self) -> Dict[str, specs.Array]:
|
45
|
-
assert False, 'This should not be called by tests.'
|
46
|
-
|
47
|
-
def set_next_timestep(self, timestep):
|
48
|
-
self._next_timestep = timestep
|
49
|
-
|
50
|
-
|
51
29
|
def _simple_spec():
|
52
|
-
return specs.
|
53
|
-
shape=np.array([300, 300, 3]),
|
30
|
+
return specs.BoundedArray(
|
31
|
+
shape=np.array([300, 300, 3]),
|
32
|
+
dtype=np.uint8,
|
33
|
+
name='pixels',
|
34
|
+
minimum=0,
|
35
|
+
maximum=255)
|
54
36
|
|
55
37
|
|
56
38
|
def _simple_timestep():
|
@@ -65,9 +47,11 @@ def _simple_timestep():
|
|
65
47
|
class ImageRescaleWrapperTest(absltest.TestCase):
|
66
48
|
|
67
49
|
def test_100x50_grayscale(self):
|
68
|
-
|
69
|
-
fake_env =
|
70
|
-
fake_env.
|
50
|
+
fake_timestep = _simple_timestep()
|
51
|
+
fake_env = mock.create_autospec(env_interface.AndroidEnvInterface)
|
52
|
+
fake_env.observation_spec.return_value = {'pixels': _simple_spec()}
|
53
|
+
fake_env.reset.return_value = fake_timestep
|
54
|
+
fake_env.step.return_value = fake_timestep
|
71
55
|
|
72
56
|
wrapper = image_rescale_wrapper.ImageRescaleWrapper(
|
73
57
|
fake_env, zoom_factors=(1.0 / 3, 1.0 / 6.0), grayscale=True)
|
@@ -81,9 +65,11 @@ class ImageRescaleWrapperTest(absltest.TestCase):
|
|
81
65
|
self.assertEqual(step_image.shape, (100, 50, 1))
|
82
66
|
|
83
67
|
def test_150x60_full_channels(self):
|
84
|
-
|
85
|
-
fake_env =
|
86
|
-
fake_env.
|
68
|
+
fake_timestep = _simple_timestep()
|
69
|
+
fake_env = mock.create_autospec(env_interface.AndroidEnvInterface)
|
70
|
+
fake_env.observation_spec.return_value = {'pixels': _simple_spec()}
|
71
|
+
fake_env.reset.return_value = fake_timestep
|
72
|
+
fake_env.step.return_value = fake_timestep
|
87
73
|
|
88
74
|
wrapper = image_rescale_wrapper.ImageRescaleWrapper(
|
89
75
|
fake_env, zoom_factors=(1.0 / 2.0, 1.0 / 5.0))
|
@@ -97,9 +83,11 @@ class ImageRescaleWrapperTest(absltest.TestCase):
|
|
97
83
|
self.assertEqual(step_image.shape, (150, 60, 3))
|
98
84
|
|
99
85
|
def test_list_zoom_factor(self):
|
100
|
-
|
101
|
-
fake_env =
|
102
|
-
fake_env.
|
86
|
+
fake_timestep = _simple_timestep()
|
87
|
+
fake_env = mock.create_autospec(env_interface.AndroidEnvInterface)
|
88
|
+
fake_env.observation_spec.return_value = {'pixels': _simple_spec()}
|
89
|
+
fake_env.reset.return_value = fake_timestep
|
90
|
+
fake_env.step.return_value = fake_timestep
|
103
91
|
|
104
92
|
wrapper = image_rescale_wrapper.ImageRescaleWrapper(
|
105
93
|
fake_env, zoom_factors=[0.5, 0.2])
|
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2024 DeepMind Technologies Limited.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -15,10 +15,8 @@
|
|
15
15
|
|
16
16
|
"""Extends Android observation with the latest action taken."""
|
17
17
|
|
18
|
-
from typing import Dict
|
19
|
-
|
20
18
|
from android_env.components import action_type
|
21
|
-
from android_env.components import
|
19
|
+
from android_env.components import pixel_fns
|
22
20
|
from android_env.wrappers import base_wrapper
|
23
21
|
import dm_env
|
24
22
|
from dm_env import specs
|
@@ -56,8 +54,8 @@ class LastActionWrapper(base_wrapper.BaseWrapper):
|
|
56
54
|
return timestep._replace(observation=processed_observation)
|
57
55
|
|
58
56
|
def _process_observation(
|
59
|
-
self, observation:
|
60
|
-
) ->
|
57
|
+
self, observation: dict[str, np.ndarray]
|
58
|
+
) -> dict[str, np.ndarray]:
|
61
59
|
"""Extends observation with last_action data."""
|
62
60
|
processed_observation = observation.copy()
|
63
61
|
last_action_layer = self._get_last_action_layer(observation['pixels'])
|
@@ -78,8 +76,9 @@ class LastActionWrapper(base_wrapper.BaseWrapper):
|
|
78
76
|
if ('action_type' in last_action and
|
79
77
|
last_action['action_type'] == action_type.ActionType.TOUCH):
|
80
78
|
touch_position = last_action['touch_position']
|
81
|
-
x, y =
|
82
|
-
touch_position, width_height=self._screen_dimensions[::-1]
|
79
|
+
x, y = pixel_fns.touch_position_to_pixel_position(
|
80
|
+
touch_position, width_height=self._screen_dimensions[::-1]
|
81
|
+
)
|
83
82
|
last_action_layer[y, x] = 255
|
84
83
|
|
85
84
|
return last_action_layer
|
@@ -92,20 +91,24 @@ class LastActionWrapper(base_wrapper.BaseWrapper):
|
|
92
91
|
timestep = self._env.step(action)
|
93
92
|
return self._process_timestep(timestep)
|
94
93
|
|
95
|
-
def observation_spec(self) ->
|
94
|
+
def observation_spec(self) -> dict[str, specs.Array]:
|
96
95
|
parent_spec = self._env.observation_spec().copy()
|
97
96
|
shape = parent_spec['pixels'].shape
|
98
97
|
if self._concat_to_pixels:
|
99
|
-
parent_spec['pixels'] = specs.
|
98
|
+
parent_spec['pixels'] = specs.BoundedArray(
|
100
99
|
shape=(shape[0], shape[1], shape[2] + 1),
|
101
100
|
dtype=parent_spec['pixels'].dtype,
|
102
|
-
name=parent_spec['pixels'].name
|
101
|
+
name=parent_spec['pixels'].name,
|
102
|
+
minimum=parent_spec['pixels'].minimum,
|
103
|
+
maximum=parent_spec['pixels'].maximum)
|
103
104
|
else:
|
104
105
|
parent_spec.update({
|
105
106
|
'last_action':
|
106
|
-
specs.
|
107
|
+
specs.BoundedArray(
|
107
108
|
shape=(shape[0], shape[1]),
|
108
109
|
dtype=parent_spec['pixels'].dtype,
|
109
|
-
name='last_action'
|
110
|
+
name='last_action',
|
111
|
+
minimum=parent_spec['pixels'].minimum,
|
112
|
+
maximum=parent_spec['pixels'].maximum)
|
110
113
|
})
|
111
114
|
return parent_spec
|
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2024 DeepMind Technologies Limited.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -15,10 +15,11 @@
|
|
15
15
|
|
16
16
|
"""Tests for android_env.wrappers.last_action_wrapper."""
|
17
17
|
|
18
|
-
from typing import Any
|
18
|
+
from typing import Any
|
19
|
+
from unittest import mock
|
19
20
|
|
20
21
|
from absl.testing import absltest
|
21
|
-
from android_env import
|
22
|
+
from android_env import env_interface
|
22
23
|
from android_env.components import action_type
|
23
24
|
from android_env.wrappers import last_action_wrapper
|
24
25
|
import dm_env
|
@@ -26,37 +27,13 @@ from dm_env import specs
|
|
26
27
|
import numpy as np
|
27
28
|
|
28
29
|
|
29
|
-
class FakeEnv(environment.AndroidEnv):
|
30
|
-
"""A class that we can use to inject custom observations and specs."""
|
31
|
-
|
32
|
-
def __init__(self, obs_spec):
|
33
|
-
self._obs_spec = obs_spec
|
34
|
-
self._next_obs = None
|
35
|
-
self._latest_action = {}
|
36
|
-
|
37
|
-
def reset(self) -> dm_env.TimeStep:
|
38
|
-
return self._next_timestep
|
39
|
-
|
40
|
-
def step(self, action: Any) -> dm_env.TimeStep:
|
41
|
-
self._latest_action = action
|
42
|
-
return self._next_timestep
|
43
|
-
|
44
|
-
def observation_spec(self) -> Dict[str, specs.Array]:
|
45
|
-
return self._obs_spec
|
46
|
-
|
47
|
-
def action_spec(self) -> Dict[str, specs.Array]:
|
48
|
-
assert False, 'This should not be called by tests.'
|
49
|
-
|
50
|
-
def set_next_timestep(self, timestep):
|
51
|
-
self._next_timestep = timestep
|
52
|
-
|
53
|
-
def close(self):
|
54
|
-
pass
|
55
|
-
|
56
|
-
|
57
30
|
def _simple_spec():
|
58
|
-
return specs.
|
59
|
-
shape=np.array([120, 80, 3]),
|
31
|
+
return specs.BoundedArray(
|
32
|
+
shape=np.array([120, 80, 3]),
|
33
|
+
dtype=np.uint8,
|
34
|
+
name='pixels',
|
35
|
+
minimum=0,
|
36
|
+
maximum=255)
|
60
37
|
|
61
38
|
|
62
39
|
def _simple_timestep():
|
@@ -71,9 +48,11 @@ def _simple_timestep():
|
|
71
48
|
class LastActionWrapperTest(absltest.TestCase):
|
72
49
|
|
73
50
|
def test_concat_to_pixels(self):
|
74
|
-
|
75
|
-
fake_env =
|
76
|
-
fake_env.
|
51
|
+
fake_timestep = _simple_timestep()
|
52
|
+
fake_env = mock.create_autospec(env_interface.AndroidEnvInterface)
|
53
|
+
fake_env.observation_spec.return_value = {'pixels': _simple_spec()}
|
54
|
+
fake_env.reset.return_value = fake_timestep
|
55
|
+
fake_env.step.return_value = fake_timestep
|
77
56
|
|
78
57
|
wrapper = last_action_wrapper.LastActionWrapper(
|
79
58
|
fake_env, concat_to_pixels=True)
|
@@ -86,10 +65,12 @@ class LastActionWrapperTest(absltest.TestCase):
|
|
86
65
|
last_action_layer = reset_image[:, :, -1]
|
87
66
|
self.assertEqual(np.sum(last_action_layer), 0)
|
88
67
|
|
89
|
-
|
68
|
+
action1 = {
|
90
69
|
'action_type': action_type.ActionType.TOUCH,
|
91
70
|
'touch_position': np.array([0.25, 0.75], dtype=np.float32), # (W x H)
|
92
|
-
}
|
71
|
+
}
|
72
|
+
type(fake_env).raw_action = mock.PropertyMock(return_value=action1)
|
73
|
+
step_timestep = wrapper.step(action=action1)
|
93
74
|
step_image = step_timestep.observation['pixels']
|
94
75
|
self.assertEqual(step_image.shape, (120, 80, 4)) # (H x W)
|
95
76
|
last_action_layer = step_image[:, :, -1]
|
@@ -97,19 +78,23 @@ class LastActionWrapperTest(absltest.TestCase):
|
|
97
78
|
y, x = np.where(last_action_layer == 255)
|
98
79
|
self.assertEqual((y.item(), x.item()), (90, 20))
|
99
80
|
|
100
|
-
|
81
|
+
action2 = {
|
101
82
|
'action_type': action_type.ActionType.LIFT,
|
102
83
|
'touch_position': np.array([0.25, 0.75], dtype=np.float32),
|
103
|
-
}
|
84
|
+
}
|
85
|
+
type(fake_env).raw_action = mock.PropertyMock(return_value=action2)
|
86
|
+
step_timestep = wrapper.step(action=action2)
|
104
87
|
step_image = step_timestep.observation['pixels']
|
105
88
|
self.assertEqual(step_image.shape, (120, 80, 4))
|
106
89
|
last_action_layer = step_image[:, :, -1]
|
107
90
|
self.assertEqual(np.sum(last_action_layer), 0)
|
108
91
|
|
109
|
-
|
92
|
+
action3 = {
|
110
93
|
'action_type': action_type.ActionType.TOUCH,
|
111
94
|
'touch_position': np.array([0.25, 1.0], dtype=np.float32),
|
112
|
-
}
|
95
|
+
}
|
96
|
+
type(fake_env).raw_action = mock.PropertyMock(return_value=action3)
|
97
|
+
step_timestep = wrapper.step(action=action3)
|
113
98
|
step_image = step_timestep.observation['pixels']
|
114
99
|
self.assertEqual(step_image.shape, (120, 80, 4))
|
115
100
|
last_action_layer = step_image[:, :, -1]
|
@@ -118,9 +103,11 @@ class LastActionWrapperTest(absltest.TestCase):
|
|
118
103
|
self.assertEqual((y.item(), x.item()), (119, 20))
|
119
104
|
|
120
105
|
def test_no_concat_to_pixels(self):
|
121
|
-
|
122
|
-
fake_env =
|
123
|
-
fake_env.
|
106
|
+
fake_timestep = _simple_timestep()
|
107
|
+
fake_env = mock.create_autospec(env_interface.AndroidEnvInterface)
|
108
|
+
fake_env.observation_spec.return_value = {'pixels': _simple_spec()}
|
109
|
+
fake_env.reset.return_value = fake_timestep
|
110
|
+
fake_env.step.return_value = fake_timestep
|
124
111
|
|
125
112
|
wrapper = last_action_wrapper.LastActionWrapper(
|
126
113
|
fake_env, concat_to_pixels=False)
|
@@ -134,10 +121,12 @@ class LastActionWrapperTest(absltest.TestCase):
|
|
134
121
|
last_action_layer = reset_timestep.observation['last_action']
|
135
122
|
self.assertEqual(np.sum(last_action_layer), 0)
|
136
123
|
|
137
|
-
|
124
|
+
action1 = {
|
138
125
|
'action_type': action_type.ActionType.TOUCH,
|
139
126
|
'touch_position': np.array([0.25, 0.75], dtype=np.float32),
|
140
|
-
}
|
127
|
+
}
|
128
|
+
type(fake_env).raw_action = mock.PropertyMock(return_value=action1)
|
129
|
+
step_timestep = wrapper.step(action=action1)
|
141
130
|
step_image = step_timestep.observation['pixels']
|
142
131
|
self.assertEqual(step_image.shape, (120, 80, 3))
|
143
132
|
last_action_layer = step_timestep.observation['last_action']
|
@@ -145,19 +134,23 @@ class LastActionWrapperTest(absltest.TestCase):
|
|
145
134
|
y, x = np.where(last_action_layer == 255)
|
146
135
|
self.assertEqual((y.item(), x.item()), (90, 20))
|
147
136
|
|
148
|
-
|
137
|
+
action2 = {
|
149
138
|
'action_type': action_type.ActionType.LIFT,
|
150
139
|
'touch_position': np.array([0.25, 0.75], dtype=np.float32),
|
151
|
-
}
|
140
|
+
}
|
141
|
+
type(fake_env).raw_action = mock.PropertyMock(return_value=action2)
|
142
|
+
step_timestep = wrapper.step(action=action2)
|
152
143
|
step_image = step_timestep.observation['pixels']
|
153
144
|
self.assertEqual(step_image.shape, (120, 80, 3))
|
154
145
|
last_action_layer = step_timestep.observation['last_action']
|
155
146
|
self.assertEqual(np.sum(last_action_layer), 0)
|
156
147
|
|
157
|
-
|
148
|
+
action3 = {
|
158
149
|
'action_type': action_type.ActionType.TOUCH,
|
159
150
|
'touch_position': np.array([1.0, 0.75], dtype=np.float32),
|
160
|
-
}
|
151
|
+
}
|
152
|
+
type(fake_env).raw_action = mock.PropertyMock(return_value=action3)
|
153
|
+
step_timestep = wrapper.step(action=action3)
|
161
154
|
step_image = step_timestep.observation['pixels']
|
162
155
|
self.assertEqual(step_image.shape, (120, 80, 3))
|
163
156
|
last_action_layer = step_timestep.observation['last_action']
|
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2024 DeepMind Technologies Limited.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -17,7 +17,6 @@
|
|
17
17
|
|
18
18
|
import enum
|
19
19
|
import time
|
20
|
-
from typing import Dict
|
21
20
|
|
22
21
|
from android_env import env_interface
|
23
22
|
from android_env.components import action_type
|
@@ -78,9 +77,13 @@ class RateLimitWrapper(base_wrapper.BaseWrapper):
|
|
78
77
|
self._last_step_time = time.time()
|
79
78
|
return timestep
|
80
79
|
|
81
|
-
def step(self, action:
|
80
|
+
def step(self, action: dict[str, np.ndarray]) -> dm_env.TimeStep:
|
82
81
|
"""Takes a step while maintaining a steady interaction rate."""
|
83
82
|
|
83
|
+
# If max_wait is non-positive, the wrapper has no effect.
|
84
|
+
if self._max_wait <= 0.0:
|
85
|
+
return self._env.step(action)
|
86
|
+
|
84
87
|
if self._sleep_type == RateLimitWrapper.SleepType.BEFORE:
|
85
88
|
self._wait()
|
86
89
|
|
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2024 DeepMind Technologies Limited.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -16,6 +16,7 @@
|
|
16
16
|
"""Tests for rate_limit_wrapper."""
|
17
17
|
|
18
18
|
import time
|
19
|
+
from typing import Any, Protocol
|
19
20
|
from unittest import mock
|
20
21
|
|
21
22
|
from absl.testing import absltest
|
@@ -46,6 +47,17 @@ def _get_base_env():
|
|
46
47
|
return env
|
47
48
|
|
48
49
|
|
50
|
+
class _FnWithTimestamps(Protocol):
|
51
|
+
"""A function with `timestamp` and `timestamps` attributes."""
|
52
|
+
|
53
|
+
timestamp: float
|
54
|
+
timestamps: list[float]
|
55
|
+
|
56
|
+
|
57
|
+
def _with_timestamp(fn: Any) -> _FnWithTimestamps:
|
58
|
+
return fn
|
59
|
+
|
60
|
+
|
49
61
|
class RateLimitWrapperTest(parameterized.TestCase):
|
50
62
|
|
51
63
|
@parameterized.named_parameters(
|
@@ -64,6 +76,8 @@ class RateLimitWrapperTest(parameterized.TestCase):
|
|
64
76
|
'touch_position': np.array([0.123, 0.456])
|
65
77
|
})
|
66
78
|
mock_sleep.assert_not_called()
|
79
|
+
# When the wrapper is disabled, base step should only be called once.
|
80
|
+
env.step.assert_called_once()
|
67
81
|
|
68
82
|
@mock.patch.object(time, 'sleep', autospec=True)
|
69
83
|
def test_enabled(self, mock_sleep):
|
@@ -105,6 +119,7 @@ class RateLimitWrapperTest(parameterized.TestCase):
|
|
105
119
|
_ = wrapper.reset()
|
106
120
|
mock_sleep.assert_not_called() # It should never sleep during reset().
|
107
121
|
|
122
|
+
@_with_timestamp
|
108
123
|
def _sleep_fn(sleep_time):
|
109
124
|
_sleep_fn.timestamp = time.time()
|
110
125
|
self.assertBetween(sleep_time, 0.0, 33.33)
|
@@ -143,6 +158,7 @@ class RateLimitWrapperTest(parameterized.TestCase):
|
|
143
158
|
_ = wrapper.reset()
|
144
159
|
mock_sleep.assert_not_called() # It should never sleep during reset().
|
145
160
|
|
161
|
+
@_with_timestamp
|
146
162
|
def _sleep_fn(sleep_time):
|
147
163
|
_sleep_fn.timestamp = time.time()
|
148
164
|
self.assertBetween(sleep_time, 0.0, 33.33)
|
@@ -183,12 +199,14 @@ class RateLimitWrapperTest(parameterized.TestCase):
|
|
183
199
|
_ = wrapper.reset()
|
184
200
|
mock_sleep.assert_not_called() # It should never sleep during reset().
|
185
201
|
|
202
|
+
@_with_timestamp
|
186
203
|
def _sleep_fn(sleep_time):
|
187
204
|
_sleep_fn.timestamp = time.time()
|
188
205
|
self.assertBetween(sleep_time, 0.0, 33.33)
|
189
206
|
|
190
207
|
mock_sleep.side_effect = _sleep_fn
|
191
208
|
|
209
|
+
@_with_timestamp
|
192
210
|
def _step_fn(action):
|
193
211
|
# On even calls the action should be the actual agent action, but on odd
|
194
212
|
# calls they should be REPEATs.
|
@@ -212,6 +230,9 @@ class RateLimitWrapperTest(parameterized.TestCase):
|
|
212
230
|
'touch_position': np.array([0.123, 0.456])
|
213
231
|
})
|
214
232
|
|
233
|
+
# When the wrapper is enabled, base step should be called twice.
|
234
|
+
self.assertEqual(env.step.call_count, 2)
|
235
|
+
|
215
236
|
# `step()` should be called twice: before `sleep()` and after it.
|
216
237
|
self.assertLen(_step_fn.timestamps, 2)
|
217
238
|
self.assertGreaterEqual(_sleep_fn.timestamp, _step_fn.timestamps[0])
|
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2024 DeepMind Technologies Limited.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -15,7 +15,7 @@
|
|
15
15
|
|
16
16
|
"""Wraps the AndroidEnv environment to provide tap actions of a given duration."""
|
17
17
|
|
18
|
-
from
|
18
|
+
from collections.abc import Sequence
|
19
19
|
|
20
20
|
from android_env.components import action_type
|
21
21
|
from android_env.wrappers import base_wrapper
|
@@ -23,9 +23,6 @@ import dm_env
|
|
23
23
|
import numpy as np
|
24
24
|
|
25
25
|
|
26
|
-
ActionType = action_type.ActionType
|
27
|
-
|
28
|
-
|
29
26
|
class TapActionWrapper(base_wrapper.BaseWrapper):
|
30
27
|
"""AndroidEnv with tap actions."""
|
31
28
|
|
@@ -46,33 +43,35 @@ class TapActionWrapper(base_wrapper.BaseWrapper):
|
|
46
43
|
return logs
|
47
44
|
|
48
45
|
def _process_action(
|
49
|
-
self, action:
|
50
|
-
) -> Sequence[
|
51
|
-
|
46
|
+
self, action: dict[str, np.ndarray]
|
47
|
+
) -> Sequence[dict[str, np.ndarray]]:
|
52
48
|
if self._touch_only:
|
53
49
|
assert action['action_type'] == 0
|
54
50
|
touch_action = action.copy()
|
55
|
-
touch_action['action_type'] = np.array(
|
56
|
-
|
51
|
+
touch_action['action_type'] = np.array(
|
52
|
+
action_type.ActionType.TOUCH
|
53
|
+
).astype(self.action_spec()['action_type'].dtype)
|
57
54
|
actions = [touch_action] * self._num_frames
|
58
55
|
lift_action = action.copy()
|
59
|
-
lift_action['action_type'] = np.array(ActionType.LIFT).astype(
|
60
|
-
self.action_spec()['action_type'].dtype
|
56
|
+
lift_action['action_type'] = np.array(action_type.ActionType.LIFT).astype(
|
57
|
+
self.action_spec()['action_type'].dtype
|
58
|
+
)
|
61
59
|
actions.append(lift_action)
|
62
60
|
|
63
61
|
else:
|
64
|
-
if action['action_type'] == ActionType.TOUCH:
|
62
|
+
if action['action_type'] == action_type.ActionType.TOUCH:
|
65
63
|
actions = [action] * self._num_frames
|
66
64
|
lift_action = action.copy()
|
67
|
-
lift_action['action_type'] = np.array(
|
68
|
-
|
65
|
+
lift_action['action_type'] = np.array(
|
66
|
+
action_type.ActionType.LIFT
|
67
|
+
).astype(self.action_spec()['action_type'].dtype)
|
69
68
|
actions.append(lift_action)
|
70
69
|
else:
|
71
70
|
actions = [action] * (self._num_frames + 1)
|
72
71
|
|
73
72
|
return actions
|
74
73
|
|
75
|
-
def step(self, action:
|
74
|
+
def step(self, action: dict[str, np.ndarray]) -> dm_env.TimeStep:
|
76
75
|
"""Takes a step in the environment."""
|
77
76
|
self._env_steps += self._num_frames + 1
|
78
77
|
actions = self._process_action(action)
|
@@ -93,7 +92,7 @@ class TapActionWrapper(base_wrapper.BaseWrapper):
|
|
93
92
|
discount=discount,
|
94
93
|
observation=observation)
|
95
94
|
|
96
|
-
def action_spec(self) ->
|
95
|
+
def action_spec(self) -> dict[str, dm_env.specs.Array]:
|
97
96
|
if self._touch_only:
|
98
97
|
return {
|
99
98
|
'action_type':
|