flwr-nightly 1.14.0.dev20241210__py3-none-any.whl → 1.14.0.dev20241212__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (36) hide show
  1. flwr/cli/app.py +2 -0
  2. flwr/cli/cli_user_auth_interceptor.py +86 -0
  3. flwr/cli/config_utils.py +18 -2
  4. flwr/cli/log.py +10 -31
  5. flwr/cli/login/__init__.py +21 -0
  6. flwr/cli/login/login.py +82 -0
  7. flwr/cli/ls.py +10 -40
  8. flwr/cli/new/new.py +1 -1
  9. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -2
  10. flwr/cli/run/run.py +14 -25
  11. flwr/cli/stop.py +9 -39
  12. flwr/cli/utils.py +108 -1
  13. flwr/client/mod/localdp_mod.py +1 -1
  14. flwr/common/config.py +2 -1
  15. flwr/common/constant.py +4 -1
  16. flwr/common/object_ref.py +57 -54
  17. flwr/common/retry_invoker.py +75 -0
  18. flwr/common/secure_aggregation/secaggplus_utils.py +2 -2
  19. flwr/common/telemetry.py +2 -1
  20. flwr/common/typing.py +4 -0
  21. flwr/proto/fab_pb2.py +4 -4
  22. flwr/proto/fab_pb2.pyi +4 -1
  23. flwr/proto/serverappio_pb2.py +18 -18
  24. flwr/proto/serverappio_pb2.pyi +8 -2
  25. flwr/proto/serverappio_pb2_grpc.py +34 -0
  26. flwr/proto/serverappio_pb2_grpc.pyi +13 -0
  27. flwr/server/compat/app_utils.py +7 -1
  28. flwr/server/driver/grpc_driver.py +10 -63
  29. flwr/server/serverapp/app.py +8 -2
  30. flwr/server/superlink/driver/serverappio_servicer.py +68 -6
  31. flwr/server/superlink/utils.py +65 -0
  32. {flwr_nightly-1.14.0.dev20241210.dist-info → flwr_nightly-1.14.0.dev20241212.dist-info}/METADATA +6 -6
  33. {flwr_nightly-1.14.0.dev20241210.dist-info → flwr_nightly-1.14.0.dev20241212.dist-info}/RECORD +36 -32
  34. {flwr_nightly-1.14.0.dev20241210.dist-info → flwr_nightly-1.14.0.dev20241212.dist-info}/LICENSE +0 -0
  35. {flwr_nightly-1.14.0.dev20241210.dist-info → flwr_nightly-1.14.0.dev20241212.dist-info}/WHEEL +0 -0
  36. {flwr_nightly-1.14.0.dev20241210.dist-info → flwr_nightly-1.14.0.dev20241212.dist-info}/entry_points.txt +0 -0
flwr/cli/utils.py CHANGED
@@ -15,12 +15,32 @@
15
15
  """Flower command line interface utils."""
16
16
 
17
17
  import hashlib
18
+ import json
18
19
  import re
20
+ from logging import DEBUG
19
21
  from pathlib import Path
20
- from typing import Callable, Optional, cast
22
+ from typing import Any, Callable, Optional, cast
21
23
 
24
+ import grpc
22
25
  import typer
23
26
 
27
+ from flwr.cli.cli_user_auth_interceptor import CliUserAuthInterceptor
28
+ from flwr.common.address import parse_address
29
+ from flwr.common.auth_plugin import CliAuthPlugin
30
+ from flwr.common.constant import AUTH_TYPE, CREDENTIALS_DIR, FLWR_DIR
31
+ from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
32
+ from flwr.common.logger import log
33
+
34
+ from .config_utils import validate_certificate_in_federation_config
35
+
36
+ try:
37
+ from flwr.ee import get_cli_auth_plugins
38
+ except ImportError:
39
+
40
+ def get_cli_auth_plugins() -> dict[str, type[CliAuthPlugin]]:
41
+ """Return all CLI authentication plugins."""
42
+ raise NotImplementedError("No authentication plugins are currently supported.")
43
+
24
44
 
