android-env 1.2.1__py3-none-any.whl → 1.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. android_env/__init__.py +1 -1
  2. android_env/components/__init__.py +1 -1
  3. android_env/components/a11y/__init__.py +15 -0
  4. android_env/components/a11y/a11y_events.py +118 -0
  5. android_env/components/a11y/a11y_events_test.py +173 -0
  6. android_env/components/a11y/a11y_forests.py +128 -0
  7. android_env/components/a11y/a11y_forests_test.py +237 -0
  8. android_env/components/a11y/a11y_servicer.py +199 -0
  9. android_env/components/a11y/a11y_servicer_test.py +224 -0
  10. android_env/components/action_fns.py +132 -0
  11. android_env/components/action_fns_test.py +227 -0
  12. android_env/components/action_type.py +26 -3
  13. android_env/components/adb_call_parser.py +239 -196
  14. android_env/components/adb_call_parser_test.py +179 -209
  15. android_env/components/adb_controller.py +90 -52
  16. android_env/components/adb_controller_test.py +187 -16
  17. android_env/components/adb_log_stream.py +17 -5
  18. android_env/components/adb_log_stream_test.py +17 -3
  19. android_env/components/app_screen_checker.py +17 -15
  20. android_env/components/app_screen_checker_test.py +7 -8
  21. android_env/components/config_classes.py +203 -0
  22. android_env/components/coordinator.py +102 -338
  23. android_env/components/coordinator_test.py +59 -199
  24. android_env/components/device_settings.py +174 -0
  25. android_env/components/device_settings_test.py +228 -0
  26. android_env/components/dumpsys_thread.py +3 -4
  27. android_env/components/dumpsys_thread_test.py +1 -1
  28. android_env/components/errors.py +52 -10
  29. android_env/components/errors_test.py +110 -0
  30. android_env/components/log_stream.py +7 -5
  31. android_env/components/log_stream_test.py +1 -1
  32. android_env/components/logcat_thread.py +9 -8
  33. android_env/components/logcat_thread_test.py +3 -4
  34. android_env/components/{utils.py → pixel_fns.py} +20 -20
  35. android_env/components/{utils_test.py → pixel_fns_test.py} +20 -15
  36. android_env/components/setup_step_interpreter.py +47 -39
  37. android_env/components/setup_step_interpreter_test.py +4 -4
  38. android_env/components/simulators/__init__.py +1 -1
  39. android_env/components/simulators/base_simulator.py +116 -44
  40. android_env/components/simulators/base_simulator_test.py +131 -9
  41. android_env/components/simulators/emulator/__init__.py +1 -1
  42. android_env/components/simulators/emulator/emulator_launcher.py +67 -77
  43. android_env/components/simulators/emulator/emulator_launcher_test.py +153 -49
  44. android_env/components/simulators/emulator/emulator_simulator.py +276 -95
  45. android_env/components/simulators/emulator/emulator_simulator_test.py +314 -89
  46. android_env/components/simulators/fake/__init__.py +1 -1
  47. android_env/components/simulators/fake/fake_simulator.py +17 -25
  48. android_env/components/simulators/fake/fake_simulator_test.py +29 -12
  49. android_env/components/specs.py +18 -28
  50. android_env/components/specs_test.py +1 -44
  51. android_env/components/task_manager.py +48 -48
  52. android_env/components/task_manager_test.py +71 -60
  53. android_env/env_interface.py +37 -23
  54. android_env/environment.py +83 -51
  55. android_env/environment_test.py +68 -29
  56. android_env/loader.py +57 -43
  57. android_env/loader_test.py +115 -35
  58. android_env/proto/__init__.py +1 -1
  59. android_env/proto/a11y/__init__.py +15 -0
  60. android_env/proto/a11y/a11y.proto +75 -0
  61. android_env/proto/a11y/a11y_pb2.py +54 -0
  62. android_env/proto/a11y/a11y_pb2.pyi +49 -0
  63. android_env/proto/a11y/a11y_pb2_grpc.py +202 -0
  64. android_env/proto/a11y/android_accessibility_action.proto +32 -0
  65. android_env/proto/a11y/android_accessibility_action_pb2.py +37 -0
  66. android_env/proto/a11y/android_accessibility_action_pb2.pyi +13 -0
  67. android_env/proto/a11y/android_accessibility_action_pb2_grpc.py +24 -0
  68. android_env/proto/a11y/android_accessibility_forest.proto +29 -0
  69. android_env/proto/a11y/android_accessibility_forest_pb2.py +38 -0
  70. android_env/proto/a11y/android_accessibility_forest_pb2.pyi +13 -0
  71. android_env/proto/a11y/android_accessibility_forest_pb2_grpc.py +24 -0
  72. android_env/proto/a11y/android_accessibility_node_info.proto +122 -0
  73. android_env/proto/a11y/android_accessibility_node_info_clickable_span.proto +49 -0
  74. android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2.py +39 -0
  75. android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2.pyi +28 -0
  76. android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2_grpc.py +24 -0
  77. android_env/proto/a11y/android_accessibility_node_info_pb2.py +42 -0
  78. android_env/proto/a11y/android_accessibility_node_info_pb2.pyi +75 -0
  79. android_env/proto/a11y/android_accessibility_node_info_pb2_grpc.py +24 -0
  80. android_env/proto/a11y/android_accessibility_tree.proto +29 -0
  81. android_env/proto/a11y/android_accessibility_tree_pb2.py +38 -0
  82. android_env/proto/a11y/android_accessibility_tree_pb2.pyi +13 -0
  83. android_env/proto/a11y/android_accessibility_tree_pb2_grpc.py +24 -0
  84. android_env/proto/a11y/android_accessibility_window_info.proto +84 -0
  85. android_env/proto/a11y/android_accessibility_window_info_pb2.py +41 -0
  86. android_env/proto/a11y/android_accessibility_window_info_pb2.pyi +48 -0
  87. android_env/proto/a11y/android_accessibility_window_info_pb2_grpc.py +24 -0
  88. android_env/proto/a11y/rect.proto +30 -0
  89. android_env/proto/a11y/rect_pb2.py +37 -0
  90. android_env/proto/a11y/rect_pb2.pyi +17 -0
  91. android_env/proto/a11y/rect_pb2_grpc.py +24 -0
  92. android_env/proto/adb.proto +17 -6
  93. android_env/proto/adb_pb2.py +120 -107
  94. android_env/proto/adb_pb2.pyi +396 -0
  95. android_env/proto/adb_pb2_grpc.py +20 -0
  96. android_env/proto/emulator_controller.proto +68 -63
  97. android_env/proto/emulator_controller_pb2.py +142 -131
  98. android_env/proto/emulator_controller_pb2.pyi +672 -0
  99. android_env/proto/emulator_controller_pb2_grpc.py +505 -142
  100. android_env/proto/snapshot.proto +169 -0
  101. android_env/proto/snapshot_pb2.py +47 -0
  102. android_env/proto/snapshot_pb2.pyi +117 -0
  103. android_env/proto/snapshot_pb2_grpc.py +24 -0
  104. android_env/proto/snapshot_service.proto +289 -0
  105. android_env/proto/snapshot_service_pb2.py +54 -0
  106. android_env/proto/snapshot_service_pb2.pyi +86 -0
  107. android_env/proto/snapshot_service_pb2_grpc.py +487 -0
  108. android_env/proto/state.proto +63 -0
  109. android_env/proto/state_pb2.py +63 -0
  110. android_env/proto/state_pb2.pyi +85 -0
  111. android_env/proto/state_pb2_grpc.py +24 -0
  112. android_env/proto/task.proto +5 -1
  113. android_env/proto/task_pb2.py +42 -31
  114. android_env/proto/task_pb2.pyi +160 -0
  115. android_env/proto/task_pb2_grpc.py +20 -0
  116. android_env/wrappers/__init__.py +1 -1
  117. android_env/wrappers/a11y_grpc_wrapper.py +500 -0
  118. android_env/wrappers/a11y_grpc_wrapper_test.py +849 -0
  119. android_env/wrappers/base_wrapper.py +34 -13
  120. android_env/wrappers/base_wrapper_test.py +22 -16
  121. android_env/wrappers/discrete_action_wrapper.py +18 -17
  122. android_env/wrappers/discrete_action_wrapper_test.py +4 -4
  123. android_env/wrappers/flat_interface_wrapper.py +5 -5
  124. android_env/wrappers/flat_interface_wrapper_test.py +7 -11
  125. android_env/wrappers/float_pixels_wrapper.py +9 -10
  126. android_env/wrappers/float_pixels_wrapper_test.py +3 -3
  127. android_env/wrappers/gym_wrapper.py +19 -13
  128. android_env/wrappers/gym_wrapper_test.py +3 -5
  129. android_env/wrappers/image_rescale_wrapper.py +18 -21
  130. android_env/wrappers/image_rescale_wrapper_test.py +25 -37
  131. android_env/wrappers/last_action_wrapper.py +16 -13
  132. android_env/wrappers/last_action_wrapper_test.py +44 -51
  133. android_env/wrappers/rate_limit_wrapper.py +6 -3
  134. android_env/wrappers/rate_limit_wrapper_test.py +22 -1
  135. android_env/wrappers/tap_action_wrapper.py +16 -17
  136. android_env/wrappers/tap_action_wrapper_test.py +51 -16
  137. {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info}/METADATA +14 -18
  138. android_env-1.2.3.dist-info/RECORD +141 -0
  139. {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info}/WHEEL +1 -1
  140. android_env/proto/raw_observation.proto +0 -39
  141. android_env/proto/raw_observation_pb2.py +0 -27
  142. android_env/proto/raw_observation_pb2_grpc.py +0 -4
  143. android_env-1.2.1.dist-info/RECORD +0 -81
  144. {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info/licenses}/LICENSE +0 -0
  145. {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2022 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -17,6 +17,7 @@
17
17
 
18
18
  import re
19
19
  from absl.testing import absltest
20
+ from android_env.components import config_classes
20
21
  from android_env.components.simulators.fake import fake_simulator
21
22
  import numpy as np
22
23
 
@@ -24,18 +25,24 @@ import numpy as np
24
25
  class FakeSimulatorTest(absltest.TestCase):
25
26
 
26
27
  def test_device_name(self):
27
- simulator = fake_simulator.FakeSimulator(screen_dimensions=(320, 480))
28
+ simulator = fake_simulator.FakeSimulator(
29
+ config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480))
30
+ )
28
31
  self.assertEqual(simulator.adb_device_name(), 'fake_simulator')
