android-env 1.2.1__py3-none-any.whl → 1.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. android_env/__init__.py +1 -1
  2. android_env/components/__init__.py +1 -1
  3. android_env/components/a11y/__init__.py +15 -0
  4. android_env/components/a11y/a11y_events.py +118 -0
  5. android_env/components/a11y/a11y_events_test.py +173 -0
  6. android_env/components/a11y/a11y_forests.py +128 -0
  7. android_env/components/a11y/a11y_forests_test.py +237 -0
  8. android_env/components/a11y/a11y_servicer.py +199 -0
  9. android_env/components/a11y/a11y_servicer_test.py +224 -0
  10. android_env/components/action_fns.py +132 -0
  11. android_env/components/action_fns_test.py +227 -0
  12. android_env/components/action_type.py +26 -3
  13. android_env/components/adb_call_parser.py +239 -196
  14. android_env/components/adb_call_parser_test.py +179 -209
  15. android_env/components/adb_controller.py +90 -52
  16. android_env/components/adb_controller_test.py +187 -16
  17. android_env/components/adb_log_stream.py +17 -5
  18. android_env/components/adb_log_stream_test.py +17 -3
  19. android_env/components/app_screen_checker.py +17 -15
  20. android_env/components/app_screen_checker_test.py +7 -8
  21. android_env/components/config_classes.py +203 -0
  22. android_env/components/coordinator.py +102 -338
  23. android_env/components/coordinator_test.py +59 -199
  24. android_env/components/device_settings.py +174 -0
  25. android_env/components/device_settings_test.py +228 -0
  26. android_env/components/dumpsys_thread.py +3 -4
  27. android_env/components/dumpsys_thread_test.py +1 -1
  28. android_env/components/errors.py +52 -10
  29. android_env/components/errors_test.py +110 -0
  30. android_env/components/log_stream.py +7 -5
  31. android_env/components/log_stream_test.py +1 -1
  32. android_env/components/logcat_thread.py +9 -8
  33. android_env/components/logcat_thread_test.py +3 -4
  34. android_env/components/{utils.py → pixel_fns.py} +20 -20
  35. android_env/components/{utils_test.py → pixel_fns_test.py} +20 -15
  36. android_env/components/setup_step_interpreter.py +47 -39
  37. android_env/components/setup_step_interpreter_test.py +4 -4
  38. android_env/components/simulators/__init__.py +1 -1
  39. android_env/components/simulators/base_simulator.py +116 -44
  40. android_env/components/simulators/base_simulator_test.py +131 -9
  41. android_env/components/simulators/emulator/__init__.py +1 -1
  42. android_env/components/simulators/emulator/emulator_launcher.py +67 -77
  43. android_env/components/simulators/emulator/emulator_launcher_test.py +153 -49
  44. android_env/components/simulators/emulator/emulator_simulator.py +276 -95
  45. android_env/components/simulators/emulator/emulator_simulator_test.py +314 -89
  46. android_env/components/simulators/fake/__init__.py +1 -1
  47. android_env/components/simulators/fake/fake_simulator.py +17 -25
  48. android_env/components/simulators/fake/fake_simulator_test.py +29 -12
  49. android_env/components/specs.py +18 -28
  50. android_env/components/specs_test.py +1 -44
  51. android_env/components/task_manager.py +48 -48
  52. android_env/components/task_manager_test.py +71 -60
  53. android_env/env_interface.py +37 -23
  54. android_env/environment.py +83 -51
  55. android_env/environment_test.py +68 -29
  56. android_env/loader.py +57 -43
  57. android_env/loader_test.py +115 -35
  58. android_env/proto/__init__.py +1 -1
  59. android_env/proto/a11y/__init__.py +15 -0
  60. android_env/proto/a11y/a11y.proto +75 -0
  61. android_env/proto/a11y/a11y_pb2.py +54 -0
  62. android_env/proto/a11y/a11y_pb2.pyi +49 -0
  63. android_env/proto/a11y/a11y_pb2_grpc.py +202 -0
  64. android_env/proto/a11y/android_accessibility_action.proto +32 -0
  65. android_env/proto/a11y/android_accessibility_action_pb2.py +37 -0
  66. android_env/proto/a11y/android_accessibility_action_pb2.pyi +13 -0
  67. android_env/proto/a11y/android_accessibility_action_pb2_grpc.py +24 -0
  68. android_env/proto/a11y/android_accessibility_forest.proto +29 -0
  69. android_env/proto/a11y/android_accessibility_forest_pb2.py +38 -0
  70. android_env/proto/a11y/android_accessibility_forest_pb2.pyi +13 -0
  71. android_env/proto/a11y/android_accessibility_forest_pb2_grpc.py +24 -0
  72. android_env/proto/a11y/android_accessibility_node_info.proto +122 -0
  73. android_env/proto/a11y/android_accessibility_node_info_clickable_span.proto +49 -0
  74. android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2.py +39 -0
  75. android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2.pyi +28 -0
  76. android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2_grpc.py +24 -0
  77. android_env/proto/a11y/android_accessibility_node_info_pb2.py +42 -0
  78. android_env/proto/a11y/android_accessibility_node_info_pb2.pyi +75 -0
  79. android_env/proto/a11y/android_accessibility_node_info_pb2_grpc.py +24 -0
  80. android_env/proto/a11y/android_accessibility_tree.proto +29 -0
  81. android_env/proto/a11y/android_accessibility_tree_pb2.py +38 -0
  82. android_env/proto/a11y/android_accessibility_tree_pb2.pyi +13 -0
  83. android_env/proto/a11y/android_accessibility_tree_pb2_grpc.py +24 -0
  84. android_env/proto/a11y/android_accessibility_window_info.proto +84 -0
  85. android_env/proto/a11y/android_accessibility_window_info_pb2.py +41 -0
  86. android_env/proto/a11y/android_accessibility_window_info_pb2.pyi +48 -0
  87. android_env/proto/a11y/android_accessibility_window_info_pb2_grpc.py +24 -0
  88. android_env/proto/a11y/rect.proto +30 -0
  89. android_env/proto/a11y/rect_pb2.py +37 -0
  90. android_env/proto/a11y/rect_pb2.pyi +17 -0
  91. android_env/proto/a11y/rect_pb2_grpc.py +24 -0
  92. android_env/proto/adb.proto +17 -6
  93. android_env/proto/adb_pb2.py +120 -107
  94. android_env/proto/adb_pb2.pyi +396 -0
  95. android_env/proto/adb_pb2_grpc.py +20 -0
  96. android_env/proto/emulator_controller.proto +68 -63
  97. android_env/proto/emulator_controller_pb2.py +142 -131
  98. android_env/proto/emulator_controller_pb2.pyi +672 -0
  99. android_env/proto/emulator_controller_pb2_grpc.py +505 -142
  100. android_env/proto/snapshot.proto +169 -0
  101. android_env/proto/snapshot_pb2.py +47 -0
  102. android_env/proto/snapshot_pb2.pyi +117 -0
  103. android_env/proto/snapshot_pb2_grpc.py +24 -0
  104. android_env/proto/snapshot_service.proto +289 -0
  105. android_env/proto/snapshot_service_pb2.py +54 -0
  106. android_env/proto/snapshot_service_pb2.pyi +86 -0
  107. android_env/proto/snapshot_service_pb2_grpc.py +487 -0
  108. android_env/proto/state.proto +63 -0
  109. android_env/proto/state_pb2.py +63 -0
  110. android_env/proto/state_pb2.pyi +85 -0
  111. android_env/proto/state_pb2_grpc.py +24 -0
  112. android_env/proto/task.proto +5 -1
  113. android_env/proto/task_pb2.py +42 -31
  114. android_env/proto/task_pb2.pyi +160 -0
  115. android_env/proto/task_pb2_grpc.py +20 -0
  116. android_env/wrappers/__init__.py +1 -1
  117. android_env/wrappers/a11y_grpc_wrapper.py +500 -0
  118. android_env/wrappers/a11y_grpc_wrapper_test.py +849 -0
  119. android_env/wrappers/base_wrapper.py +34 -13
  120. android_env/wrappers/base_wrapper_test.py +22 -16
  121. android_env/wrappers/discrete_action_wrapper.py +18 -17
  122. android_env/wrappers/discrete_action_wrapper_test.py +4 -4
  123. android_env/wrappers/flat_interface_wrapper.py +5 -5
  124. android_env/wrappers/flat_interface_wrapper_test.py +7 -11
  125. android_env/wrappers/float_pixels_wrapper.py +9 -10
  126. android_env/wrappers/float_pixels_wrapper_test.py +3 -3
  127. android_env/wrappers/gym_wrapper.py +19 -13
  128. android_env/wrappers/gym_wrapper_test.py +3 -5
  129. android_env/wrappers/image_rescale_wrapper.py +18 -21
  130. android_env/wrappers/image_rescale_wrapper_test.py +25 -37
  131. android_env/wrappers/last_action_wrapper.py +16 -13
  132. android_env/wrappers/last_action_wrapper_test.py +44 -51
  133. android_env/wrappers/rate_limit_wrapper.py +6 -3
  134. android_env/wrappers/rate_limit_wrapper_test.py +22 -1
  135. android_env/wrappers/tap_action_wrapper.py +16 -17
  136. android_env/wrappers/tap_action_wrapper_test.py +51 -16
  137. {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info}/METADATA +14 -18
  138. android_env-1.2.3.dist-info/RECORD +141 -0
  139. {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info}/WHEEL +1 -1
  140. android_env/proto/raw_observation.proto +0 -39
  141. android_env/proto/raw_observation_pb2.py +0 -27
  142. android_env/proto/raw_observation_pb2_grpc.py +0 -4
  143. android_env-1.2.1.dist-info/RECORD +0 -81
  144. {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info/licenses}/LICENSE +0 -0
  145. {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2022 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -15,7 +15,7 @@
15
15
 
16
16
  """Wraps the AndroidEnv environment to rescale the observations."""
17
17
 
18
- from typing import Optional, Sequence, Dict
18
+ from collections.abc import Sequence
19
19
 
20
20
  from android_env.wrappers import base_wrapper
21
21
  import dm_env
@@ -37,8 +37,9 @@ class ImageRescaleWrapper(base_wrapper.BaseWrapper):
37
37
  def __init__(
38
38
  self,
39
39
  env: dm_env.Environment,
40
- zoom_factors: Optional[Sequence[float]] = (0.5, 0.5),
41
- grayscale: bool = False):
40
+ zoom_factors: Sequence[float] | None = (0.5, 0.5),
41
+ grayscale: bool = False,
42
+ ):
42
43
  super().__init__(env)
