flwr 1.24.0__py3-none-any.whl → 1.25.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/app_cmd/review.py +13 -3
- flwr/cli/federation/show.py +4 -3
- flwr/cli/ls.py +44 -3
- flwr/cli/new/new.py +106 -297
- flwr/cli/run/run.py +12 -17
- flwr/cli/run_utils.py +23 -5
- flwr/cli/stop.py +1 -1
- flwr/cli/supernode/ls.py +10 -5
- flwr/cli/utils.py +0 -137
- flwr/client/grpc_adapter_client/connection.py +2 -2
- flwr/client/grpc_rere_client/connection.py +6 -3
- flwr/client/rest_client/connection.py +6 -4
- flwr/common/serde.py +6 -0
- flwr/common/typing.py +6 -0
- flwr/proto/fleet_pb2.py +10 -10
- flwr/proto/fleet_pb2.pyi +5 -1
- flwr/proto/run_pb2.py +24 -24
- flwr/proto/run_pb2.pyi +10 -1
- flwr/server/app.py +1 -0
- flwr/server/superlink/fleet/message_handler/message_handler.py +41 -2
- flwr/server/superlink/linkstate/in_memory_linkstate.py +34 -0
- flwr/server/superlink/linkstate/linkstate.py +32 -0
- flwr/server/superlink/linkstate/sqlite_linkstate.py +60 -3
- flwr/supercore/constant.py +3 -0
- flwr/supercore/utils.py +190 -0
- flwr/superlink/servicer/control/control_grpc.py +2 -0
- flwr/superlink/servicer/control/control_servicer.py +88 -5
- flwr/supernode/nodestate/in_memory_nodestate.py +62 -1
- flwr/supernode/nodestate/nodestate.py +45 -0
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +7 -1
- flwr/supernode/start_client_internal.py +7 -4
- {flwr-1.24.0.dist-info → flwr-1.25.0.dist-info}/METADATA +2 -4
- {flwr-1.24.0.dist-info → flwr-1.25.0.dist-info}/RECORD +35 -96
- flwr/cli/new/templates/__init__.py +0 -15
- flwr/cli/new/templates/app/.gitignore.tpl +0 -163
- flwr/cli/new/templates/app/LICENSE.tpl +0 -202
- flwr/cli/new/templates/app/README.baseline.md.tpl +0 -127
- flwr/cli/new/templates/app/README.flowertune.md.tpl +0 -68
- flwr/cli/new/templates/app/README.md.tpl +0 -37
- flwr/cli/new/templates/app/__init__.py +0 -15
- flwr/cli/new/templates/app/code/__init__.baseline.py.tpl +0 -1
- flwr/cli/new/templates/app/code/__init__.py +0 -15
- flwr/cli/new/templates/app/code/__init__.py.tpl +0 -1
- flwr/cli/new/templates/app/code/__init__.pytorch_legacy_api.py.tpl +0 -1
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +0 -75
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +0 -93
- flwr/cli/new/templates/app/code/client.jax.py.tpl +0 -71
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +0 -102
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +0 -46
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +0 -80
- flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +0 -55
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +0 -108
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +0 -82
- flwr/cli/new/templates/app/code/client.xgboost.py.tpl +0 -110
- flwr/cli/new/templates/app/code/dataset.baseline.py.tpl +0 -36
- flwr/cli/new/templates/app/code/flwr_tune/__init__.py +0 -15
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +0 -92
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +0 -87
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +0 -56
- flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +0 -73
- flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +0 -78
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -66
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +0 -43
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +0 -42
- flwr/cli/new/templates/app/code/server.jax.py.tpl +0 -39
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +0 -41
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +0 -38
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +0 -41
- flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +0 -31
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +0 -44
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +0 -38
- flwr/cli/new/templates/app/code/server.xgboost.py.tpl +0 -56
- flwr/cli/new/templates/app/code/strategy.baseline.py.tpl +0 -1
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +0 -98
- flwr/cli/new/templates/app/code/task.jax.py.tpl +0 -57
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +0 -102
- flwr/cli/new/templates/app/code/task.numpy.py.tpl +0 -7
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +0 -99
- flwr/cli/new/templates/app/code/task.pytorch_legacy_api.py.tpl +0 -111
- flwr/cli/new/templates/app/code/task.sklearn.py.tpl +0 -67
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +0 -52
- flwr/cli/new/templates/app/code/task.xgboost.py.tpl +0 -67
- flwr/cli/new/templates/app/code/utils.baseline.py.tpl +0 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +0 -146
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +0 -80
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +0 -65
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +0 -52
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +0 -56
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +0 -49
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +0 -53
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +0 -53
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +0 -52
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +0 -53
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +0 -61
- {flwr-1.24.0.dist-info → flwr-1.25.0.dist-info}/WHEEL +0 -0
- {flwr-1.24.0.dist-info → flwr-1.25.0.dist-info}/entry_points.txt +0 -0
|
@@ -127,7 +127,7 @@ def send_node_heartbeat(
|
|
|
127
127
|
return SendNodeHeartbeatResponse(success=res)
|
|
128
128
|
|
|
129
129
|
|
|
130
|
-
def pull_messages(
|
|
130
|
+
def pull_messages( # pylint: disable=too-many-locals
|
|
131
131
|
request: PullMessagesRequest,
|
|
132
132
|
state: LinkState,
|
|
133
133
|
store: ObjectStore,
|
|
@@ -143,6 +143,8 @@ def pull_messages(
|
|
|
143
143
|
# Convert to Messages
|
|
144
144
|
msg_proto = []
|
|
145
145
|
trees = []
|
|
146
|
+
run_id_to_record: int | None = None
|
|
147
|
+
|
|
146
148
|
for msg in message_list:
|
|
147
149
|
try:
|
|
148
150
|
# Retrieve Message object tree from ObjectStore
|
|
@@ -152,12 +154,30 @@ def pull_messages(
|
|
|
152
154
|
# Add Message and its object tree to the response
|
|
153
155
|
msg_proto.append(message_to_proto(msg))
|
|
154
156
|
trees.append(obj_tree)
|
|
157
|
+
|
|
158
|
+
# Track run_id for traffic recording
|
|
159
|
+
run_id_to_record = msg.metadata.run_id
|
|
160
|
+
|
|
155
161
|
except NoObjectInStoreError as e:
|
|
156
162
|
log(ERROR, e.message)
|
|
157
163
|
# Delete message ins from state
|
|
158
164
|
state.delete_messages(message_ins_ids={msg_object_id})
|
|
159
165
|
|
|
160
|
-
|
|
166
|
+
response = PullMessagesResponse(messages_list=msg_proto, message_object_trees=trees)
|
|
167
|
+
|
|
168
|
+
# Record incoming traffic size
|
|
169
|
+
bytes_recv = len(request.SerializeToString())
|
|
170
|
+
|
|
171
|
+
# Record traffic only if message was successfully processed
|
|
172
|
+
# All messages in this request are assumed to belong to the same run
|
|
173
|
+
if run_id_to_record is not None:
|
|
174
|
+
# Record outgoing traffic size
|
|
175
|
+
bytes_sent = len(response.SerializeToString())
|
|
176
|
+
state.store_traffic(
|
|
177
|
+
run_id_to_record, bytes_sent=bytes_sent, bytes_recv=bytes_recv
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
return response
|
|
161
181
|
|
|
162
182
|
|
|
163
183
|
def push_messages(
|
|
@@ -170,6 +190,9 @@ def push_messages(
|
|
|
170
190
|
msg = message_from_proto(message_proto=request.messages_list[0])
|
|
171
191
|
run_id = msg.metadata.run_id
|
|
172
192
|
|
|
193
|
+
# Record incoming traffic size
|
|
194
|
+
bytes_recv = len(request.SerializeToString())
|
|
195
|
+
|
|
173
196
|
# Abort if the run is not running
|
|
174
197
|
abort_msg = check_abort(
|
|
175
198
|
run_id,
|
|
@@ -193,6 +216,16 @@ def push_messages(
|
|
|
193
216
|
results={str(message_id): 0},
|
|
194
217
|
objects_to_push=objects_to_push,
|
|
195
218
|
)
|
|
219
|
+
|
|
220
|
+
# Record outgoing traffic size
|
|
221
|
+
bytes_sent = len(response.SerializeToString())
|
|
222
|
+
|
|
223
|
+
# Record traffic only if message was successfully processed
|
|
224
|
+
# Only one message is processed per request
|
|
225
|
+
state.store_traffic(run_id, bytes_sent=bytes_sent, bytes_recv=bytes_recv)
|
|
226
|
+
if request.clientapp_runtime_list:
|
|
227
|
+
state.add_clientapp_runtime(run_id, request.clientapp_runtime_list[0])
|
|
228
|
+
|
|
196
229
|
return response
|
|
197
230
|
|
|
198
231
|
|
|
@@ -257,6 +290,10 @@ def push_object(
|
|
|
257
290
|
try:
|
|
258
291
|
store.put(request.object_id, request.object_content)
|
|
259
292
|
stored = True
|
|
293
|
+
# Record bytes traffic pushed from SuperNode
|
|
294
|
+
state.store_traffic(
|
|
295
|
+
request.run_id, bytes_sent=0, bytes_recv=len(request.object_content)
|
|
296
|
+
)
|
|
260
297
|
except (NoObjectInStoreError, ValueError) as e:
|
|
261
298
|
log(ERROR, str(e))
|
|
262
299
|
except UnexpectedObjectContentError as e:
|
|
@@ -283,6 +320,8 @@ def pull_object(
|
|
|
283
320
|
content = store.get(request.object_id)
|
|
284
321
|
if content is not None:
|
|
285
322
|
object_available = content != b""
|
|
323
|
+
# Record bytes traffic pulled by SuperNode
|
|
324
|
+
state.store_traffic(request.run_id, bytes_sent=len(content), bytes_recv=0)
|
|
286
325
|
return PullObjectResponse(
|
|
287
326
|
object_found=True,
|
|
288
327
|
object_available=object_available,
|
|
@@ -576,6 +576,9 @@ class InMemoryLinkState(LinkState, InMemoryCoreState): # pylint: disable=R0902,
|
|
|
576
576
|
),
|
|
577
577
|
flwr_aid=flwr_aid if flwr_aid else "",
|
|
578
578
|
federation=federation,
|
|
579
|
+
bytes_sent=0,
|
|
580
|
+
bytes_recv=0,
|
|
581
|
+
clientapp_runtime=0.0,
|
|
579
582
|
),
|
|
580
583
|
)
|
|
581
584
|
self.run_ids[run_id] = run_record
|
|
@@ -771,3 +774,34 @@ class InMemoryLinkState(LinkState, InMemoryCoreState): # pylint: disable=R0902,
|
|
|
771
774
|
index = bisect_right(run.logs, (after_timestamp, ""))
|
|
772
775
|
latest_timestamp = run.logs[-1][0] if index < len(run.logs) else 0.0
|
|
773
776
|
return "".join(log for _, log in run.logs[index:]), latest_timestamp
|
|
777
|
+
|
|
778
|
+
def store_traffic(self, run_id: int, *, bytes_sent: int, bytes_recv: int) -> None:
|
|
779
|
+
"""Store traffic data for the specified `run_id`."""
|
|
780
|
+
# Validate non-negative values
|
|
781
|
+
if bytes_sent < 0 or bytes_recv < 0:
|
|
782
|
+
raise ValueError(
|
|
783
|
+
f"Negative traffic values for run {run_id}: "
|
|
784
|
+
f"bytes_sent={bytes_sent}, bytes_recv={bytes_recv}"
|
|
785
|
+
)
|
|
786
|
+
|
|
787
|
+
if bytes_sent == 0 and bytes_recv == 0:
|
|
788
|
+
raise ValueError(
|
|
789
|
+
f"Both bytes_sent and bytes_recv cannot be zero for run {run_id}"
|
|
790
|
+
)
|
|
791
|
+
|
|
792
|
+
with self.lock:
|
|
793
|
+
if run_id not in self.run_ids:
|
|
794
|
+
raise ValueError(f"Run {run_id} not found")
|
|
795
|
+
run_record = self.run_ids[run_id]
|
|
796
|
+
|
|
797
|
+
with run_record.lock:
|
|
798
|
+
run = run_record.run
|
|
799
|
+
run.bytes_sent += bytes_sent
|
|
800
|
+
run.bytes_recv += bytes_recv
|
|
801
|
+
|
|
802
|
+
def add_clientapp_runtime(self, run_id: int, runtime: float) -> None:
|
|
803
|
+
"""Add ClientApp runtime to the cumulative total for the specified `run_id`."""
|
|
804
|
+
with self.lock:
|
|
805
|
+
if run_id not in self.run_ids:
|
|
806
|
+
raise ValueError(f"Run {run_id} not found")
|
|
807
|
+
self.run_ids[run_id].run.clientapp_runtime += runtime
|
|
@@ -480,3 +480,35 @@ class LinkState(CoreState): # pylint: disable=R0904
|
|
|
480
480
|
- The timestamp of the latest log entry in the returned logs.
|
|
481
481
|
Returns `0` if no logs are returned.
|
|
482
482
|
"""
|
|
483
|
+
|
|
484
|
+
@abc.abstractmethod
|
|
485
|
+
def store_traffic(self, run_id: int, *, bytes_sent: int, bytes_recv: int) -> None:
|
|
486
|
+
"""Store traffic data for the specified `run_id`.
|
|
487
|
+
|
|
488
|
+
Parameters
|
|
489
|
+
----------
|
|
490
|
+
run_id : int
|
|
491
|
+
The identifier of the run for which to store traffic data.
|
|
492
|
+
bytes_sent : int
|
|
493
|
+
The number of bytes pulled by SuperNodes from the SuperLink to add to the
|
|
494
|
+
run's total.
|
|
495
|
+
bytes_recv : int
|
|
496
|
+
The number of bytes received by SuperLink from SuperNodes to add to the
|
|
497
|
+
run's total.
|
|
498
|
+
"""
|
|
499
|
+
|
|
500
|
+
@abc.abstractmethod
|
|
501
|
+
def add_clientapp_runtime(self, run_id: int, runtime: float) -> None:
|
|
502
|
+
"""Add ClientApp runtime to the cumulative total for the specified `run_id`.
|
|
503
|
+
|
|
504
|
+
This method accumulates the runtime by adding the provided value to the
|
|
505
|
+
existing total runtime for the run. Multiple ClientApps can contribute
|
|
506
|
+
to the same run's total runtime.
|
|
507
|
+
|
|
508
|
+
Parameters
|
|
509
|
+
----------
|
|
510
|
+
run_id : int
|
|
511
|
+
The identifier of the run for which to store each ClientApp's runtime.
|
|
512
|
+
runtime : float
|
|
513
|
+
The runtime in seconds to add to the `run_id`'s cumulative total.
|
|
514
|
+
"""
|
|
@@ -119,7 +119,10 @@ CREATE TABLE IF NOT EXISTS run(
|
|
|
119
119
|
details TEXT,
|
|
120
120
|
federation TEXT,
|
|
121
121
|
federation_options BLOB,
|
|
122
|
-
flwr_aid TEXT
|
|
122
|
+
flwr_aid TEXT,
|
|
123
|
+
bytes_sent INTEGER DEFAULT 0,
|
|
124
|
+
bytes_recv INTEGER DEFAULT 0,
|
|
125
|
+
clientapp_runtime REAL DEFAULT 0.0
|
|
123
126
|
);
|
|
124
127
|
"""
|
|
125
128
|
|
|
@@ -905,8 +908,8 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
905
908
|
(run_id, fab_id, fab_version,
|
|
906
909
|
fab_hash, override_config, federation, federation_options,
|
|
907
910
|
pending_at, starting_at, running_at, finished_at, sub_status,
|
|
908
|
-
details, flwr_aid)
|
|
909
|
-
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
|
|
911
|
+
details, flwr_aid, bytes_sent, bytes_recv, clientapp_runtime)
|
|
912
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
|
|
910
913
|
"""
|
|
911
914
|
override_config_json = json.dumps(override_config)
|
|
912
915
|
data = [
|
|
@@ -924,6 +927,9 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
924
927
|
"", # sub_status
|
|
925
928
|
"", # details
|
|
926
929
|
flwr_aid or "", # flwr_aid
|
|
930
|
+
0, # bytes_sent
|
|
931
|
+
0, # bytes_recv
|
|
932
|
+
0, # clientapp_runtime
|
|
927
933
|
]
|
|
928
934
|
self.conn.execute(query, tuple(data))
|
|
929
935
|
return uint64_run_id
|
|
@@ -972,6 +978,9 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
972
978
|
),
|
|
973
979
|
flwr_aid=row["flwr_aid"],
|
|
974
980
|
federation=row["federation"],
|
|
981
|
+
bytes_sent=row["bytes_sent"],
|
|
982
|
+
bytes_recv=row["bytes_recv"],
|
|
983
|
+
clientapp_runtime=row["clientapp_runtime"],
|
|
975
984
|
)
|
|
976
985
|
log(ERROR, "`run_id` does not exist.")
|
|
977
986
|
return None
|
|
@@ -1255,6 +1264,54 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
1255
1264
|
|
|
1256
1265
|
return rows[0]
|
|
1257
1266
|
|
|
1267
|
+
def store_traffic(self, run_id: int, *, bytes_sent: int, bytes_recv: int) -> None:
|
|
1268
|
+
"""Store traffic data for the specified `run_id`."""
|
|
1269
|
+
# Validate non-negative values
|
|
1270
|
+
if bytes_sent < 0 or bytes_recv < 0:
|
|
1271
|
+
raise ValueError(
|
|
1272
|
+
f"Negative traffic values for run {run_id}: "
|
|
1273
|
+
f"bytes_sent={bytes_sent}, bytes_recv={bytes_recv}"
|
|
1274
|
+
)
|
|
1275
|
+
|
|
1276
|
+
if bytes_sent == 0 and bytes_recv == 0:
|
|
1277
|
+
raise ValueError(
|
|
1278
|
+
f"Both bytes_sent and bytes_recv cannot be zero for run {run_id}"
|
|
1279
|
+
)
|
|
1280
|
+
|
|
1281
|
+
sint64_run_id = uint64_to_int64(run_id)
|
|
1282
|
+
|
|
1283
|
+
with self.conn:
|
|
1284
|
+
# Check if run exists, performing the update only if it does
|
|
1285
|
+
update_query = """
|
|
1286
|
+
UPDATE run
|
|
1287
|
+
SET bytes_sent = bytes_sent + ?,
|
|
1288
|
+
bytes_recv = bytes_recv + ?
|
|
1289
|
+
WHERE run_id = ?
|
|
1290
|
+
RETURNING run_id;
|
|
1291
|
+
"""
|
|
1292
|
+
rows = self.conn.execute(
|
|
1293
|
+
update_query, (bytes_sent, bytes_recv, sint64_run_id)
|
|
1294
|
+
).fetchall()
|
|
1295
|
+
|
|
1296
|
+
if not rows:
|
|
1297
|
+
raise ValueError(f"Run {run_id} not found")
|
|
1298
|
+
|
|
1299
|
+
def add_clientapp_runtime(self, run_id: int, runtime: float) -> None:
|
|
1300
|
+
"""Add ClientApp runtime to the cumulative total for the specified `run_id`."""
|
|
1301
|
+
sint64_run_id = uint64_to_int64(run_id)
|
|
1302
|
+
with self.conn:
|
|
1303
|
+
# Check if run exists, performing the update only if it does
|
|
1304
|
+
update_query = """
|
|
1305
|
+
UPDATE run
|
|
1306
|
+
SET clientapp_runtime = clientapp_runtime + ?
|
|
1307
|
+
WHERE run_id = ?
|
|
1308
|
+
RETURNING run_id;
|
|
1309
|
+
"""
|
|
1310
|
+
rows = self.conn.execute(update_query, (runtime, sint64_run_id)).fetchall()
|
|
1311
|
+
|
|
1312
|
+
if not rows:
|
|
1313
|
+
raise ValueError(f"Run {run_id} not found")
|
|
1314
|
+
|
|
1258
1315
|
|
|
1259
1316
|
def message_to_dict(message: Message) -> dict[str, Any]:
|
|
1260
1317
|
"""Transform Message to dict."""
|
flwr/supercore/constant.py
CHANGED
|
@@ -57,6 +57,9 @@ NOOP_FEDERATION = "default"
|
|
|
57
57
|
# Constants for exit handling
|
|
58
58
|
FORCE_EXIT_TIMEOUT_SECONDS = 5 # Used in `flwr_exit` function
|
|
59
59
|
|
|
60
|
+
# Constants for message processing timing
|
|
61
|
+
MESSAGE_TIME_ENTRY_MAX_AGE_SECONDS = 3600
|
|
62
|
+
|
|
60
63
|
|
|
61
64
|
class NodeStatus:
|
|
62
65
|
"""Event log writer types."""
|
flwr/supercore/utils.py
CHANGED
|
@@ -15,6 +15,16 @@
|
|
|
15
15
|
"""Utility functions for the infrastructure."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
import json
|
|
19
|
+
import re
|
|
20
|
+
|
|
21
|
+
import requests
|
|
22
|
+
|
|
23
|
+
from flwr.common.version import package_version as flwr_version
|
|
24
|
+
|
|
25
|
+
from .constant import APP_ID_PATTERN, APP_VERSION_PATTERN
|
|
26
|
+
|
|
27
|
+
|
|
18
28
|
def mask_string(value: str, head: int = 4, tail: int = 4) -> str:
|
|
19
29
|
"""Mask a string by preserving only the head and tail characters.
|
|
20
30
|
|
|
@@ -50,3 +60,183 @@ def int64_to_uint64(signed: int) -> int:
|
|
|
50
60
|
if signed < 0:
|
|
51
61
|
return signed + (1 << 64)
|
|
52
62
|
return signed
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def parse_app_spec(app_spec: str) -> tuple[str, str | None]:
|
|
66
|
+
"""Parse app specification string into app ID and version.
|
|
67
|
+
|
|
68
|
+
Parameters
|
|
69
|
+
----------
|
|
70
|
+
app_spec : str
|
|
71
|
+
The app specification string in the format '@account/app' or
|
|
72
|
+
'@account/app==x.y.z' (digits only).
|
|
73
|
+
|
|
74
|
+
Returns
|
|
75
|
+
-------
|
|
76
|
+
tuple[str, str | None]
|
|
77
|
+
A tuple containing the app ID and optional version.
|
|
78
|
+
|
|
79
|
+
Raises
|
|
80
|
+
------
|
|
81
|
+
ValueError
|
|
82
|
+
If the app specification format is invalid.
|
|
83
|
+
"""
|
|
84
|
+
if "==" in app_spec:
|
|
85
|
+
app_id, app_version = app_spec.split("==", 1)
|
|
86
|
+
|
|
87
|
+
if not re.match(APP_VERSION_PATTERN, app_version):
|
|
88
|
+
raise ValueError(
|
|
89
|
+
"Invalid app version. Expected format: x.y.z (digits only)."
|
|
90
|
+
)
|
|
91
|
+
else:
|
|
92
|
+
app_id = app_spec
|
|
93
|
+
app_version = None
|
|
94
|
+
|
|
95
|
+
if not re.match(APP_ID_PATTERN, app_id):
|
|
96
|
+
raise ValueError(
|
|
97
|
+
"Invalid remote app ID. Expected format: '@account_name/app_name'."
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
return app_id, app_version
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def request_download_link(
|
|
104
|
+
app_id: str, app_version: str | None, in_url: str, out_url: str
|
|
105
|
+
) -> tuple[str, list[dict[str, str]] | None]:
|
|
106
|
+
"""Request a download link for the given app from the Flower Platform API.
|
|
107
|
+
|
|
108
|
+
Parameters
|
|
109
|
+
----------
|
|
110
|
+
app_id : str
|
|
111
|
+
The application identifier in the format '@account/app'.
|
|
112
|
+
app_version : str | None
|
|
113
|
+
The application version (e.g., '1.2.3'), or None to request the latest version.
|
|
114
|
+
in_url : str
|
|
115
|
+
The Platform API endpoint URL to query.
|
|
116
|
+
out_url : str
|
|
117
|
+
The key name in the response that contains the download URL.
|
|
118
|
+
|
|
119
|
+
Returns
|
|
120
|
+
-------
|
|
121
|
+
tuple[str, list[dict[str, str]] | None]
|
|
122
|
+
A tuple containing:
|
|
123
|
+
- The download URL for the application.
|
|
124
|
+
- A list of verification dictionaries if provided by the API, otherwise None.
|
|
125
|
+
|
|
126
|
+
Raises
|
|
127
|
+
------
|
|
128
|
+
ValueError
|
|
129
|
+
If the API connection fails, the application or version is not found,
|
|
130
|
+
the API returns a non-200 response, or the response format is invalid.
|
|
131
|
+
"""
|
|
132
|
+
headers = {
|
|
133
|
+
"Content-Type": "application/json",
|
|
134
|
+
"Accept": "application/json",
|
|
135
|
+
}
|
|
136
|
+
body = {
|
|
137
|
+
"app_id": app_id, # send raw string of app_id
|
|
138
|
+
"app_version": app_version,
|
|
139
|
+
"flwr_version": flwr_version,
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
try:
|
|
143
|
+
resp = requests.post(in_url, headers=headers, data=json.dumps(body), timeout=20)
|
|
144
|
+
except requests.RequestException as e:
|
|
145
|
+
raise ValueError(f"Unable to connect to Platform API: {e}") from e
|
|
146
|
+
|
|
147
|
+
if resp.status_code == 404:
|
|
148
|
+
# Expecting a JSON body with a "detail" field
|
|
149
|
+
try:
|
|
150
|
+
error_message = resp.json().get("detail")
|
|
151
|
+
except ValueError:
|
|
152
|
+
# JSON parsing failed
|
|
153
|
+
raise ValueError(f"{app_id} not found in Platform API.") from None
|
|
154
|
+
|
|
155
|
+
if isinstance(error_message, dict):
|
|
156
|
+
available_app_versions = error_message.get("available_app_versions", [])
|
|
157
|
+
available_versions_str = (
|
|
158
|
+
", ".join(map(str, available_app_versions))
|
|
159
|
+
if available_app_versions
|
|
160
|
+
else "None"
|
|
161
|
+
)
|
|
162
|
+
raise ValueError(
|
|
163
|
+
f"{app_id}=={app_version} not found in Platform API. "
|
|
164
|
+
f"Available app versions for {app_id}: {available_versions_str}"
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
raise ValueError(f"{app_id} not found in Platform API.")
|
|
168
|
+
|
|
169
|
+
if not resp.ok:
|
|
170
|
+
raise ValueError(
|
|
171
|
+
f"Platform API request failed with status {resp.status_code}. "
|
|
172
|
+
f"Details: {resp.text}"
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
data = resp.json()
|
|
176
|
+
if out_url not in data:
|
|
177
|
+
raise ValueError("Invalid response from Platform API")
|
|
178
|
+
|
|
179
|
+
verifications = data["verifications"] if "verifications" in data else None
|
|
180
|
+
|
|
181
|
+
return str(data[out_url]), verifications
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def humanize_duration(seconds: float) -> str:
|
|
185
|
+
"""Convert a duration in seconds to a human-friendly string.
|
|
186
|
+
|
|
187
|
+
Rules:
|
|
188
|
+
- < 90 seconds: show seconds
|
|
189
|
+
- < 1 hour: show minutes + seconds
|
|
190
|
+
- < 1 day: show hours + minutes
|
|
191
|
+
- >= 1 day: show days + hours
|
|
192
|
+
"""
|
|
193
|
+
seconds = int(seconds)
|
|
194
|
+
|
|
195
|
+
# Under 90 seconds → Seconds only
|
|
196
|
+
if seconds < 90:
|
|
197
|
+
return f"{seconds}s"
|
|
198
|
+
|
|
199
|
+
# Under 1 hour → Minutes and seconds
|
|
200
|
+
minutes, sec = divmod(seconds, 60)
|
|
201
|
+
if minutes < 60:
|
|
202
|
+
return f"{minutes}m {sec}s"
|
|
203
|
+
|
|
204
|
+
# Under 1 day → Hours and minutes
|
|
205
|
+
hours, minutes = divmod(minutes, 60)
|
|
206
|
+
if hours < 24:
|
|
207
|
+
return f"{hours}h {minutes}m"
|
|
208
|
+
|
|
209
|
+
# 1+ days → Days and hours
|
|
210
|
+
days, hours = divmod(hours, 24)
|
|
211
|
+
return f"{days}d {hours}h"
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def humanize_bytes(num_bytes: int) -> str:
|
|
215
|
+
"""Convert a number of bytes to a human-friendly string.
|
|
216
|
+
|
|
217
|
+
Uses 1024-based units and 0-1 decimal precision.
|
|
218
|
+
Rules:
|
|
219
|
+
- < 1 KB: bytes
|
|
220
|
+
- < 1 MB: KB
|
|
221
|
+
- < 1 GB: MB
|
|
222
|
+
- < 1 TB: GB
|
|
223
|
+
"""
|
|
224
|
+
value = float(num_bytes)
|
|
225
|
+
|
|
226
|
+
for suffix in ["B", "KB", "MB", "GB", "TB"]:
|
|
227
|
+
if value < 1024 or suffix == "TB":
|
|
228
|
+
# Bytes → no decimals
|
|
229
|
+
if suffix == "B":
|
|
230
|
+
return f"{int(value)} B"
|
|
231
|
+
|
|
232
|
+
# Decide precision: 1 decimal for <10, otherwise no decimal
|
|
233
|
+
if value < 10:
|
|
234
|
+
formatted = f"{value:.1f}"
|
|
235
|
+
else:
|
|
236
|
+
formatted = f"{int(value)}"
|
|
237
|
+
|
|
238
|
+
return f"{formatted} {suffix}"
|
|
239
|
+
|
|
240
|
+
value /= 1024
|
|
241
|
+
|
|
242
|
+
raise RuntimeError("Unreachable code") # Make mypy happy
|
|
@@ -61,6 +61,7 @@ def run_control_api_grpc(
|
|
|
61
61
|
authz_plugin: ControlAuthzPlugin,
|
|
62
62
|
event_log_plugin: EventLogWriterPlugin | None = None,
|
|
63
63
|
artifact_provider: ArtifactProvider | None = None,
|
|
64
|
+
fleet_api_type: str | None = None,
|
|
64
65
|
) -> grpc.Server:
|
|
65
66
|
"""Run Control API (gRPC, request-response)."""
|
|
66
67
|
license_plugin: LicensePlugin | None = get_license_plugin()
|
|
@@ -74,6 +75,7 @@ def run_control_api_grpc(
|
|
|
74
75
|
is_simulation=is_simulation,
|
|
75
76
|
authn_plugin=authn_plugin,
|
|
76
77
|
artifact_provider=artifact_provider,
|
|
78
|
+
fleet_api_type=fleet_api_type,
|
|
77
79
|
)
|
|
78
80
|
interceptors = [ControlAccountAuthInterceptor(authn_plugin, authz_plugin)]
|
|
79
81
|
if license_plugin is not None:
|
|
@@ -16,12 +16,14 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import hashlib
|
|
19
|
+
import json
|
|
19
20
|
import time
|
|
20
21
|
from collections.abc import Generator, Sequence
|
|
21
22
|
from logging import ERROR, INFO
|
|
22
23
|
from typing import Any, cast
|
|
23
24
|
|
|
24
25
|
import grpc
|
|
26
|
+
import requests
|
|
25
27
|
|
|
26
28
|
from flwr.cli.config_utils import get_fab_metadata
|
|
27
29
|
from flwr.common import Context, RecordDict, now
|
|
@@ -36,6 +38,7 @@ from flwr.common.constant import (
|
|
|
36
38
|
PUBLIC_KEY_NOT_VALID,
|
|
37
39
|
PULL_UNFINISHED_RUN_MESSAGE,
|
|
38
40
|
RUN_ID_NOT_FOUND_MESSAGE,
|
|
41
|
+
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
39
42
|
Status,
|
|
40
43
|
SubStatus,
|
|
41
44
|
)
|
|
@@ -76,9 +79,11 @@ from flwr.proto.control_pb2 import ( # pylint: disable=E0611
|
|
|
76
79
|
from flwr.proto.federation_pb2 import Federation # pylint: disable=E0611
|
|
77
80
|
from flwr.proto.node_pb2 import NodeInfo # pylint: disable=E0611
|
|
78
81
|
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
|
82
|
+
from flwr.supercore.constant import PLATFORM_API_URL
|
|
79
83
|
from flwr.supercore.ffs import FfsFactory
|
|
80
84
|
from flwr.supercore.object_store import ObjectStore, ObjectStoreFactory
|
|
81
85
|
from flwr.supercore.primitives.asymmetric import bytes_to_public_key, uses_nist_ec_curve
|
|
86
|
+
from flwr.supercore.utils import parse_app_spec, request_download_link
|
|
82
87
|
from flwr.superlink.artifact_provider import ArtifactProvider
|
|
83
88
|
from flwr.superlink.auth_plugin import ControlAuthnPlugin
|
|
84
89
|
|
|
@@ -96,6 +101,7 @@ class ControlServicer(control_pb2_grpc.ControlServicer):
|
|
|
96
101
|
is_simulation: bool,
|
|
97
102
|
authn_plugin: ControlAuthnPlugin,
|
|
98
103
|
artifact_provider: ArtifactProvider | None = None,
|
|
104
|
+
fleet_api_type: str | None = None,
|
|
99
105
|
) -> None:
|
|
100
106
|
self.linkstate_factory = linkstate_factory
|
|
101
107
|
self.ffs_factory = ffs_factory
|
|
@@ -103,8 +109,9 @@ class ControlServicer(control_pb2_grpc.ControlServicer):
|
|
|
103
109
|
self.is_simulation = is_simulation
|
|
104
110
|
self.authn_plugin = authn_plugin
|
|
105
111
|
self.artifact_provider = artifact_provider
|
|
112
|
+
self.fleet_api_type = fleet_api_type
|
|
106
113
|
|
|
107
|
-
def StartRun( # pylint: disable=too-many-locals
|
|
114
|
+
def StartRun( # pylint: disable=too-many-locals, too-many-branches, too-many-statements
|
|
108
115
|
self, request: StartRunRequest, context: grpc.ServicerContext
|
|
109
116
|
) -> StartRunResponse:
|
|
110
117
|
"""Create run ID."""
|
|
@@ -112,7 +119,15 @@ class ControlServicer(control_pb2_grpc.ControlServicer):
|
|
|
112
119
|
state = self.linkstate_factory.state()
|
|
113
120
|
ffs = self.ffs_factory.ffs()
|
|
114
121
|
|
|
115
|
-
|
|
122
|
+
verification_dict: dict[str, str] = {}
|
|
123
|
+
if request.app_spec:
|
|
124
|
+
fab_file, verification_dict = _get_remote_fab(
|
|
125
|
+
self.fleet_api_type, request.app_spec, context
|
|
126
|
+
)
|
|
127
|
+
else:
|
|
128
|
+
fab_file = request.fab.content
|
|
129
|
+
|
|
130
|
+
if len(fab_file) > FAB_MAX_SIZE:
|
|
116
131
|
log(
|
|
117
132
|
ERROR,
|
|
118
133
|
"FAB size exceeds maximum allowed size of %d bytes.",
|
|
@@ -124,7 +139,6 @@ class ControlServicer(control_pb2_grpc.ControlServicer):
|
|
|
124
139
|
flwr_aid = _check_flwr_aid_exists(flwr_aid, context)
|
|
125
140
|
override_config = user_config_from_proto(request.override_config)
|
|
126
141
|
federation_options = config_record_from_proto(request.federation_options)
|
|
127
|
-
fab_file = request.fab.content
|
|
128
142
|
|
|
129
143
|
try:
|
|
130
144
|
# Check that num-supernodes is set
|
|
@@ -150,9 +164,10 @@ class ControlServicer(control_pb2_grpc.ControlServicer):
|
|
|
150
164
|
fab = Fab(
|
|
151
165
|
hashlib.sha256(fab_file).hexdigest(),
|
|
152
166
|
fab_file,
|
|
153
|
-
|
|
167
|
+
verification_dict,
|
|
154
168
|
)
|
|
155
|
-
fab_hash = ffs.put(fab.content,
|
|
169
|
+
fab_hash = ffs.put(fab.content, fab.verifications)
|
|
170
|
+
|
|
156
171
|
if fab_hash != fab.hash_str:
|
|
157
172
|
raise RuntimeError(
|
|
158
173
|
f"FAB ({fab.hash_str}) hash from request doesn't match contents"
|
|
@@ -612,3 +627,71 @@ def _check_flwr_aid_in_run(
|
|
|
612
627
|
grpc.StatusCode.PERMISSION_DENIED,
|
|
613
628
|
"⛔️ Run ID does not belong to the account",
|
|
614
629
|
)
|
|
630
|
+
|
|
631
|
+
|
|
632
|
+
def _format_verification(verifications: list[dict[str, str]]) -> dict[str, str]:
|
|
633
|
+
"""Format verification information for FAB."""
|
|
634
|
+
# Convert verifications to dict[str, str] type
|
|
635
|
+
verification_dict = {
|
|
636
|
+
item["public_key_id"]: json.dumps(
|
|
637
|
+
{k: v for k, v in item.items() if k != "public_key_id"}
|
|
638
|
+
)
|
|
639
|
+
for item in verifications
|
|
640
|
+
}
|
|
641
|
+
verification_dict.update({"valid_license": "Valid"})
|
|
642
|
+
|
|
643
|
+
return verification_dict
|
|
644
|
+
|
|
645
|
+
|
|
646
|
+
def _get_remote_fab(
|
|
647
|
+
fleet_api_type: str | None,
|
|
648
|
+
app_spec: str,
|
|
649
|
+
context: grpc.ServicerContext,
|
|
650
|
+
) -> tuple[bytes, dict[str, str]]:
|
|
651
|
+
"""Get remote FAB from Flower platform API."""
|
|
652
|
+
if fleet_api_type == TRANSPORT_TYPE_GRPC_ADAPTER:
|
|
653
|
+
context.abort(
|
|
654
|
+
grpc.StatusCode.FAILED_PRECONDITION,
|
|
655
|
+
"The selected SuperLink transport type is not "
|
|
656
|
+
"supported for connecting to Flower Platform.",
|
|
657
|
+
)
|
|
658
|
+
|
|
659
|
+
# Parse and validate app specification
|
|
660
|
+
try:
|
|
661
|
+
app_id, app_version = parse_app_spec(app_spec)
|
|
662
|
+
except ValueError as e:
|
|
663
|
+
context.abort(
|
|
664
|
+
grpc.StatusCode.FAILED_PRECONDITION,
|
|
665
|
+
f"{e}",
|
|
666
|
+
)
|
|
667
|
+
|
|
668
|
+
# Request download link and verification information
|
|
669
|
+
url = f"{PLATFORM_API_URL}/hub/fetch-fab"
|
|
670
|
+
try:
|
|
671
|
+
presigned_url, verifications = request_download_link(
|
|
672
|
+
app_id, app_version, url, "fab_url"
|
|
673
|
+
)
|
|
674
|
+
except ValueError as e:
|
|
675
|
+
context.abort(
|
|
676
|
+
grpc.StatusCode.FAILED_PRECONDITION,
|
|
677
|
+
f"{e}",
|
|
678
|
+
)
|
|
679
|
+
|
|
680
|
+
# Format verification information
|
|
681
|
+
verification_dict = (
|
|
682
|
+
_format_verification(verifications)
|
|
683
|
+
if verifications is not None
|
|
684
|
+
else {"valid_license": ""}
|
|
685
|
+
)
|
|
686
|
+
|
|
687
|
+
# Download FAB from Flower platform API
|
|
688
|
+
try:
|
|
689
|
+
r = requests.get(presigned_url, timeout=60)
|
|
690
|
+
r.raise_for_status()
|
|
691
|
+
except requests.RequestException as e:
|
|
692
|
+
context.abort(
|
|
693
|
+
grpc.StatusCode.FAILED_PRECONDITION,
|
|
694
|
+
f"FAB download failed: {str(e)}",
|
|
695
|
+
)
|
|
696
|
+
fab_file = r.content
|
|
697
|
+
return fab_file, verification_dict
|