syft-flwr 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- syft_flwr/__init__.py +15 -0
- syft_flwr/bootstrap.py +102 -0
- syft_flwr/cli.py +116 -0
- syft_flwr/config.py +36 -0
- syft_flwr/consts.py +2 -0
- syft_flwr/flower_client.py +199 -0
- syft_flwr/flower_server.py +50 -0
- syft_flwr/grid.py +580 -0
- syft_flwr/mounts.py +62 -0
- syft_flwr/run.py +63 -0
- syft_flwr/run_simulation.py +328 -0
- syft_flwr/serde.py +15 -0
- syft_flwr/strategy/__init__.py +3 -0
- syft_flwr/strategy/fedavg.py +38 -0
- syft_flwr/templates/main.py.tpl +31 -0
- syft_flwr/utils.py +126 -0
- syft_flwr-0.4.3.dist-info/METADATA +31 -0
- syft_flwr-0.4.3.dist-info/RECORD +21 -0
- syft_flwr-0.4.3.dist-info/WHEEL +4 -0
- syft_flwr-0.4.3.dist-info/entry_points.txt +2 -0
- syft_flwr-0.4.3.dist-info/licenses/LICENSE +201 -0
syft_flwr/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
__version__ = "0.4.3"
|
|
2
|
+
|
|
3
|
+
from syft_flwr.bootstrap import bootstrap
|
|
4
|
+
from syft_flwr.run import run
|
|
5
|
+
|
|
6
|
+
__all__ = ["bootstrap", "run"]
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
# Register the mount provider for syft_rds when syft_flwr is initializedAdd commentMore actions
|
|
10
|
+
from syft_rds.syft_runtime.mounts import register_mount_provider
|
|
11
|
+
|
|
12
|
+
from .mounts import SyftFlwrMountProvider
|
|
13
|
+
|
|
14
|
+
# Register the mount provider
|
|
15
|
+
register_mount_provider("syft_flwr", SyftFlwrMountProvider())
|
syft_flwr/bootstrap.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
from loguru import logger
|
|
5
|
+
from typing_extensions import List, Union
|
|
6
|
+
|
|
7
|
+
from syft_flwr import __version__
|
|
8
|
+
from syft_flwr.config import load_flwr_pyproject, write_toml
|
|
9
|
+
from syft_flwr.utils import is_valid_datasite
|
|
10
|
+
|
|
11
|
+
__all__ = ["bootstrap"]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
MAIN_TEMPLATE_PATH = Path(__file__).parent / "templates" / "main.py.tpl"
|
|
15
|
+
MAIN_TEMPLATE_CONTENT = MAIN_TEMPLATE_PATH.read_text()
|
|
16
|
+
assert MAIN_TEMPLATE_CONTENT
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def __copy_main_py(flwr_project_dir: Path) -> None:
|
|
20
|
+
"""Copy the content below to `main.py` file to the syft-flwr project"""
|
|
21
|
+
|
|
22
|
+
main_py_path = flwr_project_dir / "main.py"
|
|
23
|
+
|
|
24
|
+
if main_py_path.exists():
|
|
25
|
+
raise Exception(f"The file '{main_py_path}' already exists")
|
|
26
|
+
|
|
27
|
+
main_py_path.write_text(MAIN_TEMPLATE_CONTENT)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def __update_pyproject_toml(
|
|
31
|
+
flwr_project_dir: Union[str, Path],
|
|
32
|
+
aggregator: str,
|
|
33
|
+
datasites: List[str],
|
|
34
|
+
) -> None:
|
|
35
|
+
"""Update the `pyproject.toml` file to the syft-flwr project"""
|
|
36
|
+
flwr_project_dir = Path(flwr_project_dir)
|
|
37
|
+
flwr_pyproject = Path(flwr_project_dir, "pyproject.toml")
|
|
38
|
+
pyproject_conf = load_flwr_pyproject(flwr_pyproject, check_module=False)
|
|
39
|
+
|
|
40
|
+
# TODO: remove this after we find out how to pass the right context to the clients
|
|
41
|
+
pyproject_conf["tool"]["flwr"]["app"]["config"]["partition-id"] = 0
|
|
42
|
+
pyproject_conf["tool"]["flwr"]["app"]["config"]["num-partitions"] = 1
|
|
43
|
+
# TODO end
|
|
44
|
+
|
|
45
|
+
# add syft_flwr as a dependency
|
|
46
|
+
if "dependencies" not in pyproject_conf["project"]:
|
|
47
|
+
pyproject_conf["project"]["dependencies"] = []
|
|
48
|
+
|
|
49
|
+
deps: list = pyproject_conf["project"]["dependencies"]
|
|
50
|
+
deps = [dep for dep in deps if not dep.startswith("syft_flwr")]
|
|
51
|
+
deps.append(f"syft_flwr=={__version__}")
|
|
52
|
+
pyproject_conf["project"]["dependencies"] = deps
|
|
53
|
+
|
|
54
|
+
pyproject_conf["tool"]["syft_flwr"] = {}
|
|
55
|
+
|
|
56
|
+
# configure unique app name for each syft_flwr run
|
|
57
|
+
base_app_name = pyproject_conf["project"]["name"]
|
|
58
|
+
pyproject_conf["tool"]["syft_flwr"]["app_name"] = (
|
|
59
|
+
f"{aggregator}_{base_app_name}_{int(time.time())}"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# always override the datasites and aggregator
|
|
63
|
+
pyproject_conf["tool"]["syft_flwr"]["datasites"] = datasites
|
|
64
|
+
pyproject_conf["tool"]["syft_flwr"]["aggregator"] = aggregator
|
|
65
|
+
|
|
66
|
+
write_toml(flwr_pyproject, pyproject_conf)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def __validate_flwr_project_dir(flwr_project_dir: Union[str, Path]) -> Path:
|
|
70
|
+
flwr_pyproject = flwr_project_dir / "pyproject.toml"
|
|
71
|
+
flwr_main_py = flwr_project_dir / "main.py"
|
|
72
|
+
|
|
73
|
+
if flwr_main_py.exists():
|
|
74
|
+
raise FileExistsError(f"File '{flwr_main_py}' already exists")
|
|
75
|
+
|
|
76
|
+
if not flwr_project_dir.exists():
|
|
77
|
+
raise FileNotFoundError(f"Directory '{flwr_project_dir}' not found")
|
|
78
|
+
|
|
79
|
+
if not flwr_pyproject.exists():
|
|
80
|
+
raise FileNotFoundError(f"File '{flwr_pyproject}' not found")
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def bootstrap(
|
|
84
|
+
flwr_project_dir: Union[str, Path], aggregator: str, datasites: List[str]
|
|
85
|
+
) -> None:
|
|
86
|
+
"""Bootstrap a new syft-flwr project from the flwr project at the given path"""
|
|
87
|
+
flwr_project_dir = Path(flwr_project_dir)
|
|
88
|
+
|
|
89
|
+
if not is_valid_datasite(aggregator):
|
|
90
|
+
raise ValueError(f"'{aggregator}' is not a valid datasite")
|
|
91
|
+
|
|
92
|
+
for ds in datasites:
|
|
93
|
+
if not is_valid_datasite(ds):
|
|
94
|
+
raise ValueError(f"{ds} is not a valid datasite")
|
|
95
|
+
|
|
96
|
+
__validate_flwr_project_dir(flwr_project_dir)
|
|
97
|
+
__update_pyproject_toml(flwr_project_dir, aggregator, datasites)
|
|
98
|
+
__copy_main_py(flwr_project_dir)
|
|
99
|
+
|
|
100
|
+
logger.info(
|
|
101
|
+
f"Successfully bootstrapped syft-flwr project at {flwr_project_dir} with datasites {datasites} and aggregator '{aggregator}' ✅"
|
|
102
|
+
)
|
syft_flwr/cli.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
import typer
|
|
4
|
+
from rich import print as rprint
|
|
5
|
+
from typing_extensions import Annotated, List, Tuple
|
|
6
|
+
|
|
7
|
+
app = typer.Typer(
|
|
8
|
+
name="syft_flwr",
|
|
9
|
+
no_args_is_help=True,
|
|
10
|
+
pretty_exceptions_enable=False,
|
|
11
|
+
context_settings={"help_option_names": ["-h", "--help"]},
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@app.command()
|
|
16
|
+
def version() -> None:
|
|
17
|
+
"""Print syft_flwr version"""
|
|
18
|
+
from syft_flwr import __version__
|
|
19
|
+
|
|
20
|
+
print(__version__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
PROJECT_DIR_OPTS = typer.Argument(help="Path to a Flower project")
|
|
24
|
+
AGGREGATOR_OPTS = typer.Option(
|
|
25
|
+
"-a",
|
|
26
|
+
"--aggregator",
|
|
27
|
+
"-s",
|
|
28
|
+
"--server",
|
|
29
|
+
help="Datasite email of the Flower Server",
|
|
30
|
+
)
|
|
31
|
+
DATASITES_OPTS = typer.Option(
|
|
32
|
+
"-d",
|
|
33
|
+
"--datasites",
|
|
34
|
+
help="Datasites addresses",
|
|
35
|
+
)
|
|
36
|
+
MOCK_DATASET_PATHS_OPTS = typer.Option(
|
|
37
|
+
"-m",
|
|
38
|
+
"--mock-dataset-paths",
|
|
39
|
+
help="Mock dataset paths",
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def prompt_for_missing_args(
|
|
44
|
+
aggregator: str, datasites: List[str]
|
|
45
|
+
) -> Tuple[Path, str, List[str]]:
|
|
46
|
+
if not aggregator:
|
|
47
|
+
aggregator = typer.prompt(
|
|
48
|
+
"Enter the datasite email of the Aggregator (Flower Server)"
|
|
49
|
+
)
|
|
50
|
+
if not datasites:
|
|
51
|
+
datasites = typer.prompt(
|
|
52
|
+
"Enter a comma-separated email of datasites of the Flower Clients"
|
|
53
|
+
)
|
|
54
|
+
datasites = datasites.split(",")
|
|
55
|
+
|
|
56
|
+
return aggregator, datasites
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def prompt_for_missing_mock_paths(mock_dataset_paths: List[str]) -> List[str]:
|
|
60
|
+
if not mock_dataset_paths:
|
|
61
|
+
mock_paths = typer.prompt("Enter comma-separated paths to mock datasets")
|
|
62
|
+
mock_dataset_paths = mock_paths.split(",")
|
|
63
|
+
|
|
64
|
+
return mock_dataset_paths
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@app.command()
|
|
68
|
+
def bootstrap(
|
|
69
|
+
project_dir: Annotated[Path, PROJECT_DIR_OPTS],
|
|
70
|
+
aggregator: Annotated[str, AGGREGATOR_OPTS] = None,
|
|
71
|
+
datasites: Annotated[List[str], DATASITES_OPTS] = None,
|
|
72
|
+
) -> None:
|
|
73
|
+
"""Bootstrap a new syft_flwr project from a flwr project"""
|
|
74
|
+
from syft_flwr.bootstrap import bootstrap
|
|
75
|
+
|
|
76
|
+
aggregator, datasites = prompt_for_missing_args(aggregator, datasites)
|
|
77
|
+
|
|
78
|
+
try:
|
|
79
|
+
project_dir = project_dir.absolute()
|
|
80
|
+
rprint(f"[cyan]Bootstrapping project at '{project_dir}'[/cyan]")
|
|
81
|
+
rprint(f"[cyan]Aggregator: {aggregator}[/cyan]")
|
|
82
|
+
rprint(f"[cyan]Datasites: {datasites}[/cyan]")
|
|
83
|
+
bootstrap(project_dir, aggregator, datasites)
|
|
84
|
+
rprint(f"[green]Bootstrapped project at '{project_dir}'[/green]")
|
|
85
|
+
except Exception as e:
|
|
86
|
+
rprint(f"[red]Error[/red]: {e}")
|
|
87
|
+
raise typer.Exit(1)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@app.command()
|
|
91
|
+
def run(
|
|
92
|
+
project_dir: Annotated[Path, PROJECT_DIR_OPTS],
|
|
93
|
+
mock_dataset_paths: Annotated[List[str], MOCK_DATASET_PATHS_OPTS] = None,
|
|
94
|
+
) -> None:
|
|
95
|
+
"""Run a syft_flwr project in simulation mode over mock data"""
|
|
96
|
+
from syft_flwr.run_simulation import run
|
|
97
|
+
|
|
98
|
+
try:
|
|
99
|
+
mock_dataset_paths: list[str] = prompt_for_missing_mock_paths(
|
|
100
|
+
mock_dataset_paths
|
|
101
|
+
)
|
|
102
|
+
project_dir = Path(project_dir).expanduser().resolve()
|
|
103
|
+
rprint(f"[cyan]Running syft_flwr project at '{project_dir}'[/cyan]")
|
|
104
|
+
rprint(f"[cyan]Mock dataset paths: {mock_dataset_paths}[/cyan]")
|
|
105
|
+
run(project_dir, mock_dataset_paths)
|
|
106
|
+
except Exception as e:
|
|
107
|
+
rprint(f"[red]Error[/red]: {e}")
|
|
108
|
+
raise typer.Exit(1)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def main() -> None:
|
|
112
|
+
app()
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
if __name__ == "__main__":
|
|
116
|
+
main()
|
syft_flwr/config.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import tomli
|
|
6
|
+
import tomli_w
|
|
7
|
+
from flwr.common.config import validate_config
|
|
8
|
+
from loguru import logger
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def load_toml(path: str):
|
|
12
|
+
with open(path, "rb") as fp:
|
|
13
|
+
return tomli.load(fp)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def write_toml(path: str, val: dict):
|
|
17
|
+
with open(path, "wb") as fp:
|
|
18
|
+
tomli_w.dump(val, fp)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def load_flwr_pyproject(path: Path, check_module: bool = True) -> dict:
|
|
22
|
+
"""Load the flower's pyproject.toml file and validate it."""
|
|
23
|
+
|
|
24
|
+
if path.name != "pyproject.toml":
|
|
25
|
+
path = path / "pyproject.toml"
|
|
26
|
+
|
|
27
|
+
pyproject = load_toml(path)
|
|
28
|
+
is_valid, errors, warnings = validate_config(pyproject, check_module, path.parent)
|
|
29
|
+
|
|
30
|
+
if not is_valid:
|
|
31
|
+
raise Exception(errors)
|
|
32
|
+
|
|
33
|
+
if warnings:
|
|
34
|
+
logger.warning(warnings)
|
|
35
|
+
|
|
36
|
+
return pyproject
|
syft_flwr/consts.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import sys
|
|
3
|
+
import traceback
|
|
4
|
+
|
|
5
|
+
from flwr.client import ClientApp
|
|
6
|
+
from flwr.common import Context
|
|
7
|
+
from flwr.common.constant import ErrorCode, MessageType
|
|
8
|
+
from flwr.common.message import Error, Message
|
|
9
|
+
from flwr.common.record import RecordDict
|
|
10
|
+
from loguru import logger
|
|
11
|
+
from syft_event import SyftEvents
|
|
12
|
+
from syft_event.types import Request
|
|
13
|
+
from typing_extensions import Optional, Union
|
|
14
|
+
|
|
15
|
+
from syft_flwr.serde import bytes_to_flower_message, flower_message_to_bytes
|
|
16
|
+
from syft_flwr.utils import create_flwr_message, setup_client
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class MessageHandler:
|
|
20
|
+
"""Handles message processing for Flower client."""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self, client_app: ClientApp, context: Context, encryption_enabled: bool
|
|
24
|
+
):
|
|
25
|
+
self.client_app = client_app
|
|
26
|
+
self.context = context
|
|
27
|
+
self.encryption_enabled = encryption_enabled
|
|
28
|
+
|
|
29
|
+
def prepare_reply(self, data: bytes) -> Union[str, bytes]:
|
|
30
|
+
"""Prepare reply data based on encryption setting."""
|
|
31
|
+
if self.encryption_enabled:
|
|
32
|
+
logger.info(f"🔒 Preparing ENCRYPTED reply, size: {len(data)/2**20:.2f} MB")
|
|
33
|
+
return base64.b64encode(data).decode("utf-8")
|
|
34
|
+
else:
|
|
35
|
+
logger.info(f"📤 Preparing PLAINTEXT reply, size: {len(data)/2**20:.2f} MB")
|
|
36
|
+
return data
|
|
37
|
+
|
|
38
|
+
def process_message(self, message: Message) -> Union[str, bytes]:
|
|
39
|
+
"""Process normal Flower message and return reply."""
|
|
40
|
+
logger.info(f"Processing message with metadata: {message.metadata}")
|
|
41
|
+
reply_message = self.client_app(message=message, context=self.context)
|
|
42
|
+
reply_bytes = flower_message_to_bytes(reply_message)
|
|
43
|
+
return self.prepare_reply(reply_bytes)
|
|
44
|
+
|
|
45
|
+
def create_error_reply(
|
|
46
|
+
self, message: Optional[Message], error: Error
|
|
47
|
+
) -> Union[str, bytes]:
|
|
48
|
+
"""Create error reply message."""
|
|
49
|
+
error_reply = create_flwr_message(
|
|
50
|
+
content=RecordDict(),
|
|
51
|
+
reply_to=message,
|
|
52
|
+
message_type=message.metadata.message_type
|
|
53
|
+
if message
|
|
54
|
+
else MessageType.SYSTEM,
|
|
55
|
+
dst_node_id=message.metadata.src_node_id if message else 0,
|
|
56
|
+
group_id=message.metadata.group_id if message else "",
|
|
57
|
+
error=error,
|
|
58
|
+
)
|
|
59
|
+
error_bytes = flower_message_to_bytes(error_reply)
|
|
60
|
+
logger.info(f"Error reply size: {len(error_bytes)/2**20:.2f} MB")
|
|
61
|
+
return self.prepare_reply(error_bytes)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class RequestProcessor:
|
|
65
|
+
"""Processes incoming requests and handles encryption/decryption."""
|
|
66
|
+
|
|
67
|
+
def __init__(
|
|
68
|
+
self, message_handler: MessageHandler, box: SyftEvents, client_email: str
|
|
69
|
+
):
|
|
70
|
+
self.message_handler = message_handler
|
|
71
|
+
self.box = box
|
|
72
|
+
self.client_email = client_email
|
|
73
|
+
|
|
74
|
+
def decode_request_body(self, request_body: Union[bytes, str]) -> bytes:
|
|
75
|
+
"""Decode request body, handling base64 if encrypted."""
|
|
76
|
+
if not self.message_handler.encryption_enabled:
|
|
77
|
+
return request_body
|
|
78
|
+
|
|
79
|
+
try:
|
|
80
|
+
# Convert to string if bytes
|
|
81
|
+
if isinstance(request_body, bytes):
|
|
82
|
+
request_body_str = request_body.decode("utf-8")
|
|
83
|
+
else:
|
|
84
|
+
request_body_str = request_body
|
|
85
|
+
# Decode base64
|
|
86
|
+
decoded = base64.b64decode(request_body_str)
|
|
87
|
+
logger.debug("🔓 Decoded base64 message")
|
|
88
|
+
return decoded
|
|
89
|
+
except Exception:
|
|
90
|
+
# Not base64 or decoding failed, use as-is
|
|
91
|
+
return request_body
|
|
92
|
+
|
|
93
|
+
def is_stop_signal(self, message: Message) -> bool:
|
|
94
|
+
"""Check if message is a stop signal."""
|
|
95
|
+
if message.metadata.message_type != MessageType.SYSTEM:
|
|
96
|
+
return False
|
|
97
|
+
|
|
98
|
+
# Check for stop action in config
|
|
99
|
+
if "config" in message.content and "action" in message.content["config"]:
|
|
100
|
+
return message.content["config"]["action"] == "stop"
|
|
101
|
+
|
|
102
|
+
# Alternative stop signal format
|
|
103
|
+
return message.metadata.group_id == "final"
|
|
104
|
+
|
|
105
|
+
def process(self, request: Request) -> Optional[Union[str, bytes]]:
|
|
106
|
+
"""Process incoming request and return response."""
|
|
107
|
+
original_sender = request.headers.get("X-Syft-Original-Sender", "unknown")
|
|
108
|
+
encryption_status = (
|
|
109
|
+
"🔐 ENCRYPTED"
|
|
110
|
+
if self.message_handler.encryption_enabled
|
|
111
|
+
else "📥 PLAINTEXT"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
logger.info(
|
|
115
|
+
f"{encryption_status} Received request from {original_sender}, "
|
|
116
|
+
f"id: {request.id}, size: {len(request.body) / 1024 / 1024:.2f} MB"
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Parse message
|
|
120
|
+
try:
|
|
121
|
+
request_body = self.decode_request_body(request.body)
|
|
122
|
+
message = bytes_to_flower_message(request_body)
|
|
123
|
+
|
|
124
|
+
if self.message_handler.encryption_enabled:
|
|
125
|
+
logger.debug(
|
|
126
|
+
f"🔓 Successfully decrypted message from {original_sender}"
|
|
127
|
+
)
|
|
128
|
+
except Exception as e:
|
|
129
|
+
logger.error(
|
|
130
|
+
f"❌ Failed to deserialize message from {original_sender}: {e}"
|
|
131
|
+
)
|
|
132
|
+
logger.debug(
|
|
133
|
+
f"Request body preview (first 200 bytes): {str(request.body[:200])}"
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Can't create error reply without valid message - skip response
|
|
137
|
+
logger.warning(
|
|
138
|
+
"Skipping error reply (cannot create without valid parsed message)"
|
|
139
|
+
)
|
|
140
|
+
return None
|
|
141
|
+
|
|
142
|
+
# Handle message
|
|
143
|
+
try:
|
|
144
|
+
# Check for stop signal
|
|
145
|
+
if self.is_stop_signal(message):
|
|
146
|
+
logger.info("Received stop signal")
|
|
147
|
+
self.box._stop_event.set()
|
|
148
|
+
return None
|
|
149
|
+
|
|
150
|
+
# Process normal message
|
|
151
|
+
return self.message_handler.process_message(message)
|
|
152
|
+
|
|
153
|
+
except Exception as e:
|
|
154
|
+
error_message = f"Client: '{self.client_email}'. Error: {str(e)}. Traceback: {traceback.format_exc()}"
|
|
155
|
+
logger.error(error_message)
|
|
156
|
+
|
|
157
|
+
error = Error(
|
|
158
|
+
code=ErrorCode.CLIENT_APP_RAISED_EXCEPTION, reason=error_message
|
|
159
|
+
)
|
|
160
|
+
return self.message_handler.create_error_reply(message, error)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def syftbox_flwr_client(client_app: ClientApp, context: Context, app_name: str):
|
|
164
|
+
"""Run the Flower ClientApp with SyftBox."""
|
|
165
|
+
# Setup
|
|
166
|
+
client, encryption_enabled, syft_flwr_app_name = setup_client(app_name)
|
|
167
|
+
box = SyftEvents(
|
|
168
|
+
app_name=syft_flwr_app_name,
|
|
169
|
+
client=client,
|
|
170
|
+
cleanup_expiry="1d", # Keep request/response files for 1 days
|
|
171
|
+
cleanup_interval="1d", # Run cleanup daily
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
logger.info(f"Started SyftBox Flower Client on: {box.client.email}")
|
|
175
|
+
logger.info(f"syft_flwr app name: {syft_flwr_app_name}")
|
|
176
|
+
|
|
177
|
+
# Check if cleanup is running
|
|
178
|
+
if box.is_cleanup_running():
|
|
179
|
+
logger.info("Cleanup service is active")
|
|
180
|
+
|
|
181
|
+
# Create handlers
|
|
182
|
+
message_handler = MessageHandler(client_app, context, encryption_enabled)
|
|
183
|
+
processor = RequestProcessor(message_handler, box, box.client.email)
|
|
184
|
+
|
|
185
|
+
# Register message handler
|
|
186
|
+
@box.on_request(
|
|
187
|
+
"/messages", auto_decrypt=encryption_enabled, encrypt_reply=encryption_enabled
|
|
188
|
+
)
|
|
189
|
+
def handle_messages(request: Request) -> Optional[Union[str, bytes]]:
|
|
190
|
+
return processor.process(request)
|
|
191
|
+
|
|
192
|
+
# Run
|
|
193
|
+
try:
|
|
194
|
+
box.run_forever()
|
|
195
|
+
except Exception as e:
|
|
196
|
+
logger.error(
|
|
197
|
+
f"Fatal error in syftbox_flwr_client: {str(e)}\n{traceback.format_exc()}"
|
|
198
|
+
)
|
|
199
|
+
sys.exit(1)
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import traceback
|
|
2
|
+
from random import randint
|
|
3
|
+
|
|
4
|
+
from flwr.common import Context
|
|
5
|
+
from flwr.server import ServerApp
|
|
6
|
+
from flwr.server.run_serverapp import run as run_server
|
|
7
|
+
from loguru import logger
|
|
8
|
+
|
|
9
|
+
from syft_flwr.grid import SyftGrid
|
|
10
|
+
from syft_flwr.utils import setup_client
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def syftbox_flwr_server(
|
|
14
|
+
server_app: ServerApp,
|
|
15
|
+
context: Context,
|
|
16
|
+
datasites: list[str],
|
|
17
|
+
app_name: str,
|
|
18
|
+
) -> Context:
|
|
19
|
+
"""Run the Flower ServerApp with SyftBox."""
|
|
20
|
+
client, _, syft_flwr_app_name = setup_client(app_name)
|
|
21
|
+
|
|
22
|
+
# Construct the SyftGrid
|
|
23
|
+
syft_grid = SyftGrid(
|
|
24
|
+
app_name=syft_flwr_app_name, datasites=datasites, client=client
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
# Set the run id (random for now)
|
|
28
|
+
run_id = randint(0, 1000)
|
|
29
|
+
syft_grid.set_run(run_id)
|
|
30
|
+
|
|
31
|
+
logger.info(f"Started SyftBox Flower Server on: {syft_grid._client.email}")
|
|
32
|
+
logger.info(f"syft_flwr app name: {syft_flwr_app_name}")
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
updated_context = run_server(
|
|
36
|
+
syft_grid,
|
|
37
|
+
context=context,
|
|
38
|
+
loaded_server_app=server_app,
|
|
39
|
+
server_app_dir="",
|
|
40
|
+
)
|
|
41
|
+
logger.info(f"Server completed with context: {updated_context}")
|
|
42
|
+
except Exception as e:
|
|
43
|
+
logger.error(f"Server encountered an error: {str(e)}")
|
|
44
|
+
logger.error(f"Traceback: {traceback.format_exc()}")
|
|
45
|
+
updated_context = context
|
|
46
|
+
finally:
|
|
47
|
+
syft_grid.send_stop_signal(group_id="final", reason="Server stopped")
|
|
48
|
+
logger.info("Sending stop signals to the clients")
|
|
49
|
+
|
|
50
|
+
return updated_context
|