syft-flwr 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of syft-flwr might be problematic. Click here for more details.

README.md ADDED
@@ -0,0 +1,6 @@
1
+ # syft_flwr
2
+
3
+ `syft_flwr` is an open source framework that facilitate federated learning projects using [Flower](https://github.com/adap/flower) over the [SyftBox](https://github.com/OpenMined/syftbox) protocol
4
+
5
+ ## Installation
6
+ `pip install syft_flwr`
pyproject.toml ADDED
@@ -0,0 +1,52 @@
1
+ [project]
2
+ name = "syft-flwr"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.9.2"
7
+ dependencies = [
8
+ "flwr[simulation]==1.17.0",
9
+ "flwr-datasets[vision]>=0.5.0",
10
+ "loguru>=0.7.3",
11
+ "safetensors>=0.5.0",
12
+ "typing-extensions>=4.13.0",
13
+ "tomli>=2.2.1",
14
+ "tomli-w>=1.2.0",
15
+ "syft-core==0.2.3",
16
+ "syft-event==0.2.0",
17
+ "syft-rpc==0.2.0",
18
+ "syft-rds==0.1.0",
19
+ ]
20
+
21
+ [project.scripts]
22
+ syft_flwr = "syft_flwr.cli:main"
23
+
24
+ [build-system]
25
+ requires = ["hatchling"]
26
+ build-backend = "hatchling.build"
27
+
28
+ [tool.uv]
29
+ dev-dependencies = [
30
+ "ipykernel>=6.29.5",
31
+ "ipywidgets>=8.1.6",
32
+ "pytest>=8.3.4",
33
+ "pre-commit>=4.0.1",
34
+ ]
35
+
36
+ [tool.hatch.build.targets.wheel]
37
+ packages = ["src/syft_flwr"]
38
+ only-include = ["src", "pyproject.toml", "/README.md"]
39
+ exclude = ["src/**/__pycache__"]
40
+
41
+ [tool.hatch.build.targets.sdist]
42
+ only-include = ["src", "pyproject.toml", "/README.md"]
43
+ exclude = ["src/**/__pycache__", "examples", "notebooks", "justfile"]
44
+
45
+ [tool.ruff]
46
+ exclude = [".archive"]
47
+
48
+ [tool.ruff.lint]
49
+ extend-select = ["I"]
50
+
51
+ [tool.ruff.lint.per-file-ignores]
52
+ "**/__init__.py" = ["F401"]
syft_flwr/__init__.py ADDED
@@ -0,0 +1,6 @@
1
+ __version__ = "0.1.0"
2
+
3
+ from syft_flwr.bootstrap import bootstrap
4
+ from syft_flwr.run import run
5
+
6
+ __all__ = ["bootstrap", "run"]
syft_flwr/bootstrap.py ADDED
@@ -0,0 +1,97 @@
1
+ from pathlib import Path
2
+
3
+ from loguru import logger
4
+ from typing_extensions import List, Union
5
+
6
+ from syft_flwr import __version__
7
+ from syft_flwr.config import load_flwr_pyproject, write_toml
8
+ from syft_flwr.utils import is_valid_datasite
9
+
10
+ __all__ = ["bootstrap"]
11
+
12
+
13
+ MAIN_TEMPLATE_PATH = Path(__file__).parent / "templates" / "main.py.tpl"
14
+ MAIN_TEMPLATE_CONTENT = MAIN_TEMPLATE_PATH.read_text()
15
+ assert MAIN_TEMPLATE_CONTENT
16
+
17
+
18
+ def __copy_main_py(flwr_project_dir: Path) -> None:
19
+ """Copy the content below to `main.py` file to the syft-flwr project"""
20
+
21
+ main_py_path = flwr_project_dir / "main.py"
22
+
23
+ if main_py_path.exists():
24
+ raise Exception(f"The file '{main_py_path}' already exists")
25
+
26
+ main_py_path.write_text(MAIN_TEMPLATE_CONTENT)
27
+
28
+
29
+ def __update_pyproject_toml(
30
+ flwr_project_dir: Union[str, Path],
31
+ aggregator: str,
32
+ datasites: List[str],
33
+ ) -> None:
34
+ """Update the `pyproject.toml` file to the syft-flwr project"""
35
+ flwr_project_dir = Path(flwr_project_dir)
36
+ flwr_pyproject = Path(flwr_project_dir, "pyproject.toml")
37
+ pyproject_conf = load_flwr_pyproject(flwr_pyproject, check_module=False)
38
+
39
+ # TODO: remove this after we find out how to pass the right context to the clients
40
+ pyproject_conf["tool"]["flwr"]["app"]["config"]["partition-id"] = 0
41
+ pyproject_conf["tool"]["flwr"]["app"]["config"]["num-partitions"] = 1
42
+ # TODO end
43
+
44
+ # add syft_flwr as a dependency
45
+ if "dependencies" not in pyproject_conf["project"]:
46
+ pyproject_conf["project"]["dependencies"] = []
47
+
48
+ deps: list = pyproject_conf["project"]["dependencies"]
49
+ deps = [dep for dep in deps if not dep.startswith("syft_flwr")]
50
+ deps.append(f"syft_flwr=={__version__}")
51
+ pyproject_conf["project"]["dependencies"] = deps
52
+
53
+ # always override the datasites and aggregator
54
+ pyproject_conf["tool"]["syft_flwr"] = {}
55
+ pyproject_conf["tool"]["syft_flwr"]["datasites"] = datasites
56
+ pyproject_conf["tool"]["syft_flwr"]["aggregator"] = aggregator
57
+
58
+ write_toml(flwr_pyproject, pyproject_conf)
59
+
60
+
61
+ def __validate_flwr_project_dir(flwr_project_dir: Union[str, Path]) -> Path:
62
+ flwr_pyproject = flwr_project_dir / "pyproject.toml"
63
+ flwr_main_py = flwr_project_dir / "main.py"
64
+
65
+ if flwr_main_py.exists():
66
+ raise FileExistsError(f"File '{flwr_main_py}' already exists")
67
+
68
+ if not flwr_project_dir.exists():
69
+ raise FileNotFoundError(f"Directory '{flwr_project_dir}' not found")
70
+
71
+ if not flwr_pyproject.exists():
72
+ raise FileNotFoundError(f"File '{flwr_pyproject}' not found")
73
+
74
+
75
+ def bootstrap(
76
+ flwr_project_dir: Union[str, Path], aggregator: str, datasites: List[str]
77
+ ) -> None:
78
+ """Bootstrap a new syft-flwr project from the flwr project at the given path"""
79
+ flwr_project_dir = Path(flwr_project_dir)
80
+
81
+ if not is_valid_datasite(aggregator):
82
+ raise ValueError(f"'{aggregator}' is not a valid datasite")
83
+
84
+ if len(datasites) < 2:
85
+ raise ValueError("You must provide at least two datasites as Flower clients")
86
+
87
+ for ds in datasites:
88
+ if not is_valid_datasite(ds):
89
+ raise ValueError(f"{ds} is not a valid datasite")
90
+
91
+ __validate_flwr_project_dir(flwr_project_dir)
92
+ __update_pyproject_toml(flwr_project_dir, aggregator, datasites)
93
+ __copy_main_py(flwr_project_dir)
94
+
95
+ logger.info(
96
+ f"Successfully bootstrapped syft-flwr project at {flwr_project_dir} with datasites {datasites} and aggregator {aggregator} ✅"
97
+ )
syft_flwr/cli.py ADDED
@@ -0,0 +1,116 @@
1
+ from pathlib import Path
2
+
3
+ import typer
4
+ from rich import print as rprint
5
+ from typing_extensions import Annotated, List, Tuple
6
+
7
+ app = typer.Typer(
8
+ name="syft_flwr",
9
+ no_args_is_help=True,
10
+ pretty_exceptions_enable=False,
11
+ context_settings={"help_option_names": ["-h", "--help"]},
12
+ )
13
+
14
+
15
+ @app.command()
16
+ def version() -> None:
17
+ """Print syft_flwr version"""
18
+ from syft_flwr import __version__
19
+
20
+ print(__version__)
21
+
22
+
23
+ PROJECT_DIR_OPTS = typer.Argument(help="Path to a Flower project")
24
+ AGGREGATOR_OPTS = typer.Option(
25
+ "-a",
26
+ "--aggregator",
27
+ "-s",
28
+ "--server",
29
+ help="Datasite email of the Flower Server",
30
+ )
31
+ DATASITES_OPTS = typer.Option(
32
+ "-d",
33
+ "--datasites",
34
+ help="Datasites addresses",
35
+ )
36
+ MOCK_DATASET_PATHS_OPTS = typer.Option(
37
+ "-m",
38
+ "--mock-dataset-paths",
39
+ help="Mock dataset paths",
40
+ )
41
+
42
+
43
+ def prompt_for_missing_args(
44
+ aggregator: str, datasites: List[str]
45
+ ) -> Tuple[Path, str, List[str]]:
46
+ if not aggregator:
47
+ aggregator = typer.prompt(
48
+ "Enter the datasite email of the Aggregator (Flower Server)"
49
+ )
50
+ if not datasites:
51
+ datasites = typer.prompt(
52
+ "Enter a comma-separated email of datasites of the Flower Clients"
53
+ )
54
+ datasites = datasites.split(",")
55
+
56
+ return aggregator, datasites
57
+
58
+
59
+ def prompt_for_missing_mock_paths(mock_dataset_paths: List[str]) -> List[str]:
60
+ if not mock_dataset_paths:
61
+ mock_paths = typer.prompt("Enter comma-separated paths to mock datasets")
62
+ mock_dataset_paths = mock_paths.split(",")
63
+
64
+ return mock_dataset_paths
65
+
66
+
67
+ @app.command()
68
+ def bootstrap(
69
+ project_dir: Annotated[Path, PROJECT_DIR_OPTS],
70
+ aggregator: Annotated[str, AGGREGATOR_OPTS] = None,
71
+ datasites: Annotated[List[str], DATASITES_OPTS] = None,
72
+ ) -> None:
73
+ """Bootstrap a new syft_flwr project from a flwr project"""
74
+ from syft_flwr.bootstrap import bootstrap
75
+
76
+ aggregator, datasites = prompt_for_missing_args(aggregator, datasites)
77
+
78
+ try:
79
+ project_dir = project_dir.absolute()
80
+ rprint(f"[cyan]Bootstrapping project at '{project_dir}'[/cyan]")
81
+ rprint(f"[cyan]Aggregator: {aggregator}[/cyan]")
82
+ rprint(f"[cyan]Datasites: {datasites}[/cyan]")
83
+ bootstrap(project_dir, aggregator, datasites)
84
+ rprint(f"[green]Bootstrapped project at '{project_dir}'[/green]")
85
+ except Exception as e:
86
+ rprint(f"[red]Error[/red]: {e}")
87
+ raise typer.Exit(1)
88
+
89
+
90
+ @app.command()
91
+ def run(
92
+ project_dir: Annotated[Path, PROJECT_DIR_OPTS],
93
+ mock_dataset_paths: Annotated[List[str], MOCK_DATASET_PATHS_OPTS] = None,
94
+ ) -> None:
95
+ """Run a syft_flwr project in simulation mode over mock data"""
96
+ from syft_flwr.run_simulation import run
97
+
98
+ try:
99
+ mock_dataset_paths: list[str] = prompt_for_missing_mock_paths(
100
+ mock_dataset_paths
101
+ )
102
+ project_dir = Path(project_dir).expanduser().resolve()
103
+ rprint(f"[cyan]Running syft_flwr project at '{project_dir}'[/cyan]")
104
+ rprint(f"[cyan]Mock dataset paths: {mock_dataset_paths}[/cyan]")
105
+ run(project_dir, mock_dataset_paths)
106
+ except Exception as e:
107
+ rprint(f"[red]Error[/red]: {e}")
108
+ raise typer.Exit(1)
109
+
110
+
111
+ def main() -> None:
112
+ app()
113
+
114
+
115
+ if __name__ == "__main__":
116
+ main()
syft_flwr/config.py ADDED
@@ -0,0 +1,36 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+
5
+ import tomli
6
+ import tomli_w
7
+ from flwr.common.config import validate_config
8
+ from loguru import logger
9
+
10
+
11
+ def load_toml(path: str):
12
+ with open(path, "rb") as fp:
13
+ return tomli.load(fp)
14
+
15
+
16
+ def write_toml(path: str, val: dict):
17
+ with open(path, "wb") as fp:
18
+ tomli_w.dump(val, fp)
19
+
20
+
21
+ def load_flwr_pyproject(path: Path, check_module: bool = True) -> dict:
22
+ """Load the flower's pyproject.toml file and validate it."""
23
+
24
+ if path.name != "pyproject.toml":
25
+ path = path / "pyproject.toml"
26
+
27
+ pyproject = load_toml(path)
28
+ is_valid, errors, warnings = validate_config(pyproject, check_module, path.parent)
29
+
30
+ if not is_valid:
31
+ raise Exception(errors)
32
+
33
+ if warnings:
34
+ logger.warning(warnings)
35
+
36
+ return pyproject
@@ -0,0 +1,65 @@
1
+ import sys
2
+ import traceback
3
+
4
+ from flwr.client import ClientApp
5
+ from flwr.common import Context
6
+ from flwr.common.constant import ErrorCode
7
+ from flwr.common.message import Error, Message
8
+ from loguru import logger
9
+ from syft_event import SyftEvents
10
+ from syft_event.types import Request
11
+
12
+ from syft_flwr.flwr_compatibility import RecordDict, create_flwr_message
13
+ from syft_flwr.serde import bytes_to_flower_message, flower_message_to_bytes
14
+
15
+
16
+ def syftbox_flwr_client(client_app: ClientApp, context: Context):
17
+ """Run the Flower ClientApp with SyftBox."""
18
+
19
+ box = SyftEvents("flwr")
20
+ client_email = box.client.email
21
+ logger.info(f"Started SyftBox Flower Client on: {client_email}")
22
+
23
+ @box.on_request("/messages")
24
+ def handle_messages(request: Request) -> None:
25
+ logger.info(
26
+ f"Received request id: {request.id}, size: {len(request.body) / 1024 / 1024} (MB)"
27
+ )
28
+ message: Message = bytes_to_flower_message(request.body)
29
+
30
+ try:
31
+ reply_message: Message = client_app(message=message, context=context)
32
+ res_bytes: bytes = flower_message_to_bytes(reply_message)
33
+ logger.info(f"Reply message size: {len(res_bytes)/2**20} MB")
34
+ return res_bytes
35
+
36
+ except Exception as e:
37
+ error_traceback = traceback.format_exc()
38
+ error_message = f"Client: '{client_email}'. Error: {str(e)}. Traceback: {error_traceback}"
39
+ logger.error(error_message)
40
+
41
+ error = Error(
42
+ code=ErrorCode.CLIENT_APP_RAISED_EXCEPTION, reason=f"{error_message}"
43
+ )
44
+
45
+ error_reply: Message = create_flwr_message(
46
+ content=RecordDict(),
47
+ reply_to=message,
48
+ message_type=message.metadata.message_type,
49
+ src_node_id=message.metadata.dst_node_id,
50
+ dst_node_id=message.metadata.src_node_id,
51
+ group_id=message.metadata.group_id,
52
+ run_id=message.metadata.run_id,
53
+ error=error,
54
+ )
55
+ error_bytes: bytes = flower_message_to_bytes(error_reply)
56
+ logger.info(f"Error reply message size: {len(error_bytes)/2**20} MB")
57
+ return error_bytes
58
+
59
+ try:
60
+ box.run_forever()
61
+ except Exception as e:
62
+ logger.error(
63
+ f"Fatal error in syftbox_flwr_client: {str(e)}\n{traceback.format_exc()}"
64
+ )
65
+ sys.exit(1)
@@ -0,0 +1,25 @@
1
+ from random import randint
2
+
3
+ from flwr.common import Context
4
+ from flwr.server import ServerApp
5
+ from flwr.server.run_serverapp import run as run_server
6
+ from loguru import logger
7
+
8
+ from syft_flwr.grid import SyftGrid
9
+
10
+
11
+ def syftbox_flwr_server(server_app: ServerApp, context: Context, datasites: list[str]):
12
+ """Run the Flower ServerApp with SyftBox."""
13
+ syft_grid = SyftGrid(datasites=datasites)
14
+ run_id = randint(0, 1000)
15
+ syft_grid.set_run(run_id)
16
+ logger.info(f"Started SyftBox Flower Server on: {syft_grid._client.email}")
17
+
18
+ updated_context = run_server(
19
+ syft_grid,
20
+ context=context,
21
+ loaded_server_app=server_app,
22
+ server_app_dir="",
23
+ )
24
+ logger.info(f"Server completed with context: {updated_context}")
25
+ return updated_context
@@ -0,0 +1,121 @@
1
+ import flwr
2
+ from flwr.common import Metadata
3
+ from flwr.common.message import Error, Message
4
+ from packaging.version import Version
5
+ from typing_extensions import Optional
6
+
7
+
8
+ def flwr_later_than_1_17():
9
+ return Version(flwr.__version__) >= Version("1.17.0")
10
+
11
+
12
+ # Version-dependent imports
13
+ if flwr_later_than_1_17():
14
+ from flwr.common.record import RecordDict
15
+ from flwr.server.grid import Grid
16
+ else:
17
+ from flwr.common.record import RecordSet as RecordDict
18
+ from flwr.server.driver import Driver as Grid
19
+
20
+
21
+ __all__ = ["Grid", "RecordDict"]
22
+
23
+
24
+ def check_reply_to_field(metadata: Metadata) -> bool:
25
+ """Check if reply_to field is empty based on Flower version."""
26
+ if flwr_later_than_1_17():
27
+ return metadata.reply_to_message_id == ""
28
+ else:
29
+ return metadata.reply_to_message == ""
30
+
31
+
32
+ def create_flwr_message(
33
+ content: RecordDict,
34
+ message_type: str,
35
+ src_node_id: int,
36
+ dst_node_id: int,
37
+ group_id: str,
38
+ run_id: int,
39
+ ttl: Optional[float] = None,
40
+ error: Optional[Error] = None,
41
+ reply_to: Optional[Message] = None,
42
+ ) -> Message:
43
+ """Create a Flower message with version-compatible parameters."""
44
+ if flwr_later_than_1_17():
45
+ return _create_message_v1_17_plus(
46
+ content,
47
+ message_type,
48
+ dst_node_id,
49
+ group_id,
50
+ ttl,
51
+ error,
52
+ reply_to,
53
+ )
54
+ else:
55
+ return _create_message_pre_v1_17(
56
+ content,
57
+ message_type,
58
+ src_node_id,
59
+ dst_node_id,
60
+ group_id,
61
+ run_id,
62
+ ttl,
63
+ error,
64
+ )
65
+
66
+
67
+ def _create_message_v1_17_plus(
68
+ content: RecordDict,
69
+ message_type: str,
70
+ dst_node_id: int,
71
+ group_id: str,
72
+ ttl: Optional[float],
73
+ error: Optional[Error],
74
+ reply_to: Optional[Message],
75
+ ) -> Message:
76
+ """Create message for Flower version 1.17+."""
77
+ if reply_to is not None:
78
+ if error is not None:
79
+ return Message(reply_to=reply_to, error=error)
80
+ return Message(content=content, reply_to=reply_to)
81
+ else:
82
+ if error is not None:
83
+ raise ValueError("Error and reply_to cannot both be None")
84
+ return Message(
85
+ content=content,
86
+ dst_node_id=dst_node_id,
87
+ message_type=message_type,
88
+ ttl=ttl,
89
+ group_id=group_id,
90
+ )
91
+
92
+
93
+ def _create_message_pre_v1_17(
94
+ content: RecordDict,
95
+ message_type: str,
96
+ src_node_id: int,
97
+ dst_node_id: int,
98
+ group_id: str,
99
+ run_id: int,
100
+ ttl: Optional[float],
101
+ error: Optional[Error],
102
+ ) -> Message:
103
+ """Create message for Flower versions before 1.17."""
104
+ from flwr.common import DEFAULT_TTL
105
+
106
+ ttl_ = DEFAULT_TTL if ttl is None else ttl
107
+ metadata = Metadata(
108
+ run_id=run_id,
109
+ message_id="", # Will be set when saving to file
110
+ src_node_id=src_node_id,
111
+ dst_node_id=dst_node_id,
112
+ reply_to_message="",
113
+ group_id=group_id,
114
+ ttl=ttl_,
115
+ message_type=message_type,
116
+ )
117
+
118
+ if error is not None:
119
+ return Message(metadata=metadata, error=error)
120
+ else:
121
+ return Message(metadata=metadata, content=content)