droidrun 0.3.5__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,7 +18,8 @@ from droidrun.agent.codeact.events import (
18
18
  TaskThinkingEvent,
19
19
  EpisodicMemoryEvent,
20
20
  )
21
- from droidrun.agent.common.events import ScreenshotEvent
21
+ from droidrun.agent.common.events import ScreenshotEvent, RecordUIStateEvent
22
+ from droidrun.agent.usage import get_usage_from_response
22
23
  from droidrun.agent.utils import chat_utils
23
24
  from droidrun.agent.utils.executer import SimpleCodeExecutor
24
25
  from droidrun.agent.codeact.prompts import (
@@ -182,6 +183,7 @@ class CodeActAgent(Workflow):
182
183
  try:
183
184
  state = self.tools.get_state()
184
185
  await ctx.set("ui_state", state["a11y_tree"])
186
+ ctx.write_event_to_stream(RecordUIStateEvent(ui_state=state["a11y_tree"]))
185
187
  chat_history = await chat_utils.add_ui_text_block(
186
188
  state["a11y_tree"], chat_history
187
189
  )
@@ -202,11 +204,17 @@ class CodeActAgent(Workflow):
202
204
  success=False, reason="LLM response is None. This is a critical error."
203
205
  )
204
206
 
207
+ try:
208
+ usage = get_usage_from_response(self.llm.class_name(), response)
209
+ except Exception as e:
210
+ logger.warning(f"Could not get llm usage from response: {e}")
211
+ usage = None
212
+
205
213
  await self.chat_memory.aput(response.message)
206
214
 
207
215
  code, thoughts = chat_utils.extract_code_and_thought(response.message.content)
208
216
 
209
- event = TaskThinkingEvent(thoughts=thoughts, code=code)
217
+ event = TaskThinkingEvent(thoughts=thoughts, code=code, usage=usage)
210
218
  ctx.write_event_to_stream(event)
211
219
  return event
212
220
 
@@ -255,6 +263,10 @@ class CodeActAgent(Workflow):
255
263
  for screenshot in screenshots[:-1]: # the last screenshot will be captured by next step
256
264
  ctx.write_event_to_stream(ScreenshotEvent(screenshot=screenshot))
257
265
 
266
+ ui_states = result['ui_states']
267
+ for ui_state in ui_states[:-1]:
268
+ ctx.write_event_to_stream(RecordUIStateEvent(ui_state=ui_state['a11y_tree']))
269
+
258
270
  if self.tools.finished == True:
259
271
  logger.debug(" - Task completed.")