29
32
 
30
33
  def test_launch_close(self):
31
34
  # The simulator should launch and not crash.
32
- simulator = fake_simulator.FakeSimulator(screen_dimensions=(320, 480))
35
+ simulator = fake_simulator.FakeSimulator(
36
+ config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480))
37
+ )
33
38
  simulator.launch()
34
39
  # Closing the simulator should also not crash.
35
40
  simulator.close()
36
41
 
37
42
  def test_get_screenshot(self):
38
- simulator = fake_simulator.FakeSimulator(screen_dimensions=(320, 480))
43
+ simulator = fake_simulator.FakeSimulator(
44
+ config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480))
45
+ )
39
46
  simulator.launch()
40
47
 
41
48
  screenshot = simulator.get_screenshot()
@@ -43,7 +50,9 @@ class FakeSimulatorTest(absltest.TestCase):
43
50
  np.testing.assert_equal(screenshot.dtype, np.uint8)
44
51
 
45
52
  def test_log_stream(self):
46
- simulator = fake_simulator.FakeSimulator(screen_dimensions=(320, 480))
53
+ simulator = fake_simulator.FakeSimulator(
54
+ config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480))
55
+ )
47
56
  simulator.launch()
48
57
  log_stream = simulator.create_log_stream()
49
58
  # Start yielding lines from LogStream.
