android-env 1.2.1__py3-none-any.whl → 1.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. android_env/__init__.py +1 -1
  2. android_env/components/__init__.py +1 -1
  3. android_env/components/a11y/__init__.py +15 -0
  4. android_env/components/a11y/a11y_events.py +118 -0
  5. android_env/components/a11y/a11y_events_test.py +173 -0
  6. android_env/components/a11y/a11y_forests.py +128 -0
  7. android_env/components/a11y/a11y_forests_test.py +237 -0
  8. android_env/components/a11y/a11y_servicer.py +199 -0
  9. android_env/components/a11y/a11y_servicer_test.py +224 -0
  10. android_env/components/action_fns.py +132 -0
  11. android_env/components/action_fns_test.py +227 -0
  12. android_env/components/action_type.py +26 -3
  13. android_env/components/adb_call_parser.py +239 -196
  14. android_env/components/adb_call_parser_test.py +179 -209
  15. android_env/components/adb_controller.py +90 -52
  16. android_env/components/adb_controller_test.py +187 -16
  17. android_env/components/adb_log_stream.py +17 -5
  18. android_env/components/adb_log_stream_test.py +17 -3
  19. android_env/components/app_screen_checker.py +17 -15
  20. android_env/components/app_screen_checker_test.py +7 -8
  21. android_env/components/config_classes.py +203 -0
  22. android_env/components/coordinator.py +102 -338
  23. android_env/components/coordinator_test.py +59 -199
  24. android_env/components/device_settings.py +174 -0
  25. android_env/components/device_settings_test.py +228 -0
  26. android_env/components/dumpsys_thread.py +3 -4
  27. android_env/components/dumpsys_thread_test.py +1 -1
  28. android_env/components/errors.py +52 -10
  29. android_env/components/errors_test.py +110 -0
  30. android_env/components/log_stream.py +7 -5
  31. android_env/components/log_stream_test.py +1 -1
  32. android_env/components/logcat_thread.py +9 -8
  33. android_env/components/logcat_thread_test.py +3 -4
  34. android_env/components/{utils.py → pixel_fns.py} +20 -20
  35. android_env/components/{utils_test.py → pixel_fns_test.py} +20 -15
  36. android_env/components/setup_step_interpreter.py +47 -39
  37. android_env/components/setup_step_interpreter_test.py +4 -4
  38. android_env/components/simulators/__init__.py +1 -1
  39. android_env/components/simulators/base_simulator.py +116 -44
  40. android_env/components/simulators/base_simulator_test.py +131 -9
  41. android_env/components/simulators/emulator/__init__.py +1 -1
  42. android_env/components/simulators/emulator/emulator_launcher.py +67 -77
  43. android_env/components/simulators/emulator/emulator_launcher_test.py +153 -49
  44. android_env/components/simulators/emulator/emulator_simulator.py +276 -95
  45. android_env/components/simulators/emulator/emulator_simulator_test.py +314 -89
  46. android_env/components/simulators/fake/__init__.py +1 -1
  47. android_env/components/simulators/fake/fake_simulator.py +17 -25
  48. android_env/components/simulators/fake/fake_simulator_test.py +29 -12
  49. android_env/components/specs.py +18 -28
  50. android_env/components/specs_test.py +1 -44
  51. android_env/components/task_manager.py +48 -48
  52. android_env/components/task_manager_test.py +71 -60
  53. android_env/env_interface.py +37 -23
  54. android_env/environment.py +83 -51
  55. android_env/environment_test.py +68 -29
  56. android_env/loader.py +57 -43
  57. android_env/loader_test.py +115 -35
  58. android_env/proto/__init__.py +1 -1
  59. android_env/proto/a11y/__init__.py +15 -0
  60. android_env/proto/a11y/a11y.proto +75 -0
  61. android_env/proto/a11y/a11y_pb2.py +54 -0
  62. android_env/proto/a11y/a11y_pb2.pyi +49 -0
  63. android_env/proto/a11y/a11y_pb2_grpc.py +202 -0
  64. android_env/proto/a11y/android_accessibility_action.proto +32 -0
  65. android_env/proto/a11y/android_accessibility_action_pb2.py +37 -0
  66. android_env/proto/a11y/android_accessibility_action_pb2.pyi +13 -0
  67. android_env/proto/a11y/android_accessibility_action_pb2_grpc.py +24 -0
  68. android_env/proto/a11y/android_accessibility_forest.proto +29 -0
  69. android_env/proto/a11y/android_accessibility_forest_pb2.py +38 -0
  70. android_env/proto/a11y/android_accessibility_forest_pb2.pyi +13 -0
  71. android_env/proto/a11y/android_accessibility_forest_pb2_grpc.py +24 -0
  72. android_env/proto/a11y/android_accessibility_node_info.proto +122 -0
  73. android_env/proto/a11y/android_accessibility_node_info_clickable_span.proto +49 -0
  74. android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2.py +39 -0
  75. android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2.pyi +28 -0
  76. android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2_grpc.py +24 -0
  77. android_env/proto/a11y/android_accessibility_node_info_pb2.py +42 -0
  78. android_env/proto/a11y/android_accessibility_node_info_pb2.pyi +75 -0
  79. android_env/proto/a11y/android_accessibility_node_info_pb2_grpc.py +24 -0
  80. android_env/proto/a11y/android_accessibility_tree.proto +29 -0
  81. android_env/proto/a11y/android_accessibility_tree_pb2.py +38 -0
  82. android_env/proto/a11y/android_accessibility_tree_pb2.pyi +13 -0
  83. android_env/proto/a11y/android_accessibility_tree_pb2_grpc.py +24 -0
  84. android_env/proto/a11y/android_accessibility_window_info.proto +84 -0
  85. android_env/proto/a11y/android_accessibility_window_info_pb2.py +41 -0
  86. android_env/proto/a11y/android_accessibility_window_info_pb2.pyi +48 -0
  87. android_env/proto/a11y/android_accessibility_window_info_pb2_grpc.py +24 -0
  88. android_env/proto/a11y/rect.proto +30 -0
  89. android_env/proto/a11y/rect_pb2.py +37 -0
  90. android_env/proto/a11y/rect_pb2.pyi +17 -0
  91. android_env/proto/a11y/rect_pb2_grpc.py +24 -0
  92. android_env/proto/adb.proto +17 -6
  93. android_env/proto/adb_pb2.py +120 -107
  94. android_env/proto/adb_pb2.pyi +396 -0
  95. android_env/proto/adb_pb2_grpc.py +20 -0
  96. android_env/proto/emulator_controller.proto +68 -63
  97. android_env/proto/emulator_controller_pb2.py +142 -131
  98. android_env/proto/emulator_controller_pb2.pyi +672 -0
  99. android_env/proto/emulator_controller_pb2_grpc.py +505 -142
  100. android_env/proto/snapshot.proto +169 -0
  101. android_env/proto/snapshot_pb2.py +47 -0
  102. android_env/proto/snapshot_pb2.pyi +117 -0
  103. android_env/proto/snapshot_pb2_grpc.py +24 -0
  104. android_env/proto/snapshot_service.proto +289 -0
  105. android_env/proto/snapshot_service_pb2.py +54 -0
  106. android_env/proto/snapshot_service_pb2.pyi +86 -0
  107. android_env/proto/snapshot_service_pb2_grpc.py +487 -0
  108. android_env/proto/state.proto +63 -0
  109. android_env/proto/state_pb2.py +63 -0
  110. android_env/proto/state_pb2.pyi +85 -0
  111. android_env/proto/state_pb2_grpc.py +24 -0
  112. android_env/proto/task.proto +5 -1
  113. android_env/proto/task_pb2.py +42 -31
  114. android_env/proto/task_pb2.pyi +160 -0
  115. android_env/proto/task_pb2_grpc.py +20 -0
  116. android_env/wrappers/__init__.py +1 -1
  117. android_env/wrappers/a11y_grpc_wrapper.py +500 -0
  118. android_env/wrappers/a11y_grpc_wrapper_test.py +849 -0
  119. android_env/wrappers/base_wrapper.py +34 -13
  120. android_env/wrappers/base_wrapper_test.py +22 -16
  121. android_env/wrappers/discrete_action_wrapper.py +18 -17
  122. android_env/wrappers/discrete_action_wrapper_test.py +4 -4
  123. android_env/wrappers/flat_interface_wrapper.py +5 -5
  124. android_env/wrappers/flat_interface_wrapper_test.py +7 -11
  125. android_env/wrappers/float_pixels_wrapper.py +9 -10
  126. android_env/wrappers/float_pixels_wrapper_test.py +3 -3
  127. android_env/wrappers/gym_wrapper.py +19 -13
  128. android_env/wrappers/gym_wrapper_test.py +3 -5
  129. android_env/wrappers/image_rescale_wrapper.py +18 -21
  130. android_env/wrappers/image_rescale_wrapper_test.py +25 -37
  131. android_env/wrappers/last_action_wrapper.py +16 -13
  132. android_env/wrappers/last_action_wrapper_test.py +44 -51
  133. android_env/wrappers/rate_limit_wrapper.py +6 -3
  134. android_env/wrappers/rate_limit_wrapper_test.py +22 -1
  135. android_env/wrappers/tap_action_wrapper.py +16 -17
  136. android_env/wrappers/tap_action_wrapper_test.py +51 -16
  137. {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info}/METADATA +14 -18
  138. android_env-1.2.3.dist-info/RECORD +141 -0
  139. {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info}/WHEEL +1 -1
  140. android_env/proto/raw_observation.proto +0 -39
  141. android_env/proto/raw_observation_pb2.py +0 -27
  142. android_env/proto/raw_observation_pb2_grpc.py +0 -4
  143. android_env-1.2.1.dist-info/RECORD +0 -81
  144. {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info/licenses}/LICENSE +0 -0
  145. {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2022 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -13,11 +13,11 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- """Tests for android_env.components.utils."""
16
+ """Tests for pixel_fns."""
17
17
 
18
18
  from absl.testing import absltest
19
19
  from absl.testing import parameterized
20
- from android_env.components import utils
20
+ from android_env.components import pixel_fns
21
21
  from dm_env import specs
22
22
  import numpy as np
23
23
 
@@ -32,38 +32,42 @@ class UtilsTest(parameterized.TestCase):
32
32
  )
