xoscar 0.9.0__cp312-cp312-macosx_10_13_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. xoscar/__init__.py +61 -0
  2. xoscar/_utils.cpython-312-darwin.so +0 -0
  3. xoscar/_utils.pxd +36 -0
  4. xoscar/_utils.pyx +246 -0
  5. xoscar/_version.py +693 -0
  6. xoscar/aio/__init__.py +16 -0
  7. xoscar/aio/base.py +86 -0
  8. xoscar/aio/file.py +59 -0
  9. xoscar/aio/lru.py +228 -0
  10. xoscar/aio/parallelism.py +39 -0
  11. xoscar/api.py +527 -0
  12. xoscar/backend.py +67 -0
  13. xoscar/backends/__init__.py +14 -0
  14. xoscar/backends/allocate_strategy.py +160 -0
  15. xoscar/backends/communication/__init__.py +30 -0
  16. xoscar/backends/communication/base.py +315 -0
  17. xoscar/backends/communication/core.py +69 -0
  18. xoscar/backends/communication/dummy.py +253 -0
  19. xoscar/backends/communication/errors.py +20 -0
  20. xoscar/backends/communication/socket.py +444 -0
  21. xoscar/backends/communication/ucx.py +538 -0
  22. xoscar/backends/communication/utils.py +97 -0
  23. xoscar/backends/config.py +157 -0
  24. xoscar/backends/context.py +437 -0
  25. xoscar/backends/core.py +352 -0
  26. xoscar/backends/indigen/__init__.py +16 -0
  27. xoscar/backends/indigen/__main__.py +19 -0
  28. xoscar/backends/indigen/backend.py +51 -0
  29. xoscar/backends/indigen/driver.py +26 -0
  30. xoscar/backends/indigen/fate_sharing.py +221 -0
  31. xoscar/backends/indigen/pool.py +515 -0
  32. xoscar/backends/indigen/shared_memory.py +548 -0
  33. xoscar/backends/message.cpython-312-darwin.so +0 -0
  34. xoscar/backends/message.pyi +255 -0
  35. xoscar/backends/message.pyx +646 -0
  36. xoscar/backends/pool.py +1630 -0
  37. xoscar/backends/router.py +285 -0
  38. xoscar/backends/test/__init__.py +16 -0
  39. xoscar/backends/test/backend.py +38 -0
  40. xoscar/backends/test/pool.py +233 -0
  41. xoscar/batch.py +256 -0
  42. xoscar/collective/__init__.py +27 -0
  43. xoscar/collective/backend/__init__.py +13 -0
  44. xoscar/collective/backend/nccl_backend.py +160 -0
  45. xoscar/collective/common.py +102 -0
  46. xoscar/collective/core.py +737 -0
  47. xoscar/collective/process_group.py +687 -0
  48. xoscar/collective/utils.py +41 -0
  49. xoscar/collective/xoscar_pygloo.cpython-312-darwin.so +0 -0
  50. xoscar/collective/xoscar_pygloo.pyi +239 -0
  51. xoscar/constants.py +23 -0
  52. xoscar/context.cpython-312-darwin.so +0 -0
  53. xoscar/context.pxd +21 -0
  54. xoscar/context.pyx +368 -0
  55. xoscar/core.cpython-312-darwin.so +0 -0
  56. xoscar/core.pxd +51 -0
  57. xoscar/core.pyx +664 -0
  58. xoscar/debug.py +188 -0
  59. xoscar/driver.py +42 -0
  60. xoscar/errors.py +63 -0
  61. xoscar/libcpp.pxd +31 -0
  62. xoscar/metrics/__init__.py +21 -0
  63. xoscar/metrics/api.py +288 -0
  64. xoscar/metrics/backends/__init__.py +13 -0
  65. xoscar/metrics/backends/console/__init__.py +13 -0
  66. xoscar/metrics/backends/console/console_metric.py +82 -0
  67. xoscar/metrics/backends/metric.py +149 -0
  68. xoscar/metrics/backends/prometheus/__init__.py +13 -0
  69. xoscar/metrics/backends/prometheus/prometheus_metric.py +70 -0
  70. xoscar/nvutils.py +717 -0
  71. xoscar/profiling.py +260 -0
  72. xoscar/serialization/__init__.py +20 -0
  73. xoscar/serialization/aio.py +141 -0
  74. xoscar/serialization/core.cpython-312-darwin.so +0 -0
  75. xoscar/serialization/core.pxd +28 -0
  76. xoscar/serialization/core.pyi +57 -0
  77. xoscar/serialization/core.pyx +944 -0
  78. xoscar/serialization/cuda.py +111 -0
  79. xoscar/serialization/exception.py +48 -0
  80. xoscar/serialization/mlx.py +67 -0
  81. xoscar/serialization/numpy.py +82 -0
  82. xoscar/serialization/pyfury.py +37 -0
  83. xoscar/serialization/scipy.py +72 -0
  84. xoscar/serialization/torch.py +180 -0
  85. xoscar/utils.py +522 -0
  86. xoscar/virtualenv/__init__.py +34 -0
  87. xoscar/virtualenv/core.py +268 -0
  88. xoscar/virtualenv/platform.py +56 -0
  89. xoscar/virtualenv/utils.py +100 -0
  90. xoscar/virtualenv/uv.py +321 -0
  91. xoscar-0.9.0.dist-info/METADATA +230 -0
  92. xoscar-0.9.0.dist-info/RECORD +94 -0
  93. xoscar-0.9.0.dist-info/WHEEL +6 -0
  94. xoscar-0.9.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,111 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ # derived from copyright 1999-2021 Alibaba Group Holding Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, Dict, List, Tuple
