snowglobe 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowglobe/client/__init__.py +17 -0
- snowglobe/client/src/app.py +732 -0
- snowglobe/client/src/cli.py +736 -0
- snowglobe/client/src/cli_utils.py +361 -0
- snowglobe/client/src/config.py +213 -0
- snowglobe/client/src/models.py +37 -0
- snowglobe/client/src/project_manager.py +290 -0
- snowglobe/client/src/stats.py +53 -0
- snowglobe/client/src/utils.py +117 -0
- snowglobe-0.4.0.dist-info/METADATA +128 -0
- snowglobe-0.4.0.dist-info/RECORD +15 -0
- snowglobe-0.4.0.dist-info/WHEEL +5 -0
- snowglobe-0.4.0.dist-info/entry_points.txt +2 -0
- snowglobe-0.4.0.dist-info/licenses/LICENSE +21 -0
- snowglobe-0.4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,361 @@
|
|
1
|
+
import contextlib
|
2
|
+
import math
|
3
|
+
import os
|
4
|
+
import sys
|
5
|
+
import time
|
6
|
+
from typing import Any, Dict, List, Optional, Tuple
|
7
|
+
|
8
|
+
import requests
|
9
|
+
from rich.console import Console
|
10
|
+
from rich.prompt import Confirm, Prompt
|
11
|
+
from rich.table import Table
|
12
|
+
|
13
|
+
from .config import get_rc_file_path
|
14
|
+
from .stats import get_shutdown_stats
|
15
|
+
|
16
|
+
console = Console()
|
17
|
+
|
18
|
+
|
19
|
+
class CliState:
|
20
|
+
"""Global CLI state management"""
|
21
|
+
|
22
|
+
def __init__(self):
|
23
|
+
self.verbose = False
|
24
|
+
self.quiet = False
|
25
|
+
self.json_output = False
|
26
|
+
|
27
|
+
|
28
|
+
cli_state = CliState()
|
29
|
+
|
30
|
+
|
31
|
+
def get_api_key() -> Optional[str]:
|
32
|
+
"""Get API key from environment or config file"""
|
33
|
+
api_key = os.getenv("SNOWGLOBE_API_KEY") or os.getenv("GUARDRAILS_API_KEY")
|
34
|
+
if not api_key:
|
35
|
+
rc_path = get_rc_file_path()
|
36
|
+
if os.path.exists(rc_path):
|
37
|
+
with open(rc_path, "r") as rc_file:
|
38
|
+
for line in rc_file:
|
39
|
+
if line.startswith("SNOWGLOBE_API_KEY="):
|
40
|
+
api_key = line.strip().split("=", 1)[1]
|
41
|
+
break
|
42
|
+
return api_key
|
43
|
+
|
44
|
+
|
45
|
+
def get_control_plane_url() -> str:
|
46
|
+
"""Get control plane URL from environment or config file"""
|
47
|
+
control_plane_url = os.getenv("CONTROL_PLANE_URL")
|
48
|
+
if not control_plane_url:
|
49
|
+
control_plane_url = "https://api.snowglobe.guardrailsai.com"
|
50
|
+
rc_path = get_rc_file_path()
|
51
|
+
if os.path.exists(rc_path):
|
52
|
+
with open(rc_path, "r") as rc_file:
|
53
|
+
for line in rc_file:
|
54
|
+
if line.startswith("CONTROL_PLANE_URL="):
|
55
|
+
control_plane_url = line.strip().split("=", 1)[1]
|
56
|
+
break
|
57
|
+
return control_plane_url
|
58
|
+
|
59
|
+
|
60
|
+
def success(message: str) -> None:
|
61
|
+
"""Print success message with formatting"""
|
62
|
+
if cli_state.json_output:
|
63
|
+
return
|
64
|
+
if not cli_state.quiet:
|
65
|
+
console.print(f"✅ {message}", style="green")
|
66
|
+
|
67
|
+
|
68
|
+
def warning(message: str) -> None:
|
69
|
+
"""Print warning message with formatting"""
|
70
|
+
if cli_state.json_output:
|
71
|
+
return
|
72
|
+
if not cli_state.quiet:
|
73
|
+
console.print(f"⚠️ {message}", style="yellow")
|
74
|
+
|
75
|
+
|
76
|
+
def error(message: str) -> None:
|
77
|
+
"""Print error message with formatting"""
|
78
|
+
if cli_state.json_output:
|
79
|
+
return
|
80
|
+
console.print(f"❌ {message}", style="red")
|
81
|
+
|
82
|
+
|
83
|
+
def info(message: str) -> None:
|
84
|
+
"""Print info message with formatting"""
|
85
|
+
if cli_state.json_output:
|
86
|
+
return
|
87
|
+
if not cli_state.quiet:
|
88
|
+
console.print(f"💡 {message}", style="blue")
|
89
|
+
|
90
|
+
|
91
|
+
def debug(message: str) -> None:
|
92
|
+
"""Print debug message if verbose mode is enabled"""
|
93
|
+
if cli_state.json_output:
|
94
|
+
return
|
95
|
+
if cli_state.verbose:
|
96
|
+
console.print(f"🔍 {message}", style="dim")
|
97
|
+
|
98
|
+
|
99
|
+
def docs_link(message: str, url: str = "https://www.snowglobe.so/docs") -> None:
|
100
|
+
"""Print documentation link"""
|
101
|
+
if cli_state.json_output:
|
102
|
+
return
|
103
|
+
if not cli_state.quiet:
|
104
|
+
console.print(f"📖 {message}: {url}", style="cyan")
|
105
|
+
|
106
|
+
|
107
|
+
def graceful_shutdown():
|
108
|
+
"""Handle graceful shutdown with session summary"""
|
109
|
+
console.print("\n")
|
110
|
+
warning("🛑 Shutting down gracefully...")
|
111
|
+
success("Completing current scenarios")
|
112
|
+
success("Connection closed")
|
113
|
+
|
114
|
+
stats = get_shutdown_stats()
|
115
|
+
|
116
|
+
if stats and stats["total_messages"] > 0:
|
117
|
+
success("Session summary:")
|
118
|
+
if len(stats["experiment_totals"]) > 1:
|
119
|
+
# Multiple experiments - show breakdown
|
120
|
+
for exp_name, count in stats["experiment_totals"].items():
|
121
|
+
console.print(f" • {exp_name}: {count} scenarios processed")
|
122
|
+
console.print(
|
123
|
+
f" • Total: {stats['total_messages']} scenarios in {stats['uptime']}"
|
124
|
+
)
|
125
|
+
else:
|
126
|
+
# Single experiment or total only
|
127
|
+
console.print(
|
128
|
+
f" • {stats['total_messages']} scenarios processed in {stats['uptime']}"
|
129
|
+
)
|
130
|
+
else:
|
131
|
+
success("No scenarios processed during this session")
|
132
|
+
|
133
|
+
console.print()
|
134
|
+
success("Agent disconnected successfully")
|
135
|
+
sys.exit(0)
|
136
|
+
|
137
|
+
|
138
|
+
@contextlib.contextmanager
|
139
|
+
def spinner(text: str):
|
140
|
+
"""Context manager for showing a spinner during operations"""
|
141
|
+
if cli_state.json_output or cli_state.quiet:
|
142
|
+
yield
|
143
|
+
return
|
144
|
+
|
145
|
+
with console.status(f"[bold blue]{text}..."):
|
146
|
+
yield
|
147
|
+
|
148
|
+
|
149
|
+
def check_auth_status() -> Tuple[bool, str, Dict[str, Any]]:
|
150
|
+
"""Check authentication status"""
|
151
|
+
api_key = get_api_key()
|
152
|
+
if not api_key:
|
153
|
+
return False, "No API key found", {}
|
154
|
+
|
155
|
+
control_plane_url = get_control_plane_url()
|
156
|
+
try:
|
157
|
+
response = requests.get(
|
158
|
+
f"{control_plane_url}/api/applications",
|
159
|
+
headers={"x-api-key": api_key},
|
160
|
+
timeout=10,
|
161
|
+
)
|
162
|
+
if response.status_code == 200:
|
163
|
+
return True, "Authenticated", response.json()
|
164
|
+
else:
|
165
|
+
return False, f"Authentication failed: {response.status_code}", {}
|
166
|
+
except requests.RequestException as e:
|
167
|
+
return False, f"Connection error: {str(e)}", {}
|
168
|
+
|
169
|
+
|
170
|
+
def select_stateful_interactive(
|
171
|
+
stateful: bool = False,
|
172
|
+
) -> bool:
|
173
|
+
"""Interactive prompt to confirm if the agent is stateful"""
|
174
|
+
if cli_state.json_output:
|
175
|
+
# For JSON mode, just return the default stateful value
|
176
|
+
return stateful
|
177
|
+
info("Some stateful agents such as ones that maintain communication over a websocket or convo specific completion endpoint require stateful integration.")
|
178
|
+
info("If your agent takes messages and completions on a single completion endpoint regardless of context, you can answer no to the following question.")
|
179
|
+
if Confirm.ask("Would you like to create a new application?"):
|
180
|
+
return True
|
181
|
+
return False
|
182
|
+
|
183
|
+
def select_application_interactive(
|
184
|
+
applications: List[Dict[str, Any]],
|
185
|
+
) -> Optional[Dict[str, Any]]:
|
186
|
+
"""Clean, readable application selection interface"""
|
187
|
+
if cli_state.json_output:
|
188
|
+
# For JSON mode, just return the first app or None
|
189
|
+
return applications[0] if applications else None
|
190
|
+
|
191
|
+
if not applications:
|
192
|
+
info("No applications found")
|
193
|
+
if Confirm.ask("Would you like to create a new application?"):
|
194
|
+
return "new"
|
195
|
+
return None
|
196
|
+
|
197
|
+
# Sort applications by updated_at (most recent first)
|
198
|
+
sorted_applications = sort_applications_by_date(applications)
|
199
|
+
|
200
|
+
return display_applications_clean(sorted_applications)
|
201
|
+
|
202
|
+
|
203
|
+
def sort_applications_by_date(
|
204
|
+
applications: List[Dict[str, Any]],
|
205
|
+
) -> List[Dict[str, Any]]:
|
206
|
+
"""Sort applications by updated_at date, most recent first"""
|
207
|
+
|
208
|
+
def get_sort_key(app):
|
209
|
+
updated_at = app.get("updated_at", "")
|
210
|
+
if not updated_at:
|
211
|
+
return "" # Apps without dates go to the end
|
212
|
+
return updated_at
|
213
|
+
|
214
|
+
# Sort in reverse order (most recent first)
|
215
|
+
return sorted(applications, key=get_sort_key, reverse=True)
|
216
|
+
|
217
|
+
|
218
|
+
def display_applications_clean(
|
219
|
+
applications: List[Dict[str, Any]], page_size: int = 15
|
220
|
+
) -> Optional[Dict[str, Any]]:
|
221
|
+
"""Display applications in a clean table format"""
|
222
|
+
total_apps = len(applications)
|
223
|
+
total_pages = math.ceil(total_apps / page_size) if total_apps > 0 else 1
|
224
|
+
current_page = 0
|
225
|
+
|
226
|
+
# Check if any apps have updated_at date information
|
227
|
+
has_dates = any(app.get("updated_at") for app in applications)
|
228
|
+
|
229
|
+
while True:
|
230
|
+
# Calculate page boundaries
|
231
|
+
start_idx = current_page * page_size
|
232
|
+
end_idx = min(start_idx + page_size, total_apps)
|
233
|
+
page_apps = applications[start_idx:end_idx]
|
234
|
+
|
235
|
+
# Create table
|
236
|
+
table = Table(title=f"📱 Your Applications ({total_apps} total)")
|
237
|
+
if total_pages > 1:
|
238
|
+
table.title = f"📱 Your Applications ({total_apps} total) - Page {current_page + 1}/{total_pages}"
|
239
|
+
|
240
|
+
table.add_column("#", style="bold blue", width=4)
|
241
|
+
table.add_column("Name", style="bold", min_width=15)
|
242
|
+
|
243
|
+
# Add date column if date info is available
|
244
|
+
if has_dates:
|
245
|
+
table.add_column("Updated", style="dim", min_width=10)
|
246
|
+
|
247
|
+
table.add_column("Description", style="green", min_width=20)
|
248
|
+
|
249
|
+
# Add applications to table
|
250
|
+
for i, app in enumerate(page_apps):
|
251
|
+
app_idx = start_idx + i + 1
|
252
|
+
name = app.get("name", "Unknown")
|
253
|
+
description = app.get("description", "No description")
|
254
|
+
|
255
|
+
# Clean up description - remove newlines and extra spaces
|
256
|
+
description = " ".join(description.split())
|
257
|
+
|
258
|
+
# Truncate description to 20 characters
|
259
|
+
if len(description) > 20:
|
260
|
+
description = description[:17] + "..."
|
261
|
+
|
262
|
+
# Get the best available date
|
263
|
+
date_str = "-"
|
264
|
+
if has_dates:
|
265
|
+
date_str = get_best_date(app)
|
266
|
+
|
267
|
+
# Build row based on whether we have dates
|
268
|
+
if has_dates:
|
269
|
+
table.add_row(str(app_idx), name, date_str, description)
|
270
|
+
else:
|
271
|
+
table.add_row(str(app_idx), name, description)
|
272
|
+
|
273
|
+
# Add create new option
|
274
|
+
if has_dates:
|
275
|
+
table.add_row("new", "🆕 Create New App", "-", "Set up new application")
|
276
|
+
else:
|
277
|
+
table.add_row("new", "🆕 Create New App", "Set up new application")
|
278
|
+
|
279
|
+
console.print(table)
|
280
|
+
|
281
|
+
# Navigation instructions
|
282
|
+
nav_options = []
|
283
|
+
if current_page > 0:
|
284
|
+
nav_options.append("[bold cyan]p[/bold cyan] Previous")
|
285
|
+
if current_page < total_pages - 1:
|
286
|
+
nav_options.append("[bold cyan]n[/bold cyan] Next")
|
287
|
+
nav_options.extend(
|
288
|
+
[
|
289
|
+
f"[bold yellow]1-{total_apps}[/bold yellow] Select app",
|
290
|
+
"[bold green]new[/bold green] Create new",
|
291
|
+
"[bold red]q[/bold red] Quit",
|
292
|
+
]
|
293
|
+
)
|
294
|
+
|
295
|
+
console.print("\nOptions: " + " | ".join(nav_options))
|
296
|
+
|
297
|
+
# Get user input
|
298
|
+
try:
|
299
|
+
choice = Prompt.ask("\n[bold]Your choice[/bold]").strip().lower()
|
300
|
+
|
301
|
+
if choice == "q":
|
302
|
+
return None
|
303
|
+
elif choice == "p" and current_page > 0:
|
304
|
+
current_page -= 1
|
305
|
+
continue
|
306
|
+
elif choice == "n" and current_page < total_pages - 1:
|
307
|
+
current_page += 1
|
308
|
+
continue
|
309
|
+
elif choice == "new":
|
310
|
+
return "new"
|
311
|
+
elif choice.isdigit():
|
312
|
+
idx = int(choice)
|
313
|
+
if 1 <= idx <= total_apps:
|
314
|
+
return applications[idx - 1]
|
315
|
+
else:
|
316
|
+
error(f"Please choose between 1 and {total_apps}")
|
317
|
+
time.sleep(1)
|
318
|
+
else:
|
319
|
+
error("Invalid choice. Try again.")
|
320
|
+
time.sleep(1)
|
321
|
+
|
322
|
+
except (KeyboardInterrupt, EOFError):
|
323
|
+
warning("\nSelection cancelled")
|
324
|
+
return None
|
325
|
+
|
326
|
+
|
327
|
+
def get_best_date(app: Dict[str, Any]) -> str:
|
328
|
+
"""Get updated_at date formatted for display"""
|
329
|
+
date_value = app.get("updated_at")
|
330
|
+
|
331
|
+
if not date_value:
|
332
|
+
return "-"
|
333
|
+
|
334
|
+
# Format ISO date like "2025-07-29T04:35:22.093Z" to "2025-07-29"
|
335
|
+
date_str = str(date_value)
|
336
|
+
if "T" in date_str:
|
337
|
+
return date_str.split("T")[0] # Take just the date part
|
338
|
+
elif len(date_str) > 10:
|
339
|
+
return date_str[:10]
|
340
|
+
return date_str
|
341
|
+
|
342
|
+
|
343
|
+
def get_remote_applications() -> Tuple[bool, List[Dict[str, Any]], str]:
|
344
|
+
"""Fetch applications from the remote API"""
|
345
|
+
api_key = get_api_key()
|
346
|
+
if not api_key:
|
347
|
+
return False, [], "No API key found"
|
348
|
+
|
349
|
+
control_plane_url = get_control_plane_url()
|
350
|
+
try:
|
351
|
+
response = requests.get(
|
352
|
+
f"{control_plane_url}/api/applications",
|
353
|
+
headers={"x-api-key": api_key},
|
354
|
+
timeout=10,
|
355
|
+
)
|
356
|
+
if response.status_code == 200:
|
357
|
+
return True, response.json(), "Success"
|
358
|
+
else:
|
359
|
+
return False, [], f"HTTP {response.status_code}: {response.text}"
|
360
|
+
except requests.RequestException as e:
|
361
|
+
return False, [], f"Connection error: {str(e)}"
|
@@ -0,0 +1,213 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
|
4
|
+
from rich.console import Console
|
5
|
+
|
6
|
+
LOGGER = logging.getLogger(__name__)
|
7
|
+
console = Console()
|
8
|
+
|
9
|
+
|
10
|
+
def get_rc_file_path() -> str:
|
11
|
+
"""Get path to snowglobe rc file"""
|
12
|
+
return os.path.join(os.getcwd(), ".snowglobe", "config.rc")
|
13
|
+
|
14
|
+
|
15
|
+
def get_legacy_rc_file_path() -> str:
|
16
|
+
"""Get path to legacy .snowgloberc file for migration purposes"""
|
17
|
+
return os.path.join(os.getcwd(), ".snowgloberc")
|
18
|
+
|
19
|
+
|
20
|
+
def migrate_rc_file_if_needed():
|
21
|
+
"""Migrate .snowgloberc to .snowglobe/config.rc if needed"""
|
22
|
+
legacy_path = get_legacy_rc_file_path()
|
23
|
+
new_path = get_rc_file_path()
|
24
|
+
|
25
|
+
# If new file already exists, no migration needed
|
26
|
+
if os.path.exists(new_path):
|
27
|
+
return
|
28
|
+
|
29
|
+
# If legacy file exists, migrate it
|
30
|
+
if os.path.exists(legacy_path):
|
31
|
+
# Ensure .snowglobe directory exists
|
32
|
+
os.makedirs(os.path.dirname(new_path), exist_ok=True)
|
33
|
+
|
34
|
+
# Copy content from legacy to new location
|
35
|
+
with open(legacy_path, "r") as legacy_file:
|
36
|
+
content = legacy_file.read()
|
37
|
+
|
38
|
+
with open(new_path, "w") as new_file:
|
39
|
+
new_file.write(content)
|
40
|
+
|
41
|
+
LOGGER.info(f"Migrated {legacy_path} to {new_path}")
|
42
|
+
|
43
|
+
|
44
|
+
class Config:
|
45
|
+
def __init__(self):
|
46
|
+
# Migrate legacy .snowgloberc if needed
|
47
|
+
migrate_rc_file_if_needed()
|
48
|
+
|
49
|
+
self.SNOWGLOBE_APPLICATION_HEARTBEAT_INTERVAL_MINUTES = int(
|
50
|
+
os.getenv("SNOWGLOBE_APPLICATION_HEARTBEAT_INTERVAL_MINUTES", "5")
|
51
|
+
)
|
52
|
+
self.CONCURRENT_HEARTBEATS_PER_INTERVAL = 120
|
53
|
+
self.CONCURRENT_HEARTBEATS_INTERVAL_SECONDS = 60
|
54
|
+
|
55
|
+
# Initialize API key - just raise exception if missing, let caller handle it
|
56
|
+
self.API_KEY = self.get_api_key()
|
57
|
+
self.APPLICATION_ID = self.get_application_id()
|
58
|
+
self.CONTROL_PLANE_URL = self.get_control_plane_url()
|
59
|
+
self.SNOWGLOBE_CLIENT_URL = self.get_snowglobe_client_url()
|
60
|
+
self.SNOWGLOBE_CLIENT_PORT = self.get_snowglobe_client_port()
|
61
|
+
self.CONCURRENT_COMPLETIONS_PER_INTERVAL = self.get_completions_per_interval()
|
62
|
+
self.CONCURRENT_COMPLETIONS_INTERVAL_SECONDS = (
|
63
|
+
self.get_completions_interval_seconds()
|
64
|
+
)
|
65
|
+
self.CONCURRENT_RISK_EVALUATIONS = self.get_concurrent_risk_evaluations()
|
66
|
+
self.CONCURRENT_RISK_EVALUATIONS_INTERVAL_SECONDS = (
|
67
|
+
self.get_concurrent_risk_evaluations_interval_seconds()
|
68
|
+
)
|
69
|
+
self.DEBUG = os.getenv("DEBUG", "false").lower() == "true"
|
70
|
+
|
71
|
+
def get_snowglobe_client_port(self) -> int:
|
72
|
+
snowglobe_client_port = os.getenv("SNOWGLOBE_CLIENT_PORT")
|
73
|
+
if not snowglobe_client_port:
|
74
|
+
snowglobe_client_port = "8000"
|
75
|
+
rc_path = get_rc_file_path()
|
76
|
+
if os.path.exists(rc_path):
|
77
|
+
with open(rc_path, "r") as rc_file:
|
78
|
+
for line in rc_file:
|
79
|
+
if line.startswith("SNOWGLOBE_CLIENT_PORT="):
|
80
|
+
snowglobe_client_port = line.strip().split("=", 1)[1]
|
81
|
+
break
|
82
|
+
LOGGER.debug(f"setting SNOWGLOBE_CLIENT_PORT: {snowglobe_client_port}")
|
83
|
+
return int(snowglobe_client_port)
|
84
|
+
|
85
|
+
def get_snowglobe_client_url(self) -> str:
|
86
|
+
snowglobe_client_url = os.getenv("SNOWGLOBE_CLIENT_URL")
|
87
|
+
if not snowglobe_client_url:
|
88
|
+
snowglobe_client_url = "http://localhost:8000"
|
89
|
+
rc_path = get_rc_file_path()
|
90
|
+
if os.path.exists(rc_path):
|
91
|
+
with open(rc_path, "r") as rc_file:
|
92
|
+
for line in rc_file:
|
93
|
+
if line.startswith("SNOWGLOBE_CLIENT_URL="):
|
94
|
+
snowglobe_client_url = line.strip().split("=", 1)[1]
|
95
|
+
break
|
96
|
+
LOGGER.debug(f"setting SNOWGLOBE_CLIENT_URL: {snowglobe_client_url}")
|
97
|
+
return snowglobe_client_url
|
98
|
+
|
99
|
+
def get_completions_interval_seconds(self) -> int:
|
100
|
+
completions_interval_seconds = os.getenv("COMPLETIONS_INTERVAL_SECONDS")
|
101
|
+
if not completions_interval_seconds:
|
102
|
+
completions_interval_seconds = "60"
|
103
|
+
LOGGER.debug(
|
104
|
+
"COMPLETIONS_INTERVAL_SECONDS not found in environment variables, using COMPLETIONS_INTERVAL as fallback"
|
105
|
+
)
|
106
|
+
rc_path = get_rc_file_path()
|
107
|
+
if os.path.exists(rc_path):
|
108
|
+
with open(rc_path, "r") as rc_file:
|
109
|
+
for line in rc_file:
|
110
|
+
if line.startswith("COMPLETIONS_INTERVAL_SECONDS="):
|
111
|
+
completions_interval_seconds = line.strip().split("=", 1)[1]
|
112
|
+
break
|
113
|
+
return int(completions_interval_seconds)
|
114
|
+
|
115
|
+
def get_completions_per_interval(self) -> int:
|
116
|
+
completions_per_second = os.getenv("COMPLETIONS_PER_SECOND")
|
117
|
+
if not completions_per_second:
|
118
|
+
completions_per_second = "120"
|
119
|
+
LOGGER.debug(
|
120
|
+
"COMPLETIONS_PER_SECOND not found in environment variables, using MAX_COMPLETIONS as fallback"
|
121
|
+
)
|
122
|
+
rc_path = get_rc_file_path()
|
123
|
+
if os.path.exists(rc_path):
|
124
|
+
with open(rc_path, "r") as rc_file:
|
125
|
+
for line in rc_file:
|
126
|
+
if line.startswith("COMPLETIONS_PER_SECOND="):
|
127
|
+
completions_per_second = line.strip().split("=", 1)[1]
|
128
|
+
break
|
129
|
+
return int(completions_per_second)
|
130
|
+
|
131
|
+
def get_concurrent_risk_evaluations(self) -> int:
|
132
|
+
concurrent_risk_evaluations = os.getenv("CONCURRENT_RISK_EVALUATIONS")
|
133
|
+
if not concurrent_risk_evaluations:
|
134
|
+
concurrent_risk_evaluations = "120"
|
135
|
+
rc_path = get_rc_file_path()
|
136
|
+
if os.path.exists(rc_path):
|
137
|
+
with open(rc_path, "r") as rc_file:
|
138
|
+
for line in rc_file:
|
139
|
+
if line.startswith("CONCURRENT_RISK_EVALUATIONS="):
|
140
|
+
concurrent_risk_evaluations = line.strip().split("=", 1)[1]
|
141
|
+
break
|
142
|
+
return int(concurrent_risk_evaluations)
|
143
|
+
|
144
|
+
def get_concurrent_risk_evaluations_interval_seconds(self) -> int:
|
145
|
+
concurrent_risk_evaluations_interval_seconds = os.getenv(
|
146
|
+
"CONCURRENT_RISK_EVALUATIONS_INTERVAL_SECONDS"
|
147
|
+
)
|
148
|
+
if not concurrent_risk_evaluations_interval_seconds:
|
149
|
+
concurrent_risk_evaluations_interval_seconds = "60"
|
150
|
+
rc_path = get_rc_file_path()
|
151
|
+
if os.path.exists(rc_path):
|
152
|
+
with open(rc_path, "r") as rc_file:
|
153
|
+
for line in rc_file:
|
154
|
+
if line.startswith(
|
155
|
+
"CONCURRENT_RISK_EVALUATIONS_INTERVAL_SECONDS="
|
156
|
+
):
|
157
|
+
concurrent_risk_evaluations_interval_seconds = (
|
158
|
+
line.strip().split("=", 1)[1]
|
159
|
+
)
|
160
|
+
break
|
161
|
+
return int(concurrent_risk_evaluations_interval_seconds)
|
162
|
+
|
163
|
+
def get_control_plane_url(self) -> str:
|
164
|
+
control_plane_url = os.getenv("CONTROL_PLANE_URL")
|
165
|
+
if not control_plane_url:
|
166
|
+
control_plane_url = "https://api.snowglobe.guardrailsai.com"
|
167
|
+
rc_path = get_rc_file_path()
|
168
|
+
if os.path.exists(rc_path):
|
169
|
+
with open(rc_path, "r") as rc_file:
|
170
|
+
for line in rc_file:
|
171
|
+
LOGGER.debug(f"Checking line: {line.strip()}")
|
172
|
+
if line.startswith("CONTROL_PLANE_URL="):
|
173
|
+
control_plane_url = line.strip().split("=", 1)[1]
|
174
|
+
break
|
175
|
+
LOGGER.debug(f"setting CONTROL_PLANE_URL: {control_plane_url}")
|
176
|
+
return control_plane_url
|
177
|
+
|
178
|
+
def get_api_key(self) -> str:
|
179
|
+
api_key = os.getenv("SNOWGLOBE_API_KEY") or os.getenv("GUARDRAILS_API_KEY")
|
180
|
+
if not api_key:
|
181
|
+
rc_path = get_rc_file_path()
|
182
|
+
if os.path.exists(rc_path):
|
183
|
+
with open(rc_path, "r") as rc_file:
|
184
|
+
for line in rc_file:
|
185
|
+
if line.startswith("SNOWGLOBE_API_KEY="):
|
186
|
+
api_key = line.strip().split("=", 1)[1]
|
187
|
+
break
|
188
|
+
if not api_key:
|
189
|
+
raise ValueError(
|
190
|
+
"API key is required either passed as an argument or set as an environment variable. \n"
|
191
|
+
"You can set the API key as an environment variable by running: \n"
|
192
|
+
"export SNOWGLOBE_API_KEY=<your_api_key> \n"
|
193
|
+
"or \n"
|
194
|
+
"export GUARDRAILS_API_KEY=<your_api_key> \n"
|
195
|
+
"Or you can create a .snowglobe/config.rc file in the current directory with the line: \n"
|
196
|
+
"SNOWGLOBE_API_KEY=<your_api_key> \n"
|
197
|
+
)
|
198
|
+
return api_key
|
199
|
+
|
200
|
+
def get_application_id(self) -> str:
|
201
|
+
application_id = os.getenv("SNOWGLOBE_APP_ID")
|
202
|
+
if not application_id:
|
203
|
+
rc_path = get_rc_file_path()
|
204
|
+
if os.path.exists(rc_path):
|
205
|
+
with open(rc_path, "r") as rc_file:
|
206
|
+
for line in rc_file:
|
207
|
+
if line.startswith("SNOWGLOBE_APP_ID="):
|
208
|
+
application_id = line.strip().split("=", 1)[1]
|
209
|
+
break
|
210
|
+
return application_id
|
211
|
+
|
212
|
+
|
213
|
+
config = Config()
|
@@ -0,0 +1,37 @@
|
|
1
|
+
from typing import Dict, List, Optional
|
2
|
+
|
3
|
+
from pydantic import BaseModel
|
4
|
+
|
5
|
+
|
6
|
+
class SnowglobeData(BaseModel):
|
7
|
+
conversation_id: str
|
8
|
+
test_id: str
|
9
|
+
|
10
|
+
|
11
|
+
class SnowglobeMessage(BaseModel):
|
12
|
+
role: str
|
13
|
+
content: str
|
14
|
+
snowglobe_data: Optional[SnowglobeData] = None
|
15
|
+
|
16
|
+
|
17
|
+
class CompletionFunctionOutputs(BaseModel):
|
18
|
+
response: str
|
19
|
+
|
20
|
+
|
21
|
+
class CompletionRequest(BaseModel):
|
22
|
+
messages: List[SnowglobeMessage]
|
23
|
+
|
24
|
+
def to_openai_messages(self) -> List[Dict]:
|
25
|
+
"""Return a list of OpenAI messages from the Snowglobe messages"""
|
26
|
+
return [{"role": msg.role, "content": msg.content} for msg in self.messages]
|
27
|
+
|
28
|
+
|
29
|
+
class RiskEvaluationRequest(BaseModel):
|
30
|
+
messages: List[SnowglobeMessage]
|
31
|
+
|
32
|
+
|
33
|
+
class RiskEvaluationOutputs(BaseModel):
|
34
|
+
triggered: bool
|
35
|
+
tags: Optional[Dict[str, str]] = None
|
36
|
+
reason: Optional[str] = None
|
37
|
+
severity: Optional[int] = None
|