33
33
  def test_touch_position_to_pixel_position(
34
34
  self, touch_pos, width_height, pixel_pos):
35
- self.assertEqual(utils.touch_position_to_pixel_position(
36
- np.array(touch_pos), width_height), pixel_pos)
35
+ self.assertEqual(
36
+ pixel_fns.touch_position_to_pixel_position(
37
+ np.array(touch_pos), width_height
38
+ ),
39
+ pixel_pos,
40
+ )
37
41
 
38
42
  def test_transpose_pixels(self):
39
43
  image = np.reshape(np.array(range(12)), (3, 2, 2))
40
44
  expected = [[[0, 1], [4, 5], [8, 9]], [[2, 3], [6, 7], [10, 11]]]
41
- self.assertEqual(utils.transpose_pixels(image).shape, (2, 3, 2))
42
- self.assertTrue((utils.transpose_pixels(image) == expected).all())
45
+ self.assertEqual(pixel_fns.transpose_pixels(image).shape, (2, 3, 2))
46
+ self.assertTrue((pixel_fns.transpose_pixels(image) == expected).all())
43
47
 
44
48
  def test_orient_pixels(self):
45
49
  image = np.reshape(np.array(range(12)), (3, 2, 2))
46
50
 
47
51
  expected_90 = [[[8, 9], [4, 5], [0, 1]], [[10, 11], [6, 7], [2, 3]]]
