android-env 1.2.1__py3-none-any.whl → 1.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. android_env/__init__.py +1 -1
  2. android_env/components/__init__.py +1 -1
  3. android_env/components/a11y/__init__.py +15 -0
  4. android_env/components/a11y/a11y_events.py +118 -0
  5. android_env/components/a11y/a11y_events_test.py +173 -0
  6. android_env/components/a11y/a11y_forests.py +128 -0
  7. android_env/components/a11y/a11y_forests_test.py +237 -0
  8. android_env/components/a11y/a11y_servicer.py +199 -0
  9. android_env/components/a11y/a11y_servicer_test.py +224 -0
  10. android_env/components/action_fns.py +132 -0
  11. android_env/components/action_fns_test.py +227 -0
  12. android_env/components/action_type.py +26 -3
  13. android_env/components/adb_call_parser.py +239 -196
  14. android_env/components/adb_call_parser_test.py +179 -209
  15. android_env/components/adb_controller.py +90 -52
  16. android_env/components/adb_controller_test.py +187 -16
  17. android_env/components/adb_log_stream.py +17 -5
  18. android_env/components/adb_log_stream_test.py +17 -3
  19. android_env/components/app_screen_checker.py +17 -15
  20. android_env/components/app_screen_checker_test.py +7 -8
  21. android_env/components/config_classes.py +203 -0
  22. android_env/components/coordinator.py +102 -338
  23. android_env/components/coordinator_test.py +59 -199
  24. android_env/components/device_settings.py +174 -0
  25. android_env/components/device_settings_test.py +228 -0
  26. android_env/components/dumpsys_thread.py +3 -4
  27. android_env/components/dumpsys_thread_test.py +1 -1
  28. android_env/components/errors.py +52 -10
  29. android_env/components/errors_test.py +110 -0
  30. android_env/components/log_stream.py +7 -5
  31. android_env/components/log_stream_test.py +1 -1
  32. android_env/components/logcat_thread.py +9 -8
  33. android_env/components/logcat_thread_test.py +3 -4
  34. android_env/components/{utils.py → pixel_fns.py} +20 -20
  35. android_env/components/{utils_test.py → pixel_fns_test.py} +20 -15
  36. android_env/components/setup_step_interpreter.py +47 -39
  37. android_env/components/setup_step_interpreter_test.py +4 -4
  38. android_env/components/simulators/__init__.py +1 -1
  39. android_env/components/simulators/base_simulator.py +116 -44
  40. android_env/components/simulators/base_simulator_test.py +131 -9
  41. android_env/components/simulators/emulator/__init__.py +1 -1
  42. android_env/components/simulators/emulator/emulator_launcher.py +67 -77
  43. android_env/components/simulators/emulator/emulator_launcher_test.py +153 -49
  44. android_env/components/simulators/emulator/emulator_simulator.py +276 -95
  45. android_env/components/simulators/emulator/emulator_simulator_test.py +314 -89
  46. android_env/components/simulators/fake/__init__.py +1 -1
  47. android_env/components/simulators/fake/fake_simulator.py +17 -25
  48. android_env/components/simulators/fake/fake_simulator_test.py +29 -12
  49. android_env/components/specs.py +18 -28
  50. android_env/components/specs_test.py +1 -44
  51. android_env/components/task_manager.py +48 -48
  52. android_env/components/task_manager_test.py +71 -60
  53. android_env/env_interface.py +37 -23
  54. android_env/environment.py +83 -51
  55. android_env/environment_test.py +68 -29
  56. android_env/loader.py +57 -43
  57. android_env/loader_test.py +115 -35
  58. android_env/proto/__init__.py +1 -1
  59. android_env/proto/a11y/__init__.py +15 -0
  60. android_env/proto/a11y/a11y.proto +75 -0
  61. android_env/proto/a11y/a11y_pb2.py +54 -0
  62. android_env/proto/a11y/a11y_pb2.pyi +49 -0
  63. android_env/proto/a11y/a11y_pb2_grpc.py +202 -0
  64. android_env/proto/a11y/android_accessibility_action.proto +32 -0
  65. android_env/proto/a11y/android_accessibility_action_pb2.py +37 -0
  66. android_env/proto/a11y/android_accessibility_action_pb2.pyi +13 -0
  67. android_env/proto/a11y/android_accessibility_action_pb2_grpc.py +24 -0
  68. android_env/proto/a11y/android_accessibility_forest.proto +29 -0
  69. android_env/proto/a11y/android_accessibility_forest_pb2.py +38 -0
  70. android_env/proto/a11y/android_accessibility_forest_pb2.pyi +13 -0
  71. android_env/proto/a11y/android_accessibility_forest_pb2_grpc.py +24 -0
  72. android_env/proto/a11y/android_accessibility_node_info.proto +122 -0
  73. android_env/proto/a11y/android_accessibility_node_info_clickable_span.proto +49 -0
  74. android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2.py +39 -0
  75. android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2.pyi +28 -0
  76. android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2_grpc.py +24 -0
  77. android_env/proto/a11y/android_accessibility_node_info_pb2.py +42 -0
  78. android_env/proto/a11y/android_accessibility_node_info_pb2.pyi +75 -0
  79. android_env/proto/a11y/android_accessibility_node_info_pb2_grpc.py +24 -0
  80. android_env/proto/a11y/android_accessibility_tree.proto +29 -0
  81. android_env/proto/a11y/android_accessibility_tree_pb2.py +38 -0
  82. android_env/proto/a11y/android_accessibility_tree_pb2.pyi +13 -0
  83. android_env/proto/a11y/android_accessibility_tree_pb2_grpc.py +24 -0
  84. android_env/proto/a11y/android_accessibility_window_info.proto +84 -0
  85. android_env/proto/a11y/android_accessibility_window_info_pb2.py +41 -0
  86. android_env/proto/a11y/android_accessibility_window_info_pb2.pyi +48 -0
  87. android_env/proto/a11y/android_accessibility_window_info_pb2_grpc.py +24 -0
  88. android_env/proto/a11y/rect.proto +30 -0
  89. android_env/proto/a11y/rect_pb2.py +37 -0
  90. android_env/proto/a11y/rect_pb2.pyi +17 -0
  91. android_env/proto/a11y/rect_pb2_grpc.py +24 -0
  92. android_env/proto/adb.proto +17 -6
  93. android_env/proto/adb_pb2.py +120 -107
  94. android_env/proto/adb_pb2.pyi +396 -0
  95. android_env/proto/adb_pb2_grpc.py +20 -0
  96. android_env/proto/emulator_controller.proto +68 -63
  97. android_env/proto/emulator_controller_pb2.py +142 -131
  98. android_env/proto/emulator_controller_pb2.pyi +672 -0
  99. android_env/proto/emulator_controller_pb2_grpc.py +505 -142
  100. android_env/proto/snapshot.proto +169 -0
  101. android_env/proto/snapshot_pb2.py +47 -0
  102. android_env/proto/snapshot_pb2.pyi +117 -0
  103. android_env/proto/snapshot_pb2_grpc.py +24 -0
  104. android_env/proto/snapshot_service.proto +289 -0
  105. android_env/proto/snapshot_service_pb2.py +54 -0
  106. android_env/proto/snapshot_service_pb2.pyi +86 -0
  107. android_env/proto/snapshot_service_pb2_grpc.py +487 -0
  108. android_env/proto/state.proto +63 -0
  109. android_env/proto/state_pb2.py +63 -0
  110. android_env/proto/state_pb2.pyi +85 -0
  111. android_env/proto/state_pb2_grpc.py +24 -0
  112. android_env/proto/task.proto +5 -1
  113. android_env/proto/task_pb2.py +42 -31
  114. android_env/proto/task_pb2.pyi +160 -0
  115. android_env/proto/task_pb2_grpc.py +20 -0
  116. android_env/wrappers/__init__.py +1 -1
  117. android_env/wrappers/a11y_grpc_wrapper.py +500 -0
  118. android_env/wrappers/a11y_grpc_wrapper_test.py +849 -0
  119. android_env/wrappers/base_wrapper.py +34 -13
  120. android_env/wrappers/base_wrapper_test.py +22 -16
  121. android_env/wrappers/discrete_action_wrapper.py +18 -17
  122. android_env/wrappers/discrete_action_wrapper_test.py +4 -4
  123. android_env/wrappers/flat_interface_wrapper.py +5 -5
  124. android_env/wrappers/flat_interface_wrapper_test.py +7 -11
  125. android_env/wrappers/float_pixels_wrapper.py +9 -10
  126. android_env/wrappers/float_pixels_wrapper_test.py +3 -3
  127. android_env/wrappers/gym_wrapper.py +19 -13
  128. android_env/wrappers/gym_wrapper_test.py +3 -5
  129. android_env/wrappers/image_rescale_wrapper.py +18 -21
  130. android_env/wrappers/image_rescale_wrapper_test.py +25 -37
  131. android_env/wrappers/last_action_wrapper.py +16 -13
  132. android_env/wrappers/last_action_wrapper_test.py +44 -51
  133. android_env/wrappers/rate_limit_wrapper.py +6 -3
  134. android_env/wrappers/rate_limit_wrapper_test.py +22 -1
  135. android_env/wrappers/tap_action_wrapper.py +16 -17
  136. android_env/wrappers/tap_action_wrapper_test.py +51 -16
  137. {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info}/METADATA +14 -18
  138. android_env-1.2.3.dist-info/RECORD +141 -0
  139. {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info}/WHEEL +1 -1
  140. android_env/proto/raw_observation.proto +0 -39
  141. android_env/proto/raw_observation_pb2.py +0 -27
  142. android_env/proto/raw_observation_pb2_grpc.py +0 -4
  143. android_env-1.2.1.dist-info/RECORD +0 -81
  144. {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info/licenses}/LICENSE +0 -0
  145. {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2022 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -15,12 +15,12 @@
15
15
 
16
16
  """Base class for AndroidEnv wrappers."""
17
17
 
18
- from typing import Any, Dict
18
+ from typing import Any
19
19
 
20
20
  from absl import logging
21
21
  from android_env import env_interface
22
22
  from android_env.proto import adb_pb2
23
- from android_env.proto import task_pb2
23
+ from android_env.proto import state_pb2
24
24
  import dm_env
25
25
  from dm_env import specs
26
26
  import numpy as np
@@ -42,7 +42,7 @@ class BaseWrapper(env_interface.AndroidEnvInterface):
42
42
  action = self._process_action(action)
43
43
  return self._process_timestep(self._env.step(action))
44
44
 
45
- def task_extras(self, latest_only: bool = True) -> Dict[str, np.ndarray]:
45
+ def task_extras(self, latest_only: bool = True) -> dict[str, np.ndarray]:
46
46
  return self._env.task_extras(latest_only=latest_only)
47
47
 
48
48
  def _reset_state(self):
@@ -54,31 +54,52 @@ class BaseWrapper(env_interface.AndroidEnvInterface):
54
54
  def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep:
55
55
  return timestep
56
56
 
57
- def observation_spec(self) -> Dict[str, specs.Array]:
57
+ def observation_spec(self) -> dict[str, specs.Array]:
58
58
  return self._env.observation_spec()
59
59
 
60
- def action_spec(self) -> Dict[str, specs.Array]:
60
+ def action_spec(self) -> dict[str, specs.Array]:
61
61
  return self._env.action_spec()
62
62
 
63
- def task_extras_spec(self) -> Dict[str, specs.Array]:
64
- return self._env.task_extras_spec()
63
+ def reward_spec(self) -> specs.Array:
64
+ return self._env.reward_spec()
65
65
 
66
- def _wrapper_stats(self) -> Dict[str, Any]:
66
+ def discount_spec(self) -> specs.Array:
67
+ return self._env.discount_spec()
68
+
69
+ def _wrapper_stats(self) -> dict[str, Any]:
67
70
  """Add wrapper specific logging here."""
68
71
  return {}
69
72
 
70
- def stats(self) -> Dict[str, Any]:
73
+ def stats(self) -> dict[str, Any]:
71
74
  info = self._env.stats()
72
75
  info.update(self._wrapper_stats())
73
76
  return info
74
77
 
78
+ def load_state(
79
+ self, request: state_pb2.LoadStateRequest
80
+ ) -> state_pb2.LoadStateResponse:
81
+ """Loads a state."""
82
+ return self._env.load_state(request)
83
+
84
+ def save_state(
85
+ self, request: state_pb2.SaveStateRequest
86
+ ) -> state_pb2.SaveStateResponse:
87
+ """Saves a state.
88
+
89
+ Args:
90
+ request: A `SaveStateRequest` containing any parameters necessary to
91
+ specify how/what state to save.
92
+
93
+ Returns:
94
+ A `SaveStateResponse` containing the status, error message (if
95
+ applicable), and any other relevant information.
96
+ """
97
+ return self._env.save_state(request)
98
+
75
99
  def execute_adb_call(self,
76
100
  adb_call: adb_pb2.AdbRequest) -> adb_pb2.AdbResponse:
77
101
  return self._env.execute_adb_call(adb_call)
78
102
 
79
- def update_task(self, task: task_pb2.Task) -> bool:
80
- return self._env.update_task(task)
81
-
82
103
  @property
83
104
  def raw_action(self):
84
105
  return self._env.raw_action
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2022 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -19,8 +19,8 @@ from unittest import mock
19
19
 
20
20
  from absl import logging
21
21
  from absl.testing import absltest
22
- from android_env import environment
23
- from android_env.proto import task_pb2
22
+ from android_env import env_interface
23
+ from android_env.proto import state_pb2
24
24
  from android_env.wrappers import base_wrapper
25
25
 
26
26
 
@@ -28,7 +28,7 @@ class BaseWrapperTest(absltest.TestCase):
28
28
 
29
29
  @mock.patch.object(logging, 'info')
30
30
  def test_base_function_forwarding(self, mock_info):
31
- base_env = mock.create_autospec(environment.AndroidEnv)
31
+ base_env = mock.create_autospec(env_interface.AndroidEnvInterface)
32
32
  wrapped_env = base_wrapper.BaseWrapper(base_env)
33
33
  mock_info.assert_called_with('Wrapping with %s', 'BaseWrapper')
34
34
 
@@ -58,11 +58,6 @@ class BaseWrapperTest(absltest.TestCase):
58
58
  self.assertEqual(fake_action_spec, wrapped_env.action_spec())
59
59
  base_env.action_spec.assert_called_once()
60
60
 
61
- fake_task_extras_spec = 'fake_task_extras_spec'
62
- base_env.task_extras_spec.return_value = fake_task_extras_spec
63
- self.assertEqual(fake_task_extras_spec, wrapped_env.task_extras_spec())
64
- base_env.task_extras_spec.assert_called_once()
65
-
66
61
  fake_raw_action = 'fake_raw_action'
67
62
  type(base_env).raw_action = mock.PropertyMock(return_value=fake_raw_action)
68
63
  self.assertEqual(fake_raw_action, wrapped_env.raw_action)
@@ -72,10 +67,21 @@ class BaseWrapperTest(absltest.TestCase):
72
67
  return_value=fake_raw_observation)
