flwr-nightly 1.19.0.dev20250601__py3-none-any.whl → 1.19.0.dev20250603__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. flwr/client/clientapp/__init__.py +0 -7
  2. flwr/client/grpc_rere_client/connection.py +9 -18
  3. flwr/common/inflatable.py +8 -2
  4. flwr/common/inflatable_grpc_utils.py +9 -5
  5. flwr/common/record/configrecord.py +9 -8
  6. flwr/common/record/metricrecord.py +6 -5
  7. flwr/common/retry_invoker.py +5 -1
  8. flwr/common/serde.py +39 -28
  9. flwr/common/serde_utils.py +2 -0
  10. flwr/proto/message_pb2.py +8 -8
  11. flwr/proto/message_pb2.pyi +20 -3
  12. flwr/proto/recorddict_pb2.py +16 -28
  13. flwr/proto/recorddict_pb2.pyi +46 -64
  14. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +62 -5
  15. flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -2
  16. flwr/server/superlink/serverappio/serverappio_servicer.py +58 -3
  17. flwr/supernode/cli/__init__.py +5 -1
  18. flwr/supernode/cli/flower_supernode.py +1 -2
  19. flwr/supernode/cli/flwr_clientapp.py +73 -0
  20. flwr/supernode/nodestate/in_memory_nodestate.py +112 -0
  21. flwr/supernode/nodestate/nodestate.py +132 -6
  22. flwr/supernode/runtime/__init__.py +15 -0
  23. flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +2 -54
  24. flwr/supernode/servicer/__init__.py +15 -0
  25. flwr/supernode/servicer/clientappio/__init__.py +24 -0
  26. flwr/supernode/start_client_internal.py +24 -20
  27. {flwr_nightly-1.19.0.dev20250601.dist-info → flwr_nightly-1.19.0.dev20250603.dist-info}/METADATA +1 -1
  28. {flwr_nightly-1.19.0.dev20250601.dist-info → flwr_nightly-1.19.0.dev20250603.dist-info}/RECORD +31 -27
  29. {flwr_nightly-1.19.0.dev20250601.dist-info → flwr_nightly-1.19.0.dev20250603.dist-info}/entry_points.txt +1 -1
  30. /flwr/{client/clientapp → supernode/servicer/clientappio}/clientappio_servicer.py +0 -0
  31. {flwr_nightly-1.19.0.dev20250601.dist-info → flwr_nightly-1.19.0.dev20250603.dist-info}/WHEEL +0 -0
@@ -197,23 +197,34 @@ global___ConfigRecordValue = ConfigRecordValue
197
197
 
198
198
  class ArrayRecord(google.protobuf.message.Message):
199
199
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
200
- DATA_KEYS_FIELD_NUMBER: builtins.int
201
- DATA_VALUES_FIELD_NUMBER: builtins.int
202
- @property
203
- def data_keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ...
200
+ class Item(google.protobuf.message.Message):
201
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
202
+ KEY_FIELD_NUMBER: builtins.int
203
+ VALUE_FIELD_NUMBER: builtins.int
204
+ key: typing.Text
205
+ @property
206
+ def value(self) -> global___Array: ...
207
+ def __init__(self,
208
+ *,
209
+ key: typing.Text = ...,
210
+ value: typing.Optional[global___Array] = ...,
211
+ ) -> None: ...
212
+ def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
213
+ def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...
214
+
215
+ ITEMS_FIELD_NUMBER: builtins.int
204
216
  @property
205
- def data_values(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Array]: ...
217
+ def items(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ArrayRecord.Item]: ...
206
218
  def __init__(self,
207
219
  *,
208
- data_keys: typing.Optional[typing.Iterable[typing.Text]] = ...,
209
- data_values: typing.Optional[typing.Iterable[global___Array]] = ...,
220
+ items: typing.Optional[typing.Iterable[global___ArrayRecord.Item]] = ...,
210
221
  ) -> None: ...