260
272
  event = TaskEndEvent(
@@ -311,7 +323,8 @@ class CodeActAgent(Workflow):
311
323
  await ctx.set("chat_memory", self.chat_memory)
312
324
 
313
325
  # Add final state observation to episodic memory
314
- await self._add_final_state_observation(ctx)
326
+ if self.vision:
327
+ await self._add_final_state_observation(ctx)
315
328
 
316
329
  result = {}
317
330
  result.update(
@@ -1,6 +1,8 @@
1
1
  from llama_index.core.llms import ChatMessage
2
2
  from llama_index.core.workflow import Event
3
3
  from typing import Optional
4
+
5
+ from droidrun.agent.usage import UsageResult
4
6
  from ..context.episodic_memory import EpisodicMemory
5
7
 
6
8
  class TaskInputEvent(Event):
@@ -11,6 +13,7 @@ class TaskInputEvent(Event):
11
13
  class TaskThinkingEvent(Event):
12
14
  thoughts: Optional[str] = None
13
15
  code: Optional[str] = None
16
+ usage: Optional[UsageResult] = None
14
17
 
15
18
  class TaskExecutionEvent(Event):
16
19
  code: str
@@ -1,4 +1,5 @@
1
1
  from llama_index.core.workflow import Event
2
+ from typing import Dict, Any
2
3
 
3
4
  class ScreenshotEvent(Event):
4
5
  screenshot: bytes
@@ -44,4 +45,7 @@ class KeyPressActionEvent(MacroEvent):
44
45
  class StartAppEvent(MacroEvent):
45
46
  """"Event for starting an app"""
46
47
  package: str
47
- activity: str = None
48
+ activity: str = None
49
+
50
+ class RecordUIStateEvent(Event):
51
+ ui_state: list[Dict[str, Any]]
@@ -16,23 +16,28 @@ from droidrun.agent.planner import PlannerAgent
16
16
  from droidrun.agent.context.task_manager import TaskManager
17
17
  from droidrun.agent.utils.trajectory import Trajectory
18
18
  from droidrun.tools import Tools, describe_tools
19
- from droidrun.agent.common.events import ScreenshotEvent, MacroEvent
19
+ from droidrun.agent.common.events import ScreenshotEvent, MacroEvent, RecordUIStateEvent
20
20
  from droidrun.agent.common.default import MockWorkflow
21
21
  from droidrun.agent.context import ContextInjectionManager
22
22
  from droidrun.agent.context.agent_persona import AgentPersona
23
23
  from droidrun.agent.context.personas import DEFAULT
24
24
  from droidrun.agent.oneflows.reflector import Reflector
25
- from droidrun.telemetry import capture, flush, DroidAgentInitEvent, DroidAgentFinalizeEvent
26
-
25
+ from droidrun.telemetry import (
26
+ capture,
27
+ flush,
28
+ DroidAgentInitEvent,
29
+ DroidAgentFinalizeEvent,
30
+ )
27
31
 
28
32
  logger = logging.getLogger("droidrun")
29
33
 
34
+
30
35
  class DroidAgent(Workflow):
31
36
  """
32
- A wrapper class that coordinates between PlannerAgent (creates plans) and
33
- CodeActAgent (executes tasks) to achieve a user's goal.
37
+ A wrapper class that coordinates between PlannerAgent (creates plans) and
38
+ CodeActAgent (executes tasks) to achieve a user's goal.
34
39
  """
35
-
40
+
36
41
  @staticmethod
37
42
  def _configure_default_logging(debug: bool = False):
38
43
  """
@@ -43,20 +48,20 @@ A wrapper class that coordinates between PlannerAgent (creates plans) and
43
48
  if not logger.handlers:
44
49
  # Create a console handler
45
50
  handler = logging.StreamHandler()
46
-
51
+
47
52
  # Set format
48
53
  if debug:
49
54
  formatter = logging.Formatter("%(asctime)s %(levelname)s: %(message)s", "%H:%M:%S")
50
55
  else:
51
56
  formatter = logging.Formatter("%(message)s")
52
-
57
+
53
58
  handler.setFormatter(formatter)
54
59
  logger.addHandler(handler)
55
60
  logger.setLevel(logging.DEBUG if debug else logging.INFO)
56
61
  logger.propagate = False
57
-
62
+
58
63
  def __init__(
59
- self,
64
+ self,
60
65
  goal: str,
61
66
  llm: LLM,
62
67
  tools: Tools,
@@ -71,17 +76,17 @@ A wrapper class that coordinates between PlannerAgent (creates plans) and
71
76
  save_trajectories: str = "none",
72
77
  excluded_tools: List[str] = None,
73
78
  *args,
74
- **kwargs
79
+ **kwargs,
75
80
  ):
76
81
  """
77
82
  Initialize the DroidAgent wrapper.
78
-
83
+
79
84
  Args:
80
85
  goal: The user's goal or command to execute
81
86
  llm: The language model to use for both agents
82
87
  max_steps: Maximum number of steps for both agents
83
88
  timeout: Timeout for agent execution in seconds
84
- reasoning: Whether to use the PlannerAgent for complex reasoning (True)
89
+ reasoning: Whether to use the PlannerAgent for complex reasoning (True)
85
90
  or send tasks directly to CodeActAgent (False)
86
91
  reflection: Whether to reflect on steps the CodeActAgent did to give the PlannerAgent advice
87
92
  enable_tracing: Whether to enable Arize Phoenix tracing
@@ -93,14 +98,15 @@ A wrapper class that coordinates between PlannerAgent (creates plans) and
93
98
  **kwargs: Additional keyword arguments to pass to the agents
94
99
  """
95
100
  self.user_id = kwargs.pop("user_id", None)
96
- super().__init__(timeout=timeout ,*args,**kwargs)
101
+ super().__init__(timeout=timeout, *args, **kwargs)
97
102
  # Configure default logging if not already configured
98
103
  self._configure_default_logging(debug=debug)
99
-
104
+
100
105
  # Setup global tracing first if enabled
101
106
  if enable_tracing:
102
107
  try:
103
108
  from llama_index.core import set_global_handler
109
+
104
110
  set_global_handler("arize_phoenix")
105
111
  logger.info("🔍 Arize Phoenix tracing enabled globally")
106
112
  except ImportError:
@@ -125,27 +131,27 @@ A wrapper class that coordinates between PlannerAgent (creates plans) and
125
131
  # Validate string values
126
132
  valid_values = ["none", "step", "action"]
127
133
  if save_trajectories not in valid_values:
128
- logger.warning(f"Invalid save_trajectories value: {save_trajectories}. Using 'none' instead.")
134
+ logger.warning(
135
+ f"Invalid save_trajectories value: {save_trajectories}. Using 'none' instead."
136
+ )
129
137
  self.save_trajectories = "none"
130
138
  else:
131
139
  self.save_trajectories = save_trajectories
132
-
140
+
133
141
  self.trajectory = Trajectory(goal=goal)
134
142
  self.task_manager = TaskManager()
135
143
  self.task_iter = None
136
144
 
137
-
138
145
  self.cim = ContextInjectionManager(personas=personas)
139
146
  self.current_episodic_memory = None
140
147
 
141
148
  logger.info("🤖 Initializing DroidAgent...")
142
149
  logger.info(f"💾 Trajectory saving level: {self.save_trajectories}")
143
-
150
+
144
151
  self.tool_list = describe_tools(tools, excluded_tools)
145
152
  self.tools_instance = tools
146
-
147
- self.tools_instance.save_trajectories = self.save_trajectories
148
153
 
154
+ self.tools_instance.save_trajectories = self.save_trajectories
149
155
 
150
156
  if self.reasoning:
151
157
  logger.info("📝 Initializing Planner Agent...")
@@ -157,14 +163,14 @@ A wrapper class that coordinates between PlannerAgent (creates plans) and
157
163
  task_manager=self.task_manager,
158
164
  tools_instance=tools,
159
165
  timeout=timeout,
160
- debug=debug
166
+ debug=debug,
161
167
  )
162
168
  self.add_workflows(planner_agent=self.planner_agent)
163
169
  self.max_codeact_steps = 5
164
170
 
165
171
  if self.reflection:
166
172
  self.reflector = Reflector(llm=llm, debug=debug)
167
-
173
+
168
174
  else:
169
175
  logger.debug("🚫 Planning disabled - will execute tasks directly with CodeActAgent")
170
176
  self.planner_agent = None
@@ -184,10 +190,9 @@ A wrapper class that coordinates between PlannerAgent (creates plans) and
184
190
  debug=debug,
185
191
  save_trajectories=save_trajectories,
186
192
  ),
187
- self.user_id
193
+ self.user_id,
188
194
  )
189
195
 
190
-
191
196
  logger.info("✅ DroidAgent initialized successfully.")
192
197
 
193
198
  def run(self, *args, **kwargs) -> WorkflowHandler:
@@ -195,19 +200,15 @@ A wrapper class that coordinates between PlannerAgent (creates plans) and
195
200
  Run the DroidAgent workflow.
196
201
  """
197
202
  return super().run(*args, **kwargs)
198
-
203
+
199
204
  @step
200
- async def execute_task(
201
- self,
202
- ctx: Context,
203
- ev: CodeActExecuteEvent
204
- ) -> CodeActResultEvent:
205
+ async def execute_task(self, ctx: Context, ev: CodeActExecuteEvent) -> CodeActResultEvent:
205
206
  """
206
207
  Execute a single task using the CodeActAgent.
207
-
208
+
208
209
  Args:
209
210
  task: Task dictionary with description and status
210
-
211
+
211
212
  Returns:
212
213
  Tuple of (success, reason)
213
214
  """
@@ -232,34 +233,53 @@ A wrapper class that coordinates between PlannerAgent (creates plans) and
232
233
  handler = codeact_agent.run(
233
234
  input=task.description,
234
235
  remembered_info=self.tools_instance.memory,
235
- reflection=reflection
236
+ reflection=reflection,
236
237
  )
237
-
238
+
238
239
  async for nested_ev in handler.stream_events():
239
240
  self.handle_stream_event(nested_ev, ctx)
240
241
 
241
242
  result = await handler
242
243
 
243
-
244
244
  if "success" in result and result["success"]:
245
- return CodeActResultEvent(success=True, reason=result["reason"], task=task, steps=result["codeact_steps"])
245
+ return CodeActResultEvent(
246
+ success=True,
247
+ reason=result["reason"],
248
+ task=task,
249
+ steps=result["codeact_steps"],
250
+ )
246
251
  else:
247
- return CodeActResultEvent(success=False, reason=result["reason"], task=task, steps=result["codeact_steps"])
248
-
252
+ return CodeActResultEvent(
253
+ success=False,
254
+ reason=result["reason"],
255
+ task=task,
256
+ steps=result["codeact_steps"],
257
+ )
258
+
249
259
  except Exception as e:
250
260
  logger.error(f"Error during task execution: {e}")
251
261
  if self.debug:
252
262
  import traceback
263
+
253
264
  logger.error(traceback.format_exc())
254
265
  return CodeActResultEvent(success=False, reason=f"Error: {str(e)}", task=task, steps=[])
255
-
266
+
256
267
  @step
257
- async def handle_codeact_execute(self, ctx: Context, ev: CodeActResultEvent) -> FinalizeEvent | ReflectionEvent | ReasoningLogicEvent:
268
+ async def handle_codeact_execute(
269
+ self, ctx: Context, ev: CodeActResultEvent
270
+ ) -> FinalizeEvent | ReflectionEvent | ReasoningLogicEvent:
258
271
  try:
259
272
  task = ev.task
260
273
  if not self.reasoning:
261
- return FinalizeEvent(success=ev.success, reason=ev.reason, output=ev.reason, task=[task], tasks=[task], steps=ev.steps)
262
-
274
+ return FinalizeEvent(
275
+ success=ev.success,
276
+ reason=ev.reason,
277
+ output=ev.reason,
278
+ task=[task],
279
+ tasks=[task],
280
+ steps=ev.steps,
281
+ )
282
+
263
283
  if self.reflection and ev.success:
264
284
  return ReflectionEvent(task=task)
265
285
 
@@ -277,51 +297,64 @@ A wrapper class that coordinates between PlannerAgent (creates plans) and
277
297
  logger.error(f"❌ Error during DroidAgent execution: {e}")
278
298
  if self.debug:
279
299
  import traceback
300
+
280
301
  logger.error(traceback.format_exc())
281
302
  tasks = self.task_manager.get_task_history()
282
- return FinalizeEvent(success=False, reason=str(e), output=str(e), task=tasks, tasks=tasks, steps=self.step_counter)
283
-
303
+ return FinalizeEvent(
304
+ success=False,
305
+ reason=str(e),
306
+ output=str(e),
307
+ task=tasks,
308
+ tasks=tasks,
309
+ steps=self.step_counter,
310
+ )
284
311
 
285
312
  @step
286
313
  async def reflect(
287
- self,
288
- ctx: Context,
289
- ev: ReflectionEvent
290
- ) -> ReasoningLogicEvent | CodeActExecuteEvent:
291
-
292
-
314
+ self, ctx: Context, ev: ReflectionEvent
315
+ ) -> ReasoningLogicEvent | CodeActExecuteEvent:
293
316
  task = ev.task
294
317
  if ev.task.agent_type == "AppStarterExpert":
295
318
  self.task_manager.complete_task(task)
296
319
  return ReasoningLogicEvent()
297
-
298
- reflection = await self.reflector.reflect_on_episodic_memory(episodic_memory=self.current_episodic_memory, goal=task.description)
320
+
321
+ reflection = await self.reflector.reflect_on_episodic_memory(
322
+ episodic_memory=self.current_episodic_memory, goal=task.description
323
+ )
299
324
 
300
325
  if reflection.goal_achieved:
301
326
  self.task_manager.complete_task(task)
302
327
  return ReasoningLogicEvent()
303
-
328
+
304
329
  else:
305
330
  self.task_manager.fail_task(task)
306
331
  return ReasoningLogicEvent(reflection=reflection)
307
-
308
332
 
309
333
  @step
310
334
  async def handle_reasoning_logic(
311
335
  self,
312
336
  ctx: Context,
313
337
  ev: ReasoningLogicEvent,
314
- planner_agent: Workflow = MockWorkflow()
315
- ) -> FinalizeEvent | CodeActExecuteEvent:
338
+ planner_agent: Workflow = MockWorkflow(),
339
+ ) -> FinalizeEvent | CodeActExecuteEvent:
316
340
  try:
