flwr-nightly 1.10.0.dev20240624__py3-none-any.whl → 1.10.0.dev20240722__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (95) hide show
  1. flwr/cli/build.py +18 -4
  2. flwr/cli/config_utils.py +36 -14
  3. flwr/cli/install.py +17 -1
  4. flwr/cli/new/new.py +31 -20
  5. flwr/cli/new/templates/app/code/client.hf.py.tpl +11 -3
  6. flwr/cli/new/templates/app/code/client.jax.py.tpl +2 -1
  7. flwr/cli/new/templates/app/code/client.mlx.py.tpl +15 -10
  8. flwr/cli/new/templates/app/code/client.numpy.py.tpl +2 -1
  9. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +12 -3
  10. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +6 -3
  11. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +13 -3
  12. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +2 -2
  13. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +1 -1
  14. flwr/cli/new/templates/app/code/server.hf.py.tpl +16 -11
  15. flwr/cli/new/templates/app/code/server.jax.py.tpl +15 -8
  16. flwr/cli/new/templates/app/code/server.mlx.py.tpl +11 -7
  17. flwr/cli/new/templates/app/code/server.numpy.py.tpl +15 -8
  18. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +15 -13
  19. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +16 -10
  20. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +16 -13
  21. flwr/cli/new/templates/app/code/task.hf.py.tpl +2 -2
  22. flwr/cli/new/templates/app/code/task.mlx.py.tpl +2 -2
  23. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -1
  24. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +9 -12
  25. flwr/cli/new/templates/app/pyproject.hf.toml.tpl +17 -16
  26. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +17 -11
  27. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +17 -12
  28. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +12 -12
  29. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +13 -12
  30. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +12 -12
  31. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +15 -12
  32. flwr/cli/run/run.py +135 -51
  33. flwr/client/__init__.py +2 -0
  34. flwr/client/app.py +63 -26
  35. flwr/client/client_app.py +49 -4
  36. flwr/client/grpc_adapter_client/connection.py +3 -2
  37. flwr/client/grpc_client/connection.py +3 -2
  38. flwr/client/grpc_rere_client/connection.py +17 -6
  39. flwr/client/message_handler/message_handler.py +3 -4
  40. flwr/client/node_state.py +60 -10
  41. flwr/client/node_state_tests.py +4 -3
  42. flwr/client/rest_client/connection.py +19 -8
  43. flwr/client/supernode/app.py +60 -21
  44. flwr/client/typing.py +1 -0
  45. flwr/common/config.py +87 -2
  46. flwr/common/constant.py +6 -0
  47. flwr/common/context.py +26 -1
  48. flwr/common/logger.py +38 -0
  49. flwr/common/message.py +0 -17
  50. flwr/common/serde.py +45 -0
  51. flwr/common/telemetry.py +17 -0
  52. flwr/common/typing.py +5 -0
  53. flwr/proto/common_pb2.py +36 -0
  54. flwr/proto/common_pb2.pyi +121 -0
  55. flwr/proto/common_pb2_grpc.py +4 -0
  56. flwr/proto/common_pb2_grpc.pyi +4 -0
  57. flwr/proto/driver_pb2.py +24 -19
  58. flwr/proto/driver_pb2.pyi +21 -1
  59. flwr/proto/exec_pb2.py +16 -11
  60. flwr/proto/exec_pb2.pyi +22 -1
  61. flwr/proto/run_pb2.py +12 -7
  62. flwr/proto/run_pb2.pyi +22 -1
  63. flwr/proto/task_pb2.py +7 -8
  64. flwr/server/__init__.py +2 -0
  65. flwr/server/compat/legacy_context.py +5 -4
  66. flwr/server/driver/grpc_driver.py +82 -140
  67. flwr/server/run_serverapp.py +40 -15
  68. flwr/server/server_app.py +56 -10
  69. flwr/server/serverapp_components.py +52 -0
  70. flwr/server/superlink/driver/driver_servicer.py +18 -3
  71. flwr/server/superlink/fleet/message_handler/message_handler.py +13 -2
  72. flwr/server/superlink/fleet/vce/backend/backend.py +4 -4
  73. flwr/server/superlink/fleet/vce/backend/raybackend.py +10 -10
  74. flwr/server/superlink/fleet/vce/vce_api.py +149 -122
  75. flwr/server/superlink/state/in_memory_state.py +15 -7
  76. flwr/server/superlink/state/sqlite_state.py +27 -12
  77. flwr/server/superlink/state/state.py +7 -2
  78. flwr/server/superlink/state/utils.py +6 -0
  79. flwr/server/typing.py +2 -0
  80. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -2
  81. flwr/simulation/app.py +52 -36
  82. flwr/simulation/ray_transport/ray_actor.py +15 -19
  83. flwr/simulation/ray_transport/ray_client_proxy.py +33 -13
  84. flwr/simulation/run_simulation.py +237 -66
  85. flwr/superexec/app.py +14 -7
  86. flwr/superexec/deployment.py +186 -0
  87. flwr/superexec/exec_grpc.py +5 -1
  88. flwr/superexec/exec_servicer.py +4 -1
  89. flwr/superexec/executor.py +18 -0
  90. flwr/superexec/simulation.py +151 -0
  91. {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/METADATA +3 -2
  92. {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/RECORD +95 -88
  93. {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/LICENSE +0 -0
  94. {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/WHEEL +0 -0
  95. {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/entry_points.txt +0 -0
flwr/server/typing.py CHANGED
@@ -20,6 +20,8 @@ from typing import Callable
20
20
  from flwr.common import Context
21
21
 
22
22
  from .driver import Driver
23
+ from .serverapp_components import ServerAppComponents
23
24
 
24
25
  ServerAppCallable = Callable[[Driver, Context], None]
25
26
  Workflow = Callable[[Driver, Context], None]
27
+ ServerFn = Callable[[Context], ServerAppComponents]
@@ -81,6 +81,7 @@ class WorkflowState: # pylint: disable=R0902
81
81
  forward_ciphertexts: Dict[int, List[bytes]] = field(default_factory=dict)
82
82
  aggregate_ndarrays: NDArrays = field(default_factory=list)
83
83
  legacy_results: List[Tuple[ClientProxy, FitRes]] = field(default_factory=list)
84
+ failures: List[Exception] = field(default_factory=list)
84
85
 
85
86
 
86
87
  class SecAggPlusWorkflow:
@@ -394,6 +395,7 @@ class SecAggPlusWorkflow:
394
395
 
395
396
  for msg in msgs:
396
397
  if msg.has_error():
398
+ state.failures.append(Exception(msg.error))
397
399
  continue
398
400
  key_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
399
401
  node_id = msg.metadata.src_node_id
@@ -451,6 +453,9 @@ class SecAggPlusWorkflow:
451
453
  nid: [] for nid in state.active_node_ids
452
454
  } # dest node ID -> list of src node IDs
453
455
  for msg in msgs:
456
+ if msg.has_error():
457
+ state.failures.append(Exception(msg.error))
458
+ continue
454
459
  node_id = msg.metadata.src_node_id
455
460
  res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
456
461
  dst_lst = cast(List[int], res_dict[Key.DESTINATION_LIST])
@@ -515,6 +520,9 @@ class SecAggPlusWorkflow:
515
520
  # Sum collected masked vectors and compute active/dead node IDs
516
521
  masked_vector = None
517
522
  for msg in msgs:
523
+ if msg.has_error():
524
+ state.failures.append(Exception(msg.error))
525
+ continue
518
526
  res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
519
527
  bytes_list = cast(List[bytes], res_dict[Key.MASKED_PARAMETERS])
520
528
  client_masked_vec = [bytes_to_ndarray(b) for b in bytes_list]
@@ -528,6 +536,9 @@ class SecAggPlusWorkflow:
528
536
 
529
537
  # Backward compatibility with Strategy
530
538
  for msg in msgs:
539
+ if msg.has_error():
540
+ state.failures.append(Exception(msg.error))
541
+ continue
531
542
  fitres = compat.recordset_to_fitres(msg.content, True)
532
543
  proxy = state.nid_to_proxies[msg.metadata.src_node_id]
533
544
  state.legacy_results.append((proxy, fitres))
@@ -584,6 +595,9 @@ class SecAggPlusWorkflow:
584
595
  for nid in state.sampled_node_ids:
585
596
  collected_shares_dict[nid] = []
586
597
  for msg in msgs:
598
+ if msg.has_error():
599
+ state.failures.append(Exception(msg.error))
600
+ continue
587
601
  res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
588
602
  nids = cast(List[int], res_dict[Key.NODE_ID_LIST])
589
603
  shares = cast(List[bytes], res_dict[Key.SHARE_LIST])
@@ -652,9 +666,11 @@ class SecAggPlusWorkflow:
652
666
  INFO,
653
667
  "aggregate_fit: received %s results and %s failures",
654
668
  len(results),
655
- 0,
669
+ len(state.failures),
670
+ )
671
+ aggregated_result = context.strategy.aggregate_fit(
672
+ current_round, results, state.failures # type: ignore
656
673
  )
657
- aggregated_result = context.strategy.aggregate_fit(current_round, results, [])
658
674
  parameters_aggregated, metrics_aggregated = aggregated_result
659
675
 
660
676
  # Update the parameters and write history
flwr/simulation/app.py CHANGED
@@ -27,14 +27,16 @@ from typing import Any, Dict, List, Optional, Type, Union
27
27
  import ray
28
28
  from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
29
29
 
30
- from flwr.client import ClientFn
30
+ from flwr.client import ClientFnExt
31
31
  from flwr.common import EventType, event
32
- from flwr.common.logger import log, set_logger_propagation
32
+ from flwr.common.constant import NODE_ID_NUM_BYTES
33
+ from flwr.common.logger import log, set_logger_propagation, warn_unsupported_feature
33
34
  from flwr.server.client_manager import ClientManager
34
35
  from flwr.server.history import History
35
36
  from flwr.server.server import Server, init_defaults, run_fl
36
37
  from flwr.server.server_config import ServerConfig
37
38
  from flwr.server.strategy import Strategy
39
+ from flwr.server.superlink.state.utils import generate_rand_int_from_bytes
38
40
  from flwr.simulation.ray_transport.ray_actor import (
39
41
  ClientAppActor,
40
42
  VirtualClientEngineActor,
@@ -51,7 +53,7 @@ Invalid Arguments in method:
51
53
  `start_simulation(
52
54
  *,
53
55
  client_fn: ClientFn,
54
- num_clients: Optional[int] = None,
56
+ num_clients: int,
55
57
  clients_ids: Optional[List[str]] = None,
56
58
  client_resources: Optional[Dict[str, float]] = None,
57
59
  server: Optional[Server] = None,
@@ -70,13 +72,29 @@ REASON:
70
72
 
71
73
  """
72
74
 
75
+ NodeToPartitionMapping = Dict[int, int]
76
+
77
+
78
+ def _create_node_id_to_partition_mapping(
79
+ num_clients: int,
80
+ ) -> NodeToPartitionMapping:
81
+ """Generate a node_id:partition_id mapping."""
82
+ nodes_mapping: NodeToPartitionMapping = {} # {node-id; partition-id}
83
+ for i in range(num_clients):
84
+ while True:
85
+ node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
86
+ if node_id not in nodes_mapping:
87
+ break
88
+ nodes_mapping[node_id] = i
89
+ return nodes_mapping
90
+
73
91
 
74
92
  # pylint: disable=too-many-arguments,too-many-statements,too-many-branches
75
93
  def start_simulation(
76
94
  *,
77
- client_fn: ClientFn,
78
- num_clients: Optional[int] = None,
79
- clients_ids: Optional[List[str]] = None,
95
+ client_fn: ClientFnExt,
96
+ num_clients: int,
97
+ clients_ids: Optional[List[str]] = None, # UNSUPPORTED, WILL BE REMOVED
80
98
  client_resources: Optional[Dict[str, float]] = None,
81
99
  server: Optional[Server] = None,
82
100
  config: Optional[ServerConfig] = None,
@@ -92,23 +110,24 @@ def start_simulation(
92
110
 
93
111
  Parameters
94
112
  ----------
95
- client_fn : ClientFn
96
- A function creating client instances. The function must take a single
97
- `str` argument called `cid`. It should return a single client instance
98
- of type Client. Note that the created client instances are ephemeral
99
- and will often be destroyed after a single method invocation. Since client
100
- instances are not long-lived, they should not attempt to carry state over
101
- method invocations. Any state required by the instance (model, dataset,
102
- hyperparameters, ...) should be (re-)created in either the call to `client_fn`
103
- or the call to any of the client methods (e.g., load evaluation data in the
104
- `evaluate` method itself).
105
- num_clients : Optional[int]
106
- The total number of clients in this simulation. This must be set if
107
- `clients_ids` is not set and vice-versa.
113
+ client_fn : ClientFnExt
114
+ A function creating `Client` instances. The function must have the signature
115
+ `client_fn(context: Context). It should return
116
+ a single client instance of type `Client`. Note that the created client
117
+ instances are ephemeral and will often be destroyed after a single method
118
+ invocation. Since client instances are not long-lived, they should not attempt
119
+ to carry state over method invocations. Any state required by the instance
120
+ (model, dataset, hyperparameters, ...) should be (re-)created in either the
121
+ call to `client_fn` or the call to any of the client methods (e.g., load
122
+ evaluation data in the `evaluate` method itself).
123
+ num_clients : int
124
+ The total number of clients in this simulation.
108
125
  clients_ids : Optional[List[str]]
126
+ UNSUPPORTED, WILL BE REMOVED. USE `num_clients` INSTEAD.
109
127
  List `client_id`s for each client. This is only required if
110
128
  `num_clients` is not set. Setting both `num_clients` and `clients_ids`
111
129
  with `len(clients_ids)` not equal to `num_clients` generates an error.
130
+ Using this argument will raise an error.
112
131
  client_resources : Optional[Dict[str, float]] (default: `{"num_cpus": 1, "num_gpus": 0.0}`)
113
132
  CPU and GPU resources for a single client. Supported keys
114
133
  are `num_cpus` and `num_gpus`. To understand the GPU utilization caused by
@@ -158,7 +177,6 @@ def start_simulation(
158
177
  is an advanced feature. For all details, please refer to the Ray documentation:
159
178
  https://docs.ray.io/en/latest/ray-core/scheduling/index.html
160
179
 
161
-
162
180
  Returns
163
181
  -------
164
182
  hist : flwr.server.history.History
@@ -170,6 +188,14 @@ def start_simulation(
170
188
  {"num_clients": len(clients_ids) if clients_ids is not None else num_clients},
171
189
  )
172
190
 
191
+ if clients_ids is not None:
192
+ warn_unsupported_feature(
193
+ "Passing `clients_ids` to `start_simulation` is deprecated and not longer "
194
+ "used by `start_simulation`. Use `num_clients` exclusively instead."
195
+ )
196
+ log(ERROR, "`clients_ids` argument used.")
197
+ sys.exit()
198
+
173
199
  # Set logger propagation
174
200
  loop: Optional[asyncio.AbstractEventLoop] = None
175
201
  try:
@@ -196,20 +222,8 @@ def start_simulation(
196
222
  initialized_config,
197
223
  )
198
224
 
199
- # clients_ids takes precedence
200
- cids: List[str]
201
- if clients_ids is not None:
202
- if (num_clients is not None) and (len(clients_ids) != num_clients):
203
- log(ERROR, INVALID_ARGUMENTS_START_SIMULATION)
204
- sys.exit()
205
- else:
206
- cids = clients_ids
207
- else:
208
- if num_clients is None:
209
- log(ERROR, INVALID_ARGUMENTS_START_SIMULATION)
210
- sys.exit()
211
- else:
212
- cids = [str(x) for x in range(num_clients)]
225
+ # Create node-id to partition-id mapping
226
+ nodes_mapping = _create_node_id_to_partition_mapping(num_clients)
213
227
 
214
228
  # Default arguments for Ray initialization
215
229
  if not ray_init_args:
@@ -308,10 +322,12 @@ def start_simulation(
308
322
  )
309
323
 
310
324
  # Register one RayClientProxy object for each client with the ClientManager
311
- for cid in cids:
325
+ for node_id, partition_id in nodes_mapping.items():
312
326
  client_proxy = RayActorClientProxy(
313
327
  client_fn=client_fn,
314
- cid=cid,
328
+ node_id=node_id,
329
+ partition_id=partition_id,
330
+ num_partitions=num_clients,
315
331
  actor_pool=pool,
316
332
  )
317
333
  initialized_server.client_manager().register(client=client_proxy)
@@ -14,7 +14,6 @@
14
14
  # ==============================================================================
15
15
  """Ray-based Flower Actor and ActorPool implementation."""
16
16
 
17
- import asyncio
18
17
  import threading
19
18
  from abc import ABC
20
19
  from logging import DEBUG, ERROR, WARNING
@@ -411,9 +410,7 @@ class BasicActorPool:
411
410
  self.client_resources = client_resources
412
411
 
413
412
  # Queue of idle actors
414
- self.pool: "asyncio.Queue[Type[VirtualClientEngineActor]]" = asyncio.Queue(
415
- maxsize=1024
416
- )
413
+ self.pool: List[VirtualClientEngineActor] = []
417
414
  self.num_actors = 0
418
415
 
419
416
  # Resolve arguments to pass during actor init
@@ -427,38 +424,37 @@ class BasicActorPool:
427
424
  # Figure out how many actors can be created given the cluster resources
428
425
  # and the resources the user indicates each VirtualClient will need
429
426
  self.actors_capacity = pool_size_from_resources(client_resources)
430
- self._future_to_actor: Dict[Any, Type[VirtualClientEngineActor]] = {}
427
+ self._future_to_actor: Dict[Any, VirtualClientEngineActor] = {}
431
428
 
432
429
  def is_actor_available(self) -> bool:
433
430
  """Return true if there is an idle actor."""
434
- return self.pool.qsize() > 0
431
+ return len(self.pool) > 0
435
432
 
436
- async def add_actors_to_pool(self, num_actors: int) -> None:
433
+ def add_actors_to_pool(self, num_actors: int) -> None:
437
434
  """Add actors to the pool.
438
435
 
439
436
  This method may be executed also if new resources are added to your Ray cluster
440
437
  (e.g. you add a new node).
441
438
  """
442
439
  for _ in range(num_actors):
443
- await self.pool.put(self.create_actor_fn()) # type: ignore
440
+ self.pool.append(self.create_actor_fn()) # type: ignore
444
441
  self.num_actors += num_actors
445
442
 
446
- async def terminate_all_actors(self) -> None:
443
+ def terminate_all_actors(self) -> None:
447
444
  """Terminate actors in pool."""
448
445
  num_terminated = 0
449
- while self.pool.qsize():
450
- actor = await self.pool.get()
446
+ for actor in self.pool:
451
447
  actor.terminate.remote() # type: ignore
452
448
  num_terminated += 1
453
449
 
454
450
  log(DEBUG, "Terminated %i actors", num_terminated)
455
451
 
456
- async def submit(
452
+ def submit(
457
453
  self, actor_fn: Any, job: Tuple[ClientAppFn, Message, str, Context]
458
454
  ) -> Any:
459
455
  """On idle actor, submit job and return future."""
460
456
  # Remove idle actor from pool
461
- actor = await self.pool.get()
457
+ actor = self.pool.pop()
462
458
  # Submit job to actor
463
459
  app_fn, mssg, cid, context = job
464
460
  future = actor_fn(actor, app_fn, mssg, cid, context)
@@ -467,18 +463,18 @@ class BasicActorPool:
467
463
  self._future_to_actor[future] = actor
468
464
  return future
469
465
 
470
- async def add_actor_back_to_pool(self, future: Any) -> None:
466
+ def add_actor_back_to_pool(self, future: Any) -> None:
471
467
  """Ad actor assigned to run future back into the pool."""
472
468
  actor = self._future_to_actor.pop(future)
473
- await self.pool.put(actor)
469
+ self.pool.append(actor)
474
470
 
475
- async def fetch_result_and_return_actor_to_pool(
471
+ def fetch_result_and_return_actor_to_pool(
476
472
  self, future: Any
477
473
  ) -> Tuple[Message, Context]:
478
474
  """Pull result given a future and add actor back to pool."""
479
- # Get actor that ran job
480
- await self.add_actor_back_to_pool(future)
481
475
  # Retrieve result for object store
482
476
  # Instead of doing ray.get(future) we await it
483
- _, out_mssg, updated_context = await future
477
+ _, out_mssg, updated_context = ray.get(future)
478
+ # Get actor that ran job
479
+ self.add_actor_back_to_pool(future)
484
480
  return out_mssg, updated_context
@@ -20,11 +20,16 @@ from logging import ERROR
20
20
  from typing import Optional
21
21
 
22
22
  from flwr import common
23
- from flwr.client import ClientFn
23
+ from flwr.client import ClientFnExt
24
24
  from flwr.client.client_app import ClientApp
25
25
  from flwr.client.node_state import NodeState
26
26
  from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
27
- from flwr.common.constant import MessageType, MessageTypeLegacy
27
+ from flwr.common.constant import (
28
+ NUM_PARTITIONS_KEY,
29
+ PARTITION_ID_KEY,
30
+ MessageType,
31
+ MessageTypeLegacy,
32
+ )
28
33
  from flwr.common.logger import log
29
34
  from flwr.common.recordset_compat import (
30
35
  evaluateins_to_recordset,
@@ -43,17 +48,30 @@ from flwr.simulation.ray_transport.ray_actor import VirtualClientEngineActorPool
43
48
  class RayActorClientProxy(ClientProxy):
44
49
  """Flower client proxy which delegates work using Ray."""
45
50
 
46
- def __init__(
47
- self, client_fn: ClientFn, cid: str, actor_pool: VirtualClientEngineActorPool
51
+ def __init__( # pylint: disable=too-many-arguments
52
+ self,
53
+ client_fn: ClientFnExt,
54
+ node_id: int,
55
+ partition_id: int,
56
+ num_partitions: int,
57
+ actor_pool: VirtualClientEngineActorPool,
48
58
  ):
49
- super().__init__(cid)
59
+ super().__init__(cid=str(node_id))
60
+ self.node_id = node_id
61
+ self.partition_id = partition_id
50
62
 
51
63
  def _load_app() -> ClientApp:
52
64
  return ClientApp(client_fn=client_fn)
53
65
 
54
66
  self.app_fn = _load_app
55
67
  self.actor_pool = actor_pool
56
- self.proxy_state = NodeState()
68
+ self.proxy_state = NodeState(
69
+ node_id=node_id,
70
+ node_config={
71
+ PARTITION_ID_KEY: str(partition_id),
72
+ NUM_PARTITIONS_KEY: str(num_partitions),
73
+ },
74
+ )
57
75
 
58
76
  def _submit_job(self, message: Message, timeout: Optional[float]) -> Message:
59
77
  """Sumbit a message to the ActorPool."""
@@ -62,16 +80,19 @@ class RayActorClientProxy(ClientProxy):
62
80
  # Register state
63
81
  self.proxy_state.register_context(run_id=run_id)
64
82
 
65
- # Retrieve state
66
- state = self.proxy_state.retrieve_context(run_id=run_id)
83
+ # Retrieve context
84
+ context = self.proxy_state.retrieve_context(run_id=run_id)
85
+ partition_id_str = str(context.node_config[PARTITION_ID_KEY])
67
86
 
68
87
  try:
69
88
  self.actor_pool.submit_client_job(
70
- lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state),
71
- (self.app_fn, message, self.cid, state),
89
+ lambda a, a_fn, mssg, partition_id, context: a.run.remote(
90
+ a_fn, mssg, partition_id, context
91
+ ),
92
+ (self.app_fn, message, partition_id_str, context),
72
93
  )
73
94
  out_mssg, updated_context = self.actor_pool.get_client_result(
74
- self.cid, timeout
95
+ partition_id_str, timeout
75
96
  )
76
97
 
77
98
  # Update state
@@ -103,11 +124,10 @@ class RayActorClientProxy(ClientProxy):
103
124
  message_id="",
104
125
  group_id=str(group_id) if group_id is not None else "",
105
126
  src_node_id=0,
106
- dst_node_id=int(self.cid),
127
+ dst_node_id=self.node_id,
107
128
  reply_to_message="",
108
129
  ttl=timeout if timeout else DEFAULT_TTL,
109
130
  message_type=message_type,
110
- partition_id=int(self.cid),
111
131
  ),
112
132
  )
113
133