flwr-nightly 1.8.0.dev20240309__py3-none-any.whl → 1.8.0.dev20240311__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. flwr/cli/flower_toml.py +4 -48
  2. flwr/cli/new/new.py +6 -3
  3. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -3
  4. flwr/cli/new/templates/app/pyproject.toml.tpl +1 -1
  5. flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +2 -2
  6. flwr/cli/utils.py +14 -1
  7. flwr/client/app.py +39 -5
  8. flwr/client/client_app.py +1 -47
  9. flwr/client/mod/__init__.py +2 -1
  10. flwr/client/mod/secure_aggregation/__init__.py +2 -0
  11. flwr/client/mod/secure_aggregation/secagg_mod.py +30 -0
  12. flwr/client/mod/secure_aggregation/secaggplus_mod.py +73 -57
  13. flwr/common/grpc.py +3 -3
  14. flwr/common/logger.py +78 -15
  15. flwr/common/object_ref.py +140 -0
  16. flwr/common/secure_aggregation/ndarrays_arithmetic.py +5 -5
  17. flwr/common/secure_aggregation/secaggplus_constants.py +7 -6
  18. flwr/common/secure_aggregation/secaggplus_utils.py +15 -15
  19. flwr/server/compat/app.py +2 -1
  20. flwr/server/driver/grpc_driver.py +4 -4
  21. flwr/server/history.py +22 -15
  22. flwr/server/run_serverapp.py +22 -4
  23. flwr/server/server.py +27 -23
  24. flwr/server/server_app.py +1 -47
  25. flwr/server/server_config.py +9 -0
  26. flwr/server/strategy/fedavg.py +2 -0
  27. flwr/server/superlink/fleet/vce/vce_api.py +9 -2
  28. flwr/server/superlink/state/in_memory_state.py +34 -32
  29. flwr/server/workflow/__init__.py +3 -0
  30. flwr/server/workflow/constant.py +32 -0
  31. flwr/server/workflow/default_workflows.py +52 -57
  32. flwr/server/workflow/secure_aggregation/__init__.py +24 -0
  33. flwr/server/workflow/secure_aggregation/secagg_workflow.py +112 -0
  34. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +676 -0
  35. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/METADATA +1 -1
  36. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/RECORD +39 -33
  37. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/LICENSE +0 -0
  38. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/WHEEL +0 -0
  39. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/entry_points.txt +0 -0
flwr/cli/flower_toml.py CHANGED
@@ -14,12 +14,13 @@
14
14
  # ==============================================================================
15
15
  """Utility to validate the `flower.toml` file."""
16
16
 
17
- import importlib
18
17
  import os
19
18
  from typing import Any, Dict, List, Optional, Tuple
20
19
 
21
20
  import tomli
22
21
 
22
+ from flwr.common.object_ref import validate
23
+
23
24
 
24
25
  def load_flower_toml(path: Optional[str] = None) -> Optional[Dict[str, Any]]:
25
26
  """Load flower.toml and return as dict."""
@@ -71,47 +72,6 @@ def validate_flower_toml_fields(
71
72
  return len(errors) == 0, errors, warnings
72
73
 
73
74
 
74
- def validate_object_reference(ref: str) -> Tuple[bool, Optional[str]]:
75
- """Validate object reference.
76
-
77
- Returns
78
- -------
79
- Tuple[bool, Optional[str]]
80
- A boolean indicating whether an object reference is valid and
81
- the reason why it might not be.
82
- """
83
- module_str, _, attributes_str = ref.partition(":")
84
- if not module_str:
85
- return (
86
- False,
87
- f"Missing module in {ref}",
88
- )
89
- if not attributes_str:
90
- return (
91
- False,
92
- f"Missing attribute in {ref}",
93
- )
94
-
95
- # Load module
96
- try:
97
- module = importlib.import_module(module_str)
98
- except ModuleNotFoundError:
99
- return False, f"Unable to load module {module_str}"
100
-
101
- # Recursively load attribute
102
- attribute = module
103
- try:
104
- for attribute_str in attributes_str.split("."):
105
- attribute = getattr(attribute, attribute_str)
106
- except AttributeError:
107
- return (
108
- False,
109
- f"Unable to load attribute {attributes_str} from module {module_str}",
110
- )
111
-
112
- return (True, None)
113
-
114
-
115
75
  def validate_flower_toml(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]:
116
76
  """Validate flower.toml."""
117
77
  is_valid, errors, warnings = validate_flower_toml_fields(config)
@@ -120,16 +80,12 @@ def validate_flower_toml(config: Dict[str, Any]) -> Tuple[bool, List[str], List[
120
80
  return False, errors, warnings
121
81
 
122
82
  # Validate serverapp
123
- is_valid, reason = validate_object_reference(
124
- config["flower"]["components"]["serverapp"]
125
- )
83
+ is_valid, reason = validate(config["flower"]["components"]["serverapp"])
126
84
  if not is_valid and isinstance(reason, str):
127
85
  return False, [reason], []
128
86
 
129
87
  # Validate clientapp
130
- is_valid, reason = validate_object_reference(
131
- config["flower"]["components"]["clientapp"]
132
- )
88
+ is_valid, reason = validate(config["flower"]["components"]["clientapp"])
133
89
 
134
90
  if not is_valid and isinstance(reason, str):
135
91
  return False, [reason], []
flwr/cli/new/new.py CHANGED
@@ -22,7 +22,7 @@ from typing import Dict, Optional
22
22
  import typer
23
23
  from typing_extensions import Annotated
24
24
 
25
- from ..utils import prompt_options
25
+ from ..utils import prompt_options, prompt_text
26
26
 
27
27
 
28
28
  class MlFramework(str, Enum):
@@ -72,9 +72,9 @@ def render_and_create(file_path: str, template: str, context: Dict[str, str]) ->
72
72
 
73
73
  def new(
74
74
  project_name: Annotated[
75
- str,
75
+ Optional[str],
76
76
  typer.Argument(metavar="project_name", help="The name of the project"),
77
- ],
77
+ ] = None,
78
78
  framework: Annotated[
79
79
  Optional[MlFramework],
80
80
  typer.Option(case_sensitive=False, help="The ML framework to use"),
@@ -83,6 +83,9 @@ def new(
83
83
  """Create new Flower project."""
84
84
  print(f"Creating Flower project {project_name}...")
85
85
 
86
+ if project_name is None:
87
+ project_name = prompt_text("Please provide project name")
88
+
86
89
  if framework is not None:
87
90
  framework_str = str(framework.value)
88
91
  else:
@@ -1,6 +1,5 @@
1
1
  """$project_name: A Flower / PyTorch app."""
2
2
 
3
- import warnings
4
3
  from collections import OrderedDict
5
4
 
6
5
  import torch
@@ -9,7 +8,6 @@ import torch.nn.functional as F
9
8
  from torch.utils.data import DataLoader
10
9
  from torchvision.datasets import CIFAR10
11
10
  from torchvision.transforms import Compose, Normalize, ToTensor
12
- from tqdm import tqdm
13
11
 
14
12
 
15
13
  DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -77,7 +75,7 @@ def test(net, testloader):
77
75
  criterion = torch.nn.CrossEntropyLoss()
78
76
  correct, loss = 0, 0.0
79
77
  with torch.no_grad():
80
- for images, labels in tqdm(testloader):
78
+ for images, labels in testloader:
81
79
  outputs = net(images.to(DEVICE))
82
80
  labels = labels.to(DEVICE)
83
81
  loss += criterion(outputs, labels).item()
@@ -15,7 +15,7 @@ readme = "README.md"
15
15
  [tool.poetry.dependencies]
16
16
  python = "^3.9"
17
17
  # Mandatory dependencies
18
- flwr-nightly = { version = "1.8.0.dev20240308", extras = ["simulation"] }
18
+ flwr-nightly = { version = "1.8.0.dev20240309", extras = ["simulation"] }
19
19
  flwr-datasets = { version = "^0.0.2", extras = ["vision"] }
20
20
  torch = "2.2.1"
21
21
  torchvision = "0.17.1"
@@ -1,4 +1,4 @@
1
- flwr>=1.8, <2.0
2
- flwr-datasets[vision]>=0.0.2, <1.0.0
1
+ flwr-nightly[simulation]==1.8.0.dev20240309
2
+ flwr-datasets[vision]==0.0.2
3
3
  torch==2.2.1
4
4
  torchvision==0.17.1
flwr/cli/utils.py CHANGED
@@ -14,11 +14,24 @@
14
14
  # ==============================================================================
15
15
  """Flower command line interface utils."""
16
16
 
17
- from typing import List
17
+ from typing import List, cast
18
18
 
19
19
  import typer
20
20
 
21
21
 
22
+ def prompt_text(text: str) -> str:
23
+ """Ask user to enter text input."""
24
+ while True:
25
+ result = typer.prompt(
26
+ typer.style(f"\n💬 {text}", fg=typer.colors.MAGENTA, bold=True)
27
+ )
28
+ if len(result) > 0:
29
+ break
30
+ print(typer.style("❌ Invalid entry", fg=typer.colors.RED, bold=True))
31
+
32
+ return cast(str, result)
33
+
34
+
22
35
  def prompt_options(text: str, options: List[str]) -> str:
23
36
  """Ask user to select one of the given options and return the selected item."""
24
37
  # Turn options into a list with index as in " [ 0] quickstart-pytorch"
flwr/client/app.py CHANGED
@@ -25,7 +25,7 @@ from typing import Callable, ContextManager, Optional, Tuple, Type, Union
25
25
  from grpc import RpcError
26
26
 
27
27
  from flwr.client.client import Client
28
- from flwr.client.client_app import ClientApp
28
+ from flwr.client.client_app import ClientApp, LoadClientAppError
29
29
  from flwr.client.typing import ClientFn
30
30
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, Message, event
31
31
  from flwr.common.address import parse_address
@@ -38,9 +38,9 @@ from flwr.common.constant import (
38
38
  )
39
39
  from flwr.common.exit_handlers import register_exit_handlers
40
40
  from flwr.common.logger import log, warn_deprecated_feature, warn_experimental_feature
41
+ from flwr.common.object_ref import load_app, validate
41
42
  from flwr.common.retry_invoker import RetryInvoker, exponential
42
43
 
43
- from .client_app import load_client_app
44
44
  from .grpc_client.connection import grpc_connection
45
45
  from .grpc_rere_client.connection import grpc_request_response
46
46
  from .message_handler.message_handler import handle_control_message
@@ -97,8 +97,19 @@ def run_client_app() -> None:
97
97
  if client_app_dir is not None:
98
98
  sys.path.insert(0, client_app_dir)
99
99
 
100
+ app_ref: str = getattr(args, "client-app")
101
+ valid, error_msg = validate(app_ref)
102
+ if not valid and error_msg:
103
+ raise LoadClientAppError(error_msg) from None
104
+
100
105
  def _load() -> ClientApp:
101
- client_app: ClientApp = load_client_app(getattr(args, "client-app"))
106
+ client_app = load_app(app_ref, LoadClientAppError)
107
+
108
+ if not isinstance(client_app, ClientApp):
109
+ raise LoadClientAppError(
110
+ f"Attribute {app_ref} is not of type {ClientApp}",
111
+ ) from None
112
+
102
113
  return client_app
103
114
 
104
115
  _start_client_internal(
@@ -445,7 +456,19 @@ def _start_client_internal(
445
456
  time.sleep(3) # Wait for 3s before asking again
446
457
  continue
447
458
 
448
- log(INFO, "Received message")
459
+ log(INFO, "")
460
+ log(
461
+ INFO,
462
+ "[RUN %s, ROUND %s]",
463
+ message.metadata.run_id,
464
+ message.metadata.group_id,
465
+ )
466
+ log(
467
+ INFO,
468
+ "Received: %s message %s",
469
+ message.metadata.message_type,
470
+ message.metadata.message_id,
471
+ )
449
472
 
450
473
  # Handle control message
451
474
  out_message, sleep_duration = handle_control_message(message)
@@ -473,7 +496,18 @@ def _start_client_internal(
473
496
 
474
497
  # Send
475
498
  send(out_message)
476
- log(INFO, "Sent reply")
499
+ log(
500
+ INFO,
501
+ "[RUN %s, ROUND %s]",
502
+ out_message.metadata.run_id,
503
+ out_message.metadata.group_id,
504
+ )
505
+ log(
506
+ INFO,
507
+ "Sent: %s reply to message %s",
508
+ out_message.metadata.message_type,
509
+ message.metadata.message_id,
510
+ )
477
511
 
478
512
  # Unregister node
479
513
  if delete_node is not None:
flwr/client/client_app.py CHANGED
@@ -15,8 +15,7 @@
15
15
  """Flower ClientApp."""
16
16
 
17
17
 
18
- import importlib
19
- from typing import Callable, List, Optional, cast
18
+ from typing import Callable, List, Optional
20
19
 
21
20
  from flwr.client.message_handler.message_handler import (
22
21
  handle_legacy_message_from_msgtype,
@@ -194,51 +193,6 @@ class LoadClientAppError(Exception):
194
193
  """Error when trying to load `ClientApp`."""
195
194
 
196
195
 
197
- def load_client_app(module_attribute_str: str) -> ClientApp:
198
- """Load the `ClientApp` object specified in a module attribute string.
199
-
200
- The module/attribute string should have the form <module>:<attribute>. Valid
201
- examples include `client:app` and `project.package.module:wrapper.app`. It
202
- must refer to a module on the PYTHONPATH, the module needs to have the specified
203
- attribute, and the attribute must be of type `ClientApp`.
204
- """
205
- module_str, _, attributes_str = module_attribute_str.partition(":")
206
- if not module_str:
207
- raise LoadClientAppError(
208
- f"Missing module in {module_attribute_str}",
209
- ) from None
210
- if not attributes_str:
211
- raise LoadClientAppError(
212
- f"Missing attribute in {module_attribute_str}",
213
- ) from None
214
-
215
- # Load module
216
- try:
217
- module = importlib.import_module(module_str)
218
- except ModuleNotFoundError:
219
- raise LoadClientAppError(
220
- f"Unable to load module {module_str}",
221
- ) from None
222
-
223
- # Recursively load attribute
224
- attribute = module
225
- try:
226
- for attribute_str in attributes_str.split("."):
227
- attribute = getattr(attribute, attribute_str)
228
- except AttributeError:
229
- raise LoadClientAppError(
230
- f"Unable to load attribute {attributes_str} from module {module_str}",
231
- ) from None
232
-
233
- # Check type
234
- if not isinstance(attribute, ClientApp):
235
- raise LoadClientAppError(
236
- f"Attribute {attributes_str} is not of type {ClientApp}",
237
- ) from None
238
-
239
- return cast(ClientApp, attribute)
240
-
241
-
242
196
  def _registration_error(fn_name: str) -> ValueError:
243
197
  return ValueError(
244
198
  f"""Use either `@app.{fn_name}()` or `client_fn`, but not both.
@@ -17,7 +17,7 @@
17
17
 
18
18
  from .centraldp_mods import adaptiveclipping_mod, fixedclipping_mod
19
19
  from .localdp_mod import LocalDpMod
20
- from .secure_aggregation.secaggplus_mod import secaggplus_mod
20
+ from .secure_aggregation import secagg_mod, secaggplus_mod
21
21
  from .utils import make_ffn
22
22
 
23
23
  __all__ = [
@@ -25,5 +25,6 @@ __all__ = [
25
25
  "fixedclipping_mod",
26
26
  "LocalDpMod",
27
27
  "make_ffn",
28
+ "secagg_mod",
28
29
  "secaggplus_mod",
29
30
  ]
@@ -15,8 +15,10 @@
15
15
  """Secure Aggregation mods."""
16
16
 
17
17
 
18
+ from .secagg_mod import secagg_mod
18
19
  from .secaggplus_mod import secaggplus_mod
19
20
 
20
21
  __all__ = [
22
+ "secagg_mod",
21
23
  "secaggplus_mod",
22
24
  ]
@@ -0,0 +1,30 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Modifier for the SecAgg protocol."""
16
+
17
+
18
+ from flwr.client.typing import ClientAppCallable
19
+ from flwr.common import Context, Message
20
+
21
+ from .secaggplus_mod import secaggplus_mod
22
+
23
+
24
+ def secagg_mod(
25
+ msg: Message,
26
+ ctxt: Context,
27
+ call_next: ClientAppCallable,
28
+ ) -> Message:
29
+ """Handle incoming message and return results, following the SecAgg protocol."""
30
+ return secaggplus_mod(msg, ctxt, call_next)