nia-mcp-server 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nia-mcp-server might be problematic. Click here for more details.
- nia_mcp_server/__init__.py +5 -0
- nia_mcp_server/__main__.py +11 -0
- nia_mcp_server/api_client.py +477 -0
- nia_mcp_server/server.py +804 -0
- nia_mcp_server-1.0.0.dist-info/METADATA +200 -0
- nia_mcp_server-1.0.0.dist-info/RECORD +9 -0
- nia_mcp_server-1.0.0.dist-info/WHEEL +4 -0
- nia_mcp_server-1.0.0.dist-info/entry_points.txt +2 -0
- nia_mcp_server-1.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,477 @@
|
|
|
1
|
+
"""
|
|
2
|
+
NIA API Client for communicating with production NIA API
|
|
3
|
+
"""
|
|
4
|
+
import os
|
|
5
|
+
import httpx
|
|
6
|
+
import asyncio
|
|
7
|
+
from typing import Dict, Any, List, Optional, AsyncIterator
|
|
8
|
+
import json
|
|
9
|
+
import logging
|
|
10
|
+
from urllib.parse import quote
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
class APIError(Exception):
|
|
15
|
+
"""Custom exception for API errors with status code."""
|
|
16
|
+
def __init__(self, message: str, status_code: int = None, detail: str = None):
|
|
17
|
+
super().__init__(message)
|
|
18
|
+
self.status_code = status_code
|
|
19
|
+
self.detail = detail
|
|
20
|
+
|
|
21
|
+
class NIAApiClient:
|
|
22
|
+
"""Client for interacting with NIA's production API."""
|
|
23
|
+
|
|
24
|
+
def __init__(self, api_key: str, base_url: str = None):
|
|
25
|
+
self.api_key = api_key
|
|
26
|
+
# Remove trailing slash from base URL to prevent double slashes
|
|
27
|
+
self.base_url = (base_url or os.getenv("NIA_API_URL", "https://api.trynia.ai")).rstrip('/')
|
|
28
|
+
self.client = httpx.AsyncClient(
|
|
29
|
+
headers={
|
|
30
|
+
"Authorization": f"Bearer {api_key}",
|
|
31
|
+
"User-Agent": "nia-mcp-server/1.0.0",
|
|
32
|
+
"Content-Type": "application/json"
|
|
33
|
+
},
|
|
34
|
+
timeout=300.0 # 5 minute timeout for long operations
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
async def close(self):
|
|
38
|
+
"""Close the HTTP client."""
|
|
39
|
+
await self.client.aclose()
|
|
40
|
+
|
|
41
|
+
async def __aenter__(self):
|
|
42
|
+
return self
|
|
43
|
+
|
|
44
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
45
|
+
await self.close()
|
|
46
|
+
|
|
47
|
+
def _handle_api_error(self, e: httpx.HTTPStatusError) -> APIError:
|
|
48
|
+
"""Convert HTTP errors to more specific API errors."""
|
|
49
|
+
error_detail = e.response.text
|
|
50
|
+
try:
|
|
51
|
+
error_json = e.response.json()
|
|
52
|
+
error_detail = error_json.get("detail", error_detail)
|
|
53
|
+
except (json.JSONDecodeError, ValueError):
|
|
54
|
+
# Failed to parse JSON response, keep original error_detail
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
status_code = e.response.status_code
|
|
58
|
+
|
|
59
|
+
# Log the full error for debugging
|
|
60
|
+
logger.error(f"API error - Status: {status_code}, Response: {error_detail}")
|
|
61
|
+
|
|
62
|
+
# Handle specific error cases
|
|
63
|
+
if status_code == 401:
|
|
64
|
+
return APIError(
|
|
65
|
+
"Invalid or missing API key. Please check your API key at https://trynia.ai/api-keys",
|
|
66
|
+
status_code,
|
|
67
|
+
error_detail,
|
|
68
|
+
)
|
|
69
|
+
elif status_code == 403:
|
|
70
|
+
# Check for various forms of usage limit errors
|
|
71
|
+
error_lower = error_detail.lower()
|
|
72
|
+
if any(
|
|
73
|
+
phrase in error_lower
|
|
74
|
+
for phrase in [
|
|
75
|
+
"lifetime limit",
|
|
76
|
+
"no chat credits",
|
|
77
|
+
"free api requests",
|
|
78
|
+
"25 free",
|
|
79
|
+
"usage limit",
|
|
80
|
+
]
|
|
81
|
+
):
|
|
82
|
+
# Use the exact error message from the API for clarity
|
|
83
|
+
return APIError(error_detail, status_code, error_detail)
|
|
84
|
+
else:
|
|
85
|
+
return APIError(
|
|
86
|
+
f"Access forbidden: {error_detail}", status_code, error_detail
|
|
87
|
+
)
|
|
88
|
+
elif status_code == 429:
|
|
89
|
+
return APIError(f"Rate limit exceeded: {error_detail}", status_code, error_detail)
|
|
90
|
+
elif status_code == 404:
|
|
91
|
+
return APIError(f"Resource not found: {error_detail}", status_code, error_detail)
|
|
92
|
+
elif status_code == 500:
|
|
93
|
+
# For 500 errors, try to extract more meaningful error details
|
|
94
|
+
if error_detail:
|
|
95
|
+
error_lower = error_detail.lower()
|
|
96
|
+
# Check if it's actually a wrapped error from middleware or API
|
|
97
|
+
if any(
|
|
98
|
+
phrase in error_lower
|
|
99
|
+
for phrase in [
|
|
100
|
+
"lifetime limit",
|
|
101
|
+
"free api requests",
|
|
102
|
+
"25 free",
|
|
103
|
+
"usage limit",
|
|
104
|
+
]
|
|
105
|
+
):
|
|
106
|
+
return APIError(error_detail, 403, error_detail)
|
|
107
|
+
else:
|
|
108
|
+
return APIError(f"Server error: {error_detail}", status_code, error_detail)
|
|
109
|
+
else:
|
|
110
|
+
return APIError(
|
|
111
|
+
"Internal server error. Please try again later.",
|
|
112
|
+
status_code,
|
|
113
|
+
error_detail,
|
|
114
|
+
)
|
|
115
|
+
else:
|
|
116
|
+
return APIError(
|
|
117
|
+
f"API error (status {status_code}): {error_detail}",
|
|
118
|
+
status_code,
|
|
119
|
+
error_detail,
|
|
120
|
+
)
|
|
121
|
+
async def validate_api_key(self) -> bool:
|
|
122
|
+
"""Validate the API key by making a test request."""
|
|
123
|
+
try:
|
|
124
|
+
response = await self.client.get(f"{self.base_url}/v2/repositories")
|
|
125
|
+
return response.status_code == 200
|
|
126
|
+
except httpx.HTTPStatusError as e:
|
|
127
|
+
# Log the specific error but return False for validation
|
|
128
|
+
error = self._handle_api_error(e)
|
|
129
|
+
logger.error(f"API key validation failed: {error}")
|
|
130
|
+
return False
|
|
131
|
+
except Exception as e:
|
|
132
|
+
logger.error(f"API key validation failed: {e}")
|
|
133
|
+
return False
|
|
134
|
+
|
|
135
|
+
async def list_repositories(self) -> List[Dict[str, Any]]:
|
|
136
|
+
"""List all indexed repositories."""
|
|
137
|
+
try:
|
|
138
|
+
response = await self.client.get(f"{self.base_url}/v2/repositories")
|
|
139
|
+
response.raise_for_status()
|
|
140
|
+
return response.json()
|
|
141
|
+
except httpx.HTTPStatusError as e:
|
|
142
|
+
logger.error(f"Caught HTTPStatusError in list_repositories: status={e.response.status_code}, detail={e.response.text}")
|
|
143
|
+
raise self._handle_api_error(e)
|
|
144
|
+
except Exception as e:
|
|
145
|
+
logger.error(f"Failed to list repositories: {e}")
|
|
146
|
+
raise APIError(f"Failed to list repositories: {str(e)}")
|
|
147
|
+
|
|
148
|
+
async def index_repository(self, repo_url: str, branch: str = None) -> Dict[str, Any]:
|
|
149
|
+
"""Index a GitHub repository."""
|
|
150
|
+
try:
|
|
151
|
+
# Parse repo URL to get owner/repo format
|
|
152
|
+
if "github.com" in repo_url:
|
|
153
|
+
parts = repo_url.rstrip('/').split('/')
|
|
154
|
+
owner_repo = f"{parts[-2]}/{parts[-1]}"
|
|
155
|
+
else:
|
|
156
|
+
owner_repo = repo_url
|
|
157
|
+
|
|
158
|
+
payload = {
|
|
159
|
+
"repository": owner_repo,
|
|
160
|
+
"branch": branch
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
response = await self.client.post(
|
|
164
|
+
f"{self.base_url}/v2/repositories",
|
|
165
|
+
json=payload
|
|
166
|
+
)
|
|
167
|
+
response.raise_for_status()
|
|
168
|
+
return response.json()
|
|
169
|
+
|
|
170
|
+
except httpx.HTTPStatusError as e:
|
|
171
|
+
raise self._handle_api_error(e)
|
|
172
|
+
except Exception as e:
|
|
173
|
+
raise APIError(f"Failed to index repository: {str(e)}")
|
|
174
|
+
|
|
175
|
+
async def get_repository_status(self, owner_repo: str) -> Dict[str, Any]:
|
|
176
|
+
"""Get the status of a repository."""
|
|
177
|
+
try:
|
|
178
|
+
# Check if this looks like owner/repo format (contains /)
|
|
179
|
+
if '/' in owner_repo:
|
|
180
|
+
# First, list all repositories to find the matching one
|
|
181
|
+
repos = await self.list_repositories()
|
|
182
|
+
|
|
183
|
+
# Look for a repository matching this owner/repo
|
|
184
|
+
matching_repo = None
|
|
185
|
+
for repo in repos:
|
|
186
|
+
# Check if repository field matches
|
|
187
|
+
if repo.get("repository") == owner_repo:
|
|
188
|
+
matching_repo = repo
|
|
189
|
+
break
|
|
190
|
+
|
|
191
|
+
if not matching_repo:
|
|
192
|
+
logger.warning(f"Repository {owner_repo} not found in list")
|
|
193
|
+
return None
|
|
194
|
+
|
|
195
|
+
# Use the repository_id from the matched repo
|
|
196
|
+
repo_id = matching_repo.get("repository_id") or matching_repo.get("id")
|
|
197
|
+
if not repo_id:
|
|
198
|
+
logger.error(f"No repository ID found for {owner_repo}")
|
|
199
|
+
return None
|
|
200
|
+
|
|
201
|
+
# Now get the status using the ID
|
|
202
|
+
response = await self.client.get(f"{self.base_url}/v2/repositories/{repo_id}")
|
|
203
|
+
response.raise_for_status()
|
|
204
|
+
|
|
205
|
+
# Merge the response with what we know
|
|
206
|
+
status = response.json()
|
|
207
|
+
# Ensure repository field is included for consistency
|
|
208
|
+
if "repository" not in status:
|
|
209
|
+
status["repository"] = owner_repo
|
|
210
|
+
return status
|
|
211
|
+
else:
|
|
212
|
+
# Assume it's already a repository ID
|
|
213
|
+
response = await self.client.get(f"{self.base_url}/v2/repositories/{owner_repo}")
|
|
214
|
+
response.raise_for_status()
|
|
215
|
+
return response.json()
|
|
216
|
+
|
|
217
|
+
except httpx.HTTPStatusError as e:
|
|
218
|
+
if e.response.status_code == 404:
|
|
219
|
+
return None
|
|
220
|
+
raise
|
|
221
|
+
except Exception as e:
|
|
222
|
+
logger.error(f"Failed to get repository status: {e}")
|
|
223
|
+
return None
|
|
224
|
+
|
|
225
|
+
async def query_repositories(
|
|
226
|
+
self,
|
|
227
|
+
messages: List[Dict[str, str]],
|
|
228
|
+
repositories: List[str],
|
|
229
|
+
stream: bool = True,
|
|
230
|
+
include_sources: bool = True
|
|
231
|
+
) -> AsyncIterator[str]:
|
|
232
|
+
"""Query indexed repositories with streaming support."""
|
|
233
|
+
try:
|
|
234
|
+
# Format repositories for the API
|
|
235
|
+
repo_list = []
|
|
236
|
+
for repo in repositories:
|
|
237
|
+
if "/" in repo:
|
|
238
|
+
repo_list.append({"repository": repo})
|
|
239
|
+
else:
|
|
240
|
+
# Assume it's a project ID or other identifier
|
|
241
|
+
repo_list.append({"repository": repo})
|
|
242
|
+
|
|
243
|
+
payload = {
|
|
244
|
+
"messages": messages,
|
|
245
|
+
"repositories": repo_list,
|
|
246
|
+
"stream": stream,
|
|
247
|
+
"include_sources": include_sources
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
if stream:
|
|
251
|
+
async with self.client.stream(
|
|
252
|
+
"POST",
|
|
253
|
+
f"{self.base_url}/v2/query",
|
|
254
|
+
json=payload
|
|
255
|
+
) as response:
|
|
256
|
+
response.raise_for_status()
|
|
257
|
+
|
|
258
|
+
async for line in response.aiter_lines():
|
|
259
|
+
if line.strip():
|
|
260
|
+
if line.startswith("data: "):
|
|
261
|
+
data = line[6:] # Remove "data: " prefix
|
|
262
|
+
if data == "[DONE]":
|
|
263
|
+
break
|
|
264
|
+
yield data
|
|
265
|
+
else:
|
|
266
|
+
response = await self.client.post(
|
|
267
|
+
f"{self.base_url}/v2/query",
|
|
268
|
+
json=payload
|
|
269
|
+
)
|
|
270
|
+
response.raise_for_status()
|
|
271
|
+
yield json.dumps(response.json())
|
|
272
|
+
|
|
273
|
+
except httpx.HTTPStatusError as e:
|
|
274
|
+
raise self._handle_api_error(e)
|
|
275
|
+
except Exception as e:
|
|
276
|
+
raise APIError(f"Query failed: {str(e)}")
|
|
277
|
+
|
|
278
|
+
async def wait_for_indexing(self, owner_repo: str, timeout: int = 600) -> Dict[str, Any]:
|
|
279
|
+
"""Wait for a repository to finish indexing."""
|
|
280
|
+
start_time = asyncio.get_event_loop().time()
|
|
281
|
+
|
|
282
|
+
while True:
|
|
283
|
+
status = await self.get_repository_status(owner_repo)
|
|
284
|
+
|
|
285
|
+
if not status:
|
|
286
|
+
raise Exception(f"Repository {owner_repo} not found")
|
|
287
|
+
|
|
288
|
+
if status["status"] == "completed":
|
|
289
|
+
return status
|
|
290
|
+
elif status["status"] == "failed":
|
|
291
|
+
raise Exception(f"Indexing failed: {status.get('error', 'Unknown error')}")
|
|
292
|
+
|
|
293
|
+
# Check timeout
|
|
294
|
+
if asyncio.get_event_loop().time() - start_time > timeout:
|
|
295
|
+
raise Exception(f"Indexing timeout after {timeout} seconds")
|
|
296
|
+
|
|
297
|
+
# Wait before next check
|
|
298
|
+
await asyncio.sleep(2)
|
|
299
|
+
|
|
300
|
+
async def delete_repository(self, owner_repo: str) -> bool:
|
|
301
|
+
"""Delete an indexed repository."""
|
|
302
|
+
try:
|
|
303
|
+
# Check if this looks like owner/repo format (contains /)
|
|
304
|
+
if '/' in owner_repo:
|
|
305
|
+
# First, get the repository ID
|
|
306
|
+
status = await self.get_repository_status(owner_repo)
|
|
307
|
+
if not status:
|
|
308
|
+
logger.warning(f"Repository {owner_repo} not found")
|
|
309
|
+
return False
|
|
310
|
+
|
|
311
|
+
# Extract the repository ID from status
|
|
312
|
+
repo_id = status.get("repository_id") or status.get("id")
|
|
313
|
+
if not repo_id:
|
|
314
|
+
# Try to get it from list as fallback
|
|
315
|
+
repos = await self.list_repositories()
|
|
316
|
+
for repo in repos:
|
|
317
|
+
if repo.get("repository") == owner_repo:
|
|
318
|
+
repo_id = repo.get("repository_id") or repo.get("id")
|
|
319
|
+
break
|
|
320
|
+
|
|
321
|
+
if not repo_id:
|
|
322
|
+
logger.error(f"No repository ID found for {owner_repo}")
|
|
323
|
+
return False
|
|
324
|
+
|
|
325
|
+
# Delete using the ID
|
|
326
|
+
response = await self.client.delete(f"{self.base_url}/v2/repositories/{repo_id}")
|
|
327
|
+
response.raise_for_status()
|
|
328
|
+
return True
|
|
329
|
+
else:
|
|
330
|
+
# Assume it's already a repository ID
|
|
331
|
+
response = await self.client.delete(f"{self.base_url}/v2/repositories/{owner_repo}")
|
|
332
|
+
response.raise_for_status()
|
|
333
|
+
return True
|
|
334
|
+
|
|
335
|
+
except Exception as e:
|
|
336
|
+
logger.error(f"Failed to delete repository: {e}")
|
|
337
|
+
return False
|
|
338
|
+
|
|
339
|
+
# Data Source methods
|
|
340
|
+
|
|
341
|
+
async def create_data_source(
|
|
342
|
+
self,
|
|
343
|
+
url: str,
|
|
344
|
+
url_patterns: List[str] = None,
|
|
345
|
+
max_age: int = None,
|
|
346
|
+
only_main_content: bool = True
|
|
347
|
+
) -> Dict[str, Any]:
|
|
348
|
+
"""Create a new documentation/web data source."""
|
|
349
|
+
try:
|
|
350
|
+
payload = {
|
|
351
|
+
"url": url,
|
|
352
|
+
"url_patterns": url_patterns or []
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
# Add optional parameters
|
|
356
|
+
if max_age is not None:
|
|
357
|
+
payload["max_age"] = max_age
|
|
358
|
+
# Always use markdown for documentation scraping
|
|
359
|
+
payload["formats"] = ["markdown"]
|
|
360
|
+
if only_main_content is not None:
|
|
361
|
+
payload["only_main_content"] = only_main_content
|
|
362
|
+
|
|
363
|
+
response = await self.client.post(
|
|
364
|
+
f"{self.base_url}/v2/data-sources",
|
|
365
|
+
json=payload
|
|
366
|
+
)
|
|
367
|
+
response.raise_for_status()
|
|
368
|
+
return response.json()
|
|
369
|
+
|
|
370
|
+
except httpx.HTTPStatusError as e:
|
|
371
|
+
raise self._handle_api_error(e)
|
|
372
|
+
except Exception as e:
|
|
373
|
+
raise APIError(f"Failed to create data source: {str(e)}")
|
|
374
|
+
|
|
375
|
+
async def list_data_sources(self) -> List[Dict[str, Any]]:
|
|
376
|
+
"""List all data sources for the authenticated user."""
|
|
377
|
+
try:
|
|
378
|
+
response = await self.client.get(f"{self.base_url}/v2/data-sources")
|
|
379
|
+
response.raise_for_status()
|
|
380
|
+
return response.json()
|
|
381
|
+
except httpx.HTTPStatusError as e:
|
|
382
|
+
raise self._handle_api_error(e)
|
|
383
|
+
except Exception as e:
|
|
384
|
+
logger.error(f"Failed to list data sources: {e}")
|
|
385
|
+
raise APIError(f"Failed to list data sources: {str(e)}")
|
|
386
|
+
|
|
387
|
+
async def get_data_source_status(self, source_id: str) -> Dict[str, Any]:
|
|
388
|
+
"""Get the status of a data source."""
|
|
389
|
+
try:
|
|
390
|
+
response = await self.client.get(f"{self.base_url}/v2/data-sources/{source_id}")
|
|
391
|
+
response.raise_for_status()
|
|
392
|
+
return response.json()
|
|
393
|
+
except httpx.HTTPStatusError as e:
|
|
394
|
+
if e.response.status_code == 404:
|
|
395
|
+
return None
|
|
396
|
+
raise
|
|
397
|
+
except Exception as e:
|
|
398
|
+
logger.error(f"Failed to get data source status: {e}")
|
|
399
|
+
return None
|
|
400
|
+
|
|
401
|
+
async def delete_data_source(self, source_id: str) -> bool:
|
|
402
|
+
"""Delete a data source."""
|
|
403
|
+
try:
|
|
404
|
+
response = await self.client.delete(f"{self.base_url}/v2/data-sources/{source_id}")
|
|
405
|
+
response.raise_for_status()
|
|
406
|
+
return True
|
|
407
|
+
except Exception as e:
|
|
408
|
+
logger.error(f"Failed to delete data source: {e}")
|
|
409
|
+
return False
|
|
410
|
+
|
|
411
|
+
async def query_unified(
|
|
412
|
+
self,
|
|
413
|
+
messages: List[Dict[str, str]],
|
|
414
|
+
repositories: List[str] = None,
|
|
415
|
+
data_sources: List[str] = None,
|
|
416
|
+
search_mode: str = "unified",
|
|
417
|
+
stream: bool = True,
|
|
418
|
+
include_sources: bool = True
|
|
419
|
+
) -> AsyncIterator[str]:
|
|
420
|
+
"""Query across repositories and/or documentation sources."""
|
|
421
|
+
try:
|
|
422
|
+
# Build repository list
|
|
423
|
+
repo_list = []
|
|
424
|
+
if repositories:
|
|
425
|
+
for repo in repositories:
|
|
426
|
+
repo_list.append({"repository": repo})
|
|
427
|
+
|
|
428
|
+
# Build data source list
|
|
429
|
+
source_list = []
|
|
430
|
+
if data_sources:
|
|
431
|
+
for source_id in data_sources:
|
|
432
|
+
# Handle both list of IDs and list of dicts
|
|
433
|
+
if isinstance(source_id, dict):
|
|
434
|
+
source_list.append(source_id)
|
|
435
|
+
else:
|
|
436
|
+
source_list.append({"source_id": source_id})
|
|
437
|
+
|
|
438
|
+
# Validate at least one source
|
|
439
|
+
if not repo_list and not source_list:
|
|
440
|
+
raise Exception("No repositories or data sources specified")
|
|
441
|
+
|
|
442
|
+
payload = {
|
|
443
|
+
"messages": messages,
|
|
444
|
+
"repositories": repo_list,
|
|
445
|
+
"data_sources": source_list,
|
|
446
|
+
"search_mode": search_mode,
|
|
447
|
+
"stream": stream,
|
|
448
|
+
"include_sources": include_sources
|
|
449
|
+
}
|
|
450
|
+
|
|
451
|
+
if stream:
|
|
452
|
+
async with self.client.stream(
|
|
453
|
+
"POST",
|
|
454
|
+
f"{self.base_url}/v2/query",
|
|
455
|
+
json=payload
|
|
456
|
+
) as response:
|
|
457
|
+
response.raise_for_status()
|
|
458
|
+
|
|
459
|
+
async for line in response.aiter_lines():
|
|
460
|
+
if line.strip():
|
|
461
|
+
if line.startswith("data: "):
|
|
462
|
+
data = line[6:] # Remove "data: " prefix
|
|
463
|
+
if data == "[DONE]":
|
|
464
|
+
break
|
|
465
|
+
yield data
|
|
466
|
+
else:
|
|
467
|
+
response = await self.client.post(
|
|
468
|
+
f"{self.base_url}/v2/query",
|
|
469
|
+
json=payload
|
|
470
|
+
)
|
|
471
|
+
response.raise_for_status()
|
|
472
|
+
yield json.dumps(response.json())
|
|
473
|
+
|
|
474
|
+
except httpx.HTTPStatusError as e:
|
|
475
|
+
raise self._handle_api_error(e)
|
|
476
|
+
except Exception as e:
|
|
477
|
+
raise APIError(f"Query failed: {str(e)}")
|