flwr-nightly 1.19.0.dev20250601__py3-none-any.whl → 1.19.0.dev20250603__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/client/clientapp/__init__.py +0 -7
- flwr/client/grpc_rere_client/connection.py +9 -18
- flwr/common/inflatable.py +8 -2
- flwr/common/inflatable_grpc_utils.py +9 -5
- flwr/common/record/configrecord.py +9 -8
- flwr/common/record/metricrecord.py +6 -5
- flwr/common/retry_invoker.py +5 -1
- flwr/common/serde.py +39 -28
- flwr/common/serde_utils.py +2 -0
- flwr/proto/message_pb2.py +8 -8
- flwr/proto/message_pb2.pyi +20 -3
- flwr/proto/recorddict_pb2.py +16 -28
- flwr/proto/recorddict_pb2.pyi +46 -64
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +62 -5
- flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -2
- flwr/server/superlink/serverappio/serverappio_servicer.py +58 -3
- flwr/supernode/cli/__init__.py +5 -1
- flwr/supernode/cli/flower_supernode.py +1 -2
- flwr/supernode/cli/flwr_clientapp.py +73 -0
- flwr/supernode/nodestate/in_memory_nodestate.py +112 -0
- flwr/supernode/nodestate/nodestate.py +132 -6
- flwr/supernode/runtime/__init__.py +15 -0
- flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +2 -54
- flwr/supernode/servicer/__init__.py +15 -0
- flwr/supernode/servicer/clientappio/__init__.py +24 -0
- flwr/supernode/start_client_internal.py +24 -20
- {flwr_nightly-1.19.0.dev20250601.dist-info → flwr_nightly-1.19.0.dev20250603.dist-info}/METADATA +1 -1
- {flwr_nightly-1.19.0.dev20250601.dist-info → flwr_nightly-1.19.0.dev20250603.dist-info}/RECORD +31 -27
- {flwr_nightly-1.19.0.dev20250601.dist-info → flwr_nightly-1.19.0.dev20250603.dist-info}/entry_points.txt +1 -1
- /flwr/{client/clientapp → supernode/servicer/clientappio}/clientappio_servicer.py +0 -0
- {flwr_nightly-1.19.0.dev20250601.dist-info → flwr_nightly-1.19.0.dev20250603.dist-info}/WHEEL +0 -0
flwr/proto/recorddict_pb2.pyi
CHANGED
@@ -197,23 +197,34 @@ global___ConfigRecordValue = ConfigRecordValue
|
|
197
197
|
|
198
198
|
class ArrayRecord(google.protobuf.message.Message):
|
199
199
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
200
|
+
class Item(google.protobuf.message.Message):
|
201
|
+
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
202
|
+
KEY_FIELD_NUMBER: builtins.int
|
203
|
+
VALUE_FIELD_NUMBER: builtins.int
|
204
|
+
key: typing.Text
|
205
|
+
@property
|
206
|
+
def value(self) -> global___Array: ...
|
207
|
+
def __init__(self,
|
208
|
+
*,
|
209
|
+
key: typing.Text = ...,
|
210
|
+
value: typing.Optional[global___Array] = ...,
|
211
|
+
) -> None: ...
|
212
|
+
def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
|
213
|
+
def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...
|
214
|
+
|
215
|
+
ITEMS_FIELD_NUMBER: builtins.int
|
204
216
|
@property
|
205
|
-
def
|
217
|
+
def items(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ArrayRecord.Item]: ...
|
206
218
|
def __init__(self,
|
207
219
|
*,
|
208
|
-
|
209
|
-
data_values: typing.Optional[typing.Iterable[global___Array]] = ...,
|
220
|
+
items: typing.Optional[typing.Iterable[global___ArrayRecord.Item]] = ...,
|
210
221
|
) -> None: ...
|
211
|
-
def ClearField(self, field_name: typing_extensions.Literal["
|
222
|
+
def ClearField(self, field_name: typing_extensions.Literal["items",b"items"]) -> None: ...
|
212
223
|
global___ArrayRecord = ArrayRecord
|
213
224
|
|
214
225
|
class MetricRecord(google.protobuf.message.Message):
|
215
226
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
216
|
-
class
|
227
|
+
class Item(google.protobuf.message.Message):
|
217
228
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
218
229
|
KEY_FIELD_NUMBER: builtins.int
|
219
230
|
VALUE_FIELD_NUMBER: builtins.int
|
@@ -228,19 +239,19 @@ class MetricRecord(google.protobuf.message.Message):
|
|
228
239
|
def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
|
229
240
|
def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...
|
230
241
|
|
231
|
-
|
242
|
+
ITEMS_FIELD_NUMBER: builtins.int
|
232
243
|
@property
|
233
|
-
def
|
244
|
+
def items(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___MetricRecord.Item]: ...
|
234
245
|
def __init__(self,
|
235
246
|
*,
|
236
|
-
|
247
|
+
items: typing.Optional[typing.Iterable[global___MetricRecord.Item]] = ...,
|
237
248
|
) -> None: ...
|
238
|
-
def ClearField(self, field_name: typing_extensions.Literal["
|
249
|
+
def ClearField(self, field_name: typing_extensions.Literal["items",b"items"]) -> None: ...
|
239
250
|
global___MetricRecord = MetricRecord
|
240
251
|
|
241
252
|
class ConfigRecord(google.protobuf.message.Message):
|
242
253
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
243
|
-
class
|
254
|
+
class Item(google.protobuf.message.Message):
|
244
255
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
245
256
|
KEY_FIELD_NUMBER: builtins.int
|
246
257
|
VALUE_FIELD_NUMBER: builtins.int
|
@@ -255,77 +266,48 @@ class ConfigRecord(google.protobuf.message.Message):
|
|
255
266
|
def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
|
256
267
|
def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...
|
257
268
|
|
258
|
-
|
269
|
+
ITEMS_FIELD_NUMBER: builtins.int
|
259
270
|
@property
|
260
|
-
def
|
271
|
+
def items(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ConfigRecord.Item]: ...
|
261
272
|
def __init__(self,
|
262
273
|
*,
|
263
|
-
|
274
|
+
items: typing.Optional[typing.Iterable[global___ConfigRecord.Item]] = ...,
|
264
275
|
) -> None: ...
|
265
|
-
def ClearField(self, field_name: typing_extensions.Literal["
|
276
|
+
def ClearField(self, field_name: typing_extensions.Literal["items",b"items"]) -> None: ...
|
266
277
|
global___ConfigRecord = ConfigRecord
|
267
278
|
|
268
279
|
class RecordDict(google.protobuf.message.Message):
|
269
280
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
270
|
-
class
|
281
|
+
class Item(google.protobuf.message.Message):
|
271
282
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
272
283
|
KEY_FIELD_NUMBER: builtins.int
|
273
|
-
|
284
|
+
ARRAY_RECORD_FIELD_NUMBER: builtins.int
|
285
|
+
METRIC_RECORD_FIELD_NUMBER: builtins.int
|
286
|
+
CONFIG_RECORD_FIELD_NUMBER: builtins.int
|
274
287
|
key: typing.Text
|
275
288
|
@property
|
276
|
-
def
|
277
|
-
def __init__(self,
|
278
|
-
*,
|
279
|
-
key: typing.Text = ...,
|
280
|
-
value: typing.Optional[global___ArrayRecord] = ...,
|
281
|
-
) -> None: ...
|
282
|
-
def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
|
283
|
-
def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...
|
284
|
-
|
285
|
-
class MetricsEntry(google.protobuf.message.Message):
|
286
|
-
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
287
|
-
KEY_FIELD_NUMBER: builtins.int
|
288
|
-
VALUE_FIELD_NUMBER: builtins.int
|
289
|
-
key: typing.Text
|
289
|
+
def array_record(self) -> global___ArrayRecord: ...
|
290
290
|
@property
|
291
|
-
def
|
292
|
-
def __init__(self,
|
293
|
-
*,
|
294
|
-
key: typing.Text = ...,
|
295
|
-
value: typing.Optional[global___MetricRecord] = ...,
|
296
|
-
) -> None: ...
|
297
|
-
def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
|
298
|
-
def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...
|
299
|
-
|
300
|
-
class ConfigsEntry(google.protobuf.message.Message):
|
301
|
-
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
302
|
-
KEY_FIELD_NUMBER: builtins.int
|
303
|
-
VALUE_FIELD_NUMBER: builtins.int
|
304
|
-
key: typing.Text
|
291
|
+
def metric_record(self) -> global___MetricRecord: ...
|
305
292
|
@property
|
306
|
-
def
|
293
|
+
def config_record(self) -> global___ConfigRecord: ...
|
307
294
|
def __init__(self,
|
308
295
|
*,
|
309
296
|
key: typing.Text = ...,
|
310
|
-
|
297
|
+
array_record: typing.Optional[global___ArrayRecord] = ...,
|
298
|
+
metric_record: typing.Optional[global___MetricRecord] = ...,
|
299
|
+
config_record: typing.Optional[global___ConfigRecord] = ...,
|
311
300
|
) -> None: ...
|
312
|
-
def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
|
313
|
-
def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...
|
301
|
+
def HasField(self, field_name: typing_extensions.Literal["array_record",b"array_record","config_record",b"config_record","metric_record",b"metric_record","value",b"value"]) -> builtins.bool: ...
|
302
|
+
def ClearField(self, field_name: typing_extensions.Literal["array_record",b"array_record","config_record",b"config_record","key",b"key","metric_record",b"metric_record","value",b"value"]) -> None: ...
|
303
|
+
def WhichOneof(self, oneof_group: typing_extensions.Literal["value",b"value"]) -> typing.Optional[typing_extensions.Literal["array_record","metric_record","config_record"]]: ...
|
314
304
|
|
315
|
-
|
316
|
-
METRICS_FIELD_NUMBER: builtins.int
|
317
|
-
CONFIGS_FIELD_NUMBER: builtins.int
|
318
|
-
@property
|
319
|
-
def arrays(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, global___ArrayRecord]: ...
|
320
|
-
@property
|
321
|
-
def metrics(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, global___MetricRecord]: ...
|
305
|
+
ITEMS_FIELD_NUMBER: builtins.int
|
322
306
|
@property
|
323
|
-
def
|
307
|
+
def items(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___RecordDict.Item]: ...
|
324
308
|
def __init__(self,
|
325
309
|
*,
|
326
|
-
|
327
|
-
metrics: typing.Optional[typing.Mapping[typing.Text, global___MetricRecord]] = ...,
|
328
|
-
configs: typing.Optional[typing.Mapping[typing.Text, global___ConfigRecord]] = ...,
|
310
|
+
items: typing.Optional[typing.Iterable[global___RecordDict.Item]] = ...,
|
329
311
|
) -> None: ...
|
330
|
-
def ClearField(self, field_name: typing_extensions.Literal["
|
312
|
+
def ClearField(self, field_name: typing_extensions.Literal["items",b"items"]) -> None: ...
|
331
313
|
global___RecordDict = RecordDict
|
@@ -15,11 +15,12 @@
|
|
15
15
|
"""Fleet API gRPC request-response servicer."""
|
16
16
|
|
17
17
|
|
18
|
-
from logging import DEBUG, INFO
|
18
|
+
from logging import DEBUG, ERROR, INFO
|
19
19
|
|
20
20
|
import grpc
|
21
21
|
from google.protobuf.json_format import MessageToDict
|
22
22
|
|
23
|
+
from flwr.common.constant import Status
|
23
24
|
from flwr.common.inflatable import check_body_len_consistency
|
24
25
|
from flwr.common.logger import log
|
25
26
|
from flwr.common.typing import InvalidRunStatusException
|
@@ -49,8 +50,9 @@ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=
|
|
49
50
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
50
51
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
51
52
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
52
|
-
from flwr.server.superlink.utils import abort_grpc_context
|
53
|
+
from flwr.server.superlink.utils import abort_grpc_context, check_abort
|
53
54
|
from flwr.supercore.object_store import ObjectStoreFactory
|
55
|
+
from flwr.supercore.object_store.object_store import NoObjectInStoreError
|
54
56
|
|
55
57
|
|
56
58
|
class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
@@ -183,11 +185,39 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
183
185
|
request.object_id,
|
184
186
|
)
|
185
187
|
|
188
|
+
state = self.state_factory.state()
|
189
|
+
|
190
|
+
# Abort if the run is not running
|
191
|
+
abort_msg = check_abort(
|
192
|
+
request.run_id,
|
193
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
194
|
+
state,
|
195
|
+
)
|
196
|
+
if abort_msg:
|
197
|
+
abort_grpc_context(abort_msg, context)
|
198
|
+
|
199
|
+
if request.node.node_id not in state.get_nodes(run_id=request.run_id):
|
200
|
+
# Cancel insertion in ObjectStore
|
201
|
+
context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Unexpected node ID.")
|
202
|
+
|
186
203
|
if not check_body_len_consistency(request.object_content):
|
187
204
|
# Cancel insertion in ObjectStore
|
188
|
-
context.abort(
|
205
|
+
context.abort(
|
206
|
+
grpc.StatusCode.FAILED_PRECONDITION, "Unexpected object length"
|
207
|
+
)
|
208
|
+
|
209
|
+
# Init store
|
210
|
+
store = self.objectstore_factory.store()
|
211
|
+
|
212
|
+
# Insert in store
|
213
|
+
stored = False
|
214
|
+
try:
|
215
|
+
store.put(request.object_id, request.object_content)
|
216
|
+
stored = True
|
217
|
+
except (NoObjectInStoreError, ValueError) as e:
|
218
|
+
log(ERROR, str(e))
|
189
219
|
|
190
|
-
return PushObjectResponse()
|
220
|
+
return PushObjectResponse(stored=stored)
|
191
221
|
|
192
222
|
def PullObject(
|
193
223
|
self, request: PullObjectRequest, context: grpc.ServicerContext
|
@@ -199,4 +229,31 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
199
229
|
request.object_id,
|
200
230
|
)
|
201
231
|
|
202
|
-
|
232
|
+
state = self.state_factory.state()
|
233
|
+
|
234
|
+
# Abort if the run is not running
|
235
|
+
abort_msg = check_abort(
|
236
|
+
request.run_id,
|
237
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
238
|
+
state,
|
239
|
+
)
|
240
|
+
if abort_msg:
|
241
|
+
abort_grpc_context(abort_msg, context)
|
242
|
+
|
243
|
+
if request.node.node_id not in state.get_nodes(run_id=request.run_id):
|
244
|
+
# Cancel insertion in ObjectStore
|
245
|
+
context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Unexpected node ID.")
|
246
|
+
|
247
|
+
# Init store
|
248
|
+
store = self.objectstore_factory.store()
|
249
|
+
|
250
|
+
# Fetch from store
|
251
|
+
content = store.get(request.object_id)
|
252
|
+
if content is not None:
|
253
|
+
object_available = content != b""
|
254
|
+
return PullObjectResponse(
|
255
|
+
object_found=True,
|
256
|
+
object_available=object_available,
|
257
|
+
object_content=content,
|
258
|
+
)
|
259
|
+
return PullObjectResponse(object_found=False, object_available=False)
|
@@ -114,7 +114,7 @@ async def pull_message(request: PullMessagesRequest) -> PullMessagesResponse:
|
|
114
114
|
"""Pull PullMessages."""
|
115
115
|
# Get state from app
|
116
116
|
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
117
|
-
store: ObjectStore = cast(ObjectStoreFactory, app.state.
|
117
|
+
store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
|
118
118
|
|
119
119
|
# Handle message
|
120
120
|
return message_handler.pull_messages(request=request, state=state, store=store)
|
@@ -125,7 +125,7 @@ async def push_message(request: PushMessagesRequest) -> PushMessagesResponse:
|
|
125
125
|
"""Pull PushMessages."""
|
126
126
|
# Get state from app
|
127
127
|
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
128
|
-
store: ObjectStore = cast(ObjectStoreFactory, app.state.
|
128
|
+
store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
|
129
129
|
|
130
130
|
# Handle message
|
131
131
|
return message_handler.push_messages(request=request, state=state, store=store)
|
@@ -409,11 +409,39 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
409
409
|
"""Push an object to the ObjectStore."""
|
410
410
|
log(DEBUG, "ServerAppIoServicer.PushObject")
|
411
411
|
|
412
|
+
# Init state
|
413
|
+
state: LinkState = self.state_factory.state()
|
414
|
+
|
415
|
+
# Abort if the run is not running
|
416
|
+
abort_if(
|
417
|
+
request.run_id,
|
418
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
419
|
+
state,
|
420
|
+
context,
|
421
|
+
)
|
422
|
+
|
423
|
+
if request.node.node_id != SUPERLINK_NODE_ID:
|
424
|
+
# Cancel insertion in ObjectStore
|
425
|
+
context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Unexpected node ID.")
|
426
|
+
|
412
427
|
if not check_body_len_consistency(request.object_content):
|
413
428
|
# Cancel insertion in ObjectStore
|
414
|
-
context.abort(
|
429
|
+
context.abort(
|
430
|
+
grpc.StatusCode.FAILED_PRECONDITION, "Unexpected object length."
|
431
|
+
)
|
432
|
+
|
433
|
+
# Init store
|
434
|
+
store = self.objectstore_factory.store()
|
435
|
+
|
436
|
+
# Insert in store
|
437
|
+
stored = False
|
438
|
+
try:
|
439
|
+
store.put(request.object_id, request.object_content)
|
440
|
+
stored = True
|
441
|
+
except (NoObjectInStoreError, ValueError) as e:
|
442
|
+
log(ERROR, str(e))
|
415
443
|
|
416
|
-
return PushObjectResponse()
|
444
|
+
return PushObjectResponse(stored=stored)
|
417
445
|
|
418
446
|
def PullObject(
|
419
447
|
self, request: PullObjectRequest, context: grpc.ServicerContext
|
@@ -421,7 +449,34 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
421
449
|
"""Pull an object from the ObjectStore."""
|
422
450
|
log(DEBUG, "ServerAppIoServicer.PullObject")
|
423
451
|
|
424
|
-
|
452
|
+
# Init state
|
453
|
+
state: LinkState = self.state_factory.state()
|
454
|
+
|
455
|
+
# Abort if the run is not running
|
456
|
+
abort_if(
|
457
|
+
request.run_id,
|
458
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
459
|
+
state,
|
460
|
+
context,
|
461
|
+
)
|
462
|
+
|
463
|
+
if request.node.node_id != SUPERLINK_NODE_ID:
|
464
|
+
# Cancel insertion in ObjectStore
|
465
|
+
context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Unexpected node ID.")
|
466
|
+
|
467
|
+
# Init store
|
468
|
+
store = self.objectstore_factory.store()
|
469
|
+
|
470
|
+
# Fetch from store
|
471
|
+
content = store.get(request.object_id)
|
472
|
+
if content is not None:
|
473
|
+
object_available = content != b""
|
474
|
+
return PullObjectResponse(
|
475
|
+
object_found=True,
|
476
|
+
object_available=object_available,
|
477
|
+
object_content=content,
|
478
|
+
)
|
479
|
+
return PullObjectResponse(object_found=False, object_available=False)
|
425
480
|
|
426
481
|
|
427
482
|
def _raise_if(validation_error: bool, request_name: str, detail: str) -> None:
|
flwr/supernode/cli/__init__.py
CHANGED
@@ -42,8 +42,7 @@ from flwr.common.constant import (
|
|
42
42
|
from flwr.common.exit import ExitCode, flwr_exit
|
43
43
|
from flwr.common.exit_handlers import register_exit_handlers
|
44
44
|
from flwr.common.logger import log
|
45
|
-
|
46
|
-
from ..start_client_internal import start_client_internal
|
45
|
+
from flwr.supernode.start_client_internal import start_client_internal
|
47
46
|
|
48
47
|
|
49
48
|
def flower_supernode() -> None:
|
@@ -0,0 +1,73 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""`flwr-clientapp` command."""
|
16
|
+
|
17
|
+
|
18
|
+
import argparse
|
19
|
+
from logging import DEBUG, INFO
|
20
|
+
|
21
|
+
from flwr.common.args import add_args_flwr_app_common
|
22
|
+
from flwr.common.constant import CLIENTAPPIO_API_DEFAULT_CLIENT_ADDRESS
|
23
|
+
from flwr.common.exit import ExitCode, flwr_exit
|
24
|
+
from flwr.common.logger import log
|
25
|
+
from flwr.supernode.runtime.run_clientapp import run_clientapp
|
26
|
+
|
27
|
+
|
28
|
+
def flwr_clientapp() -> None:
|
29
|
+
"""Run process-isolated Flower ClientApp."""
|
30
|
+
args = _parse_args_run_flwr_clientapp().parse_args()
|
31
|
+
if not args.insecure:
|
32
|
+
flwr_exit(
|
33
|
+
ExitCode.COMMON_TLS_NOT_SUPPORTED,
|
34
|
+
"flwr-clientapp does not support TLS yet.",
|
35
|
+
)
|
36
|
+
|
37
|
+
log(INFO, "Start `flwr-clientapp` process")
|
38
|
+
log(
|
39
|
+
DEBUG,
|
40
|
+
"`flwr-clientapp` will attempt to connect to SuperNode's "
|
41
|
+
"ClientAppIo API at %s with token %s",
|
42
|
+
args.clientappio_api_address,
|
43
|
+
args.token,
|
44
|
+
)
|
45
|
+
run_clientapp(
|
46
|
+
clientappio_api_address=args.clientappio_api_address,
|
47
|
+
run_once=(args.token is not None),
|
48
|
+
token=args.token,
|
49
|
+
flwr_dir=args.flwr_dir,
|
50
|
+
certificates=None,
|
51
|
+
)
|
52
|
+
|
53
|
+
|
54
|
+
def _parse_args_run_flwr_clientapp() -> argparse.ArgumentParser:
|
55
|
+
"""Parse flwr-clientapp command line arguments."""
|
56
|
+
parser = argparse.ArgumentParser(
|
57
|
+
description="Run a Flower ClientApp",
|
58
|
+
)
|
59
|
+
parser.add_argument(
|
60
|
+
"--clientappio-api-address",
|
61
|
+
default=CLIENTAPPIO_API_DEFAULT_CLIENT_ADDRESS,
|
62
|
+
type=str,
|
63
|
+
help="Address of SuperNode's ClientAppIo API (IPv4, IPv6, or a domain name)."
|
64
|
+
f"By default, it is set to {CLIENTAPPIO_API_DEFAULT_CLIENT_ADDRESS}.",
|
65
|
+
)
|
66
|
+
parser.add_argument(
|
67
|
+
"--token",
|
68
|
+
type=int,
|
69
|
+
required=False,
|
70
|
+
help="Unique token generated by SuperNode for each ClientApp execution",
|
71
|
+
)
|
72
|
+
add_args_flwr_app_common(parser=parser)
|
73
|
+
return parser
|
@@ -15,17 +15,40 @@
|
|
15
15
|
"""In-memory NodeState implementation."""
|
16
16
|
|
17
17
|
|
18
|
+
from collections.abc import Sequence
|
19
|
+
from dataclasses import dataclass
|
20
|
+
from threading import Lock
|
18
21
|
from typing import Optional
|
19
22
|
|
23
|
+
from flwr.common import Context, Message
|
24
|
+
from flwr.common.typing import Run
|
25
|
+
|
20
26
|
from .nodestate import NodeState
|
21
27
|
|
22
28
|
|
29
|
+
@dataclass
|
30
|
+
class MessageEntry:
|
31
|
+
"""Data class to represent a message entry."""
|
32
|
+
|
33
|
+
message: Message
|
34
|
+
is_retrieved: bool = False
|
35
|
+
|
36
|
+
|
23
37
|
class InMemoryNodeState(NodeState):
|
24
38
|
"""In-memory NodeState implementation."""
|
25
39
|
|
26
40
|
def __init__(self) -> None:
|
27
41
|
# Store node_id
|
28
42
|
self.node_id: Optional[int] = None
|
43
|
+
# Store Object ID to MessageEntry mapping
|
44
|
+
self.msg_store: dict[str, MessageEntry] = {}
|
45
|
+
self.lock_msg_store = Lock()
|
46
|
+
# Store run ID to Run mapping
|
47
|
+
self.run_store: dict[int, Run] = {}
|
48
|
+
self.lock_run_store = Lock()
|
49
|
+
# Store run ID to Context mapping
|
50
|
+
self.ctx_store: dict[int, Context] = {}
|
51
|
+
self.lock_ctx_store = Lock()
|
29
52
|
|
30
53
|
def set_node_id(self, node_id: Optional[int]) -> None:
|
31
54
|
"""Set the node ID."""
|
@@ -36,3 +59,92 @@ class InMemoryNodeState(NodeState):
|
|
36
59
|
if self.node_id is None:
|
37
60
|
raise ValueError("Node ID not set")
|
38
61
|
return self.node_id
|
62
|
+
|
63
|
+
def store_message(self, message: Message) -> Optional[str]:
|
64
|
+
"""Store a message."""
|
65
|
+
with self.lock_msg_store:
|
66
|
+
msg_id = message.metadata.message_id
|
67
|
+
if msg_id == "" or msg_id in self.msg_store:
|
68
|
+
return None
|
69
|
+
self.msg_store[msg_id] = MessageEntry(message=message)
|
70
|
+
return msg_id
|
71
|
+
|
72
|
+
def get_messages(
|
73
|
+
self,
|
74
|
+
*,
|
75
|
+
run_ids: Optional[Sequence[int]] = None,
|
76
|
+
is_reply: Optional[bool] = None,
|
77
|
+
limit: Optional[int] = None,
|
78
|
+
) -> Sequence[Message]:
|
79
|
+
"""Retrieve messages based on the specified filters."""
|
80
|
+
selected_messages: list[Message] = []
|
81
|
+
|
82
|
+
with self.lock_msg_store:
|
83
|
+
# Iterate through all messages in the store
|
84
|
+
for object_id in list(self.msg_store.keys()):
|
85
|
+
entry = self.msg_store[object_id]
|
86
|
+
message = entry.message
|
87
|
+
|
88
|
+
# Skip messages that have already been retrieved
|
89
|
+
if entry.is_retrieved:
|
90
|
+
continue
|
91
|
+
|
92
|
+
# Skip messages whose run_id doesn't match the filter
|
93
|
+
if run_ids is not None:
|
94
|
+
if message.metadata.run_id not in run_ids:
|
95
|
+
continue
|
96
|
+
|
97
|
+
# If is_reply filter is set, filter for reply/non-reply messages
|
98
|
+
if is_reply is not None:
|
99
|
+
is_reply_message = message.metadata.reply_to_message_id != ""
|
100
|
+
# XOR logic to filter mismatched types (reply vs non-reply)
|
101
|
+
if is_reply ^ is_reply_message:
|
102
|
+
continue
|
103
|
+
|
104
|
+
# Add the message to the result set
|
105
|
+
selected_messages.append(message)
|
106
|
+
|
107
|
+
# Mark the message as retrieved
|
108
|
+
entry.is_retrieved = True
|
109
|
+
|
110
|
+
# Stop if the number of collected messages reaches the limit
|
111
|
+
if limit is not None and len(selected_messages) >= limit:
|
112
|
+
break
|
113
|
+
|
114
|
+
return selected_messages
|
115
|
+
|
116
|
+
def delete_messages(
|
117
|
+
self,
|
118
|
+
*,
|
119
|
+
message_ids: Optional[Sequence[str]] = None,
|
120
|
+
) -> None:
|
121
|
+
"""Delete messages based on the specified filters."""
|
122
|
+
with self.lock_msg_store:
|
123
|
+
if message_ids is None:
|
124
|
+
# If no message IDs are provided, clear the entire store
|
125
|
+
self.msg_store.clear()
|
126
|
+
return
|
127
|
+
|
128
|
+
# Remove specified messages from the store
|
129
|
+
for msg_id in message_ids:
|
130
|
+
self.msg_store.pop(msg_id, None)
|
131
|
+
|
132
|
+
def store_run(self, run: Run) -> None:
|
133
|
+
"""Store a run."""
|
134
|
+
with self.lock_run_store:
|
135
|
+
self.run_store[run.run_id] = run
|
136
|
+
|
137
|
+
def get_run(self, run_id: int) -> Optional[Run]:
|
138
|
+
"""Retrieve a run by its ID."""
|
139
|
+
with self.lock_run_store:
|
140
|
+
return self.run_store.get(run_id)
|
141
|
+
|
142
|
+
def store_context(self, context: Context) -> None:
|
143
|
+
"""Store a context."""
|
144
|
+
with self.lock_ctx_store:
|
145
|
+
self.ctx_store[context.run_id] = context
|
146
|
+
|
147
|
+
def get_context(self, run_id: int) -> Optional[Context]:
|
148
|
+
"""Retrieve a context by its run ID."""
|
149
|
+
with self.lock_ctx_store:
|
150
|
+
return self.ctx_store.get(run_id)
|