317
341
  if self.step_counter >= self.max_steps:
318
342
  output = f"Reached maximum number of steps ({self.max_steps})"
319
343
  tasks = self.task_manager.get_task_history()
320
- return FinalizeEvent(success=False, reason=output, output=output, task=tasks, tasks=tasks, steps=self.step_counter)
344
+ return FinalizeEvent(
345
+ success=False,
346
+ reason=output,
347
+ output=output,
348
+ task=tasks,
349
+ tasks=tasks,
350
+ steps=self.step_counter,
351
+ )
321
352
  self.step_counter += 1
322
353
 
323
354
  if ev.reflection:
324
- handler = planner_agent.run(remembered_info=self.tools_instance.memory, reflection=ev.reflection)
355
+ handler = planner_agent.run(
356
+ remembered_info=self.tools_instance.memory, reflection=ev.reflection
357
+ )
325
358
  else:
326
359
  if not ev.force_planning and self.task_iter:
327
360
  try:
@@ -332,7 +365,9 @@ A wrapper class that coordinates between PlannerAgent (creates plans) and
332
365
 
333
366
  logger.debug(f"Planning step {self.step_counter}/{self.max_steps}")
334
367
 
335
- handler = planner_agent.run(remembered_info=self.tools_instance.memory, reflection=None)
368
+ handler = planner_agent.run(
369
+ remembered_info=self.tools_instance.memory, reflection=None
370
+ )
336
371
 