@@ -61,12 +70,16 @@ class FakeSimulatorTest(absltest.TestCase):
61
70
  break
62
71
 
63
72
  def test_adb_output(self):
64
- simulator = fake_simulator.FakeSimulator(screen_dimensions=(320, 480))
73
+ simulator = fake_simulator.FakeSimulator(
74
+ config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480))
75
+ )
65
76
  simulator.launch()
66
77
  adb_controller = simulator.create_adb_controller()
67
78
  line = adb_controller.execute_command(['shell', 'dumpsys', 'input'])
68
79
  line = line.decode('utf-8')
69
- orientation = re.match(r'\s+SurfaceOrientation:\s+(\d)', line).group(1)
80
+ matches = re.match(r'\s+SurfaceOrientation:\s+(\d)', line)
81
+ self.assertIsNotNone(matches)
82
+ orientation = matches.group(1)
70
83
  self.assertEqual(orientation, '0')
71
84
  line = adb_controller.execute_command(['shell', 'service', 'check', 'foo'])
72
85
  line = line.decode('utf-8')
@@ -77,7 +90,9 @@ class FakeSimulatorTest(absltest.TestCase):
77
90
  'topActivity=ComponentInfo{fake_activity}')
78
91
 
79
92
  def test_send_touch(self):
80
- simulator = fake_simulator.FakeSimulator(screen_dimensions=(320, 480))
93
+ simulator = fake_simulator.FakeSimulator(
94
+ config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480))
95
+ )
81
96
  simulator.launch()