43
44
  assert 'pixels' in self._env.observation_spec()
44
45
  assert self._env.observation_spec()['pixels'].shape[-1] in [1, 3], (
@@ -50,16 +51,8 @@ class ImageRescaleWrapper(base_wrapper.BaseWrapper):
50
51
  # want to zoom the number of channels so we just multiply it by 1.0.
51
52
  self._zoom_factors = tuple(zoom_factors) + (1.0,)
52
53
 
53
- # Save the raw image for making videos, for example.
54
- self._raw_pixels = None
55
-
56
- @property
57
- def raw_pixels(self):
58
- return self._raw_pixels
59
-
60
54
  def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep:
61
55
  observation = timestep.observation
62
- self._raw_pixels = observation['pixels'].copy()
63
56
  processed_observation = observation.copy()
64
57
  processed_observation['pixels'] = self._process_pixels(
65
58
  observation['pixels'])
@@ -78,14 +71,16 @@ class ImageRescaleWrapper(base_wrapper.BaseWrapper):
78
71
  return self._resize_image_array(image, new_shape)
79
72
 
80
73
  def _resize_image_array(
81
- self,
82
- grayscale_or_rbg_array: np.ndarray,
83
- new_shape: Sequence[int]) -> np.ndarray:
74
+ self, grayscale_or_rbg_array: np.ndarray, new_shape: np.ndarray
75
+ ) -> np.ndarray:
84
76
  """Resize color or grayscale/action_layer array to new_shape."""