337
372
  async for nested_ev in handler.stream_events():
338
373
  self.handle_stream_event(nested_ev, ctx)
@@ -345,51 +380,73 @@ A wrapper class that coordinates between PlannerAgent (creates plans) and
345
380
  if self.task_manager.goal_completed:
346
381
  logger.info(f"✅ Goal completed: {self.task_manager.message}")
347
382
  tasks = self.task_manager.get_task_history()
348
- return FinalizeEvent(success=True, reason=self.task_manager.message, output=self.task_manager.message, task=tasks, tasks=tasks, steps=self.step_counter)
383
+ return FinalizeEvent(
384
+ success=True,
385
+ reason=self.task_manager.message,
386
+ output=self.task_manager.message,
387
+ task=tasks,
388
+ tasks=tasks,
389
+ steps=self.step_counter,
390
+ )
349
391
  if not self.tasks:
350
392
  logger.warning("No tasks generated by planner")
351
393
  output = "Planner did not generate any tasks"
352
394
  tasks = self.task_manager.get_task_history()
353
- return FinalizeEvent(success=False, reason=output, output=output, task=tasks, tasks=tasks, steps=self.step_counter)
354
-
395
+ return FinalizeEvent(
396
+ success=False,
397
+ reason=output,
398
+ output=output,
399
+ task=tasks,
400
+ tasks=tasks,
401
+ steps=self.step_counter,
402
+ )
403
+
355
404
  return CodeActExecuteEvent(task=next(self.task_iter), reflection=None)