73
68
  self.assertEqual(fake_raw_observation, wrapped_env.raw_observation)
74
69
 
75
- task = task_pb2.Task(id='my_task')
76
- base_env.update_task.return_value = False
77
- self.assertFalse(wrapped_env.update_task(task))
78
- base_env.update_task.assert_called_once_with(task)
70
+ load_request = state_pb2.LoadStateRequest(args={})
71
+ expected_response = state_pb2.LoadStateResponse(
72
+ status=state_pb2.LoadStateResponse.Status.OK
73
+ )
74
+ base_env.load_state.return_value = expected_response
75
+ self.assertEqual(wrapped_env.load_state(load_request), expected_response)
76
+ base_env.load_state.assert_called_once_with(load_request)
77
+
78
+ save_request = state_pb2.SaveStateRequest(args={})
79
+ expected_response = state_pb2.SaveStateResponse(
80
+ status=state_pb2.SaveStateResponse.Status.OK
81
+ )
82
+ base_env.save_state.return_value = expected_response
83
+ self.assertEqual(wrapped_env.save_state(save_request), expected_response)
84
+ base_env.save_state.assert_called_once_with(save_request)
79
85
 
80
86
  wrapped_env.close()
81
87
  base_env.close.assert_called_once()
@@ -87,7 +93,7 @@ class BaseWrapperTest(absltest.TestCase):
87
93
  base_env.some_random_function.return_value = fake_return_value
