pylance-mcp-server 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +213 -0
- package/bin/pylance-mcp.js +68 -0
- package/mcp_server/__init__.py +13 -0
- package/mcp_server/__pycache__/__init__.cpython-312.pyc +0 -0
- package/mcp_server/__pycache__/__init__.cpython-313.pyc +0 -0
- package/mcp_server/__pycache__/__init__.cpython-314.pyc +0 -0
- package/mcp_server/__pycache__/ai_features.cpython-313.pyc +0 -0
- package/mcp_server/__pycache__/api_routes.cpython-313.pyc +0 -0
- package/mcp_server/__pycache__/auth.cpython-313.pyc +0 -0
- package/mcp_server/__pycache__/cloud_sync.cpython-313.pyc +0 -0
- package/mcp_server/__pycache__/logging_db.cpython-312.pyc +0 -0
- package/mcp_server/__pycache__/logging_db.cpython-313.pyc +0 -0
- package/mcp_server/__pycache__/pylance_bridge.cpython-312.pyc +0 -0
- package/mcp_server/__pycache__/pylance_bridge.cpython-313.pyc +0 -0
- package/mcp_server/__pycache__/pylance_bridge.cpython-314.pyc +0 -0
- package/mcp_server/__pycache__/resources.cpython-312.pyc +0 -0
- package/mcp_server/__pycache__/resources.cpython-313.pyc +0 -0
- package/mcp_server/__pycache__/tools.cpython-312.pyc +0 -0
- package/mcp_server/__pycache__/tools.cpython-313.pyc +0 -0
- package/mcp_server/__pycache__/tracing.cpython-313.pyc +0 -0
- package/mcp_server/ai_features.py +274 -0
- package/mcp_server/api_routes.py +429 -0
- package/mcp_server/auth.py +275 -0
- package/mcp_server/cloud_sync.py +427 -0
- package/mcp_server/logging_db.py +403 -0
- package/mcp_server/pylance_bridge.py +579 -0
- package/mcp_server/resources.py +174 -0
- package/mcp_server/tools.py +642 -0
- package/mcp_server/tracing.py +84 -0
- package/package.json +53 -0
- package/requirements.txt +29 -0
- package/scripts/check-python.js +57 -0
- package/server.py +1228 -0
|
@@ -0,0 +1,429 @@
|
|
|
1
|
+
"""
|
|
2
|
+
REST API Routes
|
|
3
|
+
|
|
4
|
+
HTTP API endpoints wrapping MCP tools for web/mobile clients.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
from typing import List, Dict, Any, Optional
|
|
9
|
+
from fastapi import APIRouter, Depends, HTTPException
|
|
10
|
+
from pydantic import BaseModel, Field
|
|
11
|
+
|
|
12
|
+
from mcp_server.auth import verify_api_key_header, auth_service
|
|
13
|
+
from mcp_server import ai_features
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
# Create API router
|
|
18
|
+
api_router = APIRouter(prefix="/api/v1", tags=["API"])
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# ==================== REQUEST/RESPONSE MODELS ====================
|
|
22
|
+
|
|
23
|
+
class AnalyzeCodeRequest(BaseModel):
|
|
24
|
+
"""Request model for code analysis."""
|
|
25
|
+
file_path: str = Field(..., description="Path to Python file (relative to workspace)")
|
|
26
|
+
content: str = Field(..., description="Python code content to analyze")
|
|
27
|
+
include_type_check: bool = Field(True, description="Run full type checking")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class AnalyzeCodeResponse(BaseModel):
|
|
31
|
+
"""Response model for code analysis."""
|
|
32
|
+
diagnostics: List[Dict[str, Any]]
|
|
33
|
+
summary: Dict[str, Any]
|
|
34
|
+
file_path: str
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class CompletionRequest(BaseModel):
|
|
38
|
+
"""Request model for code completions."""
|
|
39
|
+
file_path: str
|
|
40
|
+
content: str
|
|
41
|
+
line: int = Field(..., description="Line number (0-indexed)")
|
|
42
|
+
character: int = Field(..., description="Character position (0-indexed)")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class CompletionResponse(BaseModel):
|
|
46
|
+
"""Response model for completions."""
|
|
47
|
+
completions: List[Dict[str, Any]]
|
|
48
|
+
count: int
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class HoverRequest(BaseModel):
|
|
52
|
+
"""Request model for hover information."""
|
|
53
|
+
file_path: str
|
|
54
|
+
content: str
|
|
55
|
+
line: int
|
|
56
|
+
character: int
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class HoverResponse(BaseModel):
|
|
60
|
+
"""Response model for hover info."""
|
|
61
|
+
documentation: str
|
|
62
|
+
type_info: Optional[str] = None
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class DiagnosticsRequest(BaseModel):
|
|
66
|
+
"""Request model for diagnostics."""
|
|
67
|
+
file_path: str
|
|
68
|
+
content: str
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class DiagnosticsResponse(BaseModel):
|
|
72
|
+
"""Response model for diagnostics."""
|
|
73
|
+
errors: List[Dict[str, Any]]
|
|
74
|
+
warnings: List[Dict[str, Any]]
|
|
75
|
+
information: List[Dict[str, Any]]
|
|
76
|
+
summary: Dict[str, int]
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
# ==================== API ENDPOINTS ====================
|
|
80
|
+
|
|
81
|
+
@api_router.post("/analyze", response_model=AnalyzeCodeResponse)
|
|
82
|
+
async def analyze_code(
|
|
83
|
+
request: AnalyzeCodeRequest,
|
|
84
|
+
user: dict = Depends(verify_api_key_header)
|
|
85
|
+
):
|
|
86
|
+
"""
|
|
87
|
+
Analyze Python code for errors, warnings, and type issues.
|
|
88
|
+
|
|
89
|
+
Returns comprehensive diagnostics with severity levels.
|
|
90
|
+
"""
|
|
91
|
+
try:
|
|
92
|
+
from server import tools
|
|
93
|
+
|
|
94
|
+
if not tools:
|
|
95
|
+
raise HTTPException(status_code=503, detail="Server not initialized")
|
|
96
|
+
|
|
97
|
+
# Get diagnostics
|
|
98
|
+
diagnostics = tools.get_diagnostics(request.file_path, request.content)
|
|
99
|
+
|
|
100
|
+
# If type checking requested, add full type check
|
|
101
|
+
if request.include_type_check:
|
|
102
|
+
type_check_result = tools.type_check(
|
|
103
|
+
file_path=request.file_path,
|
|
104
|
+
content=request.content
|
|
105
|
+
)
|
|
106
|
+
diagnostics.extend(type_check_result.get("diagnostics", []))
|
|
107
|
+
|
|
108
|
+
# Summarize by severity
|
|
109
|
+
summary = {
|
|
110
|
+
"total_errors": sum(1 for d in diagnostics if d.get("severity") == "error"),
|
|
111
|
+
"total_warnings": sum(1 for d in diagnostics if d.get("severity") == "warning"),
|
|
112
|
+
"total_info": sum(1 for d in diagnostics if d.get("severity") in ["information", "hint"])
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
# Log usage
|
|
116
|
+
await auth_service.log_usage(user["api_key"], "analyze", tokens=len(request.content))
|
|
117
|
+
|
|
118
|
+
return AnalyzeCodeResponse(
|
|
119
|
+
diagnostics=diagnostics,
|
|
120
|
+
summary=summary,
|
|
121
|
+
file_path=request.file_path
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
except Exception as e:
|
|
125
|
+
logger.error(f"Analyze error: {e}")
|
|
126
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@api_router.post("/complete", response_model=CompletionResponse)
|
|
130
|
+
async def get_completions(
|
|
131
|
+
request: CompletionRequest,
|
|
132
|
+
user: dict = Depends(verify_api_key_header)
|
|
133
|
+
):
|
|
134
|
+
"""
|
|
135
|
+
Get code completions at a specific position.
|
|
136
|
+
|
|
137
|
+
Returns intelligent suggestions with type information.
|
|
138
|
+
"""
|
|
139
|
+
try:
|
|
140
|
+
from server import tools
|
|
141
|
+
|
|
142
|
+
if not tools:
|
|
143
|
+
raise HTTPException(status_code=503, detail="Server not initialized")
|
|
144
|
+
|
|
145
|
+
completions = tools.get_completions(
|
|
146
|
+
file_path=request.file_path,
|
|
147
|
+
line=request.line,
|
|
148
|
+
character=request.character,
|
|
149
|
+
content=request.content
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# Log usage
|
|
153
|
+
await auth_service.log_usage(user["api_key"], "complete", tokens=len(request.content))
|
|
154
|
+
|
|
155
|
+
return CompletionResponse(
|
|
156
|
+
completions=completions,
|
|
157
|
+
count=len(completions)
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
except Exception as e:
|
|
161
|
+
logger.error(f"Completion error: {e}")
|
|
162
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
@api_router.post("/hover", response_model=HoverResponse)
|
|
166
|
+
async def get_hover_info(
|
|
167
|
+
request: HoverRequest,
|
|
168
|
+
user: dict = Depends(verify_api_key_header)
|
|
169
|
+
):
|
|
170
|
+
"""
|
|
171
|
+
Get type information and documentation for a symbol.
|
|
172
|
+
|
|
173
|
+
Returns hover information like in VS Code.
|
|
174
|
+
"""
|
|
175
|
+
try:
|
|
176
|
+
from server import tools
|
|
177
|
+
|
|
178
|
+
if not tools:
|
|
179
|
+
raise HTTPException(status_code=503, detail="Server not initialized")
|
|
180
|
+
|
|
181
|
+
hover_text = tools.get_hover(
|
|
182
|
+
file_path=request.file_path,
|
|
183
|
+
line=request.line,
|
|
184
|
+
character=request.character,
|
|
185
|
+
content=request.content
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
# Log usage
|
|
189
|
+
await auth_service.log_usage(user["api_key"], "hover")
|
|
190
|
+
|
|
191
|
+
return HoverResponse(
|
|
192
|
+
documentation=hover_text,
|
|
193
|
+
type_info=hover_text # Can be parsed further
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
except Exception as e:
|
|
197
|
+
logger.error(f"Hover error: {e}")
|
|
198
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
@api_router.post("/diagnostics", response_model=DiagnosticsResponse)
|
|
202
|
+
async def get_diagnostics(
|
|
203
|
+
request: DiagnosticsRequest,
|
|
204
|
+
user: dict = Depends(verify_api_key_header)
|
|
205
|
+
):
|
|
206
|
+
"""
|
|
207
|
+
Get all diagnostics (errors, warnings, info) for a file.
|
|
208
|
+
|
|
209
|
+
Returns categorized diagnostics with line numbers.
|
|
210
|
+
"""
|
|
211
|
+
try:
|
|
212
|
+
from server import tools
|
|
213
|
+
|
|
214
|
+
if not tools:
|
|
215
|
+
raise HTTPException(status_code=503, detail="Server not initialized")
|
|
216
|
+
|
|
217
|
+
all_diagnostics = tools.get_diagnostics(request.file_path, request.content)
|
|
218
|
+
|
|
219
|
+
# Categorize by severity
|
|
220
|
+
errors = [d for d in all_diagnostics if d.get("severity") == "error"]
|
|
221
|
+
warnings = [d for d in all_diagnostics if d.get("severity") == "warning"]
|
|
222
|
+
information = [d for d in all_diagnostics if d.get("severity") in ["information", "hint"]]
|
|
223
|
+
|
|
224
|
+
summary = {
|
|
225
|
+
"errors": len(errors),
|
|
226
|
+
"warnings": len(warnings),
|
|
227
|
+
"information": len(information)
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
# Log usage
|
|
231
|
+
await auth_service.log_usage(user["api_key"], "diagnostics", tokens=len(request.content))
|
|
232
|
+
|
|
233
|
+
return DiagnosticsResponse(
|
|
234
|
+
errors=errors,
|
|
235
|
+
warnings=warnings,
|
|
236
|
+
information=information,
|
|
237
|
+
summary=summary
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
except Exception as e:
|
|
241
|
+
logger.error(f"Diagnostics error: {e}")
|
|
242
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
@api_router.get("/usage")
|
|
246
|
+
async def get_usage_stats(user: dict = Depends(verify_api_key_header)):
|
|
247
|
+
"""
|
|
248
|
+
Get API usage statistics for the current user.
|
|
249
|
+
|
|
250
|
+
Returns request counts, rate limit status, and billing info.
|
|
251
|
+
"""
|
|
252
|
+
try:
|
|
253
|
+
# TODO: Query Supabase usage_logs table
|
|
254
|
+
# For now, return mock data
|
|
255
|
+
return {
|
|
256
|
+
"user_id": user["user_id"],
|
|
257
|
+
"tier": user["tier"],
|
|
258
|
+
"current_period": {
|
|
259
|
+
"requests_made": 150,
|
|
260
|
+
"requests_limit": 5000 if user["tier"] == "pro" else 20,
|
|
261
|
+
"tokens_used": 45000,
|
|
262
|
+
"reset_at": "2025-12-19T00:00:00Z"
|
|
263
|
+
},
|
|
264
|
+
"rate_limit": {
|
|
265
|
+
"hourly_limit": 5000 if user["tier"] == "pro" else 20,
|
|
266
|
+
"daily_limit": 50000 if user["tier"] == "pro" else 100
|
|
267
|
+
}
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
except Exception as e:
|
|
271
|
+
logger.error(f"Usage stats error: {e}")
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
# ==================== AI FEATURES (PRO TIER ONLY) ====================
|
|
275
|
+
|
|
276
|
+
class ExplainErrorRequest(BaseModel):
|
|
277
|
+
"""Request model for AI type error explanations."""
|
|
278
|
+
code: str
|
|
279
|
+
error_message: str
|
|
280
|
+
file_path: str
|
|
281
|
+
line_number: int
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
class ExplainErrorResponse(BaseModel):
|
|
285
|
+
"""Response model for error explanations."""
|
|
286
|
+
explanation: str
|
|
287
|
+
model: str
|
|
288
|
+
tokens_used: int
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
class EnhanceTypesRequest(BaseModel):
|
|
292
|
+
"""Request model for AI type enhancement."""
|
|
293
|
+
code: str
|
|
294
|
+
file_path: str
|
|
295
|
+
existing_types: Optional[List[Dict[str, Any]]] = []
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
class EnhanceTypesResponse(BaseModel):
|
|
299
|
+
"""Response model for type enhancements."""
|
|
300
|
+
suggestions: str
|
|
301
|
+
model: str
|
|
302
|
+
tokens_used: int
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
@api_router.post("/explain-error", response_model=ExplainErrorResponse)
|
|
306
|
+
async def explain_type_error(
|
|
307
|
+
request: ExplainErrorRequest,
|
|
308
|
+
user: dict = Depends(verify_api_key_header)
|
|
309
|
+
):
|
|
310
|
+
"""
|
|
311
|
+
AI-powered type error explanation (PRO tier only).
|
|
312
|
+
|
|
313
|
+
Uses Claude to explain type errors in beginner-friendly language with fix suggestions.
|
|
314
|
+
"""
|
|
315
|
+
try:
|
|
316
|
+
# Check PRO tier
|
|
317
|
+
if user["tier"] != "pro":
|
|
318
|
+
raise HTTPException(
|
|
319
|
+
status_code=403,
|
|
320
|
+
detail="AI features require PRO tier. Upgrade at pylancemcp.com/pricing"
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
result = await ai_features.explain_type_error(
|
|
324
|
+
code=request.code,
|
|
325
|
+
error_message=request.error_message,
|
|
326
|
+
file_path=request.file_path,
|
|
327
|
+
line_number=request.line_number
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
if not result["success"]:
|
|
331
|
+
raise HTTPException(status_code=500, detail=result.get("error", "AI processing failed"))
|
|
332
|
+
|
|
333
|
+
# Log usage
|
|
334
|
+
await auth_service.log_usage(
|
|
335
|
+
user["api_key"],
|
|
336
|
+
"explain-error",
|
|
337
|
+
tokens=result["tokens_used"]
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
return ExplainErrorResponse(
|
|
341
|
+
explanation=result["explanation"],
|
|
342
|
+
model=result["model"],
|
|
343
|
+
tokens_used=result["tokens_used"]
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
except HTTPException:
|
|
347
|
+
raise
|
|
348
|
+
except Exception as e:
|
|
349
|
+
logger.error(f"Explain error failed: {e}")
|
|
350
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
@api_router.post("/enhance-types", response_model=EnhanceTypesResponse)
|
|
354
|
+
async def enhance_type_annotations(
|
|
355
|
+
request: EnhanceTypesRequest,
|
|
356
|
+
user: dict = Depends(verify_api_key_header)
|
|
357
|
+
):
|
|
358
|
+
"""
|
|
359
|
+
AI-powered type annotation suggestions (PRO tier only).
|
|
360
|
+
|
|
361
|
+
Uses Claude to suggest improved type annotations for Python code.
|
|
362
|
+
"""
|
|
363
|
+
try:
|
|
364
|
+
# Check PRO tier
|
|
365
|
+
if user["tier"] != "pro":
|
|
366
|
+
raise HTTPException(
|
|
367
|
+
status_code=403,
|
|
368
|
+
detail="AI features require PRO tier. Upgrade at pylancemcp.com/pricing"
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
result = await ai_features.enhance_type_annotations(
|
|
372
|
+
code=request.code,
|
|
373
|
+
file_path=request.file_path,
|
|
374
|
+
existing_types=request.existing_types or []
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
if not result["success"]:
|
|
378
|
+
raise HTTPException(status_code=500, detail=result.get("error", "AI processing failed"))
|
|
379
|
+
|
|
380
|
+
# Log usage
|
|
381
|
+
await auth_service.log_usage(
|
|
382
|
+
user["api_key"],
|
|
383
|
+
"enhance-types",
|
|
384
|
+
tokens=result["tokens_used"]
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
return EnhanceTypesResponse(
|
|
388
|
+
suggestions=result["suggestions"],
|
|
389
|
+
model=result["model"],
|
|
390
|
+
tokens_used=result["tokens_used"]
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
except HTTPException:
|
|
394
|
+
raise
|
|
395
|
+
except Exception as e:
|
|
396
|
+
logger.error(f"Enhance types failed: {e}")
|
|
397
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
@api_router.post("/analyze-patterns")
|
|
401
|
+
async def analyze_code_patterns_endpoint(
|
|
402
|
+
code: str,
|
|
403
|
+
diagnostics: List[Dict[str, Any]],
|
|
404
|
+
user: dict = Depends(verify_api_key_header)
|
|
405
|
+
):
|
|
406
|
+
"""
|
|
407
|
+
AI-powered code pattern analysis (PRO tier only).
|
|
408
|
+
|
|
409
|
+
Identifies anti-patterns and suggests improvements.
|
|
410
|
+
"""
|
|
411
|
+
try:
|
|
412
|
+
if user["tier"] != "pro":
|
|
413
|
+
raise HTTPException(status_code=403, detail="AI features require PRO tier")
|
|
414
|
+
|
|
415
|
+
result = await ai_features.analyze_code_patterns(code, diagnostics)
|
|
416
|
+
|
|
417
|
+
if not result["success"]:
|
|
418
|
+
raise HTTPException(status_code=500, detail=result.get("error"))
|
|
419
|
+
|
|
420
|
+
await auth_service.log_usage(user["api_key"], "analyze-patterns", tokens=result["tokens_used"])
|
|
421
|
+
|
|
422
|
+
return result
|
|
423
|
+
|
|
424
|
+
except HTTPException:
|
|
425
|
+
raise
|
|
426
|
+
except Exception as e:
|
|
427
|
+
logger.error(f"Pattern analysis failed: {e}")
|
|
428
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
429
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Authentication Middleware
|
|
3
|
+
|
|
4
|
+
Handles API key verification, subscription tier checks, and rate limiting.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
from datetime import datetime, timedelta
|
|
10
|
+
from typing import Optional, Dict, Any
|
|
11
|
+
from fastapi import Header, HTTPException, Request
|
|
12
|
+
from collections import defaultdict
|
|
13
|
+
import asyncio
|
|
14
|
+
import hashlib
|
|
15
|
+
import requests
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
# In-memory rate limiting (use Redis in production)
|
|
20
|
+
rate_limit_store: Dict[str, list] = defaultdict(list)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class AuthService:
|
|
24
|
+
"""Authentication and authorization service."""
|
|
25
|
+
|
|
26
|
+
def __init__(self) -> None:
|
|
27
|
+
"""Initialize auth service."""
|
|
28
|
+
self.appwrite_endpoint = os.getenv("APPWRITE_ENDPOINT") or os.getenv("VITE_APPWRITE_ENDPOINT")
|
|
29
|
+
self.appwrite_project_id = os.getenv("APPWRITE_PROJECT_ID") or os.getenv("VITE_APPWRITE_PROJECT_ID")
|
|
30
|
+
self.appwrite_api_key = os.getenv("APPWRITE_API_KEY")
|
|
31
|
+
self.appwrite_database_id = os.getenv("APPWRITE_DATABASE_ID") or os.getenv("VITE_APPWRITE_DATABASE_ID")
|
|
32
|
+
self.appwrite_user_profiles_collection_id = (
|
|
33
|
+
os.getenv("APPWRITE_USER_PROFILES_COLLECTION_ID")
|
|
34
|
+
or os.getenv("VITE_APPWRITE_USERS_COLLECTION_ID")
|
|
35
|
+
)
|
|
36
|
+
self.appwrite_api_keys_collection_id = (
|
|
37
|
+
os.getenv("APPWRITE_API_KEYS_COLLECTION_ID")
|
|
38
|
+
or os.getenv("VITE_APPWRITE_API_KEYS_COLLECTION_ID")
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
# Simple in-memory cache to reduce Appwrite calls: api_key_hash -> (expires_at_utc, user_info)
|
|
42
|
+
self._cache: Dict[str, Any] = {}
|
|
43
|
+
self._cache_ttl_seconds = int(os.getenv("AUTH_CACHE_TTL_SECONDS", "60"))
|
|
44
|
+
|
|
45
|
+
# Rate limits per tier
|
|
46
|
+
self.rate_limits = {
|
|
47
|
+
"free": {"requests_per_hour": 20, "requests_per_day": 100},
|
|
48
|
+
"pro": {"requests_per_hour": 5000, "requests_per_day": 50000},
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
def _appwrite_headers(self) -> Dict[str, str]:
|
|
52
|
+
if not self.appwrite_project_id or not self.appwrite_api_key:
|
|
53
|
+
raise HTTPException(status_code=500, detail="Appwrite auth is not configured")
|
|
54
|
+
|
|
55
|
+
return {
|
|
56
|
+
"X-Appwrite-Project": self.appwrite_project_id,
|
|
57
|
+
"X-Appwrite-Key": self.appwrite_api_key,
|
|
58
|
+
"Content-Type": "application/json",
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
def _require_appwrite_config(self) -> None:
|
|
62
|
+
missing = []
|
|
63
|
+
if not self.appwrite_endpoint:
|
|
64
|
+
missing.append("APPWRITE_ENDPOINT")
|
|
65
|
+
if not self.appwrite_project_id:
|
|
66
|
+
missing.append("APPWRITE_PROJECT_ID")
|
|
67
|
+
if not self.appwrite_api_key:
|
|
68
|
+
missing.append("APPWRITE_API_KEY")
|
|
69
|
+
if not self.appwrite_database_id:
|
|
70
|
+
missing.append("APPWRITE_DATABASE_ID")
|
|
71
|
+
if not self.appwrite_user_profiles_collection_id:
|
|
72
|
+
missing.append("APPWRITE_USER_PROFILES_COLLECTION_ID")
|
|
73
|
+
if not self.appwrite_api_keys_collection_id:
|
|
74
|
+
missing.append("APPWRITE_API_KEYS_COLLECTION_ID")
|
|
75
|
+
|
|
76
|
+
if missing:
|
|
77
|
+
raise HTTPException(
|
|
78
|
+
status_code=500,
|
|
79
|
+
detail=f"Appwrite configuration missing: {', '.join(missing)}",
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
def _sha256_hex(self, value: str) -> str:
|
|
83
|
+
return hashlib.sha256(value.encode("utf-8")).hexdigest()
|
|
84
|
+
|
|
85
|
+
def _appwrite_query_equal(self, field: str, value: Any) -> str:
|
|
86
|
+
# Appwrite expects query strings like:
|
|
87
|
+
# - equal("field", ["value"]) for strings
|
|
88
|
+
# - equal("active", [true]) for booleans
|
|
89
|
+
if isinstance(value, bool):
|
|
90
|
+
return f'equal("{field}", [{"true" if value else "false"}])'
|
|
91
|
+
|
|
92
|
+
escaped = str(value).replace('"', '\\"')
|
|
93
|
+
return f'equal("{field}", ["{escaped}"])'
|
|
94
|
+
|
|
95
|
+
def _get(self, url: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
|
96
|
+
headers = self._appwrite_headers()
|
|
97
|
+
resp = requests.get(url, headers=headers, params=params, timeout=10)
|
|
98
|
+
|
|
99
|
+
if resp.status_code == 400:
|
|
100
|
+
logger.error("Appwrite request failed (400): %s", resp.text)
|
|
101
|
+
raise HTTPException(status_code=500, detail="Auth backend misconfigured")
|
|
102
|
+
|
|
103
|
+
if resp.status_code in (401, 403, 404):
|
|
104
|
+
raise HTTPException(status_code=401, detail="Invalid API key")
|
|
105
|
+
|
|
106
|
+
if resp.status_code >= 500:
|
|
107
|
+
logger.error("Appwrite request failed (%s): %s", resp.status_code, resp.text)
|
|
108
|
+
raise HTTPException(status_code=503, detail="Auth backend unavailable")
|
|
109
|
+
|
|
110
|
+
if resp.status_code >= 400:
|
|
111
|
+
logger.error("Appwrite request failed (%s): %s", resp.status_code, resp.text)
|
|
112
|
+
raise HTTPException(status_code=500, detail="Auth backend error")
|
|
113
|
+
return resp.json()
|
|
114
|
+
|
|
115
|
+
def _get_api_key_record_by_hash(self, api_key_hash: str) -> Optional[Dict[str, Any]]:
|
|
116
|
+
base = self.appwrite_endpoint.rstrip("/")
|
|
117
|
+
url = (
|
|
118
|
+
f"{base}/databases/{self.appwrite_database_id}/collections/{self.appwrite_api_keys_collection_id}/documents"
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
params = {
|
|
122
|
+
"queries[]": [
|
|
123
|
+
self._appwrite_query_equal("keyHash", api_key_hash),
|
|
124
|
+
self._appwrite_query_equal("active", True),
|
|
125
|
+
"limit(1)",
|
|
126
|
+
]
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
data = self._get(url, params=params)
|
|
130
|
+
docs = data.get("documents", [])
|
|
131
|
+
if not docs:
|
|
132
|
+
return None
|
|
133
|
+
return docs[0]
|
|
134
|
+
|
|
135
|
+
def _get_user_profile(self, user_id: str) -> Optional[Dict[str, Any]]:
|
|
136
|
+
base = self.appwrite_endpoint.rstrip("/")
|
|
137
|
+
url = (
|
|
138
|
+
f"{base}/databases/{self.appwrite_database_id}/collections/{self.appwrite_user_profiles_collection_id}/documents/{user_id}"
|
|
139
|
+
)
|
|
140
|
+
try:
|
|
141
|
+
return self._get(url)
|
|
142
|
+
except HTTPException:
|
|
143
|
+
return None
|
|
144
|
+
|
|
145
|
+
async def verify_api_key(self, api_key: str) -> Dict[str, Any]:
|
|
146
|
+
"""
|
|
147
|
+
Verify API key and return user info with subscription tier.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
api_key: API key from request header
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Dictionary with user_id, email, tier, and permissions
|
|
154
|
+
|
|
155
|
+
Raises:
|
|
156
|
+
HTTPException: If API key is invalid or expired
|
|
157
|
+
"""
|
|
158
|
+
if not api_key:
|
|
159
|
+
raise HTTPException(status_code=401, detail="API key required")
|
|
160
|
+
|
|
161
|
+
# Allow test keys
|
|
162
|
+
if api_key.startswith("test_"):
|
|
163
|
+
return {
|
|
164
|
+
"user_id": "test_user",
|
|
165
|
+
"email": "test@example.com",
|
|
166
|
+
"tier": "pro",
|
|
167
|
+
"api_key": api_key,
|
|
168
|
+
"active": True
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
self._require_appwrite_config()
|
|
172
|
+
|
|
173
|
+
api_key_hash = self._sha256_hex(api_key)
|
|
174
|
+
now = datetime.utcnow()
|
|
175
|
+
|
|
176
|
+
cached = self._cache.get(api_key_hash)
|
|
177
|
+
if cached:
|
|
178
|
+
expires_at, user_info = cached
|
|
179
|
+
if expires_at > now:
|
|
180
|
+
return user_info
|
|
181
|
+
self._cache.pop(api_key_hash, None)
|
|
182
|
+
|
|
183
|
+
def lookup() -> Dict[str, Any]:
|
|
184
|
+
record = self._get_api_key_record_by_hash(api_key_hash)
|
|
185
|
+
if not record:
|
|
186
|
+
raise HTTPException(status_code=401, detail="Invalid API key")
|
|
187
|
+
|
|
188
|
+
user_id = record.get("userId") or record.get("user_id")
|
|
189
|
+
if not user_id:
|
|
190
|
+
raise HTTPException(status_code=401, detail="Invalid API key")
|
|
191
|
+
|
|
192
|
+
profile = self._get_user_profile(user_id) or {}
|
|
193
|
+
tier = profile.get("tier") or "free"
|
|
194
|
+
email = profile.get("email") or ""
|
|
195
|
+
|
|
196
|
+
user_info = {
|
|
197
|
+
"user_id": user_id,
|
|
198
|
+
"email": email,
|
|
199
|
+
"tier": tier,
|
|
200
|
+
"api_key": api_key,
|
|
201
|
+
"active": True,
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
return user_info
|
|
205
|
+
|
|
206
|
+
user_info = await asyncio.to_thread(lookup)
|
|
207
|
+
self._cache[api_key_hash] = (now + timedelta(seconds=self._cache_ttl_seconds), user_info)
|
|
208
|
+
return user_info
|
|
209
|
+
|
|
210
|
+
async def check_rate_limit(self, api_key: str, tier: str) -> bool:
|
|
211
|
+
"""
|
|
212
|
+
Check if request is within rate limits for the tier.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
api_key: User's API key
|
|
216
|
+
tier: Subscription tier (free/pro)
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
True if within limits, raises HTTPException otherwise
|
|
220
|
+
"""
|
|
221
|
+
now = datetime.utcnow()
|
|
222
|
+
hour_ago = now - timedelta(hours=1)
|
|
223
|
+
|
|
224
|
+
# Clean old entries
|
|
225
|
+
rate_limit_store[api_key] = [
|
|
226
|
+
ts for ts in rate_limit_store[api_key] if ts > hour_ago
|
|
227
|
+
]
|
|
228
|
+
|
|
229
|
+
# Check limit
|
|
230
|
+
current_count = len(rate_limit_store[api_key])
|
|
231
|
+
limit = self.rate_limits.get(tier, self.rate_limits["free"])["requests_per_hour"]
|
|
232
|
+
|
|
233
|
+
if current_count >= limit:
|
|
234
|
+
raise HTTPException(
|
|
235
|
+
status_code=429,
|
|
236
|
+
detail=f"Rate limit exceeded. Limit: {limit} requests/hour. "
|
|
237
|
+
f"Upgrade to Pro for higher limits."
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# Add current request
|
|
241
|
+
rate_limit_store[api_key].append(now)
|
|
242
|
+
return True
|
|
243
|
+
|
|
244
|
+
async def log_usage(self, api_key: str, endpoint: str, tokens: int = 0) -> None:
|
|
245
|
+
"""
|
|
246
|
+
Log API usage for analytics and billing.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
api_key: User's API key
|
|
250
|
+
endpoint: API endpoint called
|
|
251
|
+
tokens: Number of tokens processed (if applicable)
|
|
252
|
+
"""
|
|
253
|
+
# TODO: Insert into Supabase usage_logs table
|
|
254
|
+
logger.info(f"Usage: {api_key[:8]}... -> {endpoint} ({tokens} tokens)")
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
# Global auth service instance
|
|
258
|
+
auth_service = AuthService()
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
async def verify_api_key_header(
|
|
262
|
+
x_api_key: str = Header(..., description="Your API key")
|
|
263
|
+
) -> Dict[str, Any]:
|
|
264
|
+
"""
|
|
265
|
+
FastAPI dependency for API key verification.
|
|
266
|
+
|
|
267
|
+
Usage:
|
|
268
|
+
@app.post("/api/v1/analyze")
|
|
269
|
+
async def analyze(user: dict = Depends(verify_api_key_header)):
|
|
270
|
+
# user has: user_id, email, tier, api_key, active
|
|
271
|
+
pass
|
|
272
|
+
"""
|
|
273
|
+
user = await auth_service.verify_api_key(x_api_key)
|
|
274
|
+
await auth_service.check_rate_limit(x_api_key, user["tier"])
|
|
275
|
+
return user
|