android-env 1.2.2__py3-none-any.whl → 1.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (142) hide show
  1. android_env/__init__.py +1 -1
  2. android_env/components/__init__.py +1 -1
  3. android_env/components/a11y/__init__.py +15 -0
  4. android_env/components/a11y/a11y_events.py +118 -0
  5. android_env/components/a11y/a11y_events_test.py +173 -0
  6. android_env/components/a11y/a11y_forests.py +128 -0
  7. android_env/components/a11y/a11y_forests_test.py +237 -0
  8. android_env/components/a11y/a11y_servicer.py +199 -0
  9. android_env/components/a11y/a11y_servicer_test.py +224 -0
  10. android_env/components/action_fns.py +132 -0
  11. android_env/components/action_fns_test.py +227 -0
  12. android_env/components/action_type.py +26 -3
  13. android_env/components/adb_call_parser.py +233 -185
  14. android_env/components/adb_call_parser_test.py +165 -163
  15. android_env/components/adb_controller.py +19 -28
  16. android_env/components/adb_controller_test.py +100 -9
  17. android_env/components/adb_log_stream.py +3 -3
  18. android_env/components/adb_log_stream_test.py +1 -1
  19. android_env/components/app_screen_checker.py +15 -13
  20. android_env/components/app_screen_checker_test.py +1 -1
  21. android_env/components/config_classes.py +203 -0
  22. android_env/components/coordinator.py +53 -338
  23. android_env/components/coordinator_test.py +26 -283
  24. android_env/components/device_settings.py +174 -0
  25. android_env/components/device_settings_test.py +228 -0
  26. android_env/components/dumpsys_thread.py +3 -4
  27. android_env/components/dumpsys_thread_test.py +1 -1
  28. android_env/components/errors.py +2 -5
  29. android_env/components/errors_test.py +1 -1
  30. android_env/components/log_stream.py +2 -2
  31. android_env/components/log_stream_test.py +1 -1
  32. android_env/components/logcat_thread.py +9 -8
  33. android_env/components/logcat_thread_test.py +2 -3
  34. android_env/components/{utils.py → pixel_fns.py} +19 -20
  35. android_env/components/{utils_test.py → pixel_fns_test.py} +20 -15
  36. android_env/components/setup_step_interpreter.py +45 -37
  37. android_env/components/setup_step_interpreter_test.py +1 -1
  38. android_env/components/simulators/__init__.py +1 -1
  39. android_env/components/simulators/base_simulator.py +79 -23
  40. android_env/components/simulators/base_simulator_test.py +131 -9
  41. android_env/components/simulators/emulator/__init__.py +1 -1
  42. android_env/components/simulators/emulator/emulator_launcher.py +62 -81
  43. android_env/components/simulators/emulator/emulator_launcher_test.py +120 -43
  44. android_env/components/simulators/emulator/emulator_simulator.py +111 -98
  45. android_env/components/simulators/emulator/emulator_simulator_test.py +174 -138
  46. android_env/components/simulators/fake/__init__.py +1 -1
  47. android_env/components/simulators/fake/fake_simulator.py +9 -17
  48. android_env/components/simulators/fake/fake_simulator_test.py +23 -8
  49. android_env/components/specs.py +1 -1
  50. android_env/components/specs_test.py +1 -1
  51. android_env/components/task_manager.py +26 -31
  52. android_env/components/task_manager_test.py +1 -18
  53. android_env/env_interface.py +1 -17
  54. android_env/environment.py +27 -17
  55. android_env/environment_test.py +51 -25
  56. android_env/loader.py +57 -43
  57. android_env/loader_test.py +115 -35
  58. android_env/proto/__init__.py +1 -1
  59. android_env/proto/a11y/__init__.py +15 -0
  60. android_env/proto/a11y/a11y.proto +75 -0
  61. android_env/proto/a11y/a11y_pb2.py +54 -0
  62. android_env/proto/a11y/a11y_pb2.pyi +49 -0
  63. android_env/proto/a11y/a11y_pb2_grpc.py +202 -0
  64. android_env/proto/a11y/android_accessibility_action.proto +32 -0
  65. android_env/proto/a11y/android_accessibility_action_pb2.py +37 -0
  66. android_env/proto/a11y/android_accessibility_action_pb2.pyi +13 -0
  67. android_env/proto/a11y/android_accessibility_action_pb2_grpc.py +24 -0
  68. android_env/proto/a11y/android_accessibility_forest.proto +29 -0
  69. android_env/proto/a11y/android_accessibility_forest_pb2.py +38 -0
  70. android_env/proto/a11y/android_accessibility_forest_pb2.pyi +13 -0
  71. android_env/proto/a11y/android_accessibility_forest_pb2_grpc.py +24 -0
  72. android_env/proto/a11y/android_accessibility_node_info.proto +122 -0
  73. android_env/proto/a11y/android_accessibility_node_info_clickable_span.proto +49 -0
  74. android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2.py +39 -0
  75. android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2.pyi +28 -0
  76. android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2_grpc.py +24 -0
  77. android_env/proto/a11y/android_accessibility_node_info_pb2.py +42 -0
  78. android_env/proto/a11y/android_accessibility_node_info_pb2.pyi +75 -0
  79. android_env/proto/a11y/android_accessibility_node_info_pb2_grpc.py +24 -0
  80. android_env/proto/a11y/android_accessibility_tree.proto +29 -0
  81. android_env/proto/a11y/android_accessibility_tree_pb2.py +38 -0
  82. android_env/proto/a11y/android_accessibility_tree_pb2.pyi +13 -0
  83. android_env/proto/a11y/android_accessibility_tree_pb2_grpc.py +24 -0
  84. android_env/proto/a11y/android_accessibility_window_info.proto +84 -0
  85. android_env/proto/a11y/android_accessibility_window_info_pb2.py +41 -0
  86. android_env/proto/a11y/android_accessibility_window_info_pb2.pyi +48 -0
  87. android_env/proto/a11y/android_accessibility_window_info_pb2_grpc.py +24 -0
  88. android_env/proto/a11y/rect.proto +30 -0
  89. android_env/proto/a11y/rect_pb2.py +37 -0
  90. android_env/proto/a11y/rect_pb2.pyi +17 -0
  91. android_env/proto/a11y/rect_pb2_grpc.py +24 -0
  92. android_env/proto/adb.proto +13 -1
  93. android_env/proto/adb_pb2.py +120 -107
  94. android_env/proto/adb_pb2.pyi +396 -0
  95. android_env/proto/adb_pb2_grpc.py +20 -0
  96. android_env/proto/emulator_controller.proto +1 -1
  97. android_env/proto/emulator_controller_pb2.py +142 -131
  98. android_env/proto/emulator_controller_pb2.pyi +672 -0
  99. android_env/proto/emulator_controller_pb2_grpc.py +497 -136
  100. android_env/proto/snapshot.proto +1 -1
  101. android_env/proto/snapshot_pb2.py +30 -19
  102. android_env/proto/snapshot_pb2.pyi +117 -0
  103. android_env/proto/snapshot_pb2_grpc.py +20 -0
  104. android_env/proto/snapshot_service.proto +1 -1
  105. android_env/proto/snapshot_service_pb2.py +36 -25
  106. android_env/proto/snapshot_service_pb2.pyi +86 -0
  107. android_env/proto/snapshot_service_pb2_grpc.py +119 -28
  108. android_env/proto/state.proto +1 -1
  109. android_env/proto/state_pb2.py +46 -35
  110. android_env/proto/state_pb2.pyi +85 -0
  111. android_env/proto/state_pb2_grpc.py +20 -0
  112. android_env/proto/task.proto +4 -1
  113. android_env/proto/task_pb2.py +41 -30
  114. android_env/proto/task_pb2.pyi +160 -0
  115. android_env/proto/task_pb2_grpc.py +20 -0
  116. android_env/wrappers/__init__.py +1 -1
  117. android_env/wrappers/a11y_grpc_wrapper.py +500 -0
  118. android_env/wrappers/a11y_grpc_wrapper_test.py +849 -0
  119. android_env/wrappers/base_wrapper.py +1 -5
  120. android_env/wrappers/base_wrapper_test.py +1 -7
  121. android_env/wrappers/discrete_action_wrapper.py +15 -14
  122. android_env/wrappers/discrete_action_wrapper_test.py +1 -1
  123. android_env/wrappers/flat_interface_wrapper.py +5 -5
  124. android_env/wrappers/flat_interface_wrapper_test.py +1 -1
  125. android_env/wrappers/float_pixels_wrapper.py +5 -4
  126. android_env/wrappers/float_pixels_wrapper_test.py +1 -1
  127. android_env/wrappers/gym_wrapper.py +1 -1
  128. android_env/wrappers/gym_wrapper_test.py +1 -1
  129. android_env/wrappers/image_rescale_wrapper.py +13 -10
  130. android_env/wrappers/image_rescale_wrapper_test.py +1 -1
  131. android_env/wrappers/last_action_wrapper.py +5 -4
  132. android_env/wrappers/last_action_wrapper_test.py +1 -1
  133. android_env/wrappers/rate_limit_wrapper.py +1 -1
  134. android_env/wrappers/rate_limit_wrapper_test.py +1 -1
  135. android_env/wrappers/tap_action_wrapper.py +12 -12
  136. android_env/wrappers/tap_action_wrapper_test.py +49 -14
  137. {android_env-1.2.2.dist-info → android_env-1.2.3.dist-info}/METADATA +14 -16
  138. android_env-1.2.3.dist-info/RECORD +141 -0
  139. {android_env-1.2.2.dist-info → android_env-1.2.3.dist-info}/WHEEL +1 -1
  140. android_env-1.2.2.dist-info/RECORD +0 -88
  141. {android_env-1.2.2.dist-info → android_env-1.2.3.dist-info/licenses}/LICENSE +0 -0
  142. {android_env-1.2.2.dist-info → android_env-1.2.3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2023 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -16,16 +16,18 @@
16
16
  """TaskManager handles all events and information related to the task."""
17
17
 
18
18
  import ast
19
+ from collections.abc import Callable
19
20
  import copy
20
21
  import datetime
21
22
  import json
22
23
  import re
23
24
  import threading
24
- from typing import Any, Callable
25
+ from typing import Any
25
26
 
26
27
  from absl import logging
27
28
  from android_env.components import adb_call_parser as adb_call_parser_lib
28
29
  from android_env.components import app_screen_checker
30
+ from android_env.components import config_classes
29
31
  from android_env.components import dumpsys_thread
30
32
  from android_env.components import log_stream as log_stream_lib
31
33
  from android_env.components import logcat_thread
@@ -41,28 +43,18 @@ class TaskManager:
41
43
  def __init__(
42
44
  self,
43
45
  task: task_pb2.Task,
44
- max_bad_states: int = 3,
45
- dumpsys_check_frequency: int = 150,
46
- max_failed_current_activity: int = 10,
46
+ config: config_classes.TaskManagerConfig | None = None,
47
47
  ):
48
48
  """Controls task-relevant events and information.
49
49
 
50
50
  Args:
51
51
  task: A task proto defining the RL task.
52
- max_bad_states: How many bad states in a row are allowed before a restart
53
- of the simulator is triggered.
54
- dumpsys_check_frequency: Frequency, in steps, at which to check
55
- current_activity and view hierarchy
56
- max_failed_current_activity: The maximum number of tries for extracting
57
- the current activity before forcing the episode to restart.
52
+ config: Configuration for this instance.
58
53
  """
59
- self._task = task
60
- self._max_bad_states = max_bad_states
61
- self._dumpsys_check_frequency = dumpsys_check_frequency
62
- self._max_failed_current_activity = max_failed_current_activity
63
54
 
55
+ self._task = task
56
+ self._config = config or config_classes.TaskManagerConfig()
64
57
  self._lock = threading.Lock()
65
- self._extras_max_buffer_size = 100
66
58
  self._logcat_thread = None
67
59
  self._dumpsys_thread = None
68
60
  self._setup_step_interpreter = None
@@ -92,13 +84,6 @@ class TaskManager:
92
84
 
93
85
  logging.info('Task config: %s', self._task)
94
86
 
95
- def task(self) -> task_pb2.Task:
96
- return self._task
97
-
98
- def update_task(self, task: task_pb2.Task) -> None:
99
- self._stats['task_updates'] += 1
100
- self._task = task
101
-
102
87
  def stats(self) -> dict[str, Any]:
103
88
  """Returns a dictionary of stats.
104
89
 
@@ -248,9 +233,11 @@ class TaskManager:
248
233
  self._dumpsys_thread = dumpsys_thread.DumpsysThread(
249
234
  app_screen_checker=app_screen_checker.AppScreenChecker(
250
235
  adb_call_parser=adb_call_parser,
251
- expected_app_screen=self._task.expected_app_screen),
252
- check_frequency=self._dumpsys_check_frequency,
253
- max_failed_current_activity=self._max_failed_current_activity)
236
+ expected_app_screen=self._task.expected_app_screen,
237
+ ),
238
+ check_frequency=self._config.dumpsys_check_frequency,
239
+ max_failed_current_activity=self._config.max_failed_current_activity,
240
+ )
254
241
 
255
242
  def _stop_logcat_thread(self):
256
243
  if self._logcat_thread is not None:
@@ -266,11 +253,11 @@ class TaskManager:
266
253
  to a good state.
267
254
  """
268
255
  logging.warning('Bad state detected.')
269
- if self._max_bad_states:
256
+ if self._config.max_bad_states:
270
257
  self._is_bad_episode = True
271
258
  self._bad_state_counter += 1
272
259
  logging.warning('Bad state counter: %d.', self._bad_state_counter)
273
- if self._bad_state_counter >= self._max_bad_states:
260
+ if self._bad_state_counter >= self._config.max_bad_states:
274
261
  logging.error('Too many consecutive bad states. Restarting simulator.')
275
262
  self._stats['restart_count_max_bad_states'] += 1
276
263
  self._should_restart = True
@@ -342,8 +329,13 @@ class TaskManager:
342
329
  if extra:
343
330
  try:
344
331
  extra = ast.literal_eval(extra)
345
- # Except all to avoid unnecessary crashes, only log error.
346
- except Exception: # pylint: disable=broad-except
332
+ except (
333
+ ValueError,
334
+ TypeError,
335
+ SyntaxError,
336
+ MemoryError,
337
+ RecursionError,
338
+ ):
347
339
  logging.exception('Could not parse extra: %s', extra)
348
340
  # Don't try to process the extra as text; that would probably crash.
349
341
  return
@@ -380,7 +372,10 @@ class TaskManager:
380
372
  latest_extras = self._latest_values['extra']
381
373
  if extra_name in latest_extras:
382
374
  # If latest extra is not flushed, append.
383
- if len(latest_extras[extra_name]) >= self._extras_max_buffer_size:
375
+ if (
376
+ len(latest_extras[extra_name])
377
+ >= self._config.extras_max_buffer_size
378
+ ):
384
379
  latest_extras[extra_name].pop(0)
385
380
  latest_extras[extra_name].append(extra)
386
381
  else:
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2023 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -55,23 +55,6 @@ class TaskManagerTest(absltest.TestCase):
55
55
  log_stream, 'LogStream',
56
56
  return_value=self._log_stream).start()
