lumivor 0.1.7__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,1017 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import base64
5
+ import io
6
+ import json
7
+ import logging
8
+ import os
9
+ import textwrap
10
+ import time
11
+ import uuid
12
+ from io import BytesIO
13
+ from pathlib import Path
14
+ from typing import Any, Optional, Type, TypeVar
15
+
16
+ from dotenv import load_dotenv
17
+ from langchain_core.language_models.chat_models import BaseChatModel
18
+ from langchain_core.messages import (
19
+ BaseMessage,
20
+ SystemMessage,
21
+ )
22
+ from openai import RateLimitError
23
+ from PIL import Image, ImageDraw, ImageFont
24
+ from pydantic import BaseModel, ValidationError
25
+
26
+ from lumivor.agent.message_manager.service import MessageManager
27
+ from lumivor.agent.prompts import AgentMessagePrompt, SystemPrompt
28
+ from lumivor.agent.views import (
29
+ ActionResult,
30
+ AgentError,
31
+ AgentHistory,
32
+ AgentHistoryList,
33
+ AgentOutput,
34
+ AgentStepInfo,
35
+ )
36
+ from lumivor.browser.browser import Browser
37
+ from lumivor.browser.context import BrowserContext
38
+ from lumivor.browser.views import BrowserState, BrowserStateHistory
39
+ from lumivor.controller.registry.views import ActionModel
40
+ from lumivor.controller.service import Controller
41
+ from lumivor.dom.history_tree_processor.service import (
42
+ DOMHistoryElement,
43
+ HistoryTreeProcessor,
44
+ )
45
+ from lumivor.telemetry.service import ProductTelemetry
46
+ from lumivor.telemetry.views import (
47
+ AgentEndTelemetryEvent,
48
+ AgentRunTelemetryEvent,
49
+ AgentStepErrorTelemetryEvent,
50
+ )
51
+ from lumivor.utils import time_execution_async
52
+
53
+ load_dotenv()
54
+ logger = logging.getLogger(__name__)
55
+
56
+ T = TypeVar('T', bound=BaseModel)
57
+
58
+
59
+ class Agent:
60
+ def __init__(
61
+ self,
62
+ task: str,
63
+ llm: BaseChatModel,
64
+ browser: Browser | None = None,
65
+ browser_context: BrowserContext | None = None,
66
+ controller: Controller = Controller(),
67
+ use_vision: bool = True,
68
+ save_conversation_path: Optional[str] = None,
69
+ max_failures: int = 3,
70
+ retry_delay: int = 10,
71
+ system_prompt_class: Type[SystemPrompt] = SystemPrompt,
72
+ max_input_tokens: int = 128000,
73
+ validate_output: bool = False,
74
+ generate_gif: bool = True,
75
+ include_attributes: list[str] = [
76
+ 'title',
77
+ 'type',
78
+ 'name',
79
+ 'role',
80
+ 'tabindex',
81
+ 'aria-label',
82
+ 'placeholder',
83
+ 'value',
84
+ 'alt',
85
+ 'aria-expanded',
86
+ ],
87
+ max_error_length: int = 400,
88
+ max_actions_per_step: int = 10,
89
+ tool_call_in_content: bool = True,
90
+ ):
91
+ self.agent_id = str(uuid.uuid4()) # unique identifier for the agent
92
+
93
+ self.task = task
94
+ self.use_vision = use_vision
95
+ self.llm = llm
96
+ self.save_conversation_path = save_conversation_path
97
+ self._last_result = None
98
+ self.include_attributes = include_attributes
99
+ self.max_error_length = max_error_length
100
+ self.generate_gif = generate_gif
101
+ # Controller setup
102
+ self.controller = controller
103
+ self.max_actions_per_step = max_actions_per_step
104
+
105
+ # Browser setup
106
+ self.injected_browser = browser is not None
107
+ self.injected_browser_context = browser_context is not None
108
+
109
+ # Initialize browser first if needed
110
+ self.browser = browser if browser is not None else (
111
+ None if browser_context else Browser())
112
+
113
+ # Initialize browser context
114
+ if browser_context:
115
+ self.browser_context = browser_context
116
+ elif self.browser:
117
+ self.browser_context = BrowserContext(
118
+ browser=self.browser, config=self.browser.config.new_context_config
119
+ )
120
+ else:
121
+ # If neither is provided, create both new
122
+ self.browser = Browser()
123
+ self.browser_context = BrowserContext(browser=self.browser)
124
+
125
+ self.system_prompt_class = system_prompt_class
126
+
127
+ # Telemetry setup
128
+ self.telemetry = ProductTelemetry()
129
+
130
+ # Action and output models setup
131
+ self._setup_action_models()
132
+
133
+ self.max_input_tokens = max_input_tokens
134
+
135
+ self.message_manager = MessageManager(
136
+ llm=self.llm,
137
+ task=self.task,
138
+ action_descriptions=self.controller.registry.get_prompt_description(),
139
+ system_prompt_class=self.system_prompt_class,
140
+ max_input_tokens=self.max_input_tokens,
141
+ include_attributes=self.include_attributes,
142
+ max_error_length=self.max_error_length,
143
+ max_actions_per_step=self.max_actions_per_step,
144
+ tool_call_in_content=tool_call_in_content,
145
+ )
146
+
147
+ # Tracking variables
148
+ self.history: AgentHistoryList = AgentHistoryList(history=[])
149
+ self.n_steps = 1
150
+ self.consecutive_failures = 0
151
+ self.max_failures = max_failures
152
+ self.retry_delay = retry_delay
153
+ self.validate_output = validate_output
154
+
155
+ if save_conversation_path:
156
+ logger.info(f'Saving conversation to {save_conversation_path}')
157
+
158
+ def _setup_action_models(self) -> None:
159
+ """Setup dynamic action models from controller's registry"""
160
+ # Get the dynamic action model from controller's registry
161
+ self.ActionModel = self.controller.registry.create_action_model()
162
+ # Create output model with the dynamic actions
163
+ self.AgentOutput = AgentOutput.type_with_custom_actions(
164
+ self.ActionModel)
165
+
166
+ @time_execution_async('--step')
167
+ async def step(self, step_info: Optional[AgentStepInfo] = None) -> None:
168
+ """Execute one step of the task"""
169
+ logger.info(f'\n📍 Step {self.n_steps}')
170
+ state = None
171
+ model_output = None
172
+ result: list[ActionResult] = []
173
+
174
+ try:
175
+ state = await self.browser_context.get_state(use_vision=self.use_vision)
176
+ self.message_manager.add_state_message(
177
+ state, self._last_result, step_info)
178
+ input_messages = self.message_manager.get_messages()
179
+ try:
180
+ model_output = await self.get_next_action(input_messages)
181
+ self._save_conversation(input_messages, model_output)
182
+ # we dont want the whole state in the chat history
183
+ self.message_manager._remove_last_state_message()
184
+ self.message_manager.add_model_output(model_output)
185
+ except Exception as e:
186
+ # model call failed, remove last state message from history
187
+ self.message_manager._remove_last_state_message()
188
+ raise e
189
+
190
+ result: list[ActionResult] = await self.controller.multi_act(
191
+ model_output.action, self.browser_context
192
+ )
193
+ self._last_result = result
194
+
195
+ if len(result) > 0 and result[-1].is_done:
196
+ logger.info(f'📄 Result: {result[-1].extracted_content}')
197
+
198
+ self.consecutive_failures = 0
199
+
200
+ except Exception as e:
201
+ result = self._handle_step_error(e)
202
+ self._last_result = result
203
+
204
+ finally:
205
+ if not result:
206
+ return
207
+ for r in result:
208
+ if r.error:
209
+ self.telemetry.capture(
210
+ AgentStepErrorTelemetryEvent(
211
+ agent_id=self.agent_id,
212
+ error=r.error,
213
+ )
214
+ )
215
+ if state:
216
+ self._make_history_item(model_output, state, result)
217
+
218
+ def _handle_step_error(self, error: Exception) -> list[ActionResult]:
219
+ """Handle all types of errors that can occur during a step"""
220
+ include_trace = logger.isEnabledFor(logging.DEBUG)
221
+ error_msg = AgentError.format_error(error, include_trace=include_trace)
222
+ prefix = f'❌ Result failed {
223
+ self.consecutive_failures + 1}/{self.max_failures} times:\n '
224
+
225
+ if isinstance(error, (ValidationError, ValueError)):
226
+ logger.error(f'{prefix}{error_msg}')
227
+ if 'Max token limit reached' in error_msg:
228
+ # cut tokens from history
229
+ self.message_manager.max_input_tokens = self.max_input_tokens - 500
230
+ logger.info(
231
+ f'Cutting tokens from history - new max input tokens: {
232
+ self.message_manager.max_input_tokens}'
233
+ )
234
+ self.message_manager.cut_messages()
235
+ elif 'Could not parse response' in error_msg:
236
+ # give model a hint how output should look like
237
+ error_msg += '\n\nReturn a valid JSON object with the required fields.'
238
+
239
+ self.consecutive_failures += 1
240
+ elif isinstance(error, RateLimitError):
241
+ logger.warning(f'{prefix}{error_msg}')
242
+ time.sleep(self.retry_delay)
243
+ self.consecutive_failures += 1
244
+ else:
245
+ logger.error(f'{prefix}{error_msg}')
246
+ self.consecutive_failures += 1
247
+
248
+ return [ActionResult(error=error_msg, include_in_memory=True)]
249
+
250
+ def _make_history_item(
251
+ self,
252
+ model_output: AgentOutput | None,
253
+ state: BrowserState,
254
+ result: list[ActionResult],
255
+ ) -> None:
256
+ """Create and store history item"""
257
+ interacted_element = None
258
+ len_result = len(result)
259
+
260
+ if model_output:
261
+ interacted_elements = AgentHistory.get_interacted_element(
262
+ model_output, state.selector_map
263
+ )
264
+ else:
265
+ interacted_elements = [None]
266
+
267
+ state_history = BrowserStateHistory(
268
+ url=state.url,
269
+ title=state.title,
270
+ tabs=state.tabs,
271
+ interacted_element=interacted_elements,
272
+ screenshot=state.screenshot,
273
+ )
274
+
275
+ history_item = AgentHistory(
276
+ model_output=model_output, result=result, state=state_history)
277
+
278
+ self.history.history.append(history_item)
279
+
280
+ @time_execution_async('--get_next_action')
281
+ async def get_next_action(self, input_messages: list[BaseMessage]) -> AgentOutput:
282
+ """Get next action from LLM based on current state"""
283
+
284
+ structured_llm = self.llm.with_structured_output(
285
+ self.AgentOutput, include_raw=True)
286
+ # type: ignore
287
+ response: dict[str, Any] = await structured_llm.ainvoke(input_messages)
288
+
289
+ parsed: AgentOutput = response['parsed']
290
+ if parsed is None:
291
+ raise ValueError(f'Could not parse response.')
292
+
293
+ # cut the number of actions to max_actions_per_step
294
+ parsed.action = parsed.action[: self.max_actions_per_step]
295
+ self._log_response(parsed)
296
+ self.n_steps += 1
297
+
298
+ return parsed
299
+
300
+ def _log_response(self, response: AgentOutput) -> None:
301
+ """Log the model's response"""
302
+ if 'Success' in response.current_state.evaluation_previous_goal:
303
+ emoji = '👍'
304
+ elif 'Failed' in response.current_state.evaluation_previous_goal:
305
+ emoji = '⚠'
306
+ else:
307
+ emoji = '🤷'
308
+
309
+ logger.info(
310
+ f'{emoji} Eval: {response.current_state.evaluation_previous_goal}')
311
+ logger.info(f'🧠 Memory: {response.current_state.memory}')
312
+ logger.info(f'🎯 Next goal: {response.current_state.next_goal}')
313
+ for i, action in enumerate(response.action):
314
+ logger.info(
315
+ f'🛠️ Action {
316
+ i + 1}/{len(response.action)}: {action.model_dump_json(exclude_unset=True)}'
317
+ )
318
+
319
+ def _save_conversation(self, input_messages: list[BaseMessage], response: Any) -> None:
320
+ """Save conversation history to file if path is specified"""
321
+ if not self.save_conversation_path:
322
+ return
323
+
324
+ # create folders if not exists
325
+ os.makedirs(os.path.dirname(
326
+ self.save_conversation_path), exist_ok=True)
327
+
328
+ with open(self.save_conversation_path + f'_{self.n_steps}.txt', 'w') as f:
329
+ self._write_messages_to_file(f, input_messages)
330
+ self._write_response_to_file(f, response)
331
+
332
+ def _write_messages_to_file(self, f: Any, messages: list[BaseMessage]) -> None:
333
+ """Write messages to conversation file"""
334
+ for message in messages:
335
+ f.write(f' {message.__class__.__name__} \n')
336
+
337
+ if isinstance(message.content, list):
338
+ for item in message.content:
339
+ if isinstance(item, dict) and item.get('type') == 'text':
340
+ f.write(item['text'].strip() + '\n')
341
+ elif isinstance(message.content, str):
342
+ try:
343
+ content = json.loads(message.content)
344
+ f.write(json.dumps(content, indent=2) + '\n')
345
+ except json.JSONDecodeError:
346
+ f.write(message.content.strip() + '\n')
347
+
348
+ f.write('\n')
349
+
350
+ def _write_response_to_file(self, f: Any, response: Any) -> None:
351
+ """Write model response to conversation file"""
352
+ f.write(' RESPONSE\n')
353
+ f.write(json.dumps(json.loads(
354
+ response.model_dump_json(exclude_unset=True)), indent=2))
355
+
356
+ async def run(self, max_steps: int = 100) -> AgentHistoryList:
357
+ """Execute the task with maximum number of steps"""
358
+ try:
359
+ logger.info(f'🚀 Starting task: {self.task}')
360
+
361
+ self.telemetry.capture(
362
+ AgentRunTelemetryEvent(
363
+ agent_id=self.agent_id,
364
+ task=self.task,
365
+ )
366
+ )
367
+
368
+ for step in range(max_steps):
369
+ if self._too_many_failures():
370
+ break
371
+
372
+ await self.step()
373
+
374
+ if self.history.is_done():
375
+ if (
376
+ self.validate_output and step < max_steps - 1
377
+ ): # if last step, we dont need to validate
378
+ if not await self._validate_output():
379
+ continue
380
+
381
+ logger.info('✅ Task completed successfully')
382
+ break
383
+ else:
384
+ logger.info('❌ Failed to complete task in maximum steps')
385
+
386
+ return self.history
387
+
388
+ finally:
389
+ self.telemetry.capture(
390
+ AgentEndTelemetryEvent(
391
+ agent_id=self.agent_id,
392
+ task=self.task,
393
+ success=self.history.is_done(),
394
+ steps=len(self.history.history),
395
+ )
396
+ )
397
+ if not self.injected_browser_context:
398
+ await self.browser_context.close()
399
+
400
+ if not self.injected_browser and self.browser:
401
+ await self.browser.close()
402
+
403
+ if self.generate_gif:
404
+ self.create_history_gif()
405
+
406
+ def _too_many_failures(self) -> bool:
407
+ """Check if we should stop due to too many failures"""
408
+ if self.consecutive_failures >= self.max_failures:
409
+ logger.error(
410
+ f'❌ Stopping due to {self.max_failures} consecutive failures')
411
+ return True
412
+ return False
413
+
414
+ async def _validate_output(self) -> bool:
415
+ """Validate the output of the last action is what the user wanted"""
416
+ system_msg = (
417
+ f'You are a validator of an agent who interacts with a browser. '
418
+ f'Validate if the output of last action is what the user wanted and if the task is completed. '
419
+ f'If the task is unclear defined, you can let it pass. But if something is missing or the image does not show what was requested dont let it pass. '
420
+ f'Try to understand the page and help the model with suggestions like scroll, do x, ... to get the solution right. '
421
+ f'Task to validate: {
422
+ self.task}. Return a JSON object with 2 keys: is_valid and reason. '
423
+ f'is_valid is a boolean that indicates if the output is correct. '
424
+ f'reason is a string that explains why it is valid or not.'
425
+ f' example: {{"is_valid": false, "reason": "The user wanted to search for "cat photos", but the agent searched for "dog photos" instead."}}'
426
+ )
427
+
428
+ if self.browser_context.session:
429
+ state = await self.browser_context.get_state(use_vision=self.use_vision)
430
+ content = AgentMessagePrompt(
431
+ state=state,
432
+ result=self._last_result,
433
+ include_attributes=self.include_attributes,
434
+ max_error_length=self.max_error_length,
435
+ )
436
+ msg = [SystemMessage(content=system_msg),
437
+ content.get_user_message()]
438
+ else:
439
+ # if no browser session, we can't validate the output
440
+ return True
441
+
442
+ class ValidationResult(BaseModel):
443
+ is_valid: bool
444
+ reason: str
445
+
446
+ validator = self.llm.with_structured_output(
447
+ ValidationResult, include_raw=True)
448
+ response: dict[str, Any] = await validator.ainvoke(msg) # type: ignore
449
+ parsed: ValidationResult = response['parsed']
450
+ is_valid = parsed.is_valid
451
+ if not is_valid:
452
+ logger.info(f'❌ Validator decision: {parsed.reason}')
453
+ msg = f'The ouput is not yet correct. {parsed.reason}.'
454
+ self._last_result = [ActionResult(
455
+ extracted_content=msg, include_in_memory=True)]
456
+ else:
457
+ logger.info(f'✅ Validator decision: {parsed.reason}')
458
+ return is_valid
459
+
460
+ async def rerun_history(
461
+ self,
462
+ history: AgentHistoryList,
463
+ max_retries: int = 3,
464
+ skip_failures: bool = True,
465
+ delay_between_actions: float = 2.0,
466
+ ) -> list[ActionResult]:
467
+ """
468
+ Rerun a saved history of actions with error handling and retry logic.
469
+
470
+ Args:
471
+ history: The history to replay
472
+ max_retries: Maximum number of retries per action
473
+ skip_failures: Whether to skip failed actions or stop execution
474
+ delay_between_actions: Delay between actions in seconds
475
+
476
+ Returns:
477
+ List of action results
478
+ """
479
+ results = []
480
+
481
+ for i, history_item in enumerate(history.history):
482
+ goal = (
483
+ history_item.model_output.current_state.next_goal
484
+ if history_item.model_output
485
+ else ''
486
+ )
487
+ logger.info(
488
+ f'Replaying step {i + 1}/{len(history.history)}: goal: {goal}')
489
+
490
+ if (
491
+ not history_item.model_output
492
+ or not history_item.model_output.action
493
+ or history_item.model_output.action == [None]
494
+ ):
495
+ logger.warning(f'Step {i + 1}: No action to replay, skipping')
496
+ results.append(ActionResult(error='No action to replay'))
497
+ continue
498
+
499
+ retry_count = 0
500
+ while retry_count < max_retries:
501
+ try:
502
+ result = await self._execute_history_step(history_item, delay_between_actions)
503
+ results.extend(result)
504
+ break
505
+
506
+ except Exception as e:
507
+ retry_count += 1
508
+ if retry_count == max_retries:
509
+ error_msg = f'Step {
510
+ i + 1} failed after {max_retries} attempts: {str(e)}'
511
+ logger.error(error_msg)
512
+ if not skip_failures:
513
+ results.append(ActionResult(error=error_msg))
514
+ raise RuntimeError(error_msg)
515
+ else:
516
+ logger.warning(
517
+ f'Step {
518
+ i + 1} failed (attempt {retry_count}/{max_retries}), retrying...'
519
+ )
520
+ await asyncio.sleep(delay_between_actions)
521
+
522
+ return results
523
+
524
+ async def _execute_history_step(
525
+ self, history_item: AgentHistory, delay: float
526
+ ) -> list[ActionResult]:
527
+ """Execute a single step from history with element validation"""
528
+
529
+ state = await self.browser_context.get_state()
530
+ if not state or not history_item.model_output:
531
+ raise ValueError('Invalid state or model output')
532
+ updated_actions = []
533
+ for i, action in enumerate(history_item.model_output.action):
534
+ updated_action = await self._update_action_indices(
535
+ history_item.state.interacted_element[i],
536
+ action,
537
+ state,
538
+ )
539
+ updated_actions.append(updated_action)
540
+
541
+ if updated_action is None:
542
+ raise ValueError(
543
+ f'Could not find matching element {i} in current page')
544
+
545
+ result = await self.controller.multi_act(updated_actions, self.browser_context)
546
+
547
+ await asyncio.sleep(delay)
548
+ return result
549
+
550
+ async def _update_action_indices(
551
+ self,
552
+ historical_element: Optional[DOMHistoryElement],
553
+ action: ActionModel, # Type this properly based on your action model
554
+ current_state: BrowserState,
555
+ ) -> Optional[ActionModel]:
556
+ """
557
+ Update action indices based on current page state.
558
+ Returns updated action or None if element cannot be found.
559
+ """
560
+ if not historical_element or not current_state.element_tree:
561
+ return action
562
+
563
+ current_element = HistoryTreeProcessor.find_history_element_in_tree(
564
+ historical_element, current_state.element_tree
565
+ )
566
+
567
+ if not current_element or current_element.highlight_index is None:
568
+ return None
569
+
570
+ old_index = action.get_index()
571
+ if old_index != current_element.highlight_index:
572
+ action.set_index(current_element.highlight_index)
573
+ logger.info(
574
+ f'Element moved in DOM, updated index from {
575
+ old_index} to {current_element.highlight_index}'
576
+ )
577
+
578
+ return action
579
+
580
+ async def load_and_rerun(
581
+ self, history_file: Optional[str | Path] = None, **kwargs
582
+ ) -> list[ActionResult]:
583
+ """
584
+ Load history from file and rerun it.
585
+
586
+ Args:
587
+ history_file: Path to the history file
588
+ **kwargs: Additional arguments passed to rerun_history
589
+ """
590
+ if not history_file:
591
+ history_file = 'AgentHistory.json'
592
+ history = AgentHistoryList.load_from_file(
593
+ history_file, self.AgentOutput)
594
+ return await self.rerun_history(history, **kwargs)
595
+
596
+ def save_history(self, file_path: Optional[str | Path] = None) -> None:
597
+ """Save the history to a file"""
598
+ if not file_path:
599
+ file_path = 'AgentHistory.json'
600
+ self.history.save_to_file(file_path)
601
+
602
+ def create_history_gif(
603
+ self,
604
+ output_path: str = 'agent_history.gif',
605
+ duration: int = 3000,
606
+ show_goals: bool = True,
607
+ show_task: bool = True,
608
+ show_logo: bool = False,
609
+ font_size: int = 40,
610
+ title_font_size: int = 56,
611
+ goal_font_size: int = 44,
612
+ margin: int = 40,
613
+ line_spacing: float = 1.5,
614
+ ) -> None:
615
+ """Create a GIF from the agent's history with overlaid task and goal text."""
616
+ if not self.history.history:
617
+ logger.warning('No history to create GIF from')
618
+ return
619
+
620
+ images = []
621
+ # if history is empty or first screenshot is None, we can't create a gif
622
+ if not self.history.history or not self.history.history[0].state.screenshot:
623
+ logger.warning('No history or first screenshot to create GIF from')
624
+ return
625
+
626
+ # Try to load nicer fonts
627
+ try:
628
+ # Try different font options in order of preference
629
+ font_options = ['Helvetica', 'Arial', 'DejaVuSans', 'Verdana']
630
+ font_loaded = False
631
+
632
+ for font_name in font_options:
633
+ try:
634
+ regular_font = ImageFont.truetype(font_name, font_size)
635
+ title_font = ImageFont.truetype(font_name, title_font_size)
636
+ goal_font = ImageFont.truetype(font_name, goal_font_size)
637
+ font_loaded = True
638
+ break
639
+ except OSError:
640
+ continue
641
+
642
+ if not font_loaded:
643
+ raise OSError('No preferred fonts found')
644
+
645
+ except OSError:
646
+ regular_font = ImageFont.load_default()
647
+ title_font = ImageFont.load_default()
648
+
649
+ goal_font = regular_font
650
+
651
+ # Load logo if requested
652
+ logo = None
653
+ if show_logo:
654
+ try:
655
+ logo = Image.open('./static/lumivor.png')
656
+ # Resize logo to be small (e.g., 40px height)
657
+ logo_height = 150
658
+ aspect_ratio = logo.width / logo.height
659
+ logo_width = int(logo_height * aspect_ratio)
660
+ logo = logo.resize((logo_width, logo_height),
661
+ Image.Resampling.LANCZOS)
662
+ except Exception as e:
663
+ logger.warning(f'Could not load logo: {e}')
664
+
665
+ # Create task frame if requested
666
+ if show_task and self.task:
667
+ task_frame = self._create_task_frame(
668
+ self.task,
669
+ self.history.history[0].state.screenshot,
670
+ title_font,
671
+ regular_font,
672
+ logo,
673
+ line_spacing,
674
+ )
675
+ images.append(task_frame)
676
+
677
+ # Process each history item
678
+ for i, item in enumerate(self.history.history, 1):
679
+ if not item.state.screenshot:
680
+ continue
681
+
682
+ # Convert base64 screenshot to PIL Image
683
+ img_data = base64.b64decode(item.state.screenshot)
684
+ image = Image.open(io.BytesIO(img_data))
685
+
686
+ if show_goals and item.model_output:
687
+ image = self._add_overlay_to_image(
688
+ image=image,
689
+ step_number=i,
690
+ goal_text=item.model_output.current_state.next_goal,
691
+ regular_font=regular_font,
692
+ title_font=title_font,
693
+ margin=margin,
694
+ logo=logo,
695
+ )
696
+
697
+ images.append(image)
698
+
699
+ if images:
700
+ # Save the GIF
701
+ images[0].save(
702
+ output_path,
703
+ save_all=True,
704
+ append_images=images[1:],
705
+ duration=duration,
706
+ loop=0,
707
+ optimize=False,
708
+ )
709
+ logger.info(f'Created GIF at {output_path}')
710
+ else:
711
+ logger.warning('No images found in history to create GIF')
712
+
713
+ def _create_task_frame(
714
+ self,
715
+ task: str,
716
+ first_screenshot: str,
717
+ title_font: ImageFont.FreeTypeFont,
718
+ regular_font: ImageFont.FreeTypeFont,
719
+ logo: Optional[Image.Image] = None,
720
+ line_spacing: float = 1.5,
721
+ ) -> Image.Image:
722
+ """Create initial frame showing the task."""
723
+ img_data = base64.b64decode(first_screenshot)
724
+ template = Image.open(io.BytesIO(img_data))
725
+ image = Image.new('RGB', template.size, (0, 0, 0))
726
+ draw = ImageDraw.Draw(image)
727
+
728
+ # Calculate vertical center of image
729
+ center_y = image.height // 2
730
+
731
+ # Draw task text with increased font size
732
+ margin = 140 # Increased margin
733
+ max_width = image.width - (2 * margin)
734
+ larger_font = ImageFont.truetype(
735
+ regular_font.path, regular_font.size + 16
736
+ ) # Increase font size more
737
+ wrapped_text = self._wrap_text(task, larger_font, max_width)
738
+
739
+ # Calculate line height with spacing
740
+ line_height = larger_font.size * line_spacing
741
+
742
+ # Split text into lines and draw with custom spacing
743
+ lines = wrapped_text.split('\n')
744
+ total_height = line_height * len(lines)
745
+
746
+ # Start position for first line
747
+ text_y = center_y - (total_height / 2) + 50 # Shifted down slightly
748
+
749
+ for line in lines:
750
+ # Get line width for centering
751
+ line_bbox = draw.textbbox((0, 0), line, font=larger_font)
752
+ text_x = (image.width - (line_bbox[2] - line_bbox[0])) // 2
753
+
754
+ draw.text(
755
+ (text_x, text_y),
756
+ line,
757
+ font=larger_font,
758
+ fill=(255, 255, 255),
759
+ )
760
+ text_y += line_height
761
+
762
+ # Add logo if provided (top right corner)
763
+ if logo:
764
+ logo_margin = 20
765
+ logo_x = image.width - logo.width - logo_margin
766
+ image.paste(logo, (logo_x, logo_margin),
767
+ logo if logo.mode == 'RGBA' else None)
768
+
769
+ return image
770
+
771
+ def _add_overlay_to_image(
772
+ self,
773
+ image: Image.Image,
774
+ step_number: int,
775
+ goal_text: str,
776
+ regular_font: ImageFont.FreeTypeFont,
777
+ title_font: ImageFont.FreeTypeFont,
778
+ margin: int,
779
+ logo: Optional[Image.Image] = None,
780
+ ) -> Image.Image:
781
+ """Add step number and goal overlay to an image."""
782
+ image = image.convert('RGBA')
783
+ txt_layer = Image.new('RGBA', image.size, (0, 0, 0, 0))
784
+ draw = ImageDraw.Draw(txt_layer)
785
+
786
+ # Add step number (bottom left)
787
+ step_text = str(step_number)
788
+ step_bbox = draw.textbbox((0, 0), step_text, font=title_font)
789
+ step_width = step_bbox[2] - step_bbox[0]
790
+ step_height = step_bbox[3] - step_bbox[1]
791
+
792
+ # Position step number in bottom left
793
+ x_step = margin + 10 # Slight additional offset from edge
794
+ y_step = image.height - margin - step_height - 10 # Slight offset from bottom
795
+
796
+ # Draw rounded rectangle background for step number
797
+ padding = 20 # Increased padding
798
+ step_bg_bbox = (
799
+ x_step - padding,
800
+ y_step - padding,
801
+ x_step + step_width + padding,
802
+ y_step + step_height + padding,
803
+ )
804
+ draw.rounded_rectangle(
805
+ step_bg_bbox,
806
+ radius=15, # Add rounded corners
807
+ fill=(0, 0, 0, 255),
808
+ )
809
+
810
+ # Draw step number
811
+ draw.text(
812
+ (x_step, y_step),
813
+ step_text,
814
+ font=title_font,
815
+ fill=(255, 255, 255, 255),
816
+ )
817
+
818
+ # Draw goal text (centered, bottom)
819
+ max_width = image.width - (4 * margin)
820
+ wrapped_goal = self._wrap_text(goal_text, title_font, max_width)
821
+ goal_bbox = draw.multiline_textbbox(
822
+ (0, 0), wrapped_goal, font=title_font)
823
+ goal_width = goal_bbox[2] - goal_bbox[0]
824
+ goal_height = goal_bbox[3] - goal_bbox[1]
825
+
826
+ # Center goal text horizontally, place above step number
827
+ x_goal = (image.width - goal_width) // 2
828
+ y_goal = y_step - goal_height - padding * 4 # More space between step and goal
829
+
830
+ # Draw rounded rectangle background for goal
831
+ padding_goal = 25 # Increased padding for goal
832
+ goal_bg_bbox = (
833
+ x_goal - padding_goal, # Remove extra space for logo
834
+ y_goal - padding_goal,
835
+ x_goal + goal_width + padding_goal,
836
+ y_goal + goal_height + padding_goal,
837
+ )
838
+ draw.rounded_rectangle(
839
+ goal_bg_bbox,
840
+ radius=15, # Add rounded corners
841
+ fill=(0, 0, 0, 255),
842
+ )
843
+
844
+ # Draw goal text
845
+ draw.multiline_text(
846
+ (x_goal, y_goal),
847
+ wrapped_goal,
848
+ font=title_font,
849
+ fill=(255, 255, 255, 255),
850
+ align='center',
851
+ )
852
+
853
+ # Add logo if provided (top right corner)
854
+ if logo:
855
+ logo_layer = Image.new('RGBA', image.size, (0, 0, 0, 0))
856
+ logo_margin = 20
857
+ logo_x = image.width - logo.width - logo_margin
858
+ logo_layer.paste(logo, (logo_x, logo_margin),
859
+ logo if logo.mode == 'RGBA' else None)
860
+ txt_layer = Image.alpha_composite(logo_layer, txt_layer)
861
+
862
+ # Composite and convert
863
+ result = Image.alpha_composite(image, txt_layer)
864
+ return result.convert('RGB')
865
+
866
+ def _wrap_text(self, text: str, font: ImageFont.FreeTypeFont, max_width: int) -> str:
867
+ """
868
+ Wrap text to fit within a given width.
869
+
870
+ Args:
871
+ text: Text to wrap
872
+ font: Font to use for text
873
+ max_width: Maximum width in pixels
874
+
875
+ Returns:
876
+ Wrapped text with newlines
877
+ """
878
+ words = text.split()
879
+ lines = []
880
+ current_line = []
881
+
882
+ for word in words:
883
+ current_line.append(word)
884
+ line = ' '.join(current_line)
885
+ bbox = font.getbbox(line)
886
+ if bbox[2] > max_width:
887
+ if len(current_line) == 1:
888
+ lines.append(current_line.pop())
889
+ else:
890
+ current_line.pop()
891
+ lines.append(' '.join(current_line))
892
+ current_line = [word]
893
+
894
+ if current_line:
895
+ lines.append(' '.join(current_line))
896
+
897
+ return '\n'.join(lines)
898
+
899
+ def _create_frame(
900
+ self, screenshot: str, text: str, step_number: int, width: int = 1200, height: int = 800
901
+ ) -> Image.Image:
902
+ """Create a frame for the GIF with improved styling"""
903
+
904
+ # Create base image
905
+ frame = Image.new('RGB', (width, height), 'white')
906
+
907
+ # Load and resize screenshot
908
+ screenshot_img = Image.open(BytesIO(base64.b64decode(screenshot)))
909
+ # Leave space for text
910
+ screenshot_img.thumbnail((width - 40, height - 160))
911
+
912
+ # Calculate positions
913
+ screenshot_x = (width - screenshot_img.width) // 2
914
+ screenshot_y = 120 # Leave space for header
915
+
916
+ # Draw screenshot
917
+ frame.paste(screenshot_img, (screenshot_x, screenshot_y))
918
+
919
+ # Load lumivor logo
920
+ logo_size = 100 # Increased size for lumivor logo
921
+ logo_path = os.path.join(os.path.dirname(
922
+ __file__), 'assets/lumivor-logo.png')
923
+ if os.path.exists(logo_path):
924
+ logo = Image.open(logo_path)
925
+ logo.thumbnail((logo_size, logo_size))
926
+ frame.paste(
927
+ logo, (width - logo_size - 20,
928
+ 20), logo if 'A' in logo.getbands() else None
929
+ )
930
+
931
+ # Create drawing context
932
+ draw = ImageDraw.Draw(frame)
933
+
934
+ # Load fonts
935
+ try:
936
+ title_font = ImageFont.truetype(
937
+ 'Arial.ttf', 36) # Increased font size
938
+ text_font = ImageFont.truetype(
939
+ 'Arial.ttf', 24) # Increased font size
940
+ # Increased font size for step number
941
+ number_font = ImageFont.truetype('Arial.ttf', 48)
942
+ except:
943
+ title_font = ImageFont.load_default()
944
+ text_font = ImageFont.load_default()
945
+ number_font = ImageFont.load_default()
946
+
947
+ # Draw task text with increased spacing
948
+ margin = 80 # Increased margin
949
+ max_text_width = width - (2 * margin)
950
+
951
+ # Create rounded rectangle for goal text
952
+ text_padding = 20
953
+ text_lines = textwrap.wrap(text, width=60)
954
+ text_height = sum(draw.textsize(line, font=text_font)
955
+ [1] for line in text_lines)
956
+ text_box_height = text_height + (2 * text_padding)
957
+
958
+ # Draw rounded rectangle background for goal
959
+ goal_bg_coords = [
960
+ margin - text_padding,
961
+ 40, # Top position
962
+ width - margin + text_padding,
963
+ 40 + text_box_height,
964
+ ]
965
+ draw.rounded_rectangle(
966
+ goal_bg_coords,
967
+ radius=15, # Increased radius for more rounded corners
968
+ fill='#f0f0f0',
969
+ )
970
+
971
+ # Draw lumivor small logo in top left of goal box
972
+ small_logo_size = 30
973
+ if os.path.exists(logo_path):
974
+ small_logo = Image.open(logo_path)
975
+ small_logo.thumbnail((small_logo_size, small_logo_size))
976
+ frame.paste(
977
+ small_logo,
978
+ (margin - text_padding + 10, 45), # Positioned inside goal box
979
+ small_logo if 'A' in small_logo.getbands() else None,
980
+ )
981
+
982
+ # Draw text with proper wrapping
983
+ y = 50 # Starting y position for text
984
+ for line in text_lines:
985
+ draw.text((margin + small_logo_size + 20, y),
986
+ line, font=text_font, fill='black')
987
+ y += draw.textsize(line, font=text_font)[1] + 5
988
+
989
+ # Draw step number with rounded background
990
+ number_text = str(step_number)
991
+ number_size = draw.textsize(number_text, font=number_font)
992
+ number_padding = 20
993
+ number_box_width = number_size[0] + (2 * number_padding)
994
+ number_box_height = number_size[1] + (2 * number_padding)
995
+
996
+ # Draw rounded rectangle for step number
997
+ number_bg_coords = [
998
+ 20, # Left position
999
+ height - number_box_height - 20, # Bottom position
1000
+ 20 + number_box_width,
1001
+ height - 20,
1002
+ ]
1003
+ draw.rounded_rectangle(
1004
+ number_bg_coords,
1005
+ radius=15,
1006
+ fill='#007AFF', # Blue background
1007
+ )
1008
+
1009
+ # Center number in its background
1010
+ number_x = number_bg_coords[0] + \
1011
+ ((number_box_width - number_size[0]) // 2)
1012
+ number_y = number_bg_coords[1] + \
1013
+ ((number_box_height - number_size[1]) // 2)
1014
+ draw.text((number_x, number_y), number_text,
1015
+ font=number_font, fill='white')
1016
+
1017
+ return frame