356
-
405
+
357
406
  except Exception as e:
358
407
  logger.error(f"❌ Error during DroidAgent execution: {e}")
359
408
  if self.debug:
360
409
  import traceback
410
+
361
411
  logger.error(traceback.format_exc())
362
412
  tasks = self.task_manager.get_task_history()
363
- return FinalizeEvent(success=False, reason=str(e), output=str(e), task=tasks, tasks=tasks, steps=self.step_counter)
364
-
413
+ return FinalizeEvent(
414
+ success=False,
415
+ reason=str(e),
416
+ output=str(e),
417
+ task=tasks,
418
+ tasks=tasks,
419
+ steps=self.step_counter,
420
+ )
365
421
 
366
422
  @step
367
- async def start_handler(self, ctx: Context, ev: StartEvent) -> CodeActExecuteEvent | ReasoningLogicEvent:
423
+ async def start_handler(
424
+ self, ctx: Context, ev: StartEvent
425
+ ) -> CodeActExecuteEvent | ReasoningLogicEvent:
368
426
  """
369
427
  Main execution loop that coordinates between planning and execution.
370
-
428
+
371
429
  Returns:
372
430
  Dict containing the execution result
373
431
  """
374
432
  logger.info(f"🚀 Running DroidAgent to achieve goal: {self.goal}")