57
57
 
58
- def test_update_task(self):
59
- init_task = task_pb2.Task(id='initial_task')
60
- new_task = task_pb2.Task(id='updated_task')
61
- task_mgr = task_manager.TaskManager(task=init_task)
62
-
63
- # Setting up the initial task so that the setup_step_interpreter
64
- # is properly initialized.
65
- adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
66
- task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
67
- task_mgr.setup_task()
68
-
69
- self.assertEqual(init_task, task_mgr.task())
70
- self.assertEqual(0, task_mgr.stats()['task_updates'])
71
- task_mgr.update_task(new_task)
72
- self.assertEqual(new_task, task_mgr.task())
73
- self.assertEqual(1, task_mgr.stats()['task_updates'])
74
-
75
58
  def test_start(self):
76
59
  task_mgr = task_manager.TaskManager(task=task_pb2.Task())
77
60
  adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2023 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -24,7 +24,6 @@ from typing import Any
24
24
 
25
25
  from android_env.proto import adb_pb2
26
26
  from android_env.proto import state_pb2
27
- from android_env.proto import task_pb2
28
27
  import dm_env
29
28
  import numpy as np
30
29
 
@@ -79,20 +78,6 @@ class AndroidEnvInterface(dm_env.Environment, metaclass=abc.ABCMeta):
79
78
 
80
79
  return adb_pb2.AdbResponse()
81
80
 
82
- def update_task(self, task: task_pb2.Task) -> bool:
83
- """Replaces the current task with a new task.
84
-
85
- It is the caller's responsibility to call `reset()` after the task update.
86
-
87
- Args:
88
- task: A new task to replace the current one.
89
-
90
- Returns:
91
- A bool indicating the success of the task setup.
92
- """
93
-
94
- return True
95
-
96
81
  def load_state(
97
82
  self, request: state_pb2.LoadStateRequest
98
83
  ) -> state_pb2.LoadStateResponse:
@@ -122,4 +107,3 @@ class AndroidEnvInterface(dm_env.Environment, metaclass=abc.ABCMeta):
122
107
  applicable), and any other relevant information.
123
108
  """
124
109
  raise NotImplementedError('This environment does not support saving state')
125
-
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2023 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -19,10 +19,12 @@ from typing import Any
19
19
 
20
20
  from absl import logging
21
21
  from android_env import env_interface
22
+ from android_env.components import adb_call_parser
22
23
  from android_env.components import coordinator as coordinator_lib
24
+ from android_env.components import task_manager as task_manager_lib
25
+ from android_env.components.simulators import base_simulator
23
26
  from android_env.proto import adb_pb2
24
27
  from android_env.proto import state_pb2
25
- from android_env.proto import task_pb2
26
28
  import dm_env
27
29
  import numpy as np
28
30
 
@@ -30,10 +32,17 @@ import numpy as np
30
32
  class AndroidEnv(env_interface.AndroidEnvInterface):
31
33
  """An RL environment that interacts with Android apps."""
32
34
 
33
- def __init__(self, coordinator: coordinator_lib.Coordinator):
35
+ def __init__(
36
+ self,
37
+ simulator: base_simulator.BaseSimulator,
38
+ coordinator: coordinator_lib.Coordinator,
39
+ task_manager: task_manager_lib.TaskManager,
40
+ ):
34
41
  """Initializes the state of this AndroidEnv object."""
35
42
 
43
+ self._simulator = simulator
36
44
  self._coordinator = coordinator
45
+ self._task_manager = task_manager
37
46
  self._latest_action = {}
38
47
  self._latest_observation = {}
39
48
  self._latest_extras = {}
@@ -133,22 +142,13 @@ class AndroidEnv(env_interface.AndroidEnvInterface):
133
142
  return self._latest_observation.copy()
134
143
 
135
144
  def stats(self) -> dict[str, Any]:
136
- return self._coordinator.stats()
145
+ coordinator_stats = self._coordinator.stats()
146
+ task_manager_stats = self._task_manager.stats()
147
+ return coordinator_stats | task_manager_stats
137
148
 
138
149
  def execute_adb_call(self, call: adb_pb2.AdbRequest) -> adb_pb2.AdbResponse:
139
150
  return self._coordinator.execute_adb_call(call)
140
151
 
141
- def update_task(self, task: task_pb2.Task) -> bool:
142
- """Replaces the current task with a new task.
143
-
144
- Args:
145
- task: A new task to replace the current one.
146
-
147
- Returns:
148
- A bool indicating the success of the task setup.
149
- """
150
- return self._coordinator.update_task(task)
151
-
152
152
  def load_state(
153
153
  self, request: state_pb2.LoadStateRequest
154
154
  ) -> state_pb2.LoadStateResponse:
@@ -162,7 +162,16 @@ class AndroidEnv(env_interface.AndroidEnvInterface):
162
162
  A `LoadStateResponse` containing the status, error message (if
163
163
  applicable), and any other relevant information.
164
164
  """