88
94
 
89
95
  def test_multiple_wrappers(self):
90
- base_env = mock.create_autospec(environment.AndroidEnv)
96
+ base_env = mock.create_autospec(env_interface.AndroidEnvInterface)
91
97
  wrapped_env_1 = base_wrapper.BaseWrapper(base_env)
92
98
  wrapped_env_2 = base_wrapper.BaseWrapper(wrapped_env_1)
93
99
 
@@ -101,7 +107,7 @@ class BaseWrapperTest(absltest.TestCase):
101
107
  self.assertEqual(base_env, wrapped_env_2.raw_env)
102
108
 
103
109
  def test_stats(self):
104
- base_env = mock.create_autospec(environment.AndroidEnv)
110
+ base_env = mock.create_autospec(env_interface.AndroidEnvInterface)
105
111
  wrapped_env = base_wrapper.BaseWrapper(base_env)
106
112
  base_stats = {'base': 'stats'}
107
113
  base_env.stats.return_value = base_stats
@@ -109,7 +115,7 @@ class BaseWrapperTest(absltest.TestCase):
109
115
 
110
116
  @mock.patch.object(logging, 'info')
111
117
  def test_wrapped_stats(self, mock_info):
112
- base_env = mock.create_autospec(environment.AndroidEnv)
118
+ base_env = mock.create_autospec(env_interface.AndroidEnvInterface)
113
119
 
