xoscar 0.9.0__cp312-cp312-macosx_10_13_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. xoscar/__init__.py +61 -0
  2. xoscar/_utils.cpython-312-darwin.so +0 -0
  3. xoscar/_utils.pxd +36 -0
  4. xoscar/_utils.pyx +246 -0
  5. xoscar/_version.py +693 -0
  6. xoscar/aio/__init__.py +16 -0
  7. xoscar/aio/base.py +86 -0
  8. xoscar/aio/file.py +59 -0
  9. xoscar/aio/lru.py +228 -0
  10. xoscar/aio/parallelism.py +39 -0
  11. xoscar/api.py +527 -0
  12. xoscar/backend.py +67 -0
  13. xoscar/backends/__init__.py +14 -0
  14. xoscar/backends/allocate_strategy.py +160 -0
  15. xoscar/backends/communication/__init__.py +30 -0
  16. xoscar/backends/communication/base.py +315 -0
  17. xoscar/backends/communication/core.py +69 -0
  18. xoscar/backends/communication/dummy.py +253 -0
  19. xoscar/backends/communication/errors.py +20 -0
  20. xoscar/backends/communication/socket.py +444 -0
  21. xoscar/backends/communication/ucx.py +538 -0
  22. xoscar/backends/communication/utils.py +97 -0
  23. xoscar/backends/config.py +157 -0
  24. xoscar/backends/context.py +437 -0
  25. xoscar/backends/core.py +352 -0
  26. xoscar/backends/indigen/__init__.py +16 -0
  27. xoscar/backends/indigen/__main__.py +19 -0
  28. xoscar/backends/indigen/backend.py +51 -0
  29. xoscar/backends/indigen/driver.py +26 -0
  30. xoscar/backends/indigen/fate_sharing.py +221 -0
  31. xoscar/backends/indigen/pool.py +515 -0
  32. xoscar/backends/indigen/shared_memory.py +548 -0
  33. xoscar/backends/message.cpython-312-darwin.so +0 -0
  34. xoscar/backends/message.pyi +255 -0
  35. xoscar/backends/message.pyx +646 -0
  36. xoscar/backends/pool.py +1630 -0
  37. xoscar/backends/router.py +285 -0
  38. xoscar/backends/test/__init__.py +16 -0
  39. xoscar/backends/test/backend.py +38 -0
  40. xoscar/backends/test/pool.py +233 -0
  41. xoscar/batch.py +256 -0
  42. xoscar/collective/__init__.py +27 -0
  43. xoscar/collective/backend/__init__.py +13 -0
  44. xoscar/collective/backend/nccl_backend.py +160 -0
  45. xoscar/collective/common.py +102 -0
  46. xoscar/collective/core.py +737 -0
  47. xoscar/collective/process_group.py +687 -0
  48. xoscar/collective/utils.py +41 -0
  49. xoscar/collective/xoscar_pygloo.cpython-312-darwin.so +0 -0
  50. xoscar/collective/xoscar_pygloo.pyi +239 -0
  51. xoscar/constants.py +23 -0
  52. xoscar/context.cpython-312-darwin.so +0 -0
  53. xoscar/context.pxd +21 -0
  54. xoscar/context.pyx +368 -0
  55. xoscar/core.cpython-312-darwin.so +0 -0
  56. xoscar/core.pxd +51 -0
  57. xoscar/core.pyx +664 -0
  58. xoscar/debug.py +188 -0
  59. xoscar/driver.py +42 -0
  60. xoscar/errors.py +63 -0
  61. xoscar/libcpp.pxd +31 -0
  62. xoscar/metrics/__init__.py +21 -0
  63. xoscar/metrics/api.py +288 -0
  64. xoscar/metrics/backends/__init__.py +13 -0
  65. xoscar/metrics/backends/console/__init__.py +13 -0
  66. xoscar/metrics/backends/console/console_metric.py +82 -0
  67. xoscar/metrics/backends/metric.py +149 -0
  68. xoscar/metrics/backends/prometheus/__init__.py +13 -0
  69. xoscar/metrics/backends/prometheus/prometheus_metric.py +70 -0
  70. xoscar/nvutils.py +717 -0
  71. xoscar/profiling.py +260 -0
  72. xoscar/serialization/__init__.py +20 -0
  73. xoscar/serialization/aio.py +141 -0
  74. xoscar/serialization/core.cpython-312-darwin.so +0 -0
  75. xoscar/serialization/core.pxd +28 -0
  76. xoscar/serialization/core.pyi +57 -0
  77. xoscar/serialization/core.pyx +944 -0
  78. xoscar/serialization/cuda.py +111 -0
  79. xoscar/serialization/exception.py +48 -0
  80. xoscar/serialization/mlx.py +67 -0
  81. xoscar/serialization/numpy.py +82 -0
  82. xoscar/serialization/pyfury.py +37 -0
  83. xoscar/serialization/scipy.py +72 -0
  84. xoscar/serialization/torch.py +180 -0
  85. xoscar/utils.py +522 -0
  86. xoscar/virtualenv/__init__.py +34 -0
  87. xoscar/virtualenv/core.py +268 -0
  88. xoscar/virtualenv/platform.py +56 -0
  89. xoscar/virtualenv/utils.py +100 -0
  90. xoscar/virtualenv/uv.py +321 -0
  91. xoscar-0.9.0.dist-info/METADATA +230 -0
  92. xoscar-0.9.0.dist-info/RECORD +94 -0
  93. xoscar-0.9.0.dist-info/WHEEL +6 -0
  94. xoscar-0.9.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,157 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ # derived from copyright 1999-2021 Alibaba Group Holding Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ from typing import Any