82
97
  simulator.send_touch([(0, 1, True, 0)])
83
98
  simulator.send_touch([(0, 1, False, 0)])
@@ -85,11 +100,13 @@ class FakeSimulatorTest(absltest.TestCase):
85
100
  # without crashing anything.
86
101
 
87
102
  def test_send_key(self):
88
- simulator = fake_simulator.FakeSimulator(screen_dimensions=(320, 480))
103
+ simulator = fake_simulator.FakeSimulator(
104
+ config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480))
105
+ )
89
106
  simulator.launch()
90
- simulator.send_key(123, 'keydown')
91
- simulator.send_key(123, 'keyup')
92
- simulator.send_key(123, 'keypress')
107
+ simulator.send_key(np.int32(123), 'keydown')
108
+ simulator.send_key(np.int32(123), 'keyup')
109
+ simulator.send_key(np.int32(123), 'keypress')
93
110
  # No assertions, we just want to ensure that `send_key()` can be called
94
111
  # without crashing anything.
95
112
 
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2022 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -15,11 +15,8 @@
15
15
 
16
16
  """Base specs for AndroidEnv."""
17
17
 
18
- from typing import Dict
19
-
20
18
  from android_env.components import action_type
21
19
  from android_env.proto import task_pb2
22
- import dm_env
23
20
  from dm_env import specs
24
21
  import numpy as np
25
22
 
@@ -41,11 +38,13 @@ _PROTO_DTYPE_TO_NUMPY_DTYPE = {
41
38
  task_pb2.ArraySpec.DataType.STRING_U25: np.dtype(('<U25')),
42
39
  task_pb2.ArraySpec.DataType.STRING_U250: np.dtype(('<U250')),
43
40
  task_pb2.ArraySpec.DataType.STRING: np.dtype(('<U0')),
41
+ task_pb2.ArraySpec.DataType.OBJECT: np.dtype('O'),
44
42
  }
45
43
 
46
44
 
47
- def base_action_spec(num_fingers: int = 1,
48
- enable_key_events: bool = False) -> Dict[str, specs.Array]:
45
+ def base_action_spec(
46
+ num_fingers: int = 1, enable_key_events: bool = False
47
+ ) -> dict[str, specs.Array]:
49
48
  """Default action spec for AndroidEnv.
50
49
 
51
50
  Args:
@@ -58,6 +57,8 @@ def base_action_spec(num_fingers: int = 1,
58
57
  touch_position: Position [x, y] of the touch action, where x, y are float
59
58
  values between 0.0 and 1.0 corresponding to the relative position on the
60
59
  screen. IGNORED when (action_type != ActionType.TOUCH).
60
+ keycode: code representing the desired key press in XKB format. See the
61
+ emulator_controller_pb2 for details.
61
62
  action_type_i: Action type for additional fingers (i>1).
62
63
  touch_position_i: Touch position for additional fingers (i>1).
63
64
  """
@@ -98,7 +99,7 @@ def base_action_spec(num_fingers: int = 1,
98
99
  return action_spec
99
100
 
100
101
 
101
- def base_observation_spec(height: int, width: int) -> Dict[str, specs.Array]:
102
+ def base_observation_spec(height: int, width: int) -> dict[str, specs.Array]:
102
103
  """Default observation spec for AndroidEnv.
103
104
 
104
105
  Args:
@@ -119,30 +120,19 @@ def base_observation_spec(height: int, width: int) -> Dict[str, specs.Array]:
119
120
 
120
121
  return {
121
122
  'pixels':
122
- specs.Array(
123
+ specs.BoundedArray(
123
124
  shape=(height, width, 3),
124
125
  dtype=np.uint8,
125
- name='pixels'),
126
+ name='pixels',
127
+ minimum=0,
128
+ maximum=255),
126
129
  'timedelta':
127
130
  specs.Array(shape=(), dtype=np.int64, name='timedelta'),
128
131
  'orientation':
129
- specs.Array(shape=np.array([4]), dtype=np.uint8, name='orientation'),
130
- }
131
-
132
-
133
- def base_task_extras_spec(task: task_pb2.Task) -> Dict[str, dm_env.specs.Array]:
134
- """Task extras spec for AndroidEnv, as read from a task_pb2.Task."""
135
-
136
- return {
137
- spec.name: _convert_spec(spec)
138
- for spec in task.extras_spec
132
+ specs.BoundedArray(
133
+ shape=np.array([4]),
134
+ dtype=np.uint8,
135
+ name='orientation',
136
+ minimum=0,
137
+ maximum=1),
139
138
  }
140
-
141
-
142
- def _convert_spec(array_spec: task_pb2.ArraySpec) -> specs.Array:
143
- """Converts ArraySpec proto to dm_env specs.Array."""
144
-
145
- return specs.Array(
146
- shape=array_spec.shape,
147
- dtype=_PROTO_DTYPE_TO_NUMPY_DTYPE[array_spec.dtype],
148
- name=array_spec.name)
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2022 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -79,49 +79,6 @@ class SpecsTest(parameterized.TestCase):
79
79
  self.assertEqual(observation_spec['orientation'].shape, (4,))
80
80
  self.assertEqual(observation_spec['orientation'].dtype, np.uint8)
81
81
 
82
- def test_base_task_extras_spec(self):
83
- array_spec_1 = task_pb2.ArraySpec()
84
- array_spec_1.name = 'my_extra_1'
85
- array_spec_1.shape.extend([10, 10])
86
- array_spec_1.dtype = task_pb2.ArraySpec.FLOAT
87
-
88
- array_spec_2 = task_pb2.ArraySpec()
89
- array_spec_2.name = 'my_extra_2'
90
- array_spec_2.shape.extend([1])
91
- array_spec_2.dtype = task_pb2.ArraySpec.UINT8
92
-
93
- fake_task = task_pb2.Task()
94
- fake_task.extras_spec.extend([array_spec_1, array_spec_2])
95
- task_extras_spec = specs.base_task_extras_spec(fake_task)
96
- for spec in task_extras_spec.values():
97
- self.assertIsInstance(spec, dm_env_specs.Array)
98
-
99
- self.assertEqual(task_extras_spec['my_extra_1'].shape, (10, 10))
100
- self.assertEqual(task_extras_spec['my_extra_2'].shape, (1,))
101
- self.assertEqual(task_extras_spec['my_extra_1'].dtype, np.float32)
102
- self.assertEqual(task_extras_spec['my_extra_2'].dtype, np.uint8)
103
-
104
- @parameterized.parameters(
105
- ('name_1', [480, 320, 3], task_pb2.ArraySpec.FLOAT, np.float32),
106
- ('name_2', [100, 100, 3], task_pb2.ArraySpec.INT32, np.int32),
107
- ('name_3', [123, 456, 3], task_pb2.ArraySpec.UINT8, np.uint8),
108
- ('name_4', [480, 320, 1], task_pb2.ArraySpec.BOOL, np.bool_),
109
- ('', [480, 320], task_pb2.ArraySpec.STRING_U25, np.dtype(('<U25'))),
110
- )
111
- def test_convert_spec(self, name, shape, dtype, expected_dtype):
112
- fake_array_spec = task_pb2.ArraySpec()
113
- fake_array_spec.name = name
114
- fake_array_spec.dtype = dtype
115
- fake_array_spec.shape.extend(shape)
116
- fake_task = task_pb2.Task()
117
- fake_task.extras_spec.extend([fake_array_spec])
118
- task_extras_spec = specs.base_task_extras_spec(fake_task)
119
- for spec in task_extras_spec.values():
120
- self.assertIsInstance(spec, dm_env_specs.Array)
121
-
122
- self.assertEqual(task_extras_spec[name].shape, tuple(shape))
123
- self.assertEqual(task_extras_spec[name].dtype, expected_dtype)
124
-
125
82
 
126
83
  if __name__ == '__main__':
127
84
  absltest.main()
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2022 DeepMind Technologies Limited.
2
+ # Copyright 2024 DeepMind Technologies Limited.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -16,16 +16,18 @@
16
16
  """TaskManager handles all events and information related to the task."""
17
17
 
18
18
  import ast
19
+ from collections.abc import Callable
19
20
  import copy
20
21
  import datetime
21
22
  import json
22
23
  import re
23
24
  import threading
24
- from typing import Any, Callable, Dict
25
+ from typing import Any
25
26
 
26
27
  from absl import logging
27
28
  from android_env.components import adb_call_parser as adb_call_parser_lib
28
29
  from android_env.components import app_screen_checker
30
+ from android_env.components import config_classes
29
31
  from android_env.components import dumpsys_thread
30
32
  from android_env.components import log_stream as log_stream_lib
31
33
  from android_env.components import logcat_thread
@@ -35,34 +37,24 @@ import dm_env
35
37
  import numpy as np
36
38
 
37
39
 
38
- class TaskManager():
40
+ class TaskManager:
39
41
  """Handles all events and information related to the task."""
40
42
 
41
43
  def __init__(
42
44
  self,
43
45
  task: task_pb2.Task,
44
- max_bad_states: int = 3,
45
- dumpsys_check_frequency: int = 150,
46
- max_failed_current_activity: int = 10,
46
+ config: config_classes.TaskManagerConfig | None = None,
47
47
  ):
48
48
  """Controls task-relevant events and information.
49
49
 
50
50
  Args:
51
51
  task: A task proto defining the RL task.
52
- max_bad_states: How many bad states in a row are allowed before a restart
53
- of the simulator is triggered.
54
- dumpsys_check_frequency: Frequency, in steps, at which to check
55
- current_activity and view hierarchy
56
- max_failed_current_activity: The maximum number of tries for extracting
57
- the current activity before forcing the episode to restart.
52
+ config: Configuration for this instance.
58
53
  """
59
- self._task = task
60
- self._max_bad_states = max_bad_states
61
- self._dumpsys_check_frequency = dumpsys_check_frequency
62
- self._max_failed_current_activity = max_failed_current_activity
63
54
 
55
+ self._task = task
56
+ self._config = config or config_classes.TaskManagerConfig()
64
57
  self._lock = threading.Lock()
65
- self._extras_max_buffer_size = 100
66
58
  self._logcat_thread = None
67
59
  self._dumpsys_thread = None
68
60
  self._setup_step_interpreter = None
@@ -92,14 +84,7 @@ class TaskManager():
92
84
 
93
85
  logging.info('Task config: %s', self._task)
94
86
 
95
- def task(self) -> task_pb2.Task:
96
- return self._task
97
-
98
- def update_task(self, task: task_pb2.Task) -> None:
99
- self._stats['task_updates'] += 1
100
- self._task = task
101
-
102
- def stats(self) -> Dict[str, Any]:
87
+ def stats(self) -> dict[str, Any]:
103
88
  """Returns a dictionary of stats.
104
89
 
105
90
  This method is expected to be called after setup_task() has been called.
@@ -109,21 +94,24 @@ class TaskManager():
109
94
  output.update(self._setup_step_interpreter.stats())
110
95
  return output
111
96
 
112
- def setup_task(
97
+ def setup_task(self) -> None:
98
+ """Performs one-off task setup.."""
99
+ self._setup_step_interpreter.interpret(self._task.setup_steps)
100
+
101
+ def stop(self) -> None:
102
+ """Suspends task processing."""
103
+ self._stop_logcat_thread()
104
+
105
+ def start(
113
106
  self,
114
107
  adb_call_parser_factory: Callable[[], adb_call_parser_lib.AdbCallParser],
115
108
  log_stream: log_stream_lib.LogStream) -> None:
116
- """Performs one-off task setup.."""
109
+ """Starts task processing."""
117
110
 
118
111
  self._start_logcat_thread(log_stream=log_stream)
112
+ self._logcat_thread.resume()
119
113
  self._start_dumpsys_thread(adb_call_parser_factory())
120
114
  self._start_setup_step_interpreter(adb_call_parser_factory())
121
- self._setup_step_interpreter.interpret(self._task.setup_steps)
122
-
123
- def stop_task(self) -> None:
124
- """Suspends task processing."""
125
-
126
- self._stop_logcat_thread()
127
115
 
128
116
  def reset_task(self) -> None:
129
117
  """Resets a task for a new run."""
@@ -146,7 +134,7 @@ class TaskManager():
146
134
  'episode_end': False,
147
135
  }
148
136
 
149
- def rl_reset(self, observation: Dict[str, Any]) -> dm_env.TimeStep:
137
+ def rl_reset(self, observation: dict[str, Any]) -> dm_env.TimeStep:
150
138
  """Performs one RL step."""
151
139
 
152
140
  self._stats['episode_steps'] = 0
@@ -163,7 +151,7 @@ class TaskManager():
163
151
  discount=0.0,
164
152
  observation=observation)