85
- assert np.array(new_shape).ndim == 1
77
+ assert new_shape.ndim == 1
86
78
  assert len(new_shape) == 2
87
- resized_array = np.array(Image.fromarray(
88
- grayscale_or_rbg_array.astype('uint8')).resize(new_shape))
79
+ resized_array = np.array(
80
+ Image.fromarray(grayscale_or_rbg_array.astype('uint8')).resize(
81
+ tuple(new_shape)
82
+ )
83
+ )
89
84
  if resized_array.ndim == 2:
90
85
  return np.expand_dims(resized_array, axis=-1)
91
86
  return resized_array
@@ -98,15 +93,17 @@ class ImageRescaleWrapper(base_wrapper.BaseWrapper):
98
93
  timestep = self._env.step(action)
99
94
  return self._process_timestep(timestep)
100
95
 
101
- def observation_spec(self) -> Dict[str, specs.Array]:
96
+ def observation_spec(self) -> dict[str, specs.Array]:
102
97
  parent_spec = self._env.observation_spec().copy()
103
98
  out_shape = np.multiply(parent_spec['pixels'].shape,
104
99
  self._zoom_factors).astype(np.int32)
105
100
  if self._grayscale:
106
101
  # In grayscale mode we want the output shape to be [W, H, 1]
107
102
  out_shape[-1] = 1
108
- parent_spec['pixels'] = specs.Array(
103
+ parent_spec['pixels'] = specs.BoundedArray(
109
104
  shape=out_shape,
110
105
  dtype=parent_spec['pixels'].dtype,
111
- name=parent_spec['pixels'].name)
106
+ name=parent_spec['pixels'].name,
107
+ minimum=parent_spec['pixels'].minimum,
108
+ maximum=parent_spec['pixels'].maximum)
112
109
  return parent_spec
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2022 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -15,42 +15,24 @@
15
15
 
16
16
  """Tests for android_env.wrappers.image_rescale_wrapper."""
17
17
 
18
- from typing import Any, Dict
18
+ from typing import Any
19
+ from unittest import mock
19
20
 
