android-env 1.2.1__py3-none-any.whl → 1.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. android_env/__init__.py +1 -1
  2. android_env/components/__init__.py +1 -1
  3. android_env/components/a11y/__init__.py +15 -0
  4. android_env/components/a11y/a11y_events.py +118 -0
  5. android_env/components/a11y/a11y_events_test.py +173 -0
  6. android_env/components/a11y/a11y_forests.py +128 -0
  7. android_env/components/a11y/a11y_forests_test.py +237 -0
  8. android_env/components/a11y/a11y_servicer.py +199 -0
  9. android_env/components/a11y/a11y_servicer_test.py +224 -0
  10. android_env/components/action_fns.py +132 -0
  11. android_env/components/action_fns_test.py +227 -0
  12. android_env/components/action_type.py +26 -3
  13. android_env/components/adb_call_parser.py +239 -196
  14. android_env/components/adb_call_parser_test.py +179 -209
  15. android_env/components/adb_controller.py +90 -52
  16. android_env/components/adb_controller_test.py +187 -16
  17. android_env/components/adb_log_stream.py +17 -5
  18. android_env/components/adb_log_stream_test.py +17 -3
  19. android_env/components/app_screen_checker.py +17 -15
  20. android_env/components/app_screen_checker_test.py +7 -8
  21. android_env/components/config_classes.py +203 -0
  22. android_env/components/coordinator.py +102 -338
  23. android_env/components/coordinator_test.py +59 -199
  24. android_env/components/device_settings.py +174 -0
  25. android_env/components/device_settings_test.py +228 -0
  26. android_env/components/dumpsys_thread.py +3 -4
  27. android_env/components/dumpsys_thread_test.py +1 -1
  28. android_env/components/errors.py +52 -10
  29. android_env/components/errors_test.py +110 -0
  30. android_env/components/log_stream.py +7 -5
  31. android_env/components/log_stream_test.py +1 -1
  32. android_env/components/logcat_thread.py +9 -8
  33. android_env/components/logcat_thread_test.py +3 -4
  34. android_env/components/{utils.py → pixel_fns.py} +20 -20
  35. android_env/components/{utils_test.py → pixel_fns_test.py} +20 -15
  36. android_env/components/setup_step_interpreter.py +47 -39
  37. android_env/components/setup_step_interpreter_test.py +4 -4
  38. android_env/components/simulators/__init__.py +1 -1
  39. android_env/components/simulators/base_simulator.py +116 -44
  40. android_env/components/simulators/base_simulator_test.py +131 -9
  41. android_env/components/simulators/emulator/__init__.py +1 -1
  42. android_env/components/simulators/emulator/emulator_launcher.py +67 -77
  43. android_env/components/simulators/emulator/emulator_launcher_test.py +153 -49
  44. android_env/components/simulators/emulator/emulator_simulator.py +276 -95
  45. android_env/components/simulators/emulator/emulator_simulator_test.py +314 -89
  46. android_env/components/simulators/fake/__init__.py +1 -1
  47. android_env/components/simulators/fake/fake_simulator.py +17 -25
  48. android_env/components/simulators/fake/fake_simulator_test.py +29 -12
  49. android_env/components/specs.py +18 -28
  50. android_env/components/specs_test.py +1 -44
  51. android_env/components/task_manager.py +48 -48
  52. android_env/components/task_manager_test.py +71 -60
  53. android_env/env_interface.py +37 -23
  54. android_env/environment.py +83 -51
  55. android_env/environment_test.py +68 -29
  56. android_env/loader.py +57 -43
  57. android_env/loader_test.py +115 -35
  58. android_env/proto/__init__.py +1 -1
  59. android_env/proto/a11y/__init__.py +15 -0
  60. android_env/proto/a11y/a11y.proto +75 -0
  61. android_env/proto/a11y/a11y_pb2.py +54 -0
  62. android_env/proto/a11y/a11y_pb2.pyi +49 -0
  63. android_env/proto/a11y/a11y_pb2_grpc.py +202 -0
  64. android_env/proto/a11y/android_accessibility_action.proto +32 -0
  65. android_env/proto/a11y/android_accessibility_action_pb2.py +37 -0
  66. android_env/proto/a11y/android_accessibility_action_pb2.pyi +13 -0
  67. android_env/proto/a11y/android_accessibility_action_pb2_grpc.py +24 -0
  68. android_env/proto/a11y/android_accessibility_forest.proto +29 -0
  69. android_env/proto/a11y/android_accessibility_forest_pb2.py +38 -0
  70. android_env/proto/a11y/android_accessibility_forest_pb2.pyi +13 -0
  71. android_env/proto/a11y/android_accessibility_forest_pb2_grpc.py +24 -0
  72. android_env/proto/a11y/android_accessibility_node_info.proto +122 -0
  73. android_env/proto/a11y/android_accessibility_node_info_clickable_span.proto +49 -0
  74. android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2.py +39 -0
  75. android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2.pyi +28 -0
  76. android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2_grpc.py +24 -0
  77. android_env/proto/a11y/android_accessibility_node_info_pb2.py +42 -0
  78. android_env/proto/a11y/android_accessibility_node_info_pb2.pyi +75 -0
  79. android_env/proto/a11y/android_accessibility_node_info_pb2_grpc.py +24 -0
  80. android_env/proto/a11y/android_accessibility_tree.proto +29 -0
  81. android_env/proto/a11y/android_accessibility_tree_pb2.py +38 -0
  82. android_env/proto/a11y/android_accessibility_tree_pb2.pyi +13 -0
  83. android_env/proto/a11y/android_accessibility_tree_pb2_grpc.py +24 -0
  84. android_env/proto/a11y/android_accessibility_window_info.proto +84 -0
  85. android_env/proto/a11y/android_accessibility_window_info_pb2.py +41 -0
  86. android_env/proto/a11y/android_accessibility_window_info_pb2.pyi +48 -0
  87. android_env/proto/a11y/android_accessibility_window_info_pb2_grpc.py +24 -0
  88. android_env/proto/a11y/rect.proto +30 -0
  89. android_env/proto/a11y/rect_pb2.py +37 -0
  90. android_env/proto/a11y/rect_pb2.pyi +17 -0
  91. android_env/proto/a11y/rect_pb2_grpc.py +24 -0
  92. android_env/proto/adb.proto +17 -6
  93. android_env/proto/adb_pb2.py +120 -107
  94. android_env/proto/adb_pb2.pyi +396 -0
  95. android_env/proto/adb_pb2_grpc.py +20 -0
  96. android_env/proto/emulator_controller.proto +68 -63
  97. android_env/proto/emulator_controller_pb2.py +142 -131
  98. android_env/proto/emulator_controller_pb2.pyi +672 -0
  99. android_env/proto/emulator_controller_pb2_grpc.py +505 -142
  100. android_env/proto/snapshot.proto +169 -0
  101. android_env/proto/snapshot_pb2.py +47 -0
  102. android_env/proto/snapshot_pb2.pyi +117 -0
  103. android_env/proto/snapshot_pb2_grpc.py +24 -0
  104. android_env/proto/snapshot_service.proto +289 -0
  105. android_env/proto/snapshot_service_pb2.py +54 -0
  106. android_env/proto/snapshot_service_pb2.pyi +86 -0
  107. android_env/proto/snapshot_service_pb2_grpc.py +487 -0
  108. android_env/proto/state.proto +63 -0
  109. android_env/proto/state_pb2.py +63 -0
  110. android_env/proto/state_pb2.pyi +85 -0
  111. android_env/proto/state_pb2_grpc.py +24 -0
  112. android_env/proto/task.proto +5 -1
  113. android_env/proto/task_pb2.py +42 -31
  114. android_env/proto/task_pb2.pyi +160 -0
  115. android_env/proto/task_pb2_grpc.py +20 -0
  116. android_env/wrappers/__init__.py +1 -1
  117. android_env/wrappers/a11y_grpc_wrapper.py +500 -0
  118. android_env/wrappers/a11y_grpc_wrapper_test.py +849 -0
  119. android_env/wrappers/base_wrapper.py +34 -13
  120. android_env/wrappers/base_wrapper_test.py +22 -16
  121. android_env/wrappers/discrete_action_wrapper.py +18 -17
  122. android_env/wrappers/discrete_action_wrapper_test.py +4 -4
  123. android_env/wrappers/flat_interface_wrapper.py +5 -5
  124. android_env/wrappers/flat_interface_wrapper_test.py +7 -11
  125. android_env/wrappers/float_pixels_wrapper.py +9 -10
  126. android_env/wrappers/float_pixels_wrapper_test.py +3 -3
  127. android_env/wrappers/gym_wrapper.py +19 -13
  128. android_env/wrappers/gym_wrapper_test.py +3 -5
  129. android_env/wrappers/image_rescale_wrapper.py +18 -21
  130. android_env/wrappers/image_rescale_wrapper_test.py +25 -37
  131. android_env/wrappers/last_action_wrapper.py +16 -13
  132. android_env/wrappers/last_action_wrapper_test.py +44 -51
  133. android_env/wrappers/rate_limit_wrapper.py +6 -3
  134. android_env/wrappers/rate_limit_wrapper_test.py +22 -1
  135. android_env/wrappers/tap_action_wrapper.py +16 -17
  136. android_env/wrappers/tap_action_wrapper_test.py +51 -16
  137. {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info}/METADATA +14 -18
  138. android_env-1.2.3.dist-info/RECORD +141 -0
  139. {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info}/WHEEL +1 -1
  140. android_env/proto/raw_observation.proto +0 -39
  141. android_env/proto/raw_observation_pb2.py +0 -27
  142. android_env/proto/raw_observation_pb2_grpc.py +0 -4
  143. android_env-1.2.1.dist-info/RECORD +0 -81
  144. {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info/licenses}/LICENSE +0 -0
  145. {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2022 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -55,33 +55,26 @@ class TaskManagerTest(absltest.TestCase):
55
55
  log_stream, 'LogStream',
56
56
  return_value=self._log_stream).start()
