android-env 1.2.1__py3-none-any.whl → 1.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- android_env/__init__.py +1 -1
- android_env/components/__init__.py +1 -1
- android_env/components/a11y/__init__.py +15 -0
- android_env/components/a11y/a11y_events.py +118 -0
- android_env/components/a11y/a11y_events_test.py +173 -0
- android_env/components/a11y/a11y_forests.py +128 -0
- android_env/components/a11y/a11y_forests_test.py +237 -0
- android_env/components/a11y/a11y_servicer.py +199 -0
- android_env/components/a11y/a11y_servicer_test.py +224 -0
- android_env/components/action_fns.py +132 -0
- android_env/components/action_fns_test.py +227 -0
- android_env/components/action_type.py +26 -3
- android_env/components/adb_call_parser.py +239 -196
- android_env/components/adb_call_parser_test.py +179 -209
- android_env/components/adb_controller.py +90 -52
- android_env/components/adb_controller_test.py +187 -16
- android_env/components/adb_log_stream.py +17 -5
- android_env/components/adb_log_stream_test.py +17 -3
- android_env/components/app_screen_checker.py +17 -15
- android_env/components/app_screen_checker_test.py +7 -8
- android_env/components/config_classes.py +203 -0
- android_env/components/coordinator.py +102 -338
- android_env/components/coordinator_test.py +59 -199
- android_env/components/device_settings.py +174 -0
- android_env/components/device_settings_test.py +228 -0
- android_env/components/dumpsys_thread.py +3 -4
- android_env/components/dumpsys_thread_test.py +1 -1
- android_env/components/errors.py +52 -10
- android_env/components/errors_test.py +110 -0
- android_env/components/log_stream.py +7 -5
- android_env/components/log_stream_test.py +1 -1
- android_env/components/logcat_thread.py +9 -8
- android_env/components/logcat_thread_test.py +3 -4
- android_env/components/{utils.py → pixel_fns.py} +20 -20
- android_env/components/{utils_test.py → pixel_fns_test.py} +20 -15
- android_env/components/setup_step_interpreter.py +47 -39
- android_env/components/setup_step_interpreter_test.py +4 -4
- android_env/components/simulators/__init__.py +1 -1
- android_env/components/simulators/base_simulator.py +116 -44
- android_env/components/simulators/base_simulator_test.py +131 -9
- android_env/components/simulators/emulator/__init__.py +1 -1
- android_env/components/simulators/emulator/emulator_launcher.py +67 -77
- android_env/components/simulators/emulator/emulator_launcher_test.py +153 -49
- android_env/components/simulators/emulator/emulator_simulator.py +276 -95
- android_env/components/simulators/emulator/emulator_simulator_test.py +314 -89
- android_env/components/simulators/fake/__init__.py +1 -1
- android_env/components/simulators/fake/fake_simulator.py +17 -25
- android_env/components/simulators/fake/fake_simulator_test.py +29 -12
- android_env/components/specs.py +18 -28
- android_env/components/specs_test.py +1 -44
- android_env/components/task_manager.py +48 -48
- android_env/components/task_manager_test.py +71 -60
- android_env/env_interface.py +37 -23
- android_env/environment.py +83 -51
- android_env/environment_test.py +68 -29
- android_env/loader.py +57 -43
- android_env/loader_test.py +115 -35
- android_env/proto/__init__.py +1 -1
- android_env/proto/a11y/__init__.py +15 -0
- android_env/proto/a11y/a11y.proto +75 -0
- android_env/proto/a11y/a11y_pb2.py +54 -0
- android_env/proto/a11y/a11y_pb2.pyi +49 -0
- android_env/proto/a11y/a11y_pb2_grpc.py +202 -0
- android_env/proto/a11y/android_accessibility_action.proto +32 -0
- android_env/proto/a11y/android_accessibility_action_pb2.py +37 -0
- android_env/proto/a11y/android_accessibility_action_pb2.pyi +13 -0
- android_env/proto/a11y/android_accessibility_action_pb2_grpc.py +24 -0
- android_env/proto/a11y/android_accessibility_forest.proto +29 -0
- android_env/proto/a11y/android_accessibility_forest_pb2.py +38 -0
- android_env/proto/a11y/android_accessibility_forest_pb2.pyi +13 -0
- android_env/proto/a11y/android_accessibility_forest_pb2_grpc.py +24 -0
- android_env/proto/a11y/android_accessibility_node_info.proto +122 -0
- android_env/proto/a11y/android_accessibility_node_info_clickable_span.proto +49 -0
- android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2.py +39 -0
- android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2.pyi +28 -0
- android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2_grpc.py +24 -0
- android_env/proto/a11y/android_accessibility_node_info_pb2.py +42 -0
- android_env/proto/a11y/android_accessibility_node_info_pb2.pyi +75 -0
- android_env/proto/a11y/android_accessibility_node_info_pb2_grpc.py +24 -0
- android_env/proto/a11y/android_accessibility_tree.proto +29 -0
- android_env/proto/a11y/android_accessibility_tree_pb2.py +38 -0
- android_env/proto/a11y/android_accessibility_tree_pb2.pyi +13 -0
- android_env/proto/a11y/android_accessibility_tree_pb2_grpc.py +24 -0
- android_env/proto/a11y/android_accessibility_window_info.proto +84 -0
- android_env/proto/a11y/android_accessibility_window_info_pb2.py +41 -0
- android_env/proto/a11y/android_accessibility_window_info_pb2.pyi +48 -0
- android_env/proto/a11y/android_accessibility_window_info_pb2_grpc.py +24 -0
- android_env/proto/a11y/rect.proto +30 -0
- android_env/proto/a11y/rect_pb2.py +37 -0
- android_env/proto/a11y/rect_pb2.pyi +17 -0
- android_env/proto/a11y/rect_pb2_grpc.py +24 -0
- android_env/proto/adb.proto +17 -6
- android_env/proto/adb_pb2.py +120 -107
- android_env/proto/adb_pb2.pyi +396 -0
- android_env/proto/adb_pb2_grpc.py +20 -0
- android_env/proto/emulator_controller.proto +68 -63
- android_env/proto/emulator_controller_pb2.py +142 -131
- android_env/proto/emulator_controller_pb2.pyi +672 -0
- android_env/proto/emulator_controller_pb2_grpc.py +505 -142
- android_env/proto/snapshot.proto +169 -0
- android_env/proto/snapshot_pb2.py +47 -0
- android_env/proto/snapshot_pb2.pyi +117 -0
- android_env/proto/snapshot_pb2_grpc.py +24 -0
- android_env/proto/snapshot_service.proto +289 -0
- android_env/proto/snapshot_service_pb2.py +54 -0
- android_env/proto/snapshot_service_pb2.pyi +86 -0
- android_env/proto/snapshot_service_pb2_grpc.py +487 -0
- android_env/proto/state.proto +63 -0
- android_env/proto/state_pb2.py +63 -0
- android_env/proto/state_pb2.pyi +85 -0
- android_env/proto/state_pb2_grpc.py +24 -0
- android_env/proto/task.proto +5 -1
- android_env/proto/task_pb2.py +42 -31
- android_env/proto/task_pb2.pyi +160 -0
- android_env/proto/task_pb2_grpc.py +20 -0
- android_env/wrappers/__init__.py +1 -1
- android_env/wrappers/a11y_grpc_wrapper.py +500 -0
- android_env/wrappers/a11y_grpc_wrapper_test.py +849 -0
- android_env/wrappers/base_wrapper.py +34 -13
- android_env/wrappers/base_wrapper_test.py +22 -16
- android_env/wrappers/discrete_action_wrapper.py +18 -17
- android_env/wrappers/discrete_action_wrapper_test.py +4 -4
- android_env/wrappers/flat_interface_wrapper.py +5 -5
- android_env/wrappers/flat_interface_wrapper_test.py +7 -11
- android_env/wrappers/float_pixels_wrapper.py +9 -10
- android_env/wrappers/float_pixels_wrapper_test.py +3 -3
- android_env/wrappers/gym_wrapper.py +19 -13
- android_env/wrappers/gym_wrapper_test.py +3 -5
- android_env/wrappers/image_rescale_wrapper.py +18 -21
- android_env/wrappers/image_rescale_wrapper_test.py +25 -37
- android_env/wrappers/last_action_wrapper.py +16 -13
- android_env/wrappers/last_action_wrapper_test.py +44 -51
- android_env/wrappers/rate_limit_wrapper.py +6 -3
- android_env/wrappers/rate_limit_wrapper_test.py +22 -1
- android_env/wrappers/tap_action_wrapper.py +16 -17
- android_env/wrappers/tap_action_wrapper_test.py +51 -16
- {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info}/METADATA +14 -18
- android_env-1.2.3.dist-info/RECORD +141 -0
- {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info}/WHEEL +1 -1
- android_env/proto/raw_observation.proto +0 -39
- android_env/proto/raw_observation_pb2.py +0 -27
- android_env/proto/raw_observation_pb2_grpc.py +0 -4
- android_env-1.2.1.dist-info/RECORD +0 -81
- {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info/licenses}/LICENSE +0 -0
- {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2024 DeepMind Technologies Limited.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -55,33 +55,26 @@ class TaskManagerTest(absltest.TestCase):
|
|
55
55
|
log_stream, 'LogStream',
|
56
56
|
return_value=self._log_stream).start()
|
57
57
|
|
58
|
-
def
|
59
|
-
|
60
|
-
new_task = task_pb2.Task(id='updated_task')
|
61
|
-
task_mgr = task_manager.TaskManager(task=init_task)
|
62
|
-
|
63
|
-
# Setting up the initial task so that the setup_step_interpreter
|
64
|
-
# is properly initialized.
|
58
|
+
def test_start(self):
|
59
|
+
task_mgr = task_manager.TaskManager(task=task_pb2.Task())
|
65
60
|
adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
|
66
|
-
task_mgr.
|
67
|
-
|
68
|
-
self.
|
69
|
-
self.
|
70
|
-
task_mgr.update_task(new_task)
|
71
|
-
self.assertEqual(new_task, task_mgr.task())
|
72
|
-
self.assertEqual(1, task_mgr.stats()['task_updates'])
|
61
|
+
task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
|
62
|
+
self.assertIsNotNone(task_mgr._logcat_thread)
|
63
|
+
self.assertIsNotNone(task_mgr._dumpsys_thread)
|
64
|
+
self.assertIsNotNone(task_mgr._setup_step_interpreter)
|
73
65
|
|
74
66
|
def test_setup_task(self):
|
75
67
|
task_mgr = task_manager.TaskManager(task=task_pb2.Task())
|
76
68
|
adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
|
77
|
-
task_mgr.
|
78
|
-
|
79
|
-
self.
|
69
|
+
task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
|
70
|
+
task_mgr.setup_task()
|
71
|
+
self._setup_step_interpreter.interpret.assert_called_once()
|
80
72
|
|
81
73
|
def test_step_count(self):
|
82
74
|
task_mgr = task_manager.TaskManager(task=task_pb2.Task())
|
83
75
|
adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
|
84
|
-
task_mgr.
|
76
|
+
task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
|
77
|
+
task_mgr.setup_task()
|
85
78
|
task_mgr.rl_reset(observation={})
|
86
79
|
self.assertEqual(task_mgr.stats()['episode_steps'], 0)
|
87
80
|
task_mgr.rl_step(observation={})
|
@@ -109,7 +102,8 @@ class TaskManagerTest(absltest.TestCase):
|
|
109
102
|
task_mgr = task_manager.TaskManager(task=task)
|
110
103
|
self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
|
111
104
|
adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
|
112
|
-
task_mgr.
|
105
|
+
task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
|
106
|
+
task_mgr.setup_task()
|
113
107
|
timestep = task_mgr.rl_step(
|
114
108
|
observation={
|
115
109
|
'pixels': np.array([1, 2, 3]),
|
@@ -144,7 +138,8 @@ class TaskManagerTest(absltest.TestCase):
|
|
144
138
|
task_mgr = task_manager.TaskManager(task=task)
|
145
139
|
self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
|
146
140
|
adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
|
147
|
-
task_mgr.
|
141
|
+
task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
|
142
|
+
task_mgr.setup_task()
|
148
143
|
timestep = task_mgr.rl_step(
|
149
144
|
observation={
|
150
145
|
'pixels': np.array([1, 2, 3]),
|
@@ -174,7 +169,8 @@ class TaskManagerTest(absltest.TestCase):
|
|
174
169
|
task_mgr = task_manager.TaskManager(task=task)
|
175
170
|
self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
|
176
171
|
adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
|
177
|
-
task_mgr.
|
172
|
+
task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
|
173
|
+
task_mgr.setup_task()
|
178
174
|
timestep = task_mgr.rl_step(
|
179
175
|
observation={
|
180
176
|
'pixels': np.array([1, 2, 3]),
|
@@ -207,7 +203,8 @@ class TaskManagerTest(absltest.TestCase):
|
|
207
203
|
task_mgr = task_manager.TaskManager(task=task)
|
208
204
|
self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
|
209
205
|
adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
|
210
|
-
task_mgr.
|
206
|
+
task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
|
207
|
+
task_mgr.setup_task()
|
211
208
|
|
212
209
|
timestep = task_mgr.rl_step(
|
213
210
|
observation={
|
@@ -256,7 +253,8 @@ class TaskManagerTest(absltest.TestCase):
|
|
256
253
|
task_mgr = task_manager.TaskManager(task=task)
|
257
254
|
self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
|
258
255
|
adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
|
259
|
-
task_mgr.
|
256
|
+
task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
|
257
|
+
task_mgr.setup_task()
|
260
258
|
|
261
259
|
timestep = task_mgr.rl_step(
|
262
260
|
observation={
|
@@ -286,6 +284,49 @@ class TaskManagerTest(absltest.TestCase):
|
|
286
284
|
np.testing.assert_equal(
|
287
285
|
expected_extra.get('extra_dict'), extras.get('extra_dict'))
|
288
286
|
|
287
|
+
def test_get_current_extras_failed_to_parse(self):
|
288
|
+
# Replace `LogcatThread.add_event_listener` with one that simply calls `fn`
|
289
|
+
# right away.
|
290
|
+
def my_add_ev_listener(event_listener: logcat_thread.EventListener):
|
291
|
+
# Check that the event matches what's expected.
|
292
|
+
event = event_listener.regexp
|
293
|
+
match = event.match('extra: some_extra [1, 2]')
|
294
|
+
if match is None: # Ignore events that are not extras.
|
295
|
+
return
|
296
|
+
|
297
|
+
# Emit events.
|
298
|
+
fn = event_listener.handler_fn
|
299
|
+
fn(event, event.match('extra: extra_with_malformed_1 [1]'))
|
300
|
+
fn(event, event.match('extra: extra_with_malformed_1 [\'this is \\ bad]'))
|
301
|
+
fn(event, event.match('extra: extra_with_malformed_1 [2]'))
|
302
|
+
fn(event, event.match('extra: extra_with_malformed_2 [\'this is bad]'))
|
303
|
+
fn(event, event.match('extra: extra_with_malformed_2 [1]'))
|
304
|
+
fn(event, event.match('extra: extra_malformed_only [_very_bad_news]'))
|
305
|
+
|
306
|
+
# Setup the task and trigger the listener.
|
307
|
+
task = task_pb2.Task()
|
308
|
+
task.log_parsing_config.log_regexps.extra.extend([
|
309
|
+
'^extra: (?P<name>[^ ]*)[ ]?(?P<extra>.*)$'
|
310
|
+
])
|
311
|
+
task_mgr = task_manager.TaskManager(task=task)
|
312
|
+
self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
|
313
|
+
adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
|
314
|
+
task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
|
315
|
+
task_mgr.setup_task()
|
316
|
+
|
317
|
+
timestep = task_mgr.rl_step(
|
318
|
+
observation={
|
319
|
+
'pixels': np.array([1, 2, 3]),
|
320
|
+
})
|
321
|
+
|
322
|
+
# Check expectations.
|
323
|
+
self.assertIn('extras', timestep.observation)
|
324
|
+
extras = timestep.observation['extras']
|
325
|
+
np.testing.assert_almost_equal(extras.get('extra_with_malformed_1'),
|
326
|
+
[[1], [2]])
|
327
|
+
np.testing.assert_almost_equal(extras.get('extra_with_malformed_2'), [[1]])
|
328
|
+
self.assertNotIn('extra_malformed_only', extras)
|
329
|
+
|
289
330
|
def test_multi_log_regexp(self):
|
290
331
|
# Replace `LogcatThread.add_event_listener` with one that simply calls `fn`
|
291
332
|
# right away.
|
@@ -305,7 +346,8 @@ class TaskManagerTest(absltest.TestCase):
|
|
305
346
|
task_mgr = task_manager.TaskManager(task=task)
|
306
347
|
self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
|
307
348
|
adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
|
308
|
-
task_mgr.
|
349
|
+
task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
|
350
|
+
task_mgr.setup_task()
|
309
351
|
timestep = task_mgr.rl_step(
|
310
352
|
observation={
|
311
353
|
'pixels': np.array([1, 2, 3]),
|
@@ -335,7 +377,8 @@ class TaskManagerTest(absltest.TestCase):
|
|
335
377
|
task_mgr = task_manager.TaskManager(task=task)
|
336
378
|
self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
|
337
379
|
adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
|
338
|
-
task_mgr.
|
380
|
+
task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
|
381
|
+
task_mgr.setup_task()
|
339
382
|
timestep = task_mgr.rl_step(
|
340
383
|
observation={
|
341
384
|
'pixels': np.array([1, 2, 3]),
|
@@ -359,7 +402,8 @@ class TaskManagerTest(absltest.TestCase):
|
|
359
402
|
task_mgr = task_manager.TaskManager(task=task)
|
360
403
|
self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
|
361
404
|
adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
|
362
|
-
task_mgr.
|
405
|
+
task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
|
406
|
+
task_mgr.setup_task()
|
363
407
|
timestep = task_mgr.rl_step(
|
364
408
|
observation={
|
365
409
|
'pixels': np.array([1, 2, 3]),
|
@@ -367,38 +411,5 @@ class TaskManagerTest(absltest.TestCase):
|
|
367
411
|
self.assertTrue(timestep.last())
|
368
412
|
|
369
413
|
|
370
|
-
def test_setup_steps_skipped_if_no_change(self):
|
371
|
-
setup_steps = [task_pb2.SetupStep(sleep=task_pb2.Sleep(time_sec=10))]
|
372
|
-
task_mgr = task_manager.TaskManager(
|
373
|
-
task=task_pb2.Task(id='first_task', setup_steps=setup_steps))
|
374
|
-
adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
|
375
|
-
task_mgr.setup_task(lambda: adb_call_parser, log_stream=self._log_stream)
|
376
|
-
self.assertIsNotNone(task_mgr._logcat_thread)
|
377
|
-
self.assertIsNotNone(task_mgr._setup_step_interpreter)
|
378
|
-
self._setup_step_interpreter.interpret.assert_called_once_with(setup_steps)
|
379
|
-
task_mgr.update_task(
|
380
|
-
task=task_pb2.Task(id='second_task', setup_steps=setup_steps))
|
381
|
-
# The interpreter should not be called again.
|
382
|
-
task_mgr.setup_task(lambda: adb_call_parser, log_stream=self._log_stream)
|
383
|
-
self._setup_step_interpreter.interpret.assert_called_once()
|
384
|
-
|
385
|
-
|
386
|
-
def test_setup_steps_interpreted_if_changed(self):
|
387
|
-
setup_steps = [task_pb2.SetupStep(sleep=task_pb2.Sleep(time_sec=10))]
|
388
|
-
task_mgr = task_manager.TaskManager(
|
389
|
-
task=task_pb2.Task(id='first_task', setup_steps=setup_steps))
|
390
|
-
adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
|
391
|
-
task_mgr.setup_task(lambda: adb_call_parser, log_stream=self._log_stream)
|
392
|
-
self.assertIsNotNone(task_mgr._logcat_thread)
|
393
|
-
self.assertIsNotNone(task_mgr._setup_step_interpreter)
|
394
|
-
self._setup_step_interpreter.interpret.assert_called_once_with(setup_steps)
|
395
|
-
new_setup_steps = [task_pb2.SetupStep(sleep=task_pb2.Sleep(time_sec=5))]
|
396
|
-
task_mgr.update_task(
|
397
|
-
task=task_pb2.Task(id='second_task', setup_steps=new_setup_steps))
|
398
|
-
task_mgr.setup_task(lambda: adb_call_parser, log_stream=self._log_stream)
|
399
|
-
self.assertEqual(2, self._setup_step_interpreter.interpret.num_calls())
|
400
|
-
self._setup_step_interpreter.interpret.assert_has_calls(
|
401
|
-
mock.call(setup_steps), mock.call(new_setup_steps))
|
402
|
-
|
403
414
|
if __name__ == '__main__':
|
404
415
|
absltest.main()
|
android_env/env_interface.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2024 DeepMind Technologies Limited.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -20,10 +20,10 @@ extra methods that clients may use for extended functionality.
|
|
20
20
|
"""
|
21
21
|
|
22
22
|
import abc
|
23
|
-
from typing import Any
|
23
|
+
from typing import Any
|
24
24
|
|
25
25
|
from android_env.proto import adb_pb2
|
26
|
-
from android_env.proto import
|
26
|
+
from android_env.proto import state_pb2
|
27
27
|
import dm_env
|
28
28
|
import numpy as np
|
29
29
|
|
@@ -34,11 +34,11 @@ class AndroidEnvInterface(dm_env.Environment, metaclass=abc.ABCMeta):
|
|
34
34
|
# Methods required by dm_env.Environment.
|
35
35
|
|
36
36
|
@abc.abstractmethod
|
37
|
-
def action_spec(self) ->
|
37
|
+
def action_spec(self) -> dict[str, dm_env.specs.Array]:
|
38
38
|
"""Returns the action specification."""
|
39
39
|
|
40
40
|
@abc.abstractmethod
|
41
|
-
def observation_spec(self) ->
|
41
|
+
def observation_spec(self) -> dict[str, dm_env.specs.Array]:
|
42
42
|
"""Returns the observation specification."""
|
43
43
|
|
44
44
|
@abc.abstractmethod
|
@@ -46,7 +46,7 @@ class AndroidEnvInterface(dm_env.Environment, metaclass=abc.ABCMeta):
|
|
46
46
|
"""Resets the current episode."""
|
47
47
|
|
48
48
|
@abc.abstractmethod
|
49
|
-
def step(self, action:
|
49
|
+
def step(self, action: dict[str, np.ndarray]) -> dm_env.TimeStep:
|
50
50
|
"""Executes `action` and returns a `TimeStep`."""
|
51
51
|
|
52
52
|
@abc.abstractmethod
|
@@ -55,41 +55,55 @@ class AndroidEnvInterface(dm_env.Environment, metaclass=abc.ABCMeta):
|
|
55
55
|
|
56
56
|
# Extensions provided by AndroidEnv.
|
57
57
|
|
58
|
-
|
59
|
-
def task_extras_spec(self) -> Dict[str, dm_env.specs.Array]:
|
60
|
-
"""Returns the specification for extra info provided by tasks."""
|
61
|
-
|
62
|
-
@abc.abstractmethod
|
63
|
-
def task_extras(self, latest_only: bool = True) -> Dict[str, np.ndarray]:
|
58
|
+
def task_extras(self, latest_only: bool = True) -> dict[str, np.ndarray]:
|
64
59
|
"""Returns extra info provided by tasks."""
|
65
60
|
|
61
|
+
return {}
|
62
|
+
|
66
63
|
@property
|
67
|
-
@abc.abstractmethod
|
68
64
|
def raw_action(self):
|
69
65
|
"""Returns the latest action."""
|
70
66
|
|
71
67
|
@property
|
72
|
-
@abc.abstractmethod
|
73
68
|
def raw_observation(self):
|
74
69
|
"""Returns the latest observation."""
|
75
70
|
|
76
|
-
|
77
|
-
def stats(self) -> Dict[str, Any]:
|
71
|
+
def stats(self) -> dict[str, Any]:
|
78
72
|
"""Returns information generated inside the implementation."""
|
79
73
|
|
80
|
-
|
74
|
+
return {}
|
75
|
+
|
81
76
|
def execute_adb_call(self, call: adb_pb2.AdbRequest) -> adb_pb2.AdbResponse:
|
82
77
|
"""Executes `call` and returns its response."""
|
83
78
|
|
84
|
-
|
85
|
-
|
86
|
-
|
79
|
+
return adb_pb2.AdbResponse()
|
80
|
+
|
81
|
+
def load_state(
|
82
|
+
self, request: state_pb2.LoadStateRequest
|
83
|
+
) -> state_pb2.LoadStateResponse:
|
84
|
+
"""Loads a state.
|
85
|
+
|
86
|
+
Args:
|
87
|
+
request: A `LoadStateRequest` containing any parameters necessary to
|
88
|
+
specify how/what state to load.
|
89
|
+
|
90
|
+
Returns:
|
91
|
+
A `LoadStateResponse` containing the status, error message (if
|
92
|
+
applicable), and any other relevant information.
|
93
|
+
"""
|
94
|
+
raise NotImplementedError('This environment does not support loading state')
|
87
95
|
|
88
|
-
|
96
|
+
def save_state(
|
97
|
+
self, request: state_pb2.SaveStateRequest
|
98
|
+
) -> state_pb2.SaveStateResponse:
|
99
|
+
"""Saves a state.
|
89
100
|
|
90
101
|
Args:
|
91
|
-
|
102
|
+
request: A `SaveStateRequest` containing any parameters necessary to
|
103
|
+
specify how/what state to save.
|
92
104
|
|
93
105
|
Returns:
|
94
|
-
A
|
106
|
+
A `SaveStateResponse` containing the status, error message (if
|
107
|
+
applicable), and any other relevant information.
|
95
108
|
"""
|
109
|
+
raise NotImplementedError('This environment does not support saving state')
|
android_env/environment.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2024 DeepMind Technologies Limited.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -15,12 +15,16 @@
|
|
15
15
|
|
16
16
|
"""Android environment implementation."""
|
17
17
|
|
18
|
-
from typing import Any
|
18
|
+
from typing import Any
|
19
|
+
|
19
20
|
from absl import logging
|
20
21
|
from android_env import env_interface
|
22
|
+
from android_env.components import adb_call_parser
|
21
23
|
from android_env.components import coordinator as coordinator_lib
|
24
|
+
from android_env.components import task_manager as task_manager_lib
|
25
|
+
from android_env.components.simulators import base_simulator
|
22
26
|
from android_env.proto import adb_pb2
|
23
|
-
from android_env.proto import
|
27
|
+
from android_env.proto import state_pb2
|
24
28
|
import dm_env
|
25
29
|
import numpy as np
|
26
30
|
|
@@ -28,10 +32,17 @@ import numpy as np
|
|
28
32
|
class AndroidEnv(env_interface.AndroidEnvInterface):
|
29
33
|
"""An RL environment that interacts with Android apps."""
|
30
34
|
|
31
|
-
def __init__(
|
35
|
+
def __init__(
|
36
|
+
self,
|
37
|
+
simulator: base_simulator.BaseSimulator,
|
38
|
+
coordinator: coordinator_lib.Coordinator,
|
39
|
+
task_manager: task_manager_lib.TaskManager,
|
40
|
+
):
|
32
41
|
"""Initializes the state of this AndroidEnv object."""
|
33
42
|
|
43
|
+
self._simulator = simulator
|
34
44
|
self._coordinator = coordinator
|
45
|
+
self._task_manager = task_manager
|
35
46
|
self._latest_action = {}
|
36
47
|
self._latest_observation = {}
|
37
48
|
self._latest_extras = {}
|
@@ -40,27 +51,17 @@ class AndroidEnv(env_interface.AndroidEnvInterface):
|
|
40
51
|
|
41
52
|
logging.info('Action spec: %s', self.action_spec())
|
42
53
|
logging.info('Observation spec: %s', self.observation_spec())
|
43
|
-
logging.info('Task extras spec: %s', self.task_extras_spec())
|
44
|
-
|
45
|
-
def action_spec(self) -> Dict[str, dm_env.specs.Array]:
|
46
|
-
return self._coordinator.action_spec()
|
47
54
|
|
48
|
-
def
|
49
|
-
|
50
|
-
|
51
|
-
def task_extras_spec(self) -> Dict[str, dm_env.specs.Array]:
|
52
|
-
return self._coordinator.task_extras_spec()
|
55
|
+
def __del__(self) -> None:
|
56
|
+
self.close()
|
53
57
|
|
54
|
-
|
55
|
-
def raw_action(self):
|
56
|
-
return self._latest_action.copy()
|
58
|
+
# Methods required by dm_env.Environment.
|
57
59
|
|
58
|
-
|
59
|
-
|
60
|
-
return self._latest_observation.copy()
|
60
|
+
def action_spec(self) -> dict[str, dm_env.specs.Array]:
|
61
|
+
return self._coordinator.action_spec()
|
61
62
|
|
62
|
-
def
|
63
|
-
return self._coordinator.
|
63
|
+
def observation_spec(self) -> dict[str, dm_env.specs.Array]:
|
64
|
+
return self._coordinator.observation_spec()
|
64
65
|
|
65
66
|
def reset(self) -> dm_env.TimeStep:
|
66
67
|
"""Resets the environment for a new RL episode."""
|
@@ -86,7 +87,7 @@ class AndroidEnv(env_interface.AndroidEnvInterface):
|
|
86
87
|
|
87
88
|
return timestep
|
88
89
|
|
89
|
-
def step(self, action:
|
90
|
+
def step(self, action: dict[str, np.ndarray]) -> dm_env.TimeStep:
|
90
91
|
"""Takes a step in the environment."""
|
91
92
|
|
92
93
|
# Check if it's time to reset the episode.
|
@@ -112,47 +113,78 @@ class AndroidEnv(env_interface.AndroidEnvInterface):
|
|
112
113
|
|
113
114
|
return timestep
|
114
115
|
|
115
|
-
def
|
116
|
+
def close(self) -> None:
|
117
|
+
"""Cleans up running processes, threads and local files."""
|
118
|
+
if not self._is_closed:
|
119
|
+
logging.info('Cleaning up AndroidEnv...')
|
120
|
+
if hasattr(self, '_coordinator'):
|
121
|
+
self._coordinator.close()
|
122
|
+
logging.info('Done cleaning up AndroidEnv.')
|
123
|
+
self._is_closed = True
|
124
|
+
|
125
|
+
# Extensions provided by AndroidEnv.
|
126
|
+
|
127
|
+
def task_extras(self, latest_only: bool = True) -> dict[str, np.ndarray]:
|
116
128
|
"""Returns latest task extras."""
|
117
129
|
|
118
|
-
task_extras = {}
|
119
|
-
for
|
120
|
-
|
121
|
-
|
122
|
-
for extra in extra_values:
|
123
|
-
self._validate_type(spec, extra)
|
124
|
-
task_extras[key] = extra_values[-1] if latest_only else extra_values
|
130
|
+
task_extras = {} # Build a copy to avoid reusing objects.
|
131
|
+
for k, spec in self._latest_extras.items():
|
132
|
+
extra_values = spec.astype(spec.dtype)
|
133
|
+
task_extras[k] = extra_values[-1] if latest_only else extra_values
|
125
134
|
return task_extras
|
126
135
|
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
136
|
+
@property
|
137
|
+
def raw_action(self):
|
138
|
+
return self._latest_action.copy()
|
139
|
+
|
140
|
+
@property
|
141
|
+
def raw_observation(self):
|
142
|
+
return self._latest_observation.copy()
|
143
|
+
|
144
|
+
def stats(self) -> dict[str, Any]:
|
145
|
+
coordinator_stats = self._coordinator.stats()
|
146
|
+
task_manager_stats = self._task_manager.stats()
|
147
|
+
return coordinator_stats | task_manager_stats
|
133
148
|
|
134
149
|
def execute_adb_call(self, call: adb_pb2.AdbRequest) -> adb_pb2.AdbResponse:
|
135
150
|
return self._coordinator.execute_adb_call(call)
|
136
151
|
|
137
|
-
def
|
138
|
-
|
152
|
+
def load_state(
|
153
|
+
self, request: state_pb2.LoadStateRequest
|
154
|
+
) -> state_pb2.LoadStateResponse:
|
155
|
+
"""Loads a state.
|
139
156
|
|
140
157
|
Args:
|
141
|
-
|
158
|
+
request: A `LoadStateRequest` containing any parameters necessary to
|
159
|
+
specify how/what state to load.
|
142
160
|
|
143
161
|
Returns:
|
144
|
-
A
|
162
|
+
A `LoadStateResponse` containing the status, error message (if
|
163
|
+
applicable), and any other relevant information.
|
145
164
|
"""
|
146
|
-
return self._coordinator.update_task(task)
|
147
165
|
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
166
|
+
self._task_manager.stop()
|
167
|
+
response = self._simulator.load_state(request)
|
168
|
+
self._task_manager.start(
|
169
|
+
adb_call_parser_factory=lambda: adb_call_parser.AdbCallParser(
|
170
|
+
self._simulator.create_adb_controller()
|
171
|
+
),
|
172
|
+
log_stream=self._simulator.create_log_stream(),
|
173
|
+
)
|
174
|
+
return response
|
175
|
+
|
176
|
+
def save_state(
|
177
|
+
self, request: state_pb2.SaveStateRequest
|
178
|
+
) -> state_pb2.SaveStateResponse:
|
179
|
+
"""Saves a state.
|
156
180
|
|
157
|
-
|
158
|
-
|
181
|
+
Args:
|
182
|
+
request: A `SaveStateRequest` containing any parameters necessary to
|
183
|
+
specify how/what state to save.
|
184
|
+
|
185
|
+
Returns:
|
186
|
+
A `SaveStateResponse` containing the status, error message (if
|
187
|
+
applicable), and any other relevant information.
|
188
|
+
"""
|
189
|
+
|
190
|
+
return self._simulator.save_state(request)
|