375
433
  ctx.write_event_to_stream(ev)
376
-
434
+
377
435
  self.step_counter = 0
378
436
  self.retry_counter = 0
379
-
437
+
380
438
  if not self.reasoning:
381
439
  logger.info(f"🔄 Direct execution mode - executing goal: {self.goal}")
382
440
  task = Task(
383
441
  description=self.goal,
384
442
  status=self.task_manager.STATUS_PENDING,
385
- agent_type="Default"
443
+ agent_type="Default",
386
444
  )
387
-
445
+
388
446
  return CodeActExecuteEvent(task=task, reflection=None)
389
-
447
+
390
448
  return ReasoningLogicEvent()
391
-
392
-
449
+
393
450
  @step
394
451
  async def finalize(self, ctx: Context, ev: FinalizeEvent) -> StopEvent:
395
452
  ctx.write_event_to_stream(ev)
@@ -400,7 +457,7 @@ A wrapper class that coordinates between PlannerAgent (creates plans) and
400
457
  output=ev.output,
401
458
  steps=ev.steps,
402
459
  ),
403
- self.user_id
460
+ self.user_id,
404
461
  )
405
462
  flush()
406
463
 
@@ -416,24 +473,20 @@ A wrapper class that coordinates between PlannerAgent (creates plans) and
416
473
  self.trajectory.save_trajectory()
417
474
 
418
475
  return StopEvent(result)
419
-
420
- def handle_stream_event(self, ev: Event, ctx: Context):
421
476
 
477
+ def handle_stream_event(self, ev: Event, ctx: Context):
422
478
  if isinstance(ev, EpisodicMemoryEvent):
423
479
  self.current_episodic_memory = ev.episodic_memory
424
480
  return
425
-
426
-
427
481
 
428
482
  if not isinstance(ev, StopEvent):
429
483
  ctx.write_event_to_stream(ev)
430
-
484
+
431
485
  if isinstance(ev, ScreenshotEvent):
432
486
  self.trajectory.screenshots.append(ev.screenshot)
433
487
  elif isinstance(ev, MacroEvent):
434
488
  self.trajectory.macro.append(ev)
489
+ elif isinstance(ev, RecordUIStateEvent):
490
+ self.trajectory.ui_states.append(ev.ui_state)
435
491
  else:
436
492
  self.trajectory.events.append(ev)
437
-
438
-
439
-
@@ -2,6 +2,7 @@ from llama_index.core.workflow import Event
2
2
  from llama_index.core.base.llms.types import ChatMessage
3
3
  from typing import Optional, Any
4
4
  from droidrun.agent.context import Task
5
+ from droidrun.agent.usage import UsageResult
5
6
 
