flwr 1.22.0__py3-none-any.whl → 1.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. flwr/cli/app.py +15 -1
  2. flwr/cli/auth_plugin/__init__.py +15 -6
  3. flwr/cli/auth_plugin/auth_plugin.py +95 -0
  4. flwr/cli/auth_plugin/noop_auth_plugin.py +58 -0
  5. flwr/cli/auth_plugin/oidc_cli_plugin.py +16 -25
  6. flwr/cli/build.py +118 -47
  7. flwr/cli/{cli_user_auth_interceptor.py → cli_account_auth_interceptor.py} +6 -5
  8. flwr/cli/log.py +2 -2
  9. flwr/cli/login/login.py +34 -23
  10. flwr/cli/ls.py +13 -9
  11. flwr/cli/new/new.py +187 -35
  12. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  13. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  14. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  15. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  16. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  17. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  18. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  19. flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
  20. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  21. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  22. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
  23. flwr/cli/pull.py +2 -2
  24. flwr/cli/run/run.py +11 -7
  25. flwr/cli/stop.py +2 -2
  26. flwr/cli/supernode/__init__.py +25 -0
  27. flwr/cli/supernode/ls.py +260 -0
  28. flwr/cli/supernode/register.py +185 -0
  29. flwr/cli/supernode/unregister.py +138 -0
  30. flwr/cli/utils.py +92 -69
  31. flwr/client/__init__.py +2 -1
  32. flwr/client/grpc_adapter_client/connection.py +6 -8
  33. flwr/client/grpc_rere_client/connection.py +59 -31
  34. flwr/client/grpc_rere_client/grpc_adapter.py +28 -12
  35. flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +3 -6
  36. flwr/client/mod/secure_aggregation/secaggplus_mod.py +7 -5
  37. flwr/client/rest_client/connection.py +82 -37
  38. flwr/clientapp/__init__.py +1 -2
  39. flwr/{client/clientapp → clientapp}/utils.py +1 -1
  40. flwr/common/constant.py +53 -13
  41. flwr/common/exit/exit_code.py +20 -10
  42. flwr/common/inflatable_utils.py +10 -10
  43. flwr/common/record/array.py +3 -3
  44. flwr/common/record/arrayrecord.py +10 -1
  45. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
  46. flwr/common/serde.py +4 -2
  47. flwr/common/typing.py +7 -6
  48. flwr/compat/client/app.py +1 -1
  49. flwr/compat/client/grpc_client/connection.py +2 -2
  50. flwr/proto/control_pb2.py +48 -35
  51. flwr/proto/control_pb2.pyi +71 -5
  52. flwr/proto/control_pb2_grpc.py +102 -0
  53. flwr/proto/control_pb2_grpc.pyi +39 -0
  54. flwr/proto/fab_pb2.py +11 -7
  55. flwr/proto/fab_pb2.pyi +21 -1
  56. flwr/proto/fleet_pb2.py +31 -23
  57. flwr/proto/fleet_pb2.pyi +63 -23
  58. flwr/proto/fleet_pb2_grpc.py +98 -28
  59. flwr/proto/fleet_pb2_grpc.pyi +45 -13
  60. flwr/proto/node_pb2.py +3 -1
  61. flwr/proto/node_pb2.pyi +48 -0
  62. flwr/server/app.py +139 -114
  63. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +17 -7
  64. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +132 -38
  65. flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +27 -51
  66. flwr/server/superlink/fleet/message_handler/message_handler.py +67 -22
  67. flwr/server/superlink/fleet/rest_rere/rest_api.py +52 -31
  68. flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
  69. flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -1
  70. flwr/server/superlink/fleet/vce/vce_api.py +18 -5
  71. flwr/server/superlink/linkstate/in_memory_linkstate.py +167 -73
  72. flwr/server/superlink/linkstate/linkstate.py +107 -24
  73. flwr/server/superlink/linkstate/linkstate_factory.py +2 -1
  74. flwr/server/superlink/linkstate/sqlite_linkstate.py +306 -255
  75. flwr/server/superlink/linkstate/utils.py +3 -54
  76. flwr/server/superlink/serverappio/serverappio_servicer.py +2 -2
  77. flwr/server/superlink/simulation/simulationio_servicer.py +1 -1
  78. flwr/server/utils/validator.py +2 -3
  79. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +4 -2
  80. flwr/simulation/ray_transport/ray_actor.py +1 -1
  81. flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
  82. flwr/simulation/run_simulation.py +3 -2
  83. flwr/supercore/constant.py +22 -0
  84. flwr/supercore/object_store/in_memory_object_store.py +0 -4
  85. flwr/supercore/object_store/object_store_factory.py +26 -6
  86. flwr/supercore/object_store/sqlite_object_store.py +252 -0
  87. flwr/{client/clientapp → supercore/primitives}/__init__.py +1 -1
  88. flwr/supercore/primitives/asymmetric.py +117 -0
  89. flwr/supercore/primitives/asymmetric_ed25519.py +165 -0
  90. flwr/supercore/sqlite_mixin.py +156 -0
  91. flwr/supercore/utils.py +20 -0
  92. flwr/{common → superlink}/auth_plugin/__init__.py +6 -6
  93. flwr/superlink/auth_plugin/auth_plugin.py +91 -0
  94. flwr/superlink/auth_plugin/noop_auth_plugin.py +87 -0
  95. flwr/superlink/servicer/control/{control_user_auth_interceptor.py → control_account_auth_interceptor.py} +19 -19
  96. flwr/superlink/servicer/control/control_event_log_interceptor.py +1 -1
  97. flwr/superlink/servicer/control/control_grpc.py +13 -11
  98. flwr/superlink/servicer/control/control_servicer.py +152 -60
  99. flwr/supernode/cli/flower_supernode.py +19 -26
  100. flwr/supernode/runtime/run_clientapp.py +2 -2
  101. flwr/supernode/servicer/clientappio/clientappio_servicer.py +1 -1
  102. flwr/supernode/start_client_internal.py +17 -9
  103. {flwr-1.22.0.dist-info → flwr-1.23.0.dist-info}/METADATA +1 -1
  104. {flwr-1.22.0.dist-info → flwr-1.23.0.dist-info}/RECORD +107 -96
  105. flwr/common/auth_plugin/auth_plugin.py +0 -149
  106. /flwr/{client → clientapp}/client_app.py +0 -0
  107. {flwr-1.22.0.dist-info → flwr-1.23.0.dist-info}/WHEEL +0 -0
  108. {flwr-1.22.0.dist-info → flwr-1.23.0.dist-info}/entry_points.txt +0 -0