211
- def ClearField(self, field_name: typing_extensions.Literal["data_keys",b"data_keys","data_values",b"data_values"]) -> None: ...
222
+ def ClearField(self, field_name: typing_extensions.Literal["items",b"items"]) -> None: ...
212
223
  global___ArrayRecord = ArrayRecord
213
224
 
214
225
  class MetricRecord(google.protobuf.message.Message):
215
226
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
216
- class DataEntry(google.protobuf.message.Message):
227
+ class Item(google.protobuf.message.Message):
217
228
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
218
229
  KEY_FIELD_NUMBER: builtins.int
219
230
  VALUE_FIELD_NUMBER: builtins.int
@@ -228,19 +239,19 @@ class MetricRecord(google.protobuf.message.Message):
228
239
  def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
229
240
  def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...
230
241
 
231
- DATA_FIELD_NUMBER: builtins.int
242
+ ITEMS_FIELD_NUMBER: builtins.int
232
243
  @property
233
- def data(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, global___MetricRecordValue]: ...
244
+ def items(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___MetricRecord.Item]: ...
234
245
  def __init__(self,
235
246
  *,
236
- data: typing.Optional[typing.Mapping[typing.Text, global___MetricRecordValue]] = ...,
247
+ items: typing.Optional[typing.Iterable[global___MetricRecord.Item]] = ...,
237
248
  ) -> None: ...
238
- def ClearField(self, field_name: typing_extensions.Literal["data",b"data"]) -> None: ...
249
+ def ClearField(self, field_name: typing_extensions.Literal["items",b"items"]) -> None: ...
239
250
  global___MetricRecord = MetricRecord
240
251
 
241
252
  class ConfigRecord(google.protobuf.message.Message):
242
253
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
243
- class DataEntry(google.protobuf.message.Message):
254
+ class Item(google.protobuf.message.Message):
244
255
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
245
256
  KEY_FIELD_NUMBER: builtins.int
246
257
  VALUE_FIELD_NUMBER: builtins.int
@@ -255,77 +266,48 @@ class ConfigRecord(google.protobuf.message.Message):
255
266
  def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
256
267
  def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...
257
268
 
258
- DATA_FIELD_NUMBER: builtins.int
269
+ ITEMS_FIELD_NUMBER: builtins.int
259
270
  @property
260
- def data(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, global___ConfigRecordValue]: ...
271
+ def items(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ConfigRecord.Item]: ...
261
272
  def __init__(self,
262
273
  *,
263
- data: typing.Optional[typing.Mapping[typing.Text, global___ConfigRecordValue]] = ...,
274
+ items: typing.Optional[typing.Iterable[global___ConfigRecord.Item]] = ...,
264
275
  ) -> None: ...
265
- def ClearField(self, field_name: typing_extensions.Literal["data",b"data"]) -> None: ...
276
+ def ClearField(self, field_name: typing_extensions.Literal["items",b"items"]) -> None: ...
266
277
  global___ConfigRecord = ConfigRecord
267
278
 
268
279
  class RecordDict(google.protobuf.message.Message):
269
280
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
270
- class ArraysEntry(google.protobuf.message.Message):
281
+ class Item(google.protobuf.message.Message):
271
282
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
272
283
  KEY_FIELD_NUMBER: builtins.int
273
- VALUE_FIELD_NUMBER: builtins.int
284
+ ARRAY_RECORD_FIELD_NUMBER: builtins.int
285
+ METRIC_RECORD_FIELD_NUMBER: builtins.int
286
+ CONFIG_RECORD_FIELD_NUMBER: builtins.int
274
287
  key: typing.Text
275
288
  @property
