computer-use-ootb-internal 0.0.145__py3-none-any.whl → 0.0.147__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,552 +1,572 @@
1
- import argparse
2
- import time
3
- import json
4
- from datetime import datetime
5
- import threading
6
- import requests
7
- import platform # Add platform import
8
- import pyautogui # Add pyautogui import
9
- import webbrowser # Add webbrowser import
10
- import os # Import os for path joining
11
- import logging # Import logging
12
- import importlib # For dynamic imports
13
- import pkgutil # To find modules
14
- import sys # For logging setup
15
- import traceback # For logging setup
16
- from logging.handlers import RotatingFileHandler # For logging setup
17
- from fastapi import FastAPI, Request
18
- from fastapi.responses import JSONResponse
19
- from fastapi.middleware.cors import CORSMiddleware
20
- from computer_use_ootb_internal.computer_use_demo.tools.computer import get_screen_details
21
- from computer_use_ootb_internal.run_teachmode_ootb_args import simple_teachmode_sampling_loop
22
- import uvicorn # Assuming uvicorn is used to run FastAPI
23
-
24
- # --- App Logging Setup ---
25
- try:
26
- # Log to user's AppData directory for better accessibility
27
- log_dir_base = os.environ.get('APPDATA', os.path.expanduser('~'))
28
- log_dir = os.path.join(log_dir_base, 'OOTBAppLogs')
29
- os.makedirs(log_dir, exist_ok=True)
30
- log_file = os.path.join(log_dir, 'ootb_app.log')
31
-
32
- log_format = '%(asctime)s - %(levelname)s - %(process)d - %(threadName)s - %(message)s'
33
- log_level = logging.INFO # Or logging.DEBUG for more detail
34
-
35
- # Use rotating file handler
36
- handler = RotatingFileHandler(log_file, maxBytes=5*1024*1024, backupCount=2, encoding='utf-8')
37
- handler.setFormatter(logging.Formatter(log_format))
38
-
39
- # Configure root logger
40
- logging.basicConfig(level=log_level, handlers=[handler])
41
-
42
- # Add stream handler to see logs if running interactively (optional)
43
- # logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
44
-
45
- logging.info("="*20 + " OOTB App Starting " + "="*20)
46
- logging.info(f"Running with args: {sys.argv}")
47
- logging.info(f"Python Executable: {sys.executable}")
48
- logging.info(f"Working Directory: {os.getcwd()}")
49
- logging.info(f"User: {os.getenv('USERNAME')}")
50
-
51
- except Exception as log_setup_e:
52
- print(f"FATAL: Failed to set up logging: {log_setup_e}")
53
- # Fallback logging might be needed here if file logging fails
54
-
55
- # --- End App Logging Setup ---
56
-
57
- app = FastAPI()
58
-
59
- # Add CORS middleware to allow requests from the frontend
60
- app.add_middleware(
61
- CORSMiddleware,
62
- allow_origins=["*"],
63
- allow_credentials=True,
64
- allow_methods=["*"],
65
- allow_headers=["*"],
66
- )
67
-
68
- # Rate limiter for API endpoints
69
- class RateLimiter:
70
- def __init__(self, interval_seconds=2):
71
- self.interval = interval_seconds
72
- self.last_request_time = {}
73
- self.lock = threading.Lock()
74
-
75
- def allow_request(self, endpoint):
76
- with self.lock:
77
- current_time = time.time()
78
- # Priority endpoints always allowed
79
- if endpoint in ["/update_params", "/update_message"]:
80
- return True
81
-
82
- # For other endpoints, apply rate limiting
83
- if endpoint not in self.last_request_time:
84
- self.last_request_time[endpoint] = current_time
85
- return True
86
-
87
- elapsed = current_time - self.last_request_time[endpoint]
88
- if elapsed < self.interval:
89
- return False
90
-
91
- self.last_request_time[endpoint] = current_time
92
- return True
93
-
94
-
95
- def log_ootb_request(server_url, ootb_request_type, data):
96
- logging.info(f"OOTB Request: Type={ootb_request_type}, Data={data}")
97
- # Keep the requests post for now if it serves a specific purpose
98
- logging_data = {
99
- "type": ootb_request_type,
100
- "data": data,
101
- "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
102
- }
103
- if not server_url.endswith("/update_ootb_logging"):
104
- server_logging_url = server_url + "/update_ootb_logging"
105
- else:
106
- server_logging_url = server_url
107
- try:
108
- requests.post(server_logging_url, json=logging_data, timeout=5)
109
- except Exception as req_log_e:
110
- logging.warning(f"Could not log ootb request to server {server_logging_url}: {req_log_e}")
111
-
112
-
113
- class SharedState:
114
- def __init__(self, args):
115
- self.args = args
116
- self.task_updated = False
117
- self.chatbot_messages = []
118
- # Store all state-related data here
119
- self.model = args.model
120
- self.task = getattr(args, 'task', "")
121
- self.selected_screen = args.selected_screen
122
- self.user_id = args.user_id
123
- self.trace_id = args.trace_id
124
- self.api_keys = args.api_keys
125
- self.server_url = args.server_url
126
- self.message_queue = []
127
- self.is_processing = False
128
- self.should_stop = False
129
- self.is_paused = False
130
- # Add a new event to better control stopping
131
- self.stop_event = threading.Event()
132
- # Add a reference to the processing thread
133
- self.processing_thread = None
134
-
135
- shared_state = None
136
- rate_limiter = RateLimiter(interval_seconds=2)
137
-
138
- # Set up logging for this module
139
- log = logging.getLogger(__name__)
140
-
141
- def prepare_environment(state):
142
- """Dynamically loads and runs preparation logic based on software name."""
143
- # TODO: Replace hardcoded software name with value from shared_state when available
144
- software_name = "star rail"
145
- # Normalize the software name to be a valid Python module name
146
- # Replace spaces/hyphens with underscores, convert to lowercase
147
- module_name_base = software_name.replace(" ", "_").replace("-", "_").lower()
148
- module_to_run = f"{module_name_base}_prepare"
149
-
150
- log.info(f"Attempting preparation for software: '{software_name}' (Module: '{module_to_run}')")
151
-
152
- try:
153
- # Construct the full module path within the package
154
- prep_package = "computer_use_ootb_internal.preparation"
155
- full_module_path = f"{prep_package}.{module_to_run}"
156
-
157
- # Dynamically import the module
158
- # Check if module exists first using pkgutil to avoid import errors
159
- # Note: pkgutil.find_loader might be deprecated, consider importlib.util.find_spec
160
- loader = pkgutil.find_loader(full_module_path)
161
- if loader is None:
162
- log.warning(f"Preparation module '{full_module_path}' not found. Skipping preparation.")
163
- return
164
-
165
- prep_module = importlib.import_module(full_module_path)
166
-
167
- # Check if the module has the expected function
168
- if hasattr(prep_module, "run_preparation") and callable(prep_module.run_preparation):
169
- log.info(f"Running preparation function from {full_module_path}...")
170
- prep_module.run_preparation(state)
171
- log.info(f"Preparation function from {full_module_path} completed.")
172
- else:
173
- log.warning(f"Module {full_module_path} found, but does not have a callable 'run_preparation' function. Skipping.")
174
-
175
- except ModuleNotFoundError:
176
- log.warning(f"Preparation module '{full_module_path}' not found. Skipping preparation.")
177
- except Exception as e:
178
- log.error(f"Error during dynamic preparation loading/execution for '{module_to_run}': {e}", exc_info=True)
179
-
180
-
181
- @app.post("/update_params")
182
- async def update_parameters(request: Request):
183
- logging.info("Received request to /update_params")
184
- try:
185
- data = await request.json()
186
-
187
- if 'task' not in data:
188
- return JSONResponse(
189
- content={"status": "error", "message": "Missing required field: task"},
190
- status_code=400
191
- )
192
-
193
- # Clear message histories before updating parameters
194
- shared_state.message_queue.clear()
195
- shared_state.chatbot_messages.clear()
196
- logging.info("Cleared message queue and chatbot messages.")
197
-
198
- shared_state.args = argparse.Namespace(**data)
199
- shared_state.task_updated = True
200
-
201
- # Update shared state when parameters change
202
- shared_state.model = getattr(shared_state.args, 'model', "teach-mode-gpt-4o")
203
- shared_state.task = getattr(shared_state.args, 'task', "Following the instructions to complete the task.")
204
- shared_state.selected_screen = getattr(shared_state.args, 'selected_screen', 0)
205
- shared_state.user_id = getattr(shared_state.args, 'user_id', "hero_cases")
206
- shared_state.trace_id = getattr(shared_state.args, 'trace_id', "build_scroll_combat")
207
- shared_state.api_keys = getattr(shared_state.args, 'api_keys', "sk-proj-1234567890")
208
- shared_state.server_url = getattr(shared_state.args, 'server_url', "http://ec2-44-234-43-86.us-west-2.compute.amazonaws.com")
209
-
210
- log_ootb_request(shared_state.server_url, "update_params", data)
211
-
212
- # Call the (now dynamic) preparation function here, after parameters are updated
213
- prepare_environment(shared_state)
214
-
215
- logging.info("Parameters updated successfully.")
216
- return JSONResponse(
217
- content={"status": "success", "message": "Parameters updated", "new_args": vars(shared_state.args)},
218
- status_code=200
219
- )
220
- except Exception as e:
221
- logging.error("Error processing /update_params:", exc_info=True)
222
- return JSONResponse(content={"status": "error", "message": "Internal server error"}, status_code=500)
223
-
224
- @app.post("/update_message")
225
- async def update_message(request: Request):
226
- data = await request.json()
227
-
228
- if 'message' not in data:
229
- return JSONResponse(
230
- content={"status": "error", "message": "Missing required field: message"},
231
- status_code=400
232
- )
233
-
234
- log_ootb_request(shared_state.server_url, "update_message", data)
235
-
236
- message = data['message']
237
- # shared_state.chatbot_messages.append({"role": "user", "content": message, "type": "text"})
238
- shared_state.task = message
239
- shared_state.args.task = message
240
-
241
- # Reset stop event before starting
242
- shared_state.stop_event.clear()
243
-
244
- # Start processing if not already running
245
- if not shared_state.is_processing:
246
- # Create and store the thread
247
- shared_state.processing_thread = threading.Thread(target=process_input, daemon=True)
248
- shared_state.processing_thread.start()
249
-
250
- return JSONResponse(
251
- content={"status": "success", "message": "Message received", "task": shared_state.task},
252
- status_code=200
253
- )
254
-
255
- @app.get("/get_messages")
256
- async def get_messages(request: Request):
257
- # Apply rate limiting
258
- if not rate_limiter.allow_request(request.url.path):
259
- return JSONResponse(
260
- content={"status": "error", "message": "Rate limit exceeded. Try again after 2 seconds."},
261
- status_code=429
262
- )
263
-
264
- # log_ootb_request(shared_state.server_url, "get_messages", {})
265
-
266
- # Return all messages in the queue and clear it
267
- messages = shared_state.message_queue.copy()
268
- shared_state.message_queue = []
269
-
270
- return JSONResponse(
271
- content={"status": "success", "messages": messages},
272
- status_code=200
273
- )
274
-
275
- @app.get("/get_screens")
276
- async def get_screens(request: Request):
277
- # Apply rate limiting
278
- if not rate_limiter.allow_request(request.url.path):
279
- return JSONResponse(
280
- content={"status": "error", "message": "Rate limit exceeded. Try again after 2 seconds."},
281
- status_code=429
282
- )
283
-
284
- log_ootb_request(shared_state.server_url, "get_screens", {})
285
-
286
- screen_options, primary_index = get_screen_details()
287
-
288
- return JSONResponse(
289
- content={"status": "success", "screens": screen_options, "primary_index": primary_index},
290
- status_code=200
291
- )
292
-
293
- @app.post("/stop_processing")
294
- async def stop_processing(request: Request):
295
- # Apply rate limiting
296
- if not rate_limiter.allow_request(request.url.path):
297
- return JSONResponse(
298
- content={"status": "error", "message": "Rate limit exceeded. Try again after 2 seconds."},
299
- status_code=429
300
- )
301
-
302
- log_ootb_request(shared_state.server_url, "stop_processing", {})
303
-
304
- if shared_state.is_processing:
305
- # Set both flags to ensure stopping the current task
306
- shared_state.should_stop = True
307
- shared_state.stop_event.set()
308
-
309
- # Clear message histories
310
- shared_state.message_queue.clear()
311
- shared_state.chatbot_messages.clear()
312
- logging.info("Cleared message queue and chatbot messages during stop.")
313
-
314
- # Send an immediate message to the queue to inform the user
315
- stop_initiated_msg = {"role": "assistant", "content": f"Stopping task '{shared_state.task}'...", "type": "text", "action_type": ""}
316
- # Append the stop message AFTER clearing, so it's the only one left
317
- shared_state.message_queue.append(stop_initiated_msg)
318
- shared_state.chatbot_messages.append(stop_initiated_msg)
319
-
320
- return JSONResponse(
321
- content={"status": "success", "message": "Task is being stopped, server will remain available for new tasks"},
322
- status_code=200
323
- )
324
- else:
325
- # Clear message histories even if not processing, to ensure clean state
326
- shared_state.message_queue.clear()
327
- shared_state.chatbot_messages.clear()
328
- logging.info("Cleared message queue and chatbot messages (no active process to stop).")
329
- return JSONResponse(
330
- content={"status": "error", "message": "No active processing to stop"},
331
- status_code=400
332
- )
333
-
334
- @app.post("/toggle_pause")
335
- async def toggle_pause(request: Request):
336
- # Apply rate limiting
337
- if not rate_limiter.allow_request(request.url.path):
338
- return JSONResponse(
339
- content={"status": "error", "message": "Rate limit exceeded. Try again after 2 seconds."},
340
- status_code=429
341
- )
342
-
343
- log_ootb_request(shared_state.server_url, "toggle_pause", {})
344
-
345
- if not shared_state.is_processing:
346
- return JSONResponse(
347
- content={"status": "error", "message": "No active processing to pause/resume"},
348
- status_code=400
349
- )
350
-
351
- # Toggle the pause state
352
- shared_state.is_paused = not shared_state.is_paused
353
- current_state = shared_state.is_paused
354
-
355
- print(f"Toggled pause state to: {current_state}")
356
-
357
- status_message = "paused" if current_state else "resumed"
358
-
359
- # Add a message to the queue to inform the user
360
- if current_state:
361
- message = {"role": "assistant", "content": f"Task '{shared_state.task}' has been paused. Click Continue to resume.", "type": "text", "action_type": ""}
362
- else:
363
- message = {"role": "assistant", "content": f"Task '{shared_state.task}' has been resumed.", "type": "text", "action_type": ""}
364
-
365
- shared_state.chatbot_messages.append(message)
366
- shared_state.message_queue.append(message)
367
-
368
- return JSONResponse(
369
- content={
370
- "status": "success",
371
- "message": f"Processing {status_message}",
372
- "is_paused": current_state
373
- },
374
- status_code=200
375
- )
376
-
377
- @app.get("/status")
378
- async def get_status(request: Request):
379
- # Apply rate limiting
380
- if not rate_limiter.allow_request(request.url.path):
381
- return JSONResponse(
382
- content={"status": "error", "message": "Rate limit exceeded. Try again after 2 seconds."},
383
- status_code=429
384
- )
385
-
386
- # log_ootb_request(shared_state.server_url, "get_status", {})
387
-
388
- print(f"Status check - Processing: {shared_state.is_processing}, Paused: {shared_state.is_paused}")
389
- return JSONResponse(
390
- content={
391
- "status": "success",
392
- "is_processing": shared_state.is_processing,
393
- "is_paused": shared_state.is_paused
394
- },
395
- status_code=200
396
- )
397
-
398
- def process_input():
399
- global shared_state
400
- logging.info("process_input thread started.")
401
- shared_state.is_processing = True
402
- shared_state.should_stop = False
403
- shared_state.is_paused = False
404
- shared_state.stop_event.clear() # Ensure stop event is cleared at the start
405
-
406
- print(f"start sampling loop: {shared_state.chatbot_messages}")
407
- print(f"shared_state.args before sampling loop: {shared_state.args}")
408
-
409
-
410
- try:
411
- # Get the generator for the sampling loop
412
- sampling_loop = simple_teachmode_sampling_loop(
413
- model=shared_state.model,
414
- task=shared_state.task,
415
- selected_screen=shared_state.selected_screen,
416
- user_id=shared_state.user_id,
417
- trace_id=shared_state.trace_id,
418
- api_keys=shared_state.api_keys,
419
- server_url=shared_state.server_url,
420
- )
421
-
422
- # Process messages from the sampling loop
423
- for loop_msg in sampling_loop:
424
- # Check stop condition more frequently
425
- if shared_state.should_stop or shared_state.stop_event.is_set():
426
- print("Processing stopped by user")
427
- break
428
-
429
- # Check if paused and wait while paused
430
- while shared_state.is_paused and not shared_state.should_stop and not shared_state.stop_event.is_set():
431
- print(f"Processing paused at: {time.strftime('%H:%M:%S')}")
432
- # Wait a short time and check stop condition regularly
433
- for _ in range(5): # Check 5 times per second
434
- if shared_state.should_stop or shared_state.stop_event.is_set():
435
- break
436
- time.sleep(0.2)
437
-
438
- # Check again after pause loop
439
- if shared_state.should_stop or shared_state.stop_event.is_set():
440
- print("Processing stopped while paused or resuming")
441
- break
442
-
443
- shared_state.chatbot_messages.append(loop_msg)
444
- shared_state.message_queue.append(loop_msg)
445
-
446
- # Short sleep to allow stop signals to be processed
447
- for _ in range(5): # Check 5 times per second
448
- if shared_state.should_stop or shared_state.stop_event.is_set():
449
- print("Processing stopped during sleep")
450
- break
451
- time.sleep(0.1)
452
-
453
- if shared_state.should_stop or shared_state.stop_event.is_set():
454
- break
455
-
456
- except Exception as e:
457
- # Handle any exceptions in the processing loop
458
- error_msg = f"Error during task processing: {e}"
459
- print(error_msg)
460
- error_message = {"role": "assistant", "content": error_msg, "type": "error", "action_type": ""}
461
- shared_state.message_queue.append(error_message)
462
-
463
- finally:
464
- # Handle completion or interruption
465
- if shared_state.should_stop or shared_state.stop_event.is_set():
466
- stop_msg = f"Task '{shared_state.task}' was stopped. Ready for new tasks."
467
- final_message = {"role": "assistant", "content": stop_msg, "type": "text", "action_type": ""}
468
- else:
469
- complete_msg = f"Task '{shared_state.task}' completed. Thanks for using Teachmode-OOTB."
470
- final_message = {"role": "assistant", "content": complete_msg, "type": "text", "action_type": ""}
471
-
472
- shared_state.chatbot_messages.append(final_message)
473
- shared_state.message_queue.append(final_message)
474
-
475
- # Reset all state flags to allow for new tasks
476
- shared_state.is_processing = False
477
- shared_state.should_stop = False
478
- shared_state.is_paused = False
479
- shared_state.stop_event.clear()
480
- print("Processing completed, ready for new tasks")
481
- logging.info("process_input thread finished.")
482
-
483
- def main():
484
- # Logging is set up at the top level now
485
- logging.info("App main() function starting setup.")
486
- global app, shared_state, rate_limiter # Ensure app is global if needed by uvicorn
487
- parser = argparse.ArgumentParser()
488
- # Add arguments, but NOT host and port
489
- parser.add_argument("--model", type=str, default="teach-mode-gpt-4o", help="Model name")
490
- parser.add_argument("--task", type=str, default="Following the instructions to complete the task.", help="Initial task description")
491
- parser.add_argument("--selected_screen", type=int, default=0, help="Selected screen index")
492
- parser.add_argument("--user_id", type=str, default="hero_cases", help="User ID for the session")
493
- parser.add_argument("--trace_id", type=str, default="build_scroll_combat", help="Trace ID for the session")
494
- parser.add_argument("--api_keys", type=str, default="sk-proj-1234567890", help="API keys")
495
- parser.add_argument("--server_url", type=str, default="http://ec2-44-234-43-86.us-west-2.compute.amazonaws.com", help="Server URL for the session")
496
-
497
- args = parser.parse_args()
498
-
499
- # Validate args or set defaults if needed (keep these)
500
- if not hasattr(args, 'model'): args.model = "default_model"
501
- if not hasattr(args, 'task'): args.task = "default_task"
502
- if not hasattr(args, 'selected_screen'): args.selected_screen = 0
503
- if not hasattr(args, 'user_id'): args.user_id = "unknown_user"
504
- if not hasattr(args, 'trace_id'): args.trace_id = "unknown_trace"
505
- if not hasattr(args, 'api_keys'): args.api_keys = "none"
506
- if not hasattr(args, 'server_url'): args.server_url = "none"
507
-
508
- shared_state = SharedState(args)
509
- rate_limiter = RateLimiter(interval_seconds=2) # Re-initialize rate limiter
510
- logging.info(f"Shared state initialized for user: {args.user_id}")
511
-
512
- # --- Restore original port calculation logic ---
513
- port = 7888 # Default port
514
- host = "0.0.0.0" # Listen on all interfaces
515
-
516
- if platform.system() == "Windows":
517
- try:
518
- username = os.environ["USERNAME"].lower()
519
- logging.info(f"Determining port based on Windows username: {username}")
520
- if username == "altair":
521
- port = 14000
522
- elif username.startswith("guest") and username[5:].isdigit():
523
- num = int(username[5:])
524
- if 1 <= num <= 10: # Assuming max 10 guests for this range
525
- port = 14000 + num
526
- else:
527
- logging.warning(f"Guest user number {num} out of range (1-10), using default port {port}.")
528
- else:
529
- logging.info(f"Username '{username}' doesn't match specific rules, using default port {port}.")
530
- except Exception as e:
531
- logging.error(f"Error determining port from username: {e}. Using default port {port}.", exc_info=True)
532
- else:
533
- logging.info(f"Not running on Windows, using default port {port}.")
534
- # --- End of restored port calculation ---
535
-
536
- logging.info(f"Final Host={host}, Port={port}")
537
-
538
- try:
539
- logging.info(f"Starting Uvicorn server on {host}:{port}")
540
- # Use the calculated port and specific host
541
- uvicorn.run(app, host=host, port=port)
542
- logging.info("Uvicorn server stopped.")
543
- except Exception as main_e:
544
- logging.error("Error in main execution:", exc_info=True)
545
- finally:
546
- logging.info("App main() function finished.")
547
-
548
- if __name__ == "__main__":
549
- main()
550
-
551
- # Test log_ootb_request
1
+ import argparse
2
+ import time
3
+ import json
4
+ from datetime import datetime
5
+ import threading
6
+ import requests
7
+ import platform # Add platform import
8
+ import pyautogui # Add pyautogui import
9
+ import webbrowser # Add webbrowser import
10
+ import os # Import os for path joining
11
+ import logging # Import logging
12
+ import importlib # For dynamic imports
13
+ import pkgutil # To find modules
14
+ import sys # For logging setup
15
+ import traceback # For logging setup
16
+ from logging.handlers import RotatingFileHandler # For logging setup
17
+ from fastapi import FastAPI, Request
18
+ from fastapi.responses import JSONResponse
19
+ from fastapi.middleware.cors import CORSMiddleware
20
+ from computer_use_ootb_internal.computer_use_demo.tools.computer import get_screen_details
21
+ from computer_use_ootb_internal.run_teachmode_ootb_args import simple_teachmode_sampling_loop
22
+ import uvicorn # Assuming uvicorn is used to run FastAPI
23
+
24
+ # --- App Logging Setup ---
25
+ try:
26
+ # Log to user's AppData directory for better accessibility
27
+ log_dir_base = os.environ.get('APPDATA', os.path.expanduser('~'))
28
+ log_dir = os.path.join(log_dir_base, 'OOTBAppLogs')
29
+ os.makedirs(log_dir, exist_ok=True)
30
+ log_file = os.path.join(log_dir, 'ootb_app.log')
31
+
32
+ log_format = '%(asctime)s - %(levelname)s - %(process)d - %(threadName)s - %(message)s'
33
+ log_level = logging.INFO # Or logging.DEBUG for more detail
34
+
35
+ # Use rotating file handler
36
+ handler = RotatingFileHandler(log_file, maxBytes=5*1024*1024, backupCount=2, encoding='utf-8')
37
+ handler.setFormatter(logging.Formatter(log_format))
38
+
39
+ # Configure root logger
40
+ logging.basicConfig(level=log_level, handlers=[handler])
41
+
42
+ # Add stream handler to see logs if running interactively (optional)
43
+ # logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
44
+
45
+ logging.info("="*20 + " OOTB App Starting " + "="*20)
46
+ logging.info(f"Running with args: {sys.argv}")
47
+ logging.info(f"Python Executable: {sys.executable}")
48
+ logging.info(f"Working Directory: {os.getcwd()}")
49
+ logging.info(f"User: {os.getenv('USERNAME')}")
50
+
51
+ except Exception as log_setup_e:
52
+ print(f"FATAL: Failed to set up logging: {log_setup_e}")
53
+ # Fallback logging might be needed here if file logging fails
54
+
55
+ # --- End App Logging Setup ---
56
+
57
+ app = FastAPI()
58
+
59
+ # Add CORS middleware to allow requests from the frontend
60
+ app.add_middleware(
61
+ CORSMiddleware,
62
+ allow_origins=["*"],
63
+ allow_credentials=True,
64
+ allow_methods=["*"],
65
+ allow_headers=["*"],
66
+ )
67
+
68
+ # Rate limiter for API endpoints
69
+ class RateLimiter:
70
+ def __init__(self, interval_seconds=2):
71
+ self.interval = interval_seconds
72
+ self.last_request_time = {}
73
+ self.lock = threading.Lock()
74
+
75
+ def allow_request(self, endpoint):
76
+ with self.lock:
77
+ current_time = time.time()
78
+ # Priority endpoints always allowed
79
+ if endpoint in ["/update_params", "/update_message"]:
80
+ return True
81
+
82
+ # For other endpoints, apply rate limiting
83
+ if endpoint not in self.last_request_time:
84
+ self.last_request_time[endpoint] = current_time
85
+ return True
86
+
87
+ elapsed = current_time - self.last_request_time[endpoint]
88
+ if elapsed < self.interval:
89
+ return False
90
+
91
+ self.last_request_time[endpoint] = current_time
92
+ return True
93
+
94
+
95
+ def log_ootb_request(server_url, ootb_request_type, data):
96
+ logging.info(f"OOTB Request: Type={ootb_request_type}, Data={data}")
97
+ # Keep the requests post for now if it serves a specific purpose
98
+ logging_data = {
99
+ "type": ootb_request_type,
100
+ "data": data,
101
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
102
+ }
103
+ if not server_url.endswith("/update_ootb_logging"):
104
+ server_logging_url = server_url + "/update_ootb_logging"
105
+ else:
106
+ server_logging_url = server_url
107
+ try:
108
+ requests.post(server_logging_url, json=logging_data, timeout=5)
109
+ except Exception as req_log_e:
110
+ logging.warning(f"Could not log ootb request to server {server_logging_url}: {req_log_e}")
111
+
112
+
113
+ class SharedState:
114
+ def __init__(self, args):
115
+ self.args = args
116
+ self.task_updated = False
117
+ self.chatbot_messages = []
118
+ # Store all state-related data here
119
+ self.model = args.model
120
+ self.task = getattr(args, 'task', "")
121
+ self.selected_screen = args.selected_screen
122
+ self.user_id = args.user_id
123
+ self.trace_id = args.trace_id
124
+ self.api_keys = args.api_keys
125
+ self.server_url = args.server_url
126
+ self.port = getattr(args, 'port', None)
127
+ self.message_queue = []
128
+ self.is_processing = False
129
+ self.should_stop = False
130
+ self.is_paused = False
131
+ # Add a new event to better control stopping
132
+ self.stop_event = threading.Event()
133
+ # Add a reference to the processing thread
134
+ self.processing_thread = None
135
+
136
+ shared_state = None
137
+ rate_limiter = RateLimiter(interval_seconds=2)
138
+
139
+ # Set up logging for this module
140
+ log = logging.getLogger(__name__)
141
+
142
+ def prepare_environment(state):
143
+ """Dynamically loads and runs preparation logic based on software name."""
144
+ # TODO: Replace hardcoded software name with value from shared_state when available
145
+ software_name = "star rail"
146
+ # Normalize the software name to be a valid Python module name
147
+ # Replace spaces/hyphens with underscores, convert to lowercase
148
+ module_name_base = software_name.replace(" ", "_").replace("-", "_").lower()
149
+ module_to_run = f"{module_name_base}_prepare"
150
+
151
+ log.info(f"Attempting preparation for software: '{software_name}' (Module: '{module_to_run}')")
152
+
153
+ try:
154
+ # Construct the full module path within the package
155
+ prep_package = "computer_use_ootb_internal.preparation"
156
+ full_module_path = f"{prep_package}.{module_to_run}"
157
+
158
+ # Dynamically import the module
159
+ # Check if module exists first using pkgutil to avoid import errors
160
+ # Note: pkgutil.find_loader might be deprecated, consider importlib.util.find_spec
161
+ loader = pkgutil.find_loader(full_module_path)
162
+ if loader is None:
163
+ log.warning(f"Preparation module '{full_module_path}' not found. Skipping preparation.")
164
+ return
165
+
166
+ prep_module = importlib.import_module(full_module_path)
167
+
168
+ # Check if the module has the expected function
169
+ if hasattr(prep_module, "run_preparation") and callable(prep_module.run_preparation):
170
+ log.info(f"Running preparation function from {full_module_path}...")
171
+ prep_module.run_preparation(state)
172
+ log.info(f"Preparation function from {full_module_path} completed.")
173
+ else:
174
+ log.warning(f"Module {full_module_path} found, but does not have a callable 'run_preparation' function. Skipping.")
175
+
176
+ except ModuleNotFoundError:
177
+ log.warning(f"Preparation module '{full_module_path}' not found. Skipping preparation.")
178
+ except Exception as e:
179
+ log.error(f"Error during dynamic preparation loading/execution for '{module_to_run}': {e}", exc_info=True)
180
+
181
+
182
+ @app.post("/update_params")
183
+ async def update_parameters(request: Request):
184
+ logging.info("Received request to /update_params")
185
+ try:
186
+ data = await request.json()
187
+
188
+ if 'task' not in data:
189
+ return JSONResponse(
190
+ content={"status": "error", "message": "Missing required field: task"},
191
+ status_code=400
192
+ )
193
+
194
+ # Clear message histories before updating parameters
195
+ shared_state.message_queue.clear()
196
+ shared_state.chatbot_messages.clear()
197
+ logging.info("Cleared message queue and chatbot messages.")
198
+
199
+ shared_state.args = argparse.Namespace(**data)
200
+ shared_state.task_updated = True
201
+
202
+ # Update shared state when parameters change
203
+ shared_state.model = getattr(shared_state.args, 'model', "teach-mode-gpt-4o")
204
+ shared_state.task = getattr(shared_state.args, 'task', "Following the instructions to complete the task.")
205
+ shared_state.selected_screen = getattr(shared_state.args, 'selected_screen', 0)
206
+ shared_state.user_id = getattr(shared_state.args, 'user_id', "hero_cases")
207
+ shared_state.trace_id = getattr(shared_state.args, 'trace_id', "build_scroll_combat")
208
+ shared_state.api_keys = getattr(shared_state.args, 'api_keys', "sk-proj-1234567890")
209
+ shared_state.server_url = getattr(shared_state.args, 'server_url', "http://ec2-44-234-43-86.us-west-2.compute.amazonaws.com")
210
+ # Include the specified port if available
211
+ shared_state.port = getattr(shared_state.args, 'port', None)
212
+
213
+ log_ootb_request(shared_state.server_url, "update_params", data)
214
+
215
+ # Call the (now dynamic) preparation function here, after parameters are updated
216
+ prepare_environment(shared_state)
217
+
218
+ logging.info("Parameters updated successfully.")
219
+ return JSONResponse(
220
+ content={"status": "success", "message": "Parameters updated", "new_args": vars(shared_state.args)},
221
+ status_code=200
222
+ )
223
+ except Exception as e:
224
+ logging.error("Error processing /update_params:", exc_info=True)
225
+ return JSONResponse(content={"status": "error", "message": "Internal server error"}, status_code=500)
226
+
227
+ @app.post("/update_message")
228
+ async def update_message(request: Request):
229
+ data = await request.json()
230
+
231
+ if 'message' not in data:
232
+ return JSONResponse(
233
+ content={"status": "error", "message": "Missing required field: message"},
234
+ status_code=400
235
+ )
236
+
237
+ log_ootb_request(shared_state.server_url, "update_message", data)
238
+
239
+ message = data['message']
240
+ # shared_state.chatbot_messages.append({"role": "user", "content": message, "type": "text"})
241
+ shared_state.task = message
242
+ shared_state.args.task = message
243
+
244
+ # Reset stop event before starting
245
+ shared_state.stop_event.clear()
246
+
247
+ # Start processing if not already running
248
+ if not shared_state.is_processing:
249
+ # Create and store the thread
250
+ shared_state.processing_thread = threading.Thread(target=process_input, daemon=True)
251
+ shared_state.processing_thread.start()
252
+
253
+ return JSONResponse(
254
+ content={"status": "success", "message": "Message received", "task": shared_state.task},
255
+ status_code=200
256
+ )
257
+
258
+ @app.get("/get_messages")
259
+ async def get_messages(request: Request):
260
+ # Apply rate limiting
261
+ if not rate_limiter.allow_request(request.url.path):
262
+ return JSONResponse(
263
+ content={"status": "error", "message": "Rate limit exceeded. Try again after 2 seconds."},
264
+ status_code=429
265
+ )
266
+
267
+ # log_ootb_request(shared_state.server_url, "get_messages", {})
268
+
269
+ # Return all messages in the queue and clear it
270
+ messages = shared_state.message_queue.copy()
271
+ shared_state.message_queue = []
272
+
273
+ return JSONResponse(
274
+ content={"status": "success", "messages": messages},
275
+ status_code=200
276
+ )
277
+
278
+ @app.get("/get_screens")
279
+ async def get_screens(request: Request):
280
+ # Apply rate limiting
281
+ if not rate_limiter.allow_request(request.url.path):
282
+ return JSONResponse(
283
+ content={"status": "error", "message": "Rate limit exceeded. Try again after 2 seconds."},
284
+ status_code=429
285
+ )
286
+
287
+ log_ootb_request(shared_state.server_url, "get_screens", {})
288
+
289
+ screen_options, primary_index = get_screen_details()
290
+
291
+ return JSONResponse(
292
+ content={"status": "success", "screens": screen_options, "primary_index": primary_index},
293
+ status_code=200
294
+ )
295
+
296
+ @app.post("/stop_processing")
297
+ async def stop_processing(request: Request):
298
+ # Apply rate limiting
299
+ if not rate_limiter.allow_request(request.url.path):
300
+ return JSONResponse(
301
+ content={"status": "error", "message": "Rate limit exceeded. Try again after 2 seconds."},
302
+ status_code=429
303
+ )
304
+
305
+ log_ootb_request(shared_state.server_url, "stop_processing", {})
306
+
307
+ if shared_state.is_processing:
308
+ # Set both flags to ensure stopping the current task
309
+ shared_state.should_stop = True
310
+ shared_state.stop_event.set()
311
+
312
+ # Clear message histories
313
+ shared_state.message_queue.clear()
314
+ shared_state.chatbot_messages.clear()
315
+ logging.info("Cleared message queue and chatbot messages during stop.")
316
+
317
+ # Send an immediate message to the queue to inform the user
318
+ stop_initiated_msg = {"role": "assistant", "content": f"Stopping task '{shared_state.task}'...", "type": "text", "action_type": ""}
319
+ # Append the stop message AFTER clearing, so it's the only one left
320
+ shared_state.message_queue.append(stop_initiated_msg)
321
+ shared_state.chatbot_messages.append(stop_initiated_msg)
322
+
323
+ return JSONResponse(
324
+ content={"status": "success", "message": "Task is being stopped, server will remain available for new tasks"},
325
+ status_code=200
326
+ )
327
+ else:
328
+ # Clear message histories even if not processing, to ensure clean state
329
+ shared_state.message_queue.clear()
330
+ shared_state.chatbot_messages.clear()
331
+ logging.info("Cleared message queue and chatbot messages (no active process to stop).")
332
+ return JSONResponse(
333
+ content={"status": "error", "message": "No active processing to stop"},
334
+ status_code=400
335
+ )
336
+
337
+ @app.post("/toggle_pause")
338
+ async def toggle_pause(request: Request):
339
+ # Apply rate limiting
340
+ if not rate_limiter.allow_request(request.url.path):
341
+ return JSONResponse(
342
+ content={"status": "error", "message": "Rate limit exceeded. Try again after 2 seconds."},
343
+ status_code=429
344
+ )
345
+
346
+ log_ootb_request(shared_state.server_url, "toggle_pause", {})
347
+
348
+ if not shared_state.is_processing:
349
+ return JSONResponse(
350
+ content={"status": "error", "message": "No active processing to pause/resume"},
351
+ status_code=400
352
+ )
353
+
354
+ # Toggle the pause state
355
+ shared_state.is_paused = not shared_state.is_paused
356
+ current_state = shared_state.is_paused
357
+
358
+ print(f"Toggled pause state to: {current_state}")
359
+
360
+ status_message = "paused" if current_state else "resumed"
361
+
362
+ # Add a message to the queue to inform the user
363
+ if current_state:
364
+ message = {"role": "assistant", "content": f"Task '{shared_state.task}' has been paused. Click Continue to resume.", "type": "text", "action_type": ""}
365
+ else:
366
+ message = {"role": "assistant", "content": f"Task '{shared_state.task}' has been resumed.", "type": "text", "action_type": ""}
367
+
368
+ shared_state.chatbot_messages.append(message)
369
+ shared_state.message_queue.append(message)
370
+
371
+ return JSONResponse(
372
+ content={
373
+ "status": "success",
374
+ "message": f"Processing {status_message}",
375
+ "is_paused": current_state
376
+ },
377
+ status_code=200
378
+ )
379
+
380
+ @app.get("/status")
381
+ async def get_status(request: Request):
382
+ # Apply rate limiting
383
+ if not rate_limiter.allow_request(request.url.path):
384
+ return JSONResponse(
385
+ content={"status": "error", "message": "Rate limit exceeded. Try again after 2 seconds."},
386
+ status_code=429
387
+ )
388
+
389
+ # log_ootb_request(shared_state.server_url, "get_status", {})
390
+
391
+ print(f"Status check - Processing: {shared_state.is_processing}, Paused: {shared_state.is_paused}")
392
+ return JSONResponse(
393
+ content={
394
+ "status": "success",
395
+ "is_processing": shared_state.is_processing,
396
+ "is_paused": shared_state.is_paused
397
+ },
398
+ status_code=200
399
+ )
400
+
401
+ def process_input():
402
+ global shared_state
403
+ logging.info("process_input thread started.")
404
+ shared_state.is_processing = True
405
+ shared_state.should_stop = False
406
+ shared_state.is_paused = False
407
+ shared_state.stop_event.clear() # Ensure stop event is cleared at the start
408
+
409
+ print(f"start sampling loop: {shared_state.chatbot_messages}")
410
+ print(f"shared_state.args before sampling loop: {shared_state.args}")
411
+
412
+
413
+ try:
414
+ # Get the generator for the sampling loop
415
+ sampling_loop = simple_teachmode_sampling_loop(
416
+ model=shared_state.model,
417
+ task=shared_state.task,
418
+ selected_screen=shared_state.selected_screen,
419
+ user_id=shared_state.user_id,
420
+ trace_id=shared_state.trace_id,
421
+ api_keys=shared_state.api_keys,
422
+ server_url=shared_state.server_url,
423
+ )
424
+
425
+ # Process messages from the sampling loop
426
+ for loop_msg in sampling_loop:
427
+ # Check stop condition more frequently
428
+ if shared_state.should_stop or shared_state.stop_event.is_set():
429
+ print("Processing stopped by user")
430
+ break
431
+
432
+ # Check if paused and wait while paused
433
+ while shared_state.is_paused and not shared_state.should_stop and not shared_state.stop_event.is_set():
434
+ print(f"Processing paused at: {time.strftime('%H:%M:%S')}")
435
+ # Wait a short time and check stop condition regularly
436
+ for _ in range(5): # Check 5 times per second
437
+ if shared_state.should_stop or shared_state.stop_event.is_set():
438
+ break
439
+ time.sleep(0.2)
440
+
441
+ # Check again after pause loop
442
+ if shared_state.should_stop or shared_state.stop_event.is_set():
443
+ print("Processing stopped while paused or resuming")
444
+ break
445
+
446
+ shared_state.chatbot_messages.append(loop_msg)
447
+ shared_state.message_queue.append(loop_msg)
448
+
449
+ # Short sleep to allow stop signals to be processed
450
+ for _ in range(5): # Check 5 times per second
451
+ if shared_state.should_stop or shared_state.stop_event.is_set():
452
+ print("Processing stopped during sleep")
453
+ break
454
+ time.sleep(0.1)
455
+
456
+ if shared_state.should_stop or shared_state.stop_event.is_set():
457
+ break
458
+
459
+ except Exception as e:
460
+ # Handle any exceptions in the processing loop
461
+ error_msg = f"Error during task processing: {e}"
462
+ print(error_msg)
463
+ error_message = {"role": "assistant", "content": error_msg, "type": "error", "action_type": ""}
464
+ shared_state.message_queue.append(error_message)
465
+
466
+ finally:
467
+ # Handle completion or interruption
468
+ if shared_state.should_stop or shared_state.stop_event.is_set():
469
+ stop_msg = f"Task '{shared_state.task}' was stopped. Ready for new tasks."
470
+ final_message = {"role": "assistant", "content": stop_msg, "type": "text", "action_type": ""}
471
+ else:
472
+ complete_msg = f"Task '{shared_state.task}' completed. Thanks for using Teachmode-OOTB."
473
+ final_message = {"role": "assistant", "content": complete_msg, "type": "text", "action_type": ""}
474
+
475
+ shared_state.chatbot_messages.append(final_message)
476
+ shared_state.message_queue.append(final_message)
477
+
478
+ # Reset all state flags to allow for new tasks
479
+ shared_state.is_processing = False
480
+ shared_state.should_stop = False
481
+ shared_state.is_paused = False
482
+ shared_state.stop_event.clear()
483
+ print("Processing completed, ready for new tasks")
484
+ logging.info("process_input thread finished.")
485
+
486
+ def main():
487
+ # Logging is set up at the top level now
488
+ logging.info("App main() function starting setup.")
489
+ global app, shared_state, rate_limiter # Ensure app is global if needed by uvicorn
490
+ parser = argparse.ArgumentParser()
491
+ # Add arguments, but NOT host and port
492
+ parser.add_argument("--model", type=str, default="teach-mode-gpt-4o", help="Model name")
493
+ parser.add_argument("--task", type=str, default="Following the instructions to complete the task.", help="Initial task description")
494
+ parser.add_argument("--selected_screen", type=int, default=0, help="Selected screen index")
495
+ parser.add_argument("--user_id", type=str, default="hero_cases", help="User ID for the session")
496
+ parser.add_argument("--trace_id", type=str, default="build_scroll_combat", help="Trace ID for the session")
497
+ parser.add_argument("--api_keys", type=str, default="sk-proj-1234567890", help="API keys")
498
+ parser.add_argument("--server_url", type=str, default="http://ec2-44-234-43-86.us-west-2.compute.amazonaws.com", help="Server URL for the session")
499
+ # Add argument to directly specify the port
500
+ parser.add_argument("-p", "--port", type=int, default=None, help="Specify the exact port to run the server on, overriding username-based calculation.")
501
+
502
+ args = parser.parse_args()
503
+
504
+ # Validate args or set defaults if needed (keep these)
505
+ if not hasattr(args, 'model'): args.model = "default_model"
506
+ if not hasattr(args, 'task'): args.task = "default_task"
507
+ if not hasattr(args, 'selected_screen'): args.selected_screen = 0
508
+ if not hasattr(args, 'user_id'): args.user_id = "unknown_user"
509
+ if not hasattr(args, 'trace_id'): args.trace_id = "unknown_trace"
510
+ if not hasattr(args, 'api_keys'): args.api_keys = "none"
511
+ if not hasattr(args, 'server_url'): args.server_url = "none"
512
+ # Ensure the port arg exists
513
+ if not hasattr(args, 'port'): args.port = None
514
+
515
+ shared_state = SharedState(args)
516
+ rate_limiter = RateLimiter(interval_seconds=2) # Re-initialize rate limiter
517
+ logging.info(f"Shared state initialized for user: {args.user_id}")
518
+
519
+ # --- Port Determination ---
520
+ port = args.port # Get port from argument first
521
+ host = "0.0.0.0" # Listen on all interfaces
522
+
523
+ if port is not None:
524
+ logging.info(f"Using specified port from --port argument: {port}")
525
+ elif platform.system() == "Windows":
526
+ # Fallback to username calculation ONLY if --port was not provided
527
+ port = 7888 # Default port for fallback calculation
528
+ try:
529
+ username_to_check = os.environ.get("USERNAME")
530
+ if username_to_check:
531
+ username = username_to_check.lower()
532
+ logging.info(f"Determining port based on current user environment: {username}")
533
+ if username == "altair":
534
+ port = 14000
535
+ elif username.startswith("guest") and username[5:].isdigit():
536
+ num = int(username[5:])
537
+ if 1 <= num <= 10:
538
+ port = 14000 + num
539
+ else:
540
+ logging.warning(f"Guest user number {num} out of range (1-10), using default port {port}.")
541
+ elif username == "administrator":
542
+ logging.info(f"Running as Administrator, using default port {port}.")
543
+ else:
544
+ logging.info(f"Username '{username}' doesn't match specific rules, using default port {port}.")
545
+ else:
546
+ logging.warning(f"USERNAME environment variable not found. Using default port {port}.")
547
+ except Exception as e:
548
+ logging.error(f"Error calculating port from username: {e}. Using default port {port}.", exc_info=True)
549
+ else:
550
+ # Non-windows, or Windows with no username found, and no --port arg
551
+ port = 7888 # Default port
552
+ logging.info(f"Not on Windows or username not found, using default port {port}.")
553
+
554
+ # --- End Port Determination ---
555
+
556
+ logging.info(f"Final Host={host}, Port={port}")
557
+
558
+ try:
559
+ logging.info(f"Starting Uvicorn server on {host}:{port}")
560
+ # Use the calculated port and specific host
561
+ uvicorn.run(app, host=host, port=port)
562
+ logging.info("Uvicorn server stopped.")
563
+ except Exception as main_e:
564
+ logging.error("Error in main execution:", exc_info=True)
565
+ finally:
566
+ logging.info("App main() function finished.")
567
+
568
+ if __name__ == "__main__":
569
+ main()
570
+
571
+ # Test log_ootb_request
552
572
  log_ootb_request("http://ec2-44-234-43-86.us-west-2.compute.amazonaws.com", "test_request", {"message": "Test message"})