flwr 1.13.1__py3-none-any.whl → 1.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. flwr/cli/app.py +5 -0
  2. flwr/cli/auth_plugin/__init__.py +31 -0
  3. flwr/cli/auth_plugin/oidc_cli_plugin.py +150 -0
  4. flwr/cli/build.py +1 -0
  5. flwr/cli/cli_user_auth_interceptor.py +90 -0
  6. flwr/cli/config_utils.py +43 -149
  7. flwr/cli/constant.py +27 -0
  8. flwr/cli/example.py +1 -0
  9. flwr/cli/install.py +2 -1
  10. flwr/cli/log.py +34 -37
  11. flwr/cli/login/__init__.py +22 -0
  12. flwr/cli/login/login.py +116 -0
  13. flwr/cli/ls.py +214 -106
  14. flwr/cli/new/__init__.py +1 -0
  15. flwr/cli/new/new.py +2 -1
  16. flwr/cli/new/templates/app/.gitignore.tpl +3 -0
  17. flwr/cli/new/templates/app/README.md.tpl +3 -2
  18. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
  19. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +4 -4
  20. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
  21. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +2 -2
  22. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +3 -4
  23. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +2 -2
  24. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +4 -4
  25. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +3 -3
  26. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
  27. flwr/cli/run/__init__.py +1 -0
  28. flwr/cli/run/run.py +103 -43
  29. flwr/cli/stop.py +139 -0
  30. flwr/cli/utils.py +186 -8
  31. flwr/client/app.py +49 -50
  32. flwr/client/client.py +1 -32
  33. flwr/client/clientapp/app.py +23 -26
  34. flwr/client/clientapp/utils.py +2 -1
  35. flwr/client/grpc_adapter_client/connection.py +1 -1
  36. flwr/client/grpc_client/connection.py +2 -13
  37. flwr/client/grpc_rere_client/client_interceptor.py +19 -119
  38. flwr/client/grpc_rere_client/connection.py +59 -43
  39. flwr/client/grpc_rere_client/grpc_adapter.py +12 -12
  40. flwr/client/message_handler/message_handler.py +1 -2
  41. flwr/client/message_handler/task_handler.py +0 -17
  42. flwr/client/mod/comms_mods.py +1 -0
  43. flwr/client/mod/localdp_mod.py +1 -1
  44. flwr/client/nodestate/__init__.py +1 -0
  45. flwr/client/nodestate/nodestate.py +1 -0
  46. flwr/client/nodestate/nodestate_factory.py +1 -0
  47. flwr/client/numpy_client.py +0 -44
  48. flwr/client/rest_client/connection.py +37 -29
  49. flwr/client/supernode/app.py +20 -74
  50. flwr/common/address.py +1 -0
  51. flwr/common/args.py +26 -47
  52. flwr/common/auth_plugin/__init__.py +24 -0
  53. flwr/common/auth_plugin/auth_plugin.py +122 -0
  54. flwr/common/config.py +169 -17
  55. flwr/common/constant.py +38 -9
  56. flwr/common/differential_privacy.py +2 -1
  57. flwr/common/exit/__init__.py +24 -0
  58. flwr/common/exit/exit.py +99 -0
  59. flwr/common/exit/exit_code.py +93 -0
  60. flwr/common/exit_handlers.py +24 -10
  61. flwr/common/grpc.py +167 -4
  62. flwr/common/logger.py +66 -7
  63. flwr/common/message.py +1 -0
  64. flwr/common/object_ref.py +57 -54
  65. flwr/common/pyproject.py +1 -0
  66. flwr/common/record/__init__.py +1 -0
  67. flwr/common/record/parametersrecord.py +1 -0
  68. flwr/common/record/recordset.py +1 -1
  69. flwr/common/retry_invoker.py +77 -0
  70. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +45 -0
  71. flwr/common/secure_aggregation/secaggplus_utils.py +2 -2
  72. flwr/common/serde.py +6 -4
  73. flwr/common/telemetry.py +15 -4
  74. flwr/common/typing.py +32 -0
  75. flwr/common/version.py +1 -0
  76. flwr/proto/clientappio_pb2.py +1 -1
  77. flwr/proto/error_pb2.py +1 -1
  78. flwr/proto/exec_pb2.py +27 -15
  79. flwr/proto/exec_pb2.pyi +80 -2
  80. flwr/proto/exec_pb2_grpc.py +102 -0
  81. flwr/proto/exec_pb2_grpc.pyi +39 -0
  82. flwr/proto/fab_pb2.py +5 -5
  83. flwr/proto/fab_pb2.pyi +4 -1
  84. flwr/proto/fleet_pb2.py +31 -31
  85. flwr/proto/fleet_pb2.pyi +23 -23
  86. flwr/proto/fleet_pb2_grpc.py +30 -30
  87. flwr/proto/fleet_pb2_grpc.pyi +20 -20
  88. flwr/proto/grpcadapter_pb2.py +1 -1
  89. flwr/proto/log_pb2.py +1 -1
  90. flwr/proto/message_pb2.py +1 -1
  91. flwr/proto/node_pb2.py +3 -3
  92. flwr/proto/node_pb2.pyi +1 -4
  93. flwr/proto/recordset_pb2.py +1 -1
  94. flwr/proto/run_pb2.py +1 -1
  95. flwr/proto/serverappio_pb2.py +24 -25
  96. flwr/proto/serverappio_pb2.pyi +32 -32
  97. flwr/proto/serverappio_pb2_grpc.py +62 -28
  98. flwr/proto/serverappio_pb2_grpc.pyi +29 -16
  99. flwr/proto/simulationio_pb2.py +3 -3
  100. flwr/proto/simulationio_pb2_grpc.py +34 -0
  101. flwr/proto/simulationio_pb2_grpc.pyi +13 -0
  102. flwr/proto/task_pb2.py +1 -1
  103. flwr/proto/transport_pb2.py +1 -1
  104. flwr/server/app.py +152 -112
  105. flwr/server/compat/app_utils.py +7 -2
  106. flwr/server/compat/driver_client_proxy.py +1 -2
  107. flwr/server/driver/grpc_driver.py +38 -85
  108. flwr/server/driver/inmemory_driver.py +7 -2
  109. flwr/server/run_serverapp.py +8 -9
  110. flwr/server/serverapp/app.py +37 -13
  111. flwr/server/strategy/dpfedavg_fixed.py +1 -0
  112. flwr/server/superlink/driver/serverappio_grpc.py +2 -1
  113. flwr/server/superlink/driver/serverappio_servicer.py +148 -63
  114. flwr/server/superlink/ffs/disk_ffs.py +1 -0
  115. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +20 -87
  116. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -0
  117. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -165
  118. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +56 -35
  119. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +99 -169
  120. flwr/server/superlink/fleet/message_handler/message_handler.py +69 -29
  121. flwr/server/superlink/fleet/rest_rere/rest_api.py +20 -19
  122. flwr/server/superlink/fleet/vce/__init__.py +1 -0
  123. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -0
  124. flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -0
  125. flwr/server/superlink/fleet/vce/vce_api.py +2 -2
  126. flwr/server/superlink/linkstate/in_memory_linkstate.py +60 -99
  127. flwr/server/superlink/linkstate/linkstate.py +30 -36
  128. flwr/server/superlink/linkstate/sqlite_linkstate.py +105 -188
  129. flwr/server/superlink/linkstate/utils.py +18 -8
  130. flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
  131. flwr/server/superlink/simulation/simulationio_servicer.py +33 -0
  132. flwr/server/superlink/utils.py +65 -0
  133. flwr/server/utils/validator.py +9 -34
  134. flwr/simulation/app.py +20 -10
  135. flwr/simulation/legacy_app.py +4 -2
  136. flwr/simulation/ray_transport/ray_actor.py +1 -0
  137. flwr/simulation/ray_transport/utils.py +1 -0
  138. flwr/simulation/run_simulation.py +36 -22
  139. flwr/simulation/simulationio_connection.py +5 -1
  140. flwr/superexec/app.py +1 -0
  141. flwr/superexec/deployment.py +1 -0
  142. flwr/superexec/exec_grpc.py +20 -2
  143. flwr/superexec/exec_servicer.py +97 -2
  144. flwr/superexec/exec_user_auth_interceptor.py +101 -0
  145. flwr/superexec/executor.py +1 -0
  146. {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/METADATA +14 -13
  147. {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/RECORD +150 -144
  148. flwr/proto/common_pb2.py +0 -36
  149. flwr/proto/common_pb2.pyi +0 -121
  150. flwr/proto/common_pb2_grpc.py +0 -4
  151. flwr/proto/common_pb2_grpc.pyi +0 -4
  152. flwr/proto/control_pb2.py +0 -27
  153. flwr/proto/control_pb2.pyi +0 -7
  154. flwr/proto/control_pb2_grpc.py +0 -135
  155. flwr/proto/control_pb2_grpc.pyi +0 -53
  156. {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/LICENSE +0 -0
  157. {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/WHEEL +0 -0
  158. {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/entry_points.txt +0 -0
@@ -16,14 +16,13 @@
16
16
 
17
17
 
18
18
  import threading
19
- import time
20
19
  from logging import DEBUG, INFO
21
20
  from typing import Optional
22
21
  from uuid import UUID
23
22
 
24
23
  import grpc
25
24
 
26
- from flwr.common import ConfigsRecord
25
+ from flwr.common import ConfigsRecord, now
27
26
  from flwr.common.constant import Status
28
27
  from flwr.common.logger import log
29
28
  from flwr.common.serde import (
@@ -31,7 +30,12 @@ from flwr.common.serde import (
31
30
  context_to_proto,
32
31
  fab_from_proto,
33
32
  fab_to_proto,
33
+ message_from_proto,
34
+ message_from_taskres,
35
+ message_to_proto,
36
+ message_to_taskins,
34
37
  run_status_from_proto,
38
+ run_status_to_proto,
35
39
  run_to_proto,
36
40
  user_config_from_proto,
37
41
  )
@@ -48,25 +52,28 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
48
52
  CreateRunResponse,
49
53
  GetRunRequest,
50
54
  GetRunResponse,
55
+ GetRunStatusRequest,
56
+ GetRunStatusResponse,
51
57
  UpdateRunStatusRequest,
52
58
  UpdateRunStatusResponse,
53
59
  )
54
60
  from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
55
61
  GetNodesRequest,
56
62
  GetNodesResponse,
63
+ PullResMessagesRequest,
64
+ PullResMessagesResponse,
57
65
  PullServerAppInputsRequest,
58
66
  PullServerAppInputsResponse,
59
- PullTaskResRequest,
60
- PullTaskResResponse,
67
+ PushInsMessagesRequest,
68
+ PushInsMessagesResponse,
61
69
  PushServerAppOutputsRequest,
62
70
  PushServerAppOutputsResponse,
63
- PushTaskInsRequest,
64
- PushTaskInsResponse,
65
71
  )
66
72
  from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
67
73
  from flwr.server.superlink.ffs.ffs import Ffs
68
74
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
69
75
  from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
76
+ from flwr.server.superlink.utils import abort_if
70
77
  from flwr.server.utils.validator import validate_task_ins_or_res
71
78
 
72
79
 
@@ -85,11 +92,20 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
85
92
  ) -> GetNodesResponse:
86
93
  """Get available nodes."""
87
94
  log(DEBUG, "ServerAppIoServicer.GetNodes")
95
+
96
+ # Init state
88
97
  state: LinkState = self.state_factory.state()
98
+
99
+ # Abort if the run is not running
100
+ abort_if(
101
+ request.run_id,
102
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
103
+ state,
104
+ context,
105
+ )
106
+
89
107
  all_ids: set[int] = state.get_nodes(request.run_id)
90
- nodes: list[Node] = [
91
- Node(node_id=node_id, anonymous=False) for node_id in all_ids
92
- ]
108
+ nodes: list[Node] = [Node(node_id=node_id) for node_id in all_ids]
93
109
  return GetNodesResponse(nodes=nodes)
94
110
 
95
111
  def CreateRun(
@@ -103,8 +119,9 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
103
119
  ffs: Ffs = self.ffs_factory.ffs()
104
120
  fab_hash = ffs.put(fab.content, {})
105
121
  _raise_if(
106
- fab_hash != fab.hash_str,
107
- f"FAB ({fab.hash_str}) hash from request doesn't match contents",
122
+ validation_error=fab_hash != fab.hash_str,
123
+ request_name="CreateRun",
124
+ detail=f"FAB ({fab.hash_str}) hash from request doesn't match contents",
108
125
  )
109
126
  else:
110
127
  fab_hash = ""
@@ -117,70 +134,104 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
117
134
  )
118
135
  return CreateRunResponse(run_id=run_id)
119
136
 
120
- def PushTaskIns(
121
- self, request: PushTaskInsRequest, context: grpc.ServicerContext
122
- ) -> PushTaskInsResponse:
123
- """Push a set of TaskIns."""
124
- log(DEBUG, "ServerAppIoServicer.PushTaskIns")
137
+ def PushMessages(
138
+ self, request: PushInsMessagesRequest, context: grpc.ServicerContext
139
+ ) -> PushInsMessagesResponse:
140
+ """Push a set of Messages."""
141
+ log(DEBUG, "ServerAppIoServicer.PushMessages")
142
+
143
+ # Init state
144
+ state: LinkState = self.state_factory.state()
145
+
146
+ # Abort if the run is not running
147
+ abort_if(
148
+ request.run_id,
149
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
150
+ state,
151
+ context,
152
+ )
125
153
 
126
154
  # Set pushed_at (timestamp in seconds)
127
- pushed_at = time.time()
128
- for task_ins in request.task_ins_list:
129
- task_ins.task.pushed_at = pushed_at
155
+ pushed_at = now().timestamp()
130
156
 
131
- # Validate request
132
- _raise_if(len(request.task_ins_list) == 0, "`task_ins_list` must not be empty")
133
- for task_ins in request.task_ins_list:
157
+ # Validate request and insert in State
158
+ _raise_if(
159
+ validation_error=len(request.messages_list) == 0,
160
+ request_name="PushMessages",
161
+ detail="`messages_list` must not be empty",
162
+ )
163
+ message_ids: list[Optional[UUID]] = []
164
+ while request.messages_list:
165
+ message_proto = request.messages_list.pop(0)
166
+ message = message_from_proto(message_proto=message_proto)
167
+ task_ins = message_to_taskins(message=message)
168
+ task_ins.task.pushed_at = pushed_at
134
169
  validation_errors = validate_task_ins_or_res(task_ins)
135
- _raise_if(bool(validation_errors), ", ".join(validation_errors))
170
+ _raise_if(
171
+ validation_error=bool(validation_errors),
172
+ request_name="PushMessages",
173
+ detail=", ".join(validation_errors),
174
+ )
175
+ _raise_if(
176
+ validation_error=request.run_id != task_ins.run_id,
177
+ request_name="PushMessages",
178
+ detail="`task_ins` has mismatched `run_id`",
179
+ )
180
+ # Store
181
+ message_id: Optional[UUID] = state.store_task_ins(task_ins=task_ins)
182
+ message_ids.append(message_id)
183
+
184
+ return PushInsMessagesResponse(
185
+ message_ids=[
186
+ str(message_id) if message_id else "" for message_id in message_ids
187
+ ]
188
+ )
189
+
190
+ def PullMessages(
191
+ self, request: PullResMessagesRequest, context: grpc.ServicerContext
192
+ ) -> PullResMessagesResponse:
193
+ """Pull a set of Messages."""
194
+ log(DEBUG, "ServerAppIoServicer.PullMessages")
136
195
 
137
196
  # Init state
138
197
  state: LinkState = self.state_factory.state()
139
198
 
140
- # Store each TaskIns
141
- task_ids: list[Optional[UUID]] = []
142
- for task_ins in request.task_ins_list:
143
- task_id: Optional[UUID] = state.store_task_ins(task_ins=task_ins)
144
- task_ids.append(task_id)
145
-
146
- return PushTaskInsResponse(
147
- task_ids=[str(task_id) if task_id else "" for task_id in task_ids]
199
+ # Abort if the run is not running
200
+ abort_if(
201
+ request.run_id,
202
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
203
+ state,
204
+ context,
148
205
  )
149
206
 
150
- def PullTaskRes(
151
- self, request: PullTaskResRequest, context: grpc.ServicerContext
152
- ) -> PullTaskResResponse:
153
- """Pull a set of TaskRes."""
154
- log(DEBUG, "ServerAppIoServicer.PullTaskRes")
155
-
156
207
  # Convert each task_id str to UUID
157
- task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}
208
+ message_ids: set[UUID] = {
209
+ UUID(message_id) for message_id in request.message_ids
210
+ }
158
211
 
159
- # Init state
160
- state: LinkState = self.state_factory.state()
212
+ # Read from state
213
+ task_res_list: list[TaskRes] = state.get_task_res(task_ids=message_ids)
161
214
 
162
- # Register callback
163
- def on_rpc_done() -> None:
164
- log(
165
- DEBUG,
166
- "ServerAppIoServicer.PullTaskRes callback: delete TaskIns/TaskRes",
215
+ # Convert to Messages
216
+ messages_list = []
217
+ while task_res_list:
218
+ task_res = task_res_list.pop(0)
219
+ _raise_if(
220
+ validation_error=request.run_id != task_res.run_id,
221
+ request_name="PullMessages",
222
+ detail="`task_res` has mismatched `run_id`",
167
223
  )
224
+ message = message_from_taskres(taskres=task_res)
225
+ messages_list.append(message_to_proto(message))
168
226
 
169
- if context.is_active():
170
- return
171
- if context.code() != grpc.StatusCode.OK:
172
- return
227
+ # Delete the TaskIns/TaskRes pairs if TaskRes is found
228
+ task_ins_ids_to_delete = {
229
+ UUID(task_res.task.ancestry[0]) for task_res in task_res_list
230
+ }
173
231
 
174
- # Delete delivered TaskIns and TaskRes
175
- state.delete_tasks(task_ids=task_ids)
232
+ state.delete_tasks(task_ins_ids=task_ins_ids_to_delete)
176
233
 
177
- context.add_callback(on_rpc_done)
178
-
179
- # Read from state
180
- task_res_list: list[TaskRes] = state.get_task_res(task_ids=task_ids)
181
-
182
- context.set_code(grpc.StatusCode.OK)
183
- return PullTaskResResponse(task_res_list=task_res_list)
234
+ return PullResMessagesResponse(messages_list=messages_list)
184
235
 
185
236
  def GetRun(
186
237
  self, request: GetRunRequest, context: grpc.ServicerContext
@@ -217,9 +268,8 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
217
268
  ) -> PullServerAppInputsResponse:
218
269
  """Pull ServerApp process inputs."""
219
270
  log(DEBUG, "ServerAppIoServicer.PullServerAppInputs")
220
- # Init access to LinkState and Ffs
271
+ # Init access to LinkState
221
272
  state = self.state_factory.state()
222
- ffs = self.ffs_factory.ffs()
223
273
 
224
274
  # Lock access to LinkState, preventing obtaining the same pending run_id
225
275
  with self.lock:
@@ -229,6 +279,9 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
229
279
  if run_id is None:
230
280
  return PullServerAppInputsResponse()
231
281
 
282
+ # Init access to Ffs
283
+ ffs = self.ffs_factory.ffs()
284
+
232
285
  # Retrieve Context, Run and Fab for the run_id
233
286
  serverapp_ctxt = state.get_serverapp_context(run_id)
234
287
  run = state.get_run(run_id)
@@ -255,7 +308,18 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
255
308
  ) -> PushServerAppOutputsResponse:
256
309
  """Push ServerApp process outputs."""
257
310
  log(DEBUG, "ServerAppIoServicer.PushServerAppOutputs")
311
+
312
+ # Init state
258
313
  state = self.state_factory.state()
314
+
315
+ # Abort if the run is not running
316
+ abort_if(
317
+ request.run_id,
318
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
319
+ state,
320
+ context,
321
+ )
322
+
259
323
  state.set_serverapp_context(request.run_id, context_from_proto(request.context))
260
324
  return PushServerAppOutputsResponse()
261
325
 
@@ -263,9 +327,14 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
263
327
  self, request: UpdateRunStatusRequest, context: grpc.ServicerContext
264
328
  ) -> UpdateRunStatusResponse:
265
329
  """Update the status of a run."""
266
- log(DEBUG, "ControlServicer.UpdateRunStatus")
330
+ log(DEBUG, "ServerAppIoServicer.UpdateRunStatus")
331
+
332
+ # Init state
267
333
  state = self.state_factory.state()
268
334
 
335
+ # Abort if the run is finished
336
+ abort_if(request.run_id, [Status.FINISHED], state, context)
337
+
269
338
  # Update the run status
270
339
  state.update_run_status(
271
340
  run_id=request.run_id, new_status=run_status_from_proto(request.run_status)
@@ -284,7 +353,23 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
284
353
  state.add_serverapp_log(request.run_id, merged_logs)
285
354
  return PushLogsResponse()
286
355
 
356
+ def GetRunStatus(
357
+ self, request: GetRunStatusRequest, context: grpc.ServicerContext
358
+ ) -> GetRunStatusResponse:
359
+ """Get the status of a run."""
360
+ log(DEBUG, "ServerAppIoServicer.GetRunStatus")
361
+ state = self.state_factory.state()
362
+
363
+ # Get run status from LinkState
364
+ run_statuses = state.get_run_status(set(request.run_ids))
365
+ run_status_dict = {
366
+ run_id: run_status_to_proto(run_status)
367
+ for run_id, run_status in run_statuses.items()
368
+ }
369
+ return GetRunStatusResponse(run_status_dict=run_status_dict)
370
+
287
371
 
288
- def _raise_if(validation_error: bool, detail: str) -> None:
372
+ def _raise_if(validation_error: bool, request_name: str, detail: str) -> None:
373
+ """Raise a `ValueError` with a detailed message if a validation error occurs."""
289
374
  if validation_error:
290
- raise ValueError(f"Malformed PushTaskInsRequest: {detail}")
375
+ raise ValueError(f"Malformed {request_name}: {detail}")
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Disk based Flower File Storage."""
16
16
 
17
+
17
18
  import hashlib
18
19
  import json
19
20
  from pathlib import Path
@@ -15,7 +15,7 @@
15
15
  """Fleet API gRPC adapter servicer."""
16
16
 
17
17
 
18
- from logging import DEBUG, INFO
18
+ from logging import DEBUG
19
19
  from typing import Callable, TypeVar
20
20
 
21
21
  import grpc
@@ -31,35 +31,30 @@ from flwr.common.constant import (
31
31
  from flwr.common.logger import log
32
32
  from flwr.common.version import package_name, package_version
33
33
  from flwr.proto import grpcadapter_pb2_grpc # pylint: disable=E0611
34
- from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
34
+ from flwr.proto.fab_pb2 import GetFabRequest # pylint: disable=E0611
35
35
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
36
36
  CreateNodeRequest,
37
- CreateNodeResponse,
38
37
  DeleteNodeRequest,
39
- DeleteNodeResponse,
40
38
  PingRequest,
41
- PingResponse,
42
- PullTaskInsRequest,
43
- PullTaskInsResponse,
44
- PushTaskResRequest,
45
- PushTaskResResponse,
39
+ PullMessagesRequest,
40
+ PushMessagesRequest,
46
41
  )
47
42
  from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
48
- from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
49
- from flwr.server.superlink.ffs.ffs_factory import FfsFactory
50
- from flwr.server.superlink.fleet.message_handler import message_handler
51
- from flwr.server.superlink.linkstate import LinkStateFactory
43
+ from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611
44
+
45
+ from ..grpc_rere.fleet_servicer import FleetServicer
52
46
 
53
47
  T = TypeVar("T", bound=GrpcMessage)
54
48
 
55
49
 
56
50
  def _handle(
57
51
  msg_container: MessageContainer,
52
+ context: grpc.ServicerContext,
58
53
  request_type: type[T],
59
- handler: Callable[[T], GrpcMessage],
54
+ handler: Callable[[T, grpc.ServicerContext], GrpcMessage],
60
55
  ) -> MessageContainer:
61
56
  req = request_type.FromString(msg_container.grpc_message_content)
62
- res = handler(req)
57
+ res = handler(req, context)
63
58
  res_cls = res.__class__
64
59
  return MessageContainer(
65
60
  metadata={
@@ -74,88 +69,26 @@ def _handle(
74
69
  )
75
70
 
76
71
 
77
- class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer):
72
+ class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer, FleetServicer):
78
73
  """Fleet API via GrpcAdapter servicer."""
79
74
 
80
- def __init__(
81
- self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
82
- ) -> None:
83
- self.state_factory = state_factory
84
- self.ffs_factory = ffs_factory
85
-
86
75
  def SendReceive( # pylint: disable=too-many-return-statements
87
76
  self, request: MessageContainer, context: grpc.ServicerContext
88
77
  ) -> MessageContainer:
89
78
  """."""
90
79
  log(DEBUG, "GrpcAdapterServicer.SendReceive")
91
80
  if request.grpc_message_name == CreateNodeRequest.__qualname__:
92
- return _handle(request, CreateNodeRequest, self._create_node)
81
+ return _handle(request, context, CreateNodeRequest, self.CreateNode)
93
82
  if request.grpc_message_name == DeleteNodeRequest.__qualname__:
94
- return _handle(request, DeleteNodeRequest, self._delete_node)
83
+ return _handle(request, context, DeleteNodeRequest, self.DeleteNode)
95
84
  if request.grpc_message_name == PingRequest.__qualname__:
96
- return _handle(request, PingRequest, self._ping)
97
- if request.grpc_message_name == PullTaskInsRequest.__qualname__:
98
- return _handle(request, PullTaskInsRequest, self._pull_task_ins)
99
- if request.grpc_message_name == PushTaskResRequest.__qualname__:
100
- return _handle(request, PushTaskResRequest, self._push_task_res)
85
+ return _handle(request, context, PingRequest, self.Ping)
101
86
  if request.grpc_message_name == GetRunRequest.__qualname__:
102
- return _handle(request, GetRunRequest, self._get_run)
87
+ return _handle(request, context, GetRunRequest, self.GetRun)
103
88
  if request.grpc_message_name == GetFabRequest.__qualname__:
104
- return _handle(request, GetFabRequest, self._get_fab)
89
+ return _handle(request, context, GetFabRequest, self.GetFab)
90
+ if request.grpc_message_name == PullMessagesRequest.__qualname__:
91
+ return _handle(request, context, PullMessagesRequest, self.PullMessages)
92
+ if request.grpc_message_name == PushMessagesRequest.__qualname__:
93
+ return _handle(request, context, PushMessagesRequest, self.PushMessages)
105
94
  raise ValueError(f"Invalid grpc_message_name: {request.grpc_message_name}")
106
-
107
- def _create_node(self, request: CreateNodeRequest) -> CreateNodeResponse:
108
- """."""
109
- log(INFO, "GrpcAdapter.CreateNode")
110
- return message_handler.create_node(
111
- request=request,
112
- state=self.state_factory.state(),
113
- )
114
-
115
- def _delete_node(self, request: DeleteNodeRequest) -> DeleteNodeResponse:
116
- """."""
117
- log(INFO, "GrpcAdapter.DeleteNode")
118
- return message_handler.delete_node(
119
- request=request,
120
- state=self.state_factory.state(),
121
- )
122
-
123
- def _ping(self, request: PingRequest) -> PingResponse:
124
- """."""
125
- log(DEBUG, "GrpcAdapter.Ping")
126
- return message_handler.ping(
127
- request=request,
128
- state=self.state_factory.state(),
129
- )
130
-
131
- def _pull_task_ins(self, request: PullTaskInsRequest) -> PullTaskInsResponse:
132
- """Pull TaskIns."""
133
- log(INFO, "GrpcAdapter.PullTaskIns")
134
- return message_handler.pull_task_ins(
135
- request=request,
136
- state=self.state_factory.state(),
137
- )
138
-
139
- def _push_task_res(self, request: PushTaskResRequest) -> PushTaskResResponse:
140
- """Push TaskRes."""
141
- log(INFO, "GrpcAdapter.PushTaskRes")
142
- return message_handler.push_task_res(
143
- request=request,
144
- state=self.state_factory.state(),
145
- )
146
-
147
- def _get_run(self, request: GetRunRequest) -> GetRunResponse:
148
- """Get run information."""
149
- log(INFO, "GrpcAdapter.GetRun")
150
- return message_handler.get_run(
151
- request=request,
152
- state=self.state_factory.state(),
153
- )
154
-
155
- def _get_fab(self, request: GetFabRequest) -> GetFabResponse:
156
- """Get FAB."""
157
- log(INFO, "GrpcAdapter.GetFab")
158
- return message_handler.get_fab(
159
- request=request,
160
- ffs=self.ffs_factory.ffs(),
161
- )
@@ -18,6 +18,7 @@ Relevant knowledge for reading this modules code:
18
18
  - https://github.com/grpc/grpc/blob/master/doc/statuscodes.md
19
19
  """
20
20
 
21
+
21
22
  import uuid
22
23
  from collections.abc import Iterator
23
24
  from typing import Callable
@@ -15,49 +15,19 @@
15
15
  """Implements utility function to create a gRPC server."""
16
16
 
17
17
 
18
- import concurrent.futures
19
- import sys
20
- from collections.abc import Sequence
21
- from logging import ERROR
22
- from typing import Any, Callable, Optional, Union
18
+ from typing import Optional
23
19
 
24
20
  import grpc
25
21
 
26
22
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
27
- from flwr.common.address import is_port_in_use
28
- from flwr.common.logger import log
23
+ from flwr.common.grpc import generic_create_grpc_server
29
24
  from flwr.proto.transport_pb2_grpc import ( # pylint: disable=E0611
30
25
  add_FlowerServiceServicer_to_server,
31
26
  )
32
27
  from flwr.server.client_manager import ClientManager
33
- from flwr.server.superlink.driver.serverappio_servicer import ServerAppIoServicer
34
- from flwr.server.superlink.fleet.grpc_adapter.grpc_adapter_servicer import (
35
- GrpcAdapterServicer,
36
- )
37
28
  from flwr.server.superlink.fleet.grpc_bidi.flower_service_servicer import (
38
29
  FlowerServiceServicer,
39
30
  )
40
- from flwr.server.superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
41
-
42
- INVALID_CERTIFICATES_ERR_MSG = """
43
- When setting any of root_certificate, certificate, or private_key,
44
- all of them need to be set.
45
- """
46
-
47
- AddServicerToServerFn = Callable[..., Any]
48
-
49
-
50
- def valid_certificates(certificates: tuple[bytes, bytes, bytes]) -> bool:
51
- """Validate certificates tuple."""
52
- is_valid = (
53
- all(isinstance(certificate, bytes) for certificate in certificates)
54
- and len(certificates) == 3
55
- )
56
-
57
- if not is_valid:
58
- log(ERROR, INVALID_CERTIFICATES_ERR_MSG)
59
-
60
- return is_valid
61
31
 
62
32
 
63
33
  def start_grpc_server( # pylint: disable=too-many-arguments,R0917
@@ -154,136 +124,3 @@ def start_grpc_server( # pylint: disable=too-many-arguments,R0917
154
124
  server.start()
155
125
 
156
126
  return server
157
-
158
-
159
- def generic_create_grpc_server( # pylint: disable=too-many-arguments,R0917
160
- servicer_and_add_fn: Union[
161
- tuple[FleetServicer, AddServicerToServerFn],
162
- tuple[GrpcAdapterServicer, AddServicerToServerFn],
163
- tuple[FlowerServiceServicer, AddServicerToServerFn],
164
- tuple[ServerAppIoServicer, AddServicerToServerFn],
165
- ],
166
- server_address: str,
167
- max_concurrent_workers: int = 1000,
168
- max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
169
- keepalive_time_ms: int = 210000,
170
- certificates: Optional[tuple[bytes, bytes, bytes]] = None,
171
- interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
172
- ) -> grpc.Server:
173
- """Create a gRPC server with a single servicer.
174
-
175
- Parameters
176
- ----------
177
- servicer_and_add_fn : tuple
178
- A tuple holding a servicer implementation and a matching
179
- add_Servicer_to_server function.
180
- server_address : str
181
- Server address in the form of HOST:PORT e.g. "[::]:8080"
182
- max_concurrent_workers : int
183
- Maximum number of clients the server can process before returning
184
- RESOURCE_EXHAUSTED status (default: 1000)
185
- max_message_length : int
186
- Maximum message length that the server can send or receive.
187
- Int valued in bytes. -1 means unlimited. (default: GRPC_MAX_MESSAGE_LENGTH)
188
- keepalive_time_ms : int
189
- Flower uses a default gRPC keepalive time of 210000ms (3 minutes 30 seconds)
190
- because some cloud providers (for example, Azure) agressively clean up idle
191
- TCP connections by terminating them after some time (4 minutes in the case
192
- of Azure). Flower does not use application-level keepalive signals and relies
193
- on the assumption that the transport layer will fail in cases where the
194
- connection is no longer active. `keepalive_time_ms` can be used to customize
195
- the keepalive interval for specific environments. The default Flower gRPC
196
- keepalive of 210000 ms (3 minutes 30 seconds) ensures that Flower can keep
197
- the long running streaming connection alive in most environments. The actual
198
- gRPC default of this setting is 7200000 (2 hours), which results in dropped
199
- connections in some cloud environments.
200
-
201
- These settings are related to the issue described here:
202
- - https://github.com/grpc/proposal/blob/master/A8-client-side-keepalive.md
203
- - https://github.com/grpc/grpc/blob/master/doc/keepalive.md
204
- - https://grpc.io/docs/guides/performance/
205
-
206
- Mobile Flower clients may choose to increase this value if their server
207
- environment allows long-running idle TCP connections.
208
- (default: 210000)
209
- certificates : Tuple[bytes, bytes, bytes] (default: None)
210
- Tuple containing root certificate, server certificate, and private key to
211
- start a secure SSL-enabled server. The tuple is expected to have three bytes
212
- elements in the following order:
213
-
214
- * CA certificate.
215
- * server certificate.
216
- * server private key.
217
- interceptors : Optional[Sequence[grpc.ServerInterceptor]] (default: None)
218
- A list of gRPC interceptors.
219
-
220
- Returns
221
- -------
222
- server : grpc.Server
223
- A non-running instance of a gRPC server.
224
- """
225
- # Check if port is in use
226
- if is_port_in_use(server_address):
227
- sys.exit(f"Port in server address {server_address} is already in use.")
228
-
229
- # Deconstruct tuple into servicer and function
230
- servicer, add_servicer_to_server_fn = servicer_and_add_fn
231
-
232
- # Possible options:
233
- # https://github.com/grpc/grpc/blob/v1.43.x/include/grpc/impl/codegen/grpc_types.h
234
- options = [
235
- # Maximum number of concurrent incoming streams to allow on a http2
236
- # connection. Int valued.
237
- ("grpc.max_concurrent_streams", max(100, max_concurrent_workers)),
238
- # Maximum message length that the channel can send.
239
- # Int valued, bytes. -1 means unlimited.
240
- ("grpc.max_send_message_length", max_message_length),
241
- # Maximum message length that the channel can receive.
242
- # Int valued, bytes. -1 means unlimited.
243
- ("grpc.max_receive_message_length", max_message_length),
244
- # The gRPC default for this setting is 7200000 (2 hours). Flower uses a
245
- # customized default of 210000 (3 minutes and 30 seconds) to improve
246
- # compatibility with popular cloud providers. Mobile Flower clients may
247
- # choose to increase this value if their server environment allows
248
- # long-running idle TCP connections.
249
- ("grpc.keepalive_time_ms", keepalive_time_ms),
250
- # Setting this to zero will allow sending unlimited keepalive pings in between
251
- # sending actual data frames.
252
- ("grpc.http2.max_pings_without_data", 0),
253
- # Is it permissible to send keepalive pings from the client without
254
- # any outstanding streams. More explanation here:
255
- # https://github.com/adap/flower/pull/2197
256
- ("grpc.keepalive_permit_without_calls", 0),
257
- ]
258
-
259
- server = grpc.server(
260
- concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrent_workers),
261
- # Set the maximum number of concurrent RPCs this server will service before
262
- # returning RESOURCE_EXHAUSTED status, or None to indicate no limit.
263
- maximum_concurrent_rpcs=max_concurrent_workers,
264
- options=options,
265
- interceptors=interceptors,
266
- )
267
- add_servicer_to_server_fn(servicer, server)
268
-
269
- if certificates is not None:
270
- if not valid_certificates(certificates):
271
- sys.exit(1)
272
-
273
- root_certificate_b, certificate_b, private_key_b = certificates
274
-
275
- server_credentials = grpc.ssl_server_credentials(
276
- ((private_key_b, certificate_b),),
277
- root_certificates=root_certificate_b,
278
- # A boolean indicating whether or not to require clients to be
279
- # authenticated. May only be True if root_certificates is not None.
280
- # We are explicitly setting the current gRPC default to document
281
- # the option. For further reference see:
282
- # https://grpc.github.io/grpc/python/grpc.html#create-server-credentials
283
- require_client_auth=False,
284
- )
285
- server.add_secure_port(server_address, server_credentials)
286
- else:
287
- server.add_insecure_port(server_address)
288
-
289
- return server