48
52
  rot_90 = 1 # LANDSCAPE_90
49
- rotated = utils.orient_pixels(image, rot_90)
53
+ rotated = pixel_fns.orient_pixels(image, rot_90)
50
54
  self.assertEqual(rotated.shape, (2, 3, 2))
51
55
  self.assertTrue((rotated == expected_90).all())
52
56
 
53
57
  expected_180 = [[[10, 11], [8, 9]], [[6, 7], [4, 5]], [[2, 3], [0, 1]]]
54
58
  rot_180 = 2 # PORTRAIT_180
55
- rotated = utils.orient_pixels(image, rot_180)
59
+ rotated = pixel_fns.orient_pixels(image, rot_180)
56
60
  self.assertEqual(rotated.shape, (3, 2, 2))
57
61
  self.assertTrue((rotated == expected_180).all())
58
62
 
59
63
  expected_270 = [[[2, 3], [6, 7], [10, 11]], [[0, 1], [4, 5], [8, 9]]]
60
64
  rot_270 = 3 # LANDSCAPE_270
61
- rotated = utils.orient_pixels(image, rot_270)
65
+ rotated = pixel_fns.orient_pixels(image, rot_270)
62
66
  self.assertEqual(rotated.shape, (2, 3, 2))
63
67
  self.assertTrue((rotated == expected_270).all())
64
68
 
65
69
  rot_0 = 0 # PORTRAIT_0
66
- rotated = utils.orient_pixels(image, rot_0)
70
+ rotated = pixel_fns.orient_pixels(image, rot_0)
67
71
  self.assertEqual(rotated.shape, (3, 2, 2))
68
72
  self.assertTrue((rotated == image).all())
69
73
 
@@ -75,15 +79,16 @@ class UtilsTest(parameterized.TestCase):
75
79
  maximum=[5, 5, 20, 2],
76
80
  name='bounded_array')
77
81
  data = np.array([2, 2, 10, 0], dtype=np.int32)
78
- float_data = utils.convert_int_to_float(data, spec, np.float64)
82
+ float_data = pixel_fns.convert_int_to_float(data, spec)
79
83
  np.testing.assert_equal(
80
- np.array([2. / 5., 1. / 4., 0., 0.5], dtype=np.float64), float_data)
84
+ np.array([2.0 / 5.0, 1.0 / 4.0, 0.0, 0.5], dtype=np.float32), float_data
85
+ )
81
86
 