57
57
 
58
- def test_update_task(self):
59
- init_task = task_pb2.Task(id='initial_task')
60
- new_task = task_pb2.Task(id='updated_task')
61
- task_mgr = task_manager.TaskManager(task=init_task)
62
-
63
- # Setting up the initial task so that the setup_step_interpreter
64
- # is properly initialized.
58
+ def test_start(self):
59
+ task_mgr = task_manager.TaskManager(task=task_pb2.Task())
65
60
  adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
66
- task_mgr.setup_task(lambda: adb_call_parser, log_stream=self._log_stream)
67
-
68
- self.assertEqual(init_task, task_mgr.task())
69
- self.assertEqual(0, task_mgr.stats()['task_updates'])
70
- task_mgr.update_task(new_task)
71
- self.assertEqual(new_task, task_mgr.task())
72
- self.assertEqual(1, task_mgr.stats()['task_updates'])
61
+ task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
62
+ self.assertIsNotNone(task_mgr._logcat_thread)
63
+ self.assertIsNotNone(task_mgr._dumpsys_thread)
64
+ self.assertIsNotNone(task_mgr._setup_step_interpreter)
73
65
 
74
66
  def test_setup_task(self):
75
67
  task_mgr = task_manager.TaskManager(task=task_pb2.Task())