165
153
 
166
- def rl_step(self, observation: Dict[str, Any]) -> dm_env.TimeStep:
154
+ def rl_step(self, observation: dict[str, Any]) -> dm_env.TimeStep:
167
155
  """Performs one RL step."""
168
156
 
169
157
  self._stats['episode_steps'] += 1
@@ -184,7 +172,7 @@ class TaskManager():
184
172
  self._latest_values['reward'] = 0.0
185
173
  return reward
186
174
 
187
- def _get_current_extras(self) -> Dict[str, Any]:
175
+ def _get_current_extras(self) -> dict[str, Any]:
188
176
  """Returns task extras accumulated since the last step."""
189
177
  extras = {}
190
178
  for name, values in self._latest_values['extra'].items():
@@ -213,8 +201,8 @@ class TaskManager():
213
201
  if self._task.max_episode_steps > 0:
214
202
  if self._stats['episode_steps'] > self._task.max_episode_steps:
215
203
  self._stats['reset_count_max_duration_reached'] += 1
216
- logging.info('Maximum task duration (steps) reached. '
217
- 'Truncating the episode.')
204
+ logging.info('Maximum task duration (%r steps) reached. '
205
+ 'Truncating the episode.', self._task.max_episode_steps)
218
206
  return dm_env.truncation