25
45
  def prompt_text(
26
46
  text: str,
@@ -136,3 +156,90 @@ def get_sha256_hash(file_path: Path) -> str:
136
156
  break
137
157
  sha256.update(data)
138
158
  return sha256.hexdigest()
159
+
160
+
161
+ def get_user_auth_config_path(
162
+ root_dir: Path, federation: str, server_address: str
163
+ ) -> Path:
164
+ """Return the path to the user auth config file."""
165
+ # Parse the server address
166
+ parsed_addr = parse_address(server_address)
167
+ if parsed_addr is None:
168
+ raise ValueError(f"Invalid server address: {server_address}")
169
+ host, port, is_v6 = parsed_addr
170
+ formatted_addr = f"[{host}]_{port}" if is_v6 else f"{host}_{port}"
171
+
172
+ # Locate the credentials directory
173
+ credentials_dir = root_dir.absolute() / FLWR_DIR / CREDENTIALS_DIR
174
+ credentials_dir.mkdir(parents=True, exist_ok=True)
175
+ return credentials_dir / f"{federation}_{formatted_addr}.json"
176
+
177
+
178
+ def try_obtain_cli_auth_plugin(
179
+ root_dir: Path,
180
+ federation: str,
181
+ federation_config: dict[str, Any],
182
+ auth_type: Optional[str] = None,
183
+ ) -> Optional[CliAuthPlugin]:
184
+ """Load the CLI-side user auth plugin for the given auth type."""
185
+ config_path = get_user_auth_config_path(
186
+ root_dir, federation, federation_config["address"]
187
+ )
188
+
189
+ # Load the config file if it exists
190
+ config: dict[str, Any] = {}
191
+ if config_path.exists():
192
+ with config_path.open("r", encoding="utf-8") as file:
193
+ config = json.load(file)
194
+ # This is the case when the user auth is not enabled
195
+ elif auth_type is None:
196
+ return None
197
+
198
+ # Get the auth type from the config if not provided
199
+ if auth_type is None:
200
+ if AUTH_TYPE not in config:
201
+ return None
202
+ auth_type = config[AUTH_TYPE]
203
+
204
+ # Retrieve auth plugin class and instantiate it
205
+ try:
206
+ all_plugins: dict[str, type[CliAuthPlugin]] = get_cli_auth_plugins()
207
+ auth_plugin_class = all_plugins[auth_type]
208
+ return auth_plugin_class(config_path)
209
+ except KeyError:
210
+ typer.echo(f"❌ Unknown user authentication type: {auth_type}")
211
+ raise typer.Exit(code=1) from None
212
+ except ImportError:
213
+ typer.echo("❌ No authentication plugins are currently supported.")
214
+ raise typer.Exit(code=1) from None
215
+
216
+
217
+ def init_channel(
218
+ app: Path, federation_config: dict[str, Any], auth_plugin: Optional[CliAuthPlugin]
219
+ ) -> grpc.Channel:
220
+ """Initialize gRPC channel to the Exec API."""
221
+
222
+ def on_channel_state_change(channel_connectivity: str) -> None:
223
+ """Log channel connectivity."""
224
+ log(DEBUG, channel_connectivity)
225
+
226
+ insecure, root_certificates_bytes = validate_certificate_in_federation_config(
227
+ app, federation_config
228
+ )
229
+
230
+ # Initialize the CLI-side user auth interceptor
231
+ interceptors: list[grpc.UnaryUnaryClientInterceptor] = []
232
+ if auth_plugin is not None:
233
+ auth_plugin.load_tokens()
234
+ interceptors = CliUserAuthInterceptor(auth_plugin)
235
+
236
+ # Create the gRPC channel
237
+ channel = create_channel(
238
+ server_address=federation_config["address"],
239
+ insecure=insecure,
240
+ root_certificates=root_certificates_bytes,
241
+ max_message_length=GRPC_MAX_MESSAGE_LENGTH,
242
+ interceptors=interceptors or None,
243
+ )
244
+ channel.subscribe(on_channel_state_change)
245
+ return channel
@@ -136,7 +136,7 @@ class LocalDpMod:
136
136
  fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
137
137
 
138
138
  # Add noise to model params
139
- add_localdp_gaussian_noise_to_params(
139
+ fit_res.parameters = add_localdp_gaussian_noise_to_params(
140
140
  fit_res.parameters, self.sensitivity, self.epsilon, self.delta
141
141
  )
142
142
 
flwr/common/config.py CHANGED
@@ -27,6 +27,7 @@ from flwr.common.constant import (
27
27
  APP_DIR,
28
28
  FAB_CONFIG_FILE,
29
29
  FAB_HASH_TRUNCATION,
30
+ FLWR_DIR,
30
31
  FLWR_HOME,
31
32
  )
32
33
  from flwr.common.typing import Run, UserConfig, UserConfigValue
@@ -38,7 +39,7 @@ def get_flwr_dir(provided_path: Optional[str] = None) -> Path:
38
39
  return Path(
39
40
  os.getenv(
40
41
  FLWR_HOME,
41
- Path(f"{os.getenv('XDG_DATA_HOME', os.getenv('HOME'))}") / ".flwr",
42
+ Path(f"{os.getenv('XDG_DATA_HOME', os.getenv('HOME'))}") / FLWR_DIR,
42
43
  )
43
44
  )
44
45
  return Path(provided_path).absolute()
flwr/common/constant.py CHANGED
@@ -80,7 +80,8 @@ FAB_ALLOWED_EXTENSIONS = {".py", ".toml", ".md"}
80
80
  FAB_CONFIG_FILE = "pyproject.toml"
81
81
  FAB_DATE = (2024, 10, 1, 0, 0, 0)
82
82
  FAB_HASH_TRUNCATION = 8
83
- FLWR_HOME = "FLWR_HOME"
83
+ FLWR_DIR = ".flwr" # The default Flower directory: ~/.flwr/
84
+ FLWR_HOME = "FLWR_HOME" # If set, override the default Flower directory
84
85
 
85
86
  # Constants entries in Node config for Simulation
86
87
  PARTITION_ID_KEY = "partition-id"
@@ -110,6 +111,8 @@ LOG_UPLOAD_INTERVAL = 0.2 # Minimum interval between two log uploads
110
111
  # Retry configurations
111
112
  MAX_RETRY_DELAY = 20 # Maximum delay duration between two consecutive retries.
112
113
 
114
+ # Constants for user authentication
115
+ CREDENTIALS_DIR = ".credentials"
113
116
  AUTH_TYPE = "auth_type"
114
117
 
115
118
 
flwr/common/object_ref.py CHANGED
@@ -21,6 +21,7 @@ import sys
21
21
  from importlib.util import find_spec
22
22
  from logging import WARN
23
23
  from pathlib import Path
24
+ from threading import Lock
24
25
  from typing import Any, Optional, Union
25
26
 
26
27
  from .logger import log
@@ -34,6 +35,7 @@ attribute.
34
35
 
35
36
 
36
37
  _current_sys_path: Optional[str] = None
38
+ _import_lock = Lock()
37
39
 
38
40
 
39
41
  def validate(
@@ -146,60 +148,61 @@ def load_app( # pylint: disable= too-many-branches
146
148
  - This function will modify `sys.path` by inserting the provided `project_dir`
147
149
  and removing the previously inserted `project_dir`.
148
150
  """
149
- valid, error_msg = validate(module_attribute_str, check_module=False)
150
- if not valid and error_msg:
151
- raise error_type(error_msg) from None
152
-
153
- module_str, _, attributes_str = module_attribute_str.partition(":")
154
-
155
- try:
156
- # Initialize project path
157
- if project_dir is None:
158
- project_dir = Path.cwd()
159
- project_dir = Path(project_dir).absolute()
160
-
161
- # Unload modules if the project directory has changed
162
- if _current_sys_path and _current_sys_path != str(project_dir):
163
- _unload_modules(Path(_current_sys_path))
164
-
165
- # Set the system path
166
- _set_sys_path(project_dir)
167
-
168
- # Import the module
169
- if module_str not in sys.modules:
170
- module = importlib.import_module(module_str)
171
- # Hack: `tabnet` does not work with `importlib.reload`
172
- elif "tabnet" in sys.modules:
173
- log(
174
- WARN,
175
- "Cannot reload module `%s` from disk due to compatibility issues "
176
- "with the `tabnet` library. The module will be loaded from the "
177
- "cache instead. If you experience issues, consider restarting "
178
- "the application.",
179
- module_str,
180
- )
181
- module = sys.modules[module_str]
182
- else:
183
- module = sys.modules[module_str]
184
- _reload_modules(project_dir)
185
-
186
- except ModuleNotFoundError as err:
187
- raise error_type(
188
- f"Unable to load module {module_str}{OBJECT_REF_HELP_STR}",
189
- ) from err
190
-
191
- # Recursively load attribute
192
- attribute = module
193
- try:
194
- for attribute_str in attributes_str.split("."):
195
- attribute = getattr(attribute, attribute_str)
196
- except AttributeError as err:
197
- raise error_type(
198
- f"Unable to load attribute {attributes_str} from module {module_str}"
199
- f"{OBJECT_REF_HELP_STR}",
200
- ) from err
201
-
202
- return attribute
151
+ with _import_lock:
152
+ valid, error_msg = validate(module_attribute_str, check_module=False)
153
+ if not valid and error_msg:
154
+ raise error_type(error_msg) from None
155
+
156
+ module_str, _, attributes_str = module_attribute_str.partition(":")
157
+
158
+ try:
159
+ # Initialize project path
160
+ if project_dir is None:
161
+ project_dir = Path.cwd()
162
+ project_dir = Path(project_dir).absolute()
163
+
164
+ # Unload modules if the project directory has changed
165
+ if _current_sys_path and _current_sys_path != str(project_dir):
166
+ _unload_modules(Path(_current_sys_path))
167
+
168
+ # Set the system path
169
+ _set_sys_path(project_dir)
170
+
171
+ # Import the module
172
+ if module_str not in sys.modules:
173
+ module = importlib.import_module(module_str)
174
+ # Hack: `tabnet` does not work with `importlib.reload`
175
+ elif "tabnet" in sys.modules:
176
+ log(
177
+ WARN,
178
+ "Cannot reload module `%s` from disk due to compatibility issues "
179
+ "with the `tabnet` library. The module will be loaded from the "
180
+ "cache instead. If you experience issues, consider restarting "
181
+ "the application.",
182
+ module_str,
183
+ )
184
+ module = sys.modules[module_str]
185
+ else:
186
+ module = sys.modules[module_str]
187
+ _reload_modules(project_dir)
188
+
189
+ except ModuleNotFoundError as err:
190
+ raise error_type(
191
+ f"Unable to load module {module_str}{OBJECT_REF_HELP_STR}",
192
+ ) from err
193
+
194
+ # Recursively load attribute
195
+ attribute = module
196
+ try:
197
+ for attribute_str in attributes_str.split("."):
198
+ attribute = getattr(attribute, attribute_str)
199
+ except AttributeError as err:
200
+ raise error_type(
201
+ f"Unable to load attribute {attributes_str} from module {module_str}"
202
+ f"{OBJECT_REF_HELP_STR}",
203
+ ) from err
204
+
205
+ return attribute
203
206
 
204
207
 
205
208
  def _unload_modules(project_dir: Path) -> None:
@@ -20,8 +20,17 @@ import random
20
20
  import time
21
21
  from collections.abc import Generator, Iterable
22
22
  from dataclasses import dataclass
23
+ from logging import INFO, WARN
23
24
  from typing import Any, Callable, Optional, Union, cast
24
25
 
26
+ import grpc
27
+
28
+ from flwr.common.constant import MAX_RETRY_DELAY
29
+ from flwr.common.logger import log
30
+ from flwr.common.typing import RunNotRunningException
31
+ from flwr.proto.clientappio_pb2_grpc import ClientAppIoStub
32
+ from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub
33
+
25
34
 
26
35
  def exponential(
27
36
  base_delay: float = 1,
@@ -303,3 +312,69 @@ class RetryInvoker:
303
312
  # Trigger success event
304
313
  try_call_event_handler(self.on_success)
305
314
  return ret
315
+
316
+
317
+ def _make_simple_grpc_retry_invoker() -> RetryInvoker:
318
+ """Create a simple gRPC retry invoker."""
319
+
320
+ def _on_sucess(retry_state: RetryState) -> None:
321
+ if retry_state.tries > 1:
322
+ log(
323
+ INFO,
324
+ "Connection successful after %.2f seconds and %s tries.",
325
+ retry_state.elapsed_time,
326
+ retry_state.tries,
327
+ )
328
+
329
+ def _on_backoff(retry_state: RetryState) -> None:
330
+ if retry_state.tries == 1:
331
+ log(WARN, "Connection attempt failed, retrying...")
332
+ else:
333
+ log(
334
+ WARN,
335
+ "Connection attempt failed, retrying in %.2f seconds",
336
+ retry_state.actual_wait,
337
+ )
338
+
339
+ def _on_giveup(retry_state: RetryState) -> None:
340
+ if retry_state.tries > 1:
341
+ log(
342
+ WARN,
343
+ "Giving up reconnection after %.2f seconds and %s tries.",
344
+ retry_state.elapsed_time,
345
+ retry_state.tries,
346
+ )
347
+
348
+ def _should_giveup_fn(e: Exception) -> bool:
349
+ if e.code() == grpc.StatusCode.PERMISSION_DENIED: # type: ignore
350
+ raise RunNotRunningException
351
+ if e.code() == grpc.StatusCode.UNAVAILABLE: # type: ignore
352
+ return False
353
+ return True
354
+
355
+ return RetryInvoker(
356
+ wait_gen_factory=lambda: exponential(max_delay=MAX_RETRY_DELAY),
357
+ recoverable_exceptions=grpc.RpcError,
358
+ max_tries=None,
359
+ max_time=None,
360
+ on_success=_on_sucess,
361
+ on_backoff=_on_backoff,
362
+ on_giveup=_on_giveup,
363
+ should_giveup=_should_giveup_fn,
364
+ )
365
+
366
+
367
+ def _wrap_stub(
368
+ stub: Union[ServerAppIoStub, ClientAppIoStub], retry_invoker: RetryInvoker
369
+ ) -> None:
370
+ """Wrap a gRPC stub with a retry invoker."""
371
+
372
+ def make_lambda(original_method: Any) -> Any:
373
+ return lambda *args, **kwargs: retry_invoker.invoke(
374
+ original_method, *args, **kwargs
375
+ )
376
+
377
+ for method_name in vars(stub):
378
+ method = getattr(stub, method_name)
379
+ if callable(method):
380
+ setattr(stub, method_name, make_lambda(method))
@@ -93,8 +93,8 @@ def pseudo_rand_gen(
93
93
  output = []
94
94
  for dimension in dimensions_list:
95
95
  if len(dimension) == 0:
96
- arr = np.array(gen.randint(0, num_range - 1), dtype=int)
96
+ arr = np.array(gen.randint(0, num_range - 1), dtype=np.int64)
97
97
  else:
98
- arr = gen.randint(0, num_range - 1, dimension)
98
+ arr = gen.randint(0, num_range - 1, dimension, dtype=np.int64)
99
99
  output.append(arr)
100
100
  return output
flwr/common/telemetry.py CHANGED
@@ -27,6 +27,7 @@ from enum import Enum, auto
27
27
  from pathlib import Path
28
28
  from typing import Any, Optional, Union, cast
29
29
 
30
+ from flwr.common.constant import FLWR_DIR
30
31
  from flwr.common.version import package_name, package_version
31
32
 
32
33
  FLWR_TELEMETRY_ENABLED = os.getenv("FLWR_TELEMETRY_ENABLED", "1")
@@ -86,7 +87,7 @@ def _get_source_id() -> str:
86
87
  # If the home directory can’t be resolved, RuntimeError is raised.
87
88
  return source_id
88
89
 
89
- flwr_dir = home.joinpath(".flwr")
90
+ flwr_dir = home.joinpath(FLWR_DIR)
90
91
  # Create .flwr directory if it does not exist yet.
91
92
  try:
92
93
  flwr_dir.mkdir(parents=True, exist_ok=True)
flwr/common/typing.py CHANGED
@@ -254,3 +254,7 @@ class Fab:
254
254
 
255
255
  hash_str: str
256
256
  content: bytes
257
+
258
+
259
+ class RunNotRunningException(BaseException):
260
+ """Raised when a run is not running."""
flwr/proto/fab_pb2.py CHANGED
@@ -15,7 +15,7 @@ _sym_db = _symbol_database.Default()
15
15
  from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2
16
16
 
17
17
 
18
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/fab.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\"(\n\x03\x46\x61\x62\x12\x10\n\x08hash_str\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\x0c\"A\n\rGetFabRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08hash_str\x18\x02 \x01(\t\".\n\x0eGetFabResponse\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fabb\x06proto3')
18
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/fab.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\"(\n\x03\x46\x61\x62\x12\x10\n\x08hash_str\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\x0c\"Q\n\rGetFabRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08hash_str\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x04\".\n\x0eGetFabResponse\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fabb\x06proto3')
19
19
 
20
20
  _globals = globals()
21
21
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -25,7 +25,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
25
25
  _globals['_FAB']._serialized_start=59
26
26
  _globals['_FAB']._serialized_end=99
27
27
  _globals['_GETFABREQUEST']._serialized_start=101
28
- _globals['_GETFABREQUEST']._serialized_end=166
29
- _globals['_GETFABRESPONSE']._serialized_start=168
30
- _globals['_GETFABRESPONSE']._serialized_end=214
28
+ _globals['_GETFABREQUEST']._serialized_end=182
29
+ _globals['_GETFABRESPONSE']._serialized_start=184
30
+ _globals['_GETFABRESPONSE']._serialized_end=230
31
31
  # @@protoc_insertion_point(module_scope)
flwr/proto/fab_pb2.pyi CHANGED
@@ -36,16 +36,19 @@ class GetFabRequest(google.protobuf.message.Message):
36
36
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
37
37
  NODE_FIELD_NUMBER: builtins.int
38
38
  HASH_STR_FIELD_NUMBER: builtins.int
39
+ RUN_ID_FIELD_NUMBER: builtins.int
39
40
  @property
40
41
  def node(self) -> flwr.proto.node_pb2.Node: ...
41
42
  hash_str: typing.Text
43
+ run_id: builtins.int
42
44
  def __init__(self,
43
45
  *,
44
46
  node: typing.Optional[flwr.proto.node_pb2.Node] = ...,
45
47
  hash_str: typing.Text = ...,
48
+ run_id: builtins.int = ...,
46
49
  ) -> None: ...
47
50
  def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ...
48
- def ClearField(self, field_name: typing_extensions.Literal["hash_str",b"hash_str","node",b"node"]) -> None: ...
51
+ def ClearField(self, field_name: typing_extensions.Literal["hash_str",b"hash_str","node",b"node","run_id",b"run_id"]) -> None: ...
49
52
  global___GetFabRequest = GetFabRequest
50
53
 
51
54
  class GetFabResponse(google.protobuf.message.Message):
@@ -20,7 +20,7 @@ from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2
20
20
  from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2
21
21
 
22
22
 
23
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/serverappio.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/log.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\x1c\n\x1aPullServerAppInputsRequest\"\x7f\n\x1bPullServerAppInputsResponse\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run\x12\x1c\n\x03\x66\x61\x62\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Fab\"S\n\x1bPushServerAppOutputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\"\x1e\n\x1cPushServerAppOutputsResponse2\xca\x06\n\x0bServerAppIo\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12h\n\x13PullServerAppInputs\x12&.flwr.proto.PullServerAppInputsRequest\x1a\'.flwr.proto.PullServerAppInputsResponse\"\x00\x12k\n\x14PushServerAppOutputs\x12\'.flwr.proto.PushServerAppOutputsRequest\x1a(.flwr.proto.PushServerAppOutputsResponse\"\x00\x12\\\n\x0fUpdateRunStatus\x12\".flwr.proto.UpdateRunStatusRequest\x1a#.flwr.proto.UpdateRunStatusResponse\"\x00\x12G\n\x08PushLogs\x12\x1b.flwr.proto.PushLogsRequest\x1a\x1c.flwr.proto.PushLogsResponse\"\x00\x62\x06proto3')
23
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/serverappio.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/log.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"P\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"V\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x04\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\x1c\n\x1aPullServerAppInputsRequest\"\x7f\n\x1bPullServerAppInputsResponse\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run\x12\x1c\n\x03\x66\x61\x62\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Fab\"S\n\x1bPushServerAppOutputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\"\x1e\n\x1cPushServerAppOutputsResponse2\x9f\x07\n\x0bServerAppIo\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12h\n\x13PullServerAppInputs\x12&.flwr.proto.PullServerAppInputsRequest\x1a\'.flwr.proto.PullServerAppInputsResponse\"\x00\x12k\n\x14PushServerAppOutputs\x12\'.flwr.proto.PushServerAppOutputsRequest\x1a(.flwr.proto.PushServerAppOutputsResponse\"\x00\x12\\\n\x0fUpdateRunStatus\x12\".flwr.proto.UpdateRunStatusRequest\x1a#.flwr.proto.UpdateRunStatusResponse\"\x00\x12S\n\x0cGetRunStatus\x12\x1f.flwr.proto.GetRunStatusRequest\x1a .flwr.proto.GetRunStatusResponse\"\x00\x12G\n\x08PushLogs\x12\x1b.flwr.proto.PushLogsRequest\x1a\x1c.flwr.proto.PushLogsResponse\"\x00\x62\x06proto3')
24
24
 
25
25
  _globals = globals()
26
26
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -32,21 +32,21 @@ if _descriptor._USE_C_DESCRIPTORS == False:
32
32
  _globals['_GETNODESRESPONSE']._serialized_start=217
33
33
  _globals['_GETNODESRESPONSE']._serialized_end=268
34
34
  _globals['_PUSHTASKINSREQUEST']._serialized_start=270
35
- _globals['_PUSHTASKINSREQUEST']._serialized_end=334
36
- _globals['_PUSHTASKINSRESPONSE']._serialized_start=336
37
- _globals['_PUSHTASKINSRESPONSE']._serialized_end=375
38
- _globals['_PULLTASKRESREQUEST']._serialized_start=377
39
- _globals['_PULLTASKRESREQUEST']._serialized_end=447
40
- _globals['_PULLTASKRESRESPONSE']._serialized_start=449
41
- _globals['_PULLTASKRESRESPONSE']._serialized_end=514
42
- _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_start=516
43
- _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_end=544
44
- _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_start=546
45
- _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_end=673
46
- _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_start=675
47
- _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_end=758
48
- _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_start=760
49
- _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_end=790
50
- _globals['_SERVERAPPIO']._serialized_start=793
51
- _globals['_SERVERAPPIO']._serialized_end=1635
35
+ _globals['_PUSHTASKINSREQUEST']._serialized_end=350
36
+ _globals['_PUSHTASKINSRESPONSE']._serialized_start=352
37
+ _globals['_PUSHTASKINSRESPONSE']._serialized_end=391
38
+ _globals['_PULLTASKRESREQUEST']._serialized_start=393
39
+ _globals['_PULLTASKRESREQUEST']._serialized_end=479
40
+ _globals['_PULLTASKRESRESPONSE']._serialized_start=481
41
+ _globals['_PULLTASKRESRESPONSE']._serialized_end=546
42
+ _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_start=548
43
+ _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_end=576
44
+ _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_start=578
45
+ _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_end=705
46
+ _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_start=707
47
+ _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_end=790
48
+ _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_start=792
49
+ _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_end=822
50
+ _globals['_SERVERAPPIO']._serialized_start=825
51
+ _globals['_SERVERAPPIO']._serialized_end=1752
52
52
  # @@protoc_insertion_point(module_scope)
@@ -44,13 +44,16 @@ class PushTaskInsRequest(google.protobuf.message.Message):
44
44
  """PushTaskIns messages"""
45
45
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
46
46
  TASK_INS_LIST_FIELD_NUMBER: builtins.int
47
+ RUN_ID_FIELD_NUMBER: builtins.int
47
48
  @property
48
49
  def task_ins_list(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[flwr.proto.task_pb2.TaskIns]: ...
50
+ run_id: builtins.int
49
51
  def __init__(self,
50
52
  *,
51
53
  task_ins_list: typing.Optional[typing.Iterable[flwr.proto.task_pb2.TaskIns]] = ...,
54
+ run_id: builtins.int = ...,
52
55
  ) -> None: ...
53
- def ClearField(self, field_name: typing_extensions.Literal["task_ins_list",b"task_ins_list"]) -> None: ...
56
+ def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id","task_ins_list",b"task_ins_list"]) -> None: ...
54
57
  global___PushTaskInsRequest = PushTaskInsRequest
55
58
 
56
59
  class PushTaskInsResponse(google.protobuf.message.Message):
@@ -70,17 +73,20 @@ class PullTaskResRequest(google.protobuf.message.Message):
70
73
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
71
74
  NODE_FIELD_NUMBER: builtins.int
72
75
  TASK_IDS_FIELD_NUMBER: builtins.int
76
+ RUN_ID_FIELD_NUMBER: builtins.int
73
77
  @property
74
78
  def node(self) -> flwr.proto.node_pb2.Node: ...
75
79
  @property
76
80
  def task_ids(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ...
81
+ run_id: builtins.int
77
82
  def __init__(self,
78
83
  *,
79
84
  node: typing.Optional[flwr.proto.node_pb2.Node] = ...,
80
85
  task_ids: typing.Optional[typing.Iterable[typing.Text]] = ...,
86
+ run_id: builtins.int = ...,
81
87
  ) -> None: ...
82
88
  def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ...
83
- def ClearField(self, field_name: typing_extensions.Literal["node",b"node","task_ids",b"task_ids"]) -> None: ...
89
+ def ClearField(self, field_name: typing_extensions.Literal["node",b"node","run_id",b"run_id","task_ids",b"task_ids"]) -> None: ...
84
90
  global___PullTaskResRequest = PullTaskResRequest
85
91
 
86
92
  class PullTaskResResponse(google.protobuf.message.Message):
@@ -62,6 +62,11 @@ class ServerAppIoStub(object):
62
62
  request_serializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.SerializeToString,
63
63
  response_deserializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.FromString,
64
64
  )
65
+ self.GetRunStatus = channel.unary_unary(
66
+ '/flwr.proto.ServerAppIo/GetRunStatus',
67
+ request_serializer=flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.SerializeToString,
68
+ response_deserializer=flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.FromString,
69
+ )
65
70
  self.PushLogs = channel.unary_unary(
66
71
  '/flwr.proto.ServerAppIo/PushLogs',
67
72
  request_serializer=flwr_dot_proto_dot_log__pb2.PushLogsRequest.SerializeToString,
@@ -135,6 +140,13 @@ class ServerAppIoServicer(object):
135
140
  context.set_details('Method not implemented!')
136
141
  raise NotImplementedError('Method not implemented!')
137
142
 
143
+ def GetRunStatus(self, request, context):
144
+ """Get the status of a given run
145
+ """
146
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
147
+ context.set_details('Method not implemented!')
148
+ raise NotImplementedError('Method not implemented!')
149
+
138
150
  def PushLogs(self, request, context):
139
151
  """Push ServerApp logs
140
152
  """
@@ -190,6 +202,11 @@ def add_ServerAppIoServicer_to_server(servicer, server):
190
202
  request_deserializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.FromString,
191
203
  response_serializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.SerializeToString,
192
204
  ),
205
+ 'GetRunStatus': grpc.unary_unary_rpc_method_handler(
206
+ servicer.GetRunStatus,
207
+ request_deserializer=flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.FromString,
208
+ response_serializer=flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.SerializeToString,
209
+ ),
193
210
  'PushLogs': grpc.unary_unary_rpc_method_handler(
194
211
  servicer.PushLogs,
195
212
  request_deserializer=flwr_dot_proto_dot_log__pb2.PushLogsRequest.FromString,
@@ -358,6 +375,23 @@ class ServerAppIo(object):
358
375
  options, channel_credentials,
359
376
  insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
360
377
 
378
+ @staticmethod
379
+ def GetRunStatus(request,
380
+ target,
381
+ options=(),
382
+ channel_credentials=None,
383
+ call_credentials=None,
384
+ insecure=False,
385
+ compression=None,
386
+ wait_for_ready=None,
387
+ timeout=None,
388
+ metadata=None):
389
+ return grpc.experimental.unary_unary(request, target, '/flwr.proto.ServerAppIo/GetRunStatus',
390
+ flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.SerializeToString,
391
+ flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.FromString,
392
+ options, channel_credentials,
393
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
394
+
361
395
  @staticmethod
362
396
  def PushLogs(request,
363
397
  target,