276
- def value(self) -> global___ArrayRecord: ...
277
- def __init__(self,
278
- *,
279
- key: typing.Text = ...,
280
- value: typing.Optional[global___ArrayRecord] = ...,
281
- ) -> None: ...
282
- def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
283
- def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...
284
-
285
- class MetricsEntry(google.protobuf.message.Message):
286
- DESCRIPTOR: google.protobuf.descriptor.Descriptor
287
- KEY_FIELD_NUMBER: builtins.int
288
- VALUE_FIELD_NUMBER: builtins.int
289
- key: typing.Text
289
+ def array_record(self) -> global___ArrayRecord: ...
290
290
  @property
291
- def value(self) -> global___MetricRecord: ...
292
- def __init__(self,
293
- *,
294
- key: typing.Text = ...,
295
- value: typing.Optional[global___MetricRecord] = ...,
296
- ) -> None: ...
297
- def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
298
- def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...
299
-
300
- class ConfigsEntry(google.protobuf.message.Message):
301
- DESCRIPTOR: google.protobuf.descriptor.Descriptor
302
- KEY_FIELD_NUMBER: builtins.int
303
- VALUE_FIELD_NUMBER: builtins.int
304
- key: typing.Text
291
+ def metric_record(self) -> global___MetricRecord: ...
305
292
  @property
306
- def value(self) -> global___ConfigRecord: ...
293
+ def config_record(self) -> global___ConfigRecord: ...
307
294
  def __init__(self,
308
295
  *,
309
296
  key: typing.Text = ...,
310
- value: typing.Optional[global___ConfigRecord] = ...,
297
+ array_record: typing.Optional[global___ArrayRecord] = ...,
298
+ metric_record: typing.Optional[global___MetricRecord] = ...,
299
+ config_record: typing.Optional[global___ConfigRecord] = ...,
311
300
  ) -> None: ...
312
- def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
313
- def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...
301
+ def HasField(self, field_name: typing_extensions.Literal["array_record",b"array_record","config_record",b"config_record","metric_record",b"metric_record","value",b"value"]) -> builtins.bool: ...
302
+ def ClearField(self, field_name: typing_extensions.Literal["array_record",b"array_record","config_record",b"config_record","key",b"key","metric_record",b"metric_record","value",b"value"]) -> None: ...
303
+ def WhichOneof(self, oneof_group: typing_extensions.Literal["value",b"value"]) -> typing.Optional[typing_extensions.Literal["array_record","metric_record","config_record"]]: ...
314
304
 
315
- ARRAYS_FIELD_NUMBER: builtins.int
316
- METRICS_FIELD_NUMBER: builtins.int
317
- CONFIGS_FIELD_NUMBER: builtins.int
318
- @property
319
- def arrays(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, global___ArrayRecord]: ...
320
- @property
321
- def metrics(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, global___MetricRecord]: ...
305
+ ITEMS_FIELD_NUMBER: builtins.int
322
306
  @property
323
- def configs(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, global___ConfigRecord]: ...
307
+ def items(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___RecordDict.Item]: ...
324
308
  def __init__(self,
325
309
  *,
326
- arrays: typing.Optional[typing.Mapping[typing.Text, global___ArrayRecord]] = ...,
327
- metrics: typing.Optional[typing.Mapping[typing.Text, global___MetricRecord]] = ...,
328
- configs: typing.Optional[typing.Mapping[typing.Text, global___ConfigRecord]] = ...,
310
+ items: typing.Optional[typing.Iterable[global___RecordDict.Item]] = ...,
329
311
  ) -> None: ...
330
- def ClearField(self, field_name: typing_extensions.Literal["arrays",b"arrays","configs",b"configs","metrics",b"metrics"]) -> None: ...
312
+ def ClearField(self, field_name: typing_extensions.Literal["items",b"items"]) -> None: ...
331
313
  global___RecordDict = RecordDict
@@ -15,11 +15,12 @@
15
15
  """Fleet API gRPC request-response servicer."""
16
16
 
17
17
 
18
- from logging import DEBUG, INFO
18
+ from logging import DEBUG, ERROR, INFO
19
19
 