114
120
  class LoggingWrapper1(base_wrapper.BaseWrapper):
115
121
 
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2022 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -15,7 +15,7 @@
15
15
 
16
16
  """Wraps the AndroidEnv environment to provide discrete actions."""
17
17
 
18
- from typing import Optional, Sequence, Dict
18
+ from collections.abc import Sequence
19
19
 
20
20
  from android_env.components import action_type
21
21
  from android_env.wrappers import base_wrapper
@@ -24,23 +24,24 @@ from dm_env import specs
24
24
  import numpy as np
25
25
 
26
26
 
27
- NOISE_CLIP_VALUE = 0.4999
27
+ _NOISE_CLIP_VALUE = 0.4999
28
28
 
29
29
 
30
30
  class DiscreteActionWrapper(base_wrapper.BaseWrapper):
31
31
  """AndroidEnv with discrete actions."""
32
32
 
33
- def __init__(self,
34
- env: dm_env.Environment,
35
- action_grid: Optional[Sequence[int]] = (10, 10),
36
- redundant_actions: bool = True,
37
- noise: float = 0.1):
38
-
33
+ def __init__(
34
+ self,
35
+ env: dm_env.Environment,
36
+ action_grid: Sequence[int] = (10, 10),
37
+ redundant_actions: bool = True,
38
+ noise: float = 0.1,
39
+ ):
39
40
  super().__init__(env)
