flwr-nightly 1.9.0.dev20240531__py3-none-any.whl → 1.10.0.dev20240612__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +2 -0
- flwr/cli/build.py +2 -15
- flwr/cli/config_utils.py +11 -4
- flwr/cli/install.py +196 -0
- flwr/cli/new/templates/app/pyproject.hf.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/utils.py +14 -0
- flwr/client/__init__.py +1 -0
- flwr/client/app.py +135 -97
- flwr/client/client_app.py +1 -1
- flwr/client/grpc_rere_client/client_interceptor.py +1 -1
- flwr/client/grpc_rere_client/connection.py +1 -4
- flwr/client/mod/__init__.py +1 -1
- flwr/client/rest_client/connection.py +1 -2
- flwr/client/supernode/app.py +29 -5
- flwr/common/object_ref.py +13 -9
- flwr/proto/driver_pb2.py +20 -19
- flwr/proto/driver_pb2_grpc.py +35 -0
- flwr/proto/driver_pb2_grpc.pyi +14 -0
- flwr/proto/fleet_pb2.py +28 -33
- flwr/proto/fleet_pb2.pyi +0 -42
- flwr/proto/fleet_pb2_grpc.py +7 -6
- flwr/proto/fleet_pb2_grpc.pyi +5 -4
- flwr/proto/run_pb2.py +30 -0
- flwr/proto/run_pb2.pyi +52 -0
- flwr/proto/run_pb2_grpc.py +4 -0
- flwr/proto/run_pb2_grpc.pyi +4 -0
- flwr/server/__init__.py +0 -4
- flwr/server/app.py +57 -214
- flwr/server/run_serverapp.py +29 -5
- flwr/server/server_app.py +2 -2
- flwr/server/superlink/driver/driver_servicer.py +7 -0
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +1 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +5 -3
- flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -1
- {flwr_nightly-1.9.0.dev20240531.dist-info → flwr_nightly-1.10.0.dev20240612.dist-info}/METADATA +1 -1
- {flwr_nightly-1.9.0.dev20240531.dist-info → flwr_nightly-1.10.0.dev20240612.dist-info}/RECORD +46 -41
- {flwr_nightly-1.9.0.dev20240531.dist-info → flwr_nightly-1.10.0.dev20240612.dist-info}/entry_points.txt +0 -2
- {flwr_nightly-1.9.0.dev20240531.dist-info → flwr_nightly-1.10.0.dev20240612.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.9.0.dev20240531.dist-info → flwr_nightly-1.10.0.dev20240612.dist-info}/WHEEL +0 -0
flwr/client/app.py
CHANGED
|
@@ -14,8 +14,10 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower client app."""
|
|
16
16
|
|
|
17
|
+
import signal
|
|
17
18
|
import sys
|
|
18
19
|
import time
|
|
20
|
+
from dataclasses import dataclass
|
|
19
21
|
from logging import DEBUG, ERROR, INFO, WARN
|
|
20
22
|
from typing import Callable, ContextManager, Optional, Tuple, Type, Union
|
|
21
23
|
|
|
@@ -37,7 +39,7 @@ from flwr.common.constant import (
|
|
|
37
39
|
)
|
|
38
40
|
from flwr.common.logger import log, warn_deprecated_feature
|
|
39
41
|
from flwr.common.message import Error
|
|
40
|
-
from flwr.common.retry_invoker import RetryInvoker, exponential
|
|
42
|
+
from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
|
|
41
43
|
|
|
42
44
|
from .grpc_client.connection import grpc_connection
|
|
43
45
|
from .grpc_rere_client.connection import grpc_request_response
|
|
@@ -263,6 +265,29 @@ def _start_client_internal(
|
|
|
263
265
|
transport, server_address
|
|
264
266
|
)
|
|
265
267
|
|
|
268
|
+
app_state_tracker = _AppStateTracker()
|
|
269
|
+
|
|
270
|
+
def _on_sucess(retry_state: RetryState) -> None:
|
|
271
|
+
app_state_tracker.is_connected = True
|
|
272
|
+
if retry_state.tries > 1:
|
|
273
|
+
log(
|
|
274
|
+
INFO,
|
|
275
|
+
"Connection successful after %.2f seconds and %s tries.",
|
|
276
|
+
retry_state.elapsed_time,
|
|
277
|
+
retry_state.tries,
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
def _on_backoff(retry_state: RetryState) -> None:
|
|
281
|
+
app_state_tracker.is_connected = False
|
|
282
|
+
if retry_state.tries == 1:
|
|
283
|
+
log(WARN, "Connection attempt failed, retrying...")
|
|
284
|
+
else:
|
|
285
|
+
log(
|
|
286
|
+
DEBUG,
|
|
287
|
+
"Connection attempt failed, retrying in %.2f seconds",
|
|
288
|
+
retry_state.actual_wait,
|
|
289
|
+
)
|
|
290
|
+
|
|
266
291
|
retry_invoker = RetryInvoker(
|
|
267
292
|
wait_gen_factory=exponential,
|
|
268
293
|
recoverable_exceptions=connection_error_type,
|
|
@@ -278,30 +303,13 @@ def _start_client_internal(
|
|
|
278
303
|
if retry_state.tries > 1
|
|
279
304
|
else None
|
|
280
305
|
),
|
|
281
|
-
on_success=
|
|
282
|
-
|
|
283
|
-
INFO,
|
|
284
|
-
"Connection successful after %.2f seconds and %s tries.",
|
|
285
|
-
retry_state.elapsed_time,
|
|
286
|
-
retry_state.tries,
|
|
287
|
-
)
|
|
288
|
-
if retry_state.tries > 1
|
|
289
|
-
else None
|
|
290
|
-
),
|
|
291
|
-
on_backoff=lambda retry_state: (
|
|
292
|
-
log(WARN, "Connection attempt failed, retrying...")
|
|
293
|
-
if retry_state.tries == 1
|
|
294
|
-
else log(
|
|
295
|
-
DEBUG,
|
|
296
|
-
"Connection attempt failed, retrying in %.2f seconds",
|
|
297
|
-
retry_state.actual_wait,
|
|
298
|
-
)
|
|
299
|
-
),
|
|
306
|
+
on_success=_on_sucess,
|
|
307
|
+
on_backoff=_on_backoff,
|
|
300
308
|
)
|
|
301
309
|
|
|
302
310
|
node_state = NodeState()
|
|
303
311
|
|
|
304
|
-
while
|
|
312
|
+
while not app_state_tracker.interrupt:
|
|
305
313
|
sleep_duration: int = 0
|
|
306
314
|
with connection(
|
|
307
315
|
address,
|
|
@@ -318,99 +326,112 @@ def _start_client_internal(
|
|
|
318
326
|
if create_node is not None:
|
|
319
327
|
create_node() # pylint: disable=not-callable
|
|
320
328
|
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
329
|
+
app_state_tracker.register_signal_handler()
|
|
330
|
+
while not app_state_tracker.interrupt:
|
|
331
|
+
try:
|
|
332
|
+
# Receive
|
|
333
|
+
message = receive()
|
|
334
|
+
if message is None:
|
|
335
|
+
time.sleep(3) # Wait for 3s before asking again
|
|
336
|
+
continue
|
|
337
|
+
|
|
338
|
+
log(INFO, "")
|
|
339
|
+
if len(message.metadata.group_id) > 0:
|
|
340
|
+
log(
|
|
341
|
+
INFO,
|
|
342
|
+
"[RUN %s, ROUND %s]",
|
|
343
|
+
message.metadata.run_id,
|
|
344
|
+
message.metadata.group_id,
|
|
345
|
+
)
|
|
330
346
|
log(
|
|
331
347
|
INFO,
|
|
332
|
-
"
|
|
333
|
-
message.metadata.
|
|
334
|
-
message.metadata.
|
|
348
|
+
"Received: %s message %s",
|
|
349
|
+
message.metadata.message_type,
|
|
350
|
+
message.metadata.message_id,
|
|
335
351
|
)
|
|
336
|
-
log(
|
|
337
|
-
INFO,
|
|
338
|
-
"Received: %s message %s",
|
|
339
|
-
message.metadata.message_type,
|
|
340
|
-
message.metadata.message_id,
|
|
341
|
-
)
|
|
342
|
-
|
|
343
|
-
# Handle control message
|
|
344
|
-
out_message, sleep_duration = handle_control_message(message)
|
|
345
|
-
if out_message:
|
|
346
|
-
send(out_message)
|
|
347
|
-
break
|
|
348
|
-
|
|
349
|
-
# Register context for this run
|
|
350
|
-
node_state.register_context(run_id=message.metadata.run_id)
|
|
351
|
-
|
|
352
|
-
# Retrieve context for this run
|
|
353
|
-
context = node_state.retrieve_context(run_id=message.metadata.run_id)
|
|
354
352
|
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
353
|
+
# Handle control message
|
|
354
|
+
out_message, sleep_duration = handle_control_message(message)
|
|
355
|
+
if out_message:
|
|
356
|
+
send(out_message)
|
|
357
|
+
break
|
|
360
358
|
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
# Load ClientApp instance
|
|
364
|
-
client_app: ClientApp = load_client_app_fn()
|
|
365
|
-
|
|
366
|
-
# Execute ClientApp
|
|
367
|
-
reply_message = client_app(message=message, context=context)
|
|
368
|
-
except Exception as ex: # pylint: disable=broad-exception-caught
|
|
369
|
-
|
|
370
|
-
# Legacy grpc-bidi
|
|
371
|
-
if transport in ["grpc-bidi", None]:
|
|
372
|
-
log(ERROR, "Client raised an exception.", exc_info=ex)
|
|
373
|
-
# Raise exception, crash process
|
|
374
|
-
raise ex
|
|
375
|
-
|
|
376
|
-
# Don't update/change NodeState
|
|
377
|
-
|
|
378
|
-
e_code = ErrorCode.CLIENT_APP_RAISED_EXCEPTION
|
|
379
|
-
# Reason example: "<class 'ZeroDivisionError'>:<'division by zero'>"
|
|
380
|
-
reason = str(type(ex)) + ":<'" + str(ex) + "'>"
|
|
381
|
-
exc_entity = "ClientApp"
|
|
382
|
-
if isinstance(ex, LoadClientAppError):
|
|
383
|
-
reason = (
|
|
384
|
-
"An exception was raised when attempting to load "
|
|
385
|
-
"`ClientApp`"
|
|
386
|
-
)
|
|
387
|
-
e_code = ErrorCode.LOAD_CLIENT_APP_EXCEPTION
|
|
388
|
-
exc_entity = "SuperNode"
|
|
359
|
+
# Register context for this run
|
|
360
|
+
node_state.register_context(run_id=message.metadata.run_id)
|
|
389
361
|
|
|
390
|
-
|
|
362
|
+
# Retrieve context for this run
|
|
363
|
+
context = node_state.retrieve_context(
|
|
364
|
+
run_id=message.metadata.run_id
|
|
365
|
+
)
|
|
391
366
|
|
|
392
|
-
# Create error message
|
|
367
|
+
# Create an error reply message that will never be used to prevent
|
|
368
|
+
# the used-before-assignment linting error
|
|
393
369
|
reply_message = message.create_error_reply(
|
|
394
|
-
error=Error(code=
|
|
395
|
-
)
|
|
396
|
-
else:
|
|
397
|
-
# No exception, update node state
|
|
398
|
-
node_state.update_context(
|
|
399
|
-
run_id=message.metadata.run_id,
|
|
400
|
-
context=context,
|
|
370
|
+
error=Error(code=ErrorCode.UNKNOWN, reason="Unknown")
|
|
401
371
|
)
|
|
402
372
|
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
373
|
+
# Handle app loading and task message
|
|
374
|
+
try:
|
|
375
|
+
# Load ClientApp instance
|
|
376
|
+
client_app: ClientApp = load_client_app_fn()
|
|
377
|
+
|
|
378
|
+
# Execute ClientApp
|
|
379
|
+
reply_message = client_app(message=message, context=context)
|
|
380
|
+
except Exception as ex: # pylint: disable=broad-exception-caught
|
|
381
|
+
|
|
382
|
+
# Legacy grpc-bidi
|
|
383
|
+
if transport in ["grpc-bidi", None]:
|
|
384
|
+
log(ERROR, "Client raised an exception.", exc_info=ex)
|
|
385
|
+
# Raise exception, crash process
|
|
386
|
+
raise ex
|
|
387
|
+
|
|
388
|
+
# Don't update/change NodeState
|
|
389
|
+
|
|
390
|
+
e_code = ErrorCode.CLIENT_APP_RAISED_EXCEPTION
|
|
391
|
+
# Ex fmt: "<class 'ZeroDivisionError'>:<'division by zero'>"
|
|
392
|
+
reason = str(type(ex)) + ":<'" + str(ex) + "'>"
|
|
393
|
+
exc_entity = "ClientApp"
|
|
394
|
+
if isinstance(ex, LoadClientAppError):
|
|
395
|
+
reason = (
|
|
396
|
+
"An exception was raised when attempting to load "
|
|
397
|
+
"`ClientApp`"
|
|
398
|
+
)
|
|
399
|
+
e_code = ErrorCode.LOAD_CLIENT_APP_EXCEPTION
|
|
400
|
+
exc_entity = "SuperNode"
|
|
401
|
+
|
|
402
|
+
if not app_state_tracker.interrupt:
|
|
403
|
+
log(
|
|
404
|
+
ERROR, "%s raised an exception", exc_entity, exc_info=ex
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
# Create error message
|
|
408
|
+
reply_message = message.create_error_reply(
|
|
409
|
+
error=Error(code=e_code, reason=reason)
|
|
410
|
+
)
|
|
411
|
+
else:
|
|
412
|
+
# No exception, update node state
|
|
413
|
+
node_state.update_context(
|
|
414
|
+
run_id=message.metadata.run_id,
|
|
415
|
+
context=context,
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
# Send
|
|
419
|
+
send(reply_message)
|
|
420
|
+
log(INFO, "Sent reply")
|
|
421
|
+
|
|
422
|
+
except StopIteration:
|
|
423
|
+
sleep_duration = 0
|
|
424
|
+
break
|
|
406
425
|
|
|
407
426
|
# Unregister node
|
|
408
|
-
if delete_node is not None:
|
|
427
|
+
if delete_node is not None and app_state_tracker.is_connected:
|
|
409
428
|
delete_node() # pylint: disable=not-callable
|
|
410
429
|
|
|
411
430
|
if sleep_duration == 0:
|
|
412
431
|
log(INFO, "Disconnect and shut down")
|
|
432
|
+
del app_state_tracker
|
|
413
433
|
break
|
|
434
|
+
|
|
414
435
|
# Sleep and reconnect afterwards
|
|
415
436
|
log(
|
|
416
437
|
INFO,
|
|
@@ -579,3 +600,20 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
|
|
|
579
600
|
)
|
|
580
601
|
|
|
581
602
|
return connection, address, error_type
|
|
603
|
+
|
|
604
|
+
|
|
605
|
+
@dataclass
|
|
606
|
+
class _AppStateTracker:
|
|
607
|
+
interrupt: bool = False
|
|
608
|
+
is_connected: bool = False
|
|
609
|
+
|
|
610
|
+
def register_signal_handler(self) -> None:
|
|
611
|
+
"""Register handlers for exit signals."""
|
|
612
|
+
|
|
613
|
+
def signal_handler(sig, frame): # type: ignore
|
|
614
|
+
# pylint: disable=unused-argument
|
|
615
|
+
self.interrupt = True
|
|
616
|
+
raise StopIteration from None
|
|
617
|
+
|
|
618
|
+
signal.signal(signal.SIGINT, signal_handler)
|
|
619
|
+
signal.signal(signal.SIGTERM, signal_handler)
|
flwr/client/client_app.py
CHANGED
|
@@ -31,11 +31,11 @@ from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
|
31
31
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
32
32
|
CreateNodeRequest,
|
|
33
33
|
DeleteNodeRequest,
|
|
34
|
-
GetRunRequest,
|
|
35
34
|
PingRequest,
|
|
36
35
|
PullTaskInsRequest,
|
|
37
36
|
PushTaskResRequest,
|
|
38
37
|
)
|
|
38
|
+
from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611
|
|
39
39
|
|
|
40
40
|
_PUBLIC_KEY_HEADER = "public-key"
|
|
41
41
|
_AUTH_TOKEN_HEADER = "auth-token"
|
|
@@ -44,8 +44,6 @@ from flwr.common.serde import message_from_taskins, message_to_taskres
|
|
|
44
44
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
45
45
|
CreateNodeRequest,
|
|
46
46
|
DeleteNodeRequest,
|
|
47
|
-
GetRunRequest,
|
|
48
|
-
GetRunResponse,
|
|
49
47
|
PingRequest,
|
|
50
48
|
PingResponse,
|
|
51
49
|
PullTaskInsRequest,
|
|
@@ -53,6 +51,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
53
51
|
)
|
|
54
52
|
from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
|
|
55
53
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
54
|
+
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
56
55
|
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
57
56
|
|
|
58
57
|
from .client_interceptor import AuthenticateClientInterceptor
|
|
@@ -193,8 +192,6 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
|
193
192
|
|
|
194
193
|
# Stop the ping-loop thread
|
|
195
194
|
ping_stop_event.set()
|
|
196
|
-
if ping_thread is not None:
|
|
197
|
-
ping_thread.join()
|
|
198
195
|
|
|
199
196
|
# Call FleetAPI
|
|
200
197
|
delete_node_request = DeleteNodeRequest(node=node)
|
flwr/client/mod/__init__.py
CHANGED
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""Mods."""
|
|
15
|
+
"""Flower Built-in Mods."""
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from .centraldp_mods import adaptiveclipping_mod, fixedclipping_mod
|
|
@@ -46,8 +46,6 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
46
46
|
CreateNodeResponse,
|
|
47
47
|
DeleteNodeRequest,
|
|
48
48
|
DeleteNodeResponse,
|
|
49
|
-
GetRunRequest,
|
|
50
|
-
GetRunResponse,
|
|
51
49
|
PingRequest,
|
|
52
50
|
PingResponse,
|
|
53
51
|
PullTaskInsRequest,
|
|
@@ -56,6 +54,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
56
54
|
PushTaskResResponse,
|
|
57
55
|
)
|
|
58
56
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
57
|
+
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
59
58
|
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
60
59
|
|
|
61
60
|
try:
|
flwr/client/supernode/app.py
CHANGED
|
@@ -30,11 +30,13 @@ from cryptography.hazmat.primitives.serialization import (
|
|
|
30
30
|
from flwr.client.client_app import ClientApp, LoadClientAppError
|
|
31
31
|
from flwr.common import EventType, event
|
|
32
32
|
from flwr.common.exit_handlers import register_exit_handlers
|
|
33
|
-
from flwr.common.logger import log
|
|
33
|
+
from flwr.common.logger import log, warn_deprecated_feature
|
|
34
34
|
from flwr.common.object_ref import load_app, validate
|
|
35
35
|
|
|
36
36
|
from ..app import _start_client_internal
|
|
37
37
|
|
|
38
|
+
ADDRESS_FLEET_API_GRPC_RERE = "0.0.0.0:9092"
|
|
39
|
+
|
|
38
40
|
|
|
39
41
|
def run_supernode() -> None:
|
|
40
42
|
"""Run Flower SuperNode."""
|
|
@@ -63,6 +65,23 @@ def run_client_app() -> None:
|
|
|
63
65
|
|
|
64
66
|
args = _parse_args_run_client_app().parse_args()
|
|
65
67
|
|
|
68
|
+
if args.server != ADDRESS_FLEET_API_GRPC_RERE:
|
|
69
|
+
warn = "Passing flag --server is deprecated. Use --superlink instead."
|
|
70
|
+
warn_deprecated_feature(warn)
|
|
71
|
+
|
|
72
|
+
if args.superlink != ADDRESS_FLEET_API_GRPC_RERE:
|
|
73
|
+
# if `--superlink` also passed, then
|
|
74
|
+
# warn user that this argument overrides what was passed with `--server`
|
|
75
|
+
log(
|
|
76
|
+
WARN,
|
|
77
|
+
"Both `--server` and `--superlink` were passed. "
|
|
78
|
+
"`--server` will be ignored. Connecting to the Superlink Fleet API "
|
|
79
|
+
"at %s.",
|
|
80
|
+
args.superlink,
|
|
81
|
+
)
|
|
82
|
+
else:
|
|
83
|
+
args.superlink = args.server
|
|
84
|
+
|
|
66
85
|
root_certificates = _get_certificates(args)
|
|
67
86
|
log(
|
|
68
87
|
DEBUG,
|
|
@@ -73,7 +92,7 @@ def run_client_app() -> None:
|
|
|
73
92
|
authentication_keys = _try_setup_client_authentication(args)
|
|
74
93
|
|
|
75
94
|
_start_client_internal(
|
|
76
|
-
server_address=args.
|
|
95
|
+
server_address=args.superlink,
|
|
77
96
|
load_client_app_fn=load_fn,
|
|
78
97
|
transport="rest" if args.rest else "grpc-rere",
|
|
79
98
|
root_certificates=root_certificates,
|
|
@@ -100,7 +119,7 @@ def _get_certificates(args: argparse.Namespace) -> Optional[bytes]:
|
|
|
100
119
|
WARN,
|
|
101
120
|
"Option `--insecure` was set. "
|
|
102
121
|
"Starting insecure HTTP client connected to %s.",
|
|
103
|
-
args.
|
|
122
|
+
args.superlink,
|
|
104
123
|
)
|
|
105
124
|
root_certificates = None
|
|
106
125
|
else:
|
|
@@ -114,7 +133,7 @@ def _get_certificates(args: argparse.Namespace) -> Optional[bytes]:
|
|
|
114
133
|
DEBUG,
|
|
115
134
|
"Starting secure HTTPS client connected to %s "
|
|
116
135
|
"with the following certificates: %s.",
|
|
117
|
-
args.
|
|
136
|
+
args.superlink,
|
|
118
137
|
cert_path,
|
|
119
138
|
)
|
|
120
139
|
return root_certificates
|
|
@@ -213,9 +232,14 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
|
|
|
213
232
|
)
|
|
214
233
|
parser.add_argument(
|
|
215
234
|
"--server",
|
|
216
|
-
default=
|
|
235
|
+
default=ADDRESS_FLEET_API_GRPC_RERE,
|
|
217
236
|
help="Server address",
|
|
218
237
|
)
|
|
238
|
+
parser.add_argument(
|
|
239
|
+
"--superlink",
|
|
240
|
+
default=ADDRESS_FLEET_API_GRPC_RERE,
|
|
241
|
+
help="SuperLink Fleet API (gRPC-rere) address (IPv4, IPv6, or a domain name)",
|
|
242
|
+
)
|
|
219
243
|
parser.add_argument(
|
|
220
244
|
"--max-retries",
|
|
221
245
|
type=int,
|
flwr/common/object_ref.py
CHANGED
|
@@ -30,6 +30,7 @@ attribute.
|
|
|
30
30
|
|
|
31
31
|
def validate(
|
|
32
32
|
module_attribute_str: str,
|
|
33
|
+
check_module: bool = True,
|
|
33
34
|
) -> Tuple[bool, Optional[str]]:
|
|
34
35
|
"""Validate object reference.
|
|
35
36
|
|
|
@@ -56,15 +57,18 @@ def validate(
|
|
|
56
57
|
f"Missing attribute in {module_attribute_str}{OBJECT_REF_HELP_STR}",
|
|
57
58
|
)
|
|
58
59
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
if
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
60
|
+
if check_module:
|
|
61
|
+
# Load module
|
|
62
|
+
module = find_spec(module_str)
|
|
63
|
+
if module and module.origin:
|
|
64
|
+
if not _find_attribute_in_module(module.origin, attributes_str):
|
|
65
|
+
return (
|
|
66
|
+
False,
|
|
67
|
+
f"Unable to find attribute {attributes_str} in module {module_str}"
|
|
68
|
+
f"{OBJECT_REF_HELP_STR}",
|
|
69
|
+
)
|
|
70
|
+
return (True, None)
|
|
71
|
+
else:
|
|
68
72
|
return (True, None)
|
|
69
73
|
|
|
70
74
|
return (
|
flwr/proto/driver_pb2.py
CHANGED
|
@@ -14,31 +14,32 @@ _sym_db = _symbol_database.Default()
|
|
|
14
14
|
|
|
15
15
|
from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2
|
|
16
16
|
from flwr.proto import task_pb2 as flwr_dot_proto_dot_task__pb2
|
|
17
|
+
from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2
|
|
17
18
|
|
|
18
19
|
|
|
19
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"7\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\
|
|
20
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\"7\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\x84\x03\n\x06\x44river\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x62\x06proto3')
|
|
20
21
|
|
|
21
22
|
_globals = globals()
|
|
22
23
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
23
24
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.driver_pb2', _globals)
|
|
24
25
|
if _descriptor._USE_C_DESCRIPTORS == False:
|
|
25
26
|
DESCRIPTOR._options = None
|
|
26
|
-
_globals['_CREATERUNREQUEST']._serialized_start=
|
|
27
|
-
_globals['_CREATERUNREQUEST']._serialized_end=
|
|
28
|
-
_globals['_CREATERUNRESPONSE']._serialized_start=
|
|
29
|
-
_globals['_CREATERUNRESPONSE']._serialized_end=
|
|
30
|
-
_globals['_GETNODESREQUEST']._serialized_start=
|
|
31
|
-
_globals['_GETNODESREQUEST']._serialized_end=
|
|
32
|
-
_globals['_GETNODESRESPONSE']._serialized_start=
|
|
33
|
-
_globals['_GETNODESRESPONSE']._serialized_end=
|
|
34
|
-
_globals['_PUSHTASKINSREQUEST']._serialized_start=
|
|
35
|
-
_globals['_PUSHTASKINSREQUEST']._serialized_end=
|
|
36
|
-
_globals['_PUSHTASKINSRESPONSE']._serialized_start=
|
|
37
|
-
_globals['_PUSHTASKINSRESPONSE']._serialized_end=
|
|
38
|
-
_globals['_PULLTASKRESREQUEST']._serialized_start=
|
|
39
|
-
_globals['_PULLTASKRESREQUEST']._serialized_end=
|
|
40
|
-
_globals['_PULLTASKRESRESPONSE']._serialized_start=
|
|
41
|
-
_globals['_PULLTASKRESRESPONSE']._serialized_end=
|
|
42
|
-
_globals['_DRIVER']._serialized_start=
|
|
43
|
-
_globals['_DRIVER']._serialized_end=
|
|
27
|
+
_globals['_CREATERUNREQUEST']._serialized_start=107
|
|
28
|
+
_globals['_CREATERUNREQUEST']._serialized_end=162
|
|
29
|
+
_globals['_CREATERUNRESPONSE']._serialized_start=164
|
|
30
|
+
_globals['_CREATERUNRESPONSE']._serialized_end=199
|
|
31
|
+
_globals['_GETNODESREQUEST']._serialized_start=201
|
|
32
|
+
_globals['_GETNODESREQUEST']._serialized_end=234
|
|
33
|
+
_globals['_GETNODESRESPONSE']._serialized_start=236
|
|
34
|
+
_globals['_GETNODESRESPONSE']._serialized_end=287
|
|
35
|
+
_globals['_PUSHTASKINSREQUEST']._serialized_start=289
|
|
36
|
+
_globals['_PUSHTASKINSREQUEST']._serialized_end=353
|
|
37
|
+
_globals['_PUSHTASKINSRESPONSE']._serialized_start=355
|
|
38
|
+
_globals['_PUSHTASKINSRESPONSE']._serialized_end=394
|
|
39
|
+
_globals['_PULLTASKRESREQUEST']._serialized_start=396
|
|
40
|
+
_globals['_PULLTASKRESREQUEST']._serialized_end=466
|
|
41
|
+
_globals['_PULLTASKRESRESPONSE']._serialized_start=468
|
|
42
|
+
_globals['_PULLTASKRESRESPONSE']._serialized_end=533
|
|
43
|
+
_globals['_DRIVER']._serialized_start=536
|
|
44
|
+
_globals['_DRIVER']._serialized_end=924
|
|
44
45
|
# @@protoc_insertion_point(module_scope)
|
flwr/proto/driver_pb2_grpc.py
CHANGED
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
import grpc
|
|
4
4
|
|
|
5
5
|
from flwr.proto import driver_pb2 as flwr_dot_proto_dot_driver__pb2
|
|
6
|
+
from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
class DriverStub(object):
|
|
@@ -34,6 +35,11 @@ class DriverStub(object):
|
|
|
34
35
|
request_serializer=flwr_dot_proto_dot_driver__pb2.PullTaskResRequest.SerializeToString,
|
|
35
36
|
response_deserializer=flwr_dot_proto_dot_driver__pb2.PullTaskResResponse.FromString,
|
|
36
37
|
)
|
|
38
|
+
self.GetRun = channel.unary_unary(
|
|
39
|
+
'/flwr.proto.Driver/GetRun',
|
|
40
|
+
request_serializer=flwr_dot_proto_dot_run__pb2.GetRunRequest.SerializeToString,
|
|
41
|
+
response_deserializer=flwr_dot_proto_dot_run__pb2.GetRunResponse.FromString,
|
|
42
|
+
)
|
|
37
43
|
|
|
38
44
|
|
|
39
45
|
class DriverServicer(object):
|
|
@@ -67,6 +73,13 @@ class DriverServicer(object):
|
|
|
67
73
|
context.set_details('Method not implemented!')
|
|
68
74
|
raise NotImplementedError('Method not implemented!')
|
|
69
75
|
|
|
76
|
+
def GetRun(self, request, context):
|
|
77
|
+
"""Get run details
|
|
78
|
+
"""
|
|
79
|
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
80
|
+
context.set_details('Method not implemented!')
|
|
81
|
+
raise NotImplementedError('Method not implemented!')
|
|
82
|
+
|
|
70
83
|
|
|
71
84
|
def add_DriverServicer_to_server(servicer, server):
|
|
72
85
|
rpc_method_handlers = {
|
|
@@ -90,6 +103,11 @@ def add_DriverServicer_to_server(servicer, server):
|
|
|
90
103
|
request_deserializer=flwr_dot_proto_dot_driver__pb2.PullTaskResRequest.FromString,
|
|
91
104
|
response_serializer=flwr_dot_proto_dot_driver__pb2.PullTaskResResponse.SerializeToString,
|
|
92
105
|
),
|
|
106
|
+
'GetRun': grpc.unary_unary_rpc_method_handler(
|
|
107
|
+
servicer.GetRun,
|
|
108
|
+
request_deserializer=flwr_dot_proto_dot_run__pb2.GetRunRequest.FromString,
|
|
109
|
+
response_serializer=flwr_dot_proto_dot_run__pb2.GetRunResponse.SerializeToString,
|
|
110
|
+
),
|
|
93
111
|
}
|
|
94
112
|
generic_handler = grpc.method_handlers_generic_handler(
|
|
95
113
|
'flwr.proto.Driver', rpc_method_handlers)
|
|
@@ -167,3 +185,20 @@ class Driver(object):
|
|
|
167
185
|
flwr_dot_proto_dot_driver__pb2.PullTaskResResponse.FromString,
|
|
168
186
|
options, channel_credentials,
|
|
169
187
|
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
188
|
+
|
|
189
|
+
@staticmethod
|
|
190
|
+
def GetRun(request,
|
|
191
|
+
target,
|
|
192
|
+
options=(),
|
|
193
|
+
channel_credentials=None,
|
|
194
|
+
call_credentials=None,
|
|
195
|
+
insecure=False,
|
|
196
|
+
compression=None,
|
|
197
|
+
wait_for_ready=None,
|
|
198
|
+
timeout=None,
|
|
199
|
+
metadata=None):
|
|
200
|
+
return grpc.experimental.unary_unary(request, target, '/flwr.proto.Driver/GetRun',
|
|
201
|
+
flwr_dot_proto_dot_run__pb2.GetRunRequest.SerializeToString,
|
|
202
|
+
flwr_dot_proto_dot_run__pb2.GetRunResponse.FromString,
|
|
203
|
+
options, channel_credentials,
|
|
204
|
+
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
flwr/proto/driver_pb2_grpc.pyi
CHANGED
|
@@ -4,6 +4,7 @@ isort:skip_file
|
|
|
4
4
|
"""
|
|
5
5
|
import abc
|
|
6
6
|
import flwr.proto.driver_pb2
|
|
7
|
+
import flwr.proto.run_pb2
|
|
7
8
|
import grpc
|
|
8
9
|
|
|
9
10
|
class DriverStub:
|
|
@@ -28,6 +29,11 @@ class DriverStub:
|
|
|
28
29
|
flwr.proto.driver_pb2.PullTaskResResponse]
|
|
29
30
|
"""Get task results"""
|
|
30
31
|
|
|
32
|
+
GetRun: grpc.UnaryUnaryMultiCallable[
|
|
33
|
+
flwr.proto.run_pb2.GetRunRequest,
|
|
34
|
+
flwr.proto.run_pb2.GetRunResponse]
|
|
35
|
+
"""Get run details"""
|
|
36
|
+
|
|
31
37
|
|
|
32
38
|
class DriverServicer(metaclass=abc.ABCMeta):
|
|
33
39
|
@abc.abstractmethod
|
|
@@ -62,5 +68,13 @@ class DriverServicer(metaclass=abc.ABCMeta):
|
|
|
62
68
|
"""Get task results"""
|
|
63
69
|
pass
|
|
64
70
|
|
|
71
|
+
@abc.abstractmethod
|
|
72
|
+
def GetRun(self,
|
|
73
|
+
request: flwr.proto.run_pb2.GetRunRequest,
|
|
74
|
+
context: grpc.ServicerContext,
|
|
75
|
+
) -> flwr.proto.run_pb2.GetRunResponse:
|
|
76
|
+
"""Get run details"""
|
|
77
|
+
pass
|
|
78
|
+
|
|
65
79
|
|
|
66
80
|
def add_DriverServicer_to_server(servicer: DriverServicer, server: grpc.Server) -> None: ...
|