20
21
  from absl.testing import absltest
21
- from android_env import environment
22
+ from android_env import env_interface
22
23
  from android_env.wrappers import image_rescale_wrapper
23
24
  import dm_env
24
25
  from dm_env import specs
25
26
  import numpy as np
26
27
 
27
28
 
28
- class FakeEnv(environment.AndroidEnv):
29
- """A class that we can use to inject custom observations and specs."""
30
-
31
- def __init__(self, obs_spec):
32
- self._obs_spec = obs_spec
33
- self._next_obs = None
34
-
35
- def reset(self) -> dm_env.TimeStep:
36
- return self._next_timestep
37
-
38
- def step(self, action: Any) -> dm_env.TimeStep:
39
- return self._next_timestep
40
-
41
- def observation_spec(self) -> Dict[str, specs.Array]:
42
- return self._obs_spec
43
-
44
- def action_spec(self) -> Dict[str, specs.Array]:
45
- assert False, 'This should not be called by tests.'
46
-
47
- def set_next_timestep(self, timestep):
48
- self._next_timestep = timestep
49
-
50
-
51
29
  def _simple_spec():
52
- return specs.Array(
53
- shape=np.array([300, 300, 3]), dtype=np.uint8, name='pixels')
30
+ return specs.BoundedArray(
31
+ shape=np.array([300, 300, 3]),
32
+ dtype=np.uint8,
33
+ name='pixels',
34
+ minimum=0,
35
+ maximum=255)
54
36
 
55
37
 
56
38
  def _simple_timestep():
@@ -65,9 +47,11 @@ def _simple_timestep():
65
47
  class ImageRescaleWrapperTest(absltest.TestCase):
66
48
 
67
49
  def test_100x50_grayscale(self):
68
- obs_spec = {'pixels': _simple_spec()}
69
- fake_env = FakeEnv(obs_spec)
70
- fake_env.set_next_timestep(_simple_timestep())
50
+ fake_timestep = _simple_timestep()
51
+ fake_env = mock.create_autospec(env_interface.AndroidEnvInterface)
52
+ fake_env.observation_spec.return_value = {'pixels': _simple_spec()}
53
+ fake_env.reset.return_value = fake_timestep
54
+ fake_env.step.return_value = fake_timestep
71
55
 
72
56
  wrapper = image_rescale_wrapper.ImageRescaleWrapper(
73
57
  fake_env, zoom_factors=(1.0 / 3, 1.0 / 6.0), grayscale=True)
@@ -81,9 +65,11 @@ class ImageRescaleWrapperTest(absltest.TestCase):
81
65
  self.assertEqual(step_image.shape, (100, 50, 1))
82
66
 
83
67
  def test_150x60_full_channels(self):
84
- obs_spec = {'pixels': _simple_spec()}
85
- fake_env = FakeEnv(obs_spec)
86
- fake_env.set_next_timestep(_simple_timestep())
68
+ fake_timestep = _simple_timestep()
69
+ fake_env = mock.create_autospec(env_interface.AndroidEnvInterface)
70
+ fake_env.observation_spec.return_value = {'pixels': _simple_spec()}
71
+ fake_env.reset.return_value = fake_timestep
72
+ fake_env.step.return_value = fake_timestep
87
73
 
88
74
  wrapper = image_rescale_wrapper.ImageRescaleWrapper(
89
75
  fake_env, zoom_factors=(1.0 / 2.0, 1.0 / 5.0))
@@ -97,9 +83,11 @@ class ImageRescaleWrapperTest(absltest.TestCase):
97
83
  self.assertEqual(step_image.shape, (150, 60, 3))
98
84
 
99
85
  def test_list_zoom_factor(self):
100
- obs_spec = {'pixels': _simple_spec()}
101
- fake_env = FakeEnv(obs_spec)
102
- fake_env.set_next_timestep(_simple_timestep())
86
+ fake_timestep = _simple_timestep()
87
+ fake_env = mock.create_autospec(env_interface.AndroidEnvInterface)
88
+ fake_env.observation_spec.return_value = {'pixels': _simple_spec()}
89
+ fake_env.reset.return_value = fake_timestep
90
+ fake_env.step.return_value = fake_timestep
103
91
 
104
92
  wrapper = image_rescale_wrapper.ImageRescaleWrapper(
105
93
  fake_env, zoom_factors=[0.5, 0.2])
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2022 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -15,10 +15,8 @@
15
15
 
16
16
  """Extends Android observation with the latest action taken."""
17
17
 
18
- from typing import Dict
19
-
20
18
  from android_env.components import action_type
21
- from android_env.components import utils
19
+ from android_env.components import pixel_fns
22
20
  from android_env.wrappers import base_wrapper
23
21
  import dm_env
24
22
  from dm_env import specs
