genxai-framework 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__init__.py +3 -0
- cli/commands/__init__.py +6 -0
- cli/commands/approval.py +85 -0
- cli/commands/audit.py +127 -0
- cli/commands/metrics.py +25 -0
- cli/commands/tool.py +389 -0
- cli/main.py +32 -0
- genxai/__init__.py +81 -0
- genxai/api/__init__.py +5 -0
- genxai/api/app.py +21 -0
- genxai/config/__init__.py +5 -0
- genxai/config/settings.py +37 -0
- genxai/connectors/__init__.py +19 -0
- genxai/connectors/base.py +122 -0
- genxai/connectors/kafka.py +92 -0
- genxai/connectors/postgres_cdc.py +95 -0
- genxai/connectors/registry.py +44 -0
- genxai/connectors/sqs.py +94 -0
- genxai/connectors/webhook.py +73 -0
- genxai/core/__init__.py +37 -0
- genxai/core/agent/__init__.py +32 -0
- genxai/core/agent/base.py +206 -0
- genxai/core/agent/config_io.py +59 -0
- genxai/core/agent/registry.py +98 -0
- genxai/core/agent/runtime.py +970 -0
- genxai/core/communication/__init__.py +6 -0
- genxai/core/communication/collaboration.py +44 -0
- genxai/core/communication/message_bus.py +192 -0
- genxai/core/communication/protocols.py +35 -0
- genxai/core/execution/__init__.py +22 -0
- genxai/core/execution/metadata.py +181 -0
- genxai/core/execution/queue.py +201 -0
- genxai/core/graph/__init__.py +30 -0
- genxai/core/graph/checkpoints.py +77 -0
- genxai/core/graph/edges.py +131 -0
- genxai/core/graph/engine.py +813 -0
- genxai/core/graph/executor.py +516 -0
- genxai/core/graph/nodes.py +161 -0
- genxai/core/graph/trigger_runner.py +40 -0
- genxai/core/memory/__init__.py +19 -0
- genxai/core/memory/base.py +72 -0
- genxai/core/memory/embedding.py +327 -0
- genxai/core/memory/episodic.py +448 -0
- genxai/core/memory/long_term.py +467 -0
- genxai/core/memory/manager.py +543 -0
- genxai/core/memory/persistence.py +297 -0
- genxai/core/memory/procedural.py +461 -0
- genxai/core/memory/semantic.py +526 -0
- genxai/core/memory/shared.py +62 -0
- genxai/core/memory/short_term.py +303 -0
- genxai/core/memory/vector_store.py +508 -0
- genxai/core/memory/working.py +211 -0
- genxai/core/state/__init__.py +6 -0
- genxai/core/state/manager.py +293 -0
- genxai/core/state/schema.py +115 -0
- genxai/llm/__init__.py +14 -0
- genxai/llm/base.py +150 -0
- genxai/llm/factory.py +329 -0
- genxai/llm/providers/__init__.py +1 -0
- genxai/llm/providers/anthropic.py +249 -0
- genxai/llm/providers/cohere.py +274 -0
- genxai/llm/providers/google.py +334 -0
- genxai/llm/providers/ollama.py +147 -0
- genxai/llm/providers/openai.py +257 -0
- genxai/llm/routing.py +83 -0
- genxai/observability/__init__.py +6 -0
- genxai/observability/logging.py +327 -0
- genxai/observability/metrics.py +494 -0
- genxai/observability/tracing.py +372 -0
- genxai/performance/__init__.py +39 -0
- genxai/performance/cache.py +256 -0
- genxai/performance/pooling.py +289 -0
- genxai/security/audit.py +304 -0
- genxai/security/auth.py +315 -0
- genxai/security/cost_control.py +528 -0
- genxai/security/default_policies.py +44 -0
- genxai/security/jwt.py +142 -0
- genxai/security/oauth.py +226 -0
- genxai/security/pii.py +366 -0
- genxai/security/policy_engine.py +82 -0
- genxai/security/rate_limit.py +341 -0
- genxai/security/rbac.py +247 -0
- genxai/security/validation.py +218 -0
- genxai/tools/__init__.py +21 -0
- genxai/tools/base.py +383 -0
- genxai/tools/builtin/__init__.py +131 -0
- genxai/tools/builtin/communication/__init__.py +15 -0
- genxai/tools/builtin/communication/email_sender.py +159 -0
- genxai/tools/builtin/communication/notification_manager.py +167 -0
- genxai/tools/builtin/communication/slack_notifier.py +118 -0
- genxai/tools/builtin/communication/sms_sender.py +118 -0
- genxai/tools/builtin/communication/webhook_caller.py +136 -0
- genxai/tools/builtin/computation/__init__.py +15 -0
- genxai/tools/builtin/computation/calculator.py +101 -0
- genxai/tools/builtin/computation/code_executor.py +183 -0
- genxai/tools/builtin/computation/data_validator.py +259 -0
- genxai/tools/builtin/computation/hash_generator.py +129 -0
- genxai/tools/builtin/computation/regex_matcher.py +201 -0
- genxai/tools/builtin/data/__init__.py +15 -0
- genxai/tools/builtin/data/csv_processor.py +213 -0
- genxai/tools/builtin/data/data_transformer.py +299 -0
- genxai/tools/builtin/data/json_processor.py +233 -0
- genxai/tools/builtin/data/text_analyzer.py +288 -0
- genxai/tools/builtin/data/xml_processor.py +175 -0
- genxai/tools/builtin/database/__init__.py +15 -0
- genxai/tools/builtin/database/database_inspector.py +157 -0
- genxai/tools/builtin/database/mongodb_query.py +196 -0
- genxai/tools/builtin/database/redis_cache.py +167 -0
- genxai/tools/builtin/database/sql_query.py +145 -0
- genxai/tools/builtin/database/vector_search.py +163 -0
- genxai/tools/builtin/file/__init__.py +17 -0
- genxai/tools/builtin/file/directory_scanner.py +214 -0
- genxai/tools/builtin/file/file_compressor.py +237 -0
- genxai/tools/builtin/file/file_reader.py +102 -0
- genxai/tools/builtin/file/file_writer.py +122 -0
- genxai/tools/builtin/file/image_processor.py +186 -0
- genxai/tools/builtin/file/pdf_parser.py +144 -0
- genxai/tools/builtin/test/__init__.py +15 -0
- genxai/tools/builtin/test/async_simulator.py +62 -0
- genxai/tools/builtin/test/data_transformer.py +99 -0
- genxai/tools/builtin/test/error_generator.py +82 -0
- genxai/tools/builtin/test/simple_math.py +94 -0
- genxai/tools/builtin/test/string_processor.py +72 -0
- genxai/tools/builtin/web/__init__.py +15 -0
- genxai/tools/builtin/web/api_caller.py +161 -0
- genxai/tools/builtin/web/html_parser.py +330 -0
- genxai/tools/builtin/web/http_client.py +187 -0
- genxai/tools/builtin/web/url_validator.py +162 -0
- genxai/tools/builtin/web/web_scraper.py +170 -0
- genxai/tools/custom/my_test_tool_2.py +9 -0
- genxai/tools/dynamic.py +105 -0
- genxai/tools/mcp_server.py +167 -0
- genxai/tools/persistence/__init__.py +6 -0
- genxai/tools/persistence/models.py +55 -0
- genxai/tools/persistence/service.py +322 -0
- genxai/tools/registry.py +227 -0
- genxai/tools/security/__init__.py +11 -0
- genxai/tools/security/limits.py +214 -0
- genxai/tools/security/policy.py +20 -0
- genxai/tools/security/sandbox.py +248 -0
- genxai/tools/templates.py +435 -0
- genxai/triggers/__init__.py +19 -0
- genxai/triggers/base.py +104 -0
- genxai/triggers/file_watcher.py +75 -0
- genxai/triggers/queue.py +68 -0
- genxai/triggers/registry.py +82 -0
- genxai/triggers/schedule.py +66 -0
- genxai/triggers/webhook.py +68 -0
- genxai/utils/__init__.py +1 -0
- genxai/utils/tokens.py +295 -0
- genxai_framework-0.1.0.dist-info/METADATA +495 -0
- genxai_framework-0.1.0.dist-info/RECORD +156 -0
- genxai_framework-0.1.0.dist-info/WHEEL +5 -0
- genxai_framework-0.1.0.dist-info/entry_points.txt +2 -0
- genxai_framework-0.1.0.dist-info/licenses/LICENSE +21 -0
- genxai_framework-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,528 @@
|
|
|
1
|
+
"""Cost control and budget management for GenXAI."""
|
|
2
|
+
|
|
3
|
+
import sqlite3
|
|
4
|
+
from typing import Dict, Any, Optional
|
|
5
|
+
from datetime import datetime, timedelta
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
import os
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# Token costs per 1K tokens (as of 2026)
|
|
11
|
+
TOKEN_COSTS = {
|
|
12
|
+
"openai": {
|
|
13
|
+
"gpt-4": {"prompt": 0.03, "completion": 0.06},
|
|
14
|
+
"gpt-4-turbo": {"prompt": 0.01, "completion": 0.03},
|
|
15
|
+
"gpt-3.5-turbo": {"prompt": 0.0015, "completion": 0.002},
|
|
16
|
+
},
|
|
17
|
+
"anthropic": {
|
|
18
|
+
"claude-3-opus": {"prompt": 0.015, "completion": 0.075},
|
|
19
|
+
"claude-3-sonnet": {"prompt": 0.003, "completion": 0.015},
|
|
20
|
+
"claude-3-haiku": {"prompt": 0.00025, "completion": 0.00125},
|
|
21
|
+
},
|
|
22
|
+
"google": {
|
|
23
|
+
"gemini-pro": {"prompt": 0.00025, "completion": 0.0005},
|
|
24
|
+
"gemini-ultra": {"prompt": 0.01, "completion": 0.02},
|
|
25
|
+
},
|
|
26
|
+
"cohere": {
|
|
27
|
+
"command": {"prompt": 0.001, "completion": 0.002},
|
|
28
|
+
"command-light": {"prompt": 0.0003, "completion": 0.0006},
|
|
29
|
+
},
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class UsageRecord:
|
|
35
|
+
"""Token usage record."""
|
|
36
|
+
user_id: str
|
|
37
|
+
provider: str
|
|
38
|
+
model: str
|
|
39
|
+
prompt_tokens: int
|
|
40
|
+
completion_tokens: int
|
|
41
|
+
cost: float
|
|
42
|
+
timestamp: datetime
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class TokenUsageTracker:
|
|
46
|
+
"""Track LLM token usage."""
|
|
47
|
+
|
|
48
|
+
def __init__(self, db_path: str = "genxai_usage.db"):
|
|
49
|
+
"""Initialize usage tracker.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
db_path: Path to SQLite database
|
|
53
|
+
"""
|
|
54
|
+
self.db_path = db_path
|
|
55
|
+
self._init_db()
|
|
56
|
+
|
|
57
|
+
def _init_db(self):
|
|
58
|
+
"""Initialize database schema."""
|
|
59
|
+
conn = sqlite3.connect(self.db_path)
|
|
60
|
+
cursor = conn.cursor()
|
|
61
|
+
|
|
62
|
+
cursor.execute("""
|
|
63
|
+
CREATE TABLE IF NOT EXISTS token_usage (
|
|
64
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
65
|
+
user_id TEXT NOT NULL,
|
|
66
|
+
provider TEXT NOT NULL,
|
|
67
|
+
model TEXT NOT NULL,
|
|
68
|
+
prompt_tokens INTEGER NOT NULL,
|
|
69
|
+
completion_tokens INTEGER NOT NULL,
|
|
70
|
+
cost REAL NOT NULL,
|
|
71
|
+
timestamp TIMESTAMP NOT NULL
|
|
72
|
+
)
|
|
73
|
+
""")
|
|
74
|
+
|
|
75
|
+
cursor.execute("""
|
|
76
|
+
CREATE INDEX IF NOT EXISTS idx_user_timestamp
|
|
77
|
+
ON token_usage(user_id, timestamp)
|
|
78
|
+
""")
|
|
79
|
+
|
|
80
|
+
conn.commit()
|
|
81
|
+
conn.close()
|
|
82
|
+
|
|
83
|
+
def record_usage(
|
|
84
|
+
self,
|
|
85
|
+
user_id: str,
|
|
86
|
+
provider: str,
|
|
87
|
+
model: str,
|
|
88
|
+
prompt_tokens: int,
|
|
89
|
+
completion_tokens: int,
|
|
90
|
+
cost: float
|
|
91
|
+
):
|
|
92
|
+
"""Record token usage.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
user_id: User ID
|
|
96
|
+
provider: LLM provider
|
|
97
|
+
model: Model name
|
|
98
|
+
prompt_tokens: Prompt tokens used
|
|
99
|
+
completion_tokens: Completion tokens used
|
|
100
|
+
cost: Cost in USD
|
|
101
|
+
"""
|
|
102
|
+
conn = sqlite3.connect(self.db_path)
|
|
103
|
+
cursor = conn.cursor()
|
|
104
|
+
|
|
105
|
+
cursor.execute("""
|
|
106
|
+
INSERT INTO token_usage
|
|
107
|
+
(user_id, provider, model, prompt_tokens, completion_tokens, cost, timestamp)
|
|
108
|
+
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
109
|
+
""", (user_id, provider, model, prompt_tokens, completion_tokens, cost, datetime.utcnow()))
|
|
110
|
+
|
|
111
|
+
conn.commit()
|
|
112
|
+
conn.close()
|
|
113
|
+
|
|
114
|
+
def get_usage(
|
|
115
|
+
self,
|
|
116
|
+
user_id: str,
|
|
117
|
+
period: str = "day"
|
|
118
|
+
) -> Dict[str, Any]:
|
|
119
|
+
"""Get usage statistics.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
user_id: User ID
|
|
123
|
+
period: Time period (day, week, month)
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
Usage statistics
|
|
127
|
+
"""
|
|
128
|
+
# Calculate time range
|
|
129
|
+
now = datetime.utcnow()
|
|
130
|
+
if period == "day":
|
|
131
|
+
start_time = now - timedelta(days=1)
|
|
132
|
+
elif period == "week":
|
|
133
|
+
start_time = now - timedelta(weeks=1)
|
|
134
|
+
elif period == "month":
|
|
135
|
+
start_time = now - timedelta(days=30)
|
|
136
|
+
else:
|
|
137
|
+
start_time = now - timedelta(days=1)
|
|
138
|
+
|
|
139
|
+
conn = sqlite3.connect(self.db_path)
|
|
140
|
+
cursor = conn.cursor()
|
|
141
|
+
|
|
142
|
+
cursor.execute("""
|
|
143
|
+
SELECT
|
|
144
|
+
SUM(prompt_tokens) as total_prompt_tokens,
|
|
145
|
+
SUM(completion_tokens) as total_completion_tokens,
|
|
146
|
+
SUM(cost) as total_cost,
|
|
147
|
+
COUNT(*) as request_count
|
|
148
|
+
FROM token_usage
|
|
149
|
+
WHERE user_id = ? AND timestamp >= ?
|
|
150
|
+
""", (user_id, start_time))
|
|
151
|
+
|
|
152
|
+
row = cursor.fetchone()
|
|
153
|
+
|
|
154
|
+
# Get breakdown by provider/model
|
|
155
|
+
cursor.execute("""
|
|
156
|
+
SELECT provider, model, SUM(cost) as cost
|
|
157
|
+
FROM token_usage
|
|
158
|
+
WHERE user_id = ? AND timestamp >= ?
|
|
159
|
+
GROUP BY provider, model
|
|
160
|
+
""", (user_id, start_time))
|
|
161
|
+
|
|
162
|
+
breakdown = {}
|
|
163
|
+
for provider, model, cost in cursor.fetchall():
|
|
164
|
+
if provider not in breakdown:
|
|
165
|
+
breakdown[provider] = {}
|
|
166
|
+
breakdown[provider][model] = cost
|
|
167
|
+
|
|
168
|
+
conn.close()
|
|
169
|
+
|
|
170
|
+
return {
|
|
171
|
+
"period": period,
|
|
172
|
+
"total_prompt_tokens": row[0] or 0,
|
|
173
|
+
"total_completion_tokens": row[1] or 0,
|
|
174
|
+
"total_cost": row[2] or 0.0,
|
|
175
|
+
"request_count": row[3] or 0,
|
|
176
|
+
"breakdown": breakdown
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class BudgetManager:
|
|
181
|
+
"""Manage user budgets."""
|
|
182
|
+
|
|
183
|
+
def __init__(self, db_path: str = "genxai_budgets.db"):
|
|
184
|
+
"""Initialize budget manager.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
db_path: Path to SQLite database
|
|
188
|
+
"""
|
|
189
|
+
self.db_path = db_path
|
|
190
|
+
self._init_db()
|
|
191
|
+
|
|
192
|
+
def _init_db(self):
|
|
193
|
+
"""Initialize database schema."""
|
|
194
|
+
conn = sqlite3.connect(self.db_path)
|
|
195
|
+
cursor = conn.cursor()
|
|
196
|
+
|
|
197
|
+
cursor.execute("""
|
|
198
|
+
CREATE TABLE IF NOT EXISTS budgets (
|
|
199
|
+
user_id TEXT PRIMARY KEY,
|
|
200
|
+
amount REAL NOT NULL,
|
|
201
|
+
period TEXT NOT NULL,
|
|
202
|
+
created_at TIMESTAMP NOT NULL,
|
|
203
|
+
updated_at TIMESTAMP NOT NULL
|
|
204
|
+
)
|
|
205
|
+
""")
|
|
206
|
+
|
|
207
|
+
conn.commit()
|
|
208
|
+
conn.close()
|
|
209
|
+
|
|
210
|
+
def set_budget(
|
|
211
|
+
self,
|
|
212
|
+
user_id: str,
|
|
213
|
+
amount: float,
|
|
214
|
+
period: str = "month"
|
|
215
|
+
):
|
|
216
|
+
"""Set budget limit.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
user_id: User ID
|
|
220
|
+
amount: Budget amount in USD
|
|
221
|
+
period: Budget period (day, week, month)
|
|
222
|
+
"""
|
|
223
|
+
conn = sqlite3.connect(self.db_path)
|
|
224
|
+
cursor = conn.cursor()
|
|
225
|
+
|
|
226
|
+
now = datetime.utcnow()
|
|
227
|
+
|
|
228
|
+
cursor.execute("""
|
|
229
|
+
INSERT OR REPLACE INTO budgets (user_id, amount, period, created_at, updated_at)
|
|
230
|
+
VALUES (?, ?, ?, ?, ?)
|
|
231
|
+
""", (user_id, amount, period, now, now))
|
|
232
|
+
|
|
233
|
+
conn.commit()
|
|
234
|
+
conn.close()
|
|
235
|
+
|
|
236
|
+
def check_budget(self, user_id: str, usage_tracker: TokenUsageTracker) -> bool:
|
|
237
|
+
"""Check if user is within budget.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
user_id: User ID
|
|
241
|
+
usage_tracker: TokenUsageTracker instance
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
True if within budget, False otherwise
|
|
245
|
+
"""
|
|
246
|
+
# Get budget
|
|
247
|
+
conn = sqlite3.connect(self.db_path)
|
|
248
|
+
cursor = conn.cursor()
|
|
249
|
+
|
|
250
|
+
cursor.execute("""
|
|
251
|
+
SELECT amount, period FROM budgets WHERE user_id = ?
|
|
252
|
+
""", (user_id,))
|
|
253
|
+
|
|
254
|
+
row = cursor.fetchone()
|
|
255
|
+
conn.close()
|
|
256
|
+
|
|
257
|
+
if not row:
|
|
258
|
+
# No budget set, allow
|
|
259
|
+
return True
|
|
260
|
+
|
|
261
|
+
amount, period = row
|
|
262
|
+
|
|
263
|
+
# Get usage
|
|
264
|
+
usage = usage_tracker.get_usage(user_id, period)
|
|
265
|
+
|
|
266
|
+
return usage["total_cost"] < amount
|
|
267
|
+
|
|
268
|
+
def get_remaining(self, user_id: str, usage_tracker: TokenUsageTracker) -> float:
|
|
269
|
+
"""Get remaining budget.
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
user_id: User ID
|
|
273
|
+
usage_tracker: TokenUsageTracker instance
|
|
274
|
+
|
|
275
|
+
Returns:
|
|
276
|
+
Remaining budget in USD
|
|
277
|
+
"""
|
|
278
|
+
# Get budget
|
|
279
|
+
conn = sqlite3.connect(self.db_path)
|
|
280
|
+
cursor = conn.cursor()
|
|
281
|
+
|
|
282
|
+
cursor.execute("""
|
|
283
|
+
SELECT amount, period FROM budgets WHERE user_id = ?
|
|
284
|
+
""", (user_id,))
|
|
285
|
+
|
|
286
|
+
row = cursor.fetchone()
|
|
287
|
+
conn.close()
|
|
288
|
+
|
|
289
|
+
if not row:
|
|
290
|
+
return float('inf')
|
|
291
|
+
|
|
292
|
+
amount, period = row
|
|
293
|
+
|
|
294
|
+
# Get usage
|
|
295
|
+
usage = usage_tracker.get_usage(user_id, period)
|
|
296
|
+
|
|
297
|
+
return max(0, amount - usage["total_cost"])
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
class CostEstimator:
|
|
301
|
+
"""Estimate costs before execution."""
|
|
302
|
+
|
|
303
|
+
def __init__(self, costs: Optional[Dict[str, Any]] = None):
|
|
304
|
+
"""Initialize cost estimator.
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
costs: Custom cost table (default: use built-in costs)
|
|
308
|
+
"""
|
|
309
|
+
self.costs = costs or TOKEN_COSTS
|
|
310
|
+
|
|
311
|
+
def estimate_cost(
|
|
312
|
+
self,
|
|
313
|
+
provider: str,
|
|
314
|
+
model: str,
|
|
315
|
+
prompt_tokens: int,
|
|
316
|
+
estimated_completion_tokens: int
|
|
317
|
+
) -> float:
|
|
318
|
+
"""Estimate execution cost.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
provider: LLM provider
|
|
322
|
+
model: Model name
|
|
323
|
+
prompt_tokens: Number of prompt tokens
|
|
324
|
+
estimated_completion_tokens: Estimated completion tokens
|
|
325
|
+
|
|
326
|
+
Returns:
|
|
327
|
+
Estimated cost in USD
|
|
328
|
+
"""
|
|
329
|
+
if provider not in self.costs:
|
|
330
|
+
return 0.0
|
|
331
|
+
|
|
332
|
+
if model not in self.costs[provider]:
|
|
333
|
+
return 0.0
|
|
334
|
+
|
|
335
|
+
model_costs = self.costs[provider][model]
|
|
336
|
+
|
|
337
|
+
prompt_cost = (prompt_tokens / 1000) * model_costs["prompt"]
|
|
338
|
+
completion_cost = (estimated_completion_tokens / 1000) * model_costs["completion"]
|
|
339
|
+
|
|
340
|
+
return prompt_cost + completion_cost
|
|
341
|
+
|
|
342
|
+
def estimate_workflow_cost(
|
|
343
|
+
self,
|
|
344
|
+
steps: list[Dict[str, Any]]
|
|
345
|
+
) -> float:
|
|
346
|
+
"""Estimate workflow cost.
|
|
347
|
+
|
|
348
|
+
Args:
|
|
349
|
+
steps: List of workflow steps with provider, model, tokens
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
Estimated total cost
|
|
353
|
+
"""
|
|
354
|
+
total_cost = 0.0
|
|
355
|
+
|
|
356
|
+
for step in steps:
|
|
357
|
+
cost = self.estimate_cost(
|
|
358
|
+
step["provider"],
|
|
359
|
+
step["model"],
|
|
360
|
+
step["prompt_tokens"],
|
|
361
|
+
step["estimated_completion_tokens"]
|
|
362
|
+
)
|
|
363
|
+
total_cost += cost
|
|
364
|
+
|
|
365
|
+
return total_cost
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
class CostAlertManager:
|
|
369
|
+
"""Send alerts when costs exceed thresholds."""
|
|
370
|
+
|
|
371
|
+
def __init__(self, db_path: str = "genxai_alerts.db"):
|
|
372
|
+
"""Initialize alert manager.
|
|
373
|
+
|
|
374
|
+
Args:
|
|
375
|
+
db_path: Path to SQLite database
|
|
376
|
+
"""
|
|
377
|
+
self.db_path = db_path
|
|
378
|
+
self._init_db()
|
|
379
|
+
|
|
380
|
+
def _init_db(self):
|
|
381
|
+
"""Initialize database schema."""
|
|
382
|
+
conn = sqlite3.connect(self.db_path)
|
|
383
|
+
cursor = conn.cursor()
|
|
384
|
+
|
|
385
|
+
cursor.execute("""
|
|
386
|
+
CREATE TABLE IF NOT EXISTS cost_alerts (
|
|
387
|
+
user_id TEXT PRIMARY KEY,
|
|
388
|
+
threshold REAL NOT NULL,
|
|
389
|
+
notification_method TEXT NOT NULL,
|
|
390
|
+
last_alert TIMESTAMP
|
|
391
|
+
)
|
|
392
|
+
""")
|
|
393
|
+
|
|
394
|
+
conn.commit()
|
|
395
|
+
conn.close()
|
|
396
|
+
|
|
397
|
+
def set_alert(
|
|
398
|
+
self,
|
|
399
|
+
user_id: str,
|
|
400
|
+
threshold: float,
|
|
401
|
+
notification_method: str = "email"
|
|
402
|
+
):
|
|
403
|
+
"""Set cost alert.
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
user_id: User ID
|
|
407
|
+
threshold: Cost threshold in USD
|
|
408
|
+
notification_method: Notification method (email, slack, webhook)
|
|
409
|
+
"""
|
|
410
|
+
conn = sqlite3.connect(self.db_path)
|
|
411
|
+
cursor = conn.cursor()
|
|
412
|
+
|
|
413
|
+
cursor.execute("""
|
|
414
|
+
INSERT OR REPLACE INTO cost_alerts (user_id, threshold, notification_method)
|
|
415
|
+
VALUES (?, ?, ?)
|
|
416
|
+
""", (user_id, threshold, notification_method))
|
|
417
|
+
|
|
418
|
+
conn.commit()
|
|
419
|
+
conn.close()
|
|
420
|
+
|
|
421
|
+
def check_and_notify(self, user_id: str, current_cost: float):
|
|
422
|
+
"""Check threshold and send notification.
|
|
423
|
+
|
|
424
|
+
Args:
|
|
425
|
+
user_id: User ID
|
|
426
|
+
current_cost: Current cost
|
|
427
|
+
"""
|
|
428
|
+
conn = sqlite3.connect(self.db_path)
|
|
429
|
+
cursor = conn.cursor()
|
|
430
|
+
|
|
431
|
+
cursor.execute("""
|
|
432
|
+
SELECT threshold, notification_method, last_alert
|
|
433
|
+
FROM cost_alerts
|
|
434
|
+
WHERE user_id = ?
|
|
435
|
+
""", (user_id,))
|
|
436
|
+
|
|
437
|
+
row = cursor.fetchone()
|
|
438
|
+
|
|
439
|
+
if not row:
|
|
440
|
+
conn.close()
|
|
441
|
+
return
|
|
442
|
+
|
|
443
|
+
threshold, notification_method, last_alert = row
|
|
444
|
+
|
|
445
|
+
# Check if threshold exceeded
|
|
446
|
+
if current_cost >= threshold:
|
|
447
|
+
# Check if we already sent alert recently (within 1 hour)
|
|
448
|
+
if last_alert:
|
|
449
|
+
last_alert_time = datetime.fromisoformat(last_alert)
|
|
450
|
+
if datetime.utcnow() - last_alert_time < timedelta(hours=1):
|
|
451
|
+
conn.close()
|
|
452
|
+
return
|
|
453
|
+
|
|
454
|
+
# Send notification
|
|
455
|
+
self._send_notification(
|
|
456
|
+
user_id,
|
|
457
|
+
notification_method,
|
|
458
|
+
current_cost,
|
|
459
|
+
threshold
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
# Update last alert time
|
|
463
|
+
cursor.execute("""
|
|
464
|
+
UPDATE cost_alerts SET last_alert = ? WHERE user_id = ?
|
|
465
|
+
""", (datetime.utcnow(), user_id))
|
|
466
|
+
|
|
467
|
+
conn.commit()
|
|
468
|
+
|
|
469
|
+
conn.close()
|
|
470
|
+
|
|
471
|
+
def _send_notification(
|
|
472
|
+
self,
|
|
473
|
+
user_id: str,
|
|
474
|
+
method: str,
|
|
475
|
+
current_cost: float,
|
|
476
|
+
threshold: float
|
|
477
|
+
):
|
|
478
|
+
"""Send notification.
|
|
479
|
+
|
|
480
|
+
Args:
|
|
481
|
+
user_id: User ID
|
|
482
|
+
method: Notification method
|
|
483
|
+
current_cost: Current cost
|
|
484
|
+
threshold: Threshold
|
|
485
|
+
"""
|
|
486
|
+
message = f"Cost alert: User {user_id} has exceeded threshold ${threshold:.2f}. Current cost: ${current_cost:.2f}"
|
|
487
|
+
|
|
488
|
+
# Placeholder for actual notification implementation
|
|
489
|
+
print(f"[{method.upper()}] {message}")
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
# Global instances
|
|
493
|
+
_usage_tracker = None
|
|
494
|
+
_budget_manager = None
|
|
495
|
+
_cost_estimator = None
|
|
496
|
+
_alert_manager = None
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
def get_usage_tracker() -> TokenUsageTracker:
|
|
500
|
+
"""Get global usage tracker."""
|
|
501
|
+
global _usage_tracker
|
|
502
|
+
if _usage_tracker is None:
|
|
503
|
+
_usage_tracker = TokenUsageTracker()
|
|
504
|
+
return _usage_tracker
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
def get_budget_manager() -> BudgetManager:
|
|
508
|
+
"""Get global budget manager."""
|
|
509
|
+
global _budget_manager
|
|
510
|
+
if _budget_manager is None:
|
|
511
|
+
_budget_manager = BudgetManager()
|
|
512
|
+
return _budget_manager
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
def get_cost_estimator() -> CostEstimator:
|
|
516
|
+
"""Get global cost estimator."""
|
|
517
|
+
global _cost_estimator
|
|
518
|
+
if _cost_estimator is None:
|
|
519
|
+
_cost_estimator = CostEstimator()
|
|
520
|
+
return _cost_estimator
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
def get_alert_manager() -> CostAlertManager:
|
|
524
|
+
"""Get global alert manager."""
|
|
525
|
+
global _alert_manager
|
|
526
|
+
if _alert_manager is None:
|
|
527
|
+
_alert_manager = CostAlertManager()
|
|
528
|
+
return _alert_manager
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
"""Default policy setup with explicit approval request IDs."""
|
|
2
|
+
|
|
3
|
+
from genxai.security.audit import get_approval_service
|
|
4
|
+
from genxai.security.policy_engine import AccessRule, get_policy_engine
|
|
5
|
+
from genxai.security.rbac import Permission
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def register_default_policies() -> None:
|
|
9
|
+
"""Register default policies with explicit approval IDs."""
|
|
10
|
+
policy = get_policy_engine()
|
|
11
|
+
approvals = get_approval_service()
|
|
12
|
+
|
|
13
|
+
tool_approval = approvals.submit("tool.execute", "tool:calculator", "system")
|
|
14
|
+
policy.add_rule(
|
|
15
|
+
"tool:calculator",
|
|
16
|
+
AccessRule(
|
|
17
|
+
permissions={Permission.TOOL_EXECUTE},
|
|
18
|
+
allowed_users={"admin"},
|
|
19
|
+
requires_approval=True,
|
|
20
|
+
approval_request_id=tool_approval.request_id,
|
|
21
|
+
),
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
agent_approval = approvals.submit("agent.execute", "agent:finance_agent", "system")
|
|
25
|
+
policy.add_rule(
|
|
26
|
+
"agent:finance_agent",
|
|
27
|
+
AccessRule(
|
|
28
|
+
permissions={Permission.AGENT_EXECUTE},
|
|
29
|
+
allowed_users={"admin"},
|
|
30
|
+
requires_approval=True,
|
|
31
|
+
approval_request_id=agent_approval.request_id,
|
|
32
|
+
),
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
workflow_approval = approvals.submit("workflow.execute", "workflow:workflow", "system")
|
|
36
|
+
policy.add_rule(
|
|
37
|
+
"workflow:workflow",
|
|
38
|
+
AccessRule(
|
|
39
|
+
permissions={Permission.WORKFLOW_EXECUTE},
|
|
40
|
+
allowed_users={"admin"},
|
|
41
|
+
requires_approval=True,
|
|
42
|
+
approval_request_id=workflow_approval.request_id,
|
|
43
|
+
),
|
|
44
|
+
)
|
genxai/security/jwt.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
"""JWT token management for GenXAI."""
|
|
2
|
+
|
|
3
|
+
import jwt
|
|
4
|
+
import os
|
|
5
|
+
from datetime import datetime, timedelta
|
|
6
|
+
from typing import Dict, Any, Optional
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class TokenPayload:
|
|
12
|
+
"""JWT token payload."""
|
|
13
|
+
user_id: str
|
|
14
|
+
role: str
|
|
15
|
+
permissions: list[str]
|
|
16
|
+
exp: int
|
|
17
|
+
iat: int
|
|
18
|
+
iss: str = "genxai"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class JWTManager:
|
|
22
|
+
"""Manage JWT tokens."""
|
|
23
|
+
|
|
24
|
+
def __init__(self, secret_key: Optional[str] = None, algorithm: str = "HS256"):
|
|
25
|
+
"""Initialize JWT manager.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
secret_key: Secret key for signing tokens
|
|
29
|
+
algorithm: JWT algorithm (default: HS256)
|
|
30
|
+
"""
|
|
31
|
+
self.secret_key = secret_key or os.getenv("GENXAI_JWT_SECRET", "change-me-in-production")
|
|
32
|
+
self.algorithm = algorithm
|
|
33
|
+
self.issuer = "genxai"
|
|
34
|
+
|
|
35
|
+
def create_token(
|
|
36
|
+
self,
|
|
37
|
+
user_id: str,
|
|
38
|
+
role: str,
|
|
39
|
+
permissions: list[str],
|
|
40
|
+
expires_in: int = 3600
|
|
41
|
+
) -> str:
|
|
42
|
+
"""Create JWT token.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
user_id: User ID
|
|
46
|
+
role: User role
|
|
47
|
+
permissions: List of permissions
|
|
48
|
+
expires_in: Token expiration in seconds (default: 1 hour)
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
JWT token string
|
|
52
|
+
"""
|
|
53
|
+
now = datetime.utcnow()
|
|
54
|
+
exp = now + timedelta(seconds=expires_in)
|
|
55
|
+
|
|
56
|
+
payload = {
|
|
57
|
+
"sub": user_id,
|
|
58
|
+
"role": role,
|
|
59
|
+
"permissions": permissions,
|
|
60
|
+
"exp": int(exp.timestamp()),
|
|
61
|
+
"iat": int(now.timestamp()),
|
|
62
|
+
"iss": self.issuer
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
|
|
66
|
+
return token
|
|
67
|
+
|
|
68
|
+
def verify_token(self, token: str) -> Dict[str, Any]:
|
|
69
|
+
"""Verify and decode JWT token.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
token: JWT token string
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
Decoded token payload
|
|
76
|
+
|
|
77
|
+
Raises:
|
|
78
|
+
jwt.ExpiredSignatureError: Token has expired
|
|
79
|
+
jwt.InvalidTokenError: Token is invalid
|
|
80
|
+
"""
|
|
81
|
+
try:
|
|
82
|
+
payload = jwt.decode(
|
|
83
|
+
token,
|
|
84
|
+
self.secret_key,
|
|
85
|
+
algorithms=[self.algorithm],
|
|
86
|
+
issuer=self.issuer
|
|
87
|
+
)
|
|
88
|
+
return payload
|
|
89
|
+
except jwt.ExpiredSignatureError:
|
|
90
|
+
raise ValueError("Token has expired")
|
|
91
|
+
except jwt.InvalidTokenError as e:
|
|
92
|
+
raise ValueError(f"Invalid token: {str(e)}")
|
|
93
|
+
|
|
94
|
+
def refresh_token(self, token: str, expires_in: int = 3600) -> str:
|
|
95
|
+
"""Refresh JWT token.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
token: Existing JWT token
|
|
99
|
+
expires_in: New token expiration in seconds
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
New JWT token
|
|
103
|
+
"""
|
|
104
|
+
# Verify existing token
|
|
105
|
+
payload = self.verify_token(token)
|
|
106
|
+
|
|
107
|
+
# Create new token with same claims
|
|
108
|
+
return self.create_token(
|
|
109
|
+
user_id=payload["sub"],
|
|
110
|
+
role=payload["role"],
|
|
111
|
+
permissions=payload["permissions"],
|
|
112
|
+
expires_in=expires_in
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def decode_token_unsafe(self, token: str) -> Dict[str, Any]:
|
|
116
|
+
"""Decode token without verification (for debugging).
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
token: JWT token string
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
Decoded token payload
|
|
123
|
+
"""
|
|
124
|
+
return jwt.decode(token, options={"verify_signature": False})
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
# Global JWT manager
|
|
128
|
+
_jwt_manager = None
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def get_jwt_manager() -> JWTManager:
|
|
132
|
+
"""Get global JWT manager.
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
JWTManager instance
|
|
136
|
+
"""
|
|
137
|
+
global _jwt_manager
|
|
138
|
+
|
|
139
|
+
if _jwt_manager is None:
|
|
140
|
+
_jwt_manager = JWTManager()
|
|
141
|
+
|
|
142
|
+
return _jwt_manager
|