snowglobe 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,732 @@
1
+ # Set up the scheduler
2
+ import asyncio
3
+ import datetime
4
+ import importlib.util
5
+ import json
6
+ import logging
7
+ import os
8
+ import time
9
+ import traceback
10
+ from collections import defaultdict, deque
11
+ from contextlib import asynccontextmanager
12
+ from functools import wraps
13
+ from logging import getLogger
14
+ from typing import Dict
15
+ from urllib.parse import quote_plus
16
+
17
+ import httpx
18
+ import uvicorn
19
+ from apscheduler import AsyncScheduler
20
+ from apscheduler.triggers.interval import IntervalTrigger
21
+ from fastapi import FastAPI, HTTPException, Request
22
+
23
+ from .cli_utils import info
24
+ from .config import config
25
+ from .models import CompletionFunctionOutputs, CompletionRequest, RiskEvaluationRequest
26
+ from .stats import initialize_stats, track_batch_completion
27
+ from .utils import fetch_experiments, fetch_messages
28
+
29
+ # Configure logging - check DEBUG from environment directly to avoid config initialization
30
+ if os.getenv("DEBUG", "false").lower() == "true":
31
+ logging.basicConfig(
32
+ handlers=[logging.StreamHandler()],
33
+ level=os.getenv("LOG_LEVEL", "DEBUG").upper(),
34
+ )
35
+ else:
36
+ logging.basicConfig(
37
+ handlers=[logging.StreamHandler()],
38
+ )
39
+
40
+ LOGGER = getLogger("snowglobe_client")
41
+ # Allow LOG_LEVEL environment variable to override default INFO level
42
+ LOGGER.setLevel(getattr(logging, os.getenv("LOG_LEVEL", "INFO").upper()))
43
+
44
+ # In-memory storage for rate limiting per route
45
+ route_request_times = defaultdict(lambda: defaultdict(deque))
46
+
47
+
48
+ class ConfigurableRateLimiter:
49
+ def __init__(self):
50
+ self.route_configs = {}
51
+
52
+ def configure_route(self, route_key: str, max_requests: int, time_window: int):
53
+ """Configure rate limiting for a specific route"""
54
+ self.route_configs[route_key] = {
55
+ "max_requests": max_requests,
56
+ "time_window": time_window,
57
+ }
58
+
59
+ def is_allowed(self, client_id: str, route_key: str) -> bool:
60
+ # Get route configuration or use default
61
+ config = self.route_configs.get(
62
+ route_key, {"max_requests": 1, "time_window": 1}
63
+ )
64
+ max_requests = config["max_requests"]
65
+ time_window = config["time_window"]
66
+
67
+ now = time.time()
68
+ client_requests = route_request_times[route_key][client_id]
69
+
70
+ # Remove old requests outside the time window
71
+ while client_requests and client_requests[0] <= now - time_window:
72
+ client_requests.popleft()
73
+
74
+ # Check if client has exceeded rate limit for this route
75
+ if len(client_requests) >= max_requests:
76
+ return False
77
+
78
+ # Add current request time
79
+ client_requests.append(now)
80
+ return True
81
+
82
+ def get_route_config(self, route_key: str) -> Dict[str, int]:
83
+ """Get the configuration for a specific route"""
84
+ return self.route_configs.get(route_key, {"max_requests": 1, "time_window": 1})
85
+
86
+
87
+ # Create global rate limiter instance
88
+ rate_limiter = ConfigurableRateLimiter()
89
+
90
+
91
+ def rate_limit(route_name: str, max_requests: int = 1, time_window: int = 1):
92
+ """Decorator for per-route rate limiting with configurable limits"""
93
+
94
+ def decorator(func):
95
+ # Configure the route when decorator is applied
96
+ rate_limiter.configure_route(route_name, max_requests, time_window)
97
+
98
+ @wraps(func)
99
+ async def wrapper(request: Request, *args, **kwargs):
100
+ client_ip = request.client.host
101
+
102
+ if not rate_limiter.is_allowed(client_ip, route_name):
103
+ config = rate_limiter.get_route_config(route_name)
104
+ raise HTTPException(
105
+ status_code=429,
106
+ detail=f"Rate limit exceeded for route '{route_name}'. Maximum {config['max_requests']} requests per {config['time_window']} seconds.",
107
+ )
108
+
109
+ return await func(request, *args, **kwargs)
110
+
111
+ return wrapper
112
+
113
+ return decorator
114
+
115
+
116
+ queued_tests = {}
117
+ queued_evaluations = {}
118
+ risks = {}
119
+ apps = {}
120
+
121
+
122
+ async def process_application_heartbeat(app_id):
123
+ connection_test_payload = {
124
+ "appId": app_id,
125
+ }
126
+ try:
127
+ prompt = "Hello from Snowglobe!"
128
+ test_request = CompletionRequest(messages=[{"role": "user", "content": prompt}])
129
+ completion_fn = apps.get(app_id, {}).get("completion_fn")
130
+ if not completion_fn:
131
+ LOGGER.warning(
132
+ f"No completion function found for application {app_id}. Skipping heartbeat."
133
+ )
134
+ return
135
+ if asyncio.iscoroutinefunction(completion_fn):
136
+ response = await completion_fn(test_request)
137
+ else:
138
+ response = completion_fn(test_request)
139
+ if not isinstance(response, CompletionFunctionOutputs):
140
+ LOGGER.error(
141
+ f"Completion function for application {app_id} did not return a valid response. Expected CompletionFunctionOutputs, got {type(response)}"
142
+ )
143
+ connection_test_payload["status"] = "failed"
144
+ connection_test_payload["error"] = (
145
+ "Completion function did not return a valid response. "
146
+ "Expected CompletionFunctionOutputs, got {type(response)}"
147
+ )
148
+
149
+ if not hasattr(response, "response") or not isinstance(response.response, str):
150
+ LOGGER.error(
151
+ f"Completion function for application {app_id} did not return a valid response. Expected a string, got {type(response.response)}"
152
+ )
153
+ connection_test_payload["status"] = "failed"
154
+ connection_test_payload["error"] = (
155
+ "Completion function did not return a valid response. Expected a string, got {type(response.response)}"
156
+ )
157
+
158
+ if response.response == "This is a string response from your application":
159
+ LOGGER.error(
160
+ f"Completion function for application {app_id} returned a default response. This indicates the application is not properly connected."
161
+ )
162
+ connection_test_payload["status"] = "failed"
163
+ connection_test_payload["error"] = (
164
+ "Completion function returned a default response. "
165
+ )
166
+
167
+ if connection_test_payload.get("status") != "failed":
168
+ connection_test_payload["response"] = response.response
169
+ connection_test_payload["prompt"] = prompt
170
+
171
+ except Exception as e:
172
+ connection_test_payload["status"] = "failed"
173
+ connection_test_payload["error"] = f"{str(e)}\n{traceback.format_exc()}"
174
+ connection_test_payload["app_id"] = app_id
175
+ connection_test_payload["applicationId"] = app_id
176
+
177
+ connection_test_url = (
178
+ f"{config.CONTROL_PLANE_URL}/api/successful-code-connection-tests"
179
+ )
180
+
181
+ if connection_test_payload.get("status") == "failed":
182
+ connection_test_url = (
183
+ f"{config.CONTROL_PLANE_URL}/api/failed-code-connection-tests"
184
+ )
185
+
186
+ async with httpx.AsyncClient() as client:
187
+ LOGGER.info(
188
+ f"Posting code connection test for application {app_id} connection_test_payload: {connection_test_payload}"
189
+ )
190
+ connection_test_response = await client.post(
191
+ connection_test_url,
192
+ json=connection_test_payload,
193
+ headers={"x-api-key": config.API_KEY},
194
+ )
195
+
196
+ if not connection_test_response.is_success:
197
+ LOGGER.error(
198
+ f"Error posting code connection test for application {app_id}: {connection_test_response.text}"
199
+ )
200
+ return connection_test_response.json()
201
+
202
+
203
+ async def process_risk_evaluation(test, risk_name):
204
+ """finds correct risk and calls the risk evaluation function and creates a risk evaluation for the test"""
205
+ start = time.time()
206
+
207
+ messages = await fetch_messages(test=test)
208
+
209
+ if asyncio.iscoroutinefunction(risks[risk_name]):
210
+ risk_evaluation = await risks[risk_name](
211
+ RiskEvaluationRequest(messages=messages)
212
+ )
213
+ else:
214
+ risk_evaluation = risks[risk_name](RiskEvaluationRequest(messages=messages))
215
+
216
+ LOGGER.debug(f"Risk evaluation output: {risk_evaluation}")
217
+
218
+ # Extract fields from risk_evaluation object
219
+ severity = getattr(risk_evaluation, "severity", "")
220
+ reason = getattr(risk_evaluation, "reason", "")
221
+ risk_triggered = getattr(risk_evaluation, "triggered", "")
222
+
223
+ response_xml = (
224
+ f"<risk>"
225
+ f"<name>{risk_name}</name>"
226
+ f"<severity>{severity}</severity>"
227
+ f"<reason>{reason}</reason>"
228
+ f"<risk_triggered>{risk_triggered}</risk_triggered>"
229
+ f"</risk>"
230
+ )
231
+
232
+ # Post a Risk Evaluation (async)
233
+ async with httpx.AsyncClient() as client:
234
+ risk_evaluation_response = await client.post(
235
+ f"{config.CONTROL_PLANE_URL}/api/experiments/{test['experiment_id']}/tests/{test['id']}/evaluations",
236
+ json={
237
+ "test_id": test["id"],
238
+ "judge_prompt": "", # does this need to be set?
239
+ "judge_response": response_xml,
240
+ "risk_type": risk_name,
241
+ "risk_triggered": risk_evaluation.triggered,
242
+ },
243
+ headers={"x-api-key": config.API_KEY},
244
+ )
245
+ LOGGER.debug(f"Time taken: {time.time() - start} seconds")
246
+ if not risk_evaluation_response.is_success:
247
+ LOGGER.error("Error posting risk evaluation", risk_evaluation_response.text)
248
+ raise Exception("Error posting risk evaluation, task is not healthy")
249
+
250
+
251
+ async def process_test(test, completion_fn, app_id):
252
+ """Processes a test by converting it to OpenAI style messages and calling the completion function"""
253
+ start = time.time()
254
+ # convert test to openai style messages
255
+ messages = await fetch_messages(test=test)
256
+
257
+ if asyncio.iscoroutinefunction(completion_fn):
258
+ completionOutput = await completion_fn(CompletionRequest(messages=messages))
259
+ else:
260
+ completionOutput = completion_fn(CompletionRequest(messages=messages))
261
+
262
+ LOGGER.debug(f"Completion output: {completionOutput}")
263
+
264
+ async with httpx.AsyncClient() as client:
265
+ await client.put(
266
+ f"{config.CONTROL_PLANE_URL}/api/experiments/{test['experiment_id']}/tests/{test['id']}",
267
+ json={
268
+ "id": test["id"],
269
+ "appId": app_id,
270
+ "prompt": test["prompt"],
271
+ "response": completionOutput.response,
272
+ "persona": test["persona"],
273
+ },
274
+ headers={"x-api-key": config.API_KEY},
275
+ )
276
+ LOGGER.debug(f"Time taken: {time.time() - start} seconds")
277
+ # remove test['id'] from queued_tests
278
+ queued_tests.pop(test["id"], None)
279
+
280
+ return completionOutput
281
+
282
+
283
+ # The task to run
284
+ async def poll_for_completions():
285
+ if len(apps) == 0:
286
+ LOGGER.warning("No applications found. Skipping completions check.")
287
+ return
288
+ experiments = await fetch_experiments()
289
+
290
+ async with httpx.AsyncClient() as client:
291
+ LOGGER.info(f"Polling {len(experiments)} experiments for completions...")
292
+ for experiment in experiments:
293
+ LOGGER.debug(
294
+ f"Checking experiment: id={experiment.get('id')}, name={experiment.get('name', 'unknown')}"
295
+ )
296
+ for experiment in experiments:
297
+ app_id = experiment.get("app_id")
298
+ if app_id not in apps:
299
+ LOGGER.debug(
300
+ f"Skipping experiment as we do not have a completion function for app_id: {app_id}"
301
+ )
302
+ continue
303
+ experiment_id = experiment["id"]
304
+ experiment_request = await client.get(
305
+ f"{config.CONTROL_PLANE_URL}/api/experiments/{experiment_id}?appId={app_id}",
306
+ headers={"x-api-key": config.API_KEY},
307
+ )
308
+ if not experiment_request.is_success:
309
+ LOGGER.error(
310
+ f"Error fetching experiment {experiment_id}: {experiment_request.text}"
311
+ )
312
+ continue
313
+ experiment = experiment_request.json()
314
+
315
+ limit = 10
316
+ LOGGER.debug(f"Checking for tests for experiment {experiment_id}")
317
+ tests_response = await client.get(
318
+ f"{config.CONTROL_PLANE_URL}/api/experiments/{experiment_id}/tests?appId={app_id}&include-risk-evaluations=false&limit={limit}&unprocessed-only=true",
319
+ headers={"x-api-key": config.API_KEY},
320
+ )
321
+
322
+ if not tests_response.is_success:
323
+ LOGGER.error(f"Error fetching tests: {tests_response.text}")
324
+ continue
325
+
326
+ tests = tests_response.json()
327
+ completion_count = 0
328
+ for test in tests:
329
+ LOGGER.debug(f"Found test {test['id']} for experiment {experiment_id}")
330
+ if not test["response"] and test["id"] not in queued_tests:
331
+ completion_request = await httpx.AsyncClient().post(
332
+ f"{config.SNOWGLOBE_CLIENT_URL}/completion",
333
+ json={"test": test, "app_id": app_id},
334
+ timeout=30,
335
+ )
336
+ # if 429 raise and exception and stop this batch
337
+ if (
338
+ not completion_request.is_success
339
+ and completion_request.status_code == 429
340
+ ):
341
+ LOGGER.warning(
342
+ f"Completion Rate limit exceeded for test {test['id']}: {completion_request.text}"
343
+ )
344
+ raise ValueError(
345
+ status_code=429,
346
+ detail=f"Rate limit exceeded for test {test['id']}",
347
+ )
348
+ completion_count += 1
349
+ queued_tests[test["id"]] = True
350
+ if completion_count > 0:
351
+ experiment_name = experiment.get("name", "unknown")
352
+ if LOGGER.level <= logging.INFO: # Verbose mode
353
+ LOGGER.info(
354
+ f"Processed {completion_count} completions for experiment {experiment_name} ({experiment_id})"
355
+ )
356
+ else: # Clean UI mode
357
+ timestamp = datetime.datetime.now().strftime("%H:%M:%S")
358
+ info(
359
+ f"[{timestamp}] ✓ Batch complete: {completion_count} responses sent ({experiment_name})"
360
+ )
361
+
362
+ # Track batch completion
363
+ track_batch_completion(experiment_name, completion_count)
364
+
365
+
366
+ async def process_application_heartbeats():
367
+ connection_test_count = 0
368
+ LOGGER.info("Processing application heartbeats...")
369
+ for app_id, app_info in apps.items():
370
+ connection_test_request = await httpx.AsyncClient().post(
371
+ f"{config.SNOWGLOBE_CLIENT_URL}/heartbeat",
372
+ json={"app_id": app_id},
373
+ timeout=30,
374
+ )
375
+ if not connection_test_request.is_success:
376
+ LOGGER.error(
377
+ f"Error sending heartbeat for application {app_id}: {connection_test_request.text}"
378
+ )
379
+ continue
380
+ connection_test_count += 1
381
+
382
+ LOGGER.info(f"Processed {connection_test_count} heartbeats for applications.")
383
+
384
+
385
+ async def poll_for_risk_evaluations():
386
+ """Poll for risk evaluations and process them."""
387
+ experiments = await fetch_experiments()
388
+ LOGGER.info("Checking for pending risk evaluations...")
389
+ LOGGER.debug(f"Found {len(experiments)} experiments with validation in progress")
390
+ # experiments = [{"id": "123"}]
391
+ async with httpx.AsyncClient() as client:
392
+ for experiment in experiments:
393
+ experiment_request = await client.get(
394
+ f"{config.CONTROL_PLANE_URL}/api/experiments/{experiment['id']}",
395
+ headers={"x-api-key": config.API_KEY},
396
+ )
397
+ if not experiment_request.is_success:
398
+ LOGGER.error(
399
+ f"Error fetching experiment {experiment['id']}: {experiment_request.text}"
400
+ )
401
+ continue
402
+ experiment = experiment_request.json()
403
+ risk_eval_count = 0
404
+ for risk_name in risks.keys():
405
+ try:
406
+ if (
407
+ risk_name
408
+ not in experiment.get("source_data", {})
409
+ .get("evaluation_configuration", {})
410
+ .keys()
411
+ ):
412
+ LOGGER.debug(
413
+ f"Skipping experiment {experiment['id']} as it does not have risk {risk_name}"
414
+ )
415
+ continue
416
+ LOGGER.debug(
417
+ f"checking for tests for experiment {experiment['id']}"
418
+ )
419
+ tests_response = await client.get(
420
+ f"{config.CONTROL_PLANE_URL}/api/experiments/{experiment['id']}/tests?unevaluated-risk={quote_plus(risk_name)}&include-risk-evaluations=true",
421
+ headers={"x-api-key": config.API_KEY},
422
+ )
423
+
424
+ if not tests_response.is_success:
425
+ message = (
426
+ tests_response.json().get("message") or tests_response.text
427
+ )
428
+ raise ValueError(
429
+ status_code=tests_response.status_code,
430
+ message=message,
431
+ )
432
+ tests = tests_response.json()
433
+
434
+ for test in tests:
435
+ test_id = test["id"]
436
+ if (
437
+ test_id not in queued_evaluations
438
+ and test.get("response") is not None
439
+ ):
440
+ risk_eval_response = await httpx.AsyncClient().post(
441
+ f"{config.SNOWGLOBE_CLIENT_URL}/risk-evaluation",
442
+ json={"test": test, "risk_name": risk_name},
443
+ timeout=30,
444
+ )
445
+ # if risk evaltion response is 429 raise and exception and bail on this batch
446
+ if (
447
+ not risk_eval_response.is_success
448
+ and risk_eval_response.status_code == 429
449
+ ):
450
+ LOGGER.error(
451
+ f"Rate limit exceeded for risk evaluation {test['id']}: {risk_eval_response.text}"
452
+ )
453
+ raise ValueError(
454
+ status_code=429,
455
+ detail=f"Rate limit exceeded for risk evaluation {test['id']}",
456
+ )
457
+ queued_evaluations[test_id] = True
458
+ risk_eval_count += 1
459
+ except Exception as e:
460
+ LOGGER.error(f"Error fetching tests: {e}")
461
+ if risk_eval_count > 0:
462
+ LOGGER.info(
463
+ f"Processed {risk_eval_count} risk evaluations for experiment {experiment.get('name', 'unknown')} ({experiment['id']})"
464
+ )
465
+
466
+
467
+ # Ensure the scheduler shuts down properly on application exit.
468
+ @asynccontextmanager
469
+ async def lifespan(app: FastAPI):
470
+ try:
471
+ # Load agents from .snowglobe/agents.json system
472
+ agents_json_path = os.path.join(os.getcwd(), ".snowglobe", "agents.json")
473
+ if os.path.exists(agents_json_path):
474
+ try:
475
+ with open(agents_json_path, "r") as f:
476
+ agents_data = json.load(f)
477
+
478
+ for filename, agent_info in agents_data.items():
479
+ try:
480
+ app_id = agent_info["uuid"]
481
+ app_name = agent_info["name"]
482
+ agent_file_path = os.path.join(os.getcwd(), filename)
483
+
484
+ if not os.path.exists(agent_file_path):
485
+ LOGGER.warning(f"Agent file not found: {agent_file_path}")
486
+ continue
487
+
488
+ spec = importlib.util.spec_from_file_location(
489
+ "agent_wrapper", agent_file_path
490
+ )
491
+ agent_module = importlib.util.module_from_spec(spec)
492
+ spec.loader.exec_module(agent_module)
493
+
494
+ if not hasattr(agent_module, "process_scenario"):
495
+ LOGGER.warning(
496
+ f"Agent {filename} does not have a process_scenario function"
497
+ )
498
+ continue
499
+
500
+ # Wrap process_scenario to match the expected completion_fn interface
501
+ def make_completion_fn(process_fn):
502
+ def completion_fn(request):
503
+ return process_fn(request)
504
+
505
+ return completion_fn
506
+
507
+ apps[app_id] = {
508
+ "completion_fn": make_completion_fn(
509
+ agent_module.process_scenario
510
+ ),
511
+ "name": app_name,
512
+ }
513
+
514
+ except Exception as e:
515
+ LOGGER.error(f"Error loading agent {filename}: {e}")
516
+ continue
517
+
518
+ except (json.JSONDecodeError, IOError) as e:
519
+ LOGGER.error(f"Error reading agents.json: {e}")
520
+ else:
521
+ LOGGER.warning(
522
+ "No .snowglobe/agents.json found. Run 'snowglobe-connect init' to set up an agent."
523
+ )
524
+
525
+ for app_id, app_info in apps.items():
526
+ LOGGER.info(
527
+ f"Loaded application {app_info['name']} with ID {app_id} for completions."
528
+ )
529
+ if config.APPLICATION_ID:
530
+ if config.APPLICATION_ID not in apps:
531
+ LOGGER.warning(
532
+ "\n********* START WARNING *********"
533
+ f"\nLegacy single application detected with ID {config.APPLICATION_ID}. "
534
+ "\nPlease migrate to the new applications structure."
535
+ f"\nRun snowglobe-connect init and follow the prompts to set up your application."
536
+ "\nThis configuration option will be removed in the next major release."
537
+ "\n********* END WARNING *********"
538
+ )
539
+ # load the legacy applications connect file out of the base of this directory
540
+ legacy_connect_file = os.path.join(os.getcwd(), "snowglobe_connect.py")
541
+ if os.path.exists(legacy_connect_file):
542
+ spec = importlib.util.spec_from_file_location(
543
+ "snowglobe_connect", legacy_connect_file
544
+ )
545
+ sg_connect = importlib.util.module_from_spec(spec)
546
+ spec.loader.exec_module(sg_connect)
547
+ if hasattr(sg_connect, "completion_fn"):
548
+ apps[config.APPLICATION_ID] = {
549
+ "completion_fn": sg_connect.completion_fn,
550
+ "name": "Legacy Single Application",
551
+ }
552
+ LOGGER.info(
553
+ f"Loaded legacy application with ID {config.APPLICATION_ID} for completions."
554
+ )
555
+ else:
556
+ LOGGER.error(
557
+ f"Legacy application with ID {config.APPLICATION_ID} does not have a completion_fn."
558
+ )
559
+ except Exception as e:
560
+ LOGGER.error(f"Error loading applications: {e}")
561
+
562
+ # attempt to judge risks from custom_risks/
563
+ # each judge name is encoded in the filename with spaces replaced by underscores
564
+ try:
565
+ risks_dir = os.path.join(os.getcwd(), "custom_risks")
566
+ if os.path.exists(risks_dir):
567
+ for judge_file in os.listdir(risks_dir):
568
+ if judge_file.endswith(".py"):
569
+ judge_name = judge_file[:-3].replace("_", " ")
570
+
571
+ spec = importlib.util.spec_from_file_location(
572
+ judge_name, os.path.join(risks_dir, judge_file)
573
+ )
574
+ judge_module = importlib.util.module_from_spec(spec)
575
+ spec.loader.exec_module(judge_module)
576
+ if hasattr(judge_module, "risk_evaluation_fn"):
577
+ risks[judge_name] = judge_module.risk_evaluation_fn
578
+ else:
579
+ LOGGER.warning(
580
+ f"Judge {judge_name} does not have a risk_evaluation_fn. Skipping."
581
+ )
582
+ LOGGER.info(f"Loaded risks: {list(risks.keys())}")
583
+ except Exception as e:
584
+ LOGGER.error(
585
+ f"Error loading risks from custom_risks: {e}. Custom judging will not be available."
586
+ )
587
+
588
+ async with AsyncScheduler() as scheduler:
589
+ await scheduler.add_schedule(poll_for_completions, IntervalTrigger(seconds=3))
590
+ await scheduler.add_schedule(
591
+ poll_for_risk_evaluations, IntervalTrigger(seconds=7)
592
+ )
593
+ await scheduler.add_schedule(
594
+ process_application_heartbeats, IntervalTrigger(minutes=5)
595
+ )
596
+ await scheduler.start_in_background()
597
+ yield
598
+ # try to loop over applications and send a failed connection test
599
+ for app_id in apps.keys():
600
+ try:
601
+ connection_test_url = (
602
+ f"{config.CONTROL_PLANE_URL}/api/failed-code-connection-tests"
603
+ )
604
+ connection_test_payload = {
605
+ "appId": app_id,
606
+ "status": "failed",
607
+ "error": "snowglobe-connect shut down gracefully",
608
+ }
609
+ async with httpx.AsyncClient() as client:
610
+ LOGGER.info(
611
+ f"Posting shut down code connection test for application {app_id} connection_test_payload: {connection_test_payload}"
612
+ )
613
+ connection_test_response = await client.post(
614
+ connection_test_url,
615
+ json=connection_test_payload,
616
+ headers={"x-api-key": config.API_KEY},
617
+ )
618
+ if not connection_test_response.is_success:
619
+ LOGGER.error(
620
+ f"Error posting shut down code connection test for application {app_id}: {connection_test_response.text}"
621
+ )
622
+ LOGGER.info(
623
+ f"Posted shut down heart beat for application {app_id} successfully."
624
+ )
625
+ except Exception as e:
626
+ LOGGER.error(
627
+ f"Error processing application heartbeat for {app_id}: {e}\n{traceback.format_exc()}"
628
+ )
629
+ await scheduler.stop()
630
+ await scheduler.wait_until_stopped()
631
+
632
+
633
+ def create_client():
634
+ """Create and configure the FastAPI application."""
635
+ app = FastAPI(lifespan=lifespan)
636
+
637
+ @app.get("/")
638
+ def read_root():
639
+ return {"message": "Dashing through the snow..."}
640
+
641
+ @app.post("/completion")
642
+ @rate_limit(
643
+ "completion",
644
+ max_requests=config.CONCURRENT_COMPLETIONS_PER_INTERVAL,
645
+ time_window=config.CONCURRENT_COMPLETIONS_INTERVAL_SECONDS,
646
+ )
647
+ async def completion_endpoint(request: Request):
648
+ # request body is test
649
+ completion_body = await request.json()
650
+ test = completion_body.get("test")
651
+ app_id = completion_body.get("app_id")
652
+ # both are required non empty strings
653
+ if not test or not app_id:
654
+ raise HTTPException(
655
+ status_code=400,
656
+ detail="Both 'test' and 'app_id' must be provided in the request body.",
657
+ )
658
+ if app_id not in apps:
659
+ raise HTTPException(
660
+ status_code=404,
661
+ detail=f"Application with ID {app_id} not found.",
662
+ )
663
+ completion_fn = apps.get(app_id, {}).get("completion_fn")
664
+ LOGGER.debug(f"Received test: {test['id']}")
665
+
666
+ await process_test(test, completion_fn, app_id)
667
+ return {"status": "processed"}
668
+
669
+ @app.post("/heartbeat")
670
+ @rate_limit(
671
+ "heartbeat",
672
+ max_requests=config.CONCURRENT_HEARTBEATS_PER_INTERVAL,
673
+ time_window=config.CONCURRENT_HEARTBEATS_INTERVAL_SECONDS,
674
+ )
675
+ async def heartbeat_endpoint(request: Request):
676
+ """Endpoint to check if the client is alive and well."""
677
+ body = await request.json()
678
+ app_id = body.get("app_id")
679
+ if not app_id:
680
+ raise HTTPException(
681
+ status_code=400,
682
+ detail="Application ID must be provided in the request body.",
683
+ )
684
+ if app_id not in apps:
685
+ raise HTTPException(
686
+ status_code=404,
687
+ detail=f"Application with ID {app_id} not found.",
688
+ )
689
+ LOGGER.debug(f"Received heartbeat for application: {app_id}")
690
+
691
+ # Simulate processing heartbeat
692
+ await process_application_heartbeat(app_id)
693
+ return {"status": "heartbeat received"}
694
+
695
+ @app.post("/risk-evaluation")
696
+ @rate_limit(
697
+ "risk_evaluation",
698
+ max_requests=config.CONCURRENT_RISK_EVALUATIONS,
699
+ time_window=config.CONCURRENT_RISK_EVALUATIONS_INTERVAL_SECONDS,
700
+ )
701
+ async def risk_evaluation_endpoint(request: Request):
702
+ # request body is test
703
+ body = await request.json()
704
+ test = body.get("test")
705
+ risk_name = body.get("risk_name")
706
+ LOGGER.debug(f"Received risk evaluation for test: {test['id']}")
707
+
708
+ # For now, just simulate processing
709
+ await process_risk_evaluation(test, risk_name)
710
+ return {"status": "risk evaluation processed"}
711
+
712
+ return app
713
+
714
+
715
+ def start_client(verbose=False):
716
+ """Start the FastAPI client."""
717
+ # Configure logging based on verbose flag
718
+ if not verbose:
719
+ LOGGER.setLevel(logging.WARNING)
720
+ logging.getLogger("apscheduler").setLevel(logging.ERROR)
721
+
722
+ # Initialize stats tracking
723
+ initialize_stats()
724
+
725
+ app = create_client()
726
+
727
+ port = config.SNOWGLOBE_CLIENT_PORT
728
+ uvicorn.run(app, host="0.0.0.0", port=port, log_level="warning")
729
+
730
+
731
+ if __name__ == "__main__":
732
+ start_client()