17
+
18
+ import pandas as pd
19
+
20
+ from ..utils import lazy_import
21
+ from .core import Serializer, buffered
22
+
23
+ cupy = lazy_import("cupy")
24
+ cudf = lazy_import("cudf")
25
+
26
+
27
+ class CupySerializer(Serializer):
28
+ @buffered
29
+ def serial(self, obj: Any, context: Dict):
30
+ if not (obj.flags["C_CONTIGUOUS"] or obj.flags["F_CONTIGUOUS"]):
31
+ obj = cupy.array(obj, copy=True)
32
+
33
+ header = obj.__cuda_array_interface__.copy()
34
+ header["strides"] = tuple(obj.strides)
35
+ header["lengths"] = [obj.nbytes]
36
+ buffer = cupy.ndarray(
37
+ shape=(obj.nbytes,), dtype=cupy.dtype("u1"), memptr=obj.data, strides=(1,)
38
+ )
39
+ return (header,), [buffer], True
40
+
41
+ def deserial(self, serialized: Tuple, context: Dict, subs: List):
42
+ (header,) = serialized
43
+ return cupy.ndarray(
44
+ shape=header["shape"],
45
+ dtype=header["typestr"],
46
+ memptr=cupy.asarray(subs[0]).data,
47
+ strides=header["strides"],
48
+ )
49
+
50
+
51
+ class CudfSerializer(Serializer):
52
+ @staticmethod
53
+ def _get_ext_index_type(index_obj):
54
+ import cudf
55
+
56
+ multi_index_type = None
57
+ if isinstance(index_obj, pd.MultiIndex):
58
+ multi_index_type = "pandas"
59
+ elif isinstance(index_obj, cudf.MultiIndex):
60
+ multi_index_type = "cudf"
61
+
62
+ if multi_index_type is None:
63
+ return None
64
+ return {
65
+ "index_type": multi_index_type,
66
+ "names": list(index_obj.names),
67
+ }
68
+
69
+ @staticmethod
70
+ def _apply_index_type(obj, attr, header):
71
+ import cudf
72
+
73
+ multi_index_cls = (
74
+ pd.MultiIndex if header["index_type"] == "pandas" else cudf.MultiIndex
75
+ )
76
+ original_index = getattr(obj, attr)
77
+ if isinstance(original_index, (pd.MultiIndex, cudf.MultiIndex)):
78
+ return
79
+ new_index = multi_index_cls.from_tuples(original_index, names=header["names"])
80
+ setattr(obj, attr, new_index)
81
+
82
+ def serial(self, obj: Any, context: Dict):
83
+ header, buffers = obj.device_serialize()
84
+ if hasattr(obj, "columns"):
85
+ header["_ext_columns"] = self._get_ext_index_type(obj.columns)
86
+ if hasattr(obj, "index"):
87
+ header["_ext_index"] = self._get_ext_index_type(obj.index)
88
+ return (header,), buffers, True
89
+
90
+ def deserial(self, serialized: Tuple, context: Dict, buffers: List):
91
+ from cudf.core.abc import Serializable
92
+
93
+ (header,) = serialized
94
+ col_header = header.pop("_ext_columns", None)
95
+ index_header = header.pop("_ext_index", None)
96
+
97
+ result = Serializable.device_deserialize(header, buffers)
98
+
99
+ if col_header is not None:
100
+ self._apply_index_type(result, "columns", col_header)
101
+ if index_header is not None:
102
+ self._apply_index_type(result, "index", index_header)
103
+ return result
104
+
105
+
106
+ if cupy is not None:
107
+ CupySerializer.register("cupy.ndarray")
108
+ if cudf is not None:
109
+ CudfSerializer.register("cudf.DataFrame")
110
+ CudfSerializer.register("cudf.Series")
111
+ CudfSerializer.register("cudf.Index")
@@ -0,0 +1,48 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ # derived from copyright 1999-2021 Alibaba Group Holding Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import pickle # nosec # pylint: disable=import_pickle
19
+
20
+ from .core import Serializer, buffered, pickle_buffers, unpickle_buffers
21
+
22
+
23
+ class UnpickleableError(Exception):
24
+ def __init__(self, raw_error: str | Exception):
25
+ if isinstance(raw_error, str):
26
+ super().__init__(raw_error)
27
+ else:
28
+ super().__init__(
29
+ f"Error cannot be pickled, "
30
+ f"error type: {type(raw_error)}, "
31
+ f"raw error:\n{raw_error}"
32
+ )
33
+
34
+
35
+ class ExceptionSerializer(Serializer):
36
+ @buffered
37
+ def serial(self, obj: Exception, context: dict):
38
+ try:
39
+ buffers = pickle_buffers(obj)
40
+ except (TypeError, pickle.PicklingError):
41
+ buffers = pickle_buffers(UnpickleableError(obj))
42
+ return (), buffers, True
43
+
44
+ def deserial(self, serialized: tuple, context: dict, subs: list):
45
+ return unpickle_buffers(subs)
46
+
47
+
48
+ ExceptionSerializer.register(Exception)
@@ -0,0 +1,67 @@
1
+ # Copyright 2022-2025 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, List
16
+
17
+ import numpy as np
18
+
19
+ from ..utils import lazy_import
20
+ from .core import Serializer, buffered
21
+
22
+ mx = lazy_import("mlx.core")
23
+
24
+
25
+ dtype_map = {
26
+ "b": np.int8,
27
+ "B": np.uint8,
28
+ "h": np.int16,
29
+ "H": np.uint16,
30
+ "i": np.int32,
31
+ "I": np.uint32,
32
+ "q": np.int64,
33
+ "Q": np.uint64,
34
+ "e": np.float16,
35
+ "f": np.float32,
36
+ "d": np.float64,
37
+ }
38
+
39
+
40
+ class MLXSerislizer(Serializer):
41
+ @buffered
42
+ def serial(self, obj: "mx.array", context: dict): # type: ignore
43
+ ravel_obj = obj.reshape(-1).view(mx.uint8)
44
+ mv = memoryview(ravel_obj)
45
+ header = dict(
46
+ shape=obj.shape, format=mv.format, dtype=str(obj.dtype).rsplit(".", 1)[-1]
47
+ )
48
+ if not mv.c_contiguous:
49
+ # NOTE: we only consider c contiguous here,
50
+ # MLX has no way to create f contiguous arrays.
51
+ mv = memoryview(bytes(mv))
52
+ return (header,), [mv], True
53
+
54
+ def deserial(self, serialized: tuple, context: dict, subs: List[Any]):
55
+ header = serialized[0]
56
+ shape, format, dtype = header["shape"], header["format"], header["dtype"]
57
+ mv = memoryview(subs[0])
58
+ if mv.format != format:
59
+ dtype = dtype_map.get(format, np.uint8)
60
+ np_arr = np.frombuffer(mv, dtype=dtype).reshape(shape) # parse
61
+ mv = memoryview(np_arr) # recreate memoryview
62
+ ravel_array = mx.array(mv)
63
+ return ravel_array.view(getattr(mx, dtype)).reshape(shape)
64
+
65
+
66
+ if mx is not None:
67
+ MLXSerislizer.register(mx.array)
@@ -0,0 +1,82 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ # derived from copyright 1999-2021 Alibaba Group Holding Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, Dict, List, Tuple
17
+
18
+ import numpy as np
19
+
20
+ from .core import Serializer, buffered, pickle_buffers, unpickle_buffers
21
+
22
+
23
+ class NDArraySerializer(Serializer):
24
+ @buffered
25
+ def serial(self, obj: np.ndarray, context: Dict):
26
+ header: dict = {}
27
+ if obj.dtype.hasobject:
28
+ header["pickle"] = True
29
+ buffers = pickle_buffers(obj)
30
+ return (header,), buffers, True
31
+
32
+ order = "C"
33
+ if obj.flags.f_contiguous:
34
+ order = "F"
35
+ elif not obj.flags.c_contiguous:
36
+ obj = np.ascontiguousarray(obj)
37
+ try:
38
+ desc = np.lib.format.dtype_to_descr(obj.dtype)
39
+ dtype_new_order = None
40
+ except ValueError:
41
+ # for structured dtype, array[[field2, field1]] will create a view,
42
+ # and dtype_to_desc will fail due to the order
43
+ fields = obj.dtype.fields
44
+ new_fields = sorted(fields, key=lambda k: fields[k][1])
45
+ desc = np.lib.format.dtype_to_descr(obj.dtype[new_fields])
46
+ dtype_new_order = list(fields)
47
+ header.update(
48
+ dict(
49
+ pickle=False,
50
+ descr=desc,
51
+ dtype_new_order=dtype_new_order,
52
+ shape=list(obj.shape),
53
+ strides=list(obj.strides),
54
+ order=order,
55
+ )
56
+ )
57
+ return (header,), [memoryview(obj.ravel(order=order).view("uint8").data)], True # type: ignore
58
+
59
+ def deserial(self, serialized: Tuple, context: Dict, subs: List[Any]):
60
+ header = serialized[0]
61
+ if header["pickle"]:
62
+ return unpickle_buffers(subs)
63
+
64
+ try:
65
+ dtype = np.lib.format.descr_to_dtype(header["descr"])
66
+ except AttributeError: # pragma: no cover
67
+ # for older numpy versions, descr_to_dtype is not implemented
68
+ dtype = np.dtype(header["descr"])
69
+
70
+ dtype_new_order = header["dtype_new_order"]
71
+ if dtype_new_order:
72
+ dtype = dtype[dtype_new_order]
73
+ return np.ndarray(
74
+ shape=tuple(header["shape"]),
75
+ dtype=dtype,
76
+ buffer=subs[0],
77
+ strides=tuple(header["strides"]),
78
+ order=header["order"],
79
+ )
80
+
81
+
82
+ NDArraySerializer.register(np.ndarray)
@@ -0,0 +1,37 @@
1
+ import os
2
+ import threading
3
+
4
+ _fury = threading.local()
5
+ _fury_not_installed = object()
6
+ _register_class_list = set()
7
+
8
+
9
+ def register_classes(*args):
10
+ instance = get_fury()
11
+ if instance is not None:
12
+ _register_class_list.update(args)
13
+ for c in _register_class_list:
14
+ instance.register_class(c)
15
+
16
+
17
+ def get_fury():
18
+ if os.environ.get("USE_FURY") in ("1", "true", "True"):
19
+ instance = getattr(_fury, "instance", None)
20
+ if instance is _fury_not_installed: # pragma: no cover
21
+ return None
22
+ if instance is not None:
23
+ return instance
24
+ else:
25
+ try:
26
+ import pyfury
27
+
28
+ _fury.instance = instance = pyfury.Fury(
29
+ language=pyfury.Language.PYTHON, require_class_registration=False
30
+ )
31
+ for c in _register_class_list: # pragma: no cover
32
+ instance.register_class(c)
33
+ print("pyfury is enabled.")
34
+ except ImportError: # pragma: no cover
35
+ print("pyfury is not installed.")
36
+ _fury.instance = _fury_not_installed
37
+ return instance
@@ -0,0 +1,72 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ # derived from copyright 1999-2021 Alibaba Group Holding Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, Dict, List, Tuple
17
+
18
+ import numpy as np
19
+
20
+ try:
21
+ import scipy.sparse as sps
22
+ except ImportError: # pragma: no cover
23
+ sps = None
24
+
25
+ from .core import Serializer, buffered, deserialize, serialize
26
+
27
+
28
+ class CsrMatrixSerializer(Serializer):
29
+ @buffered
30
+ def serial(self, obj: Any, context: Dict):
31
+ data_header, data_buffers = serialize(obj.data)
32
+ idx_header, idx_buffers = serialize(obj.indices)
33
+ indptr_header, indptr_buffers = serialize(obj.indptr)
34
+ header = (
35
+ data_header, # data_header
36
+ len(data_buffers), # data_buf_num
37
+ idx_header, # idx_header
38
+ len(idx_buffers), # idx_buf_num
39
+ indptr_header, # indptr_header
40
+ obj.shape, # shape
41
+ )
42
+ return header, data_buffers + idx_buffers + indptr_buffers, True
43
+
44
+ def deserial(self, serialized: Tuple, context: Dict, subs: List):
45
+ (
46
+ data_header,
47
+ data_buf_num,
48
+ idx_header,
49
+ idx_buf_num,
50
+ indptr_header,
51
+ shape,
52
+ ) = serialized
53
+ data_buffers = subs[:data_buf_num]
54
+ idx_buffers = subs[data_buf_num : data_buf_num + idx_buf_num]
55
+ indptr_buffers = subs[data_buf_num + idx_buf_num :]
56
+
57
+ data = deserialize(data_header, data_buffers)
58
+ indices = deserialize(idx_header, idx_buffers)
59
+ indptr = deserialize(indptr_header, indptr_buffers)
60
+ shape = tuple(shape)
61
+
62
+ empty_arr = np.zeros(0, dtype=data.dtype)
63
+
64
+ target_csr = sps.coo_matrix(
65
+ (empty_arr, (empty_arr,) * 2), dtype=data.dtype, shape=shape
66
+ ).tocsr()
67
+ target_csr.data, target_csr.indices, target_csr.indptr = data, indices, indptr
68
+ return target_csr
69
+
70
+
71
+ if sps: # pragma: no branch
72
+ CsrMatrixSerializer.register(sps.csr_matrix)
@@ -0,0 +1,180 @@
1
+ # Copyright 2024 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, List, Tuple
16
+
17
+ import numpy as np
18
+
19
+ from ..utils import lazy_import
20
+ from .core import Serializer, buffered
21
+
22
+ # lazy import PyTorch to avoid enforced dependency
23
+ torch = lazy_import("torch")
24
+ cupy = lazy_import("cupy")
25
+
26
+
27
+ def rmm_to_torch(buf):
28
+ cupy_arr = cupy.asarray(buf) # zero-copy
29
+ torch_tensor = torch.utils.dlpack.from_dlpack(cupy_arr.toDlpack()) # zero-copy
30
+ return torch_tensor
31
+
32
+
33
+ class TorchTensorSerializer(Serializer):
34
+ @buffered
35
+ def serial(self, obj: "torch.Tensor", context: Dict): # type: ignore
36
+ # for cpu tensor, use memory viewpoint
37
+ if obj.device.type == "cpu":
38
+ # make sure tensor is contiguous
39
+ if not obj.is_contiguous():
40
+ obj = obj.contiguous()
41
+ # get memory viewpoint and collect header information
42
+ header = {
43
+ "shape": tuple(obj.shape),
44
+ "dtype": str(obj.dtype),
45
+ "device": obj.device.type,
46
+ "requires_grad": obj.requires_grad,
47
+ "strides": tuple(obj.stride()),
48
+ }
49
+ # Try zero-copy path: torch -> numpy
50
+ try:
51
+ np_array = obj.numpy() # zero-copy if supported
52
+ buffer = memoryview(np_array)
53
+ header["format"] = "numpy"
54
+ return (header,), [buffer], True
55
+
56
+ except Exception:
57
+ # Fallback: copy to uint8 bytes
58
+ # This works for ANY dtype
59
+ byte_tensor = obj.view(torch.uint8).clone() # explicit copy
60
+ np_array = byte_tensor.numpy()
61
+ buffer = memoryview(np_array)
62
+ header["format"] = "bytes"
63
+ return (header,), [buffer], True
64
+ elif obj.device.type == "cuda":
65
+ # for CUDA, use __cuda_array_interface__
66
+ if not (
67
+ obj.is_contiguous()
68
+ or obj.is_contiguous(memory_format=torch.channels_last)
69
+ ):
70
+ obj = obj.contiguous()
71
+
72
+ # get cuda array interface information
73
+ header = {
74
+ "shape": tuple(obj.shape),
75
+ "dtype": str(obj.dtype),
76
+ "device": obj.device.type,
77
+ "device_index": obj.device.index,
78
+ "requires_grad": obj.requires_grad,
79
+ "strides": tuple(obj.stride()),
80
+ }
81
+
82
+ # ---- Core idea: expose raw CUDA memory as a uint8 buffer (zero-copy) ----
83
+ # Get the underlying untyped storage that actually owns the CUDA memory
84
+ storage = obj.untyped_storage()
85
+
86
+ # Create a uint8 CUDA tensor that will act as a byte-view of the same memory
87
+ # This does NOT allocate new GPU memory; it only creates a new Tensor wrapper
88
+ buffer = torch.empty(
89
+ (storage.nbytes(),),
90
+ dtype=torch.uint8,
91
+ device=obj.device,
92
+ )
93
+
94
+ # Make the buffer tensor share the same storage (zero-copy)
95
+ buffer.set_(storage) # type: ignore[attr-defined]
96
+
97
+ # Return: (metadata,), [raw CUDA buffer], mark as buffered
98
+ return (header,), [buffer], True
99
+ else:
100
+ # for unsupported device
101
+ raise NotImplementedError(f"Unsupported device type: {obj.device.type}")
102
+
103
+ def deserial(self, serialized: Tuple, context: Dict, subs: List[Any]):
104
+ header = serialized[0]
105
+ device = header["device"]
106
+ data_buffer = subs[0]
107
+
108
+ if device == "cpu":
109
+ fmt = header.get("format", "numpy")
110
+ dtype_name = header["dtype"].split(".")[-1]
111
+ dtype = getattr(torch, dtype_name)
112
+
113
+ if fmt == "numpy":
114
+ # zero-copy path
115
+ np_array = np.frombuffer(data_buffer, dtype=np.dtype(dtype_name))
116
+ tensor = torch.from_numpy(np_array).view(header["shape"])
117
+
118
+ else:
119
+ # bytes path (copy unavoidable)
120
+ byte_np = np.frombuffer(data_buffer, dtype=np.uint8)
121
+ byte_tensor = torch.from_numpy(byte_np)
122
+
123
+ tensor = torch.empty(0, dtype=dtype)
124
+ tensor = tensor.set_(
125
+ byte_tensor.untyped_storage(),
126
+ storage_offset=0,
127
+ size=tuple(header["shape"]),
128
+ stride=tuple(header["strides"]),
129
+ )
130
+
131
+ if header.get("requires_grad"):
132
+ tensor.requires_grad_(True)
133
+
134
+ elif device == "cuda":
135
+ # Unpack metadata
136
+ (header,) = serialized
137
+
138
+ # Raw CUDA buffer (uint8 tensor)
139
+ buffer = subs[0]
140
+ if not isinstance(buffer, torch.Tensor):
141
+ buffer = rmm_to_torch(buffer)
142
+ assert buffer.is_cuda, "buffer must be a CUDA tensor"
143
+ assert buffer.dtype == torch.uint8, "buffer must be uint8"
144
+
145
+ # Get the shared CUDA storage
146
+ storage = buffer.untyped_storage()
147
+
148
+ # Resolve original dtype
149
+ dtype_name = header["dtype"].split(".")[-1]
150
+ dtype = getattr(torch, dtype_name)
151
+
152
+ # Create an empty tensor wrapper with the correct dtype
153
+ # This does NOT allocate new GPU memory for data
154
+ tensor = torch.empty(0, device=buffer.device, dtype=dtype)
155
+
156
+ # Bind the tensor to the same storage with original shape/stride
157
+ # Pure zero-copy: only a new Tensor view, no data movement
158
+ tensor.set_(
159
+ storage,
160
+ storage_offset=0,
161
+ size=tuple(header["shape"]),
162
+ stride=tuple(header["strides"]),
163
+ )
164
+
165
+ # Restore requires_grad if needed
166
+ if header.get("requires_grad"):
167
+ tensor.requires_grad_(True)
168
+
169
+ return tensor
170
+ else:
171
+ raise NotImplementedError(f"Unsupported device type: {device}")
172
+
173
+ # recover requires_grad attributes
174
+ tensor.requires_grad = header["requires_grad"]
175
+ return tensor
176
+
177
+
178
+ # only when torch is available, we register module
179
+ if torch is not None:
180
+ TorchTensorSerializer.register("torch.Tensor")