40
41
  self._parent_action_spec = self._env.action_spec()
41
42
  self._assert_base_env()
42
43
  self._action_grid = action_grid # [height, width]
43
- self._grid_size = np.product(self._action_grid)
44
+ self._grid_size = np.prod(self._action_grid)
44
45
  self._num_action_types = self._parent_action_spec['action_type'].num_values
45
46
  self._redundant_actions = redundant_actions
46
47
  self._noise = noise
@@ -57,16 +58,16 @@ class DiscreteActionWrapper(base_wrapper.BaseWrapper):
57
58
  """Number of discrete actions."""
58
59
 
59
60
  if self._redundant_actions:
60
- return np.product(self._action_grid) * self._num_action_types
61
+ return self._grid_size * self._num_action_types
61
62
  else:
62
- return np.product(self._action_grid) + self._num_action_types - 1
63
+ return self._grid_size + self._num_action_types - 1
63
64
 
64
- def step(self, action: Dict[str, int]) -> dm_env.TimeStep:
65
+ def step(self, action: dict[str, int]) -> dm_env.TimeStep:
65
66
  """Take a step in the base environment."""
66
67
 
67
68
  return self._env.step(self._process_action(action))
68
69
 
69
- def _process_action(self, action: Dict[str, int]) -> Dict[str, np.ndarray]:
70
+ def _process_action(self, action: dict[str, int]) -> dict[str, np.ndarray]:
70
71
  """Transforms action so that it agrees with AndroidEnv's action spec."""
71
72
 
72
73
  return {
@@ -133,8 +134,8 @@ class DiscreteActionWrapper(base_wrapper.BaseWrapper):
133
134
  noise_y = np.random.normal(loc=0.0, scale=self._noise)
134
135
 
135
136
  # Noise is clipped so that the action will strictly stay in the cell.
136
- noise_x = max(min(noise_x, NOISE_CLIP_VALUE), -NOISE_CLIP_VALUE)
137
- noise_y = max(min(noise_y, NOISE_CLIP_VALUE), -NOISE_CLIP_VALUE)
137
+ noise_x = max(min(noise_x, _NOISE_CLIP_VALUE), -_NOISE_CLIP_VALUE)
138
+ noise_y = max(min(noise_y, _NOISE_CLIP_VALUE), -_NOISE_CLIP_VALUE)
138
139
 
139
140
  x_pos = (x_pos_grid + 0.5 + noise_x) / self._action_grid[1] # WIDTH
140
141
  y_pos = (y_pos_grid + 0.5 + noise_y) / self._action_grid[0] # HEIGHT
@@ -149,7 +150,7 @@ class DiscreteActionWrapper(base_wrapper.BaseWrapper):
149
150
 
150
151
  return [x_pos, y_pos]
151
152
 
152
- def action_spec(self) -> Dict[str, specs.Array]:
153
+ def action_spec(self) -> dict[str, specs.Array]:
153
154
  """Action spec of the wrapped environment."""
154
155
 
155
156
  return {
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2022 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -18,7 +18,7 @@
18
18
  from unittest import mock
19
19
 
20
20
  from absl.testing import absltest
21
- from android_env import environment
21
+ from android_env import env_interface
22
22
  from android_env.components import action_type as action_type_lib
23
23
  from android_env.wrappers import discrete_action_wrapper
24
24
  from dm_env import specs
@@ -64,7 +64,7 @@ class DiscreteActionWrapperTest(absltest.TestCase):
64
64
  'touch_position': _make_array_spec(
65
65
  shape=(2,), dtype=np.float32, name='touch_position'),
66
66
  }
67
- self.base_env = mock.create_autospec(environment.AndroidEnv)
67
+ self.base_env = mock.create_autospec(env_interface.AndroidEnvInterface)
68
68
  self.base_env.action_spec.return_value = self._base_action_spec
69
69
 
70
70
  def test_num_actions(self):
@@ -295,7 +295,7 @@ class DiscreteActionWrapperTest(absltest.TestCase):
295
295
  'touch_position': _make_array_spec(
296
296
  shape=(2,), dtype=np.float64, name='touch_position'),
297
297
  }