20
20
  import grpc
21
21
  from google.protobuf.json_format import MessageToDict
22
22
 
23
+ from flwr.common.constant import Status
23
24
  from flwr.common.inflatable import check_body_len_consistency
24
25
  from flwr.common.logger import log
25
26
  from flwr.common.typing import InvalidRunStatusException
@@ -49,8 +50,9 @@ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=
49
50
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
50
51
  from flwr.server.superlink.fleet.message_handler import message_handler
51
52
  from flwr.server.superlink.linkstate import LinkStateFactory
52
- from flwr.server.superlink.utils import abort_grpc_context
53
+ from flwr.server.superlink.utils import abort_grpc_context, check_abort
53
54
  from flwr.supercore.object_store import ObjectStoreFactory
55
+ from flwr.supercore.object_store.object_store import NoObjectInStoreError
54
56
 
55
57
 
56
58
  class FleetServicer(fleet_pb2_grpc.FleetServicer):
@@ -183,11 +185,39 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
183
185
  request.object_id,
184
186
  )
185
187
 
188
+ state = self.state_factory.state()
189
+
190
+ # Abort if the run is not running
191
+ abort_msg = check_abort(
192
+ request.run_id,
193
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
194
+ state,
195
+ )
196
+ if abort_msg:
197
+ abort_grpc_context(abort_msg, context)
198
+
199
+ if request.node.node_id not in state.get_nodes(run_id=request.run_id):
200
+ # Cancel insertion in ObjectStore
201
+ context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Unexpected node ID.")
202
+
186
203
  if not check_body_len_consistency(request.object_content):
187
204
  # Cancel insertion in ObjectStore
188
- context.abort(grpc.StatusCode.PERMISSION_DENIED, "Unexpected object length")
205
+ context.abort(
206
+ grpc.StatusCode.FAILED_PRECONDITION, "Unexpected object length"
207
+ )
208
+
209
+ # Init store
210
+ store = self.objectstore_factory.store()
211
+
212
+ # Insert in store
213
+ stored = False
214
+ try:
215
+ store.put(request.object_id, request.object_content)
216
+ stored = True
217
+ except (NoObjectInStoreError, ValueError) as e:
218
+ log(ERROR, str(e))
189
219
 
190
- return PushObjectResponse()
220
+ return PushObjectResponse(stored=stored)
191
221
 
192
222
  def PullObject(
193
223
  self, request: PullObjectRequest, context: grpc.ServicerContext
@@ -199,4 +229,31 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
199
229
  request.object_id,
200
230
  )
201
231
 
202
- return PullObjectResponse()
232
+ state = self.state_factory.state()
233
+
234
+ # Abort if the run is not running
235
+ abort_msg = check_abort(
236
+ request.run_id,
237
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
238
+ state,
239
+ )
240
+ if abort_msg:
241
+ abort_grpc_context(abort_msg, context)
242
+
243
+ if request.node.node_id not in state.get_nodes(run_id=request.run_id):
244
+ # Cancel insertion in ObjectStore
245
+ context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Unexpected node ID.")
246
+
247
+ # Init store
248
+ store = self.objectstore_factory.store()
249
+
250
+ # Fetch from store
251
+ content = store.get(request.object_id)
252
+ if content is not None:
253
+ object_available = content != b""
254
+ return PullObjectResponse(
255
+ object_found=True,
256
+ object_available=object_available,
257
+ object_content=content,
258
+ )
259
+ return PullObjectResponse(object_found=False, object_available=False)
@@ -114,7 +114,7 @@ async def pull_message(request: PullMessagesRequest) -> PullMessagesResponse:
114
114
  """Pull PullMessages."""
115
115
  # Get state from app
116
116
  state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
117
- store: ObjectStore = cast(ObjectStoreFactory, app.state.STATE_FACTORY).store()
117
+ store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
118
118
 