6
7
  class PlanInputEvent(Event):
7
8
  input: list[ChatMessage]
@@ -10,6 +11,7 @@ class PlanInputEvent(Event):
10
11
  class PlanThinkingEvent(Event):
11
12
  thoughts: Optional[str] = None
12
13
  code: Optional[str] = None
14
+ usage: Optional[UsageResult] = None
13
15
 
14
16
 
15
17
  class PlanCreatedEvent(Event):
@@ -13,11 +13,12 @@ from llama_index.core.llms.llm import LLM
13
13
  from llama_index.core.workflow import Workflow, StartEvent, StopEvent, Context, step
14
14
  from llama_index.core.memory import Memory
15
15
  from llama_index.core.llms.llm import LLM
16
+ from droidrun.agent.usage import get_usage_from_response
16
17
  from droidrun.agent.utils.executer import SimpleCodeExecutor
17
18
  from droidrun.agent.utils import chat_utils
18
19
  from droidrun.agent.context.task_manager import TaskManager
19
20
  from droidrun.tools import Tools
20
- from droidrun.agent.common.events import ScreenshotEvent
21
+ from droidrun.agent.common.events import ScreenshotEvent, RecordUIStateEvent
21
22
  from droidrun.agent.planner.events import (
22
23
  PlanInputEvent,
23
24
  PlanCreatedEvent,
@@ -130,16 +131,16 @@ class PlannerAgent(Workflow):
130
131
  self.steps_counter += 1
131
132
  logger.info(f"🧠 Thinking about how to plan the goal...")
132
133
 
133
- # if vision is disabled, screenshot should save to trajectory
134
- screenshot = (self.tools_instance.take_screenshot())[1]
135
- ctx.write_event_to_stream(ScreenshotEvent(screenshot=screenshot))
136
134
  if self.vision:
135
+ screenshot = (self.tools_instance.take_screenshot())[1]
136
+ ctx.write_event_to_stream(ScreenshotEvent(screenshot=screenshot))
137
137
  await ctx.set("screenshot", screenshot)
138
138
 
139
139
  try:
140
140
  state = self.tools_instance.get_state()
141
141
  await ctx.set("ui_state", state["a11y_tree"])
142
142
  await ctx.set("phone_state", state["phone_state"])
143
+ ctx.write_event_to_stream(RecordUIStateEvent(ui_state=state["a11y_tree"]))
143
144
  except Exception as e:
144
145
  logger.warning(f"⚠️ Error retrieving state from the connected device. Is the Accessibility Service enabled?")
145
146
 
@@ -148,11 +149,16 @@ class PlannerAgent(Workflow):
148
149
  await ctx.set("reflection", self.reflection)
149
150
 
150
151
  response = await self._get_llm_response(ctx, chat_history)
152
+ try:
153
+ usage = get_usage_from_response(self.llm.class_name(), response)
154
+ except Exception as e:
155
+ logger.warning(f"Could not get llm usage from response: {e}")
156
+ usage = None
151
157
  await self.chat_memory.aput(response.message)
152
158
 
153
159
  code, thoughts = chat_utils.extract_code_and_thought(response.message.content)
154
160
 
155
- event = PlanThinkingEvent(thoughts=thoughts, code=code)
161
+ event = PlanThinkingEvent(thoughts=thoughts, code=code, usage=usage)
156
162
  ctx.write_event_to_stream(event)
157
163
  return event
158
164
 
@@ -174,6 +180,10 @@ class PlannerAgent(Workflow):
174
180
  screenshots = result['screenshots']
175
181
  for screenshot in screenshots[:-1]: # the last screenshot will be captured by next step
176
182
  ctx.write_event_to_stream(ScreenshotEvent(screenshot=screenshot))
183
+
184
+ ui_states = result['ui_states']
185
+ for ui_state in ui_states[:-1]:
186
+ ctx.write_event_to_stream(RecordUIStateEvent(ui_state=ui_state['a11y_tree']))
177
187
 
178
188
  await self.chat_memory.aput(
179
189
  ChatMessage(