298
- base_env = mock.create_autospec(environment.AndroidEnv)
298
+ base_env = mock.create_autospec(env_interface.AndroidEnvInterface)
299
299
  base_env.action_spec.return_value = base_action_spec
300
300
 
301
301
  wrapped_env = discrete_action_wrapper.DiscreteActionWrapper(
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2022 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -15,7 +15,7 @@
15
15
 
16
16
  """Wraps the AndroidEnv environment to make its interface flat."""
17
17
 
18
- from typing import Union, Dict, Any
18
+ from typing import Any
19
19
 
20
20
  from android_env.wrappers import base_wrapper
21
21
  import dm_env
@@ -74,7 +74,7 @@ class FlatInterfaceWrapper(base_wrapper.BaseWrapper):
74
74
  assert isinstance(base_action_spec, dict)
75
75
  assert isinstance(base_action_spec[self._action_name], specs.BoundedArray)
76
76
 
77
- def _process_action(self, action: Union[int, np.ndarray, Dict[str, Any]]):
77
+ def _process_action(self, action: int | np.ndarray | dict[str, Any]):
78
78
  if self._flat_actions:
79
79
  return {self._action_name: action}
80
80
  else:
@@ -103,7 +103,7 @@ class FlatInterfaceWrapper(base_wrapper.BaseWrapper):
103
103
  timestep = self._env.step(self._process_action(action))
104
104
  return self._process_timestep(timestep)
105
105
 
106
- def observation_spec(self) -> Union[specs.Array, Dict[str, specs.Array]]:
106
+ def observation_spec(self) -> specs.Array | dict[str, specs.Array]: # pytype: disable=signature-mismatch # overriding-return-type-checks
107
107
  if self._flat_observations:
108
108
  pixels_spec = self._env.observation_spec()['pixels']
109
109
  if not self._keep_action_layer:
@@ -112,7 +112,7 @@ class FlatInterfaceWrapper(base_wrapper.BaseWrapper):
112
112
  else:
113
113
  return self._env.observation_spec()
114
114
 
115
- def action_spec(self) -> Union[specs.Array, Dict[str, specs.Array]]:
115
+ def action_spec(self) -> specs.BoundedArray | dict[str, specs.Array]: # pytype: disable=signature-mismatch # overriding-return-type-checks
116
116
  if self._flat_actions:
117
117
  return self._env.action_spec()[self._action_name]
118
118
  else:
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2022 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -15,6 +15,7 @@
15
15
 
16
16
  """Tests for android_env.wrappers.flat_interface_wrapper."""
17
17
 
18
+ from typing import cast
18
19
  from unittest import mock
19
20
 
20
21
  from absl.testing import absltest
@@ -33,10 +34,6 @@ def _make_array_spec(shape, dtype=np.float32, name=None, maximum=3, minimum=0):
33
34
  minimum=np.ones(shape) * minimum)
34
35
 
35
36
 
36
- def _make_discrete_array_spec(name, num_values):
37
- return specs.DiscreteArray(name=name, num_values=num_values)
38
-
39
-
40
37
  def _make_timestep(observation):
