android-env 1.2.1__py3-none-any.whl → 1.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- android_env/__init__.py +1 -1
- android_env/components/__init__.py +1 -1
- android_env/components/a11y/__init__.py +15 -0
- android_env/components/a11y/a11y_events.py +118 -0
- android_env/components/a11y/a11y_events_test.py +173 -0
- android_env/components/a11y/a11y_forests.py +128 -0
- android_env/components/a11y/a11y_forests_test.py +237 -0
- android_env/components/a11y/a11y_servicer.py +199 -0
- android_env/components/a11y/a11y_servicer_test.py +224 -0
- android_env/components/action_fns.py +132 -0
- android_env/components/action_fns_test.py +227 -0
- android_env/components/action_type.py +26 -3
- android_env/components/adb_call_parser.py +239 -196
- android_env/components/adb_call_parser_test.py +179 -209
- android_env/components/adb_controller.py +90 -52
- android_env/components/adb_controller_test.py +187 -16
- android_env/components/adb_log_stream.py +17 -5
- android_env/components/adb_log_stream_test.py +17 -3
- android_env/components/app_screen_checker.py +17 -15
- android_env/components/app_screen_checker_test.py +7 -8
- android_env/components/config_classes.py +203 -0
- android_env/components/coordinator.py +102 -338
- android_env/components/coordinator_test.py +59 -199
- android_env/components/device_settings.py +174 -0
- android_env/components/device_settings_test.py +228 -0
- android_env/components/dumpsys_thread.py +3 -4
- android_env/components/dumpsys_thread_test.py +1 -1
- android_env/components/errors.py +52 -10
- android_env/components/errors_test.py +110 -0
- android_env/components/log_stream.py +7 -5
- android_env/components/log_stream_test.py +1 -1
- android_env/components/logcat_thread.py +9 -8
- android_env/components/logcat_thread_test.py +3 -4
- android_env/components/{utils.py → pixel_fns.py} +20 -20
- android_env/components/{utils_test.py → pixel_fns_test.py} +20 -15
- android_env/components/setup_step_interpreter.py +47 -39
- android_env/components/setup_step_interpreter_test.py +4 -4
- android_env/components/simulators/__init__.py +1 -1
- android_env/components/simulators/base_simulator.py +116 -44
- android_env/components/simulators/base_simulator_test.py +131 -9
- android_env/components/simulators/emulator/__init__.py +1 -1
- android_env/components/simulators/emulator/emulator_launcher.py +67 -77
- android_env/components/simulators/emulator/emulator_launcher_test.py +153 -49
- android_env/components/simulators/emulator/emulator_simulator.py +276 -95
- android_env/components/simulators/emulator/emulator_simulator_test.py +314 -89
- android_env/components/simulators/fake/__init__.py +1 -1
- android_env/components/simulators/fake/fake_simulator.py +17 -25
- android_env/components/simulators/fake/fake_simulator_test.py +29 -12
- android_env/components/specs.py +18 -28
- android_env/components/specs_test.py +1 -44
- android_env/components/task_manager.py +48 -48
- android_env/components/task_manager_test.py +71 -60
- android_env/env_interface.py +37 -23
- android_env/environment.py +83 -51
- android_env/environment_test.py +68 -29
- android_env/loader.py +57 -43
- android_env/loader_test.py +115 -35
- android_env/proto/__init__.py +1 -1
- android_env/proto/a11y/__init__.py +15 -0
- android_env/proto/a11y/a11y.proto +75 -0
- android_env/proto/a11y/a11y_pb2.py +54 -0
- android_env/proto/a11y/a11y_pb2.pyi +49 -0
- android_env/proto/a11y/a11y_pb2_grpc.py +202 -0
- android_env/proto/a11y/android_accessibility_action.proto +32 -0
- android_env/proto/a11y/android_accessibility_action_pb2.py +37 -0
- android_env/proto/a11y/android_accessibility_action_pb2.pyi +13 -0
- android_env/proto/a11y/android_accessibility_action_pb2_grpc.py +24 -0
- android_env/proto/a11y/android_accessibility_forest.proto +29 -0
- android_env/proto/a11y/android_accessibility_forest_pb2.py +38 -0
- android_env/proto/a11y/android_accessibility_forest_pb2.pyi +13 -0
- android_env/proto/a11y/android_accessibility_forest_pb2_grpc.py +24 -0
- android_env/proto/a11y/android_accessibility_node_info.proto +122 -0
- android_env/proto/a11y/android_accessibility_node_info_clickable_span.proto +49 -0
- android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2.py +39 -0
- android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2.pyi +28 -0
- android_env/proto/a11y/android_accessibility_node_info_clickable_span_pb2_grpc.py +24 -0
- android_env/proto/a11y/android_accessibility_node_info_pb2.py +42 -0
- android_env/proto/a11y/android_accessibility_node_info_pb2.pyi +75 -0
- android_env/proto/a11y/android_accessibility_node_info_pb2_grpc.py +24 -0
- android_env/proto/a11y/android_accessibility_tree.proto +29 -0
- android_env/proto/a11y/android_accessibility_tree_pb2.py +38 -0
- android_env/proto/a11y/android_accessibility_tree_pb2.pyi +13 -0
- android_env/proto/a11y/android_accessibility_tree_pb2_grpc.py +24 -0
- android_env/proto/a11y/android_accessibility_window_info.proto +84 -0
- android_env/proto/a11y/android_accessibility_window_info_pb2.py +41 -0
- android_env/proto/a11y/android_accessibility_window_info_pb2.pyi +48 -0
- android_env/proto/a11y/android_accessibility_window_info_pb2_grpc.py +24 -0
- android_env/proto/a11y/rect.proto +30 -0
- android_env/proto/a11y/rect_pb2.py +37 -0
- android_env/proto/a11y/rect_pb2.pyi +17 -0
- android_env/proto/a11y/rect_pb2_grpc.py +24 -0
- android_env/proto/adb.proto +17 -6
- android_env/proto/adb_pb2.py +120 -107
- android_env/proto/adb_pb2.pyi +396 -0
- android_env/proto/adb_pb2_grpc.py +20 -0
- android_env/proto/emulator_controller.proto +68 -63
- android_env/proto/emulator_controller_pb2.py +142 -131
- android_env/proto/emulator_controller_pb2.pyi +672 -0
- android_env/proto/emulator_controller_pb2_grpc.py +505 -142
- android_env/proto/snapshot.proto +169 -0
- android_env/proto/snapshot_pb2.py +47 -0
- android_env/proto/snapshot_pb2.pyi +117 -0
- android_env/proto/snapshot_pb2_grpc.py +24 -0
- android_env/proto/snapshot_service.proto +289 -0
- android_env/proto/snapshot_service_pb2.py +54 -0
- android_env/proto/snapshot_service_pb2.pyi +86 -0
- android_env/proto/snapshot_service_pb2_grpc.py +487 -0
- android_env/proto/state.proto +63 -0
- android_env/proto/state_pb2.py +63 -0
- android_env/proto/state_pb2.pyi +85 -0
- android_env/proto/state_pb2_grpc.py +24 -0
- android_env/proto/task.proto +5 -1
- android_env/proto/task_pb2.py +42 -31
- android_env/proto/task_pb2.pyi +160 -0
- android_env/proto/task_pb2_grpc.py +20 -0
- android_env/wrappers/__init__.py +1 -1
- android_env/wrappers/a11y_grpc_wrapper.py +500 -0
- android_env/wrappers/a11y_grpc_wrapper_test.py +849 -0
- android_env/wrappers/base_wrapper.py +34 -13
- android_env/wrappers/base_wrapper_test.py +22 -16
- android_env/wrappers/discrete_action_wrapper.py +18 -17
- android_env/wrappers/discrete_action_wrapper_test.py +4 -4
- android_env/wrappers/flat_interface_wrapper.py +5 -5
- android_env/wrappers/flat_interface_wrapper_test.py +7 -11
- android_env/wrappers/float_pixels_wrapper.py +9 -10
- android_env/wrappers/float_pixels_wrapper_test.py +3 -3
- android_env/wrappers/gym_wrapper.py +19 -13
- android_env/wrappers/gym_wrapper_test.py +3 -5
- android_env/wrappers/image_rescale_wrapper.py +18 -21
- android_env/wrappers/image_rescale_wrapper_test.py +25 -37
- android_env/wrappers/last_action_wrapper.py +16 -13
- android_env/wrappers/last_action_wrapper_test.py +44 -51
- android_env/wrappers/rate_limit_wrapper.py +6 -3
- android_env/wrappers/rate_limit_wrapper_test.py +22 -1
- android_env/wrappers/tap_action_wrapper.py +16 -17
- android_env/wrappers/tap_action_wrapper_test.py +51 -16
- {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info}/METADATA +14 -18
- android_env-1.2.3.dist-info/RECORD +141 -0
- {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info}/WHEEL +1 -1
- android_env/proto/raw_observation.proto +0 -39
- android_env/proto/raw_observation_pb2.py +0 -27
- android_env/proto/raw_observation_pb2_grpc.py +0 -4
- android_env-1.2.1.dist-info/RECORD +0 -81
- {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info/licenses}/LICENSE +0 -0
- {android_env-1.2.1.dist-info → android_env-1.2.3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2024 DeepMind Technologies Limited.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -17,6 +17,7 @@
|
|
17
17
|
|
18
18
|
import re
|
19
19
|
from absl.testing import absltest
|
20
|
+
from android_env.components import config_classes
|
20
21
|
from android_env.components.simulators.fake import fake_simulator
|
21
22
|
import numpy as np
|
22
23
|
|
@@ -24,18 +25,24 @@ import numpy as np
|
|
24
25
|
class FakeSimulatorTest(absltest.TestCase):
|
25
26
|
|
26
27
|
def test_device_name(self):
|
27
|
-
simulator = fake_simulator.FakeSimulator(
|
28
|
+
simulator = fake_simulator.FakeSimulator(
|
29
|
+
config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480))
|
30
|
+
)
|
28
31
|
self.assertEqual(simulator.adb_device_name(), 'fake_simulator')
|
29
32
|
|
30
33
|
def test_launch_close(self):
|
31
34
|
# The simulator should launch and not crash.
|
32
|
-
simulator = fake_simulator.FakeSimulator(
|
35
|
+
simulator = fake_simulator.FakeSimulator(
|
36
|
+
config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480))
|
37
|
+
)
|
33
38
|
simulator.launch()
|
34
39
|
# Closing the simulator should also not crash.
|
35
40
|
simulator.close()
|
36
41
|
|
37
42
|
def test_get_screenshot(self):
|
38
|
-
simulator = fake_simulator.FakeSimulator(
|
43
|
+
simulator = fake_simulator.FakeSimulator(
|
44
|
+
config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480))
|
45
|
+
)
|
39
46
|
simulator.launch()
|
40
47
|
|
41
48
|
screenshot = simulator.get_screenshot()
|
@@ -43,7 +50,9 @@ class FakeSimulatorTest(absltest.TestCase):
|
|
43
50
|
np.testing.assert_equal(screenshot.dtype, np.uint8)
|
44
51
|
|
45
52
|
def test_log_stream(self):
|
46
|
-
simulator = fake_simulator.FakeSimulator(
|
53
|
+
simulator = fake_simulator.FakeSimulator(
|
54
|
+
config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480))
|
55
|
+
)
|
47
56
|
simulator.launch()
|
48
57
|
log_stream = simulator.create_log_stream()
|
49
58
|
# Start yielding lines from LogStream.
|
@@ -61,12 +70,16 @@ class FakeSimulatorTest(absltest.TestCase):
|
|
61
70
|
break
|
62
71
|
|
63
72
|
def test_adb_output(self):
|
64
|
-
simulator = fake_simulator.FakeSimulator(
|
73
|
+
simulator = fake_simulator.FakeSimulator(
|
74
|
+
config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480))
|
75
|
+
)
|
65
76
|
simulator.launch()
|
66
77
|
adb_controller = simulator.create_adb_controller()
|
67
78
|
line = adb_controller.execute_command(['shell', 'dumpsys', 'input'])
|
68
79
|
line = line.decode('utf-8')
|
69
|
-
|
80
|
+
matches = re.match(r'\s+SurfaceOrientation:\s+(\d)', line)
|
81
|
+
self.assertIsNotNone(matches)
|
82
|
+
orientation = matches.group(1)
|
70
83
|
self.assertEqual(orientation, '0')
|
71
84
|
line = adb_controller.execute_command(['shell', 'service', 'check', 'foo'])
|
72
85
|
line = line.decode('utf-8')
|
@@ -77,7 +90,9 @@ class FakeSimulatorTest(absltest.TestCase):
|
|
77
90
|
'topActivity=ComponentInfo{fake_activity}')
|
78
91
|
|
79
92
|
def test_send_touch(self):
|
80
|
-
simulator = fake_simulator.FakeSimulator(
|
93
|
+
simulator = fake_simulator.FakeSimulator(
|
94
|
+
config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480))
|
95
|
+
)
|
81
96
|
simulator.launch()
|
82
97
|
simulator.send_touch([(0, 1, True, 0)])
|
83
98
|
simulator.send_touch([(0, 1, False, 0)])
|
@@ -85,11 +100,13 @@ class FakeSimulatorTest(absltest.TestCase):
|
|
85
100
|
# without crashing anything.
|
86
101
|
|
87
102
|
def test_send_key(self):
|
88
|
-
simulator = fake_simulator.FakeSimulator(
|
103
|
+
simulator = fake_simulator.FakeSimulator(
|
104
|
+
config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480))
|
105
|
+
)
|
89
106
|
simulator.launch()
|
90
|
-
simulator.send_key(123, 'keydown')
|
91
|
-
simulator.send_key(123, 'keyup')
|
92
|
-
simulator.send_key(123, 'keypress')
|
107
|
+
simulator.send_key(np.int32(123), 'keydown')
|
108
|
+
simulator.send_key(np.int32(123), 'keyup')
|
109
|
+
simulator.send_key(np.int32(123), 'keypress')
|
93
110
|
# No assertions, we just want to ensure that `send_key()` can be called
|
94
111
|
# without crashing anything.
|
95
112
|
|
android_env/components/specs.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2024 DeepMind Technologies Limited.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -15,11 +15,8 @@
|
|
15
15
|
|
16
16
|
"""Base specs for AndroidEnv."""
|
17
17
|
|
18
|
-
from typing import Dict
|
19
|
-
|
20
18
|
from android_env.components import action_type
|
21
19
|
from android_env.proto import task_pb2
|
22
|
-
import dm_env
|
23
20
|
from dm_env import specs
|
24
21
|
import numpy as np
|
25
22
|
|
@@ -41,11 +38,13 @@ _PROTO_DTYPE_TO_NUMPY_DTYPE = {
|
|
41
38
|
task_pb2.ArraySpec.DataType.STRING_U25: np.dtype(('<U25')),
|
42
39
|
task_pb2.ArraySpec.DataType.STRING_U250: np.dtype(('<U250')),
|
43
40
|
task_pb2.ArraySpec.DataType.STRING: np.dtype(('<U0')),
|
41
|
+
task_pb2.ArraySpec.DataType.OBJECT: np.dtype('O'),
|
44
42
|
}
|
45
43
|
|
46
44
|
|
47
|
-
def base_action_spec(
|
48
|
-
|
45
|
+
def base_action_spec(
|
46
|
+
num_fingers: int = 1, enable_key_events: bool = False
|
47
|
+
) -> dict[str, specs.Array]:
|
49
48
|
"""Default action spec for AndroidEnv.
|
50
49
|
|
51
50
|
Args:
|
@@ -58,6 +57,8 @@ def base_action_spec(num_fingers: int = 1,
|
|
58
57
|
touch_position: Position [x, y] of the touch action, where x, y are float
|
59
58
|
values between 0.0 and 1.0 corresponding to the relative position on the
|
60
59
|
screen. IGNORED when (action_type != ActionType.TOUCH).
|
60
|
+
keycode: code representing the desired key press in XKB format. See the
|
61
|
+
emulator_controller_pb2 for details.
|
61
62
|
action_type_i: Action type for additional fingers (i>1).
|
62
63
|
touch_position_i: Touch position for additional fingers (i>1).
|
63
64
|
"""
|
@@ -98,7 +99,7 @@ def base_action_spec(num_fingers: int = 1,
|
|
98
99
|
return action_spec
|
99
100
|
|
100
101
|
|
101
|
-
def base_observation_spec(height: int, width: int) ->
|
102
|
+
def base_observation_spec(height: int, width: int) -> dict[str, specs.Array]:
|
102
103
|
"""Default observation spec for AndroidEnv.
|
103
104
|
|
104
105
|
Args:
|
@@ -119,30 +120,19 @@ def base_observation_spec(height: int, width: int) -> Dict[str, specs.Array]:
|
|
119
120
|
|
120
121
|
return {
|
121
122
|
'pixels':
|
122
|
-
specs.
|
123
|
+
specs.BoundedArray(
|
123
124
|
shape=(height, width, 3),
|
124
125
|
dtype=np.uint8,
|
125
|
-
name='pixels'
|
126
|
+
name='pixels',
|
127
|
+
minimum=0,
|
128
|
+
maximum=255),
|
126
129
|
'timedelta':
|
127
130
|
specs.Array(shape=(), dtype=np.int64, name='timedelta'),
|
128
131
|
'orientation':
|
129
|
-
specs.
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
return {
|
137
|
-
spec.name: _convert_spec(spec)
|
138
|
-
for spec in task.extras_spec
|
132
|
+
specs.BoundedArray(
|
133
|
+
shape=np.array([4]),
|
134
|
+
dtype=np.uint8,
|
135
|
+
name='orientation',
|
136
|
+
minimum=0,
|
137
|
+
maximum=1),
|
139
138
|
}
|
140
|
-
|
141
|
-
|
142
|
-
def _convert_spec(array_spec: task_pb2.ArraySpec) -> specs.Array:
|
143
|
-
"""Converts ArraySpec proto to dm_env specs.Array."""
|
144
|
-
|
145
|
-
return specs.Array(
|
146
|
-
shape=array_spec.shape,
|
147
|
-
dtype=_PROTO_DTYPE_TO_NUMPY_DTYPE[array_spec.dtype],
|
148
|
-
name=array_spec.name)
|
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2024 DeepMind Technologies Limited.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -79,49 +79,6 @@ class SpecsTest(parameterized.TestCase):
|
|
79
79
|
self.assertEqual(observation_spec['orientation'].shape, (4,))
|
80
80
|
self.assertEqual(observation_spec['orientation'].dtype, np.uint8)
|
81
81
|
|
82
|
-
def test_base_task_extras_spec(self):
|
83
|
-
array_spec_1 = task_pb2.ArraySpec()
|
84
|
-
array_spec_1.name = 'my_extra_1'
|
85
|
-
array_spec_1.shape.extend([10, 10])
|
86
|
-
array_spec_1.dtype = task_pb2.ArraySpec.FLOAT
|
87
|
-
|
88
|
-
array_spec_2 = task_pb2.ArraySpec()
|
89
|
-
array_spec_2.name = 'my_extra_2'
|
90
|
-
array_spec_2.shape.extend([1])
|
91
|
-
array_spec_2.dtype = task_pb2.ArraySpec.UINT8
|
92
|
-
|
93
|
-
fake_task = task_pb2.Task()
|
94
|
-
fake_task.extras_spec.extend([array_spec_1, array_spec_2])
|
95
|
-
task_extras_spec = specs.base_task_extras_spec(fake_task)
|
96
|
-
for spec in task_extras_spec.values():
|
97
|
-
self.assertIsInstance(spec, dm_env_specs.Array)
|
98
|
-
|
99
|
-
self.assertEqual(task_extras_spec['my_extra_1'].shape, (10, 10))
|
100
|
-
self.assertEqual(task_extras_spec['my_extra_2'].shape, (1,))
|
101
|
-
self.assertEqual(task_extras_spec['my_extra_1'].dtype, np.float32)
|
102
|
-
self.assertEqual(task_extras_spec['my_extra_2'].dtype, np.uint8)
|
103
|
-
|
104
|
-
@parameterized.parameters(
|
105
|
-
('name_1', [480, 320, 3], task_pb2.ArraySpec.FLOAT, np.float32),
|
106
|
-
('name_2', [100, 100, 3], task_pb2.ArraySpec.INT32, np.int32),
|
107
|
-
('name_3', [123, 456, 3], task_pb2.ArraySpec.UINT8, np.uint8),
|
108
|
-
('name_4', [480, 320, 1], task_pb2.ArraySpec.BOOL, np.bool_),
|
109
|
-
('', [480, 320], task_pb2.ArraySpec.STRING_U25, np.dtype(('<U25'))),
|
110
|
-
)
|
111
|
-
def test_convert_spec(self, name, shape, dtype, expected_dtype):
|
112
|
-
fake_array_spec = task_pb2.ArraySpec()
|
113
|
-
fake_array_spec.name = name
|
114
|
-
fake_array_spec.dtype = dtype
|
115
|
-
fake_array_spec.shape.extend(shape)
|
116
|
-
fake_task = task_pb2.Task()
|
117
|
-
fake_task.extras_spec.extend([fake_array_spec])
|
118
|
-
task_extras_spec = specs.base_task_extras_spec(fake_task)
|
119
|
-
for spec in task_extras_spec.values():
|
120
|
-
self.assertIsInstance(spec, dm_env_specs.Array)
|
121
|
-
|
122
|
-
self.assertEqual(task_extras_spec[name].shape, tuple(shape))
|
123
|
-
self.assertEqual(task_extras_spec[name].dtype, expected_dtype)
|
124
|
-
|
125
82
|
|
126
83
|
if __name__ == '__main__':
|
127
84
|
absltest.main()
|
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2024 DeepMind Technologies Limited.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -16,16 +16,18 @@
|
|
16
16
|
"""TaskManager handles all events and information related to the task."""
|
17
17
|
|
18
18
|
import ast
|
19
|
+
from collections.abc import Callable
|
19
20
|
import copy
|
20
21
|
import datetime
|
21
22
|
import json
|
22
23
|
import re
|
23
24
|
import threading
|
24
|
-
from typing import Any
|
25
|
+
from typing import Any
|
25
26
|
|
26
27
|
from absl import logging
|
27
28
|
from android_env.components import adb_call_parser as adb_call_parser_lib
|
28
29
|
from android_env.components import app_screen_checker
|
30
|
+
from android_env.components import config_classes
|
29
31
|
from android_env.components import dumpsys_thread
|
30
32
|
from android_env.components import log_stream as log_stream_lib
|
31
33
|
from android_env.components import logcat_thread
|
@@ -35,34 +37,24 @@ import dm_env
|
|
35
37
|
import numpy as np
|
36
38
|
|
37
39
|
|
38
|
-
class TaskManager
|
40
|
+
class TaskManager:
|
39
41
|
"""Handles all events and information related to the task."""
|
40
42
|
|
41
43
|
def __init__(
|
42
44
|
self,
|
43
45
|
task: task_pb2.Task,
|
44
|
-
|
45
|
-
dumpsys_check_frequency: int = 150,
|
46
|
-
max_failed_current_activity: int = 10,
|
46
|
+
config: config_classes.TaskManagerConfig | None = None,
|
47
47
|
):
|
48
48
|
"""Controls task-relevant events and information.
|
49
49
|
|
50
50
|
Args:
|
51
51
|
task: A task proto defining the RL task.
|
52
|
-
|
53
|
-
of the simulator is triggered.
|
54
|
-
dumpsys_check_frequency: Frequency, in steps, at which to check
|
55
|
-
current_activity and view hierarchy
|
56
|
-
max_failed_current_activity: The maximum number of tries for extracting
|
57
|
-
the current activity before forcing the episode to restart.
|
52
|
+
config: Configuration for this instance.
|
58
53
|
"""
|
59
|
-
self._task = task
|
60
|
-
self._max_bad_states = max_bad_states
|
61
|
-
self._dumpsys_check_frequency = dumpsys_check_frequency
|
62
|
-
self._max_failed_current_activity = max_failed_current_activity
|
63
54
|
|
55
|
+
self._task = task
|
56
|
+
self._config = config or config_classes.TaskManagerConfig()
|
64
57
|
self._lock = threading.Lock()
|
65
|
-
self._extras_max_buffer_size = 100
|
66
58
|
self._logcat_thread = None
|
67
59
|
self._dumpsys_thread = None
|
68
60
|
self._setup_step_interpreter = None
|
@@ -92,14 +84,7 @@ class TaskManager():
|
|
92
84
|
|
93
85
|
logging.info('Task config: %s', self._task)
|
94
86
|
|
95
|
-
def
|
96
|
-
return self._task
|
97
|
-
|
98
|
-
def update_task(self, task: task_pb2.Task) -> None:
|
99
|
-
self._stats['task_updates'] += 1
|
100
|
-
self._task = task
|
101
|
-
|
102
|
-
def stats(self) -> Dict[str, Any]:
|
87
|
+
def stats(self) -> dict[str, Any]:
|
103
88
|
"""Returns a dictionary of stats.
|
104
89
|
|
105
90
|
This method is expected to be called after setup_task() has been called.
|
@@ -109,21 +94,24 @@ class TaskManager():
|
|
109
94
|
output.update(self._setup_step_interpreter.stats())
|
110
95
|
return output
|
111
96
|
|
112
|
-
def setup_task(
|
97
|
+
def setup_task(self) -> None:
|
98
|
+
"""Performs one-off task setup.."""
|
99
|
+
self._setup_step_interpreter.interpret(self._task.setup_steps)
|
100
|
+
|
101
|
+
def stop(self) -> None:
|
102
|
+
"""Suspends task processing."""
|
103
|
+
self._stop_logcat_thread()
|
104
|
+
|
105
|
+
def start(
|
113
106
|
self,
|
114
107
|
adb_call_parser_factory: Callable[[], adb_call_parser_lib.AdbCallParser],
|
115
108
|
log_stream: log_stream_lib.LogStream) -> None:
|
116
|
-
"""
|
109
|
+
"""Starts task processing."""
|
117
110
|
|
118
111
|
self._start_logcat_thread(log_stream=log_stream)
|
112
|
+
self._logcat_thread.resume()
|
119
113
|
self._start_dumpsys_thread(adb_call_parser_factory())
|
120
114
|
self._start_setup_step_interpreter(adb_call_parser_factory())
|
121
|
-
self._setup_step_interpreter.interpret(self._task.setup_steps)
|
122
|
-
|
123
|
-
def stop_task(self) -> None:
|
124
|
-
"""Suspends task processing."""
|
125
|
-
|
126
|
-
self._stop_logcat_thread()
|
127
115
|
|
128
116
|
def reset_task(self) -> None:
|
129
117
|
"""Resets a task for a new run."""
|
@@ -146,7 +134,7 @@ class TaskManager():
|
|
146
134
|
'episode_end': False,
|
147
135
|
}
|
148
136
|
|
149
|
-
def rl_reset(self, observation:
|
137
|
+
def rl_reset(self, observation: dict[str, Any]) -> dm_env.TimeStep:
|
150
138
|
"""Performs one RL step."""
|
151
139
|
|
152
140
|
self._stats['episode_steps'] = 0
|
@@ -163,7 +151,7 @@ class TaskManager():
|
|
163
151
|
discount=0.0,
|
164
152
|
observation=observation)
|
165
153
|
|
166
|
-
def rl_step(self, observation:
|
154
|
+
def rl_step(self, observation: dict[str, Any]) -> dm_env.TimeStep:
|
167
155
|
"""Performs one RL step."""
|
168
156
|
|
169
157
|
self._stats['episode_steps'] += 1
|
@@ -184,7 +172,7 @@ class TaskManager():
|
|
184
172
|
self._latest_values['reward'] = 0.0
|
185
173
|
return reward
|
186
174
|
|
187
|
-
def _get_current_extras(self) ->
|
175
|
+
def _get_current_extras(self) -> dict[str, Any]:
|
188
176
|
"""Returns task extras accumulated since the last step."""
|
189
177
|
extras = {}
|
190
178
|
for name, values in self._latest_values['extra'].items():
|
@@ -213,8 +201,8 @@ class TaskManager():
|
|
213
201
|
if self._task.max_episode_steps > 0:
|
214
202
|
if self._stats['episode_steps'] > self._task.max_episode_steps:
|
215
203
|
self._stats['reset_count_max_duration_reached'] += 1
|
216
|
-
logging.info('Maximum task duration (steps) reached. '
|
217
|
-
'Truncating the episode.')
|
204
|
+
logging.info('Maximum task duration (%r steps) reached. '
|
205
|
+
'Truncating the episode.', self._task.max_episode_steps)
|
218
206
|
return dm_env.truncation
|
219
207
|
|
220
208
|
if self._task.max_episode_sec > 0.0:
|
@@ -222,8 +210,8 @@ class TaskManager():
|
|
222
210
|
max_episode_sec = self._task.max_episode_sec
|
223
211
|
if task_duration > datetime.timedelta(seconds=int(max_episode_sec)):
|
224
212
|
self._stats['reset_count_max_duration_reached'] += 1
|
225
|
-
logging.info('Maximum task duration (sec) reached. '
|
226
|
-
'Truncating the episode.')
|
213
|
+
logging.info('Maximum task duration (%r sec) reached. '
|
214
|
+
'Truncating the episode.', max_episode_sec)
|
227
215
|
return dm_env.truncation
|
228
216
|
|
229
217
|
return dm_env.transition
|
@@ -245,9 +233,11 @@ class TaskManager():
|
|
245
233
|
self._dumpsys_thread = dumpsys_thread.DumpsysThread(
|
246
234
|
app_screen_checker=app_screen_checker.AppScreenChecker(
|
247
235
|
adb_call_parser=adb_call_parser,
|
248
|
-
expected_app_screen=self._task.expected_app_screen
|
249
|
-
|
250
|
-
|
236
|
+
expected_app_screen=self._task.expected_app_screen,
|
237
|
+
),
|
238
|
+
check_frequency=self._config.dumpsys_check_frequency,
|
239
|
+
max_failed_current_activity=self._config.max_failed_current_activity,
|
240
|
+
)
|
251
241
|
|
252
242
|
def _stop_logcat_thread(self):
|
253
243
|
if self._logcat_thread is not None:
|
@@ -263,11 +253,11 @@ class TaskManager():
|
|
263
253
|
to a good state.
|
264
254
|
"""
|
265
255
|
logging.warning('Bad state detected.')
|
266
|
-
if self.
|
256
|
+
if self._config.max_bad_states:
|
267
257
|
self._is_bad_episode = True
|
268
258
|
self._bad_state_counter += 1
|
269
259
|
logging.warning('Bad state counter: %d.', self._bad_state_counter)
|
270
|
-
if self._bad_state_counter >= self.
|
260
|
+
if self._bad_state_counter >= self._config.max_bad_states:
|
271
261
|
logging.error('Too many consecutive bad states. Restarting simulator.')
|
272
262
|
self._stats['restart_count_max_bad_states'] += 1
|
273
263
|
self._should_restart = True
|
@@ -339,9 +329,16 @@ class TaskManager():
|
|
339
329
|
if extra:
|
340
330
|
try:
|
341
331
|
extra = ast.literal_eval(extra)
|
342
|
-
|
343
|
-
|
332
|
+
except (
|
333
|
+
ValueError,
|
334
|
+
TypeError,
|
335
|
+
SyntaxError,
|
336
|
+
MemoryError,
|
337
|
+
RecursionError,
|
338
|
+
):
|
344
339
|
logging.exception('Could not parse extra: %s', extra)
|
340
|
+
# Don't try to process the extra as text; that would probably crash.
|
341
|
+
return
|
345
342
|
else:
|
346
343
|
# No extra value provided for boolean extra. Setting value to True.
|
347
344
|
extra = 1
|
@@ -375,7 +372,10 @@ class TaskManager():
|
|
375
372
|
latest_extras = self._latest_values['extra']
|
376
373
|
if extra_name in latest_extras:
|
377
374
|
# If latest extra is not flushed, append.
|
378
|
-
if
|
375
|
+
if (
|
376
|
+
len(latest_extras[extra_name])
|
377
|
+
>= self._config.extras_max_buffer_size
|
378
|
+
):
|
379
379
|
latest_extras[extra_name].pop(0)
|
380
380
|
latest_extras[extra_name].append(extra)
|
381
381
|
else:
|