@@ -56,8 +54,8 @@ class LastActionWrapper(base_wrapper.BaseWrapper):
56
54
  return timestep._replace(observation=processed_observation)
57
55
 
58
56
  def _process_observation(
59
- self, observation: Dict[str, np.ndarray]
60
- ) -> Dict[str, np.ndarray]:
57
+ self, observation: dict[str, np.ndarray]
58
+ ) -> dict[str, np.ndarray]:
61
59
  """Extends observation with last_action data."""
62
60
  processed_observation = observation.copy()
63
61
  last_action_layer = self._get_last_action_layer(observation['pixels'])
@@ -78,8 +76,9 @@ class LastActionWrapper(base_wrapper.BaseWrapper):
78
76
  if ('action_type' in last_action and
79
77
  last_action['action_type'] == action_type.ActionType.TOUCH):
80
78
  touch_position = last_action['touch_position']
81
- x, y = utils.touch_position_to_pixel_position(
82
- touch_position, width_height=self._screen_dimensions[::-1])
79
+ x, y = pixel_fns.touch_position_to_pixel_position(
80
+ touch_position, width_height=self._screen_dimensions[::-1]
81
+ )
83
82
  last_action_layer[y, x] = 255
84
83
 
85
84
  return last_action_layer
@@ -92,20 +91,24 @@ class LastActionWrapper(base_wrapper.BaseWrapper):
92
91
  timestep = self._env.step(action)
93
92
  return self._process_timestep(timestep)
94
93
 
95
- def observation_spec(self) -> Dict[str, specs.Array]:
94
+ def observation_spec(self) -> dict[str, specs.Array]:
96
95
  parent_spec = self._env.observation_spec().copy()
97
96
  shape = parent_spec['pixels'].shape
98
97
  if self._concat_to_pixels:
99
- parent_spec['pixels'] = specs.Array(
98
+ parent_spec['pixels'] = specs.BoundedArray(
100
99
  shape=(shape[0], shape[1], shape[2] + 1),
101
100
  dtype=parent_spec['pixels'].dtype,
102
- name=parent_spec['pixels'].name)
101
+ name=parent_spec['pixels'].name,
102
+ minimum=parent_spec['pixels'].minimum,
103
+ maximum=parent_spec['pixels'].maximum)
103
104
  else:
104
105
  parent_spec.update({
105
106
  'last_action':
106
- specs.Array(
107
+ specs.BoundedArray(
107
108
  shape=(shape[0], shape[1]),
108
109
  dtype=parent_spec['pixels'].dtype,
109
- name='last_action')
110
+ name='last_action',
111
+ minimum=parent_spec['pixels'].minimum,
112
+ maximum=parent_spec['pixels'].maximum)
110
113
  })
111
114
  return parent_spec
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2022 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -15,10 +15,11 @@
15
15
 
16
16
  """Tests for android_env.wrappers.last_action_wrapper."""
17
17
 
18
- from typing import Any, Dict
18
+ from typing import Any
19
+ from unittest import mock
19
20
 
20
21
  from absl.testing import absltest
21
- from android_env import environment
22
+ from android_env import env_interface
22
23
  from android_env.components import action_type
23
24
  from android_env.wrappers import last_action_wrapper
24
25
  import dm_env
@@ -26,37 +27,13 @@ from dm_env import specs
26
27
  import numpy as np
27
28
 
28
29
 
29
- class FakeEnv(environment.AndroidEnv):
30
- """A class that we can use to inject custom observations and specs."""
31
-
32
- def __init__(self, obs_spec):
33
- self._obs_spec = obs_spec
34
- self._next_obs = None
35
- self._latest_action = {}
36
-
37
- def reset(self) -> dm_env.TimeStep:
38
- return self._next_timestep
39
-
40
- def step(self, action: Any) -> dm_env.TimeStep:
41
- self._latest_action = action
42
- return self._next_timestep
43
-
44
- def observation_spec(self) -> Dict[str, specs.Array]:
45
- return self._obs_spec
46
-
47
- def action_spec(self) -> Dict[str, specs.Array]:
48
- assert False, 'This should not be called by tests.'
49
-
50
- def set_next_timestep(self, timestep):
51
- self._next_timestep = timestep
52
-
53
- def close(self):
54
- pass
55
-
56
-
57
30
  def _simple_spec():
58
- return specs.Array(
59
- shape=np.array([120, 80, 3]), dtype=np.uint8, name='pixels')
31
+ return specs.BoundedArray(
32
+ shape=np.array([120, 80, 3]),
33
+ dtype=np.uint8,
34
+ name='pixels',
35
+ minimum=0,
36
+ maximum=255)
60
37
 
61
38
 
62
39
  def _simple_timestep():
@@ -71,9 +48,11 @@ def _simple_timestep():
71
48
  class LastActionWrapperTest(absltest.TestCase):