119
119
  # Handle message
120
120
  return message_handler.pull_messages(request=request, state=state, store=store)
@@ -125,7 +125,7 @@ async def push_message(request: PushMessagesRequest) -> PushMessagesResponse:
125
125
  """Pull PushMessages."""
126
126
  # Get state from app
127
127
  state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
128
- store: ObjectStore = cast(ObjectStoreFactory, app.state.STATE_FACTORY).store()
128
+ store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
129
129
 
130
130
  # Handle message
131
131
  return message_handler.push_messages(request=request, state=state, store=store)
@@ -409,11 +409,39 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
409
409
  """Push an object to the ObjectStore."""
410
410
  log(DEBUG, "ServerAppIoServicer.PushObject")
411
411
 
412
+ # Init state
413
+ state: LinkState = self.state_factory.state()
414
+
415
+ # Abort if the run is not running
416
+ abort_if(
417
+ request.run_id,
418
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
419
+ state,
420
+ context,
421
+ )
422
+
423
+ if request.node.node_id != SUPERLINK_NODE_ID:
424
+ # Cancel insertion in ObjectStore
425
+ context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Unexpected node ID.")
426
+
412
427
  if not check_body_len_consistency(request.object_content):
413
428
  # Cancel insertion in ObjectStore
414
- context.abort(grpc.StatusCode.PERMISSION_DENIED, "Unexpected object length")
429
+ context.abort(
430
+ grpc.StatusCode.FAILED_PRECONDITION, "Unexpected object length."
431
+ )
432
+
433
+ # Init store
434
+ store = self.objectstore_factory.store()
435
+
436
+ # Insert in store
437
+ stored = False
438
+ try:
439
+ store.put(request.object_id, request.object_content)
440
+ stored = True
441
+ except (NoObjectInStoreError, ValueError) as e:
442
+ log(ERROR, str(e))
415
443
 
416
- return PushObjectResponse()
444
+ return PushObjectResponse(stored=stored)
417
445
 
