xoscar 0.9.0__cp312-cp312-macosx_10_13_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xoscar/__init__.py +61 -0
- xoscar/_utils.cpython-312-darwin.so +0 -0
- xoscar/_utils.pxd +36 -0
- xoscar/_utils.pyx +246 -0
- xoscar/_version.py +693 -0
- xoscar/aio/__init__.py +16 -0
- xoscar/aio/base.py +86 -0
- xoscar/aio/file.py +59 -0
- xoscar/aio/lru.py +228 -0
- xoscar/aio/parallelism.py +39 -0
- xoscar/api.py +527 -0
- xoscar/backend.py +67 -0
- xoscar/backends/__init__.py +14 -0
- xoscar/backends/allocate_strategy.py +160 -0
- xoscar/backends/communication/__init__.py +30 -0
- xoscar/backends/communication/base.py +315 -0
- xoscar/backends/communication/core.py +69 -0
- xoscar/backends/communication/dummy.py +253 -0
- xoscar/backends/communication/errors.py +20 -0
- xoscar/backends/communication/socket.py +444 -0
- xoscar/backends/communication/ucx.py +538 -0
- xoscar/backends/communication/utils.py +97 -0
- xoscar/backends/config.py +157 -0
- xoscar/backends/context.py +437 -0
- xoscar/backends/core.py +352 -0
- xoscar/backends/indigen/__init__.py +16 -0
- xoscar/backends/indigen/__main__.py +19 -0
- xoscar/backends/indigen/backend.py +51 -0
- xoscar/backends/indigen/driver.py +26 -0
- xoscar/backends/indigen/fate_sharing.py +221 -0
- xoscar/backends/indigen/pool.py +515 -0
- xoscar/backends/indigen/shared_memory.py +548 -0
- xoscar/backends/message.cpython-312-darwin.so +0 -0
- xoscar/backends/message.pyi +255 -0
- xoscar/backends/message.pyx +646 -0
- xoscar/backends/pool.py +1630 -0
- xoscar/backends/router.py +285 -0
- xoscar/backends/test/__init__.py +16 -0
- xoscar/backends/test/backend.py +38 -0
- xoscar/backends/test/pool.py +233 -0
- xoscar/batch.py +256 -0
- xoscar/collective/__init__.py +27 -0
- xoscar/collective/backend/__init__.py +13 -0
- xoscar/collective/backend/nccl_backend.py +160 -0
- xoscar/collective/common.py +102 -0
- xoscar/collective/core.py +737 -0
- xoscar/collective/process_group.py +687 -0
- xoscar/collective/utils.py +41 -0
- xoscar/collective/xoscar_pygloo.cpython-312-darwin.so +0 -0
- xoscar/collective/xoscar_pygloo.pyi +239 -0
- xoscar/constants.py +23 -0
- xoscar/context.cpython-312-darwin.so +0 -0
- xoscar/context.pxd +21 -0
- xoscar/context.pyx +368 -0
- xoscar/core.cpython-312-darwin.so +0 -0
- xoscar/core.pxd +51 -0
- xoscar/core.pyx +664 -0
- xoscar/debug.py +188 -0
- xoscar/driver.py +42 -0
- xoscar/errors.py +63 -0
- xoscar/libcpp.pxd +31 -0
- xoscar/metrics/__init__.py +21 -0
- xoscar/metrics/api.py +288 -0
- xoscar/metrics/backends/__init__.py +13 -0
- xoscar/metrics/backends/console/__init__.py +13 -0
- xoscar/metrics/backends/console/console_metric.py +82 -0
- xoscar/metrics/backends/metric.py +149 -0
- xoscar/metrics/backends/prometheus/__init__.py +13 -0
- xoscar/metrics/backends/prometheus/prometheus_metric.py +70 -0
- xoscar/nvutils.py +717 -0
- xoscar/profiling.py +260 -0
- xoscar/serialization/__init__.py +20 -0
- xoscar/serialization/aio.py +141 -0
- xoscar/serialization/core.cpython-312-darwin.so +0 -0
- xoscar/serialization/core.pxd +28 -0
- xoscar/serialization/core.pyi +57 -0
- xoscar/serialization/core.pyx +944 -0
- xoscar/serialization/cuda.py +111 -0
- xoscar/serialization/exception.py +48 -0
- xoscar/serialization/mlx.py +67 -0
- xoscar/serialization/numpy.py +82 -0
- xoscar/serialization/pyfury.py +37 -0
- xoscar/serialization/scipy.py +72 -0
- xoscar/serialization/torch.py +180 -0
- xoscar/utils.py +522 -0
- xoscar/virtualenv/__init__.py +34 -0
- xoscar/virtualenv/core.py +268 -0
- xoscar/virtualenv/platform.py +56 -0
- xoscar/virtualenv/utils.py +100 -0
- xoscar/virtualenv/uv.py +321 -0
- xoscar-0.9.0.dist-info/METADATA +230 -0
- xoscar-0.9.0.dist-info/RECORD +94 -0
- xoscar-0.9.0.dist-info/WHEEL +6 -0
- xoscar-0.9.0.dist-info/top_level.txt +2 -0
xoscar/profiling.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
# derived from copyright 1999-2021 Alibaba Group Holding Ltd.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import copy
|
|
18
|
+
import heapq
|
|
19
|
+
import json
|
|
20
|
+
import logging
|
|
21
|
+
import operator
|
|
22
|
+
import os
|
|
23
|
+
from collections import Counter
|
|
24
|
+
from collections.abc import Mapping
|
|
25
|
+
|
|
26
|
+
from .backends.message import SendMessage, TellMessage
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
XOSCAR_ENABLE_PROFILING = int(os.environ.get("XOSCAR_ENABLE_PROFILING", 0))
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class _ProfilingOptionDescriptor:
|
|
34
|
+
def __init__(self, _type, default):
|
|
35
|
+
self._name = None
|
|
36
|
+
self._type = _type
|
|
37
|
+
self._default = default
|
|
38
|
+
|
|
39
|
+
def __get__(self, obj, cls):
|
|
40
|
+
if obj is None:
|
|
41
|
+
return self
|
|
42
|
+
v = obj._options.get(self._name)
|
|
43
|
+
if v is None:
|
|
44
|
+
v = os.environ.get(f"XOSCAR_PROFILING_{self._name.upper()}", self._default)
|
|
45
|
+
if v is not None:
|
|
46
|
+
v = self._type(v)
|
|
47
|
+
# Cache the value.
|
|
48
|
+
obj.__dict__[self._name] = v
|
|
49
|
+
return v
|
|
50
|
+
|
|
51
|
+
def set_name(self, name: str):
|
|
52
|
+
self._name = name
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class _ProfilingOptionsMeta(type):
|
|
56
|
+
def __init__(cls, name, bases, classdict):
|
|
57
|
+
super(_ProfilingOptionsMeta, cls).__init__(name, bases, classdict)
|
|
58
|
+
for k, v in classdict.items():
|
|
59
|
+
if isinstance(v, _ProfilingOptionDescriptor):
|
|
60
|
+
v.set_name(k)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class _ProfilingOptions(metaclass=_ProfilingOptionsMeta):
|
|
64
|
+
debug_interval_seconds = _ProfilingOptionDescriptor(float, default=None)
|
|
65
|
+
slow_calls_duration_threshold = _ProfilingOptionDescriptor(int, default=1)
|
|
66
|
+
|
|
67
|
+
def __init__(self, options):
|
|
68
|
+
if isinstance(options, Mapping):
|
|
69
|
+
invalid_keys = options.keys() - type(self).__dict__.keys()
|
|
70
|
+
if invalid_keys:
|
|
71
|
+
raise ValueError(f"Invalid profiling options: {invalid_keys}")
|
|
72
|
+
self._options = options
|
|
73
|
+
elif options in (True, False, None):
|
|
74
|
+
self._options = {}
|
|
75
|
+
else:
|
|
76
|
+
raise ValueError(f"Invalid profiling options: {options}")
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class DummyOperator:
|
|
80
|
+
@staticmethod
|
|
81
|
+
def set(key, value):
|
|
82
|
+
pass
|
|
83
|
+
|
|
84
|
+
@staticmethod
|
|
85
|
+
def inc(key, value):
|
|
86
|
+
pass
|
|
87
|
+
|
|
88
|
+
@staticmethod
|
|
89
|
+
def nest(key):
|
|
90
|
+
return DummyOperator
|
|
91
|
+
|
|
92
|
+
@staticmethod
|
|
93
|
+
def values():
|
|
94
|
+
return []
|
|
95
|
+
|
|
96
|
+
@staticmethod
|
|
97
|
+
def empty():
|
|
98
|
+
return True
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class ProfilingDataOperator:
|
|
102
|
+
__slots__ = ("_target",)
|
|
103
|
+
|
|
104
|
+
def __init__(self, target):
|
|
105
|
+
self._target = target
|
|
106
|
+
|
|
107
|
+
def set(self, key, value):
|
|
108
|
+
self._target[key] = value
|
|
109
|
+
|
|
110
|
+
def inc(self, key, value):
|
|
111
|
+
old = self._target.get(key, 0)
|
|
112
|
+
self._target[key] = old + value
|
|
113
|
+
|
|
114
|
+
def nest(self, key):
|
|
115
|
+
v = self._target.setdefault(key, {})
|
|
116
|
+
if not isinstance(v, dict):
|
|
117
|
+
raise TypeError(
|
|
118
|
+
f"The value type of key {key} is {type(v)}, but a dict is expected."
|
|
119
|
+
)
|
|
120
|
+
return ProfilingDataOperator(v)
|
|
121
|
+
|
|
122
|
+
def values(self):
|
|
123
|
+
return self._target.values()
|
|
124
|
+
|
|
125
|
+
def empty(self):
|
|
126
|
+
return len(self._target) == 0
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class _CallStats:
|
|
130
|
+
_call_counter: Counter
|
|
131
|
+
_slow_calls: list
|
|
132
|
+
|
|
133
|
+
def __init__(self, options: _ProfilingOptions):
|
|
134
|
+
self._options = options
|
|
135
|
+
self._call_counter = Counter()
|
|
136
|
+
self._slow_calls = []
|
|
137
|
+
|
|
138
|
+
def collect(self, message, duration: float):
|
|
139
|
+
key = (message.actor_ref.uid, message.content[0])
|
|
140
|
+
self._call_counter[key] += 1
|
|
141
|
+
if duration < self._options.slow_calls_duration_threshold:
|
|
142
|
+
return
|
|
143
|
+
slow_call_key = (
|
|
144
|
+
duration,
|
|
145
|
+
message.actor_ref.uid,
|
|
146
|
+
message.actor_ref.address,
|
|
147
|
+
message.content,
|
|
148
|
+
)
|
|
149
|
+
try:
|
|
150
|
+
if len(self._slow_calls) < 10:
|
|
151
|
+
heapq.heappush(self._slow_calls, slow_call_key)
|
|
152
|
+
else:
|
|
153
|
+
heapq.heapreplace(self._slow_calls, slow_call_key)
|
|
154
|
+
except TypeError:
|
|
155
|
+
pass
|
|
156
|
+
|
|
157
|
+
def to_dict(self) -> dict:
|
|
158
|
+
most_calls = {}
|
|
159
|
+
for name_tuple, count in self._call_counter.most_common(10):
|
|
160
|
+
uid, method_name = name_tuple
|
|
161
|
+
most_calls[f"{uid.decode('utf-8')}.{method_name}"] = count
|
|
162
|
+
slow_calls = {}
|
|
163
|
+
for duration, uid, address, content in sorted(
|
|
164
|
+
self._slow_calls, key=operator.itemgetter(0), reverse=True
|
|
165
|
+
):
|
|
166
|
+
method_name, _batch, args, kwargs = content
|
|
167
|
+
slow_calls[
|
|
168
|
+
f"[{address}]{uid.decode('utf-8')}.{method_name}(args={args}, kwargs={kwargs})"
|
|
169
|
+
] = duration
|
|
170
|
+
return {"most_calls": most_calls, "slow_calls": slow_calls}
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class _ProfilingData:
|
|
174
|
+
_instance = None
|
|
175
|
+
|
|
176
|
+
@classmethod
|
|
177
|
+
def get_instance(cls):
|
|
178
|
+
if cls._instance is None:
|
|
179
|
+
cls._instance = cls()
|
|
180
|
+
return cls._instance
|
|
181
|
+
|
|
182
|
+
@classmethod
|
|
183
|
+
def set_instance(cls, inst: "_ProfilingData"):
|
|
184
|
+
cls._instance = inst
|
|
185
|
+
|
|
186
|
+
def __init__(self):
|
|
187
|
+
self._data = {}
|
|
188
|
+
self._call_stats = {}
|
|
189
|
+
self._debug_task = {}
|
|
190
|
+
|
|
191
|
+
def init(self, task_id: str, options=None):
|
|
192
|
+
options = _ProfilingOptions(options)
|
|
193
|
+
logger.info(
|
|
194
|
+
"Init profiling data for task %s with debug interval seconds %s.",
|
|
195
|
+
task_id,
|
|
196
|
+
options.debug_interval_seconds,
|
|
197
|
+
)
|
|
198
|
+
self._data[task_id] = {
|
|
199
|
+
"general": {},
|
|
200
|
+
"serialization": {},
|
|
201
|
+
"most_calls": {},
|
|
202
|
+
"slow_calls": {},
|
|
203
|
+
}
|
|
204
|
+
self._call_stats[task_id] = _CallStats(options)
|
|
205
|
+
|
|
206
|
+
async def _debug_profiling_log():
|
|
207
|
+
while True:
|
|
208
|
+
try:
|
|
209
|
+
r = self._data.get(task_id, None)
|
|
210
|
+
if r is None:
|
|
211
|
+
logger.info("Profiling debug log break.")
|
|
212
|
+
break
|
|
213
|
+
r = copy.copy(r) # shadow copy is enough.
|
|
214
|
+
r.update(self._call_stats.get(task_id).to_dict())
|
|
215
|
+
logger.warning("Profiling debug:\n%s", json.dumps(r, indent=4))
|
|
216
|
+
except Exception:
|
|
217
|
+
logger.exception("Profiling debug log failed.")
|
|
218
|
+
await asyncio.sleep(options.debug_interval_seconds)
|
|
219
|
+
|
|
220
|
+
if options.debug_interval_seconds is not None:
|
|
221
|
+
self._debug_task[task_id] = task = asyncio.create_task(
|
|
222
|
+
_debug_profiling_log()
|
|
223
|
+
)
|
|
224
|
+
task.add_done_callback(lambda _: self._debug_task.pop(task_id, None))
|
|
225
|
+
|
|
226
|
+
def pop(self, task_id: str):
|
|
227
|
+
logger.info("Pop profiling data of task %s.", task_id)
|
|
228
|
+
debug_task = self._debug_task.pop(task_id, None)
|
|
229
|
+
if debug_task is not None:
|
|
230
|
+
debug_task.cancel()
|
|
231
|
+
r = self._data.pop(task_id, None)
|
|
232
|
+
if r is not None:
|
|
233
|
+
r.update(self._call_stats.pop(task_id).to_dict())
|
|
234
|
+
return r
|
|
235
|
+
|
|
236
|
+
def collect_actor_call(self, message, duration: float):
|
|
237
|
+
if self._call_stats:
|
|
238
|
+
message_type = type(message)
|
|
239
|
+
if message_type is SendMessage or message_type is TellMessage:
|
|
240
|
+
for stats in self._call_stats.values():
|
|
241
|
+
stats.collect(message, duration)
|
|
242
|
+
|
|
243
|
+
def __getitem__(self, item):
|
|
244
|
+
key = item if isinstance(item, tuple) else (item,)
|
|
245
|
+
v = None
|
|
246
|
+
d = self._data
|
|
247
|
+
for k in key:
|
|
248
|
+
v = d.get(k, None)
|
|
249
|
+
if v is None:
|
|
250
|
+
break
|
|
251
|
+
else:
|
|
252
|
+
d = v
|
|
253
|
+
return DummyOperator if v is None else ProfilingDataOperator(v)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def get_profiling_data() -> _ProfilingData:
|
|
257
|
+
return _ProfilingData.get_instance()
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
ProfilingData = get_profiling_data()
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
# derived from copyright 1999-2021 Alibaba Group Holding Ltd.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from . import cuda, exception, mlx, numpy, scipy, torch
|
|
17
|
+
from .aio import AioDeserializer, AioSerializer
|
|
18
|
+
from .core import Serializer, deserialize, serialize, serialize_with_spawn
|
|
19
|
+
|
|
20
|
+
del cuda, numpy, scipy, mlx, exception, torch
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
# derived from copyright 1999-2021 Alibaba Group Holding Ltd.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import struct
|
|
18
|
+
from io import BytesIO
|
|
19
|
+
from typing import Any, BinaryIO, Union
|
|
20
|
+
|
|
21
|
+
import cloudpickle
|
|
22
|
+
|
|
23
|
+
from ..utils import lazy_import
|
|
24
|
+
from .core import deserialize, serialize_with_spawn
|
|
25
|
+
|
|
26
|
+
rmm = lazy_import("rmm")
|
|
27
|
+
|
|
28
|
+
DEFAULT_SERIALIZATION_VERSION = 1
|
|
29
|
+
DEFAULT_SPAWN_THRESHOLD = 100
|
|
30
|
+
BUFFER_SIZES_NAME = "buf_sizes"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class AioSerializer:
|
|
34
|
+
def __init__(self, obj: Any, compress=0):
|
|
35
|
+
self._obj = obj
|
|
36
|
+
self._compress = compress
|
|
37
|
+
|
|
38
|
+
async def _get_buffers(self):
|
|
39
|
+
headers, buffers = await serialize_with_spawn(
|
|
40
|
+
self._obj, spawn_threshold=DEFAULT_SPAWN_THRESHOLD
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
def _is_cuda_buffer(buf: Union["rmm.DeviceBuffer", BinaryIO]): # type: ignore
|
|
44
|
+
return hasattr(buf, "__cuda_array_interface__")
|
|
45
|
+
|
|
46
|
+
is_cuda_buffers = [_is_cuda_buffer(buf) for buf in buffers]
|
|
47
|
+
headers[0]["is_cuda_buffers"] = list(is_cuda_buffers)
|
|
48
|
+
|
|
49
|
+
# add buffer lengths into headers
|
|
50
|
+
headers[0][BUFFER_SIZES_NAME] = [
|
|
51
|
+
buf.nbytes if hasattr(buf, "nbytes") else len(buf) for buf in buffers
|
|
52
|
+
]
|
|
53
|
+
header = cloudpickle.dumps(headers)
|
|
54
|
+
|
|
55
|
+
# gen header buffer
|
|
56
|
+
header_bio = BytesIO()
|
|
57
|
+
# write version first
|
|
58
|
+
header_bio.write(struct.pack("B", DEFAULT_SERIALIZATION_VERSION))
|
|
59
|
+
# write header length
|
|
60
|
+
header_bio.write(struct.pack("<Q", len(header)))
|
|
61
|
+
# write compression
|
|
62
|
+
header_bio.write(struct.pack("<H", self._compress))
|
|
63
|
+
|
|
64
|
+
out_buffers = [header_bio.getbuffer(), header]
|
|
65
|
+
out_buffers.extend(buffers)
|
|
66
|
+
|
|
67
|
+
return out_buffers
|
|
68
|
+
|
|
69
|
+
async def run(self):
|
|
70
|
+
return await self._get_buffers()
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
MALFORMED_MSG = "Received malformed data, please check Xoscar version on both side."
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def get_header_length(header_bytes: bytes):
|
|
77
|
+
version = struct.unpack("B", header_bytes[:1])[0]
|
|
78
|
+
# now we only have default version
|
|
79
|
+
if version != DEFAULT_SERIALIZATION_VERSION:
|
|
80
|
+
# when version not matched,
|
|
81
|
+
# we will immediately abort the connection
|
|
82
|
+
# EOFError will be captured by channel
|
|
83
|
+
raise EOFError(MALFORMED_MSG)
|
|
84
|
+
# header length
|
|
85
|
+
header_length = struct.unpack("<Q", header_bytes[1:9])[0]
|
|
86
|
+
# compress
|
|
87
|
+
_ = struct.unpack("<H", header_bytes[9:])[0]
|
|
88
|
+
return header_length
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class AioDeserializer:
|
|
92
|
+
def __init__(self, file):
|
|
93
|
+
self._file = file
|
|
94
|
+
|
|
95
|
+
def _readexactly(self, n: int):
|
|
96
|
+
# asyncio StreamReader may not guarantee to read n bytes
|
|
97
|
+
# for it we need to call `readexactly` instead
|
|
98
|
+
read = (
|
|
99
|
+
self._file.readexactly
|
|
100
|
+
if hasattr(self._file, "readexactly")
|
|
101
|
+
else self._file.read
|
|
102
|
+
)
|
|
103
|
+
return read(n)
|
|
104
|
+
|
|
105
|
+
async def _get_obj_header_bytes(self):
|
|
106
|
+
try:
|
|
107
|
+
header_bytes = bytes(await self._readexactly(11))
|
|
108
|
+
except ConnectionResetError:
|
|
109
|
+
raise EOFError("Server may be closed")
|
|
110
|
+
if len(header_bytes) == 0:
|
|
111
|
+
raise EOFError("Received empty bytes")
|
|
112
|
+
header_length = get_header_length(header_bytes)
|
|
113
|
+
return await self._readexactly(header_length)
|
|
114
|
+
|
|
115
|
+
async def _get_obj(self):
|
|
116
|
+
header = cloudpickle.loads(await self._get_obj_header_bytes())
|
|
117
|
+
# get buffer size
|
|
118
|
+
buffer_sizes = header[0].pop(BUFFER_SIZES_NAME)
|
|
119
|
+
# get buffers
|
|
120
|
+
buffers = [await self._readexactly(size) for size in buffer_sizes]
|
|
121
|
+
# get num of objs
|
|
122
|
+
num_objs = header[0].get("_N", 0)
|
|
123
|
+
|
|
124
|
+
if num_objs <= DEFAULT_SPAWN_THRESHOLD:
|
|
125
|
+
return deserialize(header, buffers)
|
|
126
|
+
else:
|
|
127
|
+
return await asyncio.to_thread(deserialize, header, buffers)
|
|
128
|
+
|
|
129
|
+
async def run(self):
|
|
130
|
+
return await self._get_obj()
|
|
131
|
+
|
|
132
|
+
async def get_size(self):
|
|
133
|
+
# extract header
|
|
134
|
+
header_bytes = await self._get_obj_header_bytes()
|
|
135
|
+
header = cloudpickle.loads(header_bytes)
|
|
136
|
+
# get buffer size
|
|
137
|
+
buffer_sizes = header[0].pop(BUFFER_SIZES_NAME)
|
|
138
|
+
return 11 + len(header_bytes) + sum(buffer_sizes)
|
|
139
|
+
|
|
140
|
+
async def get_header(self):
|
|
141
|
+
return cloudpickle.loads(await self._get_obj_header_bytes())
|
|
Binary file
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
# derived from copyright 1999-2022 Alibaba Group Holding Ltd.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
cdef class Serializer:
|
|
17
|
+
cdef int _serializer_id
|
|
18
|
+
|
|
19
|
+
cpdef serial(self, object obj, dict context)
|
|
20
|
+
cpdef deserial(self, tuple serialized, dict context, list subs)
|
|
21
|
+
cpdef on_deserial_error(
|
|
22
|
+
self,
|
|
23
|
+
tuple serialized,
|
|
24
|
+
dict context,
|
|
25
|
+
list subs_serialized,
|
|
26
|
+
int error_index,
|
|
27
|
+
object exc,
|
|
28
|
+
)
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
# derived from copyright 1999-2021 Alibaba Group Holding Ltd.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
from concurrent.futures import Executor
|
|
19
|
+
from typing import Any, Callable
|
|
20
|
+
|
|
21
|
+
def buffered(func: Callable) -> Callable: ...
|
|
22
|
+
def fast_id(obj: Any) -> int: ...
|
|
23
|
+
|
|
24
|
+
class Serializer:
|
|
25
|
+
serializer_id: int
|
|
26
|
+
def serial(self, obj: Any, context: dict): ...
|
|
27
|
+
def deserial(self, serialized: tuple, context: dict, subs: list[Any]): ...
|
|
28
|
+
def on_deserial_error(
|
|
29
|
+
self,
|
|
30
|
+
serialized: tuple,
|
|
31
|
+
context: dict,
|
|
32
|
+
subs_serialized: list,
|
|
33
|
+
error_index: int,
|
|
34
|
+
exc: BaseException,
|
|
35
|
+
): ...
|
|
36
|
+
@classmethod
|
|
37
|
+
def register(cls, obj_type, name: str | None = None): ...
|
|
38
|
+
@classmethod
|
|
39
|
+
def unregister(cls, obj_type): ...
|
|
40
|
+
|
|
41
|
+
class Placeholder:
|
|
42
|
+
id: int
|
|
43
|
+
callbacks: list[Callable]
|
|
44
|
+
def __init__(self, id_: int): ...
|
|
45
|
+
def __hash__(self): ...
|
|
46
|
+
def __eq__(self, other): ...
|
|
47
|
+
|
|
48
|
+
def serialize(obj: Any, context: dict | None = None): ...
|
|
49
|
+
async def serialize_with_spawn(
|
|
50
|
+
obj: Any,
|
|
51
|
+
context: dict | None = None,
|
|
52
|
+
spawn_threshold: int = 100,
|
|
53
|
+
executor: Executor | None = None,
|
|
54
|
+
): ...
|
|
55
|
+
def deserialize(headers: list, buffers: list, context: dict | None = None): ...
|
|
56
|
+
def pickle_buffers(obj: Any) -> list: ...
|
|
57
|
+
def unpickle_buffers(buffers: list) -> Any: ...
|