72
49
 
73
50
  def test_concat_to_pixels(self):
74
- obs_spec = {'pixels': _simple_spec()}
75
- fake_env = FakeEnv(obs_spec)
76
- fake_env.set_next_timestep(_simple_timestep())
51
+ fake_timestep = _simple_timestep()
52
+ fake_env = mock.create_autospec(env_interface.AndroidEnvInterface)
53
+ fake_env.observation_spec.return_value = {'pixels': _simple_spec()}
54
+ fake_env.reset.return_value = fake_timestep
55
+ fake_env.step.return_value = fake_timestep
77
56
 
78
57
  wrapper = last_action_wrapper.LastActionWrapper(
79
58
  fake_env, concat_to_pixels=True)
@@ -86,10 +65,12 @@ class LastActionWrapperTest(absltest.TestCase):
86
65
  last_action_layer = reset_image[:, :, -1]
87
66
  self.assertEqual(np.sum(last_action_layer), 0)
88
67
 
89
- step_timestep = wrapper.step(action={
68
+ action1 = {
90
69
  'action_type': action_type.ActionType.TOUCH,
91
70
  'touch_position': np.array([0.25, 0.75], dtype=np.float32), # (W x H)
92
- })
71
+ }
72
+ type(fake_env).raw_action = mock.PropertyMock(return_value=action1)
73
+ step_timestep = wrapper.step(action=action1)
93
74
  step_image = step_timestep.observation['pixels']
94
75
  self.assertEqual(step_image.shape, (120, 80, 4)) # (H x W)
95
76
  last_action_layer = step_image[:, :, -1]
@@ -97,19 +78,23 @@ class LastActionWrapperTest(absltest.TestCase):
97
78
  y, x = np.where(last_action_layer == 255)
98
79
  self.assertEqual((y.item(), x.item()), (90, 20))
99
80
 
100
- step_timestep = wrapper.step(action={
81
+ action2 = {
101
82
  'action_type': action_type.ActionType.LIFT,
102
83
  'touch_position': np.array([0.25, 0.75], dtype=np.float32),
103
- })
84
+ }
85
+ type(fake_env).raw_action = mock.PropertyMock(return_value=action2)
86
+ step_timestep = wrapper.step(action=action2)
104
87
  step_image = step_timestep.observation['pixels']
105
88
  self.assertEqual(step_image.shape, (120, 80, 4))
106
89
  last_action_layer = step_image[:, :, -1]
107
90
  self.assertEqual(np.sum(last_action_layer), 0)
108
91
 
109
- step_timestep = wrapper.step(action={
92
+ action3 = {
110
93
  'action_type': action_type.ActionType.TOUCH,
111
94
  'touch_position': np.array([0.25, 1.0], dtype=np.float32),
112
- })
95
+ }
96
+ type(fake_env).raw_action = mock.PropertyMock(return_value=action3)
97
+ step_timestep = wrapper.step(action=action3)
113
98
  step_image = step_timestep.observation['pixels']
114
99
  self.assertEqual(step_image.shape, (120, 80, 4))
115
100
  last_action_layer = step_image[:, :, -1]
@@ -118,9 +103,11 @@ class LastActionWrapperTest(absltest.TestCase):
118
103
  self.assertEqual((y.item(), x.item()), (119, 20))
119
104
 
120
105
  def test_no_concat_to_pixels(self):
121
- obs_spec = {'pixels': _simple_spec()}
122
- fake_env = FakeEnv(obs_spec)
123
- fake_env.set_next_timestep(_simple_timestep())
106
+ fake_timestep = _simple_timestep()
107
+ fake_env = mock.create_autospec(env_interface.AndroidEnvInterface)
108
+ fake_env.observation_spec.return_value = {'pixels': _simple_spec()}
109
+ fake_env.reset.return_value = fake_timestep
110
+ fake_env.step.return_value = fake_timestep
124
111
 
125
112
  wrapper = last_action_wrapper.LastActionWrapper(
126
113
  fake_env, concat_to_pixels=False)
@@ -134,10 +121,12 @@ class LastActionWrapperTest(absltest.TestCase):
134
121
  last_action_layer = reset_timestep.observation['last_action']
135
122
  self.assertEqual(np.sum(last_action_layer), 0)
136
123
 
137
- step_timestep = wrapper.step(action={
124
+ action1 = {
138
125
  'action_type': action_type.ActionType.TOUCH,
139
126
  'touch_position': np.array([0.25, 0.75], dtype=np.float32),
140
- })
127
+ }
128
+ type(fake_env).raw_action = mock.PropertyMock(return_value=action1)
129
+ step_timestep = wrapper.step(action=action1)
141
130
  step_image = step_timestep.observation['pixels']
142
131
  self.assertEqual(step_image.shape, (120, 80, 3))