165
- return self._coordinator.load_state(request)
165
+
166
+ self._task_manager.stop()
167
+ response = self._simulator.load_state(request)
168
+ self._task_manager.start(
169
+ adb_call_parser_factory=lambda: adb_call_parser.AdbCallParser(
170
+ self._simulator.create_adb_controller()
171
+ ),
172
+ log_stream=self._simulator.create_log_stream(),
173
+ )
174
+ return response
166
175
 
167
176
  def save_state(
168
177
  self, request: state_pb2.SaveStateRequest
@@ -177,4 +186,5 @@ class AndroidEnv(env_interface.AndroidEnvInterface):
177
186
  A `SaveStateResponse` containing the status, error message (if
178
187
  applicable), and any other relevant information.
179
188
  """
180
- return self._coordinator.save_state(request)
189
+
190
+ return self._simulator.save_state(request)
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2023 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -19,10 +19,13 @@ from unittest import mock
19
19
 
20
20
  from absl.testing import absltest
21
21
  from android_env import environment
22
+ from android_env.components import config_classes
22
23
  from android_env.components import coordinator as coordinator_lib
24
+ from android_env.components import task_manager as task_manager_lib
25
+ from android_env.components.simulators import base_simulator
26
+ from android_env.components.simulators.fake import fake_simulator
23
27
  from android_env.proto import adb_pb2
24
28
  from android_env.proto import state_pb2
25
- from android_env.proto import task_pb2
26
29
  import dm_env
27
30
  import numpy as np
28
31
 
@@ -44,10 +47,21 @@ def _create_mock_coordinator() -> coordinator_lib.Coordinator:
44
47
  return coordinator
45
48
 
46
49
 
50
+ def _create_fake_simulator() -> fake_simulator.FakeSimulator:
51
+ return fake_simulator.FakeSimulator(
52
+ config=config_classes.FakeSimulatorConfig(screen_dimensions=(123, 456))
53
+ )
54
+
55
+
47
56
  class AndroidEnvTest(absltest.TestCase):
48
57
 
49
58
  def test_specs(self):
50
- env = environment.AndroidEnv(_create_mock_coordinator())
59
+ simulator = _create_fake_simulator()
60
+ coordinator = _create_mock_coordinator()
61
+ task_manager = mock.create_autospec(task_manager_lib.TaskManager)
62
+ env = environment.AndroidEnv(
63
+ simulator=simulator, coordinator=coordinator, task_manager=task_manager
64
+ )
51
65
 
52
66
  # Check action spec.
53
67
  self.assertNotEmpty(env.action_spec())
@@ -77,7 +91,9 @@ class AndroidEnvTest(absltest.TestCase):
77
91
  self.assertEqual(env.observation_spec()['orientation'].shape, (4,))
78
92
 
79
93
  def test_reset_and_step(self):
80
- coordinator = mock.create_autospec(coordinator_lib.Coordinator)
94
+ simulator = _create_fake_simulator()
95
+ coordinator = _create_mock_coordinator()
96
+ task_manager = mock.create_autospec(task_manager_lib.TaskManager)
81
97
  coordinator.action_spec.return_value = {
82
98
  'action_type':
83
99
  dm_env.specs.DiscreteArray(num_values=3),
@@ -90,7 +106,9 @@ class AndroidEnvTest(absltest.TestCase):
90
106
  'timedelta': dm_env.specs.Array(shape=(), dtype=np.int64),
91
107
  'orientation': dm_env.specs.Array(shape=(4,), dtype=np.uint8),
92
108
  }
93
- env = environment.AndroidEnv(coordinator)
109
+ env = environment.AndroidEnv(
110
+ simulator=simulator, coordinator=coordinator, task_manager=task_manager
111
+ )
94
112
  coordinator.rl_reset.return_value = dm_env.TimeStep(
95
113
  step_type=dm_env.StepType.FIRST,
96
114
  reward=0.0,
@@ -125,9 +143,8 @@ class AndroidEnvTest(absltest.TestCase):
125
143
  self.assertIn('click', extras)
126
144
  self.assertEqual(extras['click'], np.array([246], dtype=np.int64))
127
145
 
128
- coordinator.stats.return_value = {
129
- 'my_measurement': 135,
130
- }
146
+ coordinator.stats.return_value = {'my_measurement': 135}
147
+ task_manager.stats.return_value = {'another_measurement': 79}
131
148
 
132
149
  # Step again in the environment and check expectations again.
133
150
  pixels = np.random.rand(987, 654, 3)
@@ -189,8 +206,12 @@ class AndroidEnvTest(absltest.TestCase):
189
206
  np.testing.assert_equal(obs['orientation'], (1, 0, 0, 0))
190
207
 
191
208
  def test_adb_call(self):
209
+ simulator = _create_fake_simulator()
192
210
  coordinator = _create_mock_coordinator()
193
- env = environment.AndroidEnv(coordinator)
211
+ task_manager = mock.create_autospec(task_manager_lib.TaskManager)
212
+ env = environment.AndroidEnv(
213
+ simulator=simulator, coordinator=coordinator, task_manager=task_manager
214
+ )
194
215
  call = adb_pb2.AdbRequest(
195
216
  force_stop=adb_pb2.AdbRequest.ForceStop(package_name='blah'))
196
217
  expected_response = adb_pb2.AdbResponse(
@@ -203,41 +224,46 @@ class AndroidEnvTest(absltest.TestCase):
203
224
  coordinator.execute_adb_call.assert_called_once_with(call)
204
225
 
205
226
  def test_load_state(self):
227
+ simulator = mock.create_autospec(base_simulator.BaseSimulator)
206
228
  coordinator = _create_mock_coordinator()
207
- env = environment.AndroidEnv(coordinator)
229
+ task_manager = mock.create_autospec(task_manager_lib.TaskManager)
230
+ env = environment.AndroidEnv(
231
+ simulator=simulator, coordinator=coordinator, task_manager=task_manager
232
+ )
208
233
  expected_response = state_pb2.LoadStateResponse(
209
234
  status=state_pb2.LoadStateResponse.Status.OK
210
235
  )
211
236
  request = state_pb2.LoadStateRequest(args={'foo': 'bar'})
212
- coordinator.load_state.return_value = expected_response
237
+ simulator.load_state.return_value = expected_response
213
238
  response = env.load_state(request)
214
239
  self.assertEqual(response, expected_response)
215
- coordinator.load_state.assert_called_once_with(request)
240
+ simulator.load_state.assert_called_once_with(request)
241
+ task_manager.stop.assert_called_once()
242
+ task_manager.start.assert_called_once()
216
243
 
217
244
  def test_save_state(self):
245
+ simulator = mock.create_autospec(base_simulator.BaseSimulator)
218
246
  coordinator = _create_mock_coordinator()
219
- env = environment.AndroidEnv(coordinator)
247
+ task_manager = mock.create_autospec(task_manager_lib.TaskManager)
248
+ env = environment.AndroidEnv(
249
+ simulator=simulator, coordinator=coordinator, task_manager=task_manager
250
+ )
220
251
  expected_response = state_pb2.SaveStateResponse(
221
252
  status=state_pb2.SaveStateResponse.Status.OK
222
253
  )
223
254
  request = state_pb2.SaveStateRequest(args={'foo': 'bar'})
224
- coordinator.save_state.return_value = expected_response
255
+ simulator.save_state.return_value = expected_response
225
256
  response = env.save_state(request)
226
257
  self.assertEqual(response, expected_response)
227
- coordinator.save_state.assert_called_once_with(request)
228
-
229
- def test_update_task(self):
230
- coordinator = _create_mock_coordinator()
231
- env = environment.AndroidEnv(coordinator)
232
- task = task_pb2.Task()
233
- coordinator.update_task.return_value = True
234
- response = env.update_task(task)
235
- self.assertEqual(response, True)
236
- coordinator.update_task.assert_called_once_with(task)
258
+ simulator.save_state.assert_called_once_with(request)
237
259
 
238
260
  def test_double_close(self):
261
+ simulator = _create_fake_simulator()
239
262
  coordinator = _create_mock_coordinator()
240
- env = environment.AndroidEnv(coordinator)
263
+ task_manager = mock.create_autospec(task_manager_lib.TaskManager)
264
+ env = environment.AndroidEnv(
265
+ simulator=simulator, coordinator=coordinator, task_manager=task_manager
266
+ )
241
267
  env.close()
242
268
  env.close()
243
269
  coordinator.close.assert_called_once()
android_env/loader.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2023 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -17,59 +17,73 @@
17
17
 
18
18
  import os
19
19
 
20
+ from absl import logging
20
21
  from android_env import environment
22
+ from android_env.components import config_classes
21
23
  from android_env.components import coordinator as coordinator_lib
24
+ from android_env.components import device_settings as device_settings_lib
22
25
  from android_env.components import task_manager as task_manager_lib
23
26
  from android_env.components.simulators.emulator import emulator_simulator
27
+ from android_env.components.simulators.fake import fake_simulator
24
28
  from android_env.proto import task_pb2
25
29
 
26
30
  from google.protobuf import text_format
27
31
 
28
32
 
29
- def load(task_path: str,
30
- avd_name: str,
31
- android_avd_home: str = '~/.android/avd',
32
- android_sdk_root: str = '~/Android/Sdk',
33
- emulator_path: str = '~/Android/Sdk/emulator/emulator',
34
- adb_path: str = '~/Android/Sdk/platform-tools/adb',
35
- run_headless: bool = False) -> environment.AndroidEnv:
36
- """Loads an AndroidEnv instance.
37
-
38
- Args:
39
- task_path: Path to the task textproto file.
40
- avd_name: Name of the AVD (Android Virtual Device).
41
- android_avd_home: Path to the AVD (Android Virtual Device).
42
- android_sdk_root: Root directory of the SDK.
43
- emulator_path: Path to the emulator binary.
44
- adb_path: Path to the ADB (Android Debug Bridge).
45
- run_headless: If True, the emulator display is turned off.
46
- Returns:
47
- env: An AndroidEnv instance.
48
- """
49
-
50
- # Create simulator.
51
- simulator = emulator_simulator.EmulatorSimulator(
52
- adb_controller_args=dict(
53
- adb_path=os.path.expanduser(adb_path),
54
- adb_server_port=5037,
55
- ),
56
- emulator_launcher_args=dict(
57
- avd_name=avd_name,
58
- android_avd_home=os.path.expanduser(android_avd_home),
59
- android_sdk_root=os.path.expanduser(android_sdk_root),
60
- emulator_path=os.path.expanduser(emulator_path),
61
- run_headless=run_headless,
62
- gpu_mode='swiftshader_indirect',
63
- ),
64
- )
33
+ def _load_task(task_config: config_classes.TaskConfig) -> task_pb2.Task:
34
+ """Returns the task according to `task_config`."""
65
35
 
66
- # Prepare task.
67
36
  task = task_pb2.Task()
68
- with open(task_path, 'r') as proto_file:
69
- text_format.Parse(proto_file.read(), task)
37
+ match task_config:
38
+ case config_classes.FilesystemTaskConfig():
39
+ with open(task_config.path, 'r') as proto_file:
40
+ text_format.Parse(proto_file.read(), task)
41
+ case _:
42
+ logging.error('Unsupported TaskConfig: %r', task_config)
43
+
44
+ return task
45
+
46
+
47
+ def load(config: config_classes.AndroidEnvConfig) -> environment.AndroidEnv:
48
+ """Loads an AndroidEnv instance."""
70
49
 
50
+ task = _load_task(config.task)
71
51
  task_manager = task_manager_lib.TaskManager(task)
72
- coordinator = coordinator_lib.Coordinator(simulator, task_manager)
73
52
 
74
- # Load environment.
75
- return environment.AndroidEnv(coordinator=coordinator)
53
+ match config.simulator:
54
+ case config_classes.EmulatorConfig():
55
+ _process_emulator_launcher_config(config.simulator)
56
+ simulator = emulator_simulator.EmulatorSimulator(config=config.simulator)
57
+ case config_classes.FakeSimulatorConfig():
58
+ simulator = fake_simulator.FakeSimulator(config=config.simulator)
59
+ case _:
60
+ raise ValueError('Unsupported simulator config: {config.simulator}')
61
+
62
+ device_settings = device_settings_lib.DeviceSettings(simulator)
63
+ coordinator = coordinator_lib.Coordinator(
64
+ simulator, task_manager, device_settings
65
+ )
66
+ return environment.AndroidEnv(
67
+ simulator=simulator, coordinator=coordinator, task_manager=task_manager
68
+ )
69
+
70
+
71
+ def _process_emulator_launcher_config(
72
+ emulator_config: config_classes.EmulatorConfig,
73
+ ) -> None:
74
+ """Adjusts the configuration of the emulator depending on some conditions."""
75
+
76
+ # Expand the user directory if specified.
77
+ launcher_config = emulator_config.emulator_launcher
78
+ launcher_config.android_avd_home = os.path.expanduser(
79
+ launcher_config.android_avd_home
80
+ )
81
+ launcher_config.android_sdk_root = os.path.expanduser(
82
+ launcher_config.android_sdk_root
83
+ )
84
+ launcher_config.emulator_path = os.path.expanduser(
85
+ launcher_config.emulator_path
86
+ )
87
+ emulator_config.adb_controller.adb_path = os.path.expanduser(
88
+ emulator_config.adb_controller.adb_path
89
+ )