82
87
  def test_convert_int_to_float_bounded_array_broadcast(self):
83
88
  spec = specs.BoundedArray(
84
89
  shape=(3,), dtype=np.int16, minimum=2, maximum=4, name='bounded_array')
85
90
  data = np.array([2, 3, 4], dtype=np.int16)
86
- float_data = utils.convert_int_to_float(data, spec, np.float32)
91
+ float_data = pixel_fns.convert_int_to_float(data, spec)
87
92
  np.testing.assert_equal(
88
93
  np.array([0.0, 0.5, 1.0], dtype=np.float32), float_data)
89
94
 
@@ -93,7 +98,7 @@ class UtilsTest(parameterized.TestCase):
93
98
  dtype=np.int8, # int8 implies min=-128, max=127
94
99
  name='bounded_array')
95
100
  data = np.array([-128, 0, 127], dtype=np.int16)
96
- float_data = utils.convert_int_to_float(data, spec, np.float32)
101
+ float_data = pixel_fns.convert_int_to_float(data, spec)
97
102
  np.testing.assert_equal(
98
103
  np.array([0.0, 128. / 255., 1.0], dtype=np.float32), float_data)
99
104
 
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2022 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -15,9 +15,10 @@
15
15
 
16
16
  """A component that parses and processes SetupSteps."""
17
17
 
18
+ from collections.abc import Sequence
18
19
  import copy
19
20
  import time
20
- from typing import Any, Dict, Optional, Sequence
21
+ from typing import Any
21
22
 
22
23
  from absl import logging
23
24
  from android_env.components import adb_call_parser as adb_call_parser_lib
@@ -27,7 +28,7 @@ from android_env.proto import adb_pb2
27
28
  from android_env.proto import task_pb2
28
29
 
29
30
 
30
- class SetupStepInterpreter():
31
+ class SetupStepInterpreter:
31
32
  """An interpreter for SetupSteps."""
32
33
 
33
34
  def __init__(self, adb_call_parser: adb_call_parser_lib.AdbCallParser):
@@ -45,7 +46,7 @@ class SetupStepInterpreter():
45
46
  'total_time_waiting_for_app_screen': 0
46
47
  }
47
48
 
48
- def stats(self) -> Dict[str, Any]:
49
+ def stats(self) -> dict[str, Any]:
49
50
  return copy.deepcopy(self._stats)
50
51
 
51
52
  def interpret(self, setup_steps: Sequence[task_pb2.SetupStep]) -> None:
@@ -110,45 +111,52 @@ class SetupStepInterpreter():
110
111
  f'Step failed: [{step_cmd}]') from latest_error
111
112
 
112
113
  def _execute_step_cmd(
113
- self, step_cmd: task_pb2.SetupStep,
114
- step_type: Optional[str]) -> Optional[adb_pb2.AdbResponse]:
114
+ self, step_cmd: task_pb2.SetupStep, step_type: str | None
115
+ ) -> adb_pb2.AdbResponse | None:
115
116
  """Executes a step command of given type."""
116
- if not step_type:
117
- return
118
117
 
119
- if step_type == 'sleep':
120
- time.sleep(step_cmd.sleep.time_sec)
121
- elif step_type == 'adb_request':
122
- response = self._adb_call_parser.parse(step_cmd.adb_request)
123
- if response.status != adb_pb2.AdbResponse.Status.OK:
124
- raise errors.AdbControllerError(
125
- f'Failed to execute AdbRequest [{step_cmd.adb_request}].\n'
126
- f'Status: {response.status}\n'
127
- f'Error: {response.error_message}')
128
- return response
129
- else:
130
- raise NotImplementedError('No step command of type [%s].' % step_type)
131
-
132
- def _check_success(self,
133
- success_check: Optional[str],
134
- success_condition: task_pb2.SuccessCondition) -> None:
118
+ match step_type:
119
+ case None:
120
+ return None
121
+ case 'sleep':
122
+ time.sleep(step_cmd.sleep.time_sec)
123
+ return None
124
+ case 'adb_request':
125
+ response = self._adb_call_parser.parse(step_cmd.adb_request)
126
+ if response.status != adb_pb2.AdbResponse.Status.OK:
127
+ raise errors.AdbControllerError(
128
+ f'Failed to execute AdbRequest [{step_cmd.adb_request}].\n'
129
+ f'Status: {response.status}\n'
130
+ f'Error: {response.error_message}'
131
+ )
132
+ return response
133
+ case _:
134
+ raise NotImplementedError(f'No step command of type [{step_type}].')
135
+
136
+ def _check_success(
137
+ self,
138
+ success_check: str | None,
139
+ success_condition: task_pb2.SuccessCondition,
140
+ ) -> None:
135
141
  """Checks whether the given success condition was met."""
136
142
 
137
- if not success_check:
138
- return
139
-
140
- if success_check == 'wait_for_app_screen':
141
- wait_for_app_screen = success_condition.wait_for_app_screen
142
- screen_checker = app_screen_checker.AppScreenChecker(
143
- adb_call_parser=self._adb_call_parser,
144
- expected_app_screen=wait_for_app_screen.app_screen)
145
- wait_time = screen_checker.wait_for_app_screen(
146
- timeout_sec=wait_for_app_screen.timeout_sec)
147
- self._stats['total_time_waiting_for_app_screen'] += wait_time
148
- elif success_check == 'check_install':
149
- self._check_install(success_condition.check_install)
150
- else:
151
- raise NotImplementedError('No success check called [%s].' % success_check)
143
+ match success_check:
144
+ case None:
145
+ return None
146
+ case 'wait_for_app_screen':
147
+ wait_for_app_screen = success_condition.wait_for_app_screen
148
+ screen_checker = app_screen_checker.AppScreenChecker(
149
+ adb_call_parser=self._adb_call_parser,
150
+ expected_app_screen=wait_for_app_screen.app_screen,
151
+ )
152
+ wait_time = screen_checker.wait_for_app_screen(
153
+ timeout_sec=wait_for_app_screen.timeout_sec
154
+ )
155
+ self._stats['total_time_waiting_for_app_screen'] += wait_time
156
+ case 'check_install':
157
+ self._check_install(success_condition.check_install)
158
+ case _:
159
+ raise NotImplementedError(f'No success check called [{success_check}].')
152
160
 
153
161
  def _check_install(self, check_install: task_pb2.CheckInstall) -> None:
154
162
  """Checks that the given package is installed."""
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2022 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -57,14 +57,14 @@ class SetupStepInterpreterTest(absltest.TestCase):
57
57
  interpreter = setup_step_interpreter.SetupStepInterpreter(
58
58
  adb_call_parser=self._parser)
59
59
  # Empty setup steps should be ignored.
60
- interpreter.interpret([None])
60
+ interpreter.interpret([])
61
61
 
62
62
  def test_invalid_setup_step(self):
63
63
  interpreter = setup_step_interpreter.SetupStepInterpreter(
64
64
  adb_call_parser=self._parser)
65
65
  # Empty setup steps should be ignored.
66
- with self.assertRaises(AssertionError):
67
- interpreter.interpret([_to_proto(task_pb2.SetupStep, '')])
66
+ self.assertRaises(AssertionError, interpreter.interpret,
67
+ [task_pb2.SetupStep()])
68
68
 
69
69
  def test_adb_install_apk_filesystem(self):