143
132
  last_action_layer = step_timestep.observation['last_action']
@@ -145,19 +134,23 @@ class LastActionWrapperTest(absltest.TestCase):
145
134
  y, x = np.where(last_action_layer == 255)
146
135
  self.assertEqual((y.item(), x.item()), (90, 20))
147
136
 
148
- step_timestep = wrapper.step(action={
137
+ action2 = {
149
138
  'action_type': action_type.ActionType.LIFT,
150
139
  'touch_position': np.array([0.25, 0.75], dtype=np.float32),
151
- })
140
+ }
141
+ type(fake_env).raw_action = mock.PropertyMock(return_value=action2)
142
+ step_timestep = wrapper.step(action=action2)
152
143
  step_image = step_timestep.observation['pixels']
153
144
  self.assertEqual(step_image.shape, (120, 80, 3))
154
145
  last_action_layer = step_timestep.observation['last_action']
155
146
  self.assertEqual(np.sum(last_action_layer), 0)
156
147
 
157
- step_timestep = wrapper.step(action={
148
+ action3 = {
158
149
  'action_type': action_type.ActionType.TOUCH,
159
150
  'touch_position': np.array([1.0, 0.75], dtype=np.float32),
160
- })
151
+ }
152
+ type(fake_env).raw_action = mock.PropertyMock(return_value=action3)
153
+ step_timestep = wrapper.step(action=action3)
161
154
  step_image = step_timestep.observation['pixels']
162
155
  self.assertEqual(step_image.shape, (120, 80, 3))
163
156
  last_action_layer = step_timestep.observation['last_action']
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2022 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -17,7 +17,6 @@
17
17
 
18
18
  import enum
19
19
  import time
20
- from typing import Dict
21
20
 
22
21
  from android_env import env_interface
23
22
  from android_env.components import action_type
@@ -78,9 +77,13 @@ class RateLimitWrapper(base_wrapper.BaseWrapper):
78
77
  self._last_step_time = time.time()
79
78
  return timestep
80
79
 
81
- def step(self, action: Dict[str, np.ndarray]) -> dm_env.TimeStep:
80
+ def step(self, action: dict[str, np.ndarray]) -> dm_env.TimeStep:
82
81
  """Takes a step while maintaining a steady interaction rate."""
83
82
 
83
+ # If max_wait is non-positive, the wrapper has no effect.
84
+ if self._max_wait <= 0.0:
85
+ return self._env.step(action)
86
+
84
87
  if self._sleep_type == RateLimitWrapper.SleepType.BEFORE:
85
88
  self._wait()
86
89
 
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2022 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -16,6 +16,7 @@
16
16
  """Tests for rate_limit_wrapper."""
17
17
 
18
18
  import time
19
+ from typing import Any, Protocol
19
20
  from unittest import mock
20
21
 
21
22
  from absl.testing import absltest
@@ -46,6 +47,17 @@ def _get_base_env():
46
47
  return env
47
48
 
48
49
 
50
+ class _FnWithTimestamps(Protocol):
51
+ """A function with `timestamp` and `timestamps` attributes."""
52
+
53
+ timestamp: float
54
+ timestamps: list[float]
55
+
56
+
57
+ def _with_timestamp(fn: Any) -> _FnWithTimestamps:
58
+ return fn
59
+
60
+
49
61
  class RateLimitWrapperTest(parameterized.TestCase):
50
62
 
51
63
  @parameterized.named_parameters(
@@ -64,6 +76,8 @@ class RateLimitWrapperTest(parameterized.TestCase):
64
76
  'touch_position': np.array([0.123, 0.456])
65
77
  })
66
78
  mock_sleep.assert_not_called()
79
+ # When the wrapper is disabled, base step should only be called once.
80
+ env.step.assert_called_once()
67
81
 
68
82
  @mock.patch.object(time, 'sleep', autospec=True)
69
83
  def test_enabled(self, mock_sleep):
@@ -105,6 +119,7 @@ class RateLimitWrapperTest(parameterized.TestCase):
105
119
  _ = wrapper.reset()
106
120
  mock_sleep.assert_not_called() # It should never sleep during reset().
107
121
 
122
+ @_with_timestamp
108
123
  def _sleep_fn(sleep_time):
109
124
  _sleep_fn.timestamp = time.time()
110
125
  self.assertBetween(sleep_time, 0.0, 33.33)
@@ -143,6 +158,7 @@ class RateLimitWrapperTest(parameterized.TestCase):
143
158
  _ = wrapper.reset()
144
159
  mock_sleep.assert_not_called() # It should never sleep during reset().
145
160
 
161
+ @_with_timestamp
146
162
  def _sleep_fn(sleep_time):
147
163
  _sleep_fn.timestamp = time.time()
148
164
  self.assertBetween(sleep_time, 0.0, 33.33)