219
207
 
220
208
  if self._task.max_episode_sec > 0.0:
@@ -222,8 +210,8 @@ class TaskManager():
222
210
  max_episode_sec = self._task.max_episode_sec
223
211
  if task_duration > datetime.timedelta(seconds=int(max_episode_sec)):
224
212
  self._stats['reset_count_max_duration_reached'] += 1
225
- logging.info('Maximum task duration (sec) reached. '
226
- 'Truncating the episode.')
213
+ logging.info('Maximum task duration (%r sec) reached. '
214
+ 'Truncating the episode.', max_episode_sec)
227
215
  return dm_env.truncation
228
216
 
229
217
  return dm_env.transition
@@ -245,9 +233,11 @@ class TaskManager():
245
233
  self._dumpsys_thread = dumpsys_thread.DumpsysThread(
246
234
  app_screen_checker=app_screen_checker.AppScreenChecker(
247
235
  adb_call_parser=adb_call_parser,
248
- expected_app_screen=self._task.expected_app_screen),
249
- check_frequency=self._dumpsys_check_frequency,
250
- max_failed_current_activity=self._max_failed_current_activity)
236
+ expected_app_screen=self._task.expected_app_screen,
237
+ ),
238
+ check_frequency=self._config.dumpsys_check_frequency,
239
+ max_failed_current_activity=self._config.max_failed_current_activity,
240
+ )
251
241
 
