flock-core 0.4.512__py3-none-any.whl → 0.4.513__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flock-core might be problematic. Click here for more details.

@@ -1,7 +1,9 @@
1
1
  """Base Config for MCP Clients."""
2
2
 
3
- from typing import Literal, TypeVar
3
+ import importlib
4
+ from typing import Any, Literal, TypeVar
4
5
 
6
+ import httpx
5
7
  from pydantic import BaseModel, ConfigDict, Field, create_model
6
8
 
7
9
  from flock.core.mcp.types.types import (
@@ -11,6 +13,15 @@ from flock.core.mcp.types.types import (
11
13
  FlockSamplingMCPCallback,
12
14
  MCPRoot,
13
15
  ServerParameters,
16
+ SseServerParameters,
17
+ StdioServerParameters,
18
+ StreamableHttpServerParameters,
19
+ WebsocketServerParameters,
20
+ )
21
+ from flock.core.serialization.serializable import Serializable
22
+ from flock.core.serialization.serialization_utils import (
23
+ deserialize_item,
24
+ serialize_item,
14
25
  )
15
26
 
16
27
  LoggingLevel = Literal[
@@ -32,7 +43,7 @@ D = TypeVar("D", bound="FlockMCPCachingConfigurationBase")
32
43
  E = TypeVar("E", bound="FlockMCPFeatureConfigurationBase")
33
44
 
34
45
 
35
- class FlockMCPCachingConfigurationBase(BaseModel):
46
+ class FlockMCPCachingConfigurationBase(BaseModel, Serializable):
36
47
  """Configuration for Caching in Clients."""
37
48
 
38
49
  tool_cache_max_size: float = Field(
@@ -79,6 +90,18 @@ class FlockMCPCachingConfigurationBase(BaseModel):
79
90
  extra="allow",
80
91
  )
81
92
 
93
+ def to_dict(self, path_type: str = "relative"):
94
+ """Serialize the config object."""
95
+ return self.model_dump(
96
+ exclude_none=True,
97
+ mode="json",
98
+ )
99
+
100
+ @classmethod
101
+ def from_dict(cls: type[D], data: dict[str, Any]) -> D:
102
+ """Deserialize from a dict."""
103
+ return cls(**{k: v for k, v in data.items()})
104
+
82
105
  @classmethod
83
106
  def with_fields(cls: type[D], **field_definitions) -> type[D]:
84
107
  """Create a new config class with additional fields."""
@@ -87,7 +110,7 @@ class FlockMCPCachingConfigurationBase(BaseModel):
87
110
  )
88
111
 
89
112
 
90
- class FlockMCPCallbackConfigurationBase(BaseModel):
113
+ class FlockMCPCallbackConfigurationBase(BaseModel, Serializable):
91
114
  """Base Configuration Class for Callbacks for Clients."""
92
115
 
93
116
  sampling_callback: FlockSamplingMCPCallback | None = Field(
@@ -111,6 +134,52 @@ class FlockMCPCallbackConfigurationBase(BaseModel):
111
134
 
112
135
  model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
113
136
 
137
+ def to_dict(self, path_type: str = "relative"):
138
+ """Serialize the object."""
139
+ # we need to register callables.
140
+ data: dict[str, Any] = {}
141
+ if self.sampling_callback:
142
+ sampling_callback_data = serialize_item(self.sampling_callback)
143
+ data["sampling_callback"] = sampling_callback_data
144
+
145
+ if self.list_roots_callback:
146
+ list_roots_callback_data = serialize_item(self.list_roots_callback)
147
+ data["list_roots_callback"] = list_roots_callback_data
148
+
149
+ if self.logging_callback:
150
+ logging_callback_data = serialize_item(self.logging_callback)
151
+ data["logging_callback"] = logging_callback_data
152
+
153
+ if self.message_handler:
154
+ message_handler_data = serialize_item(self.message_handler)
155
+ data["message_handler"] = message_handler_data
156
+
157
+ return data
158
+
159
+ @classmethod
160
+ def from_dict(cls: type[A], data: dict[str, Any]) -> A:
161
+ """Deserialize from a dict."""
162
+ instance = cls()
163
+ if data:
164
+ if "sampling_callback" in data:
165
+ instance.sampling_callback = deserialize_item(
166
+ data["sampling_callback"]
167
+ )
168
+ if "list_roots_callback" in data:
169
+ instance.list_roots_callback = deserialize_item(
170
+ data["list_roots_callback"]
171
+ )
172
+ if "logging_callback" in data:
173
+ instance.logging_callback = deserialize_item(
174
+ data["logging_callback"]
175
+ )
176
+ if "message_handler" in data:
177
+ instance.message_handler = deserialize_item(
178
+ data["message_handler"]
179
+ )
180
+
181
+ return instance
182
+
114
183
  @classmethod
115
184
  def with_fields(cls: type[A], **field_definitions) -> type[A]:
116
185
  """Create a new config class with additional fields."""
@@ -119,7 +188,7 @@ class FlockMCPCallbackConfigurationBase(BaseModel):
119
188
  )
120
189
 
121
190
 
122
- class FlockMCPConnectionConfigurationBase(BaseModel):
191
+ class FlockMCPConnectionConfigurationBase(BaseModel, Serializable):
123
192
  """Base Configuration Class for Connection Parameters for a client."""
124
193
 
125
194
  max_retries: int = Field(
@@ -131,9 +200,9 @@ class FlockMCPConnectionConfigurationBase(BaseModel):
131
200
  ..., description="Connection parameters for the server."
132
201
  )
133
202
 
134
- transport_type: Literal["stdio", "websockets", "sse", "custom"] = Field(
135
- ..., description="Type of transport to use."
136
- )
203
+ transport_type: Literal[
204
+ "stdio", "websockets", "sse", "streamable_http", "custom"
205
+ ] = Field(..., description="Type of transport to use.")
137
206
 
138
207
  mount_points: list[MCPRoot] | None = Field(
139
208
  default=None, description="Initial Mountpoints to operate under."
@@ -150,6 +219,77 @@ class FlockMCPConnectionConfigurationBase(BaseModel):
150
219
 
151
220
  model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
152
221
 
222
+ def to_dict(self, path_type: str = "relative") -> dict[str, Any]:
223
+ """Serialize object to a dict."""
224
+ exclude = ["connection_parameters"]
225
+
226
+ data = self.model_dump(
227
+ exclude=exclude,
228
+ exclude_defaults=False,
229
+ exclude_none=True,
230
+ mode="json",
231
+ )
232
+
233
+ data["connection_parameters"] = self.connection_parameters.to_dict(
234
+ path_type=path_type
235
+ )
236
+
237
+ return data
238
+
239
+ @classmethod
240
+ def from_dict(cls: type[B], data: dict[str, Any]) -> B:
241
+ """Deserialize from dict."""
242
+ connection_params = data.get("connection_parameters")
243
+ connection_params_obj = None
244
+ auth_obj: httpx.Auth | None = None
245
+ if connection_params:
246
+ kind = connection_params.get("transport_type", None)
247
+ auth_spec = connection_params.get("auth", None)
248
+ if auth_spec:
249
+ # find the concrete implementation and
250
+ # instantiate it.
251
+ # find the concrete implementation for auth and instatiate it.
252
+ impl = auth_spec.get("implementation", None)
253
+ params = auth_spec.get("params", None)
254
+ if impl and params:
255
+ mod = importlib.import_module(impl["module_path"])
256
+ real_cls = getattr(mod, impl["class_name"])
257
+ auth_obj = real_cls(**{k: v for k, v in params.items()})
258
+
259
+ if auth_obj:
260
+ connection_params["auth"] = auth_obj
261
+ else:
262
+ # just to be sure
263
+ connection_params.pop("auth", None)
264
+ match kind:
265
+ case "stdio":
266
+ connection_params_obj = StdioServerParameters(
267
+ **{k: v for k, v in connection_params.items()}
268
+ )
269
+ case "websockets":
270
+ connection_params_obj = WebsocketServerParameters(
271
+ **{k: v for k, v in connection_params.items()}
272
+ )
273
+ case "streamable_http":
274
+ connection_params_obj = StreamableHttpServerParameters(
275
+ **{k: v for k, v in connection_params.items()}
276
+ )
277
+ case "sse":
278
+ connection_params_obj = SseServerParameters(
279
+ **{k: v for k, v in connection_params.items()}
280
+ )
281
+ case _:
282
+ # handle custom server params
283
+ connection_params_obj = ServerParameters(
284
+ **{k: v for k, v in connection_params.items()}
285
+ )
286
+
287
+ if connection_params_obj:
288
+ data["connection_parameters"] = connection_params_obj
289
+ return cls(**{k: v for k, v in data.items()})
290
+ else:
291
+ raise ValueError("No connection parameters provided.")
292
+
153
293
  @classmethod
154
294
  def with_fields(cls: type[B], **field_definitions) -> type[B]:
155
295
  """Create a new config class with additional fields."""
@@ -158,7 +298,7 @@ class FlockMCPConnectionConfigurationBase(BaseModel):
158
298
  )
159
299
 
160
300
 
161
- class FlockMCPFeatureConfigurationBase(BaseModel):
301
+ class FlockMCPFeatureConfigurationBase(BaseModel, Serializable):
162
302
  """Base Configuration Class for switching MCP Features on and off."""
163
303
 
164
304
  roots_enabled: bool = Field(
@@ -186,6 +326,18 @@ class FlockMCPFeatureConfigurationBase(BaseModel):
186
326
  extra="allow",
187
327
  )
188
328
 
329
+ def to_dict(self, path_type: str = "relative"):
330
+ """Serialize the object."""
331
+ return self.model_dump(
332
+ mode="json",
333
+ exclude_none=True,
334
+ )
335
+
336
+ @classmethod
337
+ def from_dict(cls, data: dict[str, Any]):
338
+ """Deserialize from a dict."""
339
+ return cls(**{k: v for k, v in data.items()})
340
+
189
341
  @classmethod
190
342
  def with_fields(cls: type[E], **field_definitions) -> type[E]:
191
343
  """Create a new config class with additional fields."""
@@ -194,7 +346,7 @@ class FlockMCPFeatureConfigurationBase(BaseModel):
194
346
  )
