wirehead 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wirehead-0.1.0/LICENSE +21 -0
- wirehead-0.1.0/PKG-INFO +4 -0
- wirehead-0.1.0/README.md +201 -0
- wirehead-0.1.0/setup.cfg +4 -0
- wirehead-0.1.0/setup.py +21 -0
- wirehead-0.1.0/wirehead/__init__.py +4 -0
- wirehead-0.1.0/wirehead/dataset.py +253 -0
- wirehead-0.1.0/wirehead/generator.py +127 -0
- wirehead-0.1.0/wirehead/manager.py +117 -0
- wirehead-0.1.0/wirehead.egg-info/PKG-INFO +4 -0
- wirehead-0.1.0/wirehead.egg-info/SOURCES.txt +12 -0
- wirehead-0.1.0/wirehead.egg-info/dependency_links.txt +1 -0
- wirehead-0.1.0/wirehead.egg-info/requires.txt +4 -0
- wirehead-0.1.0/wirehead.egg-info/top_level.txt +1 -0
wirehead-0.1.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2021 neuroneural/wirehead
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
wirehead-0.1.0/PKG-INFO
ADDED
wirehead-0.1.0/README.md
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
# Wirehead #
|
|
2
|
+
|
|
3
|
+
Caching system for scaling of synthetic data generators using MongoDB
|
|
4
|
+
|
|
5
|
+

|
|
6
|
+
|
|
7
|
+
---
|
|
8
|
+
|
|
9
|
+
# I. Installation
|
|
10
|
+
|
|
11
|
+
## 1. MongoDB Setup (For Development/Testing Only)
|
|
12
|
+
|
|
13
|
+
- [Ubuntu Setup](#a-quick-mongodb-setup-ubuntu)
|
|
14
|
+
- [macOS Setup](#b-quick-mongodb-setup-macos)
|
|
15
|
+
|
|
16
|
+
**Important Note:** The following instructions are for development and testing purposes only. For production deployments, please refer to the [official MongoDB documentation](https://www.mongodb.com/docs/manual/administration/install-community/) for secure and proper installation guidelines.
|
|
17
|
+
|
|
18
|
+
#### a. Quick MongoDB Setup ([Ubuntu](https://www.mongodb.com/docs/manual/tutorial/install-mongodb-on-ubuntu/)):
|
|
19
|
+
|
|
20
|
+
```bash
|
|
21
|
+
sudo apt-get install gnupg curl
|
|
22
|
+
curl -fsSL https://www.mongodb.org/static/pgp/server-7.0.asc | \
|
|
23
|
+
sudo gpg -o /usr/share/keyrings/mongodb-server-7.0.gpg \
|
|
24
|
+
--dearmor
|
|
25
|
+
echo "deb [ arch=amd64,arm64 signed-by=/usr/share/keyrings/mongodb-server-7.0.gpg ] https://repo.mongodb.org/apt/ubuntu jammy/mongodb-org/7.0 multiverse" | sudo tee /etc/apt/sources.list.d/mongodb-org-7.0.list
|
|
26
|
+
sudo apt-get update
|
|
27
|
+
sudo apt-get install -y mongodb-org
|
|
28
|
+
```
|
|
29
|
+
|
|
30
|
+
```bash
|
|
31
|
+
# Run MongoDB
|
|
32
|
+
sudo systemctl start mongod
|
|
33
|
+
```
|
|
34
|
+
|
|
35
|
+
```bash
|
|
36
|
+
# Stop MongoDB
|
|
37
|
+
sudo systemctl stop mongod
|
|
38
|
+
```
|
|
39
|
+
|
|
40
|
+
#### b. Quick MongoDB Setup ([MacOS](https://www.mongodb.com/docs/manual/tutorial/install-mongodb-on-os-x/)):
|
|
41
|
+
|
|
42
|
+
```bash
|
|
43
|
+
brew tap mongodb/brew
|
|
44
|
+
brew update
|
|
45
|
+
brew install mongodb-community@7.0
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
```bash
|
|
49
|
+
# Run MongoDB
|
|
50
|
+
brew services start mongodb-community@7.0
|
|
51
|
+
```
|
|
52
|
+
|
|
53
|
+
```bash
|
|
54
|
+
# Run MongoDB
|
|
55
|
+
brew services stop mongodb-community@7.0
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
**Warning:** These instructions are simplified for ease of setup in a development environment. For production use, ensure proper security measures, authentication, and follow best practices as outlined in the official MongoDB documentation.
|
|
59
|
+
Installing and deploying MongoDB
|
|
60
|
+
|
|
61
|
+
## 2. Create virtual environment
|
|
62
|
+
```bash
|
|
63
|
+
# Note:
|
|
64
|
+
# python version doesn't necessarily have to be 3.10
|
|
65
|
+
# but this gives better support for some generation pipelines
|
|
66
|
+
|
|
67
|
+
# Conda
|
|
68
|
+
conda create -n wirehead python=3.10
|
|
69
|
+
conda activate wirehead
|
|
70
|
+
|
|
71
|
+
# venv
|
|
72
|
+
python3.10 -m venv wirehead
|
|
73
|
+
source venv/bin/activate
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
## 3. Install wirehead:
|
|
77
|
+
```bash
|
|
78
|
+
git clone git@github.com:neuroneural/wirehead.git
|
|
79
|
+
cd wirehead
|
|
80
|
+
pip install -e .
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
## 4. Run the test
|
|
84
|
+
```bash
|
|
85
|
+
cd examples/unit
|
|
86
|
+
chmod +x test.sh
|
|
87
|
+
./test.sh
|
|
88
|
+
```
|
|
89
|
+
|
|
90
|
+
# II. Usage
|
|
91
|
+
|
|
92
|
+
See examples/unit for a minimal example
|
|
93
|
+
|
|
94
|
+
## 1. Manager
|
|
95
|
+
```python
|
|
96
|
+
from wirehead import WireheadManager
|
|
97
|
+
|
|
98
|
+
if __name__ == "__main__":
|
|
99
|
+
wirehead_runtime = WireheadManager(config_path="config.yaml")
|
|
100
|
+
wirehead_runtime.run_manager()
|
|
101
|
+
```
|
|
102
|
+
|
|
103
|
+
## 2. Generator
|
|
104
|
+
|
|
105
|
+
```python
|
|
106
|
+
import numpy as np
|
|
107
|
+
from wirehead import WireheadGenerator
|
|
108
|
+
|
|
109
|
+
def create_generator():
|
|
110
|
+
while True:
|
|
111
|
+
img = np.random.rand(256,256,256)
|
|
112
|
+
lab = np.random.rand(256,256,256)
|
|
113
|
+
yield (img, lab)
|
|
114
|
+
|
|
115
|
+
if __name__ == "__main__":
|
|
116
|
+
brain_generator = create_generator()
|
|
117
|
+
wirehead_runtime = WireheadGenerator(
|
|
118
|
+
generator = brain_generator,
|
|
119
|
+
config_path = "config.yaml"
|
|
120
|
+
)
|
|
121
|
+
wirehead_runtime.run_generator()
|
|
122
|
+
```
|
|
123
|
+
|
|
124
|
+
## 3. Dataset
|
|
125
|
+
```python
|
|
126
|
+
import torch
|
|
127
|
+
from wirehead import MongoheadDataset
|
|
128
|
+
|
|
129
|
+
dataset = MongoheadDataset(config_path = "config.yaml")
|
|
130
|
+
|
|
131
|
+
idx = [0]
|
|
132
|
+
data = dataset[idx]
|
|
133
|
+
sample, label = data[0]['input'], data[0]['label']
|
|
134
|
+
```
|
|
135
|
+
|
|
136
|
+
# III. Config guide
|
|
137
|
+
|
|
138
|
+
All wirehead configs live inside yaml files, and must be specified when declaring wirehead manager, generator and dataset objects. For the system to work, all components must use the __same__ configs.
|
|
139
|
+
|
|
140
|
+
## 1. Basic configs:
|
|
141
|
+
```yaml
|
|
142
|
+
MONGOHOST -- IP address or hostname for machine running MongoDB instance
|
|
143
|
+
DBNAME -- MongoDB database name
|
|
144
|
+
PORT -- Port for MongoDB instance. Defaults to 27017
|
|
145
|
+
SWAP_CAP -- Size cap for read and write collections. bigger means bigger cache, and less frequent swaps. The total memory used by wirehead can be calculated with:
|
|
146
|
+
SWAP_CAP * SIZE OF YIELDED TUPLE * 2
|
|
147
|
+
```
|
|
148
|
+
|
|
149
|
+
## 2. Advanced configs:
|
|
150
|
+
```yaml
|
|
151
|
+
SAMPLE -- Array of strings denoting name of samples in data tuple.
|
|
152
|
+
WRITE_COLLECTION -- Name of write collection (generators push to this)
|
|
153
|
+
READ_COLLECTION -- Name of read colletion (dataset reads from this)
|
|
154
|
+
COUNTER_COLLECTION -- Name of counter collection for manager metrics
|
|
155
|
+
TEMP_COLLECTION -- Name of temporary collection used for moving data during swap
|
|
156
|
+
CHUNKSIZE -- Number of megabytes used for chunking data
|
|
157
|
+
```
|
|
158
|
+
|
|
159
|
+
---
|
|
160
|
+
|
|
161
|
+
# IV. Generator example
|
|
162
|
+
|
|
163
|
+
See a simple example in [examples/unit/generator.py](examples/unit/generator.py) or a Synthseg example in [examples/synthseg/generator.py](examples/synthseg/generator.py)
|
|
164
|
+
|
|
165
|
+
Wirehead's [WireheadGenerator](https://github.com/neuroneural/wirehead/blob/main/wirehead/generator.py) object takes in a generator, which is a python generator function. This function yields a tuple containing numpy arrays. The number of samples in this tuple should match the number of strings specified in SAMPLE in config.yaml
|
|
166
|
+
|
|
167
|
+
## 1. Set SAMPLE in "config.yaml" (note the number of keys)
|
|
168
|
+
```yaml
|
|
169
|
+
SAMPLE: ["a", "b"]
|
|
170
|
+
```
|
|
171
|
+
|
|
172
|
+
## 2. Create a generator function, which yields the same number of objects
|
|
173
|
+
```python
|
|
174
|
+
def create_generator():
|
|
175
|
+
while True:
|
|
176
|
+
a = np.random.rand(256,256,256)
|
|
177
|
+
b = np.random.rand(256,256,256)
|
|
178
|
+
yield (a, b)
|
|
179
|
+
```
|
|
180
|
+
|
|
181
|
+
## 3. Insert config file path and generator function into WireheadGenerator
|
|
182
|
+
```python
|
|
183
|
+
generator = create_generator()
|
|
184
|
+
runtime = WireheadGenerator(
|
|
185
|
+
generator = generator,
|
|
186
|
+
config_path = "config.yaml"
|
|
187
|
+
)
|
|
188
|
+
```
|
|
189
|
+
|
|
190
|
+
## 4. Press play
|
|
191
|
+
```python
|
|
192
|
+
runtime.run_generator() # runs an infinite loop
|
|
193
|
+
```
|
|
194
|
+
|
|
195
|
+
---
|
|
196
|
+
|
|
197
|
+
# Citation/Contact
|
|
198
|
+
|
|
199
|
+
This code is under [MIT](https://github.com/neuroneural/wirehead/blob/main/LICENSE) licensing
|
|
200
|
+
|
|
201
|
+
If you have any questions specific to the Wirehead pipeline, please raise an issue or contact us at mdoan4@gsu.edu
|
wirehead-0.1.0/setup.cfg
ADDED
wirehead-0.1.0/setup.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from setuptools import setup, find_packages
|
|
2
|
+
|
|
3
|
+
setup(
|
|
4
|
+
name='wirehead',
|
|
5
|
+
version='0.1.0',
|
|
6
|
+
packages=find_packages(),
|
|
7
|
+
install_requires=[
|
|
8
|
+
'pymongo',
|
|
9
|
+
'torch',
|
|
10
|
+
'numpy',
|
|
11
|
+
'PyYaml',
|
|
12
|
+
# List your package dependencies here
|
|
13
|
+
# Example: 'numpy>=1.19.0',
|
|
14
|
+
],
|
|
15
|
+
entry_points={
|
|
16
|
+
'console_scripts': [
|
|
17
|
+
# Define any command-line entry points here
|
|
18
|
+
# Example: 'your-command = your_package.module:function',
|
|
19
|
+
],
|
|
20
|
+
},
|
|
21
|
+
)
|
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
""" Wirehead dataset class for MongoDB """
|
|
2
|
+
|
|
3
|
+
import io
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
import time
|
|
7
|
+
import yaml
|
|
8
|
+
import torch
|
|
9
|
+
from torch.utils.data import Dataset
|
|
10
|
+
from pymongo import MongoClient
|
|
11
|
+
from pymongo.errors import OperationFailure, OperationFailure
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def unit_interval_normalize(img):
|
|
15
|
+
"""Unit interval preprocessing"""
|
|
16
|
+
img = (img - img.min()) / (img.max() - img.min())
|
|
17
|
+
return img
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def quantile_normalize(img, qmin=0.01, qmax=0.99):
|
|
21
|
+
"""Unit interval preprocessing"""
|
|
22
|
+
img = (img - img.quantile(qmin)) / (img.quantile(qmax) - img.quantile(qmin))
|
|
23
|
+
return img
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def binary_to_tensor(tensor_binary):
|
|
27
|
+
""" Converts a binary io buffer to a torch tensor """
|
|
28
|
+
buffer = io.BytesIO(tensor_binary)
|
|
29
|
+
tensor = torch.load(buffer)
|
|
30
|
+
return tensor
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class MongoheadDataset(Dataset):
|
|
34
|
+
"""
|
|
35
|
+
A dataset for fetching batches of records from a MongoDB
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(self,
|
|
39
|
+
config_path="",
|
|
40
|
+
timeout=60,
|
|
41
|
+
collection=None,
|
|
42
|
+
sample=("data", "label"),
|
|
43
|
+
transform=binary_to_tensor,
|
|
44
|
+
normalize=lambda x: x,
|
|
45
|
+
id="id",
|
|
46
|
+
keeptrying=True):
|
|
47
|
+
"""Constructor
|
|
48
|
+
:param config_path: path to wirehead config .yaml file
|
|
49
|
+
:param indices: a set of indices to be extracted from the collection
|
|
50
|
+
:param transform: a function to be applied to each extracted record
|
|
51
|
+
:param collection: pymongo collection to be used
|
|
52
|
+
:param sample: a pair of fields to be fetched as `input` and `label`
|
|
53
|
+
, e.g. (`T1`, `label104`)
|
|
54
|
+
:param id: the field to be used as an index. The `indices` are values of this field
|
|
55
|
+
:param keeptrying: whether to attempt a refetch if first attempt fails
|
|
56
|
+
:returns: an object of MongoheadDataset class
|
|
57
|
+
"""
|
|
58
|
+
self.id = id
|
|
59
|
+
self.normalize = normalize
|
|
60
|
+
self.transform = transform
|
|
61
|
+
self.keeptrying = keeptrying # retries if fetch fails
|
|
62
|
+
self.fields = {"id": 1, "chunk": 1, "kind": 1, "chunk_id": 1}
|
|
63
|
+
self.timeout = timeout
|
|
64
|
+
|
|
65
|
+
if config_path != "" and os.path.exists(config_path):
|
|
66
|
+
self.load_from_yaml(config_path)
|
|
67
|
+
|
|
68
|
+
else:
|
|
69
|
+
self.collection = collection
|
|
70
|
+
self.sample = sample
|
|
71
|
+
|
|
72
|
+
self.indices = self.get_indeces()
|
|
73
|
+
|
|
74
|
+
def load_from_yaml(self, config_path):
|
|
75
|
+
"""
|
|
76
|
+
Loads config options from config_path
|
|
77
|
+
"""
|
|
78
|
+
print("Dataset: config loaded from " + config_path)
|
|
79
|
+
with open(config_path, "r") as file:
|
|
80
|
+
config = yaml.safe_load(file)
|
|
81
|
+
dbname = config.get('DBNAME')
|
|
82
|
+
mongohost = config.get('MONGOHOST')
|
|
83
|
+
port = config.get('PORT') if config.get('PORT') is not None else 27017
|
|
84
|
+
client = MongoClient("mongodb://" + mongohost + ":" + str(port))
|
|
85
|
+
|
|
86
|
+
db = client[dbname]
|
|
87
|
+
self.wait_for_data(db)
|
|
88
|
+
read_collection = config.get("READ_COLLECTION")
|
|
89
|
+
self.collection = db[read_collection]
|
|
90
|
+
self.sample = tuple(config.get("SAMPLE"))
|
|
91
|
+
|
|
92
|
+
def wait_for_data(self, db, timeout=600, check_interval=10):
|
|
93
|
+
"""
|
|
94
|
+
Prevents data object from reading before data is ready.
|
|
95
|
+
Raises:
|
|
96
|
+
TimeoutError: If the wait time exceeds the timeout
|
|
97
|
+
"""
|
|
98
|
+
status_collection = db["status"]
|
|
99
|
+
start_time = time.time()
|
|
100
|
+
|
|
101
|
+
def check_status():
|
|
102
|
+
try:
|
|
103
|
+
latest_status = status_collection.find_one(sort=[("_id", -1)])
|
|
104
|
+
if latest_status is None:
|
|
105
|
+
print("Dataset: Database is empty, waiting for data...")
|
|
106
|
+
return False
|
|
107
|
+
if not latest_status.get("swapped", False):
|
|
108
|
+
print("Dataset: Swap has not happened, waiting...")
|
|
109
|
+
return False
|
|
110
|
+
return True
|
|
111
|
+
except (ConnectionFailure, OperationFailure) as e:
|
|
112
|
+
print(f"Dataset: Error accessing database: {e}")
|
|
113
|
+
return False
|
|
114
|
+
|
|
115
|
+
while time.time() - start_time < timeout:
|
|
116
|
+
if check_status():
|
|
117
|
+
print("Dataset: Data is ready")
|
|
118
|
+
return
|
|
119
|
+
time.sleep(check_interval)
|
|
120
|
+
|
|
121
|
+
raise TimeoutError("Dataset: Waited too long for data to be ready")
|
|
122
|
+
|
|
123
|
+
def get_indeces(self):
|
|
124
|
+
"""
|
|
125
|
+
Retrieve the index array of samples in read collection
|
|
126
|
+
"""
|
|
127
|
+
last_post = self.collection['bin'].find_one(sort=[(self.id, -1)])
|
|
128
|
+
|
|
129
|
+
if last_post is None:
|
|
130
|
+
print("Empty collection, exiting")
|
|
131
|
+
sys.exit()
|
|
132
|
+
num_examples = int(last_post[self.id] + 1)
|
|
133
|
+
return range(num_examples)
|
|
134
|
+
|
|
135
|
+
def __len__(self):
|
|
136
|
+
return len(self.indices)
|
|
137
|
+
|
|
138
|
+
def make_serial(self, samples_for_id, kind):
|
|
139
|
+
"""
|
|
140
|
+
Converts collection chunks into a contiguous byte sequence
|
|
141
|
+
"""
|
|
142
|
+
return b"".join([
|
|
143
|
+
sample["chunk"] for sample in sorted(
|
|
144
|
+
(sample for sample in samples_for_id if sample["kind"] == kind),
|
|
145
|
+
key=lambda x: x["chunk_id"],
|
|
146
|
+
)
|
|
147
|
+
])
|
|
148
|
+
|
|
149
|
+
def retry_on_eof_error(retry_count, verbose=False):
|
|
150
|
+
"""
|
|
151
|
+
Error handling for reads that happen mid swap
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
def decorator(func):
|
|
155
|
+
|
|
156
|
+
def wrapper(self, batch, *args, **kwargs):
|
|
157
|
+
myException = Exception # Default Exception if not overwritten
|
|
158
|
+
for attempt in range(retry_count):
|
|
159
|
+
try:
|
|
160
|
+
return func(self, batch, *args, **kwargs)
|
|
161
|
+
except (
|
|
162
|
+
EOFError,
|
|
163
|
+
OperationFailure,
|
|
164
|
+
) as exception: # Specifically catching EOFError
|
|
165
|
+
if self.keeptrying:
|
|
166
|
+
if verbose:
|
|
167
|
+
print(
|
|
168
|
+
f"EOFError caught. Retrying {attempt+1}/{retry_count}"
|
|
169
|
+
)
|
|
170
|
+
time.sleep(1)
|
|
171
|
+
continue
|
|
172
|
+
else:
|
|
173
|
+
raise exception
|
|
174
|
+
raise myException("Failed after multiple retries.")
|
|
175
|
+
|
|
176
|
+
return wrapper
|
|
177
|
+
|
|
178
|
+
return decorator
|
|
179
|
+
|
|
180
|
+
@retry_on_eof_error(retry_count=3, verbose=True)
|
|
181
|
+
def __getitem__(self, batch):
|
|
182
|
+
"""
|
|
183
|
+
Fetch all samples for ids in the batch and where 'kind' is either
|
|
184
|
+
data or label as specified by the sample parameter
|
|
185
|
+
"""
|
|
186
|
+
samples = list(self.collection["bin"].find(
|
|
187
|
+
{
|
|
188
|
+
self.id: {
|
|
189
|
+
"$in": [self.indices[_] for _ in batch]
|
|
190
|
+
},
|
|
191
|
+
"kind": {
|
|
192
|
+
"$in": self.sample
|
|
193
|
+
},
|
|
194
|
+
},
|
|
195
|
+
self.fields,
|
|
196
|
+
))
|
|
197
|
+
results = {}
|
|
198
|
+
for id in batch:
|
|
199
|
+
# Separate samples for this id
|
|
200
|
+
samples_for_id = [
|
|
201
|
+
sample for sample in samples if sample[self.id] == self.indices[id]
|
|
202
|
+
]
|
|
203
|
+
|
|
204
|
+
# Separate processing for each 'kind'
|
|
205
|
+
data = self.make_serial(samples_for_id, self.sample[0])
|
|
206
|
+
label = self.make_serial(samples_for_id, self.sample[1])
|
|
207
|
+
|
|
208
|
+
# Add to results
|
|
209
|
+
results[id] = {
|
|
210
|
+
"input": self.normalize(self.transform(data).float()),
|
|
211
|
+
"label": self.transform(label),
|
|
212
|
+
}
|
|
213
|
+
return results
|
|
214
|
+
|
|
215
|
+
class MongoTupleheadDataset(MongoheadDataset):
|
|
216
|
+
"""
|
|
217
|
+
A dataset for fetching batches of records from a MongoDB
|
|
218
|
+
Returns a tuple instead of a dict
|
|
219
|
+
"""
|
|
220
|
+
|
|
221
|
+
def __init__(self, *args, **kwargs):
|
|
222
|
+
super().__init__(*args, **kwargs)
|
|
223
|
+
|
|
224
|
+
@MongoheadDataset.retry_on_eof_error(retry_count=3, verbose=True)
|
|
225
|
+
def __getitem__(self, batch):
|
|
226
|
+
"""
|
|
227
|
+
Fetch all samples for ids in the batch and return as tuples
|
|
228
|
+
"""
|
|
229
|
+
samples = list(self.collection["bin"].find(
|
|
230
|
+
{
|
|
231
|
+
self.id: {
|
|
232
|
+
"$in": [self.indices[_] for _ in batch]
|
|
233
|
+
},
|
|
234
|
+
"kind": {
|
|
235
|
+
"$in": self.sample
|
|
236
|
+
},
|
|
237
|
+
},
|
|
238
|
+
self.fields,
|
|
239
|
+
))
|
|
240
|
+
results = []
|
|
241
|
+
for id in batch:
|
|
242
|
+
samples_for_id = [
|
|
243
|
+
sample for sample in samples if sample[self.id] == self.indices[id]
|
|
244
|
+
]
|
|
245
|
+
|
|
246
|
+
data = self.make_serial(samples_for_id, self.sample[0])
|
|
247
|
+
label = self.make_serial(samples_for_id, self.sample[1])
|
|
248
|
+
|
|
249
|
+
data = self.normalize(self.transform(data).float())
|
|
250
|
+
label = self.transform(label)
|
|
251
|
+
|
|
252
|
+
results.append((data, label))
|
|
253
|
+
return results
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
""" Wirehead Generator Class """
|
|
2
|
+
|
|
3
|
+
import io
|
|
4
|
+
import os
|
|
5
|
+
import time
|
|
6
|
+
import yaml
|
|
7
|
+
import bson
|
|
8
|
+
import torch
|
|
9
|
+
from pymongo import MongoClient, ReturnDocument
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class WireheadGenerator():
|
|
13
|
+
"""
|
|
14
|
+
Wirehead runtime class, which wraps around the generator
|
|
15
|
+
and manager runtimes.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, generator, config_path, n_samples = 1000):
|
|
19
|
+
if config_path is None or os.path.exists(config_path) is False:
|
|
20
|
+
print("No valid config specified, exiting")
|
|
21
|
+
return
|
|
22
|
+
self.load_from_yaml(config_path)
|
|
23
|
+
self.generator = generator
|
|
24
|
+
self.n_samples = n_samples
|
|
25
|
+
|
|
26
|
+
def load_from_yaml(self, config_path):
|
|
27
|
+
""" Loads manager configs from config_path """
|
|
28
|
+
with open(config_path, 'r', encoding='utf-8') as file:
|
|
29
|
+
config = yaml.safe_load(file)
|
|
30
|
+
dbname = config.get('DBNAME')
|
|
31
|
+
mongohost = config.get('MONGOHOST')
|
|
32
|
+
port = config.get('PORT') if config.get('PORT') is not None else 27017
|
|
33
|
+
client = MongoClient("mongodb://" + mongohost + ":" + str(port))
|
|
34
|
+
|
|
35
|
+
self.db = client[dbname]
|
|
36
|
+
self.swap_cap = config.get('SWAP_CAP')
|
|
37
|
+
self.sample = tuple(config.get("SAMPLE"))
|
|
38
|
+
self.chunksize = config.get("CHUNKSIZE")
|
|
39
|
+
self.collectionw = config.get("WRITE_COLLECTION") + ".bin"
|
|
40
|
+
self.collectionc = config.get("COUNTER_COLLECTION")
|
|
41
|
+
|
|
42
|
+
def chunkify(self, data, index):
|
|
43
|
+
"""
|
|
44
|
+
Converts a tuple of tensors and their labels into
|
|
45
|
+
a list of chunks of serialized objects for mongodb
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def chunk_binobj(tensor_compressed, idx, kind, chunksize):
|
|
49
|
+
"""
|
|
50
|
+
Convert chunksize from megabytes to bytes
|
|
51
|
+
"""
|
|
52
|
+
chunksize_bytes = chunksize * 1024 * 1024
|
|
53
|
+
# Calculate the number of chunks
|
|
54
|
+
num_chunks = len(tensor_compressed) // chunksize_bytes
|
|
55
|
+
if len(tensor_compressed) % chunksize_bytes != 0:
|
|
56
|
+
num_chunks += 1
|
|
57
|
+
# Yield chunks
|
|
58
|
+
for i in range(num_chunks):
|
|
59
|
+
start = i * chunksize_bytes
|
|
60
|
+
end = min((i + 1) * chunksize_bytes, len(tensor_compressed))
|
|
61
|
+
chunk = tensor_compressed[start:end]
|
|
62
|
+
yield {
|
|
63
|
+
"id": idx,
|
|
64
|
+
"chunk_id": i,
|
|
65
|
+
"kind": kind,
|
|
66
|
+
"chunk": bson.Binary(chunk),
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
def tensor2bin(tensor):
|
|
70
|
+
"""
|
|
71
|
+
Seralize a torch tensor into an IO buffer
|
|
72
|
+
"""
|
|
73
|
+
tensor_1d = tensor.to(torch.uint8)
|
|
74
|
+
buffer = io.BytesIO()
|
|
75
|
+
torch.save(tensor_1d, buffer)
|
|
76
|
+
tensor_binary = buffer.getvalue()
|
|
77
|
+
return tensor_binary
|
|
78
|
+
|
|
79
|
+
chunks = []
|
|
80
|
+
binobj = data
|
|
81
|
+
kinds = self.sample
|
|
82
|
+
for i, kind in enumerate(kinds):
|
|
83
|
+
chunks += list(
|
|
84
|
+
chunk_binobj(tensor2bin(torch.from_numpy(binobj[i])), index, kind,
|
|
85
|
+
self.chunksize))
|
|
86
|
+
return chunks
|
|
87
|
+
|
|
88
|
+
def push_chunks(self, chunks):
|
|
89
|
+
""" Pushes chunkified tensors to mongodb, with error handling"""
|
|
90
|
+
collection_bin = self.db[self.collectionw]
|
|
91
|
+
try:
|
|
92
|
+
collection_bin.insert_many(chunks)
|
|
93
|
+
except Exception as exception:
|
|
94
|
+
print(f"Generator: An error occurred: {exception}, are you swapping?")
|
|
95
|
+
time.sleep(1)
|
|
96
|
+
|
|
97
|
+
def get_current_idx(self):
|
|
98
|
+
""" Get current index of sample in write collection """
|
|
99
|
+
dbc = self.db[self.collectionc]
|
|
100
|
+
counter_doc = dbc.find_one_and_update(
|
|
101
|
+
{"_id": "uniqueFieldCounter"},
|
|
102
|
+
{"$inc": {
|
|
103
|
+
"sequence_value": 1
|
|
104
|
+
}},
|
|
105
|
+
return_document=ReturnDocument.BEFORE,
|
|
106
|
+
)
|
|
107
|
+
return counter_doc["sequence_value"]
|
|
108
|
+
|
|
109
|
+
def generate_and_insert(self):
|
|
110
|
+
""" Fetch from generator and inserts into mongodb """
|
|
111
|
+
# 0. Fetch data from generator
|
|
112
|
+
data = next(self.generator)
|
|
113
|
+
# 1. Get the correct index for this current sample
|
|
114
|
+
index = self.get_current_idx()
|
|
115
|
+
# 2. Turn the data into a list of serialized chunks
|
|
116
|
+
chunks = self.chunkify(data, index)
|
|
117
|
+
# 3. Push to mongodb + error handling
|
|
118
|
+
if index < self.swap_cap:
|
|
119
|
+
self.push_chunks(chunks)
|
|
120
|
+
|
|
121
|
+
def run_generator(self):
|
|
122
|
+
""" Initializes and runs a SynthSeg brain generator in a loop,
|
|
123
|
+
preprocesses, then pushes to mongoDB"""
|
|
124
|
+
print("Generator: Initialized")
|
|
125
|
+
n_samples = self.n_samples
|
|
126
|
+
for _ in range(n_samples):
|
|
127
|
+
self.generate_and_insert()
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
""" Wirehead Manager Class """
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import time
|
|
5
|
+
import yaml
|
|
6
|
+
from pymongo import MongoClient, ASCENDING
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class WireheadManager():
|
|
10
|
+
"""
|
|
11
|
+
Manages the state of the mongo collections in Wirehead.
|
|
12
|
+
|
|
13
|
+
:param config_path: path to yaml file containing wirehead configs
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, config_path):
|
|
17
|
+
if config_path is None or os.path.exists(config_path) is False:
|
|
18
|
+
print("No valid config specified, exiting")
|
|
19
|
+
return
|
|
20
|
+
self.load_from_yaml(config_path)
|
|
21
|
+
|
|
22
|
+
def load_from_yaml(self, config_path):
|
|
23
|
+
"""
|
|
24
|
+
Loads manager configs from config_path.
|
|
25
|
+
"""
|
|
26
|
+
with open(config_path, 'r', encoding='utf-8') as file:
|
|
27
|
+
config = yaml.safe_load(file)
|
|
28
|
+
|
|
29
|
+
dbname = config.get('DBNAME')
|
|
30
|
+
mongohost = config.get('MONGOHOST')
|
|
31
|
+
port = config.get('PORT') if config.get('PORT') is not None else 27017
|
|
32
|
+
client = MongoClient("mongodb://" + mongohost + ":" + str(port))
|
|
33
|
+
|
|
34
|
+
self.db = client[dbname]
|
|
35
|
+
self.swap_cap = config.get('SWAP_CAP')
|
|
36
|
+
self.collectionw = config.get("WRITE_COLLECTION") + ".bin"
|
|
37
|
+
self.collectionr = config.get("READ_COLLECTION") + ".bin"
|
|
38
|
+
self.collectionc = config.get("COUNTER_COLLECTION")
|
|
39
|
+
self.collectiont = config.get("TEMP_COLLECTION") + ".bin"
|
|
40
|
+
|
|
41
|
+
def run_manager(self):
|
|
42
|
+
"""
|
|
43
|
+
Initializes the database manager, swaps and cleans the database whenever swap_cap is hit.
|
|
44
|
+
"""
|
|
45
|
+
print("Manager: Initialized")
|
|
46
|
+
self.db["status"].insert_one({"swapped": False})
|
|
47
|
+
self.reset_counter_and_collection()
|
|
48
|
+
generated = 0
|
|
49
|
+
while True:
|
|
50
|
+
generated = self.watch_and_swap(generated)
|
|
51
|
+
|
|
52
|
+
def verify_collection_integrity(self, collection):
|
|
53
|
+
"""
|
|
54
|
+
Verifies collection contains contiguous elements with id 0..swap_cap
|
|
55
|
+
"""
|
|
56
|
+
unique_ids_count = len(collection.distinct("id"))
|
|
57
|
+
assert (
|
|
58
|
+
unique_ids_count == self.swap_cap
|
|
59
|
+
), f"Manager: Expected {self.swap_cap} unique ids, found {unique_ids_count}"
|
|
60
|
+
expected_ids_set = set(range(self.swap_cap))
|
|
61
|
+
actual_ids_set = set(collection.distinct("id"))
|
|
62
|
+
assert (expected_ids_set == actual_ids_set
|
|
63
|
+
), "Manager: The ids aren't continuous from 0 to self.swap_cap - 1"
|
|
64
|
+
|
|
65
|
+
def reset_counter_and_collection(self):
|
|
66
|
+
"""
|
|
67
|
+
Delete all documents in the main collection that have creeped in
|
|
68
|
+
between the renaming and now. This operation is within a transaction.
|
|
69
|
+
"""
|
|
70
|
+
dbw = self.db[self.collectionw]
|
|
71
|
+
dbc = self.db[self.collectionc]
|
|
72
|
+
dbw.delete_many({}) # wipe the write collection
|
|
73
|
+
# Reset the counter to zero
|
|
74
|
+
_result = dbc.update_one(
|
|
75
|
+
{"_id": "uniqueFieldCounter"}, # Query part: the document to match
|
|
76
|
+
{"$set": {
|
|
77
|
+
"sequence_value": 0
|
|
78
|
+
}}, # Update part: what to set if the document is matched/found
|
|
79
|
+
upsert=True,
|
|
80
|
+
)
|
|
81
|
+
dbw.delete_many({})
|
|
82
|
+
dbw.create_index([("id", ASCENDING)], background=True)
|
|
83
|
+
|
|
84
|
+
def swap(self, generated):
|
|
85
|
+
"""
|
|
86
|
+
Moves data from write collection to read collection
|
|
87
|
+
Deletes old write collection
|
|
88
|
+
Maintains data integrity in between
|
|
89
|
+
"""
|
|
90
|
+
time.sleep(2) # Buffer for incomplete ops
|
|
91
|
+
generated += self.swap_cap
|
|
92
|
+
print("\n----swap----")
|
|
93
|
+
print(f"Manager: Generated samples so far {generated}")
|
|
94
|
+
self.db[self.collectionw].rename(self.collectiont, dropTarget=True)
|
|
95
|
+
# Now atomically reset the counter to 0 and delete whatever records
|
|
96
|
+
# may have been written between the execution of the previous line
|
|
97
|
+
# and the next
|
|
98
|
+
self.reset_counter_and_collection() # this is atomic
|
|
99
|
+
result = self.db[self.collectiont].delete_many(
|
|
100
|
+
{"id": {
|
|
101
|
+
"$gt": self.swap_cap - 1
|
|
102
|
+
}})
|
|
103
|
+
# Print the result of the deletion
|
|
104
|
+
print(f"Manager: Documents deleted: {result.deleted_count}")
|
|
105
|
+
self.verify_collection_integrity(self.db[self.collectiont])
|
|
106
|
+
self.db[self.collectiont].rename(self.collectionr, dropTarget=True)
|
|
107
|
+
self.db["status"].insert_one({"swapped": True})
|
|
108
|
+
return generated
|
|
109
|
+
|
|
110
|
+
def watch_and_swap(self, generated):
|
|
111
|
+
"""
|
|
112
|
+
Watch the write collection and swap when full
|
|
113
|
+
"""
|
|
114
|
+
counter_doc = self.db[self.collectionc].find_one({"_id": "uniqueFieldCounter"})
|
|
115
|
+
if counter_doc["sequence_value"] >= self.swap_cap: # watch
|
|
116
|
+
return self.swap(generated) # swap
|
|
117
|
+
return generated
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
setup.py
|
|
4
|
+
wirehead/__init__.py
|
|
5
|
+
wirehead/dataset.py
|
|
6
|
+
wirehead/generator.py
|
|
7
|
+
wirehead/manager.py
|
|
8
|
+
wirehead.egg-info/PKG-INFO
|
|
9
|
+
wirehead.egg-info/SOURCES.txt
|
|
10
|
+
wirehead.egg-info/dependency_links.txt
|
|
11
|
+
wirehead.egg-info/requires.txt
|
|
12
|
+
wirehead.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
wirehead
|