418
446
  def PullObject(
419
447
  self, request: PullObjectRequest, context: grpc.ServicerContext
@@ -421,7 +449,34 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
421
449
  """Pull an object from the ObjectStore."""
422
450
  log(DEBUG, "ServerAppIoServicer.PullObject")
423
451
 
424
- return PullObjectResponse()
452
+ # Init state
453
+ state: LinkState = self.state_factory.state()
454
+
455
+ # Abort if the run is not running
456
+ abort_if(
457
+ request.run_id,
458
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
459
+ state,
460
+ context,
461
+ )
462
+
463
+ if request.node.node_id != SUPERLINK_NODE_ID:
464
+ # Cancel insertion in ObjectStore
465
+ context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Unexpected node ID.")
466
+
467
+ # Init store
468
+ store = self.objectstore_factory.store()
469
+
470
+ # Fetch from store
471
+ content = store.get(request.object_id)
472
+ if content is not None:
473
+ object_available = content != b""
474
+ return PullObjectResponse(
475
+ object_found=True,
476
+ object_available=object_available,
477
+ object_content=content,
478
+ )
479
+ return PullObjectResponse(object_found=False, object_available=False)
425
480
 
426
481
 
427
482
  def _raise_if(validation_error: bool, request_name: str, detail: str) -> None:
@@ -16,5 +16,9 @@
16
16
 
17
17
 
18
18
  from .flower_supernode import flower_supernode
19
+ from .flwr_clientapp import flwr_clientapp
19
20
 
20
- __all__ = ["flower_supernode"]
21
+ __all__ = [
22
+ "flower_supernode",
23
+ "flwr_clientapp",
24
+ ]
@@ -42,8 +42,7 @@ from flwr.common.constant import (
42
42
  from flwr.common.exit import ExitCode, flwr_exit
43
43
  from flwr.common.exit_handlers import register_exit_handlers
44
44
  from flwr.common.logger import log
45
-
46
- from ..start_client_internal import start_client_internal
45
+ from flwr.supernode.start_client_internal import start_client_internal
47
46
 
48
47
 
49
48
  def flower_supernode() -> None:
@@ -0,0 +1,73 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """`flwr-clientapp` command."""
16
+
17
+
18
+ import argparse
19
+ from logging import DEBUG, INFO
20
+
21
+ from flwr.common.args import add_args_flwr_app_common
22
+ from flwr.common.constant import CLIENTAPPIO_API_DEFAULT_CLIENT_ADDRESS
23
+ from flwr.common.exit import ExitCode, flwr_exit
24
+ from flwr.common.logger import log
25
+ from flwr.supernode.runtime.run_clientapp import run_clientapp
26
+
27
+
28
+ def flwr_clientapp() -> None:
29
+ """Run process-isolated Flower ClientApp."""
30
+ args = _parse_args_run_flwr_clientapp().parse_args()
31
+ if not args.insecure:
32
+ flwr_exit(
33
+ ExitCode.COMMON_TLS_NOT_SUPPORTED,
34
+ "flwr-clientapp does not support TLS yet.",
35
+ )
36
+
37
+ log(INFO, "Start `flwr-clientapp` process")
38
+ log(
39
+ DEBUG,
40
+ "`flwr-clientapp` will attempt to connect to SuperNode's "
41
+ "ClientAppIo API at %s with token %s",
42
+ args.clientappio_api_address,
43
+ args.token,
44
+ )
45
+ run_clientapp(
46
+ clientappio_api_address=args.clientappio_api_address,
47
+ run_once=(args.token is not None),
48
+ token=args.token,
49
+ flwr_dir=args.flwr_dir,
50
+ certificates=None,
51
+ )
52
+
53
+
54
+ def _parse_args_run_flwr_clientapp() -> argparse.ArgumentParser:
55
+ """Parse flwr-clientapp command line arguments."""
56
+ parser = argparse.ArgumentParser(
57
+ description="Run a Flower ClientApp",
58
+ )
59
+ parser.add_argument(
60
+ "--clientappio-api-address",
61
+ default=CLIENTAPPIO_API_DEFAULT_CLIENT_ADDRESS,
62
+ type=str,
63
+ help="Address of SuperNode's ClientAppIo API (IPv4, IPv6, or a domain name)."
64
+ f"By default, it is set to {CLIENTAPPIO_API_DEFAULT_CLIENT_ADDRESS}.",
65
+ )
66
+ parser.add_argument(
67
+ "--token",
68
+ type=int,
69
+ required=False,
70
+ help="Unique token generated by SuperNode for each ClientApp execution",
71
+ )
72
+ add_args_flwr_app_common(parser=parser)
73
+ return parser
@@ -15,17 +15,40 @@
15
15
  """In-memory NodeState implementation."""
16
16
 
17
17
 
18
+ from collections.abc import Sequence
19
+ from dataclasses import dataclass
20
+ from threading import Lock
18
21
  from typing import Optional
19
22
 
23
+ from flwr.common import Context, Message
24
+ from flwr.common.typing import Run
25
+
20
26
  from .nodestate import NodeState
21
27
 
22
28
 
29
+ @dataclass
30
+ class MessageEntry:
31
+ """Data class to represent a message entry."""
32
+
33
+ message: Message
34
+ is_retrieved: bool = False
35
+
36
+
23
37
  class InMemoryNodeState(NodeState):
24
38
  """In-memory NodeState implementation."""
25
39
 
26
40
  def __init__(self) -> None:
27
41
  # Store node_id
28
42
  self.node_id: Optional[int] = None
43
+ # Store Object ID to MessageEntry mapping
44
+ self.msg_store: dict[str, MessageEntry] = {}
45
+ self.lock_msg_store = Lock()
46
+ # Store run ID to Run mapping
47
+ self.run_store: dict[int, Run] = {}
48
+ self.lock_run_store = Lock()
49
+ # Store run ID to Context mapping
50
+ self.ctx_store: dict[int, Context] = {}
51
+ self.lock_ctx_store = Lock()
29
52
 
30
53
  def set_node_id(self, node_id: Optional[int]) -> None:
31
54
  """Set the node ID."""
@@ -36,3 +59,92 @@ class InMemoryNodeState(NodeState):
36
59
  if self.node_id is None:
37
60
  raise ValueError("Node ID not set")
38
61
  return self.node_id
62
+
63
+ def store_message(self, message: Message) -> Optional[str]:
64
+ """Store a message."""
65
+ with self.lock_msg_store:
66
+ msg_id = message.metadata.message_id
67
+ if msg_id == "" or msg_id in self.msg_store:
68
+ return None
69
+ self.msg_store[msg_id] = MessageEntry(message=message)
70
+ return msg_id
71
+
72
+ def get_messages(
73
+ self,
74
+ *,
75
+ run_ids: Optional[Sequence[int]] = None,
76
+ is_reply: Optional[bool] = None,
77
+ limit: Optional[int] = None,
78
+ ) -> Sequence[Message]:
79
+ """Retrieve messages based on the specified filters."""
80
+ selected_messages: list[Message] = []
81
+
82
+ with self.lock_msg_store:
83
+ # Iterate through all messages in the store
84
+ for object_id in list(self.msg_store.keys()):
85
+ entry = self.msg_store[object_id]
86
+ message = entry.message
87
+
88
+ # Skip messages that have already been retrieved
89
+ if entry.is_retrieved:
90
+ continue
91
+
92
+ # Skip messages whose run_id doesn't match the filter
93
+ if run_ids is not None:
94
+ if message.metadata.run_id not in run_ids:
95
+ continue
96
+
97
+ # If is_reply filter is set, filter for reply/non-reply messages
98
+ if is_reply is not None:
99
+ is_reply_message = message.metadata.reply_to_message_id != ""
100
+ # XOR logic to filter mismatched types (reply vs non-reply)
101
+ if is_reply ^ is_reply_message:
102
+ continue
103
+
104
+ # Add the message to the result set
105
+ selected_messages.append(message)
106
+
107
+ # Mark the message as retrieved
108
+ entry.is_retrieved = True
109
+
110
+ # Stop if the number of collected messages reaches the limit
111
+ if limit is not None and len(selected_messages) >= limit:
112
+ break
113
+
114
+ return selected_messages
115
+
116
+ def delete_messages(
117
+ self,
118
+ *,
119
+ message_ids: Optional[Sequence[str]] = None,
120
+ ) -> None:
121
+ """Delete messages based on the specified filters."""
122
+ with self.lock_msg_store:
123
+ if message_ids is None:
124
+ # If no message IDs are provided, clear the entire store
125
+ self.msg_store.clear()
126
+ return
127
+
128
+ # Remove specified messages from the store
129
+ for msg_id in message_ids:
130
+ self.msg_store.pop(msg_id, None)
131
+
132
+ def store_run(self, run: Run) -> None:
133
+ """Store a run."""
134
+ with self.lock_run_store:
135
+ self.run_store[run.run_id] = run
136
+
137
+ def get_run(self, run_id: int) -> Optional[Run]:
138
+ """Retrieve a run by its ID."""
139
+ with self.lock_run_store:
140
+ return self.run_store.get(run_id)
141
+
142
+ def store_context(self, context: Context) -> None:
143
+ """Store a context."""
144
+ with self.lock_ctx_store:
145
+ self.ctx_store[context.run_id] = context
146
+
147
+ def get_context(self, run_id: int) -> Optional[Context]:
148
+ """Retrieve a context by its run ID."""
149
+ with self.lock_ctx_store:
150
+ return self.ctx_store.get(run_id)