252
242
  def _stop_logcat_thread(self):
253
243
  if self._logcat_thread is not None:
@@ -263,11 +253,11 @@ class TaskManager():
263
253
  to a good state.
264
254
  """
265
255
  logging.warning('Bad state detected.')
266
- if self._max_bad_states:
256
+ if self._config.max_bad_states:
267
257
  self._is_bad_episode = True
268
258
  self._bad_state_counter += 1
269
259
  logging.warning('Bad state counter: %d.', self._bad_state_counter)
270
- if self._bad_state_counter >= self._max_bad_states:
260
+ if self._bad_state_counter >= self._config.max_bad_states:
271
261
  logging.error('Too many consecutive bad states. Restarting simulator.')
272
262
  self._stats['restart_count_max_bad_states'] += 1
273
263
  self._should_restart = True
@@ -339,9 +329,16 @@ class TaskManager():
339
329
  if extra:
340
330
  try:
341
331
  extra = ast.literal_eval(extra)
342
- # Except all to avoid unnecessary crashes, only log error.
343
- except Exception: # pylint: disable=broad-except
332
+ except (
333
+ ValueError,
334
+ TypeError,
335
+ SyntaxError,
336
+ MemoryError,
337
+ RecursionError,
338
+ ):
344
339
  logging.exception('Could not parse extra: %s', extra)
340
+ # Don't try to process the extra as text; that would probably crash.
341
+ return
345
342
  else:
346
343
  # No extra value provided for boolean extra. Setting value to True.
347
344
  extra = 1
@@ -375,7 +372,10 @@ class TaskManager():
375
372
  latest_extras = self._latest_values['extra']
376
373
  if extra_name in latest_extras:
377
374
  # If latest extra is not flushed, append.
378
- if len(latest_extras[extra_name]) >= self._extras_max_buffer_size:
375
+ if (
376
+ len(latest_extras[extra_name])
377
+ >= self._config.extras_max_buffer_size
378
+ ):
379
379
  latest_extras[extra_name].pop(0)
380
380
  latest_extras[extra_name].append(extra)
381
381
  else: