vajra-streamer 0.0.13__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vajra_streamer-0.0.13/LICENSE +21 -0
- vajra_streamer-0.0.13/MANIFEST.in +3 -0
- vajra_streamer-0.0.13/PKG-INFO +43 -0
- vajra_streamer-0.0.13/README.md +32 -0
- vajra_streamer-0.0.13/example/test.py +52 -0
- vajra_streamer-0.0.13/pyproject.toml +25 -0
- vajra_streamer-0.0.13/setup.cfg +4 -0
- vajra_streamer-0.0.13/setup.py +36 -0
- vajra_streamer-0.0.13/vajra/__init__.py +249 -0
- vajra_streamer-0.0.13/vajra/libvajra.so +0 -0
- vajra_streamer-0.0.13/vajra_streamer.egg-info/PKG-INFO +43 -0
- vajra_streamer-0.0.13/vajra_streamer.egg-info/SOURCES.txt +13 -0
- vajra_streamer-0.0.13/vajra_streamer.egg-info/dependency_links.txt +1 -0
- vajra_streamer-0.0.13/vajra_streamer.egg-info/requires.txt +1 -0
- vajra_streamer-0.0.13/vajra_streamer.egg-info/top_level.txt +1 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Tvashta Labs
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: vajra-streamer
|
|
3
|
+
Version: 0.0.13
|
|
4
|
+
Summary: A fast zero-copy PyTorch tensor streamer
|
|
5
|
+
License-Expression: MIT
|
|
6
|
+
Requires-Python: >=3.8
|
|
7
|
+
Description-Content-Type: text/markdown
|
|
8
|
+
License-File: LICENSE
|
|
9
|
+
Requires-Dist: torch>=2.0.0
|
|
10
|
+
Dynamic: license-file
|
|
11
|
+
|
|
12
|
+
# Vajra
|
|
13
|
+
|
|
14
|
+
A fast zero-copy PyTorch tensor streamer powered by Dlang.
|
|
15
|
+
|
|
16
|
+
## Usage
|
|
17
|
+
|
|
18
|
+
Once installed, you can use Vajra in any Python script or Jupyter Notebook:
|
|
19
|
+
|
|
20
|
+
```python
|
|
21
|
+
import torch
|
|
22
|
+
from vajra import VajraStreamer, StreamConfig
|
|
23
|
+
|
|
24
|
+
# Configure the streamer (optional, uses defaults otherwise)
|
|
25
|
+
config = StreamConfig(
|
|
26
|
+
auth_token="hf_YOUR_TOKEN", # Required for gated models like Llama 3
|
|
27
|
+
chunk_size_mb=64,
|
|
28
|
+
chunk_workers=16,
|
|
29
|
+
gpu_workers=3,
|
|
30
|
+
disable_cache=False
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
# Context manager ensures VRAM is freed when done
|
|
34
|
+
with VajraStreamer(config) as streamer:
|
|
35
|
+
# Pass the URL or a Hugging Face repo ID
|
|
36
|
+
# (e.g. "meta-llama/Meta-Llama-3-8B")
|
|
37
|
+
tensors = streamer.load("meta-llama/Meta-Llama-3-8B")
|
|
38
|
+
|
|
39
|
+
# 'tensors' is a dictionary mapping tensor names to zero-copy PyTorch tensors
|
|
40
|
+
# that are backed directly by the downloaded GPU memory.
|
|
41
|
+
for name, tensor in tensors.items():
|
|
42
|
+
print(f"Tensor: {name}, Shape: {tensor.shape}, Dtype: {tensor.dtype}")
|
|
43
|
+
```
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# Vajra
|
|
2
|
+
|
|
3
|
+
A fast zero-copy PyTorch tensor streamer powered by Dlang.
|
|
4
|
+
|
|
5
|
+
## Usage
|
|
6
|
+
|
|
7
|
+
Once installed, you can use Vajra in any Python script or Jupyter Notebook:
|
|
8
|
+
|
|
9
|
+
```python
|
|
10
|
+
import torch
|
|
11
|
+
from vajra import VajraStreamer, StreamConfig
|
|
12
|
+
|
|
13
|
+
# Configure the streamer (optional, uses defaults otherwise)
|
|
14
|
+
config = StreamConfig(
|
|
15
|
+
auth_token="hf_YOUR_TOKEN", # Required for gated models like Llama 3
|
|
16
|
+
chunk_size_mb=64,
|
|
17
|
+
chunk_workers=16,
|
|
18
|
+
gpu_workers=3,
|
|
19
|
+
disable_cache=False
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
# Context manager ensures VRAM is freed when done
|
|
23
|
+
with VajraStreamer(config) as streamer:
|
|
24
|
+
# Pass the URL or a Hugging Face repo ID
|
|
25
|
+
# (e.g. "meta-llama/Meta-Llama-3-8B")
|
|
26
|
+
tensors = streamer.load("meta-llama/Meta-Llama-3-8B")
|
|
27
|
+
|
|
28
|
+
# 'tensors' is a dictionary mapping tensor names to zero-copy PyTorch tensors
|
|
29
|
+
# that are backed directly by the downloaded GPU memory.
|
|
30
|
+
for name, tensor in tensors.items():
|
|
31
|
+
print(f"Tensor: {name}, Shape: {tensor.shape}, Dtype: {tensor.dtype}")
|
|
32
|
+
```
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
# Ensure the parent directory (python) is in path to import vajra
|
|
5
|
+
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
|
6
|
+
|
|
7
|
+
from vajra import VajraStreamer, StreamConfig
|
|
8
|
+
|
|
9
|
+
def main():
|
|
10
|
+
# URL to a small model for testing
|
|
11
|
+
url = "https://huggingface.co/meta-llama/Meta-Llama-3-8B/resolve/main/model-00001-of-00004.safetensors"
|
|
12
|
+
|
|
13
|
+
config = StreamConfig(
|
|
14
|
+
auth_token="", # Using public model
|
|
15
|
+
chunk_size_mb=16, # Small chunk size for the small model
|
|
16
|
+
chunk_workers=4,
|
|
17
|
+
gpu_workers=2,
|
|
18
|
+
disable_cache=True,
|
|
19
|
+
log_level=4 # 4 = info, 3 = diagnostic, 2 = debug, 1 = debugV, 0 = trace
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
print(f"Streaming {url} into GPU memory (zero-copy)...")
|
|
23
|
+
|
|
24
|
+
print("[Python] Entering VajraStreamer context")
|
|
25
|
+
with VajraStreamer(config) as streamer:
|
|
26
|
+
# Load the model directly to VRAM
|
|
27
|
+
print("[Python] Calling streamer.load(url)")
|
|
28
|
+
tensors = streamer.load(url)
|
|
29
|
+
print("[Python] streamer.load(url) completed")
|
|
30
|
+
|
|
31
|
+
print("\nSuccessfully loaded model!")
|
|
32
|
+
print(f"Total tensors loaded: {len(tensors)}")
|
|
33
|
+
|
|
34
|
+
if len(tensors) > 0:
|
|
35
|
+
print("\nSample Tensors:")
|
|
36
|
+
# Print the first few tensors to verify
|
|
37
|
+
count = 0
|
|
38
|
+
for name, tensor in tensors.items():
|
|
39
|
+
if hasattr(tensor, 'dtype'):
|
|
40
|
+
print(f" - {name}: shape={tensor.shape}, dtype={tensor.dtype}, device={tensor.device}")
|
|
41
|
+
else:
|
|
42
|
+
print(f" - {name}: shape={tensor.shape}, dtype={tensor.typestr}, device=cuda (raw VajraGPUTensor)")
|
|
43
|
+
count += 1
|
|
44
|
+
if count >= 5:
|
|
45
|
+
print(" - ...")
|
|
46
|
+
break
|
|
47
|
+
|
|
48
|
+
print("\nExiting with block. All VRAM will be freed automatically.")
|
|
49
|
+
print("[Python] Context manager exited successfully")
|
|
50
|
+
|
|
51
|
+
if __name__ == "__main__":
|
|
52
|
+
main()
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=80.0", "setuptools-scm[simple]>=9.2", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "vajra-streamer"
|
|
7
|
+
dynamic = ["version"]
|
|
8
|
+
description = "A fast zero-copy PyTorch tensor streamer"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
license = "MIT"
|
|
11
|
+
license-files = ["LICENSE"]
|
|
12
|
+
requires-python = ">=3.8"
|
|
13
|
+
dependencies = [
|
|
14
|
+
"torch>=2.0.0"
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
[tool.setuptools.packages.find]
|
|
18
|
+
where = ["."]
|
|
19
|
+
include = ["vajra"]
|
|
20
|
+
|
|
21
|
+
[tool.setuptools.package-data]
|
|
22
|
+
"vajra" = ["*.so", "*.dylib", "*.dll"]
|
|
23
|
+
|
|
24
|
+
[tool.setuptools_scm]
|
|
25
|
+
root = ".."
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import platform
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
from setuptools import Distribution, setup
|
|
5
|
+
|
|
6
|
+
# Guard must run before setuptools does any work. This sdist intentionally
|
|
7
|
+
# bundles a prebuilt Linux x86_64 binary (libvajra.so). It cannot be built
|
|
8
|
+
# on any other platform. Fail loudly here rather than with a confusing
|
|
9
|
+
# linker or import error later.
|
|
10
|
+
if not (sys.platform == "linux" and platform.machine() == "x86_64"):
|
|
11
|
+
raise RuntimeError(
|
|
12
|
+
"vajra ships a prebuilt Linux x86_64 shared library (libvajra.so). "
|
|
13
|
+
"This sdist cannot be built on your platform. "
|
|
14
|
+
"Install the pre-built wheel instead:\n\n"
|
|
15
|
+
" pip install vajra-streamer\n\n"
|
|
16
|
+
"If no wheel is available for your platform, "
|
|
17
|
+
"please open an issue at https://github.com/your-org/vibe-streamer"
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class BinaryDistribution(Distribution):
|
|
22
|
+
"""Force setuptools to treat this as a platform-specific wheel.
|
|
23
|
+
|
|
24
|
+
setuptools has no declarative flag in pyproject.toml for this case —
|
|
25
|
+
where a pre-compiled .so is injected from outside the Python build
|
|
26
|
+
system. Overriding has_ext_modules() is the standard escape hatch.
|
|
27
|
+
Without this, `python -m build` produces a `*-any.whl` (pure-Python
|
|
28
|
+
tag), which PyPI would serve to Windows/macOS users who can't run .so
|
|
29
|
+
files.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def has_ext_modules(self):
|
|
33
|
+
return True
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
setup(distclass=BinaryDistribution)
|
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
import ctypes
|
|
2
|
+
import os
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Dict, Optional, Any, Callable
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
# Load the shared library
|
|
8
|
+
import platform
|
|
9
|
+
|
|
10
|
+
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
11
|
+
|
|
12
|
+
_ext = ".dylib" if platform.system() == "Darwin" else ".so"
|
|
13
|
+
# 1. First, check inside the installed package directory (where the wheel puts it)
|
|
14
|
+
_lib_path = os.path.join(_current_dir, f"libvajra{_ext}")
|
|
15
|
+
|
|
16
|
+
# 2. Fallbacks for local development if the .so isn't copied yet
|
|
17
|
+
if not os.path.exists(_lib_path):
|
|
18
|
+
_lib_path = os.path.join(os.path.dirname(_current_dir), f"libvajra{_ext}")
|
|
19
|
+
if not os.path.exists(_lib_path):
|
|
20
|
+
_lib_path = os.path.join(os.path.dirname(os.path.dirname(_current_dir)), f"libvajra{_ext}")
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
_lib = ctypes.CDLL(_lib_path, mode=ctypes.RTLD_GLOBAL)
|
|
24
|
+
except OSError:
|
|
25
|
+
# If not in the parent dir, fallback to standard library search paths
|
|
26
|
+
_lib = ctypes.CDLL(f"libvajra{_ext}", mode=ctypes.RTLD_GLOBAL)
|
|
27
|
+
|
|
28
|
+
# 1. Define the C-ABI Structs
|
|
29
|
+
class CTensor(ctypes.Structure):
|
|
30
|
+
_fields_ = [
|
|
31
|
+
("name", ctypes.c_char_p),
|
|
32
|
+
("gpu_ptr", ctypes.c_void_p),
|
|
33
|
+
("shape", ctypes.POINTER(ctypes.c_long)),
|
|
34
|
+
("ndim", ctypes.c_size_t),
|
|
35
|
+
("dtype", ctypes.c_char_p)
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
class CStreamConfig(ctypes.Structure):
|
|
39
|
+
_fields_ = [
|
|
40
|
+
("auth_token", ctypes.c_char_p),
|
|
41
|
+
("chunk_size", ctypes.c_size_t),
|
|
42
|
+
("chunk_workers", ctypes.c_uint),
|
|
43
|
+
("gpu_workers", ctypes.c_int),
|
|
44
|
+
("disable_cache", ctypes.c_bool),
|
|
45
|
+
("log_level", ctypes.c_int)
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
# Callback typedef: typedef void (*ProgressCallback)(size_t bytes_downloaded, size_t total_bytes)
|
|
49
|
+
ProgressCallbackType = ctypes.CFUNCTYPE(None, ctypes.c_size_t, ctypes.c_size_t)
|
|
50
|
+
|
|
51
|
+
# Configure the main entry point:
|
|
52
|
+
# int stream_model(const char* url, StreamConfig config, CTensor** out_tensors, size_t* num_tensors)
|
|
53
|
+
_lib.stream_model.argtypes = [
|
|
54
|
+
ctypes.c_char_p, # url
|
|
55
|
+
ctypes.POINTER(CStreamConfig), # config
|
|
56
|
+
ctypes.POINTER(ctypes.POINTER(CTensor)), # out_tensors
|
|
57
|
+
ctypes.POINTER(ctypes.c_size_t) # num_tensors
|
|
58
|
+
]
|
|
59
|
+
_lib.stream_model.restype = ctypes.c_int
|
|
60
|
+
|
|
61
|
+
# Tracker control functions
|
|
62
|
+
_lib.vajra_tracker_pause.argtypes = []
|
|
63
|
+
_lib.vajra_tracker_pause.restype = None
|
|
64
|
+
|
|
65
|
+
_lib.vajra_tracker_resume.argtypes = []
|
|
66
|
+
_lib.vajra_tracker_resume.restype = None
|
|
67
|
+
|
|
68
|
+
# Configure cleanup function: void free_model_memory(CTensor* tensors_array, size_t num_tensors)
|
|
69
|
+
_lib.free_model_memory.argtypes = [ctypes.POINTER(CTensor), ctypes.c_size_t]
|
|
70
|
+
_lib.free_model_memory.restype = None
|
|
71
|
+
|
|
72
|
+
# Configure teardown function: void vajra_teardown()
|
|
73
|
+
_lib.vajra_teardown.argtypes = []
|
|
74
|
+
_lib.vajra_teardown.restype = None
|
|
75
|
+
|
|
76
|
+
import atexit
|
|
77
|
+
import os
|
|
78
|
+
|
|
79
|
+
# Register os._exit(0) FIRST so it runs LAST.
|
|
80
|
+
# This prevents glibc from running Dlang's .fini_array (rt_term), which segfaults
|
|
81
|
+
# due to vibe.d daemon threads sleeping in epoll_wait when the process exits.
|
|
82
|
+
atexit.register(os._exit, 0)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
# 2. Python Config Dataclass
|
|
86
|
+
@dataclass
|
|
87
|
+
class StreamConfig:
|
|
88
|
+
auth_token: str = ""
|
|
89
|
+
chunk_size_mb: int = 64
|
|
90
|
+
chunk_workers: int = 16
|
|
91
|
+
gpu_workers: int = 3
|
|
92
|
+
disable_cache: bool = False
|
|
93
|
+
log_level: int = 4 # default to info (vibe-d log level 4)
|
|
94
|
+
|
|
95
|
+
def to_c_struct(self) -> CStreamConfig:
|
|
96
|
+
return CStreamConfig(
|
|
97
|
+
auth_token=self.auth_token.encode('utf-8') if self.auth_token else None,
|
|
98
|
+
chunk_size=self.chunk_size_mb * 1024 * 1024,
|
|
99
|
+
chunk_workers=self.chunk_workers,
|
|
100
|
+
gpu_workers=self.gpu_workers,
|
|
101
|
+
disable_cache=self.disable_cache,
|
|
102
|
+
log_level=self.log_level
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# 3. Zero-Copy Tensor Wrapper
|
|
106
|
+
class VajraGPUTensor:
|
|
107
|
+
def __init__(self, c_tensor: CTensor):
|
|
108
|
+
self._gpu_ptr = c_tensor.gpu_ptr
|
|
109
|
+
|
|
110
|
+
# Safely decode the name
|
|
111
|
+
try:
|
|
112
|
+
self.name = c_tensor.name.decode('utf-8', errors='ignore')
|
|
113
|
+
except Exception:
|
|
114
|
+
self.name = "unknown_tensor"
|
|
115
|
+
|
|
116
|
+
# Extract shape
|
|
117
|
+
self.shape = tuple(c_tensor.shape[i] for i in range(c_tensor.ndim))
|
|
118
|
+
|
|
119
|
+
# Map dtype string to numpy/cuda_array_interface format
|
|
120
|
+
dtype_str = c_tensor.dtype.decode('utf-8')
|
|
121
|
+
# Map Safetensors dtypes to __cuda_array_interface__ typestrings
|
|
122
|
+
self._typestr_map = {
|
|
123
|
+
"F32": "<f4",
|
|
124
|
+
"F16": "<f2",
|
|
125
|
+
"BF16": "<V2", # BF16 doesn't have a standard typestr, often handled specially by torch
|
|
126
|
+
"I64": "<i8",
|
|
127
|
+
"I32": "<i4",
|
|
128
|
+
"I16": "<i2",
|
|
129
|
+
"I8": "|i1",
|
|
130
|
+
"U8": "|u1",
|
|
131
|
+
"BOOL": "|b1",
|
|
132
|
+
}
|
|
133
|
+
self.typestr = self._typestr_map.get(dtype_str, "<f4")
|
|
134
|
+
self.original_dtype = dtype_str
|
|
135
|
+
|
|
136
|
+
@property
|
|
137
|
+
def __cuda_array_interface__(self) -> Dict[str, Any]:
|
|
138
|
+
"""PyTorch zero-copy protocol"""
|
|
139
|
+
return {
|
|
140
|
+
"shape": self.shape,
|
|
141
|
+
"typestr": self.typestr,
|
|
142
|
+
"data": (self._gpu_ptr, False), # (ptr, read_only=False)
|
|
143
|
+
"version": 3
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
# 4. Main Context Manager Streamer
|
|
148
|
+
class VajraStreamer:
|
|
149
|
+
def __init__(self, config: Optional[StreamConfig] = None):
|
|
150
|
+
self.config = config or StreamConfig()
|
|
151
|
+
self._loaded_tensors: list[VajraGPUTensor] = []
|
|
152
|
+
|
|
153
|
+
@staticmethod
|
|
154
|
+
def pause_vram_tracking():
|
|
155
|
+
"""Pause the VRAM tracker to stop logging allocations"""
|
|
156
|
+
_lib.vajra_tracker_pause()
|
|
157
|
+
|
|
158
|
+
@staticmethod
|
|
159
|
+
def resume_vram_tracking():
|
|
160
|
+
"""Resume the VRAM tracker to start logging allocations again"""
|
|
161
|
+
_lib.vajra_tracker_resume()
|
|
162
|
+
|
|
163
|
+
def __enter__(self):
|
|
164
|
+
return self
|
|
165
|
+
|
|
166
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
167
|
+
"""Destroy the VRAM Arenas and the CPU Metadata map at once"""
|
|
168
|
+
print("[Python] Entering __exit__ block")
|
|
169
|
+
ptr = getattr(self, '_out_tensors_ptr', None)
|
|
170
|
+
num = getattr(self, '_num_tensors', None)
|
|
171
|
+
print(f"[Python] ptr={bool(ptr)}, num={bool(num)}")
|
|
172
|
+
|
|
173
|
+
if ptr is not None and num is not None:
|
|
174
|
+
print("[Python] Calling _lib.free_model_memory...")
|
|
175
|
+
_lib.free_model_memory(ptr, num.value)
|
|
176
|
+
print("[Python] _lib.free_model_memory returned")
|
|
177
|
+
|
|
178
|
+
# Prevent double-freeing if __exit__ is called multiple times
|
|
179
|
+
self._out_tensors_ptr = None
|
|
180
|
+
self._num_tensors = None
|
|
181
|
+
else:
|
|
182
|
+
print("[Python] Did not call free_model_memory because ptr or num is None")
|
|
183
|
+
|
|
184
|
+
def _default_progress(self, downloaded: int, total: int):
|
|
185
|
+
# Basic progress callback if none provided
|
|
186
|
+
if total > 0:
|
|
187
|
+
percent = (downloaded / total) * 100
|
|
188
|
+
print(f"\rStreaming: {downloaded / 1024 / 1024:.2f}MB / {total / 1024 / 1024:.2f}MB ({percent:.1f}%)", end="")
|
|
189
|
+
if downloaded == total:
|
|
190
|
+
print() # New line when done
|
|
191
|
+
|
|
192
|
+
def load(self, url: str) -> Dict[str, torch.Tensor]:
|
|
193
|
+
"""
|
|
194
|
+
Stream the model and return a dictionary of zero-copy PyTorch tensors mapping tensor names to VRAM.
|
|
195
|
+
"""
|
|
196
|
+
url_bytes = url.encode('utf-8')
|
|
197
|
+
c_config = self.config.to_c_struct()
|
|
198
|
+
|
|
199
|
+
# Output pointers
|
|
200
|
+
out_tensors_ptr = ctypes.POINTER(CTensor)()
|
|
201
|
+
num_tensors = ctypes.c_size_t(0)
|
|
202
|
+
|
|
203
|
+
# Call Dlang
|
|
204
|
+
error_code = _lib.stream_model(
|
|
205
|
+
url_bytes,
|
|
206
|
+
ctypes.byref(c_config),
|
|
207
|
+
ctypes.byref(out_tensors_ptr),
|
|
208
|
+
ctypes.byref(num_tensors)
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
# Handle Errors
|
|
212
|
+
if error_code != 0:
|
|
213
|
+
if error_code == 1:
|
|
214
|
+
raise ConnectionError(f"Network error while streaming {url}")
|
|
215
|
+
elif error_code == 2:
|
|
216
|
+
raise MemoryError(f"CUDA Out Of Memory while allocating VRAM for {url}")
|
|
217
|
+
else:
|
|
218
|
+
raise RuntimeError(f"VajraStreamer failed with generic error code {error_code}")
|
|
219
|
+
|
|
220
|
+
# Save the pointers to the class instance so __exit__ can clean them up later
|
|
221
|
+
self._out_tensors_ptr = out_tensors_ptr
|
|
222
|
+
self._num_tensors = num_tensors
|
|
223
|
+
|
|
224
|
+
# Wrap the returned structs
|
|
225
|
+
torch_tensors = {}
|
|
226
|
+
for i in range(num_tensors.value):
|
|
227
|
+
c_tensor = out_tensors_ptr[i]
|
|
228
|
+
vibe_tensor = VajraGPUTensor(c_tensor)
|
|
229
|
+
|
|
230
|
+
# Adopt into PyTorch natively (Zero-Copy)
|
|
231
|
+
# Special handling for bfloat16 as it's not standard in cuda_array_interface
|
|
232
|
+
if vibe_tensor.original_dtype == "BF16":
|
|
233
|
+
# PyTorch doesn't officially support bf16 __cuda_array_interface__ via typestr
|
|
234
|
+
# We do a workaround by creating a view on Int16 data
|
|
235
|
+
vibe_tensor.typestr = "<i2"
|
|
236
|
+
try:
|
|
237
|
+
i16_tensor = torch.as_tensor(vibe_tensor, device='cuda')
|
|
238
|
+
torch_tensors[vibe_tensor.name] = i16_tensor.view(torch.bfloat16)
|
|
239
|
+
except Exception as e:
|
|
240
|
+
print(f"Skipping torch conversion for {vibe_tensor.name} due to error: {e}")
|
|
241
|
+
torch_tensors[vibe_tensor.name] = vibe_tensor
|
|
242
|
+
else:
|
|
243
|
+
try:
|
|
244
|
+
torch_tensors[vibe_tensor.name] = torch.as_tensor(vibe_tensor, device='cuda')
|
|
245
|
+
except Exception as e:
|
|
246
|
+
print(f"Skipping torch conversion for {vibe_tensor.name} due to error: {e}")
|
|
247
|
+
torch_tensors[vibe_tensor.name] = vibe_tensor
|
|
248
|
+
|
|
249
|
+
return torch_tensors
|
|
Binary file
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: vajra-streamer
|
|
3
|
+
Version: 0.0.13
|
|
4
|
+
Summary: A fast zero-copy PyTorch tensor streamer
|
|
5
|
+
License-Expression: MIT
|
|
6
|
+
Requires-Python: >=3.8
|
|
7
|
+
Description-Content-Type: text/markdown
|
|
8
|
+
License-File: LICENSE
|
|
9
|
+
Requires-Dist: torch>=2.0.0
|
|
10
|
+
Dynamic: license-file
|
|
11
|
+
|
|
12
|
+
# Vajra
|
|
13
|
+
|
|
14
|
+
A fast zero-copy PyTorch tensor streamer powered by Dlang.
|
|
15
|
+
|
|
16
|
+
## Usage
|
|
17
|
+
|
|
18
|
+
Once installed, you can use Vajra in any Python script or Jupyter Notebook:
|
|
19
|
+
|
|
20
|
+
```python
|
|
21
|
+
import torch
|
|
22
|
+
from vajra import VajraStreamer, StreamConfig
|
|
23
|
+
|
|
24
|
+
# Configure the streamer (optional, uses defaults otherwise)
|
|
25
|
+
config = StreamConfig(
|
|
26
|
+
auth_token="hf_YOUR_TOKEN", # Required for gated models like Llama 3
|
|
27
|
+
chunk_size_mb=64,
|
|
28
|
+
chunk_workers=16,
|
|
29
|
+
gpu_workers=3,
|
|
30
|
+
disable_cache=False
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
# Context manager ensures VRAM is freed when done
|
|
34
|
+
with VajraStreamer(config) as streamer:
|
|
35
|
+
# Pass the URL or a Hugging Face repo ID
|
|
36
|
+
# (e.g. "meta-llama/Meta-Llama-3-8B")
|
|
37
|
+
tensors = streamer.load("meta-llama/Meta-Llama-3-8B")
|
|
38
|
+
|
|
39
|
+
# 'tensors' is a dictionary mapping tensor names to zero-copy PyTorch tensors
|
|
40
|
+
# that are backed directly by the downloaded GPU memory.
|
|
41
|
+
for name, tensor in tensors.items():
|
|
42
|
+
print(f"Tensor: {name}, Shape: {tensor.shape}, Dtype: {tensor.dtype}")
|
|
43
|
+
```
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
MANIFEST.in
|
|
3
|
+
README.md
|
|
4
|
+
pyproject.toml
|
|
5
|
+
setup.py
|
|
6
|
+
example/test.py
|
|
7
|
+
vajra/__init__.py
|
|
8
|
+
vajra/libvajra.so
|
|
9
|
+
vajra_streamer.egg-info/PKG-INFO
|
|
10
|
+
vajra_streamer.egg-info/SOURCES.txt
|
|
11
|
+
vajra_streamer.egg-info/dependency_links.txt
|
|
12
|
+
vajra_streamer.egg-info/requires.txt
|
|
13
|
+
vajra_streamer.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
torch>=2.0.0
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
vajra
|