computer-use-ootb-internal 0.0.188__py3-none-any.whl → 0.0.190__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,666 +1,669 @@
1
- import argparse
2
- import time
3
- import json
4
- from datetime import datetime
5
- import threading
6
- import requests
7
- import platform # Add platform import
8
- import pyautogui # Add pyautogui import
9
- import webbrowser # Add webbrowser import
10
- import os # Import os for path joining
11
- import logging # Import logging
12
- import importlib # For dynamic imports
13
- import pkgutil # To find modules
14
- import sys # For logging setup
15
- import traceback # For logging setup
16
- from logging.handlers import RotatingFileHandler # For logging setup
17
- from fastapi import FastAPI, Request
18
- from fastapi.responses import JSONResponse
19
- from fastapi.middleware.cors import CORSMiddleware
20
- from computer_use_ootb_internal.computer_use_demo.tools.computer import get_screen_details
21
- from computer_use_ootb_internal.run_teachmode_ootb_args import simple_teachmode_sampling_loop
22
- from computer_use_ootb_internal.computer_use_demo.executor.teachmode_executor import TeachmodeExecutor
23
- import uvicorn # Assuming uvicorn is used to run FastAPI
24
- import concurrent.futures
25
- import asyncio
26
-
27
- # --- App Logging Setup ---
28
- try:
29
- # Log to user's AppData directory for better accessibility
30
- log_dir_base = os.environ.get('APPDATA', os.path.expanduser('~'))
31
- log_dir = os.path.join(log_dir_base, 'OOTBAppLogs')
32
- os.makedirs(log_dir, exist_ok=True)
33
- log_file = os.path.join(log_dir, 'ootb_app.log')
34
-
35
- log_format = '%(asctime)s - %(levelname)s - %(process)d - %(threadName)s - %(message)s'
36
- log_level = logging.INFO # Or logging.DEBUG for more detail
37
-
38
- # Use rotating file handler
39
- handler = RotatingFileHandler(log_file, maxBytes=5*1024*1024, backupCount=2, encoding='utf-8')
40
- handler.setFormatter(logging.Formatter(log_format))
41
-
42
- # Configure root logger
43
- logging.basicConfig(level=log_level, handlers=[handler])
44
-
45
- # Add stream handler to see logs if running interactively (optional)
46
- # logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
47
-
48
- logging.info("="*20 + " OOTB App Starting " + "="*20)
49
- logging.info(f"Running with args: {sys.argv}")
50
- logging.info(f"Python Executable: {sys.executable}")
51
- logging.info(f"Working Directory: {os.getcwd()}")
52
- logging.info(f"User: {os.getenv('USERNAME')}")
53
-
54
- except Exception as log_setup_e:
55
- print(f"FATAL: Failed to set up logging: {log_setup_e}")
56
- # Fallback logging might be needed here if file logging fails
57
-
58
- # --- Get the root logger ---
59
- root_logger = logging.getLogger()
60
- root_logger.setLevel(log_level) # Ensure root logger level is set
61
-
62
- # --- File Handler (as before) ---
63
- file_handler = RotatingFileHandler(log_file, maxBytes=5*1024*1024, backupCount=2, encoding='utf-8')
64
- file_handler.setFormatter(logging.Formatter(log_format))
65
- root_logger.addHandler(file_handler)
66
-
67
- # --- Console Handler (New) ---
68
- console_handler = logging.StreamHandler(sys.stdout) # Log to standard output
69
- console_handler.setFormatter(logging.Formatter(log_format))
70
- root_logger.addHandler(console_handler)
71
-
72
- # --- End App Logging Setup ---
73
-
74
- app = FastAPI()
75
-
76
- # Add CORS middleware to allow requests from the frontend
77
- app.add_middleware(
78
- CORSMiddleware,
79
- allow_origins=["*"],
80
- allow_credentials=True,
81
- allow_methods=["*"],
82
- allow_headers=["*"],
83
- )
84
-
85
- # Rate limiter for API endpoints
86
- class RateLimiter:
87
- def __init__(self, interval_seconds=2):
88
- self.interval = interval_seconds
89
- self.last_request_time = {}
90
- self.lock = threading.Lock()
91
-
92
- def allow_request(self, endpoint):
93
- with self.lock:
94
- current_time = time.time()
95
- # Priority endpoints always allowed
96
- if endpoint in ["/update_params", "/update_message"]:
97
- return True
98
-
99
- # For other endpoints, apply rate limiting
100
- if endpoint not in self.last_request_time:
101
- self.last_request_time[endpoint] = current_time
102
- return True
103
-
104
- elapsed = current_time - self.last_request_time[endpoint]
105
- if elapsed < self.interval:
106
- return False
107
-
108
- self.last_request_time[endpoint] = current_time
109
- return True
110
-
111
-
112
- def log_ootb_request(server_url, ootb_request_type, data):
113
- logging.info(f"OOTB Request: Type={ootb_request_type}, Data={data}")
114
- # Keep the requests post for now if it serves a specific purpose
115
- logging_data = {
116
- "type": ootb_request_type,
117
- "data": data,
118
- "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
119
- }
120
- if not server_url.endswith("/update_ootb_logging"):
121
- server_logging_url = server_url + "/update_ootb_logging"
122
- else:
123
- server_logging_url = server_url
124
- try:
125
- requests.post(server_logging_url, json=logging_data, timeout=5)
126
- except Exception as req_log_e:
127
- logging.warning(f"Could not log ootb request to server {server_logging_url}: {req_log_e}")
128
-
129
-
130
- class SharedState:
131
- def __init__(self, args):
132
- self.args = args
133
- self.task_updated = False
134
- self.chatbot_messages = []
135
- # Store all state-related data here
136
- self.model = args.model
137
- self.task = getattr(args, 'task', "")
138
- self.selected_screen = args.selected_screen
139
- self.user_id = args.user_id
140
- self.trace_id = args.trace_id
141
- self.api_keys = args.api_keys
142
- self.server_url = args.server_url
143
- self.message_queue = []
144
- self.is_processing = False
145
- self.should_stop = False
146
- self.is_paused = False
147
- self.full_screen_game_mode = getattr(args, 'full_screen_game_mode', 0)
148
- # Add a new event to better control stopping
149
- self.stop_event = threading.Event()
150
- # Add a reference to the processing thread
151
- self.processing_thread = None
152
- self.max_steps = getattr(args, 'max_steps', 50)
153
-
154
- shared_state = None
155
- rate_limiter = RateLimiter(interval_seconds=2)
156
-
157
- # Set up logging for this module
158
- log = logging.getLogger(__name__)
159
-
160
- def prepare_environment(state):
161
- """Dynamically loads and runs preparation logic based on software name."""
162
- # Determine software name from state (user_id, trace_id, or task)
163
- software_name = ""
164
-
165
- # Check user_id first
166
- user_id = getattr(state, 'user_id', '').lower()
167
- task = getattr(state, 'task', '').lower()
168
- trace_id = getattr(state, 'trace_id', '').lower()
169
-
170
- log.info(f"Checking for software in: user_id='{user_id}', trace_id='{trace_id}', task='{task}'")
171
-
172
- # Look for known software indicators
173
- if "star rail" in user_id or "star rail" in trace_id:
174
- software_name = "star rail"
175
- elif "powerpoint" in user_id or "powerpoint" in trace_id or "powerpoint" in task:
176
- software_name = "powerpoint"
177
- elif "word" in user_id or "word" in trace_id or "word" in task:
178
- software_name = "word"
179
- elif "excel" in user_id or "excel" in trace_id or "excel" in task:
180
- software_name = "excel"
181
- elif "premiere" in user_id or "premiere" in trace_id or "premiere" in task or \
182
- "pr" in user_id or "pr" in trace_id or "pr" in task: # Check for 'premiere' or 'pr'
183
- software_name = "pr" # Module name will be pr_prepare
184
- # Add more software checks here as needed
185
-
186
- # If no specific software found, check task for keywords
187
- if not software_name:
188
- log.info("No specific software detected from IDs or task content")
189
-
190
- if not software_name:
191
- log.info("No specific software preparation identified. Skipping preparation.")
192
- return
193
-
194
- log.info(f"Identified software for preparation: '{software_name}'")
195
-
196
- # Normalize the software name to be a valid Python module name
197
- # Replace spaces/hyphens with underscores, convert to lowercase
198
- module_name_base = software_name.replace(" ", "_").replace("-", "_").lower()
199
- module_to_run = f"{module_name_base}_prepare"
200
-
201
- log.info(f"Attempting preparation for software: '{software_name}' (Module: '{module_to_run}')")
202
-
203
- try:
204
- # Construct the full module path within the package
205
- prep_package = "computer_use_ootb_internal.preparation"
206
- full_module_path = f"{prep_package}.{module_to_run}"
207
-
208
- # Dynamically import the module
209
- # Check if module exists first using pkgutil to avoid import errors
210
- log.debug(f"Looking for preparation module: {full_module_path}")
211
- loader = pkgutil.find_loader(full_module_path)
212
- if loader is None:
213
- log.warning(f"Preparation module '{full_module_path}' not found. Skipping preparation.")
214
- return
215
-
216
- log.debug(f"Importing preparation module: {full_module_path}")
217
- prep_module = importlib.import_module(full_module_path)
218
-
219
- # Check if the module has the expected function
220
- if hasattr(prep_module, "run_preparation") and callable(prep_module.run_preparation):
221
- log.info(f"Running preparation function from {full_module_path}...")
222
- prep_module.run_preparation(state)
223
- log.info(f"Preparation function from {full_module_path} completed.")
224
- else:
225
- log.warning(f"Module {full_module_path} found, but does not have a callable 'run_preparation' function. Skipping.")
226
-
227
- except ModuleNotFoundError:
228
- log.warning(f"Preparation module '{full_module_path}' not found. Skipping preparation.")
229
- except Exception as e:
230
- log.error(f"Error during dynamic preparation loading/execution for '{module_to_run}': {e}", exc_info=True)
231
-
232
-
233
- @app.post("/update_params")
234
- async def update_parameters(request: Request):
235
- logging.info("Received request to /update_params")
236
- try:
237
- data = await request.json()
238
-
239
- if 'task' not in data:
240
- return JSONResponse(
241
- content={"status": "error", "message": "Missing required field: task"},
242
- status_code=400
243
- )
244
-
245
- # Clear message histories before updating parameters
246
- shared_state.message_queue.clear()
247
- shared_state.chatbot_messages.clear()
248
- logging.info("Cleared message queue and chatbot messages.")
249
-
250
- shared_state.args = argparse.Namespace(**data)
251
- shared_state.task_updated = True
252
-
253
- # Update shared state when parameters change
254
- shared_state.model = getattr(shared_state.args, 'model', "teach-mode-gpt-4o")
255
- shared_state.task = getattr(shared_state.args, 'task', "Following the instructions to complete the task.")
256
- shared_state.selected_screen = getattr(shared_state.args, 'selected_screen', 0)
257
- shared_state.user_id = getattr(shared_state.args, 'user_id', "hero_cases")
258
- shared_state.trace_id = getattr(shared_state.args, 'trace_id', "build_scroll_combat")
259
- shared_state.api_keys = getattr(shared_state.args, 'api_keys', "sk-proj-1234567890")
260
- shared_state.server_url = getattr(shared_state.args, 'server_url', "http://ec2-44-234-43-86.us-west-2.compute.amazonaws.com")
261
- shared_state.max_steps = getattr(shared_state.args, 'max_steps', 50)
262
-
263
- log_ootb_request(shared_state.server_url, "update_params", data)
264
-
265
- # Call the (now dynamic) preparation function here, after parameters are updated
266
- prepare_environment(shared_state)
267
-
268
- logging.info("Parameters updated successfully.")
269
- return JSONResponse(
270
- content={"status": "success", "message": "Parameters updated", "new_args": vars(shared_state.args)},
271
- status_code=200
272
- )
273
- except Exception as e:
274
- logging.error("Error processing /update_params:", exc_info=True)
275
- return JSONResponse(content={"status": "error", "message": "Internal server error"}, status_code=500)
276
-
277
- @app.post("/update_message")
278
- async def update_message(request: Request):
279
- data = await request.json()
280
-
281
- if 'message' not in data:
282
- return JSONResponse(
283
- content={"status": "error", "message": "Missing required field: message"},
284
- status_code=400
285
- )
286
-
287
- log_ootb_request(shared_state.server_url, "update_message", data)
288
-
289
- message = data['message']
290
- full_screen_game_mode = data.get('full_screen_game_mode', 0) # Default to 0 if not provided
291
-
292
- # shared_state.chatbot_messages.append({"role": "user", "content": message, "type": "text"})
293
- shared_state.task = message
294
- shared_state.args.task = message
295
- shared_state.full_screen_game_mode = full_screen_game_mode
296
-
297
- # Reset stop event before starting
298
- shared_state.stop_event.clear()
299
-
300
- # Start processing if not already running
301
- if not shared_state.is_processing:
302
- # Create and store the thread
303
- shared_state.processing_thread = threading.Thread(target=process_input, daemon=True)
304
- shared_state.processing_thread.start()
305
-
306
- return JSONResponse(
307
- content={"status": "success", "message": "Message received", "task": shared_state.task},
308
- status_code=200
309
- )
310
-
311
- @app.get("/get_messages")
312
- async def get_messages(request: Request):
313
- # Apply rate limiting
314
- if not rate_limiter.allow_request(request.url.path):
315
- return JSONResponse(
316
- content={"status": "error", "message": "Rate limit exceeded. Try again after 2 seconds."},
317
- status_code=429
318
- )
319
-
320
- # log_ootb_request(shared_state.server_url, "get_messages", {})
321
-
322
- # Return all messages in the queue and clear it
323
- messages = shared_state.message_queue.copy()
324
- shared_state.message_queue = []
325
-
326
- return JSONResponse(
327
- content={"status": "success", "messages": messages},
328
- status_code=200
329
- )
330
-
331
- @app.get("/get_screens")
332
- async def get_screens(request: Request):
333
- # Apply rate limiting
334
- if not rate_limiter.allow_request(request.url.path):
335
- return JSONResponse(
336
- content={"status": "error", "message": "Rate limit exceeded. Try again after 2 seconds."},
337
- status_code=429
338
- )
339
-
340
- log_ootb_request(shared_state.server_url, "get_screens", {})
341
-
342
- screen_options, primary_index = get_screen_details()
343
-
344
- return JSONResponse(
345
- content={"status": "success", "screens": screen_options, "primary_index": primary_index},
346
- status_code=200
347
- )
348
-
349
- @app.post("/stop_processing")
350
- async def stop_processing(request: Request):
351
- # Apply rate limiting
352
- if not rate_limiter.allow_request(request.url.path):
353
- return JSONResponse(
354
- content={"status": "error", "message": "Rate limit exceeded. Try again after 2 seconds."},
355
- status_code=429
356
- )
357
-
358
- log_ootb_request(shared_state.server_url, "stop_processing", {})
359
-
360
- if shared_state.is_processing:
361
- # Set both flags to ensure stopping the current task
362
- shared_state.should_stop = True
363
- shared_state.stop_event.set()
364
-
365
- # Clear message histories
366
- shared_state.message_queue.clear()
367
- shared_state.chatbot_messages.clear()
368
- logging.info("Cleared message queue and chatbot messages during stop.")
369
-
370
- # Send an immediate message to the queue to inform the user
371
- stop_initiated_msg = {"role": "assistant", "content": f"Stopping task '{shared_state.task}'...", "type": "text", "action_type": ""}
372
- # Append the stop message AFTER clearing, so it's the only one left
373
- shared_state.message_queue.append(stop_initiated_msg)
374
- shared_state.chatbot_messages.append(stop_initiated_msg)
375
-
376
- return JSONResponse(
377
- content={"status": "success", "message": "Task is being stopped, server will remain available for new tasks"},
378
- status_code=200
379
- )
380
- else:
381
- # Clear message histories even if not processing, to ensure clean state
382
- shared_state.message_queue.clear()
383
- shared_state.chatbot_messages.clear()
384
- logging.info("Cleared message queue and chatbot messages (no active process to stop).")
385
- return JSONResponse(
386
- content={"status": "error", "message": "No active processing to stop"},
387
- status_code=400
388
- )
389
-
390
- @app.post("/toggle_pause")
391
- async def toggle_pause(request: Request):
392
- # Apply rate limiting
393
- if not rate_limiter.allow_request(request.url.path):
394
- return JSONResponse(
395
- content={"status": "error", "message": "Rate limit exceeded. Try again after 2 seconds."},
396
- status_code=429
397
- )
398
-
399
- log_ootb_request(shared_state.server_url, "toggle_pause", {})
400
-
401
- if not shared_state.is_processing:
402
- return JSONResponse(
403
- content={"status": "error", "message": "No active processing to pause/resume"},
404
- status_code=400
405
- )
406
-
407
- # Toggle the pause state
408
- shared_state.is_paused = not shared_state.is_paused
409
- current_state = shared_state.is_paused
410
-
411
- print(f"Toggled pause state to: {current_state}")
412
-
413
- status_message = "paused" if current_state else "resumed"
414
-
415
- # Add a message to the queue to inform the user
416
- if current_state:
417
- message = {"role": "assistant", "content": f"Task '{shared_state.task}' has been paused. Click Continue to resume.", "type": "text", "action_type": ""}
418
- else:
419
- message = {"role": "assistant", "content": f"Task '{shared_state.task}' has been resumed.", "type": "text", "action_type": ""}
420
-
421
- shared_state.chatbot_messages.append(message)
422
- shared_state.message_queue.append(message)
423
-
424
- return JSONResponse(
425
- content={
426
- "status": "success",
427
- "message": f"Processing {status_message}",
428
- "is_paused": current_state
429
- },
430
- status_code=200
431
- )
432
-
433
- @app.get("/status")
434
- async def get_status(request: Request):
435
- # Apply rate limiting
436
- if not rate_limiter.allow_request(request.url.path):
437
- return JSONResponse(
438
- content={"status": "error", "message": "Rate limit exceeded. Try again after 2 seconds."},
439
- status_code=429
440
- )
441
-
442
- # log_ootb_request(shared_state.server_url, "get_status", {})
443
-
444
- print(f"Status check - Processing: {shared_state.is_processing}, Paused: {shared_state.is_paused}")
445
- return JSONResponse(
446
- content={
447
- "status": "success",
448
- "is_processing": shared_state.is_processing,
449
- "is_paused": shared_state.is_paused
450
- },
451
- status_code=200
452
- )
453
-
454
- @app.post("/exec_computer_tool")
455
- async def exec_computer_tool(request: Request):
456
- logging.info("Received request to /exec_computer_tool")
457
- try:
458
- data = await request.json()
459
-
460
- # Extract parameters from the request
461
- selected_screen = data.get('selected_screen', 0)
462
- full_screen_game_mode = data.get('full_screen_game_mode', 0)
463
- response = data.get('response', {})
464
-
465
- logging.info(f"Executing TeachmodeExecutor with: screen={selected_screen}, mode={full_screen_game_mode}, response={response}")
466
-
467
- # Create TeachmodeExecutor in a separate process to avoid event loop conflicts
468
- # Since TeachmodeExecutor uses asyncio.run() internally, we need to run it in a way
469
- # that doesn't conflict with FastAPI's event loop
470
-
471
- def run_executor():
472
- executor = TeachmodeExecutor(
473
- selected_screen=selected_screen,
474
- full_screen_game_mode=full_screen_game_mode
475
- )
476
-
477
- results = []
478
- try:
479
- for action_result in executor(response):
480
- results.append(action_result)
481
- except Exception as exec_error:
482
- logging.error(f"Error executing action: {exec_error}", exc_info=True)
483
- return {"error": str(exec_error)}
484
-
485
- return results
486
-
487
- # Execute in a thread pool to avoid blocking the event loop
488
- with concurrent.futures.ThreadPoolExecutor() as pool:
489
- results = await asyncio.get_event_loop().run_in_executor(pool, run_executor)
490
-
491
- if isinstance(results, dict) and "error" in results:
492
- return JSONResponse(
493
- content={"status": "error", "message": results["error"]},
494
- status_code=500
495
- )
496
-
497
- logging.info(f"Action results: {results}")
498
-
499
- return JSONResponse(
500
- content={"status": "success", "results": results},
501
- status_code=200
502
- )
503
- except Exception as e:
504
- logging.error("Error processing /exec_computer_tool:", exc_info=True)
505
- return JSONResponse(
506
- content={"status": "error", "message": f"Internal server error: {str(e)}"},
507
- status_code=500
508
- )
509
-
510
- def process_input():
511
- global shared_state
512
- logging.info("process_input thread started.")
513
- shared_state.is_processing = True
514
- shared_state.should_stop = False
515
- shared_state.is_paused = False
516
- shared_state.stop_event.clear() # Ensure stop event is cleared at the start
517
-
518
- print(f"start sampling loop: {shared_state.chatbot_messages}")
519
- print(f"shared_state.args before sampling loop: {shared_state.args}")
520
-
521
-
522
- try:
523
- # Get the generator for the sampling loop
524
- sampling_loop = simple_teachmode_sampling_loop(
525
- model=shared_state.model,
526
- task=shared_state.task,
527
- selected_screen=shared_state.selected_screen,
528
- user_id=shared_state.user_id,
529
- trace_id=shared_state.trace_id,
530
- api_keys=shared_state.api_keys,
531
- server_url=shared_state.server_url,
532
- full_screen_game_mode=shared_state.full_screen_game_mode,
533
- max_steps=shared_state.max_steps,
534
- )
535
-
536
- # Process messages from the sampling loop
537
- for loop_msg in sampling_loop:
538
- # Check stop condition more frequently
539
- if shared_state.should_stop or shared_state.stop_event.is_set():
540
- print("Processing stopped by user")
541
- break
542
-
543
- # Check if paused and wait while paused
544
- while shared_state.is_paused and not shared_state.should_stop and not shared_state.stop_event.is_set():
545
- print(f"Processing paused at: {time.strftime('%H:%M:%S')}")
546
- # Wait a short time and check stop condition regularly
547
- for _ in range(5): # Check 5 times per second
548
- if shared_state.should_stop or shared_state.stop_event.is_set():
549
- break
550
- time.sleep(0.2)
551
-
552
- # Check again after pause loop
553
- if shared_state.should_stop or shared_state.stop_event.is_set():
554
- print("Processing stopped while paused or resuming")
555
- break
556
-
557
- shared_state.chatbot_messages.append(loop_msg)
558
- shared_state.message_queue.append(loop_msg)
559
-
560
- # Short sleep to allow stop signals to be processed
561
- for _ in range(5): # Check 5 times per second
562
- if shared_state.should_stop or shared_state.stop_event.is_set():
563
- print("Processing stopped during sleep")
564
- break
565
- time.sleep(0.1)
566
-
567
- if shared_state.should_stop or shared_state.stop_event.is_set():
568
- break
569
-
570
- except Exception as e:
571
- # Handle any exceptions in the processing loop
572
- error_msg = f"Error during task processing: {e}"
573
- print(error_msg)
574
- error_message = {"role": "assistant", "content": error_msg, "type": "error", "action_type": ""}
575
- shared_state.message_queue.append(error_message)
576
-
577
- finally:
578
- # Handle completion or interruption
579
- if shared_state.should_stop or shared_state.stop_event.is_set():
580
- stop_msg = f"Task '{shared_state.task}' was stopped. Ready for new tasks."
581
- final_message = {"role": "assistant", "content": stop_msg, "type": "text", "action_type": ""}
582
- else:
583
- complete_msg = f"Task '{shared_state.task}' completed. Thanks for using Marbot Run."
584
- final_message = {"role": "assistant", "content": complete_msg, "type": "text", "action_type": ""}
585
-
586
- shared_state.chatbot_messages.append(final_message)
587
- shared_state.message_queue.append(final_message)
588
-
589
- # Reset all state flags to allow for new tasks
590
- shared_state.is_processing = False
591
- shared_state.should_stop = False
592
- shared_state.is_paused = False
593
- shared_state.stop_event.clear()
594
- print("Processing completed, ready for new tasks")
595
- logging.info("process_input thread finished.")
596
-
597
- def main():
598
- # Logging is set up at the top level now
599
- logging.info("App main() function starting setup.")
600
- global app, shared_state, rate_limiter # Ensure app is global if needed by uvicorn
601
- parser = argparse.ArgumentParser()
602
- # Add arguments, but NOT host and port
603
- parser.add_argument("--model", type=str, default="teach-mode-gpt-4o", help="Model name")
604
- parser.add_argument("--task", type=str, default="Following the instructions to complete the task.", help="Initial task description")
605
- parser.add_argument("--selected_screen", type=int, default=0, help="Selected screen index")
606
- parser.add_argument("--user_id", type=str, default="hero_cases", help="User ID for the session")
607
- parser.add_argument("--trace_id", type=str, default="build_scroll_combat", help="Trace ID for the session")
608
- parser.add_argument("--api_keys", type=str, default="sk-proj-1234567890", help="API keys")
609
- parser.add_argument("--server_url", type=str, default="http://ec2-44-234-43-86.us-west-2.compute.amazonaws.com", help="Server URL for the session")
610
-
611
- args = parser.parse_args()
612
-
613
- # Validate args or set defaults if needed (keep these)
614
- if not hasattr(args, 'model'): args.model = "default_model"
615
- if not hasattr(args, 'task'): args.task = "default_task"
616
- if not hasattr(args, 'selected_screen'): args.selected_screen = 0
617
- if not hasattr(args, 'user_id'): args.user_id = "unknown_user"
618
- if not hasattr(args, 'trace_id'): args.trace_id = "unknown_trace"
619
- if not hasattr(args, 'api_keys'): args.api_keys = "none"
620
- if not hasattr(args, 'server_url'): args.server_url = "none"
621
-
622
- shared_state = SharedState(args)
623
- rate_limiter = RateLimiter(interval_seconds=2) # Re-initialize rate limiter
624
- logging.info(f"Shared state initialized for user: {args.user_id}")
625
-
626
- # --- Restore original port calculation logic ---
627
- port = 7888 # Default port
628
- host = "0.0.0.0" # Listen on all interfaces
629
-
630
- if platform.system() == "Windows":
631
- try:
632
- username = os.environ["USERNAME"].lower()
633
- logging.info(f"Determining port based on Windows username: {username}")
634
- if username == "altair":
635
- port = 14000
636
- elif username.startswith("guest") and username[5:].isdigit():
637
- num = int(username[5:])
638
- if 1 <= num <= 10: # Assuming max 10 guests for this range
639
- port = 14000 + num
640
- else:
641
- logging.warning(f"Guest user number {num} out of range (1-10), using default port {port}.")
642
- else:
643
- logging.info(f"Username '{username}' doesn't match specific rules, using default port {port}.")
644
- except Exception as e:
645
- logging.error(f"Error determining port from username: {e}. Using default port {port}.", exc_info=True)
646
- else:
647
- logging.info(f"Not running on Windows, using default port {port}.")
648
- # --- End of restored port calculation ---
649
-
650
- logging.info(f"Final Host={host}, Port={port}")
651
-
652
- try:
653
- logging.info(f"Starting Uvicorn server on {host}:{port}")
654
- # Use the calculated port and specific host
655
- uvicorn.run(app, host=host, port=port)
656
- logging.info("Uvicorn server stopped.")
657
- except Exception as main_e:
658
- logging.error("Error in main execution:", exc_info=True)
659
- finally:
660
- logging.info("App main() function finished.")
661
-
662
- if __name__ == "__main__":
663
- main()
664
-
665
- # Test log_ootb_request
1
+ import argparse
2
+ import time
3
+ import json
4
+ from datetime import datetime
5
+ import threading
6
+ import requests
7
+ import platform # Add platform import
8
+ import pyautogui # Add pyautogui import
9
+ import webbrowser # Add webbrowser import
10
+ import os # Import os for path joining
11
+ import logging # Import logging
12
+ import importlib # For dynamic imports
13
+ import pkgutil # To find modules
14
+ import sys # For logging setup
15
+ import traceback # For logging setup
16
+ from logging.handlers import RotatingFileHandler # For logging setup
17
+ from fastapi import FastAPI, Request
18
+ from fastapi.responses import JSONResponse
19
+ from fastapi.middleware.cors import CORSMiddleware
20
+ from computer_use_ootb_internal.computer_use_demo.tools.computer import get_screen_details
21
+ from computer_use_ootb_internal.run_teachmode_ootb_args import simple_teachmode_sampling_loop
22
+ from computer_use_ootb_internal.computer_use_demo.executor.teachmode_executor import TeachmodeExecutor
23
+ import uvicorn # Assuming uvicorn is used to run FastAPI
24
+ import concurrent.futures
25
+ import asyncio
26
+
27
+ # --- App Logging Setup ---
28
+ try:
29
+ # Log to user's AppData directory for better accessibility
30
+ log_dir_base = os.environ.get('APPDATA', os.path.expanduser('~'))
31
+ log_dir = os.path.join(log_dir_base, 'OOTBAppLogs')
32
+ os.makedirs(log_dir, exist_ok=True)
33
+ log_file = os.path.join(log_dir, 'ootb_app.log')
34
+
35
+ log_format = '%(asctime)s - %(levelname)s - %(process)d - %(threadName)s - %(message)s'
36
+ log_level = logging.INFO # Or logging.DEBUG for more detail
37
+
38
+ # Use rotating file handler
39
+ handler = RotatingFileHandler(log_file, maxBytes=5*1024*1024, backupCount=2, encoding='utf-8')
40
+ handler.setFormatter(logging.Formatter(log_format))
41
+
42
+ # Configure root logger
43
+ logging.basicConfig(level=log_level, handlers=[handler])
44
+
45
+ # Add stream handler to see logs if running interactively (optional)
46
+ # logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
47
+
48
+ logging.info("="*20 + " OOTB App Starting " + "="*20)
49
+ logging.info(f"Running with args: {sys.argv}")
50
+ logging.info(f"Python Executable: {sys.executable}")
51
+ logging.info(f"Working Directory: {os.getcwd()}")
52
+ logging.info(f"User: {os.getenv('USERNAME')}")
53
+
54
+ except Exception as log_setup_e:
55
+ print(f"FATAL: Failed to set up logging: {log_setup_e}")
56
+ # Fallback logging might be needed here if file logging fails
57
+
58
+ # --- Get the root logger ---
59
+ root_logger = logging.getLogger()
60
+ root_logger.setLevel(log_level) # Ensure root logger level is set
61
+
62
+ # --- File Handler (as before) ---
63
+ file_handler = RotatingFileHandler(log_file, maxBytes=5*1024*1024, backupCount=2, encoding='utf-8')
64
+ file_handler.setFormatter(logging.Formatter(log_format))
65
+ root_logger.addHandler(file_handler)
66
+
67
+ # --- Console Handler (New) ---
68
+ console_handler = logging.StreamHandler(sys.stdout) # Log to standard output
69
+ console_handler.setFormatter(logging.Formatter(log_format))
70
+ root_logger.addHandler(console_handler)
71
+
72
+ # --- End App Logging Setup ---
73
+
74
+ app = FastAPI()
75
+
76
+ # Add CORS middleware to allow requests from the frontend
77
+ app.add_middleware(
78
+ CORSMiddleware,
79
+ allow_origins=["*"],
80
+ allow_credentials=True,
81
+ allow_methods=["*"],
82
+ allow_headers=["*"],
83
+ )
84
+
85
+ # Rate limiter for API endpoints
86
+ class RateLimiter:
87
+ def __init__(self, interval_seconds=2):
88
+ self.interval = interval_seconds
89
+ self.last_request_time = {}
90
+ self.lock = threading.Lock()
91
+
92
+ def allow_request(self, endpoint):
93
+ with self.lock:
94
+ current_time = time.time()
95
+ # Priority endpoints always allowed
96
+ if endpoint in ["/update_params", "/update_message"]:
97
+ return True
98
+
99
+ # For other endpoints, apply rate limiting
100
+ if endpoint not in self.last_request_time:
101
+ self.last_request_time[endpoint] = current_time
102
+ return True
103
+
104
+ elapsed = current_time - self.last_request_time[endpoint]
105
+ if elapsed < self.interval:
106
+ return False
107
+
108
+ self.last_request_time[endpoint] = current_time
109
+ return True
110
+
111
+
112
+ def log_ootb_request(server_url, ootb_request_type, data):
113
+ logging.info(f"OOTB Request: Type={ootb_request_type}, Data={data}")
114
+ # Keep the requests post for now if it serves a specific purpose
115
+ logging_data = {
116
+ "type": ootb_request_type,
117
+ "data": data,
118
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
119
+ }
120
+ if not server_url.endswith("/update_ootb_logging"):
121
+ server_logging_url = server_url + "/update_ootb_logging"
122
+ else:
123
+ server_logging_url = server_url
124
+ try:
125
+ requests.post(server_logging_url, json=logging_data, timeout=5)
126
+ except Exception as req_log_e:
127
+ logging.warning(f"Could not log ootb request to server {server_logging_url}: {req_log_e}")
128
+
129
+
130
+ class SharedState:
131
+ def __init__(self, args):
132
+ self.args = args
133
+ self.task_updated = False
134
+ self.chatbot_messages = []
135
+ # Store all state-related data here
136
+ self.model = args.model
137
+ self.task = getattr(args, 'task', "")
138
+ self.selected_screen = args.selected_screen
139
+ self.user_id = args.user_id
140
+ self.trace_id = args.trace_id
141
+ self.api_keys = args.api_keys
142
+ self.server_url = args.server_url
143
+ self.message_queue = []
144
+ self.is_processing = False
145
+ self.should_stop = False
146
+ self.is_paused = False
147
+ self.full_screen_game_mode = getattr(args, 'full_screen_game_mode', 0)
148
+ # Add a new event to better control stopping
149
+ self.stop_event = threading.Event()
150
+ # Add a reference to the processing thread
151
+ self.processing_thread = None
152
+ self.max_steps = getattr(args, 'max_steps', 50)
153
+
154
+ shared_state = None
155
+ rate_limiter = RateLimiter(interval_seconds=2)
156
+
157
+ # Set up logging for this module
158
+ log = logging.getLogger(__name__)
159
+
160
+ def prepare_environment(state):
161
+ """Dynamically loads and runs preparation logic based on software name."""
162
+ # Determine software name from state (user_id, trace_id, or task)
163
+ software_name = ""
164
+
165
+ # Check user_id first
166
+ user_id = getattr(state, 'user_id', '').lower()
167
+ task = getattr(state, 'task', '').lower()
168
+ trace_id = getattr(state, 'trace_id', '').lower()
169
+
170
+ log.info(f"Checking for software in: user_id='{user_id}', trace_id='{trace_id}', task='{task}'")
171
+
172
+ # Look for known software indicators
173
+ if "star rail" in user_id or "star rail" in trace_id:
174
+ software_name = "star rail"
175
+ elif "powerpoint" in user_id or "powerpoint" in trace_id or "powerpoint" in task:
176
+ software_name = "powerpoint"
177
+ elif "word" in user_id or "word" in trace_id or "word" in task:
178
+ software_name = "word"
179
+ elif "excel" in user_id or "excel" in trace_id or "excel" in task:
180
+ software_name = "excel"
181
+ elif "premiere" in user_id or "premiere" in trace_id or "premiere" in task or \
182
+ "pr" in user_id or "pr" in trace_id or "pr" in task: # Check for 'premiere' or 'pr'
183
+ software_name = "pr" # Module name will be pr_prepare
184
+ # Add more software checks here as needed
185
+
186
+ # If no specific software found, check task for keywords
187
+ if not software_name:
188
+ log.info("No specific software detected from IDs or task content")
189
+
190
+ if not software_name:
191
+ log.info("No specific software preparation identified. Skipping preparation.")
192
+ return
193
+
194
+ log.info(f"Identified software for preparation: '{software_name}'")
195
+
196
+ # Normalize the software name to be a valid Python module name
197
+ # Replace spaces/hyphens with underscores, convert to lowercase
198
+ module_name_base = software_name.replace(" ", "_").replace("-", "_").lower()
199
+ module_to_run = f"{module_name_base}_prepare"
200
+
201
+ log.info(f"Attempting preparation for software: '{software_name}' (Module: '{module_to_run}')")
202
+
203
+ try:
204
+ # Construct the full module path within the package
205
+ prep_package = "computer_use_ootb_internal.preparation"
206
+ full_module_path = f"{prep_package}.{module_to_run}"
207
+
208
+ # Dynamically import the module
209
+ # Check if module exists first using pkgutil to avoid import errors
210
+ log.debug(f"Looking for preparation module: {full_module_path}")
211
+ loader = pkgutil.find_loader(full_module_path)
212
+ if loader is None:
213
+ log.warning(f"Preparation module '{full_module_path}' not found. Skipping preparation.")
214
+ return
215
+
216
+ log.debug(f"Importing preparation module: {full_module_path}")
217
+ prep_module = importlib.import_module(full_module_path)
218
+
219
+ # Check if the module has the expected function
220
+ if hasattr(prep_module, "run_preparation") and callable(prep_module.run_preparation):
221
+ log.info(f"Running preparation function from {full_module_path}...")
222
+ prep_module.run_preparation(state)
223
+ log.info(f"Preparation function from {full_module_path} completed.")
224
+ else:
225
+ log.warning(f"Module {full_module_path} found, but does not have a callable 'run_preparation' function. Skipping.")
226
+
227
+ except ModuleNotFoundError:
228
+ log.warning(f"Preparation module '{full_module_path}' not found. Skipping preparation.")
229
+ except Exception as e:
230
+ log.error(f"Error during dynamic preparation loading/execution for '{module_to_run}': {e}", exc_info=True)
231
+
232
+
233
+ @app.post("/update_params")
234
+ async def update_parameters(request: Request):
235
+ logging.info("Received request to /update_params")
236
+ try:
237
+ data = await request.json()
238
+
239
+ if 'task' not in data:
240
+ return JSONResponse(
241
+ content={"status": "error", "message": "Missing required field: task"},
242
+ status_code=400
243
+ )
244
+
245
+ # Clear message histories before updating parameters
246
+ shared_state.message_queue.clear()
247
+ shared_state.chatbot_messages.clear()
248
+ logging.info("Cleared message queue and chatbot messages.")
249
+
250
+ shared_state.args = argparse.Namespace(**data)
251
+ shared_state.task_updated = True
252
+
253
+ # Update shared state when parameters change
254
+ shared_state.model = getattr(shared_state.args, 'model', "teach-mode-gpt-4o")
255
+ shared_state.task = getattr(shared_state.args, 'task', "Following the instructions to complete the task.")
256
+ shared_state.selected_screen = getattr(shared_state.args, 'selected_screen', 0)
257
+ shared_state.user_id = getattr(shared_state.args, 'user_id', "hero_cases")
258
+ shared_state.trace_id = getattr(shared_state.args, 'trace_id', "build_scroll_combat")
259
+ shared_state.api_keys = getattr(shared_state.args, 'api_keys', "sk-proj-1234567890")
260
+ shared_state.server_url = getattr(shared_state.args, 'server_url', "http://ec2-44-234-43-86.us-west-2.compute.amazonaws.com")
261
+ shared_state.max_steps = getattr(shared_state.args, 'max_steps', 50)
262
+ shared_state.full_screen_game_mode = getattr(shared_state.args, 'full_screen_game_mode', 0)
263
+
264
+ log_ootb_request(shared_state.server_url, "update_params", data)
265
+
266
+ # Call the (now dynamic) preparation function here, after parameters are updated
267
+ prepare_environment(shared_state)
268
+
269
+ logging.info("Parameters updated successfully.")
270
+ return JSONResponse(
271
+ content={"status": "success", "message": "Parameters updated", "new_args": vars(shared_state.args)},
272
+ status_code=200
273
+ )
274
+ except Exception as e:
275
+ logging.error("Error processing /update_params:", exc_info=True)
276
+ return JSONResponse(content={"status": "error", "message": "Internal server error"}, status_code=500)
277
+
278
+ @app.post("/update_message")
279
+ async def update_message(request: Request):
280
+ data = await request.json()
281
+
282
+ if 'message' not in data:
283
+ return JSONResponse(
284
+ content={"status": "error", "message": "Missing required field: message"},
285
+ status_code=400
286
+ )
287
+
288
+ log_ootb_request(shared_state.server_url, "update_message", data)
289
+
290
+ message = data['message']
291
+
292
+ # shared_state.chatbot_messages.append({"role": "user", "content": message, "type": "text"})
293
+ shared_state.task = message
294
+ shared_state.args.task = message
295
+
296
+ # TODO: adaptively change full_screen_game_mode
297
+ # full_screen_game_mode = data.get('full_screen_game_mode', 0) # Default to 0 if not provided
298
+ # shared_state.full_screen_game_mode = full_screen_game_mode
299
+
300
+ # Reset stop event before starting
301
+ shared_state.stop_event.clear()
302
+
303
+ # Start processing if not already running
304
+ if not shared_state.is_processing:
305
+ # Create and store the thread
306
+ shared_state.processing_thread = threading.Thread(target=process_input, daemon=True)
307
+ shared_state.processing_thread.start()
308
+
309
+ return JSONResponse(
310
+ content={"status": "success", "message": "Message received", "task": shared_state.task},
311
+ status_code=200
312
+ )
313
+
314
+ @app.get("/get_messages")
315
+ async def get_messages(request: Request):
316
+ # Apply rate limiting
317
+ if not rate_limiter.allow_request(request.url.path):
318
+ return JSONResponse(
319
+ content={"status": "error", "message": "Rate limit exceeded. Try again after 2 seconds."},
320
+ status_code=429
321
+ )
322
+
323
+ # log_ootb_request(shared_state.server_url, "get_messages", {})
324
+
325
+ # Return all messages in the queue and clear it
326
+ messages = shared_state.message_queue.copy()
327
+ shared_state.message_queue = []
328
+
329
+ return JSONResponse(
330
+ content={"status": "success", "messages": messages},
331
+ status_code=200
332
+ )
333
+
334
+ @app.get("/get_screens")
335
+ async def get_screens(request: Request):
336
+ # Apply rate limiting
337
+ if not rate_limiter.allow_request(request.url.path):
338
+ return JSONResponse(
339
+ content={"status": "error", "message": "Rate limit exceeded. Try again after 2 seconds."},
340
+ status_code=429
341
+ )
342
+
343
+ log_ootb_request(shared_state.server_url, "get_screens", {})
344
+
345
+ screen_options, primary_index = get_screen_details()
346
+
347
+ return JSONResponse(
348
+ content={"status": "success", "screens": screen_options, "primary_index": primary_index},
349
+ status_code=200
350
+ )
351
+
352
+ @app.post("/stop_processing")
353
+ async def stop_processing(request: Request):
354
+ # Apply rate limiting
355
+ if not rate_limiter.allow_request(request.url.path):
356
+ return JSONResponse(
357
+ content={"status": "error", "message": "Rate limit exceeded. Try again after 2 seconds."},
358
+ status_code=429
359
+ )
360
+
361
+ log_ootb_request(shared_state.server_url, "stop_processing", {})
362
+
363
+ if shared_state.is_processing:
364
+ # Set both flags to ensure stopping the current task
365
+ shared_state.should_stop = True
366
+ shared_state.stop_event.set()
367
+
368
+ # Clear message histories
369
+ shared_state.message_queue.clear()
370
+ shared_state.chatbot_messages.clear()
371
+ logging.info("Cleared message queue and chatbot messages during stop.")
372
+
373
+ # Send an immediate message to the queue to inform the user
374
+ stop_initiated_msg = {"role": "assistant", "content": f"Stopping task '{shared_state.task}'...", "type": "text", "action_type": ""}
375
+ # Append the stop message AFTER clearing, so it's the only one left
376
+ shared_state.message_queue.append(stop_initiated_msg)
377
+ shared_state.chatbot_messages.append(stop_initiated_msg)
378
+
379
+ return JSONResponse(
380
+ content={"status": "success", "message": "Task is being stopped, server will remain available for new tasks"},
381
+ status_code=200
382
+ )
383
+ else:
384
+ # Clear message histories even if not processing, to ensure clean state
385
+ shared_state.message_queue.clear()
386
+ shared_state.chatbot_messages.clear()
387
+ logging.info("Cleared message queue and chatbot messages (no active process to stop).")
388
+ return JSONResponse(
389
+ content={"status": "error", "message": "No active processing to stop"},
390
+ status_code=400
391
+ )
392
+
393
+ @app.post("/toggle_pause")
394
+ async def toggle_pause(request: Request):
395
+ # Apply rate limiting
396
+ if not rate_limiter.allow_request(request.url.path):
397
+ return JSONResponse(
398
+ content={"status": "error", "message": "Rate limit exceeded. Try again after 2 seconds."},
399
+ status_code=429
400
+ )
401
+
402
+ log_ootb_request(shared_state.server_url, "toggle_pause", {})
403
+
404
+ if not shared_state.is_processing:
405
+ return JSONResponse(
406
+ content={"status": "error", "message": "No active processing to pause/resume"},
407
+ status_code=400
408
+ )
409
+
410
+ # Toggle the pause state
411
+ shared_state.is_paused = not shared_state.is_paused
412
+ current_state = shared_state.is_paused
413
+
414
+ print(f"Toggled pause state to: {current_state}")
415
+
416
+ status_message = "paused" if current_state else "resumed"
417
+
418
+ # Add a message to the queue to inform the user
419
+ if current_state:
420
+ message = {"role": "assistant", "content": f"Task '{shared_state.task}' has been paused. Click Continue to resume.", "type": "text", "action_type": ""}
421
+ else:
422
+ message = {"role": "assistant", "content": f"Task '{shared_state.task}' has been resumed.", "type": "text", "action_type": ""}
423
+
424
+ shared_state.chatbot_messages.append(message)
425
+ shared_state.message_queue.append(message)
426
+
427
+ return JSONResponse(
428
+ content={
429
+ "status": "success",
430
+ "message": f"Processing {status_message}",
431
+ "is_paused": current_state
432
+ },
433
+ status_code=200
434
+ )
435
+
436
+ @app.get("/status")
437
+ async def get_status(request: Request):
438
+ # Apply rate limiting
439
+ if not rate_limiter.allow_request(request.url.path):
440
+ return JSONResponse(
441
+ content={"status": "error", "message": "Rate limit exceeded. Try again after 2 seconds."},
442
+ status_code=429
443
+ )
444
+
445
+ # log_ootb_request(shared_state.server_url, "get_status", {})
446
+
447
+ print(f"Status check - Processing: {shared_state.is_processing}, Paused: {shared_state.is_paused}")
448
+ return JSONResponse(
449
+ content={
450
+ "status": "success",
451
+ "is_processing": shared_state.is_processing,
452
+ "is_paused": shared_state.is_paused
453
+ },
454
+ status_code=200
455
+ )
456
+
457
+ @app.post("/exec_computer_tool")
458
+ async def exec_computer_tool(request: Request):
459
+ logging.info("Received request to /exec_computer_tool")
460
+ try:
461
+ data = await request.json()
462
+
463
+ # Extract parameters from the request
464
+ selected_screen = data.get('selected_screen', 0)
465
+ full_screen_game_mode = data.get('full_screen_game_mode', 0)
466
+ response = data.get('response', {})
467
+
468
+ logging.info(f"Executing TeachmodeExecutor with: screen={selected_screen}, mode={full_screen_game_mode}, response={response}")
469
+
470
+ # Create TeachmodeExecutor in a separate process to avoid event loop conflicts
471
+ # Since TeachmodeExecutor uses asyncio.run() internally, we need to run it in a way
472
+ # that doesn't conflict with FastAPI's event loop
473
+
474
+ def run_executor():
475
+ executor = TeachmodeExecutor(
476
+ selected_screen=selected_screen,
477
+ full_screen_game_mode=full_screen_game_mode
478
+ )
479
+
480
+ results = []
481
+ try:
482
+ for action_result in executor(response):
483
+ results.append(action_result)
484
+ except Exception as exec_error:
485
+ logging.error(f"Error executing action: {exec_error}", exc_info=True)
486
+ return {"error": str(exec_error)}
487
+
488
+ return results
489
+
490
+ # Execute in a thread pool to avoid blocking the event loop
491
+ with concurrent.futures.ThreadPoolExecutor() as pool:
492
+ results = await asyncio.get_event_loop().run_in_executor(pool, run_executor)
493
+
494
+ if isinstance(results, dict) and "error" in results:
495
+ return JSONResponse(
496
+ content={"status": "error", "message": results["error"]},
497
+ status_code=500
498
+ )
499
+
500
+ logging.info(f"Action results: {results}")
501
+
502
+ return JSONResponse(
503
+ content={"status": "success", "results": results},
504
+ status_code=200
505
+ )
506
+ except Exception as e:
507
+ logging.error("Error processing /exec_computer_tool:", exc_info=True)
508
+ return JSONResponse(
509
+ content={"status": "error", "message": f"Internal server error: {str(e)}"},
510
+ status_code=500
511
+ )
512
+
513
+ def process_input():
514
+ global shared_state
515
+ logging.info("process_input thread started.")
516
+ shared_state.is_processing = True
517
+ shared_state.should_stop = False
518
+ shared_state.is_paused = False
519
+ shared_state.stop_event.clear() # Ensure stop event is cleared at the start
520
+
521
+ print(f"start sampling loop: {shared_state.chatbot_messages}")
522
+ print(f"shared_state.args before sampling loop: {shared_state.args}")
523
+
524
+
525
+ try:
526
+ # Get the generator for the sampling loop
527
+ sampling_loop = simple_teachmode_sampling_loop(
528
+ model=shared_state.model,
529
+ task=shared_state.task,
530
+ selected_screen=shared_state.selected_screen,
531
+ user_id=shared_state.user_id,
532
+ trace_id=shared_state.trace_id,
533
+ api_keys=shared_state.api_keys,
534
+ server_url=shared_state.server_url,
535
+ full_screen_game_mode=shared_state.full_screen_game_mode,
536
+ max_steps=shared_state.max_steps,
537
+ )
538
+
539
+ # Process messages from the sampling loop
540
+ for loop_msg in sampling_loop:
541
+ # Check stop condition more frequently
542
+ if shared_state.should_stop or shared_state.stop_event.is_set():
543
+ print("Processing stopped by user")
544
+ break
545
+
546
+ # Check if paused and wait while paused
547
+ while shared_state.is_paused and not shared_state.should_stop and not shared_state.stop_event.is_set():
548
+ print(f"Processing paused at: {time.strftime('%H:%M:%S')}")
549
+ # Wait a short time and check stop condition regularly
550
+ for _ in range(5): # Check 5 times per second
551
+ if shared_state.should_stop or shared_state.stop_event.is_set():
552
+ break
553
+ time.sleep(0.2)
554
+
555
+ # Check again after pause loop
556
+ if shared_state.should_stop or shared_state.stop_event.is_set():
557
+ print("Processing stopped while paused or resuming")
558
+ break
559
+
560
+ shared_state.chatbot_messages.append(loop_msg)
561
+ shared_state.message_queue.append(loop_msg)
562
+
563
+ # Short sleep to allow stop signals to be processed
564
+ for _ in range(5): # Check 5 times per second
565
+ if shared_state.should_stop or shared_state.stop_event.is_set():
566
+ print("Processing stopped during sleep")
567
+ break
568
+ time.sleep(0.1)
569
+
570
+ if shared_state.should_stop or shared_state.stop_event.is_set():
571
+ break
572
+
573
+ except Exception as e:
574
+ # Handle any exceptions in the processing loop
575
+ error_msg = f"Error during task processing: {e}"
576
+ print(error_msg)
577
+ error_message = {"role": "assistant", "content": error_msg, "type": "error", "action_type": ""}
578
+ shared_state.message_queue.append(error_message)
579
+
580
+ finally:
581
+ # Handle completion or interruption
582
+ if shared_state.should_stop or shared_state.stop_event.is_set():
583
+ stop_msg = f"Task '{shared_state.task}' was stopped. Ready for new tasks."
584
+ final_message = {"role": "assistant", "content": stop_msg, "type": "text", "action_type": ""}
585
+ else:
586
+ complete_msg = f"Task '{shared_state.task}' completed. Thanks for using Marbot Run."
587
+ final_message = {"role": "assistant", "content": complete_msg, "type": "text", "action_type": ""}
588
+
589
+ shared_state.chatbot_messages.append(final_message)
590
+ shared_state.message_queue.append(final_message)
591
+
592
+ # Reset all state flags to allow for new tasks
593
+ shared_state.is_processing = False
594
+ shared_state.should_stop = False
595
+ shared_state.is_paused = False
596
+ shared_state.stop_event.clear()
597
+ print("Processing completed, ready for new tasks")
598
+ logging.info("process_input thread finished.")
599
+
600
+ def main():
601
+ # Logging is set up at the top level now
602
+ logging.info("App main() function starting setup.")
603
+ global app, shared_state, rate_limiter # Ensure app is global if needed by uvicorn
604
+ parser = argparse.ArgumentParser()
605
+ # Add arguments, but NOT host and port
606
+ parser.add_argument("--model", type=str, default="teach-mode-gpt-4o", help="Model name")
607
+ parser.add_argument("--task", type=str, default="Following the instructions to complete the task.", help="Initial task description")
608
+ parser.add_argument("--selected_screen", type=int, default=0, help="Selected screen index")
609
+ parser.add_argument("--user_id", type=str, default="hero_cases", help="User ID for the session")
610
+ parser.add_argument("--trace_id", type=str, default="build_scroll_combat", help="Trace ID for the session")
611
+ parser.add_argument("--api_keys", type=str, default="sk-proj-1234567890", help="API keys")
612
+ parser.add_argument("--server_url", type=str, default="http://ec2-44-234-43-86.us-west-2.compute.amazonaws.com", help="Server URL for the session")
613
+
614
+ args = parser.parse_args()
615
+
616
+ # Validate args or set defaults if needed (keep these)
617
+ if not hasattr(args, 'model'): args.model = "default_model"
618
+ if not hasattr(args, 'task'): args.task = "default_task"
619
+ if not hasattr(args, 'selected_screen'): args.selected_screen = 0
620
+ if not hasattr(args, 'user_id'): args.user_id = "unknown_user"
621
+ if not hasattr(args, 'trace_id'): args.trace_id = "unknown_trace"
622
+ if not hasattr(args, 'api_keys'): args.api_keys = "none"
623
+ if not hasattr(args, 'server_url'): args.server_url = "none"
624
+
625
+ shared_state = SharedState(args)
626
+ rate_limiter = RateLimiter(interval_seconds=2) # Re-initialize rate limiter
627
+ logging.info(f"Shared state initialized for user: {args.user_id}")
628
+
629
+ # --- Restore original port calculation logic ---
630
+ port = 7888 # Default port
631
+ host = "0.0.0.0" # Listen on all interfaces
632
+
633
+ if platform.system() == "Windows":
634
+ try:
635
+ username = os.environ["USERNAME"].lower()
636
+ logging.info(f"Determining port based on Windows username: {username}")
637
+ if username == "altair":
638
+ port = 14000
639
+ elif username.startswith("guest") and username[5:].isdigit():
640
+ num = int(username[5:])
641
+ if 1 <= num <= 10: # Assuming max 10 guests for this range
642
+ port = 14000 + num
643
+ else:
644
+ logging.warning(f"Guest user number {num} out of range (1-10), using default port {port}.")
645
+ else:
646
+ logging.info(f"Username '{username}' doesn't match specific rules, using default port {port}.")
647
+ except Exception as e:
648
+ logging.error(f"Error determining port from username: {e}. Using default port {port}.", exc_info=True)
649
+ else:
650
+ logging.info(f"Not running on Windows, using default port {port}.")
651
+ # --- End of restored port calculation ---
652
+
653
+ logging.info(f"Final Host={host}, Port={port}")
654
+
655
+ try:
656
+ logging.info(f"Starting Uvicorn server on {host}:{port}")
657
+ # Use the calculated port and specific host
658
+ uvicorn.run(app, host=host, port=port)
659
+ logging.info("Uvicorn server stopped.")
660
+ except Exception as main_e:
661
+ logging.error("Error in main execution:", exc_info=True)
662
+ finally:
663
+ logging.info("App main() function finished.")
664
+
665
+ if __name__ == "__main__":
666
+ main()
667
+
668
+ # Test log_ootb_request
666
669
  log_ootb_request("http://ec2-44-234-43-86.us-west-2.compute.amazonaws.com", "test_request", {"message": "Test message"})