snowglobe 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowglobe/client/__init__.py +17 -0
- snowglobe/client/src/app.py +732 -0
- snowglobe/client/src/cli.py +736 -0
- snowglobe/client/src/cli_utils.py +361 -0
- snowglobe/client/src/config.py +213 -0
- snowglobe/client/src/models.py +37 -0
- snowglobe/client/src/project_manager.py +290 -0
- snowglobe/client/src/stats.py +53 -0
- snowglobe/client/src/utils.py +117 -0
- snowglobe-0.4.0.dist-info/METADATA +128 -0
- snowglobe-0.4.0.dist-info/RECORD +15 -0
- snowglobe-0.4.0.dist-info/WHEEL +5 -0
- snowglobe-0.4.0.dist-info/entry_points.txt +2 -0
- snowglobe-0.4.0.dist-info/licenses/LICENSE +21 -0
- snowglobe-0.4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,732 @@
|
|
1
|
+
# Set up the scheduler
|
2
|
+
import asyncio
|
3
|
+
import datetime
|
4
|
+
import importlib.util
|
5
|
+
import json
|
6
|
+
import logging
|
7
|
+
import os
|
8
|
+
import time
|
9
|
+
import traceback
|
10
|
+
from collections import defaultdict, deque
|
11
|
+
from contextlib import asynccontextmanager
|
12
|
+
from functools import wraps
|
13
|
+
from logging import getLogger
|
14
|
+
from typing import Dict
|
15
|
+
from urllib.parse import quote_plus
|
16
|
+
|
17
|
+
import httpx
|
18
|
+
import uvicorn
|
19
|
+
from apscheduler import AsyncScheduler
|
20
|
+
from apscheduler.triggers.interval import IntervalTrigger
|
21
|
+
from fastapi import FastAPI, HTTPException, Request
|
22
|
+
|
23
|
+
from .cli_utils import info
|
24
|
+
from .config import config
|
25
|
+
from .models import CompletionFunctionOutputs, CompletionRequest, RiskEvaluationRequest
|
26
|
+
from .stats import initialize_stats, track_batch_completion
|
27
|
+
from .utils import fetch_experiments, fetch_messages
|
28
|
+
|
29
|
+
# Configure logging - check DEBUG from environment directly to avoid config initialization
|
30
|
+
if os.getenv("DEBUG", "false").lower() == "true":
|
31
|
+
logging.basicConfig(
|
32
|
+
handlers=[logging.StreamHandler()],
|
33
|
+
level=os.getenv("LOG_LEVEL", "DEBUG").upper(),
|
34
|
+
)
|
35
|
+
else:
|
36
|
+
logging.basicConfig(
|
37
|
+
handlers=[logging.StreamHandler()],
|
38
|
+
)
|
39
|
+
|
40
|
+
LOGGER = getLogger("snowglobe_client")
|
41
|
+
# Allow LOG_LEVEL environment variable to override default INFO level
|
42
|
+
LOGGER.setLevel(getattr(logging, os.getenv("LOG_LEVEL", "INFO").upper()))
|
43
|
+
|
44
|
+
# In-memory storage for rate limiting per route
|
45
|
+
route_request_times = defaultdict(lambda: defaultdict(deque))
|
46
|
+
|
47
|
+
|
48
|
+
class ConfigurableRateLimiter:
|
49
|
+
def __init__(self):
|
50
|
+
self.route_configs = {}
|
51
|
+
|
52
|
+
def configure_route(self, route_key: str, max_requests: int, time_window: int):
|
53
|
+
"""Configure rate limiting for a specific route"""
|
54
|
+
self.route_configs[route_key] = {
|
55
|
+
"max_requests": max_requests,
|
56
|
+
"time_window": time_window,
|
57
|
+
}
|
58
|
+
|
59
|
+
def is_allowed(self, client_id: str, route_key: str) -> bool:
|
60
|
+
# Get route configuration or use default
|
61
|
+
config = self.route_configs.get(
|
62
|
+
route_key, {"max_requests": 1, "time_window": 1}
|
63
|
+
)
|
64
|
+
max_requests = config["max_requests"]
|
65
|
+
time_window = config["time_window"]
|
66
|
+
|
67
|
+
now = time.time()
|
68
|
+
client_requests = route_request_times[route_key][client_id]
|
69
|
+
|
70
|
+
# Remove old requests outside the time window
|
71
|
+
while client_requests and client_requests[0] <= now - time_window:
|
72
|
+
client_requests.popleft()
|
73
|
+
|
74
|
+
# Check if client has exceeded rate limit for this route
|
75
|
+
if len(client_requests) >= max_requests:
|
76
|
+
return False
|
77
|
+
|
78
|
+
# Add current request time
|
79
|
+
client_requests.append(now)
|
80
|
+
return True
|
81
|
+
|
82
|
+
def get_route_config(self, route_key: str) -> Dict[str, int]:
|
83
|
+
"""Get the configuration for a specific route"""
|
84
|
+
return self.route_configs.get(route_key, {"max_requests": 1, "time_window": 1})
|
85
|
+
|
86
|
+
|
87
|
+
# Create global rate limiter instance
|
88
|
+
rate_limiter = ConfigurableRateLimiter()
|
89
|
+
|
90
|
+
|
91
|
+
def rate_limit(route_name: str, max_requests: int = 1, time_window: int = 1):
|
92
|
+
"""Decorator for per-route rate limiting with configurable limits"""
|
93
|
+
|
94
|
+
def decorator(func):
|
95
|
+
# Configure the route when decorator is applied
|
96
|
+
rate_limiter.configure_route(route_name, max_requests, time_window)
|
97
|
+
|
98
|
+
@wraps(func)
|
99
|
+
async def wrapper(request: Request, *args, **kwargs):
|
100
|
+
client_ip = request.client.host
|
101
|
+
|
102
|
+
if not rate_limiter.is_allowed(client_ip, route_name):
|
103
|
+
config = rate_limiter.get_route_config(route_name)
|
104
|
+
raise HTTPException(
|
105
|
+
status_code=429,
|
106
|
+
detail=f"Rate limit exceeded for route '{route_name}'. Maximum {config['max_requests']} requests per {config['time_window']} seconds.",
|
107
|
+
)
|
108
|
+
|
109
|
+
return await func(request, *args, **kwargs)
|
110
|
+
|
111
|
+
return wrapper
|
112
|
+
|
113
|
+
return decorator
|
114
|
+
|
115
|
+
|
116
|
+
queued_tests = {}
|
117
|
+
queued_evaluations = {}
|
118
|
+
risks = {}
|
119
|
+
apps = {}
|
120
|
+
|
121
|
+
|
122
|
+
async def process_application_heartbeat(app_id):
|
123
|
+
connection_test_payload = {
|
124
|
+
"appId": app_id,
|
125
|
+
}
|
126
|
+
try:
|
127
|
+
prompt = "Hello from Snowglobe!"
|
128
|
+
test_request = CompletionRequest(messages=[{"role": "user", "content": prompt}])
|
129
|
+
completion_fn = apps.get(app_id, {}).get("completion_fn")
|
130
|
+
if not completion_fn:
|
131
|
+
LOGGER.warning(
|
132
|
+
f"No completion function found for application {app_id}. Skipping heartbeat."
|
133
|
+
)
|
134
|
+
return
|
135
|
+
if asyncio.iscoroutinefunction(completion_fn):
|
136
|
+
response = await completion_fn(test_request)
|
137
|
+
else:
|
138
|
+
response = completion_fn(test_request)
|
139
|
+
if not isinstance(response, CompletionFunctionOutputs):
|
140
|
+
LOGGER.error(
|
141
|
+
f"Completion function for application {app_id} did not return a valid response. Expected CompletionFunctionOutputs, got {type(response)}"
|
142
|
+
)
|
143
|
+
connection_test_payload["status"] = "failed"
|
144
|
+
connection_test_payload["error"] = (
|
145
|
+
"Completion function did not return a valid response. "
|
146
|
+
"Expected CompletionFunctionOutputs, got {type(response)}"
|
147
|
+
)
|
148
|
+
|
149
|
+
if not hasattr(response, "response") or not isinstance(response.response, str):
|
150
|
+
LOGGER.error(
|
151
|
+
f"Completion function for application {app_id} did not return a valid response. Expected a string, got {type(response.response)}"
|
152
|
+
)
|
153
|
+
connection_test_payload["status"] = "failed"
|
154
|
+
connection_test_payload["error"] = (
|
155
|
+
"Completion function did not return a valid response. Expected a string, got {type(response.response)}"
|
156
|
+
)
|
157
|
+
|
158
|
+
if response.response == "This is a string response from your application":
|
159
|
+
LOGGER.error(
|
160
|
+
f"Completion function for application {app_id} returned a default response. This indicates the application is not properly connected."
|
161
|
+
)
|
162
|
+
connection_test_payload["status"] = "failed"
|
163
|
+
connection_test_payload["error"] = (
|
164
|
+
"Completion function returned a default response. "
|
165
|
+
)
|
166
|
+
|
167
|
+
if connection_test_payload.get("status") != "failed":
|
168
|
+
connection_test_payload["response"] = response.response
|
169
|
+
connection_test_payload["prompt"] = prompt
|
170
|
+
|
171
|
+
except Exception as e:
|
172
|
+
connection_test_payload["status"] = "failed"
|
173
|
+
connection_test_payload["error"] = f"{str(e)}\n{traceback.format_exc()}"
|
174
|
+
connection_test_payload["app_id"] = app_id
|
175
|
+
connection_test_payload["applicationId"] = app_id
|
176
|
+
|
177
|
+
connection_test_url = (
|
178
|
+
f"{config.CONTROL_PLANE_URL}/api/successful-code-connection-tests"
|
179
|
+
)
|
180
|
+
|
181
|
+
if connection_test_payload.get("status") == "failed":
|
182
|
+
connection_test_url = (
|
183
|
+
f"{config.CONTROL_PLANE_URL}/api/failed-code-connection-tests"
|
184
|
+
)
|
185
|
+
|
186
|
+
async with httpx.AsyncClient() as client:
|
187
|
+
LOGGER.info(
|
188
|
+
f"Posting code connection test for application {app_id} connection_test_payload: {connection_test_payload}"
|
189
|
+
)
|
190
|
+
connection_test_response = await client.post(
|
191
|
+
connection_test_url,
|
192
|
+
json=connection_test_payload,
|
193
|
+
headers={"x-api-key": config.API_KEY},
|
194
|
+
)
|
195
|
+
|
196
|
+
if not connection_test_response.is_success:
|
197
|
+
LOGGER.error(
|
198
|
+
f"Error posting code connection test for application {app_id}: {connection_test_response.text}"
|
199
|
+
)
|
200
|
+
return connection_test_response.json()
|
201
|
+
|
202
|
+
|
203
|
+
async def process_risk_evaluation(test, risk_name):
|
204
|
+
"""finds correct risk and calls the risk evaluation function and creates a risk evaluation for the test"""
|
205
|
+
start = time.time()
|
206
|
+
|
207
|
+
messages = await fetch_messages(test=test)
|
208
|
+
|
209
|
+
if asyncio.iscoroutinefunction(risks[risk_name]):
|
210
|
+
risk_evaluation = await risks[risk_name](
|
211
|
+
RiskEvaluationRequest(messages=messages)
|
212
|
+
)
|
213
|
+
else:
|
214
|
+
risk_evaluation = risks[risk_name](RiskEvaluationRequest(messages=messages))
|
215
|
+
|
216
|
+
LOGGER.debug(f"Risk evaluation output: {risk_evaluation}")
|
217
|
+
|
218
|
+
# Extract fields from risk_evaluation object
|
219
|
+
severity = getattr(risk_evaluation, "severity", "")
|
220
|
+
reason = getattr(risk_evaluation, "reason", "")
|
221
|
+
risk_triggered = getattr(risk_evaluation, "triggered", "")
|
222
|
+
|
223
|
+
response_xml = (
|
224
|
+
f"<risk>"
|
225
|
+
f"<name>{risk_name}</name>"
|
226
|
+
f"<severity>{severity}</severity>"
|
227
|
+
f"<reason>{reason}</reason>"
|
228
|
+
f"<risk_triggered>{risk_triggered}</risk_triggered>"
|
229
|
+
f"</risk>"
|
230
|
+
)
|
231
|
+
|
232
|
+
# Post a Risk Evaluation (async)
|
233
|
+
async with httpx.AsyncClient() as client:
|
234
|
+
risk_evaluation_response = await client.post(
|
235
|
+
f"{config.CONTROL_PLANE_URL}/api/experiments/{test['experiment_id']}/tests/{test['id']}/evaluations",
|
236
|
+
json={
|
237
|
+
"test_id": test["id"],
|
238
|
+
"judge_prompt": "", # does this need to be set?
|
239
|
+
"judge_response": response_xml,
|
240
|
+
"risk_type": risk_name,
|
241
|
+
"risk_triggered": risk_evaluation.triggered,
|
242
|
+
},
|
243
|
+
headers={"x-api-key": config.API_KEY},
|
244
|
+
)
|
245
|
+
LOGGER.debug(f"Time taken: {time.time() - start} seconds")
|
246
|
+
if not risk_evaluation_response.is_success:
|
247
|
+
LOGGER.error("Error posting risk evaluation", risk_evaluation_response.text)
|
248
|
+
raise Exception("Error posting risk evaluation, task is not healthy")
|
249
|
+
|
250
|
+
|
251
|
+
async def process_test(test, completion_fn, app_id):
|
252
|
+
"""Processes a test by converting it to OpenAI style messages and calling the completion function"""
|
253
|
+
start = time.time()
|
254
|
+
# convert test to openai style messages
|
255
|
+
messages = await fetch_messages(test=test)
|
256
|
+
|
257
|
+
if asyncio.iscoroutinefunction(completion_fn):
|
258
|
+
completionOutput = await completion_fn(CompletionRequest(messages=messages))
|
259
|
+
else:
|
260
|
+
completionOutput = completion_fn(CompletionRequest(messages=messages))
|
261
|
+
|
262
|
+
LOGGER.debug(f"Completion output: {completionOutput}")
|
263
|
+
|
264
|
+
async with httpx.AsyncClient() as client:
|
265
|
+
await client.put(
|
266
|
+
f"{config.CONTROL_PLANE_URL}/api/experiments/{test['experiment_id']}/tests/{test['id']}",
|
267
|
+
json={
|
268
|
+
"id": test["id"],
|
269
|
+
"appId": app_id,
|
270
|
+
"prompt": test["prompt"],
|
271
|
+
"response": completionOutput.response,
|
272
|
+
"persona": test["persona"],
|
273
|
+
},
|
274
|
+
headers={"x-api-key": config.API_KEY},
|
275
|
+
)
|
276
|
+
LOGGER.debug(f"Time taken: {time.time() - start} seconds")
|
277
|
+
# remove test['id'] from queued_tests
|
278
|
+
queued_tests.pop(test["id"], None)
|
279
|
+
|
280
|
+
return completionOutput
|
281
|
+
|
282
|
+
|
283
|
+
# The task to run
|
284
|
+
async def poll_for_completions():
|
285
|
+
if len(apps) == 0:
|
286
|
+
LOGGER.warning("No applications found. Skipping completions check.")
|
287
|
+
return
|
288
|
+
experiments = await fetch_experiments()
|
289
|
+
|
290
|
+
async with httpx.AsyncClient() as client:
|
291
|
+
LOGGER.info(f"Polling {len(experiments)} experiments for completions...")
|
292
|
+
for experiment in experiments:
|
293
|
+
LOGGER.debug(
|
294
|
+
f"Checking experiment: id={experiment.get('id')}, name={experiment.get('name', 'unknown')}"
|
295
|
+
)
|
296
|
+
for experiment in experiments:
|
297
|
+
app_id = experiment.get("app_id")
|
298
|
+
if app_id not in apps:
|
299
|
+
LOGGER.debug(
|
300
|
+
f"Skipping experiment as we do not have a completion function for app_id: {app_id}"
|
301
|
+
)
|
302
|
+
continue
|
303
|
+
experiment_id = experiment["id"]
|
304
|
+
experiment_request = await client.get(
|
305
|
+
f"{config.CONTROL_PLANE_URL}/api/experiments/{experiment_id}?appId={app_id}",
|
306
|
+
headers={"x-api-key": config.API_KEY},
|
307
|
+
)
|
308
|
+
if not experiment_request.is_success:
|
309
|
+
LOGGER.error(
|
310
|
+
f"Error fetching experiment {experiment_id}: {experiment_request.text}"
|
311
|
+
)
|
312
|
+
continue
|
313
|
+
experiment = experiment_request.json()
|
314
|
+
|
315
|
+
limit = 10
|
316
|
+
LOGGER.debug(f"Checking for tests for experiment {experiment_id}")
|
317
|
+
tests_response = await client.get(
|
318
|
+
f"{config.CONTROL_PLANE_URL}/api/experiments/{experiment_id}/tests?appId={app_id}&include-risk-evaluations=false&limit={limit}&unprocessed-only=true",
|
319
|
+
headers={"x-api-key": config.API_KEY},
|
320
|
+
)
|
321
|
+
|
322
|
+
if not tests_response.is_success:
|
323
|
+
LOGGER.error(f"Error fetching tests: {tests_response.text}")
|
324
|
+
continue
|
325
|
+
|
326
|
+
tests = tests_response.json()
|
327
|
+
completion_count = 0
|
328
|
+
for test in tests:
|
329
|
+
LOGGER.debug(f"Found test {test['id']} for experiment {experiment_id}")
|
330
|
+
if not test["response"] and test["id"] not in queued_tests:
|
331
|
+
completion_request = await httpx.AsyncClient().post(
|
332
|
+
f"{config.SNOWGLOBE_CLIENT_URL}/completion",
|
333
|
+
json={"test": test, "app_id": app_id},
|
334
|
+
timeout=30,
|
335
|
+
)
|
336
|
+
# if 429 raise and exception and stop this batch
|
337
|
+
if (
|
338
|
+
not completion_request.is_success
|
339
|
+
and completion_request.status_code == 429
|
340
|
+
):
|
341
|
+
LOGGER.warning(
|
342
|
+
f"Completion Rate limit exceeded for test {test['id']}: {completion_request.text}"
|
343
|
+
)
|
344
|
+
raise ValueError(
|
345
|
+
status_code=429,
|
346
|
+
detail=f"Rate limit exceeded for test {test['id']}",
|
347
|
+
)
|
348
|
+
completion_count += 1
|
349
|
+
queued_tests[test["id"]] = True
|
350
|
+
if completion_count > 0:
|
351
|
+
experiment_name = experiment.get("name", "unknown")
|
352
|
+
if LOGGER.level <= logging.INFO: # Verbose mode
|
353
|
+
LOGGER.info(
|
354
|
+
f"Processed {completion_count} completions for experiment {experiment_name} ({experiment_id})"
|
355
|
+
)
|
356
|
+
else: # Clean UI mode
|
357
|
+
timestamp = datetime.datetime.now().strftime("%H:%M:%S")
|
358
|
+
info(
|
359
|
+
f"[{timestamp}] ✓ Batch complete: {completion_count} responses sent ({experiment_name})"
|
360
|
+
)
|
361
|
+
|
362
|
+
# Track batch completion
|
363
|
+
track_batch_completion(experiment_name, completion_count)
|
364
|
+
|
365
|
+
|
366
|
+
async def process_application_heartbeats():
|
367
|
+
connection_test_count = 0
|
368
|
+
LOGGER.info("Processing application heartbeats...")
|
369
|
+
for app_id, app_info in apps.items():
|
370
|
+
connection_test_request = await httpx.AsyncClient().post(
|
371
|
+
f"{config.SNOWGLOBE_CLIENT_URL}/heartbeat",
|
372
|
+
json={"app_id": app_id},
|
373
|
+
timeout=30,
|
374
|
+
)
|
375
|
+
if not connection_test_request.is_success:
|
376
|
+
LOGGER.error(
|
377
|
+
f"Error sending heartbeat for application {app_id}: {connection_test_request.text}"
|
378
|
+
)
|
379
|
+
continue
|
380
|
+
connection_test_count += 1
|
381
|
+
|
382
|
+
LOGGER.info(f"Processed {connection_test_count} heartbeats for applications.")
|
383
|
+
|
384
|
+
|
385
|
+
async def poll_for_risk_evaluations():
|
386
|
+
"""Poll for risk evaluations and process them."""
|
387
|
+
experiments = await fetch_experiments()
|
388
|
+
LOGGER.info("Checking for pending risk evaluations...")
|
389
|
+
LOGGER.debug(f"Found {len(experiments)} experiments with validation in progress")
|
390
|
+
# experiments = [{"id": "123"}]
|
391
|
+
async with httpx.AsyncClient() as client:
|
392
|
+
for experiment in experiments:
|
393
|
+
experiment_request = await client.get(
|
394
|
+
f"{config.CONTROL_PLANE_URL}/api/experiments/{experiment['id']}",
|
395
|
+
headers={"x-api-key": config.API_KEY},
|
396
|
+
)
|
397
|
+
if not experiment_request.is_success:
|
398
|
+
LOGGER.error(
|
399
|
+
f"Error fetching experiment {experiment['id']}: {experiment_request.text}"
|
400
|
+
)
|
401
|
+
continue
|
402
|
+
experiment = experiment_request.json()
|
403
|
+
risk_eval_count = 0
|
404
|
+
for risk_name in risks.keys():
|
405
|
+
try:
|
406
|
+
if (
|
407
|
+
risk_name
|
408
|
+
not in experiment.get("source_data", {})
|
409
|
+
.get("evaluation_configuration", {})
|
410
|
+
.keys()
|
411
|
+
):
|
412
|
+
LOGGER.debug(
|
413
|
+
f"Skipping experiment {experiment['id']} as it does not have risk {risk_name}"
|
414
|
+
)
|
415
|
+
continue
|
416
|
+
LOGGER.debug(
|
417
|
+
f"checking for tests for experiment {experiment['id']}"
|
418
|
+
)
|
419
|
+
tests_response = await client.get(
|
420
|
+
f"{config.CONTROL_PLANE_URL}/api/experiments/{experiment['id']}/tests?unevaluated-risk={quote_plus(risk_name)}&include-risk-evaluations=true",
|
421
|
+
headers={"x-api-key": config.API_KEY},
|
422
|
+
)
|
423
|
+
|
424
|
+
if not tests_response.is_success:
|
425
|
+
message = (
|
426
|
+
tests_response.json().get("message") or tests_response.text
|
427
|
+
)
|
428
|
+
raise ValueError(
|
429
|
+
status_code=tests_response.status_code,
|
430
|
+
message=message,
|
431
|
+
)
|
432
|
+
tests = tests_response.json()
|
433
|
+
|
434
|
+
for test in tests:
|
435
|
+
test_id = test["id"]
|
436
|
+
if (
|
437
|
+
test_id not in queued_evaluations
|
438
|
+
and test.get("response") is not None
|
439
|
+
):
|
440
|
+
risk_eval_response = await httpx.AsyncClient().post(
|
441
|
+
f"{config.SNOWGLOBE_CLIENT_URL}/risk-evaluation",
|
442
|
+
json={"test": test, "risk_name": risk_name},
|
443
|
+
timeout=30,
|
444
|
+
)
|
445
|
+
# if risk evaltion response is 429 raise and exception and bail on this batch
|
446
|
+
if (
|
447
|
+
not risk_eval_response.is_success
|
448
|
+
and risk_eval_response.status_code == 429
|
449
|
+
):
|
450
|
+
LOGGER.error(
|
451
|
+
f"Rate limit exceeded for risk evaluation {test['id']}: {risk_eval_response.text}"
|
452
|
+
)
|
453
|
+
raise ValueError(
|
454
|
+
status_code=429,
|
455
|
+
detail=f"Rate limit exceeded for risk evaluation {test['id']}",
|
456
|
+
)
|
457
|
+
queued_evaluations[test_id] = True
|
458
|
+
risk_eval_count += 1
|
459
|
+
except Exception as e:
|
460
|
+
LOGGER.error(f"Error fetching tests: {e}")
|
461
|
+
if risk_eval_count > 0:
|
462
|
+
LOGGER.info(
|
463
|
+
f"Processed {risk_eval_count} risk evaluations for experiment {experiment.get('name', 'unknown')} ({experiment['id']})"
|
464
|
+
)
|
465
|
+
|
466
|
+
|
467
|
+
# Ensure the scheduler shuts down properly on application exit.
|
468
|
+
@asynccontextmanager
|
469
|
+
async def lifespan(app: FastAPI):
|
470
|
+
try:
|
471
|
+
# Load agents from .snowglobe/agents.json system
|
472
|
+
agents_json_path = os.path.join(os.getcwd(), ".snowglobe", "agents.json")
|
473
|
+
if os.path.exists(agents_json_path):
|
474
|
+
try:
|
475
|
+
with open(agents_json_path, "r") as f:
|
476
|
+
agents_data = json.load(f)
|
477
|
+
|
478
|
+
for filename, agent_info in agents_data.items():
|
479
|
+
try:
|
480
|
+
app_id = agent_info["uuid"]
|
481
|
+
app_name = agent_info["name"]
|
482
|
+
agent_file_path = os.path.join(os.getcwd(), filename)
|
483
|
+
|
484
|
+
if not os.path.exists(agent_file_path):
|
485
|
+
LOGGER.warning(f"Agent file not found: {agent_file_path}")
|
486
|
+
continue
|
487
|
+
|
488
|
+
spec = importlib.util.spec_from_file_location(
|
489
|
+
"agent_wrapper", agent_file_path
|
490
|
+
)
|
491
|
+
agent_module = importlib.util.module_from_spec(spec)
|
492
|
+
spec.loader.exec_module(agent_module)
|
493
|
+
|
494
|
+
if not hasattr(agent_module, "process_scenario"):
|
495
|
+
LOGGER.warning(
|
496
|
+
f"Agent {filename} does not have a process_scenario function"
|
497
|
+
)
|
498
|
+
continue
|
499
|
+
|
500
|
+
# Wrap process_scenario to match the expected completion_fn interface
|
501
|
+
def make_completion_fn(process_fn):
|
502
|
+
def completion_fn(request):
|
503
|
+
return process_fn(request)
|
504
|
+
|
505
|
+
return completion_fn
|
506
|
+
|
507
|
+
apps[app_id] = {
|
508
|
+
"completion_fn": make_completion_fn(
|
509
|
+
agent_module.process_scenario
|
510
|
+
),
|
511
|
+
"name": app_name,
|
512
|
+
}
|
513
|
+
|
514
|
+
except Exception as e:
|
515
|
+
LOGGER.error(f"Error loading agent {filename}: {e}")
|
516
|
+
continue
|
517
|
+
|
518
|
+
except (json.JSONDecodeError, IOError) as e:
|
519
|
+
LOGGER.error(f"Error reading agents.json: {e}")
|
520
|
+
else:
|
521
|
+
LOGGER.warning(
|
522
|
+
"No .snowglobe/agents.json found. Run 'snowglobe-connect init' to set up an agent."
|
523
|
+
)
|
524
|
+
|
525
|
+
for app_id, app_info in apps.items():
|
526
|
+
LOGGER.info(
|
527
|
+
f"Loaded application {app_info['name']} with ID {app_id} for completions."
|
528
|
+
)
|
529
|
+
if config.APPLICATION_ID:
|
530
|
+
if config.APPLICATION_ID not in apps:
|
531
|
+
LOGGER.warning(
|
532
|
+
"\n********* START WARNING *********"
|
533
|
+
f"\nLegacy single application detected with ID {config.APPLICATION_ID}. "
|
534
|
+
"\nPlease migrate to the new applications structure."
|
535
|
+
f"\nRun snowglobe-connect init and follow the prompts to set up your application."
|
536
|
+
"\nThis configuration option will be removed in the next major release."
|
537
|
+
"\n********* END WARNING *********"
|
538
|
+
)
|
539
|
+
# load the legacy applications connect file out of the base of this directory
|
540
|
+
legacy_connect_file = os.path.join(os.getcwd(), "snowglobe_connect.py")
|
541
|
+
if os.path.exists(legacy_connect_file):
|
542
|
+
spec = importlib.util.spec_from_file_location(
|
543
|
+
"snowglobe_connect", legacy_connect_file
|
544
|
+
)
|
545
|
+
sg_connect = importlib.util.module_from_spec(spec)
|
546
|
+
spec.loader.exec_module(sg_connect)
|
547
|
+
if hasattr(sg_connect, "completion_fn"):
|
548
|
+
apps[config.APPLICATION_ID] = {
|
549
|
+
"completion_fn": sg_connect.completion_fn,
|
550
|
+
"name": "Legacy Single Application",
|
551
|
+
}
|
552
|
+
LOGGER.info(
|
553
|
+
f"Loaded legacy application with ID {config.APPLICATION_ID} for completions."
|
554
|
+
)
|
555
|
+
else:
|
556
|
+
LOGGER.error(
|
557
|
+
f"Legacy application with ID {config.APPLICATION_ID} does not have a completion_fn."
|
558
|
+
)
|
559
|
+
except Exception as e:
|
560
|
+
LOGGER.error(f"Error loading applications: {e}")
|
561
|
+
|
562
|
+
# attempt to judge risks from custom_risks/
|
563
|
+
# each judge name is encoded in the filename with spaces replaced by underscores
|
564
|
+
try:
|
565
|
+
risks_dir = os.path.join(os.getcwd(), "custom_risks")
|
566
|
+
if os.path.exists(risks_dir):
|
567
|
+
for judge_file in os.listdir(risks_dir):
|
568
|
+
if judge_file.endswith(".py"):
|
569
|
+
judge_name = judge_file[:-3].replace("_", " ")
|
570
|
+
|
571
|
+
spec = importlib.util.spec_from_file_location(
|
572
|
+
judge_name, os.path.join(risks_dir, judge_file)
|
573
|
+
)
|
574
|
+
judge_module = importlib.util.module_from_spec(spec)
|
575
|
+
spec.loader.exec_module(judge_module)
|
576
|
+
if hasattr(judge_module, "risk_evaluation_fn"):
|
577
|
+
risks[judge_name] = judge_module.risk_evaluation_fn
|
578
|
+
else:
|
579
|
+
LOGGER.warning(
|
580
|
+
f"Judge {judge_name} does not have a risk_evaluation_fn. Skipping."
|
581
|
+
)
|
582
|
+
LOGGER.info(f"Loaded risks: {list(risks.keys())}")
|
583
|
+
except Exception as e:
|
584
|
+
LOGGER.error(
|
585
|
+
f"Error loading risks from custom_risks: {e}. Custom judging will not be available."
|
586
|
+
)
|
587
|
+
|
588
|
+
async with AsyncScheduler() as scheduler:
|
589
|
+
await scheduler.add_schedule(poll_for_completions, IntervalTrigger(seconds=3))
|
590
|
+
await scheduler.add_schedule(
|
591
|
+
poll_for_risk_evaluations, IntervalTrigger(seconds=7)
|
592
|
+
)
|
593
|
+
await scheduler.add_schedule(
|
594
|
+
process_application_heartbeats, IntervalTrigger(minutes=5)
|
595
|
+
)
|
596
|
+
await scheduler.start_in_background()
|
597
|
+
yield
|
598
|
+
# try to loop over applications and send a failed connection test
|
599
|
+
for app_id in apps.keys():
|
600
|
+
try:
|
601
|
+
connection_test_url = (
|
602
|
+
f"{config.CONTROL_PLANE_URL}/api/failed-code-connection-tests"
|
603
|
+
)
|
604
|
+
connection_test_payload = {
|
605
|
+
"appId": app_id,
|
606
|
+
"status": "failed",
|
607
|
+
"error": "snowglobe-connect shut down gracefully",
|
608
|
+
}
|
609
|
+
async with httpx.AsyncClient() as client:
|
610
|
+
LOGGER.info(
|
611
|
+
f"Posting shut down code connection test for application {app_id} connection_test_payload: {connection_test_payload}"
|
612
|
+
)
|
613
|
+
connection_test_response = await client.post(
|
614
|
+
connection_test_url,
|
615
|
+
json=connection_test_payload,
|
616
|
+
headers={"x-api-key": config.API_KEY},
|
617
|
+
)
|
618
|
+
if not connection_test_response.is_success:
|
619
|
+
LOGGER.error(
|
620
|
+
f"Error posting shut down code connection test for application {app_id}: {connection_test_response.text}"
|
621
|
+
)
|
622
|
+
LOGGER.info(
|
623
|
+
f"Posted shut down heart beat for application {app_id} successfully."
|
624
|
+
)
|
625
|
+
except Exception as e:
|
626
|
+
LOGGER.error(
|
627
|
+
f"Error processing application heartbeat for {app_id}: {e}\n{traceback.format_exc()}"
|
628
|
+
)
|
629
|
+
await scheduler.stop()
|
630
|
+
await scheduler.wait_until_stopped()
|
631
|
+
|
632
|
+
|
633
|
+
def create_client():
|
634
|
+
"""Create and configure the FastAPI application."""
|
635
|
+
app = FastAPI(lifespan=lifespan)
|
636
|
+
|
637
|
+
@app.get("/")
|
638
|
+
def read_root():
|
639
|
+
return {"message": "Dashing through the snow..."}
|
640
|
+
|
641
|
+
@app.post("/completion")
|
642
|
+
@rate_limit(
|
643
|
+
"completion",
|
644
|
+
max_requests=config.CONCURRENT_COMPLETIONS_PER_INTERVAL,
|
645
|
+
time_window=config.CONCURRENT_COMPLETIONS_INTERVAL_SECONDS,
|
646
|
+
)
|
647
|
+
async def completion_endpoint(request: Request):
|
648
|
+
# request body is test
|
649
|
+
completion_body = await request.json()
|
650
|
+
test = completion_body.get("test")
|
651
|
+
app_id = completion_body.get("app_id")
|
652
|
+
# both are required non empty strings
|
653
|
+
if not test or not app_id:
|
654
|
+
raise HTTPException(
|
655
|
+
status_code=400,
|
656
|
+
detail="Both 'test' and 'app_id' must be provided in the request body.",
|
657
|
+
)
|
658
|
+
if app_id not in apps:
|
659
|
+
raise HTTPException(
|
660
|
+
status_code=404,
|
661
|
+
detail=f"Application with ID {app_id} not found.",
|
662
|
+
)
|
663
|
+
completion_fn = apps.get(app_id, {}).get("completion_fn")
|
664
|
+
LOGGER.debug(f"Received test: {test['id']}")
|
665
|
+
|
666
|
+
await process_test(test, completion_fn, app_id)
|
667
|
+
return {"status": "processed"}
|
668
|
+
|
669
|
+
@app.post("/heartbeat")
|
670
|
+
@rate_limit(
|
671
|
+
"heartbeat",
|
672
|
+
max_requests=config.CONCURRENT_HEARTBEATS_PER_INTERVAL,
|
673
|
+
time_window=config.CONCURRENT_HEARTBEATS_INTERVAL_SECONDS,
|
674
|
+
)
|
675
|
+
async def heartbeat_endpoint(request: Request):
|
676
|
+
"""Endpoint to check if the client is alive and well."""
|
677
|
+
body = await request.json()
|
678
|
+
app_id = body.get("app_id")
|
679
|
+
if not app_id:
|
680
|
+
raise HTTPException(
|
681
|
+
status_code=400,
|
682
|
+
detail="Application ID must be provided in the request body.",
|
683
|
+
)
|
684
|
+
if app_id not in apps:
|
685
|
+
raise HTTPException(
|
686
|
+
status_code=404,
|
687
|
+
detail=f"Application with ID {app_id} not found.",
|
688
|
+
)
|
689
|
+
LOGGER.debug(f"Received heartbeat for application: {app_id}")
|
690
|
+
|
691
|
+
# Simulate processing heartbeat
|
692
|
+
await process_application_heartbeat(app_id)
|
693
|
+
return {"status": "heartbeat received"}
|
694
|
+
|
695
|
+
@app.post("/risk-evaluation")
|
696
|
+
@rate_limit(
|
697
|
+
"risk_evaluation",
|
698
|
+
max_requests=config.CONCURRENT_RISK_EVALUATIONS,
|
699
|
+
time_window=config.CONCURRENT_RISK_EVALUATIONS_INTERVAL_SECONDS,
|
700
|
+
)
|
701
|
+
async def risk_evaluation_endpoint(request: Request):
|
702
|
+
# request body is test
|
703
|
+
body = await request.json()
|
704
|
+
test = body.get("test")
|
705
|
+
risk_name = body.get("risk_name")
|
706
|
+
LOGGER.debug(f"Received risk evaluation for test: {test['id']}")
|
707
|
+
|
708
|
+
# For now, just simulate processing
|
709
|
+
await process_risk_evaluation(test, risk_name)
|
710
|
+
return {"status": "risk evaluation processed"}
|
711
|
+
|
712
|
+
return app
|
713
|
+
|
714
|
+
|
715
|
+
def start_client(verbose=False):
|
716
|
+
"""Start the FastAPI client."""
|
717
|
+
# Configure logging based on verbose flag
|
718
|
+
if not verbose:
|
719
|
+
LOGGER.setLevel(logging.WARNING)
|
720
|
+
logging.getLogger("apscheduler").setLevel(logging.ERROR)
|
721
|
+
|
722
|
+
# Initialize stats tracking
|
723
|
+
initialize_stats()
|
724
|
+
|
725
|
+
app = create_client()
|
726
|
+
|
727
|
+
port = config.SNOWGLOBE_CLIENT_PORT
|
728
|
+
uvicorn.run(app, host="0.0.0.0", port=port, log_level="warning")
|
729
|
+
|
730
|
+
|
731
|
+
if __name__ == "__main__":
|
732
|
+
start_client()
|