@@ -36,18 +36,27 @@ from flwr.common.inflatable_protobuf_utils import (
36
36
  from flwr.common.logger import log
37
37
  from flwr.common.message import Message, remove_content_from_message
38
38
  from flwr.common.retry_invoker import RetryInvoker
39
- from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
39
+ from flwr.common.serde import (
40
+ fab_from_proto,
41
+ message_from_proto,
42
+ message_to_proto,
43
+ run_from_proto,
44
+ )
40
45
  from flwr.common.typing import Fab, Run
41
46
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
42
47
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
43
- CreateNodeRequest,
44
- CreateNodeResponse,
45
- DeleteNodeRequest,
46
- DeleteNodeResponse,
48
+ ActivateNodeRequest,
49
+ ActivateNodeResponse,
50
+ DeactivateNodeRequest,
51
+ DeactivateNodeResponse,
47
52
  PullMessagesRequest,
48
53
  PullMessagesResponse,
49
54
  PushMessagesRequest,
50
55
  PushMessagesResponse,
56
+ RegisterNodeFleetRequest,
57
+ RegisterNodeFleetResponse,
58
+ UnregisterNodeFleetRequest,
59
+ UnregisterNodeFleetResponse,
51
60
  )
52
61
  from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
53
62
  SendNodeHeartbeatRequest,
@@ -64,6 +73,7 @@ from flwr.proto.message_pb2 import ( # pylint: disable=E0611
64
73
  )
65
74
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
66
75
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
76
+ from flwr.supercore.primitives.asymmetric import generate_key_pairs, public_key_to_bytes
67
77
 
68
78
  try:
69
79
  import requests
@@ -71,8 +81,10 @@ except ModuleNotFoundError:
71
81
  flwr_exit(ExitCode.COMMON_MISSING_EXTRA_REST)
72
82
 
73
83
 
74
- PATH_CREATE_NODE: str = "api/v0/fleet/create-node"
75
- PATH_DELETE_NODE: str = "api/v0/fleet/delete-node"
84
+ PATH_REGISTER_NODE: str = "/api/v0/fleet/register-node"
85
+ PATH_ACTIVATE_NODE: str = "/api/v0/fleet/activate-node"
86
+ PATH_DEACTIVATE_NODE: str = "/api/v0/fleet/deactivate-node"
87
+ PATH_UNREGISTER_NODE: str = "/api/v0/fleet/unregister-node"
76
88
  PATH_PULL_MESSAGES: str = "/api/v0/fleet/pull-messages"
77
89
  PATH_PUSH_MESSAGES: str = "/api/v0/fleet/push-messages"
78
90
  PATH_PULL_OBJECT: str = "/api/v0/fleet/pull-object"
@@ -99,10 +111,9 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
99
111
  ] = None,
100
112
  ) -> Iterator[
101
113
  tuple[
114
+ int,
102
115
  Callable[[], Optional[tuple[Message, ObjectTree]]],
103
116
  Callable[[Message, ObjectTree], set[str]],
104
- Callable[[], Optional[int]],
105
- Callable[[], None],
106
117
  Callable[[int], Run],
107
118
  Callable[[str, int], Fab],
108
119
  Callable[[int, str], bytes],
@@ -134,15 +145,15 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
134
145
  connection using the certificates will be established to an SSL-enabled
135
146
  Flower server. Bytes won't work for the REST API.
136
147
  authentication_keys : Optional[Tuple[PrivateKey, PublicKey]] (default: None)
137
- Client authentication is not supported for this transport type.
148
+ SuperNode authentication is not supported for this transport type.
138
149
 
139
150
  Returns
140
151
  -------
141
- receive : Callable
142
- send : Callable
143
- create_node : Optional[Callable]
144
- delete_node : Optional[Callable]
145
- get_run : Optional[Callable]
152
+ node_id : int
153
+ receive : Callable[[], Optional[tuple[Message, ObjectTree]]]
154
+ send : Callable[[Message, ObjectTree], set[str]]
155
+ get_run : Callable[[int], Run]
156
+ get_fab : Callable[[str, int], Fab]
146
157
  pull_object : Callable[[str], bytes]
147
158
  push_object : Callable[[str, bytes], None]
148
159
  confirm_message_received : Callable[[str], None]
@@ -171,7 +182,14 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
171
182
  "must be provided as a string path to the client.",
172
183
  )
173
184
  if authentication_keys is not None:
174
- log(ERROR, "Client authentication is not supported for this transport type.")
185
+ log(ERROR, "SuperNode authentication is not supported for this transport type.")
186
+
187
+ # REST does NOT support node authentication
188
+ self_registered = False
189
+ if authentication_keys is None:
190
+ self_registered = True
191
+ authentication_keys = generate_key_pairs()
192
+ node_pk = public_key_to_bytes(authentication_keys[1])
175
193
 
176
194
  # Shared variables for inner functions
177
195
  node: Optional[Node] = None
@@ -180,7 +198,7 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
180
198
  retry_invoker.should_giveup = None
181
199
 
182
200
  ###########################################################################
183
- # heartbeat/create_node/delete_node/receive/send/get_run functions
201
+ # SuperNode functions
184
202
  ###########################################################################
185
203
 
186
204
  def _request(
@@ -290,23 +308,35 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
290
308
 
291
309
  heartbeat_sender = HeartbeatSender(send_node_heartbeat)
292
310
 
293
- def create_node() -> Optional[int]:
294
- """Set create_node."""
295
- req = CreateNodeRequest(heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL)
311
+ def register_node() -> None:
312
+ """Register node with SuperLink."""
313
+ req = RegisterNodeFleetRequest(public_key=node_pk)
296
314
 
297
315
  # Send the request
298
- res = _request(req, CreateNodeResponse, PATH_CREATE_NODE)
316
+ res = _request(req, RegisterNodeFleetResponse, PATH_REGISTER_NODE)
299
317
  if res is None:
300
- return None
318
+ raise RuntimeError("Failed to register node")
319
+
320
+ def activate_node() -> int:
321
+ """Activate node and start heartbeat."""
322
+ req = ActivateNodeRequest(
323
+ public_key=node_pk,
324
+ heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL,
325
+ )
326
+
327
+ # Send the request
328
+ res = _request(req, ActivateNodeResponse, PATH_ACTIVATE_NODE)
329
+ if res is None:
330
+ raise RuntimeError("Failed to activate node")
301
331
 
302
332
  # Remember the node and start the heartbeat sender
303
333
  nonlocal node
304
- node = res.node
334
+ node = Node(node_id=res.node_id)
305
335
  heartbeat_sender.start()
306
336
  return node.node_id
307
337
 
308
- def delete_node() -> None:
309
- """Set delete_node."""
338
+ def deactivate_node() -> None:
339
+ """Deactivate node and stop heartbeat."""
310
340
  nonlocal node
311
341
  if node is None:
312
342
  raise RuntimeError("Node instance missing")
@@ -314,13 +344,27 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
314
344
  # Stop the heartbeat sender
315
345
  heartbeat_sender.stop()
316
346
 
317
- # Send DeleteNode request
318
- req = DeleteNodeRequest(node=node)
347
+ # Send DeactivateNode request
348
+ req = DeactivateNodeRequest(node_id=node.node_id)
319
349
 
320
350
  # Send the request
321
- res = _request(req, DeleteNodeResponse, PATH_DELETE_NODE)
351
+ res = _request(req, DeactivateNodeResponse, PATH_DEACTIVATE_NODE)
322
352
  if res is None:
323
- return
353
+ raise RuntimeError("Failed to deactivate node")
354
+
355
+ def unregister_node() -> None:
356
+ """Unregister node from SuperLink."""
357
+ nonlocal node
358
+ if node is None:
359
+ raise RuntimeError("Node instance missing")
360
+
361
+ # Send UnregisterNode request
362
+ req = UnregisterNodeFleetRequest(node_id=node.node_id)
363
+
364
+ # Send the request
365
+ res = _request(req, UnregisterNodeFleetResponse, PATH_UNREGISTER_NODE)
366
+ if res is None:
367
+ raise RuntimeError("Failed to unregister node")
324
368
 
325
369
  # Cleanup
326
370
  node = None
@@ -392,12 +436,9 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
392
436
  # Send the request
393
437
  res = _request(req, GetFabResponse, PATH_GET_FAB)
394
438
  if res is None:
395
- return Fab("", b"")
439
+ return Fab("", b"", {})
396
440
 
397
- return Fab(
398
- res.fab.hash_str,
399
- res.fab.content,
400
- )
441
+ return fab_from_proto(res.fab)
401
442
 
402
443
  def pull_object(run_id: int, object_id: str) -> bytes:
403
444
  """Pull the object from the SuperLink."""
@@ -439,12 +480,14 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
439
480
  fn(object_id)
440
481
 
441
482
  try:
483
+ if self_registered:
484
+ register_node()
485
+ node_id = activate_node()
442
486
  # Yield methods
443
487
  yield (
488
+ node_id,
444
489
  receive,
445
490
  send,
446
- create_node,
447
- delete_node,
448
491
  get_run,
449
492
  get_fab,
450
493
  pull_object,
@@ -459,6 +502,8 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
459
502
  if node is not None:
460
503
  # Disable retrying
461
504
  retry_invoker.max_tries = 1
462
- delete_node()
505
+ deactivate_node()
506
+ if self_registered:
507
+ unregister_node()
463
508
  except RequestsConnectionError:
464
509
  pass
@@ -15,9 +15,8 @@
15
15
  """Public Flower ClientApp APIs."""
16
16
 
17
17
 
18
- from flwr.client.client_app import ClientApp
19
-
20
18
  from . import mod
19
+ from .client_app import ClientApp as ClientApp
21
20
 
22
21
  __all__ = [
23
22
  "ClientApp",
@@ -19,7 +19,7 @@ from logging import DEBUG
19
19
  from pathlib import Path
20
20
  from typing import Callable, Optional
21
21
 
22
- from flwr.client.client_app import ClientApp, LoadClientAppError
22
+ from flwr.clientapp.client_app import ClientApp, LoadClientAppError
23
23
  from flwr.common.config import (
24
24
  get_flwr_dir,
25
25
  get_metadata_from_config,
flwr/common/constant.py CHANGED
@@ -17,6 +17,8 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
+ import os
21
+
20
22
  TRANSPORT_TYPE_GRPC_BIDI = "grpc-bidi"
21
23
  TRANSPORT_TYPE_GRPC_RERE = "grpc-rere"
22
24
  TRANSPORT_TYPE_GRPC_ADAPTER = "grpc-adapter"
@@ -60,7 +62,9 @@ HEARTBEAT_DEFAULT_INTERVAL = 30
60
62
  HEARTBEAT_CALL_TIMEOUT = 5
61
63
  HEARTBEAT_BASE_MULTIPLIER = 0.8
62
64
  HEARTBEAT_RANDOM_RANGE = (-0.1, 0.1)
63
- HEARTBEAT_MAX_INTERVAL = 1e300
65
+ HEARTBEAT_MIN_INTERVAL = 10
66
+ HEARTBEAT_MAX_INTERVAL = 1800 # 30 minutes
67
+ HEARTBEAT_INTERVAL_INF = 1e300 # Large value, disabling heartbeats
64
68
  HEARTBEAT_PATIENCE = 2
65
69
  RUN_FAILURE_DETAILS_NO_HEARTBEAT = "No heartbeat received from the run."
66
70
 
@@ -70,13 +74,23 @@ NODE_ID_NUM_BYTES = 8
70
74
 
71
75
  # Constants for FAB
72
76
  APP_DIR = "apps"
73
- FAB_ALLOWED_EXTENSIONS = {".py", ".toml", ".md"}
74
77
  FAB_CONFIG_FILE = "pyproject.toml"
75
78
  FAB_DATE = (2024, 10, 1, 0, 0, 0)
76
79
  FAB_HASH_TRUNCATION = 8
77
80
  FAB_MAX_SIZE = 10 * 1024 * 1024 # 10 MB
78
81
  FLWR_DIR = ".flwr" # The default Flower directory: ~/.flwr/
79
82
  FLWR_HOME = "FLWR_HOME" # If set, override the default Flower directory
83
+ # FAB file include patterns (gitignore-style patterns)
84
+ FAB_INCLUDE_PATTERNS = (
85
+ "**/*.py",
86
+ "**/*.toml",
87
+ "**/*.md",
88
+ )
89
+ # FAB file exclude patterns (gitignore-style patterns)
90
+ FAB_EXCLUDE_PATTERNS = (
91
+ "**/__pycache__/**",
92
+ FAB_CONFIG_FILE, # Exclude the original pyproject.toml
93
+ )
80
94
 
81
95
  # Constant for SuperLink
82
96
  SUPERLINK_NODE_ID = 1
@@ -109,14 +123,14 @@ LOG_UPLOAD_INTERVAL = 0.2 # Minimum interval between two log uploads
109
123
  # Retry configurations
110
124
  MAX_RETRY_DELAY = 20 # Maximum delay duration between two consecutive retries.
111
125
 
112
- # Constants for user authentication
126
+ # Constants for account authentication
113
127
  CREDENTIALS_DIR = ".credentials"
114
- AUTH_TYPE_JSON_KEY = "auth-type" # For key name in JSON file
115
- AUTH_TYPE_YAML_KEY = "auth_type" # For key name in YAML file
128
+ AUTHN_TYPE_JSON_KEY = "authn-type" # For key name in JSON file
129
+ AUTHN_TYPE_YAML_KEY = "authn_type" # For key name in YAML file
116
130
  ACCESS_TOKEN_KEY = "flwr-oidc-access-token"
117
131
  REFRESH_TOKEN_KEY = "flwr-oidc-refresh-token"
118
132
 
119
- # Constants for user authorization
133
+ # Constants for account authorization
120
134
  AUTHZ_TYPE_YAML_KEY = "authz_type" # For key name in YAML file
121
135
 
122
136
  # Constants for node authentication
@@ -135,7 +149,9 @@ GC_THRESHOLD = 200_000_000 # 200 MB
135
149
  # Constants for Inflatable
136
150
  HEAD_BODY_DIVIDER = b"\x00"
137
151
  HEAD_VALUE_DIVIDER = " "
138
- MAX_ARRAY_CHUNK_SIZE = 20_971_520 # 20 MB
152
+ FLWR_PRIVATE_MAX_ARRAY_CHUNK_SIZE = int(
153
+ os.getenv("FLWR_PRIVATE_MAX_ARRAY_CHUNK_SIZE", "5242880")
154
+ ) # 5 MB
139
155
 
140
156
  # Constants for serialization
141
157
  INT64_MAX_VALUE = 9223372036854775807 # (1 << 63) - 1
@@ -144,8 +160,12 @@ INT64_MAX_VALUE = 9223372036854775807 # (1 << 63) - 1
144
160
  FLWR_APP_TOKEN_LENGTH = 128 # Length of the token used
145
161
 
146
162
  # Constants for object pushing and pulling
147
- MAX_CONCURRENT_PUSHES = 8 # Default maximum number of concurrent pushes
148
- MAX_CONCURRENT_PULLS = 8 # Default maximum number of concurrent pulls
163
+ FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PUSHES = int(
164
+ os.getenv("FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PUSHES", "2")
165
+ ) # Default maximum number of concurrent pushes
166
+ FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PULLS = int(
167
+ os.getenv("FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PULLS", "2")
168
+ ) # Default maximum number of concurrent pulls
149
169
  PULL_MAX_TIME = 7200 # Default maximum time to wait for pulling objects
150
170
  PULL_MAX_TRIES_PER_OBJECT = 500 # Default maximum number of tries to pull an object
151
171
  PULL_INITIAL_BACKOFF = 1 # Initial backoff time for pulling objects
@@ -154,9 +174,13 @@ PULL_BACKOFF_CAP = 10 # Maximum backoff time for pulling objects
154
174
 
155
175
  # ControlServicer constants
156
176
  RUN_ID_NOT_FOUND_MESSAGE = "Run ID not found"
157
- NO_USER_AUTH_MESSAGE = "ControlServicer initialized without user authentication"
177
+ NO_ACCOUNT_AUTH_MESSAGE = "ControlServicer initialized without account authentication"
158
178
  NO_ARTIFACT_PROVIDER_MESSAGE = "ControlServicer initialized without artifact provider"
159
179
  PULL_UNFINISHED_RUN_MESSAGE = "Cannot pull artifacts for an unfinished run"
180
+ SUPERNODE_NOT_CREATED_FROM_CLI_MESSAGE = "Invalid SuperNode credentials"
181
+ PUBLIC_KEY_ALREADY_IN_USE_MESSAGE = "Public key already in use"
182
+ PUBLIC_KEY_NOT_VALID = "The provided public key is not valid"
183
+ NODE_NOT_FOUND_MESSAGE = "Node ID not found for account"
160
184
 
161
185
 
162
186
  class MessageType:
@@ -245,12 +269,23 @@ class CliOutputFormat:
245
269
  raise TypeError(f"{cls.__name__} cannot be instantiated.")
246
270
 
247
271
 
248
- class AuthType:
249
- """User authentication types."""
272
+ class AuthnType:
273
+ """Account authentication types."""
250
274
 
275
+ NOOP = "noop"
251
276
  OIDC = "oidc"
252
277
 
253
- def __new__(cls) -> AuthType:
278
+ def __new__(cls) -> AuthnType:
279
+ """Prevent instantiation."""
280
+ raise TypeError(f"{cls.__name__} cannot be instantiated.")
281
+
282
+
283
+ class AuthzType:
284
+ """Account authorization types."""
285
+
286
+ NOOP = "noop"
287
+
288
+ def __new__(cls) -> AuthzType:
254
289
  """Prevent instantiation."""
255
290
  raise TypeError(f"{cls.__name__} cannot be instantiated.")
256
291
 
@@ -281,3 +316,8 @@ class ExecPluginType:
281
316
  """Return all SuperExec plugin types."""
282
317
  # Filter all constants (uppercase) of the class
283
318
  return [v for k, v in vars(ExecPluginType).items() if k.isupper()]
319
+
320
+
321
+ # Constants for No-op auth plugins
322
+ NOOP_FLWR_AID = "<none>"
323
+ NOOP_ACCOUNT_NAME = "sys_noauth"
@@ -41,12 +41,16 @@ class ExitCode:
41
41
 
42
42
  # SuperNode-specific exit codes (300-399)
43
43
  SUPERNODE_REST_ADDRESS_INVALID = 300
44
- SUPERNODE_NODE_AUTH_KEYS_REQUIRED = 301
45
- SUPERNODE_NODE_AUTH_KEYS_INVALID = 302
44
+ # SUPERNODE_NODE_AUTH_KEYS_REQUIRED = 301 --- DELETED ---
45
+ SUPERNODE_NODE_AUTH_KEY_INVALID = 302
46
+ SUPERNODE_STARTED_WITHOUT_TLS_BUT_NODE_AUTH_ENABLED = 303
46
47
 
47
48
  # SuperExec-specific exit codes (400-499)
48
49
  SUPEREXEC_INVALID_PLUGIN_CONFIG = 400
49
50
 
51
+ # FlowerCLI-specific exit codes (500-599)
52
+ FLWRCLI_NODE_AUTH_PUBLIC_KEY_INVALID = 500
53
+
50
54
  # Common exit codes (600-699)
51
55
  COMMON_ADDRESS_INVALID = 600
52
56
  COMMON_MISSING_EXTRA_REST = 601
@@ -102,20 +106,26 @@ EXIT_CODE_HELP = {
102
106
  "When using the REST API, please provide `https://` or "
103
107
  "`http://` before the server address (e.g. `http://127.0.0.1:8080`)"
104
108
  ),
105
- ExitCode.SUPERNODE_NODE_AUTH_KEYS_REQUIRED: (
106
- "Node authentication requires file paths to both "
107
- "'--auth-supernode-private-key' and '--auth-supernode-public-key' "
108
- "to be provided (providing only one of them is not sufficient)."
109
- ),
110
- ExitCode.SUPERNODE_NODE_AUTH_KEYS_INVALID: (
111
- "Node authentication requires elliptic curve private and public key pair. "
112
- "Please ensure that the file path points to a valid private/public key "
109
+ ExitCode.SUPERNODE_NODE_AUTH_KEY_INVALID: (
110
+ "Node authentication requires elliptic curve private key. "
111
+ "Please ensure that the file path points to a valid private key "
113
112
  "file and try again."
114
113
  ),
114
+ ExitCode.SUPERNODE_STARTED_WITHOUT_TLS_BUT_NODE_AUTH_ENABLED: (
115
+ "The private key for SuperNode authentication was provided, but TLS is not "
116
+ "enabled. Node authentication can only be used when TLS is enabled."
117
+ ),
115
118
  # SuperExec-specific exit codes (400-499)
116
119
  ExitCode.SUPEREXEC_INVALID_PLUGIN_CONFIG: (
117
120
  "The YAML configuration for the SuperExec plugin is invalid."
118
121
  ),
122
+ # FlowerCLI-specific exit codes (500-599)
123
+ ExitCode.FLWRCLI_NODE_AUTH_PUBLIC_KEY_INVALID: (
124
+ "Node authentication requires a valid elliptic curve public key in the "
125
+ "SSH format and following a NIST standard elliptic curve (e.g. SECP384R1). "
126
+ "Please ensure that the file path points to a valid public key "
127
+ "file and try again."
128
+ ),
119
129
  # Common exit codes (600-699)
120
130
  ExitCode.COMMON_ADDRESS_INVALID: (
121
131
  "Please provide a valid URL, IPv4 or IPv6 address."
@@ -25,10 +25,10 @@ from typing import Callable, Optional, TypeVar
25
25
  from flwr.proto.message_pb2 import ObjectTree # pylint: disable=E0611
26
26
 
27
27
  from .constant import (
28
+ FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PULLS,
29
+ FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PUSHES,
28
30
  HEAD_BODY_DIVIDER,
29
31
  HEAD_VALUE_DIVIDER,
30
- MAX_CONCURRENT_PULLS,
31
- MAX_CONCURRENT_PUSHES,
32
32
  PULL_BACKOFF_CAP,
33
33
  PULL_INITIAL_BACKOFF,
34
34
  PULL_MAX_TIME,
@@ -118,7 +118,7 @@ def push_objects(
118
118
  *,
119
119
  object_ids_to_push: Optional[set[str]] = None,
120
120
  keep_objects: bool = False,
121
- max_concurrent_pushes: int = MAX_CONCURRENT_PUSHES,
121
+ max_concurrent_pushes: int = FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PUSHES,
122
122
  ) -> None:
123
123
  """Push multiple objects to the servicer.
124
124
 
@@ -137,7 +137,7 @@ def push_objects(
137
137
  If `True`, the original objects will be kept in the `objects` dictionary
138
138
  after pushing. If `False`, they will be removed from the dictionary to avoid
139
139
  high memory usage.
140
- max_concurrent_pushes : int (default: MAX_CONCURRENT_PUSHES)
140
+ max_concurrent_pushes : int (default: FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PUSHES)
141
141
  The maximum number of concurrent pushes to perform.
142
142
  """
143
143
  lock = threading.Lock()
@@ -168,7 +168,7 @@ def push_object_contents_from_iterable(
168
168
  object_contents: Iterable[tuple[str, bytes]],
169
169
  push_object_fn: Callable[[str, bytes], None],
170
170
  *,
171
- max_concurrent_pushes: int = MAX_CONCURRENT_PUSHES,
171
+ max_concurrent_pushes: int = FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PUSHES,
172
172
  ) -> None:
173
173
  """Push multiple object contents to the servicer.
174
174
 
@@ -181,7 +181,7 @@ def push_object_contents_from_iterable(
181
181
  A function that takes an object ID and its content as bytes, and pushes
182
182
  it to the servicer. This function should raise `ObjectIdNotPreregisteredError`
183
183
  if the object ID is not pre-registered.
184
- max_concurrent_pushes : int (default: MAX_CONCURRENT_PUSHES)
184
+ max_concurrent_pushes : int (default: FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PUSHES)
185
185
  The maximum number of concurrent pushes to perform.
186
186
  """
187
187
 
@@ -210,7 +210,7 @@ def pull_objects( # pylint: disable=too-many-arguments,too-many-locals
210
210
  object_ids: list[str],
211
211
  pull_object_fn: Callable[[str], bytes],
212
212
  *,
213
- max_concurrent_pulls: int = MAX_CONCURRENT_PULLS,
213
+ max_concurrent_pulls: int = FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PULLS,
214
214
  max_time: Optional[float] = PULL_MAX_TIME,
215
215
  max_tries_per_object: Optional[int] = PULL_MAX_TRIES_PER_OBJECT,
216
216
  initial_backoff: float = PULL_INITIAL_BACKOFF,
@@ -227,7 +227,7 @@ def pull_objects( # pylint: disable=too-many-arguments,too-many-locals
227
227
  The function should raise `ObjectUnavailableError` if the object is not yet
228
228
  available, or `ObjectIdNotPreregisteredError` if the object ID is not
229
229
  pre-registered.
230
- max_concurrent_pulls : int (default: MAX_CONCURRENT_PULLS)
230
+ max_concurrent_pulls : int (default: FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PULLS)
231
231
  The maximum number of concurrent pulls to perform.
232
232
  max_time : Optional[float] (default: PULL_MAX_TIME)
233
233
  The maximum time to wait for all pulls to complete. If `None`, waits
@@ -442,7 +442,7 @@ def pull_and_inflate_object_from_tree( # pylint: disable=R0913
442
442
  confirm_object_received_fn: Callable[[str], None],
443
443
  *,
444
444
  return_type: type[T] = InflatableObject, # type: ignore
445
- max_concurrent_pulls: int = MAX_CONCURRENT_PULLS,
445
+ max_concurrent_pulls: int = FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PULLS,
446
446
  max_time: Optional[float] = PULL_MAX_TIME,
447
447
  max_tries_per_object: Optional[int] = PULL_MAX_TRIES_PER_OBJECT,
448
448
  initial_backoff: float = PULL_INITIAL_BACKOFF,
@@ -460,7 +460,7 @@ def pull_and_inflate_object_from_tree( # pylint: disable=R0913
460
460
  A function to confirm that the object has been received.
461
461
  return_type : type[T] (default: InflatableObject)
462
462
  The type of the object to return. Must be a subclass of `InflatableObject`.
463
- max_concurrent_pulls : int (default: MAX_CONCURRENT_PULLS)
463
+ max_concurrent_pulls : int (default: FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PULLS)
464
464
  The maximum number of concurrent pulls to perform.
465
465
  max_time : Optional[float] (default: PULL_MAX_TIME)
466
466
  The maximum time to wait for all pulls to complete. If `None`, waits
@@ -25,7 +25,7 @@ from typing import TYPE_CHECKING, Any, cast, overload
25
25
 
26
26
  import numpy as np
27
27
 
28
- from ..constant import MAX_ARRAY_CHUNK_SIZE, SType
28
+ from ..constant import FLWR_PRIVATE_MAX_ARRAY_CHUNK_SIZE, SType
29
29
  from ..inflatable import (
30
30
  InflatableObject,
31
31
  add_header_to_object_body,
@@ -272,8 +272,8 @@ class Array(InflatableObject):
272
272
  chunks: list[tuple[str, InflatableObject]] = []
273
273
  # memoryview allows for zero-copy slicing
274
274
  data_view = memoryview(self.data)
275
- for start in range(0, len(data_view), MAX_ARRAY_CHUNK_SIZE):
276
- end = min(start + MAX_ARRAY_CHUNK_SIZE, len(data_view))
275
+ for start in range(0, len(data_view), FLWR_PRIVATE_MAX_ARRAY_CHUNK_SIZE):
276
+ end = min(start + FLWR_PRIVATE_MAX_ARRAY_CHUNK_SIZE, len(data_view))
277
277
  ac = ArrayChunk(data_view[start:end])
278
278
  chunks.append((ac.object_id, ac))
279
279
 
@@ -147,11 +147,20 @@ class ArrayRecord(TypedDict[str, Array], InflatableObject):
147
147
  keep_input: bool = True,
148
148
  ) -> None: ...
149
149
 
150
+ # This is also required for PyTorch state dict because they are not strongly typed
151
+ @overload
152
+ def __init__( # noqa: E704
153
+ self,
154
+ torch_state_dict: dict[str, Any],
155
+ *,
156
+ keep_input: bool = True,
157
+ ) -> None: ...
158
+
150
159
  def __init__( # pylint: disable=too-many-arguments
151
160
  self,
152
161
  *args: Any,
153
162
  numpy_ndarrays: list[NDArray] | None = None,
154
- torch_state_dict: OrderedDict[str, torch.Tensor] | None = None,
163
+ torch_state_dict: OrderedDict[str, torch.Tensor] | dict[str, Any] | None = None,
155
164
  array_dict: OrderedDict[str, Array] | None = None,
156
165
  keep_input: bool = True,
157
166
  ) -> None:
@@ -16,57 +16,14 @@
16
16
 
17
17
 
18
18
  import base64
19
- from typing import cast
20
19
 
21
20
  from cryptography.exceptions import InvalidSignature
22
21
  from cryptography.fernet import Fernet
23
- from cryptography.hazmat.primitives import hashes, hmac, serialization
22
+ from cryptography.hazmat.primitives import hashes, hmac
24
23
  from cryptography.hazmat.primitives.asymmetric import ec
25
24
  from cryptography.hazmat.primitives.kdf.hkdf import HKDF
26
25
 
27
26
 
28
- def generate_key_pairs() -> (
29
- tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
30
- ):
31
- """Generate private and public key pairs with Cryptography."""
32
- private_key = ec.generate_private_key(ec.SECP384R1())
33
- public_key = private_key.public_key()
34
- return private_key, public_key
35
-
36
-
37
- def private_key_to_bytes(private_key: ec.EllipticCurvePrivateKey) -> bytes:
38
- """Serialize private key to bytes."""
39
- return private_key.private_bytes(
40
- encoding=serialization.Encoding.PEM,
41
- format=serialization.PrivateFormat.PKCS8,
42
- encryption_algorithm=serialization.NoEncryption(),
43
- )
44
-
45
-
46
- def bytes_to_private_key(private_key_bytes: bytes) -> ec.EllipticCurvePrivateKey:
47
- """Deserialize private key from bytes."""
48
- return cast(
49
- ec.EllipticCurvePrivateKey,
50
- serialization.load_pem_private_key(data=private_key_bytes, password=None),
51
- )
52
-
53
-
54
- def public_key_to_bytes(public_key: ec.EllipticCurvePublicKey) -> bytes:
55
- """Serialize public key to bytes."""
56
- return public_key.public_bytes(
57
- encoding=serialization.Encoding.PEM,
58
- format=serialization.PublicFormat.SubjectPublicKeyInfo,
59
- )
60
-
61
-
62
- def bytes_to_public_key(public_key_bytes: bytes) -> ec.EllipticCurvePublicKey:
63
- """Deserialize public key from bytes."""
64
- return cast(
65
- ec.EllipticCurvePublicKey,
66
- serialization.load_pem_public_key(data=public_key_bytes),
67
- )
68
-
69
-
70
27
  def generate_shared_key(
71
28
  private_key: ec.EllipticCurvePrivateKey, public_key: ec.EllipticCurvePublicKey
72
29
  ) -> bytes:
@@ -117,48 +74,3 @@ def verify_hmac(key: bytes, message: bytes, hmac_value: bytes) -> bool:
117
74
  return True
118
75
  except InvalidSignature:
119
76
  return False
120
-
121
-
122
- def sign_message(private_key: ec.EllipticCurvePrivateKey, message: bytes) -> bytes:
123
- """Sign a message using the provided EC private key.
124
-
125
- Parameters
126
- ----------
127
- private_key : ec.EllipticCurvePrivateKey
128
- The EC private key to sign the message with.
129
- message : bytes
130
- The message to be signed.
131
-
132
- Returns
133
- -------
134
- bytes
135
- The signature of the message.
136
- """
137
- signature = private_key.sign(message, ec.ECDSA(hashes.SHA256()))
138
- return signature
139
-
140
-
141
- def verify_signature(
142
- public_key: ec.EllipticCurvePublicKey, message: bytes, signature: bytes
143
- ) -> bool:
144
- """Verify a signature against a message using the provided EC public key.
145
-
146
- Parameters
147
- ----------
148
- public_key : ec.EllipticCurvePublicKey
149
- The EC public key to verify the signature.
150
- message : bytes
151
- The original message.
152
- signature : bytes
153
- The signature to verify.
154
-
155
- Returns
156
- -------
157
- bool
158
- True if the signature is valid, False otherwise.
159
- """
160
- try:
161
- public_key.verify(signature, message, ec.ECDSA(hashes.SHA256()))
162
- return True
163
- except InvalidSignature:
164
- return False