70
70
  self._parser.parse.return_value = adb_pb2.AdbResponse(
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2022 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2022 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -16,34 +16,23 @@
16
16
  """A base class for talking to different types of Android simulators."""
17
17
 
18
18
  import abc
19
- from typing import List, Tuple
19
+ from collections.abc import Callable
20
+ import threading
21
+ import time
20
22
 
21
23
  from absl import logging
22
24
  from android_env.components import adb_controller
25
+ from android_env.components import config_classes
23
26
  from android_env.components import errors
24
27
  from android_env.components import log_stream
28
+ from android_env.proto import state_pb2
25
29
  import numpy as np
26
30
 
27
31
 
28
- def _print_logs_on_exception(func):
29
- """Decorator function for printing simulator logs upon any exception."""
30
- def wrapper(*args, **kwargs):
31
- try:
32
- return func(*args, **kwargs)
33
- except Exception as error:
34
- # Calls self.get_logs since self is the first arg.
35
- for line in args[0].get_logs().splitlines():
36
- logging.error(line)
37
- raise errors.SimulatorError(
38
- 'Exception caught in simulator. Please see the simulator logs '
39
- 'above for more details.') from error
40
- return wrapper
41
-
42
-
43
32
  class BaseSimulator(metaclass=abc.ABCMeta):
44
33
  """An interface for communicating with an Android simulator."""
45
34
 
46
- def __init__(self, verbose_logs: bool = False):
35
+ def __init__(self, config: config_classes.SimulatorConfig):
47
36
  """Instantiates a BaseSimulator object.
48
37
 
49
38
  The simulator may be an emulator, virtual machine or even a physical device.
@@ -51,14 +40,14 @@ class BaseSimulator(metaclass=abc.ABCMeta):
51
40
  bookkeeping.
52
41
 
53
42
  Args:
54
- verbose_logs: If true, the log stream of the simulator will be verbose.
43
+ config: Settings for this simulator.
55
44
  """
56
45
 
57
- self._verbose_logs = verbose_logs
58
- self._launched = False
46
+ self._config = config
47
+ self._interaction_thread: InteractionThread | None = None
59
48
 
60
- def is_launched(self) -> bool:
61
- return self._launched
49
+ # An increasing number that tracks the attempt at launching the simulator.
50
+ self._num_launch_attempts: int = 0
62
51
 
63
52
  def get_logs(self) -> str:
64
53
  """Returns logs recorded by the simulator (if provided)."""
@@ -76,20 +65,42 @@ class BaseSimulator(metaclass=abc.ABCMeta):
76
65
  def create_log_stream(self) -> log_stream.LogStream:
77
66
  """Creates a stream of logs from the simulator."""
78
67
 
79
- @abc.abstractmethod
80
- def _restart_impl(self) -> None:
81
- """Platform specific restart implementation."""
68
+ def launch(self) -> None:
69
+ """Starts the simulator."""
70
+
71
+ # Stop screenshot thread if it's enabled.
72
+ if self._interaction_thread is not None:
73
+ self._interaction_thread.stop()
74
+ self._interaction_thread.join()
75
+
76
+ self._num_launch_attempts += 1
77
+ try:
78
+ self._launch_impl()
79
+ except Exception as error:
80
+ for line in self.get_logs().splitlines():
81
+ logging.error(line)
82
+ raise errors.SimulatorError(
83
+ 'Exception caught in simulator. Please see the simulator logs '
84
+ 'above for more details.'
85
+ ) from error
86
+
87
+ # Start interaction thread.
88
+ if self._config.interaction_rate_sec > 0:
89
+ self._interaction_thread = InteractionThread(
90
+ self._get_screenshot_impl, self._config.interaction_rate_sec
91
+ )
92
+ self._interaction_thread.start()
82
93
 
83
94
  @abc.abstractmethod
84
95
  def _launch_impl(self) -> None:
85
96
  """Platform specific launch implementation."""
86
97
 
87
98
  @abc.abstractmethod
88
- def send_touch(self, touches: List[Tuple[int, int, bool, int]]) -> None:
99
+ def send_touch(self, touches: list[tuple[int, int, bool, int]]) -> None:
89
100
  """Sends a touch event to be executed on the simulator.
90
101
 
91
102
  Args:
92
- touches: A list of touch events. Each elemet in the list corresponds to a
103
+ touches: A list of touch events. Each element in the list corresponds to a
93
104
  single touch event. Each touch event tuple should have:
94
105
  0 x: The horizontal coordinate of this event.
95
106
  1 y: The vertical coordinate of this event.
@@ -107,25 +118,48 @@ class BaseSimulator(metaclass=abc.ABCMeta):
107
118
  event_type: Type of key event to be sent.
108
119
  """
109
120
 
110
- @_print_logs_on_exception
111
- def launch(self) -> None:
112
- """Launches the simulator."""
113
- if not self._launched:
114
- self._launched = True
115
- self._launch_impl()
116
- else:
117
- self.restart()
121
+ def load_state(
122
+ self, request: state_pb2.LoadStateRequest
123
+ ) -> state_pb2.LoadStateResponse:
124
+ """Loads a state.
118
125
 
119
- @_print_logs_on_exception
120
- def restart(self) -> None:
121
- """Restarts the simulator."""
122
- logging.info('Restarting the simulator...')
123
- self._restart_impl()
124
- logging.info('Done restarting the simulator.')
126
+ Args:
127
+ request: A `LoadStateRequest` containing any parameters necessary to
128
+ specify how/what state to load.
129
+
130
+ Returns:
131
+ A `LoadStateResponse` containing the status, error message (if
132
+ applicable), and any other relevant information.
133
+ """
134
+ raise NotImplementedError('This simulator does not support load_state()')
135
+
136
+ def save_state(
137
+ self, request: state_pb2.SaveStateRequest
138
+ ) -> state_pb2.SaveStateResponse:
139
+ """Saves a state.
140
+
141
+ Args:
142
+ request: A `SaveStateRequest` containing any parameters necessary to
143
+ specify how/what state to save.
144
+
145
+ Returns:
146
+ A `SaveStateResponse` containing the status, error message (if
147
+ applicable), and any other relevant information.
148
+ """
149
+ raise NotImplementedError('This simulator does not support save_state()')
125
150
 
126
- @abc.abstractmethod
127
151
  def get_screenshot(self) -> np.ndarray:
128
- """Returns pixels representing the current screenshot of the simulator.
152
+ """Returns pixels representing the current screenshot of the simulator."""
153
+
154
+ if self._config.interaction_rate_sec > 0:
155
+ assert self._interaction_thread is not None
156
+ return self._interaction_thread.screenshot() # Async mode.
157
+ else:
158
+ return self._get_screenshot_impl() # Sync mode.
159
+
160
+ @abc.abstractmethod
161
+ def _get_screenshot_impl(self) -> np.ndarray:
162
+ """Actual implementation of `get_screenshot()`.
129
163
 
130
164
  The output numpy array should have shape [height, width, num_channels] and
131
165
  can be loaded into PIL using Image.fromarray(img, mode='RGB') and be saved
@@ -134,3 +168,41 @@ class BaseSimulator(metaclass=abc.ABCMeta):
134
168
 
135
169
  def close(self):
136
170
  """Frees up resources allocated by this object."""
171
+
172
+ if self._interaction_thread is not None:
173
+ self._interaction_thread.stop()
174
+ self._interaction_thread.join()
175
+
176
+
177
+ class InteractionThread(threading.Thread):
178
+ """A thread that gets screenshot in the background."""
179
+
180
+ def __init__(
181
+ self,
182
+ get_screenshot_fn: Callable[[], np.ndarray],
183
+ interaction_rate_sec: float,
184
+ ):
185
+ super().__init__()
186
+ self._get_screenshot_fn = get_screenshot_fn
187
+ self._interaction_rate_sec = interaction_rate_sec
188
+ self._should_stop = threading.Event()
189
+ self._screenshot = self._get_screenshot_fn()
190
+
191
+ def run(self):
192
+ last_read = time.time()
193
+ while not self._should_stop.is_set():
194
+ self._screenshot = self._get_screenshot_fn()
195
+ now = time.time()
196
+ elapsed = now - last_read
197
+ last_read = now
198
+ sleep_time = self._interaction_rate_sec - elapsed
199
+ if sleep_time > 0.0:
200
+ time.sleep(sleep_time)
201
+ logging.info('InteractionThread.run() finished.')
202
+
203
+ def stop(self):
204
+ logging.info('Stopping InteractionThread.')
205
+ self._should_stop.set()
206
+
207
+ def screenshot(self) -> np.ndarray:
208
+ return self._screenshot
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2022 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -13,14 +13,16 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- """Tests for base_simulator."""
17
-
16
+ import itertools
17
+ import time
18
18
  from unittest import mock
19
19
 
20
20
  from absl.testing import absltest
21
+ from android_env.components import config_classes
21
22
  from android_env.components import errors
22
23
  # fake_simulator.FakeSimulator inherits from BaseSimulator, so there's no need
23
24
  # to import it here explicitly.
25
+ from android_env.components.simulators import base_simulator
24
26
  from android_env.components.simulators.fake import fake_simulator
25
27
  import numpy as np
26
28
 
@@ -28,19 +30,25 @@ import numpy as np
28
30
  class BaseSimulatorTest(absltest.TestCase):
29
31
 
30
32
  def test_launch(self):
31
- simulator = fake_simulator.FakeSimulator(screen_dimensions=(640, 480))
33
+ simulator = fake_simulator.FakeSimulator(
34
+ config_classes.FakeSimulatorConfig(screen_dimensions=(640, 480))
35
+ )
32
36
  # The simulator should launch and not crash.
33
37
  simulator.launch()
34
38
 
35
39
  def test_launch_close(self):
36
- simulator = fake_simulator.FakeSimulator()
40
+ simulator = fake_simulator.FakeSimulator(
41
+ config_classes.FakeSimulatorConfig()
42
+ )
37
43
  # The simulator should launch and not crash.
38
44
  simulator.launch()
39
45
  # Closing the simulator should also not crash.
40
46
  simulator.close()
41
47
 
42
48
  def test_get_screenshot(self):
43
- simulator = fake_simulator.FakeSimulator(screen_dimensions=(640, 480))
49
+ simulator = fake_simulator.FakeSimulator(
50
+ config_classes.FakeSimulatorConfig(screen_dimensions=(640, 480))
51
+ )
44
52
  # The simulator should launch and not crash.
45
53
  simulator.launch()
46
54
 
@@ -48,12 +56,126 @@ class BaseSimulatorTest(absltest.TestCase):
48
56
  np.testing.assert_equal(screenshot.shape, [640, 480, 3])
49
57
 
50
58
  def test_print_logs_on_exception(self):
51
- simulator = fake_simulator.FakeSimulator()
52
- with mock.patch.object(simulator, 'get_logs') as mock_get_logs, \
53
- mock.patch.object(simulator, '_launch_impl', autospec=True) as mock_launch:
59
+ simulator = fake_simulator.FakeSimulator(
60
+ config_classes.FakeSimulatorConfig()
61
+ )
62
+ with mock.patch.object(
63
+ simulator, 'get_logs'
64
+ ) as mock_get_logs, mock.patch.object(
65
+ simulator, '_launch_impl', autospec=True
66
+ ) as mock_launch:
54
67
  mock_launch.side_effect = ValueError('Oh no!')
55
68
  self.assertRaises(errors.SimulatorError, simulator.launch)
56
69
  mock_get_logs.assert_called_once()
57
70
 
71
+ def test_get_screenshot_error_async(self):
72
+ """An exception in the underlying interaction thread should bubble up."""
73
+
74
+ # Arrange.
75
+ mock_interaction_thread = mock.create_autospec(
76
+ base_simulator.InteractionThread
77
+ )
78
+ mock_interaction_thread.screenshot.side_effect = (
79
+ errors.ReadObservationError()
80
+ )
81
+ simulator = fake_simulator.FakeSimulator(
82
+ config_classes.FakeSimulatorConfig(interaction_rate_sec=0.5)
83
+ )
84
+ with mock.patch.object(
85
+ base_simulator,
86
+ 'InteractionThread',
87
+ autospec=True,
88
+ return_value=mock_interaction_thread,
89
+ ):
90
+ simulator.launch()
91
+
92
+ # Act & Assert.
93
+ self.assertRaises(errors.ReadObservationError, simulator.get_screenshot)
94
+
95
+ # Cleanup.
96
+ simulator.close()
97
+
98
+ def test_get_screenshot_faster_than_screenshot_impl(self):
99
+ """Return same screenshot when step is faster than the interaction rate."""
100
+
101
+ # Arrange.
102
+ slow_rate = 0.5
103
+ simulator = fake_simulator.FakeSimulator(
104
+ config_classes.FakeSimulatorConfig(interaction_rate_sec=slow_rate)
105
+ )
106
+
107
+ # Act.
108
+ with mock.patch.object(
109
+ simulator, '_get_screenshot_impl', autospec=True
110
+ ) as mock_get_screenshot_impl:
111
+ mock_get_screenshot_impl.side_effect = (
112
+ np.array(i, ndmin=3) for i in itertools.count(0, 1)
113
+ )
114
+ simulator.launch()
115
+ # Get two screenshots one after the other without pausing.
116
+ screenshot1 = simulator.get_screenshot()
117
+ screenshot2 = simulator.get_screenshot()
118
+
119
+ # Assert.
120
+ self.assertAlmostEqual(screenshot1[0][0][0], screenshot2[0][0][0])
121
+
122
+ # Cleanup.
123
+ simulator.close()
124
+
125
+ def test_get_screenshot_slower_than_screenshot_impl(self):
126
+ """Return different screenshots when step slower than the interaction rate."""
127
+
128
+ # Arrange.
129
+ fast_rate = 0.01
130
+ simulator = fake_simulator.FakeSimulator(
131
+ config_classes.FakeSimulatorConfig(interaction_rate_sec=fast_rate)
132
+ )
133
+
134
+ # Act.
135
+ with mock.patch.object(
136
+ simulator, '_get_screenshot_impl', autospec=True
137
+ ) as mock_get_screenshot_impl:
138
+ mock_get_screenshot_impl.side_effect = (
139
+ np.array(i, ndmin=3) for i in itertools.count(0, 1)
140
+ )
141
+ simulator.launch()
142
+ # Sleep for 500ms between two screenshots.
143
+ screenshot1 = simulator.get_screenshot()
144
+ time.sleep(0.5)
145
+ screenshot2 = simulator.get_screenshot()
146
+
147
+ # Assert.
148
+ self.assertNotEqual(screenshot1[0][0][0], screenshot2[0][0][0])
149
+
150
+ # Cleanup.
151
+ simulator.close()
152
+
153
+ def test_interaction_thread_closes_upon_relaunch(self):
154
+ """Async interaction should kill the InteractionThread when relaunching."""
155
+
156
+ # Arrange.
157
+ simulator = fake_simulator.FakeSimulator(
158
+ config_classes.FakeSimulatorConfig(interaction_rate_sec=0.01)
159
+ )
160
+ mock_interaction_thread = mock.create_autospec(
161
+ base_simulator.InteractionThread
162
+ )
163
+
164
+ # Act & Assert.
165
+ with mock.patch.object(
166
+ base_simulator,
167
+ 'InteractionThread',
168
+ autospec=True,
169
+ return_value=mock_interaction_thread,
170
+ ):
171
+ simulator.launch()
172
+ mock_interaction_thread.stop.assert_not_called()
173
+ mock_interaction_thread.join.assert_not_called()
174
+ simulator.launch()
175
+ mock_interaction_thread.stop.assert_called_once()
176
+ mock_interaction_thread.join.assert_called_once()
177
+ simulator.close()
178
+
179
+
58
180
  if __name__ == '__main__':
59
181
  absltest.main()
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2022 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.