76
68
  adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
77
- task_mgr.setup_task(lambda: adb_call_parser, log_stream=self._log_stream)
78
- self.assertIsNotNone(task_mgr._logcat_thread)
79
- self.assertIsNotNone(task_mgr._setup_step_interpreter)
69
+ task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
70
+ task_mgr.setup_task()
71
+ self._setup_step_interpreter.interpret.assert_called_once()
80
72
 
81
73
  def test_step_count(self):
82
74
  task_mgr = task_manager.TaskManager(task=task_pb2.Task())
83
75
  adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
84
- task_mgr.setup_task(lambda: adb_call_parser, log_stream=self._log_stream)
76
+ task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
77
+ task_mgr.setup_task()
85
78
  task_mgr.rl_reset(observation={})
86
79
  self.assertEqual(task_mgr.stats()['episode_steps'], 0)
87
80
  task_mgr.rl_step(observation={})
@@ -109,7 +102,8 @@ class TaskManagerTest(absltest.TestCase):
109
102
  task_mgr = task_manager.TaskManager(task=task)
110
103
  self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
111
104
  adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
112
- task_mgr.setup_task(lambda: adb_call_parser, log_stream=self._log_stream)
105
+ task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
106
+ task_mgr.setup_task()
113
107
  timestep = task_mgr.rl_step(
114
108
  observation={
115
109
  'pixels': np.array([1, 2, 3]),
@@ -144,7 +138,8 @@ class TaskManagerTest(absltest.TestCase):
144
138
  task_mgr = task_manager.TaskManager(task=task)
145
139
  self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
146
140
  adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
147
- task_mgr.setup_task(lambda: adb_call_parser, log_stream=self._log_stream)
141
+ task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
142
+ task_mgr.setup_task()
148
143
  timestep = task_mgr.rl_step(
149
144
  observation={
150
145
  'pixels': np.array([1, 2, 3]),
@@ -174,7 +169,8 @@ class TaskManagerTest(absltest.TestCase):
174
169
  task_mgr = task_manager.TaskManager(task=task)
175
170
  self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
176
171
  adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
177
- task_mgr.setup_task(lambda: adb_call_parser, log_stream=self._log_stream)
172
+ task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
173
+ task_mgr.setup_task()
178
174
  timestep = task_mgr.rl_step(
179
175
  observation={
180
176
  'pixels': np.array([1, 2, 3]),
@@ -207,7 +203,8 @@ class TaskManagerTest(absltest.TestCase):
207
203
  task_mgr = task_manager.TaskManager(task=task)
208
204
  self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
209
205
  adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
210
- task_mgr.setup_task(lambda: adb_call_parser, log_stream=self._log_stream)
206
+ task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
207
+ task_mgr.setup_task()
211
208
 
212
209
  timestep = task_mgr.rl_step(
213
210
  observation={
@@ -256,7 +253,8 @@ class TaskManagerTest(absltest.TestCase):
256
253
  task_mgr = task_manager.TaskManager(task=task)
257
254
  self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
258
255
  adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
259
- task_mgr.setup_task(lambda: adb_call_parser, log_stream=self._log_stream)
256
+ task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
257
+ task_mgr.setup_task()
260
258
 
261
259
  timestep = task_mgr.rl_step(
262
260
  observation={
@@ -286,6 +284,49 @@ class TaskManagerTest(absltest.TestCase):
286
284
  np.testing.assert_equal(
287
285
  expected_extra.get('extra_dict'), extras.get('extra_dict'))
288
286
 
287
+ def test_get_current_extras_failed_to_parse(self):
288
+ # Replace `LogcatThread.add_event_listener` with one that simply calls `fn`
289
+ # right away.
290
+ def my_add_ev_listener(event_listener: logcat_thread.EventListener):
291
+ # Check that the event matches what's expected.
292
+ event = event_listener.regexp
293
+ match = event.match('extra: some_extra [1, 2]')
294
+ if match is None: # Ignore events that are not extras.
295
+ return
296
+
297
+ # Emit events.
298
+ fn = event_listener.handler_fn
299
+ fn(event, event.match('extra: extra_with_malformed_1 [1]'))
300
+ fn(event, event.match('extra: extra_with_malformed_1 [\'this is \\ bad]'))
301
+ fn(event, event.match('extra: extra_with_malformed_1 [2]'))
302
+ fn(event, event.match('extra: extra_with_malformed_2 [\'this is bad]'))
303
+ fn(event, event.match('extra: extra_with_malformed_2 [1]'))
304
+ fn(event, event.match('extra: extra_malformed_only [_very_bad_news]'))
305
+
306
+ # Setup the task and trigger the listener.
307
+ task = task_pb2.Task()
308
+ task.log_parsing_config.log_regexps.extra.extend([
309
+ '^extra: (?P<name>[^ ]*)[ ]?(?P<extra>.*)$'
310
+ ])
311
+ task_mgr = task_manager.TaskManager(task=task)
312
+ self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
313
+ adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
314
+ task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
315
+ task_mgr.setup_task()
316
+
317
+ timestep = task_mgr.rl_step(
318
+ observation={
319
+ 'pixels': np.array([1, 2, 3]),
320
+ })
321
+
322
+ # Check expectations.
323
+ self.assertIn('extras', timestep.observation)
324
+ extras = timestep.observation['extras']
325
+ np.testing.assert_almost_equal(extras.get('extra_with_malformed_1'),
326
+ [[1], [2]])
327
+ np.testing.assert_almost_equal(extras.get('extra_with_malformed_2'), [[1]])
328
+ self.assertNotIn('extra_malformed_only', extras)
329
+
289
330
  def test_multi_log_regexp(self):
290
331
  # Replace `LogcatThread.add_event_listener` with one that simply calls `fn`
291
332
  # right away.
@@ -305,7 +346,8 @@ class TaskManagerTest(absltest.TestCase):
305
346
  task_mgr = task_manager.TaskManager(task=task)
306
347
  self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
307
348
  adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
308
- task_mgr.setup_task(lambda: adb_call_parser, log_stream=self._log_stream)
349
+ task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
350
+ task_mgr.setup_task()
309
351
  timestep = task_mgr.rl_step(
310
352
  observation={
311
353
  'pixels': np.array([1, 2, 3]),
@@ -335,7 +377,8 @@ class TaskManagerTest(absltest.TestCase):
335
377
  task_mgr = task_manager.TaskManager(task=task)
336
378
  self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
337
379
  adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
338
- task_mgr.setup_task(lambda: adb_call_parser, log_stream=self._log_stream)
380
+ task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
381
+ task_mgr.setup_task()
339
382
  timestep = task_mgr.rl_step(
340
383
  observation={
341
384
  'pixels': np.array([1, 2, 3]),
@@ -359,7 +402,8 @@ class TaskManagerTest(absltest.TestCase):
359
402
  task_mgr = task_manager.TaskManager(task=task)
360
403
  self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
361
404
  adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
362
- task_mgr.setup_task(lambda: adb_call_parser, log_stream=self._log_stream)
405
+ task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
406
+ task_mgr.setup_task()
363
407
  timestep = task_mgr.rl_step(
364
408
  observation={
365
409
  'pixels': np.array([1, 2, 3]),
@@ -367,38 +411,5 @@ class TaskManagerTest(absltest.TestCase):
367
411
  self.assertTrue(timestep.last())
368
412
 
369
413
 
370
- def test_setup_steps_skipped_if_no_change(self):
371
- setup_steps = [task_pb2.SetupStep(sleep=task_pb2.Sleep(time_sec=10))]
372
- task_mgr = task_manager.TaskManager(
373
- task=task_pb2.Task(id='first_task', setup_steps=setup_steps))
374
- adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
375
- task_mgr.setup_task(lambda: adb_call_parser, log_stream=self._log_stream)
376
- self.assertIsNotNone(task_mgr._logcat_thread)
377
- self.assertIsNotNone(task_mgr._setup_step_interpreter)
378
- self._setup_step_interpreter.interpret.assert_called_once_with(setup_steps)
379
- task_mgr.update_task(
380
- task=task_pb2.Task(id='second_task', setup_steps=setup_steps))
381
- # The interpreter should not be called again.
382
- task_mgr.setup_task(lambda: adb_call_parser, log_stream=self._log_stream)
383
- self._setup_step_interpreter.interpret.assert_called_once()
384
-
385
-
386
- def test_setup_steps_interpreted_if_changed(self):
387
- setup_steps = [task_pb2.SetupStep(sleep=task_pb2.Sleep(time_sec=10))]
388
- task_mgr = task_manager.TaskManager(
389
- task=task_pb2.Task(id='first_task', setup_steps=setup_steps))
390
- adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
391
- task_mgr.setup_task(lambda: adb_call_parser, log_stream=self._log_stream)
392
- self.assertIsNotNone(task_mgr._logcat_thread)
393
- self.assertIsNotNone(task_mgr._setup_step_interpreter)
394
- self._setup_step_interpreter.interpret.assert_called_once_with(setup_steps)
395
- new_setup_steps = [task_pb2.SetupStep(sleep=task_pb2.Sleep(time_sec=5))]
396
- task_mgr.update_task(
397
- task=task_pb2.Task(id='second_task', setup_steps=new_setup_steps))
398
- task_mgr.setup_task(lambda: adb_call_parser, log_stream=self._log_stream)
399
- self.assertEqual(2, self._setup_step_interpreter.interpret.num_calls())
400
- self._setup_step_interpreter.interpret.assert_has_calls(
401
- mock.call(setup_steps), mock.call(new_setup_steps))
402
-
403
414
  if __name__ == '__main__':
404
415
  absltest.main()
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2022 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -20,10 +20,10 @@ extra methods that clients may use for extended functionality.
20
20
  """
21
21
 
22
22
  import abc
23
- from typing import Any, Dict
23
+ from typing import Any
24
24
 
25
25
  from android_env.proto import adb_pb2
26
- from android_env.proto import task_pb2
26
+ from android_env.proto import state_pb2
27
27
  import dm_env
28
28
  import numpy as np
29
29
 
@@ -34,11 +34,11 @@ class AndroidEnvInterface(dm_env.Environment, metaclass=abc.ABCMeta):
34
34
  # Methods required by dm_env.Environment.
35
35
 
36
36
  @abc.abstractmethod
37
- def action_spec(self) -> Dict[str, dm_env.specs.Array]:
37
+ def action_spec(self) -> dict[str, dm_env.specs.Array]:
38
38
  """Returns the action specification."""
39
39
 
40
40
  @abc.abstractmethod
41
- def observation_spec(self) -> Dict[str, dm_env.specs.Array]:
41
+ def observation_spec(self) -> dict[str, dm_env.specs.Array]:
42
42
  """Returns the observation specification."""
43
43
 
44
44
  @abc.abstractmethod
@@ -46,7 +46,7 @@ class AndroidEnvInterface(dm_env.Environment, metaclass=abc.ABCMeta):
46
46
  """Resets the current episode."""
47
47
 
48
48
  @abc.abstractmethod
49
- def step(self, action: Dict[str, np.ndarray]) -> dm_env.TimeStep:
49
+ def step(self, action: dict[str, np.ndarray]) -> dm_env.TimeStep:
50
50
  """Executes `action` and returns a `TimeStep`."""
51
51
 
52
52
  @abc.abstractmethod
@@ -55,41 +55,55 @@ class AndroidEnvInterface(dm_env.Environment, metaclass=abc.ABCMeta):
55
55
 
56
56
  # Extensions provided by AndroidEnv.
57
57
 
58
- @abc.abstractmethod
59
- def task_extras_spec(self) -> Dict[str, dm_env.specs.Array]:
60
- """Returns the specification for extra info provided by tasks."""
61
-
62
- @abc.abstractmethod
63
- def task_extras(self, latest_only: bool = True) -> Dict[str, np.ndarray]:
58
+ def task_extras(self, latest_only: bool = True) -> dict[str, np.ndarray]:
64
59
  """Returns extra info provided by tasks."""
65
60
 
61
+ return {}
62
+
66
63
  @property
67
- @abc.abstractmethod
68
64
  def raw_action(self):
69
65
  """Returns the latest action."""
70
66
 
71
67
  @property
72
- @abc.abstractmethod
73
68
  def raw_observation(self):
74
69
  """Returns the latest observation."""
75
70
 
76
- @abc.abstractmethod
77
- def stats(self) -> Dict[str, Any]:
71
+ def stats(self) -> dict[str, Any]:
78
72
  """Returns information generated inside the implementation."""
79
73
 
80
- @abc.abstractmethod
74
+ return {}
75
+
81
76
  def execute_adb_call(self, call: adb_pb2.AdbRequest) -> adb_pb2.AdbResponse:
82
77
  """Executes `call` and returns its response."""
83
78
 
84
- @abc.abstractmethod
85
- def update_task(self, task: task_pb2.Task) -> bool:
86
- """Replaces the current task with a new task.
79
+ return adb_pb2.AdbResponse()
80
+
81
+ def load_state(
82
+ self, request: state_pb2.LoadStateRequest
83
+ ) -> state_pb2.LoadStateResponse:
84
+ """Loads a state.
85
+
86
+ Args:
87
+ request: A `LoadStateRequest` containing any parameters necessary to
88
+ specify how/what state to load.
89
+
90
+ Returns:
91
+ A `LoadStateResponse` containing the status, error message (if
92
+ applicable), and any other relevant information.
93
+ """
94
+ raise NotImplementedError('This environment does not support loading state')
87
95
 
88
- It is the caller's responsibility to call `reset()` after the task update.
96
+ def save_state(
97
+ self, request: state_pb2.SaveStateRequest
98
+ ) -> state_pb2.SaveStateResponse:
99
+ """Saves a state.
89
100
 
90
101
  Args:
91
- task: A new task to replace the current one.
102
+ request: A `SaveStateRequest` containing any parameters necessary to
103
+ specify how/what state to save.
92
104
 
93
105
  Returns:
94
- A bool indicating the success of the task setup.
106
+ A `SaveStateResponse` containing the status, error message (if
107
+ applicable), and any other relevant information.
95
108
  """
109
+ raise NotImplementedError('This environment does not support saving state')
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2022 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -15,12 +15,16 @@
15
15
 
16
16
  """Android environment implementation."""
17
17
 
18
- from typing import Any, Dict
18
+ from typing import Any
19
+
19
20
  from absl import logging
20
21
  from android_env import env_interface
22
+ from android_env.components import adb_call_parser
21
23
  from android_env.components import coordinator as coordinator_lib
24
+ from android_env.components import task_manager as task_manager_lib
25
+ from android_env.components.simulators import base_simulator
22
26
  from android_env.proto import adb_pb2
23
- from android_env.proto import task_pb2
27
+ from android_env.proto import state_pb2
24
28
  import dm_env
25
29
  import numpy as np
26
30
 
@@ -28,10 +32,17 @@ import numpy as np
28
32
  class AndroidEnv(env_interface.AndroidEnvInterface):
29
33
  """An RL environment that interacts with Android apps."""
30
34
 
31
- def __init__(self, coordinator: coordinator_lib.Coordinator):
35
+ def __init__(
36
+ self,
37
+ simulator: base_simulator.BaseSimulator,
38
+ coordinator: coordinator_lib.Coordinator,
39
+ task_manager: task_manager_lib.TaskManager,
40
+ ):
32
41
  """Initializes the state of this AndroidEnv object."""
33
42
 
43
+ self._simulator = simulator
34
44
  self._coordinator = coordinator
45
+ self._task_manager = task_manager
35
46
  self._latest_action = {}
36
47
  self._latest_observation = {}
37
48
  self._latest_extras = {}
@@ -40,27 +51,17 @@ class AndroidEnv(env_interface.AndroidEnvInterface):
40
51
 
41
52
  logging.info('Action spec: %s', self.action_spec())
42
53
  logging.info('Observation spec: %s', self.observation_spec())
43
- logging.info('Task extras spec: %s', self.task_extras_spec())
44
-
45
- def action_spec(self) -> Dict[str, dm_env.specs.Array]:
46
- return self._coordinator.action_spec()
47
54
 
48
- def observation_spec(self) -> Dict[str, dm_env.specs.Array]:
49
- return self._coordinator.observation_spec()
50
-
51
- def task_extras_spec(self) -> Dict[str, dm_env.specs.Array]:
52
- return self._coordinator.task_extras_spec()
55
+ def __del__(self) -> None:
56
+ self.close()
53
57
 
54
- @property
55
- def raw_action(self):
56
- return self._latest_action.copy()
58
+ # Methods required by dm_env.Environment.
57
59
 
58
- @property
59
- def raw_observation(self):
60
- return self._latest_observation.copy()
60
+ def action_spec(self) -> dict[str, dm_env.specs.Array]:
61
+ return self._coordinator.action_spec()
61
62
 
62
- def stats(self) -> Dict[str, Any]:
63
- return self._coordinator.stats()
63
+ def observation_spec(self) -> dict[str, dm_env.specs.Array]:
64
+ return self._coordinator.observation_spec()
64
65
 
65
66
  def reset(self) -> dm_env.TimeStep:
66
67
  """Resets the environment for a new RL episode."""
@@ -86,7 +87,7 @@ class AndroidEnv(env_interface.AndroidEnvInterface):
86
87
 
87
88
  return timestep
88
89
 
89
- def step(self, action: Dict[str, np.ndarray]) -> dm_env.TimeStep:
90
+ def step(self, action: dict[str, np.ndarray]) -> dm_env.TimeStep:
90
91
  """Takes a step in the environment."""
91
92
 
92
93
  # Check if it's time to reset the episode.
@@ -112,47 +113,78 @@ class AndroidEnv(env_interface.AndroidEnvInterface):
112
113
 
113
114
  return timestep
114
115
 
115
- def task_extras(self, latest_only: bool = True) -> Dict[str, np.ndarray]:
116
+ def close(self) -> None:
117
+ """Cleans up running processes, threads and local files."""
118
+ if not self._is_closed:
119
+ logging.info('Cleaning up AndroidEnv...')
120
+ if hasattr(self, '_coordinator'):
121
+ self._coordinator.close()
122
+ logging.info('Done cleaning up AndroidEnv.')
123
+ self._is_closed = True
124
+
125
+ # Extensions provided by AndroidEnv.
126
+
127
+ def task_extras(self, latest_only: bool = True) -> dict[str, np.ndarray]:
116
128
  """Returns latest task extras."""
117
129
 
118
- task_extras = {}
119
- for key, spec in self.task_extras_spec().items():
120
- if key in self._latest_extras:
121
- extra_values = self._latest_extras[key].astype(spec.dtype)
122
- for extra in extra_values:
123
- self._validate_type(spec, extra)
124
- task_extras[key] = extra_values[-1] if latest_only else extra_values
130
+ task_extras = {} # Build a copy to avoid reusing objects.
131
+ for k, spec in self._latest_extras.items():
132
+ extra_values = spec.astype(spec.dtype)
133
+ task_extras[k] = extra_values[-1] if latest_only else extra_values
125
134
  return task_extras
126
135
 
127
- def _validate_type(self, spec: dm_env.specs.Array, data: np.ndarray):
128
- # Handling of spec data type of string of arbitary length.
129
- if spec.dtype.str == '<U0' and data.dtype.str.startswith(
130
- '<U') and spec.shape == data.shape:
131
- return
132
- spec.validate(data)
136
+ @property
137
+ def raw_action(self):
138
+ return self._latest_action.copy()
139
+
140
+ @property
141
+ def raw_observation(self):
142
+ return self._latest_observation.copy()
143
+
144
+ def stats(self) -> dict[str, Any]:
145
+ coordinator_stats = self._coordinator.stats()
146
+ task_manager_stats = self._task_manager.stats()
147
+ return coordinator_stats | task_manager_stats
133
148
 
134
149
  def execute_adb_call(self, call: adb_pb2.AdbRequest) -> adb_pb2.AdbResponse:
135
150
  return self._coordinator.execute_adb_call(call)
136
151
 
137
- def update_task(self, task: task_pb2.Task) -> bool:
138
- """Replaces the current task with a new task.
152
+ def load_state(
153
+ self, request: state_pb2.LoadStateRequest
154
+ ) -> state_pb2.LoadStateResponse:
155
+ """Loads a state.
139
156
 
140
157
  Args:
141
- task: A new task to replace the current one.
158
+ request: A `LoadStateRequest` containing any parameters necessary to
159
+ specify how/what state to load.
142
160
 
143
161
  Returns:
144
- A bool indicating the success of the task setup.
162
+ A `LoadStateResponse` containing the status, error message (if
163
+ applicable), and any other relevant information.
145
164
  """
146
- return self._coordinator.update_task(task)
147
165
 
148
- def close(self) -> None:
149
- """Cleans up running processes, threads and local files."""
150
- if not self._is_closed:
151
- logging.info('Cleaning up AndroidEnv...')
152
- if hasattr(self, '_coordinator'):
153
- self._coordinator.close()
154
- logging.info('Done cleaning up AndroidEnv.')
155
- self._is_closed = True
166
+ self._task_manager.stop()
167
+ response = self._simulator.load_state(request)
168
+ self._task_manager.start(
169
+ adb_call_parser_factory=lambda: adb_call_parser.AdbCallParser(
170
+ self._simulator.create_adb_controller()
171
+ ),
172
+ log_stream=self._simulator.create_log_stream(),
173
+ )
174
+ return response
175
+
176
+ def save_state(
177
+ self, request: state_pb2.SaveStateRequest
178
+ ) -> state_pb2.SaveStateResponse:
179
+ """Saves a state.
156
180
 
157
- def __del__(self) -> None:
158
- self.close()
181
+ Args:
182
+ request: A `SaveStateRequest` containing any parameters necessary to
183
+ specify how/what state to save.
184
+
185
+ Returns:
186
+ A `SaveStateResponse` containing the status, error message (if
187
+ applicable), and any other relevant information.
188
+ """
189
+
190
+ return self._simulator.save_state(request)