19
+
20
+
21
+ class ActorPoolConfig:
22
+ __slots__ = ("_conf",)
23
+
24
+ def __init__(self, conf: dict | None = None):
25
+ if conf is None:
26
+ conf = dict()
27
+ self._conf = conf
28
+ if "pools" not in self._conf:
29
+ self._conf["pools"] = dict()
30
+ if "mapping" not in self._conf:
31
+ self._conf["mapping"] = dict()
32
+ if "metrics" not in self._conf:
33
+ self._conf["metrics"] = dict()
34
+ if "comm" not in self._conf:
35
+ self._conf["comm"] = dict()
36
+ if "proxy" not in self._conf:
37
+ self._conf["proxy"] = dict()
38
+
39
+ @property
40
+ def n_pool(self):
41
+ return len(self._conf["pools"])
42
+
43
+ def add_pool_conf(
44
+ self,
45
+ process_index: int,
46
+ label: str | None,
47
+ internal_address: str | None,
48
+ external_address: str | list[str],
49
+ env: dict | None = None,
50
+ modules: list[str] | None = None,
51
+ suspend_sigint: bool | None = False,
52
+ use_uvloop: bool | None = False,
53
+ logging_conf: dict | None = None,
54
+ kwargs: dict | None = None,
55
+ ):
56
+ pools: dict = self._conf["pools"]
57
+ if not isinstance(external_address, list):
58
+ external_address = [external_address]
59
+ pools[process_index] = {
60
+ "label": label,
61
+ "internal_address": internal_address,
62
+ "external_address": external_address,
63
+ "env": env,
64
+ "modules": modules,
65
+ "suspend_sigint": suspend_sigint,
66
+ "use_uvloop": use_uvloop,
67
+ "logging_conf": logging_conf,
68
+ "kwargs": kwargs or {},
69
+ }
70
+
71
+ mapping: dict = self._conf["mapping"]
72
+ for addr in external_address:
73
+ mapping[addr] = internal_address
74
+
75
+ def remove_pool_config(self, process_index: int):
76
+ addr = self.get_external_address(process_index)
77
+ del self._conf["pools"][process_index]
78
+ del self._conf["mapping"][addr]
79
+
80
+ def get_pool_config(self, process_index: int):
81
+ return self._conf["pools"][process_index]
82
+
83
+ def get_external_address(self, process_index: int) -> str:
84
+ return self._conf["pools"][process_index]["external_address"][0]
85
+
86
+ def get_process_indexes(self):
87
+ return list(self._conf["pools"])
88
+
89
+ def get_process_index(self, external_address: str):
90
+ for process_index, conf in self._conf["pools"].items():
91
+ if external_address in conf["external_address"]:
92
+ return process_index
93
+ raise ValueError(
94
+ f"Cannot get process_index for {external_address}"
95
+ ) # pragma: no cover
96
+
97
+ def reset_pool_external_address(
98
+ self,
99
+ process_index: int,
100
+ external_address: str | list[str],
101
+ ):
102
+ if not isinstance(external_address, list):
103
+ external_address = [external_address]
104
+ cur_pool_config = self._conf["pools"][process_index]
105
+ internal_address = cur_pool_config["internal_address"]
106
+
107
+ mapping: dict = self._conf["mapping"]
108
+ for addr in cur_pool_config["external_address"]:
109
+ if internal_address == addr:
110
+ # internal address may be the same as external address in Windows
111
+ internal_address = external_address[0]
112
+ mapping.pop(addr, None)
113
+
114
+ cur_pool_config["external_address"] = external_address
115
+ for addr in external_address:
116
+ mapping[addr] = internal_address
117
+
118
+ def get_external_addresses(self, label=None) -> list[str]:
119
+ result = []
120
+ for c in self._conf["pools"].values():
121
+ if label is not None:
122
+ if label == c["label"]:
123
+ result.append(c["external_address"][0])
124
+ else:
125
+ result.append(c["external_address"][0])
126
+ return result
127
+
128
+ @property
129
+ def external_to_internal_address_map(self) -> dict[str, str]:
130
+ return self._conf["mapping"]
131
+
132
+ def as_dict(self):
133
+ return self._conf
134
+
135
+ def add_metric_configs(self, metrics: dict[str, Any]):
136
+ if metrics:
137
+ self._conf["metrics"].update(metrics)
138
+
139
+ def get_metric_configs(self):
140
+ return self._conf["metrics"]
141
+
142
+ def add_comm_config(self, comm_config: dict[str, Any] | None):
143
+ if comm_config:
144
+ self._conf["comm"].update(comm_config)
145
+
146
+ def get_comm_config(self) -> dict:
147
+ return self._conf["comm"]
148
+
149
+ def get_proxy_config(self) -> dict:
150
+ return self._conf["proxy"]
151
+
152
+ def add_proxy_config(self, proxy_config: dict[str, str] | None):
153
+ if proxy_config:
154
+ self._conf["proxy"].update(proxy_config)
155
+
156
+ def remove_proxy(self, from_addr: str):
157
+ del self._conf["proxy"][from_addr]
@@ -0,0 +1,437 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ # derived from copyright 1999-2021 Alibaba Group Holding Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import asyncio
19
+ from dataclasses import dataclass
20
+ from typing import Any, List, Optional, Tuple, Type, Union
21
+
22
+ from .._utils import create_actor_ref, to_binary
23
+ from ..aio import AioFileObject
24
+ from ..api import Actor
25
+ from ..context import BaseActorContext
26
+ from ..core import ActorRef, BufferRef, FileObjectRef, create_local_actor_ref
27
+ from ..debug import debug_async_timeout, detect_cycle_send
28
+ from ..errors import CannotCancelTask
29
+ from ..utils import dataslots, fix_all_zero_ip
30
+ from .allocate_strategy import AddressSpecified, AllocateStrategy
31
+ from .communication import Client, DummyClient, UCXClient
32
+ from .core import ActorCaller
33
+ from .message import (
34
+ DEFAULT_PROTOCOL,
35
+ ActorRefMessage,
36
+ CancelMessage,
37
+ ControlMessage,
38
+ ControlMessageType,
39
+ CopyToBuffersMessage,
40
+ CopyToFileObjectsMessage,
41
+ CreateActorMessage,
42
+ DestroyActorMessage,
43
+ ErrorMessage,
44
+ HasActorMessage,
45
+ ResultMessage,
46
+ SendMessage,
47
+ _MessageBase,
48
+ new_message_id,
49
+ )
50
+ from .router import Router
51
+
52
+ DEFAULT_TRANSFER_BLOCK_SIZE = 4 * 1024**2
53
+
54
+
55
+ @dataslots
56
+ @dataclass
57
+ class ProfilingContext:
58
+ task_id: str
59
+
60
+
61
+ class IndigenActorContext(BaseActorContext):
62
+ __slots__ = ("_caller", "_lock")
63
+
64
+ support_allocate_strategy = True
65
+
66
+ def __init__(self, address: str | None = None):
67
+ BaseActorContext.__init__(self, address)
68
+ self._caller = ActorCaller()
69
+ self._lock = asyncio.Lock()
70
+
71
+ def __del__(self):
72
+ self._caller.cancel_tasks()
73
+
74
+ async def _call(
75
+ self,
76
+ address: str,
77
+ message: _MessageBase,
78
+ wait: bool = True,
79
+ proxy_addresses: list[str] | None = None,
80
+ ) -> Union[ResultMessage, ErrorMessage, asyncio.Future]:
81
+ return await self._caller.call(
82
+ Router.get_instance_or_empty(),
83
+ address,
84
+ message,
85
+ wait=wait,
86
+ proxy_addresses=proxy_addresses,
87
+ )
88
+
89
+ async def _call_with_client(
90
+ self, client: Client, message: _MessageBase, wait: bool = True
91
+ ) -> Union[ResultMessage, ErrorMessage, asyncio.Future]:
92
+ # NOTE: used by copyto, cannot support proxy
93
+ return await self._caller.call_with_client(client, message, wait)
94
+
95
+ async def _call_send_buffers(
96
+ self,
97
+ client: UCXClient,
98
+ local_buffers: list,
99
+ meta_message: _MessageBase,
100
+ wait: bool = True,
101
+ ) -> Union[ResultMessage, ErrorMessage, asyncio.Future]:
102
+ return await self._caller.call_send_buffers(
103
+ client, local_buffers, meta_message, wait
104
+ )
105
+
106
+ @staticmethod
107
+ def _process_result_message(message: Union[ResultMessage, ErrorMessage]):
108
+ if isinstance(message, ResultMessage):
109
+ return message.result
110
+ else:
111
+ raise message.as_instanceof_cause()
112
+
113
+ async def _wait(self, future: asyncio.Future, address: str, message: _MessageBase):
114
+ try:
115
+ await asyncio.shield(future)
116
+ except asyncio.CancelledError:
117
+ try:
118
+ await self.cancel(address, message.message_id)
119
+ except CannotCancelTask:
120
+ # cancel failed, already finished
121
+ raise asyncio.CancelledError
122
+ except: # noqa: E722 # nosec # pylint: disable=bare-except
123
+ pass
124
+ return await future
125
+
126
+ async def create_actor(
127
+ self,
128
+ actor_cls: Type[Actor],
129
+ *args,
130
+ uid=None,
131
+ address: str | None = None,
132
+ **kwargs,
133
+ ) -> ActorRef:
134
+ router = Router.get_instance_or_empty()
135
+ address = address or self._address or router.external_address
136
+ allocate_strategy = kwargs.get("allocate_strategy", None)
137
+ if isinstance(allocate_strategy, AllocateStrategy):
138
+ allocate_strategy = kwargs.pop("allocate_strategy")
139
+ else:
140
+ allocate_strategy = AddressSpecified(address)
141
+ create_actor_message = CreateActorMessage(
142
+ new_message_id(),
143
+ actor_cls,
144
+ to_binary(uid),
145
+ args,
146
+ kwargs,
147
+ allocate_strategy,
148
+ protocol=DEFAULT_PROTOCOL,
149
+ )
150
+ future = await self._call(address, create_actor_message, wait=False)
151
+ result = await self._wait(future, address, create_actor_message) # type: ignore
152
+ return self._process_result_message(result)
153
+
154
+ async def has_actor(self, actor_ref: ActorRef) -> bool:
155
+ message = HasActorMessage(
156
+ new_message_id(), actor_ref, protocol=DEFAULT_PROTOCOL
157
+ )
158
+ future = await self._call(
159
+ actor_ref.address,
160
+ message,
161
+ wait=False,
162
+ proxy_addresses=actor_ref.proxy_addresses,
163
+ )
164
+ result = await self._wait(future, actor_ref.address, message) # type: ignore
165
+ return self._process_result_message(result)
166
+
167
+ async def destroy_actor(self, actor_ref: ActorRef):
168
+ message = DestroyActorMessage(
169
+ new_message_id(), actor_ref, protocol=DEFAULT_PROTOCOL
170
+ )
171
+ future = await self._call(
172
+ actor_ref.address,
173
+ message,
174
+ wait=False,
175
+ proxy_addresses=actor_ref.proxy_addresses,
176
+ )
177
+ result = await self._wait(future, actor_ref.address, message) # type: ignore
178
+ return self._process_result_message(result)
179
+
180
+ async def kill_actor(self, actor_ref: ActorRef, force: bool = True):
181
+ # get main_pool_address
182
+ control_message = ControlMessage(
183
+ new_message_id(),
184
+ actor_ref.address,
185
+ ControlMessageType.get_config,
186
+ "main_pool_address",
187
+ protocol=DEFAULT_PROTOCOL,
188
+ )
189
+ main_address = self._process_result_message(
190
+ await self._call(actor_ref.address, control_message, proxy_addresses=actor_ref.proxy_addresses) # type: ignore
191
+ )
192
+ real_actor_ref = await self.actor_ref(actor_ref)
193
+ if real_actor_ref.address == main_address:
194
+ raise ValueError("Cannot kill actor on main pool")
195
+ stop_message = ControlMessage(
196
+ new_message_id(),
197
+ real_actor_ref.address,
198
+ ControlMessageType.stop,
199
+ # default timeout (3 secs) and force
200
+ (3.0, force),
201
+ protocol=DEFAULT_PROTOCOL,
202
+ )
203
+ # stop server
204
+ result = await self._call(
205
+ main_address, stop_message, proxy_addresses=actor_ref.proxy_addresses
206
+ )
207
+ return self._process_result_message(result) # type: ignore
208
+
209
+ async def actor_ref(self, *args, **kwargs):
210
+ actor_ref = create_actor_ref(*args, **kwargs)
211
+ connect_addr = actor_ref.address
212
+ local_actor_ref = create_local_actor_ref(actor_ref.address, actor_ref.uid)
213
+ if local_actor_ref is not None:
214
+ return local_actor_ref
215
+ message = ActorRefMessage(
216
+ new_message_id(), actor_ref, protocol=DEFAULT_PROTOCOL
217
+ )
218
+ future = await self._call(
219
+ actor_ref.address,
220
+ message,
221
+ wait=False,
222
+ proxy_addresses=actor_ref.proxy_addresses,
223
+ )
224
+ result = await self._wait(future, actor_ref.address, message)
225
+ res = self._process_result_message(result)
226
+ if res.address != connect_addr:
227
+ res.address = fix_all_zero_ip(res.address, connect_addr)
228
+ return res
229
+
230
+ async def send(
231
+ self,
232
+ actor_ref: ActorRef,
233
+ message: Tuple,
234
+ wait_response: bool = True,
235
+ profiling_context: ProfilingContext | None = None,
236
+ ):
237
+ send_message = SendMessage(
238
+ new_message_id(),
239
+ actor_ref,
240
+ message,
241
+ protocol=DEFAULT_PROTOCOL,
242
+ profiling_context=profiling_context,
243
+ )
244
+
245
+ # use `%.500` to avoid print too long messages
246
+ with debug_async_timeout(
247
+ "actor_call_timeout",
248
+ "Calling %.500r on %s at %s timed out",
249
+ send_message.content,
250
+ actor_ref.uid,
251
+ actor_ref.address,
252
+ ):
253
+ detect_cycle_send(send_message, wait_response)
254
+ future = await self._call(
255
+ actor_ref.address,
256
+ send_message,
257
+ wait=False,
258
+ proxy_addresses=actor_ref.proxy_addresses,
259
+ )
260
+ if wait_response:
261
+ result = await self._wait(future, actor_ref.address, send_message) # type: ignore
262
+ return self._process_result_message(result)
263
+ else:
264
+ return future
265
+
266
+ async def cancel(self, address: str, cancel_message_id: bytes):
267
+ message = CancelMessage(
268
+ new_message_id(), address, cancel_message_id, protocol=DEFAULT_PROTOCOL
269
+ )
270
+ result = await self._call(address, message)
271
+ return self._process_result_message(result) # type: ignore
272
+
273
+ async def wait_actor_pool_recovered(
274
+ self, address: str, main_address: str | None = None
275
+ ):
276
+ if main_address is None:
277
+ # get main_pool_address
278
+ control_message = ControlMessage(
279
+ new_message_id(),
280
+ address,
281
+ ControlMessageType.get_config,
282
+ "main_pool_address",
283
+ protocol=DEFAULT_PROTOCOL,
284
+ )
285
+ main_address = self._process_result_message(
286
+ await self._call(address, control_message) # type: ignore
287
+ )
288
+
289
+ # if address is main pool, it is never recovered
290
+ if address == main_address:
291
+ return
292
+
293
+ control_message = ControlMessage(
294
+ new_message_id(),
295
+ address,
296
+ ControlMessageType.wait_pool_recovered,
297
+ None,
298
+ protocol=DEFAULT_PROTOCOL,
299
+ )
300
+ self._process_result_message(await self._call(main_address, control_message)) # type: ignore
301
+
302
+ async def get_pool_config(self, address: str):
303
+ control_message = ControlMessage(
304
+ new_message_id(),
305
+ address,
306
+ ControlMessageType.get_config,
307
+ None,
308
+ protocol=DEFAULT_PROTOCOL,
309
+ )
310
+ return self._process_result_message(await self._call(address, control_message)) # type: ignore
311
+
312
+ @staticmethod
313
+ def _gen_switch_to_copy_to_control_message(content: Any):
314
+ return ControlMessage(
315
+ message_id=new_message_id(),
316
+ control_message_type=ControlMessageType.switch_to_copy_to,
317
+ content=content,
318
+ )
319
+
320
+ @staticmethod
321
+ def _gen_copy_to_buffers_message(content: Any):
322
+ return CopyToBuffersMessage(message_id=new_message_id(), content=content) # type: ignore
323
+
324
+ @staticmethod
325
+ def _gen_copy_to_fileobjs_message(content: Any):
326
+ return CopyToFileObjectsMessage(message_id=new_message_id(), content=content) # type: ignore
327
+
328
+ async def _get_copy_to_client(self, router, address) -> Client:
329
+ client = await self._caller.get_client(router, address)
330
+ if isinstance(client, DummyClient) or hasattr(client, "send_buffers"):
331
+ return client
332
+ client_types = router.get_all_client_types(address)
333
+ # For inter-process communication, the ``self._caller.get_client`` interface would not look for UCX Client,
334
+ # we still try to find UCXClient for this case.
335
+ try:
336
+ client_type = next(
337
+ client_type
338
+ for client_type in client_types
339
+ if hasattr(client_type, "send_buffers")
340
+ )
341
+ except StopIteration:
342
+ return client
343
+ else:
344
+ return await self._caller.get_client_via_type(router, address, client_type)
345
+
346
+ async def _get_client(self, address: str) -> Client:
347
+ router = Router.get_instance()
348
+ assert router is not None, "`copy_to` can only be used inside pools"
349
+ if router.get_proxy(address):
350
+ raise RuntimeError("Cannot run `copy_to` when enabling proxy")
351
+ return await self._get_copy_to_client(router, address)
352
+
353
+ async def copy_to_buffers(
354
+ self,
355
+ local_buffers: list,
356
+ remote_buffer_refs: List[BufferRef],
357
+ block_size: Optional[int] = None,
358
+ ):
359
+ address = remote_buffer_refs[0].address
360
+ client = await self._get_client(address)
361
+ if isinstance(client, UCXClient):
362
+ message = [(buf.address, buf.uid) for buf in remote_buffer_refs]
363
+ await self._call_send_buffers(
364
+ client,
365
+ local_buffers,
366
+ self._gen_switch_to_copy_to_control_message(message),
367
+ )
368
+ else:
369
+ # ``local_buffers`` will be divided into buffers of the specified block size for transmission.
370
+ # Smaller buffers will be accumulated and sent together,
371
+ # while larger buffers will be divided and sent.
372
+ current_buf_size = 0
373
+ one_block_data = []
374
+ block_size = block_size or DEFAULT_TRANSFER_BLOCK_SIZE
375
+ for i, (l_buf, r_buf) in enumerate(zip(local_buffers, remote_buffer_refs)):
376
+ if current_buf_size + len(l_buf) < block_size:
377
+ one_block_data.append(
378
+ (r_buf.address, r_buf.uid, 0, len(l_buf), l_buf)
379
+ )
380
+ current_buf_size += len(l_buf)
381
+ continue
382
+ last_start = 0
383
+ while current_buf_size + len(l_buf) > block_size:
384
+ remain = block_size - current_buf_size
385
+ one_block_data.append(
386
+ (r_buf.address, r_buf.uid, last_start, remain, l_buf[:remain])
387
+ )
388
+ await self._call_with_client(
389
+ client, self._gen_copy_to_buffers_message(one_block_data)
390
+ )
391
+ one_block_data = []
392
+ current_buf_size = 0
393
+ last_start += remain
394
+ l_buf = l_buf[remain:]
395
+
396
+ if len(l_buf) > 0:
397
+ one_block_data.append(
398
+ (r_buf.address, r_buf.uid, last_start, len(l_buf), l_buf)
399
+ )
400
+ current_buf_size = len(l_buf)
401
+
402
+ if one_block_data:
403
+ await self._call_with_client(
404
+ client, self._gen_copy_to_buffers_message(one_block_data)
405
+ )
406
+
407
+ async def copy_to_fileobjs(
408
+ self,
409
+ local_fileobjs: List[AioFileObject],
410
+ remote_fileobj_refs: List[FileObjectRef],
411
+ block_size: Optional[int] = None,
412
+ ):
413
+ address = remote_fileobj_refs[0].address
414
+ client = await self._get_client(address)
415
+ block_size = block_size or DEFAULT_TRANSFER_BLOCK_SIZE
416
+ one_block_data = []
417
+ current_file_size = 0
418
+ for file_obj, remote_ref in zip(local_fileobjs, remote_fileobj_refs):
419
+ while True:
420
+ file_data = await file_obj.read(block_size) # type: ignore
421
+ if file_data:
422
+ one_block_data.append(
423
+ (remote_ref.address, remote_ref.uid, file_data)
424
+ )
425
+ current_file_size += len(file_data)
426
+ if current_file_size >= block_size:
427
+ message = self._gen_copy_to_fileobjs_message(one_block_data)
428
+ await self._call_with_client(client, message)
429
+ one_block_data.clear()
430
+ current_file_size = 0
431
+ else:
432
+ break
433
+
434
+ if current_file_size > 0:
435
+ message = self._gen_copy_to_fileobjs_message(one_block_data)
436
+ await self._call_with_client(client, message)
437
+ one_block_data.clear()