android-env 1.2.1__py3-none-any.whl → 1.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. android_env/__init__.py +1 -1
  2. android_env/components/__init__.py +1 -1
  3. android_env/components/a11y/__init__.py +15 -0
  4. android_env/components/a11y/a11y_events.py +118 -0
  5. android_env/components/a11y/a11y_events_test.py +173 -0
  6. android_env/components/a11y/a11y_forests.py +128 -0
  7. android_env/components/a11y/a11y_forests_test.py +237 -0
  8. android_env/components/a11y/a11y_servicer.py +199 -0
  9. android_env/components/a11y/a11y_servicer_test.py +224 -0
  10. android_env/components/action_fns.py +132 -0
  11. android_env/components/action_fns_test.py +227 -0
  12. android_env/components/action_type.py +26 -3
  13. android_env/components/adb_call_parser.py +239 -196
  14. android_env/components/adb_call_parser_test.py +179 -209
  15. android_env/components/adb_controller.py +90 -52
  16. android_env/components/adb_controller_test.py +187 -16
  17. android_env/components/adb_log_stream.py +17 -5
  18. android_env/components/adb_log_stream_test.py +17 -3
  19. android_env/components/app_screen_checker.py +17 -15
  20. android_env/components/app_screen_checker_test.py +7 -8
  21. android_env/components/config_classes.py +203 -0
  22. android_env/components/coordinator.py +102 -338
  23. android_env/components/coordinator_test.py +59 -199
  24. android_env/components/device_settings.py +174 -0
  25. android_env/components/device_settings_test.py +228 -0
  26. android_env/components/dumpsys_thread.py +3 -4
  27. android_env/components/dumpsys_thread_test.py +1 -1
  28. android_env/components/errors.py +52 -10
  29. android_env/components/errors_test.py +110 -0
  30. android_env/components/log_stream.py +7 -5
  31. android_env/components/log_stream_test.py +1 -1
  32. android_env/components/logcat_thread.py +9 -8
  33. android_env/components/logcat_thread_test.py +3 -4
  34. android_env/components/{utils.py → pixel_fns.py} +20 -20
  35. android_env/components/{utils_test.py → pixel_fns_test.py} +20 -15
  36. android_env/components/setup_step_interpreter.py +47 -39
  37. android_env/components/setup_step_interpreter_test.py +4 -4
  38. android_env/components/simulators/__init__.py +1 -1
  39. android_env/components/simulators/base_simulator.py +116 -44
  40. android_env/components/simulators/base_simulator_test.py +131 -9
  41. android_env/components/simulators/emulator/__init__.py +1 -1
  42. android_env/components/simulators/emulator/emulator_launcher.py +67 -77
  43. android_env/components/simulators/emulator/emulator_launcher_test.py +153 -49
  44. android_env/components/simulators/emulator/emulator_simulator.py +276 -95
  45. android_env/components/simulators/emulator/emulator_simulator_test.py +314 -89
  46. android_env/components/simulators/fake/__init__.py +1 -1
  47. android_env/components/simulators/fake/fake_simulator.py +17 -25
  48. android_env/components/simulators/fake/fake_simulator_test.py +29 -12
  49. android_env/components/specs.py +18 -28
  50. android_env/components/specs_test.py +1 -44
  51. android_env/components/task_manager.py +48 -48
  52. android_env/components/task_manager_test.py +71 -60
  53. android_env/env_interface.py +37 -23
  54. android_env/environment.py +83 -51
  55. android_env/environment_test.py +68 -29
  56. android_env/loader.py +57 -43
  57. android_env/loader_test.py +115 -35
  58. android_env/proto/__init__.py +1 -1
  59. android_env/proto/a11y/__init__.py +15 -0
  60. android_env/proto/a11y/a11y.proto +75 -0
  61. android_env/proto/a11y/a11y_pb2.py +54 -0
  62. android_env/proto/a11y/a11y_pb2.pyi +49 -0
  63. android_env/proto/a11y/a11y_pb2_grpc.py +202 -0
  64. android_env/proto/a11y/android_accessibility_action.proto +32 -0
  65. android_env/proto/a11y/android_accessibility_action_pb2.py +37 -0
  66. android_env/proto/a11y/android_accessibility_action_pb2.pyi +13 -0
  67. android_env/proto/a11y/android_accessibility_action_pb2_grpc.py +24 -0
  68. android_env/proto/a11y/android_accessibility_forest.proto +29 -0
  69. android_env/proto/a11y/android_accessibility_forest_pb2.py +38 -0
  70. android_env/proto/a11y/android_accessibility_forest_pb2.pyi +13 -0
  71. android_env/proto/a11y/android_accessibility_forest_pb2_grpc.py +24 -0
  72. android_env/proto/a11y/android_accessibility_node_info.proto +122 -0
  73. android_env/proto/a11y/android_accessibility_node_info_clickable_span.proto +49 -0
  74. android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2.py +39 -0
  75. android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2.pyi +28 -0
  76. android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2_grpc.py +24 -0
  77. android_env/proto/a11y/android_accessibility_node_info_pb2.py +42 -0
  78. android_env/proto/a11y/android_accessibility_node_info_pb2.pyi +75 -0
  79. android_env/proto/a11y/android_accessibility_node_info_pb2_grpc.py +24 -0
  80. android_env/proto/a11y/android_accessibility_tree.proto +29 -0
  81. android_env/proto/a11y/android_accessibility_tree_pb2.py +38 -0
  82. android_env/proto/a11y/android_accessibility_tree_pb2.pyi +13 -0
  83. android_env/proto/a11y/android_accessibility_tree_pb2_grpc.py +24 -0
  84. android_env/proto/a11y/android_accessibility_window_info.proto +84 -0
  85. android_env/proto/a11y/android_accessibility_window_info_pb2.py +41 -0
  86. android_env/proto/a11y/android_accessibility_window_info_pb2.pyi +48 -0
  87. android_env/proto/a11y/android_accessibility_window_info_pb2_grpc.py +24 -0
  88. android_env/proto/a11y/rect.proto +30 -0
  89. android_env/proto/a11y/rect_pb2.py +37 -0
  90. android_env/proto/a11y/rect_pb2.pyi +17 -0
  91. android_env/proto/a11y/rect_pb2_grpc.py +24 -0
  92. android_env/proto/adb.proto +17 -6
  93. android_env/proto/adb_pb2.py +120 -107
  94. android_env/proto/adb_pb2.pyi +396 -0
  95. android_env/proto/adb_pb2_grpc.py +20 -0
  96. android_env/proto/emulator_controller.proto +68 -63
  97. android_env/proto/emulator_controller_pb2.py +142 -131
  98. android_env/proto/emulator_controller_pb2.pyi +672 -0
  99. android_env/proto/emulator_controller_pb2_grpc.py +505 -142
  100. android_env/proto/snapshot.proto +169 -0
  101. android_env/proto/snapshot_pb2.py +47 -0
  102. android_env/proto/snapshot_pb2.pyi +117 -0
  103. android_env/proto/snapshot_pb2_grpc.py +24 -0
  104. android_env/proto/snapshot_service.proto +289 -0
  105. android_env/proto/snapshot_service_pb2.py +54 -0
  106. android_env/proto/snapshot_service_pb2.pyi +86 -0
  107. android_env/proto/snapshot_service_pb2_grpc.py +487 -0
  108. android_env/proto/state.proto +63 -0
  109. android_env/proto/state_pb2.py +63 -0
  110. android_env/proto/state_pb2.pyi +85 -0
  111. android_env/proto/state_pb2_grpc.py +24 -0
  112. android_env/proto/task.proto +5 -1
  113. android_env/proto/task_pb2.py +42 -31
  114. android_env/proto/task_pb2.pyi +160 -0
  115. android_env/proto/task_pb2_grpc.py +20 -0
  116. android_env/wrappers/__init__.py +1 -1
  117. android_env/wrappers/a11y_grpc_wrapper.py +500 -0
  118. android_env/wrappers/a11y_grpc_wrapper_test.py +849 -0
  119. android_env/wrappers/base_wrapper.py +34 -13
  120. android_env/wrappers/base_wrapper_test.py +22 -16
  121. android_env/wrappers/discrete_action_wrapper.py +18 -17
  122. android_env/wrappers/discrete_action_wrapper_test.py +4 -4
  123. android_env/wrappers/flat_interface_wrapper.py +5 -5
  124. android_env/wrappers/flat_interface_wrapper_test.py +7 -11
  125. android_env/wrappers/float_pixels_wrapper.py +9 -10
  126. android_env/wrappers/float_pixels_wrapper_test.py +3 -3
  127. android_env/wrappers/gym_wrapper.py +19 -13
  128. android_env/wrappers/gym_wrapper_test.py +3 -5
  129. android_env/wrappers/image_rescale_wrapper.py +18 -21
  130. android_env/wrappers/image_rescale_wrapper_test.py +25 -37
  131. android_env/wrappers/last_action_wrapper.py +16 -13
  132. android_env/wrappers/last_action_wrapper_test.py +44 -51
  133. android_env/wrappers/rate_limit_wrapper.py +6 -3
  134. android_env/wrappers/rate_limit_wrapper_test.py +22 -1
  135. android_env/wrappers/tap_action_wrapper.py +16 -17
  136. android_env/wrappers/tap_action_wrapper_test.py +51 -16
  137. {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info}/METADATA +14 -18
  138. android_env-1.2.3.dist-info/RECORD +141 -0
  139. {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info}/WHEEL +1 -1
  140. android_env/proto/raw_observation.proto +0 -39
  141. android_env/proto/raw_observation_pb2.py +0 -27
  142. android_env/proto/raw_observation_pb2_grpc.py +0 -4
  143. android_env-1.2.1.dist-info/RECORD +0 -81
  144. {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info/licenses}/LICENSE +0 -0
  145. {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,500 @@
1
+ # coding=utf-8
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Wraps AndroidEnv to retrieve accessibility messages from gRPC."""
17
+
18
+ from concurrent import futures
19
+ import time
20
+ from typing import Any
21
+
22
+ import urllib
23
+
24
+ from absl import logging
25
+ from android_env import env_interface
26
+ from android_env.components import action_type as android_action_type_lib
27
+ from android_env.components.a11y import a11y_events
28
+ from android_env.components.a11y import a11y_forests
29
+ from android_env.components.a11y import a11y_servicer
30
+ from android_env.proto import adb_pb2
31
+ from android_env.proto.a11y import a11y_pb2_grpc
32
+ from android_env.wrappers import base_wrapper
33
+ import dm_env
34
+ import grpc
35
+ import numpy as np
36
+ import portpicker
37
+
38
+
39
+ def _get_accessibility_forwarder_apk() -> bytes:
40
+ logging.info('Downloading accessibility forwarder apk....')
41
+ with urllib.request.urlopen(
42
+ 'https://storage.googleapis.com/android_env-tasks/2024.05.13-accessibility_forwarder.apk'
43
+ ) as response:
44
+ return response.read()
45
+
46
+
47
+ class EnableNetworkingError(ValueError):
48
+ pass
49
+
50
+
51
+ class A11yGrpcWrapper(base_wrapper.BaseWrapper):
52
+ """Wrapper which receives A11y events and forests over gRPC.
53
+
54
+ A11y forest protobufs and event dicts are sent from the Android emulator via
55
+ gRPC from the `AccessibilityForwarder` (for use in developing reward
56
+ functions, etc). This wrapper constructs a server which receives these
57
+ messages and channels them into `task_extras`.
58
+
59
+ The downside of forwarding this information through gRPC is that no messages
60
+ will be sent if networking is turned off (e.g., if the AVD is in airplane
61
+ mode). To mitigate this problem, the `AccessibilityForwarder` logs an error
62
+ message if it fails to contact the server. This wrapper monitors the logs for
63
+ such error messages, and attempts (in another thread, to not block environment
64
+ transitions) to reconnect the AVD to the network. If this fails to fix the
65
+ problem, this wrapper ends the episode.
66
+
67
+ This wrapper is implemented to be robust to multiple upstream callers of
68
+ `task_extras`, and to ensure they each receive the same extras at every
69
+ timestep. Thus, the logic is the following:
70
+ * New a11y events/forests are fetched during `reset` and `step`, *not* during
71
+ `task_extras()` calls.
72
+ * If no one has called `task_extras()` since the last `step` or `reset`, the
73
+ extras are accumulated (so that no extras are missed because someone called
74
+ `step()` twice without calling `task_extras()`).
75
+ * If someone *has* called `task_extras()` since last step, the newly fetched
76
+ extras replace the old extras.
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ env: env_interface.AndroidEnvInterface,
82
+ disable_other_network_traffic: bool = False,
83
+ install_a11y_forwarding: bool = False,
84
+ start_a11y_service: bool = True,
85
+ enable_a11y_tree_info: bool = False,
86
+ add_latest_a11y_info_to_obs: bool = False,
87
+ a11y_info_timeout: float | None = None,
88
+ max_enable_networking_attempts: int = 10,
89
+ latest_a11y_info_only: bool = False,
90
+ ):
91
+ """Initializes wrapper.
92
+
93
+ Args:
94
+ env: Environment to wrap.
95
+ disable_other_network_traffic: When True, all network traffic, other than
96
+ the connection to the servicer, is disabled. NOTE: This requires root
97
+ access on the device (i.e. it uses the `su` command). An
98
+ `AdbControllerError` exception will be raised if the underlying command
99
+ fails.
100
+ install_a11y_forwarding: If True, the wrapper handles the installation of
101
+ all packages required for the servicer to collect a11y information.
102
+ start_a11y_service: If True, starts the a11y forwarding services. NOTE:
103
+ The packages must be installed beforehand, e.g., using the
104
+ install_a11y_forwarding flag.
105
+ enable_a11y_tree_info: When False, this wrapper collects only a11y events
106
+ and not a11y tree.
107
+ add_latest_a11y_info_to_obs: When True, the latest observed a11y forest is
108
+ added to the observation.
109
+ a11y_info_timeout: When larger than zero and add_latest_a11y_info_to_obs
110
+ is set to True, the wrapper will wait the corresponding amount of time,
111
+ measured in seconds, to collect the latest a11y forest.
112
+ max_enable_networking_attempts: When the a11y gRPC service fails to
113
+ provide a11y information, we attempt this many times to re-enable the
114
+ networking. If all these attempts fail, fetching task_extras will raise
115
+ an EnableNetworkingError.
116
+ latest_a11y_info_only: When True, the a11y servicer is setup to save only
117
+ the latest tree it has received from the Android app.
118
+ """
119
+ self._env = env
120
+ if install_a11y_forwarding:
121
+ self._install_a11y_forwarding_apk()
122
+ time.sleep(10.0)
123
+ if start_a11y_service:
124
+ self._start_a11y_services()
125
+ time.sleep(3.0)
126
+ if enable_a11y_tree_info:
127
+ self._enable_a11y_tree_logs()
128
+ self._relaunch_count = 0
129
+ self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
130
+ self._servicer = a11y_servicer.A11yServicer(
131
+ latest_forest_only=latest_a11y_info_only
132
+ )
133
+ a11y_pb2_grpc.add_A11yServiceServicer_to_server(
134
+ self._servicer, self._server
135
+ )
136
+ server_credentials = grpc.local_server_credentials()
137
+ self._port = portpicker.pick_unused_port()
138
+ logging.info('Using port %s', self._port)
139
+ uri_address = f'[::]:{self._port}'
140
+ self._server.add_secure_port(uri_address, server_credentials)
141
+ logging.info('Starting server')
142
+ self._server.start()
143
+ logging.info('Server now running.')
144
+
145
+ self._max_enable_networking_attempts = max_enable_networking_attempts
146
+ self._reset_enable_networking_attempts()
147
+
148
+ self._disable_other_network_traffic = disable_other_network_traffic
149
+ self._should_accumulate = False
150
+ self._accumulated_extras = None
151
+ self._add_latest_a11y_info_to_obs = add_latest_a11y_info_to_obs
152
+ self._a11y_info_timeout = a11y_info_timeout
153
+ self._parent_action_spec = self._env.action_spec()
154
+ if self._a11y_info_timeout is not None and self._a11y_info_timeout > 0.0:
155
+ if 'action_type' not in self._parent_action_spec.keys():
156
+ raise ValueError(
157
+ 'action_type not in the parent action spec: '
158
+ f'{self._parent_action_spec}. This is a strong requirement when '
159
+ f'a11y_info_timeout = {a11y_info_timeout} > 0'
160
+ )
161
+
162
+ def _start_a11y_services(self) -> None:
163
+ """Starts the accessibility forwarder services.
164
+
165
+ Raises:
166
+ RuntimeError: If accessibility service is not started.
167
+ """
168
+ start_service_request = adb_pb2.AdbRequest(
169
+ settings=adb_pb2.AdbRequest.SettingsRequest(
170
+ name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SECURE,
171
+ put=adb_pb2.AdbRequest.SettingsRequest.Put(
172
+ key='enabled_accessibility_services',
173
+ value=(
174
+ 'com.google.androidenv.accessibilityforwarder/com.google.'
175
+ 'androidenv.accessibilityforwarder.AccessibilityForwarder'
176
+ ),
177
+ ),
178
+ )
179
+ )
180
+ start_service_response = self._env.execute_adb_call(start_service_request)
181
+ if start_service_response.status != adb_pb2.AdbResponse.Status.OK:
182
+ raise RuntimeError(
183
+ 'Could not start accessibility forwarder '
184
+ 'service: '
185
+ f'{start_service_response}.'
186
+ )
187
+
188
+ def _install_a11y_forwarding_apk(self) -> None:
189
+ """Enables accessibility information forwarding."""
190
+ a11y_fwd_apk = _get_accessibility_forwarder_apk()
191
+ # Install and setup the Accesssibility Forwarder.
192
+ install_request = adb_pb2.AdbRequest(
193
+ install_apk=adb_pb2.AdbRequest.InstallApk(
194
+ blob=adb_pb2.AdbRequest.InstallApk.Blob(contents=a11y_fwd_apk),
195
+ )
196
+ )
197
+ install_response = self._env.execute_adb_call(install_request)
198
+ if install_response.status != adb_pb2.AdbResponse.Status.OK:
199
+ raise ValueError(
200
+ f'Could not install accessibility_forwarder.apk: {install_response}.'
201
+ )
202
+
203
+ def _enable_a11y_tree_logs(self) -> None:
204
+ enable_tree_logs_request = adb_pb2.AdbRequest(
205
+ send_broadcast=adb_pb2.AdbRequest.SendBroadcast(
206
+ action=(
207
+ 'accessibility_forwarder.intent.action.'
208
+ 'ENABLE_ACCESSIBILITY_TREE_LOGS'
209
+ ),
210
+ component=(
211
+ 'com.google.androidenv.accessibilityforwarder/com.google.androidenv.accessibilityforwarder.FlagsBroadcastReceiver'
212
+ ),
213
+ )
214
+ )
215
+ enable_tree_logs_response = self._env.execute_adb_call(
216
+ enable_tree_logs_request
217
+ )
218
+ if enable_tree_logs_response.status != adb_pb2.AdbResponse.Status.OK:
219
+ raise ValueError(
220
+ 'Could not enable accessibility tree logging: '
221
+ f'{enable_tree_logs_response}.'
222
+ )
223
+
224
+ def _reset_enable_networking_attempts(self) -> None:
225
+ self._enable_networking_attempts_left = self._max_enable_networking_attempts
226
+ self._enabling_networking_future = None
227
+ self._a11y_exception = None
228
+
229
+ def get_port(self):
230
+ return self._port
231
+
232
+ def close(self):
233
+ self._server.stop(None)
234
+ logging.info('gRPC server stopped')
235
+ self._env.close()
236
+
237
+ def attempt_enable_networking(self) -> None:
238
+ """Attempts to turn on networking within the Android device.
239
+
240
+ Attempt to turn on the networking in the Android device, by:
241
+ - turning off airplane mode;
242
+ - turning on the wifi connection.
243
+ """
244
+ self.execute_adb_call(
245
+ adb_pb2.AdbRequest(
246
+ settings=adb_pb2.AdbRequest.SettingsRequest(
247
+ name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.GLOBAL,
248
+ put=adb_pb2.AdbRequest.SettingsRequest.Put(
249
+ key='airplane_mode_on', value='0'
250
+ ),
251
+ )
252
+ )
253
+ )
254
+ time.sleep(1.0)
255
+ self.execute_adb_call(
256
+ adb_pb2.AdbRequest(
257
+ generic=adb_pb2.AdbRequest.GenericRequest(
258
+ args=[
259
+ 'shell',
260
+ 'svc',
261
+ 'wifi',
262
+ 'enable',
263
+ ]
264
+ )
265
+ )
266
+ )
267
+ time.sleep(1.0)
268
+
269
+ def _configure_grpc(self) -> None:
270
+ """Configure networking and set the gRPC port in the AVD."""
271
+
272
+ if self._disable_other_network_traffic:
273
+ self.execute_adb_call(
274
+ adb_pb2.AdbRequest(
275
+ generic=adb_pb2.AdbRequest.GenericRequest(
276
+ args=[
277
+ 'shell',
278
+ 'su',
279
+ '0',
280
+ 'iptables',
281
+ '-A',
282
+ 'OUTPUT',
283
+ '-p',
284
+ 'tcp',
285
+ '-d',
286
+ '10.0.2.2',
287
+ '--dport',
288
+ str(self._port),
289
+ '-j',
290
+ 'ACCEPT',
291
+ ]
292
+ )
293
+ )
294
+ )
295
+ time.sleep(3.0)
296
+ self.execute_adb_call(
297
+ adb_pb2.AdbRequest(
298
+ generic=adb_pb2.AdbRequest.GenericRequest(
299
+ args=[
300
+ 'shell',
301
+ 'su',
302
+ '0',
303
+ 'iptables',
304
+ '-A',
305
+ 'OUTPUT',
306
+ '-j',
307
+ 'DROP',
308
+ ]
309
+ )
310
+ )
311
+ )
312
+ time.sleep(3.0)
313
+
314
+ self.execute_adb_call(
315
+ adb_pb2.AdbRequest(
316
+ settings=adb_pb2.AdbRequest.SettingsRequest(
317
+ name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.GLOBAL,
318
+ put=adb_pb2.AdbRequest.SettingsRequest.Put(
319
+ key='no_proxy', value=f'10.0.2.2:{self._port}'
320
+ ),
321
+ )
322
+ )
323
+ )
324
+ self.attempt_enable_networking()
325
+ self.execute_adb_call(
326
+ adb_pb2.AdbRequest(
327
+ send_broadcast=adb_pb2.AdbRequest.SendBroadcast(
328
+ action=(
329
+ 'accessibility_forwarder.intent.action.SET_GRPC --ei'
330
+ f' "port" {self._port}'
331
+ ),
332
+ component=(
333
+ 'com.google.androidenv.accessibilityforwarder/com.google.androidenv.accessibilityforwarder.FlagsBroadcastReceiver'
334
+ ),
335
+ )
336
+ )
337
+ )
338
+
339
+ def _accumulate_and_return_a11y_info(
340
+ self, timer: float | None = None, get_env_observation: bool = True
341
+ ) -> dict[str, Any]:
342
+ """Accumulates and returns the latest a11y tree info and observation.
343
+
344
+ Args:
345
+ timer: If larger than 0, the system will wait this long for a11y info to
346
+ accumulate before it returns a value.
347
+ get_env_observation: If False, the corresponding observation is not
348
+ introduced here.
349
+
350
+ Returns:
351
+ a dict with a11y forest under key 'a11y_forest'. All other fields will
352
+ provide the observation, if requested.
353
+ """
354
+ timer = timer or 0.0
355
+ if timer > 0.0:
356
+ time.sleep(timer)
357
+
358
+ if get_env_observation:
359
+ # Fetch observation.
360
+ new_ts = self._env.step({
361
+ 'action_type': np.array(
362
+ android_action_type_lib.ActionType.REPEAT,
363
+ dtype=self._parent_action_spec['action_type'].dtype,
364
+ ),
365
+ })
366
+ observation = new_ts.observation
367
+ else:
368
+ observation = {}
369
+
370
+ extras = self.accumulate_new_extras()
371
+ forests = a11y_forests.extract_forests_from_task_extras(extras)
372
+ if forests:
373
+ observation['a11y_forest'] = forests[-1]
374
+ else:
375
+ observation['a11y_forest'] = None
376
+ return observation
377
+
378
+ def _fetch_task_extras_and_update_observation(
379
+ self, observation: dict[str, Any], timeout: float = 0.0
380
+ ) -> dict[str, Any]:
381
+ if timeout > 0.0:
382
+ observation = self._accumulate_and_return_a11y_info(
383
+ timeout, get_env_observation=True
384
+ )
385
+ if not self._add_latest_a11y_info_to_obs:
386
+ observation.pop('a11y_forest')
387
+ else:
388
+ new_obs = self._accumulate_and_return_a11y_info(get_env_observation=False)
389
+ if self._add_latest_a11y_info_to_obs:
390
+ observation.update(new_obs)
391
+ return observation
392
+
393
+ def reset(self) -> dm_env.TimeStep:
394
+ self._reset_enable_networking_attempts()
395
+ self._servicer.pause_and_clear()
396
+ timestep = self._env.reset()
397
+ self._servicer.resume()
398
+ if self._env.stats()['relaunch_count'] > self._relaunch_count:
399
+ self._configure_grpc()
400
+ self._relaunch_count = self._env.stats()['relaunch_count']
401
+ self._accumulated_extras = {}
402
+ timeout = self._a11y_info_timeout or 0.0
403
+ new_observation = self._fetch_task_extras_and_update_observation(
404
+ timestep.observation, timeout
405
+ )
406
+ timestep = timestep._replace(observation=new_observation)
407
+ return timestep
408
+
409
+ def step(self, action: Any) -> dm_env.TimeStep:
410
+ timeout = float(action.pop('wait_time', self._a11y_info_timeout or 0.0))
411
+ timestep = self._env.step(action)
412
+ new_observation = self._fetch_task_extras_and_update_observation(
413
+ timestep.observation, timeout=timeout
414
+ )
415
+ timestep = timestep._replace(observation=new_observation)
416
+ return timestep
417
+
418
+ def accumulate_new_extras(self) -> dict[str, Any]:
419
+ new_extras = self._fetch_task_extras()
420
+ if self._should_accumulate:
421
+ for key in new_extras:
422
+ if key in self._accumulated_extras:
423
+ self._accumulated_extras[key] = np.concatenate(
424
+ (self._accumulated_extras[key], new_extras[key]), axis=0
425
+ )
426
+ else:
427
+ self._accumulated_extras[key] = new_extras[key]
428
+ else:
429
+ self._accumulated_extras = new_extras
430
+ self._should_accumulate = True
431
+ return self._accumulated_extras
432
+
433
+ def _fetch_task_extras(self) -> dict[str, Any]:
434
+ """Fetches task_extras from the services.
435
+
436
+ NOTE: If you want to access the latest a11y information, please use
437
+ accumulate_and_return_a11y_info instead. This function has the side effect
438
+ of clearing the content from the servicer, hence all the a11y info returned
439
+ here won't be accumulated.
440
+
441
+ Returns:
442
+ A dict with the corresponding task_extras.
443
+
444
+ Raises:
445
+ EnableNetworkingError: after a fixed number of attempts to revive the a11y
446
+ services by re-enabling the network connection.
447
+ """
448
+ base_extras = self._env.task_extras(latest_only=False).copy()
449
+ # If the previous future is done, reset it to the initial state.
450
+ if (
451
+ self._enabling_networking_future is not None
452
+ and self._enabling_networking_future.done()
453
+ ):
454
+ self._enabling_networking_future = None
455
+ self._enable_networking_attempts_left -= 1
456
+ logging.info('Finished enabling networking.')
457
+
458
+ if (
459
+ self._enabling_networking_future is None
460
+ and 'exception' in base_extras
461
+ and base_extras['exception'].shape[0]
462
+ ):
463
+ self._a11y_exception = base_extras['exception']
464
+ logging.warning(
465
+ 'AccessibilityForwarder logged exceptions: %s', self._a11y_exception
466
+ )
467
+ if self._enable_networking_attempts_left > 0:
468
+ logging.warning(
469
+ 'Attempting to enable networking. %s attempts left.',
470
+ self._enable_networking_attempts_left - 1,
471
+ )
472
+ executor = futures.ThreadPoolExecutor(max_workers=1)
473
+ self._enabling_networking_future = executor.submit(
474
+ self.attempt_enable_networking
475
+ )
476
+ else:
477
+ raise EnableNetworkingError(
478
+ 'A11y service failed multiple times with'
479
+ f' exception.{self._a11y_exception}.'
480
+ )
481
+
482
+ forests = self._servicer.gather_forests()
483
+ if forests:
484
+ base_extras.update(a11y_forests.package_forests_to_task_extras(forests))
485
+ self._reset_enable_networking_attempts()
486
+ events = self._servicer.gather_events()
487
+ if events:
488
+ base_extras.update(a11y_events.package_events_to_task_extras(events))
489
+ self._reset_enable_networking_attempts()
490
+ return base_extras
491
+
492
+ def task_extras(self, latest_only: bool = False) -> dict[str, Any]:
493
+ if self._accumulated_extras is None:
494
+ raise RuntimeError('You must call .reset() before calling .task_extras()')
495
+ self._should_accumulate = False
496
+ extras = self._accumulated_extras.copy()
497
+ if latest_only:
498
+ a11y_events.keep_latest_event_only(extras)
499
+ a11y_forests.keep_latest_forest_only(extras)
500
+ return extras