41
38
  return dm_env.TimeStep(
42
39
  step_type='fake_step_type',
@@ -51,9 +48,8 @@ class FlatInterfaceWrapperTest(absltest.TestCase):
51
48
  def setUp(self):
52
49
  super().setUp()
53
50
  self.action_shape = (1,)
54
- self.base_action_spec = {
55
- 'action_id': _make_discrete_array_spec(
56
- name='action_id', num_values=4)
51
+ self.base_action_spec: dict[str, specs.DiscreteArray] = {
52
+ 'action_id': specs.DiscreteArray(name='action_id', num_values=4)
57
53
  }
58
54
  self.int_obs_shape = (3, 4, 2)
59
55
  self.float_obs_shape = (2,)
@@ -148,10 +144,10 @@ class FlatInterfaceWrapperTest(absltest.TestCase):
148
144
 
149
145
  def test_action_spec(self):
150
146
  wrapped_env = flat_interface_wrapper.FlatInterfaceWrapper(self.base_env)
151
- action_spec = wrapped_env.action_spec()
152
- base_action_spec = self.base_action_spec['action_id']
147
+ action_spec = cast(specs.BoundedArray, wrapped_env.action_spec())
148
+ parent_action_spec = self.base_action_spec['action_id']
153
149
 
154
- self.assertEqual(base_action_spec.name, action_spec.name)
150
+ self.assertEqual(parent_action_spec.name, action_spec.name)
155
151
  self.assertEqual((), action_spec.shape)
156
152
  self.assertEqual(np.int32, action_spec.dtype)
157
153
  self.assertEqual(0, action_spec.minimum)
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2022 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -15,9 +15,7 @@
15
15
 
16
16
  """Converts pixel observation to from int to float32 between 0.0 and 1.0."""
17
17
 
18
- from typing import Dict
19
-
20
- from android_env.components import utils
18
+ from android_env.components import pixel_fns
21
19
  from android_env.wrappers import base_wrapper
22
20
  import dm_env
23
21
  from dm_env import specs
@@ -34,11 +32,12 @@ class FloatPixelsWrapper(base_wrapper.BaseWrapper):
34
32
  np.integer)
35
33
 
36
34
  def _process_observation(
37
- self, observation: Dict[str, np.ndarray]
38
- ) -> Dict[str, np.ndarray]:
35
+ self, observation: dict[str, np.ndarray]
36
+ ) -> dict[str, np.ndarray]:
39
37
  if self._should_convert_int_to_float:
40
- float_pixels = utils.convert_int_to_float(observation['pixels'],
41
- self._input_spec, np.float32)
38
+ float_pixels = pixel_fns.convert_int_to_float(
39
+ observation['pixels'], self._input_spec
40
+ )
42
41
  observation['pixels'] = float_pixels
43
42
  return observation
44
43
 
@@ -53,10 +52,10 @@ class FloatPixelsWrapper(base_wrapper.BaseWrapper):
53
52
  def reset(self) -> dm_env.TimeStep:
54
53
  return self._process_timestep(self._env.reset())
55
54
 
56
- def step(self, action: Dict[str, np.ndarray]) -> dm_env.TimeStep:
55
+ def step(self, action: dict[str, np.ndarray]) -> dm_env.TimeStep:
57
56
  return self._process_timestep(self._env.step(action))
58
57
 
59
- def observation_spec(self) -> Dict[str, specs.Array]:
58
+ def observation_spec(self) -> dict[str, specs.Array]:
60
59
  if self._should_convert_int_to_float:
61
60
  observation_spec = self._env.observation_spec()
62
61
  observation_spec['pixels'] = specs.BoundedArray(
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2022 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -96,7 +96,7 @@ class FloatPixelsWrapperTest(absltest.TestCase):
96
96
 
97
97
  def test_float_pixels_wrapper_step(self):
98
98
  wrapped_env = float_pixels_wrapper.FloatPixelsWrapper(self.base_env)
99
- ts = wrapped_env.step('fake_action')
99
+ ts = wrapped_env.step({'fake_action': np.array([1, 2, 3])})
100
100
 
101
101
  self.assertEqual(self.base_timestep.step_type, ts.step_type)
102
102
  self.assertEqual(self.base_timestep.reward, ts.reward)
@@ -141,7 +141,7 @@ class FloatPixelsWrapperTest(absltest.TestCase):
141
141
  # The wrapper should not touch the timestep in this case.
142
142
  fake_timestep = ('step_type', 'reward', 'discount', 'obs')
143
143
  base_env.step.return_value = fake_timestep
144
- ts = wrapped_env.step('fake_action')
144
+ ts = wrapped_env.step({'fake_action': np.array([1, 2, 3])})
145
145
  self.assertEqual(fake_timestep, ts)
146
146
 
147
147
 
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2022 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -15,7 +15,7 @@
15
15
 
16
16
  """Wraps the AndroidEnv to expose an OpenAI Gym interface."""
17
17
 
18
- from typing import Any, Dict, Tuple
18
+ from typing import Any
19
19
 
20
20
  from android_env.wrappers import base_wrapper
21
21
  import dm_env
@@ -25,16 +25,16 @@ from gym import spaces
25
25
  import numpy as np
26
26
 
27
27
 
28
- class GymInterfaceWrapper(base_wrapper.BaseWrapper, gym.Env):
28
+ class GymInterfaceWrapper(gym.Env):
29
29
  """AndroidEnv with OpenAI Gym interface."""
30
30
 
31
31
  def __init__(self, env: dm_env.Environment):
32
-
33
- base_wrapper.BaseWrapper.__init__(self, env)
32
+ self._env = env
34
33
  self.spec = None
35
34
  self.action_space = self._spec_to_space(self._env.action_spec())
36
35
  self.observation_space = self._spec_to_space(self._env.observation_spec())
37
36
  self.metadata = {'render.modes': ['rgb_array']}
37
+ self._latest_observation = None
38
38
 
39
39
  def _spec_to_space(self, spec: specs.Array) -> spaces.Space:
40
40
  """Converts dm_env specs to OpenAI Gym spaces."""
@@ -44,7 +44,8 @@ class GymInterfaceWrapper(base_wrapper.BaseWrapper, gym.Env):
44
44
 
45
45
  if isinstance(spec, dict):
46
46
  return spaces.Dict(
47
- {name: self._spec_to_space(s) for name, s in spec.items()})
47
+ {name: self._spec_to_space(s) for name, s in spec.items()}
48
+ )
48
49
 