195
347
 
196
348
 
197
- class FlockMCPConfigurationBase(BaseModel):
349
+ class FlockMCPConfigurationBase(BaseModel, Serializable):
198
350
  """Base Configuration Class for MCP Clients.
199
351
 
200
352
  Each Client should implement their own config
@@ -229,6 +381,90 @@ class FlockMCPConfigurationBase(BaseModel):
229
381
  extra="allow",
230
382
  )
231
383
 
384
+ def to_dict(self, path_type: str = "relative") -> dict[str, Any]:
385
+ """Serialize the object to a dict."""
386
+ # each built-in type should serialize, deserialize it self.
387
+ exclude = [
388
+ "connection_config",
389
+ "caching_config",
390
+ "callback_config",
391
+ "feature_config",
392
+ ]
393
+
394
+ data = self.model_dump(
395
+ exclude=exclude,
396
+ exclude_defaults=False,
397
+ exclude_none=True,
398
+ mode="json",
399
+ )
400
+
401
+ # add the core properties
402
+ data["connection_config"] = self.connection_config.to_dict(path_type)
403
+ data["caching_config"] = self.caching_config.to_dict(path_type)
404
+ data["callback_config"] = self.callback_config.to_dict(path_type)
405
+ data["feature_config"] = self.feature_config.to_dict(path_type)
406
+
407
+ return data
408
+
409
+ @classmethod
410
+ def from_dict(cls: type[C], data: dict[str, Any]) -> C:
411
+ """Deserialize the class."""
412
+ connection_config = data.pop("connection_config", None)
413
+ caching_config = data.pop("caching_config", None)
414
+ feature_config = data.pop("feature_config", None)
415
+ callback_config = data.pop("callback_config", None)
416
+
417
+ instance_data: dict[str, Any] = {
418
+ "name": data["name"]
419
+ }
420
+
421
+ if connection_config:
422
+ # Forcing a square into a round hole
423
+ try:
424
+ config_field = cls.model_fields["connection_config"]
425
+ config_cls = config_field.annotation
426
+ except (AttributeError, KeyError):
427
+ # fallback
428
+ config_cls = FlockMCPConnectionConfigurationBase
429
+ instance_data["connection_config"] = config_cls.from_dict(connection_config)
430
+ else:
431
+ raise ValueError(f"connection_config MUST be specified for '{data.get('name', 'unknown_server')}")
432
+
433
+ if caching_config:
434
+ try:
435
+ config_field = cls.model_fields["caching_config"]
436
+ config_cls = config_field.annotation
437
+ except (AttributeError, KeyError):
438
+ # fallback
439
+ config_cls = FlockMCPCachingConfigurationBase
440
+ instance_data["caching_config"] = config_cls.from_dict(caching_config)
441
+ else:
442
+ instance_data["caching_config"] = FlockMCPCachingConfigurationBase()
443
+
444
+ if feature_config:
445
+ try:
446
+ config_field = cls.model_fields["feature_config"]
447
+ config_cls = config_field.annotation
448
+ except (AttributeError, KeyError):
449
+ # fallback
450
+ config_cls = FlockMCPFeatureConfigurationBase
451
+ instance_data["feature_config"] = config_cls.from_dict(feature_config)
452
+ else:
453
+ instance_data["feature_config"] = FlockMCPFeatureConfigurationBase()
454
+
455
+ if callback_config:
456
+ try:
457
+ config_field = cls.model_fields["callback_config"]
458
+ config_cls = config_field.annotation
459
+ except (AttributeError, KeyError):
460
+ # fallback
461
+ config_cls = FlockMCPCallbackConfigurationBase
462
+ instance_data["callback_config"] = config_cls.from_dict(callback_config)
463
+ else:
464
+ instance_data["callback_config"] = FlockMCPCallbackConfigurationBase()
465
+
466
+ return cls(**{k: v for k, v in instance_data.items()})
467
+
232
468
  @classmethod
233
469
  def with_fields(cls: type[C], **field_definitions) -> type[C]:
234
470
  """Create a new config class with additional fields."""
@@ -8,6 +8,8 @@ from mcp.types import (
8
8
  CreateMessageRequestParams,
9
9
  ErrorData,
10
10
  ListRootsResult,
11
+ LoggingMessageNotificationParams,
12
+ ServerNotification,
11
13
  ServerRequest,
12
14
  )
13
15
 
@@ -19,10 +21,6 @@ from flock.core.mcp.types.handlers import (
19
21
  handle_incoming_server_notification,
20
22
  handle_logging_message,
21
23
  )
22
- from flock.core.mcp.types.types import (
23
- FlockLoggingMessageNotificationParams,
24
- ServerNotification,
25
- )
26
24
 
27
25
 
28
26
  async def default_sampling_callback(
@@ -76,7 +74,7 @@ async def default_list_roots_callback(
76
74
 
77
75
 
78
76
  async def default_logging_callback(
79
- params: FlockLoggingMessageNotificationParams,
77
+ params: LoggingMessageNotificationParams,
80
78
  logger: FlockLogger,
81
79
  server_name: str,
82
80
  ) -> None:
@@ -1,6 +1,6 @@
1
1
  """Factories for default MCP Callbacks."""
2
2
 
3
- from typing import TYPE_CHECKING, Any
3
+ from typing import Any
4
4
 
5
5
  from mcp.shared.context import RequestContext
6
6
  from mcp.types import (
@@ -8,6 +8,12 @@ from mcp.types import (
8
8
  )
9
9
 
10
10
  from flock.core.logging.logging import FlockLogger, get_logger
11
+ from flock.core.mcp.types.callbacks import (
12
+ default_list_roots_callback,
13
+ default_logging_callback,
14
+ default_message_handler,
15
+ default_sampling_callback,
16
+ )
11
17
  from flock.core.mcp.types.types import (
12
18
  FlockListRootsMCPCallback,
13
19
  FlockLoggingMCPCallback,
@@ -17,18 +23,10 @@ from flock.core.mcp.types.types import (
17
23
  ServerNotification,
18
24
  )
19
25
 
20
- if TYPE_CHECKING:
21
- from flock.core.mcp.types.callbacks import (
22
- default_list_roots_callback,
23
- default_logging_callback,
24
- default_message_handler,
25
- default_sampling_callback,
26
- )
27
-
28
- default_logging_callback_logger = get_logger("core.mcp.callback.logging")
29
- default_sampling_callback_logger = get_logger("core.mcp.callback.sampling")
30
- default_list_roots_callback_logger = get_logger("core.mcp.callback.sampling")
31
- default_message_handler_logger = get_logger("core.mcp.callback.message")
26
+ default_logging_callback_logger = get_logger("mcp.callback.logging")
27
+ default_sampling_callback_logger = get_logger("mcp.callback.sampling")
28
+ default_list_roots_callback_logger = get_logger("mcp.callback.roots")
29
+ default_message_handler_logger = get_logger("mcp.callback.message")
32
30
 
33
31
 
34
32
  def default_flock_mcp_logging_callback_factory(
@@ -88,7 +86,7 @@ def default_flock_mcp_message_handler_callback_factory(
88
86
  ) -> None:
89
87
  await default_message_handler(
90
88
  req=n,
91
- logger_to_use=logger_to_use,
89
+ logger=logger_to_use,
92
90
  associated_client=associated_client,
93
91
  )
94
92
 
@@ -8,26 +8,23 @@ from mcp.shared.context import RequestContext
8
8
  from mcp.shared.session import RequestResponder
9
9
  from mcp.types import (
10
10
  INTERNAL_ERROR,
11
+ CancelledNotification,
11
12
  ClientResult,
12
13
  ErrorData,
13
14
  ListRootsRequest,
14
- ServerNotification as _MCPServerNotification,
15
- ServerRequest,
16
- )
17
-
18
- from flock.core.logging.logging import FlockLogger
19
- from flock.core.mcp.mcp_client import Any
20
- from flock.core.mcp.types.types import (
21
- CancelledNotification,
22
- FlockLoggingMessageNotificationParams,
23
15
  LoggingMessageNotification,
16
+ LoggingMessageNotificationParams,
24
17
  ProgressNotification,
25
18
  ResourceListChangedNotification,
26
19
  ResourceUpdatedNotification,
27
20
  ServerNotification,
21
+ ServerRequest,
28
22
  ToolListChangedNotification,
29
23
  )
30
24
 
25
+ from flock.core.logging.logging import FlockLogger
26
+ from flock.core.mcp.mcp_client import Any
27
+
31
28
 
32
29
  async def handle_incoming_exception(
33
30
  e: Exception,
@@ -129,11 +126,11 @@ async def handle_tool_list_changed_notification(
129
126
  await associated_client.invalidate_tool_cache()
130
127
 
131
128
 
132
- _SERVER_NOTIFICATION_MAP: dict[type[_MCPServerNotification], Callable] = {
129
+ _SERVER_NOTIFICATION_MAP: dict[type[ServerNotification], Callable] = {
133
130
  ResourceListChangedNotification: handle_resource_list_changed_notification,
134
131
  ResourceUpdatedNotification: handle_resource_update_notification,
135
132
  LoggingMessageNotification: lambda n, log, client: handle_logging_message(
136
- params=n,
133
+ params=n.params,
137
134
  logger=log,
138
135
  server_name=client.config.name,
139
136
  ),
@@ -154,7 +151,7 @@ async def handle_incoming_server_notification(
154
151
 
155
152
 
156
153
  async def handle_logging_message(
157
- params: FlockLoggingMessageNotificationParams,
154
+ params: LoggingMessageNotificationParams,
158
155
  logger: FlockLogger,
159
156
  server_name: str,
160
157
  ) -> None: