flwr-nightly 1.12.0.dev20240918__py3-none-any.whl → 1.12.0.dev20241006__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (47) hide show
  1. flwr/cli/app.py +2 -0
  2. flwr/cli/log.py +234 -0
  3. flwr/cli/new/new.py +1 -1
  4. flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
  5. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -0
  6. flwr/cli/run/run.py +17 -1
  7. flwr/client/grpc_rere_client/client_interceptor.py +3 -0
  8. flwr/client/grpc_rere_client/connection.py +3 -3
  9. flwr/client/grpc_rere_client/grpc_adapter.py +14 -3
  10. flwr/client/rest_client/connection.py +3 -3
  11. flwr/client/supernode/app.py +1 -0
  12. flwr/common/constant.py +6 -3
  13. flwr/common/secure_aggregation/secaggplus_utils.py +4 -4
  14. flwr/common/serde.py +22 -7
  15. flwr/proto/control_pb2.py +27 -0
  16. flwr/proto/control_pb2.pyi +7 -0
  17. flwr/proto/control_pb2_grpc.py +135 -0
  18. flwr/proto/control_pb2_grpc.pyi +53 -0
  19. flwr/proto/driver_pb2.py +15 -24
  20. flwr/proto/driver_pb2.pyi +0 -52
  21. flwr/proto/driver_pb2_grpc.py +6 -6
  22. flwr/proto/driver_pb2_grpc.pyi +4 -4
  23. flwr/proto/fab_pb2.py +8 -7
  24. flwr/proto/fab_pb2.pyi +7 -1
  25. flwr/proto/fleet_pb2.py +10 -10
  26. flwr/proto/fleet_pb2.pyi +6 -1
  27. flwr/proto/recordset_pb2.py +35 -33
  28. flwr/proto/recordset_pb2.pyi +40 -14
  29. flwr/proto/run_pb2.py +33 -9
  30. flwr/proto/run_pb2.pyi +150 -1
  31. flwr/proto/transport_pb2.py +8 -8
  32. flwr/proto/transport_pb2.pyi +9 -6
  33. flwr/server/run_serverapp.py +2 -2
  34. flwr/server/superlink/driver/driver_servicer.py +2 -2
  35. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +17 -2
  36. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +4 -0
  37. flwr/server/superlink/state/in_memory_state.py +17 -0
  38. flwr/server/superlink/state/sqlite_state.py +44 -6
  39. flwr/server/utils/validator.py +6 -0
  40. flwr/superexec/deployment.py +3 -1
  41. flwr/superexec/exec_servicer.py +68 -3
  42. flwr/superexec/executor.py +2 -1
  43. {flwr_nightly-1.12.0.dev20240918.dist-info → flwr_nightly-1.12.0.dev20241006.dist-info}/METADATA +4 -2
  44. {flwr_nightly-1.12.0.dev20240918.dist-info → flwr_nightly-1.12.0.dev20241006.dist-info}/RECORD +47 -42
  45. {flwr_nightly-1.12.0.dev20240918.dist-info → flwr_nightly-1.12.0.dev20241006.dist-info}/LICENSE +0 -0
  46. {flwr_nightly-1.12.0.dev20240918.dist-info → flwr_nightly-1.12.0.dev20241006.dist-info}/WHEEL +0 -0
  47. {flwr_nightly-1.12.0.dev20240918.dist-info → flwr_nightly-1.12.0.dev20241006.dist-info}/entry_points.txt +0 -0
flwr/cli/app.py CHANGED
@@ -19,6 +19,7 @@ from typer.main import get_command
19
19
 
20
20
  from .build import build
21
21
  from .install import install
22
+ from .log import log
22
23
  from .new import new
23
24
  from .run import run
24
25
 
@@ -35,6 +36,7 @@ app.command()(new)
35
36
  app.command()(run)
36
37
  app.command()(build)
37
38
  app.command()(install)
39
+ app.command()(log)
38
40
 
39
41
  typer_click_object = get_command(app)
40
42
 
flwr/cli/log.py ADDED
@@ -0,0 +1,234 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower command line interface `log` command."""
16
+
17
+ import sys
18
+ import time
19
+ from logging import DEBUG, ERROR, INFO
20
+ from pathlib import Path
21
+ from typing import Annotated, Optional
22
+
23
+ import grpc
24
+ import typer
25
+
26
+ from flwr.cli.config_utils import load_and_validate
27
+ from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
28
+ from flwr.common.logger import log as logger
29
+ from flwr.proto.exec_pb2 import StreamLogsRequest # pylint: disable=E0611
30
+ from flwr.proto.exec_pb2_grpc import ExecStub
31
+
32
+ CONN_REFRESH_PERIOD = 60 # Connection refresh period for log streaming (seconds)
33
+
34
+
35
+ def start_stream(
36
+ run_id: int, channel: grpc.Channel, refresh_period: int = CONN_REFRESH_PERIOD
37
+ ) -> None:
38
+ """Start log streaming for a given run ID."""
39
+ try:
40
+ while True:
41
+ logger(INFO, "Starting logstream for run_id `%s`", run_id)
42
+ stream_logs(run_id, channel, refresh_period)
43
+ time.sleep(2)
44
+ logger(DEBUG, "Reconnecting to logstream")
45
+ except KeyboardInterrupt:
46
+ logger(INFO, "Exiting logstream")
47
+ except grpc.RpcError as e:
48
+ # pylint: disable=E1101
49
+ if e.code() == grpc.StatusCode.NOT_FOUND:
50
+ logger(ERROR, "Invalid run_id `%s`, exiting", run_id)
51
+ if e.code() == grpc.StatusCode.CANCELLED:
52
+ pass
53
+ finally:
54
+ channel.close()
55
+
56
+
57
+ def stream_logs(run_id: int, channel: grpc.Channel, duration: int) -> None:
58
+ """Stream logs from the beginning of a run with connection refresh."""
59
+ start_time = time.time()
60
+ stub = ExecStub(channel)
61
+ req = StreamLogsRequest(run_id=run_id)
62
+
63
+ for res in stub.StreamLogs(req):
64
+ print(res.log_output)
65
+ if time.time() - start_time > duration:
66
+ break
67
+
68
+
69
+ def print_logs(run_id: int, channel: grpc.Channel, timeout: int) -> None:
70
+ """Print logs from the beginning of a run."""
71
+ stub = ExecStub(channel)
72
+ req = StreamLogsRequest(run_id=run_id)
73
+
74
+ try:
75
+ while True:
76
+ try:
77
+ # Enforce timeout for graceful exit
78
+ for res in stub.StreamLogs(req, timeout=timeout):
79
+ print(res.log_output)
80
+ except grpc.RpcError as e:
81
+ # pylint: disable=E1101
82
+ if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
83
+ break
84
+ if e.code() == grpc.StatusCode.NOT_FOUND:
85
+ logger(ERROR, "Invalid run_id `%s`, exiting", run_id)
86
+ break
87
+ if e.code() == grpc.StatusCode.CANCELLED:
88
+ break
89
+ except KeyboardInterrupt:
90
+ logger(DEBUG, "Stream interrupted by user")
91
+ finally:
92
+ channel.close()
93
+ logger(DEBUG, "Channel closed")
94
+
95
+
96
+ def on_channel_state_change(channel_connectivity: str) -> None:
97
+ """Log channel connectivity."""
98
+ logger(DEBUG, channel_connectivity)
99
+
100
+
101
+ def log(
102
+ run_id: Annotated[
103
+ int,
104
+ typer.Argument(help="The Flower run ID to query"),
105
+ ],
106
+ app: Annotated[
107
+ Path,
108
+ typer.Argument(help="Path of the Flower project to run"),
109
+ ] = Path("."),
110
+ federation: Annotated[
111
+ Optional[str],
112
+ typer.Argument(help="Name of the federation to run the app on"),
113
+ ] = None,
114
+ stream: Annotated[
115
+ bool,
116
+ typer.Option(
117
+ "--stream/--show",
118
+ help="Flag to stream or print logs from the Flower run",
119
+ ),
120
+ ] = True,
121
+ ) -> None:
122
+ """Get logs from a Flower project run."""
123
+ typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
124
+
125
+ pyproject_path = app / "pyproject.toml" if app else None
126
+ config, errors, warnings = load_and_validate(path=pyproject_path)
127
+
128
+ if config is None:
129
+ typer.secho(
130
+ "Project configuration could not be loaded.\n"
131
+ "pyproject.toml is invalid:\n"
132
+ + "\n".join([f"- {line}" for line in errors]),
133
+ fg=typer.colors.RED,
134
+ bold=True,
135
+ )
136
+ sys.exit()
137
+
138
+ if warnings:
139
+ typer.secho(
140
+ "Project configuration is missing the following "
141
+ "recommended properties:\n" + "\n".join([f"- {line}" for line in warnings]),
142
+ fg=typer.colors.RED,
143
+ bold=True,
144
+ )
145
+
146
+ typer.secho("Success", fg=typer.colors.GREEN)
147
+
148
+ federation = federation or config["tool"]["flwr"]["federations"].get("default")
149
+
150
+ if federation is None:
151
+ typer.secho(
152
+ "❌ No federation name was provided and the project's `pyproject.toml` "
153
+ "doesn't declare a default federation (with a SuperExec address or an "
154
+ "`options.num-supernodes` value).",
155
+ fg=typer.colors.RED,
156
+ bold=True,
157
+ )
158
+ raise typer.Exit(code=1)
159
+
160
+ # Validate the federation exists in the configuration
161
+ federation_config = config["tool"]["flwr"]["federations"].get(federation)
162
+ if federation_config is None:
163
+ available_feds = {
164
+ fed for fed in config["tool"]["flwr"]["federations"] if fed != "default"
165
+ }
166
+ typer.secho(
167
+ f"❌ There is no `{federation}` federation declared in the "
168
+ "`pyproject.toml`.\n The following federations were found:\n\n"
169
+ + "\n".join(available_feds),
170
+ fg=typer.colors.RED,
171
+ bold=True,
172
+ )
173
+ raise typer.Exit(code=1)
174
+
175
+ if "address" not in federation_config:
176
+ typer.secho(
177
+ "❌ `flwr log` currently works with `SuperExec`. Ensure that the correct"
178
+ "`SuperExec` address is provided in the `pyproject.toml`.",
179
+ fg=typer.colors.RED,
180
+ bold=True,
181
+ )
182
+ raise typer.Exit(code=1)
183
+
184
+ _log_with_superexec(federation_config, run_id, stream)
185
+
186
+
187
+ # pylint: disable-next=too-many-branches
188
+ def _log_with_superexec(
189
+ federation_config: dict[str, str],
190
+ run_id: int,
191
+ stream: bool,
192
+ ) -> None:
193
+ insecure_str = federation_config.get("insecure")
194
+ if root_certificates := federation_config.get("root-certificates"):
195
+ root_certificates_bytes = Path(root_certificates).read_bytes()
196
+ if insecure := bool(insecure_str):
197
+ typer.secho(
198
+ "❌ `root_certificates` were provided but the `insecure` parameter"
199
+ "is set to `True`.",
200
+ fg=typer.colors.RED,
201
+ bold=True,
202
+ )
203
+ raise typer.Exit(code=1)
204
+ else:
205
+ root_certificates_bytes = None
206
+ if insecure_str is None:
207
+ typer.secho(
208
+ "❌ To disable TLS, set `insecure = true` in `pyproject.toml`.",
209
+ fg=typer.colors.RED,
210
+ bold=True,
211
+ )
212
+ raise typer.Exit(code=1)
213
+ if not (insecure := bool(insecure_str)):
214
+ typer.secho(
215
+ "❌ No certificate were given yet `insecure` is set to `False`.",
216
+ fg=typer.colors.RED,
217
+ bold=True,
218
+ )
219
+ raise typer.Exit(code=1)
220
+
221
+ channel = create_channel(
222
+ server_address=federation_config["address"],
223
+ insecure=insecure,
224
+ root_certificates=root_certificates_bytes,
225
+ max_message_length=GRPC_MAX_MESSAGE_LENGTH,
226
+ interceptors=None,
227
+ )
228
+ channel.subscribe(on_channel_state_change)
229
+
230
+ if stream:
231
+ start_stream(run_id, channel, CONN_REFRESH_PERIOD)
232
+ else:
233
+ logger(INFO, "Printing logstream for run_id `%s`", run_id)
234
+ print_logs(run_id, channel, timeout=5)
flwr/cli/new/new.py CHANGED
@@ -275,7 +275,7 @@ def new(
275
275
  )
276
276
  )
277
277
 
278
- _add = " huggingface-cli login\n" if framework_str == "flowertune" else ""
278
+ _add = " huggingface-cli login\n" if llm_challenge_str else ""
279
279
  print(
280
280
  typer.style(
281
281
  f" cd {package_name}\n" + " pip install -e .\n" + _add + " flwr run\n",
@@ -55,7 +55,7 @@ We use Mistral-7B model with 4-bit quantization as default. The estimated VRAM c
55
55
  | :--------: | :--------: | :--------: | :--------: | :--------: |
56
56
  | VRAM | ~25.50 GB | ~17.30 GB | ~22.80 GB | ~17.40 GB |
57
57
 
58
- You can adjust the CPU/GPU resources you assign to each of the clients based on your device, which are specified with `options.backend.clientapp-cpus` and `options.backend.clientapp-gpus` under `[tool.flwr.federations.local-simulation]` entry in `pyproject.toml`.
58
+ You can adjust the CPU/GPU resources you assign to each of the clients based on your device, which are specified with `options.backend.client-resources.num-cpus` and `options.backend.client-resources.num-gpus` under `[tool.flwr.federations.local-simulation]` entry in `pyproject.toml`.
59
59
 
60
60
 
61
61
  ## Model saving
@@ -17,6 +17,7 @@ dependencies = [
17
17
  "transformers==4.39.3",
18
18
  "sentencepiece==0.2.0",
19
19
  "omegaconf==2.3.0",
20
+ "hf_transfer==0.1.8",
20
21
  ]
21
22
 
22
23
  [tool.hatch.build.targets.wheel]
flwr/cli/run/run.py CHANGED
@@ -34,6 +34,10 @@ from flwr.common.typing import Fab
34
34
  from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611
35
35
  from flwr.proto.exec_pb2_grpc import ExecStub
36
36
 
37
+ from ..log import start_stream
38
+
39
+ CONN_REFRESH_PERIOD = 60 # Connection refresh period for log streaming (seconds)
40
+
37
41
 
38
42
  def on_channel_state_change(channel_connectivity: str) -> None:
39
43
  """Log channel connectivity."""
@@ -62,6 +66,14 @@ def run(
62
66
  "inside the `pyproject.toml` in order to be properly overriden.",
63
67
  ),
64
68
  ] = None,
69
+ stream: Annotated[
70
+ bool,
71
+ typer.Option(
72
+ "--stream",
73
+ help="Use `--stream` with `flwr run` to display logs;\n "
74
+ "logs are not streamed by default.",
75
+ ),
76
+ ] = False,
65
77
  ) -> None:
66
78
  """Run Flower App."""
67
79
  typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
@@ -117,7 +129,7 @@ def run(
117
129
  raise typer.Exit(code=1)
118
130
 
119
131
  if "address" in federation_config:
120
- _run_with_superexec(app, federation_config, config_overrides)
132
+ _run_with_superexec(app, federation_config, config_overrides, stream)
121
133
  else:
122
134
  _run_without_superexec(app, federation_config, config_overrides, federation)
123
135
 
@@ -126,6 +138,7 @@ def _run_with_superexec(
126
138
  app: Path,
127
139
  federation_config: dict[str, Any],
128
140
  config_overrides: Optional[list[str]],
141
+ stream: bool,
129
142
  ) -> None:
130
143
 
131
144
  insecure_str = federation_config.get("insecure")
@@ -183,6 +196,9 @@ def _run_with_superexec(
183
196
  fab_path.unlink()
184
197
  typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN)
185
198
 
199
+ if stream:
200
+ start_stream(res.run_id, channel, CONN_REFRESH_PERIOD)
201
+
186
202
 
187
203
  def _run_without_superexec(
188
204
  app: Optional[Path],
@@ -31,6 +31,7 @@ from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
31
31
  generate_shared_key,
32
32
  public_key_to_bytes,
33
33
  )
34
+ from flwr.proto.fab_pb2 import GetFabRequest # pylint: disable=E0611
34
35
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
35
36
  CreateNodeRequest,
36
37
  DeleteNodeRequest,
@@ -50,6 +51,7 @@ Request = Union[
50
51
  PushTaskResRequest,
51
52
  GetRunRequest,
52
53
  PingRequest,
54
+ GetFabRequest,
53
55
  ]
54
56
 
55
57
 
@@ -126,6 +128,7 @@ class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type:
126
128
  PushTaskResRequest,
127
129
  GetRunRequest,
128
130
  PingRequest,
131
+ GetFabRequest,
129
132
  ),
130
133
  ):
131
134
  if self.shared_secret is None:
@@ -269,7 +269,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
269
269
  task_res = message_to_taskres(message)
270
270
 
271
271
  # Serialize ProtoBuf to bytes
272
- request = PushTaskResRequest(task_res_list=[task_res])
272
+ request = PushTaskResRequest(node=node, task_res_list=[task_res])
273
273
  _ = retry_invoker.invoke(stub.PushTaskRes, request)
274
274
 
275
275
  # Cleanup
@@ -277,7 +277,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
277
277
 
278
278
  def get_run(run_id: int) -> Run:
279
279
  # Call FleetAPI
280
- get_run_request = GetRunRequest(run_id=run_id)
280
+ get_run_request = GetRunRequest(node=node, run_id=run_id)
281
281
  get_run_response: GetRunResponse = retry_invoker.invoke(
282
282
  stub.GetRun,
283
283
  request=get_run_request,
@@ -294,7 +294,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
294
294
 
295
295
  def get_fab(fab_hash: str) -> Fab:
296
296
  # Call FleetAPI
297
- get_fab_request = GetFabRequest(hash_str=fab_hash)
297
+ get_fab_request = GetFabRequest(node=node, hash_str=fab_hash)
298
298
  get_fab_response: GetFabResponse = retry_invoker.invoke(
299
299
  stub.GetFab,
300
300
  request=get_fab_request,
@@ -24,10 +24,14 @@ from google.protobuf.message import Message as GrpcMessage
24
24
 
25
25
  from flwr.common import log
26
26
  from flwr.common.constant import (
27
+ GRPC_ADAPTER_METADATA_FLOWER_PACKAGE_NAME_KEY,
28
+ GRPC_ADAPTER_METADATA_FLOWER_PACKAGE_VERSION_KEY,
27
29
  GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY,
30
+ GRPC_ADAPTER_METADATA_MESSAGE_MODULE_KEY,
31
+ GRPC_ADAPTER_METADATA_MESSAGE_QUALNAME_KEY,
28
32
  GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY,
29
33
  )
30
- from flwr.common.version import package_version
34
+ from flwr.common.version import package_name, package_version
31
35
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
32
36
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
33
37
  CreateNodeRequest,
@@ -62,9 +66,16 @@ class GrpcAdapter:
62
66
  self, request: GrpcMessage, response_type: type[T], **kwargs: Any
63
67
  ) -> T:
64
68
  # Serialize request
69
+ req_cls = request.__class__
65
70
  container_req = MessageContainer(
66
- metadata={GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY: package_version},
67
- grpc_message_name=request.__class__.__qualname__,
71
+ metadata={
72
+ GRPC_ADAPTER_METADATA_FLOWER_PACKAGE_NAME_KEY: package_name,
73
+ GRPC_ADAPTER_METADATA_FLOWER_PACKAGE_VERSION_KEY: package_version,
74
+ GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY: package_version,
75
+ GRPC_ADAPTER_METADATA_MESSAGE_MODULE_KEY: req_cls.__module__,
76
+ GRPC_ADAPTER_METADATA_MESSAGE_QUALNAME_KEY: req_cls.__qualname__,
77
+ },
78
+ grpc_message_name=req_cls.__qualname__,
68
79
  grpc_message_content=request.SerializeToString(),
69
80
  )
70
81
 
@@ -340,7 +340,7 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
340
340
  task_res = message_to_taskres(message)
341
341
 
342
342
  # Serialize ProtoBuf to bytes
343
- req = PushTaskResRequest(task_res_list=[task_res])
343
+ req = PushTaskResRequest(node=node, task_res_list=[task_res])
344
344
 
345
345
  # Send the request
346
346
  res = _request(req, PushTaskResResponse, PATH_PUSH_TASK_RES)
@@ -356,7 +356,7 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
356
356
 
357
357
  def get_run(run_id: int) -> Run:
358
358
  # Construct the request
359
- req = GetRunRequest(run_id=run_id)
359
+ req = GetRunRequest(node=node, run_id=run_id)
360
360
 
361
361
  # Send the request
362
362
  res = _request(req, GetRunResponse, PATH_GET_RUN)
@@ -373,7 +373,7 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
373
373
 
374
374
  def get_fab(fab_hash: str) -> Fab:
375
375
  # Construct the request
376
- req = GetFabRequest(hash_str=fab_hash)
376
+ req = GetFabRequest(node=node, hash_str=fab_hash)
377
377
 
378
378
  # Send the request
379
379
  res = _request(req, GetFabResponse, PATH_GET_FAB)
@@ -79,6 +79,7 @@ def run_supernode() -> None:
79
79
  node_config=parse_config_args(
80
80
  [args.node_config] if args.node_config else args.node_config
81
81
  ),
82
+ flwr_path=args.flwr_dir,
82
83
  isolation=args.isolation,
83
84
  supernode_address=args.supernode_address,
84
85
  )
flwr/common/constant.py CHANGED
@@ -60,8 +60,6 @@ PING_MAX_INTERVAL = 1e300
60
60
  # IDs
61
61
  RUN_ID_NUM_BYTES = 8
62
62
  NODE_ID_NUM_BYTES = 8
63
- GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY = "flower-version"
64
- GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY = "should-exit"
65
63
 
66
64
  # Constants for FAB
67
65
  APP_DIR = "apps"
@@ -72,8 +70,13 @@ FLWR_HOME = "FLWR_HOME"
72
70
  PARTITION_ID_KEY = "partition-id"
73
71
  NUM_PARTITIONS_KEY = "num-partitions"
74
72
 
75
- GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY = "flower-version"
73
+ # Constants for keys in `metadata` of `MessageContainer` in `grpc-adapter`
74
+ GRPC_ADAPTER_METADATA_FLOWER_PACKAGE_NAME_KEY = "flower-package-name"
75
+ GRPC_ADAPTER_METADATA_FLOWER_PACKAGE_VERSION_KEY = "flower-package-version"
76
+ GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY = "flower-version" # Deprecated
76
77
  GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY = "should-exit"
78
+ GRPC_ADAPTER_METADATA_MESSAGE_MODULE_KEY = "grpc-message-module"
79
+ GRPC_ADAPTER_METADATA_MESSAGE_QUALNAME_KEY = "grpc-message-qualname"
77
80
 
78
81
 
79
82
  class MessageType:
@@ -43,8 +43,8 @@ def share_keys_plaintext_concat(
43
43
  """
44
44
  return b"".join(
45
45
  [
46
- int.to_bytes(src_node_id, 8, "little", signed=True),
47
- int.to_bytes(dst_node_id, 8, "little", signed=True),
46
+ int.to_bytes(src_node_id, 8, "little", signed=False),
47
+ int.to_bytes(dst_node_id, 8, "little", signed=False),
48
48
  int.to_bytes(len(b_share), 4, "little"),
49
49
  b_share,
50
50
  sk_share,
@@ -72,8 +72,8 @@ def share_keys_plaintext_separate(plaintext: bytes) -> tuple[int, int, bytes, by
72
72
  the secret key share of the source sent to the destination.
73
73
  """
74
74
  src, dst, mark = (
75
- int.from_bytes(plaintext[:8], "little", signed=True),
76
- int.from_bytes(plaintext[8:16], "little", signed=True),
75
+ int.from_bytes(plaintext[:8], "little", signed=False),
76
+ int.from_bytes(plaintext[8:16], "little", signed=False),
77
77
  int.from_bytes(plaintext[16:20], "little"),
78
78
  )
79
79
  ret = (src, dst, plaintext[20 : 20 + mark], plaintext[20 + mark :])
flwr/common/serde.py CHANGED
@@ -38,7 +38,7 @@ from flwr.proto.recordset_pb2 import MetricsRecord as ProtoMetricsRecord
38
38
  from flwr.proto.recordset_pb2 import MetricsRecordValue as ProtoMetricsRecordValue
39
39
  from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord
40
40
  from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
41
- from flwr.proto.recordset_pb2 import Sint64List, StringList
41
+ from flwr.proto.recordset_pb2 import SintList, StringList, UintList
42
42
  from flwr.proto.run_pb2 import Run as ProtoRun
43
43
  from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
44
44
  from flwr.proto.transport_pb2 import (
@@ -340,6 +340,7 @@ def metrics_from_proto(proto: Any) -> typing.Metrics:
340
340
 
341
341
 
342
342
  # === Scalar messages ===
343
+ INT64_MAX_VALUE = 9223372036854775807 # (1 << 63) - 1
343
344
 
344
345
 
345
346
  def scalar_to_proto(scalar: typing.Scalar) -> Scalar:
@@ -354,6 +355,9 @@ def scalar_to_proto(scalar: typing.Scalar) -> Scalar:
354
355
  return Scalar(double=scalar)
355
356
 
356
357
  if isinstance(scalar, int):
358
+ # Use uint64 for integers larger than the maximum value of sint64
359
+ if scalar > INT64_MAX_VALUE:
360
+ return Scalar(uint64=scalar)
357
361
  return Scalar(sint64=scalar)
358
362
 
359
363
  if isinstance(scalar, str):
@@ -374,16 +378,16 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar:
374
378
  # === Record messages ===
375
379
 
376
380
 
377
- _type_to_field = {
381
+ _type_to_field: dict[type, str] = {
378
382
  float: "double",
379
383
  int: "sint64",
380
384
  bool: "bool",
381
385
  str: "string",
382
386
  bytes: "bytes",
383
387
  }
384
- _list_type_to_class_and_field = {
388
+ _list_type_to_class_and_field: dict[type, tuple[type[GrpcMessage], str]] = {
385
389
  float: (DoubleList, "double_list"),
386
- int: (Sint64List, "sint64_list"),
390
+ int: (SintList, "sint_list"),
387
391
  bool: (BoolList, "bool_list"),
388
392
  str: (StringList, "string_list"),
389
393
  bytes: (BytesList, "bytes_list"),
@@ -391,6 +395,11 @@ _list_type_to_class_and_field = {
391
395
  T = TypeVar("T")
392
396
 
393
397
 
398
+ def _is_uint64(value: Any) -> bool:
399
+ """Check if a value is uint64."""
400
+ return isinstance(value, int) and value > INT64_MAX_VALUE
401
+
402
+
394
403
  def _record_value_to_proto(
395
404
  value: Any, allowed_types: list[type], proto_class: type[T]
396
405
  ) -> T:
@@ -403,12 +412,18 @@ def _record_value_to_proto(
403
412
  # Single element
404
413
  # Note: `isinstance(False, int) == True`.
405
414
  if isinstance(value, t):
406
- arg[_type_to_field[t]] = value
415
+ fld = _type_to_field[t]
416
+ if t is int and _is_uint64(value):
417
+ fld = "uint64"
418
+ arg[fld] = value
407
419
  return proto_class(**arg)
408
420
  # List
409
421
  if isinstance(value, list) and all(isinstance(item, t) for item in value):
410
- list_class, field_name = _list_type_to_class_and_field[t]
411
- arg[field_name] = list_class(vals=value)
422
+ list_class, fld = _list_type_to_class_and_field[t]
423
+ # Use UintList if any element is of type `uint64`.
424
+ if t is int and any(_is_uint64(v) for v in value):
425
+ list_class, fld = UintList, "uint_list"
426
+ arg[fld] = list_class(vals=value)
412
427
  return proto_class(**arg)
413
428
  # Invalid types
414
429
  raise TypeError(
@@ -0,0 +1,27 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # source: flwr/proto/control.proto
4
+ # Protobuf Python Version: 4.25.0
5
+ """Generated protocol buffer code."""
6
+ from google.protobuf import descriptor as _descriptor
7
+ from google.protobuf import descriptor_pool as _descriptor_pool
8
+ from google.protobuf import symbol_database as _symbol_database
9
+ from google.protobuf.internal import builder as _builder
10
+ # @@protoc_insertion_point(imports)
11
+
12
+ _sym_db = _symbol_database.Default()
13
+
14
+
15
+ from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2
16
+
17
+
18
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/control.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/run.proto2\x88\x02\n\x07\x43ontrol\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12S\n\x0cGetRunStatus\x12\x1f.flwr.proto.GetRunStatusRequest\x1a .flwr.proto.GetRunStatusResponse\"\x00\x12\\\n\x0fUpdateRunStatus\x12\".flwr.proto.UpdateRunStatusRequest\x1a#.flwr.proto.UpdateRunStatusResponse\"\x00\x62\x06proto3')
19
+
20
+ _globals = globals()
21
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
22
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.control_pb2', _globals)
23
+ if _descriptor._USE_C_DESCRIPTORS == False:
24
+ DESCRIPTOR._options = None
25
+ _globals['_CONTROL']._serialized_start=63
26
+ _globals['_CONTROL']._serialized_end=327
27
+ # @@protoc_insertion_point(module_scope)
@@ -0,0 +1,7 @@
1
+ """
2
+ @generated by mypy-protobuf. Do not edit manually!
3
+ isort:skip_file
4
+ """
5
+ import google.protobuf.descriptor
6
+
7
+ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor