comfy-env 0.0.64__py3-none-any.whl → 0.0.66__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. comfy_env/__init__.py +70 -122
  2. comfy_env/cli.py +78 -7
  3. comfy_env/config/__init__.py +19 -0
  4. comfy_env/config/parser.py +151 -0
  5. comfy_env/config/types.py +64 -0
  6. comfy_env/install.py +83 -361
  7. comfy_env/isolation/__init__.py +9 -0
  8. comfy_env/isolation/wrap.py +351 -0
  9. comfy_env/nodes.py +2 -2
  10. comfy_env/pixi/__init__.py +48 -0
  11. comfy_env/pixi/core.py +356 -0
  12. comfy_env/{resolver.py → pixi/resolver.py} +1 -14
  13. comfy_env/prestartup.py +60 -0
  14. comfy_env/templates/comfy-env-instructions.txt +30 -87
  15. comfy_env/templates/comfy-env.toml +68 -136
  16. comfy_env/workers/__init__.py +21 -32
  17. comfy_env/workers/base.py +1 -1
  18. comfy_env/workers/{torch_mp.py → mp.py} +47 -14
  19. comfy_env/workers/{venv.py → subprocess.py} +405 -441
  20. {comfy_env-0.0.64.dist-info → comfy_env-0.0.66.dist-info}/METADATA +2 -1
  21. comfy_env-0.0.66.dist-info/RECORD +34 -0
  22. comfy_env/decorator.py +0 -700
  23. comfy_env/env/__init__.py +0 -47
  24. comfy_env/env/config.py +0 -201
  25. comfy_env/env/config_file.py +0 -740
  26. comfy_env/env/manager.py +0 -636
  27. comfy_env/env/security.py +0 -267
  28. comfy_env/ipc/__init__.py +0 -55
  29. comfy_env/ipc/bridge.py +0 -476
  30. comfy_env/ipc/protocol.py +0 -265
  31. comfy_env/ipc/tensor.py +0 -371
  32. comfy_env/ipc/torch_bridge.py +0 -401
  33. comfy_env/ipc/transport.py +0 -318
  34. comfy_env/ipc/worker.py +0 -221
  35. comfy_env/isolation.py +0 -310
  36. comfy_env/pixi.py +0 -760
  37. comfy_env/stub_imports.py +0 -270
  38. comfy_env/stubs/__init__.py +0 -1
  39. comfy_env/stubs/comfy/__init__.py +0 -6
  40. comfy_env/stubs/comfy/model_management.py +0 -58
  41. comfy_env/stubs/comfy/utils.py +0 -29
  42. comfy_env/stubs/folder_paths.py +0 -71
  43. comfy_env/workers/pool.py +0 -241
  44. comfy_env-0.0.64.dist-info/RECORD +0 -48
  45. /comfy_env/{env/cuda_gpu_detection.py → pixi/cuda_detection.py} +0 -0
  46. /comfy_env/{env → pixi}/platform/__init__.py +0 -0
  47. /comfy_env/{env → pixi}/platform/base.py +0 -0
  48. /comfy_env/{env → pixi}/platform/darwin.py +0 -0
  49. /comfy_env/{env → pixi}/platform/linux.py +0 -0
  50. /comfy_env/{env → pixi}/platform/windows.py +0 -0
  51. /comfy_env/{registry.py → pixi/registry.py} +0 -0
  52. /comfy_env/{wheel_sources.yml → pixi/wheel_sources.yml} +0 -0
  53. {comfy_env-0.0.64.dist-info → comfy_env-0.0.66.dist-info}/WHEEL +0 -0
  54. {comfy_env-0.0.64.dist-info → comfy_env-0.0.66.dist-info}/entry_points.txt +0 -0
  55. {comfy_env-0.0.64.dist-info → comfy_env-0.0.66.dist-info}/licenses/LICENSE +0 -0
comfy_env/ipc/protocol.py DELETED
@@ -1,265 +0,0 @@
1
- """
2
- IPC Protocol - Message format for bridge-worker communication.
3
-
4
- Uses JSON for simplicity and debuggability. Large binary data (images, tensors)
5
- is serialized efficiently:
6
- - Tensors: Zero-copy via CUDA IPC or shared memory (see tensor.py)
7
- - Images: PNG encoded + base64
8
- - Other: pickle + base64 fallback
9
- """
10
-
11
- import json
12
- import base64
13
- import pickle
14
- import logging
15
- from dataclasses import dataclass, field, asdict
16
- from typing import Any, Dict, Optional
17
-
18
- logger = logging.getLogger(__name__)
19
-
20
- # Flag to enable/disable IPC tensor sharing (set based on process context)
21
- _use_tensor_ipc = True
22
-
23
-
24
- def set_tensor_ipc_enabled(enabled: bool) -> None:
25
- """Enable or disable IPC tensor sharing."""
26
- global _use_tensor_ipc
27
- _use_tensor_ipc = enabled
28
-
29
-
30
- def get_tensor_ipc_enabled() -> bool:
31
- """Check if IPC tensor sharing is enabled."""
32
- return _use_tensor_ipc
33
-
34
-
35
- @dataclass
36
- class Request:
37
- """
38
- Request message from bridge to worker.
39
-
40
- Attributes:
41
- id: Unique request ID for matching responses
42
- method: Method name to call on worker
43
- args: Keyword arguments for the method
44
- """
45
- id: str
46
- method: str
47
- args: Dict[str, Any] = field(default_factory=dict)
48
-
49
- def to_json(self) -> str:
50
- """Serialize to JSON string."""
51
- return json.dumps(asdict(self))
52
-
53
- @classmethod
54
- def from_json(cls, data: str) -> "Request":
55
- """Deserialize from JSON string."""
56
- d = json.loads(data)
57
- return cls(**d)
58
-
59
-
60
- @dataclass
61
- class Response:
62
- """
63
- Response message from worker to bridge.
64
-
65
- Attributes:
66
- id: Request ID this is responding to
67
- result: Result value (None if error)
68
- error: Error message (None if success)
69
- traceback: Full traceback string (only if error)
70
- """
71
- id: str
72
- result: Any = None
73
- error: Optional[str] = None
74
- traceback: Optional[str] = None
75
-
76
- @property
77
- def success(self) -> bool:
78
- """Check if response indicates success."""
79
- return self.error is None
80
-
81
- def to_json(self) -> str:
82
- """Serialize to JSON string."""
83
- return json.dumps(asdict(self))
84
-
85
- @classmethod
86
- def from_json(cls, data: str) -> "Response":
87
- """Deserialize from JSON string."""
88
- d = json.loads(data)
89
- return cls(**d)
90
-
91
-
92
- def encode_binary(data: bytes) -> str:
93
- """Encode binary data as base64 string."""
94
- return base64.b64encode(data).decode('utf-8')
95
-
96
-
97
- def decode_binary(encoded: str) -> bytes:
98
- """Decode base64 string to binary data."""
99
- return base64.b64decode(encoded)
100
-
101
-
102
- def encode_object(obj: Any) -> Dict[str, Any]:
103
- """
104
- Encode a Python object for JSON serialization.
105
-
106
- Returns a dict with _type and _data keys for special types,
107
- or the original object if it's JSON-serializable.
108
-
109
- Special handling:
110
- - PyTorch tensors: Zero-copy via CUDA IPC or shared memory
111
- - PIL Images: PNG encoded
112
- - Complex objects: pickle fallback
113
- """
114
- if obj is None:
115
- return None
116
-
117
- # Handle torch tensors - try zero-copy IPC first
118
- if hasattr(obj, 'cpu') and hasattr(obj, 'numpy'):
119
- try:
120
- import torch
121
- if isinstance(obj, torch.Tensor) and _use_tensor_ipc:
122
- try:
123
- from .tensor import serialize_tensor
124
- return serialize_tensor(obj)
125
- except Exception as e:
126
- # Fall back to pickle method if IPC fails
127
- logger.debug(f"Tensor IPC failed, using pickle: {e}")
128
- except ImportError:
129
- pass
130
-
131
- # Fallback: pickle the numpy array
132
- arr = obj.cpu().numpy()
133
- return {
134
- "_type": "tensor_pickle",
135
- "_dtype": str(arr.dtype),
136
- "_shape": list(arr.shape),
137
- "_data": encode_binary(pickle.dumps(arr)),
138
- }
139
-
140
- # Handle numpy arrays
141
- if hasattr(obj, '__array__'):
142
- import numpy as np
143
- arr = np.asarray(obj)
144
- return {
145
- "_type": "numpy",
146
- "_dtype": str(arr.dtype),
147
- "_shape": list(arr.shape),
148
- "_data": encode_binary(pickle.dumps(arr)),
149
- }
150
-
151
- # Handle PIL Images
152
- if hasattr(obj, 'save') and hasattr(obj, 'mode'):
153
- import io
154
- buffer = io.BytesIO()
155
- obj.save(buffer, format="PNG")
156
- return {
157
- "_type": "image",
158
- "_format": "PNG",
159
- "_data": encode_binary(buffer.getvalue()),
160
- }
161
-
162
- # Handle bytes
163
- if isinstance(obj, bytes):
164
- return {
165
- "_type": "bytes",
166
- "_data": encode_binary(obj),
167
- }
168
-
169
- # Handle lists/tuples recursively
170
- if isinstance(obj, (list, tuple)):
171
- encoded = [encode_object(item) for item in obj]
172
- return {
173
- "_type": "list" if isinstance(obj, list) else "tuple",
174
- "_data": encoded,
175
- }
176
-
177
- # Handle dicts recursively
178
- if isinstance(obj, dict):
179
- return {k: encode_object(v) for k, v in obj.items()}
180
-
181
- # For simple objects with __dict__, serialize as dict
182
- # This avoids pickle module path issues across process boundaries
183
- if hasattr(obj, '__dict__') and not hasattr(obj, '__slots__'):
184
- return {
185
- "_type": "object",
186
- "_class": obj.__class__.__name__,
187
- "_data": {k: encode_object(v) for k, v in obj.__dict__.items()},
188
- }
189
-
190
- # For complex objects that can't be JSON serialized, use pickle
191
- try:
192
- json.dumps(obj)
193
- return obj # JSON-serializable, return as-is
194
- except (TypeError, ValueError):
195
- return {
196
- "_type": "pickle",
197
- "_data": encode_binary(pickle.dumps(obj)),
198
- }
199
-
200
-
201
- def decode_object(obj: Any) -> Any:
202
- """
203
- Decode a JSON-deserialized object back to Python types.
204
-
205
- Reverses the encoding done by encode_object.
206
- """
207
- if obj is None:
208
- return None
209
-
210
- if not isinstance(obj, dict):
211
- return obj
212
-
213
- # Check for special encoded types
214
- obj_type = obj.get("_type")
215
-
216
- # Handle zero-copy tensor IPC
217
- if obj_type == "tensor_ipc":
218
- try:
219
- from .tensor import deserialize_tensor
220
- return deserialize_tensor(obj)
221
- except Exception as e:
222
- logger.error(f"Failed to deserialize tensor via IPC: {e}")
223
- raise
224
-
225
- # Handle pickle fallback for tensors
226
- if obj_type == "tensor_pickle":
227
- import torch
228
- arr = pickle.loads(decode_binary(obj["_data"]))
229
- return torch.from_numpy(arr)
230
-
231
- # Legacy types for backwards compatibility
232
- if obj_type == "numpy":
233
- return pickle.loads(decode_binary(obj["_data"]))
234
-
235
- if obj_type in ("tensor", "comfyui_image", "comfyui_mask"):
236
- import torch
237
- arr = pickle.loads(decode_binary(obj["_data"]))
238
- return torch.from_numpy(arr)
239
-
240
- if obj_type == "image":
241
- import io
242
- from PIL import Image
243
- buffer = io.BytesIO(decode_binary(obj["_data"]))
244
- return Image.open(buffer)
245
-
246
- if obj_type == "bytes":
247
- return decode_binary(obj["_data"])
248
-
249
- if obj_type == "pickle":
250
- return pickle.loads(decode_binary(obj["_data"]))
251
-
252
- # Simple object serialized as dict - restore as SimpleNamespace
253
- if obj_type == "object":
254
- from types import SimpleNamespace
255
- data = {k: decode_object(v) for k, v in obj["_data"].items()}
256
- ns = SimpleNamespace(**data)
257
- ns._class_name = obj.get("_class", "unknown")
258
- return ns
259
-
260
- if obj_type in ("list", "tuple"):
261
- decoded = [decode_object(item) for item in obj["_data"]]
262
- return decoded if obj_type == "list" else tuple(decoded)
263
-
264
- # Regular dict - decode values recursively
265
- return {k: decode_object(v) for k, v in obj.items()}
comfy_env/ipc/tensor.py DELETED
@@ -1,371 +0,0 @@
1
- """
2
- Tensor Serialization - Zero-copy tensor sharing via CUDA IPC and shared memory.
3
-
4
- This module provides efficient tensor transfer between processes:
5
- - CUDA tensors: Use CUDA IPC handles (zero-copy, ~0ms for any size)
6
- - CPU tensors: Use shared memory via file_system strategy (zero-copy)
7
-
8
- Based on patterns from pyisolate's tensor_serializer.py.
9
- """
10
-
11
- import base64
12
- import collections
13
- import logging
14
- import sys
15
- import threading
16
- import time
17
- from typing import Any, Dict, Optional
18
-
19
- logger = logging.getLogger(__name__)
20
-
21
-
22
- # ---------------------------------------------------------------------------
23
- # TensorKeeper - Prevents GC race conditions
24
- # ---------------------------------------------------------------------------
25
-
26
- class TensorKeeper:
27
- """
28
- Keeps strong references to serialized tensors to prevent premature GC.
29
-
30
- When we serialize a tensor for IPC, we return a handle/path to shared memory.
31
- If the original tensor is garbage collected before the receiver opens the
32
- shared memory, the data is lost. TensorKeeper holds references for a short
33
- window to prevent this race condition.
34
-
35
- Based on pyisolate's TensorKeeper pattern.
36
- """
37
-
38
- def __init__(self, retention_seconds: float = 30.0):
39
- """
40
- Args:
41
- retention_seconds: How long to keep references (default 30s)
42
- """
43
- self.retention_seconds = retention_seconds
44
- self._keeper: collections.deque = collections.deque()
45
- self._lock = threading.Lock()
46
-
47
- def keep(self, tensor: Any) -> None:
48
- """
49
- Keep a reference to a tensor.
50
-
51
- Args:
52
- tensor: The tensor to keep alive
53
- """
54
- now = time.time()
55
- with self._lock:
56
- self._keeper.append((now, tensor))
57
-
58
- # Cleanup old references
59
- while self._keeper:
60
- timestamp, _ = self._keeper[0]
61
- if now - timestamp > self.retention_seconds:
62
- self._keeper.popleft()
63
- else:
64
- break
65
-
66
-
67
- # Global tensor keeper instance
68
- _tensor_keeper = TensorKeeper()
69
-
70
-
71
- # ---------------------------------------------------------------------------
72
- # Tensor Serialization
73
- # ---------------------------------------------------------------------------
74
-
75
- def serialize_tensor(tensor: Any) -> Dict[str, Any]:
76
- """
77
- Serialize a PyTorch tensor for IPC transfer.
78
-
79
- Uses zero-copy methods when possible:
80
- - CUDA tensors: CUDA IPC handles
81
- - CPU tensors: Shared memory (file_system strategy)
82
-
83
- Args:
84
- tensor: PyTorch tensor to serialize
85
-
86
- Returns:
87
- Dict with tensor metadata and IPC handle/path
88
- """
89
- import torch
90
-
91
- if not isinstance(tensor, torch.Tensor):
92
- raise TypeError(f"Expected torch.Tensor, got {type(tensor)}")
93
-
94
- if tensor.is_cuda:
95
- return _serialize_cuda_tensor(tensor)
96
- else:
97
- return _serialize_cpu_tensor(tensor)
98
-
99
-
100
- def _serialize_cpu_tensor(tensor: Any) -> Dict[str, Any]:
101
- """
102
- Serialize CPU tensor using shared memory (file_system strategy).
103
-
104
- The tensor is moved to shared memory, and we return the path/key
105
- that the receiver can use to access it.
106
- """
107
- import torch
108
- import torch.multiprocessing.reductions as reductions
109
-
110
- # Keep tensor alive until receiver opens shared memory
111
- _tensor_keeper.keep(tensor)
112
-
113
- # Move to shared memory if not already
114
- if not tensor.is_shared():
115
- tensor.share_memory_()
116
-
117
- # Get storage reduction info
118
- storage = tensor.untyped_storage()
119
- sfunc, sargs = reductions.reduce_storage(storage)
120
-
121
- if sfunc.__name__ == 'rebuild_storage_filename':
122
- # file_system strategy - sargs: (cls, manager_path, storage_key, size)
123
- return {
124
- "_type": "tensor_ipc",
125
- "device": "cpu",
126
- "strategy": "file_system",
127
- "manager_path": sargs[1].decode('utf-8') if isinstance(sargs[1], bytes) else sargs[1],
128
- "storage_key": sargs[2].decode('utf-8') if isinstance(sargs[2], bytes) else sargs[2],
129
- "storage_size": sargs[3],
130
- "dtype": str(tensor.dtype),
131
- "shape": list(tensor.shape),
132
- "stride": list(tensor.stride()),
133
- "offset": tensor.storage_offset(),
134
- "requires_grad": tensor.requires_grad,
135
- }
136
- elif sfunc.__name__ == 'rebuild_storage_fd':
137
- # Force file_system strategy for compatibility
138
- import torch.multiprocessing as mp
139
- mp.set_sharing_strategy('file_system')
140
- tensor.share_memory_()
141
- return _serialize_cpu_tensor(tensor)
142
- else:
143
- # Fallback: pickle the tensor data (slow path)
144
- logger.warning(f"Unknown storage reduction: {sfunc.__name__}, using pickle fallback")
145
- return _serialize_tensor_fallback(tensor)
146
-
147
-
148
- def _serialize_cuda_tensor(tensor: Any) -> Dict[str, Any]:
149
- """
150
- Serialize CUDA tensor using CUDA IPC handles.
151
-
152
- This is zero-copy - we only transfer the IPC handle, not the data.
153
- The receiver uses the handle to map the same GPU memory.
154
- """
155
- import torch
156
- import torch.multiprocessing.reductions as reductions
157
-
158
- try:
159
- func, args = reductions.reduce_tensor(tensor)
160
- except RuntimeError as e:
161
- if "received from another process" in str(e):
162
- # Tensor was received via IPC and can't be re-shared
163
- # Need to clone it (expensive but necessary)
164
- tensor_size_mb = tensor.numel() * tensor.element_size() / (1024 * 1024)
165
- if tensor_size_mb > 100:
166
- logger.warning(
167
- f"Cloning large CUDA tensor ({tensor_size_mb:.1f}MB) - "
168
- "consider avoiding returning unmodified input tensors"
169
- )
170
- tensor = tensor.clone()
171
- func, args = reductions.reduce_tensor(tensor)
172
- else:
173
- raise
174
-
175
- # Keep tensor alive until receiver maps it
176
- _tensor_keeper.keep(tensor)
177
-
178
- # args structure for CUDA tensor:
179
- # (cls, size, stride, offset, storage_type, dtype, device_idx, handle,
180
- # storage_size, storage_offset, requires_grad, ref_counter_handle,
181
- # ref_counter_offset, event_handle, event_sync_required)
182
- return {
183
- "_type": "tensor_ipc",
184
- "device": "cuda",
185
- "device_idx": args[6],
186
- "shape": list(args[1]),
187
- "stride": list(args[2]),
188
- "offset": args[3],
189
- "dtype": str(args[5]),
190
- "handle": base64.b64encode(args[7]).decode('ascii'),
191
- "storage_size": args[8],
192
- "storage_offset": args[9],
193
- "requires_grad": args[10],
194
- "ref_counter_handle": base64.b64encode(args[11]).decode('ascii'),
195
- "ref_counter_offset": args[12],
196
- "event_handle": base64.b64encode(args[13]).decode('ascii') if args[13] else None,
197
- "event_sync_required": args[14],
198
- }
199
-
200
-
201
- def _serialize_tensor_fallback(tensor: Any) -> Dict[str, Any]:
202
- """
203
- Fallback serialization using pickle (slow, copies data).
204
-
205
- Used when zero-copy methods aren't available.
206
- """
207
- import pickle
208
-
209
- arr = tensor.cpu().numpy()
210
- return {
211
- "_type": "tensor_pickle",
212
- "dtype": str(tensor.dtype),
213
- "shape": list(tensor.shape),
214
- "device": str(tensor.device),
215
- "data": base64.b64encode(pickle.dumps(arr)).decode('ascii'),
216
- }
217
-
218
-
219
- # ---------------------------------------------------------------------------
220
- # Tensor Deserialization
221
- # ---------------------------------------------------------------------------
222
-
223
- def deserialize_tensor(data: Dict[str, Any]) -> Any:
224
- """
225
- Deserialize a tensor from IPC format.
226
-
227
- Args:
228
- data: Dict with tensor metadata from serialize_tensor
229
-
230
- Returns:
231
- PyTorch tensor
232
- """
233
- import torch
234
-
235
- # Already a tensor (shouldn't happen, but handle gracefully)
236
- if isinstance(data, torch.Tensor):
237
- return data
238
-
239
- obj_type = data.get("_type")
240
-
241
- if obj_type == "tensor_ipc":
242
- device = data.get("device", "cpu")
243
- if device == "cuda":
244
- return _deserialize_cuda_tensor(data)
245
- else:
246
- return _deserialize_cpu_tensor(data)
247
- elif obj_type == "tensor_pickle":
248
- return _deserialize_tensor_fallback(data)
249
- else:
250
- raise ValueError(f"Unknown tensor type: {obj_type}")
251
-
252
-
253
- def _deserialize_cpu_tensor(data: Dict[str, Any]) -> Any:
254
- """Deserialize CPU tensor from shared memory."""
255
- import torch
256
- import torch.multiprocessing.reductions as reductions
257
-
258
- strategy = data.get("strategy")
259
- if strategy != "file_system":
260
- raise RuntimeError(f"Unsupported CPU tensor strategy: {strategy}")
261
-
262
- dtype_str = data["dtype"]
263
- dtype = getattr(torch, dtype_str.split(".")[-1])
264
-
265
- manager_path = data["manager_path"]
266
- storage_key = data["storage_key"]
267
- storage_size = data["storage_size"]
268
-
269
- # Convert to bytes if needed
270
- if isinstance(manager_path, str):
271
- manager_path = manager_path.encode('utf-8')
272
- if isinstance(storage_key, str):
273
- storage_key = storage_key.encode('utf-8')
274
-
275
- # Rebuild storage
276
- rebuilt_storage = reductions.rebuild_storage_filename(
277
- torch.UntypedStorage, manager_path, storage_key, storage_size
278
- )
279
-
280
- # Wrap in typed storage
281
- typed_storage = torch.storage.TypedStorage(
282
- wrap_storage=rebuilt_storage, dtype=dtype, _internal=True
283
- )
284
-
285
- # Rebuild tensor
286
- metadata = (
287
- data["offset"],
288
- tuple(data["shape"]),
289
- tuple(data["stride"]),
290
- data["requires_grad"],
291
- )
292
- tensor = reductions.rebuild_tensor(torch.Tensor, typed_storage, metadata)
293
- return tensor
294
-
295
-
296
- def _deserialize_cuda_tensor(data: Dict[str, Any]) -> Any:
297
- """Deserialize CUDA tensor from IPC handle."""
298
- import torch
299
- import torch.multiprocessing.reductions as reductions
300
-
301
- dtype_str = data["dtype"]
302
- dtype = getattr(torch, dtype_str.split(".")[-1])
303
-
304
- handle = base64.b64decode(data["handle"])
305
- ref_counter_handle = base64.b64decode(data["ref_counter_handle"])
306
- event_handle = base64.b64decode(data["event_handle"]) if data.get("event_handle") else None
307
- device_idx = data.get("device_idx", 0)
308
-
309
- tensor = reductions.rebuild_cuda_tensor(
310
- torch.Tensor,
311
- tuple(data["shape"]),
312
- tuple(data["stride"]),
313
- data["offset"],
314
- torch.storage.TypedStorage,
315
- dtype,
316
- device_idx,
317
- handle,
318
- data["storage_size"],
319
- data["storage_offset"],
320
- data["requires_grad"],
321
- ref_counter_handle,
322
- data["ref_counter_offset"],
323
- event_handle,
324
- data["event_sync_required"],
325
- )
326
- return tensor
327
-
328
-
329
- def _deserialize_tensor_fallback(data: Dict[str, Any]) -> Any:
330
- """Deserialize tensor from pickle fallback format."""
331
- import pickle
332
- import torch
333
-
334
- arr = pickle.loads(base64.b64decode(data["data"]))
335
- tensor = torch.from_numpy(arr)
336
-
337
- # Move to original device if CUDA
338
- device = data.get("device", "cpu")
339
- if "cuda" in device:
340
- tensor = tensor.to(device)
341
-
342
- return tensor
343
-
344
-
345
- # ---------------------------------------------------------------------------
346
- # Integration with protocol.py
347
- # ---------------------------------------------------------------------------
348
-
349
- def is_tensor(obj: Any) -> bool:
350
- """Check if object is a PyTorch tensor."""
351
- try:
352
- import torch
353
- return isinstance(obj, torch.Tensor)
354
- except ImportError:
355
- return False
356
-
357
-
358
- def can_use_ipc() -> bool:
359
- """
360
- Check if IPC tensor sharing is available.
361
-
362
- Requires:
363
- - PyTorch installed
364
- - multiprocessing spawn context works
365
- """
366
- try:
367
- import torch
368
- import torch.multiprocessing as mp
369
- return True
370
- except ImportError:
371
- return False