kailash 0.8.6__py3-none-any.whl → 0.8.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1560 @@
1
+ """MCP server resource subscription implementation."""
2
+
3
+ import asyncio
4
+ import fnmatch
5
+ import hashlib
6
+ import json
7
+ import logging
8
+ import uuid
9
+ import weakref
10
+ from abc import ABC, abstractmethod
11
+ from dataclasses import dataclass, field
12
+ from datetime import datetime, timedelta
13
+ from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Union
14
+
15
+ # Optional Redis support
16
+ try:
17
+ import redis.asyncio as redis
18
+
19
+ REDIS_AVAILABLE = True
20
+ except ImportError:
21
+ redis = None
22
+ REDIS_AVAILABLE = False
23
+
24
+ from .auth import AuthManager
25
+ from .auth import PermissionError as PermissionDeniedError
26
+ from .protocol import ResourceChange, ResourceChangeType
27
+
28
+
29
+ class SubscriptionError(Exception):
30
+ """Raised when subscription operations fail."""
31
+
32
+ pass
33
+
34
+
35
+ class TransformationError(Exception):
36
+ """Raised when resource transformation fails."""
37
+
38
+ pass
39
+
40
+
41
+ class ResourceTransformer(ABC):
42
+ """Abstract base class for resource transformations."""
43
+
44
+ @abstractmethod
45
+ async def transform(
46
+ self, resource_data: Dict[str, Any], context: Dict[str, Any]
47
+ ) -> Dict[str, Any]:
48
+ """Transform resource data.
49
+
50
+ Args:
51
+ resource_data: The original resource data
52
+ context: Additional context (subscription info, user info, etc.)
53
+
54
+ Returns:
55
+ Transformed resource data
56
+ """
57
+ pass
58
+
59
+ @abstractmethod
60
+ def should_apply(self, uri: str, subscription: "ResourceSubscription") -> bool:
61
+ """Determine if this transformer should be applied.
62
+
63
+ Args:
64
+ uri: Resource URI
65
+ subscription: The subscription requesting the resource
66
+
67
+ Returns:
68
+ True if transformer should be applied
69
+ """
70
+ pass
71
+
72
+
73
+ class DataEnrichmentTransformer(ResourceTransformer):
74
+ """Transformer that adds computed fields and metadata."""
75
+
76
+ def __init__(
77
+ self, enrichment_functions: Dict[str, Callable[[Dict[str, Any]], Any]] = None
78
+ ):
79
+ self.enrichment_functions = enrichment_functions or {}
80
+
81
+ def add_enrichment(
82
+ self, field_name: str, function: Callable[[Dict[str, Any]], Any]
83
+ ):
84
+ """Add an enrichment function for a specific field."""
85
+ self.enrichment_functions[field_name] = function
86
+
87
+ async def transform(
88
+ self, resource_data: Dict[str, Any], context: Dict[str, Any]
89
+ ) -> Dict[str, Any]:
90
+ """Add enriched fields to resource data."""
91
+ enriched_data = resource_data.copy()
92
+
93
+ # Add computed fields
94
+ for field_name, function in self.enrichment_functions.items():
95
+ try:
96
+ if asyncio.iscoroutinefunction(function):
97
+ enriched_data[field_name] = await function(resource_data)
98
+ else:
99
+ enriched_data[field_name] = function(resource_data)
100
+ except Exception as e:
101
+ # Log error but continue processing
102
+ pass
103
+
104
+ # Add transformation metadata
105
+ enriched_data["__transformation"] = {
106
+ "timestamp": datetime.utcnow().isoformat(),
107
+ "enriched_fields": list(self.enrichment_functions.keys()),
108
+ "transformer": "DataEnrichmentTransformer",
109
+ }
110
+
111
+ return enriched_data
112
+
113
+ def should_apply(self, uri: str, subscription: "ResourceSubscription") -> bool:
114
+ """Apply to all resources that have enrichment functions."""
115
+ return len(self.enrichment_functions) > 0
116
+
117
+
118
+ class FormatConverterTransformer(ResourceTransformer):
119
+ """Transformer that converts between data formats."""
120
+
121
+ def __init__(self, conversions: Dict[str, Callable[[Any], Any]] = None):
122
+ self.conversions = conversions or {}
123
+
124
+ def add_conversion(self, field_pattern: str, converter: Callable[[Any], Any]):
125
+ """Add a conversion function for fields matching a pattern."""
126
+ self.conversions[field_pattern] = converter
127
+
128
+ async def transform(
129
+ self, resource_data: Dict[str, Any], context: Dict[str, Any]
130
+ ) -> Dict[str, Any]:
131
+ """Apply format conversions to matching fields."""
132
+ converted_data = await self._apply_conversions(resource_data, "")
133
+
134
+ # Add transformation metadata
135
+ converted_data["__transformation"] = {
136
+ "timestamp": datetime.utcnow().isoformat(),
137
+ "conversions_applied": list(self.conversions.keys()),
138
+ "transformer": "FormatConverterTransformer",
139
+ }
140
+
141
+ return converted_data
142
+
143
+ async def _apply_conversions(self, data: Any, path: str) -> Any:
144
+ """Recursively apply conversions to nested data."""
145
+ if isinstance(data, dict):
146
+ result = {}
147
+ for key, value in data.items():
148
+ field_path = f"{path}.{key}" if path else key
149
+
150
+ # Check if this field matches any conversion pattern
151
+ converted_value = value
152
+ for pattern, converter in self.conversions.items():
153
+ if fnmatch.fnmatch(field_path, pattern):
154
+ try:
155
+ if asyncio.iscoroutinefunction(converter):
156
+ converted_value = await converter(value)
157
+ else:
158
+ converted_value = converter(value)
159
+ break
160
+ except Exception:
161
+ # Keep original value on conversion error
162
+ pass
163
+
164
+ # Recursively process nested objects
165
+ result[key] = await self._apply_conversions(converted_value, field_path)
166
+ return result
167
+ elif isinstance(data, list):
168
+ return [
169
+ await self._apply_conversions(item, f"{path}[{i}]")
170
+ for i, item in enumerate(data)
171
+ ]
172
+ else:
173
+ return data
174
+
175
+ def should_apply(self, uri: str, subscription: "ResourceSubscription") -> bool:
176
+ """Apply to resources that have format conversions defined."""
177
+ return len(self.conversions) > 0
178
+
179
+
180
+ class AggregationTransformer(ResourceTransformer):
181
+ """Transformer that aggregates data from multiple sources."""
182
+
183
+ def __init__(self, data_sources: Dict[str, Callable[[str], Any]] = None):
184
+ self.data_sources = data_sources or {}
185
+
186
+ def add_data_source(self, source_name: str, fetcher: Callable[[str], Any]):
187
+ """Add a data source for aggregation."""
188
+ self.data_sources[source_name] = fetcher
189
+
190
+ async def transform(
191
+ self, resource_data: Dict[str, Any], context: Dict[str, Any]
192
+ ) -> Dict[str, Any]:
193
+ """Aggregate data from multiple sources."""
194
+ aggregated_data = resource_data.copy()
195
+
196
+ # Fetch data from additional sources
197
+ uri = resource_data.get("uri", "")
198
+ aggregated_sources = {}
199
+
200
+ for source_name, fetcher in self.data_sources.items():
201
+ try:
202
+ if asyncio.iscoroutinefunction(fetcher):
203
+ source_data = await fetcher(uri)
204
+ else:
205
+ source_data = fetcher(uri)
206
+ aggregated_sources[source_name] = source_data
207
+ except Exception as e:
208
+ # Log error but continue with other sources
209
+ aggregated_sources[source_name] = {"error": str(e)}
210
+
211
+ # Add aggregated data
212
+ aggregated_data["__aggregated"] = aggregated_sources
213
+
214
+ # Add transformation metadata
215
+ aggregated_data["__transformation"] = {
216
+ "timestamp": datetime.utcnow().isoformat(),
217
+ "sources": list(self.data_sources.keys()),
218
+ "transformer": "AggregationTransformer",
219
+ }
220
+
221
+ return aggregated_data
222
+
223
+ def should_apply(self, uri: str, subscription: "ResourceSubscription") -> bool:
224
+ """Apply to resources that have data sources defined."""
225
+ return len(self.data_sources) > 0
226
+
227
+
228
+ class TransformationPipeline:
229
+ """Manages a pipeline of resource transformations."""
230
+
231
+ def __init__(self):
232
+ self.transformers: List[ResourceTransformer] = []
233
+ self._enabled = True
234
+
235
+ def add_transformer(self, transformer: ResourceTransformer):
236
+ """Add a transformer to the pipeline."""
237
+ self.transformers.append(transformer)
238
+
239
+ def remove_transformer(self, transformer: ResourceTransformer):
240
+ """Remove a transformer from the pipeline."""
241
+ if transformer in self.transformers:
242
+ self.transformers.remove(transformer)
243
+
244
+ def clear(self):
245
+ """Remove all transformers."""
246
+ self.transformers.clear()
247
+
248
+ def enable(self):
249
+ """Enable the transformation pipeline."""
250
+ self._enabled = True
251
+
252
+ def disable(self):
253
+ """Disable the transformation pipeline."""
254
+ self._enabled = False
255
+
256
+ @property
257
+ def enabled(self) -> bool:
258
+ """Check if pipeline is enabled."""
259
+ return self._enabled
260
+
261
+ async def apply(
262
+ self,
263
+ resource_data: Dict[str, Any],
264
+ uri: str,
265
+ subscription: "ResourceSubscription",
266
+ ) -> Dict[str, Any]:
267
+ """Apply all applicable transformations to resource data.
268
+
269
+ Args:
270
+ resource_data: Original resource data
271
+ uri: Resource URI
272
+ subscription: Subscription requesting the resource
273
+
274
+ Returns:
275
+ Transformed resource data
276
+ """
277
+ if not self._enabled or not self.transformers:
278
+ return resource_data
279
+
280
+ # Create transformation context
281
+ context = {
282
+ "uri": uri,
283
+ "subscription_id": subscription.id,
284
+ "connection_id": subscription.connection_id,
285
+ "uri_pattern": subscription.uri_pattern,
286
+ "fields": subscription.fields,
287
+ "fragments": subscription.fragments,
288
+ "timestamp": datetime.utcnow().isoformat(),
289
+ }
290
+
291
+ transformed_data = resource_data
292
+
293
+ # Collect pipeline-level transformation metadata
294
+ pipeline_metadata = {"errors": [], "applied_transformers": []}
295
+
296
+ # Apply each transformer in sequence
297
+ for transformer in self.transformers:
298
+ if transformer.should_apply(uri, subscription):
299
+ try:
300
+ new_transformed_data = await transformer.transform(
301
+ transformed_data, context
302
+ )
303
+
304
+ # Preserve any existing pipeline errors when merging transformation metadata
305
+ if (
306
+ "__transformation" in new_transformed_data
307
+ and "errors" in pipeline_metadata
308
+ ):
309
+ existing_errors = pipeline_metadata["errors"]
310
+ new_transformation_metadata = new_transformed_data.get(
311
+ "__transformation", {}
312
+ )
313
+
314
+ # Merge errors if the new transformer also has transformation metadata
315
+ if "errors" in new_transformation_metadata:
316
+ pipeline_metadata["errors"].extend(
317
+ new_transformation_metadata["errors"]
318
+ )
319
+
320
+ # Update the new data with preserved errors
321
+ new_transformed_data["__transformation"]["errors"] = (
322
+ pipeline_metadata["errors"]
323
+ )
324
+
325
+ transformed_data = new_transformed_data
326
+ pipeline_metadata["applied_transformers"].append(
327
+ transformer.__class__.__name__
328
+ )
329
+
330
+ except Exception as e:
331
+ # Log transformation error but continue with pipeline
332
+ transformation_error = {
333
+ "transformer": transformer.__class__.__name__,
334
+ "error": str(e),
335
+ "timestamp": datetime.utcnow().isoformat(),
336
+ }
337
+ pipeline_metadata["errors"].append(transformation_error)
338
+
339
+ # Add pipeline metadata
340
+ if pipeline_metadata["applied_transformers"] or pipeline_metadata["errors"]:
341
+ if "__transformation" not in transformed_data:
342
+ transformed_data["__transformation"] = {}
343
+
344
+ # Add pipeline-level errors
345
+ if pipeline_metadata["errors"]:
346
+ transformed_data["__transformation"]["errors"] = pipeline_metadata[
347
+ "errors"
348
+ ]
349
+
350
+ # Add pipeline summary
351
+ transformed_data["__transformation"]["pipeline"] = {
352
+ "applied_transformers": pipeline_metadata["applied_transformers"],
353
+ "total_transformers": len(self.transformers),
354
+ "enabled": self._enabled,
355
+ "errors_count": len(pipeline_metadata["errors"]),
356
+ }
357
+
358
+ return transformed_data
359
+
360
+
361
+ @dataclass
362
+ class ResourceSubscription:
363
+ """Represents a resource subscription with GraphQL-style field selection."""
364
+
365
+ id: str
366
+ connection_id: str
367
+ uri_pattern: str
368
+ cursor: Optional[str] = None
369
+ created_at: datetime = field(default_factory=datetime.utcnow)
370
+ # GraphQL-style field selection
371
+ fields: Optional[List[str]] = None # e.g., ["uri", "content.text", "metadata.size"]
372
+ fragments: Optional[Dict[str, List[str]]] = (
373
+ None # e.g., {"basicInfo": ["uri", "name"]}
374
+ )
375
+
376
+ def matches_uri(self, uri: str) -> bool:
377
+ """Check if URI matches subscription pattern.
378
+
379
+ Supports:
380
+ - Single wildcard (*) - matches within directory
381
+ - Double wildcard (**) - matches across directories
382
+ - Extension patterns (*.json, *.md)
383
+ """
384
+ # Convert ** to a pattern that matches across directories
385
+ pattern = self.uri_pattern.replace("**", "|||DOUBLESTAR|||")
386
+ pattern = pattern.replace("*", "|||STAR|||")
387
+
388
+ # Escape special characters except our placeholders
389
+ pattern = pattern.replace(".", r"\.")
390
+ pattern = pattern.replace("?", r"\?")
391
+ pattern = pattern.replace("[", r"\[")
392
+ pattern = pattern.replace("]", r"\]")
393
+
394
+ # Convert back to regex
395
+ pattern = pattern.replace("|||DOUBLESTAR|||", ".*")
396
+ pattern = pattern.replace("|||STAR|||", "[^/]*")
397
+
398
+ # Add anchors
399
+ pattern = f"^{pattern}$"
400
+
401
+ import re
402
+
403
+ return bool(re.match(pattern, uri))
404
+
405
+ def apply_field_selection(self, resource_data: Dict[str, Any]) -> Dict[str, Any]:
406
+ """Apply GraphQL-style field selection to resource data.
407
+
408
+ Args:
409
+ resource_data: Full resource data
410
+
411
+ Returns:
412
+ Filtered resource data based on field selection
413
+ """
414
+ if not self.fields and not self.fragments:
415
+ # No field selection specified, return all data
416
+ return resource_data
417
+
418
+ result = {}
419
+
420
+ # Process direct field selections
421
+ if self.fields:
422
+ for field_path in self.fields:
423
+ value = self._extract_field_value(resource_data, field_path)
424
+ if value is not None:
425
+ self._set_nested_value(result, field_path, value)
426
+
427
+ # Process fragment selections
428
+ if self.fragments:
429
+ for fragment_name, fragment_fields in self.fragments.items():
430
+ fragment_data = {}
431
+ for field_path in fragment_fields:
432
+ value = self._extract_field_value(resource_data, field_path)
433
+ if value is not None:
434
+ self._set_nested_value(fragment_data, field_path, value)
435
+
436
+ if fragment_data:
437
+ result[f"__{fragment_name}"] = fragment_data
438
+
439
+ return result
440
+
441
+ def _extract_field_value(self, data: Dict[str, Any], field_path: str) -> Any:
442
+ """Extract field value using dot notation (e.g., 'content.text')."""
443
+ parts = field_path.split(".")
444
+ current = data
445
+
446
+ for part in parts:
447
+ if isinstance(current, dict) and part in current:
448
+ current = current[part]
449
+ else:
450
+ return None
451
+
452
+ return current
453
+
454
+ def _set_nested_value(self, data: Dict[str, Any], field_path: str, value: Any):
455
+ """Set nested value using dot notation."""
456
+ parts = field_path.split(".")
457
+ current = data
458
+
459
+ # Navigate to parent
460
+ for part in parts[:-1]:
461
+ if part not in current:
462
+ current[part] = {}
463
+ current = current[part]
464
+
465
+ # Set final value
466
+ current[parts[-1]] = value
467
+
468
+
469
+ class CursorManager:
470
+ """Manages cursor generation and validation for pagination."""
471
+
472
+ def __init__(self, ttl_seconds: int = 3600):
473
+ self.ttl_seconds = ttl_seconds
474
+ self._cursors: Dict[str, Dict[str, Any]] = {}
475
+ self._lock = asyncio.Lock()
476
+
477
+ def generate_cursor(self) -> str:
478
+ """Generate a unique cursor."""
479
+ cursor_id = str(uuid.uuid4())
480
+ timestamp = datetime.utcnow()
481
+
482
+ cursor_data = f"{cursor_id}:{timestamp.isoformat()}"
483
+ cursor = hashlib.sha256(cursor_data.encode()).hexdigest()[:16]
484
+
485
+ self._cursors[cursor] = {"created_at": timestamp, "data": {}}
486
+
487
+ return cursor
488
+
489
+ def create_cursor_for_position(self, items: List[Any], position: int) -> str:
490
+ """Create cursor for specific position in list."""
491
+ cursor = self.generate_cursor()
492
+ self._cursors[cursor]["data"]["position"] = position
493
+ self._cursors[cursor]["data"]["items_hash"] = hashlib.sha256(
494
+ str(items).encode()
495
+ ).hexdigest()[:8]
496
+ return cursor
497
+
498
+ def is_valid(self, cursor: str) -> bool:
499
+ """Check if cursor is valid and not expired."""
500
+ if cursor not in self._cursors:
501
+ return False
502
+
503
+ cursor_data = self._cursors[cursor]
504
+ age = datetime.utcnow() - cursor_data["created_at"]
505
+
506
+ if age > timedelta(seconds=self.ttl_seconds):
507
+ # Clean up expired cursor
508
+ del self._cursors[cursor]
509
+ return False
510
+
511
+ return True
512
+
513
+ def get_cursor_position(self, cursor: str) -> Optional[int]:
514
+ """Get position from cursor if valid."""
515
+ if not self.is_valid(cursor):
516
+ return None
517
+
518
+ return self._cursors[cursor]["data"].get("position")
519
+
520
+ async def cleanup_expired(self):
521
+ """Remove expired cursors."""
522
+ async with self._lock:
523
+ now = datetime.utcnow()
524
+ expired = []
525
+
526
+ for cursor, data in self._cursors.items():
527
+ age = now - data["created_at"]
528
+ if age > timedelta(seconds=self.ttl_seconds):
529
+ expired.append(cursor)
530
+
531
+ for cursor in expired:
532
+ del self._cursors[cursor]
533
+
534
+
535
+ class ResourceMonitor:
536
+ """Monitors resources for changes."""
537
+
538
+ def __init__(self):
539
+ self._resource_states: Dict[str, Dict[str, Any]] = {}
540
+ self._lock = asyncio.Lock()
541
+
542
+ def _compute_hash(self, content: Dict[str, Any]) -> str:
543
+ """Compute hash of resource content."""
544
+ # Sort keys for consistent hashing
545
+ sorted_content = json.dumps(content, sort_keys=True)
546
+ return hashlib.sha256(sorted_content.encode()).hexdigest()
547
+
548
+ async def register_resource(self, uri: str, content: Dict[str, Any]):
549
+ """Register resource for monitoring."""
550
+ async with self._lock:
551
+ self._resource_states[uri] = {
552
+ "hash": self._compute_hash(content),
553
+ "content": content,
554
+ "last_checked": datetime.utcnow(),
555
+ }
556
+
557
+ def is_monitored(self, uri: str) -> bool:
558
+ """Check if resource is being monitored."""
559
+ return uri in self._resource_states
560
+
561
+ async def check_for_changes(
562
+ self, uri: str, content: Dict[str, Any]
563
+ ) -> Optional[ResourceChange]:
564
+ """Check if resource has changed."""
565
+ async with self._lock:
566
+ new_hash = self._compute_hash(content)
567
+
568
+ if uri not in self._resource_states:
569
+ # New resource
570
+ self._resource_states[uri] = {
571
+ "hash": new_hash,
572
+ "content": content,
573
+ "last_checked": datetime.utcnow(),
574
+ }
575
+ return ResourceChange(
576
+ type=ResourceChangeType.CREATED,
577
+ uri=uri,
578
+ timestamp=datetime.utcnow(),
579
+ )
580
+
581
+ old_hash = self._resource_states[uri]["hash"]
582
+
583
+ if old_hash != new_hash:
584
+ # Resource updated
585
+ self._resource_states[uri] = {
586
+ "hash": new_hash,
587
+ "content": content,
588
+ "last_checked": datetime.utcnow(),
589
+ }
590
+ return ResourceChange(
591
+ type=ResourceChangeType.UPDATED,
592
+ uri=uri,
593
+ timestamp=datetime.utcnow(),
594
+ )
595
+
596
+ # No change
597
+ self._resource_states[uri]["last_checked"] = datetime.utcnow()
598
+ return None
599
+
600
+ async def check_for_deletion(self, uri: str) -> Optional[ResourceChange]:
601
+ """Mark resource as deleted."""
602
+ async with self._lock:
603
+ if uri in self._resource_states:
604
+ del self._resource_states[uri]
605
+ return ResourceChange(
606
+ type=ResourceChangeType.DELETED,
607
+ uri=uri,
608
+ timestamp=datetime.utcnow(),
609
+ )
610
+ return None
611
+
612
+
613
+ class ResourceSubscriptionManager:
614
+ """Manages resource subscriptions."""
615
+
616
+ def __init__(
617
+ self,
618
+ auth_manager: Optional[AuthManager] = None,
619
+ event_store=None,
620
+ rate_limiter=None,
621
+ ):
622
+ self.auth_manager = auth_manager
623
+ self.event_store = event_store
624
+ self.rate_limiter = rate_limiter
625
+
626
+ # Subscription tracking
627
+ self._subscriptions: Dict[str, ResourceSubscription] = {}
628
+ self._connection_subscriptions: Dict[str, Set[str]] = {}
629
+ self._pattern_index: Dict[str, Set[str]] = {} # pattern -> subscription IDs
630
+
631
+ # Concurrency control
632
+ self._lock = asyncio.Lock()
633
+
634
+ # Notification callback
635
+ self._notification_callback: Optional[Callable] = None
636
+
637
+ # Resource monitoring
638
+ self.resource_monitor = ResourceMonitor()
639
+
640
+ # Cursor management
641
+ self.cursor_manager = CursorManager()
642
+
643
+ # Transformation pipeline
644
+ self.transformation_pipeline = TransformationPipeline()
645
+
646
+ # Cleanup task
647
+ self._cleanup_task = None
648
+
649
+ async def initialize(self):
650
+ """Initialize subscription manager."""
651
+ # Start periodic cleanup
652
+ self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
653
+
654
+ async def shutdown(self):
655
+ """Shutdown subscription manager."""
656
+ if self._cleanup_task:
657
+ self._cleanup_task.cancel()
658
+ try:
659
+ await self._cleanup_task
660
+ except asyncio.CancelledError:
661
+ pass
662
+
663
+ async def _periodic_cleanup(self):
664
+ """Periodically clean up expired cursors."""
665
+ while True:
666
+ try:
667
+ await asyncio.sleep(300) # Every 5 minutes
668
+ await self.cursor_manager.cleanup_expired()
669
+ except asyncio.CancelledError:
670
+ break
671
+ except Exception:
672
+ # Log error but continue
673
+ pass
674
+
675
+ def set_notification_callback(self, callback: Callable):
676
+ """Set callback for sending notifications."""
677
+ self._notification_callback = callback
678
+
679
+ async def create_subscription(
680
+ self,
681
+ connection_id: str,
682
+ uri_pattern: str,
683
+ user_context: Optional[Dict[str, Any]] = None,
684
+ cursor: Optional[str] = None,
685
+ fields: Optional[List[str]] = None,
686
+ fragments: Optional[Dict[str, List[str]]] = None,
687
+ ) -> str:
688
+ """Create a new subscription."""
689
+ # Check rate limit
690
+ if self.rate_limiter:
691
+ user_id = user_context.get("user_id") if user_context else connection_id
692
+ if not await self.rate_limiter.check_rate_limit(user_id):
693
+ raise SubscriptionError("Rate limit exceeded")
694
+
695
+ # Check permissions
696
+ if self.auth_manager and user_context:
697
+ try:
698
+ # Use authenticate_and_authorize method if available
699
+ if hasattr(self.auth_manager, "authenticate_and_authorize"):
700
+ await self.auth_manager.authenticate_and_authorize(
701
+ user_context, required_permission="subscribe"
702
+ )
703
+ else:
704
+ # Fallback for mocked auth managers
705
+ permission_check = await self.auth_manager.check_permission(
706
+ user_context.get("user_id"),
707
+ "subscribe",
708
+ {"resource_pattern": uri_pattern},
709
+ )
710
+ if not permission_check.get("authorized", False):
711
+ raise PermissionDeniedError(
712
+ "Not authorized to subscribe to resources"
713
+ )
714
+ except Exception as e:
715
+ raise PermissionDeniedError("Not authorized to subscribe to resources")
716
+
717
+ # Create subscription
718
+ sub_id = str(uuid.uuid4())
719
+ subscription = ResourceSubscription(
720
+ id=sub_id,
721
+ connection_id=connection_id,
722
+ uri_pattern=uri_pattern,
723
+ cursor=cursor,
724
+ fields=fields,
725
+ fragments=fragments,
726
+ )
727
+
728
+ async with self._lock:
729
+ # Store subscription
730
+ self._subscriptions[sub_id] = subscription
731
+
732
+ # Track by connection
733
+ if connection_id not in self._connection_subscriptions:
734
+ self._connection_subscriptions[connection_id] = set()
735
+ self._connection_subscriptions[connection_id].add(sub_id)
736
+
737
+ # Index by pattern
738
+ if uri_pattern not in self._pattern_index:
739
+ self._pattern_index[uri_pattern] = set()
740
+ self._pattern_index[uri_pattern].add(sub_id)
741
+
742
+ # Log to event store
743
+ if self.event_store:
744
+ from kailash.middleware.gateway.event_store import EventType
745
+
746
+ await self.event_store.append(
747
+ event_type=EventType.REQUEST_COMPLETED,
748
+ request_id=sub_id,
749
+ data={
750
+ "type": "subscription.created",
751
+ "subscription_id": sub_id,
752
+ "connection_id": connection_id,
753
+ "uri_pattern": uri_pattern,
754
+ "user_id": user_context.get("user_id") if user_context else None,
755
+ },
756
+ )
757
+
758
+ return sub_id
759
+
760
+ async def create_batch_subscriptions(
761
+ self,
762
+ subscriptions: List[Dict[str, Any]],
763
+ connection_id: str,
764
+ user_context: Optional[Dict[str, Any]] = None,
765
+ ) -> Dict[str, Any]:
766
+ """Create multiple subscriptions in a single batch operation.
767
+
768
+ Args:
769
+ subscriptions: List of subscription requests, each containing:
770
+ - uri_pattern: Resource URI pattern to subscribe to
771
+ - cursor: Optional cursor for pagination
772
+ - fields: Optional field selection
773
+ - fragments: Optional fragment selection
774
+ - subscription_name: Optional name for the subscription
775
+ connection_id: Client connection ID
776
+ user_context: Optional user context for authorization
777
+
778
+ Returns:
779
+ Dictionary with created subscription IDs and any errors
780
+ """
781
+ results = {
782
+ "successful": [],
783
+ "failed": [],
784
+ "total_requested": len(subscriptions),
785
+ "total_created": 0,
786
+ "total_failed": 0,
787
+ }
788
+
789
+ # Process each subscription request
790
+ for i, sub_request in enumerate(subscriptions):
791
+ try:
792
+ # Extract subscription parameters
793
+ uri_pattern = sub_request.get("uri_pattern")
794
+ if not uri_pattern:
795
+ results["failed"].append(
796
+ {
797
+ "index": i,
798
+ "error": "Missing required parameter: uri_pattern",
799
+ "request": sub_request,
800
+ }
801
+ )
802
+ results["total_failed"] += 1
803
+ continue
804
+
805
+ cursor = sub_request.get("cursor")
806
+ fields = sub_request.get("fields")
807
+ fragments = sub_request.get("fragments")
808
+ subscription_name = sub_request.get("subscription_name")
809
+
810
+ # Create individual subscription
811
+ subscription_id = await self.create_subscription(
812
+ connection_id=connection_id,
813
+ uri_pattern=uri_pattern,
814
+ user_context=user_context,
815
+ cursor=cursor,
816
+ fields=fields,
817
+ fragments=fragments,
818
+ )
819
+
820
+ # Record successful creation
821
+ results["successful"].append(
822
+ {
823
+ "index": i,
824
+ "subscription_id": subscription_id,
825
+ "uri_pattern": uri_pattern,
826
+ "subscription_name": subscription_name,
827
+ }
828
+ )
829
+ results["total_created"] += 1
830
+
831
+ except Exception as e:
832
+ # Record failed creation
833
+ results["failed"].append(
834
+ {"index": i, "error": str(e), "request": sub_request}
835
+ )
836
+ results["total_failed"] += 1
837
+
838
+ # Log batch operation to event store
839
+ if self.event_store:
840
+ from kailash.middleware.gateway.event_store import EventType
841
+
842
+ await self.event_store.append(
843
+ event_type=EventType.REQUEST_COMPLETED,
844
+ request_id=f"batch_subscribe_{connection_id}_{uuid.uuid4()}",
845
+ data={
846
+ "type": "batch_subscription_created",
847
+ "connection_id": connection_id,
848
+ "total_requested": results["total_requested"],
849
+ "total_created": results["total_created"],
850
+ "total_failed": results["total_failed"],
851
+ "user_id": user_context.get("user_id") if user_context else None,
852
+ },
853
+ )
854
+
855
+ return results
856
+
857
+ async def remove_subscription(
858
+ self, subscription_id: str, connection_id: str
859
+ ) -> bool:
860
+ """Remove a subscription."""
861
+ async with self._lock:
862
+ if subscription_id not in self._subscriptions:
863
+ return False
864
+
865
+ subscription = self._subscriptions[subscription_id]
866
+
867
+ # Verify ownership
868
+ if subscription.connection_id != connection_id:
869
+ return False
870
+
871
+ # Remove from indices
872
+ self._connection_subscriptions[connection_id].discard(subscription_id)
873
+ if not self._connection_subscriptions[connection_id]:
874
+ del self._connection_subscriptions[connection_id]
875
+
876
+ self._pattern_index[subscription.uri_pattern].discard(subscription_id)
877
+ if not self._pattern_index[subscription.uri_pattern]:
878
+ del self._pattern_index[subscription.uri_pattern]
879
+
880
+ # Remove subscription
881
+ del self._subscriptions[subscription_id]
882
+
883
+ # Log to event store
884
+ if self.event_store:
885
+ from kailash.middleware.gateway.event_store import EventType
886
+
887
+ await self.event_store.append(
888
+ event_type=EventType.REQUEST_COMPLETED,
889
+ request_id=subscription_id,
890
+ data={
891
+ "type": "subscription.removed",
892
+ "subscription_id": subscription_id,
893
+ "connection_id": connection_id,
894
+ },
895
+ )
896
+
897
+ return True
898
+
899
+ async def remove_batch_subscriptions(
900
+ self, subscription_ids: List[str], connection_id: str
901
+ ) -> Dict[str, Any]:
902
+ """Remove multiple subscriptions in a single batch operation.
903
+
904
+ Args:
905
+ subscription_ids: List of subscription IDs to remove
906
+ connection_id: Client connection ID
907
+
908
+ Returns:
909
+ Dictionary with removal results and any errors
910
+ """
911
+ results = {
912
+ "successful": [],
913
+ "failed": [],
914
+ "total_requested": len(subscription_ids),
915
+ "total_removed": 0,
916
+ "total_failed": 0,
917
+ }
918
+
919
+ # Process each unsubscribe request
920
+ for i, subscription_id in enumerate(subscription_ids):
921
+ try:
922
+ # Attempt to remove subscription
923
+ success = await self.remove_subscription(subscription_id, connection_id)
924
+
925
+ if success:
926
+ results["successful"].append(
927
+ {
928
+ "index": i,
929
+ "subscription_id": subscription_id,
930
+ "removed": True,
931
+ }
932
+ )
933
+ results["total_removed"] += 1
934
+ else:
935
+ results["failed"].append(
936
+ {
937
+ "index": i,
938
+ "subscription_id": subscription_id,
939
+ "error": "Subscription not found or not owned by connection",
940
+ }
941
+ )
942
+ results["total_failed"] += 1
943
+
944
+ except Exception as e:
945
+ # Record failed removal
946
+ results["failed"].append(
947
+ {"index": i, "subscription_id": subscription_id, "error": str(e)}
948
+ )
949
+ results["total_failed"] += 1
950
+
951
+ # Log batch operation to event store
952
+ if self.event_store:
953
+ from kailash.middleware.gateway.event_store import EventType
954
+
955
+ await self.event_store.append(
956
+ event_type=EventType.REQUEST_COMPLETED,
957
+ request_id=f"batch_unsubscribe_{connection_id}_{uuid.uuid4()}",
958
+ data={
959
+ "type": "batch_subscription_removed",
960
+ "connection_id": connection_id,
961
+ "total_requested": results["total_requested"],
962
+ "total_removed": results["total_removed"],
963
+ "total_failed": results["total_failed"],
964
+ },
965
+ )
966
+
967
+ return results
968
+
969
+ def get_subscription(self, subscription_id: str) -> Optional[ResourceSubscription]:
970
+ """Get subscription by ID."""
971
+ return self._subscriptions.get(subscription_id)
972
+
973
+ def get_connection_subscriptions(self, connection_id: str) -> Set[str]:
974
+ """Get all subscription IDs for a connection."""
975
+ return self._connection_subscriptions.get(connection_id, set()).copy()
976
+
977
+ async def cleanup_connection(self, connection_id: str) -> int:
978
+ """Remove all subscriptions for a connection."""
979
+ sub_ids = self.get_connection_subscriptions(connection_id)
980
+ removed = 0
981
+
982
+ for sub_id in sub_ids:
983
+ if await self.remove_subscription(sub_id, connection_id):
984
+ removed += 1
985
+
986
+ return removed
987
+
988
+ async def find_matching_subscriptions(self, uri: str) -> List[ResourceSubscription]:
989
+ """Find all subscriptions that match a URI."""
990
+ matching = []
991
+
992
+ async with self._lock:
993
+ for sub_id, subscription in self._subscriptions.items():
994
+ if subscription.matches_uri(uri):
995
+ matching.append(subscription)
996
+
997
+ return matching
998
+
999
+ async def process_resource_change(
1000
+ self, change: Union[ResourceChange, Dict[str, Any]]
1001
+ ):
1002
+ """Process a resource change and notify subscribers."""
1003
+ # Convert dict to ResourceChange if needed
1004
+ if isinstance(change, dict):
1005
+ change = ResourceChange(
1006
+ type=ResourceChangeType(change["type"]),
1007
+ uri=change["uri"],
1008
+ timestamp=datetime.fromisoformat(change["timestamp"]),
1009
+ )
1010
+
1011
+ # Find matching subscriptions
1012
+ matching_subs = await self.find_matching_subscriptions(change.uri)
1013
+
1014
+ if not matching_subs:
1015
+ return
1016
+
1017
+ # Send notifications per subscription (to apply individual transformations and field selection)
1018
+ if self._notification_callback:
1019
+ for subscription in matching_subs:
1020
+ # Get full resource data for transformation and field selection
1021
+ resource_data = await self._get_resource_data(change.uri)
1022
+
1023
+ if resource_data:
1024
+ # Apply transformation pipeline before field selection
1025
+ transformed_data = await self.transformation_pipeline.apply(
1026
+ resource_data, change.uri, subscription
1027
+ )
1028
+
1029
+ # Apply field selection to transformed data
1030
+ filtered_data = subscription.apply_field_selection(transformed_data)
1031
+ else:
1032
+ filtered_data = {}
1033
+
1034
+ # Create notification
1035
+ notification = {
1036
+ "jsonrpc": "2.0",
1037
+ "method": "notifications/resources/updated",
1038
+ "params": {
1039
+ "subscriptionId": subscription.id,
1040
+ "uri": change.uri,
1041
+ "type": change.type.value,
1042
+ "timestamp": change.timestamp.isoformat(),
1043
+ "data": filtered_data,
1044
+ },
1045
+ }
1046
+
1047
+ # Send to specific connection
1048
+ conn_id = subscription.connection_id
1049
+
1050
+ # Check if callback is async
1051
+ if asyncio.iscoroutinefunction(self._notification_callback):
1052
+ await self._notification_callback(conn_id, notification)
1053
+ else:
1054
+ self._notification_callback(conn_id, notification)
1055
+
1056
+ # Log to event store
1057
+ if self.event_store:
1058
+ from kailash.middleware.gateway.event_store import EventType
1059
+
1060
+ await self.event_store.append(
1061
+ event_type=EventType.REQUEST_COMPLETED,
1062
+ request_id=f"resource_change_{change.uri}",
1063
+ data={
1064
+ "type": "resource.changed",
1065
+ "uri": change.uri,
1066
+ "change_type": change.type.value,
1067
+ "timestamp": change.timestamp.isoformat(),
1068
+ "notified_subscriptions": len(matching_subs),
1069
+ },
1070
+ )
1071
+
1072
+ async def _get_resource_data(self, uri: str) -> Optional[Dict[str, Any]]:
1073
+ """Get full resource data for field selection.
1074
+
1075
+ This method should be overridden or configured to fetch actual resource data.
1076
+ For now, it returns basic resource information from the monitored state.
1077
+ """
1078
+ async with self._resource_monitor._lock:
1079
+ if uri in self._resource_monitor._resource_states:
1080
+ state = self._resource_monitor._resource_states[uri]
1081
+ return {
1082
+ "uri": uri,
1083
+ "content": state.get("content", {}),
1084
+ "metadata": {
1085
+ "hash": state.get("hash"),
1086
+ "last_checked": (
1087
+ state.get("last_checked", "").isoformat()
1088
+ if state.get("last_checked")
1089
+ else None
1090
+ ),
1091
+ "size": len(str(state.get("content", ""))),
1092
+ },
1093
+ }
1094
+
1095
+ # Fallback: return basic URI info
1096
+ return {"uri": uri, "content": {}, "metadata": {"available": False}}
1097
+
1098
+
1099
+ class DistributedSubscriptionManager(ResourceSubscriptionManager):
1100
+ """Redis-backed distributed subscription manager for multi-instance MCP servers.
1101
+
1102
+ This manager extends the base ResourceSubscriptionManager to support distributed
1103
+ deployments where multiple MCP server instances need to coordinate subscriptions
1104
+ and resource notifications across the cluster.
1105
+
1106
+ Key Features:
1107
+ - Shared subscription state across server instances
1108
+ - Distributed resource change notifications
1109
+ - Automatic failover when instances go down
1110
+ - Subscription replication and consistency
1111
+ - Cross-instance notification routing
1112
+ """
1113
+
1114
+ def __init__(
1115
+ self,
1116
+ redis_url: str = "redis://localhost:6379",
1117
+ redis_config: Optional[Dict[str, Any]] = None,
1118
+ server_instance_id: Optional[str] = None,
1119
+ subscription_key_prefix: str = "mcp:subs:",
1120
+ notification_channel_prefix: str = "mcp:notify:",
1121
+ heartbeat_interval: int = 30,
1122
+ instance_timeout: int = 90,
1123
+ **kwargs,
1124
+ ):
1125
+ """Initialize distributed subscription manager.
1126
+
1127
+ Args:
1128
+ redis_url: Redis connection URL
1129
+ redis_config: Additional Redis configuration
1130
+ server_instance_id: Unique ID for this server instance
1131
+ subscription_key_prefix: Redis key prefix for subscriptions
1132
+ notification_channel_prefix: Redis channel prefix for notifications
1133
+ heartbeat_interval: How often to send heartbeats (seconds)
1134
+ instance_timeout: When to consider an instance dead (seconds)
1135
+ **kwargs: Arguments passed to parent ResourceSubscriptionManager
1136
+ """
1137
+ super().__init__(**kwargs)
1138
+
1139
+ if not REDIS_AVAILABLE:
1140
+ raise ImportError(
1141
+ "Redis support not available. Install with: pip install redis"
1142
+ )
1143
+
1144
+ self.redis_url = redis_url
1145
+ self.redis_config = redis_config or {}
1146
+ self.server_instance_id = (
1147
+ server_instance_id or f"mcp_server_{uuid.uuid4().hex[:8]}"
1148
+ )
1149
+ self.subscription_key_prefix = subscription_key_prefix
1150
+ self.notification_channel_prefix = notification_channel_prefix
1151
+ self.heartbeat_interval = heartbeat_interval
1152
+ self.instance_timeout = instance_timeout
1153
+
1154
+ # Redis connections
1155
+ self.redis_client: Optional[redis.Redis] = None
1156
+ self.redis_pubsub: Optional[redis.Redis] = None
1157
+
1158
+ # Instance management
1159
+ self._heartbeat_task: Optional[asyncio.Task] = None
1160
+ self._notification_listener_task: Optional[asyncio.Task] = None
1161
+ self._instance_monitor_task: Optional[asyncio.Task] = None
1162
+
1163
+ # Distributed state
1164
+ self._other_instances: Set[str] = set()
1165
+ self._instance_subscriptions: Dict[str, Set[str]] = (
1166
+ {}
1167
+ ) # instance_id -> subscription_ids
1168
+
1169
+ self.logger = logging.getLogger(__name__)
1170
+
1171
+ async def initialize(self):
1172
+ """Initialize Redis connections and distributed state."""
1173
+ await super().initialize()
1174
+
1175
+ # Connect to Redis
1176
+ self.redis_client = redis.Redis.from_url(
1177
+ self.redis_url, decode_responses=True, **self.redis_config
1178
+ )
1179
+
1180
+ # Separate connection for pub/sub
1181
+ self.redis_pubsub = redis.Redis.from_url(
1182
+ self.redis_url, decode_responses=True, **self.redis_config
1183
+ )
1184
+
1185
+ # Test connections
1186
+ try:
1187
+ await self.redis_client.ping()
1188
+ await self.redis_pubsub.ping()
1189
+ self.logger.info(f"Connected to Redis at {self.redis_url}")
1190
+ except Exception as e:
1191
+ self.logger.error(f"Failed to connect to Redis: {e}")
1192
+ raise
1193
+
1194
+ # Register this instance
1195
+ await self._register_instance()
1196
+
1197
+ # Start background tasks
1198
+ self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
1199
+ self._notification_listener_task = asyncio.create_task(
1200
+ self._notification_listener()
1201
+ )
1202
+ self._instance_monitor_task = asyncio.create_task(self._instance_monitor())
1203
+
1204
+ # Load existing distributed subscriptions
1205
+ await self._load_distributed_subscriptions()
1206
+
1207
+ self.logger.info(
1208
+ f"Distributed subscription manager initialized (instance: {self.server_instance_id})"
1209
+ )
1210
+
1211
+ async def shutdown(self):
1212
+ """Shutdown distributed subscription manager."""
1213
+ # Cancel background tasks
1214
+ for task in [
1215
+ self._heartbeat_task,
1216
+ self._notification_listener_task,
1217
+ self._instance_monitor_task,
1218
+ ]:
1219
+ if task and not task.done():
1220
+ task.cancel()
1221
+ try:
1222
+ await task
1223
+ except asyncio.CancelledError:
1224
+ pass
1225
+
1226
+ # Unregister instance
1227
+ await self._unregister_instance()
1228
+
1229
+ # Close Redis connections
1230
+ if self.redis_client:
1231
+ await self.redis_client.aclose()
1232
+ if self.redis_pubsub:
1233
+ await self.redis_pubsub.aclose()
1234
+
1235
+ await super().shutdown()
1236
+ self.logger.info(
1237
+ f"Distributed subscription manager shutdown (instance: {self.server_instance_id})"
1238
+ )
1239
+
1240
+ async def create_subscription(
1241
+ self, connection_id: str, uri_pattern: str, **kwargs
1242
+ ) -> str:
1243
+ """Create subscription and replicate to Redis."""
1244
+ # Create local subscription
1245
+ subscription_id = await super().create_subscription(
1246
+ connection_id, uri_pattern, **kwargs
1247
+ )
1248
+
1249
+ # Replicate to Redis
1250
+ await self._replicate_subscription_to_redis(subscription_id)
1251
+
1252
+ return subscription_id
1253
+
1254
+ async def remove_subscription(
1255
+ self, subscription_id: str, connection_id: str
1256
+ ) -> bool:
1257
+ """Remove subscription and update Redis."""
1258
+ # Remove local subscription
1259
+ success = await super().remove_subscription(subscription_id, connection_id)
1260
+
1261
+ if success:
1262
+ # Remove from Redis
1263
+ await self._remove_subscription_from_redis(subscription_id)
1264
+
1265
+ return success
1266
+
1267
+ async def process_resource_change(
1268
+ self, change: Union[ResourceChange, Dict[str, Any]]
1269
+ ):
1270
+ """Process resource change and distribute notifications across instances."""
1271
+ # Process locally first
1272
+ await super().process_resource_change(change)
1273
+
1274
+ # Distribute to other instances via Redis
1275
+ await self._distribute_resource_change(change)
1276
+
1277
+ async def _register_instance(self):
1278
+ """Register this server instance in Redis."""
1279
+ instance_key = f"mcp:instances:{self.server_instance_id}"
1280
+ instance_data = {
1281
+ "id": self.server_instance_id,
1282
+ "registered_at": datetime.utcnow().isoformat(),
1283
+ "last_heartbeat": datetime.utcnow().isoformat(),
1284
+ "subscriptions": 0,
1285
+ }
1286
+
1287
+ # Set with expiration
1288
+ await self.redis_client.hset(instance_key, mapping=instance_data)
1289
+ await self.redis_client.expire(instance_key, self.instance_timeout)
1290
+
1291
+ self.logger.info(f"Registered instance {self.server_instance_id}")
1292
+
1293
+ async def _unregister_instance(self):
1294
+ """Unregister this server instance from Redis."""
1295
+ instance_key = f"mcp:instances:{self.server_instance_id}"
1296
+ await self.redis_client.delete(instance_key)
1297
+
1298
+ # Clean up instance subscriptions
1299
+ await self.redis_client.delete(f"mcp:instance_subs:{self.server_instance_id}")
1300
+
1301
+ self.logger.info(f"Unregistered instance {self.server_instance_id}")
1302
+
1303
+ async def _heartbeat_loop(self):
1304
+ """Send periodic heartbeats to indicate this instance is alive."""
1305
+ while True:
1306
+ try:
1307
+ await asyncio.sleep(self.heartbeat_interval)
1308
+
1309
+ instance_key = f"mcp:instances:{self.server_instance_id}"
1310
+ await self.redis_client.hset(
1311
+ instance_key, "last_heartbeat", datetime.utcnow().isoformat()
1312
+ )
1313
+ await self.redis_client.expire(instance_key, self.instance_timeout)
1314
+
1315
+ # Update subscription count
1316
+ sub_count = len(self._subscriptions)
1317
+ await self.redis_client.hset(instance_key, "subscriptions", sub_count)
1318
+
1319
+ self.logger.debug(
1320
+ f"Heartbeat sent (instance: {self.server_instance_id}, subscriptions: {sub_count})"
1321
+ )
1322
+
1323
+ except asyncio.CancelledError:
1324
+ break
1325
+ except Exception as e:
1326
+ self.logger.error(f"Heartbeat error: {e}")
1327
+
1328
+ async def _instance_monitor(self):
1329
+ """Monitor other instances and handle failures."""
1330
+ while True:
1331
+ try:
1332
+ await asyncio.sleep(self.heartbeat_interval)
1333
+
1334
+ # Get all instances
1335
+ instance_keys = await self.redis_client.keys("mcp:instances:*")
1336
+ current_instances = set()
1337
+
1338
+ for key in instance_keys:
1339
+ instance_data = await self.redis_client.hgetall(key)
1340
+ if not instance_data:
1341
+ continue
1342
+
1343
+ instance_id = instance_data.get("id")
1344
+ if instance_id == self.server_instance_id:
1345
+ continue
1346
+
1347
+ last_heartbeat = instance_data.get("last_heartbeat")
1348
+ if last_heartbeat:
1349
+ try:
1350
+ heartbeat_time = datetime.fromisoformat(last_heartbeat)
1351
+ age = (datetime.utcnow() - heartbeat_time).total_seconds()
1352
+
1353
+ if age < self.instance_timeout:
1354
+ current_instances.add(instance_id)
1355
+ else:
1356
+ # Instance is dead, clean up
1357
+ await self._cleanup_dead_instance(instance_id)
1358
+ except ValueError:
1359
+ pass
1360
+
1361
+ # Update known instances
1362
+ new_instances = current_instances - self._other_instances
1363
+ dead_instances = self._other_instances - current_instances
1364
+
1365
+ if new_instances:
1366
+ self.logger.info(f"New instances detected: {new_instances}")
1367
+
1368
+ if dead_instances:
1369
+ self.logger.info(f"Dead instances detected: {dead_instances}")
1370
+
1371
+ self._other_instances = current_instances
1372
+
1373
+ except asyncio.CancelledError:
1374
+ break
1375
+ except Exception as e:
1376
+ self.logger.error(f"Instance monitor error: {e}")
1377
+
1378
+ async def _cleanup_dead_instance(self, instance_id: str):
1379
+ """Clean up subscriptions from a dead instance."""
1380
+ try:
1381
+ # Get subscriptions for dead instance
1382
+ instance_subs_key = f"mcp:instance_subs:{instance_id}"
1383
+ dead_subscriptions = await self.redis_client.smembers(instance_subs_key)
1384
+
1385
+ # Remove subscription data
1386
+ if dead_subscriptions:
1387
+ pipeline = self.redis_client.pipeline()
1388
+ for sub_id in dead_subscriptions:
1389
+ sub_key = f"{self.subscription_key_prefix}{sub_id}"
1390
+ pipeline.delete(sub_key)
1391
+
1392
+ pipeline.delete(instance_subs_key)
1393
+ await pipeline.execute()
1394
+
1395
+ self.logger.info(
1396
+ f"Cleaned up {len(dead_subscriptions)} subscriptions from dead instance {instance_id}"
1397
+ )
1398
+
1399
+ # Remove instance record
1400
+ await self.redis_client.delete(f"mcp:instances:{instance_id}")
1401
+
1402
+ except Exception as e:
1403
+ self.logger.error(f"Error cleaning up dead instance {instance_id}: {e}")
1404
+
1405
+ async def _replicate_subscription_to_redis(self, subscription_id: str):
1406
+ """Replicate subscription data to Redis."""
1407
+ subscription = self._subscriptions.get(subscription_id)
1408
+ if not subscription:
1409
+ return
1410
+
1411
+ # Serialize subscription data
1412
+ sub_data = {
1413
+ "id": subscription.id,
1414
+ "connection_id": subscription.connection_id,
1415
+ "uri_pattern": subscription.uri_pattern,
1416
+ "cursor": subscription.cursor or "",
1417
+ "created_at": subscription.created_at.isoformat(),
1418
+ "fields": json.dumps(subscription.fields or []),
1419
+ "fragments": json.dumps(subscription.fragments or {}),
1420
+ "server_instance": self.server_instance_id,
1421
+ }
1422
+
1423
+ # Store in Redis
1424
+ sub_key = f"{self.subscription_key_prefix}{subscription_id}"
1425
+ await self.redis_client.hset(sub_key, mapping=sub_data)
1426
+
1427
+ # Track instance subscriptions
1428
+ instance_subs_key = f"mcp:instance_subs:{self.server_instance_id}"
1429
+ await self.redis_client.sadd(instance_subs_key, subscription_id)
1430
+
1431
+ self.logger.debug(f"Replicated subscription {subscription_id} to Redis")
1432
+
1433
+ async def _remove_subscription_from_redis(self, subscription_id: str):
1434
+ """Remove subscription data from Redis."""
1435
+ sub_key = f"{self.subscription_key_prefix}{subscription_id}"
1436
+ await self.redis_client.delete(sub_key)
1437
+
1438
+ # Remove from instance tracking
1439
+ instance_subs_key = f"mcp:instance_subs:{self.server_instance_id}"
1440
+ await self.redis_client.srem(instance_subs_key, subscription_id)
1441
+
1442
+ self.logger.debug(f"Removed subscription {subscription_id} from Redis")
1443
+
1444
+ async def _load_distributed_subscriptions(self):
1445
+ """Load subscription data for other instances from Redis."""
1446
+ # This is for awareness only - we don't process other instances' subscriptions locally
1447
+ # but we might need this information for coordination
1448
+ try:
1449
+ instance_keys = await self.redis_client.keys("mcp:instances:*")
1450
+
1451
+ for key in instance_keys:
1452
+ instance_data = await self.redis_client.hgetall(key)
1453
+ instance_id = instance_data.get("id")
1454
+
1455
+ if instance_id and instance_id != self.server_instance_id:
1456
+ # Load subscription IDs for this instance
1457
+ instance_subs_key = f"mcp:instance_subs:{instance_id}"
1458
+ sub_ids = await self.redis_client.smembers(instance_subs_key)
1459
+ self._instance_subscriptions[instance_id] = set(sub_ids)
1460
+
1461
+ total_distributed_subs = sum(
1462
+ len(subs) for subs in self._instance_subscriptions.values()
1463
+ )
1464
+ self.logger.info(
1465
+ f"Loaded {total_distributed_subs} distributed subscriptions from {len(self._instance_subscriptions)} instances"
1466
+ )
1467
+
1468
+ except Exception as e:
1469
+ self.logger.error(f"Error loading distributed subscriptions: {e}")
1470
+
1471
+ async def _distribute_resource_change(
1472
+ self, change: Union[ResourceChange, Dict[str, Any]]
1473
+ ):
1474
+ """Distribute resource change notification to other instances."""
1475
+ if not self._other_instances:
1476
+ return # No other instances to notify
1477
+
1478
+ # Convert to dict if needed
1479
+ if isinstance(change, ResourceChange):
1480
+ change_data = {
1481
+ "type": change.type.value,
1482
+ "uri": change.uri,
1483
+ "timestamp": change.timestamp.isoformat(),
1484
+ "source_instance": self.server_instance_id,
1485
+ }
1486
+ else:
1487
+ change_data = dict(change)
1488
+ change_data["source_instance"] = self.server_instance_id
1489
+
1490
+ # Publish to notification channel
1491
+ channel = f"{self.notification_channel_prefix}resource_changes"
1492
+ try:
1493
+ await self.redis_client.publish(channel, json.dumps(change_data))
1494
+ self.logger.debug(
1495
+ f"Distributed resource change for {change_data['uri']} to {len(self._other_instances)} instances"
1496
+ )
1497
+ except Exception as e:
1498
+ self.logger.error(f"Error distributing resource change: {e}")
1499
+
1500
+ async def _notification_listener(self):
1501
+ """Listen for distributed resource change notifications."""
1502
+ try:
1503
+ pubsub = self.redis_pubsub.pubsub()
1504
+ channel = f"{self.notification_channel_prefix}resource_changes"
1505
+ await pubsub.subscribe(channel)
1506
+
1507
+ self.logger.info(
1508
+ f"Listening for distributed notifications on channel: {channel}"
1509
+ )
1510
+
1511
+ async for message in pubsub.listen():
1512
+ if message["type"] == "message":
1513
+ try:
1514
+ change_data = json.loads(message["data"])
1515
+ source_instance = change_data.get("source_instance")
1516
+
1517
+ # Ignore notifications from ourselves
1518
+ if source_instance == self.server_instance_id:
1519
+ continue
1520
+
1521
+ # Process the resource change locally
1522
+ # This will check local subscriptions and send notifications
1523
+ change = ResourceChange(
1524
+ type=ResourceChangeType(change_data["type"]),
1525
+ uri=change_data["uri"],
1526
+ timestamp=datetime.fromisoformat(change_data["timestamp"]),
1527
+ )
1528
+
1529
+ # Process without re-distributing (to avoid loops)
1530
+ await super().process_resource_change(change)
1531
+
1532
+ self.logger.debug(
1533
+ f"Processed distributed resource change from {source_instance}: {change_data['uri']}"
1534
+ )
1535
+
1536
+ except Exception as e:
1537
+ self.logger.error(
1538
+ f"Error processing distributed notification: {e}"
1539
+ )
1540
+
1541
+ except asyncio.CancelledError:
1542
+ pass
1543
+ except Exception as e:
1544
+ self.logger.error(f"Notification listener error: {e}")
1545
+
1546
+ def get_distributed_stats(self) -> Dict[str, Any]:
1547
+ """Get statistics about the distributed subscription system."""
1548
+ return {
1549
+ "instance_id": self.server_instance_id,
1550
+ "local_subscriptions": len(self._subscriptions),
1551
+ "other_instances": len(self._other_instances),
1552
+ "distributed_subscriptions": {
1553
+ instance_id: len(subs)
1554
+ for instance_id, subs in self._instance_subscriptions.items()
1555
+ },
1556
+ "total_distributed_subscriptions": sum(
1557
+ len(subs) for subs in self._instance_subscriptions.values()
1558
+ ),
1559
+ "redis_url": self.redis_url,
1560
+ }