@@ -183,12 +199,14 @@ class RateLimitWrapperTest(parameterized.TestCase):
183
199
  _ = wrapper.reset()
184
200
  mock_sleep.assert_not_called() # It should never sleep during reset().
185
201
 
202
+ @_with_timestamp
186
203
  def _sleep_fn(sleep_time):
187
204
  _sleep_fn.timestamp = time.time()
188
205
  self.assertBetween(sleep_time, 0.0, 33.33)
189
206
 
190
207
  mock_sleep.side_effect = _sleep_fn
191
208
 
209
+ @_with_timestamp
192
210
  def _step_fn(action):
193
211
  # On even calls the action should be the actual agent action, but on odd
194
212
  # calls they should be REPEATs.
@@ -212,6 +230,9 @@ class RateLimitWrapperTest(parameterized.TestCase):
212
230
  'touch_position': np.array([0.123, 0.456])
213
231
  })
214
232
 
233
+ # When the wrapper is enabled, base step should be called twice.
234
+ self.assertEqual(env.step.call_count, 2)
235
+
215
236
  # `step()` should be called twice: before `sleep()` and after it.
216
237
  self.assertLen(_step_fn.timestamps, 2)
217
238
  self.assertGreaterEqual(_sleep_fn.timestamp, _step_fn.timestamps[0])
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2022 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -15,7 +15,7 @@
15
15
 
16
16
  """Wraps the AndroidEnv environment to provide tap actions of a given duration."""
17
17
 
18
- from typing import Dict, Sequence
18
+ from collections.abc import Sequence
19
19
 
20
20
  from android_env.components import action_type
21
21
  from android_env.wrappers import base_wrapper
@@ -23,9 +23,6 @@ import dm_env
23
23
  import numpy as np
24
24
 
25
25
 
26
- ActionType = action_type.ActionType
27
-
28
-
29
26
  class TapActionWrapper(base_wrapper.BaseWrapper):
30
27
  """AndroidEnv with tap actions."""
31
28
 
@@ -46,33 +43,35 @@ class TapActionWrapper(base_wrapper.BaseWrapper):
46
43
  return logs
47
44
 
48
45
  def _process_action(
49
- self, action: Dict[str, np.ndarray]
50
- ) -> Sequence[Dict[str, np.ndarray]]:
51
-
46
+ self, action: dict[str, np.ndarray]
47
+ ) -> Sequence[dict[str, np.ndarray]]:
52
48
  if self._touch_only:
53
49
  assert action['action_type'] == 0
54
50
  touch_action = action.copy()
55
- touch_action['action_type'] = np.array(ActionType.TOUCH).astype(
56
- self.action_spec()['action_type'].dtype)
51
+ touch_action['action_type'] = np.array(
52
+ action_type.ActionType.TOUCH
53
+ ).astype(self.action_spec()['action_type'].dtype)
57
54
  actions = [touch_action] * self._num_frames
58
55
  lift_action = action.copy()
59
- lift_action['action_type'] = np.array(ActionType.LIFT).astype(
60
- self.action_spec()['action_type'].dtype)
56
+ lift_action['action_type'] = np.array(action_type.ActionType.LIFT).astype(
57
+ self.action_spec()['action_type'].dtype
58
+ )
61
59
  actions.append(lift_action)
62
60
 
63
61
  else:
64
- if action['action_type'] == ActionType.TOUCH:
62
+ if action['action_type'] == action_type.ActionType.TOUCH:
65
63
  actions = [action] * self._num_frames
66
64
  lift_action = action.copy()
67
- lift_action['action_type'] = np.array(ActionType.LIFT).astype(
68
- self.action_spec()['action_type'].dtype)
65
+ lift_action['action_type'] = np.array(
66
+ action_type.ActionType.LIFT
67
+ ).astype(self.action_spec()['action_type'].dtype)
69
68
  actions.append(lift_action)
70
69
  else:
71
70
  actions = [action] * (self._num_frames + 1)
72
71
 
73
72
  return actions
74
73
 
75
- def step(self, action: Dict[str, np.ndarray]) -> dm_env.TimeStep:
74
+ def step(self, action: dict[str, np.ndarray]) -> dm_env.TimeStep:
76
75
  """Takes a step in the environment."""
77
76
  self._env_steps += self._num_frames + 1
78
77
  actions = self._process_action(action)
@@ -93,7 +92,7 @@ class TapActionWrapper(base_wrapper.BaseWrapper):
93
92
  discount=discount,
94
93
  observation=observation)
95
94
 
96
- def action_spec(self) -> Dict[str, dm_env.specs.Array]:
95
+ def action_spec(self) -> dict[str, dm_env.specs.Array]:
97
96
  if self._touch_only:
98
97
  return {
99
98
  'action_type':