49
50
  if isinstance(spec, specs.DiscreteArray):
50
51
  return spaces.Box(
@@ -61,11 +62,13 @@ class GymInterfaceWrapper(base_wrapper.BaseWrapper, gym.Env):
61
62
  high=spec.maximum)
62
63
 
63
64
  if isinstance(spec, specs.Array):
64
- return spaces.Box(
65
- shape=spec.shape,
66
- dtype=spec.dtype,
67
- low=-np.inf,
68
- high=np.inf)
65
+ if spec.dtype == np.uint8:
66
+ low = 0
67
+ high = 255
68
+ else:
69
+ low = -np.inf
70
+ high = np.inf
71
+ return spaces.Box(shape=spec.shape, dtype=spec.dtype, low=low, high=high)
69
72
 
70
73
  raise ValueError('Unknown type for specs: {}'.format(spec))
71
74
 
@@ -74,18 +77,21 @@ class GymInterfaceWrapper(base_wrapper.BaseWrapper, gym.Env):
74
77
  if mode == 'rgb_array':
75
78
  if self._latest_observation is None:
76
79
  return
80
+
77
81
  return self._latest_observation['pixels']
78
82
  else:
79
83
  raise ValueError('Only supported render mode is rgb_array.')
80
84
 
81
85
  def reset(self) -> np.ndarray:
86
+ self._latest_observation = None
82
87
  timestep = self._env.reset()
83
88
  return timestep.observation
84
89
 
85
- def step(self, action: Dict[str, int]) -> Tuple[Any, ...]:
90
+ def step(self, action: dict[str, int]) -> tuple[Any, ...]:
86
91
  """Take a step in the base environment."""
87
- timestep = self._env.step(self._process_action(action))
92
+ timestep = self._env.step(action)
88
93
  observation = timestep.observation
94
+ self._latest_observation = observation
89
95
  reward = timestep.reward
90
96
  done = timestep.step_type == dm_env.StepType.LAST
91
97
  info = {'discount': timestep.discount}
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2022 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -18,7 +18,7 @@
18
18
  from unittest import mock
19
19
 
20
20
  from absl.testing import absltest
21
- from android_env import environment
21
+ from android_env import env_interface
22
22
  from android_env.wrappers import gym_wrapper
23
23
  import dm_env
24
24
  from dm_env import specs
@@ -30,7 +30,7 @@ class GymInterfaceWrapperTest(absltest.TestCase):
30
30
 
31
31
  def setUp(self):
32
32
  super().setUp()
33
- self._base_env = mock.create_autospec(environment.AndroidEnv)
33
+ self._base_env = mock.create_autospec(env_interface.AndroidEnvInterface)
34
34
  self._base_env.action_spec.return_value = {
35
35
  'action_type':
36
36
  specs.DiscreteArray(
@@ -68,7 +68,6 @@ class GymInterfaceWrapperTest(absltest.TestCase):
68
68
  def test_render(self):
69
69
  self._base_env.step.return_value = self._fake_ts
70
70
  _ = self._wrapped_env.step(action=np.zeros(shape=(1,)))
71
- self._base_env._latest_observation = {'pixels': np.ones(shape=(2, 3))}
72
71
  image = self._wrapped_env.render(mode='rgb_array')
73
72
  self.assertTrue(np.array_equal(image, np.ones(shape=(2, 3))))
74
73
 
@@ -90,7 +89,6 @@ class GymInterfaceWrapperTest(absltest.TestCase):
90
89
  self._base_env.step.return_value = self._fake_ts
91
90
  obs, _, _, _ = self._wrapped_env.step(action=np.zeros(shape=(1,)))
92
91
  self._base_env.step.assert_called_once()
93
- print(obs)
94
92
  self.assertTrue(np.array_equal(obs['pixels'], np.ones(shape=(2, 3))))
95
93
 
96
94
  def test_spec_to_space(self):